Skip to content

Commit b6a718d

Browse files
committed
fix examples
1 parent 3aa6fcc commit b6a718d

File tree

7 files changed

+67
-67
lines changed

7 files changed

+67
-67
lines changed

projects/eudsl-python-extras/examples/cuda_e2e.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@
103103
"from mlir import _mlir_libs\n",
104104
"from mlir.extras.ast.canonicalize import canonicalize\n",
105105
"from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule\n",
106-
"from mlir.extras.dialects.ext import arith, memref, scf, gpu\n",
107-
"from mlir.extras.dialects.ext import linalg\n",
108-
"from mlir.extras.dialects.ext import transform\n",
109-
"from mlir.extras.dialects.ext.func import func\n",
106+
"from mlir.extras.dialects import arith, memref, scf, gpu\n",
107+
"from mlir.extras.dialects import linalg\n",
108+
"from mlir.extras.dialects import transform\n",
109+
"from mlir.extras.dialects.func import func\n",
110110
"from mlir.extras.runtime.passes import Pipeline, run_pipeline\n",
111111
"from mlir.extras.runtime.refbackend import LLVMJITBackend\n",
112112
"from mlir.extras.util import find_ops\n",

projects/eudsl-python-extras/examples/cuda_matmul_opt.py

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@
1111
mlir_mod_ctx,
1212
MLIRContext,
1313
)
14-
from mlir.extras.dialects.ext import arith, memref, gpu, scf, linalg, vector, nvgpu
15-
from mlir.extras.dialects.ext.gpu import (
14+
from mlir.extras.dialects import arith, memref, gpu, scf, linalg, vector, nvgpu
15+
from mlir.extras.dialects.gpu import (
1616
block_idx,
1717
thread_idx,
1818
block_dim,
1919
get_compile_object_bytes,
2020
smem_space,
2121
)
22-
from mlir.extras.dialects.ext.llvm import llvm_ptr_t
23-
from mlir.extras.dialects.ext.memref import S
24-
from mlir.extras.dialects.ext.scf import range_
22+
from mlir.extras.dialects.llvm import llvm_ptr_t
23+
from mlir.extras.dialects.memref import S
24+
from mlir.extras.dialects.scf import range_
2525
from mlir.extras.runtime.passes import Pipeline, run_pipeline
2626

2727
# noinspection PyUnresolvedReferences
@@ -139,9 +139,9 @@ def sgemm_naive[
139139
K,
140140
N,
141141
dtype,
142-
A_t: T.memref(M, K, dtype),
143-
B_t: T.memref(K, N, dtype),
144-
C_t: T.memref(M, N, dtype),
142+
A_t = T.memref(M, K, dtype),
143+
B_t = T.memref(K, N, dtype),
144+
C_t = T.memref(M, N, dtype),
145145
](A: A_t, B: B_t, C: C_t):
146146
one = arith.constant(1.0, type=dtype)
147147
tmp = arith.constant(0, type=dtype)
@@ -167,9 +167,9 @@ def sgemm_naive_row_order[
167167
K,
168168
N,
169169
dtype,
170-
A_t: T.memref(M, K, dtype),
171-
B_t: T.memref(K, N, dtype),
172-
C_t: T.memref(M, N, dtype),
170+
A_t = T.memref(M, K, dtype),
171+
B_t = T.memref(K, N, dtype),
172+
C_t = T.memref(M, N, dtype),
173173
](A: A_t, B: B_t, C: C_t):
174174
one = arith.constant(1.0, type=dtype)
175175
tmp = arith.constant(0, type=dtype)
@@ -193,10 +193,10 @@ def sgemm_coalesce[
193193
K,
194194
N,
195195
dtype,
196-
BLOCK_SIZE: 32,
197-
A_t: T.memref(M, K, dtype),
198-
B_t: T.memref(K, N, dtype),
199-
C_t: T.memref(M, N, dtype),
196+
BLOCK_SIZE = 32,
197+
A_t = T.memref(M, K, dtype),
198+
B_t = T.memref(K, N, dtype),
199+
C_t = T.memref(M, N, dtype),
200200
](A: A_t, B: B_t, C: C_t):
201201

202202
tid = gpu.thread_id()
@@ -259,10 +259,10 @@ def sgemm_coalesce_transpose_B[
259259
K,
260260
N,
261261
dtype,
262-
BLOCK_SIZE: 32,
263-
A_t: T.memref(M, K, dtype),
264-
B_t: T.memref(K, N, dtype),
265-
C_t: T.memref(M, N, dtype),
262+
BLOCK_SIZE = 32,
263+
A_t = T.memref(M, K, dtype),
264+
B_t = T.memref(K, N, dtype),
265+
C_t = T.memref(M, N, dtype),
266266
](A: A_t, B: B_t, C: C_t):
267267

268268
tid = gpu.thread_id()
@@ -288,10 +288,10 @@ def sgemm_shared_mem_block[
288288
K,
289289
N,
290290
dtype,
291-
BLOCK_SIZE: 32,
292-
A_t: T.memref(M, K, dtype),
293-
B_t: T.memref(K, N, dtype),
294-
C_t: T.memref(M, N, dtype),
291+
BLOCK_SIZE = 32,
292+
A_t = T.memref(M, K, dtype),
293+
B_t = T.memref(K, N, dtype),
294+
C_t = T.memref(M, N, dtype),
295295
](A: A_t, B: B_t, C: C_t):
296296
# allocate buffer for current block in fast shared mem
297297
# shared mem is shared between all threads in a block
@@ -394,9 +394,9 @@ def sgemm_shared_mem_1d_block_tiling[
394394
BN,
395395
BK,
396396
TM,
397-
A_t: T.memref(M, K, dtype),
398-
B_t: T.memref(K, N, dtype),
399-
C_t: T.memref(M, N, dtype),
397+
A_t = T.memref(M, K, dtype),
398+
B_t = T.memref(K, N, dtype),
399+
C_t = T.memref(M, N, dtype),
400400
](A: A_t, B: B_t, C: C_t):
401401
base = gpu.dynamic_shared_memory()
402402
A_shared = memref.view(base, (BM, BK), dtype=dtype)
@@ -455,9 +455,9 @@ def sgemm_shared_mem_2d_block_tiling[
455455
BK,
456456
TM,
457457
TN,
458-
A_t: T.memref(M, K, dtype),
459-
B_t: T.memref(K, N, dtype),
460-
C_t: T.memref(M, N, dtype),
458+
A_t = T.memref(M, K, dtype),
459+
B_t = T.memref(K, N, dtype),
460+
C_t = T.memref(M, N, dtype),
461461
](A: A_t, B: B_t, C: C_t):
462462
base = gpu.dynamic_shared_memory()
463463
A_shared = memref.view(base, (BM, BK), dtype=dtype)
@@ -542,9 +542,9 @@ def sgemm_shared_mem_2d_block_tiling_vectorize[
542542
BK,
543543
TM,
544544
TN,
545-
A_t: T.memref(M, K, dtype),
546-
B_t: T.memref(K, N, dtype),
547-
C_t: T.memref(M, N, dtype),
545+
A_t = T.memref(M, K, dtype),
546+
B_t = T.memref(K, N, dtype),
547+
C_t = T.memref(M, N, dtype),
548548
](A: A_t, B: B_t, C: C_t):
549549
VECTOR_WIDTH = 4
550550
DTYPE_WIDTH = dtype.width // 8
@@ -656,9 +656,9 @@ def sgemm_warp_tiling[
656656
TM,
657657
TN,
658658
NUM_THREADS,
659-
A_t: T.memref(M, K, dtype),
660-
B_t: T.memref(K, N, dtype),
661-
C_t: T.memref(M, N, dtype),
659+
A_t = T.memref(M, K, dtype),
660+
B_t = T.memref(K, N, dtype),
661+
C_t = T.memref(M, N, dtype),
662662
](A: A_t, B: B_t, C: C_t):
663663
VECTOR_WIDTH = 4
664664
DTYPE_WIDTH = dtype.width // 8
@@ -820,11 +820,11 @@ def sgemm_tensor_core[
820820
M,
821821
K,
822822
N,
823-
A_t: T.memref(M, K, T.f16()),
824-
B_t: T.memref(K, N, T.f16()),
825-
C_t: T.memref(M, N, T.f32()),
826-
a_tma_t: llvm_ptr_t(),
827-
b_tma_t: llvm_ptr_t(),
823+
A_t = T.memref(M, K, T.f16()),
824+
B_t = T.memref(K, N, T.f16()),
825+
C_t = T.memref(M, N, T.f32()),
826+
a_tma_t = llvm_ptr_t(),
827+
b_tma_t = llvm_ptr_t(),
828828
](A: A_t, B: B_t, C: C_t, a_tma: a_tma_t, b_tma: b_tma_t):
829829
a_tma = builtin.unrealized_conversion_cast(
830830
[
@@ -987,7 +987,7 @@ def prepare_warp_tiled_kernel(ctx: MLIRContext, kernel, M, K, N):
987987
def matmul_mod():
988988
kernel[M, K, N, dtype, BM, BN, BK, WM, WN, WNITER, TM, TN, NUM_THREADS].emit()
989989

990-
# print(ctx.module)
990+
print(ctx.module)
991991
assert ctx.module.operation.verify()
992992

993993
if cuda_bindings_not_installed():

projects/eudsl-python-extras/examples/flash_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
from mlir.extras.ast.canonicalize import canonicalize
88
from mlir.extras.context import RAIIMLIRContextModule
9-
from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm
9+
from mlir.extras.dialects import memref, scf, arith, gpu, llvm
1010
from mlir.dialects import math
1111

1212
# noinspection PyUnresolvedReferences
13-
from mlir.extras.dialects.ext.gpu import (
13+
from mlir.extras.dialects.gpu import (
1414
block_idx,
1515
thread_idx,
1616
grid_dim,
@@ -222,7 +222,7 @@ def flash_attention(
222222
ip.__exit__(None, None, None)
223223

224224
assert gpu_module.operation.verify()
225-
# print(gpu_module)
225+
print(gpu_module)
226226

227227
sram_size = 4 * Bc * d * np.float32().itemsize
228228

projects/eudsl-python-extras/examples/mlir_python_extras.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,17 @@
6464
"import mlir.extras.types as T\n",
6565
"from mlir.extras.ast.canonicalize import canonicalize\n",
6666
"from mlir.extras.context import mlir_mod_ctx\n",
67-
"from mlir.extras.dialects.ext.arith import constant\n",
68-
"from mlir.extras.dialects.ext.memref import S\n",
69-
"from mlir.extras.dialects.ext.func import func\n",
70-
"from mlir.extras.dialects.ext.scf import canonicalizer as scf, range_\n",
67+
"from mlir.extras.dialects.arith import constant\n",
68+
"from mlir.extras.dialects.memref import S\n",
69+
"from mlir.extras.dialects.func import func\n",
70+
"from mlir.extras.dialects.scf import canonicalizer as scf, range_\n",
7171
"from mlir.extras.runtime.passes import Pipeline, run_pipeline\n",
7272
"from mlir.extras.runtime.refbackend import LLVMJITBackend\n",
7373
"from mlir.ir import StridedLayoutAttr\n",
7474
"\n",
7575
"# you need this to register the memref value caster\n",
7676
"# noinspection PyUnresolvedReferences\n",
77-
"import mlir.extras.dialects.ext.memref\n",
77+
"import mlir.extras.dialects.memref\n",
7878
"\n",
7979
"ctx_man = mlir_mod_ctx()\n",
8080
"ctx = ctx_man.__enter__()\n",
@@ -417,7 +417,7 @@
417417
"layout = StridedLayoutAttr.get(S, (K, 1))\n",
418418
"ranked_memref_dxd_f32 = T.memref(D, D, T.f32(), layout=layout)\n",
419419
"\n",
420-
"from mlir.extras.dialects.ext import linalg\n",
420+
"from mlir.extras.dialects import linalg\n",
421421
"\n",
422422
"@func(emit=True)\n",
423423
"@canonicalize(using=scf)\n",

projects/eudsl-python-extras/examples/mwe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
# you need this to register the memref value caster
1313
# noinspection PyUnresolvedReferences
14-
import mlir.extras.dialects.ext.memref
14+
import mlir.extras.dialects.memref
1515
from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule
1616
from mlir.dialects.bufferization import LayoutMapOption
1717
from mlir.dialects.transform.vector import (
@@ -20,15 +20,15 @@
2020
VectorTransferSplit,
2121
VectorTransposeLowering,
2222
)
23-
from mlir.extras.dialects.ext import linalg
24-
from mlir.extras.dialects.ext.func import func
25-
from mlir.extras.dialects.ext.transform import (
23+
from mlir.extras.dialects import linalg
24+
from mlir.extras.dialects.func import func
25+
from mlir.extras.dialects.transform import (
2626
match,
2727
tile_to_scf_for,
2828
get_parent_op,
2929
transform_any_op_t,
3030
)
31-
from mlir.extras.dialects.ext import transform
31+
from mlir.extras.dialects import transform
3232
from mlir.extras.runtime.passes import Pipeline, run_pipeline
3333
from mlir.extras.runtime.refbackend import LLVMJITBackend
3434

projects/eudsl-python-extras/examples/rdna_matmul_opt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
from mlir.extras.ast.canonicalize import canonicalize
44
from mlir.extras.context import RAIIMLIRContextModule
5-
from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm
5+
from mlir.extras.dialects import memref, scf, arith, gpu, llvm
66
from mlir.dialects import index as index_dialect
77
from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr, Attribute
88
import mlir.extras.types as T
99

1010
# noinspection PyUnresolvedReferences
11-
from mlir.extras.dialects.ext.gpu import (
11+
from mlir.extras.dialects.gpu import (
1212
all_reduce,
1313
wait,
1414
thread_attr as thread,
@@ -721,7 +721,7 @@ def kernel5_lds_optim(
721721
)
722722

723723
assert simplified_module.operation.verify()
724-
# print(simplified_module)
724+
print(simplified_module)
725725

726726
lowered_module = run_pipeline(
727727
simplified_module,

projects/eudsl-python-extras/examples/vectorization_e2e.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
"\n",
7272
"# you need this to register the memref value caster\n",
7373
"# noinspection PyUnresolvedReferences\n",
74-
"import mlir.extras.dialects.ext.memref\n",
74+
"import mlir.extras.dialects.memref\n",
7575
"from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule\n",
7676
"from mlir.dialects.bufferization import LayoutMapOption\n",
7777
"from mlir.dialects.transform.vector import (\n",
@@ -80,15 +80,15 @@
8080
" VectorTransferSplit,\n",
8181
" VectorTransposeLowering,\n",
8282
")\n",
83-
"from mlir.extras.dialects.ext import linalg\n",
84-
"from mlir.extras.dialects.ext.func import func\n",
85-
"from mlir.extras.dialects.ext.transform import (\n",
83+
"from mlir.extras.dialects import linalg\n",
84+
"from mlir.extras.dialects.func import func\n",
85+
"from mlir.extras.dialects.transform import (\n",
8686
" match,\n",
8787
" tile_to_scf_for,\n",
8888
" get_parent_op,\n",
8989
" transform_any_op_t,\n",
9090
")\n",
91-
"from mlir.extras.dialects.ext import transform\n",
91+
"from mlir.extras.dialects import transform\n",
9292
"from mlir.extras.runtime.passes import Pipeline, run_pipeline\n",
9393
"from mlir.extras.runtime.refbackend import LLVMJITBackend\n"
9494
],

0 commit comments

Comments
 (0)