Skip to content

Commit ed3585a

Browse files
authored
Merge pull request #2624 from devitocodes/compiler-speed
compiler: Address various compiler hotspots with operators containing large expression counts
2 parents b3408d8 + 613f1fe commit ed3585a

File tree

5 files changed

+62
-14
lines changed

5 files changed

+62
-14
lines changed

devito/ir/support/basic.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
q_constant, q_comp_acc, q_affine, q_routine, search,
1212
uxreplace)
1313
from devito.tools import (Tag, as_mapper, as_tuple, is_integer, filter_sorted,
14-
flatten, memoized_meth, memoized_generator)
14+
flatten, memoized_meth, memoized_generator, smart_gt,
15+
smart_lt)
1516
from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence,
1617
CriticalRegion, Function, Symbol, Temp, TempArray,
1718
TBArray)
@@ -364,11 +365,12 @@ def distance(self, other):
364365
# trip count. E.g. it ranges from 0 to 3; `other` performs a
365366
# constant access at 4
366367
for v in (self[n], other[n]):
367-
try:
368-
if bool(v < sit.symbolic_min or v > sit.symbolic_max):
369-
return Vector(S.ImaginaryUnit)
370-
except TypeError:
371-
pass
368+
# Note: Uses smart_ comparisons avoid evaluating expensive
369+
# symbolic Lt or Gt operations,
370+
# Note: Boolean is split to make the conditional short
371+
# circuit more frequently for mild speedup.
372+
if smart_lt(v, sit.symbolic_min) or smart_gt(v, sit.symbolic_max):
373+
return Vector(S.ImaginaryUnit)
372374

373375
# Case 2: `sit` is an IterationInterval over a local SubDimension
374376
# and `other` performs a constant access

devito/passes/clusters/cse.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections import defaultdict
1+
from collections import defaultdict, Counter
22
from functools import cached_property, singledispatch
33

44
import numpy as np
@@ -13,6 +13,7 @@
1313
from devito.finite_differences.differentiable import IndexDerivative
1414
from devito.ir import Cluster, Scope, cluster_pass
1515
from devito.symbolics import estimate_cost, q_leaf, q_terminal
16+
from devito.symbolics.search import search
1617
from devito.symbolics.manipulation import _uxreplace
1718
from devito.tools import DAG, as_list, as_tuple, frozendict, extract_dtype
1819
from devito.types import Eq, Symbol, Temp
@@ -25,9 +26,15 @@ class CTemp(Temp):
2526
"""
2627
A cluster-level Temp, similar to Temp, ensured to have different priority
2728
"""
29+
2830
ordering_of_classes.insert(ordering_of_classes.index('Temp') + 1, 'CTemp')
2931

3032

33+
def retrieve_ctemps(exprs, mode='all'):
34+
"""Shorthand to retrieve the CTemps in `exprs`"""
35+
return search(exprs, lambda expr: isinstance(expr, CTemp), mode, 'dfs')
36+
37+
3138
@cluster_pass
3239
def cse(cluster, sregistry=None, options=None, **kwargs):
3340
"""
@@ -225,8 +232,15 @@ def _compact(exprs, exclude):
225232

226233
mapper = {e.lhs: e.rhs for e in candidates if q_leaf(e.rhs)}
227234

228-
mapper.update({e.lhs: e.rhs for e in candidates
229-
if sum([i.rhs.count(e.lhs) for i in exprs]) == 1})
235+
# Find all the CTemps in expression right-hand-sides without removing duplicates
236+
ctemps = retrieve_ctemps(e.rhs for e in exprs)
237+
238+
# If there are ctemps in the expressions, then add any that only appear once to
239+
# the mapper
240+
if ctemps:
241+
ctemp_count = Counter(ctemps)
242+
mapper.update({e.lhs: e.rhs for e in candidates
243+
if ctemp_count[e.lhs] == 1})
230244

231245
processed = []
232246
for e in exprs:

devito/passes/clusters/misc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,7 @@ def is_cross(source, sink):
352352
v = len(cg0.exprs)
353353
return t0 < v <= t1 or t1 < v <= t0
354354

355-
for cg1 in cgroups[n+1:]:
356-
n1 = cgroups.index(cg1)
355+
for n1, cg1 in enumerate(cgroups[n+1:], start=n+1):
357356

358357
# A Scope to compute all cross-ClusterGroup anti-dependences
359358
scope = Scope(exprs=cg0.exprs + cg1.exprs, rules=is_cross)

devito/symbolics/search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def retrieve_functions(exprs, mode='all', deep=False):
155155

156156

157157
def retrieve_symbols(exprs, mode='all'):
158-
"""Shorthand to retrieve the Scalar in ``exprs``."""
158+
"""Shorthand to retrieve the Scalar in `exprs`."""
159159
return search(exprs, q_symbol, mode, 'dfs')
160160

161161

devito/tools/utils.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import OrderedDict
22
from collections.abc import Iterable
3-
from functools import reduce
3+
from functools import reduce, wraps
44
from itertools import chain, combinations, groupby, product, zip_longest
55
from operator import attrgetter, mul
66
import types
@@ -12,7 +12,8 @@
1212
'roundm', 'powerset', 'invert', 'flatten', 'single_or', 'filter_ordered',
1313
'as_mapper', 'filter_sorted', 'pprint', 'sweep', 'all_equal', 'as_list',
1414
'indices_to_slices', 'indices_to_sections', 'transitive_closure',
15-
'humanbytes', 'contains_val', 'sorted_priority', 'as_set', 'is_number']
15+
'humanbytes', 'contains_val', 'sorted_priority', 'as_set', 'is_number',
16+
'smart_lt', 'smart_gt']
1617

1718

1819
def prod(iterable, initial=1):
@@ -346,3 +347,35 @@ def key(i):
346347
return (v, str(type(i)))
347348

348349
return sorted(items, key=key, reverse=True)
350+
351+
352+
def avoid_symbolic(default=None):
353+
"""
354+
Decorator to avoid calling a function where doing so will result in symbolic
355+
computation being performed. For use if symbolic computation may be slow. In
356+
the case that an arg is symbolic, just give up and return a default value.
357+
"""
358+
def _avoid_symbolic(func):
359+
@wraps(func)
360+
def wrapper(*args):
361+
if any(isinstance(expr, sympy.Basic) for expr in args):
362+
# An argument is symbolic, so give up and assume default
363+
return default
364+
365+
return func(*args)
366+
367+
return wrapper
368+
369+
return _avoid_symbolic
370+
371+
372+
@avoid_symbolic(default=False)
373+
def smart_lt(a, b):
374+
"""An Lt that gives up and returns False if supplied a symbolic argument"""
375+
return a < b
376+
377+
378+
@avoid_symbolic(default=False)
379+
def smart_gt(a, b):
380+
"""A Gt that gives up and returns False if supplied a symbolic argument"""
381+
return a > b

0 commit comments

Comments
 (0)