Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,15 @@ def _map_fx_unique_metatypes(node: torch.fx.Node, metatype: om.OperatorMetatype)
:param model: Target GraphModule instance.
:return: Correct FX metatype of the given node if it is exist or the original node metatype otherwise.
"""
if metatype in [om.PTEmbeddingMetatype]:
weight_node = node.args[0]
PT_METATYPE_TO_FX_METATYPE_MAPPING = {
om.PTEmbeddingMetatype: om.PTAtenEmbeddingMetatype,
om.PTEmbeddingBagMetatype: om.PTAtenEmbeddingBagMetatype,
}
if metatype in PT_METATYPE_TO_FX_METATYPE_MAPPING:
fx_metatype = PT_METATYPE_TO_FX_METATYPE_MAPPING[metatype]
weight_node = node.args[fx_metatype.weight_port_ids[0]]
if weight_node.op == "get_attr":
return om.PTAtenEmbeddingMetatype
return fx_metatype

return metatype

Expand Down Expand Up @@ -137,6 +142,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
for source_node in model.graph.nodes:
node_type, node_metatype = GraphConverter.get_node_type_and_metatype(source_node, model)
node_metatype = GraphConverter._map_fx_unique_metatypes(source_node, node_metatype)

is_shared_node = source_node.op in ("get_attr",) and (
const_targets_counter[source_node.target] > 1 or len(source_node.users) > 1
)
Expand Down Expand Up @@ -190,7 +196,16 @@ def get_edge_params(
source_node.meta["val"], (tuple, list)
):
tensor = source_node.meta["val"][0]
elif source_nncf_node.metatype in [om.PTSplitMetatype, om.PTMaxMetatype, om.PTMinMetatype]:
elif source_nncf_node.metatype in [
om.PTSplitMetatype,
om.PTMaxMetatype,
om.PTMinMetatype,
om.PTMedianMetatype,
om.PTAdaptiveMaxPool1dMetatype,
om.PTAdaptiveMaxPool2dMetatype,
om.PTAdaptiveMaxPool3dMetatype,
om.PTAtenEmbeddingBagMetatype,
] and isinstance(source_node.meta["val"], (tuple, list)):
tensor = source_node.meta["val"][output_idx]
# Assume every outputs corresponds to an unique output_port_id
output_port_id = output_idx
Expand Down
26 changes: 26 additions & 0 deletions src/nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,12 @@ class PTMeanMetatype(PTOperatorMetatype):
hw_config_names = [HWConfigOpName.REDUCEMEAN]


@PT_OPERATOR_METATYPES.register()
class PTMedianMetatype(PTOperatorMetatype):
name = "MedianOp"
module_to_function_names = {NamespaceTarget.ATEN: ["median"]}


@PT_OPERATOR_METATYPES.register()
class PTRoundMetatype(PTOperatorMetatype):
name = "RoundOp"
Expand Down Expand Up @@ -745,6 +751,16 @@ class PTBatchNormMetatype(PTOperatorMetatype):
bias_port_id = 2


@PT_OPERATOR_METATYPES.register()
class PTAvgPool1dMetatype(PTOperatorMetatype):
name = "AvgPool1DOp"
module_to_function_names = {
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["avg_pool1d", "adaptive_avg_pool1d"],
NamespaceTarget.ATEN: ["adaptive_avg_pool1d"],
}
hw_config_names = [HWConfigOpName.AVGPOOL]


@PT_OPERATOR_METATYPES.register()
class PTAvgPool2dMetatype(PTOperatorMetatype):
name = "AvgPool2DOp"
Expand All @@ -759,6 +775,7 @@ class PTAvgPool3dMetatype(PTOperatorMetatype):
hw_config_names = [HWConfigOpName.AVGPOOL]


@PT_OPERATOR_METATYPES.register()
class PTAdaptiveMaxPool1dMetatype(PTOperatorMetatype):
name = "AdaptiveMaxPool1DOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["adaptive_max_pool1d"]}
Expand Down Expand Up @@ -965,6 +982,14 @@ class PTEmbeddingBagMetatype(PTOperatorMetatype):
weight_port_ids = [1]


@FX_OPERATOR_METATYPES.register()
class PTAtenEmbeddingBagMetatype(OperatorMetatype):
name = "EmbeddingBagOp"
module_to_function_names = {NamespaceTarget.ATEN: ["embedding_bag"]}
hw_config_names = [HWConfigOpName.EMBEDDINGBAG]
weight_port_ids = [0]


@PT_OPERATOR_METATYPES.register()
class PTSoftmaxMetatype(PTOperatorMetatype):
name = "SoftmaxOp"
Expand Down Expand Up @@ -1222,6 +1247,7 @@ def get_operator_metatypes() -> list[type[OperatorMetatype]]:
OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS = [
PTEmbeddingMetatype,
PTEmbeddingBagMetatype,
PTAtenEmbeddingBagMetatype,
PTModuleEmbeddingBagMetatype,
PTModuleEmbeddingMetatype,
]
Expand Down
28 changes: 28 additions & 0 deletions tests/torch2/data/fx/embedding_bag_model.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
strict digraph {
"0 embedding_weight" [id=0, type="get_attr"];
"1 embeddingbag_weight" [id=1, type="get_attr"];
"2 x" [id=2, type=input];
"3 embedding" [id=3, type=embedding];
"4 arange" [id=4, type=arange];
"5 reshape" [id=5, type=reshape];
"6 embedding_bag" [id=6, type="embedding_bag"];
"7 getitem" [id=7, type="__getitem__"];
"8 getitem_1" [id=8, type="__getitem__"];
"9 getitem_2" [id=9, type="__getitem__"];
"10 getitem_3" [id=10, type="__getitem__"];
"11 add" [id=11, type=add];
"12 output" [id=12, type=output];
"0 embedding_weight" -> "3 embedding" [style=solid, label="(10, 10)"];
"1 embeddingbag_weight" -> "6 embedding_bag" [style=solid, label="(10, 10)"];
"2 x" -> "3 embedding" [style=solid, label="(1, 1)"];
"2 x" -> "5 reshape" [style=solid, label="(1, 1)"];
"3 embedding" -> "11 add" [style=solid, label="(1, 1, 10)"];
"4 arange" -> "6 embedding_bag" [style=solid, label="(1,)"];
"5 reshape" -> "6 embedding_bag" [style=solid, label="(1,)"];
"6 embedding_bag" -> "7 getitem" [style=solid, label="(1, 10)"];
"6 embedding_bag" -> "8 getitem_1" [style=solid, label="(1,)"];
"6 embedding_bag" -> "9 getitem_2" [style=solid, label="(1,)"];
"6 embedding_bag" -> "10 getitem_3" [style=solid, label="(1,)"];
"7 getitem" -> "11 add" [style=solid, label="(1, 10)"];
"11 add" -> "12 output" [style=solid, label="(1, 1, 10)"];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"embedding_weight": "PTConstNoopMetatype",
"embeddingbag_weight": "PTConstNoopMetatype",
"x": "PTInputNoopMetatype",
"embedding": "PTAtenEmbeddingMetatype",
"arange": "UnknownMetatype",
"reshape": "PTReshapeMetatype",
"embedding_bag": "PTAtenEmbeddingBagMetatype",
"getitem": "PTGatherMetatype",
"getitem_1": "PTGatherMetatype",
"getitem_2": "PTGatherMetatype",
"getitem_3": "PTGatherMetatype",
"add": "PTAddMetatype",
"output": "PTOutputNoopMetatype"
}
4 changes: 3 additions & 1 deletion tests/torch2/fx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from tests.cross_fw.shared.nx_graph import compare_nx_graph_with_reference
from tests.cross_fw.shared.paths import TEST_ROOT
from tests.torch import test_models
from tests.torch.test_models.synthetic import EmbeddingSumModel
from tests.torch.test_models.synthetic import MultiBranchesConnectedModel
from tests.torch.test_models.synthetic import ShortTransformer
from tests.torch.test_models.synthetic import YOLO11N_SDPABlock
Expand Down Expand Up @@ -75,6 +76,7 @@ def torchvision_model_case(model_id: str, input_shape: tuple[int,]):
ModelCase(test_models.UNet, "unet", [1, 3, 224, 224]),
ModelCase(partial(ShortTransformer, 5, 10), "synthetic_transformer", [5]),
ModelCase(YOLO11N_SDPABlock, "yolo11n_sdpa_block", YOLO11N_SDPABlock.INPUT_SIZE),
ModelCase(EmbeddingSumModel, "embedding_bag_model", [1, 1]),
)


Expand Down Expand Up @@ -121,7 +123,7 @@ def test_model(test_case: ModelCase, regen_ref_data: bool):
model = test_case.model_builder()
model.to(device)

dtype = torch.int32 if test_case.model_id == "synthetic_transformer" else torch.float32
dtype = torch.int32 if test_case.model_id in ["synthetic_transformer", "embedding_bag_model"] else torch.float32
ex_input = torch.ones(test_case.input_shape, dtype=dtype)
exported_model = get_torch_fx_model(model, ex_input)
nncf_graph = GraphConverter.create_nncf_graph(exported_model)
Expand Down