summaryrefslogtreecommitdiffhomepage
path: root/ssh/tailssh/tailssh_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'ssh/tailssh/tailssh_test.go')
-rw-r--r--ssh/tailssh/tailssh_test.go167
1 files changed, 76 insertions, 91 deletions
diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go
index 581c8be82..c750370a2 100644
--- a/ssh/tailssh/tailssh_test.go
+++ b/ssh/tailssh/tailssh_test.go
@@ -64,31 +64,20 @@ func TestMatchRule(t *testing.T) {
tests := []struct {
name string
rule *tailcfg.SSHRule
- ci *sshConnInfo
+ ci sshConnInfo
wantErr error
wantUser string
wantAcceptEnv []string
}{
{
- name: "invalid-conn",
- rule: &tailcfg.SSHRule{
- Action: someAction,
- Principals: []*tailcfg.SSHPrincipal{{Any: true}},
- SSHUsers: map[string]string{
- "*": "ubuntu",
- },
- },
- wantErr: errInvalidConn,
- },
- {
name: "nil-rule",
- ci: &sshConnInfo{},
+ ci: sshConnInfo{},
rule: nil,
wantErr: errNilRule,
},
{
name: "nil-action",
- ci: &sshConnInfo{},
+ ci: sshConnInfo{},
rule: &tailcfg.SSHRule{},
wantErr: errNilAction,
},
@@ -98,7 +87,7 @@ func TestMatchRule(t *testing.T) {
Action: someAction,
RuleExpires: ptr.To(time.Unix(100, 0)),
},
- ci: &sshConnInfo{},
+ ci: sshConnInfo{},
wantErr: errRuleExpired,
},
{
@@ -108,7 +97,7 @@ func TestMatchRule(t *testing.T) {
SSHUsers: map[string]string{
"*": "ubuntu",
}},
- ci: &sshConnInfo{},
+ ci: sshConnInfo{},
wantErr: errPrincipalMatch,
},
{
@@ -117,7 +106,7 @@ func TestMatchRule(t *testing.T) {
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantErr: errUserMatch,
},
{
@@ -129,7 +118,7 @@ func TestMatchRule(t *testing.T) {
"*": "ubuntu",
},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "ubuntu",
},
{
@@ -144,7 +133,7 @@ func TestMatchRule(t *testing.T) {
"*": "ubuntu",
},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "ubuntu",
},
{
@@ -157,7 +146,7 @@ func TestMatchRule(t *testing.T) {
"alice": "thealice",
},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "thealice",
},
{
@@ -171,7 +160,7 @@ func TestMatchRule(t *testing.T) {
},
AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "thealice",
wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
},
@@ -181,7 +170,7 @@ func TestMatchRule(t *testing.T) {
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
Action: &tailcfg.SSHAction{Reject: true},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
},
{
name: "match-principal-node-ip",
@@ -190,7 +179,7 @@ func TestMatchRule(t *testing.T) {
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "1.2.3.4"}},
SSHUsers: map[string]string{"*": "ubuntu"},
},
- ci: &sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")},
+ ci: sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")},
wantUser: "ubuntu",
},
{
@@ -200,7 +189,7 @@ func TestMatchRule(t *testing.T) {
Principals: []*tailcfg.SSHPrincipal{{Node: "some-node-ID"}},
SSHUsers: map[string]string{"*": "ubuntu"},
},
- ci: &sshConnInfo{node: (&tailcfg.Node{StableID: "some-node-ID"}).View()},
+ ci: sshConnInfo{node: (&tailcfg.Node{StableID: "some-node-ID"}).View()},
wantUser: "ubuntu",
},
{
@@ -210,7 +199,7 @@ func TestMatchRule(t *testing.T) {
Principals: []*tailcfg.SSHPrincipal{{UserLogin: "foo@bar.com"}},
SSHUsers: map[string]string{"*": "ubuntu"},
},
- ci: &sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "foo@bar.com"}},
+ ci: sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "foo@bar.com"}},
wantUser: "ubuntu",
},
{
@@ -222,7 +211,7 @@ func TestMatchRule(t *testing.T) {
"*": "=",
},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "alice",
},
}
@@ -254,7 +243,7 @@ func TestEvalSSHPolicy(t *testing.T) {
tests := []struct {
name string
policy *tailcfg.SSHPolicy
- ci *sshConnInfo
+ ci sshConnInfo
wantResult evalResult
wantUser string
wantAcceptEnv []string
@@ -298,7 +287,7 @@ func TestEvalSSHPolicy(t *testing.T) {
},
},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "thealice",
wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
wantResult: accepted,
@@ -308,7 +297,7 @@ func TestEvalSSHPolicy(t *testing.T) {
policy: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "",
wantAcceptEnv: nil,
wantResult: rejected,
@@ -349,7 +338,7 @@ func TestEvalSSHPolicy(t *testing.T) {
},
},
},
- ci: &sshConnInfo{sshUser: "alice"},
+ ci: sshConnInfo{sshUser: "alice"},
wantUser: "",
wantAcceptEnv: nil,
wantResult: rejectedUser,
@@ -1100,7 +1089,7 @@ func TestSSH(t *testing.T) {
t.Fatal(err)
}
sc.localUser = um
- sc.info = &sshConnInfo{
+ sc.info = sshConnInfo{
sshUser: "test",
src: netip.MustParseAddrPort("1.2.3.4:32342"),
dst: netip.MustParseAddrPort("1.2.3.5:22"),
@@ -1318,76 +1307,72 @@ func TestStdOsUserUserAssumptions(t *testing.T) {
}
}
-func TestOnPolicyChangeHandlesNilLocalUser(t *testing.T) {
- synctest.Test(t, func(t *testing.T) {
- srv := &server{
- logf: tstest.WhileTestRunningLogger(t),
- lb: &localState{
- sshEnabled: true,
- matchingRule: newSSHRule(&tailcfg.SSHAction{Accept: true}),
- },
- }
- c := &conn{
- srv: srv,
- info: &sshConnInfo{sshUser: "alice"},
- }
- srv.activeConns = map[*conn]bool{c: true}
-
- srv.OnPolicyChange()
-
- synctest.Wait()
- })
-}
-
-func TestRaceWriteAndReadConnInfoAndLocalUser(t *testing.T) {
- synctest.Test(t, func(t *testing.T) {
- srv := &server{
- logf: tstest.WhileTestRunningLogger(t),
- lb: &localState{
- sshEnabled: true,
- matchingRule: newSSHRule(&tailcfg.SSHAction{Accept: true}),
- },
- }
- c := &conn{
- srv: srv,
- info: &sshConnInfo{sshUser: "alice"},
- }
- srv.activeConns = map[*conn]bool{c: true}
+func TestOnPolicyChangeDefersValidationOnEmptyLocalUser(t *testing.T) {
+ tests := []struct {
+ name string
+ sshRule *tailcfg.SSHRule
+ wantCancelOnValidation bool
+ }{
+ {
+ name: "defer-then-accept-when-allowed",
+ sshRule: newSSHRule(&tailcfg.SSHAction{Accept: true}),
+ wantCancelOnValidation: false,
+ },
+ {
+ name: "defer-then-reject-when-not-allowed",
+ sshRule: newSSHRule(&tailcfg.SSHAction{Reject: true}),
+ wantCancelOnValidation: true,
+ },
+ }
- fakeClientAuth := func() {
- c.mu.Lock()
- c.info = &sshConnInfo{sshUser: "alice"}
- c.mu.Unlock()
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
- c.mu.Lock()
- c.localUser = &userMeta{User: user.User{Username: currentUser}}
- c.mu.Unlock()
- }
+ synctest.Test(t, func(t *testing.T) {
+ srv := &server{
+ logf: tstest.WhileTestRunningLogger(t),
+ lb: &localState{
+ sshEnabled: true,
+ matchingRule: tt.sshRule,
+ },
+ }
+ c := &conn{
+ srv: srv,
+ info: sshConnInfo{sshUser: "alice"},
+ }
+ srv.activeConns = map[*conn]bool{c: true}
+ ctx, cancel := context.WithCancelCause(context.Background())
+ ss := &sshSession{ctx: ctx, cancelCtx: cancel}
+ c.sessions = []*sshSession{ss}
- // Simulate a race between clientAuth() writing and OnPolicyChange reading a connection's info and localUser.
- done := make(chan struct{})
- go func() {
- for i := 0; i < 100; i++ {
+ srv.OnPolicyChange()
+ synctest.Wait()
select {
- case <-done:
- return
+ case <-ctx.Done():
+ t.Fatal("expected deferral of cancellation decision while localUser unset but session got canceled")
default:
- fakeClientAuth()
}
- }
- }()
- go func() {
- for i := 0; i < 100; i++ {
+ c.mu.Lock()
+ c.info.isSet = true
+ c.localUser = userMeta{User: user.User{Username: currentUser}}
+ c.mu.Unlock()
+
+ srv.OnPolicyChange()
+ synctest.Wait()
select {
- case <-done:
- return
+ case <-ctx.Done():
+ if !tt.wantCancelOnValidation {
+ t.Fatal("valid session shouldn't have been canceled")
+ }
default:
- srv.OnPolicyChange()
+ if tt.wantCancelOnValidation {
+ t.Fatal("invalid session should have been canceled but it wasn't")
+ }
}
- }
- }()
- })
+ })
+ })
+ }
}
func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server {