How to define a simple artificial neural network in PyTorch?


To define a simple artificial neural network (ANN), we could use the following steps −

Steps

  • First we import the important libraries and packages. We try to implement a simple ANN in PyTorch. In all the following examples, the required Python library is torch. Make sure you have already installed it.

import torch
import torch.nn as nn
  • Our next step is to build a simple ANN model. Here, we use the nn package to implement our model. For this, we define a class MyNetwork and pass nn.Module as the parameter.

class MyNetwork(nn.Module):
  • We need to create two functions inside the class to get our model ready. First is the init() and the second is the forward(). Within the init() function, we call a super() function and define different layers.

  • We need to instantiate the class to use for training on the dataset. When we instantiate the class, the forward() function is executed.

model = MyNetwork()
  • Print the model to see the different layers.

print(model)

Example 1

In the following example, we create a simple Artificial Neural Network with four layers without forward function.

# Import the required libraries
import torch
from torch import nn

# define a simple sequential model
model = nn.Sequential(
   nn.Linear(32, 128),
   nn.ReLU(),
   nn.Linear(128, 10),
   nn.Sigmoid()
)

# print the model
print(model)

Output

Sequential(
   (0): Linear(in_features=32, out_features=128, bias=True)
   (1): ReLU()
   (2): Linear(in_features=128, out_features=10, bias=True)
   (3): Sigmoid()
)

Example 2

The following Python program shows a different way to build a simple Neural network.

import torch
import torch.nn as nn
import torch.nn.functional as F
class MyNet(nn.Module):
   def __init__(self):
      super(MyNet, self).__init__()
      self.fc1 = nn.Linear(4, 8)
      self.fc2 = nn.Linear(8, 16)
      self.fc3 = nn.Linear(16, 4)
      self.fc4 = nn.Linear(4,1)
   def forward(self, x):
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = F.relu(self.fc3(x))
      return torch.sigmoid(self.fc4(x))

model = MyNet()
print(model)

Output

MyNet(
   (fc1): Linear(in_features=4, out_features=8, bias=True)
   (fc2): Linear(in_features=8, out_features=16, bias=True)
   (fc3): Linear(in_features=16, out_features=4, bias=True)
   (fc4): Linear(in_features=4, out_features=1, bias=True)
)

Updated on: 25-Jan-2022

423 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements