diff options
Diffstat (limited to 'ssh/tailssh/tailssh_test.go')
| -rw-r--r-- | ssh/tailssh/tailssh_test.go | 167 |
1 files changed, 76 insertions, 91 deletions
diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 581c8be82..c750370a2 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -64,31 +64,20 @@ func TestMatchRule(t *testing.T) { tests := []struct { name string rule *tailcfg.SSHRule - ci *sshConnInfo + ci sshConnInfo wantErr error wantUser string wantAcceptEnv []string }{ { - name: "invalid-conn", - rule: &tailcfg.SSHRule{ - Action: someAction, - Principals: []*tailcfg.SSHPrincipal{{Any: true}}, - SSHUsers: map[string]string{ - "*": "ubuntu", - }, - }, - wantErr: errInvalidConn, - }, - { name: "nil-rule", - ci: &sshConnInfo{}, + ci: sshConnInfo{}, rule: nil, wantErr: errNilRule, }, { name: "nil-action", - ci: &sshConnInfo{}, + ci: sshConnInfo{}, rule: &tailcfg.SSHRule{}, wantErr: errNilAction, }, @@ -98,7 +87,7 @@ func TestMatchRule(t *testing.T) { Action: someAction, RuleExpires: ptr.To(time.Unix(100, 0)), }, - ci: &sshConnInfo{}, + ci: sshConnInfo{}, wantErr: errRuleExpired, }, { @@ -108,7 +97,7 @@ func TestMatchRule(t *testing.T) { SSHUsers: map[string]string{ "*": "ubuntu", }}, - ci: &sshConnInfo{}, + ci: sshConnInfo{}, wantErr: errPrincipalMatch, }, { @@ -117,7 +106,7 @@ func TestMatchRule(t *testing.T) { Action: someAction, Principals: []*tailcfg.SSHPrincipal{{Any: true}}, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantErr: errUserMatch, }, { @@ -129,7 +118,7 @@ func TestMatchRule(t *testing.T) { "*": "ubuntu", }, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "ubuntu", }, { @@ -144,7 +133,7 @@ func TestMatchRule(t *testing.T) { "*": "ubuntu", }, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "ubuntu", }, { @@ -157,7 +146,7 @@ func TestMatchRule(t *testing.T) { "alice": "thealice", }, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "thealice", }, { @@ -171,7 +160,7 @@ func TestMatchRule(t *testing.T) { }, AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"}, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "thealice", wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"}, }, @@ -181,7 +170,7 @@ func TestMatchRule(t *testing.T) { Principals: []*tailcfg.SSHPrincipal{{Any: true}}, Action: &tailcfg.SSHAction{Reject: true}, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, }, { name: "match-principal-node-ip", @@ -190,7 +179,7 @@ func TestMatchRule(t *testing.T) { Principals: []*tailcfg.SSHPrincipal{{NodeIP: "1.2.3.4"}}, SSHUsers: map[string]string{"*": "ubuntu"}, }, - ci: &sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")}, + ci: sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")}, wantUser: "ubuntu", }, { @@ -200,7 +189,7 @@ func TestMatchRule(t *testing.T) { Principals: []*tailcfg.SSHPrincipal{{Node: "some-node-ID"}}, SSHUsers: map[string]string{"*": "ubuntu"}, }, - ci: &sshConnInfo{node: (&tailcfg.Node{StableID: "some-node-ID"}).View()}, + ci: sshConnInfo{node: (&tailcfg.Node{StableID: "some-node-ID"}).View()}, wantUser: "ubuntu", }, { @@ -210,7 +199,7 @@ func TestMatchRule(t *testing.T) { Principals: []*tailcfg.SSHPrincipal{{UserLogin: "foo@bar.com"}}, SSHUsers: map[string]string{"*": "ubuntu"}, }, - ci: &sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "foo@bar.com"}}, + ci: sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "foo@bar.com"}}, wantUser: "ubuntu", }, { @@ -222,7 +211,7 @@ func TestMatchRule(t *testing.T) { "*": "=", }, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "alice", }, } @@ -254,7 +243,7 @@ func TestEvalSSHPolicy(t *testing.T) { tests := []struct { name string policy *tailcfg.SSHPolicy - ci *sshConnInfo + ci sshConnInfo wantResult evalResult wantUser string wantAcceptEnv []string @@ -298,7 +287,7 @@ func TestEvalSSHPolicy(t *testing.T) { }, }, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "thealice", wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"}, wantResult: accepted, @@ -308,7 +297,7 @@ func TestEvalSSHPolicy(t *testing.T) { policy: &tailcfg.SSHPolicy{ Rules: []*tailcfg.SSHRule{}, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "", wantAcceptEnv: nil, wantResult: rejected, @@ -349,7 +338,7 @@ func TestEvalSSHPolicy(t *testing.T) { }, }, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "", wantAcceptEnv: nil, wantResult: rejectedUser, @@ -1100,7 +1089,7 @@ func TestSSH(t *testing.T) { t.Fatal(err) } sc.localUser = um - sc.info = &sshConnInfo{ + sc.info = sshConnInfo{ sshUser: "test", src: netip.MustParseAddrPort("1.2.3.4:32342"), dst: netip.MustParseAddrPort("1.2.3.5:22"), @@ -1318,76 +1307,72 @@ func TestStdOsUserUserAssumptions(t *testing.T) { } } -func TestOnPolicyChangeHandlesNilLocalUser(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - srv := &server{ - logf: tstest.WhileTestRunningLogger(t), - lb: &localState{ - sshEnabled: true, - matchingRule: newSSHRule(&tailcfg.SSHAction{Accept: true}), - }, - } - c := &conn{ - srv: srv, - info: &sshConnInfo{sshUser: "alice"}, - } - srv.activeConns = map[*conn]bool{c: true} - - srv.OnPolicyChange() - - synctest.Wait() - }) -} - -func TestRaceWriteAndReadConnInfoAndLocalUser(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - srv := &server{ - logf: tstest.WhileTestRunningLogger(t), - lb: &localState{ - sshEnabled: true, - matchingRule: newSSHRule(&tailcfg.SSHAction{Accept: true}), - }, - } - c := &conn{ - srv: srv, - info: &sshConnInfo{sshUser: "alice"}, - } - srv.activeConns = map[*conn]bool{c: true} +func TestOnPolicyChangeDefersValidationOnEmptyLocalUser(t *testing.T) { + tests := []struct { + name string + sshRule *tailcfg.SSHRule + wantCancelOnValidation bool + }{ + { + name: "defer-then-accept-when-allowed", + sshRule: newSSHRule(&tailcfg.SSHAction{Accept: true}), + wantCancelOnValidation: false, + }, + { + name: "defer-then-reject-when-not-allowed", + sshRule: newSSHRule(&tailcfg.SSHAction{Reject: true}), + wantCancelOnValidation: true, + }, + } - fakeClientAuth := func() { - c.mu.Lock() - c.info = &sshConnInfo{sshUser: "alice"} - c.mu.Unlock() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { - c.mu.Lock() - c.localUser = &userMeta{User: user.User{Username: currentUser}} - c.mu.Unlock() - } + synctest.Test(t, func(t *testing.T) { + srv := &server{ + logf: tstest.WhileTestRunningLogger(t), + lb: &localState{ + sshEnabled: true, + matchingRule: tt.sshRule, + }, + } + c := &conn{ + srv: srv, + info: sshConnInfo{sshUser: "alice"}, + } + srv.activeConns = map[*conn]bool{c: true} + ctx, cancel := context.WithCancelCause(context.Background()) + ss := &sshSession{ctx: ctx, cancelCtx: cancel} + c.sessions = []*sshSession{ss} - // Simulate a race between clientAuth() writing and OnPolicyChange reading a connection's info and localUser. - done := make(chan struct{}) - go func() { - for i := 0; i < 100; i++ { + srv.OnPolicyChange() + synctest.Wait() select { - case <-done: - return + case <-ctx.Done(): + t.Fatal("expected deferral of cancellation decision while localUser unset but session got canceled") default: - fakeClientAuth() } - } - }() - go func() { - for i := 0; i < 100; i++ { + c.mu.Lock() + c.info.isSet = true + c.localUser = userMeta{User: user.User{Username: currentUser}} + c.mu.Unlock() + + srv.OnPolicyChange() + synctest.Wait() select { - case <-done: - return + case <-ctx.Done(): + if !tt.wantCancelOnValidation { + t.Fatal("valid session shouldn't have been canceled") + } default: - srv.OnPolicyChange() + if tt.wantCancelOnValidation { + t.Fatal("invalid session should have been canceled but it wasn't") + } } - } - }() - }) + }) + }) + } } func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server { |
