How to join tensors in PyTorch?


We can join two or more tensors using torch.cat(), and torch.stack(). torch.cat() is used to concatenate two or more tensors, whereas torch.stack() is used to stack the tensors. We can join the tensors in different dimensions such as 0 dimension, -1 dimension.

Both torch.cat() and torch.stack() are used to join the tensors. So, what is the basic difference between these two methods?

  • torch.cat() concatenates a sequence of tensors along an existing dimension, hence not changing the dimension of the tensors.

  • torch.stack() stacks the tensors along a new dimension, as a result, it increases the dimension.

Steps

  • Import the required library. In all the following examples, the required Python library is torch. Make sure you have already installed it.

  • Create two or more PyTorch tensors and print them.

  • Use torch.cat() or torch.stack() to join the above-created tensors. Provide dimension, i.e., 0, -1, to join the tensors in a particular dimension

  • Finally, print the concatenated or stacked tensors.

Example 1

# Python program to join tensors in PyTorch
# import necessary library
import torch

# create tensors
T1 = torch.Tensor([1,2,3,4])
T2 = torch.Tensor([0,3,4,1])
T3 = torch.Tensor([4,3,2,5])

# print above created tensors
print("T1:", T1)
print("T2:", T2)
print("T3:", T3)

# join (concatenate) above tensors using torch.cat()
T = torch.cat((T1,T2,T3))
# print final tensor after concatenation
print("T:",T)

Output

When you run the above Python 3 code, it will produce the following output

T1: tensor([1., 2., 3., 4.])
T2: tensor([0., 3., 4., 1.])
T3: tensor([4., 3., 2., 5.])
T: tensor([1., 2., 3., 4., 0., 3., 4., 1., 4., 3., 2., 5.])

Example 2

# import necessary library
import torch

# create tensors
T1 = torch.Tensor([[1,2],[3,4]])
T2 = torch.Tensor([[0,3],[4,1]])
T3 = torch.Tensor([[4,3],[2,5]])

# print above created tensors
print("T1:\n", T1)
print("T2:\n", T2)
print("T3:\n", T3)

print("join(concatenate) tensors in the 0 dimension")
T = torch.cat((T1,T2,T3), 0)
print("T:\n", T)

print("join(concatenate) tensors in the -1 dimension")
T = torch.cat((T1,T2,T3), -1)
print("T:\n", T)

Output

When you run the above Python 3 code, it will produce the following output

T1:
tensor([[1., 2.],
         [3., 4.]])
T2:
tensor([[0., 3.],
         [4., 1.]])
T3:
tensor([[4., 3.],
         [2., 5.]])
join(concatenate) tensors in the 0 dimension
T:
tensor([[1., 2.],
         [3., 4.],
         [0., 3.],
         [4., 1.],
         [4., 3.],
         [2., 5.]])
join(concatenate) tensors in the -1 dimension
T:
tensor([[1., 2., 0., 3., 4., 3.],
         [3., 4., 4., 1., 2., 5.]])

In the above example, 2D tensors are concatenated along 0 and -1 dimensions. Concatenating in 0 dimension increases the number of rows, leaving the number of columns unchanged.

Example 3

# Python program to join tensors in PyTorch
# import necessary library
import torch

# create tensors
T1 = torch.Tensor([1,2,3,4])
T2 = torch.Tensor([0,3,4,1])
T3 = torch.Tensor([4,3,2,5])

# print above created tensors
print("T1:", T1)
print("T2:", T2)
print("T3:", T3)

# join above tensor using "torch.stack()"
print("join(stack) tensors")
T = torch.stack((T1,T2,T3))

# print final tensor after join
print("T:\n",T)
print("join(stack) tensors in the 0 dimension")
T = torch.stack((T1,T2,T3), 0)

print("T:\n", T)
print("join(stack) tensors in the -1 dimension")
T = torch.stack((T1,T2,T3), -1)
print("T:\n", T)

Output

When you run the above Python 3 code, it will produce the following output

T1: tensor([1., 2., 3., 4.])
T2: tensor([0., 3., 4., 1.])
T3: tensor([4., 3., 2., 5.])
join(stack) tensors
T:
tensor([[1., 2., 3., 4.],
         [0., 3., 4., 1.],
         [4., 3., 2., 5.]])
join(stack) tensors in the 0 dimension
T:
tensor([[1., 2., 3., 4.],
         [0., 3., 4., 1.],
         [4., 3., 2., 5.]])
join(stack) tensors in the -1 dimension
T:
tensor([[1., 0., 4.],
         [2., 3., 3.],
         [3., 4., 2.],
         [4., 1., 5.]])

In the above example, you can notice that the 1D tensors are stacked and the final tensor is a 2D tensor.

Example 4

# import necessary library
import torch

# create tensors
T1 = torch.Tensor([[1,2],[3,4]])
T2 = torch.Tensor([[0,3],[4,1]])
T3 = torch.Tensor([[4,3],[2,5]])

# print above created tensors
print("T1:\n", T1)
print("T2:\n", T2)
print("T3:\n", T3)

print("Join (stack)tensors in the 0 dimension")
T = torch.stack((T1,T2,T3), 0)
print("T:\n", T)

print("Join(stack) tensors in the -1 dimension")
T = torch.stack((T1,T2,T3), -1)
print("T:\n", T)

Output

When you run the above Python 3 code, it will produce the following output.

T1:
tensor([[1., 2.],
         [3., 4.]])
T2:
tensor([[0., 3.],
         [4., 1.]])
T3:
tensor([[4., 3.],
         [2., 5.]])
Join (stack)tensors in the 0 dimension
T:
tensor([[[1., 2.],
         [3., 4.]],
         [[0., 3.],
         [4., 1.]],
         [[4., 3.],
         [2., 5.]]])
Join(stack) tensors in the -1 dimension
T:
tensor([[[1., 0., 4.],
         [2., 3., 3.]],
         [[3., 4., 2.],
         [4., 1., 5.]]])

In the above example, you can notice that 2D tensors are joined (stacked) to create a 3D tensor.

Updated on: 14-Sep-2023

26K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements