summaryrefslogtreecommitdiffhomepage
path: root/cmd/viewer/viewer_test.go
blob: 1e24b705069d79ff491ce7623a2bdf9498ece865 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package main

import (
	"bytes"
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"go/types"
	"testing"

	"tailscale.com/util/codegen"
)

func TestViewerImports(t *testing.T) {
	tests := []struct {
		name        string
		content     string
		typeNames   []string
		wantImports [][2]string
	}{
		{
			name:        "Map",
			content:     `type Test struct { Map map[string]int }`,
			typeNames:   []string{"Test"},
			wantImports: [][2]string{{"", "tailscale.com/types/views"}},
		},
		{
			name:        "Slice",
			content:     `type Test struct { Slice []int }`,
			typeNames:   []string{"Test"},
			wantImports: [][2]string{{"", "tailscale.com/types/views"}},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			fset := token.NewFileSet()
			f, err := parser.ParseFile(fset, "test.go", "package test\n\n"+tt.content, 0)
			if err != nil {
				fmt.Println("Error parsing:", err)
				return
			}

			info := &types.Info{
				Types: make(map[ast.Expr]types.TypeAndValue),
			}

			conf := types.Config{}
			pkg, err := conf.Check("", fset, []*ast.File{f}, info)
			if err != nil {
				t.Fatal(err)
			}
			var fieldComments map[fieldNameKey]string // don't need it for this test.

			var output bytes.Buffer
			tracker := codegen.NewImportTracker(pkg)
			for i := range tt.typeNames {
				typeName, ok := pkg.Scope().Lookup(tt.typeNames[i]).(*types.TypeName)
				if !ok {
					t.Fatalf("type %q does not exist", tt.typeNames[i])
				}
				namedType, ok := typeName.Type().(*types.Named)
				if !ok {
					t.Fatalf("%q is not a named type", tt.typeNames[i])
				}
				genView(&output, tracker, namedType, fieldComments)
			}

			for _, pkg := range tt.wantImports {
				if !tracker.Has(pkg[0], pkg[1]) {
					t.Errorf("missing import %q", pkg)
				}
			}
		})
	}
}