Skip to content
9 changes: 7 additions & 2 deletions v1/ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,9 @@ func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) {
}

func extractRules(s []any) []*Rule {
if len(s) == 0 {
return nil
}
rules := make([]*Rule, len(s))
for i := range s {
rules[i] = s[i].(*Rule)
Expand All @@ -697,7 +700,6 @@ func extractRules(s []any) []*Rule {
// GetRules("data.a.b.c") => [rule1, rule2]
// GetRules("data.a.b.d") => nil
func (c *Compiler) GetRules(ref Ref) (rules []*Rule) {

set := map[*Rule]struct{}{}

for _, rule := range c.GetRulesForVirtualDocument(ref) {
Expand All @@ -708,10 +710,13 @@ func (c *Compiler) GetRules(ref Ref) (rules []*Rule) {
set[rule] = struct{}{}
}

if len(set) == 0 {
return nil
}
rules = make([]*Rule, 0, len(set))
for rule := range set {
rules = append(rules, rule)
}

return rules
}

Expand Down
146 changes: 146 additions & 0 deletions v1/ast/optimization_bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Copyright 2025 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.

package ast

import (
"testing"
)

// BenchmarkRefPtr benchmarks the optimized Ref.Ptr() method
func BenchmarkRefPtr(b *testing.B) {
ref, err := ParseRef("data.foo.bar.baz.qux")
if err != nil {
b.Fatal(err)
}

b.ResetTimer()
for range b.N {
_, _ = ref.Ptr()
}
}

// BenchmarkArgsString benchmarks the optimized Args.String() method
func BenchmarkArgsString(b *testing.B) {
args := Args{
StringTerm("arg1"),
StringTerm("arg2"),
StringTerm("arg3"),
NumberTerm("42"),
}

b.ResetTimer()
for range b.N {
_ = args.String()
}
}

// BenchmarkBodyString benchmarks the optimized Body.String() method
func BenchmarkBodyString(b *testing.B) {
body, err := ParseBody("x := 1; y := 2; z := x + y; a := z * 2")
if err != nil {
b.Fatal(err)
}

b.ResetTimer()
for range b.N {
_ = body.String()
}
}

// BenchmarkExprString benchmarks the optimized Expr.String() method
func BenchmarkExprString(b *testing.B) {
expr, err := ParseExpr("x = y + z with input.foo as 42 with data.bar as \"test\"")
if err != nil {
b.Fatal(err)
}

b.ResetTimer()
for range b.N {
_ = expr.String()
}
}

// BenchmarkSetDiff benchmarks the optimized set.Diff() method
func BenchmarkSetDiff(b *testing.B) {
s1 := NewSet()
s2 := NewSet()
for i := range 100 {
s1.Add(IntNumberTerm(i))
if i%2 == 0 {
s2.Add(IntNumberTerm(i))
}
}

b.ResetTimer()
for range b.N {
_ = s1.Diff(s2)
}
}

// BenchmarkSetIntersect benchmarks the optimized set.Intersect() method
func BenchmarkSetIntersect(b *testing.B) {
s1 := NewSet()
s2 := NewSet()
for i := range 100 {
s1.Add(IntNumberTerm(i))
if i%2 == 0 {
s2.Add(IntNumberTerm(i))
}
}

b.ResetTimer()
for range b.N {
_ = s1.Intersect(s2)
}
}

// BenchmarkObjectKeys benchmarks the optimized object.Keys() method
func BenchmarkObjectKeys(b *testing.B) {
obj := NewObject()
for i := range 50 {
obj.Insert(StringTerm(string(rune('a'+i%26))+string(rune('0'+i/26))), IntNumberTerm(i))
}

b.ResetTimer()
for range b.N {
_ = obj.Keys()
}
}

// BenchmarkGetRules benchmarks the optimized Compiler.GetRules() method
func BenchmarkGetRules(b *testing.B) {
module := `
package test

p contains x if { x := 1 }
p contains x if { x := 2 }
q contains x if { x := 3 }
r := 4
`

mod, err := ParseModule("test.rego", module)
if err != nil {
b.Fatal(err)
}

c := NewCompiler()
c.Compile(map[string]*Module{
"test.rego": mod,
})

if c.Failed() {
b.Fatal(c.Errors)
}

ref, err := ParseRef("data.test.p")
if err != nil {
b.Fatal(err)
}

b.ResetTimer()
for range b.N {
_ = c.GetRules(ref)
}
}
78 changes: 59 additions & 19 deletions v1/ast/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -1102,11 +1102,21 @@ func (a Args) Copy() Args {
}

func (a Args) String() string {
buf := make([]string, 0, len(a))
for _, t := range a {
buf = append(buf, t.String())
if len(a) == 0 {
return "()"
}
sb := sbPool.Get()
defer sbPool.Put(sb)
sb.Grow(len(a) * 10)
sb.WriteByte('(')
for i, t := range a {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(t.String())
}
return "(" + strings.Join(buf, ", ") + ")"
sb.WriteByte(')')
return sb.String()
}

// Loc returns the Location of a.
Expand Down Expand Up @@ -1239,11 +1249,22 @@ func (body Body) SetLoc(loc *Location) {
}

func (body Body) String() string {
buf := make([]string, 0, len(body))
for _, v := range body {
buf = append(buf, v.String())
if len(body) == 0 {
return ""
}
if len(body) == 1 {
return body[0].String()
}
sb := sbPool.Get()
defer sbPool.Put(sb)
sb.Grow(len(body) * 20)
for i, v := range body {
if i > 0 {
sb.WriteString("; ")
}
sb.WriteString(v.String())
}
return strings.Join(buf, "; ")
return sb.String()
}

// Vars returns a VarSet containing variables in body. The params can be set to
Expand Down Expand Up @@ -1554,26 +1575,34 @@ func (expr *Expr) SetLoc(loc *Location) {
}

func (expr *Expr) String() string {
buf := make([]string, 0, 2+len(expr.With))
sb := sbPool.Get()
defer sbPool.Put(sb)
sb.Grow(32 + len(expr.With)*20)

if expr.Negated {
buf = append(buf, "not")
sb.WriteString("not ")
}
switch t := expr.Terms.(type) {
case []*Term:
if expr.IsEquality() && validEqAssignArgCount(expr) {
buf = append(buf, fmt.Sprintf("%v %v %v", t[1], Equality.Infix, t[2]))
sb.WriteString(t[1].String())
sb.WriteByte(' ')
sb.WriteString(Equality.Infix)
sb.WriteByte(' ')
sb.WriteString(t[2].String())
} else {
buf = append(buf, Call(t).String())
sb.WriteString(Call(t).String())
}
case fmt.Stringer:
buf = append(buf, t.String())
sb.WriteString(t.String())
}

for i := range expr.With {
buf = append(buf, expr.With[i].String())
sb.WriteByte(' ')
sb.WriteString(expr.With[i].String())
}

return strings.Join(buf, " ")
return sb.String()
}

func (expr *Expr) MarshalJSON() ([]byte, error) {
Expand Down Expand Up @@ -1673,11 +1702,22 @@ func (d *SomeDecl) String() string {
}
return "some " + call[1].String() + " in " + call[2].String()
}
buf := make([]string, len(d.Symbols))
for i := range buf {
buf[i] = d.Symbols[i].String()
if len(d.Symbols) == 0 {
return "some"
}
if len(d.Symbols) == 1 {
return "some " + d.Symbols[0].String()
}
sb := sbPool.Get()
defer sbPool.Put(sb)
sb.WriteString("some ")
for i := range d.Symbols {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(d.Symbols[i].String())
}
return "some " + strings.Join(buf, ", ")
return sb.String()
}

// SetLoc sets the Location on d.
Expand Down
Loading