prefixsum_v2

Deadline

136 days 12 hours remaining (2025-12-30 00:00 UTC)

Language

Python

GPU Types

A100, B200, H100, L4

Description

Implement an inclusive prefix sum (scan) kernel that matches the reference implementation. The kernel should compute the cumulative sum of all elements up to each position. Because of numerical instability, the tolerance is scaled by the square root of the input size. Input: - `data`: A 1D tensor of size `n` Output: - `output`: A 1D tensor of size `n`

Reference Implementation

from utils import match_reference, DeterministicContext
import torch
from task import input_t, output_t


def ref_kernel(data: input_t) -> output_t:
    """
    Reference implementation of inclusive prefix sum using PyTorch.
    Args:
        data: Input tensor to compute prefix sum on
    Returns:
        Tensor containing the inclusive prefix sum
    """
    with DeterministicContext():
        data, output = data
        output = torch.cumsum(data.to(torch.float64), dim=0).to(torch.float64)
        return output


def generate_input(size: int, seed: int) -> input_t:
    """
    Generates random input tensor.
    Returns:
        Tensor to compute prefix sum on
    """
    gen = torch.Generator(device="cuda")
    gen.manual_seed(seed)
    x = torch.randn(
        size, device="cuda", dtype=torch.float32, generator=gen
    ).contiguous()
    y = torch.empty(size, device="cuda", dtype=torch.float32).contiguous()
    return x, y


# This algorithm is very sensitive to the tolerance and the error is magnified by the input size
# The tolerance is scaled by the square root of the input size
def check_implementation(data: input_t, output: output_t) -> str:
    # Then get the size for scaling the tolerance
    n = data.numel()

    scale_factor = n ** 0.5  # Square root of input size
    rtol = 1e-5 * scale_factor
    atol = 1e-5 * scale_factor

    return match_reference(data, output, reference=ref_kernel, rtol=rtol, atol=atol)

No submissions yet

Be the first to submit a solution for this challenge!