diff --git a/v1/ast/compile.go b/v1/ast/compile.go index 3084e970fb..94e857e5f1 100644 --- a/v1/ast/compile.go +++ b/v1/ast/compile.go @@ -673,6 +673,9 @@ func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) { } func extractRules(s []any) []*Rule { + if len(s) == 0 { + return nil + } rules := make([]*Rule, len(s)) for i := range s { rules[i] = s[i].(*Rule) @@ -697,7 +700,6 @@ func extractRules(s []any) []*Rule { // GetRules("data.a.b.c") => [rule1, rule2] // GetRules("data.a.b.d") => nil func (c *Compiler) GetRules(ref Ref) (rules []*Rule) { - set := map[*Rule]struct{}{} for _, rule := range c.GetRulesForVirtualDocument(ref) { @@ -708,10 +710,13 @@ func (c *Compiler) GetRules(ref Ref) (rules []*Rule) { set[rule] = struct{}{} } + if len(set) == 0 { + return nil + } + rules = make([]*Rule, 0, len(set)) for rule := range set { rules = append(rules, rule) } - return rules } diff --git a/v1/ast/optimization_bench_test.go b/v1/ast/optimization_bench_test.go new file mode 100644 index 0000000000..511a39850b --- /dev/null +++ b/v1/ast/optimization_bench_test.go @@ -0,0 +1,146 @@ +// Copyright 2025 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "testing" +) + +// BenchmarkRefPtr benchmarks the optimized Ref.Ptr() method +func BenchmarkRefPtr(b *testing.B) { + ref, err := ParseRef("data.foo.bar.baz.qux") + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for range b.N { + _, _ = ref.Ptr() + } +} + +// BenchmarkArgsString benchmarks the optimized Args.String() method +func BenchmarkArgsString(b *testing.B) { + args := Args{ + StringTerm("arg1"), + StringTerm("arg2"), + StringTerm("arg3"), + NumberTerm("42"), + } + + b.ResetTimer() + for range b.N { + _ = args.String() + } +} + +// BenchmarkBodyString benchmarks the optimized Body.String() method +func BenchmarkBodyString(b *testing.B) { + body, err := ParseBody("x := 1; y := 2; z := x + y; a := z * 2") + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for range b.N { + _ = body.String() + } +} + +// BenchmarkExprString benchmarks the optimized Expr.String() method +func BenchmarkExprString(b *testing.B) { + expr, err := ParseExpr("x = y + z with input.foo as 42 with data.bar as \"test\"") + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for range b.N { + _ = expr.String() + } +} + +// BenchmarkSetDiff benchmarks the optimized set.Diff() method +func BenchmarkSetDiff(b *testing.B) { + s1 := NewSet() + s2 := NewSet() + for i := range 100 { + s1.Add(IntNumberTerm(i)) + if i%2 == 0 { + s2.Add(IntNumberTerm(i)) + } + } + + b.ResetTimer() + for range b.N { + _ = s1.Diff(s2) + } +} + +// BenchmarkSetIntersect benchmarks the optimized set.Intersect() method +func BenchmarkSetIntersect(b *testing.B) { + s1 := NewSet() + s2 := NewSet() + for i := range 100 { + s1.Add(IntNumberTerm(i)) + if i%2 == 0 { + s2.Add(IntNumberTerm(i)) + } + } + + b.ResetTimer() + for range b.N { + _ = s1.Intersect(s2) + } +} + +// BenchmarkObjectKeys benchmarks the optimized object.Keys() method +func BenchmarkObjectKeys(b *testing.B) { + obj := NewObject() + for i := range 50 { + obj.Insert(StringTerm(string(rune('a'+i%26))+string(rune('0'+i/26))), IntNumberTerm(i)) + } + + b.ResetTimer() + for range b.N { + _ = obj.Keys() + } +} + +// BenchmarkGetRules benchmarks the optimized Compiler.GetRules() method +func BenchmarkGetRules(b *testing.B) { + module := ` + package test + + p contains x if { x := 1 } + p contains x if { x := 2 } + q contains x if { x := 3 } + r := 4 + ` + + mod, err := ParseModule("test.rego", module) + if err != nil { + b.Fatal(err) + } + + c := NewCompiler() + c.Compile(map[string]*Module{ + "test.rego": mod, + }) + + if c.Failed() { + b.Fatal(c.Errors) + } + + ref, err := ParseRef("data.test.p") + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for range b.N { + _ = c.GetRules(ref) + } +} diff --git a/v1/ast/policy.go b/v1/ast/policy.go index 8d34f3011b..8e54546ae5 100644 --- a/v1/ast/policy.go +++ b/v1/ast/policy.go @@ -1102,11 +1102,21 @@ func (a Args) Copy() Args { } func (a Args) String() string { - buf := make([]string, 0, len(a)) - for _, t := range a { - buf = append(buf, t.String()) + if len(a) == 0 { + return "()" + } + sb := sbPool.Get() + defer sbPool.Put(sb) + sb.Grow(len(a) * 10) + sb.WriteByte('(') + for i, t := range a { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(t.String()) } - return "(" + strings.Join(buf, ", ") + ")" + sb.WriteByte(')') + return sb.String() } // Loc returns the Location of a. @@ -1239,11 +1249,22 @@ func (body Body) SetLoc(loc *Location) { } func (body Body) String() string { - buf := make([]string, 0, len(body)) - for _, v := range body { - buf = append(buf, v.String()) + if len(body) == 0 { + return "" + } + if len(body) == 1 { + return body[0].String() + } + sb := sbPool.Get() + defer sbPool.Put(sb) + sb.Grow(len(body) * 20) + for i, v := range body { + if i > 0 { + sb.WriteString("; ") + } + sb.WriteString(v.String()) } - return strings.Join(buf, "; ") + return sb.String() } // Vars returns a VarSet containing variables in body. The params can be set to @@ -1554,26 +1575,34 @@ func (expr *Expr) SetLoc(loc *Location) { } func (expr *Expr) String() string { - buf := make([]string, 0, 2+len(expr.With)) + sb := sbPool.Get() + defer sbPool.Put(sb) + sb.Grow(32 + len(expr.With)*20) + if expr.Negated { - buf = append(buf, "not") + sb.WriteString("not ") } switch t := expr.Terms.(type) { case []*Term: if expr.IsEquality() && validEqAssignArgCount(expr) { - buf = append(buf, fmt.Sprintf("%v %v %v", t[1], Equality.Infix, t[2])) + sb.WriteString(t[1].String()) + sb.WriteByte(' ') + sb.WriteString(Equality.Infix) + sb.WriteByte(' ') + sb.WriteString(t[2].String()) } else { - buf = append(buf, Call(t).String()) + sb.WriteString(Call(t).String()) } case fmt.Stringer: - buf = append(buf, t.String()) + sb.WriteString(t.String()) } for i := range expr.With { - buf = append(buf, expr.With[i].String()) + sb.WriteByte(' ') + sb.WriteString(expr.With[i].String()) } - return strings.Join(buf, " ") + return sb.String() } func (expr *Expr) MarshalJSON() ([]byte, error) { @@ -1673,11 +1702,22 @@ func (d *SomeDecl) String() string { } return "some " + call[1].String() + " in " + call[2].String() } - buf := make([]string, len(d.Symbols)) - for i := range buf { - buf[i] = d.Symbols[i].String() + if len(d.Symbols) == 0 { + return "some" + } + if len(d.Symbols) == 1 { + return "some " + d.Symbols[0].String() + } + sb := sbPool.Get() + defer sbPool.Put(sb) + sb.WriteString("some ") + for i := range d.Symbols { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(d.Symbols[i].String()) } - return "some " + strings.Join(buf, ", ") + return sb.String() } // SetLoc sets the Location on d. diff --git a/v1/ast/term.go b/v1/ast/term.go index b6dec8da5c..d81471ffb7 100644 --- a/v1/ast/term.go +++ b/v1/ast/term.go @@ -1136,15 +1136,23 @@ func (ref Ref) IsNested() bool { // contains non-string terms this function returns an error. Path // components are escaped. func (ref Ref) Ptr() (string, error) { - parts := make([]string, 0, len(ref)-1) - for _, term := range ref[1:] { + if len(ref) <= 1 { + return "", nil + } + sb := sbPool.Get() + defer sbPool.Put(sb) + sb.Grow(len(ref) * 8) // Estimate average component size + for i, term := range ref[1:] { if str, ok := term.Value.(String); ok { - parts = append(parts, url.PathEscape(string(str))) + if i > 0 { + sb.WriteByte('/') + } + sb.WriteString(url.PathEscape(string(str))) } else { return "", errors.New("invalid path value type") } } - return strings.Join(parts, "/"), nil + return sb.String(), nil } var varRegexp = regexp.MustCompile("^[[:alpha:]_][[:alpha:][:digit:]_]*$") @@ -1220,12 +1228,12 @@ func (ref Ref) OutputVars() VarSet { } func (ref Ref) toArray() *Array { - terms := make([]*Term, 0, len(ref)) - for _, term := range ref { + terms := make([]*Term, len(ref)) + for i, term := range ref { if _, ok := term.Value.(String); ok { - terms = append(terms, term) + terms[i] = term } else { - terms = append(terms, InternedTerm(term.Value.String())) + terms[i] = InternedTerm(term.Value.String()) } } return NewArray(terms...) @@ -1627,19 +1635,28 @@ func (s *set) Diff(other Set) Set { if s.Compare(other) == 0 { return NewSet() } + if s.Len() == 0 { + return NewSet() + } - terms := make([]*Term, 0, len(s.keys)) + terms := make([]*Term, 0, s.Len()) for _, term := range s.sortedKeys() { if !other.Contains(term) { terms = append(terms, term) } } + if len(terms) == 0 { + return NewSet() + } return NewSet(terms...) } // Intersect returns the set containing elements in both s and other. func (s *set) Intersect(other Set) Set { + if s.Len() == 0 || other.Len() == 0 { + return NewSet() + } o := other.(*set) n, m := s.Len(), o.Len() ss := s @@ -1657,6 +1674,9 @@ func (s *set) Intersect(other Set) Set { } } + if len(terms) == 0 { + return NewSet() + } return NewSet(terms...) } @@ -1698,7 +1718,10 @@ func (s *set) Foreach(f func(*Term)) { // Map returns a new Set obtained by applying f to each value in s. func (s *set) Map(f func(*Term) (*Term, error)) (Set, error) { - mapped := make([]*Term, 0, len(s.keys)) + if s.Len() == 0 { + return NewSet(), nil + } + mapped := make([]*Term, 0, s.Len()) for _, x := range s.sortedKeys() { term, err := f(x) if err != nil { @@ -2262,12 +2285,13 @@ func (obj *object) Map(f func(*Term, *Term) (*Term, *Term, error)) (Object, erro // Keys returns the keys of obj. func (obj *object) Keys() []*Term { - keys := make([]*Term, len(obj.keys)) - + if obj.Len() == 0 { + return nil + } + keys := make([]*Term, obj.Len()) for i, elem := range obj.sortedKeys() { keys[i] = elem.key } - return keys } @@ -2444,6 +2468,9 @@ func filterObject(o Value, filter Value) (Value, error) { case String, Number, Boolean, Null: return o, nil case *Array: + if v.Len() == 0 { + return v, nil + } values := NewArray() for i := range v.Len() { subFilter := filteredObj.Get(InternedIntegerString(i)) @@ -2457,6 +2484,9 @@ func filterObject(o Value, filter Value) (Value, error) { } return values, nil case Set: + if v.Len() == 0 { + return v, nil + } terms := make([]*Term, 0, v.Len()) for _, t := range v.Slice() { if filteredObj.Get(t) != nil { @@ -2467,6 +2497,9 @@ func filterObject(o Value, filter Value) (Value, error) { terms = append(terms, NewTerm(filteredValue)) } } + if len(terms) == 0 { + return NewSet(), nil + } return NewSet(terms...), nil case *object: values := NewObject()