# How to define a simple artificial neural network in PyTorch?

PyTorchServer Side ProgrammingProgramming

To define a simple artificial neural network (ANN), we could use the following steps −

### Steps

• First we import the important libraries and packages. We try to implement a simple ANN in PyTorch. In all the following examples, the required Python library is torch. Make sure you have already installed it.

import torch
import torch.nn as nn
• Our next step is to build a simple ANN model. Here, we use the nn package to implement our model. For this, we define a class MyNetwork and pass nn.Module as the parameter.

class MyNetwork(nn.Module):
• We need to create two functions inside the class to get our model ready. First is the init() and the second is the forward(). Within the init() function, we call a super() function and define different layers.

• We need to instantiate the class to use for training on the dataset. When we instantiate the class, the forward() function is executed.

model = MyNetwork()
• Print the model to see the different layers.

print(model)

## Example 1

In the following example, we create a simple Artificial Neural Network with four layers without forward function.

# Import the required libraries
import torch
from torch import nn

# define a simple sequential model
model = nn.Sequential(
nn.Linear(32, 128),
nn.ReLU(),
nn.Linear(128, 10),
nn.Sigmoid()
)

# print the model
print(model)

## Output

Sequential(
(0): Linear(in_features=32, out_features=128, bias=True)
(1): ReLU()
(2): Linear(in_features=128, out_features=10, bias=True)
(3): Sigmoid()
)

## Example 2

The following Python program shows a different way to build a simple Neural network.

import torch
import torch.nn as nn
import torch.nn.functional as F
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.fc1 = nn.Linear(4, 8)
self.fc2 = nn.Linear(8, 16)
self.fc3 = nn.Linear(16, 4)
self.fc4 = nn.Linear(4,1)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))

model = MyNet()
print(model)

### Output

MyNet(
(fc1): Linear(in_features=4, out_features=8, bias=True)
(fc2): Linear(in_features=8, out_features=16, bias=True)
(fc3): Linear(in_features=16, out_features=4, bias=True)
(fc4): Linear(in_features=4, out_features=1, bias=True)
)
Updated on 25-Jan-2022 08:39:11