linear_layer
LinearLayer
¶
Bases: torch.autograd.Function
forward(ctx, x, weight, bias, activation, act_inputs)
staticmethod
¶
Compute e = activation(x @ weight + bias).
This wrapper kicks the kernel_fma
Triton kernel
:param ctx: context for autograd
:param x: input tensor
:param weight: weight matrix
:param bias: an optional bias tensor
:param activation: Activation name. Needs to be a Triton kernel.
:param act_inputs: an optional tensor to save the activation inputs (for backward)
:return: result tensor
kernel_fma(C, ACT_INPUTS, A, B, bias, M, N, K, CACHE_KEY_M, CACHE_KEY_N, CACHE_KEY_K, output_m_stride, output_n_stride, act_inputs_m_stride, act_inputs_n_stride, a_m_stride, a_k_stride, b_n_stride, b_k_stride, BLOCK_M, GROUP_M, BLOCK_N, BLOCK_K, SPLIT_K, K_LOAD_MASK_NEEDED, HAS_BIAS, SHOULD_SAVE_ACT_INPUTS, ACTIVATION)
¶
Kernel for computing Out = activation(A x W + C)
- Input has shape (M, K)
- Weight has shape (K, N)
- Bias has shape (N,)
- Output has shape (M, N)
- ActInputs (optional) has shape (M, N)
'ActInputs' optionally saves the A x W + C intermediate for backward computations
This kernel will consolidate over K