From 47686533df1c3ba4944768dc1795595e4a9858de Mon Sep 17 00:00:00 2001 From: "oleksandr.yershov" Date: Thu, 27 Nov 2025 13:56:14 +0200 Subject: [PATCH 1/3] optimize: Enhance performance of various methods and add benchmarks Signed-off-by: oleksandr.yershov --- v1/ast/compile.go | 9 ++- v1/ast/optimization_bench_test.go | 129 ++++++++++++++++++++++++++++++ v1/ast/policy.go | 78 +++++++++++++----- v1/ast/term.go | 59 +++++++++++--- 4 files changed, 241 insertions(+), 34 deletions(-) create mode 100644 v1/ast/optimization_bench_test.go diff --git a/v1/ast/compile.go b/v1/ast/compile.go index e2492ecb80..e351be43ed 100644 --- a/v1/ast/compile.go +++ b/v1/ast/compile.go @@ -667,6 +667,9 @@ func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) { } func extractRules(s []any) []*Rule { + if len(s) == 0 { + return nil + } rules := make([]*Rule, len(s)) for i := range s { rules[i] = s[i].(*Rule) @@ -691,7 +694,6 @@ func extractRules(s []any) []*Rule { // GetRules("data.a.b.c") => [rule1, rule2] // GetRules("data.a.b.d") => nil func (c *Compiler) GetRules(ref Ref) (rules []*Rule) { - set := map[*Rule]struct{}{} for _, rule := range c.GetRulesForVirtualDocument(ref) { @@ -702,10 +704,13 @@ func (c *Compiler) GetRules(ref Ref) (rules []*Rule) { set[rule] = struct{}{} } + if len(set) == 0 { + return nil + } + rules = make([]*Rule, 0, len(set)) for rule := range set { rules = append(rules, rule) } - return rules } diff --git a/v1/ast/optimization_bench_test.go b/v1/ast/optimization_bench_test.go new file mode 100644 index 0000000000..b873872bf3 --- /dev/null +++ b/v1/ast/optimization_bench_test.go @@ -0,0 +1,129 @@ +// Copyright 2025 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "testing" +) + +// BenchmarkRefPtr benchmarks the optimized Ref.Ptr() method +func BenchmarkRefPtr(b *testing.B) { + ref := MustParseRef("data.foo.bar.baz.qux") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ref.Ptr() + } +} + +// BenchmarkArgsString benchmarks the optimized Args.String() method +func BenchmarkArgsString(b *testing.B) { + args := Args{ + StringTerm("arg1"), + StringTerm("arg2"), + StringTerm("arg3"), + NumberTerm("42"), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = args.String() + } +} + +// BenchmarkBodyString benchmarks the optimized Body.String() method +func BenchmarkBodyString(b *testing.B) { + body := MustParseBody("x := 1; y := 2; z := x + y; a := z * 2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = body.String() + } +} + +// BenchmarkExprString benchmarks the optimized Expr.String() method +func BenchmarkExprString(b *testing.B) { + expr := MustParseExpr("x = y + z with input.foo as 42 with data.bar as \"test\"") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = expr.String() + } +} + +// BenchmarkSetDiff benchmarks the optimized set.Diff() method +func BenchmarkSetDiff(b *testing.B) { + s1 := NewSet() + s2 := NewSet() + for i := 0; i < 100; i++ { + s1.Add(IntNumberTerm(i)) + if i%2 == 0 { + s2.Add(IntNumberTerm(i)) + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = s1.Diff(s2) + } +} + +// BenchmarkSetIntersect benchmarks the optimized set.Intersect() method +func BenchmarkSetIntersect(b *testing.B) { + s1 := NewSet() + s2 := NewSet() + for i := 0; i < 100; i++ { + s1.Add(IntNumberTerm(i)) + if i%2 == 0 { + s2.Add(IntNumberTerm(i)) + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = s1.Intersect(s2) + } +} + +// BenchmarkObjectKeys benchmarks the optimized object.Keys() method +func BenchmarkObjectKeys(b *testing.B) { + obj := NewObject() + for i := 0; i < 50; i++ { + obj.Insert(StringTerm(string(rune('a'+i%26))+string(rune('0'+i/26))), IntNumberTerm(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = obj.Keys() + } +} + +// BenchmarkGetRules benchmarks the optimized Compiler.GetRules() method +func BenchmarkGetRules(b *testing.B) { + module := ` + package test + + p[x] { x := 1 } + p[x] { x := 2 } + q[x] { x := 3 } + r := 4 + ` + + c := NewCompiler() + c.Compile(map[string]*Module{ + "test.rego": MustParseModule(module), + }) + + if c.Failed() { + b.Fatal(c.Errors) + } + + ref := MustParseRef("data.test.p") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = c.GetRules(ref) + } +} diff --git a/v1/ast/policy.go b/v1/ast/policy.go index 62c82f51ec..85e0d94a9a 100644 --- a/v1/ast/policy.go +++ b/v1/ast/policy.go @@ -1095,11 +1095,21 @@ func (a Args) Copy() Args { } func (a Args) String() string { - buf := make([]string, 0, len(a)) - for _, t := range a { - buf = append(buf, t.String()) + if len(a) == 0 { + return "()" + } + sb := sbPool.Get() + defer sbPool.Put(sb) + sb.Grow(len(a) * 10) + sb.WriteByte('(') + for i, t := range a { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(t.String()) } - return "(" + strings.Join(buf, ", ") + ")" + sb.WriteByte(')') + return sb.String() } // Loc returns the Location of a. @@ -1232,11 +1242,22 @@ func (body Body) SetLoc(loc *Location) { } func (body Body) String() string { - buf := make([]string, 0, len(body)) - for _, v := range body { - buf = append(buf, v.String()) + if len(body) == 0 { + return "" + } + if len(body) == 1 { + return body[0].String() + } + sb := sbPool.Get() + defer sbPool.Put(sb) + sb.Grow(len(body) * 20) + for i, v := range body { + if i > 0 { + sb.WriteString("; ") + } + sb.WriteString(v.String()) } - return strings.Join(buf, "; ") + return sb.String() } // Vars returns a VarSet containing variables in body. The params can be set to @@ -1547,26 +1568,34 @@ func (expr *Expr) SetLoc(loc *Location) { } func (expr *Expr) String() string { - buf := make([]string, 0, 2+len(expr.With)) + sb := sbPool.Get() + defer sbPool.Put(sb) + sb.Grow(32 + len(expr.With)*20) + if expr.Negated { - buf = append(buf, "not") + sb.WriteString("not ") } switch t := expr.Terms.(type) { case []*Term: if expr.IsEquality() && validEqAssignArgCount(expr) { - buf = append(buf, fmt.Sprintf("%v %v %v", t[1], Equality.Infix, t[2])) + sb.WriteString(t[1].String()) + sb.WriteByte(' ') + sb.WriteString(Equality.Infix) + sb.WriteByte(' ') + sb.WriteString(t[2].String()) } else { - buf = append(buf, Call(t).String()) + sb.WriteString(Call(t).String()) } case fmt.Stringer: - buf = append(buf, t.String()) + sb.WriteString(t.String()) } for i := range expr.With { - buf = append(buf, expr.With[i].String()) + sb.WriteByte(' ') + sb.WriteString(expr.With[i].String()) } - return strings.Join(buf, " ") + return sb.String() } func (expr *Expr) MarshalJSON() ([]byte, error) { @@ -1666,11 +1695,22 @@ func (d *SomeDecl) String() string { } return "some " + call[1].String() + " in " + call[2].String() } - buf := make([]string, len(d.Symbols)) - for i := range buf { - buf[i] = d.Symbols[i].String() + if len(d.Symbols) == 0 { + return "some" + } + if len(d.Symbols) == 1 { + return "some " + d.Symbols[0].String() + } + sb := sbPool.Get() + defer sbPool.Put(sb) + sb.WriteString("some ") + for i := range d.Symbols { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(d.Symbols[i].String()) } - return "some " + strings.Join(buf, ", ") + return sb.String() } // SetLoc sets the Location on d. diff --git a/v1/ast/term.go b/v1/ast/term.go index 62779e63e1..d62b92eb8c 100644 --- a/v1/ast/term.go +++ b/v1/ast/term.go @@ -1158,15 +1158,23 @@ func (ref Ref) IsNested() bool { // contains non-string terms this function returns an error. Path // components are escaped. func (ref Ref) Ptr() (string, error) { - parts := make([]string, 0, len(ref)-1) - for _, term := range ref[1:] { + if len(ref) <= 1 { + return "", nil + } + sb := sbPool.Get() + defer sbPool.Put(sb) + sb.Grow(len(ref) * 8) // Estimate average component size + for i, term := range ref[1:] { if str, ok := term.Value.(String); ok { - parts = append(parts, url.PathEscape(string(str))) + if i > 0 { + sb.WriteByte('/') + } + sb.WriteString(url.PathEscape(string(str))) } else { return "", errors.New("invalid path value type") } } - return strings.Join(parts, "/"), nil + return sb.String(), nil } var varRegexp = regexp.MustCompile("^[[:alpha:]_][[:alpha:][:digit:]_]*$") @@ -1232,12 +1240,12 @@ func (ref Ref) OutputVars() VarSet { } func (ref Ref) toArray() *Array { - terms := make([]*Term, 0, len(ref)) - for _, term := range ref { + terms := make([]*Term, len(ref)) + for i, term := range ref { if _, ok := term.Value.(String); ok { - terms = append(terms, term) + terms[i] = term } else { - terms = append(terms, InternedTerm(term.Value.String())) + terms[i] = InternedTerm(term.Value.String()) } } return NewArray(terms...) @@ -1647,19 +1655,28 @@ func (s *set) Diff(other Set) Set { if s.Compare(other) == 0 { return NewSet() } + if s.Len() == 0 { + return NewSet() + } - terms := make([]*Term, 0, len(s.keys)) + terms := make([]*Term, 0, s.Len()) for _, term := range s.sortedKeys() { if !other.Contains(term) { terms = append(terms, term) } } + if len(terms) == 0 { + return NewSet() + } return NewSet(terms...) } // Intersect returns the set containing elements in both s and other. func (s *set) Intersect(other Set) Set { + if s.Len() == 0 || other.Len() == 0 { + return NewSet() + } o := other.(*set) n, m := s.Len(), o.Len() ss := s @@ -1677,6 +1694,9 @@ func (s *set) Intersect(other Set) Set { } } + if len(terms) == 0 { + return NewSet() + } return NewSet(terms...) } @@ -1718,7 +1738,10 @@ func (s *set) Foreach(f func(*Term)) { // Map returns a new Set obtained by applying f to each value in s. func (s *set) Map(f func(*Term) (*Term, error)) (Set, error) { - mapped := make([]*Term, 0, len(s.keys)) + if s.Len() == 0 { + return NewSet(), nil + } + mapped := make([]*Term, 0, s.Len()) for _, x := range s.sortedKeys() { term, err := f(x) if err != nil { @@ -2411,12 +2434,13 @@ func (obj *object) Map(f func(*Term, *Term) (*Term, *Term, error)) (Object, erro // Keys returns the keys of obj. func (obj *object) Keys() []*Term { - keys := make([]*Term, len(obj.keys)) - + if obj.Len() == 0 { + return nil + } + keys := make([]*Term, obj.Len()) for i, elem := range obj.sortedKeys() { keys[i] = elem.key } - return keys } @@ -2769,6 +2793,9 @@ func filterObject(o Value, filter Value) (Value, error) { case String, Number, Boolean, Null: return o, nil case *Array: + if v.Len() == 0 { + return v, nil + } values := NewArray() for i := range v.Len() { subFilter := filteredObj.Get(InternedIntegerString(i)) @@ -2782,6 +2809,9 @@ func filterObject(o Value, filter Value) (Value, error) { } return values, nil case Set: + if v.Len() == 0 { + return v, nil + } terms := make([]*Term, 0, v.Len()) for _, t := range v.Slice() { if filteredObj.Get(t) != nil { @@ -2792,6 +2822,9 @@ func filterObject(o Value, filter Value) (Value, error) { terms = append(terms, NewTerm(filteredValue)) } } + if len(terms) == 0 { + return NewSet(), nil + } return NewSet(terms...), nil case *object: values := NewObject() From 53669cde9721d9fdb7517f90a25923756b92075c Mon Sep 17 00:00:00 2001 From: "oleksandr.yershov" Date: Thu, 27 Nov 2025 14:35:25 +0200 Subject: [PATCH 2/3] fix: update benchmark loops to use Go 1.22+ integer range syntax Replace traditional for loops with range-based loops in optimization benchmarks to satisfy the intrange linter and improve code readability. Signed-off-by: oleksandr.yershov --- v1/ast/optimization_bench_test.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/v1/ast/optimization_bench_test.go b/v1/ast/optimization_bench_test.go index b873872bf3..58f0d87668 100644 --- a/v1/ast/optimization_bench_test.go +++ b/v1/ast/optimization_bench_test.go @@ -13,7 +13,7 @@ func BenchmarkRefPtr(b *testing.B) { ref := MustParseRef("data.foo.bar.baz.qux") b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _, _ = ref.Ptr() } } @@ -28,7 +28,7 @@ func BenchmarkArgsString(b *testing.B) { } b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _ = args.String() } } @@ -38,7 +38,7 @@ func BenchmarkBodyString(b *testing.B) { body := MustParseBody("x := 1; y := 2; z := x + y; a := z * 2") b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _ = body.String() } } @@ -48,7 +48,7 @@ func BenchmarkExprString(b *testing.B) { expr := MustParseExpr("x = y + z with input.foo as 42 with data.bar as \"test\"") b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _ = expr.String() } } @@ -57,7 +57,7 @@ func BenchmarkExprString(b *testing.B) { func BenchmarkSetDiff(b *testing.B) { s1 := NewSet() s2 := NewSet() - for i := 0; i < 100; i++ { + for i := range 100 { s1.Add(IntNumberTerm(i)) if i%2 == 0 { s2.Add(IntNumberTerm(i)) @@ -65,7 +65,7 @@ func BenchmarkSetDiff(b *testing.B) { } b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _ = s1.Diff(s2) } } @@ -74,7 +74,7 @@ func BenchmarkSetDiff(b *testing.B) { func BenchmarkSetIntersect(b *testing.B) { s1 := NewSet() s2 := NewSet() - for i := 0; i < 100; i++ { + for i := range 100 { s1.Add(IntNumberTerm(i)) if i%2 == 0 { s2.Add(IntNumberTerm(i)) @@ -82,7 +82,7 @@ func BenchmarkSetIntersect(b *testing.B) { } b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _ = s1.Intersect(s2) } } @@ -90,12 +90,12 @@ func BenchmarkSetIntersect(b *testing.B) { // BenchmarkObjectKeys benchmarks the optimized object.Keys() method func BenchmarkObjectKeys(b *testing.B) { obj := NewObject() - for i := 0; i < 50; i++ { + for i := range 50 { obj.Insert(StringTerm(string(rune('a'+i%26))+string(rune('0'+i/26))), IntNumberTerm(i)) } b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _ = obj.Keys() } } @@ -123,7 +123,7 @@ func BenchmarkGetRules(b *testing.B) { ref := MustParseRef("data.test.p") b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _ = c.GetRules(ref) } } From 7682dfeb461613beeed4116cb63e4791494d5b4b Mon Sep 17 00:00:00 2001 From: "oleksandr.yershov" Date: Fri, 28 Nov 2025 16:11:04 +0200 Subject: [PATCH 3/3] fix: add proper error handling in v1/ast optimization benchmarks Signed-off-by: oleksandr.yershov --- v1/ast/optimization_bench_test.go | 33 +++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/v1/ast/optimization_bench_test.go b/v1/ast/optimization_bench_test.go index 58f0d87668..511a39850b 100644 --- a/v1/ast/optimization_bench_test.go +++ b/v1/ast/optimization_bench_test.go @@ -10,7 +10,10 @@ import ( // BenchmarkRefPtr benchmarks the optimized Ref.Ptr() method func BenchmarkRefPtr(b *testing.B) { - ref := MustParseRef("data.foo.bar.baz.qux") + ref, err := ParseRef("data.foo.bar.baz.qux") + if err != nil { + b.Fatal(err) + } b.ResetTimer() for range b.N { @@ -35,7 +38,10 @@ func BenchmarkArgsString(b *testing.B) { // BenchmarkBodyString benchmarks the optimized Body.String() method func BenchmarkBodyString(b *testing.B) { - body := MustParseBody("x := 1; y := 2; z := x + y; a := z * 2") + body, err := ParseBody("x := 1; y := 2; z := x + y; a := z * 2") + if err != nil { + b.Fatal(err) + } b.ResetTimer() for range b.N { @@ -45,7 +51,10 @@ func BenchmarkBodyString(b *testing.B) { // BenchmarkExprString benchmarks the optimized Expr.String() method func BenchmarkExprString(b *testing.B) { - expr := MustParseExpr("x = y + z with input.foo as 42 with data.bar as \"test\"") + expr, err := ParseExpr("x = y + z with input.foo as 42 with data.bar as \"test\"") + if err != nil { + b.Fatal(err) + } b.ResetTimer() for range b.N { @@ -105,22 +114,30 @@ func BenchmarkGetRules(b *testing.B) { module := ` package test - p[x] { x := 1 } - p[x] { x := 2 } - q[x] { x := 3 } + p contains x if { x := 1 } + p contains x if { x := 2 } + q contains x if { x := 3 } r := 4 ` + mod, err := ParseModule("test.rego", module) + if err != nil { + b.Fatal(err) + } + c := NewCompiler() c.Compile(map[string]*Module{ - "test.rego": MustParseModule(module), + "test.rego": mod, }) if c.Failed() { b.Fatal(c.Errors) } - ref := MustParseRef("data.test.p") + ref, err := ParseRef("data.test.p") + if err != nil { + b.Fatal(err) + } b.ResetTimer() for range b.N {