Leveraging Pre-trained Convolutional Neural Networks (CNNs) for Feature Extraction on the Flower Dataset with Grid Search Classifier Optimization
Introduction
In recent years, Convolutional Neural Networks (CNNs) have become ubiquitous in the fields of image processing and computer vision. These networks excel at feature extraction and modeling complex data representations, making them ideal for image classification tasks. A pre-trained CNN leverages previously learned features on large datasets (like ImageNet) to transfer that knowledge to new tasks—this is especially useful when working with smaller datasets, such as the Flowers-17 dataset. This article explores the process of using pre-trained CNNs for feature extraction from the Flower dataset, followed by employing a grid search to fine-tune a classifier for accurate predictions.
Step 1: Setup Environment
To initiate this task, it is essential to have the necessary libraries installed and imported into your Python environment, including PyTorch and scikit-learn:
import torch
import torchvision.transforms as transforms
from torchvision import models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
Step 2: Load the Dataset
The Flower dataset can be easily loaded using the ImageFolder class from PyTorch. This will simplify the process of navigating through the data.
data_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
flower_dataset = ImageFolder(root='flower_data_folder/', transform=data_transform)
data_loader = DataLoader(flower_dataset, batch_size=32, shuffle=True)
Step 3: Feature Extraction
Next, we utilize a pre-trained model, such as ResNet18, to extract features from the images in the dataset. The model will be set to evaluation mode, and we will remove the final fully connected layer to get feature vectors.
model = models.resnet18(pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1])) # Removing last layer
model.eval()
features = []
labels = []
with torch.no_grad():
for inputs, label in data_loader:
output = model(inputs)
features.append(output.view(output.size(0), -1).numpy())
labels.append(label.numpy())
features = np.concatenate(features)
labels = np.concatenate(labels)
Step 4: Classifier Training with Grid Search
To optimize the hyperparameters of a classifier, we'll implement a Grid Search on Logistic Regression. This method evaluates all possible configurations to increase the model’s performance.
param_grid = {'C': [0.01, 0.1, 1, 10, 100], 'solver': ['liblinear', 'saga']}
clf = LogisticRegression(multi_class='auto', max_iter=200)
grid_search = GridSearchCV(clf, param_grid, cv=5, scoring='accuracy')
grid_search.fit(features, labels)
print("Best parameters found: ", grid_search.best_params_)
Step 5: Evaluation
Upon finding the best parameters, we can use the optimized model to predict the labels of our dataset and evaluate its performance.
best_clf = grid_search.best_estimator_
predictions = best_clf.predict(features)
print(classification_report(labels, predictions))
Proposed Graphs and Tables
- Graphs: Plot of the accuracy scores obtained during grid search for each parameter combination.
- Tables: A summary table to show the results of different hyperparameter settings used in the Grid Search, including their respective accuracy scores.
Conclusion
Using a pre-trained CNN for feature extraction greatly simplifies the model-building process on smaller datasets, such as the Flower dataset. Coupled with a systematic approach like grid search for hyperparameter optimization, we can achieve high classification accuracy without the need for an extensive training dataset.
References
- Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press.
- Chollet, F. (2017). Deep Learning with Python. Manning Publications Company.
- Géron, A. (2019). Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow. O'Reilly Media.
- He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning for Image Recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 770-778).
This comprehensive guide integrates practical coding exercises with theoretical insights, thus providing a holistic approach to feature extraction and model optimization in machine learning.