amd-fp8-mm
Deadline
15 days 18 hours (2025-05-27 00:00 UTC)
Language
Python
GPU Type
MI300
Description
You will implement a custom fp8-blockwise matmul kernel optimized for MI300. You will be given single-precision scaling factors for your matrices. The shapes of all outer and inner dimensions of tensors are from DeepSeek-R1. To be explicit, you will be given a tuple of tensors: ``` (a, b, a_scale, b_scale, c) ``` where `a` and `b` are the input matrices, `a_scale` and `b_scale` are the scaling factors for `a` and `b` respectively, and `c` is the output matrix: * `a` is M x K in column-major order in e4m3fnuz * `b` is N x K in column-major order in e4m3fnuz * `a_scale` is M x K // 128 in column-major order in fp32 * `b_scale` is N // 128 x K // 128 in column-major order in fp32 * `c` is M x N in ROW-major order in bf16 Matrix sizes `m` and `n` are divisible by 64, `k` is divisible by 128. The ranking criteria is the geometric mean of the benchmark results. For the grand price, your kernel will be evaluated against the speed of light analysis and the solution closest to the speed of light will be awarded the grand price. ``` The speed of light analysis is: M N K time[us] 1024 1536 7168 8.63 1024 4608 7168 25.89 6144 1536 7168 51.78 6144 4608 7168 155.30 1024 7168 256 3.17 6144 7168 256 17.27 ```
Reference Implementation
import torch
from task import input_t, output_t
from utils import make_match_reference
block_shape = (128, 128)
def generate_input(m: int, n: int, k: int, seed: int) -> input_t:
"""
Generate random input and weights for Blockwise W8A8 Matmul scaled to FP32.
Returns:
Tuple of (
a: torch.Tensor[float8_e4m3fnuz] of shape [m, k],
b: torch.Tensor[float8_e4m3fnuz] of shape [n, k],
a_scale: torch.Tensor[float32] of shape [m, k // 128],
b_scale: torch.Tensor[float32] of shape [n // 128, k // 128],
c: torch.Tensor[bfloat16] of shape [m, n]
)
"""
gen = torch.Generator(device='cuda')
gen.manual_seed(seed)
block_shape_n, block_shape_k = block_shape
scale_n = (n + block_shape_n - 1) // block_shape_n
scale_k = (k + block_shape_k - 1) // block_shape_k
# Generate random inputs with FP8 quantization
a = (torch.randn((k, m), dtype=torch.bfloat16, device="cuda", generator=gen)).to(torch.float8_e4m3fnuz)
b = (torch.randn((k, n), dtype=torch.bfloat16, device="cuda", generator=gen)).to(torch.float8_e4m3fnuz)
# Generate scaling factors with FP32
a_scale = torch.randn([scale_k, m], dtype=torch.float32, device="cuda", generator=gen)
b_scale = torch.randn([scale_k, scale_n], dtype=torch.float32, device="cuda", generator=gen)
c = torch.zeros((m, n), dtype=torch.bfloat16, device="cuda")
return (a.T, b.T, a_scale.T, b_scale.T, c)
def ref_kernel(data: input_t) -> output_t:
"""
Highly inefficient torch reference implementation of FP8 GEMM.
You can use this as a reference / starting template for your implementation.
"""
# c: [m, n] is pre-allocated memory to help remove allocation overhead.
a, b, a_scale, b_scale, c = data
# a is M x K in column-major order, we convert here for simplicity.
a = a.contiguous()
a_scale = a_scale.contiguous()
b_scale = b_scale.contiguous()
# constants
m = a.shape[0]
n = b.shape[0]
k = a.shape[1]
block_shape_n = 128
block_shape_k = 128
scale_n = b_scale.shape[0]
scale_k = b_scale.shape[1]
# Apply blockwise scaling to input 'a'
a_scale = a_scale.unsqueeze(-1).repeat(1, 1, block_shape_k) # Shape: [m, scale_k, block_shape_k]
a_scale = a_scale.reshape(m, scale_k * block_shape_k)
a_scale = a_scale[:, :k]
# Dequantize 'a', in your implementation you should do this at the end.
a = a.to(a_scale.dtype) * a_scale
# Apply blockwise scaling to input 'b'
b_scale = (
b_scale.view(-1, 1)
.repeat(1, block_shape_n * block_shape_k)
.view(scale_n, scale_k, block_shape_n, block_shape_k)
.permute(0, 2, 1, 3) # Reorder dimensions: [scale_n, blk_n, scale_k, blk_k]
.reshape(scale_n * block_shape_n, scale_k * block_shape_k)
)
b_scale = b_scale[:n, :k]
# Dequantize 'b', in your implementation you should do this at the end.
b = b.to(b_scale.dtype) * b_scale
# Compute FP8 GEMM and write to 'c'.
c[...] = (a @ b.T).to(torch.bfloat16)
return c
check_implementation = make_match_reference(ref_kernel, rtol=2e-02, atol=1e-03)
Rankings
MI300
Seb 🥇 | 121.819μs | baguette_3.py |
Snektron 🥈 | 153.495μs +31.676μs | fp8_gemm.py |
Shinsato Masumi 🥉 | 174.561μs +21.067μs | Submission.py |