Commit 3119eb19 authored by Sietse Ringers's avatar Sietse Ringers
Browse files

refactor: fix code duplication in keyshare server starting

parent 0e3c9c54
package cmd package cmd
import ( import (
"context"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
irma "github.com/privacybydesign/irmago" irma "github.com/privacybydesign/irmago"
"github.com/privacybydesign/irmago/server" "github.com/privacybydesign/irmago/server"
"github.com/privacybydesign/irmago/server/keyshare/myirmaserver" "github.com/privacybydesign/irmago/server/keyshare/myirmaserver"
...@@ -21,55 +14,13 @@ var myirmadCmd = &cobra.Command{ ...@@ -21,55 +14,13 @@ var myirmadCmd = &cobra.Command{
Run: func(command *cobra.Command, args []string) { Run: func(command *cobra.Command, args []string) {
conf := configureMyirmad(command) conf := configureMyirmad(command)
// Determine full listening address.
fullAddr := fmt.Sprintf("%s:%d", viper.GetString("listen-addr"), viper.GetInt("port"))
// Load TLS configuration
TLSConfig := configureTLS()
// Create main server // Create main server
myirmaServer, err := myirmaserver.New(conf) myirmaServer, err := myirmaserver.New(conf)
if err != nil { if err != nil {
die("", err) die("", err)
} }
serv := &http.Server{ runServer(myirmaServer, conf.Logger)
Addr: fullAddr,
Handler: myirmaServer.Handler(),
TLSConfig: TLSConfig,
}
stopped := make(chan struct{})
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM)
go func() {
if TLSConfig != nil {
err = serv.ListenAndServeTLS("", "")
} else {
err = serv.ListenAndServe()
}
conf.Logger.Debug("Server stopped")
stopped <- struct{}{}
}()
for {
select {
case <-interrupt:
conf.Logger.Debug("Caught interrupt")
err = serv.Shutdown(context.Background())
if err != nil {
_ = server.LogError(err)
}
myirmaServer.Stop()
conf.Logger.Debug("Sent stop signal to server")
case <-stopped:
conf.Logger.Info("Exiting")
close(stopped)
close(interrupt)
return
}
}
}, },
} }
......
package cmd package cmd
import "github.com/sietseringers/cobra" import (
"context"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/privacybydesign/irmago/server"
"github.com/sietseringers/cobra"
"github.com/sietseringers/viper"
"github.com/sirupsen/logrus"
)
var keyshareRoot = &cobra.Command{ var keyshareRoot = &cobra.Command{
Use: "keyshare", Use: "keyshare",
...@@ -10,3 +22,58 @@ var keyshareRoot = &cobra.Command{ ...@@ -10,3 +22,58 @@ var keyshareRoot = &cobra.Command{
func init() { func init() {
RootCmd.AddCommand(keyshareRoot) RootCmd.AddCommand(keyshareRoot)
} }
type stoppableServer interface {
Handler() http.Handler
Stop()
}
func runServer(serv stoppableServer, logger *logrus.Logger) {
// Determine full listening address.
fullAddr := fmt.Sprintf("%s:%d", viper.GetString("listen-addr"), viper.GetInt("port"))
// Load TLS configuration
TLSConfig := configureTLS()
httpServer := &http.Server{
Addr: fullAddr,
Handler: serv.Handler(),
TLSConfig: TLSConfig,
}
stopped := make(chan struct{})
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM)
go func() {
var err error
if TLSConfig != nil {
err = server.FilterStopError(httpServer.ListenAndServeTLS("", ""))
} else {
err = server.FilterStopError(httpServer.ListenAndServe())
}
logger.Debug("Server stopped")
if err != nil {
_ = server.LogError(err)
}
stopped <- struct{}{}
}()
for {
select {
case <-interrupt:
logger.Debug("Caught interrupt")
err := httpServer.Shutdown(context.Background())
if err != nil {
_ = server.LogError(err)
}
serv.Stop()
logger.Debug("Sent stop signal to server")
case <-stopped:
logger.Info("Exiting")
close(stopped)
close(interrupt)
return
}
}
}
package cmd package cmd
import ( import (
"context"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
irma "github.com/privacybydesign/irmago" irma "github.com/privacybydesign/irmago"
"github.com/privacybydesign/irmago/server" "github.com/privacybydesign/irmago/server"
"github.com/privacybydesign/irmago/server/keyshare/keyshareserver" "github.com/privacybydesign/irmago/server/keyshare/keyshareserver"
...@@ -21,55 +14,13 @@ var keysharedCmd = &cobra.Command{ ...@@ -21,55 +14,13 @@ var keysharedCmd = &cobra.Command{
Run: func(command *cobra.Command, args []string) { Run: func(command *cobra.Command, args []string) {
conf := configureKeyshared(command) conf := configureKeyshared(command)
// Determine full listening address.
fullAddr := fmt.Sprintf("%s:%d", viper.GetString("listen-addr"), viper.GetInt("port"))
// Load TLS configuration
TLSConfig := configureTLS()
// Create main server // Create main server
keyshareServer, err := keyshareserver.New(conf) keyshareServer, err := keyshareserver.New(conf)
if err != nil { if err != nil {
die("", err) die("", err)
} }
serv := &http.Server{ runServer(keyshareServer, conf.Logger)
Addr: fullAddr,
Handler: keyshareServer.Handler(),
TLSConfig: TLSConfig,
}
stopped := make(chan struct{})
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM)
go func() {
if TLSConfig != nil {
err = serv.ListenAndServeTLS("", "")
} else {
err = serv.ListenAndServe()
}
conf.Logger.Debug("Server stopped")
stopped <- struct{}{}
}()
for {
select {
case <-interrupt:
conf.Logger.Debug("Caught interrupt")
err = serv.Shutdown(context.Background())
if err != nil {
_ = server.LogError(err)
}
keyshareServer.Stop()
conf.Logger.Debug("Sent stop signal to server")
case <-stopped:
conf.Logger.Info("Exiting")
close(stopped)
close(interrupt)
return
}
}
}, },
} }
......
...@@ -549,3 +549,10 @@ func ParseBody(w http.ResponseWriter, r *http.Request, input interface{}) error ...@@ -549,3 +549,10 @@ func ParseBody(w http.ResponseWriter, r *http.Request, input interface{}) error
} }
return nil return nil
} }
func FilterStopError(err error) error {
if err == http.ErrServerClosed {
return nil
}
return err
}
...@@ -117,19 +117,12 @@ func (s *Server) startServer(handler http.Handler, name, addr string, port int, ...@@ -117,19 +117,12 @@ func (s *Server) startServer(handler http.Handler, name, addr string, port int,
if tlsConf != nil { if tlsConf != nil {
s.conf.Logger.Info(name, " TLS enabled") s.conf.Logger.Info(name, " TLS enabled")
return filterStopError(serv.ListenAndServeTLS("", "")) return server.FilterStopError(serv.ListenAndServeTLS("", ""))
} else { } else {
return filterStopError(serv.ListenAndServe()) return server.FilterStopError(serv.ListenAndServe())
} }
} }
func filterStopError(err error) error {
if err == http.ErrServerClosed {
return nil
}
return err
}
func (s *Server) Stop() { func (s *Server) Stop() {
s.irmaserv.Stop() s.irmaserv.Stop()
s.stop <- struct{}{} s.stop <- struct{}{}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment