How to get the data type of a tensor in PyTorch?

PythonPyTorchServer Side ProgrammingProgramming

A PyTorch tensor is homogenous, i.e., all the elements of a tensor are of the same data type. We can access the data type of a tensor using the ".dtype" attribute of the tensor. It returns the data type of the tensor.

Steps

  • Import the required library. In all the following Python examples, the required Python library is torch. Make sure you have already installed it.

  • Create a tensor and print it.

  • Compute T.dtype. Here T is the tensor of which we want to get the data type.

  • Print the data type of the tensor.

Example 1

The following Python program shows how to get the data type of a tensor.

# Import the library
import torch

# Create a tensor of random numbers of size 3x4
T = torch.randn(3,4)
print("Original Tensor T:\n", T)

# Get the data type of above tensor
data_type = T.dtype

# Print the data type of the tensor
print("Data type of tensor T:\n", data_type)

Output

Original Tensor T:
tensor([[ 2.1768, -0.1328, 0.8155, -0.7967],
         [ 0.1194, 1.0465, 0.0779, 0.9103],
         [-0.1809, 1.8085, 0.8393, -0.2463]])
Data type of tensor T:
torch.float32

Example 2

# Python program to get data type of a tensor
# Import the library
import torch

# Create a tensor of random numbers of size 3x4
T = torch.Tensor([1,2,3,4])
print("Original Tensor T:\n", T)

# Get the data type of above tensor
data_type = T.dtype

# Print the data type of the tensor
print("Data type of tensor T:\n", data_type)

Output

Original Tensor T:
   tensor([1., 2., 3., 4.])
Data type of tensor T:
   torch.float32
raja
Updated on 06-Nov-2021 10:06:20

Advertisements