amd-mla-decode

Deadline

74 days 14 hours remaining (2025-09-02 00:00 UTC)

Language

Python

GPU Type

MI300

Description

You will implement a custom mla decode kernel optimized for MI300, a few things simplified here: 1. Q, K, V data type as bfloat16 2. decode only with pre-allocated non-paged latent kv cache 3. return the update kv cache with MLA output The shapes of all outer and inner dimensions of tensors are from DeepSeek-R1, and split number of heads to fit in one GPU. To be explicit, you will be given a tuple to tensors: ```yml input [bs, sq, dim] attn_output [bs, n_heads, sq, v_head_dim] kv_cache [bs, sq, kv_lora_rank + qk_rope_head_dim] ``` where 0. bs::128 # batch size 1. prefill::[512, 2048, 4096, 6144] # as kv length 2. sq::1 # as only consider decoding 3. dim::7168 # hidden size of deepseek v3 4. kv_lora_rank::[512] # kv lora rank of deepseek v3 5. qk_rope_head_dim::[64] # rope embedding dimension 6. v_head_dim::128 # head size 7. n_heads::128 # num of attn heads The ranking criteria is the geometric mean of the benchmark results. For the grand prize, your kernel will be evaluated against the speed of light analysis and the solution closest to the speed of light will be awarded the grand prize. The speed of light analysis is:: | bs | prefill | sq | dtype | roofline time(us) | |---|---|---|---|---| | 128 | 512 | 1 | bf16 | 54.62 | | 128 | 2048 | 1 | bf16 | 141.16 | | 128 | 4096 | 1 | bf16 | 210.75 | | 128 | 6144 | 1 | bf16 | 280.87 |

Reference Implementation

import math
from dataclasses import dataclass
import torch
from torch import nn
import torch.nn.functional as F
from task import input_t, output_t
from utils import make_match_reference

