summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorBrad Fitzpatrick <bradfitz@tailscale.com>2024-06-21 09:24:08 -0700
committerBrad Fitzpatrick <bradfitz@tailscale.com>2024-06-21 10:02:13 -0700
commitdc38eaf5512e4630983fe64fd38d89c43b8f5d67 (patch)
tree3969c7459c8fed686b7615df135912f8921c6f40
parent5ec01bf3ce6c01841bfc6d17736b5b35df06d2a3 (diff)
downloadtailscale-bradfitz/resume.tar.xz
tailscale-bradfitz/resume.zip
clientupdate/distsign: resume partial downloadsbradfitz/resume
Fixes #12573 Change-Id: If2684a2987ec95b893ba9b7c71dc0cb850765b18 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
-rw-r--r--clientupdate/distsign/distsign.go50
-rw-r--r--clientupdate/distsign/distsign_test.go75
2 files changed, 106 insertions, 19 deletions
diff --git a/clientupdate/distsign/distsign.go b/clientupdate/distsign/distsign.go
index eba4b9267..2b3f2a97f 100644
--- a/clientupdate/distsign/distsign.go
+++ b/clientupdate/distsign/distsign.go
@@ -229,8 +229,6 @@ func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error {
c.logf("Downloading %q", sigURL)
sig, err := fetch(sigURL, signatureSizeLimit)
if err != nil {
- // Best-effort clean up of downloaded package.
- os.Remove(dstPathUnverified)
return err
}
msg := binary.LittleEndian.AppendUint64(hash, uint64(len))
@@ -326,6 +324,8 @@ func fetch(url string, limit int64) ([]byte, error) {
return io.ReadAll(io.LimitReader(resp.Body, limit))
}
+var onResponseForTest = func(*http.Response) {}
+
// download writes the response body of url into a local file at dst, up to
// limit bytes. On success, the returned value is a BLAKE2s hash of the file.
func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) {
@@ -349,31 +349,61 @@ func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]
return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength)
}
c.logf("Download size: %v", res.ContentLength)
+ h := NewPackageHash()
dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil))
+
+ var skip int64
+ if fi, err := os.Stat(dst); err == nil {
+ if fi.Size() == res.ContentLength {
+ // Assume it got corrupted previously and the earlier attempt failed
+ // the checksum. Delete it and start over.
+ if err := os.Remove(dst); err != nil {
+ return nil, 0, fmt.Errorf("error deleting previous assumed-bad download: %w", err)
+ }
+ } else if fi.Size() > 0 && fi.Size() < res.ContentLength {
+ c.logf("Existing file size: %v", fi.Size())
+ skip = fi.Size()
+ dlReq.Header.Add("Range", fmt.Sprintf("bytes=%d-", skip))
+ }
+ }
+
dlRes, err := hc.Do(dlReq)
if err != nil {
return nil, 0, err
}
+ onResponseForTest(dlRes)
defer dlRes.Body.Close()
- // TODO(bradfitz): resume from existing partial file on disk
- if dlRes.StatusCode != http.StatusOK {
+
+ var of *os.File
+ wantResponseLength := res.ContentLength
+ switch dlRes.StatusCode {
+ case http.StatusOK:
+ if skip > 0 {
+ os.Remove(dst) // best effort; the Create will fail anyway if this would
+ }
+ of, err = os.Create(dst)
+ case http.StatusPartialContent:
+ wantResponseLength = res.ContentLength - skip
+ of, err = os.OpenFile(dst, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0644)
+ if err == nil {
+ // Re-hash the previously downloaded chunk.
+ _, err = io.Copy(h, io.NewSectionReader(of, 0, skip))
+ }
+ default:
return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status)
}
-
- of, err := os.Create(dst)
if err != nil {
return nil, 0, err
}
defer of.Close()
- pw := &progressWriter{total: res.ContentLength, logf: c.logf}
- h := NewPackageHash()
+ pw := &progressWriter{total: res.ContentLength, done: skip, logf: c.logf}
n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit))
if err != nil {
return nil, n, err
}
- if n != res.ContentLength {
- return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength)
+ if n != wantResponseLength {
+ return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, wantResponseLength)
}
if err := dlRes.Body.Close(); err != nil {
return nil, n, err
diff --git a/clientupdate/distsign/distsign_test.go b/clientupdate/distsign/distsign_test.go
index 09a701f49..8439f471b 100644
--- a/clientupdate/distsign/distsign_test.go
+++ b/clientupdate/distsign/distsign_test.go
@@ -5,6 +5,7 @@ package distsign
import (
"bytes"
+ "cmp"
"context"
"crypto/ed25519"
"net/http"
@@ -12,10 +13,13 @@ import (
"net/url"
"os"
"path/filepath"
+ "reflect"
"strings"
"testing"
+ "time"
"golang.org/x/crypto/blake2s"
+ "tailscale.com/tstest"
)
func TestDownload(t *testing.T) {
@@ -23,11 +27,13 @@ func TestDownload(t *testing.T) {
c := srv.client(t)
tests := []struct {
- desc string
- before func(*testing.T)
- src string
- want []byte
- wantErr bool
+ desc string
+ before func(*testing.T)
+ existing []byte // optional existing data on disk to resume from
+ src string
+ want []byte
+ wantErr bool
+ wantCode int // HTTP status code of download to expect; 0 means http.StatusOK
}{
{
desc: "missing file",
@@ -44,6 +50,45 @@ func TestDownload(t *testing.T) {
want: []byte("world"),
},
{
+ desc: "success-resume",
+ before: func(*testing.T) {
+ srv.addSigned("hello", []byte("world"))
+ },
+ src: "hello",
+ existing: []byte("wo"),
+ want: []byte("world"),
+ wantCode: http.StatusPartialContent,
+ },
+ {
+ desc: "success-resume-ignore-matching-size",
+ before: func(*testing.T) {
+ srv.addSigned("hello", []byte("world"))
+ },
+ src: "hello",
+ existing: []byte("WORLD"), // same size as world
+ want: []byte("world"),
+ wantCode: http.StatusOK,
+ },
+ {
+ desc: "success-resume-ignore-existing-too-big",
+ before: func(*testing.T) {
+ srv.addSigned("hello", []byte("world"))
+ },
+ src: "hello",
+ existing: []byte("longer-than-world"), // len greater than len("world")
+ want: []byte("world"),
+ wantCode: http.StatusOK,
+ },
+ {
+ desc: "resume-corrupt",
+ before: func(*testing.T) {
+ srv.addSigned("hello", []byte("world"))
+ },
+ src: "hello",
+ existing: []byte("WO"), // previous download was bad
+ wantErr: true,
+ },
+ {
desc: "no signature",
before: func(*testing.T) {
srv.add("hello", []byte("world"))
@@ -94,10 +139,17 @@ func TestDownload(t *testing.T) {
srv.reset()
tt.before(t)
- dst := filepath.Join(t.TempDir(), tt.src)
- t.Cleanup(func() {
- os.Remove(dst)
+ var gotCodes []int
+ tstest.Replace(t, &onResponseForTest, func(res *http.Response) {
+ gotCodes = append(gotCodes, res.StatusCode)
})
+
+ dst := filepath.Join(t.TempDir(), tt.src)
+ if len(tt.existing) > 0 {
+ if err := os.WriteFile(dst+".unverified", tt.existing, 0644); err != nil {
+ t.Fatal(err)
+ }
+ }
err := c.Download(context.Background(), tt.src, dst)
if err != nil {
if tt.wantErr {
@@ -107,6 +159,11 @@ func TestDownload(t *testing.T) {
}
if tt.wantErr {
t.Fatalf("Download(%q) succeeded, expected an error", tt.src)
+ } else {
+ wantCodes := []int{cmp.Or(tt.wantCode, http.StatusOK)}
+ if !reflect.DeepEqual(gotCodes, wantCodes) {
+ t.Errorf("HTTP response status code = %v; want %v", gotCodes, wantCodes)
+ }
}
got, err := os.ReadFile(dst)
if err != nil {
@@ -486,7 +543,7 @@ func (s *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
return
}
- w.Write(data)
+ http.ServeContent(w, r, path, time.Time{}, bytes.NewReader(data))
}
func (s *testServer) addSigned(name string, data []byte) {