This task intends to measure your deep learning basic skills. We will provide you a dataset containing 131 different fruits. Download it from this link, download it and upload it to your copy of this file
You have to implement a Convolutional Neural Network to classify the input fruit image. For this task, you must follow the following rules:
%cd /kaggle/input/fruits
!ls
/kaggle/input/fruits fruits-360
# Load and transform the dataset
import os
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms, ToTensor
from torch.utils.data import random_split, DataLoader
# Train the model
import copy
import time
import torch
import torch.nn as nn
from torchvision import models
# Evaluate the model
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
# Apply the model
from random import choice
import torch.nn.functional as nnf
train_directory = './fruits-360/Training'
test_directory = './fruits-360/Test'
transform = transforms.Compose([
transforms.Resize([224, 224]), # 224x224 is image size used in pre-trained models
transforms.ToTensor()
])
train_dataset = ImageFolder(train_directory, transform=transform)
print('Training dataset size:', len(train_dataset))
test_dataset = ImageFolder(test_directory, transform=transform)
print('Test dataset size:', len(test_dataset))
Training dataset size: 67692 Test dataset size: 22688
# Split training and validation data
validation_size = int(0.1 * len(train_dataset)) # Validation represents 10% of the training data
training_size = len(train_dataset) - validation_size
train_dataset, validation_dataset = random_split(train_dataset,
[training_size, validation_size])
len(train_dataset), len(validation_dataset)
(60923, 6769)
# Get dataset classes
classes = os.listdir(train_directory)
len(classes)
131
# Create data loaders from datasets
batch_size = 32
training_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
validation_loader = DataLoader(dataset=validation_dataset, batch_size=batch_size, shuffle=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def training_step(model, criterion, optimizer, dataloader_train):
running_loss = 0.0
# Iterate over data.
for inputs, labels in dataloader_train:
inputs = inputs.to(device)
labels = labels.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
with torch.set_grad_enabled(True):
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
return running_loss
def evaluation_step(model, criterion, optimizer, dataloader_val):
running_loss = 0.0
running_corrects = 0
# Iterate over data.
for inputs, labels in dataloader_val:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, preds = torch.max(outputs, 1)
running_corrects += torch.sum(preds == labels.data)
return running_loss, running_corrects
def train_model(model, criterion, optimizer, scheduler, num_epochs=5):
history = {
'train_loss': list(),
'validation_loss': list(),
'validation_accuracy': list()
}
since = time.time()
best_model = copy.deepcopy(model.state_dict())
best_accuracy = 0.0
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 10)
# Training step
model.train()
running_loss = training_step(
model, criterion, optimizer, training_loader
)
train_epoch_loss = running_loss / len(training_loader.dataset)
scheduler.step()
# Evaluation step
model.eval()
running_loss, running_corrects = evaluation_step(
model, criterion, optimizer, validation_loader
)
validation_epoch_loss = running_loss / len(validation_loader.dataset)
epoch_accuracy = running_corrects.double() / len(validation_loader.dataset)
print(f'Training Loss: {train_epoch_loss:.4f} '
+ f'Validation Loss: {validation_epoch_loss:.4f} '
+ f'Accuracy: {epoch_accuracy:.4f}\n')
history['train_loss'].append(train_epoch_loss)
history['validation_loss'].append(validation_epoch_loss)
history['validation_accuracy'].append(epoch_accuracy)
# Deep copy the model if it presents the best accuracy
if epoch_accuracy > best_accuracy:
best_accuracy = epoch_accuracy
best_model = copy.deepcopy(model.state_dict())
time_elapsed = time.time() - since
minutes = time_elapsed // 60
seconds = time_elapsed % 60
print(f'Training complete in {minutes:.0f}m {seconds:.0f}s')
print(f'Best accuracy: {best_accuracy:4f}')
# Load best model weights
model.load_state_dict(best_model)
return model, history
device
device(type='cuda', index=0)
# Load ResNet50 as backbone
model_base = models.resnet50(pretrained=True)
num_features = model_base.fc.in_features
model_base.fc = nn.Linear(num_features, len(classes))
# Load model to GPU
model_base.to(device)
# Hyper-parameters
step_size = 7
gamma = 0.1
criterion = nn.CrossEntropyLoss()
optimizer_ft = torch.optim.Adam(model_base.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=step_size, gamma=gamma)
Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth
model, history = train_model(model_base, criterion, optimizer_ft, scheduler, num_epochs=3)
Epoch 0/2 ---------- Training Loss: 0.0167 Validation Loss: 0.0020 Accuracy: 0.9787 Epoch 1/2 ---------- Training Loss: 0.0023 Validation Loss: 0.0006 Accuracy: 0.9934 Epoch 2/2 ---------- Training Loss: 0.0012 Validation Loss: 0.0040 Accuracy: 0.9601 Training complete in 33m 49s Best accuracy: 0.993352
def plot_accuracy(history):
accuracy = history["validation_accuracy"]
epochs = range(len(accuracy))
plt.plot(accuracy, "-x")
plt.xticks(epochs)
plt.title("Accuracy x Number of Epochs")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plot_accuracy(history)
def plot_losses(history):
train_loss = history["train_loss"]
validation_loss = history["validation_loss"]
epochs = range(len(train_loss))
plt.plot(train_loss, "-rx") # r = Red
plt.plot(validation_loss, "-bx") # b = Blue
plt.xticks(epochs)
plt.legend(["Training loss", "Validation loss"])
plt.xlabel("Epochs")
plt.ylabel("Losses (%)")
plot_losses(history)
def evaluate(model, data_loader):
correct_predictions = 0.0
model.eval()
y_hat = list()
y_true = list()
for (inputs, labels) in data_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
y_hat.append(preds)
correct_predictions += torch.sum(preds == labels.data)
accuracy = correct_predictions.double() / len(data_loader.dataset)
return accuracy, y_true, y_hat
accuracy, y_true, y_hat = evaluate(model, test_loader)
print(f'Accuracy: {accuracy:.4f}')
Accuracy: 0.9819
def evaluate_classes(data_loader, y_true, y_hat):
Y_true = list()
Y_hat = list()
for inputs, labels in iter(data_loader):
for y_true in labels.cpu().numpy():
Y_true.append(y_true)
for joint in y_hat:
for label in joint.cpu().numpy():
Y_hat.append(label)
report = classification_report(Y_true, Y_hat, target_names=classes, zero_division=0)
plt.rcParams["figure.figsize"] = (20,20)
sns.heatmap(confusion_matrix(Y_true, Y_hat))
return report
report = evaluate_classes(test_loader, y_true, y_hat)
print(report)
precision recall f1-score support
Quince 0.93 1.00 0.96 164
Grapefruit White 0.99 1.00 1.00 148
Granadilla 1.00 1.00 1.00 160
Orange 1.00 1.00 1.00 164
Apple Red 3 1.00 1.00 1.00 161
Grape White 2 1.00 1.00 1.00 164
Corn Husk 1.00 1.00 1.00 152
Tamarillo 1.00 1.00 1.00 164
Banana Red 1.00 0.96 0.98 164
Nectarine Flat 1.00 1.00 1.00 144
Pepper Yellow 0.96 1.00 0.98 166
Nut Forest 1.00 1.00 1.00 164
Pear Monster 1.00 1.00 1.00 219
Fig 1.00 1.00 1.00 164
Tomato Heart 1.00 1.00 1.00 143
Onion Red Peeled 1.00 1.00 1.00 166
Lemon Meyer 1.00 1.00 1.00 166
Onion Red 1.00 1.00 1.00 152
Passion Fruit 0.92 1.00 0.96 166
Cucumber Ripe 1.00 0.93 0.96 150
Cactus fruit 1.00 1.00 1.00 154
Tomato not Ripened 1.00 1.00 1.00 166
Mango Red 1.00 1.00 1.00 164
Apple Pink Lady 1.00 0.99 1.00 164
Pomegranate 1.00 0.94 0.97 166
Plum 1.00 1.00 1.00 234
Pineapple 1.00 1.00 1.00 164
Tomato 1 0.95 1.00 0.97 246
Cherry 2 1.00 1.00 1.00 246
Apple Red 2 1.00 1.00 1.00 164
Avocado ripe 1.00 1.00 1.00 164
Dates 0.78 1.00 0.88 164
Maracuja 0.95 0.95 0.95 153
Papaya 1.00 1.00 1.00 166
Nut Pecan 0.87 1.00 0.93 166
Pear Stone 0.99 0.88 0.93 150
Cherry Wax Yellow 1.00 0.66 0.80 154
Eggplant 1.00 1.00 1.00 130
Apple Golden 2 1.00 1.00 1.00 156
Guava 1.00 1.00 1.00 166
Beetroot 1.00 0.85 0.92 156
Tomato Maroon 1.00 1.00 1.00 234
Potato Red 1.00 0.93 0.96 99
Apple Red Delicious 1.00 1.00 1.00 166
Cherry Wax Red 1.00 1.00 1.00 328
Kiwi 1.00 1.00 1.00 164
Cherry Wax Black 0.87 1.00 0.93 166
Limes 0.99 1.00 1.00 166
Cantaloupe 2 1.00 1.00 1.00 164
Apple Braeburn 1.00 1.00 1.00 158
Pear 1.00 1.00 1.00 166
Carambula 1.00 1.00 1.00 164
Tomato 3 1.00 1.00 1.00 166
Onion White 0.98 1.00 0.99 157
Cherry 1 1.00 1.00 1.00 166
Strawberry 1.00 1.00 1.00 166
Lychee 1.00 1.00 1.00 156
Redcurrant 0.99 1.00 1.00 157
Rambutan 1.00 1.00 1.00 166
Potato Red Washed 1.00 1.00 1.00 164
Tomato 4 1.00 1.00 1.00 166
Hazelnut 1.00 1.00 1.00 166
Tomato Yellow 1.00 1.00 1.00 166
Plum 3 1.00 1.00 1.00 166
Grape White 1.00 1.00 1.00 166
Pineapple Mini 0.97 0.99 0.98 142
Mulberry 0.98 1.00 0.99 102
Grape Blue 1.00 1.00 1.00 166
Pear Abate 0.97 1.00 0.98 246
Melon Piel de Sapo 1.00 1.00 1.00 164
Pepper Orange 1.00 1.00 1.00 164
Cauliflower 1.00 1.00 1.00 160
Nectarine 1.00 1.00 1.00 218
Salak 0.99 0.96 0.97 178
Cocos 1.00 1.00 1.00 150
Chestnut 1.00 1.00 1.00 155
Blueberry 0.98 1.00 0.99 146
Apple Granny Smith 1.00 1.00 1.00 160
Banana Lady Finger 1.00 1.00 1.00 164
Apricot 1.00 1.00 1.00 166
Walnut 1.00 1.00 1.00 164
Apple Crimson Snow 0.96 1.00 0.98 246
Grapefruit Pink 0.98 1.00 0.99 164
Tangelo 1.00 1.00 1.00 164
Peach Flat 1.00 1.00 1.00 232
Pear Forelle 1.00 0.90 0.95 166
Pepper Red 1.00 1.00 1.00 234
Tomato Cherry Red 1.00 1.00 1.00 102
Pear Williams 1.00 1.00 1.00 166
Clementine 0.98 1.00 0.99 222
Apple Golden 3 0.88 1.00 0.94 237
Apple Red 1 1.00 1.00 1.00 166
Pear 2 1.00 1.00 1.00 166
Plum 2 0.81 1.00 0.89 148
Cantaloupe 1 1.00 0.34 0.51 234
Lemon 1.00 1.00 1.00 222
Physalis with Husk 0.68 1.00 0.81 222
Peach 2 1.00 1.00 1.00 164
Pepino 0.96 0.99 0.98 164
Huckleberry 1.00 0.95 0.98 166
Potato White 1.00 1.00 1.00 163
Pitahaya Red 1.00 1.00 1.00 166
Apple Golden 1 1.00 1.00 1.00 151
Pomelo Sweetie 1.00 1.00 1.00 142
Cherry Rainier 1.00 1.00 1.00 304
Avocado 1.00 0.88 0.94 164
Apple Red Yellow 2 1.00 1.00 1.00 153
Raspberry 0.95 0.97 0.96 150
Mangostan 1.00 0.97 0.98 151
Strawberry Wedge 0.93 0.99 0.96 150
Kaki 0.96 1.00 0.98 150
Mandarine 0.94 1.00 0.97 166
Potato Sweet 1.00 0.91 0.95 164
Cucumber Ripe 2 1.00 1.00 1.00 166
Kumquats 1.00 1.00 1.00 164
Pear Red 1.00 1.00 1.00 162
Ginger Root 1.00 1.00 1.00 164
Physalis 1.00 0.99 0.99 246
Pear Kaiser 1.00 1.00 1.00 166
Peach 1.00 1.00 1.00 166
Corn 1.00 1.00 1.00 246
Grape White 3 1.00 1.00 1.00 225
Apple Red Yellow 1 1.00 1.00 1.00 246
Grape Pink 1.00 1.00 1.00 160
Banana 1.00 1.00 1.00 164
Grape White 4 1.00 0.98 0.99 228
Kohlrabi 0.97 1.00 0.98 127
Pepper Green 1.00 1.00 1.00 153
Watermelon 1.00 0.78 0.88 158
Mango 1.00 1.00 1.00 249
Tomato 2 0.99 1.00 0.99 157
accuracy 0.98 22688
macro avg 0.99 0.98 0.98 22688
weighted avg 0.98 0.98 0.98 22688
def predict(image_input, model):
image = image_input.unsqueeze(0).to(device)
output = model(image)
probabilities = nnf.softmax(output, dim=1)[0]
probabilities_classes = [
(classes[index], probability) for index, probability in enumerate(probabilities)
]
# Order dict by highest probability
ordered_probabilities = sorted(
probabilities_classes,
key=lambda probability_class: probability_class[1],
reverse=True
)
return ordered_probabilities
# Randonmly chooses an image
image, label = choice(test_dataset)
# Display image and true label
plt.figure(figsize=(4, 3), dpi=80)
plt.imshow(image.permute(1,2,0))
plt.show()
print(f"Label: {classes[label]}\n")
probabilities = predict(image, model)
print("Probability scores")
for fruit_class, probability in probabilities:
print(f"{fruit_class:30s} {probability:.4f}")
Label: Apple Crimson Snow Probability scores Apple Crimson Snow 1.0000 Dates 0.0000 Potato Red Washed 0.0000 Lychee 0.0000 Cantaloupe 2 0.0000 Kaki 0.0000 Blueberry 0.0000 Lemon Meyer 0.0000 Pear Stone 0.0000 Limes 0.0000 Apple Golden 2 0.0000 Apple Red Yellow 2 0.0000 Tomato not Ripened 0.0000 Corn Husk 0.0000 Pomegranate 0.0000 Banana Red 0.0000 Tangelo 0.0000 Plum 3 0.0000 Pepino 0.0000 Pear Monster 0.0000 Physalis with Husk 0.0000 Peach 2 0.0000 Strawberry Wedge 0.0000 Nectarine 0.0000 Kohlrabi 0.0000 Cherry Wax Black 0.0000 Grapefruit Pink 0.0000 Apple Braeburn 0.0000 Grape Blue 0.0000 Granadilla 0.0000 Grape White 3 0.0000 Pear Abate 0.0000 Potato Red 0.0000 Banana Lady Finger 0.0000 Orange 0.0000 Peach Flat 0.0000 Carambula 0.0000 Mango 0.0000 Mango Red 0.0000 Tomato 3 0.0000 Mangostan 0.0000 Apple Red 3 0.0000 Onion Red 0.0000 Walnut 0.0000 Pepper Red 0.0000 Grape White 2 0.0000 Raspberry 0.0000 Pear 2 0.0000 Clementine 0.0000 Cherry Rainier 0.0000 Apple Red Yellow 1 0.0000 Maracuja 0.0000 Onion White 0.0000 Pineapple Mini 0.0000 Apple Pink Lady 0.0000 Potato White 0.0000 Cantaloupe 1 0.0000 Grape Pink 0.0000 Apple Granny Smith 0.0000 Kumquats 0.0000 Cherry 2 0.0000 Cucumber Ripe 2 0.0000 Plum 0.0000 Cocos 0.0000 Onion Red Peeled 0.0000 Tomato 4 0.0000 Pear Williams 0.0000 Cherry 1 0.0000 Pear 0.0000 Tomato 2 0.0000 Grape White 0.0000 Kiwi 0.0000 Nut Forest 0.0000 Pepper Yellow 0.0000 Banana 0.0000 Pepper Green 0.0000 Eggplant 0.0000 Chestnut 0.0000 Apple Red Delicious 0.0000 Peach 0.0000 Potato Sweet 0.0000 Pear Forelle 0.0000 Cherry Wax Red 0.0000 Apricot 0.0000 Nectarine Flat 0.0000 Rambutan 0.0000 Salak 0.0000 Melon Piel de Sapo 0.0000 Fig 0.0000 Physalis 0.0000 Apple Red 1 0.0000 Mandarine 0.0000 Lemon 0.0000 Hazelnut 0.0000 Tomato Cherry Red 0.0000 Redcurrant 0.0000 Tamarillo 0.0000 Cauliflower 0.0000 Pepper Orange 0.0000 Pineapple 0.0000 Mulberry 0.0000 Guava 0.0000 Strawberry 0.0000 Corn 0.0000 Tomato 1 0.0000 Pear Red 0.0000 Pitahaya Red 0.0000 Tomato Maroon 0.0000 Watermelon 0.0000 Nut Pecan 0.0000 Avocado ripe 0.0000 Plum 2 0.0000 Pear Kaiser 0.0000 Ginger Root 0.0000 Avocado 0.0000 Beetroot 0.0000 Cactus fruit 0.0000 Apple Red 2 0.0000 Huckleberry 0.0000 Cucumber Ripe 0.0000 Apple Golden 3 0.0000 Quince 0.0000 Grapefruit White 0.0000 Passion Fruit 0.0000 Pomelo Sweetie 0.0000 Apple Golden 1 0.0000 Grape White 4 0.0000 Cherry Wax Yellow 0.0000 Tomato Yellow 0.0000 Papaya 0.0000 Tomato Heart 0.0000