Skip to content

Commit 0ee4fac

Browse files
authored
Merge pull request #1235 from devitocodes/fix-aliases-scheduling
Fix scheduling of CIRE-detected aliasing expressions
2 parents a5b05e6 + 6be23ec commit 0ee4fac

File tree

11 files changed

+171
-115
lines changed

11 files changed

+171
-115
lines changed

devito/core/cpu.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from devito.exceptions import InvalidOperator
55
from devito.ir.clusters import Toposort
66
from devito.passes.clusters import (Blocking, Lift, cire, cse, eliminate_arrays,
7-
extract_increments, factorize, fuse, optimize_pows,
8-
scalarize)
7+
extract_increments, factorize, fuse, optimize_pows)
98
from devito.passes.iet import (DataManager, Ompizer, avoid_denormals, mpiize,
109
optimize_halospots, loop_wrapping, hoist_prodders,
1110
relax_incr_dimensions)
@@ -111,7 +110,6 @@ def _specialize_clusters(cls, clusters, **kwargs):
111110
# turn may enable further optimizations
112111
clusters = fuse(clusters)
113112
clusters = eliminate_arrays(clusters, template)
114-
clusters = scalarize(clusters, template)
115113

116114
return clusters
117115

@@ -225,7 +223,6 @@ def _specialize_clusters(cls, clusters, **kwargs):
225223
# turn may enable further optimizations
226224
clusters = fuse(clusters)
227225
clusters = eliminate_arrays(clusters, template)
228-
clusters = scalarize(clusters, template)
229226

230227
# Blocking to improve data locality
231228
clusters = Blocking(options).process(clusters)

devito/core/gpu_openmp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from devito.logger import warning
1111
from devito.mpi.routines import CopyBuffer, SendRecv, HaloUpdate
1212
from devito.passes.clusters import (Lift, cire, cse, eliminate_arrays, extract_increments,
13-
factorize, fuse, optimize_pows, scalarize)
13+
factorize, fuse, optimize_pows)
1414
from devito.passes.iet import (DataManager, Storage, Ompizer, ParallelIteration,
1515
ParallelTree, optimize_halospots, mpiize, hoist_prodders,
1616
iet_pass)
@@ -278,7 +278,6 @@ def _specialize_clusters(cls, clusters, **kwargs):
278278
# further optimizations
279279
clusters = fuse(clusters)
280280
clusters = eliminate_arrays(clusters, template)
281-
clusters = scalarize(clusters, template)
282281

283282
return clusters
284283

devito/finite_differences/differentiable.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,8 @@ def _has(self, pattern):
254254

255255
class DifferentiableOp(Differentiable):
256256

257+
__sympy_class__ = None
258+
257259
def __new__(cls, *args, **kwargs):
258260
obj = cls.__base__.__new__(cls, *args, **kwargs)
259261

@@ -294,18 +296,22 @@ def _eval_is_zero(self):
294296

295297

296298
class Add(DifferentiableOp, sympy.Add):
299+
__sympy_class__ = sympy.Add
297300
__new__ = DifferentiableOp.__new__
298301

299302

300303
class Mul(DifferentiableOp, sympy.Mul):
304+
__sympy_class__ = sympy.Mul
301305
__new__ = DifferentiableOp.__new__
302306

303307

304308
class Pow(DifferentiableOp, sympy.Pow):
309+
__sympy_class__ = sympy.Pow
305310
__new__ = DifferentiableOp.__new__
306311

307312

308313
class Mod(DifferentiableOp, sympy.Mod):
314+
__sympy_class__ = sympy.Mod
309315
__new__ = DifferentiableOp.__new__
310316

311317

@@ -369,6 +375,31 @@ def _(obj):
369375
return obj.__class__
370376

371377

378+
def diff2sympy(expr):
379+
"""
380+
Translate a Differentiable expression into a SymPy expression.
381+
"""
382+
383+
def _diff2sympy(obj):
384+
flag = False
385+
args = []
386+
for a in obj.args:
387+
ax, af = _diff2sympy(a)
388+
args.append(ax)
389+
flag |= af
390+
try:
391+
return obj.__sympy_class__(*args, evaluate=False), True
392+
except AttributeError:
393+
# Not of type DifferentiableOp
394+
pass
395+
if flag:
396+
return obj.func(*args, evaluate=False), True
397+
else:
398+
return obj, False
399+
400+
return _diff2sympy(expr)[0]
401+
402+
372403
# Make sure `sympy.evalf` knows how to evaluate the inherited classes
373404
# Without these, `evalf` would rely on a much slower, much more generic, and
374405
# thus much more time-inefficient fallback routine. This would hit us

