amd-mixture-of-experts

Deadline

8 days 5 hours (2025-05-27 00:00 UTC)

Language

Python

GPU Type

MI300

Description

For a more complete description, see: https://tinyurl.com/amd-comp-moe Implement a DeepSeek-style Mixture of Experts (MoE) layer for efficient transformer models on a single MI300X device. MoE is a technique that allows scaling model capacity without proportionally increasing computational costs by using a routing mechanism to selectively activate only a subset of parameters for each token. Your task: - Implement token routing using a simple softmax-based learned router - Route tokens to the top-k experts based on router probabilities - Process tokens through their assigned experts - Combine expert outputs weighted by router probabilities - Calculate appropriate auxiliary losses for training stability Input: - `data`: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict) - input: Input tensor of shape [bs, seq_len, d_hidden] - weights: Dictionary containing model weights - config: Dictionary containing model configuration parameters Output: - Tuple containing: - output: Processed tensor [bs, seq_len, d_model] - aux_data: Dictionary with auxiliary data like router probabilities and losses

Reference Implementation

from utils import make_match_reference
from task import input_t, output_t
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, List, Optional
import math

# Reference code in PyTorch
class Expert(nn.Module):
    def __init__(self, config: Dict, d_expert: Optional[int] = None):
        super().__init__()
        self.config = config
        self.act_fn = nn.SiLU()
        self.d_hidden: int = config["d_hidden"]
        self.d_expert: int = config["d_expert"] if d_expert is None else d_expert

        self.W_gate = nn.Linear(self.d_hidden, self.d_expert, bias=False)
        self.W_up = nn.Linear(self.d_hidden, self.d_expert, bias=False)
        self.W_down = nn.Linear(self.d_expert, self.d_hidden, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate = self.act_fn(self.W_gate(x))
        out = self.W_down(gate * self.W_up(x))
        return out


class MoEGate(nn.Module):
    def __init__(self, config: Dict):
        super().__init__()
        self.top_k: int = config["n_experts_per_token"]
        self.num_experts: int = config["n_routed_experts"]
        self.d_hidden: int = config["d_hidden"]

        self.W_g = nn.Linear(self.d_hidden, self.num_experts, bias=False)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        logits = self.W_g(x)
        scores = logits.softmax(dim=-1)
        topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)

        return topk_indices, topk_scores


class MoE(nn.Module):
    def __init__(self, config: Dict):
        super().__init__()
        self.config = config
        self.experts = nn.ModuleList([
            Expert(config)
            for _ in range(config["n_routed_experts"])
        ])
        self.gating_network = MoEGate(config)
        shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
        self.shared_expert = Expert(config=config, d_expert=shared_expert_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shared_output = self.shared_expert(x)
        expert_indices, expert_scores = self.gating_network(x)
        batch_size, seq_len, hidden_dim = x.shape
        orig_shape = x.shape
        x_flat = x.view(-1, hidden_dim)
        flat_expert_indices = expert_indices.view(-1)
        flat_expert_weights = expert_scores.view(-1, 1)
        routed_output_flat = self.moe_infer(x_flat,
                                            flat_expert_indices,
                                            flat_expert_weights)

        routed_output = routed_output_flat.view(*orig_shape)
        return routed_output + shared_output

    @torch.no_grad()
    def moe_infer(self,
                  x: torch.Tensor,
                  flat_expert_indices: torch.Tensor,
                  flat_expert_weights: torch.Tensor
                 ) -> torch.Tensor:
        expert_cache = torch.zeros_like(x)
        idxs = flat_expert_indices.argsort()
        counts = flat_expert_indices.bincount().cpu().numpy()
        tokens_per_expert = counts.cumsum()
        num_per_tok = self.config["n_experts_per_token"]
        token_idxs = idxs // num_per_tok
        for expert_id, end_idx in enumerate(tokens_per_expert):
            start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1]
            if start_idx == end_idx:
                continue

            expert = self.experts[expert_id]
            exp_token_idxs = token_idxs[start_idx:end_idx]
            expert_tokens = x[exp_token_idxs]
            expert_out    = expert(expert_tokens)
            expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
            expert_cache.scatter_reduce_(
                0,
                exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]),
                expert_out,
                reduce='sum'
            )

        return expert_cache


