Commit 4dd63fbe authored by Sietse Ringers's avatar Sietse Ringers
Browse files

refactor: reduce code duplication in internal/keysharecore

parent aa24f59c
......@@ -27,33 +27,12 @@ var (
// Generate a new keyshare secret, secured with the given pin
func (c *Core) GenerateKeyshareSecret(pinRaw string) (EncryptedKeysharePacket, error) {
pin, err := padPin(pinRaw)
if err != nil {
return EncryptedKeysharePacket{}, err
}
keyshareSecret, err := gabi.NewKeyshareSecret()
secret, err := gabi.NewKeyshareSecret()
if err != nil {
return EncryptedKeysharePacket{}, err
}
var id [32]byte
_, err = rand.Read(id[:])
if err != nil {
return EncryptedKeysharePacket{}, err
}
// Build unencrypted packet
var p unencryptedKeysharePacket
p.setPin(pin)
err = p.setKeyshareSecret(keyshareSecret)
if err != nil {
return EncryptedKeysharePacket{}, err
}
p.setID(id)
// And encrypt
return c.encryptPacket(p)
return c.DangerousBuildKeyshareSecret(pinRaw, secret)
}
func (c *Core) DangerousBuildKeyshareSecret(pinRaw string, secret *big.Int) (EncryptedKeysharePacket, error) {
......@@ -68,6 +47,7 @@ func (c *Core) DangerousBuildKeyshareSecret(pinRaw string, secret *big.Int) (Enc
return EncryptedKeysharePacket{}, err
}
// Build unencrypted packet
var p unencryptedKeysharePacket
p.setPin(pin)
err = p.setKeyshareSecret(secret)
......@@ -76,29 +56,18 @@ func (c *Core) DangerousBuildKeyshareSecret(pinRaw string, secret *big.Int) (Enc
}
p.setID(id)
// And encrypt
return c.encryptPacket(p)
}
// Check pin for validity, and generate jwt for future access
// userid is an extra field added to the jwt for
func (c *Core) ValidatePin(ep EncryptedKeysharePacket, pin string, userID string) (string, error) {
paddedPin, err := padPin(pin)
if err != nil {
return "", err
}
// decrypt
p, err := c.decryptPacket(ep)
p, err := c.decryptPacketIfPinOK(ep, pin)
if err != nil {
return "", err
}
// verify pin
refPin := p.pin()
if !hmac.Equal(refPin[:], paddedPin[:]) {
return "", ErrInvalidPin
}
// Generate jwt token
id := p.id()
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
......@@ -121,28 +90,16 @@ func (c *Core) ValidateJWT(ep EncryptedKeysharePacket, jwt string) error {
// Change pin in an encrypted keyshare packet to a new value, after validating that the old value is known by caller.
func (c *Core) ChangePin(ep EncryptedKeysharePacket, oldpinRaw, newpinRaw string) (EncryptedKeysharePacket, error) {
oldpin, err := padPin(oldpinRaw)
if err != nil {
return EncryptedKeysharePacket{}, err
}
newpin, err := padPin(newpinRaw)
p, err := c.decryptPacketIfPinOK(ep, oldpinRaw)
if err != nil {
return EncryptedKeysharePacket{}, err
}
// decrypt
p, err := c.decryptPacket(ep)
newpin, err := padPin(newpinRaw)
if err != nil {
return EncryptedKeysharePacket{}, err
}
// verify
refPin := p.pin()
// use hmac equal to make this constant time
if !hmac.Equal(refPin[:], oldpin[:]) {
return EncryptedKeysharePacket{}, ErrInvalidPin
}
// change and reencrypt
var id [32]byte
_, err = rand.Read(id[:])
......
......@@ -3,6 +3,7 @@ package keysharecore
import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"encoding/binary"
......@@ -85,11 +86,7 @@ func (c *Core) encryptPacket(p unencryptedKeysharePacket) (EncryptedKeysharePack
}
// Encrypt packet
keyedAes, err := aes.NewCipher(c.encryptionKey[:])
if err != nil {
return EncryptedKeysharePacket{}, err
}
gcm, err := cipher.NewGCM(keyedAes)
gcm, err := newGCM(c.encryptionKey)
if err != nil {
return EncryptedKeysharePacket{}, err
}
......@@ -109,11 +106,7 @@ func (c *Core) decryptPacket(p EncryptedKeysharePacket) (unencryptedKeysharePack
}
// try and decrypt packet
keyedAes, err := aes.NewCipher(key[:])
if err != nil {
return unencryptedKeysharePacket{}, err
}
gcm, err := cipher.NewGCM(keyedAes)
gcm, err := newGCM(key)
if err != nil {
return unencryptedKeysharePacket{}, err
}
......@@ -124,3 +117,34 @@ func (c *Core) decryptPacket(p EncryptedKeysharePacket) (unencryptedKeysharePack
}
return result, nil
}
func (c *Core) decryptPacketIfPinOK(ep EncryptedKeysharePacket, pin string) (unencryptedKeysharePacket, error) {
paddedPin, err := padPin(pin)
if err != nil {
return unencryptedKeysharePacket{}, err
}
p, err := c.decryptPacket(ep)
if err != nil {
return unencryptedKeysharePacket{}, err
}
// Check pins in constant time
refPin := p.pin()
if !hmac.Equal(refPin[:], paddedPin[:]) {
return unencryptedKeysharePacket{}, ErrInvalidPin
}
return p, nil
}
func newGCM(key AesKey) (cipher.AEAD, error) {
keyedAes, err := aes.NewCipher(key[:])
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(keyedAes)
if err != nil {
return nil, err
}
return gcm, nil
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment