amd-mla-decode

Deadline

10 days 3 hours (2025-06-02 00:00 UTC)

Language

Python

GPU Type

MI300

Description

You will implement a custom mla decode kernel optimized for MI300, a few things simplified here: 1. Q, K, V data type as bfloat16 2. decode only with pre-allocated non-paged latent kv cache 3. return the update kv cache with MLA output The shapes of all outer and inner dimensions of tensors are from DeepSeek-R1, and split number of heads to fit in one GPU. To be explicit, you will be given a tuple to tensors: ```yml input [bs, sq, dim] attn_output [bs, n_heads, sq, v_head_dim] kv_cache [bs, sq, kv_lora_rank + qk_rope_head_dim] ``` where 0. bs::128 # batch size 1. prefill::[512, 2048, 4096, 6144] # as kv length 2. sq::1 # as only consider decoding 3. dim::7168 # hidden size of deepseek v3 4. kv_lora_rank::[512] # kv lora rank of deepseek v3 5. qk_rope_head_dim::[64] # rope embedding dimension 6. v_head_dim::128 # head size 7. n_heads::128 # num of attn heads The ranking criteria is the geometric mean of the benchmark results. For the grand prize, your kernel will be evaluated against the speed of light analysis and the solution closest to the speed of light will be awarded the grand prize. The speed of light analysis is:: | bs | prefill | sq | dtype | roofline time(us) | |---|---|---|---|---| | 128 | 512 | 1 | bf16 | 54.62 | | 128 | 2048 | 1 | bf16 | 141.16 | | 128 | 4096 | 1 | bf16 | 210.75 | | 128 | 6144 | 1 | bf16 | 280.87 |

Reference Implementation

import math
from dataclasses import dataclass
import torch
from torch import nn
import torch.nn.functional as F
from task import input_t, output_t
from utils import make_match_reference

