- Trending Categories
Data Structure
Networking
RDBMS
Operating System
Java
MS Excel
iOS
HTML
CSS
Android
Python
C Programming
C++
C#
MongoDB
MySQL
Javascript
PHP
Physics
Chemistry
Biology
Mathematics
English
Economics
Psychology
Social Studies
Fashion Studies
Legal Studies
- Selected Reading
- UPSC IAS Exams Notes
- Developer's Best Practices
- Questions and Answers
- Effective Resume Writing
- HR Interview Questions
- Computer Glossary
- Who is Who
How to flatten an input tensor by reshaping it in PyTorch?
A tensor can be flattened into a one-dimensional tensor by reshaping it using the method torch.flatten(). This method supports both real and complex-valued input tensors. It takes a torch tensor as its input and returns a torch tensor flattened into one dimension.
It takes two optional parameters, start_dim and end_dim. If these parameters are passed, only those dimensions starting with start_dim and ending with end_dim are flattened.
The order of elements in the input tensor is not changed. This function may return the original object, a view, or copy. In the following examples, we cover all the aspects of flattening the tensor with and without using start_dim and end_dim.
Syntax
torch.flatten(input, star_dim=0, end_dim=-1)
Parameters
input - It's a torch tensor to flatten.
start_dim - It's the first dimension to flatten. It's an optional parameter. Default is set to 0.
end_dim - It's the last dimension to flatten. It's an optional parameter. Default is set to -1.
Steps
Import the required library. In all the following examples, the required Python library is torch. Make sure you have already installed it.
import torch
Create a PyTorch tensor and print the tensor.
t = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) print("Tensor:
", t)
Flatten the above tensor using the above-defined syntax and optionally assign the value to a new variable.
flatten_t = torch.flatten(t, start_dim=0, end_dim=1)
Print the flatten tensor.
print("Flattened Tensor:
", flatten_t)
Example 1
In this program, we flatten a tensor into a one-dimensional tensor. We also flatten the tensor with start_dim.
Import the required library import torch # define a torch tensor t = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) print("Tensor:
", t) print("Size of Tensor:", t.size()) # flatten the above tensor using start_dims flatten_t = torch.flatten(t) flatten_t0 = torch.flatten(t, start_dim=0) flatten_t1 = torch.flatten(t, start_dim=1) flatten_t2 = torch.flatten(t, start_dim=2) # print the flatten tensors print("Flatten tensor:
", flatten_t) print("Flatten tensor (start_dim=0):
", flatten_t0) print("Flatten tensor (start_dim=1):
", flatten_t1) print("Flatten tensor (start_dim=2):
", flatten_t2)
Output
Tensor: tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]]) Size of Tensor: torch.Size([2, 2, 3]) Flatten tensor: tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) Flatten tensor (start_dim=0): tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) Flatten tensor (start_dim=1): tensor([[ 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12]]) Flatten tensor (start_dim=2): tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]])
Example 2
In this program, we flatten a tensor into a one-dimensional tensor. We also flatten the tensor with end_dim.
import torch t = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) print("Tensor:
", t) print("Size of Tensor:", t.size()) # flatten the above tensor using end_dims flatten_t = torch.flatten(t) flatten_t0 = torch.flatten(t, end_dim=0) flatten_t1 = torch.flatten(t, end_dim=1) flatten_t2 = torch.flatten(t, end_dim=2) # print the flatten tensors print("Flatten tensor:
", flatten_t) print("Flatten tensor (end_dim=0):
", flatten_t0) print("Flatten tensor (end_dim=1):
", flatten_t1) print("Flatten tensor (end_dim=2):
", flatten_t2)
Output
Tensor: tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]]) Size of Tensor: torch.Size([2, 2, 3]) Flatten tensor: tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) Flatten tensor (end_dim=0): tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]]) Flatten tensor (end_dim=1): tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]]) Flatten tensor (end_dim=2): tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
Example 3
In this program, we flatten a tensor into a one-dimensional tensor. We also flatten the tensor with start_dim and end_dim.
import torch t = torch.empty(2,2,3,3).random_(30) print("Tensor:
", t) print("Size of Tensor:", t.size()) # flatten the above tensor using end_dims flatten_t0 = torch.flatten(t, start_dim=2, end_dim=3) # print the flatten tensors print("Flatten tensor (start_dim=2,end_dim=3):
", flatten_t0)
Output
Tensor: tensor([[[[27., 13., 29.], [ 1., 23., 15.], [15., 7., 19.]], [[ 4., 14., 24.], [ 6., 4., 7.], [ 6., 18., 11.]]], [[[ 0., 27., 3.], [25., 12., 25.], [10., 23., 9.]], [[ 3., 1., 28.], [19., 7., 28.], [23., 14., 21.]]]]) Size of Tensor: torch.Size([2, 2, 3, 3]) Flatten tensor (start_dim=2,end_dim=3): tensor([[[27., 13., 29., 1., 23., 15., 15., 7., 19.], [ 4., 14., 24., 6., 4., 7., 6., 18., 11.]], [[ 0., 27., 3., 25., 12., 25., 10., 23., 9.], [ 3., 1., 28., 19., 7., 28., 23., 14., 21.]]])