sort_v2
Deadline
136 days 12 hours remaining (2025-12-30 00:00 UTC)
Language
Python
GPU Types
A100, B200, H100, L4
Description
Implement a sort kernel that matches the reference implementation. The kernel should sort the input array in ascending order using a sort algorithm of your choice. Input arrays are generated as random floating-point numbers, where each row of a roughly square matrix is drawn from a normal distribution with a different mean value per row based on the seed and then flattened into a 1D array.
Reference Implementation
from utils import make_match_reference, DeterministicContext
import torch
from task import input_t, output_t
def ref_kernel(data: input_t) -> output_t:
"""
Reference implementation of sort using PyTorch.
Args:
data: Input tensor to be sorted
Returns:
Sorted tensor
"""
with DeterministicContext():
data, output = data
output[...] = torch.sort(data)[0]
return output
def generate_input(size: int, seed: int) -> torch.Tensor:
"""
Generates random input tensor where elements are drawn from different distributions.
Args:
size: Total size of the final 1D tensor
seed: Base seed for random generation
Returns:
1D tensor of size `size` containing flattened values from different distributions
"""
# Calculate dimensions for a roughly square 2D matrix
rows = int(size**0.5) # Square root for roughly square shape
cols = (
size + rows - 1
) // rows # Ceiling division to ensure total size >= requested size
gen = torch.Generator(device="cuda")
result = torch.empty((rows, cols), device="cuda", dtype=torch.float32)
# Different seed for each row!
for i in range(rows):
row_seed = seed + i
gen.manual_seed(row_seed)
# Generate values for this row with mean=row_seed
result[i, :] = (
torch.randn(cols, device="cuda", dtype=torch.float32, generator=gen)
+ row_seed
)
# Flatten and trim to exact size requested
input_tensor = result.flatten()[:size].contiguous()
output_tensor = torch.empty_like(
input_tensor, device="cuda", dtype=torch.float32
).contiguous()
return input_tensor, output_tensor
check_implementation = make_match_reference(ref_kernel)
No submissions yet
Be the first to submit a solution for this challenge!