diff options
| author | Josh Bleecher Snyder <josh@tailscale.com> | 2021-09-16 16:50:31 -0700 |
|---|---|---|
| committer | Josh Bleecher Snyder <josh@tailscale.com> | 2021-09-17 16:47:00 -0700 |
| commit | c7b75465871eb911df4e1ff91a57b9c91c279111 (patch) | |
| tree | 0b7975a92a5f53ad27ee78525ae4e5660d34c7fe /cmd | |
| parent | b14db5d943b84be3c9f3a909c18c9af4012523dd (diff) | |
| download | tailscale-josh/immutable-views.tar.xz tailscale-josh/immutable-views.zip | |
WIP snapshotjosh/immutable-views
Next up: view support for maps, etc.
Diffstat (limited to 'cmd')
| -rw-r--r-- | cmd/viewer/viewer.go | 216 |
1 files changed, 216 insertions, 0 deletions
diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go new file mode 100644 index 000000000..df6287c26 --- /dev/null +++ b/cmd/viewer/viewer.go @@ -0,0 +1,216 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Viewer is a tool to automate the creation of a view type. +// +// The generated View method provides a readonly view of the struct. +// +// This tool makes lots of implicit assumptions about the types you feed it. +// In particular, it can only write relatively "shallow" View methods. +// That is, if a type contains another named struct type, viewer assumes that +// named type will also have a View method. +package main + +import ( + "bytes" + "flag" + "fmt" + "go/types" + "log" + "os" + "strings" + + "golang.org/x/tools/go/packages" + "tailscale.com/util/codegen" +) + +var ( + flagTypes = flag.String("type", "", "comma-separated list of types; required") + flagOutput = flag.String("output", "", "output file; required") + flagBuildTags = flag.String("tags", "", "compiler build tags to apply") +) + +func main() { + log.SetFlags(0) + log.SetPrefix("viewer: ") + flag.Parse() + if len(*flagTypes) == 0 { + flag.Usage() + os.Exit(2) + } + typeNames := strings.Split(*flagTypes, ",") + + cfg := &packages.Config{ + Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName, + Tests: false, + } + if *flagBuildTags != "" { + cfg.BuildFlags = []string{"-tags=" + *flagBuildTags} + } + pkgs, err := packages.Load(cfg, ".") + if err != nil { + log.Fatal(err) + } + if len(pkgs) != 1 { + log.Fatalf("wrong number of packages: %d", len(pkgs)) + } + pkg := pkgs[0] + buf := new(bytes.Buffer) + imports := make(map[string]struct{}) + namedTypes := codegen.NamedTypes(pkg) + for _, typeName := range typeNames { + typ, ok := namedTypes[typeName] + if !ok { + log.Fatalf("could not find type %s", typeName) + } + gen(buf, imports, typ, pkg.Types) + } + + contents := new(bytes.Buffer) + fmt.Fprintf(contents, header, *flagTypes, pkg.Name) + fmt.Fprintf(contents, "import (\n") + for s := range imports { + fmt.Fprintf(contents, "\t%q\n", s) + } + fmt.Fprintf(contents, ")\n\n") + contents.Write(buf.Bytes()) + + output := *flagOutput + if output == "" { + flag.Usage() + os.Exit(2) + } + if err := codegen.WriteFormatted(contents.Bytes(), output); err != nil { + log.Fatal(err) + } +} + +const header = `// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Code generated by the following command; DO NOT EDIT. +// tailscale.com/cmd/viewer -type %s + +package %s + +` + +func gen(buf *bytes.Buffer, imports map[string]struct{}, typ *types.Named, thisPkg *types.Package) { + pkgQual := func(pkg *types.Package) string { + if thisPkg == pkg { + return "" + } + imports[pkg.Path()] = struct{}{} + return pkg.Name() + } + importedName := func(t types.Type) string { + return types.TypeString(t, pkgQual) + } + + t, ok := typ.Underlying().(*types.Struct) + if !ok { + return + } + + name := typ.Obj().Name() + viewName := name + "View" + fmt.Fprintf(buf, "// View makes a readonly view of %s.\n", name) + fmt.Fprintf(buf, "func (src *%s) View() %s {\n", name, viewName) + fmt.Fprintf(buf, " return %s{src}\n", viewName) + fmt.Fprintf(buf, "}\n") + + fmt.Fprintf(buf, "// %s is a readonly view of %s.\n", viewName, name) + fmt.Fprintf(buf, "type %s struct{ ж *%s }\n", viewName, name) + fmt.Fprintf(buf, "func (v %s) Valid() bool { return v.ж != nil }\n", viewName) + + for i := 0; i < t.NumFields(); i++ { + fname := t.Field(i).Name() + ft := t.Field(i).Type() + if !codegen.ContainsPointers(ft) { + fmt.Fprintf(buf, "func (v %s) %s() %s { return v.ж.%s }\n", viewName, fname, importedName(ft), fname) + continue + } + if named, _ := ft.(*types.Named); named != nil && !hasBasicUnderlying(ft) { + genViewCall(buf, viewName, fname, importedName(ft)) + continue + } + switch ft := ft.Underlying().(type) { + case *types.Slice: + if !codegen.ContainsPointers(ft.Elem()) { + // OK to return the slice as-is, since they can't modify the contents. + fmt.Fprintf(buf, "func (v %s) %s() %s { return v.ж.%s }\n", viewName, fname, importedName(ft), fname) + continue + } + + n := importedName(ft.Elem()) + if ptrTyp, isPtr := ft.Elem().(*types.Pointer); isPtr { + n = importedName(ptrTyp.Elem()) + } + + // Generate slice view. + styp := fmt.Sprintf("_%s_%s", viewName, fname) + fmt.Fprintf(buf, "type %s []%s\n", styp, importedName(ft.Elem())) + fmt.Fprintf(buf, "func (s %s) Len() int { return len(s) }\n", styp) + fmt.Fprintf(buf, "func (s %s) At(i int) %sView { return s[i].View() }\n", styp, n) + + fmt.Fprintf(buf, "func (v %s) %s() interface { Len() int; At(int) %sView } {\n", viewName, fname, n) + fmt.Fprintf(buf, " return %s(v.ж.%s)\n", styp, fname) + fmt.Fprintf(buf, "}\n") + case *types.Pointer: + if named, _ := ft.Elem().(*types.Named); named != nil && codegen.ContainsPointers(ft.Elem()) { + genViewCall(buf, viewName, fname, importedName(named)) + continue + } + if codegen.ContainsPointers(ft.Elem()) { + log.Fatalf("unhandled: pointers in pointers (%v)", ft) + } + n := importedName(ft.Elem()) + fmt.Fprintf(buf, "func (v %s) %s() *%s {\n", viewName, fname, n) + fmt.Fprintf(buf, " ptr := v.ж.%s\n", fname) + fmt.Fprintf(buf, " if ptr == nil {\n") + fmt.Fprintf(buf, " return nil\n") + fmt.Fprintf(buf, " }\n") + fmt.Fprintf(buf, " cp := *ptr\n") + fmt.Fprintf(buf, " return &cp\n") + fmt.Fprintf(buf, "}\n") + case *types.Map: + // TODO: Generate map view, like the slice view. + // We need: + // * Len() int + // * Load(k) v + // * LoadOK(k) (v, bool) + // * Range(func(k, v) bool) + // + // Note that we need to handle a variety of elem types: + // basic types (float64), types with a View method, + // slices of the foregoing. + // + // This may require recursion to handle completely, + // or we can follow cloner's lead and just manually + // inline one level deep the code generation + // that we happen to need right now. + // (If we figure out recursion in this context, + // we might want to backport to cloner, too.) + log.Printf("TODO: Handle %s (%s)", name, ft) + default: + fmt.Fprintf(buf, `panic("TODO: %s (%T)")`, fname, ft) + } + } + + buf.Write(codegen.AssertStructUnchanged(t, thisPkg, name, "View", imports)) +} + +func genViewCall(buf *bytes.Buffer, viewName, fieldName, importedName string) { + fmt.Fprintf(buf, "func (v %s) %s() %sView { return v.ж.%s.View() }\n", viewName, fieldName, importedName, fieldName) +} + +func hasBasicUnderlying(typ types.Type) bool { + switch typ.Underlying().(type) { + case *types.Slice, *types.Map: + return true + default: + return false + } +} |
