Commit d01ee210 authored by Jacob Vosmaer (GitLab)'s avatar Jacob Vosmaer (GitLab)

Merge branch 'nick-kube-proxy' into 'master'

Handle environments/:id/terminal.ws, providing access to terminal websockets

A GitLab environment may expose a terminal connection for out-of-band access. Workhorse is responsible for providing a websocket connection to the terminal if present. 

It authenticates the user and retrieves connection details from GitLab using the environments/:id/terminal.ws/authorize endpoint, and sets up a proxy to the terminal provider, converting from the remote's subprotocol to a common format.

Authentication is periodically re-done, and the connection will be broken if it fails, or if the connection details change in any way.

This MR includes support for the `channel.k8s.io` websocket subprotocol, allowing connections to be made to Kubernetes / OpenShift containers. 

Based on top of (some of) !72 

Related to https://gitlab.com/gitlab-org/gitlab-ce/issues/22864

See merge request !83
parents f2d0435b fe16eae3
{
"ImportPath": "gitlab.com/gitlab-org/gitlab-workhorse",
"GoVersion": "go1.7",
"GoVersion": "go1.5",
"GodepVersion": "v74",
"Packages": [
"./..."
],
"Deps": [
{
"ImportPath": "github.com/beorn7/perks/quantile",
......@@ -29,6 +32,11 @@
"ImportPath": "github.com/golang/protobuf/proto",
"Rev": "8ee79997227bf9b34611aee7946ae64735e6fd93"
},
{
"ImportPath": "github.com/gorilla/websocket",
"Comment": "v1.0.0-39-ge8f0f8a",
"Rev": "e8f0f8aaa98dfb6586cbdf2978d511e3199a960a"
},
{
"ImportPath": "github.com/matttproud/golang_protobuf_extensions/pbutil",
"Comment": "v1.0.0-2-gc12348c",
......
......@@ -18,6 +18,8 @@ push/pull and Git archive downloads.
when handling a Git LFS upload Workhorse first asks permission from
Rails, then it stores the request body in a tempfile, then it sends
a modified request containing the tempfile path to Rails.
- Workhorse can manage long-lived WebSocket connections for Rails.
Example: handling the terminal websocket for environments.
- Workhorse does not connect to Redis or Postgres, only to Rails.
- We assume that all requests that reach Workhorse pass through an
upstream proxy such as NGINX or Apache first.
......
# Terminal support
In some cases, GitLab can provide in-browser terminal access to an
environment (which is a running server or container, onto which a
project has been deployed) through a WebSocket. Workhorse manages
the WebSocket upgrade and long-lived connection to the terminal for
the environment, which frees up GitLab to process other requests.
This document outlines the architecture of these connections.
## Introduction to WebSockets
A websocket is an "upgraded" HTTP/1.1 request. Their purpose is to
permit bidirectional communication between a client and a server.
**Websockets are not HTTP**. Clients can send messages (known as
frames) to the server at any time, and vice-versa. Client messages
are not necessarily requests, and server messages are not necessarily
responses. WebSocket URLs have schemes like `ws://` (unencrypted) or
`wss://` (TLS-secured).
When requesting an upgrade to WebSocket, the browser sends a HTTP/1.1
request that looks like this:
```
GET /path.ws HTTP/1.1
Connection: upgrade
Upgrade: websocket
Sec-WebSocket-Protocol: terminal.gitlab.com
# More headers, including security measures
```
At this point, the connection is still HTTP, so this is a request and
the server can send a normal HTTP response, including `404 Not Found`,
`500 Internal Server Error`, etc.
If the server decides to permit the upgrade, it will send a HTTP
`101 Switching Protocols` response. From this point, the connection
is no longer HTTP. It is a WebSocket and frames, not HTTP requests,
will flow over it. The connection will persist until the client or
server closes the connection.
In addition to the subprotocol, individual websocket frames may
also specify a message type - examples include `BinaryMessage`,
`TextMessage`, `Ping`, `Pong` or `Close`. Only binary frames can
contain arbitrary data - other frames are expected to be valid
UTF-8 strings, in addition to any subprotocol expectations.
## Browser to Workhorse
GitLab serves a JavaScript terminal emulator to the browser on
a URL like `https://gitlab.com/group/project/environments/1/terminal`.
This opens a websocket connection to, e.g.,
`wss://gitlab.com/group/project/environments/1/terminal.ws`,
This endpoint doesn't exist in GitLab - only in Workhorse.
When receiving the connection, Workhorse first checks that the
client is authorized to access the requested terminal. It does
this by performing a "preauthentication" request to GitLab.
If the client has the appropriate permissions and the terminal
exists, GitLab responds with a successful response that includes
details of the terminal that the client should be connected to.
Otherwise, it returns an appropriate HTTP error response.
Errors are passed back to the client as HTTP responses, but if
GitLab returns valid terminal details to Workhorse, it will
connect to the specified terminal, upgrade the browser to a
WebSocket, and proxy between the two connections for as long
as the browser's credentials are valid. Workhorse will also
send regular `PingMessage` control frames to the browser, to
keep intervening proxies from terminating the connection
while the browser is present.
The browser must request an upgrade with a specific subprotocol:
### `terminal.gitlab.com`
This subprotocol considers `TextMessage` frames to be invalid.
Control frames, such as `PingMessage` or `CloseMessage`, have
their usual meanings.
`BinaryMessage` frames sent from the browser to the server are
arbitrary terminal input.
`BinaryMessage` frames sent from the server to the browser are
arbitrary terminal output.
These frames are expected to contain ANSI terminal control codes
and may be in any encoding.
### `base64.terminal.gitlab.com`
This subprotocol considers `BinaryMessage` frames to be invalid.
Control frames, such as `PingMessage` or `CloseMessage`, have
their usual meanings.
`TextMessage` frames sent from the browser to the server are
base64-encoded arbitrary terminal input (so the server must
base64-decode them before inputting them).
`TextMessage` frames sent from the server to the browser are
base64-encoded arbitrary terminal output (so the browser must
base64-decode them before outputting them).
In their base64-encoded form, these frames are expected to
contain ANSI terminal control codes, and may be in any encoding.
## Workhorse to GitLab
Before upgrading the browser, Workhorse sends a normal HTTP
request to GitLab on a URL like
`https://gitlab.com/group/project/environments/1/terminal.ws/authorize`.
This returns a JSON response containing details of where the
terminal can be found, and how to connect it. In particular,
the following details are returned in case of success:
* WebSocket URL to **connect** to, e.g.: `wss://example.com/terminals/1.ws?tty=1`
* WebSocket subprotocols to support, e.g.: `["channel.k8s.io"]`
* Headers to send, e.g.: `Authorization: Token xxyyz..`
* Certificate authority to verify `wss` connections with (optional)
Workhorse periodically re-checks this endpoint, and if it gets an
error response, or the details of the terminal change, it will
terminate the websocket session.
## Workhorse to Terminal
In GitLab, environments may have a deployment service (e.g.,
`KubernetesService`) associated with them. This service knows
where the terminals for an environment may be found, and these
details are returned to Workhorse by GitLab.
These URLs are *also* WebSocket URLs, and GitLab tells Workhorse
which subprotocols to speak over the connection, along with any
authentication details required by the remote end.
Before upgrading the browser's connection to a websocket,
Workhorse opens a HTTP client connection, according to the
details given to it by Workhorse, and attempts to upgrade
that connection to a websocket. If it fails, an error
response is sent to the browser; otherwise, the browser is
also upgraded.
Workhorse now has two websocket connections, albeit with
differing subprotocols. It decodes incoming frames from the
browser, re-encodes them to the terminal's subprotocol, and
sends them to the terminal. Similarly, it decodes incoming
frames from the terminal, re-encodes them to the browser's
subprotocol, and sends them to the browser.
When either connection closes or enters an error state,
Workhorse detects the error and closes the other connection,
terminating the terminal session. If the browser is the
connection that has disconnected, Workhorse will send an ANSI
`End of Transmission` control code (the `0x04` byte) to the
terminal, encoded according to the appropriate subprotocol.
Workhorse will automatically reply to any websocket ping frame
sent by the terminal, to avoid being disconnected.
Currently, Workhorse only supports the following subprotocols.
Supporting new deployment services will require new subprotocols
to be supported:
### `channel.k8s.io`
Used by Kubernetes, this subprotocol defines a simple multiplexed
channel.
Control frames have their usual meanings. `TextMessage` frames are
invalid. `BinaryMessage` frames represent I/O to a specific file
descriptor.
The first byte of each `BinaryMessage` frame represents the file
descriptor (fd) number, as a `uint8` (so the value `0x00` corresponds
to fd 0, `STDIN`, while `0x01` corresponds to fd 1, `STDOUT`).
The remaining bytes represent arbitrary data. For frames received
from the server, they are bytes that have been received from that
fd. For frames sent to the server, they are bytes that should be
written to that fd.
### `base64.channel.k8s.io`
Also used by Kubernetes, this subprotocol defines a similar multiplexed
channel to `channel.k8s.io`. The main differences are:
* `TextMessage` frames are valid, rather than `BinaryMessage` frames.
* The first byte of each `TextMessage` frame represents the file
descriptor as a numeric UTF-8 character, so the character `U+0030`,
or "0", is fd 0, STDIN).
* The remaining bytes represent base64-encoded arbitrary data.
......@@ -59,6 +59,8 @@ type Response struct {
Archive string `json:"archive"`
// Entry is a filename inside the archive point to file that needs to be extracted
Entry string `json:"entry"`
// Used to communicate terminal session details
Terminal *TerminalSettings
}
// singleJoiningSlash is taken from reverseproxy.go:NewSingleHostReverseProxy
......@@ -143,23 +145,60 @@ func (api *API) newRequest(r *http.Request, body io.Reader, suffix string) (*htt
return authReq, nil
}
func (api *API) PreAuthorizeHandler(h HandleFunc, suffix string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Perform a pre-authorization check against the API for the given HTTP request
//
// If `outErr` is set, the other fields will be nil and it should be treated as
// a 500 error.
//
// If httpResponse is present, the caller is responsible for closing its body
//
// authResponse will only be present if the authorization check was successful
func (api *API) PreAuthorize(suffix string, r *http.Request) (httpResponse *http.Response, authResponse *Response, outErr error) {
authReq, err := api.newRequest(r, nil, suffix)
if err != nil {
helper.Fail500(w, r, fmt.Errorf("preAuthorizeHandler newUpstreamRequest: %v", err))
return
return nil, nil, fmt.Errorf("preAuthorizeHandler newUpstreamRequest: %v", err)
}
httpResponse, err = api.Client.Do(authReq)
if err != nil {
return nil, nil, fmt.Errorf("preAuthorizeHandler: do request: %v", err)
}
defer func() {
if outErr != nil {
httpResponse.Body.Close()
httpResponse = nil
}
}()
if httpResponse.StatusCode != http.StatusOK {
return httpResponse, nil, nil
}
if contentType := httpResponse.Header.Get("Content-Type"); contentType != ResponseContentType {
return httpResponse, nil, fmt.Errorf("preAuthorizeHandler: API responded with wrong content type: %v", contentType)
}
authResponse = &Response{}
// The auth backend validated the client request and told us additional
// request metadata. We must extract this information from the auth
// response body.
if err := json.NewDecoder(httpResponse.Body).Decode(authResponse); err != nil {
return httpResponse, nil, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err)
}
authResponse, err := api.Client.Do(authReq)
return httpResponse, authResponse, nil
}
func (api *API) PreAuthorizeHandler(next HandleFunc, suffix string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
httpResponse, authResponse, err := api.PreAuthorize(suffix, r)
if err != nil {
helper.Fail500(w, r, fmt.Errorf("preAuthorizeHandler: do request: %v", err))
helper.Fail500(w, r, err)
return
}
defer authResponse.Body.Close()
if authResponse.StatusCode != 200 {
for k, v := range authResponse.Header {
if httpResponse.StatusCode != http.StatusOK {
for k, v := range httpResponse.Header {
// Accomodate broken clients that do case-sensitive header lookup
if k == "Www-Authenticate" {
w.Header()["WWW-Authenticate"] = v
......@@ -167,36 +206,25 @@ func (api *API) PreAuthorizeHandler(h HandleFunc, suffix string) http.Handler {
w.Header()[k] = v
}
}
w.WriteHeader(authResponse.StatusCode)
io.Copy(w, authResponse.Body)
return
}
w.WriteHeader(httpResponse.StatusCode)
io.Copy(w, httpResponse.Body)
httpResponse.Body.Close()
if contentType := authResponse.Header.Get("Content-Type"); contentType != ResponseContentType {
helper.Fail500(w, r, fmt.Errorf("preAuthorizeHandler: API responded with wrong content type: %v", contentType))
return
}
a := &Response{}
// The auth backend validated the client request and told us additional
// request metadata. We must extract this information from the auth
// response body.
if err := json.NewDecoder(authResponse.Body).Decode(a); err != nil {
helper.Fail500(w, r, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err))
return
}
// Don't hog a TCP connection in CLOSE_WAIT, we can already close it now
authResponse.Body.Close()
// Close the body immediately, rather than waiting for the next handler
// to complete
httpResponse.Body.Close()
// Negotiate authentication (Kerberos) may need to return a WWW-Authenticate
// header to the client even in case of success as per RFC4559.
for k, v := range authResponse.Header {
for k, v := range httpResponse.Header {
// Case-insensitive comparison as per RFC7230
if strings.EqualFold(k, "WWW-Authenticate") {
w.Header()[k] = v
}
}
h(w, r, a)
next(w, r, authResponse)
})
}
package api
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
"net/url"
"github.com/gorilla/websocket"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
)
type TerminalSettings struct {
// The terminal provider may require use of a particular subprotocol. If so,
// it must be specified here, and Workhorse must have a matching codec.
Subprotocols []string
// The websocket URL to connect to.
Url string
// Any headers (e.g., Authorization) to send with the websocket request
Header http.Header
// The CA roots to validate the remote endpoint with, for wss:// URLs. The
// system-provided CA pool will be used if this is blank. PEM-encoded data.
CAPem string
}
func (t *TerminalSettings) URL() (*url.URL, error) {
return url.Parse(t.Url)
}
func (t *TerminalSettings) Dialer() *websocket.Dialer {
dialer := &websocket.Dialer{
Subprotocols: t.Subprotocols,
}
if len(t.CAPem) > 0 {
pool := x509.NewCertPool()
pool.AppendCertsFromPEM([]byte(t.CAPem))
dialer.TLSClientConfig = &tls.Config{RootCAs: pool}
}
return dialer
}
func (t *TerminalSettings) Clone() *TerminalSettings {
// Doesn't clone the strings, but that's OK as strings are immutable in go
cloned := *t
cloned.Header = helper.HeaderClone(t.Header)
return &cloned
}
func (t *TerminalSettings) Dial() (*websocket.Conn, *http.Response, error) {
return t.Dialer().Dial(t.Url, t.Header)
}
func (t *TerminalSettings) Validate() error {
if t == nil {
return fmt.Errorf("Terminal details not specified")
}
if len(t.Subprotocols) == 0 {
return fmt.Errorf("No subprotocol specified")
}
parsedURL, err := t.URL()
if err != nil {
return fmt.Errorf("Invalid URL")
}
if parsedURL.Scheme != "ws" && parsedURL.Scheme != "wss" {
return fmt.Errorf("Invalid websocket scheme: %q", parsedURL.Scheme)
}
return nil
}
func (t *TerminalSettings) IsEqual(other *TerminalSettings) bool {
if t == nil && other == nil {
return true
}
if t == nil || other == nil {
return false
}
if len(t.Subprotocols) != len(other.Subprotocols) {
return false
}
for i, subprotocol := range t.Subprotocols {
if other.Subprotocols[i] != subprotocol {
return false
}
}
if len(t.Header) != len(other.Header) {
return false
}
for header, values := range t.Header {
if len(values) != len(other.Header[header]) {
return false
}
for i, value := range values {
if other.Header[header][i] != value {
return false
}
}
}
return t.Url == other.Url && t.CAPem == other.CAPem
}
package api
import (
"net/http"
"testing"
)
func terminal(url string, subprotocols ...string) *TerminalSettings {
return &TerminalSettings{
Url: url,
Subprotocols: subprotocols,
}
}
func ca(term *TerminalSettings) *TerminalSettings {
term = term.Clone()
term.CAPem = "Valid CA data"
return term
}
func header(term *TerminalSettings, values ...string) *TerminalSettings {
if len(values) == 0 {
values = []string{"Dummy Value"}
}
term = term.Clone()
term.Header = http.Header{
"Header": values,
}
return term
}
func TestClone(t *testing.T) {
a := ca(header(terminal("ws:", "", "")))
b := a.Clone()
if a == b {
t.Fatalf("Address of cloned terminal didn't change")
}
if &a.Subprotocols == &b.Subprotocols {
t.Fatalf("Address of cloned subprotocols didn't change")
}
if &a.Header == &b.Header {
t.Fatalf("Address of cloned header didn't change")
}
}
func TestValidate(t *testing.T) {
for i, tc := range []struct {
terminal *TerminalSettings
valid bool
msg string
}{
{nil, false, "nil terminal"},
{terminal("", ""), false, "empty URL"},
{terminal("ws:"), false, "empty subprotocols"},
{terminal("ws:", "foo"), true, "any subprotocol"},
{terminal("ws:", "foo", "bar"), true, "multiple subprotocols"},
{terminal("ws:", ""), true, "websocket URL"},
{terminal("wss:", ""), true, "secure websocket URL"},
{terminal("http:", ""), false, "HTTP URL"},
{terminal("https:", ""), false, " HTTPS URL"},
{ca(terminal("ws:", "")), true, "any CA pem"},
{header(terminal("ws:", "")), true, "any headers"},
{ca(header(terminal("ws:", ""))), true, "PEM and headers"},
} {
if err := tc.terminal.Validate(); (err != nil) == tc.valid {
t.Fatalf("test case %d: "+tc.msg+": valid=%v: %s: %+v", i, tc.valid, err, tc.terminal)
}
}
}
func TestDialer(t *testing.T) {
terminal := terminal("ws:", "foo")
dialer := terminal.Dialer()
if len(dialer.Subprotocols) != len(terminal.Subprotocols) {
t.Fatalf("Subprotocols don't match: %+v vs. %+v", terminal.Subprotocols, dialer.Subprotocols)
}
for i, subprotocol := range terminal.Subprotocols {
if dialer.Subprotocols[i] != subprotocol {
t.Fatalf("Subprotocols don't match: %+v vs. %+v", terminal.Subprotocols, dialer.Subprotocols)
}
}
if dialer.TLSClientConfig != nil {
t.Fatalf("Unexpected TLSClientConfig: %+v", dialer)
}
terminal = ca(terminal)
dialer = terminal.Dialer()
if dialer.TLSClientConfig == nil || dialer.TLSClientConfig.RootCAs == nil {
t.Fatalf("Custom CA certificates not recognised!")
}
}
func TestIsEqual(t *testing.T) {
term := terminal("ws:", "foo")
term_header2 := header(term, "extra")
term_header3 := header(term)
term_header3.Header.Add("Extra", "extra")
term_ca2 := ca(term)
term_ca2.CAPem = "other value"
for i, tc := range []struct {
termA *TerminalSettings
termB *TerminalSettings
expected bool
}{
{nil, nil, true},
{term, nil, false},
{nil, term, false},
{term, term, true},
{term.Clone(), term.Clone(), true},
{term, terminal("foo:"), false},
{term, terminal(term.Url), false},
{header(term), header(term), true},
{term_header2, term_header2, true},
{term_header3, term_header3, true},
{header(term), term_header2, false},
{header(term), term_header3, false},
{header(term), term, false},
{term, header(term), false},
{ca(term), ca(term), true},
{ca(term), term, false},
{term, ca(term), false},
{ca(header(term)), ca(header(term)), true},
{term_ca2, ca(term), false},
} {
if actual := tc.termA.IsEqual(tc.termB); tc.expected != actual {
t.Fatalf(
"test case %d: Comparison:\n-%+v\n+%+v\nexpected=%v: actual=%v",
i, tc.termA, tc.termB, tc.expected, actual,
)
}
}
}
package helper
import (
"bufio"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"strconv"
......@@ -43,26 +45,49 @@ func registerPrometheusMetrics() {
prometheus.MustRegister(requestsTotal)
}
type LoggingResponseWriter struct {
type LoggingResponseWriter interface {
http.ResponseWriter
Log(r *http.Request)
}
type loggingResponseWriter struct {
rw http.ResponseWriter
status int
written int64
started time.Time
}
type hijackingResponseWriter struct {
loggingResponseWriter
}
func NewLoggingResponseWriter(rw http.ResponseWriter) LoggingResponseWriter {
sessionsActive.Inc()
return LoggingResponseWriter{
out := loggingResponseWriter{
rw: rw,
started: time.Now(),
}
if _, ok := rw.(http.Hijacker); ok {
return &hijackingResponseWriter{out}
}
return &out
}
func (l *hijackingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// The only way to gethere is through NewLoggingResponseWriter(), which
// checks that this cast will be valid.
hijacker := l.rw.(http.Hijacker)
return hijacker.Hijack()
}
func (l *LoggingResponseWriter) Header() http.Header {
func (l *loggingResponseWriter) Header() http.Header {
return l.rw.Header()
}
func (l *LoggingResponseWriter) Write(data []byte) (n int, err error) {
func (l *loggingResponseWriter) Write(data []byte) (n int, err error) {
if l.status == 0 {
l.WriteHeader(http.StatusOK)
}
......@@ -71,7 +96,7 @@ func (l *LoggingResponseWriter) Write(data []byte) (n int, err error) {
return
}
func (l *LoggingResponseWriter) WriteHeader(status int) {
func (l *loggingResponseWriter) WriteHeader(status int) {
if l.status != 0 {
return
}
......@@ -80,7 +105,7 @@ func (l *LoggingResponseWriter) WriteHeader(status int) {
l.rw.WriteHeader(status)
}
func (l *LoggingResponseWriter) Log(r *http.Request) {
func (l *loggingResponseWriter) Log(r *http.Request) {
duration := time.Since(l.started)
responseLogger.Printf("%s %s - - [%s] %q %d %d %q %q %f\n",
r.Host, r.RemoteAddr, l.started,
......
package terminal
import (
"errors"
"net/http"
"time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
)
type AuthCheckerFunc func() *api.TerminalSettings
// Regularly checks that authorization is still valid for a terminal, outputting
// to the stopper when it isn't
type AuthChecker struct {
Checker AuthCheckerFunc
Template *api.TerminalSettings
StopCh chan error
Done chan struct{}
Count int64
}
var ErrAuthChanged = errors.New("Connection closed: authentication changed or endpoint unavailable.")
func NewAuthChecker(f AuthCheckerFunc, template *api.TerminalSettings, stopCh chan error) *AuthChecker {
return &AuthChecker{
Checker: f,
Template: template,
StopCh: stopCh,
Done: make(chan struct{}),
}
}
func (c *AuthChecker) Loop(interval time.Duration) {
for {
select {
case <-time.After(interval):
settings := c.Checker()
if !c.Template.IsEqual(settings) {
c.StopCh <- ErrAuthChanged
return
}
c.Count = c.Count + 1
case <-c.Done:
return
}
}
}
func (c *AuthChecker) Close() error {
close(c.Done)
return nil
}
// Generates a CheckerFunc from an *api.API + request needing authorization
func authCheckFunc(myAPI *api.API, r *http.Request, suffix string) AuthCheckerFunc {
return func() *api.TerminalSettings {
httpResponse, authResponse, err := myAPI.PreAuthorize(suffix, r)
if err != nil {
return nil
}
defer httpResponse.Body.Close()
if httpResponse.StatusCode != http.StatusOK || authResponse == nil {
return nil
}
return authResponse.Terminal
}
}
package terminal
import (
"testing"
"time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
)
func checkerSeries(values ...*api.TerminalSettings) AuthCheckerFunc {
return func() *api.TerminalSettings {
if len(values) == 0 {
return nil
}
out := values[0]
values = values[1:]
return out
}
}
func TestAuthCheckerStopsWhenAuthFails(t *testing.T) {
template := &api.TerminalSettings{Url: "ws://example.com"}
stopCh := make(chan error)
series := checkerSeries(template, template, template)
ac := NewAuthChecker(series, template, stopCh)
go ac.Loop(1 * time.Millisecond)
if err := <-stopCh; err != ErrAuthChanged {
t.Fatalf("Expected ErrAuthChanged, got %v", err)
}
if ac.Count != 3 {
t.Fatalf("Expected 3 successful checks, got %v", ac.Count)
}
}
func TestAuthCheckerStopsWhenAuthChanges(t *testing.T) {
template := &api.TerminalSettings{Url: "ws://example.com"}
changed := template.Clone()
changed.Url = "wss://example.com"
stopCh := make(chan error)
series := checkerSeries(template, changed, template)
ac := NewAuthChecker(series, template, stopCh)
go ac.Loop(1 * time.Millisecond)
if err := <-stopCh; err != ErrAuthChanged {
t.Fatalf("Expected ErrAuthChanged, got %v", err)
}
if ac.Count != 1 {
t.Fatalf("Expected 1 successful check, got %v", ac.Count)
}
}
package terminal
import (
"fmt"
"net"
"time"
"github.com/gorilla/websocket"
)
// ANSI "end of terminal" code
var eot = []byte{0x04}
// An abstraction of gorilla's *websocket.Conn
type Connection interface {
UnderlyingConn() net.Conn
ReadMessage() (int, []byte, error)
WriteMessage(int, []byte) error
WriteControl(int, []byte, time.Time) error
}
type Proxy struct {
StopCh chan error
}
// stoppers is the number of goroutines that may attempt to call Stop()
func NewProxy(stoppers int) *Proxy {
return &Proxy{
StopCh: make(chan error, stoppers+2), // each proxy() call is a stopper
}
}
func (p *Proxy) Serve(upstream, downstream Connection, upstreamAddr, downstreamAddr string) error {
// This signals the upstream terminal to kill the exec'd process
defer upstream.WriteMessage(websocket.BinaryMessage, eot)
go p.proxy(upstream, downstream, upstreamAddr, downstreamAddr)
go p.proxy(downstream, upstream, downstreamAddr, upstreamAddr)
return <-p.StopCh
}
func (p *Proxy) proxy(to, from Connection, toAddr, fromAddr string) {
for {
messageType, data, err := from.ReadMessage()
if err != nil {
p.StopCh <- fmt.Errorf("reading from %s: %s", fromAddr, err)
break
}
if err := to.WriteMessage(messageType, data); err != nil {
p.StopCh <- fmt.Errorf("writing to %s: %s", toAddr, err)
break
}
}
}
package terminal
import (
"log"
"net"
"net/http"
"strings"
"time"
"github.com/gorilla/websocket"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
)
var (
// See doc/terminal.md for documentation of this subprotocol
subprotocols = []string{"terminal.gitlab.com", "base64.terminal.gitlab.com"}
upgrader = &websocket.Upgrader{Subprotocols: subprotocols}
ReauthenticationInterval = 5 * time.Minute
BrowserPingInterval = 30 * time.Second
)
func Handler(myAPI *api.API) http.Handler {
return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
if err := a.Terminal.Validate(); err != nil {
helper.Fail500(w, r, err)
return
}
proxy := NewProxy(1) // one stopper: auth checker
checker := NewAuthChecker(
authCheckFunc(myAPI, r, "authorize"),
a.Terminal,
proxy.StopCh,
)
defer checker.Close()
go checker.Loop(ReauthenticationInterval)
ProxyTerminal(w, r, a.Terminal, proxy)
}, "authorize")
}
func ProxyTerminal(w http.ResponseWriter, r *http.Request, terminal *api.TerminalSettings, proxy *Proxy) {
server, err := connectToServer(terminal, r)
if err != nil {
helper.Fail500(w, r, err)
log.Printf("Terminal: connecting to server failed: %s", err)
return
}
defer server.UnderlyingConn().Close()
serverAddr := server.UnderlyingConn().RemoteAddr().String()
client, err := upgradeClient(w, r)
if err != nil {
log.Printf("Terminal: upgrading client to websocket failed: %s", err)
return
}
// Regularly send ping messages to the browser to keep the websocket from
// being timed out by intervening proxies.
go pingLoop(client)
defer client.UnderlyingConn().Close()
clientAddr := getClientAddr(r) // We can't know the port with confidence
log.Printf("Terminal: started proxying from %s to %s", clientAddr, serverAddr)
defer log.Printf("Terminal: finished proxying from %s to %s", clientAddr, serverAddr)
if err := proxy.Serve(server, client, serverAddr, clientAddr); err != nil {
log.Printf("Terminal: error proxying from %s to %s: %s", clientAddr, serverAddr, err)
}
}
// In the future, we might want to look at X-Client-Ip or X-Forwarded-For
func getClientAddr(r *http.Request) string {
return r.RemoteAddr
}
func upgradeClient(w http.ResponseWriter, r *http.Request) (Connection, error) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return nil, err
}
return Wrap(conn, conn.Subprotocol()), nil
}
func pingLoop(conn Connection) {
for {
time.Sleep(BrowserPingInterval)
deadline := time.Now().Add(5 * time.Second)
if err := conn.WriteControl(websocket.PingMessage, nil, deadline); err != nil {
// Either the connection was already closed so no further pings are
// needed, or this connection is now dead and no further pings can
// be sent.
break
}
}
}
func connectToServer(terminal *api.TerminalSettings, r *http.Request) (Connection, error) {
terminal = terminal.Clone()
// Pass along X-Forwarded-For, appending request.RemoteAddr, to the server
// we're connecting to.
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
if chains, ok := r.Header["X-Forwarded-For"]; ok {
terminal.Header.Set("X-Forwarded-For", strings.Join(chains, ", ")+", "+ip)
} else {
terminal.Header.Set("X-Forwarded-For", ip)
}
}
conn, _, err := terminal.Dial()
if err != nil {
return nil, err
}
return Wrap(conn, conn.Subprotocol()), nil
}
package terminal
import (
"encoding/base64"
"net"
"time"
"github.com/gorilla/websocket"
)
func Wrap(conn Connection, subprotocol string) Connection {
switch subprotocol {
case "channel.k8s.io":
return &kubeWrapper{base64: false, conn: conn}
case "base64.channel.k8s.io":
return &kubeWrapper{base64: true, conn: conn}
case "terminal.gitlab.com":
return &gitlabWrapper{base64: false, conn: conn}
case "base64.terminal.gitlab.com":
return &gitlabWrapper{base64: true, conn: conn}
}
return conn
}
type kubeWrapper struct {
base64 bool
conn Connection
}
type gitlabWrapper struct {
base64 bool
conn Connection
}
func (w *gitlabWrapper) ReadMessage() (int, []byte, error) {
mt, data, err := w.conn.ReadMessage()
if err != nil {
return mt, data, err
}
if isData(mt) {
mt = websocket.BinaryMessage
if w.base64 {
data, err = decodeBase64(data)
}
}
return mt, data, err
}
func (w *gitlabWrapper) WriteMessage(mt int, data []byte) error {
if isData(mt) {
if w.base64 {
mt = websocket.TextMessage
data = encodeBase64(data)
} else {
mt = websocket.BinaryMessage
}
}
return w.conn.WriteMessage(mt, data)
}
func (w *gitlabWrapper) WriteControl(mt int, data []byte, deadline time.Time) error {
return w.conn.WriteControl(mt, data, deadline)
}
func (w *gitlabWrapper) UnderlyingConn() net.Conn {
return w.conn.UnderlyingConn()
}
// Coalesces all wsstreams into a single stream. In practice, we should only
// receive data on stream 1.
func (w *kubeWrapper) ReadMessage() (int, []byte, error) {
mt, data, err := w.conn.ReadMessage()
if err != nil {
return mt, data, err
}
if isData(mt) {
mt = websocket.BinaryMessage
// Remove the WSStream channel number, decode to raw
if len(data) > 0 {
data = data[1:]
if w.base64 {
data, err = decodeBase64(data)
}
}
}
return mt, data, err
}
// Always sends to wsstream 0
func (w *kubeWrapper) WriteMessage(mt int, data []byte) error {
if isData(mt) {
if w.base64 {
mt = websocket.TextMessage
data = append([]byte{'0'}, encodeBase64(data)...)
} else {
mt = websocket.BinaryMessage
data = append([]byte{0}, data...)
}
}
return w.conn.WriteMessage(mt, data)
}
func (w *kubeWrapper) WriteControl(mt int, data []byte, deadline time.Time) error {
return w.conn.WriteControl(mt, data, deadline)
}
func (w *kubeWrapper) UnderlyingConn() net.Conn {
return w.conn.UnderlyingConn()
}
func isData(mt int) bool {
return mt == websocket.BinaryMessage || mt == websocket.TextMessage
}
func encodeBase64(data []byte) []byte {
buf := make([]byte, base64.StdEncoding.EncodedLen(len(data)))
base64.StdEncoding.Encode(buf, data)
return buf
}
func decodeBase64(data []byte) ([]byte, error) {
buf := make([]byte, base64.StdEncoding.DecodedLen(len(data)))
n, err := base64.StdEncoding.Decode(buf, data)
return buf[:n], err
}
package terminal
import (
"bytes"
"errors"
"net"
"testing"
"time"
"github.com/gorilla/websocket"
)
type testcase struct {
input *fakeConn
expected *fakeConn
}
type fakeConn struct {
// WebSocket message type
mt int
data []byte
err error
}
func (f *fakeConn) ReadMessage() (int, []byte, error) {
return f.mt, f.data, f.err
}
func (f *fakeConn) WriteMessage(mt int, data []byte) error {
f.mt = mt
f.data = data
return f.err
}
func (f *fakeConn) WriteControl(mt int, data []byte, _ time.Time) error {
f.mt = mt
f.data = data
return f.err
}
func (f *fakeConn) UnderlyingConn() net.Conn {
return nil
}
func fake(mt int, data []byte, err error) *fakeConn {
return &fakeConn{mt: mt, data: []byte(data), err: err}
}
var (
msg = []byte("foo bar")
msgBase64 = []byte("Zm9vIGJhcg==")
kubeMsg = append([]byte{0}, msg...)
kubeMsgBase64 = append([]byte{'0'}, msgBase64...)
fakeErr = errors.New("fake error")
text = websocket.TextMessage
binary = websocket.BinaryMessage
other = 999
fakeOther = fake(other, []byte("foo"), nil)
)
func assertEqual(t *testing.T, expected, actual *fakeConn, msg string, args ...interface{}) {
if expected.mt != actual.mt {
t.Logf("messageType expected to be %v but was %v", expected.mt, actual.mt)
t.Fatalf(msg, args...)
}
if bytes.Compare(expected.data, actual.data) != 0 {
t.Logf("data expected to be %q but was %q: ", expected.data, actual.data)
t.Fatalf(msg, args...)
}
if expected.err != actual.err {
t.Logf("error expected to be %v but was %v", expected.err, actual.err)
t.Fatalf(msg, args...)
}
}
func TestReadMessage(t *testing.T) {
testCases := map[string][]testcase{
"channel.k8s.io": {
{fake(binary, kubeMsg, fakeErr), fake(binary, kubeMsg, fakeErr)},
{fake(binary, kubeMsg, nil), fake(binary, msg, nil)},
{fake(text, kubeMsg, nil), fake(binary, msg, nil)},
{fakeOther, fakeOther},
},
"base64.channel.k8s.io": {
{fake(text, kubeMsgBase64, fakeErr), fake(text, kubeMsgBase64, fakeErr)},
{fake(text, kubeMsgBase64, nil), fake(binary, msg, nil)},
{fake(binary, kubeMsgBase64, nil), fake(binary, msg, nil)},
{fakeOther, fakeOther},
},
"terminal.gitlab.com": {
{fake(binary, msg, fakeErr), fake(binary, msg, fakeErr)},
{fake(binary, msg, nil), fake(binary, msg, nil)},
{fake(text, msg, nil), fake(binary, msg, nil)},
{fakeOther, fakeOther},
},
"base64.terminal.gitlab.com": {
{fake(text, msgBase64, fakeErr), fake(text, msgBase64, fakeErr)},
{fake(text, msgBase64, nil), fake(binary, msg, nil)},
{fake(binary, msgBase64, nil), fake(binary, msg, nil)},
{fakeOther, fakeOther},
},
}
for subprotocol, cases := range testCases {
for i, tc := range cases {
conn := Wrap(tc.input, subprotocol)
mt, data, err := conn.ReadMessage()
actual := fake(mt, data, err)
assertEqual(t, tc.expected, actual, "%s test case %v", subprotocol, i)
}
}
}
func TestWriteMessage(t *testing.T) {
testCases := map[string][]testcase{
"channel.k8s.io": {
{fake(binary, msg, fakeErr), fake(binary, kubeMsg, fakeErr)},
{fake(binary, msg, nil), fake(binary, kubeMsg, nil)},
{fake(text, msg, nil), fake(binary, kubeMsg, nil)},
{fakeOther, fakeOther},
},
"base64.channel.k8s.io": {
{fake(binary, msg, fakeErr), fake(text, kubeMsgBase64, fakeErr)},
{fake(binary, msg, nil), fake(text, kubeMsgBase64, nil)},
{fake(text, msg, nil), fake(text, kubeMsgBase64, nil)},
{fakeOther, fakeOther},
},
"terminal.gitlab.com": {
{fake(binary, msg, fakeErr), fake(binary, msg, fakeErr)},
{fake(binary, msg, nil), fake(binary, msg, nil)},
{fake(text, msg, nil), fake(binary, msg, nil)},
{fakeOther, fakeOther},
},
"base64.terminal.gitlab.com": {
{fake(binary, msg, fakeErr), fake(text, msgBase64, fakeErr)},
{fake(binary, msg, nil), fake(text, msgBase64, nil)},
{fake(text, msg, nil), fake(text, msgBase64, nil)},
{fakeOther, fakeOther},
},
}
for subprotocol, cases := range testCases {
for i, tc := range cases {
actual := fake(0, nil, tc.input.err)
conn := Wrap(actual, subprotocol)
actual.err = conn.WriteMessage(tc.input.mt, tc.input.data)
assertEqual(t, tc.expected, actual, "%s test case %v", subprotocol, i)
}
}
}
......@@ -4,21 +4,28 @@ import (
"net/http"
"regexp"
"github.com/gorilla/websocket"
apipkg "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/artifacts"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/git"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/lfs"
proxypkg "gitlab.com/gitlab-org/gitlab-workhorse/internal/proxy"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/queueing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/sendfile"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/staticpages"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/terminal"
)
type route struct {
type matcherFunc func(*http.Request) bool
type routeEntry struct {
method string
regex *regexp.Regexp
handler http.Handler
matchers []matcherFunc
}
const projectPattern = `^/[^/]+/[^/]+/`
......@@ -30,6 +37,52 @@ const apiPattern = `^/api/`
const projectsAPIPattern = `^/api/v3/projects/((\d+)|([^/]+/[^/]+))/`
const ciAPIPattern = `^/ci/api/`
func compileRegexp(regexpStr string) *regexp.Regexp {
if len(regexpStr) == 0 {
return nil
}
return regexp.MustCompile(regexpStr)
}
func route(method, regexpStr string, handler http.Handler, matchers ...matcherFunc) routeEntry {
return routeEntry{
method: method,
regex: compileRegexp(regexpStr),
handler: denyWebsocket(handler),
matchers: matchers,
}
}
func wsRoute(regexpStr string, handler http.Handler, matchers ...matcherFunc) routeEntry {
return routeEntry{
method: "GET",
regex: compileRegexp(regexpStr),
handler: handler,
matchers: append(matchers, websocket.IsWebSocketUpgrade),
}
}
func (ro *routeEntry) isMatch(cleanedPath string, req *http.Request) bool {
if ro.method != "" && req.Method != ro.method {
return false
}
if ro.regex != nil && !ro.regex.MatchString(cleanedPath) {
return false
}
ok := true
for _, matcher := range ro.matchers {
ok = matcher(req)
if !ok {
break
}
}
return ok
}
// Routing table
// We match against URI not containing the relativeUrlRoot:
// see upstream.ServeHTTP
......@@ -58,47 +111,61 @@ func (u *Upstream) configureRoutes() {
)
ciAPIProxyQueue := queueing.QueueRequests(proxy, u.APILimit, u.APIQueueLimit, u.APIQueueTimeout)
u.Routes = []route{
u.Routes = []routeEntry{
// Git Clone
route{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), git.GetInfoRefs(api)},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(git.PostRPC(api))},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(git.PostRPC(api))},
route{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfs.PutStore(api, proxy)},
route("GET", gitProjectPattern+`info/refs\z`, git.GetInfoRefs(api)),
route("POST", gitProjectPattern+`git-upload-pack\z`, contentEncodingHandler(git.PostRPC(api))),
route("POST", gitProjectPattern+`git-receive-pack\z`, contentEncodingHandler(git.PostRPC(api))),
route("PUT", gitProjectPattern+`gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`, lfs.PutStore(api, proxy)),
// CI Artifacts
route{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(artifacts.UploadArtifacts(api, proxy))},
route("POST", ciAPIPattern+`v1/builds/[0-9]+/artifacts\z`, contentEncodingHandler(artifacts.UploadArtifacts(api, proxy))),
// Terminal websocket
wsRoute(projectPattern+`environments/[0-9]+/terminal.ws\z`, terminal.Handler(api)),
// Limit capacity given to builds/register.json
route{"", regexp.MustCompile(ciAPIPattern + `v1/builds/register.json\z`), ciAPIProxyQueue},
route("", ciAPIPattern+`v1/builds/register.json\z`, ciAPIProxyQueue),
// Explicitly proxy API requests
route{"", regexp.MustCompile(apiPattern), proxy},
route{"", regexp.MustCompile(ciAPIPattern), proxy},
route("", apiPattern, proxy),
route("", ciAPIPattern, proxy),
// Serve assets
route{"", regexp.MustCompile(`^/assets/`),
static.ServeExisting(u.URLPrefix, staticpages.CacheExpireMax,
NotFoundUnless(u.DevelopmentMode,
proxy,
route(
"", `^/assets/`,
static.ServeExisting(
u.URLPrefix,
staticpages.CacheExpireMax,
NotFoundUnless(u.DevelopmentMode, proxy),
),
),
},
// For legacy reasons, user uploads are stored under the document root.
// To prevent anybody who knows/guesses the URL of a user-uploaded file
// from downloading it we make sure requests to /uploads/ do _not_ pass
// through static.ServeExisting.
route{"", regexp.MustCompile(`^/uploads/`), static.ErrorPagesUnless(u.DevelopmentMode, proxy)},
route("", `^/uploads/`, static.ErrorPagesUnless(u.DevelopmentMode, proxy)),
// Serve static files or forward the requests
route{"", nil,
static.ServeExisting(u.URLPrefix, staticpages.CacheDisabled,
static.DeployPage(
static.ErrorPagesUnless(u.DevelopmentMode,
proxy,
),
route(
"", "",
static.ServeExisting(
u.URLPrefix,
staticpages.CacheDisabled,
static.DeployPage(static.ErrorPagesUnless(u.DevelopmentMode, proxy)),
),
),
},
}
}
func denyWebsocket(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if websocket.IsWebSocketUpgrade(r) {
helper.HTTPError(w, r, "websocket upgrade not allowed", http.StatusBadRequest)
return
}
next.ServeHTTP(w, r)
})
}
......@@ -36,7 +36,7 @@ type Config struct {
type Upstream struct {
Config
URLPrefix urlprefix.Prefix
Routes []route
Routes []routeEntry
RoundTripper *badgateway.RoundTripper
}
......@@ -65,17 +65,17 @@ func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
w := helper.NewLoggingResponseWriter(ow)
defer w.Log(r)
helper.DisableResponseBuffering(&w)
helper.DisableResponseBuffering(w)
// Drop WebSocket connection and CONNECT method
// Drop RequestURI == "*" (FIXME: why?)
if r.RequestURI == "*" {
helper.HTTPError(&w, r, "Connection upgrade not allowed", http.StatusBadRequest)
helper.HTTPError(w, r, "Connection upgrade not allowed", http.StatusBadRequest)
return
}
// Disallow connect
if r.Method == "CONNECT" {
helper.HTTPError(&w, r, "CONNECT not allowed", http.StatusBadRequest)
helper.HTTPError(w, r, "CONNECT not allowed", http.StatusBadRequest)
return
}
......@@ -83,29 +83,25 @@ func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
URIPath := urlprefix.CleanURIPath(r.URL.Path)
prefix := u.URLPrefix
if !prefix.Match(URIPath) {
helper.HTTPError(&w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
helper.HTTPError(w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
return
}
// Look for a matching Git service
var ro route
foundService := false
for _, ro = range u.Routes {
if ro.method != "" && r.Method != ro.method {
continue
}
if ro.regex == nil || ro.regex.MatchString(prefix.Strip(URIPath)) {
foundService = true
// Look for a matching route
var route *routeEntry
for _, ro := range u.Routes {
if ro.isMatch(prefix.Strip(URIPath), r) {
route = &ro
break
}
}
if !foundService {
if route == nil {
// The protocol spec in git/Documentation/technical/http-protocol.txt
// says we must return 403 if no matching service is found.
helper.HTTPError(&w, r, "Forbidden", http.StatusForbidden)
helper.HTTPError(w, r, "Forbidden", http.StatusForbidden)
return
}
ro.handler.ServeHTTP(&w, r)
route.handler.ServeHTTP(w, r)
}
package main
import (
"bytes"
"encoding/pem"
"fmt"
"log"
"net"
"net/http"
"net/http/httptest"
"net/url"
"path"
"strings"
"testing"
"github.com/gorilla/websocket"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
)
var terminalPath = fmt.Sprintf("%s/environments/1/terminal.ws", testProject)
type connWithReq struct {
conn *websocket.Conn
req *http.Request
}
func TestTerminalHappyPath(t *testing.T) {
serverConns, clientURL, close := wireupTerminal(nil, "channel.k8s.io")
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
if err != nil {
t.Fatal(err)
}
server := (<-serverConns).conn
defer server.Close()
message := "test message"
// channel.k8s.io: server writes to channel 1, STDOUT
if err := say(server, "\x01"+message); err != nil {
t.Fatal(err)
}
assertReadMessage(t, client, websocket.BinaryMessage, message)
if err := say(client, message); err != nil {
t.Fatal(err)
}
// channel.k8s.io: client writes get put on channel 0, STDIN
assertReadMessage(t, server, websocket.BinaryMessage, "\x00"+message)
// Closing the client should send an EOT signal to the server's STDIN
client.Close()
assertReadMessage(t, server, websocket.BinaryMessage, "\x00\x04")
}
func TestTerminalBadTLS(t *testing.T) {
_, clientURL, close := wireupTerminal(badCA, "channel.k8s.io")
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
if err != websocket.ErrBadHandshake {
t.Fatalf("Expected connection to fail ErrBadHandshake, got: %v", err)
}
if err == nil {
log.Println("TLS negotiation should have failed!")
defer client.Close()
}
}
func TestTerminalProxyForwardsHeadersFromUpstream(t *testing.T) {
hdr := make(http.Header)
hdr.Set("Random-Header", "Value")
serverConns, clientURL, close := wireupTerminal(setHeader(hdr), "channel.k8s.io")
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
if err != nil {
t.Fatal(err)
}
defer client.Close()
sc := <-serverConns
defer sc.conn.Close()
if sc.req.Header.Get("Random-Header") != "Value" {
t.Fatal("Header specified by upstream not sent to remote")
}
}
func TestTerminalProxyForwardsXForwardedForFromClient(t *testing.T) {
serverConns, clientURL, close := wireupTerminal(nil, "channel.k8s.io")
defer close()
hdr := make(http.Header)
hdr.Set("X-Forwarded-For", "127.0.0.2")
client, _, err := dialWebsocket(clientURL, hdr, "terminal.gitlab.com")
if err != nil {
t.Fatal(err)
}
defer client.Close()
clientIP, _, err := net.SplitHostPort(client.LocalAddr().String())
if err != nil {
t.Fatal(err)
}
sc := <-serverConns
defer sc.conn.Close()
if xff := sc.req.Header.Get("X-Forwarded-For"); xff != "127.0.0.2, "+clientIP {
t.Fatalf("X-Forwarded-For from client not sent to remote: %+v", xff)
}
}
func wireupTerminal(modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) {
serverConns, remote := startWebsocketServer(subprotocols...)
authResponse := terminalOkBody(remote, nil, subprotocols...)
if modifier != nil {
modifier(authResponse)
}
upstream := testAuthServer(nil, 200, authResponse)
workhorse := startWorkhorseServer(upstream.URL)
return serverConns, websocketURL(workhorse.URL, terminalPath), func() {
workhorse.Close()
upstream.Close()
remote.Close()
}
}
func startWebsocketServer(subprotocols ...string) (chan connWithReq, *httptest.Server) {
upgrader := &websocket.Upgrader{Subprotocols: subprotocols}
connCh := make(chan connWithReq, 1)
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("WEBSOCKET", r.Method, r.URL, r.Header)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("WEBSOCKET", r.Method, r.URL, "Upgrade failed", err)
return
}
connCh <- connWithReq{conn, r}
// The connection has been hijacked so it's OK to end here
}))
return connCh, server
}
func terminalOkBody(remote *httptest.Server, header http.Header, subprotocols ...string) *api.Response {
out := &api.Response{
Terminal: &api.TerminalSettings{
Url: websocketURL(remote.URL),
Header: header,
Subprotocols: subprotocols,
},
}
if len(remote.TLS.Certificates) > 0 {
data := bytes.NewBuffer(nil)
pem.Encode(data, &pem.Block{Type: "CERTIFICATE", Bytes: remote.TLS.Certificates[0].Certificate[0]})
out.Terminal.CAPem = data.String()
}
return out
}
func badCA(authResponse *api.Response) {
authResponse.Terminal.CAPem = "Bad CA"
}
func setHeader(hdr http.Header) func(*api.Response) {
return func(authResponse *api.Response) {
authResponse.Terminal.Header = hdr
}
}
func dialWebsocket(url string, header http.Header, subprotocols ...string) (*websocket.Conn, *http.Response, error) {
dialer := &websocket.Dialer{
Subprotocols: subprotocols,
}
return dialer.Dial(url, header)
}
func websocketURL(httpURL string, suffix ...string) string {
url, err := url.Parse(httpURL)
if err != nil {
panic(err)
}
switch url.Scheme {
case "http":
url.Scheme = "ws"
case "https":
url.Scheme = "wss"
default:
panic("Unknown scheme: " + url.Scheme)
}
url.Path = path.Join(url.Path, strings.Join(suffix, "/"))
return url.String()
}
func say(conn *websocket.Conn, message string) error {
return conn.WriteMessage(websocket.TextMessage, []byte(message))
}
func assertReadMessage(t *testing.T, conn *websocket.Conn, expectedMessageType int, expectedData string) {
messageType, data, err := conn.ReadMessage()
if err != nil {
t.Fatal(err)
}
if messageType != expectedMessageType {
t.Fatalf("Expected message, %d, got %d", expectedMessageType, messageType)
}
if string(data) != expectedData {
t.Fatalf("Message was mangled in transit. Expected %q, got %q", expectedData, string(data))
}
}
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
.idea/
*.iml
\ No newline at end of file
language: go
sudo: false
matrix:
include:
- go: 1.4
- go: 1.5
- go: 1.6
- go: 1.7
- go: tip
allow_failures:
- go: tip
script:
- go get -t -v ./...
- diff -u <(echo -n) <(gofmt -d .)
- go vet $(go list ./... | grep -v /vendor/)
- go test -v -race ./...
# This is the official list of Gorilla WebSocket authors for copyright
# purposes.
#
# Please keep the list sorted.
Gary Burd <gary@beagledreams.com>
Joachim Bauch <mail@joachim-bauch.de>
Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Gorilla WebSocket
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
[![Build Status](https://travis-ci.org/gorilla/websocket.svg?branch=master)](https://travis-ci.org/gorilla/websocket)
[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
### Documentation
* [API Reference](http://godoc.org/github.com/gorilla/websocket)
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch)
### Status
The Gorilla WebSocket package provides a complete and tested implementation of
the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. The
package API is stable.
### Installation
go get github.com/gorilla/websocket
### Protocol Compliance
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](http://autobahn.ws/testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
### Gorilla WebSocket compared with other packages
<table>
<tr>
<th></th>
<th><a href="http://godoc.org/github.com/gorilla/websocket">github.com/gorilla</a></th>
<th><a href="http://godoc.org/golang.org/x/net/websocket">golang.org/x/net</a></th>
</tr>
<tr>
<tr><td colspan="3"><a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a> Features</td></tr>
<tr><td>Passes <a href="http://autobahn.ws/testsuite/">Autobahn Test Suite</a></td><td><a href="https://github.com/gorilla/websocket/tree/master/examples/autobahn">Yes</a></td><td>No</td></tr>
<tr><td>Receive <a href="https://tools.ietf.org/html/rfc6455#section-5.4">fragmented</a> message<td>Yes</td><td><a href="https://code.google.com/p/go/issues/detail?id=7632">No</a>, see note 1</td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.1">close</a> message</td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td><a href="https://code.google.com/p/go/issues/detail?id=4588">No</a></td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.2">pings</a> and receive <a href="https://tools.ietf.org/html/rfc6455#section-5.5.3">pongs</a></td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td>No</td></tr>
<tr><td>Get the <a href="https://tools.ietf.org/html/rfc6455#section-5.6">type</a> of a received data message</td><td>Yes</td><td>Yes, see note 2</td></tr>
<tr><td colspan="3">Other Features</tr></td>
<tr><td><a href="https://tools.ietf.org/html/rfc7692">Compression Extensions</a></td><td>Experimental</td><td>No</td></tr>
<tr><td>Read message using io.Reader</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextReader">Yes</a></td><td>No, see note 3</td></tr>
<tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr>
</table>
Notes:
1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html).
2. The application can get the type of a received data message by implementing
a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal)
function.
3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries.
Read returns when the input buffer is full or a frame boundary is
encountered. Each call to Write sends a single frame message. The Gorilla
io.Reader and io.WriteCloser operate on a single WebSocket message.
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"bytes"
"crypto/tls"
"encoding/base64"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
"time"
)
// ErrBadHandshake is returned when the server response to opening handshake is
// invalid.
var ErrBadHandshake = errors.New("websocket: bad handshake")
var errInvalidCompression = errors.New("websocket: invalid compression negotiation")
// NewClient creates a new client connection using the given net connection.
// The URL u specifies the host and request URI. Use requestHeader to specify
// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
// (Cookie). Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication,
// etc.
//
// Deprecated: Use Dialer instead.
func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
d := Dialer{
ReadBufferSize: readBufSize,
WriteBufferSize: writeBufSize,
NetDial: func(net, addr string) (net.Conn, error) {
return netConn, nil
},
}
return d.Dial(u.String(), requestHeader)
}
// A Dialer contains options for connecting to WebSocket server.
type Dialer struct {
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
// If Proxy is nil or returns a nil *URL, no proxy is used.
Proxy func(*http.Request) (*url.URL, error)
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used.
TLSClientConfig *tls.Config
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
// Input and output buffer sizes. If the buffer size is zero, then a
// default value of 4096 is used.
ReadBufferSize, WriteBufferSize int
// Subprotocols specifies the client's requested subprotocols.
Subprotocols []string
// EnableCompression specifies if the client should attempt to negotiate
// per message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
// Jar specifies the cookie jar.
// If Jar is nil, cookies are not sent in requests and ignored
// in responses.
Jar http.CookieJar
}
var errMalformedURL = errors.New("malformed ws or wss URL")
// parseURL parses the URL.
//
// This function is a replacement for the standard library url.Parse function.
// In Go 1.4 and earlier, url.Parse loses information from the path.
func parseURL(s string) (*url.URL, error) {
// From the RFC:
//
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
var u url.URL
switch {
case strings.HasPrefix(s, "ws://"):
u.Scheme = "ws"
s = s[len("ws://"):]
case strings.HasPrefix(s, "wss://"):
u.Scheme = "wss"
s = s[len("wss://"):]
default:
return nil, errMalformedURL
}
if i := strings.Index(s, "?"); i >= 0 {
u.RawQuery = s[i+1:]
s = s[:i]
}
if i := strings.Index(s, "/"); i >= 0 {
u.Opaque = s[i:]
s = s[:i]
} else {
u.Opaque = "/"
}
u.Host = s
if strings.Contains(u.Host, "@") {
// Don't bother parsing user information because user information is
// not allowed in websocket URIs.
return nil, errMalformedURL
}
return &u, nil
}
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
hostPort = u.Host
hostNoPort = u.Host
if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
hostNoPort = hostNoPort[:i]
} else {
switch u.Scheme {
case "wss":
hostPort += ":443"
case "https":
hostPort += ":443"
default:
hostPort += ":80"
}
}
return hostPort, hostNoPort
}
// DefaultDialer is a dialer with all fields set to the default zero values.
var DefaultDialer = &Dialer{
Proxy: http.ProxyFromEnvironment,
}
// Dial creates a new client connection. Use requestHeader to specify the
// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
// Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication,
// etcetera. The response body may not contain the entire response and does not
// need to be closed by the application.
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
if d == nil {
d = &Dialer{
Proxy: http.ProxyFromEnvironment,
}
}
challengeKey, err := generateChallengeKey()
if err != nil {
return nil, nil, err
}
u, err := parseURL(urlStr)
if err != nil {
return nil, nil, err
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
default:
return nil, nil, errMalformedURL
}
if u.User != nil {
// User name and password are not allowed in websocket URIs.
return nil, nil, errMalformedURL
}
req := &http.Request{
Method: "GET",
URL: u,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Host: u.Host,
}
// Set the cookies present in the cookie jar of the dialer
if d.Jar != nil {
for _, cookie := range d.Jar.Cookies(u) {
req.AddCookie(cookie)
}
}
// Set the request headers using the capitalization for names and values in
// RFC examples. Although the capitalization shouldn't matter, there are
// servers that depend on it. The Header.Set method is not used because the
// method canonicalizes the header names.
req.Header["Upgrade"] = []string{"websocket"}
req.Header["Connection"] = []string{"Upgrade"}
req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
req.Header["Sec-WebSocket-Version"] = []string{"13"}
if len(d.Subprotocols) > 0 {
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
}
for k, vs := range requestHeader {
switch {
case k == "Host":
if len(vs) > 0 {
req.Host = vs[0]
}
case k == "Upgrade" ||
k == "Connection" ||
k == "Sec-Websocket-Key" ||
k == "Sec-Websocket-Version" ||
k == "Sec-Websocket-Extensions" ||
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
default:
req.Header[k] = vs
}
}
if d.EnableCompression {
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
}
hostPort, hostNoPort := hostPortNoPort(u)
var proxyURL *url.URL
// Check wether the proxy method has been configured
if d.Proxy != nil {
proxyURL, err = d.Proxy(req)
}
if err != nil {
return nil, nil, err
}
var targetHostPort string
if proxyURL != nil {
targetHostPort, _ = hostPortNoPort(proxyURL)
} else {
targetHostPort = hostPort
}
var deadline time.Time
if d.HandshakeTimeout != 0 {
deadline = time.Now().Add(d.HandshakeTimeout)
}
netDial := d.NetDial
if netDial == nil {
netDialer := &net.Dialer{Deadline: deadline}
netDial = netDialer.Dial
}
netConn, err := netDial("tcp", targetHostPort)
if err != nil {
return nil, nil, err
}
defer func() {
if netConn != nil {
netConn.Close()
}
}()
if err := netConn.SetDeadline(deadline); err != nil {
return nil, nil, err
}
if proxyURL != nil {
connectHeader := make(http.Header)
if user := proxyURL.User; user != nil {
proxyUser := user.Username()
if proxyPassword, passwordSet := user.Password(); passwordSet {
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
}
}
connectReq := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: hostPort},
Host: hostPort,
Header: connectHeader,
}
connectReq.Write(netConn)
// Read response.
// Okay to use and discard buffered reader here, because
// TLS server will not speak until spoken to.
br := bufio.NewReader(netConn)
resp, err := http.ReadResponse(br, connectReq)
if err != nil {
return nil, nil, err
}
if resp.StatusCode != 200 {
f := strings.SplitN(resp.Status, " ", 2)
return nil, nil, errors.New(f[1])
}
}
if u.Scheme == "https" {
cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
}
tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
return nil, nil, err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return nil, nil, err
}
}
}
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
if err := req.Write(netConn); err != nil {
return nil, nil, err
}
resp, err := http.ReadResponse(conn.br, req)
if err != nil {
return nil, nil, err
}
if d.Jar != nil {
if rc := resp.Cookies(); len(rc) > 0 {
d.Jar.SetCookies(u, rc)
}
}
if resp.StatusCode != 101 ||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
// Before closing the network connection on return from this
// function, slurp up some of the response to aid application
// debugging.
buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf)
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, ErrBadHandshake
}
for _, ext := range parseExtensions(req.Header) {
if ext[""] != "permessage-deflate" {
continue
}
_, snct := ext["server_no_context_takeover"]
_, cnct := ext["client_no_context_takeover"]
if !snct || !cnct {
return nil, resp, errInvalidCompression
}
conn.newCompressionWriter = compressNoContextTakeover
conn.newDecompressionReader = decompressNoContextTakeover
break
}
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
netConn.SetDeadline(time.Time{})
netConn = nil // to avoid close in defer.
return conn, resp, nil
}
// cloneTLSConfig clones all public fields except the fields
// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
// config in active use.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"compress/flate"
"errors"
"io"
"strings"
)
func decompressNoContextTakeover(r io.Reader) io.Reader {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"
return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
}
func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
tw := &truncWriter{w: w}
fw, err := flate.NewWriter(tw, 3)
return &flateWrapper{fw: fw, tw: tw}, err
}
// truncWriter is an io.Writer that writes all but the last four bytes of the
// stream to another io.Writer.
type truncWriter struct {
w io.WriteCloser
n int
p [4]byte
}
func (w *truncWriter) Write(p []byte) (int, error) {
n := 0
// fill buffer first for simplicity.
if w.n < len(w.p) {
n = copy(w.p[w.n:], p)
p = p[n:]
w.n += n
if len(p) == 0 {
return n, nil
}
}
m := len(p)
if m > len(w.p) {
m = len(w.p)
}
if nn, err := w.w.Write(w.p[:m]); err != nil {
return n + nn, err
}
copy(w.p[:], w.p[m:])
copy(w.p[len(w.p)-m:], p[len(p)-m:])
nn, err := w.w.Write(p[:len(p)-m])
return n + nn, err
}
type flateWrapper struct {
fw *flate.Writer
tw *truncWriter
}
func (w *flateWrapper) Write(p []byte) (int, error) {
return w.fw.Write(p)
}
func (w *flateWrapper) Close() error {
err1 := w.fw.Flush()
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
}
err2 := w.tw.w.Close()
if err1 != nil {
return err1
}
return err2
}
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"encoding/binary"
"errors"
"io"
"io/ioutil"
"net"
"strconv"
"sync"
"time"
"unicode/utf8"
)
const (
// Frame header byte 0 bits from Section 5.2 of RFC 6455
finalBit = 1 << 7
rsv1Bit = 1 << 6
rsv2Bit = 1 << 5
rsv3Bit = 1 << 4
// Frame header byte 1 bits from Section 5.2 of RFC 6455
maskBit = 1 << 7
maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
maxControlFramePayloadSize = 125
writeWait = time.Second
defaultReadBufferSize = 4096
defaultWriteBufferSize = 4096
continuationFrame = 0
noFrame = -1
)
// Close codes defined in RFC 6455, section 11.7.
const (
CloseNormalClosure = 1000
CloseGoingAway = 1001
CloseProtocolError = 1002
CloseUnsupportedData = 1003
CloseNoStatusReceived = 1005
CloseAbnormalClosure = 1006
CloseInvalidFramePayloadData = 1007
ClosePolicyViolation = 1008
CloseMessageTooBig = 1009
CloseMandatoryExtension = 1010
CloseInternalServerErr = 1011
CloseServiceRestart = 1012
CloseTryAgainLater = 1013
CloseTLSHandshake = 1015
)
// The message types are defined in RFC 6455, section 11.8.
const (
// TextMessage denotes a text data message. The text message payload is
// interpreted as UTF-8 encoded text data.
TextMessage = 1
// BinaryMessage denotes a binary data message.
BinaryMessage = 2
// CloseMessage denotes a close control message. The optional message
// payload contains a numeric code and text. Use the FormatCloseMessage
// function to format a close message payload.
CloseMessage = 8
// PingMessage denotes a ping control message. The optional message payload
// is UTF-8 encoded text.
PingMessage = 9
// PongMessage denotes a ping control message. The optional message payload
// is UTF-8 encoded text.
PongMessage = 10
)
// ErrCloseSent is returned when the application writes a message to the
// connection after sending a close message.
var ErrCloseSent = errors.New("websocket: close sent")
// ErrReadLimit is returned when reading a message that is larger than the
// read limit set for the connection.
var ErrReadLimit = errors.New("websocket: read limit exceeded")
// netError satisfies the net Error interface.
type netError struct {
msg string
temporary bool
timeout bool
}
func (e *netError) Error() string { return e.msg }
func (e *netError) Temporary() bool { return e.temporary }
func (e *netError) Timeout() bool { return e.timeout }
// CloseError represents close frame.
type CloseError struct {
// Code is defined in RFC 6455, section 11.7.
Code int
// Text is the optional text payload.
Text string
}
func (e *CloseError) Error() string {
s := []byte("websocket: close ")
s = strconv.AppendInt(s, int64(e.Code), 10)
switch e.Code {
case CloseNormalClosure:
s = append(s, " (normal)"...)
case CloseGoingAway:
s = append(s, " (going away)"...)
case CloseProtocolError:
s = append(s, " (protocol error)"...)
case CloseUnsupportedData:
s = append(s, " (unsupported data)"...)
case CloseNoStatusReceived:
s = append(s, " (no status)"...)
case CloseAbnormalClosure:
s = append(s, " (abnormal closure)"...)
case CloseInvalidFramePayloadData:
s = append(s, " (invalid payload data)"...)
case ClosePolicyViolation:
s = append(s, " (policy violation)"...)
case CloseMessageTooBig:
s = append(s, " (message too big)"...)
case CloseMandatoryExtension:
s = append(s, " (mandatory extension missing)"...)
case CloseInternalServerErr:
s = append(s, " (internal server error)"...)
case CloseTLSHandshake:
s = append(s, " (TLS handshake error)"...)
}
if e.Text != "" {
s = append(s, ": "...)
s = append(s, e.Text...)
}
return string(s)
}
// IsCloseError returns boolean indicating whether the error is a *CloseError
// with one of the specified codes.
func IsCloseError(err error, codes ...int) bool {
if e, ok := err.(*CloseError); ok {
for _, code := range codes {
if e.Code == code {
return true
}
}
}
return false
}
// IsUnexpectedCloseError returns boolean indicating whether the error is a
// *CloseError with a code not in the list of expected codes.
func IsUnexpectedCloseError(err error, expectedCodes ...int) bool {
if e, ok := err.(*CloseError); ok {
for _, code := range expectedCodes {
if e.Code == code {
return false
}
}
return true
}
return false
}
var (
errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true}
errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()}
errBadWriteOpCode = errors.New("websocket: bad write message type")
errWriteClosed = errors.New("websocket: write closed")
errInvalidControlFrame = errors.New("websocket: invalid control frame")
)
func hideTempErr(err error) error {
if e, ok := err.(net.Error); ok && e.Temporary() {
err = &netError{msg: e.Error(), timeout: e.Timeout()}
}
return err
}
func isControl(frameType int) bool {
return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage
}
func isData(frameType int) bool {
return frameType == TextMessage || frameType == BinaryMessage
}
var validReceivedCloseCodes = map[int]bool{
// see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
CloseNormalClosure: true,
CloseGoingAway: true,
CloseProtocolError: true,
CloseUnsupportedData: true,
CloseNoStatusReceived: false,
CloseAbnormalClosure: false,
CloseInvalidFramePayloadData: true,
ClosePolicyViolation: true,
CloseMessageTooBig: true,
CloseMandatoryExtension: true,
CloseInternalServerErr: true,
CloseServiceRestart: true,
CloseTryAgainLater: true,
CloseTLSHandshake: false,
}
func isValidReceivedCloseCode(code int) bool {
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
}
type Conn struct {
conn net.Conn
isServer bool
subprotocol string
// Write fields
mu chan bool // used as mutex to protect write to conn
writeBuf []byte // frame is constructed in this buffer.
writeDeadline time.Time
writer io.WriteCloser // the current writer returned to the application
isWriting bool // for best-effort concurrent write detection
writeErrMu sync.Mutex
writeErr error
enableWriteCompression bool
newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)
// Read fields
readErr error
br *bufio.Reader
readRemaining int64 // bytes remaining in current frame.
readFinal bool // true the current message has more frames.
readLength int64 // Message size.
readLimit int64 // Maximum message size.
readMaskPos int
readMaskKey [4]byte
handlePong func(string) error
handlePing func(string) error
handleClose func(int, string) error
readErrCount int
messageReader *messageReader // the current low-level reader
readDecompress bool // whether last read frame had RSV1 set
newDecompressionReader func(io.Reader) io.Reader
}
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
mu := make(chan bool, 1)
mu <- true
if readBufferSize == 0 {
readBufferSize = defaultReadBufferSize
}
if readBufferSize < maxControlFramePayloadSize {
readBufferSize = maxControlFramePayloadSize
}
if writeBufferSize == 0 {
writeBufferSize = defaultWriteBufferSize
}
c := &Conn{
isServer: isServer,
br: bufio.NewReaderSize(conn, readBufferSize),
conn: conn,
mu: mu,
readFinal: true,
writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
enableWriteCompression: true,
}
c.SetCloseHandler(nil)
c.SetPingHandler(nil)
c.SetPongHandler(nil)
return c
}
// Subprotocol returns the negotiated protocol for the connection.
func (c *Conn) Subprotocol() string {
return c.subprotocol
}
// Close closes the underlying network connection without sending or waiting for a close frame.
func (c *Conn) Close() error {
return c.conn.Close()
}
// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// Write methods
func (c *Conn) writeFatal(err error) error {
err = hideTempErr(err)
c.writeErrMu.Lock()
if c.writeErr == nil {
c.writeErr = err
}
c.writeErrMu.Unlock()
return err
}
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
<-c.mu
defer func() { c.mu <- true }()
c.writeErrMu.Lock()
err := c.writeErr
c.writeErrMu.Unlock()
if err != nil {
return err
}
c.conn.SetWriteDeadline(deadline)
for _, buf := range bufs {
if len(buf) > 0 {
_, err := c.conn.Write(buf)
if err != nil {
return c.writeFatal(err)
}
}
}
if frameType == CloseMessage {
c.writeFatal(ErrCloseSent)
}
return nil
}
// WriteControl writes a control message with the given deadline. The allowed
// message types are CloseMessage, PingMessage and PongMessage.
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
if !isControl(messageType) {
return errBadWriteOpCode
}
if len(data) > maxControlFramePayloadSize {
return errInvalidControlFrame
}
b0 := byte(messageType) | finalBit
b1 := byte(len(data))
if !c.isServer {
b1 |= maskBit
}
buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize)
buf = append(buf, b0, b1)
if c.isServer {
buf = append(buf, data...)
} else {
key := newMaskKey()
buf = append(buf, key[:]...)
buf = append(buf, data...)
maskBytes(key, 0, buf[6:])
}
d := time.Hour * 1000
if !deadline.IsZero() {
d = deadline.Sub(time.Now())
if d < 0 {
return errWriteTimeout
}
}
timer := time.NewTimer(d)
select {
case <-c.mu:
timer.Stop()
case <-timer.C:
return errWriteTimeout
}
defer func() { c.mu <- true }()
c.writeErrMu.Lock()
err := c.writeErr
c.writeErrMu.Unlock()
if err != nil {
return err
}
c.conn.SetWriteDeadline(deadline)
_, err = c.conn.Write(buf)
if err != nil {
return c.writeFatal(err)
}
if messageType == CloseMessage {
c.writeFatal(ErrCloseSent)
}
return err
}
// NextWriter returns a writer for the next message to send. The writer's Close
// method flushes the complete message to the network.
//
// There can be at most one open writer on a connection. NextWriter closes the
// previous writer if the application has not already done so.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
// Close previous writer if not already closed by the application. It's
// probably better to return an error in this situation, but we cannot
// change this without breaking existing applications.
if c.writer != nil {
c.writer.Close()
c.writer = nil
}
if !isControl(messageType) && !isData(messageType) {
return nil, errBadWriteOpCode
}
c.writeErrMu.Lock()
err := c.writeErr
c.writeErrMu.Unlock()
if err != nil {
return nil, err
}
mw := &messageWriter{
c: c,
frameType: messageType,
pos: maxFrameHeaderSize,
}
c.writer = mw
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
w, err := c.newCompressionWriter(c.writer)
if err != nil {
c.writer = nil
return nil, err
}
mw.compress = true
c.writer = w
}
return c.writer, nil
}
type messageWriter struct {
c *Conn
compress bool // whether next call to flushFrame should set RSV1
pos int // end of data in writeBuf.
frameType int // type of the current frame.
err error
}
func (w *messageWriter) fatal(err error) error {
if w.err != nil {
w.err = err
w.c.writer = nil
}
return err
}
// flushFrame writes buffered data and extra as a frame to the network. The
// final argument indicates that this is the last frame in the message.
func (w *messageWriter) flushFrame(final bool, extra []byte) error {
c := w.c
length := w.pos - maxFrameHeaderSize + len(extra)
// Check for invalid control frames.
if isControl(w.frameType) &&
(!final || length > maxControlFramePayloadSize) {
return w.fatal(errInvalidControlFrame)
}
b0 := byte(w.frameType)
if final {
b0 |= finalBit
}
if w.compress {
b0 |= rsv1Bit
}
w.compress = false
b1 := byte(0)
if !c.isServer {
b1 |= maskBit
}
// Assume that the frame starts at beginning of c.writeBuf.
framePos := 0
if c.isServer {
// Adjust up if mask not included in the header.
framePos = 4
}
switch {
case length >= 65536:
c.writeBuf[framePos] = b0
c.writeBuf[framePos+1] = b1 | 127
binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length))
case length > 125:
framePos += 6
c.writeBuf[framePos] = b0
c.writeBuf[framePos+1] = b1 | 126
binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length))
default:
framePos += 8
c.writeBuf[framePos] = b0
c.writeBuf[framePos+1] = b1 | byte(length)
}
if !c.isServer {
key := newMaskKey()
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
if len(extra) > 0 {
return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
}
}
// Write the buffers to the connection with best-effort detection of
// concurrent writes. See the concurrency section in the package
// documentation for more info.
if c.isWriting {
panic("concurrent write to websocket connection")
}
c.isWriting = true
err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra)
if !c.isWriting {
panic("concurrent write to websocket connection")
}
c.isWriting = false
if err != nil {
return w.fatal(err)
}
if final {
c.writer = nil
return nil
}
// Setup for next frame.
w.pos = maxFrameHeaderSize
w.frameType = continuationFrame
return nil
}
func (w *messageWriter) ncopy(max int) (int, error) {
n := len(w.c.writeBuf) - w.pos
if n <= 0 {
if err := w.flushFrame(false, nil); err != nil {
return 0, err
}
n = len(w.c.writeBuf) - w.pos
}
if n > max {
n = max
}
return n, nil
}
func (w *messageWriter) Write(p []byte) (int, error) {
if w.err != nil {
return 0, w.err
}
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
// Don't buffer large messages.
err := w.flushFrame(false, p)
if err != nil {
return 0, err
}
return len(p), nil
}
nn := len(p)
for len(p) > 0 {
n, err := w.ncopy(len(p))
if err != nil {
return 0, err
}
copy(w.c.writeBuf[w.pos:], p[:n])
w.pos += n
p = p[n:]
}
return nn, nil
}
func (w *messageWriter) WriteString(p string) (int, error) {
if w.err != nil {
return 0, w.err
}
nn := len(p)
for len(p) > 0 {
n, err := w.ncopy(len(p))
if err != nil {
return 0, err
}
copy(w.c.writeBuf[w.pos:], p[:n])
w.pos += n
p = p[n:]
}
return nn, nil
}
func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
if w.err != nil {
return 0, w.err
}
for {
if w.pos == len(w.c.writeBuf) {
err = w.flushFrame(false, nil)
if err != nil {
break
}
}
var n int
n, err = r.Read(w.c.writeBuf[w.pos:])
w.pos += n
nn += int64(n)
if err != nil {
if err == io.EOF {
err = nil
}
break
}
}
return nn, err
}
func (w *messageWriter) Close() error {
if w.err != nil {
return w.err
}
if err := w.flushFrame(true, nil); err != nil {
return err
}
w.err = errWriteClosed
return nil
}
// WriteMessage is a helper method for getting a writer using NextWriter,
// writing the message and closing the writer.
func (c *Conn) WriteMessage(messageType int, data []byte) error {
w, err := c.NextWriter(messageType)
if err != nil {
return err
}
if mw, ok := w.(*messageWriter); ok && c.isServer {
// Optimize write as a single frame.
n := copy(c.writeBuf[mw.pos:], data)
mw.pos += n
data = data[n:]
err = mw.flushFrame(true, data)
return err
}
if _, err = w.Write(data); err != nil {
return err
}
return w.Close()
}
// SetWriteDeadline sets the write deadline on the underlying network
// connection. After a write has timed out, the websocket state is corrupt and
// all future writes will return an error. A zero value for t means writes will
// not time out.
func (c *Conn) SetWriteDeadline(t time.Time) error {
c.writeDeadline = t
return nil
}
// Read methods
func (c *Conn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame.
if c.readRemaining > 0 {
if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil {
return noFrame, err
}
}
// 2. Read and parse first two bytes of frame header.
p, err := c.read(2)
if err != nil {
return noFrame, err
}
final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf)
mask := p[1]&maskBit != 0
c.readRemaining = int64(p[1] & 0x7f)
c.readDecompress = false
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
c.readDecompress = true
p[0] &^= rsv1Bit
}
if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
}
switch frameType {
case CloseMessage, PingMessage, PongMessage:
if c.readRemaining > maxControlFramePayloadSize {
return noFrame, c.handleProtocolError("control frame length > 125")
}
if !final {
return noFrame, c.handleProtocolError("control frame not final")
}
case TextMessage, BinaryMessage:
if !c.readFinal {
return noFrame, c.handleProtocolError("message start before final message frame")
}
c.readFinal = final
case continuationFrame:
if c.readFinal {
return noFrame, c.handleProtocolError("continuation after final message frame")
}
c.readFinal = final
default:
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
}
// 3. Read and parse frame length.
switch c.readRemaining {
case 126:
p, err := c.read(2)
if err != nil {
return noFrame, err
}
c.readRemaining = int64(binary.BigEndian.Uint16(p))
case 127:
p, err := c.read(8)
if err != nil {
return noFrame, err
}
c.readRemaining = int64(binary.BigEndian.Uint64(p))
}
// 4. Handle frame masking.
if mask != c.isServer {
return noFrame, c.handleProtocolError("incorrect mask flag")
}
if mask {
c.readMaskPos = 0
p, err := c.read(len(c.readMaskKey))
if err != nil {
return noFrame, err
}
copy(c.readMaskKey[:], p)
}
// 5. For text and binary messages, enforce read limit and return.
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
c.readLength += c.readRemaining
if c.readLimit > 0 && c.readLength > c.readLimit {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
return noFrame, ErrReadLimit
}
return frameType, nil
}
// 6. Read control frame payload.
var payload []byte
if c.readRemaining > 0 {
payload, err = c.read(int(c.readRemaining))
c.readRemaining = 0
if err != nil {
return noFrame, err
}
if c.isServer {
maskBytes(c.readMaskKey, 0, payload)
}
}
// 7. Process control frame payload.
switch frameType {
case PongMessage:
if err := c.handlePong(string(payload)); err != nil {
return noFrame, err
}
case PingMessage:
if err := c.handlePing(string(payload)); err != nil {
return noFrame, err
}
case CloseMessage:
closeCode := CloseNoStatusReceived
closeText := ""
if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) {
return noFrame, c.handleProtocolError("invalid close code")
}
closeText = string(payload[2:])
if !utf8.ValidString(closeText) {
return noFrame, c.handleProtocolError("invalid utf8 payload in close frame")
}
}
if err := c.handleClose(closeCode, closeText); err != nil {
return noFrame, err
}
return noFrame, &CloseError{Code: closeCode, Text: closeText}
}
return frameType, nil
}
func (c *Conn) handleProtocolError(message string) error {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
return errors.New("websocket: " + message)
}
// NextReader returns the next data message received from the peer. The
// returned messageType is either TextMessage or BinaryMessage.
//
// There can be at most one open reader on a connection. NextReader discards
// the previous message if the application has not already consumed it.
//
// Applications must break out of the application's read loop when this method
// returns a non-nil error value. Errors returned from this method are
// permanent. Once this method returns a non-nil error, all subsequent calls to
// this method return the same error.
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
c.messageReader = nil
c.readLength = 0
for c.readErr == nil {
frameType, err := c.advanceFrame()
if err != nil {
c.readErr = hideTempErr(err)
break
}
if frameType == TextMessage || frameType == BinaryMessage {
c.messageReader = &messageReader{c}
var r io.Reader = c.messageReader
if c.readDecompress {
r = c.newDecompressionReader(r)
}
return frameType, r, nil
}
}
// Applications that do handle the error returned from this method spin in
// tight loop on connection failure. To help application developers detect
// this error, panic on repeated reads to the failed connection.
c.readErrCount++
if c.readErrCount >= 1000 {
panic("repeated read on failed websocket connection")
}
return noFrame, nil, c.readErr
}
type messageReader struct{ c *Conn }
func (r *messageReader) Read(b []byte) (int, error) {
c := r.c
if c.messageReader != r {
return 0, io.EOF
}
for c.readErr == nil {
if c.readRemaining > 0 {
if int64(len(b)) > c.readRemaining {
b = b[:c.readRemaining]
}
n, err := c.br.Read(b)
c.readErr = hideTempErr(err)
if c.isServer {
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
}
c.readRemaining -= int64(n)
if c.readRemaining > 0 && c.readErr == io.EOF {
c.readErr = errUnexpectedEOF
}
return n, c.readErr
}
if c.readFinal {
c.messageReader = nil
return 0, io.EOF
}
frameType, err := c.advanceFrame()
switch {
case err != nil:
c.readErr = hideTempErr(err)
case frameType == TextMessage || frameType == BinaryMessage:
c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
}
}
err := c.readErr
if err == io.EOF && c.messageReader == r {
err = errUnexpectedEOF
}
return 0, err
}
// ReadMessage is a helper method for getting a reader using NextReader and
// reading from that reader to a buffer.
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
var r io.Reader
messageType, r, err = c.NextReader()
if err != nil {
return messageType, nil, err
}
p, err = ioutil.ReadAll(r)
return messageType, p, err
}
// SetReadDeadline sets the read deadline on the underlying network connection.
// After a read has timed out, the websocket connection state is corrupt and
// all future reads will return an error. A zero value for t means reads will
// not time out.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetReadLimit sets the maximum size for a message read from the peer. If a
// message exceeds the limit, the connection sends a close frame to the peer
// and returns ErrReadLimit to the application.
func (c *Conn) SetReadLimit(limit int64) {
c.readLimit = limit
}
// CloseHandler returns the current close handler
func (c *Conn) CloseHandler() func(code int, text string) error {
return c.handleClose
}
// SetCloseHandler sets the handler for close messages received from the peer.
// The code argument to h is the received close code or CloseNoStatusReceived
// if the close message is empty. The default close handler sends a close frame
// back to the peer.
func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
if h == nil {
h = func(code int, text string) error {
message := []byte{}
if code != CloseNoStatusReceived {
message = FormatCloseMessage(code, "")
}
c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
return nil
}
}
c.handleClose = h
}
// PingHandler returns the current ping handler
func (c *Conn) PingHandler() func(appData string) error {
return c.handlePing
}
// SetPingHandler sets the handler for ping messages received from the peer.
// The appData argument to h is the PING frame application data. The default
// ping handler sends a pong to the peer.
func (c *Conn) SetPingHandler(h func(appData string) error) {
if h == nil {
h = func(message string) error {
err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
if err == ErrCloseSent {
return nil
} else if e, ok := err.(net.Error); ok && e.Temporary() {
return nil
}
return err
}
}
c.handlePing = h
}
// PongHandler returns the current pong handler
func (c *Conn) PongHandler() func(appData string) error {
return c.handlePong
}
// SetPongHandler sets the handler for pong messages received from the peer.
// The appData argument to h is the PONG frame application data. The default
// pong handler does nothing.
func (c *Conn) SetPongHandler(h func(appData string) error) {
if h == nil {
h = func(string) error { return nil }
}
c.handlePong = h
}
// UnderlyingConn returns the internal net.Conn. This can be used to further
// modifications to connection specific flags.
func (c *Conn) UnderlyingConn() net.Conn {
return c.conn
}
// EnableWriteCompression enables and disables write compression of
// subsequent text and binary messages. This function is a noop if
// compression was not negotiated with the peer.
func (c *Conn) EnableWriteCompression(enable bool) {
c.enableWriteCompression = enable
}
// FormatCloseMessage formats closeCode and text as a WebSocket close message.
func FormatCloseMessage(closeCode int, text string) []byte {
buf := make([]byte, 2+len(text))
binary.BigEndian.PutUint16(buf, uint16(closeCode))
copy(buf[2:], text)
return buf
}
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.5
package websocket
import "io"
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
c.br.Discard(len(p))
return p, err
}
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.5
package websocket
import "io"
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
if len(p) > 0 {
// advance over the bytes just read
io.ReadFull(c.br, p)
}
return p, err
}
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package websocket implements the WebSocket protocol defined in RFC 6455.
//
// Overview
//
// The Conn type represents a WebSocket connection. A server application uses
// the Upgrade function from an Upgrader object with a HTTP request handler
// to get a pointer to a Conn:
//
// var upgrader = websocket.Upgrader{
// ReadBufferSize: 1024,
// WriteBufferSize: 1024,
// }
//
// func handler(w http.ResponseWriter, r *http.Request) {
// conn, err := upgrader.Upgrade(w, r, nil)
// if err != nil {
// log.Println(err)
// return
// }
// ... Use conn to send and receive messages.
// }
//
// Call the connection's WriteMessage and ReadMessage methods to send and
// receive messages as a slice of bytes. This snippet of code shows how to echo
// messages using these methods:
//
// for {
// messageType, p, err := conn.ReadMessage()
// if err != nil {
// return
// }
// if err = conn.WriteMessage(messageType, p); err != nil {
// return err
// }
// }
//
// In above snippet of code, p is a []byte and messageType is an int with value
// websocket.BinaryMessage or websocket.TextMessage.
//
// An application can also send and receive messages using the io.WriteCloser
// and io.Reader interfaces. To send a message, call the connection NextWriter
// method to get an io.WriteCloser, write the message to the writer and close
// the writer when done. To receive a message, call the connection NextReader
// method to get an io.Reader and read until io.EOF is returned. This snippet
// shows how to echo messages using the NextWriter and NextReader methods:
//
// for {
// messageType, r, err := conn.NextReader()
// if err != nil {
// return
// }
// w, err := conn.NextWriter(messageType)
// if err != nil {
// return err
// }
// if _, err := io.Copy(w, r); err != nil {
// return err
// }
// if err := w.Close(); err != nil {
// return err
// }
// }
//
// Data Messages
//
// The WebSocket protocol distinguishes between text and binary data messages.
// Text messages are interpreted as UTF-8 encoded text. The interpretation of
// binary messages is left to the application.
//
// This package uses the TextMessage and BinaryMessage integer constants to
// identify the two data message types. The ReadMessage and NextReader methods
// return the type of the received message. The messageType argument to the
// WriteMessage and NextWriter methods specifies the type of a sent message.
//
// It is the application's responsibility to ensure that text messages are
// valid UTF-8 encoded text.
//
// Control Messages
//
// The WebSocket protocol defines three types of control messages: close, ping
// and pong. Call the connection WriteControl, WriteMessage or NextWriter
// methods to send a control message to the peer.
//
// Connections handle received close messages by sending a close message to the
// peer and returning a *CloseError from the the NextReader, ReadMessage or the
// message Read method.
//
// Connections handle received ping and pong messages by invoking callback
// functions set with SetPingHandler and SetPongHandler methods. The callback
// functions are called from the NextReader, ReadMessage and the message Read
// methods.
//
// The default ping handler sends a pong to the peer. The application's reading
// goroutine can block for a short time while the handler writes the pong data
// to the connection.
//
// The application must read the connection to process ping, pong and close
// messages sent from the peer. If the application is not otherwise interested
// in messages from the peer, then the application should start a goroutine to
// read and discard messages from the peer. A simple example is:
//
// func readLoop(c *websocket.Conn) {
// for {
// if _, _, err := c.NextReader(); err != nil {
// c.Close()
// break
// }
// }
// }
//
// Concurrency
//
// Connections support one concurrent reader and one concurrent writer.
//
// Applications are responsible for ensuring that no more than one goroutine
// calls the write methods (NextWriter, SetWriteDeadline, WriteMessage,
// WriteJSON) concurrently and that no more than one goroutine calls the read
// methods (NextReader, SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler,
// SetPingHandler) concurrently.
//
// The Close and WriteControl methods can be called concurrently with all other
// methods.
//
// Origin Considerations
//
// Web browsers allow Javascript applications to open a WebSocket connection to
// any host. It's up to the server to enforce an origin policy using the Origin
// request header sent by the browser.
//
// The Upgrader calls the function specified in the CheckOrigin field to check
// the origin. If the CheckOrigin function returns false, then the Upgrade
// method fails the WebSocket handshake with HTTP status 403.
//
// If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail
// the handshake if the Origin request header is present and not equal to the
// Host request header.
//
// An application can allow connections from any origin by specifying a
// function that always returns true:
//
// var upgrader = websocket.Upgrader{
// CheckOrigin: func(r *http.Request) bool { return true },
// }
//
// The deprecated Upgrade function does not enforce an origin policy. It's the
// application's responsibility to check the Origin header before calling
// Upgrade.
//
// Compression [Experimental]
//
// Per message compression extensions (RFC 7692) are experimentally supported
// by this package in a limited capacity. Setting the EnableCompression option
// to true in Dialer or Upgrader will attempt to negotiate per message deflate
// support. If compression was successfully negotiated with the connection's
// peer, any message received in compressed form will be automatically
// decompressed. All Read methods will return uncompressed bytes.
//
// Per message compression of messages written to a connection can be enabled
// or disabled by calling the corresponding Conn method:
//
// conn.EnableWriteCompression(true)
//
// Currently this package does not support compression with "context takeover".
// This means that messages must be compressed and decompressed in isolation,
// without retaining sliding window or dictionary state across messages. For
// more details refer to RFC 7692.
//
// Use of compression is experimental and may result in decreased performance.
package websocket
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"encoding/json"
"io"
)
// WriteJSON is deprecated, use c.WriteJSON instead.
func WriteJSON(c *Conn, v interface{}) error {
return c.WriteJSON(v)
}
// WriteJSON writes the JSON encoding of v to the connection.
//
// See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON.
func (c *Conn) WriteJSON(v interface{}) error {
w, err := c.NextWriter(TextMessage)
if err != nil {
return err
}
err1 := json.NewEncoder(w).Encode(v)
err2 := w.Close()
if err1 != nil {
return err1
}
return err2
}
// ReadJSON is deprecated, use c.ReadJSON instead.
func ReadJSON(c *Conn, v interface{}) error {
return c.ReadJSON(v)
}
// ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
//
// See the documentation for the encoding/json Unmarshal function for details
// about the conversion of JSON to a Go value.
func (c *Conn) ReadJSON(v interface{}) error {
_, r, err := c.NextReader()
if err != nil {
return err
}
err = json.NewDecoder(r).Decode(v)
if err == io.EOF {
// One value is expected in the message.
err = io.ErrUnexpectedEOF
}
return err
}
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
package websocket
import (
"math/rand"
"unsafe"
)
const wordSize = int(unsafe.Sizeof(uintptr(0)))
func newMaskKey() [4]byte {
n := rand.Uint32()
return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}
}
func maskBytes(key [4]byte, pos int, b []byte) int {
// Mask one byte at a time for small buffers.
if len(b) < 2*wordSize {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}
// Mask one byte at a time to word boundary.
if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 {
n = wordSize - n
for i := range b[:n] {
b[i] ^= key[pos&3]
pos++
}
b = b[n:]
}
// Create aligned word size key.
var k [wordSize]byte
for i := range k {
k[i] = key[(pos+i)&3]
}
kw := *(*uintptr)(unsafe.Pointer(&k))
// Mask one word at a time.
n := (len(b) / wordSize) * wordSize
for i := 0; i < n; i += wordSize {
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
}
// Mask one byte at a time for remaining bytes.
b = b[n:]
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"errors"
"net"
"net/http"
"net/url"
"strings"
"time"
)
// HandshakeError describes an error with the handshake from the peer.
type HandshakeError struct {
message string
}
func (e HandshakeError) Error() string { return e.message }
// Upgrader specifies parameters for upgrading an HTTP connection to a
// WebSocket connection.
type Upgrader struct {
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
// size is zero, then a default value of 4096 is used. The I/O buffer sizes
// do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int
// Subprotocols specifies the server's supported protocols in order of
// preference. If this field is set, then the Upgrade method negotiates a
// subprotocol by selecting the first match in this list with a protocol
// requested by the client.
Subprotocols []string
// Error specifies the function for generating HTTP error responses. If Error
// is nil, then http.Error is used to generate the HTTP response.
Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
// CheckOrigin returns true if the request Origin header is acceptable. If
// CheckOrigin is nil, the host in the Origin header must not be set or
// must match the host of the request.
CheckOrigin func(r *http.Request) bool
// EnableCompression specify if the server should attempt to negotiate per
// message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
}
func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
err := HandshakeError{reason}
if u.Error != nil {
u.Error(w, r, status, err)
} else {
w.Header().Set("Sec-Websocket-Version", "13")
http.Error(w, http.StatusText(status), status)
}
return nil, err
}
// checkSameOrigin returns true if the origin is not set or is equal to the request host.
func checkSameOrigin(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin[0])
if err != nil {
return false
}
return u.Host == r.Host
}
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
if u.Subprotocols != nil {
clientProtocols := Subprotocols(r)
for _, serverProtocol := range u.Subprotocols {
for _, clientProtocol := range clientProtocols {
if clientProtocol == serverProtocol {
return clientProtocol
}
}
}
} else if responseHeader != nil {
return responseHeader.Get("Sec-Websocket-Protocol")
}
return ""
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// application negotiated subprotocol (Sec-Websocket-Protocol).
//
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response.
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
if r.Method != "GET" {
return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
}
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific Sec-Websocket-Extensions headers are unsupported")
}
if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13")
}
if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find connection header with token 'upgrade'")
}
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find upgrade header with token 'websocket'")
}
checkOrigin := u.CheckOrigin
if checkOrigin == nil {
checkOrigin = checkSameOrigin
}
if !checkOrigin(r) {
return u.returnError(w, r, http.StatusForbidden, "websocket: origin not allowed")
}
challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" {
return u.returnError(w, r, http.StatusBadRequest, "websocket: key missing or blank")
}
subprotocol := u.selectSubprotocol(r, responseHeader)
// Negotiate PMCE
var compress bool
if u.EnableCompression {
for _, ext := range parseExtensions(r.Header) {
if ext[""] != "permessage-deflate" {
continue
}
compress = true
break
}
}
var (
netConn net.Conn
br *bufio.Reader
err error
)
h, ok := w.(http.Hijacker)
if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
}
var rw *bufio.ReadWriter
netConn, rw, err = h.Hijack()
if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
}
br = rw.Reader
if br.Buffered() > 0 {
netConn.Close()
return nil, errors.New("websocket: client sent data before handshake is complete")
}
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
c.subprotocol = subprotocol
if compress {
c.newCompressionWriter = compressNoContextTakeover
c.newDecompressionReader = decompressNoContextTakeover
}
p := c.writeBuf[:0]
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
p = append(p, computeAcceptKey(challengeKey)...)
p = append(p, "\r\n"...)
if c.subprotocol != "" {
p = append(p, "Sec-Websocket-Protocol: "...)
p = append(p, c.subprotocol...)
p = append(p, "\r\n"...)
}
if compress {
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
}
for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" {
continue
}
for _, v := range vs {
p = append(p, k...)
p = append(p, ": "...)
for i := 0; i < len(v); i++ {
b := v[i]
if b <= 31 {
// prevent response splitting.
b = ' '
}
p = append(p, b)
}
p = append(p, "\r\n"...)
}
}
p = append(p, "\r\n"...)
// Clear deadlines set by HTTP server.
netConn.SetDeadline(time.Time{})
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
}
if _, err = netConn.Write(p); err != nil {
netConn.Close()
return nil, err
}
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Time{})
}
return c, nil
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// This function is deprecated, use websocket.Upgrader instead.
//
// The application is responsible for checking the request origin before
// calling Upgrade. An example implementation of the same origin policy is:
//
// if req.Header.Get("Origin") != "http://"+req.Host {
// http.Error(w, "Origin not allowed", 403)
// return
// }
//
// If the endpoint supports subprotocols, then the application is responsible
// for negotiating the protocol used on the connection. Use the Subprotocols()
// function to get the subprotocols requested by the client. Use the
// Sec-Websocket-Protocol response header to specify the subprotocol selected
// by the application.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// negotiated subprotocol (Sec-Websocket-Protocol).
//
// The connection buffers IO to the underlying network connection. The
// readBufSize and writeBufSize parameters specify the size of the buffers to
// use. Messages can be larger than the buffers.
//
// If the request is not a valid WebSocket handshake, then Upgrade returns an
// error of type HandshakeError. Applications should handle this error by
// replying to the client with an HTTP error response.
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) {
u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize}
u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
// don't return errors to maintain backwards compatibility
}
u.CheckOrigin = func(r *http.Request) bool {
// allow all connections by default
return true
}
return u.Upgrade(w, r, responseHeader)
}
// Subprotocols returns the subprotocols requested by the client in the
// Sec-Websocket-Protocol header.
func Subprotocols(r *http.Request) []string {
h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
if h == "" {
return nil
}
protocols := strings.Split(h, ",")
for i := range protocols {
protocols[i] = strings.TrimSpace(protocols[i])
}
return protocols
}
// IsWebSocketUpgrade returns true if the client requested upgrade to the
// WebSocket protocol.
func IsWebSocketUpgrade(r *http.Request) bool {
return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
tokenListContainsValue(r.Header, "Upgrade", "websocket")
}
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"io"
"net/http"
"strings"
)
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
func generateChallengeKey() (string, error) {
p := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, p); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(p), nil
}
// Octet types from RFC 2616.
var octetTypes [256]byte
const (
isTokenOctet = 1 << iota
isSpaceOctet
)
func init() {
// From RFC 2616
//
// OCTET = <any 8-bit sequence of data>
// CHAR = <any US-ASCII character (octets 0 - 127)>
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
// CR = <US-ASCII CR, carriage return (13)>
// LF = <US-ASCII LF, linefeed (10)>
// SP = <US-ASCII SP, space (32)>
// HT = <US-ASCII HT, horizontal-tab (9)>
// <"> = <US-ASCII double-quote mark (34)>
// CRLF = CR LF
// LWS = [CRLF] 1*( SP | HT )
// TEXT = <any OCTET except CTLs, but including LWS>
// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
// | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
// token = 1*<any CHAR except CTLs or separators>
// qdtext = <any TEXT except <">>
for c := 0; c < 256; c++ {
var t byte
isCtl := c <= 31 || c == 127
isChar := 0 <= c && c <= 127
isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0
if strings.IndexRune(" \t\r\n", rune(c)) >= 0 {
t |= isSpaceOctet
}
if isChar && !isCtl && !isSeparator {
t |= isTokenOctet
}
octetTypes[c] = t
}
}
func skipSpace(s string) (rest string) {
i := 0
for ; i < len(s); i++ {
if octetTypes[s[i]]&isSpaceOctet == 0 {
break
}
}
return s[i:]
}
func nextToken(s string) (token, rest string) {
i := 0
for ; i < len(s); i++ {
if octetTypes[s[i]]&isTokenOctet == 0 {
break
}
}
return s[:i], s[i:]
}
func nextTokenOrQuoted(s string) (value string, rest string) {
if !strings.HasPrefix(s, "\"") {
return nextToken(s)
}
s = s[1:]
for i := 0; i < len(s); i++ {
switch s[i] {
case '"':
return s[:i], s[i+1:]
case '\\':
p := make([]byte, len(s)-1)
j := copy(p, s[:i])
escape := true
for i = i + 1; i < len(s); i++ {
b := s[i]
switch {
case escape:
escape = false
p[j] = b
j += 1
case b == '\\':
escape = true
case b == '"':
return string(p[:j]), s[i+1:]
default:
p[j] = b
j += 1
}
}
return "", ""
}
}
return "", ""
}
// tokenListContainsValue returns true if the 1#token header with the given
// name contains token.
func tokenListContainsValue(header http.Header, name string, value string) bool {
headers:
for _, s := range header[name] {
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
continue headers
}
s = skipSpace(s)
if s != "" && s[0] != ',' {
continue headers
}
if strings.EqualFold(t, value) {
return true
}
if s == "" {
continue headers
}
s = s[1:]
}
}
return false
}
// parseExtensiosn parses WebSocket extensions from a header.
func parseExtensions(header http.Header) []map[string]string {
// From RFC 6455:
//
// Sec-WebSocket-Extensions = extension-list
// extension-list = 1#extension
// extension = extension-token *( ";" extension-param )
// extension-token = registered-token
// registered-token = token
// extension-param = token [ "=" (token | quoted-string) ]
// ;When using the quoted-string syntax variant, the value
// ;after quoted-string unescaping MUST conform to the
// ;'token' ABNF.
var result []map[string]string
headers:
for _, s := range header["Sec-Websocket-Extensions"] {
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
continue headers
}
ext := map[string]string{"": t}
for {
s = skipSpace(s)
if !strings.HasPrefix(s, ";") {
break
}
var k string
k, s = nextToken(skipSpace(s[1:]))
if k == "" {
continue headers
}
s = skipSpace(s)
var v string
if strings.HasPrefix(s, "=") {
v, s = nextTokenOrQuoted(skipSpace(s[1:]))
s = skipSpace(s)
}
if s != "" && s[0] != ',' && s[0] != ';' {
continue headers
}
ext[k] = v
}
if s != "" && s[0] != ',' {
continue headers
}
result = append(result, ext)
if s == "" {
continue headers
}
s = s[1:]
}
}
return result
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment