 
 Data Structure Data Structure
 Networking Networking
 RDBMS RDBMS
 Operating System Operating System
 Java Java
 MS Excel MS Excel
 iOS iOS
 HTML HTML
 CSS CSS
 Android Android
 Python Python
 C Programming C Programming
 C++ C++
 C# C#
 MongoDB MongoDB
 MySQL MySQL
 Javascript Javascript
 PHP PHP
- Selected Reading
- UPSC IAS Exams Notes
- Developer's Best Practices
- Questions and Answers
- Effective Resume Writing
- HR Interview Questions
- Computer Glossary
- Who is Who
How to flatten an input tensor by reshaping it in PyTorch?
A tensor can be flattened into a one-dimensional tensor by reshaping it using the method torch.flatten(). This method supports both real and complex-valued input tensors. It takes a torch tensor as its input and returns a torch tensor flattened into one dimension.
It takes two optional parameters, start_dim and end_dim. If these parameters are passed, only those dimensions starting with start_dim and ending with end_dim are flattened.
The order of elements in the input tensor is not changed. This function may return the original object, a view, or copy. In the following examples, we cover all the aspects of flattening the tensor with and without using start_dim and end_dim.
Syntax
torch.flatten(input, star_dim=0, end_dim=-1)
Parameters
- input - It's a torch tensor to flatten. 
- start_dim - It's the first dimension to flatten. It's an optional parameter. Default is set to 0. 
- end_dim - It's the last dimension to flatten. It's an optional parameter. Default is set to -1. 
Steps
- Import the required library. In all the following examples, the required Python library is torch. Make sure you have already installed it. 
import torch
- Create a PyTorch tensor and print the tensor. 
t = torch.tensor([[[1, 2, 3],
   [4, 5, 6]],
   [[7, 8, 9],
   [10, 11, 12]]])
print("Tensor:
", t)
- Flatten the above tensor using the above-defined syntax and optionally assign the value to a new variable. 
flatten_t = torch.flatten(t, start_dim=0, end_dim=1)
- Print the flatten tensor. 
print("Flattened Tensor:
", flatten_t)
Example 1
In this program, we flatten a tensor into a one-dimensional tensor. We also flatten the tensor with start_dim.
Import the required library
import torch
# define a torch tensor
t = torch.tensor([[[1, 2, 3],
   [4, 5, 6]],
   [[7, 8, 9],
   [10, 11, 12]]])
print("Tensor:
", t)
print("Size of Tensor:", t.size())
# flatten the above tensor using start_dims
flatten_t = torch.flatten(t)
flatten_t0 = torch.flatten(t, start_dim=0)
flatten_t1 = torch.flatten(t, start_dim=1)
flatten_t2 = torch.flatten(t, start_dim=2)
# print the flatten tensors
print("Flatten tensor:
", flatten_t)
print("Flatten tensor (start_dim=0):
", flatten_t0)
print("Flatten tensor (start_dim=1):
", flatten_t1)
print("Flatten tensor (start_dim=2):
", flatten_t2)
Output
Tensor: tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]]) Size of Tensor: torch.Size([2, 2, 3]) Flatten tensor: tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) Flatten tensor (start_dim=0): tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) Flatten tensor (start_dim=1): tensor([[ 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12]]) Flatten tensor (start_dim=2): tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]])
Example 2
In this program, we flatten a tensor into a one-dimensional tensor. We also flatten the tensor with end_dim.
import torch
t = torch.tensor([[[1, 2, 3],
   [4, 5, 6]],
   [[7, 8, 9],
   [10, 11, 12]]])
print("Tensor:
", t)
print("Size of Tensor:", t.size())
# flatten the above tensor using end_dims
flatten_t = torch.flatten(t)
flatten_t0 = torch.flatten(t, end_dim=0)
flatten_t1 = torch.flatten(t, end_dim=1)
flatten_t2 = torch.flatten(t, end_dim=2)
# print the flatten tensors
print("Flatten tensor:
", flatten_t)
print("Flatten tensor (end_dim=0):
", flatten_t0)
print("Flatten tensor (end_dim=1):
", flatten_t1)
print("Flatten tensor (end_dim=2):
", flatten_t2)
Output
Tensor: tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]]) Size of Tensor: torch.Size([2, 2, 3]) Flatten tensor: tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) Flatten tensor (end_dim=0): tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]]) Flatten tensor (end_dim=1): tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]]) Flatten tensor (end_dim=2): tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
Example 3
In this program, we flatten a tensor into a one-dimensional tensor. We also flatten the tensor with start_dim and end_dim.
import torch
t = torch.empty(2,2,3,3).random_(30)
print("Tensor:
", t)
print("Size of Tensor:", t.size())
# flatten the above tensor using end_dims
flatten_t0 = torch.flatten(t, start_dim=2, end_dim=3)
# print the flatten tensors
print("Flatten tensor (start_dim=2,end_dim=3):
", flatten_t0)
Output
Tensor: tensor([[[[27., 13., 29.], [ 1., 23., 15.], [15., 7., 19.]], [[ 4., 14., 24.], [ 6., 4., 7.], [ 6., 18., 11.]]], [[[ 0., 27., 3.], [25., 12., 25.], [10., 23., 9.]], [[ 3., 1., 28.], [19., 7., 28.], [23., 14., 21.]]]]) Size of Tensor: torch.Size([2, 2, 3, 3]) Flatten tensor (start_dim=2,end_dim=3): tensor([[[27., 13., 29., 1., 23., 15., 15., 7., 19.], [ 4., 14., 24., 6., 4., 7., 6., 18., 11.]], [[ 0., 27., 3., 25., 12., 25., 10., 23., 9.], [ 3., 1., 28., 19., 7., 28., 23., 14., 21.]]])