def ref_kernel(data: input_t) -> output_t:
    """
    Reference implementation of DeepSeek-style Mixture of Experts using PyTorch.
    
    Args:
        data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
            - input: Input tensor of shape [batch_size, seq_len, hidden_dim]
            - weights: Dictionary containing model weights
            - config: Dictionary containing model configuration parameters
            
    Returns:
        Tuple containing:
            - output: Processed tensor [batch_size, seq_len, d_model]
            - aux_data: Dictionary with auxiliary data
    """
    input_tensor, weights, config = data
    num_experts = config["n_routed_experts"]
    moe = MoE(config)

    # Fill in the given weights of the model
    moe.gating_network.W_g.weight = nn.Parameter(weights['router.weight'])

    for i in range(num_experts):
        gate_proj_weight = weights[f'experts.{i}.0.weight']
        up_proj_weight = weights[f'experts.{i}.1.weight']
        down_proj_weight = weights[f'experts.{i}.2.weight']

        # Transpose weights to match expected shape for nn.Linear
        moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t())
        moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t())
        moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t())

    moe.shared_expert.W_gate.weight = nn.Parameter(weights['shared_experts.0.weight'].t())
    moe.shared_expert.W_up.weight = nn.Parameter(weights['shared_experts.1.weight'].t())
    moe.shared_expert.W_down.weight = nn.Parameter(weights['shared_experts.2.weight'].t())

    output = moe(input_tensor)

    return output


# Input generation for the reference code

def generate_input(
    dhidden: int,
    dexpert: int,
    nroutedexperts: int,
    nsharedexperts: int,
    nexpertspertoken: int,
    bs: int,
    seqlen: int,
    seed: int
) -> input_t:

    # Really dumb but for now _ isn't parsing correctly.
    d_hidden = dhidden
    d_expert = dexpert
    n_routed_experts = nroutedexperts
    n_shared_experts = nsharedexperts
    n_experts_per_token = nexpertspertoken
    batch_size = bs
    seq_len = seqlen

    config = {
        "d_hidden": d_hidden,
        "d_expert": d_expert,
        "n_routed_experts": n_routed_experts,
        "n_shared_experts": n_shared_experts,
        "n_experts_per_token": n_experts_per_token,
        "batch_size": batch_size,
        "seq_len": seq_len,
    }

    gen = torch.Generator(device='cuda')
    gen.manual_seed(seed)

    num_experts = n_routed_experts
    expert_dim = d_expert
    weights = {}

    input_tensor = torch.randn(
        (batch_size, seq_len, d_hidden),
        device='cuda',
        dtype=torch.float16,
        generator=gen
    ).contiguous()

    # Initialize router weights
    weights['router.weight'] = torch.randn(
        (num_experts, d_hidden),
        device="cuda",
        dtype=torch.float16,
        generator=gen
    ) / math.sqrt(d_hidden)

    for i in range(num_experts):
        weights[f'experts.{i}.0.weight'] = torch.randn(
            (d_hidden, expert_dim),
            device='cuda',
            dtype=torch.float16,
            generator=gen
        ) / math.sqrt(expert_dim)

        weights[f'experts.{i}.1.weight'] = torch.randn(
            (d_hidden, expert_dim),
            device='cuda',
            dtype=torch.float16,
            generator=gen
        ) / math.sqrt(expert_dim)

        weights[f'experts.{i}.2.weight'] = torch.randn(
            (expert_dim, d_hidden),
            device='cuda',
            dtype=torch.float16,
            generator=gen
        ) / math.sqrt(d_hidden)
    
    weights['shared_experts.0.weight'] = torch.randn(
        (d_hidden, expert_dim * n_shared_experts),
        device='cuda',
        dtype=torch.float16,
        generator=gen
    ) / math.sqrt(expert_dim * n_shared_experts)
    weights['shared_experts.1.weight'] = torch.randn(
        (d_hidden, expert_dim * n_shared_experts),
        device='cuda',
        dtype=torch.float16,
        generator=gen
    ) / math.sqrt(expert_dim * n_shared_experts)
    weights['shared_experts.2.weight'] = torch.randn(
        (expert_dim * n_shared_experts, d_hidden),
        device='cuda',
        dtype=torch.float16,
        generator=gen
    ) / math.sqrt(d_hidden)

    return (input_tensor, weights, config)