class RoPE(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model
        theta = 10000 ** (-torch.arange(0, d_model//2,dtype=torch.bfloat16) / (d_model//2))
        self.register_buffer("theta", theta)

    def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)

    def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
        seq_len = x.size(-2)
        d_model = x.size(-1)
        assert d_model == self.d_model
        seq_idx = torch.arange(start_pos, start_pos + seq_len, device=x.device)
        idx_theta = torch.einsum('s,d->sd', seq_idx, self.theta)
        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=-1)
        cos = idx_theta2.cos().to(torch.bfloat16)
        sin = idx_theta2.sin().to(torch.bfloat16)
        return x * cos + self.rotate_half(x) * sin

class KVCache(nn.Module):
    def __init__(self, kv_cache_shape: tuple, **kwargs) -> None:
        super().__init__(**kwargs)
        self.register_buffer('data', torch.zeros(kv_cache_shape, dtype=torch.bfloat16))
        self.seq_len = 0
        self.zero()

    def zero(self) -> None:
        self.data.zero_()
    
    def get_data(self) -> torch.Tensor:
        return self.data

    def forward(self, c_kv: torch.Tensor) -> torch.Tensor:
        assert self.seq_len + c_kv.size(1) <= self.data.size(1), "KV Cache Exceeded"

        self.data = self.data.to(c_kv.dtype)
        self.data[
            :, self.seq_len : self.seq_len + c_kv.size(1), :
        ] = c_kv
        self.seq_len += c_kv.size(1)

        return self.data[:, :self.seq_len], self.seq_len
    
@dataclass
class Config:
    batch_size: int
    dim: int
    n_heads: int
    q_lora_rank: int 
    kv_lora_rank: int
    qk_nope_head_dim: int
    qk_rope_head_dim: int
    v_head_dim: int
    seq_len: int
    max_seq_len: int
    kv_cache_shape: tuple
    Q_proj_down_weight: torch.Tensor
    Q_proj_up_weight: torch.Tensor
    KV_proj_down_weight: torch.Tensor
    KV_proj_up_weight: torch.Tensor
    wo_weight: torch.Tensor

class MLA(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.dim = config.dim
        self.n_heads = config.n_heads
        self.q_lora_rank = config.q_lora_rank
        self.kv_lora_rank = config.kv_lora_rank
        self.nope_head_dim = config.qk_nope_head_dim
        self.rope_head_dim = config.qk_rope_head_dim
        self.v_head_dim = config.v_head_dim
        # Down-projection matrices
        self.Q_proj_down = nn.Linear(self.dim, self.q_lora_rank, dtype=torch.bfloat16, bias=False)
        self.KV_proj_down = nn.Linear(self.dim, self.kv_lora_rank + self.rope_head_dim, dtype=torch.bfloat16, bias=False)

        # Up-projection and rope projection matrices
        self.Q_proj_up = nn.Linear(self.q_lora_rank, (self.nope_head_dim + self.rope_head_dim) * self.n_heads, dtype=torch.bfloat16, bias=False)
        self.KV_proj_up = nn.Linear(self.kv_lora_rank, (self.nope_head_dim + self.v_head_dim) * self.n_heads, dtype=torch.bfloat16, bias=False)

        # RoPE on half embeddings
        self.q_rope = RoPE(self.rope_head_dim)
        self.k_rope = RoPE(self.rope_head_dim)

        # Output projection
        self.wo = nn.Linear(self.v_head_dim * self.n_heads, self.dim, dtype=torch.bfloat16, bias=False)
        self.eps = 1e-6
   
    def forward(self, x: torch.Tensor, kv_cache: KVCache) -> torch.Tensor:
        # seq_len = 1 always here
        batch_size, seq_len, model_dim = x.size()

        ################################################################################
        #                 Step 1: Handle down-projection + KV cache                    #
        ################################################################################
        q_lora = self.Q_proj_down(x)
        kv_lora = self.KV_proj_down(x)
        kv_lora, kv_len = kv_cache(kv_lora)
        query_pos = kv_len - 1

        ################################################################################
        #                  Step 2: Up-project and prepare NoPE + RoPE                  #
        ################################################################################

        # Handle queries Q first
        q_nope_and_rope = self.Q_proj_up(q_lora).view(
            batch_size, seq_len, self.n_heads, self.nope_head_dim + self.rope_head_dim)
        q_nope, q_rope = torch.split(q_nope_and_rope, [self.nope_head_dim, self.rope_head_dim], dim=-1)

        # Handle keys and values K/V. V does not need RoPE
        kv_nope, k_rope = torch.split(kv_lora, [self.kv_lora_rank, self.rope_head_dim], dim=-1)
        kv_nope = self.KV_proj_up(kv_nope).view(
            batch_size, kv_len, self.n_heads, self.nope_head_dim + self.v_head_dim)
        k_nope, v = torch.split(kv_nope, [self.nope_head_dim, self.v_head_dim], dim=-1)

        ################################################################################
        #                    Step 3: Handle RoPE Stream                                #
        ################################################################################

        # Compute RoPE for queries and combine with no-RoPE part
        q_rope = q_rope.reshape(batch_size, self.n_heads, seq_len, self.rope_head_dim)
        q_rope = self.q_rope(q_rope, start_pos=query_pos)
        q_rope = q_rope.reshape(batch_size, seq_len, self.n_heads, self.rope_head_dim)
        q = torch.concat([q_nope, q_rope], dim=-1)

        # Compute RoPE for keys and combine with no-RoPE part
        k_rope = k_rope[:, :, None, :]
        k_rope = self.k_rope(k_rope).expand(-1,-1,self.n_heads,-1)
        k = torch.concat([k_nope, k_rope], dim=-1)
                
        ################################################################################
        #                        Compute Multi-head Attention                          #
        ################################################################################
        q = q.reshape(batch_size, self.n_heads, 1, self.rope_head_dim + self.nope_head_dim)
        k = k.reshape(batch_size, self.n_heads, kv_len, self.rope_head_dim + self.nope_head_dim)
        v = v.reshape(batch_size, self.n_heads, -1, self.v_head_dim)
        scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.rope_head_dim + self.nope_head_dim)
        attn = F.softmax(scores, dim=-1).to(torch.bfloat16)
        y = torch.matmul(attn, v).view(batch_size, 1, -1)
        y = self.wo(y)

        return y, kv_cache.get_data()

def generate_input(batchsize, dim, dq, prefill, seed):
    # Sizes derived from: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
    gen = torch.Generator(device='cuda')
    gen.manual_seed(seed)
    
    # Generate weights for linear layers
    Q_proj_down_weight = torch.randn((dq, dim), dtype=torch.bfloat16, generator=gen, device='cuda') / math.sqrt(dim)
    KV_proj_down_weight = torch.randn((512 + 64, dim), dtype=torch.bfloat16, generator=gen, device='cuda') / math.sqrt(dim)
    Q_proj_up_weight = torch.randn(((128 + 64) * 128, dq), dtype=torch.bfloat16, generator=gen, device='cuda') / math.sqrt(dq)
    KV_proj_up_weight = torch.randn(((128 + 128) * 128, 512), dtype=torch.bfloat16, generator=gen, device='cuda') / math.sqrt(512)
    wo_weight = torch.randn((dim, 128 * 128), dtype=torch.bfloat16, generator=gen, device='cuda') / math.sqrt(dim)

    config = Config(
        batch_size=batchsize,
        dim=dim,
        q_lora_rank=dq,
        n_heads=128,
        kv_lora_rank=512,
        qk_nope_head_dim=128,
        qk_rope_head_dim=64,
        v_head_dim=128,
        seq_len=1,
        max_seq_len=8192,
        kv_cache_shape=(batchsize, 8192, 512 + 64),
        Q_proj_down_weight=Q_proj_down_weight,
        Q_proj_up_weight=Q_proj_up_weight,
        KV_proj_down_weight=KV_proj_down_weight,
        KV_proj_up_weight=KV_proj_up_weight,
        wo_weight=wo_weight,
    )
    x = torch.randn((config.batch_size, 1, config.dim), dtype=torch.bfloat16, generator=gen, device='cuda')
    
    # Pre-fill KV cache
    kv_cache = KVCache((config.batch_size, config.max_seq_len, config.kv_lora_rank + config.qk_rope_head_dim)).to('cuda')
    pre_filled_cache = torch.randn((config.batch_size, prefill, config.kv_lora_rank + config.qk_rope_head_dim), 
                                 dtype=torch.bfloat16, generator=gen, device='cuda')
    kv_cache(pre_filled_cache)

    return config, x, kv_cache

def ref_kernel(data: input_t) -> output_t:
    config, x, kv_cache = data

    # Load in model weights
    model = MLA(config).to('cuda')
    model.Q_proj_down.weight = nn.Parameter(config.Q_proj_down_weight)
    model.Q_proj_up.weight = nn.Parameter(config.Q_proj_up_weight)
    model.KV_proj_down.weight = nn.Parameter(config.KV_proj_down_weight)
    model.KV_proj_up.weight = nn.Parameter(config.KV_proj_up_weight)
    model.wo.weight = nn.Parameter(config.wo_weight)

    output, kv_cache = model(x, kv_cache)
    return output, kv_cache

check_implementation = make_match_reference(ref_kernel, rtol=2e-02, atol=1e-03)  


def time_mla(model, x, kv_cache, num_warmup=3, num_trials=5):

    # Warmup runs
    for _ in range(1):
        output, _ = model(x, kv_cache)
        torch.cuda.synchronize()

    # Timed runs
    times = []
    for _ in range(num_trials):
        kv_cache = KVCache((config.batch_size, config.max_seq_len, config.kv_lora_rank + config.qk_rope_head_dim)).to('cuda')
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        
        start.record()
        output, updated_kv = model(x, kv_cache)
        end.record()
        
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))

    avg_time = sum(times) / len(times)
    return output, updated_kv, avg_time, times

