1111 mlir_mod_ctx ,
1212 MLIRContext ,
1313)
14- from mlir .extras .dialects . ext import arith , memref , gpu , scf , linalg , vector , nvgpu
15- from mlir .extras .dialects .ext . gpu import (
14+ from mlir .extras .dialects import arith , memref , gpu , scf , linalg , vector , nvgpu
15+ from mlir .extras .dialects .gpu import (
1616 block_idx ,
1717 thread_idx ,
1818 block_dim ,
1919 get_compile_object_bytes ,
2020 smem_space ,
2121)
22- from mlir .extras .dialects .ext . llvm import llvm_ptr_t
23- from mlir .extras .dialects .ext . memref import S
24- from mlir .extras .dialects .ext . scf import range_
22+ from mlir .extras .dialects .llvm import llvm_ptr_t
23+ from mlir .extras .dialects .memref import S
24+ from mlir .extras .dialects .scf import range_
2525from mlir .extras .runtime .passes import Pipeline , run_pipeline
2626
2727# noinspection PyUnresolvedReferences
@@ -139,9 +139,9 @@ def sgemm_naive[
139139 K ,
140140 N ,
141141 dtype ,
142- A_t : T .memref (M , K , dtype ),
143- B_t : T .memref (K , N , dtype ),
144- C_t : T .memref (M , N , dtype ),
142+ A_t = T .memref (M , K , dtype ),
143+ B_t = T .memref (K , N , dtype ),
144+ C_t = T .memref (M , N , dtype ),
145145](A : A_t , B : B_t , C : C_t ):
146146 one = arith .constant (1.0 , type = dtype )
147147 tmp = arith .constant (0 , type = dtype )
@@ -167,9 +167,9 @@ def sgemm_naive_row_order[
167167 K ,
168168 N ,
169169 dtype ,
170- A_t : T .memref (M , K , dtype ),
171- B_t : T .memref (K , N , dtype ),
172- C_t : T .memref (M , N , dtype ),
170+ A_t = T .memref (M , K , dtype ),
171+ B_t = T .memref (K , N , dtype ),
172+ C_t = T .memref (M , N , dtype ),
173173](A : A_t , B : B_t , C : C_t ):
174174 one = arith .constant (1.0 , type = dtype )
175175 tmp = arith .constant (0 , type = dtype )
@@ -193,10 +193,10 @@ def sgemm_coalesce[
193193 K ,
194194 N ,
195195 dtype ,
196- BLOCK_SIZE : 32 ,
197- A_t : T .memref (M , K , dtype ),
198- B_t : T .memref (K , N , dtype ),
199- C_t : T .memref (M , N , dtype ),
196+ BLOCK_SIZE = 32 ,
197+ A_t = T .memref (M , K , dtype ),
198+ B_t = T .memref (K , N , dtype ),
199+ C_t = T .memref (M , N , dtype ),
200200](A : A_t , B : B_t , C : C_t ):
201201
202202 tid = gpu .thread_id ()
@@ -259,10 +259,10 @@ def sgemm_coalesce_transpose_B[
259259 K ,
260260 N ,
261261 dtype ,
262- BLOCK_SIZE : 32 ,
263- A_t : T .memref (M , K , dtype ),
264- B_t : T .memref (K , N , dtype ),
265- C_t : T .memref (M , N , dtype ),
262+ BLOCK_SIZE = 32 ,
263+ A_t = T .memref (M , K , dtype ),
264+ B_t = T .memref (K , N , dtype ),
265+ C_t = T .memref (M , N , dtype ),
266266](A : A_t , B : B_t , C : C_t ):
267267
268268 tid = gpu .thread_id ()
@@ -288,10 +288,10 @@ def sgemm_shared_mem_block[
288288 K ,
289289 N ,
290290 dtype ,
291- BLOCK_SIZE : 32 ,
292- A_t : T .memref (M , K , dtype ),
293- B_t : T .memref (K , N , dtype ),
294- C_t : T .memref (M , N , dtype ),
291+ BLOCK_SIZE = 32 ,
292+ A_t = T .memref (M , K , dtype ),
293+ B_t = T .memref (K , N , dtype ),
294+ C_t = T .memref (M , N , dtype ),
295295](A : A_t , B : B_t , C : C_t ):
296296 # allocate buffer for current block in fast shared mem
297297 # shared mem is shared between all threads in a block
@@ -394,9 +394,9 @@ def sgemm_shared_mem_1d_block_tiling[
394394 BN ,
395395 BK ,
396396 TM ,
397- A_t : T .memref (M , K , dtype ),
398- B_t : T .memref (K , N , dtype ),
399- C_t : T .memref (M , N , dtype ),
397+ A_t = T .memref (M , K , dtype ),
398+ B_t = T .memref (K , N , dtype ),
399+ C_t = T .memref (M , N , dtype ),
400400](A : A_t , B : B_t , C : C_t ):
401401 base = gpu .dynamic_shared_memory ()
402402 A_shared = memref .view (base , (BM , BK ), dtype = dtype )
@@ -455,9 +455,9 @@ def sgemm_shared_mem_2d_block_tiling[
455455 BK ,
456456 TM ,
457457 TN ,
458- A_t : T .memref (M , K , dtype ),
459- B_t : T .memref (K , N , dtype ),
460- C_t : T .memref (M , N , dtype ),
458+ A_t = T .memref (M , K , dtype ),
459+ B_t = T .memref (K , N , dtype ),
460+ C_t = T .memref (M , N , dtype ),
461461](A : A_t , B : B_t , C : C_t ):
462462 base = gpu .dynamic_shared_memory ()
463463 A_shared = memref .view (base , (BM , BK ), dtype = dtype )
@@ -542,9 +542,9 @@ def sgemm_shared_mem_2d_block_tiling_vectorize[
542542 BK ,
543543 TM ,
544544 TN ,
545- A_t : T .memref (M , K , dtype ),
546- B_t : T .memref (K , N , dtype ),
547- C_t : T .memref (M , N , dtype ),
545+ A_t = T .memref (M , K , dtype ),
546+ B_t = T .memref (K , N , dtype ),
547+ C_t = T .memref (M , N , dtype ),
548548](A : A_t , B : B_t , C : C_t ):
549549 VECTOR_WIDTH = 4
550550 DTYPE_WIDTH = dtype .width // 8
@@ -656,9 +656,9 @@ def sgemm_warp_tiling[
656656 TM ,
657657 TN ,
658658 NUM_THREADS ,
659- A_t : T .memref (M , K , dtype ),
660- B_t : T .memref (K , N , dtype ),
661- C_t : T .memref (M , N , dtype ),
659+ A_t = T .memref (M , K , dtype ),
660+ B_t = T .memref (K , N , dtype ),
661+ C_t = T .memref (M , N , dtype ),
662662](A : A_t , B : B_t , C : C_t ):
663663 VECTOR_WIDTH = 4
664664 DTYPE_WIDTH = dtype .width // 8
@@ -820,11 +820,11 @@ def sgemm_tensor_core[
820820 M ,
821821 K ,
822822 N ,
823- A_t : T .memref (M , K , T .f16 ()),
824- B_t : T .memref (K , N , T .f16 ()),
825- C_t : T .memref (M , N , T .f32 ()),
826- a_tma_t : llvm_ptr_t (),
827- b_tma_t : llvm_ptr_t (),
823+ A_t = T .memref (M , K , T .f16 ()),
824+ B_t = T .memref (K , N , T .f16 ()),
825+ C_t = T .memref (M , N , T .f32 ()),
826+ a_tma_t = llvm_ptr_t (),
827+ b_tma_t = llvm_ptr_t (),
828828](A : A_t , B : B_t , C : C_t , a_tma : a_tma_t , b_tma : b_tma_t ):
829829 a_tma = builtin .unrealized_conversion_cast (
830830 [
@@ -987,7 +987,7 @@ def prepare_warp_tiled_kernel(ctx: MLIRContext, kernel, M, K, N):
987987 def matmul_mod ():
988988 kernel [M , K , N , dtype , BM , BN , BK , WM , WN , WNITER , TM , TN , NUM_THREADS ].emit ()
989989
990- # print(ctx.module)
990+ print (ctx .module )
991991 assert ctx .module .operation .verify ()
992992
993993 if cuda_bindings_not_installed ():
0 commit comments