PyTorch Tutorial: Build an Image Classifier for Flowers | LearnMuchMore

PyTorch Tutorial: Build an Image Classifier for Flowers

Flowers have captivated humanity for centuries with their beauty and diversity. But can we teach machines to identify different types of flowers? Using PyTorch, one of the most popular deep learning libraries, we can create an image classifier to distinguish between daisies, roses, tulips, and other flower types.

In this tutorial, we'll guide you through building a flower image classifier from scratch using PyTorch, training it on the Flowers Dataset from Kaggle. By the end, you’ll have a lightweight app that can predict a flower’s type with accuracy scores.


Why PyTorch for Image Classification?

PyTorch is a versatile and powerful deep learning framework, ideal for tasks like:

  • Building custom neural networks.
  • Handling large datasets with ease.
  • Leveraging pre-trained models for faster development.

What You'll Need

Prerequisites

  • Basic Python knowledge.
  • Familiarity with PyTorch basics (helpful but not required).

Required Libraries

Install the following Python libraries:

bash
pip install torch torchvision matplotlib numpy

You can download the Flowers Dataset from Kaggle.


Step 1: Load and Explore the Dataset

The Flowers Dataset contains images of various flower types. Organize the dataset into training and validation folders, categorized by flower type.

python
import os
from torchvision  import datasets, transforms
# Define data transformations
data_transforms = {
'train': transforms.Compose([ transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]),
'val': transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])
} # Load datasets
data_dir = 'flowers'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x]) for x in ['train', 'val']}
# Create dataloaders
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32, shuffle=True) for x in ['train', 'val']}

Step 2: Build the Model

Use a pre-trained model like ResNet to speed up training.

python
import torch
import torch.nn as nn from torchvision
import models
# Load pre-trained ResNet model
model = models.resnet18(pretrained=True)
# Replace the final layer to match flower classes
num_classes = len(image_datasets['train'].classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)
# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

Step 3: Train the Model

Define the loss function and optimizer, and train the model.

python
import torch.optim as optim
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
# Training loop
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10):
  for epoch in range(num_epochs):
  print(f"Epoch {epoch+1}/{num_epochs}")
  for phase in ['train', 'val']:
  if phase == 'train': model.train()
  else: model.eval()
 
  running_loss = 0.0
             correct = 0
  for inputs, labels in dataloaders[phase]:
  inputs, labels = inputs.to(device), labels.to(device)
 
optimizer.zero_grad()
  with torch.set_grad_enabled(phase == 'train'):
  outputs = model(inputs)
  loss = criterion(outputs, labels)
  _, preds = torch.max(outputs, 1)
 
if phase == 'train':
  loss.backward()
  optimizer.step()
running_loss += loss.item() * inputs.size(0)
correct += torch.sum(preds == labels.data)
        epoch_loss = running_loss / len(image_datasets[phase])
  epoch_acc = correct.double() / len(image_datasets[phase])
  print(f"{phase.capitalize()} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
train_model(model, dataloaders, criterion, optimizer)

Step 4: Evaluate and Save the Model

Test the model on unseen data and save it for deployment.

python
# Save the trained model
torch.save(model.state_dict(), 'flower_classifier.pth')
# Load the model for inference
model.load_state_dict(torch.load('flower_classifier.pth'))
model.eval()

Step 5: Build the Lightweight Prediction App

Use a simple script to upload an image and predict the flower type.

python
from PIL import Image
def predict_flower(image_path, model, class_names):
  model.eval()
  image = Image.open(image_path)
  transform = data_transforms['val']
  image = transform(image).unsqueeze(0).to(device)
  outputs = model(image) _, preds = torch.max(outputs, 1)
  return class_names[preds[0]]
# Test the app
class_names = image_datasets['train'].classes
flower = predict_flower('test_flower.jpg', model, class_names)
print(f"This is a {flower}!")

Step 6: Deploy the Model

You can deploy the model as a lightweight web app using Flask or Streamlit.


Key Takeaways

  1. Pre-trained Models: Leveraging ResNet accelerates development.
  2. PyTorch's Flexibility: Allows customizations and advanced deep learning features.
  3. Real-World Application: Build and deploy an app to classify flower images instantly.

Conclusion

Building an image classifier with PyTorch is an excellent way to learn deep learning concepts and apply them to real-world problems. By training on the Flowers Dataset, we’ve created a robust classifier that can distinguish between various flower types.