API Reference
This page contains the API reference for the Power Attention library.
High-Level Interface
The main entry point for using symmetric power attention in your models.
Main implementation of symmetric power attention, which generalizes linear transformers using symmetric power embeddings. Provides O(n) complexity for long sequences through an efficient RNN formulation.
Low-Level Components
These components implement the linear transformer with symmetric power embeddings. They're exposed for advanced usage and custom implementations.
Core Attention
Computes attention scores using symmetric power embeddings, equivalent to raising attention weights to an even power.
power_attention._attention
State Management
Functions for managing the RNN state representation, which achieves massive memory savings through symmetric tensors.
State Expansion
Computes expanded state vectors using symmetric power embeddings, achieving up to 96% memory reduction for deg=4.
power_attention._update_state
Query-State Interaction
Computes how queries interact with the compressed state representation from previous chunks.
power_attention._query_state
Recurrent Processing
Implements the RNN state update equations for O(n) processing of long sequences.
power_attention._discumsum
discumsum(X: torch.Tensor, log_G: torch.Tensor) -> torch.Tensor
Compute discounted cumulative sum for recurrent processing.
This function implements the discounted cumulative sum operation from [1]. It is a key component in transforming symmetric power attention into a linear-cost RNN, allowing efficient processing of long sequences through chunked computation.
The operation implements the state update equation of the RNN:
with initial condition \(Y_0 = 0\)
In the context of symmetric power attention:
- \(X_t\) contains expanded state vectors from the current chunk
- \(Y_t\) accumulates state information across chunks
- \(exp(log\ G_t)\) controls how much past information influences the current computation
- The +1 in the output time dimension allows for proper causality in the RNN
This formulation enables O(n) complexity instead of O(n²) for long sequences, while maintaining the expressivity of power attention through the expanded state representation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X
|
Tensor
|
Input tensor of shape |
required |
log_G
|
Tensor
|
Log discount factors of shape |
required |
Returns:
Name | Type | Description |
---|---|---|
Y |
Tensor
|
Accumulated tensor of shape |
Note
- Time dimension must be a multiple of 4
- Product of feature dimensions must be a multiple of 8
- The batch and heads dimensions are treated as independent batch dimensions
- Initial state support is planned but not yet implemented
- The RNN formulation maintains O(1) memory per layer regardless of sequence length
References
[1] J. Buckman, C. Gelada, and S. Zhang, "Symmetric Power Transformers." Manifest AI, Aug. 15, 2024.
Source code in power_attention/_discumsum/cuda.py
Utility Functions
Helper functions for testing and benchmarking.
Create sample inputs for testing power attention, with appropriate initialization for stable training.