diff options
Diffstat (limited to 'control/controlclient/direct.go')
| -rw-r--r-- | control/controlclient/direct.go | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index c436bc8b1..d3167d6e3 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -6,6 +6,7 @@ package controlclient import ( "bufio" "bytes" + "cmp" "context" "encoding/binary" "encoding/json" @@ -53,7 +54,8 @@ import ( "tailscale.com/util/clientmetric" "tailscale.com/util/multierr" "tailscale.com/util/singleflight" - "tailscale.com/util/syspolicy" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" "tailscale.com/util/systemd" "tailscale.com/util/testenv" "tailscale.com/util/zstdframe" @@ -76,6 +78,7 @@ type Direct struct { debugFlags []string skipIPForwardingCheck bool pinger Pinger + polc policyclient.Client // always non-nil popBrowser func(url string) // or nil c2nHandler http.Handler // or nil onClientVersion func(*tailcfg.ClientVersion) // or nil @@ -124,9 +127,10 @@ type Options struct { Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc DiscoPublicKey key.DiscoPublic Logf logger.Logf - HTTPTestClient *http.Client // optional HTTP client to use (for tests only) - NoiseTestClient *http.Client // optional HTTP client to use for noise RPCs (tests only) - DebugFlags []string // debug settings to send to control + PolicyClient policyclient.Client // or nil for none + HTTPTestClient *http.Client // optional HTTP client to use (for tests only) + NoiseTestClient *http.Client // optional HTTP client to use for noise RPCs (tests only) + DebugFlags []string // debug settings to send to control HealthTracker *health.Tracker PopBrowserURL func(url string) // optional func to open browser OnClientVersion func(*tailcfg.ClientVersion) // optional func to inform GUI of client version status @@ -296,6 +300,7 @@ func NewDirect(opts Options) (*Direct, error) { health: opts.HealthTracker, skipIPForwardingCheck: opts.SkipIPForwardingCheck, pinger: opts.Pinger, + polc: cmp.Or(opts.PolicyClient, policyclient.Client(policyclient.NoPolicyClient{})), popBrowser: opts.PopBrowserURL, onClientVersion: opts.OnClientVersion, onTailnetDefaultAutoUpdate: opts.OnTailnetDefaultAutoUpdate, @@ -606,7 +611,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new return regen, opt.URL, nil, err } - tailnet, err := syspolicy.GetString(syspolicy.Tailnet, "") + tailnet, err := c.polc.GetString(pkey.Tailnet, "") if err != nil { c.logf("unable to provide Tailnet field in register request. err: %v", err) } @@ -636,7 +641,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new AuthKey: authKey, } } - err = signRegisterRequest(&request, c.serverURL, c.serverLegacyKey, machinePrivKey.Public()) + err = signRegisterRequest(c.polc, &request, c.serverURL, c.serverLegacyKey, machinePrivKey.Public()) if err != nil { // If signing failed, clear all related fields request.SignatureType = tailcfg.SignatureNone |
