How to flatten an input tensor by reshaping it in PyTorch?


A tensor can be flattened into a one-dimensional tensor by reshaping it using the method torch.flatten(). This method supports both real and complex-valued input tensors. It takes a torch tensor as its input and returns a torch tensor flattened into one dimension.

It takes two optional parameters, start_dim and end_dim. If these parameters are passed, only those dimensions starting with start_dim and ending with end_dim are flattened.

The order of elements in the input tensor is not changed. This function may return the original object, a view, or copy. In the following examples, we cover all the aspects of flattening the tensor with and without using start_dim and end_dim.

Syntax

torch.flatten(input, star_dim=0, end_dim=-1)

Parameters

  • input - It's a torch tensor to flatten.

  • start_dim - It's the first dimension to flatten. It's an optional parameter. Default is set to 0.

  • end_dim - It's the last dimension to flatten. It's an optional parameter. Default is set to -1.

Steps

  • Import the required library. In all the following examples, the required Python library is torch. Make sure you have already installed it.

import torch
  • Create a PyTorch tensor and print the tensor.

t = torch.tensor([[[1, 2, 3],
   [4, 5, 6]],
   [[7, 8, 9],
   [10, 11, 12]]])
print("Tensor:
", t)
  • Flatten the above tensor using the above-defined syntax and optionally assign the value to a new variable.

flatten_t = torch.flatten(t, start_dim=0, end_dim=1)
  • Print the flatten tensor.

print("Flattened Tensor:
", flatten_t)

Example 1

In this program, we flatten a tensor into a one-dimensional tensor. We also flatten the tensor with start_dim.

Import the required library
import torch
# define a torch tensor
t = torch.tensor([[[1, 2, 3],
   [4, 5, 6]],
   [[7, 8, 9],
   [10, 11, 12]]])
print("Tensor:
", t) print("Size of Tensor:", t.size()) # flatten the above tensor using start_dims flatten_t = torch.flatten(t) flatten_t0 = torch.flatten(t, start_dim=0) flatten_t1 = torch.flatten(t, start_dim=1) flatten_t2 = torch.flatten(t, start_dim=2) # print the flatten tensors print("Flatten tensor:
", flatten_t) print("Flatten tensor (start_dim=0):
", flatten_t0) print("Flatten tensor (start_dim=1):
", flatten_t1) print("Flatten tensor (start_dim=2):
", flatten_t2)

Output

Tensor:
   tensor([[[ 1, 2, 3],
      [ 4, 5, 6]],
      [[ 7, 8, 9],
      [10, 11, 12]]])
Size of Tensor: torch.Size([2, 2, 3])
Flatten tensor:
   tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
Flatten tensor (start_dim=0):
   tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
Flatten tensor (start_dim=1):
   tensor([[ 1, 2, 3, 4, 5, 6],
      [ 7, 8, 9, 10, 11, 12]])
Flatten tensor (start_dim=2):
   tensor([[[ 1, 2, 3],
      [ 4, 5, 6]],
      [[ 7, 8, 9],
      [10, 11, 12]]])

Example 2

In this program, we flatten a tensor into a one-dimensional tensor. We also flatten the tensor with end_dim.

import torch
t = torch.tensor([[[1, 2, 3],
   [4, 5, 6]],
   [[7, 8, 9],
   [10, 11, 12]]])
print("Tensor:
", t) print("Size of Tensor:", t.size()) # flatten the above tensor using end_dims flatten_t = torch.flatten(t) flatten_t0 = torch.flatten(t, end_dim=0) flatten_t1 = torch.flatten(t, end_dim=1) flatten_t2 = torch.flatten(t, end_dim=2) # print the flatten tensors print("Flatten tensor:
", flatten_t) print("Flatten tensor (end_dim=0):
", flatten_t0) print("Flatten tensor (end_dim=1):
", flatten_t1) print("Flatten tensor (end_dim=2):
", flatten_t2)

Output

Tensor:
   tensor([[[ 1, 2, 3],
      [ 4, 5, 6]],

      [[ 7, 8, 9],
      [10, 11, 12]]])
Size of Tensor: torch.Size([2, 2, 3])
Flatten tensor:
   tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
Flatten tensor (end_dim=0):
   tensor([[[ 1, 2, 3],
      [ 4, 5, 6]],

      [[ 7, 8, 9],
      [10, 11, 12]]])
Flatten tensor (end_dim=1):
   tensor([[ 1, 2, 3],
      [ 4, 5, 6],

      [ 7, 8, 9],
      [10, 11, 12]])
Flatten tensor (end_dim=2):
   tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])

Example 3

In this program, we flatten a tensor into a one-dimensional tensor. We also flatten the tensor with start_dim and end_dim.

import torch
t = torch.empty(2,2,3,3).random_(30)
print("Tensor:
", t) print("Size of Tensor:", t.size()) # flatten the above tensor using end_dims flatten_t0 = torch.flatten(t, start_dim=2, end_dim=3) # print the flatten tensors print("Flatten tensor (start_dim=2,end_dim=3):
", flatten_t0)

Output

Tensor:
   tensor([[[[27., 13., 29.],
      [ 1., 23., 15.],
      [15., 7., 19.]],

      [[ 4., 14., 24.],
      [ 6., 4., 7.],
      [ 6., 18., 11.]]],

      [[[ 0., 27., 3.],
      [25., 12., 25.],
      [10., 23., 9.]],

      [[ 3., 1., 28.],
      [19., 7., 28.],
      [23., 14., 21.]]]])
Size of Tensor: torch.Size([2, 2, 3, 3])
Flatten tensor (start_dim=2,end_dim=3):
   tensor([[[27., 13., 29., 1., 23., 15., 15., 7., 19.],
      [ 4., 14., 24., 6., 4., 7., 6., 18., 11.]],

      [[ 0., 27., 3., 25., 12., 25., 10., 23., 9.],
      [ 3., 1., 28., 19., 7., 28., 23., 14., 21.]]])

Updated on: 20-Jan-2022

5K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements