trimul

Deadline

90 days 19 hours remaining (2025-09-30 00:00 UTC)

Language

Python

GPU Types

A100, B200, H100, MI300

Description

For a more complete description, see: https://tinyurl.com/gpumode-trimul You will be implementing a Triangle Multiplicative Update (TriMul) module that is a core operation for AlphaFold3, Chai, Protenix, and other protein structure prediction models in BioML. The TriMul operator operates over a 4D tensor of shape [B, N, N, C]. Your task: - Implement the "outgoing" version of the TriMul operator from the AlphaFold3 paper. - You will not have to compute or store gradients for this version. You will only need to implement the forward pass. Input: - `data`: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict) - input: Input tensor of shape [bs, seq_len, seq_len, dim] - mask: Mask tensor of shape [bs, seq_len, seq_len] - weights: Dictionary containing model weights - config: Dictionary containing model configuration parameters Output: - Tuple containing: - output: Processed tensor [bs, seq_len, seq_len, dim]

Reference Implementation

from utils import make_match_reference, DisableCuDNNTF32
from task import input_t, output_t

import torch
from torch import nn, einsum
import math

# Reference code in PyTorch
class TriMul(nn.Module):
    # Based on https://github.com/lucidrains/triangle-multiplicative-module/blob/main/triangle_multiplicative_module/triangle_multiplicative_module.py
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
    ):
        super().__init__()

        self.norm = nn.LayerNorm(dim)

        self.left_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.right_proj = nn.Linear(dim, hidden_dim, bias=False)

        self.left_gate = nn.Linear(dim, hidden_dim, bias=False)
        self.right_gate = nn.Linear(dim, hidden_dim, bias=False)
        self.out_gate = nn.Linear(dim, hidden_dim, bias=False)

        self.to_out_norm = nn.LayerNorm(hidden_dim)
        self.to_out = nn.Linear(hidden_dim, dim, bias=False)

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        x: [bs, seq_len, seq_len, dim]
        mask: [bs, seq_len, seq_len]

        Returns:
            output: [bs, seq_len, seq_len, dim]
        """
        batch_size, seq_len, _, dim = x.shape

        x = self.norm(x)

        left = self.left_proj(x)
        right = self.right_proj(x)

        mask = mask.unsqueeze(-1)
        left = left * mask
        right = right * mask

        left_gate = self.left_gate(x).sigmoid()
        right_gate = self.right_gate(x).sigmoid()
        out_gate = self.out_gate(x).sigmoid()

        left = left * left_gate
        right = right * right_gate

        out = einsum('... i k d, ... j k d -> ... i j d', left, right)
        # This einsum is the same as the following:
        # out = torch.zeros(batch_size, seq_len, seq_len, dim, device=x.device)
        
        # # Compute using nested loops
        # for b in range(batch_size):
        #     for i in range(seq_len):
        #         for j in range(seq_len):
        #             # Compute each output element
        #             for k in range(seq_len):
        #                 out[b, i, j] += left[b, i, k, :] * right[b, j, k, :]

        out = self.to_out_norm(out)
        out = out * out_gate
        return self.to_out(out)


def ref_kernel(data: input_t) -> output_t:
    """
    Reference implementation of TriMul using PyTorch.
    
    Args:
        data: Tuple of (input: torch.Tensor, mask: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
            - input: Input tensor of shape [batch_size, seq_len, seq_len, dim]
            - mask: Mask tensor of shape [batch_size, seq_len, seq_len]
            - weights: Dictionary containing model weights
            - config: Dictionary containing model configuration parameters
    """

    # Use deterministic kernels and disable TF32 for accuracy
    with DisableCuDNNTF32():
        input_tensor, mask, weights, config = data
        trimul = TriMul(dim=config["dim"], hidden_dim=config["hidden_dim"]).to(input_tensor.device)

        # Fill in the given weights of the model
        trimul.norm.weight = nn.Parameter(weights['norm.weight'])
        trimul.norm.bias = nn.Parameter(weights['norm.bias'])
        trimul.left_proj.weight = nn.Parameter(weights['left_proj.weight'])
        trimul.right_proj.weight = nn.Parameter(weights['right_proj.weight'])
        trimul.left_gate.weight = nn.Parameter(weights['left_gate.weight'])
        trimul.right_gate.weight = nn.Parameter(weights['right_gate.weight'])
        trimul.out_gate.weight = nn.Parameter(weights['out_gate.weight'])
        trimul.to_out_norm.weight = nn.Parameter(weights['to_out_norm.weight'])
        trimul.to_out_norm.bias = nn.Parameter(weights['to_out_norm.bias'])
        trimul.to_out.weight = nn.Parameter(weights['to_out.weight'])

        output = trimul(input_tensor, mask)

        return output


# Input generation for the reference code
def generate_input(
    seqlen: int,
    bs: int,
    dim: int,
    hiddendim: int,
    seed: int,
    nomask: bool,
    distribution: str,
) -> input_t:

    # Really dumb but for now _ isn't parsing correctly.
    batch_size = bs
    seq_len = seqlen
    hidden_dim = hiddendim
    no_mask = nomask

    config = {
        "hidden_dim": hidden_dim,
        "dim": dim,
    }

    gen = torch.Generator(device='cuda')
    gen.manual_seed(seed)

    weights = {}

    # Generate input tensor based on distribution
    if distribution == "cauchy":
        # Heavier tail distribution
        input_tensor = torch.distributions.Cauchy(0, 2).sample(
            (batch_size, seq_len, seq_len, dim)
        ).to(device='cuda', dtype=torch.float32)
    else:  # normal distribution
        input_tensor = torch.randn(
            (batch_size, seq_len, seq_len, dim),
            device='cuda',
            dtype=torch.float32,
            generator=gen
        ).contiguous()

    if no_mask:
        mask = torch.ones(batch_size, seq_len, seq_len, device=input_tensor.device)
    else:
        mask = torch.randint(0, 2, (batch_size, seq_len, seq_len), device=input_tensor.device, generator=gen)

    # Initialize model weights based on distribution
    weights["norm.weight"] = torch.randn(dim, device="cuda", dtype=torch.float32)
    weights["norm.bias"] = torch.randn(dim, device="cuda", dtype=torch.float32)
    weights["left_proj.weight"] = torch.randn(hidden_dim, dim, device="cuda", dtype=torch.float32) / math.sqrt(hidden_dim)
    weights["right_proj.weight"] = torch.randn(hidden_dim, dim, device="cuda", dtype=torch.float32) / math.sqrt(hidden_dim)
    weights["left_gate.weight"] = torch.randn(hidden_dim, dim, device="cuda", dtype=torch.float32) / math.sqrt(hidden_dim)
    weights["right_gate.weight"] = torch.randn(hidden_dim, dim, device="cuda", dtype=torch.float32) / math.sqrt(hidden_dim)
    weights["out_gate.weight"] = torch.randn(hidden_dim, dim, device="cuda", dtype=torch.float32) / math.sqrt(hidden_dim)
    weights["to_out_norm.weight"] = torch.randn(hidden_dim, device="cuda", dtype=torch.float32)
    weights["to_out.weight"] = torch.randn(dim, hidden_dim, device="cuda", dtype=torch.float32) / math.sqrt(dim)
    weights["to_out_norm.bias"] = torch.randn(hidden_dim, device="cuda", dtype=torch.float32)

    return (input_tensor, mask, weights, config)


check_implementation = make_match_reference(ref_kernel, rtol=2e-2, atol=2e-2)

Rankings

A100

bobmarleybiceps 🥇 13006.406μs functional_submission.py
Dante 🥈 16399.012μs   +3392.606μs submission.py
Haw 🥉 19027.829μs   +2628.817μs submission.py
az 20011.399μs   +983.570μs submission.py
gau.nernst 23150.423μs   +3139.024μs v0.py

B200

Dante 🥇 6712.503μs submission.py
siro 🥈 7924.450μs   +1211.947μs trimul.py
az 🥉 8204.245μs   +279.795μs submission.py

H100

Dante 🥇 9206.993μs submission.py

MI300

Dante 🥇 7827.614μs submission.py
Haw 🥈 10982.959μs   +3155.345μs baseline.py