//go:build !js
// +build !js
package websocket
import (
"bytes"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/textproto"
"net/url"
"path"
"strings"
"github.com/coder/websocket/internal/errd"
)
// AcceptOptions represents Accept's options.
type AcceptOptions struct {
// Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client.
// The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
// reject it, close the connection when c.Subprotocol() == "".
Subprotocols []string
// InsecureSkipVerify is used to disable Accept's origin verification behaviour.
//
// You probably want to use OriginPatterns instead.
InsecureSkipVerify bool
// OriginPatterns lists the host patterns for authorized origins.
// The request host is always authorized.
// Use this to enable cross origin WebSockets.
//
// i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
// In such a case, example.com is the origin and chat.example.com is the request host.
// One would set this field to []string{"example.com"} to authorize example.com to connect.
//
// Each pattern is matched case insensitively against the request origin host
// with path.Match.
// See https://golang.org/pkg/path/#Match
//
// Please ensure you understand the ramifications of enabling this.
// If used incorrectly your WebSocket server will be open to CSRF attacks.
//
// Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead
// to bring attention to the danger of such a setting.
OriginPatterns []string
// CompressionMode controls the compression mode.
// Defaults to CompressionDisabled.
//
// See docs on CompressionMode for details.
CompressionMode CompressionMode
// CompressionThreshold controls the minimum size of a message before compression is applied.
//
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int
}
func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
var o AcceptOptions
if opts != nil {
o = *opts
}
return &o
}
// Accept accepts a WebSocket handshake from a client and upgrades the
// the connection to a WebSocket.
//
// Accept will not allow cross origin requests by default.
// See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
//
// Accept will write a response to w on all errors.
//
// Note that using the http.Request Context after Accept returns may lead to
// unexpected behavior (see http.Hijacker).
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
return accept(w, r, opts)
}
func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
defer errd.Wrap(&err, "failed to accept WebSocket connection")
errCode, err := verifyClientRequest(w, r)
if err != nil {
http.Error(w, err.Error(), errCode)
return nil, err
}
opts = opts.cloneWithDefaults()
if !opts.InsecureSkipVerify {
err = authenticateOrigin(r, opts.OriginPatterns)
if err != nil {
if errors.Is(err, path.ErrBadPattern) {
log.Printf("websocket: %v", err)
err = errors.New(http.StatusText(http.StatusForbidden))
}
http.Error(w, err.Error(), http.StatusForbidden)
return nil, err
}
}
hj, ok := hijacker(w)
if !ok {
err = errors.New("http.ResponseWriter does not implement http.Hijacker")
http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
return nil, err
}
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Connection", "Upgrade")
key := r.Header.Get("Sec-WebSocket-Key")
w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
subproto := selectSubprotocol(r, opts.Subprotocols)
if subproto != "" {
w.Header().Set("Sec-WebSocket-Protocol", subproto)
}
copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
if ok {
w.Header().Set("Sec-WebSocket-Extensions", copts.String())
}
w.WriteHeader(http.StatusSwitchingProtocols)
// See https://github.com/nhooyr/websocket/issues/166
if ginWriter, ok := w.(interface {
WriteHeaderNow()
}); ok {
ginWriter.WriteHeaderNow()
}
netConn, brw, err := hj.Hijack()
if err != nil {
err = fmt.Errorf("failed to hijack connection: %w", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return nil, err
}
// https://github.com/golang/go/issues/32314
b, _ := brw.Reader.Peek(brw.Reader.Buffered())
brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
return newConn(connConfig{
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
rwc: netConn,
client: false,
copts: copts,
flateThreshold: opts.CompressionThreshold,
br: brw.Reader,
bw: brw.Writer,
}), nil
}
func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
if !r.ProtoAtLeast(1, 1) {
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
}
if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
}
if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
}
if r.Method != "GET" {
return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
}
if r.Header.Get("Sec-WebSocket-Version") != "13" {
w.Header().Set("Sec-WebSocket-Version", "13")
return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
}
websocketSecKeys := r.Header.Values("Sec-WebSocket-Key")
if len(websocketSecKeys) == 0 {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
}
if len(websocketSecKeys) > 1 {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers")
}
// The RFC states to remove any leading or trailing whitespace.
websocketSecKey := strings.TrimSpace(websocketSecKeys[0])
if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 {
return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey)
}
return 0, nil
}
func authenticateOrigin(r *http.Request, originHosts []string) error {
origin := r.Header.Get("Origin")
if origin == "" {
return nil
}
u, err := url.Parse(origin)
if err != nil {
return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
}
if strings.EqualFold(r.Host, u.Host) {
return nil
}
for _, hostPattern := range originHosts {
matched, err := match(hostPattern, u.Host)
if err != nil {
return fmt.Errorf("failed to parse path pattern %q: %w", hostPattern, err)
}
if matched {
return nil
}
}
if u.Host == "" {
return fmt.Errorf("request Origin %q is not a valid URL with a host", origin)
}
return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host)
}
func match(pattern, s string) (bool, error) {
return path.Match(strings.ToLower(pattern), strings.ToLower(s))
}
func selectSubprotocol(r *http.Request, subprotocols []string) string {
cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
for _, sp := range subprotocols {
for _, cp := range cps {
if strings.EqualFold(sp, cp) {
return cp
}
}
}
return ""
}
func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
if mode == CompressionDisabled {
return nil, false
}
for _, ext := range extensions {
switch ext.name {
// We used to implement x-webkit-deflate-frame too for Safari but Safari has bugs...
// See https://github.com/nhooyr/websocket/issues/218
case "permessage-deflate":
copts, ok := acceptDeflate(ext, mode)
if ok {
return copts, true
}
}
}
return nil, false
}
func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
copts := mode.opts()
for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
copts.clientNoContextTakeover = true
continue
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
case "client_max_window_bits",
"server_max_window_bits=15":
continue
}
if strings.HasPrefix(p, "client_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}
return nil, false
}
return copts, true
}
func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
for _, t := range headerTokens(h, key) {
if strings.EqualFold(t, token) {
return true
}
}
return false
}
type websocketExtension struct {
name string
params []string
}
func websocketExtensions(h http.Header) []websocketExtension {
var exts []websocketExtension
extStrs := headerTokens(h, "Sec-WebSocket-Extensions")
for _, extStr := range extStrs {
if extStr == "" {
continue
}
vals := strings.Split(extStr, ";")
for i := range vals {
vals[i] = strings.TrimSpace(vals[i])
}
e := websocketExtension{
name: vals[0],
params: vals[1:],
}
exts = append(exts, e)
}
return exts
}
func headerTokens(h http.Header, key string) []string {
key = textproto.CanonicalMIMEHeaderKey(key)
var tokens []string
for _, v := range h[key] {
v = strings.TrimSpace(v)
for _, t := range strings.Split(v, ",") {
t = strings.TrimSpace(t)
tokens = append(tokens, t)
}
}
return tokens
}
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func secWebSocketAccept(secWebSocketKey string) string {
h := sha1.New()
h.Write([]byte(secWebSocketKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
//go:build !js
// +build !js
package websocket
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"time"
"github.com/coder/websocket/internal/errd"
)
// StatusCode represents a WebSocket status code.
// https://tools.ietf.org/html/rfc6455#section-7.4
type StatusCode int
// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
//
// These are only the status codes defined by the protocol.
//
// You can define custom codes in the 3000-4999 range.
// The 3000-3999 range is reserved for use by libraries, frameworks and applications.
// The 4000-4999 range is reserved for private use.
const (
StatusNormalClosure StatusCode = 1000
StatusGoingAway StatusCode = 1001
StatusProtocolError StatusCode = 1002
StatusUnsupportedData StatusCode = 1003
// 1004 is reserved and so unexported.
statusReserved StatusCode = 1004
// StatusNoStatusRcvd cannot be sent in a close message.
// It is reserved for when a close message is received without
// a status code.
StatusNoStatusRcvd StatusCode = 1005
// StatusAbnormalClosure is exported for use only with Wasm.
// In non Wasm Go, the returned error will indicate whether the
// connection was closed abnormally.
StatusAbnormalClosure StatusCode = 1006
StatusInvalidFramePayloadData StatusCode = 1007
StatusPolicyViolation StatusCode = 1008
StatusMessageTooBig StatusCode = 1009
StatusMandatoryExtension StatusCode = 1010
StatusInternalError StatusCode = 1011
StatusServiceRestart StatusCode = 1012
StatusTryAgainLater StatusCode = 1013
StatusBadGateway StatusCode = 1014
// StatusTLSHandshake is only exported for use with Wasm.
// In non Wasm Go, the returned error will indicate whether there was
// a TLS handshake failure.
StatusTLSHandshake StatusCode = 1015
)
// CloseError is returned when the connection is closed with a status and reason.
//
// Use Go 1.13's errors.As to check for this error.
// Also see the CloseStatus helper.
type CloseError struct {
Code StatusCode
Reason string
}
func (ce CloseError) Error() string {
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
}
// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
// the status code from a CloseError.
//
// -1 will be returned if the passed error is nil or not a CloseError.
func CloseStatus(err error) StatusCode {
var ce CloseError
if errors.As(err, &ce) {
return ce.Code
}
return -1
}
// Close performs the WebSocket close handshake with the given status code and reason.
//
// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
// the peer to send a close frame.
// All data messages received from the peer during the close handshake will be discarded.
//
// The connection can only be closed once. Additional calls to Close
// are no-ops.
//
// The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason.
//
// Close will unblock all goroutines interacting with the connection once
// complete.
func (c *Conn) Close(code StatusCode, reason string) (err error) {
defer errd.Wrap(&err, "failed to close WebSocket")
if !c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()
err = c.closeHandshake(code, reason)
err2 := c.close()
if err == nil && err2 != nil {
err = err2
}
err2 = c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}
return err
}
// CloseNow closes the WebSocket connection without attempting a close handshake.
// Use when you do not want the overhead of the close handshake.
func (c *Conn) CloseNow() (err error) {
defer errd.Wrap(&err, "failed to immediately close WebSocket")
if !c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()
err = c.close()
err2 := c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}
return err
}
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
err := c.writeClose(code, reason)
if err != nil {
return err
}
err = c.waitCloseHandshake()
if CloseStatus(err) != code {
return err
}
return nil
}
func (c *Conn) writeClose(code StatusCode, reason string) error {
ce := CloseError{
Code: code,
Reason: reason,
}
var p []byte
var err error
if ce.Code != StatusNoStatusRcvd {
p, err = ce.bytes()
if err != nil {
return err
}
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err = c.writeControl(ctx, opClose, p)
// If the connection closed as we're writing we ignore the error as we might
// have written the close frame, the peer responded and then someone else read it
// and closed the connection.
if err != nil && !errors.Is(err, net.ErrClosed) {
return err
}
return nil
}
func (c *Conn) waitCloseHandshake() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err := c.readMu.lock(ctx)
if err != nil {
return err
}
defer c.readMu.unlock()
for i := int64(0); i < c.msgReader.payloadLength; i++ {
_, err := c.br.ReadByte()
if err != nil {
return err
}
}
for {
h, err := c.readLoop(ctx)
if err != nil {
return err
}
for i := int64(0); i < h.payloadLength; i++ {
_, err := c.br.ReadByte()
if err != nil {
return err
}
}
}
}
func (c *Conn) waitGoroutines() error {
t := time.NewTimer(time.Second * 15)
defer t.Stop()
select {
case <-c.timeoutLoopDone:
case <-t.C:
return errors.New("failed to wait for timeoutLoop goroutine to exit")
}
c.closeReadMu.Lock()
closeRead := c.closeReadCtx != nil
c.closeReadMu.Unlock()
if closeRead {
select {
case <-c.closeReadDone:
case <-t.C:
return errors.New("failed to wait for close read goroutine to exit")
}
}
select {
case <-c.closed:
case <-t.C:
return errors.New("failed to wait for connection to be closed")
}
return nil
}
func parseClosePayload(p []byte) (CloseError, error) {
if len(p) == 0 {
return CloseError{
Code: StatusNoStatusRcvd,
}, nil
}
if len(p) < 2 {
return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
}
ce := CloseError{
Code: StatusCode(binary.BigEndian.Uint16(p)),
Reason: string(p[2:]),
}
if !validWireCloseCode(ce.Code) {
return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
}
return ce, nil
}
// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
// and https://tools.ietf.org/html/rfc6455#section-7.4.1
func validWireCloseCode(code StatusCode) bool {
switch code {
case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
return false
}
if code >= StatusNormalClosure && code <= StatusBadGateway {
return true
}
if code >= 3000 && code <= 4999 {
return true
}
return false
}
func (ce CloseError) bytes() ([]byte, error) {
p, err := ce.bytesErr()
if err != nil {
err = fmt.Errorf("failed to marshal close frame: %w", err)
ce = CloseError{
Code: StatusInternalError,
}
p, _ = ce.bytesErr()
}
return p, err
}
const maxCloseReason = maxControlPayload - 2
func (ce CloseError) bytesErr() ([]byte, error) {
if len(ce.Reason) > maxCloseReason {
return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
}
if !validWireCloseCode(ce.Code) {
return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
}
buf := make([]byte, 2+len(ce.Reason))
binary.BigEndian.PutUint16(buf, uint16(ce.Code))
copy(buf[2:], ce.Reason)
return buf, nil
}
func (c *Conn) casClosing() bool {
c.closeMu.Lock()
defer c.closeMu.Unlock()
if !c.closing {
c.closing = true
return true
}
return false
}
func (c *Conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
//go:build !js
// +build !js
package websocket
import (
"compress/flate"
"io"
"sync"
)
// CompressionMode represents the modes available to the permessage-deflate extension.
// See https://tools.ietf.org/html/rfc7692
//
// Works in all modern browsers except Safari which does not implement the permessage-deflate extension.
//
// Compression is only used if the peer supports the mode selected.
type CompressionMode int
const (
// CompressionDisabled disables the negotiation of the permessage-deflate extension.
//
// This is the default. Do not enable compression without benchmarking for your particular use case first.
CompressionDisabled CompressionMode = iota
// CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from
// previous messages. i.e compression context across messages is preserved.
//
// As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient.
//
// The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's
// that are used when reading and then returned.
//
// Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently.
//
// If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover.
CompressionContextTakeover
// CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with
// a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from
// a sync.Pool.
//
// This means less efficient compression as the sliding window from previous messages will not be used but the
// memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window.
// Especially if the connections are long lived and seldom written to.
//
// Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently.
//
// If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled.
CompressionNoContextTakeover
)
func (m CompressionMode) opts() *compressionOptions {
return &compressionOptions{
clientNoContextTakeover: m == CompressionNoContextTakeover,
serverNoContextTakeover: m == CompressionNoContextTakeover,
}
}
type compressionOptions struct {
clientNoContextTakeover bool
serverNoContextTakeover bool
}
func (copts *compressionOptions) String() string {
s := "permessage-deflate"
if copts.clientNoContextTakeover {
s += "; client_no_context_takeover"
}
if copts.serverNoContextTakeover {
s += "; server_no_context_takeover"
}
return s
}
// These bytes are required to get flate.Reader to return.
// They are removed when sending to avoid the overhead as
// WebSocket framing tell's when the message has ended but then
// we need to add them back otherwise flate.Reader keeps
// trying to read more bytes.
const deflateMessageTail = "\x00\x00\xff\xff"
type trimLastFourBytesWriter struct {
w io.Writer
tail []byte
}
func (tw *trimLastFourBytesWriter) reset() {
if tw != nil && tw.tail != nil {
tw.tail = tw.tail[:0]
}
}
func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
if tw.tail == nil {
tw.tail = make([]byte, 0, 4)
}
extra := len(tw.tail) + len(p) - 4
if extra <= 0 {
tw.tail = append(tw.tail, p...)
return len(p), nil
}
// Now we need to write as many extra bytes as we can from the previous tail.
if extra > len(tw.tail) {
extra = len(tw.tail)
}
if extra > 0 {
_, err := tw.w.Write(tw.tail[:extra])
if err != nil {
return 0, err
}
// Shift remaining bytes in tail over.
n := copy(tw.tail, tw.tail[extra:])
tw.tail = tw.tail[:n]
}
// If p is less than or equal to 4 bytes,
// all of it is is part of the tail.
if len(p) <= 4 {
tw.tail = append(tw.tail, p...)
return len(p), nil
}
// Otherwise, only the last 4 bytes are.
tw.tail = append(tw.tail, p[len(p)-4:]...)
p = p[:len(p)-4]
n, err := tw.w.Write(p)
return n + 4, err
}
var flateReaderPool sync.Pool
func getFlateReader(r io.Reader, dict []byte) io.Reader {
fr, ok := flateReaderPool.Get().(io.Reader)
if !ok {
return flate.NewReaderDict(r, dict)
}
fr.(flate.Resetter).Reset(r, dict)
return fr
}
func putFlateReader(fr io.Reader) {
flateReaderPool.Put(fr)
}
var flateWriterPool sync.Pool
func getFlateWriter(w io.Writer) *flate.Writer {
fw, ok := flateWriterPool.Get().(*flate.Writer)
if !ok {
fw, _ = flate.NewWriter(w, flate.BestSpeed)
return fw
}
fw.Reset(w)
return fw
}
func putFlateWriter(w *flate.Writer) {
flateWriterPool.Put(w)
}
type slidingWindow struct {
buf []byte
}
var swPoolMu sync.RWMutex
var swPool = map[int]*sync.Pool{}
func slidingWindowPool(n int) *sync.Pool {
swPoolMu.RLock()
p, ok := swPool[n]
swPoolMu.RUnlock()
if ok {
return p
}
p = &sync.Pool{}
swPoolMu.Lock()
swPool[n] = p
swPoolMu.Unlock()
return p
}
func (sw *slidingWindow) init(n int) {
if sw.buf != nil {
return
}
if n == 0 {
n = 32768
}
p := slidingWindowPool(n)
sw2, ok := p.Get().(*slidingWindow)
if ok {
*sw = *sw2
} else {
sw.buf = make([]byte, 0, n)
}
}
func (sw *slidingWindow) close() {
sw.buf = sw.buf[:0]
swPoolMu.Lock()
swPool[cap(sw.buf)].Put(sw)
swPoolMu.Unlock()
}
func (sw *slidingWindow) write(p []byte) {
if len(p) >= cap(sw.buf) {
sw.buf = sw.buf[:cap(sw.buf)]
p = p[len(p)-cap(sw.buf):]
copy(sw.buf, p)
return
}
left := cap(sw.buf) - len(sw.buf)
if left < len(p) {
// We need to shift spaceNeeded bytes from the end to make room for p at the end.
spaceNeeded := len(p) - left
copy(sw.buf, sw.buf[spaceNeeded:])
sw.buf = sw.buf[:len(sw.buf)-spaceNeeded]
}
sw.buf = append(sw.buf, p...)
}
//go:build !js
// +build !js
package websocket
import (
"bufio"
"context"
"fmt"
"io"
"net"
"runtime"
"strconv"
"sync"
"sync/atomic"
)
// MessageType represents the type of a WebSocket message.
// See https://tools.ietf.org/html/rfc6455#section-5.6
type MessageType int
// MessageType constants.
const (
// MessageText is for UTF-8 encoded text messages like JSON.
MessageText MessageType = iota + 1
// MessageBinary is for binary messages like protobufs.
MessageBinary
)
// Conn represents a WebSocket connection.
// All methods may be called concurrently except for Reader and Read.
//
// You must always read from the connection. Otherwise control
// frames will not be handled. See Reader and CloseRead.
//
// Be sure to call Close on the connection when you
// are finished with it to release associated resources.
//
// On any error from any method, the connection is closed
// with an appropriate reason.
//
// This applies to context expirations as well unfortunately.
// See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220
type Conn struct {
noCopy noCopy
subprotocol string
rwc io.ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
br *bufio.Reader
bw *bufio.Writer
readTimeout chan context.Context
writeTimeout chan context.Context
timeoutLoopDone chan struct{}
// Read state.
readMu *mu
readHeaderBuf [8]byte
readControlBuf [maxControlPayload]byte
msgReader *msgReader
// Write state.
msgWriter *msgWriter
writeFrameMu *mu
writeBuf []byte
writeHeaderBuf [8]byte
writeHeader header
closeReadMu sync.Mutex
closeReadCtx context.Context
closeReadDone chan struct{}
closed chan struct{}
closeMu sync.Mutex
closing bool
pingCounter atomic.Int64
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
}
type connConfig struct {
subprotocol string
rwc io.ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
br *bufio.Reader
bw *bufio.Writer
}
func newConn(cfg connConfig) *Conn {
c := &Conn{
subprotocol: cfg.subprotocol,
rwc: cfg.rwc,
client: cfg.client,
copts: cfg.copts,
flateThreshold: cfg.flateThreshold,
br: cfg.br,
bw: cfg.bw,
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
timeoutLoopDone: make(chan struct{}),
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
}
c.readMu = newMu(c)
c.writeFrameMu = newMu(c)
c.msgReader = newMsgReader(c)
c.msgWriter = newMsgWriter(c)
if c.client {
c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
}
if c.flate() && c.flateThreshold == 0 {
c.flateThreshold = 128
if !c.msgWriter.flateContextTakeover() {
c.flateThreshold = 512
}
}
runtime.SetFinalizer(c, func(c *Conn) {
c.close()
})
go c.timeoutLoop()
return c
}
// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
return c.subprotocol
}
func (c *Conn) close() error {
c.closeMu.Lock()
defer c.closeMu.Unlock()
if c.isClosed() {
return net.ErrClosed
}
runtime.SetFinalizer(c, nil)
close(c.closed)
// Have to close after c.closed is closed to ensure any goroutine that wakes up
// from the connection being closed also sees that c.closed is closed and returns
// closeErr.
err := c.rwc.Close()
// With the close of rwc, these become safe to close.
c.msgWriter.close()
c.msgReader.close()
return err
}
func (c *Conn) timeoutLoop() {
defer close(c.timeoutLoopDone)
readCtx := context.Background()
writeCtx := context.Background()
for {
select {
case <-c.closed:
return
case writeCtx = <-c.writeTimeout:
case readCtx = <-c.readTimeout:
case <-readCtx.Done():
c.close()
return
case <-writeCtx.Done():
c.close()
return
}
}
}
func (c *Conn) flate() bool {
return c.copts != nil
}
// Ping sends a ping to the peer and waits for a pong.
// Use this to measure latency or ensure the peer is responsive.
// Ping must be called concurrently with Reader as it does
// not read from the connection but instead waits for a Reader call
// to read the pong.
//
// TCP Keepalives should suffice for most use cases.
func (c *Conn) Ping(ctx context.Context) error {
p := c.pingCounter.Add(1)
err := c.ping(ctx, strconv.FormatInt(p, 10))
if err != nil {
return fmt.Errorf("failed to ping: %w", err)
}
return nil
}
func (c *Conn) ping(ctx context.Context, p string) error {
pong := make(chan struct{}, 1)
c.activePingsMu.Lock()
c.activePings[p] = pong
c.activePingsMu.Unlock()
defer func() {
c.activePingsMu.Lock()
delete(c.activePings, p)
c.activePingsMu.Unlock()
}()
err := c.writeControl(ctx, opPing, []byte(p))
if err != nil {
return err
}
select {
case <-c.closed:
return net.ErrClosed
case <-ctx.Done():
return fmt.Errorf("failed to wait for pong: %w", ctx.Err())
case <-pong:
return nil
}
}
type mu struct {
c *Conn
ch chan struct{}
}
func newMu(c *Conn) *mu {
return &mu{
c: c,
ch: make(chan struct{}, 1),
}
}
func (m *mu) forceLock() {
m.ch <- struct{}{}
}
func (m *mu) tryLock() bool {
select {
case m.ch <- struct{}{}:
return true
default:
return false
}
}
func (m *mu) lock(ctx context.Context) error {
select {
case <-m.c.closed:
return net.ErrClosed
case <-ctx.Done():
return fmt.Errorf("failed to acquire lock: %w", ctx.Err())
case m.ch <- struct{}{}:
// To make sure the connection is certainly alive.
// As it's possible the send on m.ch was selected
// over the receive on closed.
select {
case <-m.c.closed:
// Make sure to release.
m.unlock()
return net.ErrClosed
default:
}
return nil
}
}
func (m *mu) unlock() {
select {
case <-m.ch:
default:
}
}
type noCopy struct{}
func (*noCopy) Lock() {}
//go:build !js
// +build !js
package websocket
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/coder/websocket/internal/errd"
)
// DialOptions represents Dial's options.
type DialOptions struct {
// HTTPClient is used for the connection.
// Its Transport must return writable bodies for WebSocket handshakes.
// http.Transport does beginning with Go 1.12.
HTTPClient *http.Client
// HTTPHeader specifies the HTTP headers included in the handshake request.
HTTPHeader http.Header
// Host optionally overrides the Host HTTP header to send. If empty, the value
// of URL.Host will be used.
Host string
// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
Subprotocols []string
// CompressionMode controls the compression mode.
// Defaults to CompressionDisabled.
//
// See docs on CompressionMode for details.
CompressionMode CompressionMode
// CompressionThreshold controls the minimum size of a message before compression is applied.
//
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int
}
func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
var cancel context.CancelFunc
var o DialOptions
if opts != nil {
o = *opts
}
if o.HTTPClient == nil {
o.HTTPClient = http.DefaultClient
}
if o.HTTPClient.Timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, o.HTTPClient.Timeout)
newClient := *o.HTTPClient
newClient.Timeout = 0
o.HTTPClient = &newClient
}
if o.HTTPHeader == nil {
o.HTTPHeader = http.Header{}
}
newClient := *o.HTTPClient
oldCheckRedirect := o.HTTPClient.CheckRedirect
newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
switch req.URL.Scheme {
case "ws":
req.URL.Scheme = "http"
case "wss":
req.URL.Scheme = "https"
}
if oldCheckRedirect != nil {
return oldCheckRedirect(req, via)
}
return nil
}
o.HTTPClient = &newClient
return ctx, cancel, &o
}
// Dial performs a WebSocket handshake on url.
//
// The response is the WebSocket handshake response from the server.
// You never need to close resp.Body yourself.
//
// If an error occurs, the returned response may be non nil.
// However, you can only read the first 1024 bytes of the body.
//
// This function requires at least Go 1.12 as it uses a new feature
// in net/http to perform WebSocket handshakes.
// See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
//
// URLs with http/https schemes will work and are interpreted as ws/wss.
func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
return dial(ctx, u, opts, nil)
}
func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
defer errd.Wrap(&err, "failed to WebSocket dial")
var cancel context.CancelFunc
ctx, cancel, opts = opts.cloneWithDefaults(ctx)
if cancel != nil {
defer cancel()
}
secWebSocketKey, err := secWebSocketKey(rand)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
}
var copts *compressionOptions
if opts.CompressionMode != CompressionDisabled {
copts = opts.CompressionMode.opts()
}
resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
if err != nil {
return nil, resp, err
}
respBody := resp.Body
resp.Body = nil
defer func() {
if err != nil {
// We read a bit of the body for easier debugging.
r := io.LimitReader(respBody, 1024)
timer := time.AfterFunc(time.Second*3, func() {
respBody.Close()
})
defer timer.Stop()
b, _ := io.ReadAll(r)
respBody.Close()
resp.Body = io.NopCloser(bytes.NewReader(b))
}
}()
copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
if err != nil {
return nil, resp, err
}
rwc, ok := respBody.(io.ReadWriteCloser)
if !ok {
return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
}
return newConn(connConfig{
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
rwc: rwc,
client: true,
copts: copts,
flateThreshold: opts.CompressionThreshold,
br: getBufioReader(rwc),
bw: getBufioWriter(rwc),
}), resp, nil
}
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
u, err := url.Parse(urls)
if err != nil {
return nil, fmt.Errorf("failed to parse url: %w", err)
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
case "http", "https":
default:
return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
}
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create new http request: %w", err)
}
if len(opts.Host) > 0 {
req.Host = opts.Host
}
req.Header = opts.HTTPHeader.Clone()
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
if copts != nil {
req.Header.Set("Sec-WebSocket-Extensions", copts.String())
}
resp, err := opts.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send handshake request: %w", err)
}
return resp, nil
}
func secWebSocketKey(rr io.Reader) (string, error) {
if rr == nil {
rr = rand.Reader
}
b := make([]byte, 16)
_, err := io.ReadFull(rr, b)
if err != nil {
return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
}
return base64.StdEncoding.EncodeToString(b), nil
}
func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
}
if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
}
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
resp.Header.Get("Sec-WebSocket-Accept"),
secWebSocketKey,
)
}
err := verifySubprotocol(opts.Subprotocols, resp)
if err != nil {
return nil, err
}
return verifyServerExtensions(copts, resp.Header)
}
func verifySubprotocol(subprotos []string, resp *http.Response) error {
proto := resp.Header.Get("Sec-WebSocket-Protocol")
if proto == "" {
return nil
}
for _, sp2 := range subprotos {
if strings.EqualFold(sp2, proto) {
return nil
}
}
return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
}
func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
exts := websocketExtensions(h)
if len(exts) == 0 {
return nil, nil
}
ext := exts[0]
if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
}
_copts := *copts
copts = &_copts
for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
copts.clientNoContextTakeover = true
continue
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
}
if strings.HasPrefix(p, "server_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
}
return copts, nil
}
var bufioReaderPool sync.Pool
func getBufioReader(r io.Reader) *bufio.Reader {
br, ok := bufioReaderPool.Get().(*bufio.Reader)
if !ok {
return bufio.NewReader(r)
}
br.Reset(r)
return br
}
func putBufioReader(br *bufio.Reader) {
bufioReaderPool.Put(br)
}
var bufioWriterPool sync.Pool
func getBufioWriter(w io.Writer) *bufio.Writer {
bw, ok := bufioWriterPool.Get().(*bufio.Writer)
if !ok {
return bufio.NewWriter(w)
}
bw.Reset(w)
return bw
}
func putBufioWriter(bw *bufio.Writer) {
bufioWriterPool.Put(bw)
}
//go:build !js
package websocket
import (
"bufio"
"encoding/binary"
"fmt"
"io"
"math"
"github.com/coder/websocket/internal/errd"
)
// opcode represents a WebSocket opcode.
type opcode int
// https://tools.ietf.org/html/rfc6455#section-11.8.
const (
opContinuation opcode = iota
opText
opBinary
// 3 - 7 are reserved for further non-control frames.
_
_
_
_
_
opClose
opPing
opPong
// 11-16 are reserved for further control frames.
)
// header represents a WebSocket frame header.
// See https://tools.ietf.org/html/rfc6455#section-5.2.
type header struct {
fin bool
rsv1 bool
rsv2 bool
rsv3 bool
opcode opcode
payloadLength int64
masked bool
maskKey uint32
}
// readFrameHeader reads a header from the reader.
// See https://tools.ietf.org/html/rfc6455#section-5.2.
func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) {
defer errd.Wrap(&err, "failed to read frame header")
b, err := r.ReadByte()
if err != nil {
return header{}, err
}
h.fin = b&(1<<7) != 0
h.rsv1 = b&(1<<6) != 0
h.rsv2 = b&(1<<5) != 0
h.rsv3 = b&(1<<4) != 0
h.opcode = opcode(b & 0xf)
b, err = r.ReadByte()
if err != nil {
return header{}, err
}
h.masked = b&(1<<7) != 0
payloadLength := b &^ (1 << 7)
switch {
case payloadLength < 126:
h.payloadLength = int64(payloadLength)
case payloadLength == 126:
_, err = io.ReadFull(r, readBuf[:2])
h.payloadLength = int64(binary.BigEndian.Uint16(readBuf))
case payloadLength == 127:
_, err = io.ReadFull(r, readBuf)
h.payloadLength = int64(binary.BigEndian.Uint64(readBuf))
}
if err != nil {
return header{}, err
}
if h.payloadLength < 0 {
return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength)
}
if h.masked {
_, err = io.ReadFull(r, readBuf[:4])
if err != nil {
return header{}, err
}
h.maskKey = binary.LittleEndian.Uint32(readBuf)
}
return h, nil
}
// maxControlPayload is the maximum length of a control frame payload.
// See https://tools.ietf.org/html/rfc6455#section-5.5.
const maxControlPayload = 125
// writeFrameHeader writes the bytes of the header to w.
// See https://tools.ietf.org/html/rfc6455#section-5.2
func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {
defer errd.Wrap(&err, "failed to write frame header")
var b byte
if h.fin {
b |= 1 << 7
}
if h.rsv1 {
b |= 1 << 6
}
if h.rsv2 {
b |= 1 << 5
}
if h.rsv3 {
b |= 1 << 4
}
b |= byte(h.opcode)
err = w.WriteByte(b)
if err != nil {
return err
}
lengthByte := byte(0)
if h.masked {
lengthByte |= 1 << 7
}
switch {
case h.payloadLength > math.MaxUint16:
lengthByte |= 127
case h.payloadLength > 125:
lengthByte |= 126
case h.payloadLength >= 0:
lengthByte |= byte(h.payloadLength)
}
err = w.WriteByte(lengthByte)
if err != nil {
return err
}
switch {
case h.payloadLength > math.MaxUint16:
binary.BigEndian.PutUint64(buf, uint64(h.payloadLength))
_, err = w.Write(buf)
case h.payloadLength > 125:
binary.BigEndian.PutUint16(buf, uint16(h.payloadLength))
_, err = w.Write(buf[:2])
}
if err != nil {
return err
}
if h.masked {
binary.LittleEndian.PutUint32(buf, h.maskKey)
_, err = w.Write(buf[:4])
if err != nil {
return err
}
}
return nil
}
//go:build !js
package websocket
import (
"net/http"
)
type rwUnwrapper interface {
Unwrap() http.ResponseWriter
}
// hijacker returns the Hijacker interface of the http.ResponseWriter.
// It follows the Unwrap method of the http.ResponseWriter if available,
// matching the behavior of http.ResponseController. If the Hijacker
// interface is not found, it returns false.
//
// Since the http.ResponseController is not available in Go 1.19, and
// does not support checking the presence of the Hijacker interface,
// this function is used to provide a consistent way to check for the
// Hijacker interface across Go versions.
func hijacker(rw http.ResponseWriter) (http.Hijacker, bool) {
for {
switch t := rw.(type) {
case http.Hijacker:
return t, true
case rwUnwrapper:
rw = t.Unwrap()
default:
return nil, false
}
}
}
package bpool
import (
"bytes"
"sync"
)
var bpool = sync.Pool{
New: func() any {
return &bytes.Buffer{}
},
}
// Get returns a buffer from the pool or creates a new one if
// the pool is empty.
func Get() *bytes.Buffer {
b := bpool.Get()
return b.(*bytes.Buffer)
}
// Put returns a buffer into the pool.
func Put(b *bytes.Buffer) {
b.Reset()
bpool.Put(b)
}
package errd
import (
"fmt"
)
// Wrap wraps err with fmt.Errorf if err is non nil.
// Intended for use with defer and a named error return.
// Inspired by https://github.com/golang/go/issues/32676.
func Wrap(err *error, f string, v ...interface{}) {
if *err != nil {
*err = fmt.Errorf(f+": %w", append(v, *err)...)
}
}
package assert
import (
"errors"
"fmt"
"reflect"
"strings"
"testing"
)
// Equal asserts exp == act.
func Equal(t testing.TB, name string, exp, got interface{}) {
t.Helper()
if !reflect.DeepEqual(exp, got) {
t.Fatalf("unexpected %v: expected %#v but got %#v", name, exp, got)
}
}
// Success asserts err == nil.
func Success(t testing.TB, err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// Error asserts err != nil.
func Error(t testing.TB, err error) {
t.Helper()
if err == nil {
t.Fatal("expected error")
}
}
// Contains asserts the fmt.Sprint(v) contains sub.
func Contains(t testing.TB, v interface{}, sub string) {
t.Helper()
s := fmt.Sprint(v)
if !strings.Contains(s, sub) {
t.Fatalf("expected %q to contain %q", s, sub)
}
}
// ErrorIs asserts errors.Is(got, exp)
func ErrorIs(t testing.TB, exp, got error) {
t.Helper()
if !errors.Is(got, exp) {
t.Fatalf("expected %v but got %v", exp, got)
}
}
package wstest
import (
"bytes"
"context"
"fmt"
"io"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/test/xrand"
"github.com/coder/websocket/internal/xsync"
)
// EchoLoop echos every msg received from c until an error
// occurs or the context expires.
// The read limit is set to 1 << 30.
func EchoLoop(ctx context.Context, c *websocket.Conn) error {
defer c.Close(websocket.StatusInternalError, "")
c.SetReadLimit(1 << 30)
ctx, cancel := context.WithTimeout(ctx, time.Minute*5)
defer cancel()
b := make([]byte, 32<<10)
for {
typ, r, err := c.Reader(ctx)
if err != nil {
return err
}
w, err := c.Writer(ctx, typ)
if err != nil {
return err
}
_, err = io.CopyBuffer(w, r, b)
if err != nil {
return err
}
err = w.Close()
if err != nil {
return err
}
}
}
// Echo writes a message and ensures the same is sent back on c.
func Echo(ctx context.Context, c *websocket.Conn, max int) error {
expType := websocket.MessageBinary
if xrand.Bool() {
expType = websocket.MessageText
}
msg := randMessage(expType, xrand.Int(max))
writeErr := xsync.Go(func() error {
return c.Write(ctx, expType, msg)
})
actType, act, err := c.Read(ctx)
if err != nil {
return err
}
err = <-writeErr
if err != nil {
return err
}
if expType != actType {
return fmt.Errorf("unexpected message typ (%v): %v", expType, actType)
}
if !bytes.Equal(msg, act) {
return fmt.Errorf("unexpected msg read: %#v", act)
}
return nil
}
func randMessage(typ websocket.MessageType, n int) []byte {
if typ == websocket.MessageBinary {
return xrand.Bytes(n)
}
return []byte(xrand.String(n))
}
//go:build !js
// +build !js
package wstest
import (
"bufio"
"context"
"net"
"net/http"
"net/http/httptest"
"github.com/coder/websocket"
)
// Pipe is used to create an in memory connection
// between two websockets analogous to net.Pipe.
func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (clientConn, serverConn *websocket.Conn) {
tt := fakeTransport{
h: func(w http.ResponseWriter, r *http.Request) {
serverConn, _ = websocket.Accept(w, r, acceptOpts)
},
}
if dialOpts == nil {
dialOpts = &websocket.DialOptions{}
}
_dialOpts := *dialOpts
dialOpts = &_dialOpts
dialOpts.HTTPClient = &http.Client{
Transport: tt,
}
clientConn, _, _ = websocket.Dial(context.Background(), "ws://example.com", dialOpts)
return clientConn, serverConn
}
type fakeTransport struct {
h http.HandlerFunc
}
func (t fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) {
clientConn, serverConn := net.Pipe()
hj := testHijacker{
ResponseRecorder: httptest.NewRecorder(),
serverConn: serverConn,
}
t.h.ServeHTTP(hj, r)
resp := hj.ResponseRecorder.Result()
if resp.StatusCode == http.StatusSwitchingProtocols {
resp.Body = clientConn
}
return resp, nil
}
type testHijacker struct {
*httptest.ResponseRecorder
serverConn net.Conn
}
var _ http.Hijacker = testHijacker{}
func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil
}
package xrand
import (
"crypto/rand"
"encoding/base64"
"fmt"
"math/big"
"strings"
)
// Bytes generates random bytes with length n.
func Bytes(n int) []byte {
b := make([]byte, n)
_, err := rand.Reader.Read(b)
if err != nil {
panic(fmt.Sprintf("failed to generate rand bytes: %v", err))
}
return b
}
// String generates a random string with length n.
func String(n int) string {
s := strings.ToValidUTF8(string(Bytes(n)), "_")
s = strings.ReplaceAll(s, "\x00", "_")
if len(s) > n {
return s[:n]
}
if len(s) < n {
// Pad with =
extra := n - len(s)
return s + strings.Repeat("=", extra)
}
return s
}
// Bool returns a randomly generated boolean.
func Bool() bool {
return Int(2) == 1
}
// Int returns a randomly generated integer between [0, max).
func Int(max int) int {
x, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
if err != nil {
panic(fmt.Sprintf("failed to get random int: %v", err))
}
return int(x.Int64())
}
// Base64 returns a randomly generated base64 string of length n.
func Base64(n int) string {
return base64.StdEncoding.EncodeToString(Bytes(n))
}
package util
// WriterFunc is used to implement one off io.Writers.
type WriterFunc func(p []byte) (int, error)
func (f WriterFunc) Write(p []byte) (int, error) {
return f(p)
}
// ReaderFunc is used to implement one off io.Readers.
type ReaderFunc func(p []byte) (int, error)
func (f ReaderFunc) Read(p []byte) (int, error) {
return f(p)
}
package xsync
import (
"fmt"
"runtime/debug"
)
// Go allows running a function in another goroutine
// and waiting for its error.
func Go(fn func() error) <-chan error {
errs := make(chan error, 1)
go func() {
defer func() {
r := recover()
if r != nil {
select {
case errs <- fmt.Errorf("panic in go fn: %v, %s", r, debug.Stack()):
default:
}
}
}()
errs <- fn()
}()
return errs
}
package websocket
import (
"encoding/binary"
"math/bits"
)
// maskGo applies the WebSocket masking algorithm to p
// with the given key.
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// The returned value is the correctly rotated key to
// to continue to mask/unmask the message.
//
// It is optimized for LittleEndian and expects the key
// to be in little endian.
//
// See https://github.com/golang/go/issues/31586
func maskGo(b []byte, key uint32) uint32 {
if len(b) >= 8 {
key64 := uint64(key)<<32 | uint64(key)
// At some point in the future we can clean these unrolled loops up.
// See https://github.com/golang/go/issues/31586#issuecomment-487436401
// Then we xor until b is less than 128 bytes.
for len(b) >= 128 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
v = binary.LittleEndian.Uint64(b[64:72])
binary.LittleEndian.PutUint64(b[64:72], v^key64)
v = binary.LittleEndian.Uint64(b[72:80])
binary.LittleEndian.PutUint64(b[72:80], v^key64)
v = binary.LittleEndian.Uint64(b[80:88])
binary.LittleEndian.PutUint64(b[80:88], v^key64)
v = binary.LittleEndian.Uint64(b[88:96])
binary.LittleEndian.PutUint64(b[88:96], v^key64)
v = binary.LittleEndian.Uint64(b[96:104])
binary.LittleEndian.PutUint64(b[96:104], v^key64)
v = binary.LittleEndian.Uint64(b[104:112])
binary.LittleEndian.PutUint64(b[104:112], v^key64)
v = binary.LittleEndian.Uint64(b[112:120])
binary.LittleEndian.PutUint64(b[112:120], v^key64)
v = binary.LittleEndian.Uint64(b[120:128])
binary.LittleEndian.PutUint64(b[120:128], v^key64)
b = b[128:]
}
// Then we xor until b is less than 64 bytes.
for len(b) >= 64 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
b = b[64:]
}
// Then we xor until b is less than 32 bytes.
for len(b) >= 32 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
b = b[32:]
}
// Then we xor until b is less than 16 bytes.
for len(b) >= 16 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
b = b[16:]
}
// Then we xor until b is less than 8 bytes.
for len(b) >= 8 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
b = b[8:]
}
}
// Then we xor until b is less than 4 bytes.
for len(b) >= 4 {
v := binary.LittleEndian.Uint32(b)
binary.LittleEndian.PutUint32(b, v^key)
b = b[4:]
}
// xor remaining bytes.
for i := range b {
b[i] ^= byte(key)
key = bits.RotateLeft32(key, -8)
}
return key
}
//go:build amd64 || arm64
package websocket
func mask(b []byte, key uint32) uint32 {
// TODO: Will enable in v1.9.0.
return maskGo(b, key)
/*
if len(b) > 0 {
return maskAsm(&b[0], len(b), key)
}
return key
*/
}
// @nhooyr: I am not confident that the amd64 or the arm64 implementations of this
// function are perfect. There are almost certainly missing optimizations or
// opportunities for simplification. I'm confident there are no bugs though.
// For example, the arm64 implementation doesn't align memory like the amd64.
// Or the amd64 implementation could use AVX512 instead of just AVX2.
// The AVX2 code I had to disable anyway as it wasn't performing as expected.
// See https://github.com/nhooyr/websocket/pull/326#issuecomment-1771138049
//
//go:noescape
//lint:ignore U1000 disabled till v1.9.0
func maskAsm(b *byte, len int, key uint32) uint32
package websocket
import (
"context"
"fmt"
"io"
"math"
"net"
"sync/atomic"
"time"
)
// NetConn converts a *websocket.Conn into a net.Conn.
//
// It's for tunneling arbitrary protocols over WebSockets.
// Few users of the library will need this but it's tricky to implement
// correctly and so provided in the library.
// See https://github.com/nhooyr/websocket/issues/100.
//
// Every Write to the net.Conn will correspond to a message write of
// the given type on *websocket.Conn.
//
// The passed ctx bounds the lifetime of the net.Conn. If cancelled,
// all reads and writes on the net.Conn will be cancelled.
//
// If a message is read that is not of the correct type, the connection
// will be closed with StatusUnsupportedData and an error will be returned.
//
// Close will close the *websocket.Conn with StatusNormalClosure.
//
// When a deadline is hit and there is an active read or write goroutine, the
// connection will be closed. This is different from most net.Conn implementations
// where only the reading/writing goroutines are interrupted but the connection
// is kept alive.
//
// The Addr methods will return the real addresses for connections obtained
// from websocket.Accept. But for connections obtained from websocket.Dial, a mock net.Addr
// will be returned that gives "websocket" for Network() and "websocket/unknown-addr" for
// String(). This is because websocket.Dial only exposes a io.ReadWriteCloser instead of the
// full net.Conn to us.
//
// When running as WASM, the Addr methods will always return the mock address described above.
//
// A received StatusNormalClosure or StatusGoingAway close frame will be translated to
// io.EOF when reading.
//
// Furthermore, the ReadLimit is set to -1 to disable it.
func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
c.SetReadLimit(-1)
nc := &netConn{
c: c,
msgType: msgType,
readMu: newMu(c),
writeMu: newMu(c),
}
nc.writeCtx, nc.writeCancel = context.WithCancel(ctx)
nc.readCtx, nc.readCancel = context.WithCancel(ctx)
nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
if !nc.writeMu.tryLock() {
// If the lock cannot be acquired, then there is an
// active write goroutine and so we should cancel the context.
nc.writeCancel()
return
}
defer nc.writeMu.unlock()
// Prevents future writes from writing until the deadline is reset.
nc.writeExpired.Store(1)
})
if !nc.writeTimer.Stop() {
<-nc.writeTimer.C
}
nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
if !nc.readMu.tryLock() {
// If the lock cannot be acquired, then there is an
// active read goroutine and so we should cancel the context.
nc.readCancel()
return
}
defer nc.readMu.unlock()
// Prevents future reads from reading until the deadline is reset.
nc.readExpired.Store(1)
})
if !nc.readTimer.Stop() {
<-nc.readTimer.C
}
return nc
}
type netConn struct {
c *Conn
msgType MessageType
writeTimer *time.Timer
writeMu *mu
writeExpired atomic.Int64
writeCtx context.Context
writeCancel context.CancelFunc
readTimer *time.Timer
readMu *mu
readExpired atomic.Int64
readCtx context.Context
readCancel context.CancelFunc
readEOFed bool
reader io.Reader
}
var _ net.Conn = &netConn{}
func (nc *netConn) Close() error {
nc.writeTimer.Stop()
nc.writeCancel()
nc.readTimer.Stop()
nc.readCancel()
return nc.c.Close(StatusNormalClosure, "")
}
func (nc *netConn) Write(p []byte) (int, error) {
nc.writeMu.forceLock()
defer nc.writeMu.unlock()
if nc.writeExpired.Load() == 1 {
return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded)
}
err := nc.c.Write(nc.writeCtx, nc.msgType, p)
if err != nil {
return 0, err
}
return len(p), nil
}
func (nc *netConn) Read(p []byte) (int, error) {
nc.readMu.forceLock()
defer nc.readMu.unlock()
for {
n, err := nc.read(p)
if err != nil {
return n, err
}
if n == 0 {
continue
}
return n, nil
}
}
func (nc *netConn) read(p []byte) (int, error) {
if nc.readExpired.Load() == 1 {
return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded)
}
if nc.readEOFed {
return 0, io.EOF
}
if nc.reader == nil {
typ, r, err := nc.c.Reader(nc.readCtx)
if err != nil {
switch CloseStatus(err) {
case StatusNormalClosure, StatusGoingAway:
nc.readEOFed = true
return 0, io.EOF
}
return 0, err
}
if typ != nc.msgType {
err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ)
nc.c.Close(StatusUnsupportedData, err.Error())
return 0, err
}
nc.reader = r
}
n, err := nc.reader.Read(p)
if err == io.EOF {
nc.reader = nil
err = nil
}
return n, err
}
type websocketAddr struct {
}
func (a websocketAddr) Network() string {
return "websocket"
}
func (a websocketAddr) String() string {
return "websocket/unknown-addr"
}
func (nc *netConn) SetDeadline(t time.Time) error {
nc.SetWriteDeadline(t)
nc.SetReadDeadline(t)
return nil
}
func (nc *netConn) SetWriteDeadline(t time.Time) error {
nc.writeExpired.Store(0)
if t.IsZero() {
nc.writeTimer.Stop()
} else {
dur := time.Until(t)
if dur <= 0 {
dur = 1
}
nc.writeTimer.Reset(dur)
}
return nil
}
func (nc *netConn) SetReadDeadline(t time.Time) error {
nc.readExpired.Store(0)
if t.IsZero() {
nc.readTimer.Stop()
} else {
dur := time.Until(t)
if dur <= 0 {
dur = 1
}
nc.readTimer.Reset(dur)
}
return nil
}
//go:build !js
// +build !js
package websocket
import "net"
func (nc *netConn) RemoteAddr() net.Addr {
if unc, ok := nc.c.rwc.(net.Conn); ok {
return unc.RemoteAddr()
}
return websocketAddr{}
}
func (nc *netConn) LocalAddr() net.Addr {
if unc, ok := nc.c.rwc.(net.Conn); ok {
return unc.LocalAddr()
}
return websocketAddr{}
}
//go:build !js
// +build !js
package websocket
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"strings"
"sync/atomic"
"time"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
// Reader reads from the connection until there is a WebSocket
// data message to be read. It will handle ping, pong and close frames as appropriate.
//
// It returns the type of the message and an io.Reader to read it.
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
//
// Call CloseRead if you do not expect any data messages from the peer.
//
// Only one Reader may be open at a time.
//
// If you need a separate timeout on the Reader call and the Read itself,
// use time.AfterFunc to cancel the context passed in.
// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
// Most users should not need this.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
return c.reader(ctx)
}
// Read is a convenience method around Reader to read a single message
// from the connection.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
typ, r, err := c.Reader(ctx)
if err != nil {
return 0, nil, err
}
b, err := io.ReadAll(r)
return typ, b, err
}
// CloseRead starts a goroutine to read from the connection until it is closed
// or a data message is received.
//
// Once CloseRead is called you cannot read any messages from the connection.
// The returned context will be cancelled when the connection is closed.
//
// If a data message is received, the connection will be closed with StatusPolicyViolation.
//
// Call CloseRead when you do not expect to read any more messages.
// Since it actively reads from the connection, it will ensure that ping, pong and close
// frames are responded to. This means c.Ping and c.Close will still work as expected.
//
// This function is idempotent.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
c.closeReadMu.Lock()
ctx2 := c.closeReadCtx
if ctx2 != nil {
c.closeReadMu.Unlock()
return ctx2
}
ctx, cancel := context.WithCancel(ctx)
c.closeReadCtx = ctx
c.closeReadDone = make(chan struct{})
c.closeReadMu.Unlock()
go func() {
defer close(c.closeReadDone)
defer cancel()
defer c.close()
_, _, err := c.Reader(ctx)
if err == nil {
c.Close(StatusPolicyViolation, "unexpected data message")
}
}()
return ctx
}
// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
//
// Set to -1 to disable.
func (c *Conn) SetReadLimit(n int64) {
if n >= 0 {
// We read one more byte than the limit in case
// there is a fin frame that needs to be read.
n++
}
c.msgReader.limitReader.limit.Store(n)
}
const defaultReadLimit = 32768
func newMsgReader(c *Conn) *msgReader {
mr := &msgReader{
c: c,
fin: true,
}
mr.readFunc = mr.read
mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1)
return mr
}
func (mr *msgReader) resetFlate() {
if mr.flateContextTakeover() {
if mr.dict == nil {
mr.dict = &slidingWindow{}
}
mr.dict.init(32768)
}
if mr.flateBufio == nil {
mr.flateBufio = getBufioReader(mr.readFunc)
}
if mr.flateContextTakeover() {
mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
} else {
mr.flateReader = getFlateReader(mr.flateBufio, nil)
}
mr.limitReader.r = mr.flateReader
mr.flateTail.Reset(deflateMessageTail)
}
func (mr *msgReader) putFlateReader() {
if mr.flateReader != nil {
putFlateReader(mr.flateReader)
mr.flateReader = nil
}
}
func (mr *msgReader) close() {
mr.c.readMu.forceLock()
mr.putFlateReader()
if mr.dict != nil {
mr.dict.close()
mr.dict = nil
}
if mr.flateBufio != nil {
putBufioReader(mr.flateBufio)
}
if mr.c.client {
putBufioReader(mr.c.br)
mr.c.br = nil
}
}
func (mr *msgReader) flateContextTakeover() bool {
if mr.c.client {
return !mr.c.copts.serverNoContextTakeover
}
return !mr.c.copts.clientNoContextTakeover
}
func (c *Conn) readRSV1Illegal(h header) bool {
// If compression is disabled, rsv1 is illegal.
if !c.flate() {
return true
}
// rsv1 is only allowed on data frames beginning messages.
if h.opcode != opText && h.opcode != opBinary {
return true
}
return false
}
func (c *Conn) readLoop(ctx context.Context) (header, error) {
for {
h, err := c.readFrameHeader(ctx)
if err != nil {
return header{}, err
}
if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
c.writeError(StatusProtocolError, err)
return header{}, err
}
if !c.client && !h.masked {
return header{}, errors.New("received unmasked frame from client")
}
switch h.opcode {
case opClose, opPing, opPong:
err = c.handleControl(ctx, h)
if err != nil {
// Pass through CloseErrors when receiving a close frame.
if h.opcode == opClose && CloseStatus(err) != -1 {
return header{}, err
}
return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
}
case opContinuation, opText, opBinary:
return h, nil
default:
err := fmt.Errorf("received unknown opcode %v", h.opcode)
c.writeError(StatusProtocolError, err)
return header{}, err
}
}
}
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
select {
case <-c.closed:
return header{}, net.ErrClosed
case c.readTimeout <- ctx:
}
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
if err != nil {
select {
case <-c.closed:
return header{}, net.ErrClosed
case <-ctx.Done():
return header{}, ctx.Err()
default:
return header{}, err
}
}
select {
case <-c.closed:
return header{}, net.ErrClosed
case c.readTimeout <- context.Background():
}
return h, nil
}
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
select {
case <-c.closed:
return 0, net.ErrClosed
case c.readTimeout <- ctx:
}
n, err := io.ReadFull(c.br, p)
if err != nil {
select {
case <-c.closed:
return n, net.ErrClosed
case <-ctx.Done():
return n, ctx.Err()
default:
return n, fmt.Errorf("failed to read frame payload: %w", err)
}
}
select {
case <-c.closed:
return n, net.ErrClosed
case c.readTimeout <- context.Background():
}
return n, err
}
func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
c.writeError(StatusProtocolError, err)
return err
}
if !h.fin {
err := errors.New("received fragmented control frame")
c.writeError(StatusProtocolError, err)
return err
}
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
b := c.readControlBuf[:h.payloadLength]
_, err = c.readFramePayload(ctx, b)
if err != nil {
return err
}
if h.masked {
mask(b, h.maskKey)
}
switch h.opcode {
case opPing:
return c.writeControl(ctx, opPong, b)
case opPong:
c.activePingsMu.Lock()
pong, ok := c.activePings[string(b)]
c.activePingsMu.Unlock()
if ok {
select {
case pong <- struct{}{}:
default:
}
}
return nil
}
// opClose
ce, err := parseClosePayload(b)
if err != nil {
err = fmt.Errorf("received invalid close payload: %w", err)
c.writeError(StatusProtocolError, err)
return err
}
err = fmt.Errorf("received close frame: %w", ce)
c.writeClose(ce.Code, ce.Reason)
c.readMu.unlock()
c.close()
return err
}
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
defer errd.Wrap(&err, "failed to get reader")
err = c.readMu.lock(ctx)
if err != nil {
return 0, nil, err
}
defer c.readMu.unlock()
if !c.msgReader.fin {
return 0, nil, errors.New("previous message not read to completion")
}
h, err := c.readLoop(ctx)
if err != nil {
return 0, nil, err
}
if h.opcode == opContinuation {
err := errors.New("received continuation frame without text or binary frame")
c.writeError(StatusProtocolError, err)
return 0, nil, err
}
c.msgReader.reset(ctx, h)
return MessageType(h.opcode), c.msgReader, nil
}
type msgReader struct {
c *Conn
ctx context.Context
flate bool
flateReader io.Reader
flateBufio *bufio.Reader
flateTail strings.Reader
limitReader *limitReader
dict *slidingWindow
fin bool
payloadLength int64
maskKey uint32
// util.ReaderFunc(mr.Read) to avoid continuous allocations.
readFunc util.ReaderFunc
}
func (mr *msgReader) reset(ctx context.Context, h header) {
mr.ctx = ctx
mr.flate = h.rsv1
mr.limitReader.reset(mr.readFunc)
if mr.flate {
mr.resetFlate()
}
mr.setFrame(h)
}
func (mr *msgReader) setFrame(h header) {
mr.fin = h.fin
mr.payloadLength = h.payloadLength
mr.maskKey = h.maskKey
}
func (mr *msgReader) Read(p []byte) (n int, err error) {
err = mr.c.readMu.lock(mr.ctx)
if err != nil {
return 0, fmt.Errorf("failed to read: %w", err)
}
defer mr.c.readMu.unlock()
n, err = mr.limitReader.Read(p)
if mr.flate && mr.flateContextTakeover() {
p = p[:n]
mr.dict.write(p)
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
mr.putFlateReader()
return n, io.EOF
}
if err != nil {
return n, fmt.Errorf("failed to read: %w", err)
}
return n, nil
}
func (mr *msgReader) read(p []byte) (int, error) {
for {
if mr.payloadLength == 0 {
if mr.fin {
if mr.flate {
return mr.flateTail.Read(p)
}
return 0, io.EOF
}
h, err := mr.c.readLoop(mr.ctx)
if err != nil {
return 0, err
}
if h.opcode != opContinuation {
err := errors.New("received new data message without finishing the previous message")
mr.c.writeError(StatusProtocolError, err)
return 0, err
}
mr.setFrame(h)
continue
}
if int64(len(p)) > mr.payloadLength {
p = p[:mr.payloadLength]
}
n, err := mr.c.readFramePayload(mr.ctx, p)
if err != nil {
return n, err
}
mr.payloadLength -= int64(n)
if !mr.c.client {
mr.maskKey = mask(p, mr.maskKey)
}
return n, nil
}
}
type limitReader struct {
c *Conn
r io.Reader
limit atomic.Int64
n int64
}
func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
lr := &limitReader{
c: c,
}
lr.limit.Store(limit)
lr.reset(r)
return lr
}
func (lr *limitReader) reset(r io.Reader) {
lr.n = lr.limit.Load()
lr.r = r
}
func (lr *limitReader) Read(p []byte) (int, error) {
if lr.n < 0 {
return lr.r.Read(p)
}
if lr.n == 0 {
err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
lr.c.writeError(StatusMessageTooBig, err)
return 0, err
}
if int64(len(p)) > lr.n {
p = p[:lr.n]
}
n, err := lr.r.Read(p)
lr.n -= int64(n)
if lr.n < 0 {
lr.n = 0
}
return n, err
}
//go:build !js
// +build !js
package websocket
import (
"bufio"
"context"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"time"
"compress/flate"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
// Writer returns a writer bounded by the context that will write
// a WebSocket message of type dataType to the connection.
//
// You must close the writer once you have written the entire message.
//
// Only one writer can be open at a time, multiple calls will block until the previous writer
// is closed.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
w, err := c.writer(ctx, typ)
if err != nil {
return nil, fmt.Errorf("failed to get writer: %w", err)
}
return w, nil
}
// Write writes a message to the connection.
//
// See the Writer method if you want to stream a message.
//
// If compression is disabled or the compression threshold is not met, then it
// will write the message in a single frame.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
_, err := c.write(ctx, typ, p)
if err != nil {
return fmt.Errorf("failed to write msg: %w", err)
}
return nil
}
type msgWriter struct {
c *Conn
mu *mu
writeMu *mu
closed bool
ctx context.Context
opcode opcode
flate bool
trimWriter *trimLastFourBytesWriter
flateWriter *flate.Writer
}
func newMsgWriter(c *Conn) *msgWriter {
mw := &msgWriter{
c: c,
mu: newMu(c),
writeMu: newMu(c),
}
return mw
}
func (mw *msgWriter) ensureFlate() {
if mw.trimWriter == nil {
mw.trimWriter = &trimLastFourBytesWriter{
w: util.WriterFunc(mw.write),
}
}
if mw.flateWriter == nil {
mw.flateWriter = getFlateWriter(mw.trimWriter)
}
mw.flate = true
}
func (mw *msgWriter) flateContextTakeover() bool {
if mw.c.client {
return !mw.c.copts.clientNoContextTakeover
}
return !mw.c.copts.serverNoContextTakeover
}
func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
err := c.msgWriter.reset(ctx, typ)
if err != nil {
return nil, err
}
return c.msgWriter, nil
}
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
mw, err := c.writer(ctx, typ)
if err != nil {
return 0, err
}
if !c.flate() {
defer c.msgWriter.mu.unlock()
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
}
n, err := mw.Write(p)
if err != nil {
return n, err
}
err = mw.Close()
return n, err
}
func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
err := mw.mu.lock(ctx)
if err != nil {
return err
}
mw.ctx = ctx
mw.opcode = opcode(typ)
mw.flate = false
mw.closed = false
mw.trimWriter.reset()
return nil
}
func (mw *msgWriter) putFlateWriter() {
if mw.flateWriter != nil {
putFlateWriter(mw.flateWriter)
mw.flateWriter = nil
}
}
// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriter) Write(p []byte) (_ int, err error) {
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return 0, fmt.Errorf("failed to write: %w", err)
}
defer mw.writeMu.unlock()
if mw.closed {
return 0, errors.New("cannot use closed writer")
}
defer func() {
if err != nil {
err = fmt.Errorf("failed to write: %w", err)
}
}()
if mw.c.flate() {
// Only enables flate if the length crosses the
// threshold on the first frame
if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold {
mw.ensureFlate()
}
}
if mw.flate {
return mw.flateWriter.Write(p)
}
return mw.write(p)
}
func (mw *msgWriter) write(p []byte) (int, error) {
n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
if err != nil {
return n, fmt.Errorf("failed to write data frame: %w", err)
}
mw.opcode = opContinuation
return n, nil
}
// Close flushes the frame to the connection.
func (mw *msgWriter) Close() (err error) {
defer errd.Wrap(&err, "failed to close writer")
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return err
}
defer mw.writeMu.unlock()
if mw.closed {
return errors.New("writer already closed")
}
mw.closed = true
if mw.flate {
err = mw.flateWriter.Flush()
if err != nil {
return fmt.Errorf("failed to flush flate: %w", err)
}
}
_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
if err != nil {
return fmt.Errorf("failed to write fin frame: %w", err)
}
if mw.flate && !mw.flateContextTakeover() {
mw.putFlateWriter()
}
mw.mu.unlock()
return nil
}
func (mw *msgWriter) close() {
if mw.c.client {
mw.c.writeFrameMu.forceLock()
putBufioWriter(mw.c.bw)
}
mw.writeMu.forceLock()
mw.putFlateWriter()
}
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
_, err := c.writeFrame(ctx, true, false, opcode, p)
if err != nil {
return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
}
return nil
}
// writeFrame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
err = c.writeFrameMu.lock(ctx)
if err != nil {
return 0, err
}
defer c.writeFrameMu.unlock()
select {
case <-c.closed:
return 0, net.ErrClosed
case c.writeTimeout <- ctx:
}
defer func() {
if err != nil {
select {
case <-c.closed:
err = net.ErrClosed
case <-ctx.Done():
err = ctx.Err()
default:
}
err = fmt.Errorf("failed to write frame: %w", err)
}
}()
c.writeHeader.fin = fin
c.writeHeader.opcode = opcode
c.writeHeader.payloadLength = int64(len(p))
if c.client {
c.writeHeader.masked = true
_, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4])
if err != nil {
return 0, fmt.Errorf("failed to generate masking key: %w", err)
}
c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:])
}
c.writeHeader.rsv1 = false
if flate && (opcode == opText || opcode == opBinary) {
c.writeHeader.rsv1 = true
}
err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:])
if err != nil {
return 0, err
}
n, err := c.writeFramePayload(p)
if err != nil {
return n, err
}
if c.writeHeader.fin {
err = c.bw.Flush()
if err != nil {
return n, fmt.Errorf("failed to flush: %w", err)
}
}
select {
case <-c.closed:
if opcode == opClose {
return n, nil
}
return n, net.ErrClosed
case c.writeTimeout <- context.Background():
}
return n, nil
}
func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
defer errd.Wrap(&err, "failed to write frame payload")
if !c.writeHeader.masked {
return c.bw.Write(p)
}
maskKey := c.writeHeader.maskKey
for len(p) > 0 {
// If the buffer is full, we need to flush.
if c.bw.Available() == 0 {
err = c.bw.Flush()
if err != nil {
return n, err
}
}
// Start of next write in the buffer.
i := c.bw.Buffered()
j := len(p)
if j > c.bw.Available() {
j = c.bw.Available()
}
_, err := c.bw.Write(p[:j])
if err != nil {
return n, err
}
maskKey = mask(c.writeBuf[i:c.bw.Buffered()], maskKey)
p = p[j:]
n += j
}
return n, nil
}
// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
// and returns it.
func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
var writeBuf []byte
bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) {
writeBuf = p2[:cap(p2)]
return len(p2), nil
}))
bw.WriteByte(0)
bw.Flush()
bw.Reset(w)
return writeBuf
}
func (c *Conn) writeError(code StatusCode, err error) {
c.writeClose(code, err.Error())
}
// Package wsjson provides helpers for reading and writing JSON messages.
package wsjson // import "github.com/coder/websocket/wsjson"
import (
"context"
"encoding/json"
"fmt"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/bpool"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
// Read reads a JSON message from c into v.
// It will reuse buffers in between calls to avoid allocations.
func Read(ctx context.Context, c *websocket.Conn, v interface{}) error {
return read(ctx, c, v)
}
func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) {
defer errd.Wrap(&err, "failed to read JSON message")
_, r, err := c.Reader(ctx)
if err != nil {
return err
}
b := bpool.Get()
defer bpool.Put(b)
_, err = b.ReadFrom(r)
if err != nil {
return err
}
err = json.Unmarshal(b.Bytes(), v)
if err != nil {
c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON")
return fmt.Errorf("failed to unmarshal JSON: %w", err)
}
return nil
}
// Write writes the JSON message v to c.
// It will reuse buffers in between calls to avoid allocations.
func Write(ctx context.Context, c *websocket.Conn, v interface{}) error {
return write(ctx, c, v)
}
func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) {
defer errd.Wrap(&err, "failed to write JSON message")
// json.Marshal cannot reuse buffers between calls as it has to return
// a copy of the byte slice but Encoder does as it directly writes to w.
err = json.NewEncoder(util.WriterFunc(func(p []byte) (int, error) {
err := c.Write(ctx, websocket.MessageText, p)
if err != nil {
return 0, err
}
return len(p), nil
})).Encode(v)
if err != nil {
return fmt.Errorf("failed to marshal JSON: %w", err)
}
return nil
}