Commit 7f6f20b4 authored by Sietse Ringers's avatar Sietse Ringers

refactor: remove superfluous revStorage interface

parent 4c99292c
......@@ -26,7 +26,7 @@ type (
// revokes an earlier issued credential.
RevocationStorage struct {
conf *Configuration
db revStorage
sqldb sqlRevStorage
memdb memRevStorage
sqlMode bool
settings map[CredentialTypeIdentifier]*RevocationSetting
......@@ -158,7 +158,7 @@ func (rs *RevocationStorage) EnableRevocation(typ CredentialTypeIdentifier, sk *
return err
}
if err = rs.addUpdate(rs.db, typ, update, true); err != nil {
if err = rs.addUpdate(rs.sqldb, typ, update, true); err != nil {
return err
}
return nil
......@@ -167,7 +167,7 @@ func (rs *RevocationStorage) EnableRevocation(typ CredentialTypeIdentifier, sk *
// Exists returns whether or not an accumulator exists in the database for the given credential type.
func (rs *RevocationStorage) Exists(typ CredentialTypeIdentifier, counter uint) (bool, error) {
// only requires sql implementation
return rs.db.Exists((*AccumulatorRecord)(nil), map[string]interface{}{"cred_type": typ, "pk_counter": counter})
return rs.sqldb.Exists((*AccumulatorRecord)(nil), map[string]interface{}{"cred_type": typ, "pk_counter": counter})
}
// Revocation update message methods
......@@ -178,7 +178,7 @@ func (rs *RevocationStorage) Exists(typ CredentialTypeIdentifier, counter uint)
func (rs *RevocationStorage) UpdateFrom(typ CredentialTypeIdentifier, pkcounter uint, index uint64) (*revocation.Update, error) {
// Only requires SQL implementation
var update *revocation.Update
if err := rs.db.Transaction(func(tx revStorage) error {
if err := rs.sqldb.Transaction(func(tx sqlRevStorage) error {
acc, err := rs.accumulator(tx, typ, pkcounter)
if err != nil {
return err
......@@ -199,7 +199,7 @@ func (rs *RevocationStorage) UpdateLatest(typ CredentialTypeIdentifier, count ui
// TODO what should this function and UpdateFrom return when no records are found?
if rs.sqlMode {
var update map[uint]*revocation.Update
if err := rs.db.Transaction(func(tx revStorage) error {
if err := rs.sqldb.Transaction(func(tx sqlRevStorage) error {
var (
records []*AccumulatorRecord
events []*EventRecord
......@@ -259,10 +259,10 @@ func (rs *RevocationStorage) newUpdate(acc *revocation.SignedAccumulator, events
}
func (rs *RevocationStorage) AddUpdate(typ CredentialTypeIdentifier, record *revocation.Update) error {
return rs.addUpdate(rs.db, typ, record, false)
return rs.addUpdate(rs.sqldb, typ, record, false)
}
func (rs *RevocationStorage) addUpdate(tx revStorage, typ CredentialTypeIdentifier, update *revocation.Update, create bool) error {
func (rs *RevocationStorage) addUpdate(tx sqlRevStorage, typ CredentialTypeIdentifier, update *revocation.Update, create bool) error {
// Unmarshal and verify the record against the appropriate public key
pk, err := rs.Keys.PublicKey(typ.IssuerIdentifier(), update.SignedAccumulator.PKCounter)
if err != nil {
......@@ -301,12 +301,12 @@ func (rs *RevocationStorage) addUpdate(tx revStorage, typ CredentialTypeIdentifi
// Issuance records
func (rs *RevocationStorage) AddIssuanceRecord(r *IssuanceRecord) error {
return rs.db.Insert(r)
return rs.sqldb.Insert(r)
}
func (rs *RevocationStorage) IssuanceRecord(typ CredentialTypeIdentifier, key string) (*IssuanceRecord, error) {
var r IssuanceRecord
err := rs.db.Last(&r, map[string]interface{}{"cred_type": typ, "revocationkey": key})
err := rs.sqldb.Last(&r, map[string]interface{}{"cred_type": typ, "revocationkey": key})
if err != nil {
return nil, err
}
......@@ -323,7 +323,7 @@ func (rs *RevocationStorage) Revoke(typ CredentialTypeIdentifier, key string) er
return errors.Errorf("cannot revoke %s", typ)
}
return rs.db.Transaction(func(tx revStorage) error {
return rs.sqldb.Transaction(func(tx sqlRevStorage) error {
var err error
issrecord, err := rs.IssuanceRecord(typ, key)
if err != nil {
......@@ -341,7 +341,7 @@ func (rs *RevocationStorage) Revoke(typ CredentialTypeIdentifier, key string) er
})
}
func (rs *RevocationStorage) revokeAttr(tx revStorage, typ CredentialTypeIdentifier, sk *revocation.PrivateKey, e *RevocationAttribute) error {
func (rs *RevocationStorage) revokeAttr(tx sqlRevStorage, typ CredentialTypeIdentifier, sk *revocation.PrivateKey, e *RevocationAttribute) error {
sacc, err := rs.accumulator(tx, typ, sk.Counter)
if err != nil {
return err
......@@ -351,7 +351,7 @@ func (rs *RevocationStorage) revokeAttr(tx revStorage, typ CredentialTypeIdentif
}
cur := sacc.Accumulator
var parent EventRecord
if err = rs.db.Last(&parent, map[string]interface{}{"cred_type": typ, "pk_counter": sk.Counter}); err != nil {
if err = rs.sqldb.Last(&parent, map[string]interface{}{"cred_type": typ, "pk_counter": sk.Counter}); err != nil {
return err
}
......@@ -370,11 +370,11 @@ func (rs *RevocationStorage) revokeAttr(tx revStorage, typ CredentialTypeIdentif
func (rs *RevocationStorage) Accumulator(typ CredentialTypeIdentifier, pkcounter uint) (
*revocation.SignedAccumulator, error,
) {
return rs.accumulator(rs.db, typ, pkcounter)
return rs.accumulator(rs.sqldb, typ, pkcounter)
}
// accumulator retrieves, verifies and deserializes the accumulator of the given type and key.
func (rs *RevocationStorage) accumulator(tx revStorage, typ CredentialTypeIdentifier, pkcounter uint) (
func (rs *RevocationStorage) accumulator(tx sqlRevStorage, typ CredentialTypeIdentifier, pkcounter uint) (
*revocation.SignedAccumulator, error,
) {
var err error
......@@ -465,7 +465,7 @@ func (rs *RevocationStorage) SaveIssuanceRecord(typ CredentialTypeIdentifier, re
// Misscelaneous methods
func (rs *RevocationStorage) updateAccumulatorTimes(types []CredentialTypeIdentifier) error {
return rs.db.Transaction(func(tx revStorage) error {
return rs.sqldb.Transaction(func(tx sqlRevStorage) error {
var err error
var records []AccumulatorRecord
Logger.Tracef("updating accumulator times")
......@@ -541,7 +541,7 @@ func (rs *RevocationStorage) Load(debug bool, dbtype, connstr string, settings m
if !rs.sqlMode {
return
}
if err := rs.db.Delete(IssuanceRecord{}, "valid_until < ?", time.Now().UnixNano()); err != nil {
if err := rs.sqldb.Delete(IssuanceRecord{}, "valid_until < ?", time.Now().UnixNano()); err != nil {
err = errors.WrapPrefix(err, "failed to delete expired issuance records", 0)
raven.CaptureError(err, nil)
}
......@@ -557,7 +557,7 @@ func (rs *RevocationStorage) Load(debug bool, dbtype, connstr string, settings m
if err != nil {
return err
}
rs.db = db
rs.sqldb = db
rs.sqlMode = true
}
if settings != nil {
......@@ -577,10 +577,7 @@ func (rs *RevocationStorage) Load(debug bool, dbtype, connstr string, settings m
}
func (rs *RevocationStorage) Close() error {
if rs.db != nil {
return rs.db.Close()
}
return nil
return rs.sqldb.Close()
}
// SetRevocationUpdates retrieves the latest revocation records from the database, and attaches
......
......@@ -11,28 +11,7 @@ import (
)
type (
revStorage interface {
// Transaction executes the given closure within a transaction.
Transaction(f func(tx revStorage) error) (err error)
// Insert a new record which must not yet exist.
Insert(o interface{}) error
// Save an existing record.
Save(o interface{}) error
// Last deserializes the last record into o.
Last(dest interface{}, query interface{}, args ...interface{}) error
// Exists checks whether records exist satisfying col = key.
Exists(typ interface{}, query interface{}, args ...interface{}) (bool, error)
// Delete records of the given type satisfying the query.
Delete(typ interface{}, query interface{}, args ...interface{}) error
// Find deserializes into o all records satisfying the specified query.
Find(dest interface{}, query interface{}, args ...interface{}) error
// Latest deserializes into o the last items; amount specified by count, ordered by col.
Latest(dest interface{}, count uint64, query interface{}, args ...interface{}) error
// Close the database.
Close() error
}
// sqlRevStorage implements the revStorage interface, storing any record type in a SQL database,
// sqlRevStorage is a wrapper around gorm, storing any record type in a SQL database,
// for use by revocation servers.
sqlRevStorage struct {
gorm *gorm.DB
......@@ -50,16 +29,16 @@ type (
}
)
func newSqlStorage(debug bool, dbtype, connstr string) (revStorage, error) {
func newSqlStorage(debug bool, dbtype, connstr string) (sqlRevStorage, error) {
switch dbtype {
case "postgres", "mysql":
default:
return nil, errors.New("unsupported database type")
return sqlRevStorage{}, errors.New("unsupported database type")
}
g, err := gorm.Open(dbtype, connstr)
if err != nil {
return nil, err
return sqlRevStorage{}, err
}
if debug {
......@@ -67,24 +46,27 @@ func newSqlStorage(debug bool, dbtype, connstr string) (revStorage, error) {
g.SetLogger(gorm.Logger{LogWriter: log.New(Logger.WriterLevel(logrus.DebugLevel), "db: ", 0)})
}
if g.AutoMigrate((*EventRecord)(nil)); g.Error != nil {
return nil, g.Error
return sqlRevStorage{}, g.Error
}
if g.AutoMigrate((*AccumulatorRecord)(nil)); g.Error != nil {
return nil, g.Error
return sqlRevStorage{}, g.Error
}
if g.AutoMigrate((*IssuanceRecord)(nil)); g.Error != nil {
return nil, g.Error
return sqlRevStorage{}, g.Error
}
return sqlRevStorage{gorm: g}, nil
}
func (s sqlRevStorage) Close() error {
if s.gorm == nil {
return nil
}
Logger.Debug("closing revocation sql database connection")
return s.gorm.Close()
}
func (s sqlRevStorage) Transaction(f func(tx revStorage) error) (err error) {
func (s sqlRevStorage) Transaction(f func(tx sqlRevStorage) error) (err error) {
tx := sqlRevStorage{gorm: s.gorm.Begin()}
defer func() {
if e := recover(); e != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment