extended_matcher
SubgraphMatcher
¶
__init__(pattern, match_output=False, match_placeholder=False, remove_overlapping_matches=True)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pattern |
Graph
|
the targeted matching pattern, represented in fx.Graph. |
required |
match_output |
bool
|
If True, output node in the pattern graph will be treated as a part of the targeted pattern. If False, output node is ignored during match. |
False
|
match_placeholder |
bool
|
If True, placeholder node in the pattern graph will be treated as a part of the targeted pattern. If False, placeholder nodes will be used a wildcard. |
False
|
remove_overlapping_matches |
bool
|
If True, in the case of overlapping matches, only the first match will be returned. |
True
|
match(graph)
¶
Returns:
Type | Description |
---|---|
List[InternalMatch]
|
The matched subgraphs. Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder and nodes returned by output) can only be consumed by nodes within the matched subgraph. |
Subgraph pattern matcher is implemented with the backtracking style in the following steps:
-
We first identify all the anchor nodes in the pattern graph. The anchor nodes are the "sinks" (nodes with no user other than the output node) of the pattern graph. One pattern graph could have multiple anchors if it has multiple return values.
-
In the target graph, we identify the potential candidate nodes that can be matched with each anchor. These anchor-candidate pairs are the starting points for pairwise per-node matching.
-
For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both pattern and target graphs. For every pattern nodes along traversal path, we compare it against the target nodes. In case any comparison failed, the match for this anchor-candidate pair fails. A match is found when DFS completes traversing the graph. See
self._match_nodes
for more details. -
In the case of multiple anchors, every anchor will need to find a match using step 3. In addition, the matches found between anchors need to have a common intersection node in order for the match to be valid. This is implemented with backtracking. See
backtracking
for more details.
Note
graph traversal must be done in the reverser order because a tensor can have multiple consumers, but can only have a single producer. Only with reverser order, we can we jointly traverse the pattern and target graph in a deterministic path. Warning: In theory, this backtracking algorithm have an exponential time complexity. However, in practice, it's unlikely to blow up.
replace_pattern(gm, pattern, replacement)
¶
Matches all possible non-overlapping sets of operators and their
data dependencies (pattern
) in the Graph of a GraphModule
(gm
), then replaces each of these matched subgraphs with another
subgraph (replacement
).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
gm |
GraphModule
|
The GraphModule that wraps the Graph to operate on |
required |
pattern |
Callable
|
The subgraph to match in |
required |
replacement |
Callable
|
The subgraph to replace |
required |
Returns:
Type | Description |
---|---|
List[Match]
|
List[Match]: A list of |
Examples:
import torch
from torch.fx import symbolic_trace, subgraph_rewriter
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, w1, w2):
m1 = torch.cat([w1, w2]).sum()
m2 = torch.cat([w1, w2]).sum()
return x + torch.max(m1) + torch.max(m2)
def pattern(w1, w2):
return torch.cat([w1, w2]).sum()
def replacement(w1, w2):
return torch.stack([w1, w2])
traced_module = symbolic_trace(M())
subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
pattern
in the forward
method of traced_module
. Pattern-matching is done based on
use-def relationships, not node names. For example, if you had
p = torch.cat([a, b])
in pattern
, you could match
m = torch.cat([a, b])
in the original forward
function,
despite the variable names being different (p
vs m
).
The return
statement in pattern
is matched based on its
value only; it may or may not match to the return
statement in
the larger graph. In other words, the pattern doesn't have to extend
to the end of the larger graph.
When the pattern is matched, it will be removed from the larger
function and replaced by replacement
. If there are multiple
matches for pattern
in the larger function, each non-overlapping
match will be replaced. In the case of a match overlap, the first
found match in the set of overlapping matches will be replaced.
("First" here being defined as the first in a topological ordering
of the Nodes' use-def relationships. In most cases, the first Node
is the parameter that appears directly after self
, while the
last Node is whatever the function returns.)
One important thing to note is that the parameters of the
pattern
Callable must be used in the Callable itself,
and the parameters of the replacement
Callable must match
the pattern. The first rule is why, in the above code block, the
forward
function has parameters x, w1, w2
, but the
pattern
function only has parameters w1, w2
. pattern
doesn't use x
, so it shouldn't specify x
as a parameter.
As an example of the second rule, consider replacing
replacement
needs the same number of parameters
as pattern
(both x
and y
), even though the parameter
y
isn't used in replacement
.
After calling subgraph_rewriter.replace_pattern
, the generated
Python code looks like this: