summaryrefslogtreecommitdiffhomepage
path: root/ssh
diff options
context:
space:
mode:
authorBrad Fitzpatrick <bradfitz@tailscale.com>2022-04-01 12:57:12 -0700
committerBrad Fitzpatrick <bradfitz@tailscale.com>2022-04-01 12:57:12 -0700
commit79483a1e5e86ddfe9c59d760809b0fea830ddc84 (patch)
treecdba04175d03017edec27e0776177d188477ad7d /ssh
parent9f604f2bd3b48fc4464727b7ac6f5ee56d02413c (diff)
downloadtailscale-bradfitz/ssh_policy_earlier.tar.xz
tailscale-bradfitz/ssh_policy_earlier.zip
tailcfg, ssh/tailssh: optionally support SSH public keys in wire policybradfitz/ssh_policy_earlier
Updates #3802 Change-Id: I756dc2d579a16757537142283d791f1d0319f4f0 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Diffstat (limited to 'ssh')
-rw-r--r--ssh/tailssh/tailssh.go164
-rw-r--r--ssh/tailssh/tailssh_test.go5
2 files changed, 148 insertions, 21 deletions
diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go
index 3fcd87d09..64e10e132 100644
--- a/ssh/tailssh/tailssh.go
+++ b/ssh/tailssh/tailssh.go
@@ -9,8 +9,10 @@
package tailssh
import (
+ "bytes"
"context"
"crypto/rand"
+ "encoding/base64"
"encoding/json"
"errors"
"fmt"
@@ -77,9 +79,15 @@ func (srv *server) newSSHServer() (*ssh.Server, error) {
Version: "SSH-2.0-Tailscale",
LocalPortForwardingCallback: srv.mayForwardLocalPortTo,
NoClientAuthCallback: func(m gossh.ConnMetadata) (*gossh.Permissions, error) {
- srv.logf("SSH connection from %v for %q; client ver %q", m.RemoteAddr(), m.User(), m.ClientVersion())
+ if srv.askForCert(m.User(), m.LocalAddr(), m.RemoteAddr()) {
+ return nil, errors.New("cert required") // any non-nil error will do
+ }
return nil, nil
},
+ PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
+ srv.logf("SSH public key %T %#v", key, key)
+ return true // rejected later, after accepting connections
+ },
}
for k, v := range ssh.DefaultRequestHandlers {
ss.RequestHandlers[k] = v
@@ -124,6 +132,31 @@ func (srv *server) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string
return ss.action.AllowLocalPortForwarding
}
+// askForCert reports whether the SSH server, during the auth negotiation phase,
+// should requires that the client send an SSH cert.
+func (srv *server) askForCert(sshUser string, localAddr, remoteAddr net.Addr) bool {
+ pol, ok := srv.sshPolicy()
+ if !ok {
+ return false
+ }
+ a, ci, _, err := srv.evaluatePolicy(sshUser, localAddr, remoteAddr, nil)
+ if err == nil && (a.Accept || a.HoldAndDelegate != "") {
+ // Policy doesn't require a cert.
+ return false
+ }
+
+ // Is there any rule that looks like it'd require a cert for
+ // this sshUser?
+ for _, r := range pol.Rules {
+ for _, p := range r.Principals {
+ if principalMatchesTailscaleIdentity(p, ci) && len(p.Certs) > 0 {
+ return true
+ }
+ }
+ }
+ return false
+}
+
// sshPolicy returns the SSHPolicy for current node.
// If there is no SSHPolicy in the netmap, it returns a debugPolicy
// if one is defined.
@@ -170,7 +203,7 @@ func asTailscaleIPPort(a net.Addr) (netaddr.IPPort, error) {
// evaluatePolicy returns the SSHAction, sshConnInfo and localUser
// after evaluating the sshUser and remoteAddr against the SSHPolicy.
// The remoteAddr and localAddr params must be Tailscale IPs.
-func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr net.Addr) (_ *tailcfg.SSHAction, _ *sshConnInfo, localUser string, _ error) {
+func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr net.Addr, pubKey ssh.PublicKey) (_ *tailcfg.SSHAction, _ *sshConnInfo, localUser string, _ error) {
logf := srv.logf
lb := srv.lb
logf("Handling SSH from %v for user %v", remoteAddr, sshUser)
@@ -194,12 +227,14 @@ func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr net.Addr
}
ci := &sshConnInfo{
- now: time.Now(),
- sshUser: sshUser,
- src: srcIPP,
- dst: dstIPP,
- node: node,
- uprof: &uprof,
+ now: time.Now(),
+ fetchPublicKeysURL: srv.fetchPublicKeysURL,
+ sshUser: sshUser,
+ src: srcIPP,
+ dst: dstIPP,
+ node: node,
+ uprof: &uprof,
+ pubKey: pubKey,
}
a, localUser, ok := evalSSHPolicy(pol, ci)
if !ok {
@@ -208,12 +243,36 @@ func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr net.Addr
return a, ci, localUser, nil
}
+func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
+ if !strings.HasPrefix(url, "https://") {
+ return nil, errors.New("invalid URL scheme")
+ }
+ // TODO(bradfitz): add caching
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+ if err != nil {
+ return nil, err
+ }
+ res, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer res.Body.Close()
+ if res.StatusCode != http.StatusOK {
+ return nil, errors.New(res.Status)
+ }
+ all, err := io.ReadAll(io.LimitReader(res.Body, 1<<10))
+ return strings.Split(string(all), "\n"), err
+}
+
// handleSSH is invoked when a new SSH connection attempt is made.
func (srv *server) handleSSH(s ssh.Session) {
logf := srv.logf
sshUser := s.User()
- action, ci, localUser, err := srv.evaluatePolicy(sshUser, s.LocalAddr(), s.RemoteAddr())
+ action, ci, localUser, err := srv.evaluatePolicy(sshUser, s.LocalAddr(), s.RemoteAddr(), s.PublicKey())
if err != nil {
logf(err.Error())
s.Exit(1)
@@ -609,6 +668,10 @@ type sshConnInfo struct {
// now is the time to consider the present moment for the
// purposes of rule evaluation.
now time.Time
+ // fetchPublicKeysURL, if non-nil, is a func to fetch the public
+ // keys of a URL. The strings are in the the typical public
+ // key "type base64-string [comment]" format seen at e.g. https://github.com/USER.keys
+ fetchPublicKeysURL func(url string) ([]string, error)
// sshUser is the requested local SSH username ("root", "alice", etc).
sshUser string
@@ -624,6 +687,11 @@ type sshConnInfo struct {
// uprof is node's UserProfile.
uprof *tailcfg.UserProfile
+
+ // pubKey is the public key presented by the client, or nil
+ // if they haven't yet sent one (as in the early "none" phase
+ // of authentication negotiation).
+ pubKey ssh.PublicKey
}
func evalSSHPolicy(pol *tailcfg.SSHPolicy, ci *sshConnInfo) (a *tailcfg.SSHAction, localUser string, ok bool) {
@@ -654,15 +722,15 @@ func matchRule(r *tailcfg.SSHRule, ci *sshConnInfo) (a *tailcfg.SSHAction, local
if r.RuleExpires != nil && ci.now.After(*r.RuleExpires) {
return nil, "", errRuleExpired
}
- if !matchesPrincipal(r.Principals, ci) {
- return nil, "", errPrincipalMatch
- }
if !r.Action.Reject || r.SSHUsers != nil {
localUser = mapLocalUser(r.SSHUsers, ci.sshUser)
if localUser == "" {
return nil, "", errUserMatch
}
}
+ if !anyPrincipalMatches(r.Principals, ci) {
+ return nil, "", errPrincipalMatch
+ }
return r.Action, localUser, nil
}
@@ -677,29 +745,85 @@ func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) (localUser
return v
}
-func matchesPrincipal(ps []*tailcfg.SSHPrincipal, ci *sshConnInfo) bool {
+func anyPrincipalMatches(ps []*tailcfg.SSHPrincipal, ci *sshConnInfo) bool {
for _, p := range ps {
if p == nil {
continue
}
- if p.Any {
+ if principalMatches(p, ci) {
return true
}
- if !p.Node.IsZero() && ci.node != nil && p.Node == ci.node.StableID {
+ }
+ return false
+}
+
+func principalMatches(p *tailcfg.SSHPrincipal, ci *sshConnInfo) bool {
+ return principalMatchesTailscaleIdentity(p, ci) &&
+ principalMatchesCert(p, ci)
+}
+
+// principalMatchesTailscaleIdentity reports whether one of p's four fields
+// that match the Tailscale identity match (Node, NodeIP, UserLogin, Any).
+// This function does not consider Certs.
+func principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal, ci *sshConnInfo) bool {
+ if p.Any {
+ return true
+ }
+ if !p.Node.IsZero() && ci.node != nil && p.Node == ci.node.StableID {
+ return true
+ }
+ if p.NodeIP != "" {
+ if ip, _ := netaddr.ParseIP(p.NodeIP); ip == ci.src.IP() {
return true
}
- if p.NodeIP != "" {
- if ip, _ := netaddr.ParseIP(p.NodeIP); ip == ci.src.IP() {
- return true
- }
+ }
+ if p.UserLogin != "" && ci.uprof != nil && ci.uprof.LoginName == p.UserLogin {
+ return true
+ }
+ return false
+}
+
+func principalMatchesCert(p *tailcfg.SSHPrincipal, ci *sshConnInfo) bool {
+ if len(p.Certs) == 0 {
+ return true
+ }
+ if ci.pubKey == nil {
+ return false
+ }
+ certs := p.Certs
+ if len(certs) == 1 && strings.HasPrefix(certs[0], "https://") {
+ if ci.fetchPublicKeysURL == nil {
+ // TODO: log?
+ return false
}
- if p.UserLogin != "" && ci.uprof != nil && ci.uprof.LoginName == p.UserLogin {
+ var err error
+ certs, err = ci.fetchPublicKeysURL(certs[0])
+ if err != nil {
+ // TODO: log?
+ return false
+ }
+ }
+ for _, cert := range certs {
+ if pubKeyMatchesAuthorizedKey(ci.pubKey, cert) {
return true
}
}
return false
}
+func pubKeyMatchesAuthorizedKey(pubKey ssh.PublicKey, wantKey string) bool {
+ wantKeyType, rest, ok := strings.Cut(wantKey, " ")
+ if !ok {
+ return false
+ }
+ if pubKey.Type() != wantKeyType {
+ return false
+ }
+ wantKeyB64, _, _ := strings.Cut(rest, " ")
+ wantKeyData, _ := base64.StdEncoding.DecodeString(wantKeyB64)
+ return len(wantKeyData) > 0 && bytes.Equal(pubKey.Marshal(), wantKeyData)
+}
+
func randBytes(n int) []byte {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go
index afae4984a..23e2540a3 100644
--- a/ssh/tailssh/tailssh_test.go
+++ b/ssh/tailssh/tailssh_test.go
@@ -63,7 +63,10 @@ func TestMatchRule(t *testing.T) {
name: "no-principal",
rule: &tailcfg.SSHRule{
Action: someAction,
- },
+ SSHUsers: map[string]string{
+ "*": "ubuntu",
+ }},
+ ci: &sshConnInfo{},
wantErr: errPrincipalMatch,
},
{