Skip to content

Add single-dispatch layer-by-layer multi-head attention#91

Draft
andrej wants to merge 8 commits intoamd:develfrom
andrej:mha-lxl-sd
Draft

Add single-dispatch layer-by-layer multi-head attention#91
andrej wants to merge 8 commits intoamd:develfrom
andrej:mha-lxl-sd

Conversation

@andrej
Copy link
Copy Markdown
Collaborator

@andrej andrej commented Apr 6, 2026

"Naive" alternative implementation for multi-head attention from the currently checked-in data-flow design. This is a simple layer-by-layer implementation, but it uses the single-dispatch mechanism to fuse it all into one MLIR file and save on CPU roundtrips and XRT overheads.

Includes two variants:

  1. "core": Only does the core matmuls and softmax; assumes projected and repeated inputs Q, K, V. This matches the functionality of the checked-in dataflow MHA.
  2. "projected": Performs the Q, K, V projections, applies a RoPE positional embedding and repeats K and V matrices for grouped-query attention. Takes an embedding vector and RoPE angles as input.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse the reference from the existing mha? (Note: does not include RoPE and Q, K, V projections, but some code reuse should be possible.)

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 7, 2026

📊 Test Results for Test Example Applications

1d87fe8 (2026_04_07_21_05_39)

IRONCLAD

Tested on 2026_04_07_21_05_39 at commit 1d87fe8.

Test Checks TTFT (mean)TPS (mean)
llama_3.2_1b_prompt_1024_tokens_1 ✅ 5/5 2.13 n/a
llama_3.2_1b_prompt_1024_tokens_40 ✅ 5/5 2.18 4.31
llama_3.2_1b_prompt_13_tokens_1 ✅ 5/5 2.09 n/a
llama_3.2_1b_prompt_13_tokens_40 ✅ 5/5 2.09 4.31
📈 Trends (vs main branch) for Test Example Applications

1d87fe8 (2026_04_07_21_05_39)

IRONCLAD Trends

llama_3.2_1b

Commit/Date Num Tokens (max)Num Tokens (mean)Num Tokens (median)Num Tokens (min)Num Tokens (stddev)TPS (max)TPS (mean)TPS (median)TPS (min)TPS (stddev)TTFT (max)TTFT (mean)TTFT (median)TTFT (min)TTFT (stddev)Total (max)Total (mean)Total (median)Total (min)Total (stddev)
130b6ea — 2025-12-05 21:33:1240.00 (+0.00%)40.00 (+0.00%)40.00 (+0.00%)40.00 (+0.00%)0.00 (n/a)4.71 (-0.42%)4.64 (-0.09%)4.64 (+0.65%)4.55 (-0.22%)0.05 (-17.66%)4.41 (-0.34%)4.39 (-0.19%)4.38 (-0.33%)4.37 (-0.15%)0.01 (-25.90%)12.96 (-0.00%)12.80 (+0.07%)12.80 (-0.23%)12.67 (+0.44%)0.09 (-21.12%)
0a6c11c — 2025-12-03 23:35:1540.00 (n/a)40.00 (n/a)40.00 (n/a)40.00 (n/a)0.00 (n/a)4.73 (n/a)4.64 (n/a)4.61 (n/a)4.56 (n/a)0.06 (n/a)4.42 (n/a)4.40 (n/a)4.40 (n/a)4.37 (n/a)0.02 (n/a)12.96 (n/a)12.79 (n/a)12.83 (n/a)12.62 (n/a)0.12 (n/a)

llama_3.2_1b_prompt_1024_tokens_1

Commit/Date TTFT (max)TTFT (mean)TTFT (median)TTFT (min)TTFT (stddev)
1d87fe8 — 2026-04-07 21:00:002.15 (+0.09%)2.13 (+0.08%)2.13 (-0.42%)2.12 (+0.62%)0.01 (-31.21%)
912e6bc — 2026-04-07 19:08:432.15 (n/a)2.13 (n/a)2.13 (n/a)2.11 (n/a)0.02 (n/a)

llama_3.2_1b_prompt_1024_tokens_40

