diff --git a/src/restic/archiver/archive_reader.go b/src/restic/archiver/archive_reader.go index 08a7fb09f..22158fbe3 100644 --- a/src/restic/archiver/archive_reader.go +++ b/src/restic/archiver/archive_reader.go @@ -25,7 +25,7 @@ func ArchiveReader(repo restic.Repository, p *restic.Progress, rd io.Reader, nam chnker := chunker.New(rd, repo.Config().ChunkerPolynomial) - var ids restic.IDs + ids := restic.IDs{} var fileSize uint64 for { diff --git a/src/restic/archiver/archive_reader_test.go b/src/restic/archiver/archive_reader_test.go index c24a0be5e..68fde3d03 100644 --- a/src/restic/archiver/archive_reader_test.go +++ b/src/restic/archiver/archive_reader_test.go @@ -2,9 +2,11 @@ package archiver import ( "bytes" + "errors" "io" "math/rand" "restic" + "restic/checker" "restic/repository" "testing" ) @@ -89,6 +91,68 @@ func TestArchiveReader(t *testing.T) { t.Logf("snapshot saved as %v, tree is %v", id.Str(), sn.Tree.Str()) checkSavedFile(t, repo, *sn.Tree, "fakefile", fakeFile(t, seed, size)) + + checker.TestCheckRepo(t, repo) +} + +func TestArchiveReaderNull(t *testing.T) { + repo, cleanup := repository.TestRepository(t) + defer cleanup() + + sn, id, err := ArchiveReader(repo, nil, bytes.NewReader(nil), "fakefile", nil) + if err != nil { + t.Fatalf("ArchiveReader() returned error %v", err) + } + + if id.IsNull() { + t.Fatalf("ArchiveReader() returned null ID") + } + + t.Logf("snapshot saved as %v, tree is %v", id.Str(), sn.Tree.Str()) + + checker.TestCheckRepo(t, repo) +} + +type errReader string + +func (e errReader) Read([]byte) (int, error) { + return 0, errors.New(string(e)) +} + +func countSnapshots(t testing.TB, repo restic.Repository) int { + done := make(chan struct{}) + defer close(done) + + snapshots := 0 + for range repo.List(restic.SnapshotFile, done) { + snapshots++ + } + return snapshots +} + +func TestArchiveReaderError(t *testing.T) { + repo, cleanup := repository.TestRepository(t) + defer cleanup() + + sn, id, err := ArchiveReader(repo, nil, errReader("error returned by reading stdin"), "fakefile", nil) + if err == nil { + t.Errorf("expected error not returned") + } + + if sn != nil { + t.Errorf("Snapshot should be nil, but isn't") + } + + if !id.IsNull() { + t.Errorf("id should be null, but %v returned", id.Str()) + } + + n := countSnapshots(t, repo) + if n > 0 { + t.Errorf("expected zero snapshots, but got %d", n) + } + + checker.TestCheckRepo(t, repo) } func BenchmarkArchiveReader(t *testing.B) {