stack function

TensorBuffer stack(
  1. List<TensorBuffer> tensors, {
  2. int dim = 0,
})

Stacks a sequence of tensors along a new dimension.

Unlike concat, which joins tensors along an existing dimension, stack creates a new dimension and stacks tensors along it.

All input tensors must have the same shape.

Equivalent to torch.stack() in PyTorch.

Parameters

  • tensors: List of tensors to stack. Must all have the same shape.
  • dim: The dimension to insert. Default is 0 (prepend dimension). Supports negative indexing (e.g., -1 means last position).

Examples

// Stack [2,3] tensors along dim 0 -> [3,2,3]
final stacked = stack([a, b, c], dim: 0);

// Stack [2,3] tensors along dim 1 -> [2,3,3]
final stacked = stack([a, b, c], dim: 1);

// Stack [2,3] tensors along dim -1 -> [2,3,3]
final stacked = stack([a, b, c], dim: -1);

Implementation

TensorBuffer stack(List<TensorBuffer> tensors, {int dim = 0}) {
  if (tensors.isEmpty) {
    throw InvalidParameterException(
      'tensors',
      'empty list',
      'Cannot stack empty list of tensors',
    );
  }

  if (tensors.length == 1) {
    // Stack single tensor by adding a dimension
    return tensors.first.unsqueeze(dim);
  }

  final firstTensor = tensors.first;
  final firstShape = firstTensor.shape;
  final rank = firstShape.length;

  // Normalize dimension (supports negative indexing)
  // For stack, valid range is [-(rank+1), rank] since we're adding a new dim
  final outputRank = rank + 1;
  final normalizedDim = dim < 0 ? dim + outputRank : dim;

  if (normalizedDim < 0 || normalizedDim > rank) {
    throw InvalidParameterException(
      'dim',
      dim.toString(),
      'Dimension $dim is out of bounds for stack with tensor rank $rank',
    );
  }

  // Validate all tensors have the same shape and dtype
  for (int i = 1; i < tensors.length; i++) {
    final tensor = tensors[i];
    final shape = tensor.shape;

    if (tensor.dtype != firstTensor.dtype) {
      throw InvalidParameterException(
        'dtype',
        'tensor $i has dtype ${tensor.dtype}, expected ${firstTensor.dtype}',
        'All tensors must have the same dtype',
      );
    }

    if (shape.length != rank) {
      throw ShapeMismatchException(
        actual: shape,
        message: 'Tensor $i has rank ${shape.length}, expected $rank',
      );
    }

    for (int d = 0; d < rank; d++) {
      if (shape[d] != firstShape[d]) {
        throw ShapeMismatchException(
          actual: shape,
          message: 'Tensor $i has shape $shape, expected $firstShape for stack',
        );
      }
    }
  }

  // Compute output shape: insert new dimension at normalizedDim
  final outputShape = <int>[];
  for (int d = 0; d < rank; d++) {
    if (d == normalizedDim) {
      outputShape.add(tensors.length);
    }
    outputShape.add(firstShape[d]);
  }
  if (normalizedDim == rank) {
    outputShape.add(tensors.length);
  }

  // Create output tensor
  final output = TensorBuffer.uninitialized(
    outputShape,
    dtype: firstTensor.dtype,
  );

  // Copy data from each tensor
  _stackTensors(tensors, output, normalizedDim);

  return output;
}