sort

Deadline

41 days 17 hours (2025-06-30 00:00 UTC)

Language

Python

GPU Types

A100, H100, L4, T4

Description

Implement a sort kernel that matches the reference implementation. The kernel should sort the input array in ascending order using a sort algorithm of your choice. Input arrays are generated as random floating-point numbers, where each row of a roughly square matrix is drawn from a normal distribution with a different mean value per row based on the seed and then flattened into a 1D array.

Reference Implementation

from utils import make_match_reference
import torch
from task import input_t, output_t


def ref_kernel(data: input_t) -> output_t:
    """
    Reference implementation of sort using PyTorch.
    Args:
        data: Input tensor to be sorted
    Returns:
        Sorted tensor
    """
    return torch.sort(data)[0]


def generate_input(size: int, seed: int) -> torch.Tensor:
    """
    Generates random input tensor where elements are drawn from different distributions.
    
    Args:
        size: Total size of the final 1D tensor
        seed: Base seed for random generation
    
    Returns:
        1D tensor of size `size` containing flattened values from different distributions
    """
    # Calculate dimensions for a roughly square 2D matrix
    rows = int(size ** 0.5)  # Square root for roughly square shape
    cols = (size + rows - 1) // rows  # Ceiling division to ensure total size >= requested size
    
    gen = torch.Generator(device='cuda')
    result = torch.empty((rows, cols), device='cuda', dtype=torch.float32)
    
    # Different seed for each row!
    for i in range(rows):
        row_seed = seed + i
        gen.manual_seed(row_seed)
        
        # Generate values for this row with mean=row_seed
        result[i, :] = torch.randn(cols, device='cuda', dtype=torch.float32, generator=gen) + row_seed
    
    # Flatten and trim to exact size requested
    return result.flatten()[:size].contiguous()


check_implementation = make_match_reference(ref_kernel)

Rankings

L4

Nader 🥇 15687.478μs submission.py
ajhinh 🥈 1175091.112μs   +1159403.634μs l4.py

T4

Nader 🥇 25029.484μs submission.py
ajhinh 🥈 1153206.609μs   +1128177.125μs t4.py

A100

Nader 🥇 3627.491μs submission.py
mancala 🥈 15801.310μs   +12173.819μs submission.py
ajhinh 🥉 202852.265μs   +187050.955μs a100.py

H100

Nader 🥇 2275.541μs submission.py
sajy 🥈 6582.328μs   +4306.786μs submission.py
mancala 🥉 7159.787μs   +577.459μs submission.py
ajhinh 139032.024μs   +131872.237μs h100.py