Skip to content

Commit 51e6c10

Browse files
committed
experiment: per-module fan-out in compiler stages
Signed-off-by: Stephan Renatus <[email protected]>
1 parent 3188e04 commit 51e6c10

File tree

2 files changed

+48
-49
lines changed

2 files changed

+48
-49
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module github.com/open-policy-agent/opa
22

3-
go 1.24.6
3+
go 1.25
44

55
require (
66
github.com/agnivade/levenshtein v1.2.1

v1/ast/compile.go

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"sort"
1414
"strconv"
1515
"strings"
16+
"sync"
17+
"sync/atomic"
1618

1719
"github.com/open-policy-agent/opa/internal/debug"
1820
"github.com/open-policy-agent/opa/internal/gojsonschema"
@@ -1756,6 +1758,15 @@ func (c *Compiler) init() {
17561758
c.initialized = true
17571759
}
17581760

1761+
func (c *Compiler) forEachModule(f func(mod *Module)) {
1762+
wg := &sync.WaitGroup{}
1763+
for _, name := range c.sorted {
1764+
wg.Go(func() { f(c.Modules[name]) })
1765+
}
1766+
wg.Wait()
1767+
}
1768+
1769+
// TODO(sr): Fix this, it's not concurrency-safe. And the panic is bad form, too.
17591770
func (c *Compiler) err(err *Error) {
17601771
if c.maxErrs > 0 && len(c.Errors) >= c.maxErrs {
17611772
c.Errors = append(c.Errors, errLimitReached)
@@ -1957,7 +1968,7 @@ func (c *Compiler) resolveAllRefs() {
19571968

19581969
func (c *Compiler) removeImports() {
19591970
c.imports = make(map[string][]*Import, len(c.Modules))
1960-
for name := range c.Modules {
1971+
for name := range c.Modules { // Trivial. No fan-out for this.
19611972
c.imports[name] = c.Modules[name].Imports
19621973
c.Modules[name].Imports = nil
19631974
}
@@ -1969,27 +1980,25 @@ func (c *Compiler) initLocalVarGen() {
19691980

19701981
func (c *Compiler) rewriteComprehensionTerms() {
19711982
f := newEqualityFactory(c.localvargen)
1972-
for _, name := range c.sorted {
1973-
mod := c.Modules[name]
1983+
c.forEachModule(func(mod *Module) {
19741984
_, _ = rewriteComprehensionTerms(f, mod) // ignore error
1975-
}
1985+
})
19761986
}
19771987

19781988
func (c *Compiler) rewriteExprTerms() {
1979-
for _, name := range c.sorted {
1980-
mod := c.Modules[name]
1989+
c.forEachModule(func(mod *Module) {
19811990
WalkRules(mod, func(rule *Rule) bool {
19821991
rewriteExprTermsInHead(c.localvargen, rule)
19831992
rule.Body = rewriteExprTermsInBody(c.localvargen, rule.Body)
19841993
return false
19851994
})
1986-
}
1995+
})
19871996
}
19881997

19891998
func (c *Compiler) rewriteRuleHeadRefs() {
19901999
f := newEqualityFactory(c.localvargen)
1991-
for _, name := range c.sorted {
1992-
WalkRules(c.Modules[name], func(rule *Rule) bool {
2000+
c.forEachModule(func(mod *Module) {
2001+
WalkRules(mod, func(rule *Rule) bool {
19932002

19942003
ref := rule.Head.Ref()
19952004
// NOTE(sr): We're backfilling Refs here -- all parser code paths would have them, but
@@ -2042,16 +2051,15 @@ func (c *Compiler) rewriteRuleHeadRefs() {
20422051

20432052
return true
20442053
})
2045-
}
2054+
})
20462055
}
20472056

20482057
func (c *Compiler) checkVoidCalls() {
2049-
for _, name := range c.sorted {
2050-
mod := c.Modules[name]
2058+
c.forEachModule(func(mod *Module) {
20512059
for _, err := range checkVoidCalls(c.TypeEnv, mod) {
20522060
c.err(err)
20532061
}
2054-
}
2062+
})
20552063
}
20562064

20572065
func (c *Compiler) rewritePrintCalls() {
@@ -2063,8 +2071,7 @@ func (c *Compiler) rewritePrintCalls() {
20632071
}
20642072
}
20652073
} else {
2066-
for _, name := range c.sorted {
2067-
mod := c.Modules[name]
2074+
c.forEachModule(func(mod *Module) {
20682075
WalkRules(mod, func(r *Rule) bool {
20692076
safe := r.Head.Args.Vars()
20702077
safe.Update(ReservedVars)
@@ -2082,7 +2089,7 @@ func (c *Compiler) rewritePrintCalls() {
20822089
WalkBodies(r.Body, vis)
20832090
return false
20842091
})
2085-
}
2092+
})
20862093
}
20872094
if modified {
20882095
c.Required.addBuiltinSorted(Print)
@@ -2295,8 +2302,7 @@ func isPrintCall(x *Expr) bool {
22952302
// p[__local0__] { i < 100; __local0__ = {"foo": data.foo[i]} }
22962303
func (c *Compiler) rewriteRefsInHead() {
22972304
f := newEqualityFactory(c.localvargen)
2298-
for _, name := range c.sorted {
2299-
mod := c.Modules[name]
2305+
c.forEachModule(func(mod *Module) {
23002306
WalkRules(mod, func(rule *Rule) bool {
23012307
if requiresEval(rule.Head.Key) {
23022308
expr := f.Generate(rule.Head.Key)
@@ -2317,27 +2323,27 @@ func (c *Compiler) rewriteRefsInHead() {
23172323
}
23182324
return false
23192325
})
2320-
}
2326+
})
23212327
}
23222328

23232329
func (c *Compiler) rewriteEquals() {
23242330
modified := false
2325-
for _, name := range c.sorted {
2326-
modified = rewriteEquals(c.Modules[name]) || modified
2327-
}
2331+
c.forEachModule(func(mod *Module) {
2332+
modified = rewriteEquals(mod) || modified
2333+
})
23282334
if modified {
23292335
c.Required.addBuiltinSorted(Equal)
23302336
}
23312337
}
23322338

23332339
func (c *Compiler) rewriteDynamicTerms() {
23342340
f := newEqualityFactory(c.localvargen)
2335-
for _, name := range c.sorted {
2336-
WalkRules(c.Modules[name], func(rule *Rule) bool {
2341+
c.forEachModule(func(mod *Module) {
2342+
WalkRules(mod, func(rule *Rule) bool {
23372343
rule.Body = rewriteDynamics(f, rule.Body)
23382344
return false
23392345
})
2340-
}
2346+
})
23412347
}
23422348

23432349
// rewriteTestRuleEqualities rewrites equality expressions in test rule bodies to create local vars for statements that would otherwise
@@ -2366,39 +2372,34 @@ func (c *Compiler) rewriteTestRuleEqualities() {
23662372
}
23672373

23682374
f := newEqualityFactory(c.localvargen)
2369-
for _, name := range c.sorted {
2370-
mod := c.Modules[name]
2375+
c.forEachModule(func(mod *Module) {
23712376
WalkRules(mod, func(rule *Rule) bool {
23722377
if strings.HasPrefix(string(rule.Head.Name), "test_") {
23732378
rule.Body = rewriteTestEqualities(f, rule.Body)
23742379
}
23752380
return false
23762381
})
2377-
}
2382+
})
23782383
}
23792384

23802385
func (c *Compiler) parseMetadataBlocks() {
23812386
// Only parse annotations if rego.metadata built-ins are called
23822387
regoMetadataCalled := false
2383-
for _, name := range c.sorted {
2384-
mod := c.Modules[name]
2388+
c.forEachModule(func(mod *Module) {
2389+
if regoMetadataCalled {
2390+
return
2391+
}
23852392
WalkExprs(mod, func(expr *Expr) bool {
23862393
if isRegoMetadataChainCall(expr) || isRegoMetadataRuleCall(expr) {
23872394
regoMetadataCalled = true
23882395
}
23892396
return regoMetadataCalled
23902397
})
2391-
2392-
if regoMetadataCalled {
2393-
break
2394-
}
2395-
}
2398+
})
23962399

23972400
if regoMetadataCalled {
23982401
// NOTE: Possible optimization: only parse annotations for modules on the path of rego.metadata-calling module
2399-
for _, name := range c.sorted {
2400-
mod := c.Modules[name]
2401-
2402+
c.forEachModule(func(mod *Module) {
24022403
if len(mod.Annotations) == 0 {
24032404
var errs Errors
24042405
mod.Annotations, errs = parseAnnotations(mod.Comments)
@@ -2409,7 +2410,7 @@ func (c *Compiler) parseMetadataBlocks() {
24092410

24102411
attachRuleAnnotations(mod)
24112412
}
2412-
}
2413+
})
24132414
}
24142415
}
24152416

@@ -2419,9 +2420,7 @@ func (c *Compiler) rewriteRegoMetadataCalls() {
24192420
_, chainFuncAllowed := c.builtins[RegoMetadataChain.Name]
24202421
_, ruleFuncAllowed := c.builtins[RegoMetadataRule.Name]
24212422

2422-
for _, name := range c.sorted {
2423-
mod := c.Modules[name]
2424-
2423+
c.forEachModule(func(mod *Module) {
24252424
WalkRules(mod, func(rule *Rule) bool {
24262425
var firstChainCall *Expr
24272426
var firstRuleCall *Expr
@@ -2499,7 +2498,7 @@ func (c *Compiler) rewriteRegoMetadataCalls() {
24992498

25002499
return false
25012500
})
2502-
}
2501+
})
25032502
}
25042503

25052504
func getPrimaryRuleAnnotations(as *AnnotationSet, rule *Rule) *Annotations {
@@ -4333,27 +4332,27 @@ const LocalVarPrefix = "__local"
43334332
type localVarGenerator struct {
43344333
exclude VarSet
43354334
suffix string
4336-
next int
4335+
next *atomic.Int32
43374336
}
43384337

43394338
func newLocalVarGeneratorForModuleSet(sorted []string, modules map[string]*Module) *localVarGenerator {
43404339
vis := NewVarVisitor()
43414340
for _, key := range sorted {
43424341
vis.Walk(modules[key])
43434342
}
4344-
return &localVarGenerator{exclude: vis.vars, next: 0}
4343+
return &localVarGenerator{exclude: vis.vars, next: &atomic.Int32{}}
43454344
}
43464345

43474346
func newLocalVarGenerator(suffix string, node any) *localVarGenerator {
43484347
vis := NewVarVisitor()
43494348
vis.Walk(node)
4350-
return &localVarGenerator{exclude: vis.vars, suffix: suffix, next: 0}
4349+
return &localVarGenerator{exclude: vis.vars, suffix: suffix, next: &atomic.Int32{}}
43514350
}
43524351

43534352
func (l *localVarGenerator) Generate() Var {
43544353
for {
4355-
result := Var(LocalVarPrefix + l.suffix + strconv.Itoa(l.next) + "__")
4356-
l.next++
4354+
next := l.next.Add(1) - 1 // we want the old number
4355+
result := Var(LocalVarPrefix + l.suffix + strconv.Itoa(int(next)) + "__")
43574356
if !l.exclude.Contains(result) {
43584357
return result
43594358
}

0 commit comments

Comments
 (0)