@@ -73,16 +73,19 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
7373
7474 Value rhs = adaptor.getOther ();
7575
76- RankedTensorType resultType =
77- cast<RankedTensorType>( OpConversionPattern<AtenOpT>::getTypeConverter ()
78- -> convertType ( op.getType ()));
76+ RankedTensorType resultType = cast<RankedTensorType>(
77+ OpConversionPattern<AtenOpT>::getTypeConverter ()-> convertType (
78+ op.getType ()));
7979
8080 if (!lhsType || !resultType)
8181 return rewriter.notifyMatchFailure (
8282 op, " Only Ranked Tensor types are supported in TCP" );
8383
84- auto inputAType = dyn_cast<torch::Torch::ValueTensorType>(op.getSelf ().getType ()).getDtype ();
85- auto outputType = dyn_cast<torch::Torch::ValueTensorType>(op.getType ()).getDtype ();
84+ auto inputAType =
85+ dyn_cast<torch::Torch::ValueTensorType>(op.getSelf ().getType ())
86+ .getDtype ();
87+ auto outputType =
88+ dyn_cast<torch::Torch::ValueTensorType>(op.getType ()).getDtype ();
8689
8790 if (isa<AtenAddScalarOp>(op) || isa<AtenSubScalarOp>(op)) {
8891 rhs = convertScalarOperandToTensor (rewriter, op, op.getOther (),
@@ -91,7 +94,9 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
9194 if (!rhs)
9295 return rewriter.notifyMatchFailure (op, " Unsupported rhs data type" );
9396 } else {
94- auto inputBType = dyn_cast<torch::Torch::ValueTensorType>(op.getOther ().getType ()).getDtype ();
97+ auto inputBType =
98+ dyn_cast<torch::Torch::ValueTensorType>(op.getOther ().getType ())
99+ .getDtype ();
95100 rhs = torch_to_tcp::castTensorToDtype (rewriter, inputBType, outputType,
96101 rhs, resultType.getElementType ());
97102 }
@@ -130,16 +135,19 @@ class ConvertAtenMulOp : public OpConversionPattern<AtenOpT> {
130135
131136 Value rhs = adaptor.getOther ();
132137
133- RankedTensorType resultType =
134- cast<RankedTensorType>( OpConversionPattern<AtenOpT>::getTypeConverter ()
135- -> convertType ( op.getType ()));
138+ RankedTensorType resultType = cast<RankedTensorType>(
139+ OpConversionPattern<AtenOpT>::getTypeConverter ()-> convertType (
140+ op.getType ()));
136141
137142 if (!lhsType || !resultType)
138143 return rewriter.notifyMatchFailure (
139144 op, " Only Ranked Tensor types are supported in TCP" );
140145
141- auto inputAType = dyn_cast<torch::Torch::ValueTensorType>(op.getSelf ().getType ()).getDtype ();
142- auto outputType = dyn_cast<torch::Torch::ValueTensorType>(op.getType ()).getDtype ();
146+ auto inputAType =
147+ dyn_cast<torch::Torch::ValueTensorType>(op.getSelf ().getType ())
148+ .getDtype ();
149+ auto outputType =
150+ dyn_cast<torch::Torch::ValueTensorType>(op.getType ()).getDtype ();
143151
144152 if (isa<AtenMulScalarOp>(op)) {
145153 rhs = convertScalarOperandToTensor (rewriter, op, op.getOther (),
@@ -148,7 +156,9 @@ class ConvertAtenMulOp : public OpConversionPattern<AtenOpT> {
148156 if (!rhs)
149157 return rewriter.notifyMatchFailure (op, " Unsupported rhs data type" );
150158 } else {
151- auto inputBType = dyn_cast<torch::Torch::ValueTensorType>(op.getOther ().getType ()).getDtype ();
159+ auto inputBType =
160+ dyn_cast<torch::Torch::ValueTensorType>(op.getOther ().getType ())
161+ .getDtype ();
152162 rhs = torch_to_tcp::castTensorToDtype (rewriter, inputBType, outputType,
153163 rhs, resultType.getElementType ());
154164 }
@@ -276,16 +286,19 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
276286
277287 Value rhs = adaptor.getOther ();
278288
279- RankedTensorType resultType =
280- cast<RankedTensorType>( OpConversionPattern<AtenOpT>::getTypeConverter ()
281- -> convertType ( op.getType ()));
289+ RankedTensorType resultType = cast<RankedTensorType>(
290+ OpConversionPattern<AtenOpT>::getTypeConverter ()-> convertType (
291+ op.getType ()));
282292
283293 if (!lhsType || !resultType)
284294 return rewriter.notifyMatchFailure (
285295 op, " Only Ranked Tensor types are supported in TCP" );
286296
287- auto inputAType = dyn_cast<torch::Torch::ValueTensorType>(op.getSelf ().getType ()).getDtype ();
288- auto outputType = dyn_cast<torch::Torch::ValueTensorType>(op.getType ()).getDtype ();
297+ auto inputAType =
298+ dyn_cast<torch::Torch::ValueTensorType>(op.getSelf ().getType ())
299+ .getDtype ();
300+ auto outputType =
301+ dyn_cast<torch::Torch::ValueTensorType>(op.getType ()).getDtype ();
289302
290303 Type inputBType = nullptr ;
291304 if (isa<AtenDivScalarOp>(op)) {
@@ -297,7 +310,9 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
297310 if (!rhs)
298311 return rewriter.notifyMatchFailure (op, " Unsupported rhs data type" );
299312 } else {
300- inputBType = dyn_cast<torch::Torch::ValueTensorType>(op.getOther ().getType ()).getDtype ();
313+ inputBType =
314+ dyn_cast<torch::Torch::ValueTensorType>(op.getOther ().getType ())
315+ .getDtype ();
301316 rhs = torch_to_tcp::castTensorToDtype (rewriter, inputBType, outputType,
302317 rhs, resultType.getElementType ());
303318 }
@@ -452,8 +467,9 @@ class ConvertAtenSqrtOp : public OpConversionPattern<AtenSqrtOp> {
452467
453468 Value newInput = input;
454469 if (isa<mlir::IntegerType>(elementType)) {
455- auto inputDType =
456- dyn_cast<torch::Torch::ValueTensorType>(op.getSelf ().getType ()).getDtype ();
470+ auto inputDType =
471+ dyn_cast<torch::Torch::ValueTensorType>(op.getSelf ().getType ())
472+ .getDtype ();
457473 auto outputDType =
458474 dyn_cast<torch::Torch::ValueTensorType>(op.getType ()).getDtype ();
459475 newInput =
@@ -519,9 +535,9 @@ class ConvertAtenUnaryIntOrFpOp : public OpConversionPattern<AtenOpT> {
519535 return rewriter.notifyMatchFailure (
520536 op, " Input tensor must have integer or floating-point datatype" );
521537
522- RankedTensorType resultType =
523- cast<RankedTensorType>( OpConversionPattern<AtenOpT>::getTypeConverter ()
524- -> convertType ( op.getType ()));
538+ RankedTensorType resultType = cast<RankedTensorType>(
539+ OpConversionPattern<AtenOpT>::getTypeConverter ()-> convertType (
540+ op.getType ()));
525541
526542 rewriter.replaceOpWithNewOp <TcpOpT>(op, resultType, input);
527543 return success ();
@@ -578,10 +594,12 @@ class ConvertAtenAtan2Op : public OpConversionPattern<AtenAtan2Op> {
578594 return rewriter.notifyMatchFailure (
579595 op, " Input tensors must have floating-point datatype" );
580596
581- auto inputAType =
582- dyn_cast<torch::Torch::ValueTensorType>(op.getSelf ().getType ()).getDtype ();
583- auto inputBType =
584- dyn_cast<torch::Torch::ValueTensorType>(op.getOther ().getType ()).getDtype ();
597+ auto inputAType =
598+ dyn_cast<torch::Torch::ValueTensorType>(op.getSelf ().getType ())
599+ .getDtype ();
600+ auto inputBType =
601+ dyn_cast<torch::Torch::ValueTensorType>(op.getOther ().getType ())
602+ .getDtype ();
585603 auto outputType =
586604 dyn_cast<torch::Torch::ValueTensorType>(op.getType ()).getDtype ();
587605
0 commit comments