Skip to content

How to Train a Convolutional Neural Network on the Flower Dataset using PyTorch

Introduction

The Oxford 102 Flower dataset is a benchmark dataset for image classification tasks in computer vision, containing 102 flower categories with a total of 8,189 images (Nilsback & Zisserman, 2008). This dataset is valuable for evaluating model performance and can be instrumental in training Convolutional Neural Networks (CNNs). This guide will demonstrate how to load the dataset, preprocess the images, train a CNN model using PyTorch, and evaluate its performance.

Importance in Computer Vision

CNNs are inherently well-suited for image classification tasks due to their ability to automatically learn spatial hierarchies of features from images (LeCun, Bottou, Bengio, & Haffner, 1998). By applying CNNs to the flower dataset, we can explore their ability to generalize across different classes and extract crucial features pertinent to identifying floral species.

Step 1: Loading the Flower Dataset with PyTorch

To effectively load and preprocess the flower dataset, we can utilize PyTorch’s torchvision library. Below is code that illustrates how to load the dataset, define transformations, and visualize some samples:

import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# Defining the transformation for image preprocessing
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# Loading the dataset
flower_data = datasets.ImageFolder(root='path_to_flower_dataset', transform=transform)
train_loader = torch.utils.data.DataLoader(flower_data, batch_size=32, shuffle=True)

# Visualizing samples from the dataset
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    plt.imshow(img.permute(1, 2, 0))
    plt.axis('off')
    plt.show()

# Displaying 5 samples
for images, labels in train_loader:
    grid = torchvision.utils.make_grid(images[:5])
    imshow(grid)
    break

This code not only loads the dataset but also applies transformations to enhance model training efficiency by resizing images to 128x128 pixels and converting them to tensors.

Step 2: Defining the CNN Model

Next, we need to define the architecture of our CNN. Below is a simple architecture suitable for flower classification:

import torch.nn as nn
import torch.nn.functional as F

class FlowerCNN(nn.Module):
    def __init__(self):
        super(FlowerCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 32 * 32, 128)
        self.fc2 = nn.Linear(128, 102)  # 102 classes

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 32 * 32)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize the model
model = FlowerCNN()

This CNN consists of two convolutional layers followed by max pooling, a fully connected layer, and an output layer corresponding to the number of flower classes (102).

Step 3: Training the Model

Now, we define the loss function and the optimizer and carry out the training process:

import torch.optim as optim

# Setting the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10): # Number of epochs
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/10], Loss: {loss.item():.4f}')

Step 4: Evaluating the Model

After training, we can evaluate the performance of the model on a validation set and measure accuracy, precision, and recall. This can be done using the following code:

# Define a function to evaluate the model on the validation dataset
def evaluate(model, validation_loader):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():  # No need to track gradients
        for images, labels in validation_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Accuracy of the model on the validation set: {accuracy:.2f}%')

# Call the evaluate function (assuming validation_loader is defined)
# evaluate(model, validation_loader)

Summary

In this guide, we covered how to train a convolutional neural network using the Oxford 102 Flower Dataset in PyTorch. By following these steps, one can effectively train a robust image classification model while also understanding the core components such as data loading, model definition, training, and evaluation.

References

  • LeCun, Y., Bottou, L., Bengio, Y., & Haffner, P. (1998). Gradient-Based Learning Applied to Document Recognition. Proceedings of the IEEE, 86(11), 2278-2324.
  • Nilsback, M.-E., & Zisserman, A. (2008). Automated Flower Classification over a Large Number of Classes. Proceedings of the Indian Conference on Computer Vision, Graphics and Image Processing, 722-727.