matmul

Deadline

41 days 19 hours (2025-06-30 00:00 UTC)

Language

Python

GPU Types

A100, H100, L4, T4

Description

Implement a custom matmul function that matches the reference implementation. The function should handle a tuple of input tensors and apply matmul The shapes of all outer and inner dimensions of tensors are multiples of 16

Reference Implementation

import torch
from task import input_t, output_t
from utils import make_match_reference


def generate_input(m: int, n: int, k: int, seed: int) -> input_t:
    gen = torch.Generator(device='cuda')
    gen.manual_seed(seed)
    a = torch.empty(m, k, device='cuda', dtype=torch.float16)
    a.uniform_(0, 1, generator=gen)
    b = torch.empty(k, n, device='cuda', dtype=torch.float16)
    b.uniform_(0, 1, generator=gen)
    return (a, b)


def ref_kernel(data: input_t) -> output_t:
    a, b = data
    return a @ b


check_implementation = make_match_reference(ref_kernel)

Rankings

L4

nikhilap 🥇 2272.513μs matmul.py
ajhinh 🥈 2344.244μs   +71.732μs l4.py
symlon 🥉 49335.120μs   +46990.876μs submission.py

T4

siclait 🥇 5959.204μs submission.py
hatoo 🥈 6340.941μs   +381.737μs matmul.py
siro 🥉 6795.622μs   +454.681μs submission.py
salad 6881.867μs   +86.246μs ref.py
Shlok 6907.271μs   +25.404μs matmul.py
DizzleRama 7136.877μs   +229.606μs submission_cuda_inline_matmul.py
ajhinh 7938.621μs   +801.744μs t4.py

A100

ajhinh 🥇 661.865μs a100.py
salad 🥈 724.330μs   +62.466μs matmul_triton.py
AM 🥉 728.589μs   +4.259μs my_submission.py
mashisong 731.574μs   +2.985μs matmul_triton.py
dtad 745.251μs   +13.677μs matmul_triton.py
mobicham 810.114μs   +64.862μs gemm_a100_v1.py
Nadav Timor 838.150μs   +28.036μs matmul.py
Bob 8157.397μs   +7319.247μs matmul.py

H100

tomaszki 🥇 249.341μs submission.py
👾ubeai 🥈 251.279μs   +1.938μs matmul.py
ajhinh 🥉 258.918μs   +7.639μs h100.py
guoliang 289.560μs   +30.643μs dp-mm.py
AM 306.678μs   +17.118μs my_submission.py
salad 306.699μs   +0.021μs matmul_triton.py
dtad 343.679μs   +36.979μs matmul_triton.py
Nadav Timor 1823.825μs   +1480.146μs matmul.py