Commit c2fb315f authored by Sietse Ringers's avatar Sietse Ringers
Browse files

Use separate session tokens for requestor and client

parent 36157819
......@@ -33,8 +33,9 @@ func New(conf *server.Configuration) (*Server, error) {
conf: conf,
scheduler: gocron.NewScheduler(),
sessions: &memorySessionStore{
m: make(map[string]*session),
conf: conf,
requestor: make(map[string]*session),
client: make(map[string]*session),
conf: conf,
},
}
s.scheduler.Every(10).Seconds().Do(func() {
......@@ -174,7 +175,7 @@ func (s *Server) StartSession(req interface{}) (*irma.Qr, string, error) {
}
return &irma.Qr{
Type: action,
URL: s.conf.URL + session.token,
URL: s.conf.URL + session.clientToken,
}, session.token, nil
}
......@@ -214,8 +215,13 @@ func ParsePath(path string) (string, string, error) {
return matches[1], matches[2], nil
}
func (s *Server) SubscribeServerSentEvents(w http.ResponseWriter, r *http.Request, token string) error {
session := s.sessions.get(token)
func (s *Server) SubscribeServerSentEvents(w http.ResponseWriter, r *http.Request, token string, requestor bool) error {
var session *session
if requestor {
session = s.sessions.get(token)
} else {
session = s.sessions.clientGet(token)
}
if session == nil {
return server.LogError(errors.Errorf("can't subscribe to server sent events of unknown session %s", token))
}
......@@ -257,9 +263,9 @@ func (s *Server) HandleProtocolMessage(
}
// Fetch the session
session := s.sessions.get(token)
session := s.sessions.clientGet(token)
if session == nil {
s.conf.Logger.Warnf("Session not found: %s", token)
s.conf.Logger.WithField("clientToken", token).Warn("Session not found")
status, output = server.JsonResponse(nil, server.RemoteError(server.ErrorSessionUnknown, ""))
return
}
......
......@@ -16,11 +16,12 @@ import (
type session struct {
sync.Mutex
action irma.Action
token string
version *irma.ProtocolVersion
rrequest irma.RequestorRequest
request irma.SessionRequest
action irma.Action
token string
clientToken string
version *irma.ProtocolVersion
rrequest irma.RequestorRequest
request irma.SessionRequest
status server.Status
prevStatus server.Status
......@@ -37,7 +38,8 @@ type session struct {
type sessionStore interface {
get(token string) *session
add(token string, session *session)
clientGet(token string) *session
add(session *session)
update(session *session)
deleteExpired()
}
......@@ -45,7 +47,9 @@ type sessionStore interface {
type memorySessionStore struct {
sync.RWMutex
conf *server.Configuration
m map[string]*session
requestor map[string]*session
client map[string]*session
}
const (
......@@ -62,16 +66,23 @@ func init() {
rand.Seed(time.Now().UnixNano())
}
func (s *memorySessionStore) get(token string) *session {
func (s *memorySessionStore) get(t string) *session {
s.RLock()
defer s.RUnlock()
return s.requestor[t]
}
func (s *memorySessionStore) clientGet(t string) *session {
s.RLock()
defer s.RUnlock()
return s.m[token]
return s.client[t]
}
func (s *memorySessionStore) add(token string, session *session) {
func (s *memorySessionStore) add(session *session) {
s.Lock()
defer s.Unlock()
s.m[token] = session
s.requestor[session.token] = session
s.client[session.clientToken] = session
}
func (s *memorySessionStore) update(session *session) {
......@@ -82,8 +93,8 @@ func (s *memorySessionStore) deleteExpired() {
// First check which sessions have expired
// We don't need a write lock for this yet, so postpone that for actual deleting
s.RLock()
expired := make([]string, 0, len(s.m))
for token, session := range s.m {
expired := make([]string, 0, len(s.requestor))
for token, session := range s.requestor {
session.Lock()
timeout := maxSessionLifetime
......@@ -108,11 +119,11 @@ func (s *memorySessionStore) deleteExpired() {
// Using a write lock, delete the expired sessions
s.Lock()
for _, token := range expired {
session := s.m[token]
session := s.requestor[token]
if session.evtSource != nil {
session.evtSource.Close()
}
delete(s.m, token)
delete(s.requestor, token)
}
s.Unlock()
}
......@@ -121,16 +132,19 @@ var one *big.Int = big.NewInt(1)
func (s *Server) newSession(action irma.Action, request irma.RequestorRequest) *session {
token := newSessionToken()
clientToken := newSessionToken()
ses := &session{
action: action,
rrequest: request,
request: request.SessionRequest(),
lastActive: time.Now(),
token: token,
status: server.StatusInitialized,
prevStatus: server.StatusInitialized,
conf: s.conf,
sessions: s.sessions,
action: action,
rrequest: request,
request: request.SessionRequest(),
lastActive: time.Now(),
token: token,
clientToken: clientToken,
status: server.StatusInitialized,
prevStatus: server.StatusInitialized,
conf: s.conf,
sessions: s.sessions,
result: &server.SessionResult{
Token: token,
Type: action,
......@@ -142,7 +156,7 @@ func (s *Server) newSession(action irma.Action, request irma.RequestorRequest) *
nonce, _ := gabi.RandomBigInt(gabi.DefaultSystemParameters[2048].Lstatzk)
ses.request.SetNonce(nonce)
ses.request.SetContext(one)
s.sessions.add(token, ses)
s.sessions.add(ses)
return ses
}
......
......@@ -13,11 +13,14 @@ import (
irma "github.com/privacybydesign/irmago"
"github.com/privacybydesign/irmago/internal/test"
"github.com/privacybydesign/irmago/irmaclient"
"github.com/privacybydesign/irmago/server"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
func init() {
irma.ForceHttps = false
irma.Logger.SetLevel(logrus.WarnLevel)
}
func TestMain(m *testing.M) {
......@@ -136,29 +139,35 @@ func getCombinedIssuanceRequest(id irma.AttributeTypeIdentifier) *irma.IssuanceR
var TestType = "irmaserver-jwt"
func startSession(t *testing.T, request irma.SessionRequest, sessiontype string) *irma.Qr {
var qr irma.Qr
var err error
var (
qr *irma.Qr = new(irma.Qr)
sesPkg *server.SessionPackage
err error
)
switch TestType {
case "apiserver":
url := "http://localhost:8088/irma_api_server/api/v2/" + sessiontype
err = irma.NewHTTPTransport(url).Post("", &qr, getJwt(t, request, sessiontype, jwt.SigningMethodNone))
err = irma.NewHTTPTransport(url).Post("", qr, getJwt(t, request, sessiontype, jwt.SigningMethodNone))
qr.URL = url + "/" + qr.URL
case "irmaserver-jwt":
url := "http://localhost:48682"
err = irma.NewHTTPTransport(url).Post("session", &qr, getJwt(t, request, sessiontype, jwt.SigningMethodRS256))
err = irma.NewHTTPTransport(url).Post("session", &sesPkg, getJwt(t, request, sessiontype, jwt.SigningMethodRS256))
qr = sesPkg.SessionPtr
case "irmaserver-hmac-jwt":
url := "http://localhost:48682"
err = irma.NewHTTPTransport(url).Post("session", &qr, getJwt(t, request, sessiontype, jwt.SigningMethodHS256))
err = irma.NewHTTPTransport(url).Post("session", &sesPkg, getJwt(t, request, sessiontype, jwt.SigningMethodHS256))
qr = sesPkg.SessionPtr
case "irmaserver":
url := "http://localhost:48682"
err = irma.NewHTTPTransport(url).Post("session", &qr, request)
err = irma.NewHTTPTransport(url).Post("session", &sesPkg, request)
qr = sesPkg.SessionPtr
default:
t.Fatal("Invalid TestType")
}
require.NoError(t, err)
return &qr
return qr
}
func getJwt(t *testing.T, request irma.SessionRequest, sessiontype string, alg jwt.SigningMethod) string {
......
......@@ -4,7 +4,6 @@ import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/go-errors/errors"
......@@ -142,16 +141,16 @@ func serverRequest(
func postRequest(serverurl string, request irma.RequestorRequest, name, authmethod, key string) (*irma.Qr, *irma.HTTPTransport, error) {
var (
err error
qr = &irma.Qr{}
pkg = &server.SessionPackage{}
transport = irma.NewHTTPTransport(serverurl)
)
switch authmethod {
case "none":
err = transport.Post("session", qr, request)
err = transport.Post("session", pkg, request)
case "token":
transport.SetHeader("Authorization", key)
err = transport.Post("session", qr, request)
err = transport.Post("session", pkg, request)
case "hmac", "rsa":
var jwtstr string
jwtstr, err = signRequest(request, name, authmethod, key)
......@@ -159,14 +158,14 @@ func postRequest(serverurl string, request irma.RequestorRequest, name, authmeth
return nil, nil, err
}
logger.Debug("Session request JWT: ", jwtstr)
err = transport.Post("session", qr, jwtstr)
err = transport.Post("session", pkg, jwtstr)
default:
return nil, nil, errors.New("Invalid authentication method (must be none, token, hmac or rsa)")
}
token := qr.URL[strings.LastIndex(qr.URL, "/")+1:]
token := pkg.Token
transport.Server += fmt.Sprintf("session/%s/", token)
return qr, transport, err
return pkg.SessionPtr, transport, err
}
// Configuration functions
......
......@@ -49,6 +49,11 @@ type Configuration struct {
Email string `json:"email" mapstructure:"email"`
}
type SessionPackage struct {
SessionPtr *irma.Qr `json:"sessionPtr"`
Token string `json:"token"`
}
// SessionResult contains session information such as the session status, type, possible errors,
// and disclosed attributes or attribute-based signature if appropriate to the session type.
type SessionResult struct {
......
......@@ -90,11 +90,11 @@ func (s *Server) CancelSession(token string) error {
// SubscribeServerSentEvents subscribes the HTTP client to server sent events on status updates
// of the specified IRMA session.
func SubscribeServerSentEvents(w http.ResponseWriter, r *http.Request, token string) error {
return s.SubscribeServerSentEvents(w, r, token)
func SubscribeServerSentEvents(w http.ResponseWriter, r *http.Request, token string, requestor bool) error {
return s.SubscribeServerSentEvents(w, r, token, requestor)
}
func (s *Server) SubscribeServerSentEvents(w http.ResponseWriter, r *http.Request, token string) error {
return s.Server.SubscribeServerSentEvents(w, r, token)
func (s *Server) SubscribeServerSentEvents(w http.ResponseWriter, r *http.Request, token string, requestor bool) error {
return s.Server.SubscribeServerSentEvents(w, r, token, requestor)
}
// HandlerFunc returns a http.HandlerFunc that handles the IRMA protocol
......@@ -120,7 +120,7 @@ func (s *Server) HandlerFunc() http.HandlerFunc {
token, noun, err := servercore.ParsePath(r.URL.Path)
if err == nil && noun == "statusevents" { // if err != nil we let it be handled by HandleProtocolMessage below
if err = s.SubscribeServerSentEvents(w, r, token); err != nil {
if err = s.SubscribeServerSentEvents(w, r, token, false); err != nil {
server.WriteError(w, server.ErrorUnexpectedRequest, err.Error())
}
return
......
......@@ -241,13 +241,16 @@ func (s *Server) handleCreate(w http.ResponseWriter, r *http.Request) {
}
// Everything is authenticated and parsed, we're good to go!
qr, _, err := s.irmaserv.StartSession(rrequest, s.doResultCallback)
qr, token, err := s.irmaserv.StartSession(rrequest, s.doResultCallback)
if err != nil {
server.WriteError(w, server.ErrorInvalidRequest, err.Error())
return
}
server.WriteJson(w, qr)
server.WriteJson(w, server.SessionPackage{
SessionPtr: qr,
Token: token,
})
}
func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
......@@ -262,7 +265,7 @@ func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleStatusEvents(w http.ResponseWriter, r *http.Request) {
token := chi.URLParam(r, "token")
s.conf.Logger.WithFields(logrus.Fields{"session": token}).Debug("new client subscribed to server sent events")
if err := s.irmaserv.SubscribeServerSentEvents(w, r, token); err != nil {
if err := s.irmaserv.SubscribeServerSentEvents(w, r, token, true); err != nil {
server.WriteError(w, server.ErrorUnexpectedRequest, err.Error())
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment