Improve error handling with the ssh subprocess

This commit is contained in:
Alexander Neumann 2016-08-28 19:17:17 +02:00
parent 8de06bd453
commit 5dd137d53e
1 changed files with 60 additions and 9 deletions

View File

@ -9,6 +9,7 @@ import (
"os/exec" "os/exec"
"path" "path"
"strings" "strings"
"time"
"restic/backend" "restic/backend"
"restic/debug" "restic/debug"
@ -26,7 +27,8 @@ type SFTP struct {
c *sftp.Client c *sftp.Client
p string p string
cmd *exec.Cmd cmd *exec.Cmd
result <-chan error
} }
func startClient(program string, args ...string) (*SFTP, error) { func startClient(program string, args ...string) (*SFTP, error) {
@ -55,13 +57,21 @@ func startClient(program string, args ...string) (*SFTP, error) {
return nil, err return nil, err
} }
// wait in a different goroutine
ch := make(chan error, 1)
go func() {
err := cmd.Wait()
debug.Log("sftp.Wait", "ssh command exited, err %v", err)
ch <- err
}()
// open the SFTP session // open the SFTP session
client, err := sftp.NewClientPipe(rd, wr) client, err := sftp.NewClientPipe(rd, wr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &SFTP{c: client, cmd: cmd}, nil return &SFTP{c: client, cmd: cmd, result: ch}, nil
} }
func paths(dir string) []string { func paths(dir string) []string {
@ -76,6 +86,19 @@ func paths(dir string) []string {
} }
} }
// clientError returns an error if the client has exited. Otherwise, nil is
// returned immediately.
func (r *SFTP) clientError() error {
select {
case err := <-r.result:
debug.Log("sftp.clientError", "client has exited with err %v", err)
return err
default:
}
return nil
}
// Open opens an sftp backend. When the command is started via // Open opens an sftp backend. When the command is started via
// exec.Command, it is expected to speak sftp on stdin/stdout. The backend // exec.Command, it is expected to speak sftp on stdin/stdout. The backend
// is expected at the given path. `dir` must be delimited by forward slashes // is expected at the given path. `dir` must be delimited by forward slashes
@ -122,6 +145,7 @@ func OpenWithConfig(cfg Config) (*SFTP, error) {
// backend at dir. Afterwards a new config blob should be created. `dir` must // backend at dir. Afterwards a new config blob should be created. `dir` must
// be delimited by forward slashes ("/"), which is required by sftp. // be delimited by forward slashes ("/"), which is required by sftp.
func Create(dir string, program string, args ...string) (*SFTP, error) { func Create(dir string, program string, args ...string) (*SFTP, error) {
debug.Log("sftp.Create", "%v %v", program, args)
sftp, err := startClient(program, args...) sftp, err := startClient(program, args...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -141,12 +165,7 @@ func Create(dir string, program string, args ...string) (*SFTP, error) {
} }
} }
err = sftp.c.Close() err = sftp.Close()
if err != nil {
return nil, err
}
err = sftp.cmd.Wait()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -158,6 +177,7 @@ func Create(dir string, program string, args ...string) (*SFTP, error) {
// CreateWithConfig creates an sftp backend as described by the config by running // CreateWithConfig creates an sftp backend as described by the config by running
// "ssh" with the appropiate arguments. // "ssh" with the appropiate arguments.
func CreateWithConfig(cfg Config) (*SFTP, error) { func CreateWithConfig(cfg Config) (*SFTP, error) {
debug.Log("sftp.CreateWithConfig", "config %v", cfg)
return Create(cfg.Dir, "ssh", buildSSHCommand(cfg)...) return Create(cfg.Dir, "ssh", buildSSHCommand(cfg)...)
} }
@ -291,6 +311,10 @@ func (r *SFTP) dirname(t backend.Type, name string) string {
// Load returns the data stored in the backend for h at the given offset // Load returns the data stored in the backend for h at the given offset
// and saves it in p. Load has the same semantics as io.ReaderAt. // and saves it in p. Load has the same semantics as io.ReaderAt.
func (r *SFTP) Load(h backend.Handle, p []byte, off int64) (n int, err error) { func (r *SFTP) Load(h backend.Handle, p []byte, off int64) (n int, err error) {
if err := r.clientError(); err != nil {
return 0, err
}
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return 0, err return 0, err
} }
@ -323,6 +347,10 @@ func (r *SFTP) Load(h backend.Handle, p []byte, off int64) (n int, err error) {
// Save stores data in the backend at the handle. // Save stores data in the backend at the handle.
func (r *SFTP) Save(h backend.Handle, p []byte) (err error) { func (r *SFTP) Save(h backend.Handle, p []byte) (err error) {
if err := r.clientError(); err != nil {
return err
}
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return err return err
} }
@ -360,6 +388,10 @@ func (r *SFTP) Save(h backend.Handle, p []byte) (err error) {
// Stat returns information about a blob. // Stat returns information about a blob.
func (r *SFTP) Stat(h backend.Handle) (backend.BlobInfo, error) { func (r *SFTP) Stat(h backend.Handle) (backend.BlobInfo, error) {
if err := r.clientError(); err != nil {
return backend.BlobInfo{}, err
}
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return backend.BlobInfo{}, err return backend.BlobInfo{}, err
} }
@ -374,6 +406,10 @@ func (r *SFTP) Stat(h backend.Handle) (backend.BlobInfo, error) {
// Test returns true if a blob of the given type and name exists in the backend. // Test returns true if a blob of the given type and name exists in the backend.
func (r *SFTP) Test(t backend.Type, name string) (bool, error) { func (r *SFTP) Test(t backend.Type, name string) (bool, error) {
if err := r.clientError(); err != nil {
return false, err
}
_, err := r.c.Lstat(r.filename(t, name)) _, err := r.c.Lstat(r.filename(t, name))
if os.IsNotExist(err) { if os.IsNotExist(err) {
return false, nil return false, nil
@ -388,6 +424,10 @@ func (r *SFTP) Test(t backend.Type, name string) (bool, error) {
// Remove removes the content stored at name. // Remove removes the content stored at name.
func (r *SFTP) Remove(t backend.Type, name string) error { func (r *SFTP) Remove(t backend.Type, name string) error {
if err := r.clientError(); err != nil {
return err
}
return r.c.Remove(r.filename(t, name)) return r.c.Remove(r.filename(t, name))
} }
@ -459,6 +499,8 @@ func (r *SFTP) List(t backend.Type, done <-chan struct{}) <-chan string {
} }
var closeTimeout = 2 * time.Second
// Close closes the sftp connection and terminates the underlying command. // Close closes the sftp connection and terminates the underlying command.
func (r *SFTP) Close() error { func (r *SFTP) Close() error {
if r == nil { if r == nil {
@ -468,9 +510,18 @@ func (r *SFTP) Close() error {
err := r.c.Close() err := r.c.Close()
debug.Log("sftp.Close", "Close returned error %v", err) debug.Log("sftp.Close", "Close returned error %v", err)
// wait for closeTimeout before killing the process
select {
case err := <-r.result:
return err
case <-time.After(closeTimeout):
}
if err := r.cmd.Process.Kill(); err != nil { if err := r.cmd.Process.Kill(); err != nil {
return err return err
} }
return r.cmd.Wait() // get the error, but ignore it
<-r.result
return nil
} }