Article: Accelerate Deep Learning models using custom Kernels
Accelerate Deep Learning models using custom Kernels
Accelerate Deep Learning models using custom Kernels
By Zonunfeli Ralte
Introduction
Computer Vision has rapidly evolved from handcrafted features to powerful architectures such as CNNs, Transformers, and multimodal vision-language systems. Along this journey, fine-grained classification has emerged as one of the most challenging tasks — distinguishing subtle differences between highly similar categories (e.g., bird species, car models, or medical images).
In my recent book, I dedicate an entire chapter to Fine-grained Bilinear CNNs, showcasing how these techniques fit into the broader evolution of vision models — from traditional pipelines to the latest generative AI–driven approaches.
This blog builds directly on that theme: taking the bilinear pooling idea from theory to practice. Instead of relying on heavy C++/CuDNN implementations, we’ll see how modern tooling enables us to supercharge deep learning models with custom kernels, all in Python.
Specifically, we’ll explore how to:
- Implement bilinear pooling for fine-grained classification.
- Use PyTorch, PyTorch Lightning, and Lightning AI for modular training.
- Leverage Triton from OpenAI to write efficient GPU kernels that rival hand-tuned CUDA.
By the end, you’ll see how to bridge concepts from the book into practice — building efficient, scalable models for real-world fine-grained recognition tasks. Also this blog provides an in-depth comparison of three implementations of a bilinear pooling convolutional neural network (CNN) for EMNIST image classification. The implementations evolve from:
- Native PyTorch: A manual approach for foundational understanding.
- PyTorch Lightning: Abstraction for cleaner, scalable code.
- Triton: Custom GPU kernels for performance optimization.
Through this journey, we aim to understand how code efficiency and performance improve with each step.
As deep learning practitioners, we often start with basic implementations before adopting frameworks or optimizations. This blog demonstrates how the same CNN model evolves, highlighting changes in code structure, efficiency, and performance.
The task involves classifying EMNIST characters using bilinear pooling, which computes the outer product of feature representations, enriching the model’s expressive power.
1. Native PyTorch Implementation
The first implementation uses raw PyTorch with manual loops and standard operations. While flexible, it involves significant boilerplate code.
Key Code Blocks
Data Preparation
The torchvision.datasets.EMNIST dataset is prepared using transforms for resizing, normalizing, and converting images to tensors.
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
train_dataset = datasets.EMNIST(
root='/content/sample_data',
split='byclass',
train=True,
download=True,
transform=transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
)
test_dataset = datasets.EMNIST(
root='/content/sample_data',
split='byclass',
train=False,
download=True,
transform=transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
Model Definition
The CNN uses two convolutional streams, followed by bilinear pooling using torch.bmm()
import torch.nn as nn
class BilinearCNN(nn.Module):
def __init__(self, num_classes=62):
super(BilinearCNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc = nn.Linear(64 * 64, num_classes)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
x1 = x1.view(x1.size(0), 64, -1)
x2 = x2.view(x2.size(0), 64, -1)
bilinear_output = torch.bmm(x1, x2.transpose(1, 2))
bilinear_output = bilinear_output.view(bilinear_output.size(0), -1)
return self.fc(bilinear_output)
Training and Validation
A manual loop handles epoch-wise training and validation.
import torch.optim as optim
model = BilinearCNN(num_classes=62).to('cuda')
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(3):
model.train()
for images, labels in train_loader:
images, labels = images.to('cuda'), labels.to('cuda')
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
2. PyTorch Lightning Implementation
The PyTorch Lightning version abstracts the repetitive code, simplifying training and validation logic.
Key Code Blocks
Model Definition
The model is defined as a subclass of pl.LightningModule, encapsulating training and validation steps.
import pytorch_lightning as pl
class BilinearCNN(pl.LightningModule):
def __init__(self, num_classes=62, lr=0.001):
super(BilinearCNN, self).__init__()
self.lr = lr
self.conv1 = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc = nn.Linear(64 * 64, num_classes)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
x1 = x1.view(x1.size(0), 64, -1)
x2 = x2.view(x2.size(0), 64, -1)
bilinear_output = torch.bmm(x1, x2.transpose(1, 2))
bilinear_output = bilinear_output.view(bilinear_output.size(0), -1)
return self.fc(bilinear_output)
def training_step(self, batch, batch_idx):
images, labels = batch
outputs = self(images)
loss = self.criterion(outputs, labels)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=self.lr)
Training with Lightning Trainer
The Trainer automates training and validation loops, improving scalability.
trainer = pl.Trainer(max_epochs=3, devices=1, accelerator='gpu') model = BilinearCNN(num_classes=62) trainer.fit(model, train_dataloaders=train_loader)
3. Triton Kernel Optimization
The Triton implementation optimizes bilinear pooling with a custom GPU kernel.
Key Code Block: Triton Kernel
A Triton kernel replaces torch.bmm() for faster bilinear pooling.
import triton
import triton.language as tl
@triton.jit
def bilinear_pooling_kernel(x1, x2, output, num_features: tl.constexpr, feature_dim: tl.constexpr):
batch_idx = tl.program_id(0)
x1_ptr = x1 + batch_idx * num_features * feature_dim
x2_ptr = x2 + batch_idx * num_features * feature_dim
output_ptr = output + batch_idx * num_features * num_features
for i in range(num_features):
for j in range(num_features):
acc = 0.0
for k in range(feature_dim):
acc += tl.load(x1_ptr + i * feature_dim + k) * tl.load(x2_ptr + j * feature_dim + k)
tl.store(output_ptr + i * num_features + j, acc)
This kernel is integrated into the Lightning model, enabling faster GPU computation.
Results
Press enter or click to view image in full size

Conclusion
Each implementation builds upon the previous, balancing simplicity, scalability, and performance. Native PyTorch is ideal for beginners, PyTorch Lightning for scalable research, and Triton for production-grade performance.

Leave a comment
This site is protected by hCaptcha and the hCaptcha Privacy Policy and Terms of Service apply.