amd-fp8-mm

Deadline

15 days 18 hours (2025-05-27 00:00 UTC)

Language

Python

GPU Type

MI300

Description

You will implement a custom fp8-blockwise matmul kernel optimized for MI300. You will be given single-precision scaling factors for your matrices. The shapes of all outer and inner dimensions of tensors are from DeepSeek-R1. To be explicit, you will be given a tuple of tensors: ``` (a, b, a_scale, b_scale, c) ``` where `a` and `b` are the input matrices, `a_scale` and `b_scale` are the scaling factors for `a` and `b` respectively, and `c` is the output matrix: * `a` is M x K in column-major order in e4m3fnuz * `b` is N x K in column-major order in e4m3fnuz * `a_scale` is M x K // 128 in column-major order in fp32 * `b_scale` is N // 128 x K // 128 in column-major order in fp32 * `c` is M x N in ROW-major order in bf16 Matrix sizes `m` and `n` are divisible by 64, `k` is divisible by 128. The ranking criteria is the geometric mean of the benchmark results. For the grand price, your kernel will be evaluated against the speed of light analysis and the solution closest to the speed of light will be awarded the grand price. ``` The speed of light analysis is: M N K time[us] 1024 1536 7168 8.63 1024 4608 7168 25.89 6144 1536 7168 51.78 6144 4608 7168 155.30 1024 7168 256 3.17 6144 7168 256 17.27 ```

Reference Implementation

import torch
from task import input_t, output_t
from utils import make_match_reference


block_shape = (128, 128)

def generate_input(m: int, n: int, k: int, seed: int) -> input_t:
    """
    Generate random input and weights for Blockwise W8A8 Matmul scaled to FP32.
    
    Returns:
        Tuple of (
            a: torch.Tensor[float8_e4m3fnuz] of shape [m, k], 
            b: torch.Tensor[float8_e4m3fnuz] of shape [n, k], 
            a_scale: torch.Tensor[float32] of shape [m, k // 128], 
            b_scale: torch.Tensor[float32] of shape [n // 128, k // 128], 
            c: torch.Tensor[bfloat16] of shape [m, n]
        )
    """
    gen = torch.Generator(device='cuda')
    gen.manual_seed(seed)
    block_shape_n, block_shape_k = block_shape
    scale_n =  (n + block_shape_n - 1) // block_shape_n
    scale_k =  (k + block_shape_k - 1) // block_shape_k

    # Generate random inputs with FP8 quantization
    a = (torch.randn((k, m), dtype=torch.bfloat16, device="cuda", generator=gen)).to(torch.float8_e4m3fnuz)
    b = (torch.randn((k, n), dtype=torch.bfloat16, device="cuda", generator=gen)).to(torch.float8_e4m3fnuz)

    # Generate scaling factors with FP32
    a_scale = torch.randn([scale_k, m], dtype=torch.float32, device="cuda", generator=gen)
    b_scale = torch.randn([scale_k, scale_n], dtype=torch.float32, device="cuda", generator=gen)


    c = torch.zeros((m, n), dtype=torch.bfloat16, device="cuda")
    return (a.T, b.T, a_scale.T, b_scale.T, c)


def ref_kernel(data: input_t) -> output_t:
    """
    Highly inefficient torch reference implementation of FP8 GEMM.
    You can use this as a reference / starting template for your implementation.
    """
    # c: [m, n] is pre-allocated memory to help remove allocation overhead.
    a, b, a_scale, b_scale, c = data

    # a is M x K in column-major order, we convert here for simplicity.
    a = a.contiguous()
    a_scale = a_scale.contiguous()
    b_scale = b_scale.contiguous()

    # constants
    m = a.shape[0]
    n = b.shape[0]
    k = a.shape[1]
    block_shape_n = 128
    block_shape_k = 128
    scale_n = b_scale.shape[0]
    scale_k = b_scale.shape[1]

    # Apply blockwise scaling to input 'a'
    a_scale = a_scale.unsqueeze(-1).repeat(1, 1, block_shape_k)  # Shape: [m, scale_k, block_shape_k]
    a_scale = a_scale.reshape(m, scale_k * block_shape_k) 
    a_scale = a_scale[:, :k]

    # Dequantize 'a', in your implementation you should do this at the end.
    a = a.to(a_scale.dtype) * a_scale 

    # Apply blockwise scaling to input 'b'
    b_scale = (
        b_scale.view(-1, 1)
        .repeat(1, block_shape_n * block_shape_k)
        .view(scale_n, scale_k, block_shape_n, block_shape_k)
        .permute(0, 2, 1, 3)  # Reorder dimensions: [scale_n, blk_n, scale_k, blk_k]
        .reshape(scale_n * block_shape_n, scale_k * block_shape_k)
    )
    b_scale = b_scale[:n, :k]

    # Dequantize 'b', in your implementation you should do this at the end.
    b = b.to(b_scale.dtype) * b_scale 

    # Compute FP8 GEMM and write to 'c'. 
    c[...] = (a @ b.T).to(torch.bfloat16)
    return c


