# How to flatten an input tensor by reshaping it in PyTorch?

PyTorchServer Side ProgrammingProgramming

A tensor can be flattened into a one-dimensional tensor by reshaping it using the method torch.flatten(). This method supports both real and complex-valued input tensors. It takes a torch tensor as its input and returns a torch tensor flattened into one dimension.

It takes two optional parameters, start_dim and end_dim. If these parameters are passed, only those dimensions starting with start_dim and ending with end_dim are flattened.

The order of elements in the input tensor is not changed. This function may return the original object, a view, or copy. In the following examples, we cover all the aspects of flattening the tensor with and without using start_dim and end_dim.

## Syntax

torch.flatten(input, star_dim=0, end_dim=-1)

## Parameters

• input - It's a torch tensor to flatten.

• start_dim - It's the first dimension to flatten. It's an optional parameter. Default is set to 0.

• end_dim - It's the last dimension to flatten. It's an optional parameter. Default is set to -1.

## Steps

• Import the required library. In all the following examples, the required Python library is torch. Make sure you have already installed it.

import torch
• Create a PyTorch tensor and print the tensor.

t = torch.tensor([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]])
print("Tensor:\n", t)
• Flatten the above tensor using the above-defined syntax and optionally assign the value to a new variable.

flatten_t = torch.flatten(t, start_dim=0, end_dim=1)
• Print the flatten tensor.

print("Flattened Tensor:\n", flatten_t)

## Example 1

In this program, we flatten a tensor into a one-dimensional tensor. We also flatten the tensor with start_dim.

Import the required library
import torch
# define a torch tensor
t = torch.tensor([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]])
print("Tensor:\n", t)
print("Size of Tensor:", t.size())

# flatten the above tensor using start_dims
flatten_t = torch.flatten(t)
flatten_t0 = torch.flatten(t, start_dim=0)
flatten_t1 = torch.flatten(t, start_dim=1)
flatten_t2 = torch.flatten(t, start_dim=2)

# print the flatten tensors
print("Flatten tensor:\n", flatten_t)
print("Flatten tensor (start_dim=0):\n", flatten_t0)
print("Flatten tensor (start_dim=1):\n", flatten_t1)
print("Flatten tensor (start_dim=2):\n", flatten_t2)

## Output

Tensor:
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
Size of Tensor: torch.Size([2, 2, 3])
Flatten tensor:
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
Flatten tensor (start_dim=0):
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
Flatten tensor (start_dim=1):
tensor([[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12]])
Flatten tensor (start_dim=2):
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])

## Example 2

In this program, we flatten a tensor into a one-dimensional tensor. We also flatten the tensor with end_dim.

import torch
t = torch.tensor([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]])
print("Tensor:\n", t)
print("Size of Tensor:", t.size())

# flatten the above tensor using end_dims
flatten_t = torch.flatten(t)
flatten_t0 = torch.flatten(t, end_dim=0)
flatten_t1 = torch.flatten(t, end_dim=1)
flatten_t2 = torch.flatten(t, end_dim=2)
# print the flatten tensors
print("Flatten tensor:\n", flatten_t)
print("Flatten tensor (end_dim=0):\n", flatten_t0)
print("Flatten tensor (end_dim=1):\n", flatten_t1)
print("Flatten tensor (end_dim=2):\n", flatten_t2)

## Output

Tensor:
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],

[[ 7, 8, 9],
[10, 11, 12]]])
Size of Tensor: torch.Size([2, 2, 3])
Flatten tensor:
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
Flatten tensor (end_dim=0):
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],

[[ 7, 8, 9],
[10, 11, 12]]])
Flatten tensor (end_dim=1):
tensor([[ 1, 2, 3],
[ 4, 5, 6],

[ 7, 8, 9],
[10, 11, 12]])
Flatten tensor (end_dim=2):
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])

## Example 3

In this program, we flatten a tensor into a one-dimensional tensor. We also flatten the tensor with start_dim and end_dim.

import torch
t = torch.empty(2,2,3,3).random_(30)
print("Tensor:\n", t)
print("Size of Tensor:", t.size())

# flatten the above tensor using end_dims
flatten_t0 = torch.flatten(t, start_dim=2, end_dim=3)

# print the flatten tensors
print("Flatten tensor (start_dim=2,end_dim=3):\n", flatten_t0)

## Output

Tensor:
tensor([[[[27., 13., 29.],
[ 1., 23., 15.],
[15., 7., 19.]],

[[ 4., 14., 24.],
[ 6., 4., 7.],
[ 6., 18., 11.]]],

[[[ 0., 27., 3.],
[25., 12., 25.],
[10., 23., 9.]],

[[ 3., 1., 28.],
[19., 7., 28.],
[23., 14., 21.]]]])
Size of Tensor: torch.Size([2, 2, 3, 3])
Flatten tensor (start_dim=2,end_dim=3):
tensor([[[27., 13., 29., 1., 23., 15., 15., 7., 19.],
[ 4., 14., 24., 6., 4., 7., 6., 18., 11.]],

[[ 0., 27., 3., 25., 12., 25., 10., 23., 9.],
[ 3., 1., 28., 19., 7., 28., 23., 14., 21.]]])
Updated on 20-Jan-2022 08:08:43