Skip to content

Commit

Permalink
Merge pull request koding#20 from jlhawn/copy_response
Browse files Browse the repository at this point in the history
Copy response to client on failed handshake
  • Loading branch information
rjeczalik committed Jul 16, 2018
2 parents 5fdfb40 + 53b8c5c commit 0fa3f99
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions websocketproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package websocketproxy

import (
"fmt"
"io"
"log"
"net"
"net/http"
Expand Down Expand Up @@ -133,7 +134,17 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// http://tools.ietf.org/html/draft-ietf-hybi-websocket-multiplexing-01
connBackend, resp, err := dialer.Dial(backendURL.String(), requestHeader)
if err != nil {
log.Printf("websocketproxy: couldn't dial to remote backend url %s\n", err)
log.Printf("websocketproxy: couldn't dial to remote backend url %s", err)
if resp != nil {
// If the WebSocket handshake fails, ErrBadHandshake is returned
// along with a non-nil *http.Response so that callers can handle
// redirects, authentication, etcetera.
if err := copyResponse(rw, resp); err != nil {
log.Printf("websocketproxy: couldn't write response after failed remote backend handshake: %s", err)
}
} else {
http.Error(rw, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
}
return
}
defer connBackend.Close()
Expand All @@ -156,7 +167,7 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Also pass the header that we gathered from the Dial handshake.
connPub, err := upgrader.Upgrade(rw, req, upgradeHeader)
if err != nil {
log.Printf("websocketproxy: couldn't upgrade %s\n", err)
log.Printf("websocketproxy: couldn't upgrade %s", err)
return
}
defer connPub.Close()
Expand Down Expand Up @@ -200,3 +211,20 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
log.Printf(message, err)
}
}

func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}

func copyResponse(rw http.ResponseWriter, resp *http.Response) error {
copyHeader(rw.Header(), resp.Header)
rw.WriteHeader(resp.StatusCode)
defer resp.Body.Close()

_, err := io.Copy(rw, resp.Body)
return err
}

0 comments on commit 0fa3f99

Please sign in to comment.