Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How to slice a 3D Tensor in Pytorch?
A 3D Tensor in PyTorch is a three-dimensional array containing matrices, while 1D and 2D tensors represent vectors and matrices respectively. PyTorch provides various methods to slice 3D tensors using indexing operations and built-in functions like split().
Basic Tensor Slicing Syntax
PyTorch uses standard Python indexing with the format tensor[dim1, dim2, dim3] where each dimension can use slice notation ?
import torch
# Create a 3D tensor with shape (2, 3, 4)
tensor_3d = torch.randn(2, 3, 4)
print("Original tensor shape:", tensor_3d.shape)
print(tensor_3d)
Original tensor shape: torch.Size([2, 3, 4])
tensor([[[-0.5234, 1.2341, -0.8765, 0.4321],
[ 0.7890, -1.3456, 0.9876, -0.2345],
[-0.6789, 0.3456, 1.2345, -0.7890]],
[[ 1.4567, -0.2345, 0.8765, -1.3456],
[-0.4321, 0.7890, 1.2345, -0.9876],
[ 0.3456, -0.6789, 0.2345, 1.4567]]])
Slicing Along Different Dimensions
You can slice along each dimension independently using colon notation ?
import torch
tensor_3d = torch.randn(2, 3, 4)
# Slice first dimension - keep first matrix only
first_matrix = tensor_3d[0, :, :]
print("First matrix shape:", first_matrix.shape)
# Slice second dimension - keep first two rows
first_two_rows = tensor_3d[:, 0:2, :]
print("First two rows shape:", first_two_rows.shape)
# Slice third dimension - keep last two columns
last_two_cols = tensor_3d[:, :, -2:]
print("Last two columns shape:", last_two_cols.shape)
First matrix shape: torch.Size([3, 4]) First two rows shape: torch.Size([2, 2, 4]) Last two columns shape: torch.Size([2, 3, 2])
Using torch.split() Method
The torch.split() function divides a tensor into chunks along a specified dimension ?
import torch
# Create a 3D tensor with shape (2, 3, 6)
data = torch.rand(2, 3, 6)
print("Original shape:", data.shape)
# Split along dimension 2 (last dimension) into chunks of size 2
chunks = torch.split(data, 2, dim=2)
print("Number of chunks:", len(chunks))
print("Each chunk shape:", chunks[0].shape)
Original shape: torch.Size([2, 3, 6]) Number of chunks: 3 Each chunk shape: torch.Size([2, 3, 2])
Indexing Specific Elements
Access specific matrices or elements by providing exact indices ?
import torch
# Create a tensor with known values
tensor_data = torch.tensor([[[10, 20, 30, 40],
[50, 60, 70, 80],
[1, 2, 3, 4]],
[[13, 14, 15, 16],
[21, 22, 23, 24],
[3, 4, 5, 6]]])
# Access individual matrices
matrix_0 = tensor_data[0]
matrix_1 = tensor_data[1]
print("First matrix:")
print(matrix_0)
print("\nSecond matrix:")
print(matrix_1)
First matrix:
tensor([[10, 20, 30, 40],
[50, 60, 70, 80],
[ 1, 2, 3, 4]])
Second matrix:
tensor([[13, 14, 15, 16],
[21, 22, 23, 24],
[ 3, 4, 5, 6]])
Common Slicing Operations
| Operation | Syntax | Description |
|---|---|---|
| All elements | tensor[:, :, :] |
Returns complete tensor |
| First matrix | tensor[0, :, :] |
First 2D slice |
| First row of all matrices | tensor[:, 0, :] |
First row across dimension 1 |
| Last column | tensor[:, :, -1] |
Last column of all matrices |
Conclusion
PyTorch tensor slicing uses standard Python indexing with colon notation for each dimension. Use tensor[dim1, dim2, dim3] syntax for basic slicing and torch.split() for dividing tensors into equal chunks along specific dimensions.
