Commit 60e4f166 authored by Sietse Ringers's avatar Sietse Ringers
Browse files

Move all functions and global state to structs and methods

parent 7a812230
......@@ -2,45 +2,14 @@ package sessiontest
import (
"encoding/json"
"net/http"
"path/filepath"
"testing"
"github.com/privacybydesign/irmago"
"github.com/privacybydesign/irmago/internal/test"
"github.com/privacybydesign/irmago/server"
"github.com/privacybydesign/irmago/server/irmarequestor"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
var irmaServer *http.Server
func StartIrmaClientServer(t *testing.T) {
testdata := test.FindTestdataFolder(t)
logger := logrus.New()
logger.Level = logrus.WarnLevel
logger.Formatter = &logrus.TextFormatter{}
require.NoError(t, irmarequestor.Initialize(&server.Configuration{
URL: "http://localhost:48680",
Logger: logger,
SchemesPath: filepath.Join(testdata, "irma_configuration"),
IssuerPrivateKeysPath: filepath.Join(testdata, "privatekeys"),
}))
mux := http.NewServeMux()
mux.HandleFunc("/", irmarequestor.HttpHandlerFunc())
irmaServer = &http.Server{Addr: ":48680", Handler: mux}
go func() {
_ = irmaServer.ListenAndServe()
}()
}
func StopIrmaClientServer() {
_ = irmaServer.Close()
}
func requestorSessionHelper(t *testing.T, request irma.SessionRequest) *server.SessionResult {
StartIrmaClientServer(t)
defer StopIrmaClientServer()
......@@ -51,7 +20,7 @@ func requestorSessionHelper(t *testing.T, request irma.SessionRequest) *server.S
clientChan := make(chan *SessionResult)
serverChan := make(chan *server.SessionResult)
qr, token, err := irmarequestor.StartSession(request, func(result *server.SessionResult) {
qr, token, err := irmaServer.StartSession(request, func(result *server.SessionResult) {
serverChan <- result
})
require.NoError(t, err)
......
package sessiontest
import (
"net/http"
"path/filepath"
"testing"
"time"
"github.com/privacybydesign/irmago/internal/test"
"github.com/privacybydesign/irmago/server"
"github.com/privacybydesign/irmago/server/irmarequestor"
"github.com/privacybydesign/irmago/server/irmaserver"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
var (
httpServer *http.Server
irmaServer *irmarequestor.Server
combinedServer *irmaserver.Server
logger = logrus.New()
testdata = test.FindTestdataFolder(nil)
)
func init() {
logger.Level = logrus.WarnLevel
logger.Level = logrus.ErrorLevel
logger.Formatter = &logrus.TextFormatter{}
}
func StartIrmaServer(configuration *irmaserver.Configuration) {
go func() {
err := irmaserver.Start(configuration)
if err != nil {
var err error
if combinedServer, err = irmaserver.New(configuration); err != nil {
panic(err)
}
if err = combinedServer.Start(configuration); err != nil {
panic("Starting server failed: " + err.Error())
}
}()
......@@ -31,7 +42,36 @@ func StartIrmaServer(configuration *irmaserver.Configuration) {
}
func StopIrmaServer() {
_ = irmaserver.Stop()
_ = combinedServer.Stop()
}
func StartIrmaClientServer(t *testing.T) {
testdata := test.FindTestdataFolder(t)
logger := logrus.New()
logger.Level = logrus.WarnLevel
logger.Formatter = &logrus.TextFormatter{}
var err error
irmaServer, err = irmarequestor.New(&server.Configuration{
URL: "http://localhost:48680",
Logger: logger,
SchemesPath: filepath.Join(testdata, "irma_configuration"),
IssuerPrivateKeysPath: filepath.Join(testdata, "privatekeys"),
})
require.NoError(t, err)
mux := http.NewServeMux()
mux.HandleFunc("/", irmaServer.HttpHandlerFunc())
httpServer = &http.Server{Addr: ":48680", Handler: mux}
go func() {
_ = httpServer.ListenAndServe()
}()
}
func StopIrmaClientServer() {
_ = httpServer.Close()
}
var IrmaServerConfiguration = &irmaserver.Configuration{
......
......@@ -13,6 +13,7 @@ import (
"strings"
"github.com/go-errors/errors"
"github.com/jasonlvhit/gocron"
"github.com/privacybydesign/gabi"
"github.com/privacybydesign/gabi/big"
"github.com/privacybydesign/irmago"
......@@ -20,68 +21,89 @@ import (
"github.com/sirupsen/logrus"
)
func Initialize(configuration *server.Configuration) error {
conf = configuration
type Server struct {
conf *server.Configuration
sessions sessionStore
scheduler *gocron.Scheduler
}
func New(conf *server.Configuration) (*Server, error) {
s := &Server{
conf: conf,
scheduler: gocron.NewScheduler(),
sessions: &memorySessionStore{
m: make(map[string]*session),
conf: conf,
},
}
s.scheduler.Every(10).Seconds().Do(func() {
s.sessions.deleteExpired()
})
s.scheduler.Start()
return s, s.verifyConfiguration(s.conf)
}
if conf.Logger == nil {
conf.Logger = logrus.New()
conf.Logger.Level = logrus.DebugLevel
conf.Logger.Formatter = &logrus.TextFormatter{}
func (s *Server) verifyConfiguration(configuration *server.Configuration) error {
if s.conf.Logger == nil {
s.conf.Logger = logrus.New()
s.conf.Logger.Level = logrus.DebugLevel
s.conf.Logger.Formatter = &logrus.TextFormatter{}
}
server.Logger = conf.Logger
irma.Logger = conf.Logger
server.Logger = s.conf.Logger
irma.Logger = s.conf.Logger
if conf.IrmaConfiguration == nil {
if s.conf.IrmaConfiguration == nil {
var err error
if conf.SchemesAssetsPath == "" {
conf.IrmaConfiguration, err = irma.NewConfiguration(conf.SchemesPath)
if s.conf.SchemesAssetsPath == "" {
s.conf.IrmaConfiguration, err = irma.NewConfiguration(s.conf.SchemesPath)
} else {
conf.IrmaConfiguration, err = irma.NewConfigurationFromAssets(conf.SchemesPath, conf.SchemesAssetsPath)
s.conf.IrmaConfiguration, err = irma.NewConfigurationFromAssets(s.conf.SchemesPath, s.conf.SchemesAssetsPath)
}
if err != nil {
return server.LogError(err)
}
if err = conf.IrmaConfiguration.ParseFolder(); err != nil {
if err = s.conf.IrmaConfiguration.ParseFolder(); err != nil {
return server.LogError(err)
}
}
if len(conf.IrmaConfiguration.SchemeManagers) == 0 {
if conf.DownloadDefaultSchemes {
if err := conf.IrmaConfiguration.DownloadDefaultSchemes(); err != nil {
if len(s.conf.IrmaConfiguration.SchemeManagers) == 0 {
if s.conf.DownloadDefaultSchemes {
if err := s.conf.IrmaConfiguration.DownloadDefaultSchemes(); err != nil {
return server.LogError(err)
}
} else {
return server.LogError(errors.New("no schemes found in irma_configuration folder " + conf.IrmaConfiguration.Path))
return server.LogError(errors.New("no schemes found in irma_configuration folder " + s.conf.IrmaConfiguration.Path))
}
}
if conf.SchemeUpdateInterval != 0 {
conf.IrmaConfiguration.AutoUpdateSchemes(uint(conf.SchemeUpdateInterval))
if s.conf.SchemeUpdateInterval != 0 {
s.conf.IrmaConfiguration.AutoUpdateSchemes(uint(s.conf.SchemeUpdateInterval))
}
if conf.IssuerPrivateKeys == nil {
conf.IssuerPrivateKeys = make(map[irma.IssuerIdentifier]*gabi.PrivateKey)
if s.conf.IssuerPrivateKeys == nil {
s.conf.IssuerPrivateKeys = make(map[irma.IssuerIdentifier]*gabi.PrivateKey)
}
if conf.IssuerPrivateKeysPath != "" {
files, err := ioutil.ReadDir(conf.IssuerPrivateKeysPath)
if s.conf.IssuerPrivateKeysPath != "" {
files, err := ioutil.ReadDir(s.conf.IssuerPrivateKeysPath)
if err != nil {
return server.LogError(err)
}
for _, file := range files {
filename := file.Name()
issid := irma.NewIssuerIdentifier(strings.TrimSuffix(filename, filepath.Ext(filename))) // strip .xml
if _, ok := conf.IrmaConfiguration.Issuers[issid]; !ok {
if _, ok := s.conf.IrmaConfiguration.Issuers[issid]; !ok {
return server.LogError(errors.Errorf("Private key %s belongs to an unknown issuer", filename))
}
sk, err := gabi.NewPrivateKeyFromFile(filepath.Join(conf.IssuerPrivateKeysPath, filename))
sk, err := gabi.NewPrivateKeyFromFile(filepath.Join(s.conf.IssuerPrivateKeysPath, filename))
if err != nil {
return server.LogError(err)
}
conf.IssuerPrivateKeys[issid] = sk
s.conf.IssuerPrivateKeys[issid] = sk
}
}
for issid, sk := range conf.IssuerPrivateKeys {
pk, err := conf.IrmaConfiguration.PublicKey(issid, int(sk.Counter))
for issid, sk := range s.conf.IssuerPrivateKeys {
pk, err := s.conf.IrmaConfiguration.PublicKey(issid, int(sk.Counter))
if err != nil {
return server.LogError(err)
}
......@@ -93,18 +115,18 @@ func Initialize(configuration *server.Configuration) error {
}
}
if conf.URL != "" {
if !strings.HasSuffix(conf.URL, "/") {
conf.URL = conf.URL + "/"
if s.conf.URL != "" {
if !strings.HasSuffix(s.conf.URL, "/") {
s.conf.URL = s.conf.URL + "/"
}
} else {
conf.Logger.Warn("No url parameter specified in configuration; unless an url is elsewhere prepended in the QR, the IRMA client will not be able to connect")
s.conf.Logger.Warn("No url parameter specified in configuration; unless an url is elsewhere prepended in the QR, the IRMA client will not be able to connect")
}
return nil
}
func StartSession(req interface{}) (*irma.Qr, string, error) {
func (s *Server) StartSession(req interface{}) (*irma.Qr, string, error) {
rrequest, err := server.ParseSessionRequest(req)
if err != nil {
return nil, "", err
......@@ -113,44 +135,44 @@ func StartSession(req interface{}) (*irma.Qr, string, error) {
request := rrequest.SessionRequest()
action := request.Action()
if action == irma.ActionIssuing {
if err := validateIssuanceRequest(request.(*irma.IssuanceRequest)); err != nil {
if err := s.validateIssuanceRequest(request.(*irma.IssuanceRequest)); err != nil {
return nil, "", err
}
}
session := newSession(action, rrequest)
conf.Logger.WithFields(logrus.Fields{"action": action, "session": session.token}).Infof("Session started")
if conf.Logger.IsLevelEnabled(logrus.DebugLevel) {
conf.Logger.WithFields(logrus.Fields{"session": session.token}).Info("Session request: ", server.ToJson(rrequest))
session := s.newSession(action, rrequest)
s.conf.Logger.WithFields(logrus.Fields{"action": action, "session": session.token}).Infof("Session started")
if s.conf.Logger.IsLevelEnabled(logrus.DebugLevel) {
s.conf.Logger.WithFields(logrus.Fields{"session": session.token}).Info("Session request: ", server.ToJson(rrequest))
} else {
conf.Logger.WithFields(logrus.Fields{"session": session.token}).Info("Session request (purged of attribute values): ", server.ToJson(purgeRequest(rrequest)))
s.conf.Logger.WithFields(logrus.Fields{"session": session.token}).Info("Session request (purged of attribute values): ", server.ToJson(purgeRequest(rrequest)))
}
return &irma.Qr{
Type: action,
URL: conf.URL + session.token,
URL: s.conf.URL + session.token,
}, session.token, nil
}
func GetSessionResult(token string) *server.SessionResult {
session := sessions.get(token)
func (s *Server) GetSessionResult(token string) *server.SessionResult {
session := s.sessions.get(token)
if session == nil {
conf.Logger.Warn("Session result requested of unknown session ", token)
s.conf.Logger.Warn("Session result requested of unknown session ", token)
return nil
}
return session.result
}
func GetRequest(token string) irma.RequestorRequest {
session := sessions.get(token)
func (s *Server) GetRequest(token string) irma.RequestorRequest {
session := s.sessions.get(token)
if session == nil {
conf.Logger.Warn("Session request requested of unknown session ", token)
s.conf.Logger.Warn("Session request requested of unknown session ", token)
return nil
}
return session.rrequest
}
func CancelSession(token string) error {
session := sessions.get(token)
func (s *Server) CancelSession(token string) error {
session := s.sessions.get(token)
if session == nil {
return server.LogError(errors.Errorf("can't cancel unknown session %s", token))
}
......@@ -167,8 +189,8 @@ func ParsePath(path string) (string, string, error) {
return matches[1], matches[2], nil
}
func SubscribeServerSentEvents(w http.ResponseWriter, r *http.Request, token string) error {
session := sessions.get(token)
func (s *Server) SubscribeServerSentEvents(w http.ResponseWriter, r *http.Request, token string) error {
session := s.sessions.get(token)
if session == nil {
return server.LogError(errors.Errorf("can't subscribe to server sent events of unknown session %s", token))
}
......@@ -182,7 +204,7 @@ func SubscribeServerSentEvents(w http.ResponseWriter, r *http.Request, token str
return nil
}
func HandleProtocolMessage(
func (s *Server) HandleProtocolMessage(
path string,
method string,
headers map[string][]string,
......@@ -198,11 +220,11 @@ func HandleProtocolMessage(
}
}
conf.Logger.WithFields(logrus.Fields{"method": method, "path": path}).Debugf("Routing protocol message")
s.conf.Logger.WithFields(logrus.Fields{"method": method, "path": path}).Debugf("Routing protocol message")
if len(message) > 0 {
conf.Logger.Trace("POST body: ", string(message))
s.conf.Logger.Trace("POST body: ", string(message))
}
conf.Logger.Trace("HTTP headers: ", server.ToJson(headers))
s.conf.Logger.Trace("HTTP headers: ", server.ToJson(headers))
token, noun, err := ParsePath(path)
if err != nil {
status, output = server.JsonResponse(nil, server.RemoteError(server.ErrorUnsupported, ""))
......@@ -210,9 +232,9 @@ func HandleProtocolMessage(
}
// Fetch the session
session := sessions.get(token)
session := s.sessions.get(token)
if session == nil {
conf.Logger.Warnf("Session not found: %s", token)
s.conf.Logger.Warnf("Session not found: %s", token)
status, output = server.JsonResponse(nil, server.RemoteError(server.ErrorSessionUnknown, ""))
return
}
......
......@@ -12,8 +12,6 @@ import (
// Maintaining the session state is done here, as well as checking whether the session is in the
// appropriate status before handling the request.
var conf *server.Configuration
func (session *session) handleDelete() {
if session.status.Finished() {
return
......@@ -34,7 +32,7 @@ func (session *session) handleGetRequest(min, max *irma.ProtocolVersion) (irma.S
if session.version, err = chooseProtocolVersion(min, max); err != nil {
return nil, session.fail(server.ErrorProtocolVersion, "")
}
conf.Logger.WithFields(logrus.Fields{"session": session.token, "version": session.version.String()}).Debugf("Protocol version negotiated")
session.conf.Logger.WithFields(logrus.Fields{"session": session.token, "version": session.version.String()}).Debugf("Protocol version negotiated")
session.request.SetVersion(session.version)
session.setStatus(server.StatusConnected)
......@@ -55,7 +53,7 @@ func (session *session) handlePostSignature(signature *irma.SignedMessage) (*irm
var rerr *irma.RemoteError
session.result.Signature = signature
session.result.Disclosed, session.result.ProofStatus, err = signature.Verify(
conf.IrmaConfiguration, session.request.(*irma.SignatureRequest))
session.conf.IrmaConfiguration, session.request.(*irma.SignatureRequest))
if err == nil {
session.setStatus(server.StatusDone)
} else {
......@@ -77,7 +75,7 @@ func (session *session) handlePostDisclosure(disclosure irma.Disclosure) (*irma.
var err error
var rerr *irma.RemoteError
session.result.Disclosed, session.result.ProofStatus, err = disclosure.Verify(
conf.IrmaConfiguration, session.request.(*irma.DisclosureRequest))
session.conf.IrmaConfiguration, session.request.(*irma.DisclosureRequest))
if err == nil {
session.setStatus(server.StatusDone)
} else {
......@@ -104,13 +102,13 @@ func (session *session) handlePostCommitments(commitments *irma.IssueCommitmentM
// Compute list of public keys against which to verify the received proofs
disclosureproofs := irma.ProofList(commitments.Proofs[:discloseCount])
pubkeys, err := disclosureproofs.ExtractPublicKeys(conf.IrmaConfiguration)
pubkeys, err := disclosureproofs.ExtractPublicKeys(session.conf.IrmaConfiguration)
if err != nil {
return nil, session.fail(server.ErrorInvalidProofs, err.Error())
}
for _, cred := range request.Credentials {
iss := cred.CredentialTypeID.IssuerIdentifier()
pubkey, _ := conf.IrmaConfiguration.PublicKey(iss, cred.KeyCounter) // No error, already checked earlier
pubkey, _ := session.conf.IrmaConfiguration.PublicKey(iss, cred.KeyCounter) // No error, already checked earlier
pubkeys = append(pubkeys, pubkey)
}
......@@ -118,7 +116,7 @@ func (session *session) handlePostCommitments(commitments *irma.IssueCommitmentM
for i, proof := range commitments.Proofs {
pubkey := pubkeys[i]
schemeid := irma.NewIssuerIdentifier(pubkey.Issuer).SchemeManagerIdentifier()
if conf.IrmaConfiguration.SchemeManagers[schemeid].Distributed() {
if session.conf.IrmaConfiguration.SchemeManagers[schemeid].Distributed() {
proofP, err := session.getProofP(commitments, schemeid)
if err != nil {
return nil, session.fail(server.ErrorKeyshareProofMissing, err.Error())
......@@ -129,7 +127,7 @@ func (session *session) handlePostCommitments(commitments *irma.IssueCommitmentM
// Verify all proofs and check disclosed attributes, if any, against request
session.result.Disclosed, session.result.ProofStatus, err = commitments.Disclosure().VerifyAgainstDisjunctions(
conf.IrmaConfiguration, request.Disclose, request.Context, request.Nonce, pubkeys, false)
session.conf.IrmaConfiguration, request.Disclose, request.Context, request.Nonce, pubkeys, false)
if err != nil {
if err == irma.ErrorMissingPublicKey {
return nil, session.fail(server.ErrorUnknownPublicKey, "")
......@@ -148,11 +146,11 @@ func (session *session) handlePostCommitments(commitments *irma.IssueCommitmentM
var sigs []*gabi.IssueSignatureMessage
for i, cred := range request.Credentials {
id := cred.CredentialTypeID.IssuerIdentifier()
pk, _ := conf.IrmaConfiguration.PublicKey(id, cred.KeyCounter)
sk, _ := conf.PrivateKey(id)
pk, _ := session.conf.IrmaConfiguration.PublicKey(id, cred.KeyCounter)
sk, _ := session.conf.PrivateKey(id)
issuer := gabi.NewIssuer(sk, pk, one)
proof := commitments.Proofs[i+discloseCount].(*gabi.ProofU)
attributes, err := cred.AttributeList(conf.IrmaConfiguration, 0x03)
attributes, err := cred.AttributeList(session.conf.IrmaConfiguration, 0x03)
if err != nil {
return nil, session.fail(server.ErrorIssuanceFailed, err.Error())
}
......
......@@ -19,20 +19,20 @@ import (
func (session *session) markAlive() {
session.lastActive = time.Now()
conf.Logger.WithFields(logrus.Fields{"session": session.token}).Debugf("Session marked active, expiry delayed")
session.conf.Logger.WithFields(logrus.Fields{"session": session.token}).Debugf("Session marked active, expiry delayed")
}
func (session *session) setStatus(status server.Status) {
conf.Logger.WithFields(logrus.Fields{"session": session.token, "prevStatus": session.prevStatus, "status": status}).
session.conf.Logger.WithFields(logrus.Fields{"session": session.token, "prevStatus": session.prevStatus, "status": status}).
Info("Session status updated")
session.status = status
session.result.Status = status
sessions.update(session)
session.sessions.update(session)
}
func (session *session) onUpdate() {
if session.evtSource != nil {
conf.Logger.WithFields(logrus.Fields{"session": session.token, "status": session.status}).
session.conf.Logger.WithFields(logrus.Fields{"session": session.token, "status": session.status}).
Debug("Sending status to SSE listeners")
session.evtSource.SendEventMessage(string(session.status), "", "")
}
......@@ -47,18 +47,18 @@ func (session *session) fail(err server.Error, message string) *irma.RemoteError
// Issuance helpers
func validateIssuanceRequest(request *irma.IssuanceRequest) error {
func (s *Server) validateIssuanceRequest(request *irma.IssuanceRequest) error {
for _, cred := range request.Credentials {
// Check that we have the appropriate private key
iss := cred.CredentialTypeID.IssuerIdentifier()
privatekey, err := conf.PrivateKey(iss)
privatekey, err := s.conf.PrivateKey(iss)
if err != nil {
return err
}
if privatekey == nil {
return errors.Errorf("missing private key of issuer %s", iss.String())
}
pubkey, err := conf.IrmaConfiguration.PublicKey(iss, int(privatekey.Counter))
pubkey, err := s.conf.IrmaConfiguration.PublicKey(iss, int(privatekey.Counter))
if err != nil {
return err
}
......@@ -68,7 +68,7 @@ func validateIssuanceRequest(request *irma.IssuanceRequest) error {
cred.KeyCounter = int(privatekey.Counter)
// Check that the credential is consistent with irma_configuration
if err := cred.Validate(conf.IrmaConfiguration); err != nil {
if err := cred.Validate(s.conf.IrmaConfiguration); err != nil {
return err
}
......@@ -95,12 +95,12 @@ func (session *session) getProofP(commitments *irma.IssueCommitmentMessage, sche
if !contains {
return nil, errors.Errorf("no keyshare proof included for scheme %s", scheme.Name())
}
conf.Logger.Debug("Parsing keyshare ProofP JWT: ", str)
session.conf.Logger.Debug("Parsing keyshare ProofP JWT: ", str)
claims := &struct {
jwt.StandardClaims
ProofP *gabi.ProofP
}{}
token, err := jwt.ParseWithClaims(str, claims, conf.IrmaConfiguration.KeyshareServerKeyFunc(scheme))
token, err := jwt.ParseWithClaims(str, claims, session.conf.IrmaConfiguration.KeyshareServerKeyFunc(scheme))
if err != nil {
return nil, err
}
......@@ -120,7 +120,7 @@ func (session *session) eventSource() eventsource.EventSource {
return session.evtSource
}
conf.Logger.WithFields(logrus.Fields{"session": session.token}).Debug("Making server sent event source")
session.conf.Logger.WithFields(logrus.Fields{"session": session.token}).Debug("Making server sent event source")
session.evtSource = eventsource.New(nil, func(_ *http.Request) [][]byte { return eventHeaders })
return session.evtSource
}
......
......@@ -5,7 +5,6 @@ import (
"sync"
"time"
"github.com/jasonlvhit/gocron"
"github.com/privacybydesign/gabi"
"github.com/privacybydesign/gabi/big"
"github.com/privacybydesign/irmago"
......@@ -31,6 +30,9 @@ type session struct {
result *server.SessionResult
kssProofs map[irma.SchemeManagerIdentifier]*gabi.ProofP
conf *server.Configuration
sessions sessionStore
}
type sessionStore interface {
......@@ -42,7 +44,8 @@ type sessionStore interface {
type memorySessionStore struct {
sync.RWMutex
m map[string]*session
conf *server.Configuration
m map[string]*session
}
const (
......@@ -53,19 +56,10 @@ const (
var (
minProtocolVersion = irma.NewVersion(2, 4)
maxProtocolVersion = irma.NewVersion(2, 4)
sessions sessionStore = &memorySessionStore{
m: make(map[string]*session),
}
)
func init() {
rand.Seed(time.Now().UnixNano())
gocron.Every(10).Seconds().Do(func() {
sessions.deleteExpired()
})
gocron.Start()
}
func (s *memorySessionStore) get(token string) *session {
......@@ -99,11 +93,11 @@ func (s memorySessionStore) deleteExpired() {
if session.lastActive.Add(timeout).Before(time.Now()) {
if !session.status.Finished() {
conf.Logger.WithFields(logrus.Fields{"session": session.token}).Infof("Session expired")