How to support a new model¶
How does Kernl optimize a model¶
Overview¶
To optimize a model, Kernl uses TorchDynamo JIT compiler and provides a custom backend where we replace part of the Torch FX graph with optimized kernels.
The custom backend is defined in src/kernl/model_optimization.py
def compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
dynamo_backend_ofi(gm)
return cuda_graphs_wrapper(gm, example_inputs, pool=pool)
- First one is to apply the graph replacements
- Second one is to use CUDA graphs
The second step eliminates most of the CPU overhead but we won't elaborate on this and focus on the first one that does the graph replacements.
Inspecting the FX Graph¶
First, a few words about Torch FX. Torch FX is a torch module to module transformation toolkit. It can trace a torch module (or a function) execution. All the operations are then recorded into a graph of nodes. From this graph, Torch FX generates python code matching the graph's semantics. Both graph and python code are accessible from the Torch FX GraphModule, which is also a torch module instance. Torch FX allows us to play with FX graphs but stay at the torch module level.
With a FX GraphModule, we can inspect both the FX Graph and the generated code, with .graph
and .code
respectively. For better readability, one may want to use graph_report(gm)
that print the FX Graph in a tabular way (similarly to Torch FX print_tabular
method)1.
graph_report(gm)
lists all the operations during the execution. Each line corresponds to a node from the FX Graph with the given information:
opcode
is the kind of operationname
is the node name, usually the operation nametarget
is the operation applied in this nodeargs
andkwargs
are the arguments given to the operation
For more precise information on the kind of nodes and their semantics see the Torch FX Node documentation.
For example, if we trace this following torch module into a FX GraphModule:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
We can print the graph with the graph_report(gm)
function:
opcode name target args kwargs
------------- ------ ----------------------- ---------- ------------------------
placeholder x x () {}
get_attr param param () {}
call_function add <built-in function add> (x, param) {}
call_module linear linear (add,) {}
call_method clamp clamp (linear,) {'min': 0.0, 'max': 1.0}
output output output (clamp,) {}
----------
Used modules
----------
name target_type
------ ------------------------------------------------
linear Linear(in_features=4, out_features=5, bias=True)
We can see here every operation listed in computation order, from getting the forward function parameter to returning the output. One more useful thing that graph_report(gm)
does is to print the list of torch modules used in the graph, as in the node list we only have the torch module names and not the actual torch module types.
The generated code from this FX Graph is the following:
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
More visually, we can draw the FX Graph to better see the computation. In this representation, edges represent the link between a node and the nodes in its arguments (we're discarding args
and kwargs
that are note previously defined nodes):
Replace part of the FX Graph¶
An FX Graph can be modified directly but we'll rather use subgraph rewriting.
To rewrite the Torch FX Graph, Kernl uses the repace_pattern()
function defined in src/kernl/utils/extended_matcher.py. It's the same function repace_pattern()
of Torch FX but with some bugfixes (that should be integrated in PyTorch in the future).
The function takes a graph gm
and two callables pattern
and replacement
that can be either a torch module or a function. It'll convert pattern
and replacement
to an FX Graph and try to replace subgraphs from gm
matching pattern
with replacement
.
For example, given this 2-layers perceptron model, we'd like to replace the first layer activation from tanh
to reLU
.
class FeedForward(torch.nn.Module):
def __init__(self, input_size, hidden_size):
super(FeedForward, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.tanh = torch.nn.Tanh()
self.fc2 = torch.nn.Linear(hidden_size, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
hidden = self.fc1(x)
tanh = self.tanh(hidden)
output = self.fc2(tanh)
output = self.sigmoid(output)
return output
m = Feedforward(5, 10)
By tracing this module, we can print the FX Graph:
opcode name target args kwargs
----------- ------- -------- ---------- --------
placeholder x x () {}
call_module fc1 fc1 (x,) {}
call_module tanh tanh (fc1,) {}
call_module fc2 fc2 (tanh,) {}
call_module sigmoid sigmoid (fc2,) {}
output output output (sigmoid,) {}
----------
Used modules
----------
name target_type
------- -------------------------------------------------
fc1 Linear(in_features=5, out_features=10, bias=True)
tanh Tanh()
fc2 Linear(in_features=10, out_features=1, bias=True)
sigmoid Sigmoid()
class Pattern(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(1, 1)
self.activation = torch.nn.Tanh()
def forward(self, v):
return self.activation(self.linear(v))
The corresponding FX Graph is the following:
opcode name target args kwargs
----------- ---------- ---------- ------------- --------
placeholder v v () {}
call_module linear linear (v,) {}
call_module activation activation (linear,) {}
output output output (activation,) {}
----------
Used modules
----------
name target_type
---------- ------------------------------------------------
linear Linear(in_features=1, out_features=1, bias=True)
activation Tanh()
We don't need the node names to be the same as the ones in the graph we want to match, what is important is that we match the same node pattern. In our example, the node names differ (fc1
and tanh
in the graph, linear
and activation
in the pattern subgraph), but the modules called are identical (Linear
and Tanh
).
We have our pattern subgraph, we may now write our replacement subgraph with the ReLU activation function and display its FX Graph.
class Replacement(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(1, 1)
self.relu = torch.nn.ReLU()
def forward(self, v):
return self.relu(self.linear(v))
opcode name target args kwargs
----------- ------ -------- --------- --------
placeholder v v () {}
call_module linear linear (v,) {}
call_module relu relu (linear,) {}
output output output (relu,) {}
----------
Used modules
----------
name target_type
------ ------------------------------------------------
linear Linear(in_features=1, out_features=1, bias=True)
relu ReLU()
Unlike the matching pattern, we must be a bit cautious of the node names in the replacement pattern. If we want to reuse the nodes matched in the graph, we must use the same node names as in the pattern. Otherwise, it'll create a new node in the graph. In our example, the linear
and the v
node are kept from the node matched in the original graph but the relu
node is added to the graph.
Finally, we can apply the replacement and look at the resulting FX Graph:
opcode name target args kwargs
----------- ------- -------- ---------- --------
placeholder x x () {}
call_module linear fc1 (x,) {}
call_module relu relu (linear,) {}
call_module fc2 fc2 (relu,) {}
call_module sigmoid sigmoid (fc2,) {}
output output output (sigmoid,) {}
----------
Used modules
----------
name target_type
------- -------------------------------------------------
fc1 Linear(in_features=5, out_features=10, bias=True)
relu ReLU()
fc2 Linear(in_features=10, out_features=1, bias=True)
sigmoid Sigmoid()
The resulting graph has switched from tanh
activation to reLU
, the fc1
node has been kept untouched.
When we don't need to match a call to a torch submodule, it's easier to write pattern and a replacement as functions, as we'll see in our example with BERT attention.
There are some limitations with subgraph rewriting. When we use a function not covered by Torch FX, we'll have to use Torch wrap function in order to appear in the FX Graph but not to be traced.
Example: replacing BERT Attention¶
In this example, we'll see how to replace the attention part of a BERT model with Kernl's optimized attention kernel.
Understanding Attention¶
First, we need to look how attention works, the original paper "Attention Is All You Need" is a good starting point. More specifically, we'll focus on the Attention part where the attention function is defined:
Attention Is All You Need
An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.
(...)
We call our particular attention "Scaled Dot-Product Attention". The input consists of queries and keys of dimension \(d_k\), and values of dimension \(d_v\). We compute the dot products of the query with all keys, divide each by \(\sqrt{d_k}\), and apply a softmax function to obtain the weights on the values. In practice, we compute the attention function on a set of queries simultaneously, packed together into a matrix \(Q\). The keys and values are also packed together into matrices \(K\) and \(V\). We compute the matrix of outputs as:
This function can be represented as a computation graph where the attention mask is added in the process:
This graph representation will be useful as it is this graph we'll try to replace to optimize a BERT model.
Find the Attention graph pattern¶
For our example, we'll replace the attention part from the "bert-base-uncased" pre-trained model from Hugging Face Transformers. If we look at the BERT implementation, we find the attention function as a torch module:
class BertSelfAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
...
# Take the dot product between "query" and "key" to get the raw attention scores. # (1)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
...
attention_scores = attention_scores / math.sqrt(self.attention_head_size) # (2)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities. # (3)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
...
context_layer = torch.matmul(attention_probs, value_layer) # (4)
...
- \(QK^T\)
- \(\frac{QK^T}{\sqrt{d_k}}\)
- \(\operatorname{softmax}(\frac{QK^T}{\sqrt{d_k}})\)
- \(\operatorname{softmax}(\frac{QK^T}{\sqrt{d_k}})V\)
class BertSelfAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores. # (1)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size) # (2)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities. # (3)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer) # (4)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
- \(QK^T\)
- \(\frac{QK^T}{\sqrt{d_k}}\)
- \(\operatorname{softmax}(\frac{QK^T}{\sqrt{d_k}})\)
- \(\operatorname{softmax}(\frac{QK^T}{\sqrt{d_k}})V\)
We see that the Hugging Face implementation is close to the definition from the paper, we want to find the attention pattern in this model.
To begin, we'll write a short script running the model with a dummy input:
import torch
from transformers import AutoModel
from kernl.model_optimization import optimize_model
model = AutoModel.from_pretrained(pretrained_model_name_or_path="bert-base-uncased").eval().cuda()
optimize_model(model)
shape = (1, 128)
with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):
inputs = {
"input_ids": torch.randint(2, 10000, shape, device="cuda", dtype=torch.long),
"attention_mask": torch.ones(shape, device="cuda", dtype=torch.long),
}
output = model(**inputs)
If we run graph_report(gm)
and gm.code
in the dynamo_backend_ofi(gm)
function in the Kernl library, we can print the FX Graph and the python code from the model computation. For our example, we'll keep the normalize_operators
and the remove_dropout
functions as it simplifies the model's graph a bit.
def dynamo_backend_ofi(gm: torch.fx.GraphModule, assume_causal=False):
normalize_operators(gm)
remove_dropout(gm)
print(graph_report(gm))
print(gm.code)
return gm
Below is the resulting output, we only show the beginning of the graph until the first attention layer.
Part of the FX Graph of BERT model
opcode name target args kwargs
------------- ------------------------------------------------- --------------------------------------------------------- --------------------------------------------------------------------------------------- ------------------------
placeholder input_ids input_ids () {}
placeholder attention_mask attention_mask () {}
get_attr self_embeddings_token_type_ids self_embeddings_token_type_ids () {}
call_function getitem <built-in function getitem> (self_embeddings_token_type_ids, (slice(None, None, None), slice(None, 128, None))) {}
call_method expand expand (getitem, 1, 128) {}
call_function getitem_1 <built-in function getitem> (attention_mask, (slice(None, None, None), None, None, slice(None, None, None))) {}
call_method to to (getitem_1,) {'dtype': torch.float32}
call_function sub <built-in function sub> (1.0, to) {}
call_function mul <built-in function mul> (sub, -3.4028234663852886e+38) {}
get_attr self_embeddings_position_ids self_embeddings_position_ids () {}
call_function getitem_2 <built-in function getitem> (self_embeddings_position_ids, (slice(None, None, None), slice(0, 128, None))) {}
call_module self_embeddings_word_embeddings self_embeddings_word_embeddings (input_ids,) {}
call_module self_embeddings_token_type_embeddings self_embeddings_token_type_embeddings (expand,) {}
call_function add <built-in function add> (self_embeddings_word_embeddings, self_embeddings_token_type_embeddings) {}
call_module self_embeddings_position_embeddings self_embeddings_position_embeddings (getitem_2,) {}
call_function add_37 <built-in method add of type object at 0x7f065046e4e0> (add, self_embeddings_position_embeddings) {}
call_module self_embeddings_layer_norm self_embeddings_LayerNorm (add_37,) {}
call_module self_encoder_layer_0_attention_self_query self_encoder_layer_0_attention_self_query (self_embeddings_layer_norm,) {}
call_module self_encoder_layer_0_attention_self_key self_encoder_layer_0_attention_self_key (self_embeddings_layer_norm,) {}
call_method view view (self_encoder_layer_0_attention_self_key, (1, 128, 12, 64)) {}
call_method permute permute (view, 0, 2, 1, 3) {}
call_module self_encoder_layer_0_attention_self_value self_encoder_layer_0_attention_self_value (self_embeddings_layer_norm,) {}
call_method view_1 view (self_encoder_layer_0_attention_self_value, (1, 128, 12, 64)) {}
call_method permute_1 permute (view_1, 0, 2, 1, 3) {}
call_method view_2 view (self_encoder_layer_0_attention_self_query, (1, 128, 12, 64)) {}
call_method permute_2 permute (view_2, 0, 2, 1, 3) {}
call_method transpose transpose (permute, -1, -2) {}
call_function matmul <built-in method matmul of type object at 0x7f065046e4e0> (permute_2, transpose) {}
call_function truediv <built-in function truediv> (matmul, 8.0) {}
call_function add_1 <built-in function add> (truediv, mul) {}
call_function softmax <function softmax at 0x7f05eca5f790> (add_1,) {'dim': -1}
call_function matmul_1 <built-in method matmul of type object at 0x7f065046e4e0> (softmax, permute_1) {}
call_method permute_3 permute (matmul_1, 0, 2, 1, 3) {}
call_method contiguous contiguous (permute_3,) {}
call_method view_3 view (contiguous, (1, 128, 768)) {}
call_module self_encoder_layer_0_attention_output_dense self_encoder_layer_0_attention_output_dense (view_3,) {}
def forward(self, input_ids : torch.Tensor, attention_mask : torch.Tensor):
self_embeddings_token_type_ids = self.self_embeddings_token_type_ids
getitem = self_embeddings_token_type_ids[(slice(None, None, None), slice(None, 128, None))]; self_embeddings_token_type_ids = None
expand = getitem.expand(1, 128); getitem = None
getitem_1 = attention_mask[(slice(None, None, None), None, None, slice(None, None, None))]; attention_mask = None
to = getitem_1.to(dtype = torch.float32); getitem_1 = None
sub = 1.0 - to; to = None
mul = sub * -3.4028234663852886e+38; sub = None
self_embeddings_position_ids = self.self_embeddings_position_ids
getitem_2 = self_embeddings_position_ids[(slice(None, None, None), slice(0, 128, None))]; self_embeddings_position_ids = None
self_embeddings_word_embeddings = self.self_embeddings_word_embeddings(input_ids); input_ids = None
self_embeddings_token_type_embeddings = self.self_embeddings_token_type_embeddings(expand); expand = None
add = self_embeddings_word_embeddings + self_embeddings_token_type_embeddings; self_embeddings_word_embeddings = self_embeddings_token_type_embeddings = None
self_embeddings_position_embeddings = self.self_embeddings_position_embeddings(getitem_2); getitem_2 = None
add_37 = torch.add(add, self_embeddings_position_embeddings); add = self_embeddings_position_embeddings = None
self_embeddings_layer_norm = self.self_embeddings_LayerNorm(add_37); add_37 = None
self_encoder_layer_0_attention_self_query = self.self_encoder_layer_0_attention_self_query(self_embeddings_layer_norm)
self_encoder_layer_0_attention_self_key = self.self_encoder_layer_0_attention_self_key(self_embeddings_layer_norm)
view = self_encoder_layer_0_attention_self_key.view((1, 128, 12, 64)); self_encoder_layer_0_attention_self_key = None
permute = view.permute(0, 2, 1, 3); view = None
self_encoder_layer_0_attention_self_value = self.self_encoder_layer_0_attention_self_value(self_embeddings_layer_norm)
view_1 = self_encoder_layer_0_attention_self_value.view((1, 128, 12, 64)); self_encoder_layer_0_attention_self_value = None
permute_1 = view_1.permute(0, 2, 1, 3); view_1 = None
view_2 = self_encoder_layer_0_attention_self_query.view((1, 128, 12, 64)); self_encoder_layer_0_attention_self_query = None
permute_2 = view_2.permute(0, 2, 1, 3); view_2 = None
transpose = permute.transpose(-1, -2); permute = None
matmul = torch.matmul(permute_2, transpose); permute_2 = transpose = None
truediv = matmul / 8.0; matmul = None
add_1 = truediv + mul; truediv = None
softmax = torch.nn.functional.softmax(add_1, dim = -1); add_1 = None
matmul_1 = torch.matmul(softmax, permute_1); softmax = permute_1 = None
permute_3 = matmul_1.permute(0, 2, 1, 3); matmul_1 = None
contiguous = permute_3.contiguous(); permute_3 = None
view_3 = contiguous.view((1, 128, 768)); contiguous = None
self_encoder_layer_0_attention_output_dense = self.self_encoder_layer_0_attention_output_dense(view_3); view_3 = None
If we draw the FX Graph, we can identify in yellow the attention part:
Now, we look into the code which lines correspond to these nodes in the FX Graph.
transpose = permute.transpose(-1, -2); permute = None
matmul = torch.matmul(permute_2, transpose); permute_2 = transpose = None
truediv = matmul / 8.0; matmul = None
add_1 = truediv + mul; truediv = None
softmax = torch.nn.functional.softmax(add_1, dim = -1); add_1 = None
matmul_1 = torch.matmul(softmax, permute_1); softmax = permute_1 = None
We now have our pattern to catch in the model, to make the pattern easier to read, we rename the following nodes:
permute
→k
permute_1
→v
permute_2
→q
mul
→attention_mask
and can write our pattern function:
def pattern(q, k, attention_mask, v):
transpose = k.transpose(-1, -2)
matmul = torch.matmul(q, transpose)
truediv = matmul / 8.0
add_1 = truediv + attention_mask
softmax = torch.nn.functional.softmax(add_1, dim=-1)
matmul_1 = torch.matmul(softmax, v)
return matmul_1
Replace the Attention part¶
We now need to add our replace function to call the optimized kernel. We can see in kernl/model_optimization.py the optimized attention kernel needs in addition to q
, k
, v
and attention_mask
, the output
and the sm_scale
parameter.
def attention_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output: torch.Tensor,
sm_scale: float,
is_causal: bool = False,
attention_mask: Optional[torch.Tensor] = None,
):
The output
parameter is simply the resulting tensor. We need to provide the tensor beforehand.
The sm_scale
parameter corresponds to the scale factor applied to the query-key compatibility function in the attention function. Defined as \(\frac{1}{\sqrt{d_k}}\), it corresponds to the true_div
node in the FX Graph. In this case sm_scale
is \(\frac{1}{8.0}\).
We can now write our replacement part by calling the optimized kernel:
torch.fx.wrap("attention_forward")
def replace(q, k, attention_mask, v):
output = torch.empty_like(q)
output = attention_forward(q, k, v, output, 1 / 8.0, is_causal=False, attention_mask=attention_mask)
return output
To wrap it up, we can define our replacement function:
import torch
from kernl.implementations.attention import attention_forward
from kernl.utils.extended_matcher import replace_pattern
torch.fx.wrap("attention_forward")
def replace_attention(gm: torch.fx.GraphModule):
def pattern(q, k, attention_mask, v):
transpose = k.transpose(-1, -2)
matmul = torch.matmul(q, transpose)
truediv = matmul / 8.0
add_1 = truediv + attention_mask
softmax = torch.nn.functional.softmax(add_1, dim=-1)
matmul_1 = torch.matmul(softmax, v)
return matmul_1
def replace(q, k, attention_mask, v):
output = torch.empty_like(q)
output = attention_forward(q, k, v, output, 1 / 8.0, is_causal=False, attention_mask=attention_mask)
return output
replace_pattern(gm, pattern, replace)
And use it in the TorchDynamo backend.
def dynamo_backend_ofi(gm: torch.fx.GraphModule, assume_causal=False):
normalize_operators(gm)
remove_dropout(gm)
replace_attention(gm)
print(graph_report(gm))
print(gm.code)
return gm
If we print again the FX Graph after the graph replacement, we see that's all the previous nodes from the attention part are now replaced by the call to the optimized kernel.
-
Aditionnaly, we can enable TorchDynamo's tracing with
torch._dynamo.config.log_level = logging.DEBUG
to display the compiled graph. Enablingtorch._dynamo.config.output_graph_code
displays the graph's code instead. See TorchDynamo's configuration for details. ↩