summaryrefslogtreecommitdiffhomepage
path: root/cmd/fastjson/fastjson.go
diff options
context:
space:
mode:
Diffstat (limited to 'cmd/fastjson/fastjson.go')
-rw-r--r--cmd/fastjson/fastjson.go329
1 files changed, 329 insertions, 0 deletions
diff --git a/cmd/fastjson/fastjson.go b/cmd/fastjson/fastjson.go
new file mode 100644
index 000000000..1b017235c
--- /dev/null
+++ b/cmd/fastjson/fastjson.go
@@ -0,0 +1,329 @@
+package main
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/token"
+ "go/types"
+ "log"
+ "os"
+ "strconv"
+ "strings"
+
+ "golang.org/x/tools/go/packages"
+ "golang.org/x/tools/imports"
+ "tailscale.com/util/codegen"
+)
+
+var (
+ flagTypes = flag.String("type", "", "comma-separated list of types; required")
+ flagBuildTags = flag.String("tags", "", "compiler build tags to apply")
+)
+
+func main() {
+ log.SetFlags(0)
+ log.SetPrefix("cloner: ")
+ log.SetOutput(os.Stderr)
+ flag.Parse()
+ if len(*flagTypes) == 0 {
+ flag.Usage()
+ os.Exit(2)
+ }
+ typeNames := strings.Split(*flagTypes, ",")
+
+ pkg, namedTypes, err := loadTypes(".", *flagBuildTags)
+ if err != nil {
+ log.Fatal(err)
+ }
+ it := codegen.NewImportTracker(pkg.Types)
+ buf := new(bytes.Buffer)
+
+ for _, typeName := range typeNames {
+ typ, ok := namedTypes[typeName]
+ if !ok {
+ log.Fatalf("could not find type %s", typeName)
+ }
+ gen(buf, it, typ)
+ }
+
+ outBuf := new(bytes.Buffer)
+ outBuf.WriteString("// Code generated by TODO; DO NOT EDIT.\n")
+ outBuf.WriteString("\n")
+ fmt.Fprintf(outBuf, "package %s\n\n", pkg.Name)
+ it.Write(outBuf)
+ outBuf.Write(buf.Bytes())
+
+ // Best-effort gofmt the output
+ out := outBuf.Bytes()
+ out, err = imports.Process("/nonexistant/main.go", out, &imports.Options{
+ Comments: true,
+ TabIndent: true,
+ TabWidth: 8,
+ FormatOnly: true, // fancy gofmt only
+ })
+ if err != nil {
+ out = outBuf.Bytes()
+ }
+ fmt.Print(string(out))
+}
+
+func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) {
+ t, ok := typ.Underlying().(*types.Struct)
+ if !ok {
+ return
+ }
+
+ name := typ.Obj().Name()
+ fmt.Fprintf(buf, "// MarshalJSONInto marshals this %s into JSON in the provided buffer.\n", name)
+ fmt.Fprintf(buf, "func (self *%s) MarshalJSONInto(buf []byte) ([]byte, error) {\n", name)
+ fmt.Fprintf(buf, "\tvar err error\n")
+ fmt.Fprintf(buf, "\t_ = err\n")
+
+ g := &generator{
+ buf: buf,
+ it: it,
+ indentLevel: 1,
+ }
+
+ g.writef(`buf = append(buf, '{')`)
+ for i := 0; i < t.NumFields(); i++ {
+ fname := t.Field(i).Name()
+ ft := t.Field(i).Type()
+
+ g.writef("")
+ g.writef(`// Encode field %s of type %q`, fname, ft.String())
+
+ // Write the field name; we need to quote the field (for JSON)
+ // and then quote it again (for the generated Go code).
+ qfname := strconv.Quote(fname) + ":"
+ g.writef(`buf = append(buf, []byte(%q)...)`, qfname)
+
+ // Write the value
+ g.encode("self."+fname, ft)
+
+ if i < t.NumFields()-1 {
+ g.writef(`buf = append(buf, ',')`)
+ }
+ }
+ g.writef(`buf = append(buf, '}')`)
+
+ g.writef("return buf, nil")
+ fmt.Fprintf(buf, "}\n\n")
+}
+
+type generator struct {
+ buf *bytes.Buffer
+ it *codegen.ImportTracker
+ indentLevel int
+}
+
+func (g *generator) writef(format string, args ...any) {
+ fmt.Fprintf(g.buf, strings.Repeat("\t", g.indentLevel)+format+"\n", args...)
+}
+
+func (g *generator) indent() {
+ g.indentLevel++
+}
+
+func (g *generator) dedent() {
+ g.indentLevel--
+}
+
+func (g *generator) encode(accessor string, ft types.Type) {
+ switch ft := ft.Underlying().(type) {
+ case *types.Basic:
+ g.encodeBasicField(accessor, ft)
+ case *types.Slice:
+ g.encodeSlice(accessor, ft)
+ case *types.Map:
+ g.encodeMap(accessor, ft)
+ case *types.Struct:
+ g.encodeStruct(accessor)
+ case *types.Pointer:
+ g.encodePointer(accessor, ft)
+ default:
+ g.writef(`panic("TODO: %s (%T)")`, accessor, ft)
+ }
+}
+
+func (g *generator) encodePointer(accessor string, ft *types.Pointer) {
+ g.writef("if %s != nil {", accessor)
+ g.indent()
+ // Don't deref for a struct, since we're going to call a function
+ // anyway; otherwise, do.
+ if _, ok := ft.Elem().Underlying().(*types.Struct); ok {
+ g.encode(accessor, ft.Elem())
+ } else {
+ g.encode("(*"+accessor+")", ft.Elem())
+ }
+ g.dedent()
+ g.writef("} else {")
+ g.writef("\tbuf = append(buf, []byte(\"null\")...)")
+ g.writef("}")
+}
+
+func (g *generator) encodeMap(accessor string, ft *types.Map) {
+ kt := ft.Key().Underlying()
+ vt := ft.Elem().Underlying()
+
+ g.writef(`buf = append(buf, '{')`)
+
+ // Determine how we marshal our key type
+ marshalKey := func() {
+ g.encode("k", kt)
+ }
+
+ // Now check how we marshal our value
+ switch vt := vt.(type) {
+ case *types.Basic:
+ g.writef("for k, v := range %s {", accessor)
+ marshalKey()
+ g.writef("\tbuf = append(buf, ':')")
+ g.encodeBasicField("v", vt)
+ g.writef("}")
+ case *types.Struct:
+ g.writef("for k, v := range %s {", accessor)
+ marshalKey()
+ g.writef("\tbuf = append(buf, ':')")
+ g.encodeStruct("v")
+ g.writef("}")
+ default:
+ g.writef(`panic("TODO: %s (%T)")`, accessor, vt)
+ }
+
+ g.writef(`buf = append(buf, '}')`)
+}
+
+func (g *generator) encodeStruct(accessor string) {
+ // Assume that this struct also has a MarshalJSONInto method.
+ g.writef("buf, err = %s.MarshalJSONInto(buf)", accessor)
+ g.writef("if err != nil {")
+ g.writef("\treturn nil, err")
+ g.writef("}")
+}
+
+func (g *generator) encodeSlice(accessor string, sl *types.Slice) {
+ switch ft := sl.Elem().Underlying().(type) {
+ case *types.Basic:
+ // Slice of basic elements
+ switch ft.Kind() {
+ case types.Byte:
+ // base64-encode
+ g.it.Import("encoding/base64")
+
+ g.writef(`buf = append(buf, '"')`)
+ g.writef("{")
+
+ // buf = append(buf, make([]byte, N)...) is a fast way to grow the slice by N
+ g.writef("encodedLen := base64.StdEncoding.EncodedLen(len(%s))", accessor)
+ g.writef("offset := len(buf)")
+ g.writef("buf = append(buf, make([]byte, encodedLen)...)")
+ g.writef("base64.StdEncoding.Encode(buf[offset:], %s)", accessor)
+
+ g.writef("}")
+ g.writef(`buf = append(buf, '"')`)
+ default:
+ // All other basic elements are encoded
+ // one at a time via encodeBasicField
+ g.writef(`buf = append(buf, '[')`)
+ g.writef(`for i, elem := range %s {`, accessor)
+ g.writef("\tif i > 0 {")
+ g.writef("\t\tbuf = append(buf, ',')")
+ g.writef("\t}")
+ g.encodeBasicField("elem", ft)
+ g.writef(`}`)
+ g.writef(`buf = append(buf, ']')`)
+ }
+
+ case *types.Struct:
+ g.writef(`buf = append(buf, '[')`)
+ g.writef(`for i, elem := range %s {`, accessor)
+ g.writef("\tif i > 0 {")
+ g.writef("\t\tbuf = append(buf, ',')")
+ g.writef("\t}")
+ g.encodeStruct("elem")
+ g.writef(`}`)
+ g.writef(`buf = append(buf, ']')`)
+
+ default:
+ // TODO: if the type implements our interface,
+ // call that function for everything in the
+ // slice.
+ g.writef(`panic("TODO: %s (%T)")`, accessor, ft)
+ }
+}
+
+func (g *generator) encodeBasicField(accessor string, field *types.Basic) {
+ switch field.Kind() {
+ case types.Bool:
+ g.writef("if %s {", accessor)
+ g.writef(`buf = append(buf, []byte("true")...)`)
+ g.writef("} else {")
+ g.writef(`buf = append(buf, []byte("false")...)`)
+ g.writef("}")
+ case types.Int, types.Int8, types.Int16, types.Int32, types.Int64:
+ g.it.Import("strconv")
+ g.writef("buf = strconv.AppendInt(buf, int64(%s), 10)", accessor)
+ case types.Uint, types.Uint8, types.Uint16, types.Uint32, types.Uint64:
+ g.it.Import("strconv")
+ g.writef("buf = strconv.AppendUint(buf, uint64(%s), 10)", accessor)
+ case types.String:
+ g.it.Import("strconv")
+ g.writef("buf = strconv.AppendQuote(buf, %s)", accessor)
+ default:
+ g.writef(`panic("TODO: %s (%T)")`, accessor, field.Kind)
+ }
+}
+
+func loadTypes(pkgName, buildTags string) (*packages.Package, map[string]*types.Named, error) {
+ cfg := &packages.Config{
+ Mode: packages.NeedTypes |
+ packages.NeedTypesInfo |
+ packages.NeedSyntax |
+ packages.NeedName,
+ Tests: false,
+ }
+ if buildTags != "" {
+ cfg.BuildFlags = []string{"-tags=" + buildTags}
+ }
+
+ pkgs, err := packages.Load(cfg, pkgName)
+ if err != nil {
+ return nil, nil, err
+ }
+ if len(pkgs) != 1 {
+ return nil, nil, fmt.Errorf("wrong number of packages: %d", len(pkgs))
+ }
+ pkg := pkgs[0]
+ return pkg, namedTypes(pkg), nil
+}
+
+func namedTypes(pkg *packages.Package) map[string]*types.Named {
+ nt := make(map[string]*types.Named)
+ for _, file := range pkg.Syntax {
+ for _, d := range file.Decls {
+ decl, ok := d.(*ast.GenDecl)
+ if !ok || decl.Tok != token.TYPE {
+ continue
+ }
+ for _, s := range decl.Specs {
+ spec, ok := s.(*ast.TypeSpec)
+ if !ok {
+ continue
+ }
+ typeNameObj, ok := pkg.TypesInfo.Defs[spec.Name]
+ if !ok {
+ continue
+ }
+ typ, ok := typeNameObj.Type().(*types.Named)
+ if !ok {
+ continue
+ }
+ nt[spec.Name.Name] = typ
+ }
+ }
+ }
+ return nt
+}