devito/ir/equations/equation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sympy
33

44
from devito.ir.equations.algorithms import dimension_sort
5+
from devito.finite_differences.differentiable import diff2sympy
56
from devito.ir.support import (IterationSpace, DataSpace, Interval, IntervalGroup,
67
Stencil, detect_accesses, detect_oobs, detect_io,
78
build_intervals, build_iterators)
@@ -147,8 +148,11 @@ def __new__(cls, *args, **kwargs):
147148
for k, v in mapper.items() if k}
148149
dspace = DataSpace(dintervals, parts)
149150

151+
# Lower all Differentiable operations into SymPy operations
152+
rhs = diff2sympy(expr.rhs)
153+
150154
# Finally create the LoweredEq with all metadata attached
151-
expr = super(LoweredEq, cls).__new__(cls, expr.lhs, expr.rhs, evaluate=False)
155+
expr = super(LoweredEq, cls).__new__(cls, expr.lhs, rhs, evaluate=False)
152156

153157
expr._dspace = dspace
154158
expr._ispace = ispace

devito/ir/support/space.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,10 @@ def zero(self):
203203
def flip(self):
204204
return Interval(self.dim, self.upper, self.lower, self.stamp)
205205

206-
def lift(self):
207-
return Interval(self.dim, self.lower, self.upper, self.stamp + 1)
206+
def lift(self, v=None):
207+
if v is None:
208+
v = self.stamp + 1
209+
return Interval(self.dim, self.lower, self.upper, v)
208210

209211
def reset(self):
210212
return Interval(self.dim, self.lower, self.upper, 0)
@@ -373,9 +375,9 @@ def zero(self, d=None):
373375
return IntervalGroup([i.zero() if i.dim in d else i for i in self],
374376
relations=self.relations)
375377

