crypto: Match signature of Encrypt() to Decrypt()

This commit is contained in:
Alexander Neumann 2015-04-12 20:58:41 +02:00
parent 7e6acfe44d
commit f8e1043ad3
5 changed files with 63 additions and 35 deletions

View File

@ -205,29 +205,33 @@ func (k *EncryptionKey) UnmarshalJSON(data []byte) error {
} }
// Encrypt encrypts and signs data. Stored in ciphertext is IV || Ciphertext || // Encrypt encrypts and signs data. Stored in ciphertext is IV || Ciphertext ||
// MAC. Encrypt returns the ciphertext's length. // MAC. Encrypt returns the new ciphertext slice, which is extended when
func Encrypt(ks *Key, ciphertext, plaintext []byte) (int, error) { // necessary. ciphertext and plaintext may point to the same slice.
if cap(ciphertext) < len(plaintext)+ivSize+macSize { func Encrypt(ks *Key, ciphertext, plaintext []byte) ([]byte, error) {
return 0, ErrBufferTooSmall // extend ciphertext slice if necessary
if cap(ciphertext) < len(plaintext)+Extension {
ext := len(plaintext) + Extension - cap(ciphertext)
n := len(ciphertext)
ciphertext = append(ciphertext, make([]byte, ext)...)
ciphertext = ciphertext[:n]
} }
iv := newIV() iv := newIV()
copy(ciphertext, iv[:])
c, err := aes.NewCipher(ks.Encrypt[:]) c, err := aes.NewCipher(ks.Encrypt[:])
if err != nil { if err != nil {
panic(fmt.Sprintf("unable to create cipher: %v", err)) panic(fmt.Sprintf("unable to create cipher: %v", err))
} }
e := cipher.NewCTR(c, ciphertext[:ivSize]) e := cipher.NewCTR(c, iv[:])
e.XORKeyStream(ciphertext[ivSize:cap(ciphertext)], plaintext) e.XORKeyStream(ciphertext[ivSize:cap(ciphertext)], plaintext)
copy(ciphertext, iv[:])
ciphertext = ciphertext[:ivSize+len(plaintext)] ciphertext = ciphertext[:ivSize+len(plaintext)]
mac := poly1305_sign(ciphertext[ivSize:], ciphertext[:ivSize], &ks.Sign) mac := poly1305_sign(ciphertext[ivSize:], ciphertext[:ivSize], &ks.Sign)
ciphertext = append(ciphertext, mac...) ciphertext = append(ciphertext, mac...)
return len(ciphertext), nil return ciphertext, nil
} }
// Decrypt verifies and decrypts the ciphertext. Ciphertext must be in the form // Decrypt verifies and decrypts the ciphertext. Ciphertext must be in the form

View File