check_implementation = make_match_reference(ref_kernel, rtol=1e-2, atol=1e-2)

Rankings

MI300

Seb 🥇 9640.920μs baguette_moe.py
hatoo 🥈 10178.474μs   +537.554μs submission.py
myy1966 🥉 11127.494μs   +949.020μs submission.py
kfz 14747.525μs   +3620.031μs moe_v3_clean.py
Xavier Init 20047.587μs   +5300.062μs shakira.py
intrinsicmode 22259.831μs   +2212.244μs triton_v03.py
fanwenjie 25062.949μs   +2803.118μs moe.py
pengcuo 25562.669μs   +499.720μs submission.py
GnSight 30246.845μs   +4684.176μs submission.py
ALI 36829.672μs   +6582.827μs moe_t.py
rt11 314470.866μs   +277641.194μs v1.py
Arseni Ivanov 602567.079μs   +288096.213μs submission.py
nicolaswilde 1391905.579μs   +789338.500μs amd-fp8-moe-myself.py
Shinsato Masumi 4285228.471μs   +2893322.893μs submission.py
wangxun1010 6159661.076μs   +1874432.605μs submission.py
bobmarleybiceps 6178832.276μs   +19171.199μs sample.py
gowtham_tupili 6179881.328μs   +1049.052μs amd-mixture-of-experts.py
Tecahens 6209036.556μs   +29155.228μs template-moe.py
LuiZzz 6225617.760μs   +16581.203μs moe_ref.py
az 6228461.073μs   +2843.314μs submission.py
sahanp 6237320.428μs   +8859.355μs submission.py
summergift0941 6243765.991μs   +6445.563μs submission.py
osborn0016 6246501.007μs   +2735.016μs my_moe_kernel.py
Qwesh157 6279979.738μs   +33478.731μs submission.py
fxfxfxfxfxfxfxfx 6327254.849μs   +47275.110μs submission.py
DizzleRama 7274846.078μs   +947591.230μs amd-mixture-of-experts.py
whatdhack_ 7306845.777μs   +31999.698μs submission_ref_moe.py
Snektron 7378804.730μs   +71958.953μs submission.py
siro 7381670.727μs   +2865.997μs triton_v01.py
DUMBPANDABEAR 7406199.872μs   +24529.145μs submission.py
Shravan 7426457.767μs   +20257.895μs pytorch_moe_06_05_2025.py
Zixian Wang 7439205.211μs   +12747.445μs amd-mixture-of-experts.py
wildman 7456970.145μs   +17764.934μs submission.py
Austin Liu 7476397.058μs   +19426.913μs submission.py
luojiehao. 7487957.032μs   +11559.974μs submission.py
ryshaw 7489773.399μs   +1816.366μs moe.py
Ding 7500260.395μs   +10486.996μs submission.py
gxtzhuxi 7525681.155μs   +25420.760μs submission.py
legendary_fawn_56575 7564848.987μs   +39167.832μs submission.py
_hui_xu 7981381.535μs   +416532.548μs amd-mixture-of-experts.py