Commit/Date TPS (max)TPS (mean)TPS (median)TPS (min)TPS (stddev)TTFT (max)TTFT (mean)TTFT (median)TTFT (min)TTFT (stddev)
1d87fe8 — 2026-04-07 21:00:004.33 (+2.90%)4.31 (+3.44%)4.31 (+3.58%)4.29 (+3.77%)0.01 (-46.93%)2.29 (+0.48%)2.18 (+0.83%)2.15 (+0.80%)2.13 (+0.61%)0.07 (-4.73%)
912e6bc — 2026-04-07 19:08:434.21 (n/a)4.17 (n/a)4.16 (n/a)4.14 (n/a)0.03 (n/a)2.28 (n/a)2.16 (n/a)2.13 (n/a)2.12 (n/a)0.07 (n/a)

llama_3.2_1b_prompt_13_tokens_1

Commit/Date TTFT (max)TTFT (mean)TTFT (median)TTFT (min)TTFT (stddev)
1d87fe8 — 2026-04-07 21:00:002.10 (-0.10%)2.09 (+0.11%)2.09 (+0.19%)2.09 (+0.00%)0.01 (+8.87%)
912e6bc — 2026-04-07 19:08:432.10 (n/a)2.09 (n/a)2.09 (n/a)2.09 (n/a)0.01 (n/a)

llama_3.2_1b_prompt_13_tokens_40

Commit/Date TPS (max)TPS (mean)TPS (median)TPS (min)TPS (stddev)TTFT (max)TTFT (mean)TTFT (median)TTFT (min)TTFT (stddev)
1d87fe8 — 2026-04-07 21:00:004.36 (+4.23%)4.31 (+3.57%)4.30 (+3.44%)4.29 (+3.23%)0.03 (+128.30%)2.09 (-0.38%)2.09 (-0.04%)2.09 (+0.00%)2.08 (+0.44%)0.01 (-34.93%)
912e6bc — 2026-04-07 19:08:434.18 (n/a)4.16 (n/a)4.16 (n/a)4.15 (n/a)0.01 (n/a)2.10 (n/a)2.09 (n/a)2.09 (n/a)2.07 (n/a)0.01 (n/a)

llama_3.2_1b_prompt_2048_tokens_1

Commit/Date Num_Tokens (max)Num_Tokens (mean)Num_Tokens (median)Num_Tokens (min)Num_Tokens (stddev)TPS (max)TPS (mean)TPS (median)TPS (min)TPS (stddev)TTFT (max)TTFT (mean)TTFT (median)TTFT (min)TTFT (stddev)
897d04e — 2026-03-06 22:56:071.00 (+0.00%)1.00 (+0.00%)1.00 (+0.00%)1.00 (+0.00%)0.00 (n/a)0.00 (n/a)0.00 (n/a)0.00 (n/a)0.00 (n/a)0.00 (n/a)2.68 (-1.06%)2.68 (-1.06%)2.68 (-1.06%)2.68 (-1.06%)0.00 (n/a)
84d3478 — 2026-02-17 23:16:231.00 (n/a)1.00 (n/a)1.00 (n/a)1.00 (n/a)0.00 (n/a)0.00 (n/a)0.00 (n/a)0.00 (n/a)0.00 (n/a)0.00 (n/a)2.70 (n/a)2.70 (n/a)2.70 (n/a)2.70 (n/a)0.00 (n/a)

llama_3.2_1b_prompt_2048_tokens_40

Commit/Date Num_Tokens (max)Num_Tokens (mean)Num_Tokens (median)Num_Tokens (min)Num_Tokens (stddev)TPS (max)TPS (mean)TPS (median)TPS (min)TPS (stddev)TTFT (max)TTFT (mean)TTFT (median)TTFT (min)TTFT (stddev)
897d04e — 2026-03-06 22:56:0740.00 (+0.00%)40.00 (+0.00%)40.00 (+0.00%)40.00 (+0.00%)0.00 (n/a)4.00 (-1.72%)4.00 (-1.72%)4.00 (-1.72%)4.00 (-1.72%)0.00 (n/a)2.70 (-0.44%)2.70 (-0.44%)2.70 (-0.44%)2.70 (-0.44%)0.00 (n/a)
84d3478 — 2026-02-17 23:16:2340.00 (n/a)40.00 (n/a)40.00 (n/a)40.00 (n/a)0.00 (n/a)4.07 (n/a)4.07 (n/a)4.07 (n/a)4.07 (n/a)0.00 (n/a)2.71 (n/a)2.71 (n/a)2.71 (n/a)2.71 (n/a)0.00 (n/a)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant