amd-fp8-mm

Deadline

65 days 18 hours remaining (2025-09-02 00:00 UTC)

Language

Python

GPU Type

MI300

Description

You will implement a custom fp8-blockwise matmul kernel optimized for MI300. You will be given single-precision scaling factors for your matrices. The shapes of all outer and inner dimensions of tensors are from DeepSeek-R1. To be explicit, you will be given a tuple of tensors: ``` (a, b, a_scale, b_scale, c) ``` where `a` and `b` are the input matrices, `a_scale` and `b_scale` are the scaling factors for `a` and `b` respectively, and `c` is the output matrix: * `a` is M x K in column-major order in e4m3fnuz * `b` is N x K in column-major order in e4m3fnuz * `a_scale` is M x K // 128 in column-major order in fp32 * `b_scale` is N // 128 x K // 128 in column-major order in fp32 * `c` is M x N in ROW-major order in bf16 Matrix sizes `m` and `n` are divisible by 64, `k` is divisible by 128. The ranking criteria is the geometric mean of the benchmark results. For the grand price, your kernel will be evaluated against the speed of light analysis and the solution closest to the speed of light will be awarded the grand price. ``` The speed of light analysis is: M N K time[us] 1024 1536 7168 8.63 1024 4608 7168 25.89 6144 1536 7168 51.78 6144 4608 7168 155.30 1024 7168 256 3.17 6144 7168 256 17.27 ```

Reference Implementation

import torch
from task import input_t, output_t
from utils import make_match_reference


block_shape = (128, 128)

def generate_input(m: int, n: int, k: int, seed: int) -> input_t:
    """
    Generate random input and weights for Blockwise W8A8 Matmul scaled to FP32.
    
    Returns:
        Tuple of (
            a: torch.Tensor[float8_e4m3fnuz] of shape [m, k], 
            b: torch.Tensor[float8_e4m3fnuz] of shape [n, k], 
            a_scale: torch.Tensor[float32] of shape [m, k // 128], 
            b_scale: torch.Tensor[float32] of shape [n // 128, k // 128], 
            c: torch.Tensor[bfloat16] of shape [m, n]
        )
    """
    gen = torch.Generator(device='cuda')
    gen.manual_seed(seed)
    block_shape_n, block_shape_k = block_shape
    scale_n =  (n + block_shape_n - 1) // block_shape_n
    scale_k =  (k + block_shape_k - 1) // block_shape_k

    # Generate random inputs with FP8 quantization
    a = (torch.randn((k, m), dtype=torch.bfloat16, device="cuda", generator=gen)).to(torch.float8_e4m3fnuz)
    b = (torch.randn((k, n), dtype=torch.bfloat16, device="cuda", generator=gen)).to(torch.float8_e4m3fnuz)

    # Generate scaling factors with FP32
    a_scale = torch.randn([scale_k, m], dtype=torch.float32, device="cuda", generator=gen)
    b_scale = torch.randn([scale_k, scale_n], dtype=torch.float32, device="cuda", generator=gen)


    c = torch.zeros((m, n), dtype=torch.bfloat16, device="cuda")
    return (a.T, b.T, a_scale.T, b_scale.T, c)


