Skip to content

Commit 1adc259

Browse files
committed
update
1 parent 8a69ae2 commit 1adc259

File tree

17 files changed

+224
-57
lines changed

17 files changed

+224
-57
lines changed

BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ cc_library(
150150
name = "TcpDialectPasses",
151151
srcs = [
152152
"lib/Dialect/Transforms/DropSymbolicShapeOpsPass.cpp",
153+
"lib/Dialect/Transforms/EliminateUnusedTorchOpsPass.cpp",
153154
"lib/Dialect/Transforms/FuseTcpOpsPass.cpp",
154155
"lib/Dialect/Transforms/FusionPatterns.cpp",
155156
"lib/Dialect/Transforms/IsolateGroupOpsPass.cpp",
@@ -160,6 +161,7 @@ cc_library(
160161
],
161162
hdrs = [
162163
"include/mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h",
164+
"include/mlir-tcp/Dialect/Transforms/EliminateUnusedTorchOpsPass.h",
163165
"include/mlir-tcp/Dialect/Transforms/FuseTcpOpsPass.h",
164166
"include/mlir-tcp/Dialect/Transforms/FusionPatterns.h",
165167
"include/mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h",
@@ -175,6 +177,7 @@ cc_library(
175177
"@llvm-project//mlir:TensorDialect",
176178
"@llvm-project//mlir:TensorTransforms",
177179
"@llvm-project//mlir:Transforms",
180+
"@torch-mlir//:TorchMLIRTorchDialect",
178181
],
179182
)
180183

deps.bzl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def third_party_deps():
2020
path = local_llvm_repo_path(),
2121
)
2222
else:
23-
LLVM_COMMIT = "eda3e96b401a9b86132e39432e41e2000d1ab382"
24-
LLVM_SHA256 = "26c4060f19982482d57f1a47945f3f7613b7659415f0482c4bac63769366b501"
23+
LLVM_COMMIT = "b231e5ff504295641b0f580ceefa2e1048011614"
24+
LLVM_SHA256 = "88dfa59052730710cb48fa20b00a4344144edd1c3cb524c06d983899835e491a"
2525
http_archive(
2626
name = "llvm-raw",
2727
build_file_content = "# empty",
@@ -42,7 +42,7 @@ def third_party_deps():
4242
http_archive(
4343
name = "torch-mlir-raw",
4444
build_file_content = "# empty",
45-
patches = ["//third_party/patches:torch-mlir-bazel-build.1.patch"],
45+
patches = ["//third_party/patches:torch-mlir-bazel-build.1.patch", "//third_party/patches:torch-mlir-bazel-build.2.patch"],
4646
sha256 = TORCH_MLIR_SHA256,
4747
strip_prefix = "torch-mlir-" + TORCH_MLIR_COMMIT,
4848
urls = ["https://github.com/llvm/torch-mlir/archive/{commit}.tar.gz".format(commit = TORCH_MLIR_COMMIT)],
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===------------------------------------------------------------*- C++ -*-===//
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Also available under a BSD-style license. See LICENSE.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#pragma once
11+
12+
#include "mlir/IR/BuiltinOps.h"
13+
#include "mlir/Pass/Pass.h"
14+
#include <memory>
15+
16+
namespace mlir::tcp {
17+
18+
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
19+
createEliminateUnusedTorchOpsPass();
20+
21+
} // namespace mlir::tcp

include/mlir-tcp/Dialect/Transforms/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,10 @@ def DropSymbolicShapeOps : Pass<"drop-symbolic-shape-ops", "func::FuncOp"> {
4545
let constructor = "mlir::tcp::createDropSymbolicShapeOpsPass()";
4646
}
4747

48+
// \brief This pass removes unused torch ops.
49+
def EliminateUnusedTorchOps : Pass<"eliminate-unused-torch-ops", "ModuleOp"> {
50+
let summary = "Removes unused/unnecessary torch ops";
51+
let constructor = "mlir::tcp::createEliminateUnusedTorchOpsPass()";
52+
}
53+
4854
#endif // TCP_PASSES

lib/Conversion/TcpToLinalg/DataMovement.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ class ConvertGatherOp : public OpConversionPattern<GatherOp> {
3636
matchAndRewrite(GatherOp op, OpAdaptor adaptor,
3737
ConversionPatternRewriter &rewriter) const override {
3838
Location loc = op->getLoc();
39-
auto resultTensorType = cast<RankedTensorType>(getTypeConverter()
40-
->convertType(op.getOut().getType()));
39+
auto resultTensorType = cast<RankedTensorType>(
40+
getTypeConverter()->convertType(op.getOut().getType()));
4141

4242
auto inputTensor = adaptor.getInput();
4343
auto indicesTensor = adaptor.getIndices();
@@ -109,8 +109,8 @@ class ConvertGatherNDOp : public OpConversionPattern<GatherNDOp> {
109109
matchAndRewrite(GatherNDOp op, OpAdaptor adaptor,
110110
ConversionPatternRewriter &rewriter) const override {
111111
Location loc = op->getLoc();
112-
auto resultTensorType = cast<RankedTensorType>(getTypeConverter()
113-
->convertType(op.getOut().getType()));
112+
auto resultTensorType = cast<RankedTensorType>(
113+
getTypeConverter()->convertType(op.getOut().getType()));
114114

115115
auto inputTensor = adaptor.getInput();
116116
auto indicesTensor = adaptor.getIndices();

lib/Conversion/TcpToLinalg/Elementwise.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,9 @@ class ConvertElementwiseOp : public OpConversionPattern<TcpOpT> {
318318
matchAndRewrite(TcpOpT op, OpAdaptor adaptor,
319319
ConversionPatternRewriter &rewriter) const override {
320320
Location loc = op->getLoc();
321-
auto resultTensorType = cast<RankedTensorType>(OpConversionPattern<TcpOpT>::getTypeConverter()
322-
->convertType(op->getResult(0).getType()));
321+
auto resultTensorType = cast<RankedTensorType>(
322+
OpConversionPattern<TcpOpT>::getTypeConverter()->convertType(
323+
op->getResult(0).getType()));
323324
auto tensorOperands = llvm::to_vector<6>(
324325
llvm::make_filter_range(adaptor.getOperands(), [](Value v) {
325326
return isa<RankedTensorType>(v.getType());

lib/Conversion/TcpToLinalg/Misc.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ class ConvertBroadcastOp : public OpConversionPattern<BroadcastOp> {
4040
LogicalResult matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
4141
ConversionPatternRewriter &b) const override {
4242
Location loc = op->getLoc();
43-
auto resultTensorType = cast<RankedTensorType>(getTypeConverter()
44-
->convertType(op->getResult(0).getType()));
43+
auto resultTensorType = cast<RankedTensorType>(
44+
getTypeConverter()->convertType(op->getResult(0).getType()));
4545
auto inputTensor = op->getOperands()[0];
4646

4747
SmallVector<int64_t> axes = getValuesFromIndexArrayAttribute(op.getAxes());

lib/Conversion/TorchToTcp/DataMovement.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
4040
SmallVector<Value> &strides) {
4141
Location loc = op.getLoc();
4242
auto input = adaptor.getSelf();
43-
RankedTensorType inputType =
44-
cast<RankedTensorType>(input.getType());
43+
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
4544

4645
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
4746
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -185,8 +184,8 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern<AtenSliceTensorOp> {
185184
return failure();
186185

187186
auto input = adaptor.getSelf();
188-
RankedTensorType resultType = cast<RankedTensorType>(getTypeConverter()
189-
->convertType(op->getResult(0).getType()));
187+
RankedTensorType resultType = cast<RankedTensorType>(
188+
getTypeConverter()->convertType(op->getResult(0).getType()));
190189

191190
SmallVector<Value> resultShape;
192191
SmallVector<Value> offsets;
@@ -212,8 +211,8 @@ class ConvertAtenGatherOp : public OpConversionPattern<AtenGatherOp> {
212211
ConversionPatternRewriter &rewriter) const override {
213212
auto input = adaptor.getSelf();
214213
auto indices = adaptor.getIndex();
215-
RankedTensorType resultType = cast<RankedTensorType>(getTypeConverter()
216-
->convertType(op->getResult(0).getType()));
214+
RankedTensorType resultType = cast<RankedTensorType>(
215+
getTypeConverter()->convertType(op->getResult(0).getType()));
217216

218217
int64_t dim = 0;
219218
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))

lib/Conversion/TorchToTcp/Elementwise.cpp

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,19 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
7373

7474
Value rhs = adaptor.getOther();
7575

76-
RankedTensorType resultType =
77-
cast<RankedTensorType>(OpConversionPattern<AtenOpT>::getTypeConverter()
78-
->convertType(op.getType()));
76+
RankedTensorType resultType = cast<RankedTensorType>(
77+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
78+
op.getType()));
7979

8080
if (!lhsType || !resultType)
8181
return rewriter.notifyMatchFailure(
8282
op, "Only Ranked Tensor types are supported in TCP");
8383

84-
auto inputAType = dyn_cast<torch::Torch::ValueTensorType>(op.getSelf().getType()).getDtype();
85-
auto outputType = dyn_cast<torch::Torch::ValueTensorType>(op.getType()).getDtype();
84+
auto inputAType =
85+
dyn_cast<torch::Torch::ValueTensorType>(op.getSelf().getType())
86+
.getDtype();
87+
auto outputType =
88+
dyn_cast<torch::Torch::ValueTensorType>(op.getType()).getDtype();
8689

8790
if (isa<AtenAddScalarOp>(op) || isa<AtenSubScalarOp>(op)) {
8891
rhs = convertScalarOperandToTensor(rewriter, op, op.getOther(),
@@ -91,7 +94,9 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
9194
if (!rhs)
9295
return rewriter.notifyMatchFailure(op, "Unsupported rhs data type");
9396
} else {
94-
auto inputBType = dyn_cast<torch::Torch::ValueTensorType>(op.getOther().getType()).getDtype();
97+
auto inputBType =
98+
dyn_cast<torch::Torch::ValueTensorType>(op.getOther().getType())
99+
.getDtype();
95100
rhs = torch_to_tcp::castTensorToDtype(rewriter, inputBType, outputType,
96101
rhs, resultType.getElementType());
97102
}
@@ -130,16 +135,19 @@ class ConvertAtenMulOp : public OpConversionPattern<AtenOpT> {
130135

131136
Value rhs = adaptor.getOther();
132137

133-
RankedTensorType resultType =
134-
cast<RankedTensorType>(OpConversionPattern<AtenOpT>::getTypeConverter()
135-
->convertType(op.getType()));
138+
RankedTensorType resultType = cast<RankedTensorType>(
139+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
140+
op.getType()));
136141

137142
if (!lhsType || !resultType)
138143
return rewriter.notifyMatchFailure(
139144
op, "Only Ranked Tensor types are supported in TCP");
140145

141-
auto inputAType = dyn_cast<torch::Torch::ValueTensorType>(op.getSelf().getType()).getDtype();
142-
auto outputType = dyn_cast<torch::Torch::ValueTensorType>(op.getType()).getDtype();
146+
auto inputAType =
147+
dyn_cast<torch::Torch::ValueTensorType>(op.getSelf().getType())
148+
.getDtype();
149+
auto outputType =
150+
dyn_cast<torch::Torch::ValueTensorType>(op.getType()).getDtype();
143151

144152
if (isa<AtenMulScalarOp>(op)) {
145153
rhs = convertScalarOperandToTensor(rewriter, op, op.getOther(),
@@ -148,7 +156,9 @@ class ConvertAtenMulOp : public OpConversionPattern<AtenOpT> {
148156
if (!rhs)
149157
return rewriter.notifyMatchFailure(op, "Unsupported rhs data type");
150158
} else {
151-
auto inputBType = dyn_cast<torch::Torch::ValueTensorType>(op.getOther().getType()).getDtype();
159+
auto inputBType =
160+
dyn_cast<torch::Torch::ValueTensorType>(op.getOther().getType())
161+
.getDtype();
152162
rhs = torch_to_tcp::castTensorToDtype(rewriter, inputBType, outputType,
153163
rhs, resultType.getElementType());
154164
}
@@ -276,16 +286,19 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
276286

277287
Value rhs = adaptor.getOther();
278288

279-
RankedTensorType resultType =
280-
cast<RankedTensorType>(OpConversionPattern<AtenOpT>::getTypeConverter()
281-
->convertType(op.getType()));
289+
RankedTensorType resultType = cast<RankedTensorType>(
290+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
291+
op.getType()));
282292

283293
if (!lhsType || !resultType)
284294
return rewriter.notifyMatchFailure(
285295
op, "Only Ranked Tensor types are supported in TCP");
286296

287-
auto inputAType = dyn_cast<torch::Torch::ValueTensorType>(op.getSelf().getType()).getDtype();
288-
auto outputType = dyn_cast<torch::Torch::ValueTensorType>(op.getType()).getDtype();
297+
auto inputAType =
298+
dyn_cast<torch::Torch::ValueTensorType>(op.getSelf().getType())
299+
.getDtype();
300+
auto outputType =
301+
dyn_cast<torch::Torch::ValueTensorType>(op.getType()).getDtype();
289302

290303
Type inputBType = nullptr;
291304
if (isa<AtenDivScalarOp>(op)) {
@@ -297,7 +310,9 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
297310
if (!rhs)
298311
return rewriter.notifyMatchFailure(op, "Unsupported rhs data type");
299312
} else {
300-
inputBType = dyn_cast<torch::Torch::ValueTensorType>(op.getOther().getType()).getDtype();
313+
inputBType =
314+
dyn_cast<torch::Torch::ValueTensorType>(op.getOther().getType())
315+
.getDtype();
301316
rhs = torch_to_tcp::castTensorToDtype(rewriter, inputBType, outputType,
302317
rhs, resultType.getElementType());
303318
}
@@ -452,8 +467,9 @@ class ConvertAtenSqrtOp : public OpConversionPattern<AtenSqrtOp> {
452467

453468
Value newInput = input;
454469
if (isa<mlir::IntegerType>(elementType)) {
455-
auto inputDType =
456-
dyn_cast<torch::Torch::ValueTensorType>(op.getSelf().getType()).getDtype();
470+
auto inputDType =
471+
dyn_cast<torch::Torch::ValueTensorType>(op.getSelf().getType())
472+
.getDtype();
457473
auto outputDType =
458474
dyn_cast<torch::Torch::ValueTensorType>(op.getType()).getDtype();
459475
newInput =
@@ -519,9 +535,9 @@ class ConvertAtenUnaryIntOrFpOp : public OpConversionPattern<AtenOpT> {
519535
return rewriter.notifyMatchFailure(
520536
op, "Input tensor must have integer or floating-point datatype");
521537

522-
RankedTensorType resultType =
523-
cast<RankedTensorType>(OpConversionPattern<AtenOpT>::getTypeConverter()
524-
->convertType(op.getType()));
538+
RankedTensorType resultType = cast<RankedTensorType>(
539+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
540+
op.getType()));
525541

526542
rewriter.replaceOpWithNewOp<TcpOpT>(op, resultType, input);
527543
return success();
@@ -578,10 +594,12 @@ class ConvertAtenAtan2Op : public OpConversionPattern<AtenAtan2Op> {
578594
return rewriter.notifyMatchFailure(
579595
op, "Input tensors must have floating-point datatype");
580596

581-
auto inputAType =
582-
dyn_cast<torch::Torch::ValueTensorType>(op.getSelf().getType()).getDtype();
583-
auto inputBType =
584-
dyn_cast<torch::Torch::ValueTensorType>(op.getOther().getType()).getDtype();
597+
auto inputAType =
598+
dyn_cast<torch::Torch::ValueTensorType>(op.getSelf().getType())
599+
.getDtype();
600+
auto inputBType =
601+
dyn_cast<torch::Torch::ValueTensorType>(op.getOther().getType())
602+
.getDtype();
585603
auto outputType =
586604
dyn_cast<torch::Torch::ValueTensorType>(op.getType()).getDtype();
587605

lib/Conversion/TorchToTcp/Misc.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ class ConvertAtenBroadcastLikeOps : public OpConversionPattern<AtenOpT> {
133133
rewriter.replaceOp(op, input);
134134
return success();
135135
}
136-
RankedTensorType resultType =
137-
cast<RankedTensorType>(OpConversionPattern<AtenOpT>::getTypeConverter()
138-
->convertType(op->getResult(0).getType()));
136+
RankedTensorType resultType = cast<RankedTensorType>(
137+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
138+
op->getResult(0).getType()));
139139
auto axesAttr = rewriter.getI64ArrayAttr(axes);
140140
rewriter.replaceOpWithNewOp<tcp::BroadcastOp>(op, resultType, input,
141141
resultShape, axesAttr);
@@ -218,8 +218,9 @@ class ConvertAtenZerosOnesOp : public OpConversionPattern<AtenOpT> {
218218
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
219219
ConversionPatternRewriter &rewriter) const override {
220220

221-
auto outType = dyn_cast<RankedTensorType>(OpConversionPattern<AtenOpT>::getTypeConverter()
222-
->convertType(op.getType()));
221+
auto outType = dyn_cast<RankedTensorType>(
222+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
223+
op.getType()));
223224
Type outElemTy = outType.getElementType();
224225

225226
if (!checkZerosOnesOpAttributes<AtenOpT>(op, outType)) {
@@ -263,8 +264,9 @@ class ConvertAtenZerosOnesLikeOp : public OpConversionPattern<AtenOpT> {
263264
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
264265
ConversionPatternRewriter &rewriter) const override {
265266
Value input = adaptor.getSelf();
266-
auto outType = dyn_cast<RankedTensorType>(OpConversionPattern<AtenOpT>::getTypeConverter()
267-
->convertType(op.getType()));
267+
auto outType = dyn_cast<RankedTensorType>(
268+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
269+
op.getType()));
268270
Type outElemTy = outType.getElementType();
269271

270272
// TODO: Check the attribute for input vtensor

0 commit comments

Comments
 (0)