prefixsum_v2
Deadline
136 days 12 hours remaining (2025-12-30 00:00 UTC)
Language
Python
GPU Types
A100, B200, H100, L4
Description
Implement an inclusive prefix sum (scan) kernel that matches the reference implementation. The kernel should compute the cumulative sum of all elements up to each position. Because of numerical instability, the tolerance is scaled by the square root of the input size. Input: - `data`: A 1D tensor of size `n` Output: - `output`: A 1D tensor of size `n`
Reference Implementation
from utils import match_reference, DeterministicContext
import torch
from task import input_t, output_t
def ref_kernel(data: input_t) -> output_t:
"""
Reference implementation of inclusive prefix sum using PyTorch.
Args:
data: Input tensor to compute prefix sum on
Returns:
Tensor containing the inclusive prefix sum
"""
with DeterministicContext():
data, output = data
output = torch.cumsum(data.to(torch.float64), dim=0).to(torch.float64)
return output
def generate_input(size: int, seed: int) -> input_t:
"""
Generates random input tensor.
Returns:
Tensor to compute prefix sum on
"""
gen = torch.Generator(device="cuda")
gen.manual_seed(seed)
x = torch.randn(
size, device="cuda", dtype=torch.float32, generator=gen
).contiguous()
y = torch.empty(size, device="cuda", dtype=torch.float32).contiguous()
return x, y
# This algorithm is very sensitive to the tolerance and the error is magnified by the input size
# The tolerance is scaled by the square root of the input size
def check_implementation(data: input_t, output: output_t) -> str:
# Then get the size for scaling the tolerance
n = data.numel()
scale_factor = n ** 0.5 # Square root of input size
rtol = 1e-5 * scale_factor
atol = 1e-5 * scale_factor
return match_reference(data, output, reference=ref_kernel, rtol=rtol, atol=atol)
No submissions yet
Be the first to submit a solution for this challenge!