class RoPE(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model
        theta = 10000 ** (-torch.arange(0, d_model//2,dtype=torch.bfloat16) / (d_model//2))
        self.register_buffer("theta", theta)

    def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)

    def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
        seq_len = x.size(-2)
        d_model = x.size(-1)
        assert d_model == self.d_model
        seq_idx = torch.arange(start_pos, start_pos + seq_len, device=x.device)
        idx_theta = torch.einsum('s,d->sd', seq_idx, self.theta)
        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=-1)
        cos = idx_theta2.cos().to(torch.bfloat16)
        sin = idx_theta2.sin().to(torch.bfloat16)
        return x * cos + self.rotate_half(x) * sin

class KVCache(nn.Module):
    def __init__(self, kv_cache_shape: tuple, **kwargs) -> None:
        super().__init__(**kwargs)
        self.register_buffer('data', torch.zeros(kv_cache_shape, dtype=torch.bfloat16))
        self.seq_len = 0
        self.zero()

    def zero(self) -> None:
        self.data.zero_()
    
    def get_data(self) -> torch.Tensor:
        return self.data

    def forward(self, c_kv: torch.Tensor) -> torch.Tensor:
        assert self.seq_len + c_kv.size(1) <= self.data.size(1), "KV Cache Exceeded"

        self.data = self.data.to(c_kv.dtype)
        self.data[
            :, self.seq_len : self.seq_len + c_kv.size(1), :
        ] = c_kv
        self.seq_len += c_kv.size(1)

        return self.data[:, :self.seq_len], self.seq_len
    
@dataclass
class Config:
    batch_size: int
    dim: int
    n_heads: int
    q_lora_rank: int 
    kv_lora_rank: int
    qk_nope_head_dim: int
    qk_rope_head_dim: int
    v_head_dim: int
    seq_len: int
    max_seq_len: int
    kv_cache_shape: tuple
    Q_proj_down_weight: torch.Tensor
    Q_proj_up_weight: torch.Tensor
    KV_proj_down_weight: torch.Tensor
    KV_proj_up_weight: torch.Tensor
    wo_weight: torch.Tensor

class MLA(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.dim = config.dim
        self.n_heads = config.n_heads
        self.q_lora_rank = config.q_lora_rank
        self.kv_lora_rank = config.kv_lora_rank
        self.nope_head_dim = config.qk_nope_head_dim
        self.rope_head_dim = config.qk_rope_head_dim
        self.v_head_dim = config.v_head_dim
        # Down-projection matrices
        self.Q_proj_down = nn.Linear(self.dim, self.q_lora_rank, dtype=torch.bfloat16, bias=False)
        self.KV_proj_down = nn.Linear(self.dim, self.kv_lora_rank + self.rope_head_dim, dtype=torch.bfloat16, bias=False)

        # Up-projection and rope projection matrices
        self.Q_proj_up = nn.Linear(self.q_lora_rank, (self.nope_head_dim + self.rope_head_dim) * self.n_heads, dtype=torch.bfloat16, bias=False)
        self.KV_proj_up = nn.Linear(self.kv_lora_rank, (self.nope_head_dim + self.v_head_dim) * self.n_heads, dtype=torch.bfloat16, bias=False)

        # RoPE on half embeddings
        self.q_rope = RoPE(self.rope_head_dim)
        self.k_rope = RoPE(self.rope_head_dim)

        # Output projection
        self.wo = nn.Linear(self.v_head_dim * self.n_heads, self.dim, dtype=torch.bfloat16, bias=False)
        self.eps = 1e-6
   
    def forward(self, x: torch.Tensor, kv_cache: KVCache) -> torch.Tensor:
        # seq_len = 1 always here
        batch_size, seq_len, model_dim = x.size()

        ################################################################################
        #                 Step 1: Handle down-projection + KV cache                    #
        ################################################################################
        q_lora = self.Q_proj_down(x)
        kv_lora = self.KV_proj_down(x)
        kv_lora, kv_len = kv_cache(kv_lora)
        query_pos = kv_len - 1

        ################################################################################
        #                  Step 2: Up-project and prepare NoPE + RoPE                  #
        ################################################################################

        # Handle queries Q first
        q_nope_and_rope = self.Q_proj_up(q_lora).view(
            batch_size, seq_len, self.n_heads, self.nope_head_dim + self.rope_head_dim)
        q_nope, q_rope = torch.split(q_nope_and_rope, [self.nope_head_dim, self.rope_head_dim], dim=-1)

        # Handle keys and values K/V. V does not need RoPE
        kv_nope, k_rope = torch.split(kv_lora, [self.kv_lora_rank, self.rope_head_dim], dim=-1)
        kv_nope = self.KV_proj_up(kv_nope).view(
            batch_size, kv_len, self.n_heads, self.nope_head_dim + self.v_head_dim)
        k_nope, v = torch.split(kv_nope, [self.nope_head_dim, self.v_head_dim], dim=-1)

        ################################################################################
        #                    Step 3: Handle RoPE Stream                                #
        ################################################################################

        # Compute RoPE for queries and combine with no-RoPE part
        q_rope = q_rope.permute(0, 2, 1, 3) # bs x n_heads x seq_len x rope_head_dim
        q_rope = self.q_rope(q_rope, start_pos=query_pos)

        q_nope = q_nope.permute(0, 2, 1, 3) # bs x n_heads x seq_len x rope_head_dim
        q = torch.concat([q_nope, q_rope], dim=-1)


        # Compute RoPE for keys and combine with no-RoPE part
        k_rope = k_rope[:, None, :, :]
        k_rope = self.k_rope(k_rope).expand(-1,self.n_heads,-1,-1)
        k_nope = k_nope.permute(0, 2, 1, 3) # bs x kv_len x n_heads x rope_head_dim
        k = torch.concat([k_nope, k_rope], dim=-1)
                
        ################################################################################
        #                        Compute Multi-head Attention                          #
        ################################################################################
        v = v.permute(0, 2, 1, 3) # bs x n_heads x kv_len x v_head_dim
        scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.rope_head_dim + self.nope_head_dim)
        attn = F.softmax(scores, dim=-1).to(torch.bfloat16)
        y = torch.matmul(attn, v).view(batch_size, 1, -1)
        y = self.wo(y)

        return y, kv_cache.get_data()

def generate_input(batchsize, dim, dq, prefill, seed):
    # Sizes derived from: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
    gen = torch.Generator(device='cuda')
    gen.manual_seed(seed)
    
    # Generate weights for linear layers
    Q_proj_down_weight = torch.randn((dq, dim), dtype=torch.bfloat16, generator=gen, device='cuda') / math.sqrt(dim)
    KV_proj_down_weight = torch.randn((512 + 64, dim), dtype=torch.bfloat16, generator=gen, device='cuda') / math.sqrt(dim)
    Q_proj_up_weight = torch.randn(((128 + 64) * 128, dq), dtype=torch.bfloat16, generator=gen, device='cuda') / math.sqrt(dq)
    KV_proj_up_weight = torch.randn(((128 + 128) * 128, 512), dtype=torch.bfloat16, generator=gen, device='cuda') / math.sqrt(512)
    wo_weight = torch.randn((dim, 128 * 128), dtype=torch.bfloat16, generator=gen, device='cuda') / math.sqrt(128 * 128)

    config = Config(
        batch_size=batchsize,
        dim=dim,
        q_lora_rank=dq,
        n_heads=128,
        kv_lora_rank=512,
        qk_nope_head_dim=128,
        qk_rope_head_dim=64,
        v_head_dim=128,
        seq_len=1,
        max_seq_len=8192,
        kv_cache_shape=(batchsize, 8192, 512 + 64),
        Q_proj_down_weight=Q_proj_down_weight,
        Q_proj_up_weight=Q_proj_up_weight,
        KV_proj_down_weight=KV_proj_down_weight,
        KV_proj_up_weight=KV_proj_up_weight,
        wo_weight=wo_weight,
    )
    x = torch.randn((config.batch_size, 1, config.dim), dtype=torch.bfloat16, generator=gen, device='cuda')
    
    # Pre-fill KV cache
    kv_cache = KVCache((config.batch_size, config.max_seq_len, config.kv_lora_rank + config.qk_rope_head_dim)).to('cuda')
    pre_filled_cache = torch.randn((config.batch_size, prefill, config.kv_lora_rank + config.qk_rope_head_dim), 
                                 dtype=torch.bfloat16, generator=gen, device='cuda')
    kv_cache(pre_filled_cache)

    return config, x, kv_cache

def ref_kernel(data: input_t) -> output_t:
    config, x, kv_cache = data

    # Load in model weights
    model = MLA(config).to('cuda')
    model.Q_proj_down.weight = nn.Parameter(config.Q_proj_down_weight)
    model.Q_proj_up.weight = nn.Parameter(config.Q_proj_up_weight)
    model.KV_proj_down.weight = nn.Parameter(config.KV_proj_down_weight)
    model.KV_proj_up.weight = nn.Parameter(config.KV_proj_up_weight)
    model.wo.weight = nn.Parameter(config.wo_weight)

    output, kv_cache = model(x, kv_cache)
    return output, kv_cache

check_implementation = make_match_reference(ref_kernel, rtol=2e-02, atol=8e-03)  


def time_mla(model, x, kv_cache, num_warmup=3, num_trials=5):

    # Warmup runs
    for _ in range(1):
        output, _ = model(x, kv_cache)
        torch.cuda.synchronize()

    # Timed runs
    times = []
    for _ in range(num_trials):
        kv_cache = KVCache((config.batch_size, config.max_seq_len, config.kv_lora_rank + config.qk_rope_head_dim)).to('cuda')
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        
        start.record()
        output, updated_kv = model(x, kv_cache)
        end.record()
        
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))

    avg_time = sum(times) / len(times)
    return output, updated_kv, avg_time, times

