summaryrefslogtreecommitdiffhomepage
path: root/ipn
diff options
context:
space:
mode:
Diffstat (limited to 'ipn')
-rw-r--r--ipn/ipnlocal/cert.go86
-rw-r--r--ipn/ipnlocal/cert_js.go2
-rw-r--r--ipn/ipnlocal/cert_test.go9
-rw-r--r--ipn/ipnlocal/serve.go4
-rw-r--r--ipn/localapi/cert.go2
5 files changed, 63 insertions, 40 deletions
diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go
index f5384276c..627cc7872 100644
--- a/ipn/ipnlocal/cert.go
+++ b/ipn/ipnlocal/cert.go
@@ -53,8 +53,8 @@ var (
// populate the on-disk cache and the rest should use that.
acmeMu sync.Mutex
- renewMu sync.Mutex // lock order: don't hold acmeMu and renewMu at the same time
- lastRenewCheck = map[string]time.Time{}
+ renewMu sync.Mutex // lock order: acmeMu before renewMu
+ renewCertAt = map[string]time.Time{}
)
// certDir returns (creating if needed) the directory in which cached
@@ -80,9 +80,15 @@ func (b *LocalBackend) certDir() (string, error) {
var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME")
-// getCertPEM gets the KeyPair for domain, either from cache, via the ACME
-// process, or from cache and kicking off an async ACME renewal.
-func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) {
+// GetCertPEM gets the TLSCertKeyPair for domain, either from cache or via the
+// ACME process. ACME process is used for new domain certs, existing expired
+// certs or existing certs that should get renewed due to upcoming expiry.
+//
+// syncRenewal changes renewal behavior for existing certs that are still valid
+// but need renewal. When syncRenewal is set, the method blocks until a new
+// cert is issued. When syncRenewal is not set, existing cert is returned right
+// away and renewal is kicked off in a background goroutine.
+func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) {
if !validLookingCertDomain(domain) {
return nil, errors.New("invalid domain")
}
@@ -105,12 +111,15 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK
shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair)
if err != nil {
logf("error checking for certificate renewal: %v", err)
- } else if shouldRenew {
+ } else if !shouldRenew {
+ return pair, nil
+ }
+ if !syncRenewal {
logf("starting async renewal")
// Start renewal in the background.
go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now)
}
- return pair, nil
+ // Synchronous renewal happens below.
}
pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now)
@@ -124,37 +133,43 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK
func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) {
renewMu.Lock()
defer renewMu.Unlock()
- if last, ok := lastRenewCheck[domain]; ok && now.Sub(last) < time.Minute {
- // We checked very recently. Don't bother reparsing &
- // validating the x509 cert.
- return false, nil
+ if renewAt, ok := renewCertAt[domain]; ok {
+ return now.After(renewAt), nil
}
- lastRenewCheck[domain] = now
- renew, err := b.shouldStartDomainRenewalByARI(cs, now, pair)
+ renewTime, err := b.domainRenewalTimeByARI(cs, pair)
if err != nil {
// Log any ARI failure and fall back to checking for renewal by expiry.
b.logf("acme: ARI check failed: %v; falling back to expiry-based check", err)
- } else {
- return renew, nil
+ renewTime, err = b.domainRenewalTimeByExpiry(pair)
+ if err != nil {
+ return false, err
+ }
}
- return b.shouldStartDomainRenewalByExpiry(now, pair)
+ renewCertAt[domain] = renewTime
+ return now.After(renewTime), nil
}
-func (b *LocalBackend) shouldStartDomainRenewalByExpiry(now time.Time, pair *TLSCertKeyPair) (bool, error) {
+func (b *LocalBackend) domainRenewed(domain string) {
+ renewMu.Lock()
+ defer renewMu.Unlock()
+ delete(renewCertAt, domain)
+}
+
+func (b *LocalBackend) domainRenewalTimeByExpiry(pair *TLSCertKeyPair) (time.Time, error) {
block, _ := pem.Decode(pair.CertPEM)
if block == nil {
- return false, fmt.Errorf("parsing certificate PEM")
+ return time.Time{}, fmt.Errorf("parsing certificate PEM")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
- return false, fmt.Errorf("parsing certificate: %w", err)
+ return time.Time{}, fmt.Errorf("parsing certificate: %w", err)
}
certLifetime := cert.NotAfter.Sub(cert.NotBefore)
if certLifetime < 0 {
- return false, fmt.Errorf("negative certificate lifetime %v", certLifetime)
+ return time.Time{}, fmt.Errorf("negative certificate lifetime %v", certLifetime)
}
// Per https://github.com/tailscale/tailscale/issues/8204, check
@@ -163,36 +178,32 @@ func (b *LocalBackend) shouldStartDomainRenewalByExpiry(now time.Time, pair *TLS
// Encrypt.
renewalDuration := certLifetime * 2 / 3
renewAt := cert.NotBefore.Add(renewalDuration)
-
- if now.After(renewAt) {
- return true, nil
- }
- return false, nil
+ return renewAt, nil
}
-func (b *LocalBackend) shouldStartDomainRenewalByARI(cs certStore, now time.Time, pair *TLSCertKeyPair) (bool, error) {
+func (b *LocalBackend) domainRenewalTimeByARI(cs certStore, pair *TLSCertKeyPair) (time.Time, error) {
var blocks []*pem.Block
rest := pair.CertPEM
for len(rest) > 0 {
var block *pem.Block
block, rest = pem.Decode(rest)
if block == nil {
- return false, fmt.Errorf("parsing certificate PEM")
+ return time.Time{}, fmt.Errorf("parsing certificate PEM")
}
blocks = append(blocks, block)
}
if len(blocks) < 2 {
- return false, fmt.Errorf("could not parse certificate chain from certStore, got %d PEM block(s)", len(blocks))
+ return time.Time{}, fmt.Errorf("could not parse certificate chain from certStore, got %d PEM block(s)", len(blocks))
}
ac, err := acmeClient(cs)
if err != nil {
- return false, err
+ return time.Time{}, err
}
ctx, cancel := context.WithTimeout(b.ctx, 5*time.Second)
defer cancel()
ri, err := ac.FetchRenewalInfo(ctx, blocks[0].Bytes, blocks[1].Bytes)
if err != nil {
- return false, fmt.Errorf("failed to fetch renewal info from ACME server: %w", err)
+ return time.Time{}, fmt.Errorf("failed to fetch renewal info from ACME server: %w", err)
}
if acmeDebug() {
b.logf("acme: ARI response: %+v", ri)
@@ -203,7 +214,7 @@ func (b *LocalBackend) shouldStartDomainRenewalByARI(cs certStore, now time.Time
// https://datatracker.ietf.org/doc/draft-ietf-acme-ari/
start, end := ri.SuggestedWindow.Start, ri.SuggestedWindow.End
renewTime := start.Add(time.Duration(insecurerand.Int63n(int64(end.Sub(start)))))
- return now.After(renewTime), nil
+ return renewTime, nil
}
// certStore provides a way to perist and retrieve TLS certificates.
@@ -371,8 +382,18 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger
acmeMu.Lock()
defer acmeMu.Unlock()
+ // In case this method was triggered multiple times in parallel (when
+ // serving incoming requests), check whether one of the other goroutines
+ // already renewed the cert before us.
if p, err := getCertPEMCached(cs, domain, now); err == nil {
- return p, nil
+ // shouldStartDomainRenewal caches its result so it's OK to call this
+ // frequently.
+ shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p)
+ if err != nil {
+ logf("error checking for certificate renewal: %v", err)
+ } else if !shouldRenew {
+ return p, nil
+ }
} else if !errors.Is(err, ipn.ErrStateNotExist) && !errors.Is(err, errCertExpired) {
return nil, err
}
@@ -509,6 +530,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger
if err := cs.WriteCert(domain, certPEM.Bytes()); err != nil {
return nil, err
}
+ b.domainRenewed(domain)
return &TLSCertKeyPair{CertPEM: certPEM.Bytes(), KeyPEM: privPEM.Bytes()}, nil
}
diff --git a/ipn/ipnlocal/cert_js.go b/ipn/ipnlocal/cert_js.go
index a5fdfc4ba..24defb47b 100644
--- a/ipn/ipnlocal/cert_js.go
+++ b/ipn/ipnlocal/cert_js.go
@@ -12,6 +12,6 @@ type TLSCertKeyPair struct {
CertPEM, KeyPEM []byte
}
-func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) {
+func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) {
return nil, errors.New("not implemented for js/wasm")
}
diff --git a/ipn/ipnlocal/cert_test.go b/ipn/ipnlocal/cert_test.go
index 52ba13453..66d942032 100644
--- a/ipn/ipnlocal/cert_test.go
+++ b/ipn/ipnlocal/cert_test.go
@@ -112,7 +112,7 @@ func TestShouldStartDomainRenewal(t *testing.T) {
reset := func() {
renewMu.Lock()
defer renewMu.Unlock()
- maps.Clear(lastRenewCheck)
+ maps.Clear(renewCertAt)
}
mustMakePair := func(template *x509.Certificate) *TLSCertKeyPair {
@@ -178,7 +178,7 @@ func TestShouldStartDomainRenewal(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
reset()
- ret, err := b.shouldStartDomainRenewalByExpiry(now, mustMakePair(&x509.Certificate{
+ ret, err := b.domainRenewalTimeByExpiry(mustMakePair(&x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: subject,
NotBefore: tt.notBefore,
@@ -192,8 +192,9 @@ func TestShouldStartDomainRenewal(t *testing.T) {
t.Errorf("got err=%q, want %q", err.Error(), tt.wantErr)
}
} else {
- if ret != tt.want {
- t.Errorf("got ret=%v, want %v", ret, tt.want)
+ renew := now.After(ret)
+ if renew != tt.want {
+ t.Errorf("got renew=%v (ret=%v), want renew %v", renew, ret, tt.want)
}
}
})
diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go
index aa2c1a605..99330309b 100644
--- a/ipn/ipnlocal/serve.go
+++ b/ipn/ipnlocal/serve.go
@@ -372,7 +372,7 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort)
GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
- pair, err := b.GetCertPEM(ctx, sni)
+ pair, err := b.GetCertPEM(ctx, sni, false)
if err != nil {
return nil, err
}
@@ -675,7 +675,7 @@ func (b *LocalBackend) getTLSServeCertForPort(port uint16) func(hi *tls.ClientHe
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
- pair, err := b.GetCertPEM(ctx, hi.ServerName)
+ pair, err := b.GetCertPEM(ctx, hi.ServerName, false)
if err != nil {
return nil, err
}
diff --git a/ipn/localapi/cert.go b/ipn/localapi/cert.go
index 447c3bc3c..e1704cb49 100644
--- a/ipn/localapi/cert.go
+++ b/ipn/localapi/cert.go
@@ -23,7 +23,7 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
http.Error(w, "internal handler config wired wrong", 500)
return
}
- pair, err := h.b.GetCertPEM(r.Context(), domain)
+ pair, err := h.b.GetCertPEM(r.Context(), domain, true)
if err != nil {
// TODO(bradfitz): 500 is a little lazy here. The errors returned from
// GetCertPEM (and everywhere) should carry info info to get whether