Data Structure
Networking
RDBMS
Operating System
Java
MS Excel
iOS
HTML
CSS
Android
Python
C Programming
C++
C#
MongoDB
MySQL
Javascript
PHP
- Selected Reading
- UPSC IAS Exams Notes
- Developer's Best Practices
- Questions and Answers
- Effective Resume Writing
- HR Interview Questions
- Computer Glossary
- Who is Who
How to narrow down a tensor in PyTorch?
torch.narrow() method is used to perform narrow operation on a PyTorch tensor. It returns a new tensor that is a narrowed version of the original input tensor.
For example, a tensor of [4, 3] can be narrowed to a tensor of size [2, 3] or [4, 2]. We can narrow down a tensor along a single dimension at a time. Here, we cannot narrow down both dimensions to a size of [2, 2]. We can also use Tensor.narrow() to narrow down a tensor.
Syntax
torch.narrow(input, dim, start, length) Tensor.narrow(dim, start, length)
Parameters
input – It's the PyTorch tensor to narrow.
dim – It's the dimension along which we have to narrow down the original tensor, input.
Start – Starting dimension.
Length – Length to the end dimension from starting dimension.
Steps
Import the torch library. Make sure you have already installed it.
import torch
Create a PyTorch tensor and print the tensor and its size.
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print("Tensor:
", t)
print("Size of tensor:", t.size()) # size 3x3
Compute torch.narrow(input, dim, start, length) and assign the value to a variable.
t1 = torch.narrow(t, 0, 1, 2)
Print the resultant tensor and its size, after narrowing.
print("Tensor after Narrowing:
", t2)
print("Size after Narrowing:", t2.size())
Example 1
In the following Python code, the input tensor size is [3, 3]. We use dim = 0, start = 1 and length = 2 to narrow down the tensor along the dimension 0. It returns a new tensor with the dimension [2, 3].
Notice the new tensor is narrowed along the dimension 0 and the length along the dimension 0 is changed to 2.
# import the library
import torch
# create a tensor
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# print the created tensor
print("Tensor:
", t)
print("Size of Tensor:", t.size())
# Narrow-down the tensor in dimension 0
t1 = torch.narrow(t, 0, 1, 2)
print("Tensor after Narrowing:
", t1)
print("Size after Narrowing:", t1.size())
# Narrow down the tensor in dimension 1
t2 = torch.narrow(t, 1, 1, 2)
print("Tensor after Narrowing:
", t2)
print("Size after Narrowing:", t2.size())
Output
Tensor:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
Size of Tensor: torch.Size([3, 3])
Tensor after Narrowing:
tensor([[4, 5, 6],
[7, 8, 9]])
Size after Narrowing: torch.Size([2, 3])
Tensor after Narrowing:
tensor([[2, 3],
[5, 6],
[8, 9]])
Size after Narrowing: torch.Size([3, 2])
Example 2
The following program shows how to implement the narrow operation using Tensor.narrow().
# import required library
import torch
# create a tensor
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
# print the above created tensor
print("Tensor:
", t)
print("Size of Tensor:", t.size())
# Narrow-down the tensor in dimension 0
t1 = t.narrow(0, 1, 2)
print("Tensor after Narrowing:
", t1)
print("Size after Narrowing:", t1.size())
# Narrow down the tensor in dimension 1
t2 = t.narrow(1, 0, 2)
print("Tensor after Narrowing:
", t2)
print("Size after Narrowing:", t2.size())
Output
Tensor:
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
Size of Tensor: torch.Size([4, 3])
Tensor after Narrowing:
tensor([[4, 5, 6],
[7, 8, 9]])
Size after Narrowing: torch.Size([2, 3])
Tensor after Narrowing:
tensor([[ 1, 2],
[ 4, 5],
[ 7, 8],
[10, 11]])
Size after Narrowing: torch.Size([4, 2])