376-
def lift(self, d):
378+
def lift(self, d, v=None):
377379
d = set(self.dimensions if d is None else as_tuple(d))
378-
return IntervalGroup([i.lift() if i.dim._defines & d else i for i in self],
380+
return IntervalGroup([i.lift(v) if i.dim._defines & d else i for i in self],
379381
relations=self.relations)
380382

381383
def reset(self):
@@ -700,13 +702,14 @@ def project(self, cond):
700702
func = lambda i: i in cond
701703

702704
intervals = [i for i in self.intervals if func(i.dim)]
703-
704705
sub_iterators = {k: v for k, v in self.sub_iterators.items() if func(k)}
705-
706706
directions = {k: v for k, v in self.directions.items() if func(k)}
707-
708707
return IterationSpace(intervals, sub_iterators, directions)
709708

709+
def lift(self, d=None, v=None):
710+
intervals = self.intervals.lift(d, v)
711+
return IterationSpace(intervals, self.sub_iterators, self.directions)
712+
710713
def is_compatible(self, other):
711714
"""
712715
A relaxed version of ``__eq__``, in which only non-derived dimensions

devito/passes/clusters/aliases.py

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from cached_property import cached_property
44
import numpy as np
55

6-
from devito.ir import (ROUNDABLE, DataSpace, IterationInstance, Interval,
7-
IntervalGroup, LabeledVector, detect_accesses, build_intervals)
6+
from devito.ir import (ROUNDABLE, DataSpace, IterationInstance, Interval, IntervalGroup,
7+
LabeledVector, Scope, detect_accesses, build_intervals)
88
from devito.passes.clusters.utils import cluster_pass, make_is_time_invariant
99
from devito.symbolics import (compare_ops, estimate_cost, q_constant, q_leaf,
1010
q_sum_of_product, q_terminalop, retrieve_indexed,
@@ -88,42 +88,50 @@ def cire(cluster, template, mode, options, platform):
8888
min_storage = options['min-storage']
8989

9090
# Setup callbacks
91-
if mode == 'invariants':
92-
# Extraction rule
93-
def extractor(context):
94-
return make_is_time_invariant(context)
95-
96-
# Extraction model
91+
def callbacks_invariants(context, *args):
92+
extractor = make_is_time_invariant(context)
9793
model = lambda e: estimate_cost(e, True) >= MIN_COST_ALIAS_INV
98-
99-
# Collection rule
10094
ignore_collected = lambda g: False
101-
102-
# Selection rule
10395
selector = lambda c, n: c >= MIN_COST_ALIAS_INV and n >= 1
104-
105-
elif mode == 'sops':
106-
# Extraction rule
107-
def extractor(context):
108-
return q_sum_of_product
109-
110-
# Extraction model
111-
model = lambda e: not (q_leaf(e) or q_terminalop(e))
112-
113-
# Collection rule
96+
return extractor, model, ignore_collected, selector
97+
98+
def callbacks_sops(context, n):
99+
# The `depth` determines "how big" the extracted sum-of-products will be.
100+
# We observe that in typical FD codes:
101+
# add(mul, mul, ...) -> stems from first order derivative
102+
# add(mul(add(mul, mul, ...), ...), ...) -> stems from second order derivative
103+
# To catch the former, we would need `depth=1`; for the latter, `depth=3`
104+
depth = 2*n + 1
105+
106+
extractor = lambda e: q_sum_of_product(e, depth)
107+
model = lambda e: not (q_leaf(e) or q_terminalop(e, depth-1))
114108
ignore_collected = lambda g: len(g) <= 1
115-
116-
# Selection rule
117109
selector = lambda c, n: c >= MIN_COST_ALIAS and n > 1
110+
return extractor, model, ignore_collected, selector
111+
112+
callbacks_mapper = {
113+
'invariants': callbacks_invariants,
114+
'sops': callbacks_sops
115+
}
118116

119-
# Actual CIRE
117+
# The main CIRE loop
120118
processed = []
121119
context = cluster.exprs
122-
for _ in range(options['cire-repeats'][mode]):
120+
for n in reversed(range(options['cire-repeats'][mode])):
121+
# Get the callbacks
122+
extractor, model, ignore_collected, selector = callbacks_mapper[mode](context, n)
123+
123124
# Extract potentially aliasing expressions
124-
exprs, extracted = extract(cluster, extractor(context), model, template)
125+
exprs, extracted = extract(cluster, extractor, model, template)
125126
if not extracted:
126127
# Do not waste time
128+
continue
129+
130+
# There can't be Dimension-dependent data dependences with any of
131+
# the `processed` Clusters, otherwise we would risk either OOB accesses
132+
# or reading from garbage uncomputed halo
133+
scope = Scope(exprs=flatten(c.exprs for c in processed) + extracted)
134+
if not all(i.is_indep() for i in scope.d_all_gen()):
127135
break
128136

129137
# Search aliasing expressions
@@ -133,7 +141,7 @@ def extractor(context):
133141
chosen, others = choose(exprs, aliases, selector)
134142
if not chosen:
135143
# Do not waste time
136-
break
144+
continue
137145

138146
# Create Aliases and assign them to Clusters
139147
clusters, subs = process(cluster, chosen, aliases, template, platform)
@@ -341,7 +349,7 @@ def choose(exprs, aliases, selector):
341349
def process(cluster, chosen, aliases, template, platform):
342350
clusters = []
343351
subs = {}
344-
for alias, writeto, aliaseds, distances in aliases.schedule(cluster.ispace):
352+
for alias, writeto, aliaseds, distances in aliases.iter(cluster.ispace):
345353
if all(i not in chosen for i in aliaseds):
346354
continue
347355

@@ -412,7 +420,7 @@ def process(cluster, chosen, aliases, template, platform):
412420

413421
# Finally, build a new Cluster for `alias`
414422
built = cluster.rebuild(exprs=expression, ispace=ispace, dspace=dspace)
415-
clusters.insert(0, built)
423+
clusters.append(built)
416424

417425
return clusters, subs
418426

@@ -433,6 +441,9 @@ def rebuild(cluster, others, aliases, subs):
433441
return cluster.rebuild(exprs=exprs, ispace=ispace, dspace=dspace)
434442

435443

444+
# Utilities
445+
446+
436447
class Candidate(object):
437448

438449
def __init__(self, expr, indexeds, bases, offsets):
@@ -488,9 +499,6 @@ def dimensions(self):
488499
return frozenset(i for i, _ in self.Toffsets)
489500

490501

491-
# Utilities
492-
493-
494502
class Group(tuple):
495503

496504
"""
@@ -688,7 +696,7 @@ def get(self, key):
688696
return aliaseds
689697
return []
690698

691-
def schedule(self, ispace):
699+
def iter(self, ispace):
692700
"""
693701
The aliases can be be scheduled in any order, but we privilege the one
694702
that minimizes storage while maximizing fusion.
@@ -710,8 +718,7 @@ def schedule(self, ispace):
710718
# use `<1>` which is the actual stamp used in the Cluster
711719
# from which the aliasing expressions were extracted
712720
assert i.stamp >= interval.stamp
713-
while interval.stamp != i.stamp:
714-
interval = interval.lift()
721+
interval = interval.lift(i.stamp)
715722

716723
writeto.append(interval)
717724

devito/passes/clusters/factorization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def collect_const(expr):
6464

6565
terms = []
6666
for k, v in inverse_mapper.items():
67-
if len(v) == 1:
68-
# We can actually evaluate everything to avoid, e.g., (-1)*a
67+
if len(v) == 1 and not v[0].is_Add:
68+
# Special case: avoid e.g. (-2)*a
6969
mul = Mul(k, *v)
7070
elif all(i.is_Mul and len(i.args) == 2 and i.args[0] == -1 for i in v):
7171
# Other special case: [-a, -b, -c ...]

devito/passes/clusters/misc.py

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
from devito.ir.clusters import Cluster, Queue
44
from devito.ir.support import TILABLE
55
from devito.passes.clusters.utils import cluster_pass
6-
from devito.symbolics import pow_to_mul, xreplace_indices, uxreplace
6+
from devito.symbolics import pow_to_mul, uxreplace
77
from devito.tools import filter_ordered, timed_pass
88
from devito.types import Scalar
99

10-
__all__ = ['Lift', 'fuse', 'scalarize', 'eliminate_arrays', 'optimize_pows',
11-
'extract_increments']
10+
__all__ = ['Lift', 'fuse', 'eliminate_arrays', 'optimize_pows', 'extract_increments']
1211

1312

1413
class Lift(Queue):
@@ -99,48 +98,6 @@ def fuse(clusters):
9998
return processed
10099

101100

102-
@timed_pass()
103-
def scalarize(clusters, template):
104-
"""
105-
Turn local "isolated" Arrays, that is Arrays appearing only in one Cluster,
106-
into Scalars.
107-
"""
108-
processed = []
109-
for c in clusters:
110-
# Get any Arrays appearing only in `c`
111-
impacted = set(clusters) - {c}
112-
arrays = {i for i in c.scope.writes if i.is_Array}
113-
arrays -= set().union(*[i.scope.reads for i in impacted])
114-
115-
# Turn them into scalars
116-
#
117-
# r[x,y,z] = g(b[x,y,z]) t0 = g(b[x,y,z])
118-
# ... = r[x,y,z] + r[x,y,z+1]` ----> t1 = g(b[x,y,z+1])
119-
# ... = t0 + t1
120-
mapper = {}
121-
exprs = []
122-
for n, e in enumerate(c.exprs):
123-
f = e.lhs.function
124-
if f in arrays:
125-
indexeds = [i.indexed for i in c.scope[f] if i.timestamp > n]
126-
for i in filter_ordered(indexeds):
127-
mapper[i] = Scalar(name=template(), dtype=f.dtype)
128-
129-
assert len(f.indices) == len(e.lhs.indices) == len(i.indices)
130-
shifting = {idx: idx + (o2 - o1) for idx, o1, o2 in
131-
zip(f.indices, e.lhs.indices, i.indices)}
132-
133-
handle = e.func(mapper[i], uxreplace(e.rhs, mapper))
134-
handle = xreplace_indices(handle, shifting)
135-
exprs.append(handle)
136-
else:
137-
exprs.append(e.func(e.lhs, uxreplace(e.rhs, mapper)))
138-
139-
processed.append(c.rebuild(exprs))
140-
141-
return processed
142-
143-
144101
@timed_pass()
145102
def eliminate_arrays(clusters, template):
146103
"""

0 commit comments

Comments
 (0)