How to slice a 3D Tensor in Pytorch?

A 3D Tensor in PyTorch is a three-dimensional array containing matrices, while 1D and 2D tensors represent vectors and matrices respectively. PyTorch provides various methods to slice 3D tensors using indexing operations and built-in functions like split().

Basic Tensor Slicing Syntax

PyTorch uses standard Python indexing with the format tensor[dim1, dim2, dim3] where each dimension can use slice notation ?

import torch

# Create a 3D tensor with shape (2, 3, 4)
tensor_3d = torch.randn(2, 3, 4)
print("Original tensor shape:", tensor_3d.shape)
print(tensor_3d)
Original tensor shape: torch.Size([2, 3, 4])
tensor([[[-0.5234,  1.2341, -0.8765,  0.4321],
         [ 0.7890, -1.3456,  0.9876, -0.2345],
         [-0.6789,  0.3456,  1.2345, -0.7890]],

        [[ 1.4567, -0.2345,  0.8765, -1.3456],
         [-0.4321,  0.7890,  1.2345, -0.9876],
         [ 0.3456, -0.6789,  0.2345,  1.4567]]])

Slicing Along Different Dimensions

You can slice along each dimension independently using colon notation ?

import torch

tensor_3d = torch.randn(2, 3, 4)

# Slice first dimension - keep first matrix only
first_matrix = tensor_3d[0, :, :]
print("First matrix shape:", first_matrix.shape)

# Slice second dimension - keep first two rows
first_two_rows = tensor_3d[:, 0:2, :]
print("First two rows shape:", first_two_rows.shape)

# Slice third dimension - keep last two columns
last_two_cols = tensor_3d[:, :, -2:]
print("Last two columns shape:", last_two_cols.shape)
First matrix shape: torch.Size([3, 4])
First two rows shape: torch.Size([2, 2, 4])
Last two columns shape: torch.Size([2, 3, 2])

Using torch.split() Method

The torch.split() function divides a tensor into chunks along a specified dimension ?

import torch

# Create a 3D tensor with shape (2, 3, 6)
data = torch.rand(2, 3, 6)
print("Original shape:", data.shape)

# Split along dimension 2 (last dimension) into chunks of size 2
chunks = torch.split(data, 2, dim=2)
print("Number of chunks:", len(chunks))
print("Each chunk shape:", chunks[0].shape)
Original shape: torch.Size([2, 3, 6])
Number of chunks: 3
Each chunk shape: torch.Size([2, 3, 2])

Indexing Specific Elements

Access specific matrices or elements by providing exact indices ?

import torch

# Create a tensor with known values
tensor_data = torch.tensor([[[10, 20, 30, 40],
                            [50, 60, 70, 80],
                            [1, 2, 3, 4]],
                           [[13, 14, 15, 16],
                            [21, 22, 23, 24],
                            [3, 4, 5, 6]]])

# Access individual matrices
matrix_0 = tensor_data[0]
matrix_1 = tensor_data[1]

print("First matrix:")
print(matrix_0)
print("\nSecond matrix:")
print(matrix_1)
First matrix:
tensor([[10, 20, 30, 40],
        [50, 60, 70, 80],
        [ 1,  2,  3,  4]])

Second matrix:
tensor([[13, 14, 15, 16],
        [21, 22, 23, 24],
        [ 3,  4,  5,  6]])

Common Slicing Operations

Operation Syntax Description
All elements tensor[:, :, :] Returns complete tensor
First matrix tensor[0, :, :] First 2D slice
First row of all matrices tensor[:, 0, :] First row across dimension 1
Last column tensor[:, :, -1] Last column of all matrices

Conclusion

PyTorch tensor slicing uses standard Python indexing with colon notation for each dimension. Use tensor[dim1, dim2, dim3] syntax for basic slicing and torch.split() for dividing tensors into equal chunks along specific dimensions.

Updated on: 2026-03-27T07:46:11+05:30

2K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements