attention
Attention
¶
Bases: torch.autograd.Function
forward(ctx, q, k, v, output, sm_scale, is_causal, attention_mask=None)
staticmethod
¶
Computes attention. FP32 input and output are not supported. https://github.com/openai/triton/issues/674 Not an issue as the function is annotated with @custom_fwd(cast_inputs=torch.float16) so the input is casted to float16 before the function is called.
@param ctx: context for autograd @param q: Query matrix size (batch, head_size, m_size, dhead) @param k: Key matrix size (batch, head_size, n_size, dhead) @param v: Value matrix size (batch, head_size, n_size, dhead) @param output: Output matrix size (batch, head_size, m_size, dhead) @param sm_scale: SM (softmax) scaling factor applied on Q•K^T just before the softmax @param is_causal: Autoregressive decoder attention @param attention_mask: Attention mask matrix broadcastable to (batch, head_size, m_size, n_size) @return:
attention_reference(q, k, v, output, sm_scale, is_causal, attention_mask)
¶
Reference implementation for attention @param q: Query matrix size (batch, heads, m_size, BLOCK_DHEAD) @param k: Key matrix size (batch, heads, n_size, BLOCK_DHEAD) @param v: Value matrix size (batch, heads, n_size, BLOCK_DHEAD) @param output: Output matrix size (batch, heads, m_size, BLOCK_DHEAD) @param sm_scale: SM (softmax) scaling factor applied on Q•K^T just before the softmax @param is_causal: Whether to apply causal attention @param attention_mask: Attention mask broadcastable to (batch, heads, m_size, n_size). Warning the mask isn't a binary mask like the one you use normally. This mask is directly added to QxK. @return:
closest_power_of_2(n, min_range=16, max_range=128)
¶
return the closests power of 2 for n, in 16-128 range
prune(configs, named_args)
¶
remove block shapes unlikely to provide optimal speedup