if __name__ == "__main__":
    # Generate test input
    batchsize = 128
    dim = 7168 
    dq = 1536
    prefill = 512
    seed = 97

    # Create model and inputs
    config, x, kv_cache = generate_input(batchsize, dim, dq, prefill, seed)
    model = MLA(config).to('cuda')

    # Run model with timing
    output, updated_kv, avg_time, times = time_mla(model, x, kv_cache)

    # Test reference kernel
    ref_output, ref_kv = ref_kernel((config, x, kv_cache))
    print("
Reference kernel output:")
    print(f"Output shape: {ref_output.shape}")
    print(f"KV cache shape: {ref_kv.shape}")
    print("
First few values of reference output:")
    print(ref_output[0, :10])

    # Compare outputs
    print("
Output difference:")
    print(f"Max absolute difference: {torch.max(torch.abs(output - ref_output))}")
    print(f"Mean absolute difference: {torch.mean(torch.abs(output - ref_output))}")

    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Updated KV cache shape: {updated_kv.shape}")
    print("
First few values of output:")
    print(output[0, :10])
    print(f"
Timing results over {len(times)} runs (ms):")
    print(f"Average: {avg_time:.2f}")
    print(f"Individual times: {[f'{t:.2f}' for t in times]}")

Rankings

MI300

kfz 🥇 112904.773μs mla_decode_v0.py
Shinsato Masumi 🥈 113041.088μs   +136.315μs submission-mla.py
myy1966 🥉 113100.929μs   +59.841μs submission.py
rt11 1061545.559μs   +948444.629μs v1.py
az 1061624.320μs   +78.761μs submission.py
intrinsicmode 1239468.056μs   +177843.736μs submission.py
fanwenjie 1249841.716μs   +10373.660μs mla-decode.py
gowtham_tupili 7351013.915μs   +6101172.199μs 3.py
legendary_fawn_56575 8605724.184μs   +1254710.269μs submission.py
Dhanshre 8874763.829μs   +269039.644μs amd-mla-decode.py
Ding 9299385.466μs   +424621.637μs submission.py