def ref_kernel(data: input_t) -> output_t:
    """
    Highly inefficient torch reference implementation of FP8 GEMM.
    You can use this as a reference / starting template for your implementation.
    """
    # c: [m, n] is pre-allocated memory to help remove allocation overhead.
    a, b, a_scale, b_scale, c = data

    # a is M x K in column-major order, we convert here for simplicity.
    a = a.contiguous()
    a_scale = a_scale.contiguous()
    b_scale = b_scale.contiguous()

    # constants
    m = a.shape[0]
    n = b.shape[0]
    k = a.shape[1]
    block_shape_n = 128
    block_shape_k = 128
    scale_n = b_scale.shape[0]
    scale_k = b_scale.shape[1]

    # Apply blockwise scaling to input 'a'
    a_scale = a_scale.unsqueeze(-1).repeat(1, 1, block_shape_k)  # Shape: [m, scale_k, block_shape_k]
    a_scale = a_scale.reshape(m, scale_k * block_shape_k) 
    a_scale = a_scale[:, :k]

    # Dequantize 'a', in your implementation you should do this at the end.
    a = a.to(a_scale.dtype) * a_scale 

    # Apply blockwise scaling to input 'b'
    b_scale = (
        b_scale.view(-1, 1)
        .repeat(1, block_shape_n * block_shape_k)
        .view(scale_n, scale_k, block_shape_n, block_shape_k)
        .permute(0, 2, 1, 3)  # Reorder dimensions: [scale_n, blk_n, scale_k, blk_k]
        .reshape(scale_n * block_shape_n, scale_k * block_shape_k)
    )
    b_scale = b_scale[:n, :k]

    # Dequantize 'b', in your implementation you should do this at the end.
    b = b.to(b_scale.dtype) * b_scale 

    # Compute FP8 GEMM and write to 'c'. 
    c[...] = (a @ b.T).to(torch.bfloat16)
    return c


check_implementation = make_match_reference(ref_kernel, rtol=2e-02, atol=1e-03)

Rankings

MI300

