How to access the metadata of a tensor in PyTorch?

PythonPyTorchServer Side ProgrammingProgramming

We access the size (or shape) of a tensor and the number of elements in the tensor as the metadata of the tensor. To access the size of a tensor, we use the .size() method and the shape of a tensor is accessed using .shape.

Both .size() and .shape produce the same result. We use the torch.numel() function to find the total number of elements in the tensor.

Steps

  • Import the required library. Here, the required library is torch. Make sure that you have installed torch.

  • Define a PyTorch tensor.

  • Find the metadata of the tensor. Use .size() and .shape to access the size and shape of the tensor. Use torch.numel() to access the number of elements in the tensor.

  • Print the tensor and the metadata for better understanding.

Example 1

# Python Program to access meta-data of a Tensor
# import necessary libraries
import torch

# Create a tensor of size 4x3
T = torch.Tensor([[1,2,3],[2,1,3],[2,3,5],[5,6,4]])
print("T:\n", T)

# Find the meta-data of tensor
# Find the size of the above tensor "T"
size_T = T.size()
print("size of tensor T:\n", size_T)

# Other method to get size using .shape
print("Shape of tensor:\n", T.shape)

# Find the number of elements in the tensor "T"
num_T = torch.numel(T)
print("Number of elements in tensor T:\n", num_T)

Output

When you run the above Python 3 code, it will produce the following output.

T:
tensor([[1., 2., 3.],
         [2., 1., 3.],
         [2., 3., 5.],
         [5., 6., 4.]])
size of tensor T:
torch.Size([4, 3])
Shape of tensor:
torch.Size([4, 3])
Number of elements in tensor T:
12

Example 2

# Python Program to access meta-data of a Tensor
# import the libraries
import torch

# Create a tensor of random numbers
T = torch.randn(4,3,2)
print("T:\n", T)

# Find the meta-data of tensor
# Find the size of the above tensor "T"
size_T = T.size()
print("size of tensor T:\n", size_T)

# Other method to get size using .shape
print("Shape of tensor:\n", T.shape)

# Find the number of elements in the tensor "T"
num_T = torch.numel(T)
print("Number of elements in tensor T:\n", num_T)

Output

When you run the above Python 3 code, it will produce the following output.

T:
tensor([[[-1.1806, 0.5569],
         [ 2.2237, 0.9709],
         [ 0.4775, -0.2491]],
         [[-0.9703, 1.9916],
         [ 0.1998, -0.6501],
         [-0.7489, -1.3013]],
         [[ 1.3191, 2.0049],
         [-0.1195, 0.1860],
         [-0.6061, -1.2451]],
         [[-0.6044, 0.6153],
         [-2.2473, -0.1531],
         [ 0.5341, 1.3697]]])
size of tensor T:
torch.Size([4, 3, 2])
Shape of tensor:
torch.Size([4, 3, 2])
Number of elements in tensor T:
24
raja
Updated on 06-Nov-2021 09:39:31

Advertisements