summaryrefslogtreecommitdiffhomepage
path: root/cmd/equaler
diff options
context:
space:
mode:
Diffstat (limited to 'cmd/equaler')
-rw-r--r--cmd/equaler/equaler.go177
1 files changed, 177 insertions, 0 deletions
diff --git a/cmd/equaler/equaler.go b/cmd/equaler/equaler.go
new file mode 100644
index 000000000..27412bb23
--- /dev/null
+++ b/cmd/equaler/equaler.go
@@ -0,0 +1,177 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Equaler is a tool to automate the creation of an Equals method.
+//
+// This tool assumes that if a type you give it contains another named struct
+// type, that type will also have an Equal method, and that all fields are
+// comparable unless explicitly excluded.
+package main
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "go/token"
+ "go/types"
+ "log"
+ "os"
+ "strings"
+
+ "golang.org/x/exp/slices"
+ "tailscale.com/util/codegen"
+)
+
+var (
+ flagTypes = flag.String("type", "", "comma-separated list of types; required")
+ flagBuildTags = flag.String("tags", "", "compiler build tags to apply")
+)
+
+func main() {
+ log.SetFlags(0)
+ log.SetPrefix("equaler: ")
+ flag.Parse()
+ if len(*flagTypes) == 0 {
+ flag.Usage()
+ os.Exit(2)
+ }
+ typeNames := strings.Split(*flagTypes, ",")
+
+ pkg, namedTypes, err := codegen.LoadTypes(*flagBuildTags, ".")
+ if err != nil {
+ log.Fatal(err)
+ }
+ it := codegen.NewImportTracker(pkg.Types)
+ buf := new(bytes.Buffer)
+ for _, typeName := range typeNames {
+ typ, ok := namedTypes[typeName]
+ if !ok {
+ log.Fatalf("could not find type %s", typeName)
+ }
+ gen(buf, it, typ, typeNames)
+ }
+
+ cloneOutput := pkg.Name + "_equal.go"
+ if err := codegen.WritePackageFile("tailscale.com/cmd/equaler", pkg, cloneOutput, it, buf); err != nil {
+ log.Fatal(err)
+ }
+}
+
+func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, typeNames []string) {
+ t, ok := typ.Underlying().(*types.Struct)
+ if !ok {
+ return
+ }
+
+ name := typ.Obj().Name()
+ fmt.Fprintf(buf, "// Equal reports whether a and b are equal.\n")
+ fmt.Fprintf(buf, "func (a *%s) Equal(b *%s) bool {\n", name, name)
+ writef := func(format string, args ...any) {
+ fmt.Fprintf(buf, "\t"+format+"\n", args...)
+ }
+ writef("if a == b {")
+ writef("\treturn true")
+ writef("}")
+
+ writef("return a != nil && b != nil &&")
+ for i := 0; i < t.NumFields(); i++ {
+ fname := t.Field(i).Name()
+ ft := t.Field(i).Type()
+
+ // Fields which are explicitly ignored are skipped.
+ if codegen.HasNoEqual(t.Tag(i)) {
+ writef("\t// Skipping %s because of codegen:noequal", fname)
+ continue
+ }
+
+ // Fields which are named types that have an Equal() method, get that method used
+ if named, _ := ft.(*types.Named); named != nil {
+ if implementsEqual(ft) || slices.Contains(typeNames, named.Obj().Name()) {
+ writef("\ta.%s.Equal(b.%s) &&", fname, fname)
+ continue
+ }
+ }
+
+ // Fields which are just values are directly compared, unless they have an Equal() method.
+ if !codegen.ContainsPointers(ft) {
+ writef("\ta.%s == b.%s &&", fname, fname)
+ continue
+ }
+
+ switch ft := ft.Underlying().(type) {
+ case *types.Pointer:
+ if named, _ := ft.Elem().(*types.Named); named != nil {
+ if slices.Contains(typeNames, named.Obj().Name()) || implementsEqual(ft) {
+ writef("\t((a.%s == nil) == (b.%s == nil)) && (a.%s == nil || a.%s.Equal(b.%s)) &&", fname, fname, fname, fname, fname)
+ continue
+ }
+ if implementsEqual(ft.Elem()) {
+ writef("\t((a.%s == nil) == (b.%s == nil)) && (a.%s == nil || a.%s.Equal(*b.%s)) &&", fname, fname, fname, fname, fname)
+ continue
+ }
+ }
+ if !codegen.ContainsPointers(ft.Elem()) {
+ writef("\t((a.%s == nil) == (b.%s == nil)) && (a.%s == nil || *a.%s == *b.%s) &&", fname, fname, fname, fname, fname)
+ continue
+ }
+ log.Fatalf("unimplemented: %s (%T)", fname, ft)
+ case *types.Slice:
+ // Empty slices and nil slices are different.
+ writef("\t((a.%s == nil) == (b.%s == nil)) &&", fname, fname)
+ if named, _ := ft.Elem().(*types.Named); named != nil {
+ if implementsEqual(ft.Elem()) {
+ it.Import("golang.org/x/exp/slices")
+ writef("\tslices.EqualFunc(a.%s, b.%s, func(aa %s, bb %s) bool {return aa.Equal(bb)}) &&", fname, fname, named.Obj().Name(), named.Obj().Name())
+ continue
+ }
+ if slices.Contains(typeNames, named.Obj().Name()) || implementsEqual(types.NewPointer(ft.Elem())) {
+ it.Import("golang.org/x/exp/slices")
+ writef("\tslices.EqualFunc(a.%s, b.%s, func(aa %s, bb %s) bool {return aa.Equal(&bb)}) &&", fname, fname, named.Obj().Name(), named.Obj().Name())
+ continue
+ }
+ }
+ if !codegen.ContainsPointers(ft.Elem()) {
+ it.Import("golang.org/x/exp/slices")
+ writef("\tslices.Equal(a.%s, b.%s) &&", fname, fname)
+ continue
+ }
+ log.Fatalf("unimplemented: %s (%T)", fname, ft)
+ case *types.Map:
+ if !codegen.ContainsPointers(ft.Elem()) {
+ it.Import("golang.org/x/exp/maps")
+ writef("\tmaps.Equal(a.%s, b.%s) &&", fname, fname)
+ continue
+ }
+ log.Fatalf("unimplemented: %s (%T)", fname, ft)
+ default:
+ log.Fatalf("unimplemented: %s (%T)", fname, ft)
+ }
+ }
+ writef("\ttrue")
+ fmt.Fprintf(buf, "}\n\n")
+
+ buf.Write(codegen.AssertStructUnchanged(t, name, "Equal", it))
+}
+
+// hasBasicUnderlying reports true when typ.Underlying() is a slice or a map.
+func hasBasicUnderlying(typ types.Type) bool {
+ switch typ.Underlying().(type) {
+ case *types.Slice, *types.Map:
+ return true
+ default:
+ return false
+ }
+}
+
+// implementsEqual reports whether typ has an Equal(typ) bool method.
+func implementsEqual(typ types.Type) bool {
+ return types.Implements(typ, types.NewInterfaceType(
+ []*types.Func{types.NewFunc(
+ token.NoPos, nil, "Equal", types.NewSignatureType(
+ types.NewVar(token.NoPos, nil, "a", typ),
+ nil, nil,
+ types.NewTuple(types.NewVar(token.NoPos, nil, "b", typ)),
+ types.NewTuple(types.NewVar(token.NoPos, nil, "", types.Typ[types.Bool])), false))},
+ []types.Type{},
+ ))
+}