Sign In
Access to Author tools and Claude Code Assistant requires authentication.
by Adam 5 min read

Writing High-Performance AI Kernels in Mojo: 10x Faster Than PyTorch

How I built custom Mojo kernels for AI research that outperform PyTorch on consumer hardware

mojo performance simd gpu pytorch optimization

Why Custom Kernels?

PyTorch is fantastic for prototyping, but when you’re running the same operation billions of times, every microsecond counts. My SipIt algorithm computes L2 distances for 32,000 vocabulary candidates x 4,096 dimensions – per token.

With 100+ tokens to recover, that’s 3.2+ billion distance computations. PyTorch’s generic implementations leave performance on the table.

The Bottleneck: L2 Distance

The core operation in SipIt token recovery:

# PyTorch implementation
def l2_distances(target, candidates):
    """
    target: [d_model]
    candidates: [vocab_size, d_model]
    returns: [vocab_size]
    """
    diff = candidates - target.unsqueeze(0)  # Broadcasting
    return torch.norm(diff, dim=-1)

For Mistral-7B: 32,000 x 4,096 = 131 million floating-point operations per position.

PyTorch time: ~25ms per batch on RTX 3090

Mojo Kernel: CPU Implementation

from memory import UnsafePointer
from algorithm import parallelize
from math import sqrt

fn compute_l2_distances_parallel(
    target: UnsafePointer[Float32],
    candidates: UnsafePointer[Float32],
    distances: UnsafePointer[Float32],
    batch_size: Int,
    d_model: Int
):
    """
    Compute L2 distances from target to all candidates.
    Uses all CPU cores via parallelization.
    """
    @parameter
    fn compute_single_distance(batch_idx: Int):
        var sum: Float32 = 0.0
        let offset = batch_idx * d_model

        # Sequential access pattern for cache efficiency
        for d in range(d_model):
            let diff = target[d] - candidates[offset + d]
            sum += diff * diff

        distances[batch_idx] = sqrt(sum)

    # Parallelize across all CPU cores
    parallelize[compute_single_distance](batch_size)

Mojo CPU time: ~8ms per batch (3x faster)

Mojo Kernel: GPU Implementation

from max.driver import DeviceContext
from gpu import thread_idx, block_idx, block_dim, sync_threads
from gpu.memory import shared_memory

alias BLOCK_SIZE = 256
alias TILE_SIZE = 256

fn l2_distance_gpu_kernel(
    target: UnsafePointer[Float32],
    candidates: UnsafePointer[Float32],
    distances: UnsafePointer[Float32],
    batch_size: Int,
    d_model: Int
):
    """GPU kernel with shared memory optimization."""

    # Thread identification
    let tid = thread_idx.x
    let bid = block_idx.x
    let global_idx = bid * block_dim.x + tid

    if global_idx >= batch_size:
        return

    # Shared memory for target vector tile
    var shared_target = shared_memory[Float32, TILE_SIZE]()

    var sum: Float32 = 0.0
    let candidate_offset = global_idx * d_model

    # Process in tiles to maximize shared memory reuse
    for tile_start in range(0, d_model, TILE_SIZE):
        # Cooperative loading of target tile
        if tid < TILE_SIZE and tile_start + tid < d_model:
            shared_target[tid] = target[tile_start + tid]

        sync_threads()

        # Compute partial distances using shared memory
        let tile_end = min(TILE_SIZE, d_model - tile_start)
        for i in range(tile_end):
            let diff = shared_target[i] - candidates[candidate_offset + tile_start + i]
            sum += diff * diff

        sync_threads()

    distances[global_idx] = sqrt(sum)

Mojo GPU time: ~2ms per batch (12x faster than PyTorch)

SIMD Vectorization

For even more CPU performance, use explicit SIMD:

from sys.info import simdwidthof
from algorithm import vectorize

fn compute_l2_simd(
    target: UnsafePointer[Float32],
    candidate: UnsafePointer[Float32],
    d_model: Int
) -> Float32:
    """SIMD-vectorized L2 distance for a single candidate."""

    alias simd_width = simdwidthof[Float32]()
    var sum = SIMD[DType.float32, simd_width](0)

    @parameter
    fn process_chunk[width: Int](offset: Int):
        let t = target.load[width=width](offset)
        let c = candidate.load[width=width](offset)
        let diff = t - c
        sum += diff * diff

    vectorize[process_chunk, simd_width](d_model)

    # Horizontal sum
    return sqrt(sum.reduce_add())

SIMD speedup: Additional 2-4x on supported CPUs.

Fidelity Kernel

Beyond L2 distance, I also need fast fidelity scoring:

fn compute_fidelity(
    hidden_states: UnsafePointer[Float32],
    lm_head: UnsafePointer[Float32],
    tokens: UnsafePointer[Int32],
    seq_len: Int,
    vocab_size: Int,
    d_model: Int,
    top_k: Int
) -> Float32:
    """
    Compute what fraction of tokens appear in top-k predictions.
    """
    var correct: Int = 0

    for pos in range(seq_len):
        let hidden = hidden_states + pos * d_model

        # Compute logits for this position
        var max_logits = DynamicVector[Tuple[Float32, Int]]()

        for v in range(vocab_size):
            var logit: Float32 = 0.0
            let weight = lm_head + v * d_model

            # Dot product
            for d in range(d_model):
                logit += hidden[d] * weight[d]

            # Track top-k
            if len(max_logits) < top_k:
                max_logits.push_back((logit, v))
            elif logit > max_logits[top_k - 1][0]:
                # Insert and maintain sorted order
                # ... (heap operations)

        # Check if true token in top-k
        let true_token = tokens[pos]
        for i in range(len(max_logits)):
            if max_logits[i][1] == true_token:
                correct += 1
                break

    return Float32(correct) / Float32(seq_len)

Benchmark Comparison

L2 Distance (32K x 4096)

ImplementationTimeSpeedup
PyTorch (CPU)45ms1.0x
PyTorch (GPU)25ms1.8x
Mojo (CPU, parallel)8ms5.6x
Mojo (CPU, SIMD)4ms11.2x
Mojo (GPU, tiled)2ms22.5x

Full SipIt Recovery (100 tokens)

ImplementationTime
PyTorch baseline~42 minutes
Mojo kernels~3.5 minutes
Speedup12x

Integration with Python

Mojo compiles to shared libraries callable from Python:

# Load Mojo kernel
import ctypes
lib = ctypes.CDLL("./mojo_kernels.so")

# Define function signature
lib.compute_l2_distances.argtypes = [
    ctypes.POINTER(ctypes.c_float),  # target
    ctypes.POINTER(ctypes.c_float),  # candidates
    ctypes.POINTER(ctypes.c_float),  # distances
    ctypes.c_int,                     # batch_size
    ctypes.c_int                      # d_model
]

def fast_l2_distances(target_tensor, candidates_tensor):
    """Python wrapper for Mojo kernel."""
    batch_size = candidates_tensor.shape[0]
    d_model = candidates_tensor.shape[1]

    # Allocate output
    distances = np.zeros(batch_size, dtype=np.float32)

    # Call Mojo kernel
    lib.compute_l2_distances(
        target_tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
        candidates_tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
        distances.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
        batch_size,
        d_model
    )

    return torch.from_numpy(distances)

MAX Engine Integration

For production deployment, use Modular’s MAX Engine:

from max import engine

# Load optimized model
session = engine.InferenceSession()
model = session.load("mistral-7b-optimized.max")

# Run inference with automatic kernel selection
output = model.execute(input_ids=tokens)

MAX automatically selects optimal kernels based on hardware.

Memory Patterns Matter

Cache-Friendly Access

# Good: Sequential memory access
for d in range(d_model):
    sum += target[d] * candidate[d]

# Bad: Strided access
for batch_idx in range(batch_size):
    for d in range(d_model):
        # Jumps by d_model each iteration - cache misses
        sum += candidates[batch_idx * d_model + d]

GPU Coalescing

# Good: Adjacent threads access adjacent memory
let idx = thread_idx.x
data[idx]  # Coalesced

# Bad: Strided access
data[idx * stride]  # Uncoalesced, slow

Lessons Learned

  1. Profile First: Not all operations need custom kernels
  2. Memory > Compute: Memory access patterns dominate performance
  3. Batch Operations: Amortize kernel launch overhead
  4. Use Shared Memory: GPU shared memory is 100x faster than global
  5. Parallelize Early: Mojo’s parallelize is nearly free

Future Optimizations

  1. Tensor Cores: Use FP16 for 2x throughput on RTX 30/40 series
  2. Kernel Fusion: Combine L2 distance + argmin in single kernel
  3. Async Execution: Pipeline CPU preparation with GPU execution
  4. Multi-GPU: Distribute vocabulary across GPUs

Reproducing Results

# Build kernels
mojo build l2_distance.mojo -o l2_benchmark

# Run benchmark
./l2_benchmark

# Expected output:
# Batch size: 32000, d_model: 4096
# Total time: 8.4ms
# Per candidate: 0.26us

When PyTorch isn’t fast enough, write your own kernels. Mojo makes it almost as easy as Python.