ColorsWind 🥇 103.392μs submission.py
Seb 🥈 114.878μs   +11.487μs baguette_gemm_6.py
Snektron 🥉 115.734μs   +0.855μs solution.py
Shinsato Masumi 119.649μs   +3.916μs submission.py
nicholaswilde_08140 130.542μs   +10.893μs amd-fp8-mm-submission.py
nicolaswilde 130.829μs   +0.287μs amd-fp8-mm-submission.py
Cong 147.475μs   +16.646μs hipmode.py
fanwenjie 154.468μs   +6.993μs finish-2.py
Emilio 160.907μs   +6.439μs submission.py
pengcuo 167.990μs   +7.082μs submission.py
hatoo 181.638μs   +13.648μs hip.py
Qwesh157 181.983μs   +0.346μs template-hip.py
mtnielsen 192.304μs   +10.321μs fp8.py
myy1966 194.836μs   +2.532μs submission_fp8-mm_version_32.py
gau.nernst 199.032μs   +4.196μs triton_v5c.py
colorswind 208.686μs   +9.654μs submission.py
guizili0 214.998μs   +6.312μs fp8t_v4.py
rt11 218.335μs   +3.337μs v8.py
intrinsicmode 222.671μs   +4.337μs triton_v06.py
Karan Jakhar 226.358μs   +3.687μs submission_1.py
ryshaw 237.268μs   +10.910μs learning_triton.py
kfz 243.681μs   +6.413μs v1.py
DStoPyA 251.168μs   +7.487μs submission.py
bobmarleybiceps 254.113μs   +2.944μs submission-triton.py
Arseni Ivanov 254.979μs   +0.866μs submission_transposed.py
Azmuth 256.945μs   +1.966μs amd-fp8-mm.py
Tecahens 260.649μs   +3.704μs fp8-v4.py
Xavier Init 267.681μs   +7.032μs triton.py
LuiZzz 275.053μs   +7.372μs deepgemm.py
hyochan 278.066μs   +3.012μs go_gpu_mode.py
Austin Liu 279.345μs   +1.280μs submission.py
ALI 280.572μs   +1.227μs submission.py
t_cc 284.212μs   +3.641μs fp8_mm.py
shikhar 288.569μs   +4.356μs lossfunk_v5.py
JackAtlas101 293.381μs   +4.813μs triton_submission.py
luojiehao. 300.386μs   +7.004μs triton_v04_32x32.py
yiakwy-xpu-ml-framework-team 307.264μs   +6.879μs flashFloat_fp8_w8a8_ckasm_ref_submission.py
cudawarped 308.319μs   +1.054μs submission_amd_fp8_mm_naive.py
jefflyu_47387 310.945μs   +2.626μs submission.py
Matthias 314.755μs   +3.810μs test.py
Cunxiao 314.766μs   +0.011μs submission_fp8.py
alex__ 316.661μs   +1.896μs submission.py
gftytkklt 318.398μs   +1.736μs submission_triton1.py
youchunbo_62981_72548 322.134μs   +3.737μs submission_2.py
nyl199310 323.129μs   +0.995μs submission.py
tangzhengju_49570 325.485μs   +2.356μs 2_submission.py
mreso 326.574μs   +1.089μs test.py
_kernelfolw_ 328.918μs   +2.343μs submission.py
whatdhack_ 329.891μs   +0.973μs 03-matrix-multiplication.py
summergift0941 337.925μs   +8.034μs block_scale_fp8.py
rosehulman. 342.447μs   +4.522μs kernel.py
terrencezwy3795 343.669μs   +1.221μs submission.py
arseni_ivanov 344.121μs   +0.452μs submission.py
legendary_fawn_56575 347.820μs   +3.699μs template.py
Ding 348.328μs   +0.508μs template.py
Shivam 366.055μs   +17.727μs submission.py
Shlok 386.380μs   +20.324μs TKGEMM_355.py
Quantizr 422.452μs   +36.072μs triton.py
mdda123 545.679μs   +123.227μs hip.py
DizzleRama 589.097μs   +43.418μs gpu_mode_fp88_8_for_testing.py
Bexboy 656.024μs   +66.927μs lossfunk_residual_v2.py
nick_lc700x 658.464μs   +2.439μs sub12.py
fforange 696.125μs   +37.661μs lol2.py
Henry 713.800μs   +17.675μs amd_fp8_mm.py
wildman 756.135μs   +42.334μs template-hip2.py
hezhexi2002 760.774μs   +4.640μs custom-triton.py
agokrani 763.792μs   +3.018μs lossfunk_v2.py
Itssshikhar 777.123μs   +13.331μs lossfunk_v3.py
yia_perf 783.819μs   +6.696μs my_custom_kernel.py
osborn0016 793.867μs   +10.048μs submission.py
timppa 797.106μs   +3.239μs submission.py
_hui_xu 798.102μs   +0.995μs amd-fp8-mm.py
sirmisscriesalot 824.405μs   +26.304μs submission2.py
siclait 824.919μs   +0.514μs submission3.py
truco5798 849.944μs   +25.025μs demo.py
niki10 851.498μs   +1.554μs ref.py
sridharnandigam 852.235μs   +0.738μs ref.py
teddy_chou 852.455μs   +0.220μs fp8_ref.py
demosright 852.897μs   +0.442μs submission.py
BirdBoss 853.392μs   +0.495μs submission.py
D++ 854.249μs   +0.856μs submission.py
AkatsukiChiri 856.079μs   +1.830μs submission.py
fxfxfxfxfxfxfxfx 858.094μs   +2.016μs submission.py
SummerGift 859.649μs   +1.555μs submission.py
ShiyuWang 861.854μs   +2.205μs submission.py
blurbird 861.884μs   +0.030μs amd-fp8-mm.py
stepinto. 861.930μs   +0.046μs submission.py
shengwenliang. 862.303μs   +0.372μs amd.py
GnSight 864.084μs   +1.781μs submission.py
amd_bob 864.138μs   +0.054μs ref.py
lollipopkit 866.930μs   +2.792μs submission.py
getapiurl_29979_86903 868.026μs   +1.096μs submission.py
zhubenzhu 869.584μs   +1.558μs submission_2.py
Raghu 870.521μs   +0.937μs submission.py
Mr.Wang 871.191μs   +0.670μs triton_kernel.py
legendary_pony_17025 872.026μs   +0.836μs submission.py
lucifer050296 872.848μs   +0.822μs submission.py
siro 873.042μs   +0.194μs submission.py
stylish_kiwi_61275 873.084μs   +0.042μs submission.py
solahome 873.661μs   +0.577μs submission.py
Hoyoun Jung 876.246μs   +2.585μs submission.py
fyr233 877.574μs   +1.328μs submission-default.py
chess 877.719μs   +0.145μs fp8mm.py
Rik - OCI 878.902μs   +1.183μs ref.py
viranchee 881.558μs   +2.656μs fp8_ref_kernel.py
Shravan 881.965μs   +0.407μs submission.py
Turtle 882.397μs   +0.432μs submission.py
spacekkkk_64980 882.401μs   +0.004μs submission.py
cymtrick 882.951μs   +0.551μs check.py
Phil Butler 883.022μs   +0.071μs no-change.py
parrotsky 884.484μs   +1.461μs sub1.py
tomaszki 884.679μs   +0.195μs fp8.py
tendazeal 886.044μs   +1.365μs submission.py
killTheHostage 887.927μs   +1.883μs submission.py
Erik S. 890.743μs   +2.816μs BASELINE.py
savik 897.788μs   +7.045μs amd-fp8-mm_python.py
xsun2001 904.710μs   +6.922μs submission.py
beetle0315 904.725μs   +0.015μs submission.py
Rinkia_Ke_Papa 1133.229μs   +228.504μs lol1.py
Sagar 1176.737μs   +43.508μs amd_fp8_k6.py
puzzledpikachu 1270.936μs   +94.199μs submission_pytorch_4.py
Ling Qin 1720.549μs   +449.613μs fp8_mm.py
Taylor 1760.344μs   +39.795μs submission.py
__seal 1776.785μs   +16.441μs submission-hip.py
gowtham_tupili 2372.085μs   +595.300μs amd-fp8-mm_2.py
Lucky@CHR 2423.662μs   +51.577μs submission.py
Rocky Singh 2432.328μs   +8.666μs amd-fp8-mm_2.py
stmnk 2546.855μs   +114.527μs submission.py
kap272 2600.381μs   +53.526μs fp8_mm.py
truk@PLT 2643.077μs   +42.696μs submission.py
Merle 3430.114μs   +787.037μs template-hip.py
pongtsu 4227.150μs   +797.036μs v2.py
dkennetz 4615.243μs   +388.093μs hip_submission.py
harrison03042014 4737.853μs   +122.610μs HipGemmFp8.py
sam 5130.954μs   +393.102μs ref_hip.py
salad 5157.419μs   +26.464μs ref_v2.py
anirudh9616 5167.363μs   +9.944μs amd-gemm-submission-default.py
mqq 5167.421μs   +0.059μs amd-fp8-mm.py
ff-0xff 5173.870μs   +6.449μs v1.py
iron_bound 5175.590μs   +1.719μs template-amd-v2.py
LeviAckerman 5191.602μs   +16.012μs amd-fp8-mm.py
ff0xff5161 5195.828μs   +4.226μs template-hip.py
sophismparadox 5196.328μs   +0.500μs test.py
ph 5198.533μs   +2.206μs hip.py
liujiali_34506 5203.008μs   +4.475μs template-hip.py
Yash! 5204.934μs   +1.926μs original.py
DD 5210.840μs   +5.906μs reference.py
gabteni 5230.547μs   +19.707μs template-hip.py
cloudysky123_18954 5234.445μs   +3.897μs template-hip.py
opdroid1234 5242.287μs   +7.843μs reference.py
eclouder 5245.847μs   +3.560μs template-hip.py
blackcola 5247.494μs   +1.647μs test.py
Haw 5248.982μs   +1.488μs template-hip.py
shauheen 5254.196μs   +5.213μs amd-fp8-mm-2.py
gauravgokhale 5254.517μs   +0.321μs test.py
Howard 5280.150μs   +25.633μs amd-fp8-mm.py
jpy794 5283.286μs   +3.136μs template-hip.py
demon_36401 5288.557μs   +5.271μs test.py
gxtzhuxi 5312.887μs   +24.330μs template-hip.py
Dhanshre 5317.595μs   +4.708μs amd-fp8-mm.py
samkg 5323.384μs   +5.789μs template-hip.py
langzs335_31673 5327.104μs   +3.720μs template-hip.py
sYnfo 5331.620μs   +4.516μs amd-fp8-mm.py
_wizard_5 5395.874μs   +64.254μs template-hip.py