stack function
Stacks a sequence of tensors along a new dimension.
Unlike concat, which joins tensors along an existing dimension, stack creates a new dimension and stacks tensors along it.
All input tensors must have the same shape.
Equivalent to torch.stack() in PyTorch.
Parameters
tensors: List of tensors to stack. Must all have the same shape.dim: The dimension to insert. Default is 0 (prepend dimension). Supports negative indexing (e.g., -1 means last position).
Examples
// Stack [2,3] tensors along dim 0 -> [3,2,3]
final stacked = stack([a, b, c], dim: 0);
// Stack [2,3] tensors along dim 1 -> [2,3,3]
final stacked = stack([a, b, c], dim: 1);
// Stack [2,3] tensors along dim -1 -> [2,3,3]
final stacked = stack([a, b, c], dim: -1);
Implementation
TensorBuffer stack(List<TensorBuffer> tensors, {int dim = 0}) {
if (tensors.isEmpty) {
throw InvalidParameterException(
'tensors',
'empty list',
'Cannot stack empty list of tensors',
);
}
if (tensors.length == 1) {
// Stack single tensor by adding a dimension
return tensors.first.unsqueeze(dim);
}
final firstTensor = tensors.first;
final firstShape = firstTensor.shape;
final rank = firstShape.length;
// Normalize dimension (supports negative indexing)
// For stack, valid range is [-(rank+1), rank] since we're adding a new dim
final outputRank = rank + 1;
final normalizedDim = dim < 0 ? dim + outputRank : dim;
if (normalizedDim < 0 || normalizedDim > rank) {
throw InvalidParameterException(
'dim',
dim.toString(),
'Dimension $dim is out of bounds for stack with tensor rank $rank',
);
}
// Validate all tensors have the same shape and dtype
for (int i = 1; i < tensors.length; i++) {
final tensor = tensors[i];
final shape = tensor.shape;
if (tensor.dtype != firstTensor.dtype) {
throw InvalidParameterException(
'dtype',
'tensor $i has dtype ${tensor.dtype}, expected ${firstTensor.dtype}',
'All tensors must have the same dtype',
);
}
if (shape.length != rank) {
throw ShapeMismatchException(
actual: shape,
message: 'Tensor $i has rank ${shape.length}, expected $rank',
);
}
for (int d = 0; d < rank; d++) {
if (shape[d] != firstShape[d]) {
throw ShapeMismatchException(
actual: shape,
message: 'Tensor $i has shape $shape, expected $firstShape for stack',
);
}
}
}
// Compute output shape: insert new dimension at normalizedDim
final outputShape = <int>[];
for (int d = 0; d < rank; d++) {
if (d == normalizedDim) {
outputShape.add(tensors.length);
}
outputShape.add(firstShape[d]);
}
if (normalizedDim == rank) {
outputShape.add(tensors.length);
}
// Create output tensor
final output = TensorBuffer.uninitialized(
outputShape,
dtype: firstTensor.dtype,
);
// Copy data from each tensor
_stackTensors(tensors, output, normalizedDim);
return output;
}