check_implementation = make_match_reference(ref_kernel, rtol=2e-02, atol=1e-03)

Rankings

MI300

Seb 🥇 121.819μs baguette_3.py
Snektron 🥈 153.495μs   +31.676μs fp8_gemm.py
Shinsato Masumi 🥉 174.561μs   +21.067μs Submission.py
nicolaswilde 179.592μs   +5.031μs amd-fp8-mm-hip-128x128-32x64-swz-trans.py
Cong 183.095μs   +3.503μs test_K64_3.py
hatoo 190.472μs   +7.377μs hip.py
ColorsWind 204.171μs   +13.699μs submission.py
guizili0 214.998μs   +10.826μs fp8t_v4.py
intrinsicmode 222.671μs   +7.674μs triton_v06.py
rt11 223.051μs   +0.380μs v8.py
Karan Jakhar 226.358μs   +3.307μs submission_1.py
ryshaw 237.268μs   +10.910μs learning_triton.py
Qwesh157 241.725μs   +4.457μs template-hip.py
kfz 243.681μs   +1.956μs v1.py
DStoPyA 251.168μs   +7.487μs submission.py
bobmarleybiceps 254.113μs   +2.944μs submission-triton.py
Arseni Ivanov 254.979μs   +0.866μs submission_transposed.py
gau.nernst 258.741μs   +3.762μs triton_v1.py
hyochan 278.066μs   +19.324μs go_gpu_mode.py
ALI 280.572μs   +2.506μs submission.py
myy1966 286.333μs   +5.762μs submission_fp8_mm_version_2.py
shikhar 288.569μs   +2.235μs lossfunk_v5.py
Xavier Init 289.527μs   +0.959μs sub.py.py
pengcuo 297.899μs   +8.372μs submission.py
Austin Liu 309.845μs   +11.946μs submission.py
t_cc 313.878μs   +4.033μs fp8_mm.py
Matthias 314.755μs   +0.877μs test.py
nyl199310 323.129μs   +8.374μs submission.py
mreso 326.574μs   +3.445μs test.py
luojiehao. 332.458μs   +5.883μs triton_v04_fp8.py
whatdhack_ 340.170μs   +7.712μs 03-matrix-multiplication.py
arseni_ivanov 344.121μs   +3.951μs submission.py
yiakwy-xpu-ml-framework-team 364.170μs   +20.049μs flashFloat_fp8_w8a8_triton_ref_submission.py
Shivam 366.055μs   +1.885μs submission.py
Shlok 386.380μs   +20.324μs TKGEMM_355.py
rosehulman. 395.839μs   +9.460μs kernel.py
Quantizr 422.452μs   +26.612μs triton.py
fanwenjie 454.824μs   +32.372μs test_4.py
DizzleRama 589.097μs   +134.273μs gpu_mode_fp88_8_for_testing.py
Bexboy 656.024μs   +66.927μs lossfunk_residual_v2.py
nick_lc700x 658.464μs   +2.439μs sub12.py
Henry 713.800μs   +55.336μs amd_fp8_mm.py
hezhexi2002 760.774μs   +46.974μs custom-triton.py
agokrani 763.792μs   +3.018μs lossfunk_v2.py
Itssshikhar 777.123μs   +13.331μs lossfunk_v3.py
yia_perf 783.819μs   +6.696μs my_custom_kernel.py
osborn0016 793.867μs   +10.048μs submission.py
Ding 795.162μs   +1.295μs submission_829us.py
timppa 797.106μs   +1.944μs submission.py
_hui_xu 798.102μs   +0.995μs amd-fp8-mm.py
wildman 819.365μs   +21.263μs o4-amd-fp8-mm.py
sirmisscriesalot 824.405μs   +5.040μs submission2.py
siclait 824.919μs   +0.514μs submission3.py
mdda123 849.532μs   +24.613μs submission.py
truco5798 849.944μs   +0.412μs demo.py
niki10 851.498μs   +1.554μs ref.py
teddy_chou 852.455μs   +0.957μs fp8_ref.py
demosright 852.897μs   +0.442μs submission.py
BirdBoss 853.392μs   +0.495μs submission.py
D++ 854.249μs   +0.856μs submission.py
AkatsukiChiri 856.079μs   +1.830μs submission.py
fxfxfxfxfxfxfxfx 858.094μs   +2.016μs submission.py
SummerGift 859.649μs   +1.555μs submission.py
blurbird 861.884μs   +2.235μs amd-fp8-mm.py
stepinto. 861.930μs   +0.046μs submission.py
shengwenliang. 862.303μs   +0.372μs amd.py
GnSight 864.084μs   +1.781μs submission.py
amd_bob 864.138μs   +0.054μs ref.py
lollipopkit 866.930μs   +2.792μs submission.py
getapiurl_29979_86903 868.026μs   +1.096μs submission.py
zhubenzhu 869.584μs   +1.558μs submission_2.py
Raghu 870.521μs   +0.937μs submission.py
Mr.Wang 871.191μs   +0.670μs triton_kernel.py
legendary_pony_17025 872.026μs   +0.836μs submission.py
lucifer050296 872.848μs   +0.822μs submission.py
siro 873.042μs   +0.194μs submission.py
stylish_kiwi_61275 873.084μs   +0.042μs submission.py
solahome 873.661μs   +0.577μs submission.py
Hoyoun Jung 876.246μs   +2.585μs submission.py
fyr233 877.574μs   +1.328μs submission-default.py
chess 877.719μs   +0.145μs fp8mm.py
Rik - OCI 878.902μs   +1.183μs ref.py
Cunxiao 880.765μs   +1.863μs submission.py
viranchee 881.558μs   +0.793μs fp8_ref_kernel.py
Shravan 881.965μs   +0.407μs submission.py
Turtle 882.397μs   +0.432μs submission.py
spacekkkk_64980 882.401μs   +0.004μs submission.py
cymtrick 882.951μs   +0.551μs check.py
Phil Butler 883.022μs   +0.071μs no-change.py
parrotsky 884.484μs   +1.461μs sub1.py
tomaszki 884.679μs   +0.195μs fp8.py
tendazeal 886.044μs   +1.365μs submission.py
killTheHostage 887.927μs   +1.883μs submission.py
Erik S. 890.743μs   +2.816μs BASELINE.py
cudawarped 892.193μs   +1.450μs submission_amd_fp8_mm_base.py
Sagar 1176.737μs   +284.544μs amd_fp8_k6.py
puzzledpikachu 1270.936μs   +94.199μs submission_pytorch_4.py
Taylor 1760.344μs   +489.408μs submission.py
__seal 1776.785μs   +16.441μs submission-hip.py
gowtham_tupili 2372.085μs   +595.300μs amd-fp8-mm_2.py
Rocky Singh 2432.328μs   +60.243μs amd-fp8-mm_2.py
truk@PLT 2643.077μs   +210.749μs submission.py
Azmuth 2771.750μs   +128.673μs amd-fp8-mm_hip.py
Ling Qin 2863.392μs   +91.642μs fp8_mm.py
Merle 3430.114μs   +566.721μs template-hip.py
dkennetz 4615.243μs   +1185.129μs hip_submission.py
harrison03042014 4737.853μs   +122.610μs HipGemmFp8.py
sam 5130.954μs   +393.102μs ref_hip.py
salad 5157.419μs   +26.464μs ref_v2.py
anirudh9616 5167.363μs   +9.944μs amd-gemm-submission-default.py
iron_bound 5175.590μs   +8.227μs template-amd-v2.py
colorswind 5190.733μs   +15.143μs template-hip.py
LeviAckerman 5191.602μs   +0.869μs amd-fp8-mm.py
sophismparadox 5196.328μs   +4.726μs test.py
liujiali_34506 5203.008μs   +6.681μs template-hip.py
Yash! 5204.934μs   +1.926μs original.py
cloudysky123_18954 5234.445μs   +29.510μs template-hip.py
opdroid1234 5242.287μs   +7.843μs reference.py
eclouder 5245.847μs   +3.560μs template-hip.py
blackcola 5247.494μs   +1.647μs test.py
Haw 5248.982μs   +1.488μs template-hip.py
Howard 5280.150μs   +31.168μs amd-fp8-mm.py
jpy794 5283.286μs   +3.136μs template-hip.py
demon_36401 5288.557μs   +5.271μs test.py
samkg 5323.384μs   +34.827μs template-hip.py
langzs335_31673 5327.104μs   +3.720μs template-hip.py
sYnfo 5331.620μs   +4.516μs amd-fp8-mm.py
_wizard_5 5395.874μs   +64.254μs template-hip.py