33from cached_property import cached_property
44import numpy as np
55
6- from devito .ir import (ROUNDABLE , DataSpace , IterationInstance , Interval ,
7- IntervalGroup , LabeledVector , detect_accesses , build_intervals )
6+ from devito .ir import (ROUNDABLE , DataSpace , IterationInstance , Interval , IntervalGroup ,
7+ LabeledVector , Scope , detect_accesses , build_intervals )
88from devito .passes .clusters .utils import cluster_pass , make_is_time_invariant
99from devito .symbolics import (compare_ops , estimate_cost , q_constant , q_leaf ,
1010 q_sum_of_product , q_terminalop , retrieve_indexed ,
@@ -88,42 +88,50 @@ def cire(cluster, template, mode, options, platform):
8888 min_storage = options ['min-storage' ]
8989
9090 # Setup callbacks
91- if mode == 'invariants' :
92- # Extraction rule
93- def extractor (context ):
94- return make_is_time_invariant (context )
95-
96- # Extraction model
91+ def callbacks_invariants (context , * args ):
92+ extractor = make_is_time_invariant (context )
9793 model = lambda e : estimate_cost (e , True ) >= MIN_COST_ALIAS_INV
98-
99- # Collection rule
10094 ignore_collected = lambda g : False
101-
102- # Selection rule
10395 selector = lambda c , n : c >= MIN_COST_ALIAS_INV and n >= 1
104-
105- elif mode == 'sops' :
106- # Extraction rule
107- def extractor (context ):
108- return q_sum_of_product
109-
110- # Extraction model
111- model = lambda e : not (q_leaf (e ) or q_terminalop (e ))
112-
113- # Collection rule
96+ return extractor , model , ignore_collected , selector
97+
98+ def callbacks_sops (context , n ):
99+ # The `depth` determines "how big" the extracted sum-of-products will be.
100+ # We observe that in typical FD codes:
101+ # add(mul, mul, ...) -> stems from first order derivative
102+ # add(mul(add(mul, mul, ...), ...), ...) -> stems from second order derivative
103+ # To catch the former, we would need `depth=1`; for the latter, `depth=3`
104+ depth = 2 * n + 1
105+
106+ extractor = lambda e : q_sum_of_product (e , depth )
107+ model = lambda e : not (q_leaf (e ) or q_terminalop (e , depth - 1 ))
114108 ignore_collected = lambda g : len (g ) <= 1
115-
116- # Selection rule
117109 selector = lambda c , n : c >= MIN_COST_ALIAS and n > 1
110+ return extractor , model , ignore_collected , selector
111+
112+ callbacks_mapper = {
113+ 'invariants' : callbacks_invariants ,
114+ 'sops' : callbacks_sops
115+ }
118116
119- # Actual CIRE
117+ # The main CIRE loop
120118 processed = []
121119 context = cluster .exprs
122- for _ in range (options ['cire-repeats' ][mode ]):
120+ for n in reversed (range (options ['cire-repeats' ][mode ])):
121+ # Get the callbacks
122+ extractor , model , ignore_collected , selector = callbacks_mapper [mode ](context , n )
123+
123124 # Extract potentially aliasing expressions
124- exprs , extracted = extract (cluster , extractor ( context ) , model , template )
125+ exprs , extracted = extract (cluster , extractor , model , template )
125126 if not extracted :
126127 # Do not waste time
128+ continue
129+
130+ # There can't be Dimension-dependent data dependences with any of
131+ # the `processed` Clusters, otherwise we would risk either OOB accesses
132+ # or reading from garbage uncomputed halo
133+ scope = Scope (exprs = flatten (c .exprs for c in processed ) + extracted )
134+ if not all (i .is_indep () for i in scope .d_all_gen ()):
127135 break
128136
129137 # Search aliasing expressions
@@ -133,7 +141,7 @@ def extractor(context):
133141 chosen , others = choose (exprs , aliases , selector )
134142 if not chosen :
135143 # Do not waste time
136- break
144+ continue
137145
138146 # Create Aliases and assign them to Clusters
139147 clusters , subs = process (cluster , chosen , aliases , template , platform )
@@ -341,7 +349,7 @@ def choose(exprs, aliases, selector):
341349def process (cluster , chosen , aliases , template , platform ):
342350 clusters = []
343351 subs = {}
344- for alias , writeto , aliaseds , distances in aliases .schedule (cluster .ispace ):
352+ for alias , writeto , aliaseds , distances in aliases .iter (cluster .ispace ):
345353 if all (i not in chosen for i in aliaseds ):
346354 continue
347355
@@ -412,7 +420,7 @@ def process(cluster, chosen, aliases, template, platform):
412420
413421 # Finally, build a new Cluster for `alias`
414422 built = cluster .rebuild (exprs = expression , ispace = ispace , dspace = dspace )
415- clusters .insert ( 0 , built )
423+ clusters .append ( built )
416424
417425 return clusters , subs
418426
@@ -433,6 +441,9 @@ def rebuild(cluster, others, aliases, subs):
433441 return cluster .rebuild (exprs = exprs , ispace = ispace , dspace = dspace )
434442
435443
444+ # Utilities
445+
446+
436447class Candidate (object ):
437448
438449 def __init__ (self , expr , indexeds , bases , offsets ):
@@ -488,9 +499,6 @@ def dimensions(self):
488499 return frozenset (i for i , _ in self .Toffsets )
489500
490501
491- # Utilities
492-
493-
494502class Group (tuple ):
495503
496504 """
@@ -688,7 +696,7 @@ def get(self, key):
688696 return aliaseds
689697 return []
690698
691- def schedule (self , ispace ):
699+ def iter (self , ispace ):
692700 """
693701 The aliases can be be scheduled in any order, but we privilege the one
694702 that minimizes storage while maximizing fusion.
@@ -710,8 +718,7 @@ def schedule(self, ispace):
710718 # use `<1>` which is the actual stamp used in the Cluster
711719 # from which the aliasing expressions were extracted
712720 assert i .stamp >= interval .stamp
713- while interval .stamp != i .stamp :
714- interval = interval .lift ()
721+ interval = interval .lift (i .stamp )
715722
716723 writeto .append (interval )
717724
0 commit comments