diff --git a/internal/backend/backend_error.go b/internal/backend/backend_error.go index 2c9a616cc..98d6c0125 100644 --- a/internal/backend/backend_error.go +++ b/internal/backend/backend_error.go @@ -3,6 +3,7 @@ package backend import ( "context" "io" + "io/ioutil" "math/rand" "sync" @@ -13,9 +14,10 @@ import ( // ErrorBackend is used to induce errors into various function calls and test // the retry functions. type ErrorBackend struct { - FailSave float32 - FailLoad float32 - FailStat float32 + FailSave float32 + FailSaveRead float32 + FailLoad float32 + FailStat float32 restic.Backend r *rand.Rand @@ -48,6 +50,15 @@ func (be *ErrorBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) return errors.Errorf("Save(%v) random error induced", h) } + if be.fail(be.FailSaveRead) { + _, err := io.CopyN(ioutil.Discard, rd, be.r.Int63n(1000)) + if err != nil { + return err + } + + return errors.Errorf("Save(%v) random error with partial read induced", h) + } + return be.Backend.Save(ctx, h, rd) } diff --git a/internal/backend/backend_retry.go b/internal/backend/backend_retry.go index 66db447ca..fae70cd9d 100644 --- a/internal/backend/backend_retry.go +++ b/internal/backend/backend_retry.go @@ -7,6 +7,7 @@ import ( "time" "github.com/cenkalti/backoff" + "github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/restic" ) @@ -44,7 +45,26 @@ func (be *RetryBackend) retry(msg string, f func() error) error { // Save stores the data in the backend under the given handle. func (be *RetryBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error { + seeker, ok := rd.(io.Seeker) + if !ok { + return errors.Errorf("reader %T is not a seeker", rd) + } + + pos, err := seeker.Seek(0, io.SeekCurrent) + if err != nil { + return errors.Wrap(err, "Seek") + } + + if pos != 0 { + return errors.Errorf("reader is not at the beginning (pos %v)", pos) + } + return be.retry(fmt.Sprintf("Save(%v)", h), func() error { + _, err := seeker.Seek(0, io.SeekStart) + if err != nil { + return err + } + return be.Backend.Save(ctx, h, rd) }) } diff --git a/internal/backend/backend_retry_test.go b/internal/backend/backend_retry_test.go new file mode 100644 index 000000000..2a4d26c44 --- /dev/null +++ b/internal/backend/backend_retry_test.go @@ -0,0 +1,90 @@ +package backend + +import ( + "bytes" + "context" + "io" + "io/ioutil" + "testing" + + "github.com/restic/restic/internal/errors" + "github.com/restic/restic/internal/mock" + "github.com/restic/restic/internal/restic" + "github.com/restic/restic/internal/test" +) + +func TestBackendRetrySeeker(t *testing.T) { + be := &mock.Backend{ + SaveFn: func(ctx context.Context, h restic.Handle, rd io.Reader) error { + return nil + }, + } + + retryBackend := RetryBackend{ + Backend: be, + } + + data := test.Random(24, 23*14123) + + type wrapReader struct { + io.Reader + } + + var rd io.Reader + rd = wrapReader{bytes.NewReader(data)} + + err := retryBackend.Save(context.TODO(), restic.Handle{}, rd) + if err == nil { + t.Fatal("did not get expected error for retry backend with non-seeker reader") + } + + rd = bytes.NewReader(data) + _, err = io.CopyN(ioutil.Discard, rd, 5) + if err != nil { + t.Fatal(err) + } + + err = retryBackend.Save(context.TODO(), restic.Handle{}, rd) + if err == nil { + t.Fatal("did not get expected error for partial reader") + } +} + +func TestBackendSaveRetry(t *testing.T) { + buf := bytes.NewBuffer(nil) + errcount := 0 + be := &mock.Backend{ + SaveFn: func(ctx context.Context, h restic.Handle, rd io.Reader) error { + if errcount == 0 { + errcount++ + _, err := io.CopyN(ioutil.Discard, rd, 120) + if err != nil { + return err + } + + return errors.New("injected error") + } + + _, err := io.Copy(buf, rd) + return err + }, + } + + retryBackend := RetryBackend{ + Backend: be, + } + + data := test.Random(23, 5*1024*1024+11241) + err := retryBackend.Save(context.TODO(), restic.Handle{}, bytes.NewReader(data)) + if err != nil { + t.Fatal(err) + } + + if len(data) != buf.Len() { + t.Errorf("wrong number of bytes written: want %d, got %d", len(data), buf.Len()) + } + + if !bytes.Equal(data, buf.Bytes()) { + t.Errorf("wrong data written to backend") + } +} diff --git a/internal/cache/backend.go b/internal/cache/backend.go index 89aae51dc..9f545f309 100644 --- a/internal/cache/backend.go +++ b/internal/cache/backend.go @@ -5,6 +5,7 @@ import ( "io" "sync" + "github.com/pkg/errors" "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/restic" ) @@ -43,52 +44,50 @@ func (b *Backend) Remove(ctx context.Context, h restic.Handle) error { return b.Cache.Remove(h) } -type teeReader struct { - rd io.Reader - wr io.Writer - err error -} - -func (t *teeReader) Read(p []byte) (n int, err error) { - n, err = t.rd.Read(p) - if t.err == nil && n > 0 { - _, t.err = t.wr.Write(p[:n]) - } - - return n, err -} - var autoCacheTypes = map[restic.FileType]struct{}{ restic.IndexFile: struct{}{}, restic.SnapshotFile: struct{}{}, } -// Save stores a new file is the backend and the cache. +// Save stores a new file in the backend and the cache. func (b *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) { if _, ok := autoCacheTypes[h.Type]; !ok { return b.Backend.Save(ctx, h, rd) } debug.Log("Save(%v): auto-store in the cache", h) - wr, err := b.Cache.SaveWriter(h) - if err != nil { - debug.Log("unable to save %v to cache: %v", h, err) - return b.Backend.Save(ctx, h, rd) + + seeker, ok := rd.(io.Seeker) + if !ok { + return errors.New("reader is not a seeker") } - tr := &teeReader{rd: rd, wr: wr} - err = b.Backend.Save(ctx, h, tr) + pos, err := seeker.Seek(0, io.SeekCurrent) + if err != nil { + return errors.Wrapf(err, "Seek") + } + + if pos != 0 { + return errors.Errorf("reader is not rewind (pos %d)", pos) + } + + err = b.Backend.Save(ctx, h, rd) if err != nil { - wr.Close() - b.Cache.Remove(h) return err } - err = wr.Close() + _, err = seeker.Seek(pos, io.SeekStart) if err != nil { - debug.Log("cache writer returned error: %v", err) - _ = b.Cache.Remove(h) + return errors.Wrapf(err, "Seek") } + + err = b.Cache.Save(h, rd) + if err != nil { + debug.Log("unable to save %v to cache: %v", h, err) + _ = b.Cache.Remove(h) + return nil + } + return nil }