diff options
| author | Andrew Dunham <andrew@du.nham.ca> | 2023-03-13 14:07:07 -0400 |
|---|---|---|
| committer | Andrew Dunham <andrew@du.nham.ca> | 2023-03-13 14:56:38 -0400 |
| commit | b080e6bb2d520c09a93a83302cd3d94363163862 (patch) | |
| tree | 654c4b319971a1ecf17b851e34c881286b3c8212 /cmd/fastjson/fastjson.go | |
| parent | 223713d4a1fc79a5c4f61d3655b173f7cb1e2409 (diff) | |
| download | tailscale-andrew/fastjson.tar.xz tailscale-andrew/fastjson.zip | |
Change-Id: I5295d47102d879f29f0a6818481e8b65eafd02dd
Diffstat (limited to 'cmd/fastjson/fastjson.go')
| -rw-r--r-- | cmd/fastjson/fastjson.go | 329 |
1 files changed, 329 insertions, 0 deletions
diff --git a/cmd/fastjson/fastjson.go b/cmd/fastjson/fastjson.go new file mode 100644 index 000000000..1b017235c --- /dev/null +++ b/cmd/fastjson/fastjson.go @@ -0,0 +1,329 @@ +package main + +import ( + "bytes" + "flag" + "fmt" + "go/ast" + "go/token" + "go/types" + "log" + "os" + "strconv" + "strings" + + "golang.org/x/tools/go/packages" + "golang.org/x/tools/imports" + "tailscale.com/util/codegen" +) + +var ( + flagTypes = flag.String("type", "", "comma-separated list of types; required") + flagBuildTags = flag.String("tags", "", "compiler build tags to apply") +) + +func main() { + log.SetFlags(0) + log.SetPrefix("cloner: ") + log.SetOutput(os.Stderr) + flag.Parse() + if len(*flagTypes) == 0 { + flag.Usage() + os.Exit(2) + } + typeNames := strings.Split(*flagTypes, ",") + + pkg, namedTypes, err := loadTypes(".", *flagBuildTags) + if err != nil { + log.Fatal(err) + } + it := codegen.NewImportTracker(pkg.Types) + buf := new(bytes.Buffer) + + for _, typeName := range typeNames { + typ, ok := namedTypes[typeName] + if !ok { + log.Fatalf("could not find type %s", typeName) + } + gen(buf, it, typ) + } + + outBuf := new(bytes.Buffer) + outBuf.WriteString("// Code generated by TODO; DO NOT EDIT.\n") + outBuf.WriteString("\n") + fmt.Fprintf(outBuf, "package %s\n\n", pkg.Name) + it.Write(outBuf) + outBuf.Write(buf.Bytes()) + + // Best-effort gofmt the output + out := outBuf.Bytes() + out, err = imports.Process("/nonexistant/main.go", out, &imports.Options{ + Comments: true, + TabIndent: true, + TabWidth: 8, + FormatOnly: true, // fancy gofmt only + }) + if err != nil { + out = outBuf.Bytes() + } + fmt.Print(string(out)) +} + +func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { + t, ok := typ.Underlying().(*types.Struct) + if !ok { + return + } + + name := typ.Obj().Name() + fmt.Fprintf(buf, "// MarshalJSONInto marshals this %s into JSON in the provided buffer.\n", name) + fmt.Fprintf(buf, "func (self *%s) MarshalJSONInto(buf []byte) ([]byte, error) {\n", name) + fmt.Fprintf(buf, "\tvar err error\n") + fmt.Fprintf(buf, "\t_ = err\n") + + g := &generator{ + buf: buf, + it: it, + indentLevel: 1, + } + + g.writef(`buf = append(buf, '{')`) + for i := 0; i < t.NumFields(); i++ { + fname := t.Field(i).Name() + ft := t.Field(i).Type() + + g.writef("") + g.writef(`// Encode field %s of type %q`, fname, ft.String()) + + // Write the field name; we need to quote the field (for JSON) + // and then quote it again (for the generated Go code). + qfname := strconv.Quote(fname) + ":" + g.writef(`buf = append(buf, []byte(%q)...)`, qfname) + + // Write the value + g.encode("self."+fname, ft) + + if i < t.NumFields()-1 { + g.writef(`buf = append(buf, ',')`) + } + } + g.writef(`buf = append(buf, '}')`) + + g.writef("return buf, nil") + fmt.Fprintf(buf, "}\n\n") +} + +type generator struct { + buf *bytes.Buffer + it *codegen.ImportTracker + indentLevel int +} + +func (g *generator) writef(format string, args ...any) { + fmt.Fprintf(g.buf, strings.Repeat("\t", g.indentLevel)+format+"\n", args...) +} + +func (g *generator) indent() { + g.indentLevel++ +} + +func (g *generator) dedent() { + g.indentLevel-- +} + +func (g *generator) encode(accessor string, ft types.Type) { + switch ft := ft.Underlying().(type) { + case *types.Basic: + g.encodeBasicField(accessor, ft) + case *types.Slice: + g.encodeSlice(accessor, ft) + case *types.Map: + g.encodeMap(accessor, ft) + case *types.Struct: + g.encodeStruct(accessor) + case *types.Pointer: + g.encodePointer(accessor, ft) + default: + g.writef(`panic("TODO: %s (%T)")`, accessor, ft) + } +} + +func (g *generator) encodePointer(accessor string, ft *types.Pointer) { + g.writef("if %s != nil {", accessor) + g.indent() + // Don't deref for a struct, since we're going to call a function + // anyway; otherwise, do. + if _, ok := ft.Elem().Underlying().(*types.Struct); ok { + g.encode(accessor, ft.Elem()) + } else { + g.encode("(*"+accessor+")", ft.Elem()) + } + g.dedent() + g.writef("} else {") + g.writef("\tbuf = append(buf, []byte(\"null\")...)") + g.writef("}") +} + +func (g *generator) encodeMap(accessor string, ft *types.Map) { + kt := ft.Key().Underlying() + vt := ft.Elem().Underlying() + + g.writef(`buf = append(buf, '{')`) + + // Determine how we marshal our key type + marshalKey := func() { + g.encode("k", kt) + } + + // Now check how we marshal our value + switch vt := vt.(type) { + case *types.Basic: + g.writef("for k, v := range %s {", accessor) + marshalKey() + g.writef("\tbuf = append(buf, ':')") + g.encodeBasicField("v", vt) + g.writef("}") + case *types.Struct: + g.writef("for k, v := range %s {", accessor) + marshalKey() + g.writef("\tbuf = append(buf, ':')") + g.encodeStruct("v") + g.writef("}") + default: + g.writef(`panic("TODO: %s (%T)")`, accessor, vt) + } + + g.writef(`buf = append(buf, '}')`) +} + +func (g *generator) encodeStruct(accessor string) { + // Assume that this struct also has a MarshalJSONInto method. + g.writef("buf, err = %s.MarshalJSONInto(buf)", accessor) + g.writef("if err != nil {") + g.writef("\treturn nil, err") + g.writef("}") +} + +func (g *generator) encodeSlice(accessor string, sl *types.Slice) { + switch ft := sl.Elem().Underlying().(type) { + case *types.Basic: + // Slice of basic elements + switch ft.Kind() { + case types.Byte: + // base64-encode + g.it.Import("encoding/base64") + + g.writef(`buf = append(buf, '"')`) + g.writef("{") + + // buf = append(buf, make([]byte, N)...) is a fast way to grow the slice by N + g.writef("encodedLen := base64.StdEncoding.EncodedLen(len(%s))", accessor) + g.writef("offset := len(buf)") + g.writef("buf = append(buf, make([]byte, encodedLen)...)") + g.writef("base64.StdEncoding.Encode(buf[offset:], %s)", accessor) + + g.writef("}") + g.writef(`buf = append(buf, '"')`) + default: + // All other basic elements are encoded + // one at a time via encodeBasicField + g.writef(`buf = append(buf, '[')`) + g.writef(`for i, elem := range %s {`, accessor) + g.writef("\tif i > 0 {") + g.writef("\t\tbuf = append(buf, ',')") + g.writef("\t}") + g.encodeBasicField("elem", ft) + g.writef(`}`) + g.writef(`buf = append(buf, ']')`) + } + + case *types.Struct: + g.writef(`buf = append(buf, '[')`) + g.writef(`for i, elem := range %s {`, accessor) + g.writef("\tif i > 0 {") + g.writef("\t\tbuf = append(buf, ',')") + g.writef("\t}") + g.encodeStruct("elem") + g.writef(`}`) + g.writef(`buf = append(buf, ']')`) + + default: + // TODO: if the type implements our interface, + // call that function for everything in the + // slice. + g.writef(`panic("TODO: %s (%T)")`, accessor, ft) + } +} + +func (g *generator) encodeBasicField(accessor string, field *types.Basic) { + switch field.Kind() { + case types.Bool: + g.writef("if %s {", accessor) + g.writef(`buf = append(buf, []byte("true")...)`) + g.writef("} else {") + g.writef(`buf = append(buf, []byte("false")...)`) + g.writef("}") + case types.Int, types.Int8, types.Int16, types.Int32, types.Int64: + g.it.Import("strconv") + g.writef("buf = strconv.AppendInt(buf, int64(%s), 10)", accessor) + case types.Uint, types.Uint8, types.Uint16, types.Uint32, types.Uint64: + g.it.Import("strconv") + g.writef("buf = strconv.AppendUint(buf, uint64(%s), 10)", accessor) + case types.String: + g.it.Import("strconv") + g.writef("buf = strconv.AppendQuote(buf, %s)", accessor) + default: + g.writef(`panic("TODO: %s (%T)")`, accessor, field.Kind) + } +} + +func loadTypes(pkgName, buildTags string) (*packages.Package, map[string]*types.Named, error) { + cfg := &packages.Config{ + Mode: packages.NeedTypes | + packages.NeedTypesInfo | + packages.NeedSyntax | + packages.NeedName, + Tests: false, + } + if buildTags != "" { + cfg.BuildFlags = []string{"-tags=" + buildTags} + } + + pkgs, err := packages.Load(cfg, pkgName) + if err != nil { + return nil, nil, err + } + if len(pkgs) != 1 { + return nil, nil, fmt.Errorf("wrong number of packages: %d", len(pkgs)) + } + pkg := pkgs[0] + return pkg, namedTypes(pkg), nil +} + +func namedTypes(pkg *packages.Package) map[string]*types.Named { + nt := make(map[string]*types.Named) + for _, file := range pkg.Syntax { + for _, d := range file.Decls { + decl, ok := d.(*ast.GenDecl) + if !ok || decl.Tok != token.TYPE { + continue + } + for _, s := range decl.Specs { + spec, ok := s.(*ast.TypeSpec) + if !ok { + continue + } + typeNameObj, ok := pkg.TypesInfo.Defs[spec.Name] + if !ok { + continue + } + typ, ok := typeNameObj.Type().(*types.Named) + if !ok { + continue + } + nt[spec.Name.Name] = typ + } + } + } + return nt +} |