if __name__ == "__main__":
    # Generate test input
    batchsize = 128
    dim = 7168 
    dq = 1536
    prefill = 512
    seed = 97

    # Create model and inputs
    config, x, kv_cache = generate_input(batchsize, dim, dq, prefill, seed)
    model = MLA(config).to('cuda')

    # Run model with timing
    output, updated_kv, avg_time, times = time_mla(model, x, kv_cache)

    # Test reference kernel
    ref_output, ref_kv = ref_kernel((config, x, kv_cache))
    print("
Reference kernel output:")
    print(f"Output shape: {ref_output.shape}")
    print(f"KV cache shape: {ref_kv.shape}")
    print("
First few values of reference output:")
    print(ref_output[0, :10])

    # Compare outputs
    print("
Output difference:")
    print(f"Max absolute difference: {torch.max(torch.abs(output - ref_output))}")
    print(f"Mean absolute difference: {torch.mean(torch.abs(output - ref_output))}")

    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Updated KV cache shape: {updated_kv.shape}")
    print("
First few values of output:")
    print(output[0, :10])
    print(f"
Timing results over {len(times)} runs (ms):")
    print(f"Average: {avg_time:.2f}")
    print(f"Individual times: {[f'{t:.2f}' for t in times]}")

Rankings

MI300

nicholaswilde_08140 🥇 1787.457μs amd-mla-submission.py
nicolaswilde 🥈 1797.110μs   +9.654μs amd-mla-submission.py
Seb 🥉 1823.031μs   +25.921μs baguette_mla_6.py
ColorsWind 1838.267μs   +15.236μs submission.py
kfz 2140.750μs   +302.482μs mla_decode_v1.py
myy1966 2616.056μs   +475.306μs submission_mla_version_4.py
fanwenjie 3577.430μs   +961.374μs mla.py
pengcuo 6693.859μs   +3116.429μs submission.py
Cunxiao 9762.319μs   +3068.460μs mla_opt_sub.py
alex__ 9997.176μs   +234.857μs mla_opt_submission.py
tangzhengju_49570 10372.336μs   +375.160μs mla_opt_submission.py
jpy794 21076.284μs   +10703.948μs submission.py
Shinsato Masumi 25349.400μs   +4273.116μs submission.py
rt11 117673.407μs   +92324.007μs v0.py
hatoo 131384.772μs   +13711.365μs submission.py
Rinkia_Ke_Papa 135315.712μs   +3930.940μs mla3.py
az 1061624.320μs   +926308.608μs submission.py
tendazeal 1062621.752μs   +997.432μs amd-mla-decode.py
Tecahens 1063132.139μs   +510.387μs mle.py
Dhanshre 1068451.212μs   +5319.073μs amd-mla-decode.py
Shravan 1072388.420μs   +3937.208μs mla_31_05_2025.py
bobmarleybiceps 1073426.880μs   +1038.460μs submission.py
gftytkklt 1077296.751μs   +3869.871μs submission.py
summergift0941 1078153.400μs   +856.649μs submission.py
DStoPyA 1083647.347μs   +5493.947μs amd-mla-decode(1).py
Ling Qin 1144455.379μs   +60808.032μs submission.py
intrinsicmode 1239468.056μs   +95012.677μs submission.py
_kernelfolw_ 1253612.980μs   +14144.924μs submission.py
Arseni Ivanov 1253844.111μs   +231.131μs reference_orig.py
Xerxes 1312405.126μs   +58561.015μs mla_v2.py
gowtham_tupili 7351013.915μs   +6038608.790μs 3.py
legendary_fawn_56575 8605724.184μs   +1254710.269μs submission.py
Ding 9299385.466μs   +693661.282μs submission.py