Commit 4953edc7 authored by Sietse Ringers's avatar Sietse Ringers
Browse files

feat: support POSTing []byte in transport

parent e7954d65
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/base64"
"encoding/json" "encoding/json"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -96,7 +97,7 @@ func (transport *HTTPTransport) SetHeader(name, val string) { ...@@ -96,7 +97,7 @@ func (transport *HTTPTransport) SetHeader(name, val string) {
} }
func (transport *HTTPTransport) request( func (transport *HTTPTransport) request(
url string, method string, reader io.Reader, isstr bool, url string, method string, reader io.Reader, contenttype string,
) (response *http.Response, err error) { ) (response *http.Response, err error) {
var req retryablehttp.Request var req retryablehttp.Request
req.Request, err = http.NewRequest(method, transport.Server+url, reader) req.Request, err = http.NewRequest(method, transport.Server+url, reader)
...@@ -105,12 +106,8 @@ func (transport *HTTPTransport) request( ...@@ -105,12 +106,8 @@ func (transport *HTTPTransport) request(
} }
req.Header.Set("User-Agent", "irmago") req.Header.Set("User-Agent", "irmago")
if reader != nil { if reader != nil && contenttype != "" {
if isstr { req.Header.Set("Content-Type", contenttype)
req.Header.Set("Content-Type", "text/plain; charset=UTF-8")
} else {
req.Header.Set("Content-Type", "application/json; charset=UTF-8")
}
} }
for name, val := range transport.headers { for name, val := range transport.headers {
req.Header.Set(name, val) req.Header.Set(name, val)
...@@ -131,24 +128,30 @@ func (transport *HTTPTransport) jsonRequest(url string, method string, result in ...@@ -131,24 +128,30 @@ func (transport *HTTPTransport) jsonRequest(url string, method string, result in
panic("Cannot GET and also post an object") panic("Cannot GET and also post an object")
} }
var isstr bool
var reader io.Reader var reader io.Reader
var contenttype string
if object != nil { if object != nil {
var objstr string switch o := object.(type) {
if objstr, isstr = object.(string); isstr { case []byte:
Logger.Trace("transport: body: ", objstr) Logger.Trace("transport: body (base64): ", base64.StdEncoding.EncodeToString(o))
reader = bytes.NewBuffer([]byte(objstr)) contenttype = "application/octet-stream"
} else { reader = bytes.NewBuffer(o)
case string:
Logger.Trace("transport: body: ", o)
contenttype = "text/plain; charset=UTF-8"
reader = bytes.NewBuffer([]byte(o))
default:
marshaled, err := json.Marshal(object) marshaled, err := json.Marshal(object)
if err != nil { if err != nil {
return &SessionError{ErrorType: ErrorSerialization, Err: err} return &SessionError{ErrorType: ErrorSerialization, Err: err}
} }
Logger.Trace("transport: body: ", string(marshaled)) Logger.Trace("transport: body: ", string(marshaled))
contenttype = "application/json; charset=UTF-8"
reader = bytes.NewBuffer(marshaled) reader = bytes.NewBuffer(marshaled)
} }
} }
res, err := transport.request(url, method, reader, isstr) res, err := transport.request(url, method, reader, contenttype)
if err != nil { if err != nil {
return err return err
} }
...@@ -171,6 +174,9 @@ func (transport *HTTPTransport) jsonRequest(url string, method string, result in ...@@ -171,6 +174,9 @@ func (transport *HTTPTransport) jsonRequest(url string, method string, result in
} }
Logger.Tracef("transport: response: %s", string(body)) Logger.Tracef("transport: response: %s", string(body))
if result == nil { // caller doesn't care about server response
return nil
}
if _, resultstr := result.(*string); resultstr { if _, resultstr := result.(*string); resultstr {
*result.(*string) = string(body) *result.(*string) = string(body)
} else { } else {
...@@ -184,7 +190,7 @@ func (transport *HTTPTransport) jsonRequest(url string, method string, result in ...@@ -184,7 +190,7 @@ func (transport *HTTPTransport) jsonRequest(url string, method string, result in
} }
func (transport *HTTPTransport) GetBytes(url string) ([]byte, error) { func (transport *HTTPTransport) GetBytes(url string) ([]byte, error) {
res, err := transport.request(url, http.MethodGet, nil, false) res, err := transport.request(url, http.MethodGet, nil, "")
if err != nil { if err != nil {
return nil, &SessionError{ErrorType: ErrorTransport, Err: err} return nil, &SessionError{ErrorType: ErrorTransport, Err: err}
} }
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment