@@ -13,6 +13,8 @@ import (
1313 "sort"
1414 "strconv"
1515 "strings"
16+ "sync"
17+ "sync/atomic"
1618
1719 "github.com/open-policy-agent/opa/internal/debug"
1820 "github.com/open-policy-agent/opa/internal/gojsonschema"
@@ -1756,6 +1758,15 @@ func (c *Compiler) init() {
17561758 c .initialized = true
17571759}
17581760
1761+ func (c * Compiler ) forEachModule (f func (mod * Module )) {
1762+ wg := & sync.WaitGroup {}
1763+ for _ , name := range c .sorted {
1764+ wg .Go (func () { f (c .Modules [name ]) })
1765+ }
1766+ wg .Wait ()
1767+ }
1768+
1769+ // TODO(sr): Fix this, it's not concurrency-safe. And the panic is bad form, too.
17591770func (c * Compiler ) err (err * Error ) {
17601771 if c .maxErrs > 0 && len (c .Errors ) >= c .maxErrs {
17611772 c .Errors = append (c .Errors , errLimitReached )
@@ -1957,7 +1968,7 @@ func (c *Compiler) resolveAllRefs() {
19571968
19581969func (c * Compiler ) removeImports () {
19591970 c .imports = make (map [string ][]* Import , len (c .Modules ))
1960- for name := range c .Modules {
1971+ for name := range c .Modules { // Trivial. No fan-out for this.
19611972 c .imports [name ] = c .Modules [name ].Imports
19621973 c .Modules [name ].Imports = nil
19631974 }
@@ -1969,27 +1980,25 @@ func (c *Compiler) initLocalVarGen() {
19691980
19701981func (c * Compiler ) rewriteComprehensionTerms () {
19711982 f := newEqualityFactory (c .localvargen )
1972- for _ , name := range c .sorted {
1973- mod := c .Modules [name ]
1983+ c .forEachModule (func (mod * Module ) {
19741984 _ , _ = rewriteComprehensionTerms (f , mod ) // ignore error
1975- }
1985+ })
19761986}
19771987
19781988func (c * Compiler ) rewriteExprTerms () {
1979- for _ , name := range c .sorted {
1980- mod := c .Modules [name ]
1989+ c .forEachModule (func (mod * Module ) {
19811990 WalkRules (mod , func (rule * Rule ) bool {
19821991 rewriteExprTermsInHead (c .localvargen , rule )
19831992 rule .Body = rewriteExprTermsInBody (c .localvargen , rule .Body )
19841993 return false
19851994 })
1986- }
1995+ })
19871996}
19881997
19891998func (c * Compiler ) rewriteRuleHeadRefs () {
19901999 f := newEqualityFactory (c .localvargen )
1991- for _ , name := range c . sorted {
1992- WalkRules (c . Modules [ name ] , func (rule * Rule ) bool {
2000+ c . forEachModule ( func ( mod * Module ) {
2001+ WalkRules (mod , func (rule * Rule ) bool {
19932002
19942003 ref := rule .Head .Ref ()
19952004 // NOTE(sr): We're backfilling Refs here -- all parser code paths would have them, but
@@ -2042,16 +2051,15 @@ func (c *Compiler) rewriteRuleHeadRefs() {
20422051
20432052 return true
20442053 })
2045- }
2054+ })
20462055}
20472056
20482057func (c * Compiler ) checkVoidCalls () {
2049- for _ , name := range c .sorted {
2050- mod := c .Modules [name ]
2058+ c .forEachModule (func (mod * Module ) {
20512059 for _ , err := range checkVoidCalls (c .TypeEnv , mod ) {
20522060 c .err (err )
20532061 }
2054- }
2062+ })
20552063}
20562064
20572065func (c * Compiler ) rewritePrintCalls () {
@@ -2063,8 +2071,7 @@ func (c *Compiler) rewritePrintCalls() {
20632071 }
20642072 }
20652073 } else {
2066- for _ , name := range c .sorted {
2067- mod := c .Modules [name ]
2074+ c .forEachModule (func (mod * Module ) {
20682075 WalkRules (mod , func (r * Rule ) bool {
20692076 safe := r .Head .Args .Vars ()
20702077 safe .Update (ReservedVars )
@@ -2082,7 +2089,7 @@ func (c *Compiler) rewritePrintCalls() {
20822089 WalkBodies (r .Body , vis )
20832090 return false
20842091 })
2085- }
2092+ })
20862093 }
20872094 if modified {
20882095 c .Required .addBuiltinSorted (Print )
@@ -2295,8 +2302,7 @@ func isPrintCall(x *Expr) bool {
22952302// p[__local0__] { i < 100; __local0__ = {"foo": data.foo[i]} }
22962303func (c * Compiler ) rewriteRefsInHead () {
22972304 f := newEqualityFactory (c .localvargen )
2298- for _ , name := range c .sorted {
2299- mod := c .Modules [name ]
2305+ c .forEachModule (func (mod * Module ) {
23002306 WalkRules (mod , func (rule * Rule ) bool {
23012307 if requiresEval (rule .Head .Key ) {
23022308 expr := f .Generate (rule .Head .Key )
@@ -2317,27 +2323,27 @@ func (c *Compiler) rewriteRefsInHead() {
23172323 }
23182324 return false
23192325 })
2320- }
2326+ })
23212327}
23222328
23232329func (c * Compiler ) rewriteEquals () {
23242330 modified := false
2325- for _ , name := range c . sorted {
2326- modified = rewriteEquals (c . Modules [ name ] ) || modified
2327- }
2331+ c . forEachModule ( func ( mod * Module ) {
2332+ modified = rewriteEquals (mod ) || modified
2333+ })
23282334 if modified {
23292335 c .Required .addBuiltinSorted (Equal )
23302336 }
23312337}
23322338
23332339func (c * Compiler ) rewriteDynamicTerms () {
23342340 f := newEqualityFactory (c .localvargen )
2335- for _ , name := range c . sorted {
2336- WalkRules (c . Modules [ name ] , func (rule * Rule ) bool {
2341+ c . forEachModule ( func ( mod * Module ) {
2342+ WalkRules (mod , func (rule * Rule ) bool {
23372343 rule .Body = rewriteDynamics (f , rule .Body )
23382344 return false
23392345 })
2340- }
2346+ })
23412347}
23422348
23432349// rewriteTestRuleEqualities rewrites equality expressions in test rule bodies to create local vars for statements that would otherwise
@@ -2366,39 +2372,34 @@ func (c *Compiler) rewriteTestRuleEqualities() {
23662372 }
23672373
23682374 f := newEqualityFactory (c .localvargen )
2369- for _ , name := range c .sorted {
2370- mod := c .Modules [name ]
2375+ c .forEachModule (func (mod * Module ) {
23712376 WalkRules (mod , func (rule * Rule ) bool {
23722377 if strings .HasPrefix (string (rule .Head .Name ), "test_" ) {
23732378 rule .Body = rewriteTestEqualities (f , rule .Body )
23742379 }
23752380 return false
23762381 })
2377- }
2382+ })
23782383}
23792384
23802385func (c * Compiler ) parseMetadataBlocks () {
23812386 // Only parse annotations if rego.metadata built-ins are called
23822387 regoMetadataCalled := false
2383- for _ , name := range c .sorted {
2384- mod := c .Modules [name ]
2388+ c .forEachModule (func (mod * Module ) {
2389+ if regoMetadataCalled {
2390+ return
2391+ }
23852392 WalkExprs (mod , func (expr * Expr ) bool {
23862393 if isRegoMetadataChainCall (expr ) || isRegoMetadataRuleCall (expr ) {
23872394 regoMetadataCalled = true
23882395 }
23892396 return regoMetadataCalled
23902397 })
2391-
2392- if regoMetadataCalled {
2393- break
2394- }
2395- }
2398+ })
23962399
23972400 if regoMetadataCalled {
23982401 // NOTE: Possible optimization: only parse annotations for modules on the path of rego.metadata-calling module
2399- for _ , name := range c .sorted {
2400- mod := c .Modules [name ]
2401-
2402+ c .forEachModule (func (mod * Module ) {
24022403 if len (mod .Annotations ) == 0 {
24032404 var errs Errors
24042405 mod .Annotations , errs = parseAnnotations (mod .Comments )
@@ -2409,7 +2410,7 @@ func (c *Compiler) parseMetadataBlocks() {
24092410
24102411 attachRuleAnnotations (mod )
24112412 }
2412- }
2413+ })
24132414 }
24142415}
24152416
@@ -2419,9 +2420,7 @@ func (c *Compiler) rewriteRegoMetadataCalls() {
24192420 _ , chainFuncAllowed := c .builtins [RegoMetadataChain .Name ]
24202421 _ , ruleFuncAllowed := c .builtins [RegoMetadataRule .Name ]
24212422
2422- for _ , name := range c .sorted {
2423- mod := c .Modules [name ]
2424-
2423+ c .forEachModule (func (mod * Module ) {
24252424 WalkRules (mod , func (rule * Rule ) bool {
24262425 var firstChainCall * Expr
24272426 var firstRuleCall * Expr
@@ -2499,7 +2498,7 @@ func (c *Compiler) rewriteRegoMetadataCalls() {
24992498
25002499 return false
25012500 })
2502- }
2501+ })
25032502}
25042503
25052504func getPrimaryRuleAnnotations (as * AnnotationSet , rule * Rule ) * Annotations {
@@ -4333,27 +4332,27 @@ const LocalVarPrefix = "__local"
43334332type localVarGenerator struct {
43344333 exclude VarSet
43354334 suffix string
4336- next int
4335+ next * atomic. Int32
43374336}
43384337
43394338func newLocalVarGeneratorForModuleSet (sorted []string , modules map [string ]* Module ) * localVarGenerator {
43404339 vis := NewVarVisitor ()
43414340 for _ , key := range sorted {
43424341 vis .Walk (modules [key ])
43434342 }
4344- return & localVarGenerator {exclude : vis .vars , next : 0 }
4343+ return & localVarGenerator {exclude : vis .vars , next : & atomic. Int32 {} }
43454344}
43464345
43474346func newLocalVarGenerator (suffix string , node any ) * localVarGenerator {
43484347 vis := NewVarVisitor ()
43494348 vis .Walk (node )
4350- return & localVarGenerator {exclude : vis .vars , suffix : suffix , next : 0 }
4349+ return & localVarGenerator {exclude : vis .vars , suffix : suffix , next : & atomic. Int32 {} }
43514350}
43524351
43534352func (l * localVarGenerator ) Generate () Var {
43544353 for {
4355- result := Var ( LocalVarPrefix + l . suffix + strconv . Itoa ( l . next ) + "__" )
4356- l . next ++
4354+ next := l . next . Add ( 1 ) - 1 // we want the old number
4355+ result := Var ( LocalVarPrefix + l . suffix + strconv . Itoa ( int ( next )) + "__" )
43574356 if ! l .exclude .Contains (result ) {
43584357 return result
43594358 }
0 commit comments