@ -98,6 +98,7 @@ func should_panic(f func()) (did_panic bool) {
} }
func TestCrypto(t *testing.T) { func TestCrypto(t *testing.T) {
msg := make([]byte, 0, 8*1024*1024) // use 8MiB for now
for _, tv := range test_values { for _, tv := range test_values {
// test encryption // test encryption
k := &Key{ k := &Key{
@ -105,12 +106,10 @@ func TestCrypto(t *testing.T) {
Sign: tv.skey, Sign: tv.skey,
} }
msg := make([]byte, 0, 8*1024*1024) // use 8MiB for now msg, err := Encrypt(k, msg, tv.plaintext)
n, err := Encrypt(k, msg, tv.plaintext)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
msg = msg[:n]
// decrypt message // decrypt message
_, err = Decrypt(k, []byte{}, msg) _, err = Decrypt(k, []byte{}, msg)

View File

@ -29,11 +29,10 @@ func TestEncryptDecrypt(t *testing.T) {
_, err := io.ReadFull(RandomReader(42, size), data) _, err := io.ReadFull(RandomReader(42, size), data)
OK(t, err) OK(t, err)
ciphertext := restic.GetChunkBuf("TestEncryptDecrypt") ciphertext, err := crypto.Encrypt(k, restic.GetChunkBuf("TestEncryptDecrypt"), data)
n, err := crypto.Encrypt(k, ciphertext, data)
OK(t, err) OK(t, err)
plaintext, err := crypto.Decrypt(k, nil, ciphertext[:n]) plaintext, err := crypto.Decrypt(k, nil, ciphertext)
OK(t, err) OK(t, err)
restic.FreeChunkBuf("TestEncryptDecrypt", ciphertext) restic.FreeChunkBuf("TestEncryptDecrypt", ciphertext)
@ -54,10 +53,40 @@ func TestSmallBuffer(t *testing.T) {
OK(t, err) OK(t, err)
ciphertext := make([]byte, size/2) ciphertext := make([]byte, size/2)
_, err = crypto.Encrypt(k, ciphertext, data) ciphertext, err = crypto.Encrypt(k, ciphertext, data)
// this must throw an error, since the target slice is too small // this must throw an error, since the target slice is too small
Assert(t, err != nil && err == crypto.ErrBufferTooSmall, Assert(t, cap(ciphertext) > size/2,
"expected restic.ErrBufferTooSmall, got %#v", err) "expected extended slice, but capacity is only %d bytes",
cap(ciphertext))
// check for the correct plaintext
plaintext, err := crypto.Decrypt(k, nil, ciphertext)
OK(t, err)
Assert(t, bytes.Equal(plaintext, data),
"wrong plaintext returned")
}
func TestSameBuffer(t *testing.T) {
k := crypto.NewKey()
size := 600
data := make([]byte, size)
f, err := os.Open("/dev/urandom")
OK(t, err)
_, err = io.ReadFull(f, data)
OK(t, err)
ciphertext := make([]byte, size)
copy(ciphertext, data)
ciphertext, err = crypto.Encrypt(k, ciphertext, ciphertext)
OK(t, err)
ciphertext, err = crypto.Decrypt(k, ciphertext, ciphertext)
OK(t, err)
Assert(t, bytes.Equal(ciphertext, data),
"wrong plaintext returned")
} }
func TestLargeEncrypt(t *testing.T) { func TestLargeEncrypt(t *testing.T) {
@ -75,11 +104,10 @@ func TestLargeEncrypt(t *testing.T) {
_, err = io.ReadFull(f, data) _, err = io.ReadFull(f, data)
OK(t, err) OK(t, err)
ciphertext := make([]byte, size+crypto.Extension) ciphertext, err := crypto.Encrypt(k, make([]byte, size+crypto.Extension), data)
n, err := crypto.Encrypt(k, ciphertext, data)
OK(t, err) OK(t, err)
plaintext, err := crypto.Decrypt(k, []byte{}, ciphertext[:n]) plaintext, err := crypto.Decrypt(k, []byte{}, ciphertext)
OK(t, err) OK(t, err)
Equals(t, plaintext, data) Equals(t, plaintext, data)
@ -183,14 +211,14 @@ func BenchmarkDecrypt(b *testing.B) {
plaintext := restic.GetChunkBuf("BenchmarkDecrypt") plaintext := restic.GetChunkBuf("BenchmarkDecrypt")
defer restic.FreeChunkBuf("BenchmarkDecrypt", plaintext) defer restic.FreeChunkBuf("BenchmarkDecrypt", plaintext)
n, err := crypto.Encrypt(k, ciphertext, data) ciphertext, err := crypto.Encrypt(k, ciphertext, data)
OK(b, err) OK(b, err)
b.ResetTimer() b.ResetTimer()
b.SetBytes(int64(size)) b.SetBytes(int64(size))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
plaintext, err = crypto.Decrypt(k, plaintext, ciphertext[:n]) plaintext, err = crypto.Decrypt(k, plaintext, ciphertext)
OK(b, err) OK(b, err)
} }
} }
@ -245,11 +273,11 @@ func TestDecryptStreamReader(t *testing.T) {
ciphertext := make([]byte, size+crypto.Extension) ciphertext := make([]byte, size+crypto.Extension)
// encrypt with default function // encrypt with default function
n, err := crypto.Encrypt(k, ciphertext, data) ciphertext, err = crypto.Encrypt(k, ciphertext, data)
OK(t, err) OK(t, err)
Assert(t, n == len(data)+crypto.Extension, Assert(t, len(ciphertext) == len(data)+crypto.Extension,
"wrong number of bytes returned after encryption: expected %d, got %d", "wrong number of bytes returned after encryption: expected %d, got %d",
len(data)+crypto.Extension, n) len(data)+crypto.Extension, len(ciphertext))
rd, err := crypto.DecryptFrom(k, bytes.NewReader(ciphertext)) rd, err := crypto.DecryptFrom(k, bytes.NewReader(ciphertext))
OK(t, err) OK(t, err)

9
key.go
View File

@ -196,9 +196,7 @@ func AddKey(s Server, password string, template *Key) (*Key, error) {
return nil, err return nil, err
} }
newkey.Data = GetChunkBuf("key") newkey.Data, err = crypto.Encrypt(newkey.user, GetChunkBuf("key"), buf)
n, err = crypto.Encrypt(newkey.user, newkey.Data, buf)
newkey.Data = newkey.Data[:n]
// dump as json // dump as json
buf, err = json.Marshal(newkey) buf, err = json.Marshal(newkey)
@ -234,8 +232,9 @@ func AddKey(s Server, password string, template *Key) (*Key, error) {
} }
// Encrypt encrypts and signs data with the master key. Stored in ciphertext is // Encrypt encrypts and signs data with the master key. Stored in ciphertext is
// IV || Ciphertext || MAC. Returns the ciphertext length. // IV || Ciphertext || MAC. Returns the ciphertext, which is extended if
func (k *Key) Encrypt(ciphertext, plaintext []byte) (int, error) { // necessary.
func (k *Key) Encrypt(ciphertext, plaintext []byte) ([]byte, error) {
return crypto.Encrypt(k.master, ciphertext, plaintext) return crypto.Encrypt(k.master, ciphertext, plaintext)
} }

View File

@ -172,13 +172,11 @@ func (s Server) Save(t backend.Type, data []byte, id backend.ID) (Blob, error) {
} }
// encrypt blob // encrypt blob
n, err := s.Encrypt(ciphertext, data) ciphertext, err := s.Encrypt(ciphertext, data)
if err != nil { if err != nil {
return Blob{}, err return Blob{}, err
} }
ciphertext = ciphertext[:n]
// compute ciphertext hash // compute ciphertext hash
sid := backend.Hash(ciphertext) sid := backend.Hash(ciphertext)
@ -309,9 +307,9 @@ func (s Server) Decrypt(ciphertext []byte) ([]byte, error) {
return s.key.Decrypt([]byte{}, ciphertext) return s.key.Decrypt([]byte{}, ciphertext)
} }
func (s Server) Encrypt(ciphertext, plaintext []byte) (int, error) { func (s Server) Encrypt(ciphertext, plaintext []byte) ([]byte, error) {
if s.key == nil { if s.key == nil {
return 0, errors.New("key for server not set") return nil, errors.New("key for server not set")
} }
return s.key.Encrypt(ciphertext, plaintext) return s.key.Encrypt(ciphertext, plaintext)