amd-fp8-mm

Deadline

8 days 4 hours (2025-05-27 00:00 UTC)

Language

Python

GPU Type

MI300

Description

You will implement a custom fp8-blockwise matmul kernel optimized for MI300. You will be given single-precision scaling factors for your matrices. The shapes of all outer and inner dimensions of tensors are from DeepSeek-R1. To be explicit, you will be given a tuple of tensors: ``` (a, b, a_scale, b_scale, c) ``` where `a` and `b` are the input matrices, `a_scale` and `b_scale` are the scaling factors for `a` and `b` respectively, and `c` is the output matrix: * `a` is M x K in column-major order in e4m3fnuz * `b` is N x K in column-major order in e4m3fnuz * `a_scale` is M x K // 128 in column-major order in fp32 * `b_scale` is N // 128 x K // 128 in column-major order in fp32 * `c` is M x N in ROW-major order in bf16 Matrix sizes `m` and `n` are divisible by 64, `k` is divisible by 128. The ranking criteria is the geometric mean of the benchmark results. For the grand price, your kernel will be evaluated against the speed of light analysis and the solution closest to the speed of light will be awarded the grand price. ``` The speed of light analysis is: M N K time[us] 1024 1536 7168 8.63 1024 4608 7168 25.89 6144 1536 7168 51.78 6144 4608 7168 155.30 1024 7168 256 3.17 6144 7168 256 17.27 ```

Reference Implementation

import torch
from task import input_t, output_t
from utils import make_match_reference


block_shape = (128, 128)

def generate_input(m: int, n: int, k: int, seed: int) -> input_t:
    """
    Generate random input and weights for Blockwise W8A8 Matmul scaled to FP32.
    
    Returns:
        Tuple of (
            a: torch.Tensor[float8_e4m3fnuz] of shape [m, k], 
            b: torch.Tensor[float8_e4m3fnuz] of shape [n, k], 
            a_scale: torch.Tensor[float32] of shape [m, k // 128], 
            b_scale: torch.Tensor[float32] of shape [n // 128, k // 128], 
            c: torch.Tensor[bfloat16] of shape [m, n]
        )
    """
    gen = torch.Generator(device='cuda')
    gen.manual_seed(seed)
    block_shape_n, block_shape_k = block_shape
    scale_n =  (n + block_shape_n - 1) // block_shape_n
    scale_k =  (k + block_shape_k - 1) // block_shape_k

    # Generate random inputs with FP8 quantization
    a = (torch.randn((k, m), dtype=torch.bfloat16, device="cuda", generator=gen)).to(torch.float8_e4m3fnuz)
    b = (torch.randn((k, n), dtype=torch.bfloat16, device="cuda", generator=gen)).to(torch.float8_e4m3fnuz)

    # Generate scaling factors with FP32
    a_scale = torch.randn([scale_k, m], dtype=torch.float32, device="cuda", generator=gen)
    b_scale = torch.randn([scale_k, scale_n], dtype=torch.float32, device="cuda", generator=gen)


    c = torch.zeros((m, n), dtype=torch.bfloat16, device="cuda")
    return (a.T, b.T, a_scale.T, b_scale.T, c)


def ref_kernel(data: input_t) -> output_t:
    """
    Highly inefficient torch reference implementation of FP8 GEMM.
    You can use this as a reference / starting template for your implementation.
    """
    # c: [m, n] is pre-allocated memory to help remove allocation overhead.
    a, b, a_scale, b_scale, c = data

    # a is M x K in column-major order, we convert here for simplicity.
    a = a.contiguous()
    a_scale = a_scale.contiguous()
    b_scale = b_scale.contiguous()

    # constants
    m = a.shape[0]
    n = b.shape[0]
    k = a.shape[1]
    block_shape_n = 128
    block_shape_k = 128
    scale_n = b_scale.shape[0]
    scale_k = b_scale.shape[1]

    # Apply blockwise scaling to input 'a'
    a_scale = a_scale.unsqueeze(-1).repeat(1, 1, block_shape_k)  # Shape: [m, scale_k, block_shape_k]
    a_scale = a_scale.reshape(m, scale_k * block_shape_k) 
    a_scale = a_scale[:, :k]

    # Dequantize 'a', in your implementation you should do this at the end.
    a = a.to(a_scale.dtype) * a_scale 

    # Apply blockwise scaling to input 'b'
    b_scale = (
        b_scale.view(-1, 1)
        .repeat(1, block_shape_n * block_shape_k)
        .view(scale_n, scale_k, block_shape_n, block_shape_k)
        .permute(0, 2, 1, 3)  # Reorder dimensions: [scale_n, blk_n, scale_k, blk_k]
        .reshape(scale_n * block_shape_n, scale_k * block_shape_k)
    )
    b_scale = b_scale[:n, :k]

    # Dequantize 'b', in your implementation you should do this at the end.
    b = b.to(b_scale.dtype) * b_scale 

    # Compute FP8 GEMM and write to 'c'. 
    c[...] = (a @ b.T).to(torch.bfloat16)
    return c


