- Trending Categories
Data Structure
Networking
RDBMS
Operating System
Java
MS Excel
iOS
HTML
CSS
Android
Python
C Programming
C++
C#
MongoDB
MySQL
Javascript
PHP
Physics
Chemistry
Biology
Mathematics
English
Economics
Psychology
Social Studies
Fashion Studies
Legal Studies
- Selected Reading
- UPSC IAS Exams Notes
- Developer's Best Practices
- Questions and Answers
- Effective Resume Writing
- HR Interview Questions
- Computer Glossary
- Who is Who
How to load a Computer Vision dataset in PyTorch?
There are many datasets available in Pytorch related to computer vision tasks. The torch.utils.data.Dataset provides different types of datasets. The torchvision.datasets is a subclass of torch.utils.data.Dataset and has many datasets related to images and videos. PyTorch also provides us a torch.utils.data.DataLoader which is used to load multiple samples from a dataset.
Steps
We could use the following steps to load computer vision datasets −
Import the required libraries. In all the following examples, the required Python libraries are torch, Matplotlib, and torchvision. Make sure you have already installed them.
import torch import torchvision from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt
We load CIFAR10 training and test datasets using datasets.CIFAR10() with the parameters train=True for training dataset and train=False for test dataset.
root="data", train=True, download=True, transform=ToTensor()
Define a train dataloader (trainloader) and test dataloader (testloader). Specify the batch_size. Set Shuffle=True to get the shuffled images. Also access the class label names.
Get some random images and labels from the training or test datasets.
dataiter = iter(trainloader) images, labels = dataiter.next()
Visualize the obtained images with the labels.
Example 1
In the following Python program, we load CIFAR10 training and test datasets.
# Import the required libraries import torch import torchvision from torchvision import datasets from torchvision.transforms import ToTensor # define batch size batch_size = 4 # download CIFAR10 training and test datasets training_data = datasets.CIFAR10( root="data", train=True, download=True, transform=ToTensor() ) test_data = datasets.CIFAR10( root="data", train=False, download=True, transform=ToTensor() ) # define train and test dataloader trainloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=2) testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2) # access names of the labels label_names = training_data.classes # display details about the dataset print("label_names:
", label_names) print("class label name to index:
", training_data.class_to_idx) print("Shape of training data:
", training_data.data.shape ) print("Shape of test data:
", test_data.data.shape )
Output
Files already downloaded and verified Files already downloaded and verified label_names: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] class label name to index: {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9} Shape of training data: (50000, 32, 32, 3) Shape of test data: (10000, 32, 32, 3)
Example 2
In this Python program, we load the CIFAR10 dataset. We also visualize some random images with their label names.
import torch import torchvision from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt batch_size = 4 training_data = datasets.CIFAR10( root="data", train=True, download=True, transform=ToTensor() ) test_data = datasets.CIFAR10( root="data", train=False, download=True, transform=ToTensor() ) trainloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=False, num_workers=2) testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2) label_names = training_data.classes # get some random training images dataiter = iter(trainloader) images, labels = dataiter.next() # display random images # define figure fig=plt.figure(figsize=(8, 5)) columns, rows = batch_size, 1 # visualize these random images for i in range(1, columns*rows +1): fig.add_subplot(rows, columns, i) plt.imshow(images[i-1].numpy().transpose(1,2,0)) plt.xticks([]) plt.yticks([]) plt.title(label_names[labels[i-1]]) plt.show()
Output
Files already downloaded and verified Files already downloaded and verified
- Related Articles
- How are Companies Designing Computer Vision?
- How to remove Load-use delay in Computer Architecture?
- How can Tensorflow be used to load the Illiad dataset using Python?
- How to load video from your computer in OpenCV using C++?
- How Tensorflow is used with Estimators to build a linear model to load the titanic dataset?
- How can Tensorflow be used to load the csv data from abalone dataset?
- How can Tensorflow be used to save and load weights for MNIST dataset?
- Difference between Computer Vision and Image Processing
- Difference between Computer Vision and Machine Learning
- Difference between Computer Vision and Pattern Recognition
- Difference between Computer Vision and Deep Learning
- What is load shedding in computer networks?
- How can Tensorflow be used to load the flower dataset and work with it?
- How can Tensorflow be used to load the dataset which contains stackoverflow questions using Python?
- What is Load/Store reordering in computer architecture?
