diff --git a/backends/arm/operators/op_tosa_identity.py b/backends/arm/operators/op_tosa_identity.py index 1b15e39154e..b1099116e12 100644 --- a/backends/arm/operators/op_tosa_identity.py +++ b/backends/arm/operators/op_tosa_identity.py @@ -3,42 +3,28 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List - -import torch import tosa_serializer as ts -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, +from executorch.backends.arm.operators.node_visitor import register_node_visitor +from executorch.backends.arm.operators.simple_node_visitor import ( + SimpleNodeVisitor, + SimpleNodeVisitorConfig, ) -from executorch.backends.arm.tosa.mapping import TosaArg @register_node_visitor -class IdentityVisitor(NodeVisitor): +class IdentityVisitor(SimpleNodeVisitor): """Lower the TOSA IDENTITY op.""" target = "tosa.IDENTITY.default" - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - validate_num_inputs(self.target, inputs, 1) - validate_same_dtype(self.target, [inputs[0], output], ts) - validate_valid_dtype( - self.target, - [inputs[0], output], - [ + @classmethod + def get_config(cls) -> SimpleNodeVisitorConfig: + return SimpleNodeVisitorConfig( + tosa_op=ts.Op.IDENTITY, + attr_method="IdentityAttribute", + num_inputs=1, + input_dtypes=[ ts.DType.BOOL, ts.DType.INT8, ts.DType.INT16, @@ -49,16 +35,4 @@ def define_node( ts.DType.FP8E4M3, ts.DType.FP8E5M2, ], - self.tosa_spec, - ) - - attr = ts.TosaSerializerAttribute() - attr.IdentityAttribute() - self._serialize_operator( - node, - tosa_graph, - ts.Op.IDENTITY, - [inputs[0].name], - [output.name], - attr, )