Skip to content

Cart

Your cart is empty

Article: Accelerate Deep Learning models using custom Kernels

Accelerate Deep Learning models using custom Kernels

Accelerate Deep Learning models using custom Kernels

Accelerate Deep Learning models using custom Kernels

By Zonunfeli Ralte

Introduction
Computer Vision has rapidly evolved from handcrafted features to powerful architectures such as CNNs, Transformers, and multimodal vision-language systems. Along this journey, fine-grained classification has emerged as one of the most challenging tasks — distinguishing subtle differences between highly similar categories (e.g., bird species, car models, or medical images).

In my recent book, I dedicate an entire chapter to Fine-grained Bilinear CNNs, showcasing how these techniques fit into the broader evolution of vision models — from traditional pipelines to the latest generative AI–driven approaches.

This blog builds directly on that theme: taking the bilinear pooling idea from theory to practice. Instead of relying on heavy C++/CuDNN implementations, we’ll see how modern tooling enables us to supercharge deep learning models with custom kernels, all in Python.

Specifically, we’ll explore how to:

  • Implement bilinear pooling for fine-grained classification.
  • Use PyTorch, PyTorch Lightning, and Lightning AI for modular training.
  • Leverage Triton from OpenAI to write efficient GPU kernels that rival hand-tuned CUDA.

By the end, you’ll see how to bridge concepts from the book into practice — building efficient, scalable models for real-world fine-grained recognition tasks. Also this blog provides an in-depth comparison of three implementations of a bilinear pooling convolutional neural network (CNN) for EMNIST image classification. The implementations evolve from:

  • Native PyTorch: A manual approach for foundational understanding.
  • PyTorch Lightning: Abstraction for cleaner, scalable code.
  • Triton: Custom GPU kernels for performance optimization.

Through this journey, we aim to understand how code efficiency and performance improve with each step.

As deep learning practitioners, we often start with basic implementations before adopting frameworks or optimizations. This blog demonstrates how the same CNN model evolves, highlighting changes in code structure, efficiency, and performance.

The task involves classifying EMNIST characters using bilinear pooling, which computes the outer product of feature representations, enriching the model’s expressive power.

1. Native PyTorch Implementation

The first implementation uses raw PyTorch with manual loops and standard operations. While flexible, it involves significant boilerplate code.

Key Code Blocks

Data Preparation

The torchvision.datasets.EMNIST dataset is prepared using transforms for resizing, normalizing, and converting images to tensors.

import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

train_dataset = datasets.EMNIST(
    root='/content/sample_data',
    split='byclass',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
)
test_dataset = datasets.EMNIST(
    root='/content/sample_data',
    split='byclass',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
  

Model Definition
The CNN uses two convolutional streams, followed by bilinear pooling using torch.bmm()

import torch.nn as nn

class BilinearCNN(nn.Module):
    def __init__(self, num_classes=62):
        super(BilinearCNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Linear(64 * 64, num_classes)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x1 = x1.view(x1.size(0), 64, -1)
        x2 = x2.view(x2.size(0), 64, -1)
        bilinear_output = torch.bmm(x1, x2.transpose(1, 2))
        bilinear_output = bilinear_output.view(bilinear_output.size(0), -1)
        return self.fc(bilinear_output)
  

Training and Validation
A manual loop handles epoch-wise training and validation.

import torch.optim as optim

model = BilinearCNN(num_classes=62).to('cuda')
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(3):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to('cuda'), labels.to('cuda')
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
  

2. PyTorch Lightning Implementation

The PyTorch Lightning version abstracts the repetitive code, simplifying training and validation logic.

Key Code Blocks

Model Definition

The model is defined as a subclass of pl.LightningModule, encapsulating training and validation steps.

import pytorch_lightning as pl

class BilinearCNN(pl.LightningModule):
    def __init__(self, num_classes=62, lr=0.001):
        super(BilinearCNN, self).__init__()
        self.lr = lr
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Linear(64 * 64, num_classes)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x1 = x1.view(x1.size(0), 64, -1)
        x2 = x2.view(x2.size(0), 64, -1)
        bilinear_output = torch.bmm(x1, x2.transpose(1, 2))
        bilinear_output = bilinear_output.view(bilinear_output.size(0), -1)
        return self.fc(bilinear_output)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)
  

Training with Lightning Trainer
The Trainer automates training and validation loops, improving scalability.

trainer = pl.Trainer(max_epochs=3, devices=1, accelerator='gpu')
model = BilinearCNN(num_classes=62)
trainer.fit(model, train_dataloaders=train_loader)
  

3. Triton Kernel Optimization

The Triton implementation optimizes bilinear pooling with a custom GPU kernel.

Key Code Block: Triton Kernel

A Triton kernel replaces torch.bmm() for faster bilinear pooling.

import triton
import triton.language as tl

@triton.jit
def bilinear_pooling_kernel(x1, x2, output, num_features: tl.constexpr, feature_dim: tl.constexpr):
    batch_idx = tl.program_id(0)
    x1_ptr = x1 + batch_idx * num_features * feature_dim
    x2_ptr = x2 + batch_idx * num_features * feature_dim
    output_ptr = output + batch_idx * num_features * num_features
    for i in range(num_features):
        for j in range(num_features):
            acc = 0.0
            for k in range(feature_dim):
                acc += tl.load(x1_ptr + i * feature_dim + k) * tl.load(x2_ptr + j * feature_dim + k)
            tl.store(output_ptr + i * num_features + j, acc)
  

This kernel is integrated into the Lightning model, enabling faster GPU computation.

Results

Press enter or click to view image in full size

Conclusion

Each implementation builds upon the previous, balancing simplicity, scalability, and performance. Native PyTorch is ideal for beginners, PyTorch Lightning for scalable research, and Triton for production-grade performance.

Leave a comment

This site is protected by hCaptcha and the hCaptcha Privacy Policy and Terms of Service apply.

All comments are moderated before being published.

Read more

The new imperative: Mastering Cloud Auditing in the age of digital transformation

The new imperative: Mastering Cloud Auditing in the age of digital transformation

The new imperative: Mastering Cloud Auditing in the age of digital transformation By Venkat Ramana Krothapalli (Author, Mastering Cloud Auditing) The shift to cloud and multi-cloud architectures ...

Read more