check_implementation = make_match_reference(ref_kernel, rtol=2e-02, atol=1e-03)

Rankings

MI300

Seb 🥇 121.819μs baguette_3.py
nicolaswilde 🥈 135.109μs   +13.290μs amd-fp8-mm-hip.py
Snektron 🥉 153.495μs   +18.386μs fp8_gemm.py
Cong 154.267μs   +0.772μs hipmode.py
fanwenjie 154.468μs   +0.201μs finish-2.py
Shinsato Masumi 174.240μs   +19.772μs submission.py
hatoo 181.638μs   +7.398μs hip.py
Qwesh157 181.983μs   +0.346μs template-hip.py
mtnielsen 192.304μs   +10.321μs fp8.py
gau.nernst 199.032μs   +6.728μs triton_v5c.py
ColorsWind 204.171μs   +5.139μs submission.py
colorswind 208.686μs   +4.515μs submission.py
guizili0 214.998μs   +6.312μs fp8t_v4.py
rt11 218.335μs   +3.337μs v8.py
intrinsicmode 222.671μs   +4.337μs triton_v06.py
Karan Jakhar 226.358μs   +3.687μs submission_1.py
ryshaw 237.268μs   +10.910μs learning_triton.py
kfz 243.681μs   +6.413μs v1.py
DStoPyA 251.168μs   +7.487μs submission.py
bobmarleybiceps 254.113μs   +2.944μs submission-triton.py
Arseni Ivanov 254.979μs   +0.866μs submission_transposed.py
Azmuth 256.945μs   +1.966μs amd-fp8-mm.py
myy1966 258.793μs   +1.848μs submission_fp8-mm_version_7.py
Tecahens 260.649μs   +1.856μs fp8-v4.py
Xavier Init 267.681μs   +7.032μs triton.py
LuiZzz 275.053μs   +7.372μs deepgemm.py
hyochan 278.066μs   +3.012μs go_gpu_mode.py
Austin Liu 279.345μs   +1.280μs submission.py
ALI 280.572μs   +1.227μs submission.py
t_cc 284.212μs   +3.641μs fp8_mm.py
shikhar 288.569μs   +4.356μs lossfunk_v5.py
pengcuo 297.899μs   +9.331μs submission.py
luojiehao. 300.386μs   +2.486μs triton_v04_32x32.py
cudawarped 308.319μs   +7.933μs submission_amd_fp8_mm_naive.py
Matthias 314.755μs   +6.436μs test.py
nyl199310 323.129μs   +8.374μs submission.py
mreso 326.574μs   +3.445μs test.py
_kernelfolw_ 328.918μs   +2.343μs submission.py
whatdhack_ 329.891μs   +0.973μs 03-matrix-multiplication.py
summergift0941 337.925μs   +8.034μs block_scale_fp8.py
rosehulman. 342.447μs   +4.522μs kernel.py
arseni_ivanov 344.121μs   +1.673μs submission.py
legendary_fawn_56575 347.820μs   +3.699μs template.py
Ding 348.328μs   +0.508μs template.py
yiakwy-xpu-ml-framework-team 364.170μs   +15.842μs flashFloat_fp8_w8a8_triton_ref_submission.py
Shivam 366.055μs   +1.885μs submission.py
Shlok 386.380μs   +20.324μs TKGEMM_355.py
Quantizr 422.452μs   +36.072μs triton.py
DizzleRama 589.097μs   +166.646μs gpu_mode_fp88_8_for_testing.py
Bexboy 656.024μs   +66.927μs lossfunk_residual_v2.py
nick_lc700x 658.464μs   +2.439μs sub12.py
Henry 713.800μs   +55.336μs amd_fp8_mm.py
hezhexi2002 760.774μs   +46.974μs custom-triton.py
agokrani 763.792μs   +3.018μs lossfunk_v2.py
Itssshikhar 777.123μs   +13.331μs lossfunk_v3.py
yia_perf 783.819μs   +6.696μs my_custom_kernel.py
osborn0016 793.867μs   +10.048μs submission.py
timppa 797.106μs   +3.239μs submission.py
_hui_xu 798.102μs   +0.995μs amd-fp8-mm.py
wildman 819.365μs   +21.263μs o4-amd-fp8-mm.py
sirmisscriesalot 824.405μs   +5.040μs submission2.py
siclait 824.919μs   +0.514μs submission3.py
mdda123 849.532μs   +24.613μs submission.py
truco5798 849.944μs   +0.412μs demo.py
niki10 851.498μs   +1.554μs ref.py
sridharnandigam 852.235μs   +0.738μs ref.py
teddy_chou 852.455μs   +0.220μs fp8_ref.py
demosright 852.897μs   +0.442μs submission.py
BirdBoss 853.392μs   +0.495μs submission.py
D++ 854.249μs   +0.856μs submission.py
AkatsukiChiri 856.079μs   +1.830μs submission.py
fxfxfxfxfxfxfxfx 858.094μs   +2.016μs submission.py
SummerGift 859.649μs   +1.555μs submission.py
ShiyuWang 861.854μs   +2.205μs submission.py
blurbird 861.884μs   +0.030μs amd-fp8-mm.py
stepinto. 861.930μs   +0.046μs submission.py
shengwenliang. 862.303μs   +0.372μs amd.py
GnSight 864.084μs   +1.781μs submission.py
amd_bob 864.138μs   +0.054μs ref.py
lollipopkit 866.930μs   +2.792μs submission.py
getapiurl_29979_86903 868.026μs   +1.096μs submission.py
zhubenzhu 869.584μs   +1.558μs submission_2.py
Raghu 870.521μs   +0.937μs submission.py
Mr.Wang 871.191μs   +0.670μs triton_kernel.py
legendary_pony_17025 872.026μs   +0.836μs submission.py
lucifer050296 872.848μs   +0.822μs submission.py
siro 873.042μs   +0.194μs submission.py
stylish_kiwi_61275 873.084μs   +0.042μs submission.py
solahome 873.661μs   +0.577μs submission.py
Hoyoun Jung 876.246μs   +2.585μs submission.py
fyr233 877.574μs   +1.328μs submission-default.py
chess 877.719μs   +0.145μs fp8mm.py
Rik - OCI 878.902μs   +1.183μs ref.py
Cunxiao 880.765μs   +1.863μs submission.py
viranchee 881.558μs   +0.793μs fp8_ref_kernel.py
Shravan 881.965μs   +0.407μs submission.py
Turtle 882.397μs   +0.432μs submission.py
spacekkkk_64980 882.401μs   +0.004μs submission.py
cymtrick 882.951μs   +0.551μs check.py
Phil Butler 883.022μs   +0.071μs no-change.py
parrotsky 884.484μs   +1.461μs sub1.py
tomaszki 884.679μs   +0.195μs fp8.py
tendazeal 886.044μs   +1.365μs submission.py
killTheHostage 887.927μs   +1.883μs submission.py
Erik S. 890.743μs   +2.816μs BASELINE.py
beetle0315 904.725μs   +13.982μs submission.py
Sagar 1176.737μs   +272.012μs amd_fp8_k6.py
puzzledpikachu 1270.936μs   +94.199μs submission_pytorch_4.py
Ling Qin 1720.549μs   +449.613μs fp8_mm.py
Taylor 1760.344μs   +39.795μs submission.py
__seal 1776.785μs   +16.441μs submission-hip.py
gowtham_tupili 2372.085μs   +595.300μs amd-fp8-mm_2.py
Lucky@CHR 2423.662μs   +51.577μs submission.py
Rocky Singh 2432.328μs   +8.666μs amd-fp8-mm_2.py
stmnk 2546.855μs   +114.527μs submission.py
truk@PLT 2643.077μs   +96.222μs submission.py
Merle 3430.114μs   +787.037μs template-hip.py
dkennetz 4615.243μs   +1185.129μs hip_submission.py
harrison03042014 4737.853μs   +122.610μs HipGemmFp8.py
sam 5130.954μs   +393.102μs ref_hip.py
salad 5157.419μs   +26.464μs ref_v2.py
anirudh9616 5167.363μs   +9.944μs amd-gemm-submission-default.py
mqq 5167.421μs   +0.059μs amd-fp8-mm.py
iron_bound 5175.590μs   +8.168μs template-amd-v2.py
LeviAckerman 5191.602μs   +16.012μs amd-fp8-mm.py
sophismparadox 5196.328μs   +4.726μs test.py
ph 5198.533μs   +2.206μs hip.py
liujiali_34506 5203.008μs   +4.475μs template-hip.py
Yash! 5204.934μs   +1.926μs original.py
gabteni 5230.547μs   +25.613μs template-hip.py
cloudysky123_18954 5234.445μs   +3.897μs template-hip.py
opdroid1234 5242.287μs   +7.843μs reference.py
eclouder 5245.847μs   +3.560μs template-hip.py
blackcola 5247.494μs   +1.647μs test.py
Haw 5248.982μs   +1.488μs template-hip.py
shauheen 5254.196μs   +5.213μs amd-fp8-mm-2.py
gauravgokhale 5254.517μs   +0.321μs test.py
Howard 5280.150μs   +25.633μs amd-fp8-mm.py
jpy794 5283.286μs   +3.136μs template-hip.py
demon_36401 5288.557μs   +5.271μs test.py
gxtzhuxi 5312.887μs   +24.330μs template-hip.py
Dhanshre 5317.595μs   +4.708μs amd-fp8-mm.py
samkg 5323.384μs   +5.789μs template-hip.py
langzs335_31673 5327.104μs   +3.720μs template-hip.py
sYnfo 5331.620μs   +4.516μs amd-fp8-mm.py
_wizard_5 5395.874μs   +64.254μs template-hip.py