summaryrefslogtreecommitdiffhomepage
path: root/util/codegen/codegen.go
blob: 4265806ee3ed47019055065094d948220573994a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package codegen contains shared utilities for generating code.
package codegen

import (
	"bytes"
	"fmt"
	"go/ast"
	"go/format"
	"go/token"
	"go/types"
	"os"

	"golang.org/x/tools/go/packages"
)

// WriteFormatted writes code to path.
// It runs gofmt on it before writing;
// if gofmt fails, it writes code unchanged.
// Errors can include I/O errors and gofmt errors.
//
// The advantage of always writing code to path,
// even if gofmt fails, is that it makes debugging easier.
// The code can be long, but you need it in order to debug.
// It is nicer to work with it in a file than a terminal.
// It is also easier to interpret gofmt errors
// with an editor providing file and line numbers.
func WriteFormatted(code []byte, path string) error {
	out, fmterr := format.Source(code)
	if fmterr != nil {
		out = code
	}
	ioerr := os.WriteFile(path, out, 0644)
	// Prefer I/O errors. They're usually easier to fix,
	// and until they're fixed you can't do much else.
	if ioerr != nil {
		return ioerr
	}
	if fmterr != nil {
		return fmt.Errorf("%s:%v", path, fmterr)
	}
	return nil
}

// NamedTypes returns all named types in pkg, keyed by their type name.
func NamedTypes(pkg *packages.Package) map[string]*types.Named {
	nt := make(map[string]*types.Named)
	for _, file := range pkg.Syntax {
		for _, d := range file.Decls {
			decl, ok := d.(*ast.GenDecl)
			if !ok || decl.Tok != token.TYPE {
				continue
			}
			for _, s := range decl.Specs {
				spec, ok := s.(*ast.TypeSpec)
				if !ok {
					continue
				}
				typeNameObj := pkg.TypesInfo.Defs[spec.Name]
				typ, ok := typeNameObj.Type().(*types.Named)
				if !ok {
					continue
				}
				nt[spec.Name.Name] = typ
			}
		}
	}
	return nt
}

// AssertStructUnchanged generates code that asserts at compile time that type t is unchanged.
// thisPkg is the package containing t.
// tname is the named type corresponding to t.
// ctx is a single-word context for this assertion, such as "Clone".
// If non-nil, AssertStructUnchanged will add elements to imports
// for each package path that the caller must import for the returned code to compile.
func AssertStructUnchanged(t *types.Struct, thisPkg *types.Package, tname, ctx string, imports map[string]struct{}) []byte {
	buf := new(bytes.Buffer)
	w := func(format string, args ...interface{}) {
		fmt.Fprintf(buf, format+"\n", args...)
	}
	w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.")
	w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)

	for i := 0; i < t.NumFields(); i++ {
		fname := t.Field(i).Name()
		ft := t.Field(i).Type()
		qname, imppath := importedName(ft, thisPkg)
		if imppath != "" && imports != nil {
			imports[imppath] = struct{}{}
		}
		w("\t%s %s", fname, qname)
	}

	w("}{})\n")
	return buf.Bytes()
}

func importedName(t types.Type, thisPkg *types.Package) (qualifiedName, importPkg string) {
	qual := func(pkg *types.Package) string {
		if thisPkg == pkg {
			return ""
		}
		importPkg = pkg.Path()
		return pkg.Name()
	}
	return types.TypeString(t, qual), importPkg
}

// ContainsPointers reports whether typ contains any pointers,
// either explicitly or implicitly.
// It has special handling for some types that contain pointers
// that we know are free from memory aliasing/mutation concerns.
func ContainsPointers(typ types.Type) bool {
	switch typ.String() {
	case "time.Time":
		// time.Time contains a pointer that does not need copying
		return false
	case "inet.af/netaddr.IP":
		return false
	}
	switch ft := typ.Underlying().(type) {
	case *types.Array:
		return ContainsPointers(ft.Elem())
	case *types.Chan:
		return true
	case *types.Interface:
		return true // a little too broad
	case *types.Map:
		return true
	case *types.Pointer:
		return true
	case *types.Slice:
		return true
	case *types.Struct:
		for i := 0; i < ft.NumFields(); i++ {
			if ContainsPointers(ft.Field(i).Type()) {
				return true
			}
		}
	}
	return false
}