summaryrefslogtreecommitdiffhomepage
path: root/cmd
diff options
context:
space:
mode:
Diffstat (limited to 'cmd')
-rw-r--r--cmd/viewer/viewer.go216
1 files changed, 216 insertions, 0 deletions
diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go
new file mode 100644
index 000000000..df6287c26
--- /dev/null
+++ b/cmd/viewer/viewer.go
@@ -0,0 +1,216 @@
+// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Viewer is a tool to automate the creation of a view type.
+//
+// The generated View method provides a readonly view of the struct.
+//
+// This tool makes lots of implicit assumptions about the types you feed it.
+// In particular, it can only write relatively "shallow" View methods.
+// That is, if a type contains another named struct type, viewer assumes that
+// named type will also have a View method.
+package main
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "go/types"
+ "log"
+ "os"
+ "strings"
+
+ "golang.org/x/tools/go/packages"
+ "tailscale.com/util/codegen"
+)
+
+var (
+ flagTypes = flag.String("type", "", "comma-separated list of types; required")
+ flagOutput = flag.String("output", "", "output file; required")
+ flagBuildTags = flag.String("tags", "", "compiler build tags to apply")
+)
+
+func main() {
+ log.SetFlags(0)
+ log.SetPrefix("viewer: ")
+ flag.Parse()
+ if len(*flagTypes) == 0 {
+ flag.Usage()
+ os.Exit(2)
+ }
+ typeNames := strings.Split(*flagTypes, ",")
+
+ cfg := &packages.Config{
+ Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName,
+ Tests: false,
+ }
+ if *flagBuildTags != "" {
+ cfg.BuildFlags = []string{"-tags=" + *flagBuildTags}
+ }
+ pkgs, err := packages.Load(cfg, ".")
+ if err != nil {
+ log.Fatal(err)
+ }
+ if len(pkgs) != 1 {
+ log.Fatalf("wrong number of packages: %d", len(pkgs))
+ }
+ pkg := pkgs[0]
+ buf := new(bytes.Buffer)
+ imports := make(map[string]struct{})
+ namedTypes := codegen.NamedTypes(pkg)
+ for _, typeName := range typeNames {
+ typ, ok := namedTypes[typeName]
+ if !ok {
+ log.Fatalf("could not find type %s", typeName)
+ }
+ gen(buf, imports, typ, pkg.Types)
+ }
+
+ contents := new(bytes.Buffer)
+ fmt.Fprintf(contents, header, *flagTypes, pkg.Name)
+ fmt.Fprintf(contents, "import (\n")
+ for s := range imports {
+ fmt.Fprintf(contents, "\t%q\n", s)
+ }
+ fmt.Fprintf(contents, ")\n\n")
+ contents.Write(buf.Bytes())
+
+ output := *flagOutput
+ if output == "" {
+ flag.Usage()
+ os.Exit(2)
+ }
+ if err := codegen.WriteFormatted(contents.Bytes(), output); err != nil {
+ log.Fatal(err)
+ }
+}
+
+const header = `// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Code generated by the following command; DO NOT EDIT.
+// tailscale.com/cmd/viewer -type %s
+
+package %s
+
+`
+
+func gen(buf *bytes.Buffer, imports map[string]struct{}, typ *types.Named, thisPkg *types.Package) {
+ pkgQual := func(pkg *types.Package) string {
+ if thisPkg == pkg {
+ return ""
+ }
+ imports[pkg.Path()] = struct{}{}
+ return pkg.Name()
+ }
+ importedName := func(t types.Type) string {
+ return types.TypeString(t, pkgQual)
+ }
+
+ t, ok := typ.Underlying().(*types.Struct)
+ if !ok {
+ return
+ }
+
+ name := typ.Obj().Name()
+ viewName := name + "View"
+ fmt.Fprintf(buf, "// View makes a readonly view of %s.\n", name)
+ fmt.Fprintf(buf, "func (src *%s) View() %s {\n", name, viewName)
+ fmt.Fprintf(buf, " return %s{src}\n", viewName)
+ fmt.Fprintf(buf, "}\n")
+
+ fmt.Fprintf(buf, "// %s is a readonly view of %s.\n", viewName, name)
+ fmt.Fprintf(buf, "type %s struct{ ж *%s }\n", viewName, name)
+ fmt.Fprintf(buf, "func (v %s) Valid() bool { return v.ж != nil }\n", viewName)
+
+ for i := 0; i < t.NumFields(); i++ {
+ fname := t.Field(i).Name()
+ ft := t.Field(i).Type()
+ if !codegen.ContainsPointers(ft) {
+ fmt.Fprintf(buf, "func (v %s) %s() %s { return v.ж.%s }\n", viewName, fname, importedName(ft), fname)
+ continue
+ }
+ if named, _ := ft.(*types.Named); named != nil && !hasBasicUnderlying(ft) {
+ genViewCall(buf, viewName, fname, importedName(ft))
+ continue
+ }
+ switch ft := ft.Underlying().(type) {
+ case *types.Slice:
+ if !codegen.ContainsPointers(ft.Elem()) {
+ // OK to return the slice as-is, since they can't modify the contents.
+ fmt.Fprintf(buf, "func (v %s) %s() %s { return v.ж.%s }\n", viewName, fname, importedName(ft), fname)
+ continue
+ }
+
+ n := importedName(ft.Elem())
+ if ptrTyp, isPtr := ft.Elem().(*types.Pointer); isPtr {
+ n = importedName(ptrTyp.Elem())
+ }
+
+ // Generate slice view.
+ styp := fmt.Sprintf("_%s_%s", viewName, fname)
+ fmt.Fprintf(buf, "type %s []%s\n", styp, importedName(ft.Elem()))
+ fmt.Fprintf(buf, "func (s %s) Len() int { return len(s) }\n", styp)
+ fmt.Fprintf(buf, "func (s %s) At(i int) %sView { return s[i].View() }\n", styp, n)
+
+ fmt.Fprintf(buf, "func (v %s) %s() interface { Len() int; At(int) %sView } {\n", viewName, fname, n)
+ fmt.Fprintf(buf, " return %s(v.ж.%s)\n", styp, fname)
+ fmt.Fprintf(buf, "}\n")
+ case *types.Pointer:
+ if named, _ := ft.Elem().(*types.Named); named != nil && codegen.ContainsPointers(ft.Elem()) {
+ genViewCall(buf, viewName, fname, importedName(named))
+ continue
+ }
+ if codegen.ContainsPointers(ft.Elem()) {
+ log.Fatalf("unhandled: pointers in pointers (%v)", ft)
+ }
+ n := importedName(ft.Elem())
+ fmt.Fprintf(buf, "func (v %s) %s() *%s {\n", viewName, fname, n)
+ fmt.Fprintf(buf, " ptr := v.ж.%s\n", fname)
+ fmt.Fprintf(buf, " if ptr == nil {\n")
+ fmt.Fprintf(buf, " return nil\n")
+ fmt.Fprintf(buf, " }\n")
+ fmt.Fprintf(buf, " cp := *ptr\n")
+ fmt.Fprintf(buf, " return &cp\n")
+ fmt.Fprintf(buf, "}\n")
+ case *types.Map:
+ // TODO: Generate map view, like the slice view.
+ // We need:
+ // * Len() int
+ // * Load(k) v
+ // * LoadOK(k) (v, bool)
+ // * Range(func(k, v) bool)
+ //
+ // Note that we need to handle a variety of elem types:
+ // basic types (float64), types with a View method,
+ // slices of the foregoing.
+ //
+ // This may require recursion to handle completely,
+ // or we can follow cloner's lead and just manually
+ // inline one level deep the code generation
+ // that we happen to need right now.
+ // (If we figure out recursion in this context,
+ // we might want to backport to cloner, too.)
+ log.Printf("TODO: Handle %s (%s)", name, ft)
+ default:
+ fmt.Fprintf(buf, `panic("TODO: %s (%T)")`, fname, ft)
+ }
+ }
+
+ buf.Write(codegen.AssertStructUnchanged(t, thisPkg, name, "View", imports))
+}
+
+func genViewCall(buf *bytes.Buffer, viewName, fieldName, importedName string) {
+ fmt.Fprintf(buf, "func (v %s) %s() %sView { return v.ж.%s.View() }\n", viewName, fieldName, importedName, fieldName)
+}
+
+func hasBasicUnderlying(typ types.Type) bool {
+ switch typ.Underlying().(type) {
+ case *types.Slice, *types.Map:
+ return true
+ default:
+ return false
+ }
+}