Skip to content

API Reference

This page contains the API reference for the Power Attention library.

High-Level Interface

The main entry point for using symmetric power attention in your models.

Main implementation of symmetric power attention, which generalizes linear transformers using symmetric power embeddings. Provides O(n) complexity for long sequences through an efficient RNN formulation.

Low-Level Components

These components implement the linear transformer with symmetric power embeddings. They're exposed for advanced usage and custom implementations.

Core Attention

Computes attention scores using symmetric power embeddings, equivalent to raising attention weights to an even power.

power_attention._attention

State Management

Functions for managing the RNN state representation, which achieves massive memory savings through symmetric tensors.

State Expansion

Computes expanded state vectors using symmetric power embeddings, achieving up to 96% memory reduction for deg=4.

power_attention._update_state

Query-State Interaction

Computes how queries interact with the compressed state representation from previous chunks.

power_attention._query_state

Recurrent Processing

Implements the RNN state update equations for O(n) processing of long sequences.

power_attention._discumsum

discumsum(X: torch.Tensor, log_G: torch.Tensor) -> torch.Tensor

Compute discounted cumulative sum for recurrent processing.

This function implements the discounted cumulative sum operation from [1]. It is a key component in transforming symmetric power attention into a linear-cost RNN, allowing efficient processing of long sequences through chunked computation.

The operation implements the state update equation of the RNN:

\[Y_t = X_t + \exp(\log G_t) \cdot Y_{t-1}\]

with initial condition \(Y_0 = 0\)

In the context of symmetric power attention:

  • \(X_t\) contains expanded state vectors from the current chunk
  • \(Y_t\) accumulates state information across chunks
  • \(exp(log\ G_t)\) controls how much past information influences the current computation
  • The +1 in the output time dimension allows for proper causality in the RNN

This formulation enables O(n) complexity instead of O(n²) for long sequences, while maintaining the expressivity of power attention through the expanded state representation.

Parameters:

Name Type Description Default
X Tensor

Input tensor of shape (batch_size, time, num_heads, *feature_dims). The tensor to be accumulated along the time dimension.

required
log_G Tensor

Log discount factors of shape (batch_size, time, num_heads). Natural logarithm of the discount/gating factors. These are broadcasted along the feature dimensions.

required

Returns:

Name Type Description
Y Tensor

Accumulated tensor of shape (batch_size, time+1, num_heads, *feature_dims). Note that the output has one more timestep than the input, with zeros at t=0.

Note
  • Time dimension must be a multiple of 4
  • Product of feature dimensions must be a multiple of 8
  • The batch and heads dimensions are treated as independent batch dimensions
  • Initial state support is planned but not yet implemented
  • The RNN formulation maintains O(1) memory per layer regardless of sequence length
References

[1] J. Buckman, C. Gelada, and S. Zhang, "Symmetric Power Transformers." Manifest AI, Aug. 15, 2024.

Source code in power_attention/_discumsum/cuda.py
@torch.library.custom_op("power_attention::discumsum", mutates_args=())
def discumsum(X : torch.Tensor, log_G : torch.Tensor) -> torch.Tensor:
    r"""Compute discounted cumulative sum for recurrent processing.

    This function implements the discounted cumulative sum operation from [1]. It is a key
    component in transforming symmetric power attention into a linear-cost RNN, allowing
    efficient processing of long sequences through chunked computation.

    The operation implements the state update equation of the RNN:

    $$Y_t = X_t + \exp(\log G_t) \cdot Y_{t-1}$$

    with initial condition $Y_0 = 0$

    In the context of symmetric power attention:

    - $X_t$ contains expanded state vectors from the current chunk
    - $Y_t$ accumulates state information across chunks
    - $exp(log\ G_t)$ controls how much past information influences the current computation
    - The +1 in the output time dimension allows for proper causality in the RNN

    This formulation enables O(n) complexity instead of O(n²) for long sequences, while
    maintaining the expressivity of power attention through the expanded state representation.

    Args:
        X: Input tensor of shape `(batch_size, time, num_heads, *feature_dims)`.
           The tensor to be accumulated along the time dimension.
        log_G: Log discount factors of shape `(batch_size, time, num_heads)`.
           Natural logarithm of the discount/gating factors.
           These are broadcasted along the feature dimensions.

    Returns:
        Y: Accumulated tensor of shape `(batch_size, time+1, num_heads, *feature_dims)`.
           Note that the output has one more timestep than the input, with zeros at t=0.

    Note:
        - Time dimension must be a multiple of 4
        - Product of feature dimensions must be a multiple of 8
        - The batch and heads dimensions are treated as independent batch dimensions
        - Initial state support is planned but not yet implemented
        - The RNN formulation maintains O(1) memory per layer regardless of sequence length

    References:
        [1] J. Buckman, C. Gelada, and S. Zhang, "Symmetric Power Transformers." 
            Manifest AI, Aug. 15, 2024.
    """
    b, n, h, *ds = X.shape
    if len(X.shape) > 4:
        X = X.view(*X.shape[:3], -1)
    cum_X = discumsum_fwd(X, log_G)
    return cum_X.view(b, n+1, h, *ds)

Utility Functions

Helper functions for testing and benchmarking.

Create sample inputs for testing power attention, with appropriate initialization for stable training.

power_attention.power_full

create_inputs(b=2, t=1024, h=8, d=32, qhead_ratio=1, dtype=torch.float16, device='cuda', gating=False, chunk_size=None, deg=2, requires_grad=False, seed=42, scale=1.0)

Source code in power_attention/create_inputs.py
def create_inputs(b=2, t=1024, h=8, d=32, qhead_ratio=1, dtype=torch.float16, device='cuda', gating=False,
                  chunk_size=None, deg=2, requires_grad=False, seed=42, scale=1.0):
    torch.manual_seed(seed)
    Q = torch.randn(size=(b, t, h * qhead_ratio, d), dtype=dtype, device=device) / math.sqrt(d)
    K = torch.randn(size=(b, t, h, d), dtype=dtype, device=device) / math.sqrt(d)
    V = torch.randn(size=(b, t, h, d), dtype=dtype, device=device) / math.sqrt(d)
    if gating:
        log_G = F.logsigmoid(torch.randn(size=(b, t, h), dtype=torch.float32, device=device))
    else:
        log_G = None
    initial_state = None
    if requires_grad:
        Q, K, V, log_G, initial_state = tree_map(
            lambda x: x.requires_grad_(True) if x is not None else None, (Q, K, V, log_G, initial_state))
    return dict(Q=Q, K=K, V=V, log_G=log_G, 
                initial_state=initial_state,
                return_final_state=False,
                deg=deg, scale=scale,
                chunk_size=chunk_size)