Commit 1580169e authored by Matthew Holt's avatar Matthew Holt

vendor: Update quic-go

parent 95514da9
...@@ -8,19 +8,20 @@ import ( ...@@ -8,19 +8,20 @@ import (
var bufferPool sync.Pool var bufferPool sync.Pool
func getPacketBuffer() []byte { func getPacketBuffer() *[]byte {
return bufferPool.Get().([]byte) return bufferPool.Get().(*[]byte)
} }
func putPacketBuffer(buf []byte) { func putPacketBuffer(buf *[]byte) {
if cap(buf) != int(protocol.MaxReceivePacketSize) { if cap(*buf) != int(protocol.MaxReceivePacketSize) {
panic("putPacketBuffer called with packet of wrong size!") panic("putPacketBuffer called with packet of wrong size!")
} }
bufferPool.Put(buf[:0]) bufferPool.Put(buf)
} }
func init() { func init() {
bufferPool.New = func() interface{} { bufferPool.New = func() interface{} {
return make([]byte, 0, protocol.MaxReceivePacketSize) b := make([]byte, 0, protocol.MaxReceivePacketSize)
return &b
} }
} }
...@@ -85,6 +85,14 @@ func Dial( ...@@ -85,6 +85,14 @@ func Dial(
} }
} }
// check that all versions are actually supported
if config != nil {
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
}
}
clientConfig := populateClientConfig(config) clientConfig := populateClientConfig(config)
c := &client{ c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr}, conn: &conn{pconn: pconn, currentAddr: remoteAddr},
...@@ -132,6 +140,18 @@ func populateClientConfig(config *Config) *Config { ...@@ -132,6 +140,18 @@ func populateClientConfig(config *Config) *Config {
if maxReceiveConnectionFlowControlWindow == 0 { if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
} }
maxIncomingStreams := config.MaxIncomingStreams
if maxIncomingStreams == 0 {
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
} else if maxIncomingStreams < 0 {
maxIncomingStreams = 0
}
maxIncomingUniStreams := config.MaxIncomingUniStreams
if maxIncomingUniStreams == 0 {
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
return &Config{ return &Config{
Versions: versions, Versions: versions,
...@@ -140,6 +160,8 @@ func populateClientConfig(config *Config) *Config { ...@@ -140,6 +160,8 @@ func populateClientConfig(config *Config) *Config {
RequestConnectionIDOmission: config.RequestConnectionIDOmission, RequestConnectionIDOmission: config.RequestConnectionIDOmission,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
KeepAlive: config.KeepAlive, KeepAlive: config.KeepAlive,
} }
} }
...@@ -171,9 +193,8 @@ func (c *client) dialTLS() error { ...@@ -171,9 +193,8 @@ func (c *client) dialTLS() error {
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
IdleTimeout: c.config.IdleTimeout, IdleTimeout: c.config.IdleTimeout,
OmitConnectionID: c.config.RequestConnectionIDOmission, OmitConnectionID: c.config.RequestConnectionIDOmission,
// TODO(#523): make these values configurable MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
MaxBidiStreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient), MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
MaxUniStreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
} }
csc := handshake.NewCryptoStreamConn(nil) csc := handshake.NewCryptoStreamConn(nil)
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version) extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version)
...@@ -245,7 +266,7 @@ func (c *client) listen() { ...@@ -245,7 +266,7 @@ func (c *client) listen() {
for { for {
var n int var n int
var addr net.Addr var addr net.Addr
data := getPacketBuffer() data := *getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize] data = data[:protocol.MaxReceivePacketSize]
// The packet size should not exceed protocol.MaxReceivePacketSize bytes // The packet size should not exceed protocol.MaxReceivePacketSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable // If it does, we only read a truncated packet, which will then end up undecryptable
...@@ -347,6 +368,8 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { ...@@ -347,6 +368,8 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
} }
} }
utils.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok { if !ok {
return qerr.InvalidVersion return qerr.InvalidVersion
......
...@@ -108,7 +108,9 @@ func (c *client) handleHeaderStream() { ...@@ -108,7 +108,9 @@ func (c *client) handleHeaderStream() {
for err == nil { for err == nil {
err = c.readResponse(h2framer, decoder) err = c.readResponse(h2framer, decoder)
} }
if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway {
utils.Debugf("Error handling header stream: %s", err) utils.Debugf("Error handling header stream: %s", err)
}
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error()) c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
// stop all running request // stop all running request
close(c.headerErrored) close(c.headerErrored)
...@@ -202,6 +204,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -202,6 +204,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
bodySent = true bodySent = true
} }
ctx := req.Context()
for !(bodySent && receivedResponse) { for !(bodySent && receivedResponse) {
select { select {
case res = <-responseChan: case res = <-responseChan:
...@@ -214,8 +217,16 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { ...@@ -214,8 +217,16 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
case <-ctx.Done():
// error code 6 signals that stream was canceled
dataStream.CancelRead(6)
dataStream.CancelWrite(6)
c.mutex.Lock()
delete(c.responses, dataStream.StreamID())
c.mutex.Unlock()
return nil, ctx.Err()
case <-c.headerErrored: case <-c.headerErrored:
// an error occured on the header stream // an error occurred on the header stream
_ = c.CloseWithError(c.headerErr) _ = c.CloseWithError(c.headerErr)
return nil, c.headerErr return nil, c.headerErr
} }
......
...@@ -76,9 +76,8 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra ...@@ -76,9 +76,8 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
if !validPseudoPath(path) { if !validPseudoPath(path) {
if req.URL.Opaque != "" { if req.URL.Opaque != "" {
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
} else {
return nil, fmt.Errorf("invalid request :path %q", orig)
} }
return nil, fmt.Errorf("invalid request :path %q", orig)
} }
} }
} }
......
...@@ -3,7 +3,6 @@ package h2quic ...@@ -3,7 +3,6 @@ package h2quic
import ( import (
"bytes" "bytes"
"errors" "errors"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/textproto" "net/textproto"
...@@ -16,7 +15,7 @@ import ( ...@@ -16,7 +15,7 @@ import (
// copied from net/http2/transport.go // copied from net/http2/transport.go
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
var noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) var noBody = ioutil.NopCloser(bytes.NewReader(nil))
// from the handleResponse function // from the handleResponse function
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) { func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
...@@ -33,16 +32,7 @@ func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) { ...@@ -33,16 +32,7 @@ func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
return nil, errors.New("malformed non-numeric status pseudo header") return nil, errors.New("malformed non-numeric status pseudo header")
} }
if statusCode == 100 { // TODO: handle statusCode == 100
// TODO: handle this
// traceGot100Continue(cs.trace)
// if cs.on100 != nil {
// cs.on100() // forces any write delay timer to fire
// }
// cs.pastHeaders = false // do it all again
// return nil, nil
}
header := make(http.Header) header := make(http.Header)
res := &http.Response{ res := &http.Response{
...@@ -78,13 +68,7 @@ func setLength(res *http.Response, isHead, streamEnded bool) *http.Response { ...@@ -78,13 +68,7 @@ func setLength(res *http.Response, isHead, streamEnded bool) *http.Response {
if clens := res.Header["Content-Length"]; len(clens) == 1 { if clens := res.Header["Content-Length"]; len(clens) == 1 {
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
res.ContentLength = clen64 res.ContentLength = clen64
} else {
// TODO: care? unlike http/1, it won't mess up our framing, so it's
// more safe smuggling-wise to ignore.
} }
} else if len(clens) > 1 {
// TODO: care? unlike http/1, it won't mess up our framing, so it's
// more safe smuggling-wise to ignore.
} }
} }
return res return res
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
) )
// Connection is a UDP connection // Connection is a UDP connection
...@@ -43,6 +44,8 @@ func (d Direction) String() string { ...@@ -43,6 +44,8 @@ func (d Direction) String() string {
} }
} }
// Is says if one direction matches another direction.
// For example, incoming matches both incoming and both, but not outgoing.
func (d Direction) Is(dir Direction) bool { func (d Direction) Is(dir Direction) bool {
if d == DirectionBoth || dir == DirectionBoth { if d == DirectionBoth || dir == DirectionBoth {
return true return true
...@@ -131,12 +134,20 @@ func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*Qu ...@@ -131,12 +134,20 @@ func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*Qu
version: version, version: version,
} }
utils.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr)
go p.runProxy() go p.runProxy()
return &p, nil return &p, nil
} }
// Close stops the UDP Proxy // Close stops the UDP Proxy
func (p *QuicProxy) Close() error { func (p *QuicProxy) Close() error {
p.mutex.Lock()
defer p.mutex.Unlock()
for _, c := range p.clientDict {
if err := c.ServerConn.Close(); err != nil {
return err
}
}
return p.conn.Close() return p.conn.Close()
} }
...@@ -189,19 +200,27 @@ func (p *QuicProxy) runProxy() error { ...@@ -189,19 +200,27 @@ func (p *QuicProxy) runProxy() error {
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1) packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
if p.dropPacket(DirectionIncoming, packetCount) { if p.dropPacket(DirectionIncoming, packetCount) {
if utils.Debug() {
utils.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n)
}
continue continue
} }
// Send the packet to the server // Send the packet to the server
delay := p.delayPacket(DirectionIncoming, packetCount) delay := p.delayPacket(DirectionIncoming, packetCount)
if delay != 0 { if delay != 0 {
if utils.Debug() {
utils.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay)
}
time.AfterFunc(delay, func() { time.AfterFunc(delay, func() {
// TODO: handle error // TODO: handle error
_, _ = conn.ServerConn.Write(raw) _, _ = conn.ServerConn.Write(raw)
}) })
} else { } else {
_, err := conn.ServerConn.Write(raw) if utils.Debug() {
if err != nil { utils.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr())
}
if _, err := conn.ServerConn.Write(raw); err != nil {
return err return err
} }
} }
...@@ -221,18 +240,26 @@ func (p *QuicProxy) runConnection(conn *connection) error { ...@@ -221,18 +240,26 @@ func (p *QuicProxy) runConnection(conn *connection) error {
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1) packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
if p.dropPacket(DirectionOutgoing, packetCount) { if p.dropPacket(DirectionOutgoing, packetCount) {
if utils.Debug() {
utils.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n)
}
continue continue
} }
delay := p.delayPacket(DirectionOutgoing, packetCount) delay := p.delayPacket(DirectionOutgoing, packetCount)
if delay != 0 { if delay != 0 {
if utils.Debug() {
utils.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay)
}
time.AfterFunc(delay, func() { time.AfterFunc(delay, func() {
// TODO: handle error // TODO: handle error
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr) _, _ = p.conn.WriteToUDP(raw, conn.ClientAddr)
}) })
} else { } else {
_, err := p.conn.WriteToUDP(raw, conn.ClientAddr) if utils.Debug() {
if err != nil { utils.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr)
}
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
return err return err
} }
} }
......
...@@ -22,7 +22,9 @@ const ( ...@@ -22,7 +22,9 @@ const (
) )
var ( var (
// PRData contains dataLen bytes of pseudo-random data.
PRData = GeneratePRData(dataLen) PRData = GeneratePRData(dataLen)
// PRDataLong contains dataLenLong bytes of pseudo-random data.
PRDataLong = GeneratePRData(dataLenLong) PRDataLong = GeneratePRData(dataLenLong)
server *h2quic.Server server *h2quic.Server
...@@ -105,11 +107,13 @@ func StartQuicServer(versions []protocol.VersionNumber) { ...@@ -105,11 +107,13 @@ func StartQuicServer(versions []protocol.VersionNumber) {
}() }()
} }
// StopQuicServer stops the h2quic.Server.
func StopQuicServer() { func StopQuicServer() {
Expect(server.Close()).NotTo(HaveOccurred()) Expect(server.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing).Should(BeClosed()) Eventually(stoppedServing).Should(BeClosed())
} }
// Port returns the UDP port of the QUIC server.
func Port() string { func Port() string {
return port return port
} }
...@@ -113,15 +113,23 @@ type StreamError interface { ...@@ -113,15 +113,23 @@ type StreamError interface {
// A Session is a QUIC connection between two peers. // A Session is a QUIC connection between two peers.
type Session interface { type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available. // AcceptStream returns the next stream opened by the peer, blocking until one is available.
// Since stream 1 is reserved for the crypto stream, the first stream is either 2 (for a client) or 3 (for a server).
AcceptStream() (Stream, error) AcceptStream() (Stream, error)
// OpenStream opens a new QUIC stream, returning a special error when the peer's concurrent stream limit is reached. // AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
// New streams always have the smallest possible stream ID. AcceptUniStream() (ReceiveStream, error)
// TODO: Enable testing for the special error // OpenStream opens a new bidirectional QUIC stream.
// It returns a special error when the peer's concurrent stream limit is reached.
// TODO(#1152): Enable testing for the special error
OpenStream() (Stream, error) OpenStream() (Stream, error)
// OpenStreamSync opens a new QUIC stream, blocking until the peer's concurrent stream limit allows a new stream to be opened. // OpenStreamSync opens a new bidirectional QUIC stream.
// It always picks the smallest possible stream ID. // It blocks until the peer's concurrent stream limit allows a new stream to be opened.
OpenStreamSync() (Stream, error) OpenStreamSync() (Stream, error)
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
// It returns a special error when the peer's concurrent stream limit is reached.
// TODO(#1152): Enable testing for the special error
OpenUniStream() (SendStream, error)
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
OpenUniStreamSync() (SendStream, error)
// LocalAddr returns the local address. // LocalAddr returns the local address.
LocalAddr() net.Addr LocalAddr() net.Addr
// RemoteAddr returns the address of the peer. // RemoteAddr returns the address of the peer.
...@@ -166,6 +174,17 @@ type Config struct { ...@@ -166,6 +174,17 @@ type Config struct {
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data. // MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client. // If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
MaxReceiveConnectionFlowControlWindow uint64 MaxReceiveConnectionFlowControlWindow uint64
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any bidirectional streams.
// Values larger than 65535 (math.MaxUint16) are invalid.
MaxIncomingStreams int
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
// This value doesn't have any effect in Google QUIC.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any unidirectional streams.
// Values larger than 65535 (math.MaxUint16) are invalid.
MaxIncomingUniStreams int
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive. // KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
KeepAlive bool KeepAlive bool
} }
......
...@@ -10,15 +10,13 @@ import ( ...@@ -10,15 +10,13 @@ import (
// SentPacketHandler handles ACKs received for outgoing packets // SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface { type SentPacketHandler interface {
// SentPacket may modify the packet // SentPacket may modify the packet
SentPacket(packet *Packet) error SentPacket(packet *Packet)
SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber)
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
SetHandshakeComplete() SetHandshakeComplete()
// SendingAllowed says if a packet can be sent. // The SendMode determines if and what kind of packets can be sent.
// Sending packets might not be possible because: SendMode() SendMode
// * we're congestion limited
// * we're tracking the maximum number of sent packets
SendingAllowed() bool
// TimeUntilSend is the time when the next packet should be sent. // TimeUntilSend is the time when the next packet should be sent.
// It is used for pacing packets. // It is used for pacing packets.
TimeUntilSend() time.Time TimeUntilSend() time.Time
...@@ -32,10 +30,10 @@ type SentPacketHandler interface { ...@@ -32,10 +30,10 @@ type SentPacketHandler interface {
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
DequeuePacketForRetransmission() (packet *Packet) DequeuePacketForRetransmission() (packet *Packet)
GetLeastUnacked() protocol.PacketNumber GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
GetAlarmTimeout() time.Time GetAlarmTimeout() time.Time
OnAlarm() OnAlarm() error
} }
// ReceivedPacketHandler handles ACKs needed to send for incoming packets // ReceivedPacketHandler handles ACKs needed to send for incoming packets
......
...@@ -11,12 +11,19 @@ import ( ...@@ -11,12 +11,19 @@ import (
// +gen linkedlist // +gen linkedlist
type Packet struct { type Packet struct {
PacketNumber protocol.PacketNumber PacketNumber protocol.PacketNumber
PacketType protocol.PacketType
Frames []wire.Frame Frames []wire.Frame
Length protocol.ByteCount Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel EncryptionLevel protocol.EncryptionLevel
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
sendTime time.Time sendTime time.Time
queuedForRetransmission bool
includedInBytesInFlight bool
retransmittedAs []protocol.PacketNumber
isRetransmission bool // we need a separate bool here because 0 is a valid packet number
retransmissionOf protocol.PacketNumber
} }
// GetFramesForRetransmission gets all the frames for retransmission // GetFramesForRetransmission gets all the frames for retransmission
......
...@@ -3,7 +3,9 @@ package ackhandler ...@@ -3,7 +3,9 @@ package ackhandler
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
) )
...@@ -15,6 +17,7 @@ type receivedPacketHandler struct { ...@@ -15,6 +17,7 @@ type receivedPacketHandler struct {
packetHistory *receivedPacketHistory packetHistory *receivedPacketHistory
ackSendDelay time.Duration ackSendDelay time.Duration
rttStats *congestion.RTTStats
packetsReceivedSinceLastAck int packetsReceivedSinceLastAck int
retransmittablePacketsReceivedSinceLastAck int retransmittablePacketsReceivedSinceLastAck int
...@@ -25,29 +28,54 @@ type receivedPacketHandler struct { ...@@ -25,29 +28,54 @@ type receivedPacketHandler struct {
version protocol.VersionNumber version protocol.VersionNumber
} }
const (
// maximum delay that can be applied to an ACK for a retransmittable packet
ackSendDelay = 25 * time.Millisecond
// initial maximum number of retransmittable packets received before sending an ack.
initialRetransmittablePacketsBeforeAck = 2
// number of retransmittable that an ACK is sent for
retransmittablePacketsBeforeAck = 10
// 1/5 RTT delay when doing ack decimation
ackDecimationDelay = 1.0 / 4
// 1/8 RTT delay when doing ack decimation
shortAckDecimationDelay = 1.0 / 8
// Minimum number of packets received before ack decimation is enabled.
// This intends to avoid the beginning of slow start, when CWNDs may be
// rapidly increasing.
minReceivedBeforeAckDecimation = 100
// Maximum number of packets to ack immediately after a missing packet for
// fast retransmission to kick in at the sender. This limit is created to
// reduce the number of acks sent that have no benefit for fast retransmission.
// Set to the number of nacks needed for fast retransmit plus one for protection
// against an ack loss
maxPacketsAfterNewMissing = 4
)
// NewReceivedPacketHandler creates a new receivedPacketHandler // NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHandler { func NewReceivedPacketHandler(rttStats *congestion.RTTStats, version protocol.VersionNumber) ReceivedPacketHandler {
return &receivedPacketHandler{ return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(), packetHistory: newReceivedPacketHistory(),
ackSendDelay: protocol.AckSendDelay, ackSendDelay: ackSendDelay,
rttStats: rttStats,
version: version, version: version,
} }
} }
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error { func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
if packetNumber < h.ignoreBelow {
return nil
}
isMissing := h.isMissing(packetNumber)
if packetNumber > h.largestObserved { if packetNumber > h.largestObserved {
h.largestObserved = packetNumber h.largestObserved = packetNumber
h.largestObservedReceivedTime = rcvTime h.largestObservedReceivedTime = rcvTime
} }
if packetNumber < h.ignoreBelow {
return nil
}
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil { if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err return err
} }
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck) h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing)
return nil return nil
} }
...@@ -58,35 +86,68 @@ func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) { ...@@ -58,35 +86,68 @@ func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
h.packetHistory.DeleteBelow(p) h.packetHistory.DeleteBelow(p)
} }
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) { // isMissing says if a packet was reported missing in the last ACK.
h.packetsReceivedSinceLastAck++ func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool {
if h.lastAck == nil {
if shouldInstigateAck { return false
h.retransmittablePacketsReceivedSinceLastAck++
} }
return p < h.lastAck.LargestAcked && !h.lastAck.AcksPacket(p)
}
// always ack the first packet func (h *receivedPacketHandler) hasNewMissingPackets() bool {
if h.lastAck == nil { if h.lastAck == nil {
h.ackQueued = true return false
} }
highestRange := h.packetHistory.GetHighestAckRange()
return highestRange.First >= h.lastAck.LargestAcked && highestRange.Len() <= maxPacketsAfterNewMissing
}
// if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK // maybeQueueAck queues an ACK, if necessary.
// note that it cannot be a duplicate because they're already filtered out by ReceivedPacket() // It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck()
if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked { // in ACK_DECIMATION_WITH_REORDERING mode.
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) {
h.packetsReceivedSinceLastAck++
// always ack the first packet
if h.lastAck == nil {
h.ackQueued = true h.ackQueued = true
return
} }
// check if a new missing range above the previously was created // Send an ACK if this packet was reported missing in an ACK sent before.
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().First > h.lastAck.LargestAcked { // Ack decimation with reordering relies on the timer to send an ACK, but if
// missing packets we reported in the previous ack, send an ACK immediately.
if wasMissing {
h.ackQueued = true h.ackQueued = true
} }
if !h.ackQueued && shouldInstigateAck { if !h.ackQueued && shouldInstigateAck {
if h.retransmittablePacketsReceivedSinceLastAck >= protocol.RetransmittablePacketsBeforeAck { h.retransmittablePacketsReceivedSinceLastAck++
if packetNumber > minReceivedBeforeAckDecimation {
// ack up to 10 packets at once
if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
h.ackQueued = true h.ackQueued = true
} else if h.ackAlarm.IsZero() {
// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
h.ackAlarm = rcvTime.Add(ackDelay)
}
} else { } else {
if h.ackAlarm.IsZero() { // send an ACK every 2 retransmittable packets
h.ackAlarm = rcvTime.Add(h.ackSendDelay) if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
h.ackQueued = true
} else if h.ackAlarm.IsZero() {
h.ackAlarm = rcvTime.Add(ackSendDelay)
}
}
// If there are new missing packets to report, set a short timer to send an ACK.
if h.hasNewMissingPackets() {
// wait the minimum of 1/8 min RTT and the existing ack time
ackDelay := float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay)
ackTime := rcvTime.Add(time.Duration(ackDelay))
if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
h.ackAlarm = ackTime
} }
} }
} }
...@@ -118,7 +179,6 @@ func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame { ...@@ -118,7 +179,6 @@ func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
h.ackQueued = false h.ackQueued = false
h.packetsReceivedSinceLastAck = 0 h.packetsReceivedSinceLastAck = 0
h.retransmittablePacketsReceivedSinceLastAck = 0 h.retransmittablePacketsReceivedSinceLastAck = 0
return ack return ack
} }
......
package ackhandler
import "fmt"
// The SendMode says what kind of packets can be sent.
type SendMode uint8
const (
// SendNone means that no packets should be sent
SendNone SendMode = iota
// SendAck means an ACK-only packet should be sent
SendAck
// SendRetransmission means that retransmissions should be sent
SendRetransmission
// SendAny packet should be sent
SendAny
)
func (s SendMode) String() string {
switch s {
case SendNone:
return "none"
case SendAck:
return "ack"
case SendRetransmission:
return "retransmission"
case SendAny:
return "any"
default:
return fmt.Sprintf("invalid send mode: %d", s)
}
}
package ackhandler
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type sentPacketHistory struct {
packetList *PacketList
packetMap map[protocol.PacketNumber]*PacketElement
firstOutstanding *PacketElement
}
func newSentPacketHistory() *sentPacketHistory {
return &sentPacketHistory{
packetList: NewPacketList(),
packetMap: make(map[protocol.PacketNumber]*PacketElement),
}
}
func (h *sentPacketHistory) SentPacket(p *Packet) {
h.sentPacketImpl(p)
}
func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
el := h.packetList.PushBack(*p)
h.packetMap[p.PacketNumber] = el
if h.firstOutstanding == nil {
h.firstOutstanding = el
}
return el
}
func (h *sentPacketHistory) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
retransmission, ok := h.packetMap[retransmissionOf]
// The retransmitted packet is not present anymore.
// This can happen if it was acked in between dequeueing of the retransmission and sending.
// Just treat the retransmissions as normal packets.
// TODO: This won't happen if we clear packets queued for retransmission on new ACKs.
if !ok {
for _, packet := range packets {
h.sentPacketImpl(packet)
}
return
}
retransmission.Value.retransmittedAs = make([]protocol.PacketNumber, len(packets))
for i, packet := range packets {
retransmission.Value.retransmittedAs[i] = packet.PacketNumber
el := h.sentPacketImpl(packet)
el.Value.isRetransmission = true
el.Value.retransmissionOf = retransmissionOf
}
}
func (h *sentPacketHistory) GetPacket(p protocol.PacketNumber) *Packet {
if el, ok := h.packetMap[p]; ok {
return &el.Value
}
return nil
}
// Iterate iterates through all packets.
// The callback must not modify the history.
func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error {
cont := true
for el := h.packetList.Front(); cont && el != nil; el = el.Next() {
var err error
cont, err = cb(&el.Value)
if err != nil {
return err
}
}
return nil
}
// FirstOutStanding returns the first outstanding packet.
// It must not be modified (e.g. retransmitted).
// Use DequeueFirstPacketForRetransmission() to retransmit it.
func (h *sentPacketHistory) FirstOutstanding() *Packet {
if h.firstOutstanding == nil {
return nil
}
return &h.firstOutstanding.Value
}
// QueuePacketForRetransmission marks a packet for retransmission.
// A packet can only be queued once.
func (h *sentPacketHistory) QueuePacketForRetransmission(pn protocol.PacketNumber) (*Packet, error) {
el, ok := h.packetMap[pn]
if !ok {
return nil, fmt.Errorf("sent packet history: packet %d not found", pn)
}
if el.Value.queuedForRetransmission {
return nil, fmt.Errorf("sent packet history BUG: packet %d already queued for retransmission", pn)
}
el.Value.queuedForRetransmission = true
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
return &el.Value, nil
}
// readjustFirstOutstanding readjusts the pointer to the first outstanding packet.
// This is necessary every time the first outstanding packet is deleted or retransmitted.
func (h *sentPacketHistory) readjustFirstOutstanding() {
el := h.firstOutstanding.Next()
for el != nil && el.Value.queuedForRetransmission {
el = el.Next()
}
h.firstOutstanding = el
}
func (h *sentPacketHistory) Len() int {
return len(h.packetMap)
}
func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
el, ok := h.packetMap[p]
if !ok {
return fmt.Errorf("packet %d not found in sent packet history", p)
}
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
h.packetList.Remove(el)
delete(h.packetMap, p)
return nil
}
...@@ -55,28 +55,28 @@ func (c *certChain) GetLeafCert(sni string) ([]byte, error) { ...@@ -55,28 +55,28 @@ func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
return cert.Certificate[0], nil return cert.Certificate[0], nil
} }
func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
c := cc.config conf := c.config
c, err := maybeGetConfigForClient(c, sni) conf, err := maybeGetConfigForClient(conf, sni)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// The rest of this function is mostly copied from crypto/tls.getCertificate // The rest of this function is mostly copied from crypto/tls.getCertificate
if c.GetCertificate != nil { if conf.GetCertificate != nil {
cert, err := c.GetCertificate(&tls.ClientHelloInfo{ServerName: sni}) cert, err := conf.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
if cert != nil || err != nil { if cert != nil || err != nil {
return cert, err return cert, err
} }
} }
if len(c.Certificates) == 0 { if len(conf.Certificates) == 0 {
return nil, errNoMatchingCertificate return nil, errNoMatchingCertificate
} }
if len(c.Certificates) == 1 || c.NameToCertificate == nil { if len(conf.Certificates) == 1 || conf.NameToCertificate == nil {
// There's only one choice, so no point doing any work. // There's only one choice, so no point doing any work.
return &c.Certificates[0], nil return &conf.Certificates[0], nil
} }
name := strings.ToLower(sni) name := strings.ToLower(sni)
...@@ -84,7 +84,7 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { ...@@ -84,7 +84,7 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
name = name[:len(name)-1] name = name[:len(name)-1]
} }
if cert, ok := c.NameToCertificate[name]; ok { if cert, ok := conf.NameToCertificate[name]; ok {
return cert, nil return cert, nil
} }
...@@ -94,13 +94,13 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { ...@@ -94,13 +94,13 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
for i := range labels { for i := range labels {
labels[i] = "*" labels[i] = "*"
candidate := strings.Join(labels, ".") candidate := strings.Join(labels, ".")
if cert, ok := c.NameToCertificate[candidate]; ok { if cert, ok := conf.NameToCertificate[candidate]; ok {
return cert, nil return cert, nil
} }
} }
// If nothing matches, return the first certificate. // If nothing matches, return the first certificate.
return &c.Certificates[0], nil return &conf.Certificates[0], nil
} }
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) { func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
......
package crypto package crypto
import ( import (
"crypto"
"encoding/binary"
"github.com/bifurcation/mint" "github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
const ( const (
clientExporterLabel = "EXPORTER-QUIC client 1-RTT Secret" clientExporterLabel = "EXPORTER-QUIC client 1rtt"
serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret" serverExporterLabel = "EXPORTER-QUIC server 1rtt"
) )
// A TLSExporter gets the negotiated ciphersuite and computes exporter // A TLSExporter gets the negotiated ciphersuite and computes exporter
...@@ -16,6 +19,16 @@ type TLSExporter interface { ...@@ -16,6 +19,16 @@ type TLSExporter interface {
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
} }
func qhkdfExpand(secret []byte, label string, length int) []byte {
// The last byte should be 0x0.
// Since Go initializes the slice to 0, we don't need to set it explicitly.
qlabel := make([]byte, 2+1+5+len(label)+1)
binary.BigEndian.PutUint16(qlabel[0:2], uint16(length))
qlabel[2] = uint8(5 + len(label))
copy(qlabel[3:], []byte("QUIC "+label))
return mint.HkdfExpand(crypto.SHA256, secret, qlabel, length)
}
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance // DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) { func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
var myLabel, otherLabel string var myLabel, otherLabel string
...@@ -43,7 +56,7 @@ func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) ...@@ -43,7 +56,7 @@ func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
key = mint.HkdfExpandLabel(cs.Hash, secret, "key", nil, cs.KeyLen) key = qhkdfExpand(secret, "key", cs.KeyLen)
iv = mint.HkdfExpandLabel(cs.Hash, secret, "iv", nil, cs.IvLen) iv = qhkdfExpand(secret, "iv", cs.IvLen)
return key, iv, nil return key, iv, nil
} }
...@@ -31,14 +31,14 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec ...@@ -31,14 +31,14 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec
func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) { func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
connID := make([]byte, 8) connID := make([]byte, 8)
binary.BigEndian.PutUint64(connID, uint64(connectionID)) binary.BigEndian.PutUint64(connID, uint64(connectionID))
cleartextSecret := mint.HkdfExtract(crypto.SHA256, []byte(quicVersion1Salt), connID) handshakeSecret := mint.HkdfExtract(crypto.SHA256, quicVersion1Salt, connID)
clientSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC client cleartext Secret", []byte{}, crypto.SHA256.Size()) clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size())
serverSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC server cleartext Secret", []byte{}, crypto.SHA256.Size()) serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size())
return return
} }
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) { func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
key = mint.HkdfExpandLabel(crypto.SHA256, secret, "key", nil, 16) key = qhkdfExpand(secret, "key", 16)
iv = mint.HkdfExpandLabel(crypto.SHA256, secret, "iv", nil, 12) iv = qhkdfExpand(secret, "iv", 12)
return return
} }
...@@ -7,6 +7,9 @@ import ( ...@@ -7,6 +7,9 @@ import (
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
// A CookieHandler generates and validates cookies.
// The cookie is sent in the TLS Retry.
// By including the cookie in its ClientHello, a client can proof ownership of its source address.
type CookieHandler struct { type CookieHandler struct {
callback func(net.Addr, *Cookie) bool callback func(net.Addr, *Cookie) bool
...@@ -15,6 +18,7 @@ type CookieHandler struct { ...@@ -15,6 +18,7 @@ type CookieHandler struct {
var _ mint.CookieHandler = &CookieHandler{} var _ mint.CookieHandler = &CookieHandler{}
// NewCookieHandler creates a new CookieHandler.
func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, error) { func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, error) {
cookieGenerator, err := NewCookieGenerator() cookieGenerator, err := NewCookieGenerator()
if err != nil { if err != nil {
...@@ -26,6 +30,7 @@ func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, er ...@@ -26,6 +30,7 @@ func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, er
}, nil }, nil
} }
// Generate a new cookie for a mint connection.
func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) { func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
if h.callback(conn.RemoteAddr(), nil) { if h.callback(conn.RemoteAddr(), nil) {
return nil, nil return nil, nil
...@@ -33,6 +38,7 @@ func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) { ...@@ -33,6 +38,7 @@ func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
return h.cookieGenerator.NewToken(conn.RemoteAddr()) return h.cookieGenerator.NewToken(conn.RemoteAddr())
} }
// Validate a cookie.
func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool { func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
data, err := h.cookieGenerator.DecodeToken(token) data, err := h.cookieGenerator.DecodeToken(token)
if err != nil { if err != nil {
......
...@@ -429,7 +429,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T ...@@ -429,7 +429,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
replyMap := h.params.getHelloMap() replyMap := h.params.getHelloMap()
// add crypto parameters // add crypto parameters
verTag := &bytes.Buffer{} verTag := &bytes.Buffer{}
for _, v := range protocol.GetGreasedVersions(h.supportedVersions) { for _, v := range h.supportedVersions {
utils.BigEndian.WriteUint32(verTag, uint32(v)) utils.BigEndian.WriteUint32(verTag, uint32(v))
} }
replyMap[TagPUBS] = ephermalKex.PublicKey() replyMap[TagPUBS] = ephermalKex.PublicKey()
......
...@@ -84,7 +84,7 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) { ...@@ -84,7 +84,7 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) {
offset := uint32(0) offset := uint32(0)
for i, t := range h.getTagsSorted() { for i, t := range h.getTagsSorted() {
v := data[Tag(t)] v := data[t]
b.Write(v) b.Write(v)
offset += uint32(len(v)) offset += uint32(len(v))
binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t)) binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
...@@ -111,8 +111,7 @@ func (h *HandshakeMessage) getTagsSorted() []Tag { ...@@ -111,8 +111,7 @@ func (h *HandshakeMessage) getTagsSorted() []Tag {
func (h HandshakeMessage) String() string { func (h HandshakeMessage) String() string {
var pad string var pad string
res := tagToString(h.Tag) + ":\n" res := tagToString(h.Tag) + ":\n"
for _, t := range h.getTagsSorted() { for _, tag := range h.getTagsSorted() {
tag := Tag(t)
if tag == TagPAD { if tag == TagPAD {
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag])) pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag]))
} else { } else {
......
...@@ -102,32 +102,37 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error { ...@@ -102,32 +102,37 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
} }
var pubs_kexs []struct{Length uint32; Value []byte} var pubsKexs []struct {
var last_len uint32 Length uint32
Value []byte
for i := 0; i < len(pubs)-3; i += int(last_len)+3 { }
var lastLen uint32
for i := 0; i < len(pubs)-3; i += int(lastLen) + 3 {
// the PUBS value is always prepended by 3 byte little endian length field // the PUBS value is always prepended by 3 byte little endian length field
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &last_len); err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &lastLen)
if err != nil { if err != nil {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable") return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable")
} }
if last_len == 0 { if lastLen == 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS") return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
} }
if i+3+int(last_len) > len(pubs) { if i+3+int(lastLen) > len(pubs) {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS") return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
} }
pubs_kexs = append(pubs_kexs, struct{Length uint32; Value []byte}{last_len, pubs[i+3:i+3+int(last_len)]}) pubsKexs = append(pubsKexs, struct {
Length uint32
Value []byte
}{lastLen, pubs[i+3 : i+3+int(lastLen)]})
} }
if c255Foundat >= len(pubs_kexs) { if c255Foundat >= len(pubsKexs) {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS") return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS")
} }
if pubs_kexs[c255Foundat].Length != 32 { if pubsKexs[c255Foundat].Length != 32 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS") return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
} }
...@@ -137,8 +142,7 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error { ...@@ -137,8 +142,7 @@ func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
return err return err
} }
s.sharedSecret, err = s.kex.CalculateSharedKey(pubsKexs[c255Foundat].Value)
s.sharedSecret, err = s.kex.CalculateSharedKey(pubs_kexs[c255Foundat].Value)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -11,12 +11,12 @@ const quicTLSExtensionType = 26 ...@@ -11,12 +11,12 @@ const quicTLSExtensionType = 26
const ( const (
initialMaxStreamDataParameterID transportParameterID = 0x0 initialMaxStreamDataParameterID transportParameterID = 0x0
initialMaxDataParameterID transportParameterID = 0x1 initialMaxDataParameterID transportParameterID = 0x1
initialMaxStreamIDBiDiParameterID transportParameterID = 0x2 initialMaxStreamsBiDiParameterID transportParameterID = 0x2
idleTimeoutParameterID transportParameterID = 0x3 idleTimeoutParameterID transportParameterID = 0x3
omitConnectionIDParameterID transportParameterID = 0x4 omitConnectionIDParameterID transportParameterID = 0x4
maxPacketSizeParameterID transportParameterID = 0x5 maxPacketSizeParameterID transportParameterID = 0x5
statelessResetTokenParameterID transportParameterID = 0x6 statelessResetTokenParameterID transportParameterID = 0x6
initialMaxStreamIDUniParameterID transportParameterID = 0x8 initialMaxStreamsUniParameterID transportParameterID = 0x8
) )
type transportParameter struct { type transportParameter struct {
......
...@@ -3,13 +3,13 @@ package handshake ...@@ -3,13 +3,13 @@ package handshake
import ( import (
"errors" "errors"
"fmt" "fmt"
"math"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/bifurcation/mint" "github.com/bifurcation/mint"
"github.com/bifurcation/mint/syntax" "github.com/bifurcation/mint/syntax"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
) )
type extensionHandlerClient struct { type extensionHandlerClient struct {
...@@ -31,7 +31,10 @@ func NewExtensionHandlerClient( ...@@ -31,7 +31,10 @@ func NewExtensionHandlerClient(
supportedVersions []protocol.VersionNumber, supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber, version protocol.VersionNumber,
) TLSExtensionHandler { ) TLSExtensionHandler {
paramsChan := make(chan TransportParameters, 1) // The client reads the transport parameters from the Encrypted Extensions message.
// The paramsChan is used in the session's run loop's select statement.
// We have to use an unbuffered channel here to make sure that the session actually processes the transport parameters immediately.
paramsChan := make(chan TransportParameters)
return &extensionHandlerClient{ return &extensionHandlerClient{
ourParams: params, ourParams: params,
paramsChan: paramsChan, paramsChan: paramsChan,
...@@ -46,6 +49,7 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi ...@@ -46,6 +49,7 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi
return nil return nil
} }
utils.Debugf("Sending Transport Parameters: %s", h.ourParams)
data, err := syntax.Marshal(clientHelloTransportParameters{ data, err := syntax.Marshal(clientHelloTransportParameters{
InitialVersion: uint32(h.initialVersion), InitialVersion: uint32(h.initialVersion),
Parameters: h.ourParams.getTransportParameters(), Parameters: h.ourParams.getTransportParameters(),
...@@ -63,17 +67,12 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte ...@@ -63,17 +67,12 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte
return err return err
} }
if hType != mint.HandshakeTypeEncryptedExtensions && hType != mint.HandshakeTypeNewSessionTicket { if hType != mint.HandshakeTypeEncryptedExtensions {
if found { if found {
return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType) return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType)
} }
return nil return nil
} }
if hType == mint.HandshakeTypeNewSessionTicket {
// the extension it's optional in the NewSessionTicket message
// TODO: handle this
return nil
}
// hType == mint.HandshakeTypeEncryptedExtensions // hType == mint.HandshakeTypeEncryptedExtensions
if !found { if !found {
...@@ -119,12 +118,11 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte ...@@ -119,12 +118,11 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte
// TODO: return the right error here // TODO: return the right error here
return errors.New("server didn't sent stateless_reset_token") return errors.New("server didn't sent stateless_reset_token")
} }
params, err := readTransportParamters(eetp.Parameters) params, err := readTransportParameters(eetp.Parameters)
if err != nil { if err != nil {
return err return err
} }
// TODO(#878): remove this when implementing the MAX_STREAM_ID frame utils.Debugf("Received Transport Parameters: %s", params)
params.MaxStreams = math.MaxUint32
h.paramsChan <- *params h.paramsChan <- *params
return nil return nil
} }
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"github.com/bifurcation/mint" "github.com/bifurcation/mint"
"github.com/bifurcation/mint/syntax" "github.com/bifurcation/mint/syntax"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
) )
type extensionHandlerServer struct { type extensionHandlerServer struct {
...@@ -29,6 +30,8 @@ func NewExtensionHandlerServer( ...@@ -29,6 +30,8 @@ func NewExtensionHandlerServer(
supportedVersions []protocol.VersionNumber, supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber, version protocol.VersionNumber,
) TLSExtensionHandler { ) TLSExtensionHandler {
// Processing the ClientHello is performed statelessly (and from a single go-routine).
// Therefore, we have to use a buffered chan to pass the transport parameters to that go routine.
paramsChan := make(chan TransportParameters, 1) paramsChan := make(chan TransportParameters, 1)
return &extensionHandlerServer{ return &extensionHandlerServer{
ourParams: params, ourParams: params,
...@@ -53,6 +56,7 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi ...@@ -53,6 +56,7 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi
for i, v := range supportedVersions { for i, v := range supportedVersions {
versions[i] = uint32(v) versions[i] = uint32(v)
} }
utils.Debugf("Sending Transport Parameters: %s", h.ourParams)
data, err := syntax.Marshal(encryptedExtensionsTransportParameters{ data, err := syntax.Marshal(encryptedExtensionsTransportParameters{
NegotiatedVersion: uint32(h.version), NegotiatedVersion: uint32(h.version),
SupportedVersions: versions, SupportedVersions: versions,
...@@ -100,10 +104,11 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte ...@@ -100,10 +104,11 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte
return errors.New("client sent a stateless reset token") return errors.New("client sent a stateless reset token")
} }
} }
params, err := readTransportParamters(chtp.Parameters) params, err := readTransportParameters(chtp.Parameters)
if err != nil { if err != nil {
return err return err
} }
utils.Debugf("Received Transport Parameters: %s", params)
h.paramsChan <- *params h.paramsChan <- *params
return nil return nil
} }
......
...@@ -20,8 +20,10 @@ type TransportParameters struct { ...@@ -20,8 +20,10 @@ type TransportParameters struct {
StreamFlowControlWindow protocol.ByteCount StreamFlowControlWindow protocol.ByteCount
ConnectionFlowControlWindow protocol.ByteCount ConnectionFlowControlWindow protocol.ByteCount
MaxBidiStreamID protocol.StreamID // only used for IETF QUIC MaxPacketSize protocol.ByteCount
MaxUniStreamID protocol.StreamID // only used for IETF QUIC
MaxUniStreams uint16 // only used for IETF QUIC
MaxBidiStreams uint16 // only used for IETF QUIC
MaxStreams uint32 // only used for gQUIC MaxStreams uint32 // only used for gQUIC
OmitConnectionID bool OmitConnectionID bool
...@@ -93,7 +95,7 @@ func (p *TransportParameters) getHelloMap() map[Tag][]byte { ...@@ -93,7 +95,7 @@ func (p *TransportParameters) getHelloMap() map[Tag][]byte {
} }
// readTransportParameters reads the transport parameters sent in the QUIC TLS extension // readTransportParameters reads the transport parameters sent in the QUIC TLS extension
func readTransportParamters(paramsList []transportParameter) (*TransportParameters, error) { func readTransportParameters(paramsList []transportParameter) (*TransportParameters, error) {
params := &TransportParameters{} params := &TransportParameters{}
var foundInitialMaxStreamData bool var foundInitialMaxStreamData bool
...@@ -114,18 +116,16 @@ func readTransportParamters(paramsList []transportParameter) (*TransportParamete ...@@ -114,18 +116,16 @@ func readTransportParamters(paramsList []transportParameter) (*TransportParamete
return nil, fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", len(p.Value)) return nil, fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", len(p.Value))
} }
params.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value)) params.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value))
case initialMaxStreamIDBiDiParameterID: case initialMaxStreamsBiDiParameterID:
if len(p.Value) != 4 { if len(p.Value) != 2 {
return nil, fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 4)", len(p.Value)) return nil, fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 2)", len(p.Value))
} }
// TODO(#1154): validate the stream ID params.MaxBidiStreams = binary.BigEndian.Uint16(p.Value)
params.MaxBidiStreamID = protocol.StreamID(binary.BigEndian.Uint32(p.Value)) case initialMaxStreamsUniParameterID:
case initialMaxStreamIDUniParameterID: if len(p.Value) != 2 {
if len(p.Value) != 4 { return nil, fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 2)", len(p.Value))
return nil, fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 4)", len(p.Value))
} }
// TODO(#1154): validate the stream ID params.MaxUniStreams = binary.BigEndian.Uint16(p.Value)
params.MaxUniStreamID = protocol.StreamID(binary.BigEndian.Uint32(p.Value))
case idleTimeoutParameterID: case idleTimeoutParameterID:
foundIdleTimeout = true foundIdleTimeout = true
if len(p.Value) != 2 { if len(p.Value) != 2 {
...@@ -137,6 +137,15 @@ func readTransportParamters(paramsList []transportParameter) (*TransportParamete ...@@ -137,6 +137,15 @@ func readTransportParamters(paramsList []transportParameter) (*TransportParamete
return nil, fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value)) return nil, fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value))
} }
params.OmitConnectionID = true params.OmitConnectionID = true
case maxPacketSizeParameterID:
if len(p.Value) != 2 {
return nil, fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", len(p.Value))
}
maxPacketSize := protocol.ByteCount(binary.BigEndian.Uint16(p.Value))
if maxPacketSize < 1200 {
return nil, fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", maxPacketSize)
}
params.MaxPacketSize = maxPacketSize
} }
} }
...@@ -153,10 +162,10 @@ func (p *TransportParameters) getTransportParameters() []transportParameter { ...@@ -153,10 +162,10 @@ func (p *TransportParameters) getTransportParameters() []transportParameter {
binary.BigEndian.PutUint32(initialMaxStreamData, uint32(p.StreamFlowControlWindow)) binary.BigEndian.PutUint32(initialMaxStreamData, uint32(p.StreamFlowControlWindow))
initialMaxData := make([]byte, 4) initialMaxData := make([]byte, 4)
binary.BigEndian.PutUint32(initialMaxData, uint32(p.ConnectionFlowControlWindow)) binary.BigEndian.PutUint32(initialMaxData, uint32(p.ConnectionFlowControlWindow))
initialMaxBidiStreamID := make([]byte, 4) initialMaxBidiStreamID := make([]byte, 2)
binary.BigEndian.PutUint32(initialMaxBidiStreamID, uint32(p.MaxBidiStreamID)) binary.BigEndian.PutUint16(initialMaxBidiStreamID, p.MaxBidiStreams)
initialMaxUniStreamID := make([]byte, 4) initialMaxUniStreamID := make([]byte, 2)
binary.BigEndian.PutUint32(initialMaxUniStreamID, uint32(p.MaxUniStreamID)) binary.BigEndian.PutUint16(initialMaxUniStreamID, p.MaxUniStreams)
idleTimeout := make([]byte, 2) idleTimeout := make([]byte, 2)
binary.BigEndian.PutUint16(idleTimeout, uint16(p.IdleTimeout/time.Second)) binary.BigEndian.PutUint16(idleTimeout, uint16(p.IdleTimeout/time.Second))
maxPacketSize := make([]byte, 2) maxPacketSize := make([]byte, 2)
...@@ -164,8 +173,8 @@ func (p *TransportParameters) getTransportParameters() []transportParameter { ...@@ -164,8 +173,8 @@ func (p *TransportParameters) getTransportParameters() []transportParameter {
params := []transportParameter{ params := []transportParameter{
{initialMaxStreamDataParameterID, initialMaxStreamData}, {initialMaxStreamDataParameterID, initialMaxStreamData},
{initialMaxDataParameterID, initialMaxData}, {initialMaxDataParameterID, initialMaxData},
{initialMaxStreamIDBiDiParameterID, initialMaxBidiStreamID}, {initialMaxStreamsBiDiParameterID, initialMaxBidiStreamID},
{initialMaxStreamIDUniParameterID, initialMaxUniStreamID}, {initialMaxStreamsUniParameterID, initialMaxUniStreamID},
{idleTimeoutParameterID, idleTimeout}, {idleTimeoutParameterID, idleTimeout},
{maxPacketSizeParameterID, maxPacketSize}, {maxPacketSizeParameterID, maxPacketSize},
} }
...@@ -174,3 +183,9 @@ func (p *TransportParameters) getTransportParameters() []transportParameter { ...@@ -174,3 +183,9 @@ func (p *TransportParameters) getTransportParameters() []transportParameter {
} }
return params return params
} }
// String returns a string representation, intended for logging.
// It should only used for IETF QUIC.
func (p *TransportParameters) String() string {
return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, OmitConnectionID: %t, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.OmitConnectionID, p.IdleTimeout)
}
...@@ -61,18 +61,6 @@ func (mr *MockSentPacketHandlerMockRecorder) GetAlarmTimeout() *gomock.Call { ...@@ -61,18 +61,6 @@ func (mr *MockSentPacketHandlerMockRecorder) GetAlarmTimeout() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAlarmTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).GetAlarmTimeout)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAlarmTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).GetAlarmTimeout))
} }
// GetLeastUnacked mocks base method
func (m *MockSentPacketHandler) GetLeastUnacked() protocol.PacketNumber {
ret := m.ctrl.Call(m, "GetLeastUnacked")
ret0, _ := ret[0].(protocol.PacketNumber)
return ret0
}
// GetLeastUnacked indicates an expected call of GetLeastUnacked
func (mr *MockSentPacketHandlerMockRecorder) GetLeastUnacked() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeastUnacked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLeastUnacked))
}
// GetLowestPacketNotConfirmedAcked mocks base method // GetLowestPacketNotConfirmedAcked mocks base method
func (m *MockSentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { func (m *MockSentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
ret := m.ctrl.Call(m, "GetLowestPacketNotConfirmedAcked") ret := m.ctrl.Call(m, "GetLowestPacketNotConfirmedAcked")
...@@ -85,6 +73,18 @@ func (mr *MockSentPacketHandlerMockRecorder) GetLowestPacketNotConfirmedAcked() ...@@ -85,6 +73,18 @@ func (mr *MockSentPacketHandlerMockRecorder) GetLowestPacketNotConfirmedAcked()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLowestPacketNotConfirmedAcked)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLowestPacketNotConfirmedAcked))
} }
// GetPacketNumberLen mocks base method
func (m *MockSentPacketHandler) GetPacketNumberLen(arg0 protocol.PacketNumber) protocol.PacketNumberLen {
ret := m.ctrl.Call(m, "GetPacketNumberLen", arg0)
ret0, _ := ret[0].(protocol.PacketNumberLen)
return ret0
}
// GetPacketNumberLen indicates an expected call of GetPacketNumberLen
func (mr *MockSentPacketHandlerMockRecorder) GetPacketNumberLen(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPacketNumberLen", reflect.TypeOf((*MockSentPacketHandler)(nil).GetPacketNumberLen), arg0)
}
// GetStopWaitingFrame mocks base method // GetStopWaitingFrame mocks base method
func (m *MockSentPacketHandler) GetStopWaitingFrame(arg0 bool) *wire.StopWaitingFrame { func (m *MockSentPacketHandler) GetStopWaitingFrame(arg0 bool) *wire.StopWaitingFrame {
ret := m.ctrl.Call(m, "GetStopWaitingFrame", arg0) ret := m.ctrl.Call(m, "GetStopWaitingFrame", arg0)
...@@ -98,8 +98,10 @@ func (mr *MockSentPacketHandlerMockRecorder) GetStopWaitingFrame(arg0 interface{ ...@@ -98,8 +98,10 @@ func (mr *MockSentPacketHandlerMockRecorder) GetStopWaitingFrame(arg0 interface{
} }
// OnAlarm mocks base method // OnAlarm mocks base method
func (m *MockSentPacketHandler) OnAlarm() { func (m *MockSentPacketHandler) OnAlarm() error {
m.ctrl.Call(m, "OnAlarm") ret := m.ctrl.Call(m, "OnAlarm")
ret0, _ := ret[0].(error)
return ret0
} }
// OnAlarm indicates an expected call of OnAlarm // OnAlarm indicates an expected call of OnAlarm
...@@ -119,23 +121,21 @@ func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2, arg3 ...@@ -119,23 +121,21 @@ func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2, arg3) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2, arg3)
} }
// SendingAllowed mocks base method // SendMode mocks base method
func (m *MockSentPacketHandler) SendingAllowed() bool { func (m *MockSentPacketHandler) SendMode() ackhandler.SendMode {
ret := m.ctrl.Call(m, "SendingAllowed") ret := m.ctrl.Call(m, "SendMode")
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(ackhandler.SendMode)
return ret0 return ret0
} }
// SendingAllowed indicates an expected call of SendingAllowed // SendMode indicates an expected call of SendMode
func (mr *MockSentPacketHandlerMockRecorder) SendingAllowed() *gomock.Call { func (mr *MockSentPacketHandlerMockRecorder) SendMode() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendingAllowed", reflect.TypeOf((*MockSentPacketHandler)(nil).SendingAllowed)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMode", reflect.TypeOf((*MockSentPacketHandler)(nil).SendMode))
} }
// SentPacket mocks base method // SentPacket mocks base method
func (m *MockSentPacketHandler) SentPacket(arg0 *ackhandler.Packet) error { func (m *MockSentPacketHandler) SentPacket(arg0 *ackhandler.Packet) {
ret := m.ctrl.Call(m, "SentPacket", arg0) m.ctrl.Call(m, "SentPacket", arg0)
ret0, _ := ret[0].(error)
return ret0
} }
// SentPacket indicates an expected call of SentPacket // SentPacket indicates an expected call of SentPacket
...@@ -143,6 +143,16 @@ func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0 interface{}) *gomoc ...@@ -143,6 +143,16 @@ func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0 interface{}) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0)
} }
// SentPacketsAsRetransmission mocks base method
func (m *MockSentPacketHandler) SentPacketsAsRetransmission(arg0 []*ackhandler.Packet, arg1 protocol.PacketNumber) {
m.ctrl.Call(m, "SentPacketsAsRetransmission", arg0, arg1)
}
// SentPacketsAsRetransmission indicates an expected call of SentPacketsAsRetransmission
func (mr *MockSentPacketHandlerMockRecorder) SentPacketsAsRetransmission(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacketsAsRetransmission", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacketsAsRetransmission), arg0, arg1)
}
// SetHandshakeComplete mocks base method // SetHandshakeComplete mocks base method
func (m *MockSentPacketHandler) SetHandshakeComplete() { func (m *MockSentPacketHandler) SetHandshakeComplete() {
m.ctrl.Call(m, "SetHandshakeComplete") m.ctrl.Call(m, "SetHandshakeComplete")
......
...@@ -27,14 +27,14 @@ const ( ...@@ -27,14 +27,14 @@ const (
type PacketType uint8 type PacketType uint8
const ( const (
// PacketTypeInitial is the packet type of a Initial packet // PacketTypeInitial is the packet type of an Initial packet
PacketTypeInitial PacketType = 2 PacketTypeInitial PacketType = 0x7f
// PacketTypeRetry is the packet type of a Retry packet // PacketTypeRetry is the packet type of a Retry packet
PacketTypeRetry PacketType = 3 PacketTypeRetry PacketType = 0x7e
// PacketTypeHandshake is the packet type of a Cleartext packet // PacketTypeHandshake is the packet type of a Handshake packet
PacketTypeHandshake PacketType = 4 PacketTypeHandshake PacketType = 0x7d
// PacketType0RTT is the packet type of a 0-RTT packet // PacketType0RTT is the packet type of a 0-RTT packet
PacketType0RTT PacketType = 5 PacketType0RTT PacketType = 0x7c
) )
func (t PacketType) String() string { func (t PacketType) String() string {
...@@ -77,11 +77,11 @@ const DefaultTCPMSS ByteCount = 1460 ...@@ -77,11 +77,11 @@ const DefaultTCPMSS ByteCount = 1460
// MinClientHelloSize is the minimum size the server expects an inchoate CHLO to have (in gQUIC) // MinClientHelloSize is the minimum size the server expects an inchoate CHLO to have (in gQUIC)
const MinClientHelloSize = 1024 const MinClientHelloSize = 1024
// MinInitialPacketSize is the minimum size an Initial packet (in IETF QUIC) is requried to have. // MinInitialPacketSize is the minimum size an Initial packet (in IETF QUIC) is required to have.
const MinInitialPacketSize = 1200 const MinInitialPacketSize = 1200
// MaxClientHellos is the maximum number of times we'll send a client hello // MaxClientHellos is the maximum number of times we'll send a client hello
// The value 3 accounts for: // The value 3 accounts for:
// * one failure due to an incorrect or missing source-address token // * one failure due to an incorrect or missing source-address token
// * one failure due the server's certificate chain being unavailible and the server being unwilling to send it without a valid source-address token // * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
const MaxClientHellos = 3 const MaxClientHellos = 3
...@@ -2,9 +2,11 @@ package protocol ...@@ -2,9 +2,11 @@ package protocol
import "time" import "time"
// MaxPacketSize is the maximum packet size that we use for sending packets. // MaxPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets.
// It includes the QUIC packet header, but excludes the UDP and IP header. const MaxPacketSizeIPv4 = 1252
const MaxPacketSize ByteCount = 1200
// MaxPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets.
const MaxPacketSizeIPv6 = 1232
// NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet // NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet
// This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames // This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames
...@@ -24,10 +26,6 @@ const MaxUndecryptablePackets = 10 ...@@ -24,10 +26,6 @@ const MaxUndecryptablePackets = 10
// This timeout allows the Go scheduler to switch to the Go rountine that reads the crypto stream and to escalate the crypto // This timeout allows the Go scheduler to switch to the Go rountine that reads the crypto stream and to escalate the crypto
const PublicResetTimeout = 500 * time.Millisecond const PublicResetTimeout = 500 * time.Millisecond
// AckSendDelay is the maximum delay that can be applied to an ACK for a retransmittable packet
// This is the value Chromium is using
const AckSendDelay = 25 * time.Millisecond
// ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data // ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data
// This is the value that Google servers are using // This is the value that Google servers are using
const ReceiveStreamFlowControlWindow = (1 << 10) * 32 // 32 kB const ReceiveStreamFlowControlWindow = (1 << 10) * 32 // 32 kB
...@@ -59,8 +57,11 @@ const ConnectionFlowControlMultiplier = 1.5 ...@@ -59,8 +57,11 @@ const ConnectionFlowControlMultiplier = 1.5
// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client // WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client
const WindowUpdateThreshold = 0.25 const WindowUpdateThreshold = 0.25
// MaxIncomingStreams is the maximum number of streams that a peer may open // DefaultMaxIncomingStreams is the maximum number of streams that a peer may open
const MaxIncomingStreams = 100 const DefaultMaxIncomingStreams = 100
// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open
const DefaultMaxIncomingUniStreams = 100
// MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this procentual increase and the absolute increment specified by MaxStreamsMinimumIncrement is used. // MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this procentual increase and the absolute increment specified by MaxStreamsMinimumIncrement is used.
const MaxStreamsMultiplier = 1.1 const MaxStreamsMultiplier = 1.1
...@@ -68,10 +69,6 @@ const MaxStreamsMultiplier = 1.1 ...@@ -68,10 +69,6 @@ const MaxStreamsMultiplier = 1.1
// MaxStreamsMinimumIncrement is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this absolute increment and the procentual increase specified by MaxStreamsMultiplier is used. // MaxStreamsMinimumIncrement is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this absolute increment and the procentual increase specified by MaxStreamsMultiplier is used.
const MaxStreamsMinimumIncrement = 10 const MaxStreamsMinimumIncrement = 10
// MaxNewStreamIDDelta is the maximum difference between and a newly opened Stream and the highest StreamID that a client has ever opened
// note that the number of streams is half this value, since the client can only open streams with open StreamID
const MaxNewStreamIDDelta = 4 * MaxIncomingStreams
// MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed. // MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed.
const MaxSessionUnprocessedPackets = DefaultMaxCongestionWindow const MaxSessionUnprocessedPackets = DefaultMaxCongestionWindow
...@@ -84,8 +81,15 @@ const MaxTrackedSkippedPackets = 10 ...@@ -84,8 +81,15 @@ const MaxTrackedSkippedPackets = 10
// CookieExpiryTime is the valid time of a cookie // CookieExpiryTime is the valid time of a cookie
const CookieExpiryTime = 24 * time.Hour const CookieExpiryTime = 24 * time.Hour
// MaxTrackedSentPackets is maximum number of sent packets saved for either later retransmission or entropy calculation // MaxOutstandingSentPackets is maximum number of packets saved for retransmission.
const MaxTrackedSentPackets = 2 * DefaultMaxCongestionWindow // When reached, it imposes a soft limit on sending new packets:
// Sending ACKs and retransmission is still allowed, but now new regular packets can be sent.
const MaxOutstandingSentPackets = 2 * DefaultMaxCongestionWindow
// MaxTrackedSentPackets is maximum number of sent packets saved for retransmission.
// When reached, no more packets will be sent.
// This value *must* be larger than MaxOutstandingSentPackets.
const MaxTrackedSentPackets = MaxOutstandingSentPackets * 5 / 4
// MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked // MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked
const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow
...@@ -93,9 +97,6 @@ const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow ...@@ -93,9 +97,6 @@ const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow
// MaxNonRetransmittableAcks is the maximum number of packets containing an ACK, but no retransmittable frames, that we send in a row // MaxNonRetransmittableAcks is the maximum number of packets containing an ACK, but no retransmittable frames, that we send in a row
const MaxNonRetransmittableAcks = 19 const MaxNonRetransmittableAcks = 19
// RetransmittablePacketsBeforeAck is the number of retransmittable that an ACK is sent for
const RetransmittablePacketsBeforeAck = 10
// MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames // MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames
// prevents DoS attacks against the streamFrameSorter // prevents DoS attacks against the streamFrameSorter
const MaxStreamFrameSorterGaps = 1000 const MaxStreamFrameSorterGaps = 1000
......
...@@ -4,10 +4,11 @@ import ( ...@@ -4,10 +4,11 @@ import (
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"math"
) )
// VersionNumber is a version number as int // VersionNumber is a version number as int
type VersionNumber int32 type VersionNumber uint32
// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions // gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions
const ( const (
...@@ -20,7 +21,7 @@ const ( ...@@ -20,7 +21,7 @@ const (
Version39 VersionNumber = gquicVersion0 + 3*0x100 + 0x9 + iota Version39 VersionNumber = gquicVersion0 + 3*0x100 + 0x9 + iota
VersionTLS VersionNumber = 101 VersionTLS VersionNumber = 101
VersionWhatever VersionNumber = 0 // for when the version doesn't matter VersionWhatever VersionNumber = 0 // for when the version doesn't matter
VersionUnknown VersionNumber = -1 VersionUnknown VersionNumber = math.MaxUint32
) )
// SupportedVersions lists the versions that the server supports // SupportedVersions lists the versions that the server supports
...@@ -29,6 +30,11 @@ var SupportedVersions = []VersionNumber{ ...@@ -29,6 +30,11 @@ var SupportedVersions = []VersionNumber{
Version39, Version39,
} }
// IsValidVersion says if the version is known to quic-go
func IsValidVersion(v VersionNumber) bool {
return v == VersionTLS || IsSupportedVersion(SupportedVersions, v)
}
// UsesTLS says if this QUIC version uses TLS 1.3 for the handshake // UsesTLS says if this QUIC version uses TLS 1.3 for the handshake
func (vn VersionNumber) UsesTLS() bool { func (vn VersionNumber) UsesTLS() bool {
return vn == VersionTLS return vn == VersionTLS
...@@ -46,7 +52,7 @@ func (vn VersionNumber) String() string { ...@@ -46,7 +52,7 @@ func (vn VersionNumber) String() string {
if vn.isGQUIC() { if vn.isGQUIC() {
return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion()) return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion())
} }
return fmt.Sprintf("%d", vn) return fmt.Sprintf("%#x", uint32(vn))
} }
} }
...@@ -71,6 +77,11 @@ func (vn VersionNumber) UsesIETFFrameFormat() bool { ...@@ -71,6 +77,11 @@ func (vn VersionNumber) UsesIETFFrameFormat() bool {
return vn != Version39 return vn != Version39
} }
// UsesStopWaitingFrames tells if this version uses STOP_WAITING frames
func (vn VersionNumber) UsesStopWaitingFrames() bool {
return vn == Version39
}
// StreamContributesToConnectionFlowControl says if a stream contributes to connection-level flow control // StreamContributesToConnectionFlowControl says if a stream contributes to connection-level flow control
func (vn VersionNumber) StreamContributesToConnectionFlowControl(id StreamID) bool { func (vn VersionNumber) StreamContributesToConnectionFlowControl(id StreamID) bool {
if id == vn.CryptoStreamID() { if id == vn.CryptoStreamID() {
......
...@@ -31,7 +31,7 @@ func (t *Timer) Reset(deadline time.Time) { ...@@ -31,7 +31,7 @@ func (t *Timer) Reset(deadline time.Time) {
if !t.t.Stop() && !t.read { if !t.t.Stop() && !t.read {
<-t.t.C <-t.t.C
} }
t.t.Reset(deadline.Sub(time.Now())) t.t.Reset(time.Until(deadline))
t.read = false t.read = false
t.deadline = deadline t.deadline = deadline
......
...@@ -60,7 +60,7 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, ...@@ -60,7 +60,7 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame,
if ackBlock > frame.LargestAcked { if ackBlock > frame.LargestAcked {
return nil, errors.New("invalid first ACK range") return nil, errors.New("invalid first ACK range")
} }
smallest := frame.LargestAcked - protocol.PacketNumber(ackBlock) smallest := frame.LargestAcked - ackBlock
// read all the other ACK ranges // read all the other ACK ranges
if numBlocks > 0 { if numBlocks > 0 {
...@@ -86,7 +86,7 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, ...@@ -86,7 +86,7 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame,
if ackBlock > largest { if ackBlock > largest {
return nil, errInvalidAckRanges return nil, errInvalidAckRanges
} }
smallest = largest - protocol.PacketNumber(ackBlock) smallest = largest - ackBlock
frame.AckRanges = append(frame.AckRanges, AckRange{First: smallest, Last: largest}) frame.AckRanges = append(frame.AckRanges, AckRange{First: smallest, Last: largest})
} }
...@@ -144,7 +144,7 @@ func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount { ...@@ -144,7 +144,7 @@ func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
return f.lengthLegacy(version) return f.lengthLegacy(version)
} }
length := 1 + utils.VarIntLen(uint64(f.LargestAcked)) + utils.VarIntLen(uint64(encodeAckDelay(f.DelayTime))) length := 1 + utils.VarIntLen(uint64(f.LargestAcked)) + utils.VarIntLen(encodeAckDelay(f.DelayTime))
var lowestInFirstRange protocol.PacketNumber var lowestInFirstRange protocol.PacketNumber
if f.HasMissingRanges() { if f.HasMissingRanges() {
......
...@@ -7,3 +7,8 @@ type AckRange struct { ...@@ -7,3 +7,8 @@ type AckRange struct {
First protocol.PacketNumber First protocol.PacketNumber
Last protocol.PacketNumber Last protocol.PacketNumber
} }
// Len returns the number of packets contained in this ACK range
func (r AckRange) Len() protocol.PacketNumber {
return r.Last - r.First + 1
}
...@@ -48,7 +48,7 @@ func ParseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) ...@@ -48,7 +48,7 @@ func ParseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber)
reasonPhraseLen = uint64(length) reasonPhraseLen = uint64(length)
} }
// shortcut to prevent the unneccessary allocation of dataLen bytes // shortcut to prevent the unnecessary allocation of dataLen bytes
// if the dataLen is larger than the remaining length of the packet // if the dataLen is larger than the remaining length of the packet
// reading the whole reason phrase would result in EOF when attempting to READ // reading the whole reason phrase would result in EOF when attempting to READ
if int(reasonPhraseLen) > r.Len() { if int(reasonPhraseLen) > r.Len() {
...@@ -62,7 +62,7 @@ func ParseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) ...@@ -62,7 +62,7 @@ func ParseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber)
} }
return &ConnectionCloseFrame{ return &ConnectionCloseFrame{
ErrorCode: qerr.ErrorCode(errorCode), ErrorCode: errorCode,
ReasonPhrase: string(reasonPhrase), ReasonPhrase: string(reasonPhrase),
}, nil }, nil
} }
......
...@@ -41,7 +41,7 @@ func ParseGoawayFrame(r *bytes.Reader, _ protocol.VersionNumber) (*GoawayFrame, ...@@ -41,7 +41,7 @@ func ParseGoawayFrame(r *bytes.Reader, _ protocol.VersionNumber) (*GoawayFrame,
return nil, err return nil, err
} }
if reasonPhraseLen > uint16(protocol.MaxPacketSize) { if reasonPhraseLen > uint16(protocol.MaxReceivePacketSize) {
return nil, qerr.Error(qerr.InvalidGoawayData, "reason phrase too long") return nil, qerr.Error(qerr.InvalidGoawayData, "reason phrase too long")
} }
......
...@@ -50,6 +50,7 @@ func ParseHeaderSentByServer(b *bytes.Reader, version protocol.VersionNumber) (* ...@@ -50,6 +50,7 @@ func ParseHeaderSentByServer(b *bytes.Reader, version protocol.VersionNumber) (*
// the client knows the version that this packet was sent with // the client knows the version that this packet was sent with
isPublicHeader = !version.UsesTLS() isPublicHeader = !version.UsesTLS()
} }
return parsePacketHeader(b, protocol.PerspectiveServer, isPublicHeader) return parsePacketHeader(b, protocol.PerspectiveServer, isPublicHeader)
} }
...@@ -61,12 +62,13 @@ func ParseHeaderSentByClient(b *bytes.Reader) (*Header, error) { ...@@ -61,12 +62,13 @@ func ParseHeaderSentByClient(b *bytes.Reader) (*Header, error) {
} }
_ = b.UnreadByte() // unread the type byte _ = b.UnreadByte() // unread the type byte
// If this is a gQUIC header 0x80 and 0x40 will be set to 0. // In an IETF QUIC packet header
// If this is an IETF QUIC header there are two options: // * either 0x80 is set (for the Long Header)
// * either 0x80 will be 1 (for the Long Header) // * or 0x8 is unset (for the Short Header)
// * or 0x40 (the Connection ID Flag) will be 0 (for the Short Header), since we don't the client to omit it // In a gQUIC Public Header
isPublicHeader := typeByte&0xc0 == 0 // * 0x80 is always unset and
// * and 0x8 is always set (this is the Connection ID flag, which the client always sets)
isPublicHeader := typeByte&0x88 == 0x8
return parsePacketHeader(b, protocol.PerspectiveClient, isPublicHeader) return parsePacketHeader(b, protocol.PerspectiveClient, isPublicHeader)
} }
......
...@@ -2,6 +2,7 @@ package wire ...@@ -2,6 +2,7 @@ package wire
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
...@@ -31,14 +32,8 @@ func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte ...@@ -31,14 +32,8 @@ func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte
if err != nil { if err != nil {
return nil, err return nil, err
} }
pn, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return nil, err
}
h := &Header{ h := &Header{
ConnectionID: protocol.ConnectionID(connID), ConnectionID: protocol.ConnectionID(connID),
PacketNumber: protocol.PacketNumber(pn),
PacketNumberLen: protocol.PacketNumberLen4,
Version: protocol.VersionNumber(v), Version: protocol.VersionNumber(v),
} }
if v == 0 { // version negotiation packet if v == 0 { // version negotiation packet
...@@ -60,6 +55,12 @@ func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte ...@@ -60,6 +55,12 @@ func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte
return h, nil return h, nil
} }
h.IsLongHeader = true h.IsLongHeader = true
pn, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return nil, err
}
h.PacketNumber = protocol.PacketNumber(pn)
h.PacketNumberLen = protocol.PacketNumberLen4
h.Type = protocol.PacketType(typeByte & 0x7f) h.Type = protocol.PacketType(typeByte & 0x7f)
if sentBy == protocol.PerspectiveClient && (h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeHandshake && h.Type != protocol.PacketType0RTT) { if sentBy == protocol.PerspectiveClient && (h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeHandshake && h.Type != protocol.PacketType0RTT) {
return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type)) return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type))
...@@ -71,26 +72,40 @@ func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte ...@@ -71,26 +72,40 @@ func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte
} }
func parseShortHeader(b *bytes.Reader, typeByte byte) (*Header, error) { func parseShortHeader(b *bytes.Reader, typeByte byte) (*Header, error) {
hasConnID := typeByte&0x40 > 0 omitConnID := typeByte&0x40 > 0
var connID uint64 var connID uint64
if hasConnID { if !omitConnID {
var err error var err error
connID, err = utils.BigEndian.ReadUint64(b) connID, err = utils.BigEndian.ReadUint64(b)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
pnLen := 1 << ((typeByte & 0x3) - 1) // bit 4 must be set, bit 5 must be unset
if typeByte&0x18 != 0x10 {
return nil, errors.New("invalid bit 4 and 5")
}
var pnLen protocol.PacketNumberLen
switch typeByte & 0x7 {
case 0x0:
pnLen = protocol.PacketNumberLen1
case 0x1:
pnLen = protocol.PacketNumberLen2
case 0x2:
pnLen = protocol.PacketNumberLen4
default:
return nil, errors.New("invalid short header type")
}
pn, err := utils.BigEndian.ReadUintN(b, uint8(pnLen)) pn, err := utils.BigEndian.ReadUintN(b, uint8(pnLen))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Header{ return &Header{
KeyPhase: int(typeByte&0x20) >> 5, KeyPhase: int(typeByte&0x20) >> 5,
OmitConnectionID: !hasConnID, OmitConnectionID: omitConnID,
ConnectionID: protocol.ConnectionID(connID), ConnectionID: protocol.ConnectionID(connID),
PacketNumber: protocol.PacketNumber(pn), PacketNumber: protocol.PacketNumber(pn),
PacketNumberLen: protocol.PacketNumberLen(pnLen), PacketNumberLen: pnLen,
}, nil }, nil
} }
...@@ -112,17 +127,17 @@ func (h *Header) writeLongHeader(b *bytes.Buffer) error { ...@@ -112,17 +127,17 @@ func (h *Header) writeLongHeader(b *bytes.Buffer) error {
} }
func (h *Header) writeShortHeader(b *bytes.Buffer) error { func (h *Header) writeShortHeader(b *bytes.Buffer) error {
typeByte := byte(h.KeyPhase << 5) typeByte := byte(0x10)
if !h.OmitConnectionID { typeByte ^= byte(h.KeyPhase << 5)
if h.OmitConnectionID {
typeByte ^= 0x40 typeByte ^= 0x40
} }
switch h.PacketNumberLen { switch h.PacketNumberLen {
case protocol.PacketNumberLen1: case protocol.PacketNumberLen1:
typeByte ^= 0x1
case protocol.PacketNumberLen2: case protocol.PacketNumberLen2:
typeByte ^= 0x2 typeByte ^= 0x1
case protocol.PacketNumberLen4: case protocol.PacketNumberLen4:
typeByte ^= 0x3 typeByte ^= 0x2
default: default:
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
} }
......
...@@ -134,10 +134,9 @@ func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Hea ...@@ -134,10 +134,9 @@ func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Hea
} }
} }
// Contrary to what the gQUIC wire spec says, the 0x4 bit only indicates the presence of the diversification nonce for packets sent by the server.
// It doesn't have any meaning when sent by the client.
if packetSentBy == protocol.PerspectiveServer && publicFlagByte&0x04 > 0 { if packetSentBy == protocol.PerspectiveServer && publicFlagByte&0x04 > 0 {
// TODO: remove the if once the Google servers send the correct value
// assume that a packet doesn't contain a diversification nonce if the version flag or the reset flag is set, no matter what the public flag says
// see https://github.com/lucas-clemente/quic-go/issues/232
if !header.VersionFlag && !header.ResetFlag { if !header.VersionFlag && !header.ResetFlag {
header.DiversificationNonce = make([]byte, 32) header.DiversificationNonce = make([]byte, 32)
if _, err := io.ReadFull(b, header.DiversificationNonce); err != nil { if _, err := io.ReadFull(b, header.DiversificationNonce); err != nil {
...@@ -148,7 +147,7 @@ func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Hea ...@@ -148,7 +147,7 @@ func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Hea
// Version (optional) // Version (optional)
if !header.ResetFlag && header.VersionFlag { if !header.ResetFlag && header.VersionFlag {
if packetSentBy == protocol.PerspectiveServer { // parse the version negotiaton packet if packetSentBy == protocol.PerspectiveServer { // parse the version negotiation packet
if b.Len() == 0 { if b.Len() == 0 {
return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list") return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
} }
...@@ -236,7 +235,7 @@ func (h *Header) logPublicHeader() { ...@@ -236,7 +235,7 @@ func (h *Header) logPublicHeader() {
} }
ver := "(unset)" ver := "(unset)"
if h.Version != 0 { if h.Version != 0 {
ver = fmt.Sprintf("%s", h.Version) ver = h.Version.String()
} }
utils.Debugf(" Public Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", connID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce) utils.Debugf(" Public Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", connID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce)
} }
...@@ -56,7 +56,7 @@ func ParseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamF ...@@ -56,7 +56,7 @@ func ParseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamF
if err != nil { if err != nil {
return nil, err return nil, err
} }
// shortcut to prevent the unneccessary allocation of dataLen bytes // shortcut to prevent the unnecessary allocation of dataLen bytes
// if the dataLen is larger than the remaining length of the packet // if the dataLen is larger than the remaining length of the packet
// reading the packet contents would result in EOF when attempting to READ // reading the packet contents would result in EOF when attempting to READ
if dataLen > uint64(r.Len()) { if dataLen > uint64(r.Len()) {
......
...@@ -52,7 +52,7 @@ func parseLegacyStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamF ...@@ -52,7 +52,7 @@ func parseLegacyStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamF
} }
} }
// shortcut to prevent the unneccessary allocation of dataLen bytes // shortcut to prevent the unnecessary allocation of dataLen bytes
// if the dataLen is larger than the remaining length of the packet // if the dataLen is larger than the remaining length of the packet
// reading the packet contents would result in EOF when attempting to READ // reading the packet contents would result in EOF when attempting to READ
if int(dataLen) > r.Len() { if int(dataLen) > r.Len() {
......
...@@ -10,50 +10,37 @@ import ( ...@@ -10,50 +10,37 @@ import (
// ComposeGQUICVersionNegotiation composes a Version Negotiation Packet for gQUIC // ComposeGQUICVersionNegotiation composes a Version Negotiation Packet for gQUIC
func ComposeGQUICVersionNegotiation(connID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { func ComposeGQUICVersionNegotiation(connID protocol.ConnectionID, versions []protocol.VersionNumber) []byte {
fullReply := &bytes.Buffer{} buf := &bytes.Buffer{}
ph := Header{ ph := Header{
ConnectionID: connID, ConnectionID: connID,
PacketNumber: 1, PacketNumber: 1,
VersionFlag: true, VersionFlag: true,
IsVersionNegotiation: true, IsVersionNegotiation: true,
} }
if err := ph.writePublicHeader(fullReply, protocol.PerspectiveServer, protocol.VersionWhatever); err != nil { if err := ph.writePublicHeader(buf, protocol.PerspectiveServer, protocol.VersionWhatever); err != nil {
utils.Errorf("error composing version negotiation packet: %s", err.Error()) utils.Errorf("error composing version negotiation packet: %s", err.Error())
return nil return nil
} }
writeVersions(fullReply, versions) for _, v := range versions {
return fullReply.Bytes() utils.BigEndian.WriteUint32(buf, uint32(v))
}
return buf.Bytes()
} }
// ComposeVersionNegotiation composes a Version Negotiation according to the IETF draft // ComposeVersionNegotiation composes a Version Negotiation according to the IETF draft
func ComposeVersionNegotiation( func ComposeVersionNegotiation(
connID protocol.ConnectionID, connID protocol.ConnectionID,
pn protocol.PacketNumber,
versions []protocol.VersionNumber, versions []protocol.VersionNumber,
) []byte { ) []byte {
fullReply := &bytes.Buffer{} greasedVersions := protocol.GetGreasedVersions(versions)
buf := bytes.NewBuffer(make([]byte, 0, 1+8+4+len(greasedVersions)*4))
r := make([]byte, 1) r := make([]byte, 1)
_, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here. _, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here.
h := Header{ buf.WriteByte(r[0] | 0x80)
IsLongHeader: true, utils.BigEndian.WriteUint64(buf, uint64(connID))
Type: protocol.PacketType(r[0] | 0x80), utils.BigEndian.WriteUint32(buf, 0) // version 0
ConnectionID: connID, for _, v := range greasedVersions {
PacketNumber: pn,
Version: 0,
IsVersionNegotiation: true,
}
if err := h.writeHeader(fullReply); err != nil {
utils.Errorf("error composing version negotiation packet: %s", err.Error())
return nil
}
writeVersions(fullReply, versions)
return fullReply.Bytes()
}
// writeVersions writes the versions for a Version Negotiation Packet.
// It inserts one reserved version number at a random position.
func writeVersions(buf *bytes.Buffer, supported []protocol.VersionNumber) {
for _, v := range protocol.GetGreasedVersions(supported) {
utils.BigEndian.WriteUint32(buf, uint32(v)) utils.BigEndian.WriteUint32(buf, uint32(v))
} }
return buf.Bytes()
} }
...@@ -139,8 +139,8 @@ func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, versio ...@@ -139,8 +139,8 @@ func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, versio
// packUnencryptedPacket provides a low-overhead way to pack a packet. // packUnencryptedPacket provides a low-overhead way to pack a packet.
// It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available. // It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available.
func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective) ([]byte, error) { func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective) ([]byte, error) {
raw := getPacketBuffer() raw := *getPacketBuffer()
buffer := bytes.NewBuffer(raw) buffer := bytes.NewBuffer(raw[:0])
if err := hdr.Write(buffer, pers, hdr.Version); err != nil { if err := hdr.Write(buffer, pers, hdr.Version); err != nil {
return nil, err return nil, err
} }
......
...@@ -24,8 +24,9 @@ type packetUnpacker struct { ...@@ -24,8 +24,9 @@ type packetUnpacker struct {
} }
func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) { func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) {
buf := getPacketBuffer() buf := *getPacketBuffer()
defer putPacketBuffer(buf) buf = buf[:0]
defer putPacketBuffer(&buf)
decrypted, encryptionLevel, err := u.aead.Open(buf, data, hdr.PacketNumber, headerBinary) decrypted, encryptionLevel, err := u.aead.Open(buf, data, hdr.PacketNumber, headerBinary)
if err != nil { if err != nil {
// Wrap err in quicError so that public reset is sent by session // Wrap err in quicError so that public reset is sent by session
......
...@@ -31,6 +31,7 @@ func (e *QuicError) Error() string { ...@@ -31,6 +31,7 @@ func (e *QuicError) Error() string {
return fmt.Sprintf("%s: %s", e.ErrorCode.String(), e.ErrorMessage) return fmt.Sprintf("%s: %s", e.ErrorCode.String(), e.ErrorMessage)
} }
// Timeout says if this error is a timeout.
func (e *QuicError) Timeout() bool { func (e *QuicError) Timeout() bool {
switch e.ErrorCode { switch e.ErrorCode {
case NetworkIdleTimeout, case NetworkIdleTimeout,
......
...@@ -124,7 +124,7 @@ func (s *receiveStream) Read(p []byte) (int, error) { ...@@ -124,7 +124,7 @@ func (s *receiveStream) Read(p []byte) (int, error) {
} else { } else {
select { select {
case <-s.readChan: case <-s.readChan:
case <-time.After(deadline.Sub(time.Now())): case <-time.After(time.Until(deadline)):
} }
} }
s.mutex.Lock() s.mutex.Lock()
......
...@@ -116,7 +116,7 @@ func (s *sendStream) Write(p []byte) (int, error) { ...@@ -116,7 +116,7 @@ func (s *sendStream) Write(p []byte) (int, error) {
} else { } else {
select { select {
case <-s.writeChan: case <-s.writeChan:
case <-time.After(deadline.Sub(time.Now())): case <-time.After(time.Until(deadline)):
} }
} }
s.mutex.Lock() s.mutex.Lock()
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"net" "net"
"sync" "sync"
"time" "time"
...@@ -85,9 +86,12 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, ...@@ -85,9 +86,12 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
} }
config = populateServerConfig(config) config = populateServerConfig(config)
// check if any of the supported versions supports TLS
var supportsTLS bool var supportsTLS bool
for _, v := range config.Versions { for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
// check if any of the supported versions supports TLS
if v.UsesTLS() { if v.UsesTLS() {
supportsTLS = true supportsTLS = true
break break
...@@ -136,10 +140,11 @@ func (s *server) setupTLS() error { ...@@ -136,10 +140,11 @@ func (s *server) setupTLS() error {
case tlsSession := <-sessionChan: case tlsSession := <-sessionChan:
connID := tlsSession.connID connID := tlsSession.connID
sess := tlsSession.sess sess := tlsSession.sess
s.sessionsMutex.Lock()
if _, ok := s.sessions[connID]; ok { // drop this session if it already exists if _, ok := s.sessions[connID]; ok { // drop this session if it already exists
return s.sessionsMutex.Unlock()
continue
} }
s.sessionsMutex.Lock()
s.sessions[connID] = sess s.sessions[connID] = sess
s.sessionsMutex.Unlock() s.sessionsMutex.Unlock()
s.runHandshakeAndSession(sess, connID) s.runHandshakeAndSession(sess, connID)
...@@ -198,6 +203,18 @@ func populateServerConfig(config *Config) *Config { ...@@ -198,6 +203,18 @@ func populateServerConfig(config *Config) *Config {
if maxReceiveConnectionFlowControlWindow == 0 { if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowServer maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowServer
} }
maxIncomingStreams := config.MaxIncomingStreams
if maxIncomingStreams == 0 {
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
} else if maxIncomingStreams < 0 {
maxIncomingStreams = 0
}
maxIncomingUniStreams := config.MaxIncomingUniStreams
if maxIncomingUniStreams == 0 {
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
return &Config{ return &Config{
Versions: versions, Versions: versions,
...@@ -207,13 +224,15 @@ func populateServerConfig(config *Config) *Config { ...@@ -207,13 +224,15 @@ func populateServerConfig(config *Config) *Config {
KeepAlive: config.KeepAlive, KeepAlive: config.KeepAlive,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
} }
} }
// serve listens on an existing PacketConn // serve listens on an existing PacketConn
func (s *server) serve() { func (s *server) serve() {
for { for {
data := getPacketBuffer() data := *getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize] data = data[:protocol.MaxReceivePacketSize]
// The packet size should not exceed protocol.MaxReceivePacketSize bytes // The packet size should not exceed protocol.MaxReceivePacketSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable // If it does, we only read a truncated packet, which will then end up undecryptable
...@@ -309,7 +328,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet ...@@ -309,7 +328,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
var pr *wire.PublicReset var pr *wire.PublicReset
pr, err = wire.ParsePublicReset(r) pr, err = wire.ParsePublicReset(r)
if err != nil { if err != nil {
utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.") utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.", hdr.ConnectionID)
} else { } else {
utils.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.RejectedPacketNumber) utils.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.RejectedPacketNumber)
} }
......
...@@ -36,7 +36,6 @@ type serverTLS struct { ...@@ -36,7 +36,6 @@ type serverTLS struct {
config *Config config *Config
supportedVersions []protocol.VersionNumber supportedVersions []protocol.VersionNumber
mintConf *mint.Config mintConf *mint.Config
cookieProtector mint.CookieProtector
params *handshake.TransportParameters params *handshake.TransportParameters
newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error)
...@@ -72,9 +71,8 @@ func newServerTLS( ...@@ -72,9 +71,8 @@ func newServerTLS(
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
IdleTimeout: config.IdleTimeout, IdleTimeout: config.IdleTimeout,
// TODO(#523): make these values configurable MaxBidiStreams: uint16(config.MaxIncomingStreams),
MaxBidiStreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveServer), MaxUniStreams: uint16(config.MaxIncomingUniStreams),
MaxUniStreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveServer),
}, },
} }
s.newMintConn = s.newMintConnImpl s.newMintConn = s.newMintConnImpl
...@@ -85,7 +83,7 @@ func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data [] ...@@ -85,7 +83,7 @@ func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []
utils.Debugf("Received a Packet. Handling it statelessly.") utils.Debugf("Received a Packet. Handling it statelessly.")
sess, err := s.handleInitialImpl(remoteAddr, hdr, data) sess, err := s.handleInitialImpl(remoteAddr, hdr, data)
if err != nil { if err != nil {
utils.Errorf("Error occured handling initial packet: %s", err) utils.Errorf("Error occurred handling initial packet: %s", err)
return return
} }
if sess == nil { // a stateless reset was done if sess == nil { // a stateless reset was done
...@@ -132,7 +130,7 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat ...@@ -132,7 +130,7 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat
// check version, if not matching send VNP // check version, if not matching send VNP
if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) { if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) {
utils.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) utils.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
_, err := s.conn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, s.supportedVersions), remoteAddr) _, err := s.conn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, s.supportedVersions), remoteAddr)
return nil, err return nil, err
} }
...@@ -149,7 +147,7 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat ...@@ -149,7 +147,7 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat
sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead) sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead)
if err != nil { if err != nil {
if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil { if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil {
utils.Debugf("Error sending CONNECTION_CLOSE: ", ccerr) utils.Debugf("Error sending CONNECTION_CLOSE: %s", ccerr)
} }
return nil, err return nil, err
} }
......
...@@ -12,8 +12,6 @@ type streamFramer struct { ...@@ -12,8 +12,6 @@ type streamFramer struct {
cryptoStream cryptoStreamI cryptoStream cryptoStreamI
version protocol.VersionNumber version protocol.VersionNumber
retransmissionQueue []*wire.StreamFrame
streamQueueMutex sync.Mutex streamQueueMutex sync.Mutex
activeStreams map[protocol.StreamID]struct{} activeStreams map[protocol.StreamID]struct{}
streamQueue []protocol.StreamID streamQueue []protocol.StreamID
...@@ -33,10 +31,6 @@ func newStreamFramer( ...@@ -33,10 +31,6 @@ func newStreamFramer(
} }
} }
func (f *streamFramer) AddFrameForRetransmission(frame *wire.StreamFrame) {
f.retransmissionQueue = append(f.retransmissionQueue, frame)
}
func (f *streamFramer) AddActiveStream(id protocol.StreamID) { func (f *streamFramer) AddActiveStream(id protocol.StreamID) {
if id == f.version.CryptoStreamID() { // the crypto stream is handled separately if id == f.version.CryptoStreamID() { // the crypto stream is handled separately
f.streamQueueMutex.Lock() f.streamQueueMutex.Lock()
...@@ -52,15 +46,6 @@ func (f *streamFramer) AddActiveStream(id protocol.StreamID) { ...@@ -52,15 +46,6 @@ func (f *streamFramer) AddActiveStream(id protocol.StreamID) {
f.streamQueueMutex.Unlock() f.streamQueueMutex.Unlock()
} }
func (f *streamFramer) PopStreamFrames(maxLen protocol.ByteCount) []*wire.StreamFrame {
fs, currentLen := f.maybePopFramesForRetransmission(maxLen)
return append(fs, f.maybePopNormalFrames(maxLen-currentLen)...)
}
func (f *streamFramer) HasFramesForRetransmission() bool {
return len(f.retransmissionQueue) > 0
}
func (f *streamFramer) HasCryptoStreamData() bool { func (f *streamFramer) HasCryptoStreamData() bool {
f.streamQueueMutex.Lock() f.streamQueueMutex.Lock()
hasCryptoStreamData := f.hasCryptoStreamData hasCryptoStreamData := f.hasCryptoStreamData
...@@ -76,34 +61,7 @@ func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.Str ...@@ -76,34 +61,7 @@ func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.Str
return frame return frame
} }
func (f *streamFramer) maybePopFramesForRetransmission(maxTotalLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) { func (f *streamFramer) PopStreamFrames(maxTotalLen protocol.ByteCount) []*wire.StreamFrame {
for len(f.retransmissionQueue) > 0 {
frame := f.retransmissionQueue[0]
frame.DataLenPresent = true
maxLen := maxTotalLen - currentLen
if frame.Length(f.version) > maxLen && maxLen < protocol.MinStreamFrameSize {
break
}
splitFrame, err := frame.MaybeSplitOffFrame(maxLen, f.version)
if err != nil { // maxLen is too small. Can't split frame
break
}
if splitFrame != nil { // frame was split
res = append(res, splitFrame)
currentLen += splitFrame.Length(f.version)
break
}
f.retransmissionQueue = f.retransmissionQueue[1:]
res = append(res, frame)
currentLen += frame.Length(f.version)
}
return
}
func (f *streamFramer) maybePopNormalFrames(maxTotalLen protocol.ByteCount) []*wire.StreamFrame {
var currentLen protocol.ByteCount var currentLen protocol.ByteCount
var frames []*wire.StreamFrame var frames []*wire.StreamFrame
f.streamQueueMutex.Lock() f.streamQueueMutex.Lock()
......
...@@ -35,6 +35,8 @@ var _ streamManager = &streamsMap{} ...@@ -35,6 +35,8 @@ var _ streamManager = &streamsMap{}
func newStreamsMap( func newStreamsMap(
sender streamSender, sender streamSender,
newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController, newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
maxIncomingStreams int,
maxIncomingUniStreams int,
perspective protocol.Perspective, perspective protocol.Perspective,
version protocol.VersionNumber, version protocol.VersionNumber,
) streamManager { ) streamManager {
...@@ -46,7 +48,7 @@ func newStreamsMap( ...@@ -46,7 +48,7 @@ func newStreamsMap(
var firstOutgoingBidiStream, firstOutgoingUniStream, firstIncomingBidiStream, firstIncomingUniStream protocol.StreamID var firstOutgoingBidiStream, firstOutgoingUniStream, firstIncomingBidiStream, firstIncomingUniStream protocol.StreamID
if perspective == protocol.PerspectiveServer { if perspective == protocol.PerspectiveServer {
firstOutgoingBidiStream = 1 firstOutgoingBidiStream = 1
firstIncomingBidiStream = 4 // the crypto stream is handled separatedly firstIncomingBidiStream = 4 // the crypto stream is handled separately
firstOutgoingUniStream = 3 firstOutgoingUniStream = 3
firstIncomingUniStream = 2 firstIncomingUniStream = 2
} else { } else {
...@@ -69,11 +71,10 @@ func newStreamsMap( ...@@ -69,11 +71,10 @@ func newStreamsMap(
newBidiStream, newBidiStream,
sender.queueControlFrame, sender.queueControlFrame,
) )
// TODO(#523): make these values configurable
m.incomingBidiStreams = newIncomingBidiStreamsMap( m.incomingBidiStreams = newIncomingBidiStreamsMap(
firstIncomingBidiStream, firstIncomingBidiStream,
protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, perspective), protocol.MaxBidiStreamID(maxIncomingStreams, perspective),
protocol.MaxIncomingStreams, maxIncomingStreams,
sender.queueControlFrame, sender.queueControlFrame,
newBidiStream, newBidiStream,
) )
...@@ -82,11 +83,10 @@ func newStreamsMap( ...@@ -82,11 +83,10 @@ func newStreamsMap(
newUniSendStream, newUniSendStream,
sender.queueControlFrame, sender.queueControlFrame,
) )
// TODO(#523): make these values configurable
m.incomingUniStreams = newIncomingUniStreamsMap( m.incomingUniStreams = newIncomingUniStreamsMap(
firstIncomingUniStream, firstIncomingUniStream,
protocol.MaxUniStreamID(protocol.MaxIncomingStreams, perspective), protocol.MaxUniStreamID(maxIncomingUniStreams, perspective),
protocol.MaxIncomingStreams, maxIncomingUniStreams,
sender.queueControlFrame, sender.queueControlFrame,
newUniReceiveStream, newUniReceiveStream,
) )
...@@ -206,8 +206,14 @@ func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error { ...@@ -206,8 +206,14 @@ func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error {
} }
func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) { func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) {
m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamID) // Max{Uni,Bidi}StreamID returns the highest stream ID that the peer is allowed to open.
m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamID) // Invert the perspective to determine the value that we are allowed to open.
peerPers := protocol.PerspectiveServer
if m.perspective == protocol.PerspectiveServer {
peerPers = protocol.PerspectiveClient
}
m.outgoingBidiStreams.SetMaxStream(protocol.MaxBidiStreamID(int(p.MaxBidiStreams), peerPers))
m.outgoingUniStreams.SetMaxStream(protocol.MaxUniStreamID(int(p.MaxUniStreams), peerPers))
} }
func (m *streamsMap) CloseWithError(err error) { func (m *streamsMap) CloseWithError(err error) {
......
...@@ -69,20 +69,25 @@ func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) { ...@@ -69,20 +69,25 @@ func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) {
} }
func (m *incomingBidiStreamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { func (m *incomingBidiStreamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
m.mutex.RLock()
if id > m.maxStream { if id > m.maxStream {
m.mutex.RUnlock()
return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream)
} }
// if the id is smaller than the highest we accepted // if the id is smaller than the highest we accepted
// * this stream exists in the map, and we can return it, or // * this stream exists in the map, and we can return it, or
// * this stream was already closed, then we can return the nil // * this stream was already closed, then we can return the nil
if id <= m.highestStream { if id <= m.highestStream {
m.mutex.RLock()
s := m.streams[id] s := m.streams[id]
m.mutex.RUnlock() m.mutex.RUnlock()
return s, nil return s, nil
} }
m.mutex.RUnlock()
m.mutex.Lock() m.mutex.Lock()
// no need to check the two error conditions from above again
// * maxStream can only increase, so if the id was valid before, it definitely is valid now
// * highestStream is only modified by this function
var start protocol.StreamID var start protocol.StreamID
if m.highestStream == 0 { if m.highestStream == 0 {
start = m.nextStream start = m.nextStream
......
...@@ -67,20 +67,25 @@ func (m *incomingItemsMap) AcceptStream() (item, error) { ...@@ -67,20 +67,25 @@ func (m *incomingItemsMap) AcceptStream() (item, error) {
} }
func (m *incomingItemsMap) GetOrOpenStream(id protocol.StreamID) (item, error) { func (m *incomingItemsMap) GetOrOpenStream(id protocol.StreamID) (item, error) {
m.mutex.RLock()
if id > m.maxStream { if id > m.maxStream {
m.mutex.RUnlock()
return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream)
} }
// if the id is smaller than the highest we accepted // if the id is smaller than the highest we accepted
// * this stream exists in the map, and we can return it, or // * this stream exists in the map, and we can return it, or
// * this stream was already closed, then we can return the nil // * this stream was already closed, then we can return the nil
if id <= m.highestStream { if id <= m.highestStream {
m.mutex.RLock()
s := m.streams[id] s := m.streams[id]
m.mutex.RUnlock() m.mutex.RUnlock()
return s, nil return s, nil
} }
m.mutex.RUnlock()
m.mutex.Lock() m.mutex.Lock()
// no need to check the two error conditions from above again
// * maxStream can only increase, so if the id was valid before, it definitely is valid now
// * highestStream is only modified by this function
var start protocol.StreamID var start protocol.StreamID
if m.highestStream == 0 { if m.highestStream == 0 {
start = m.nextStream start = m.nextStream
......
...@@ -69,20 +69,25 @@ func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) { ...@@ -69,20 +69,25 @@ func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) {
} }
func (m *incomingUniStreamsMap) GetOrOpenStream(id protocol.StreamID) (receiveStreamI, error) { func (m *incomingUniStreamsMap) GetOrOpenStream(id protocol.StreamID) (receiveStreamI, error) {
m.mutex.RLock()
if id > m.maxStream { if id > m.maxStream {
m.mutex.RUnlock()
return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream)
} }
// if the id is smaller than the highest we accepted // if the id is smaller than the highest we accepted
// * this stream exists in the map, and we can return it, or // * this stream exists in the map, and we can return it, or
// * this stream was already closed, then we can return the nil // * this stream was already closed, then we can return the nil
if id <= m.highestStream { if id <= m.highestStream {
m.mutex.RLock()
s := m.streams[id] s := m.streams[id]
m.mutex.RUnlock() m.mutex.RUnlock()
return s, nil return s, nil
} }
m.mutex.RUnlock()
m.mutex.Lock() m.mutex.Lock()
// no need to check the two error conditions from above again
// * maxStream can only increase, so if the id was valid before, it definitely is valid now
// * highestStream is only modified by this function
var start protocol.StreamID var start protocol.StreamID
if m.highestStream == 0 { if m.highestStream == 0 {
start = m.nextStream start = m.nextStream
......
...@@ -39,11 +39,10 @@ var _ streamManager = &streamsMapLegacy{} ...@@ -39,11 +39,10 @@ var _ streamManager = &streamsMapLegacy{}
var errMapAccess = errors.New("streamsMap: Error accessing the streams map") var errMapAccess = errors.New("streamsMap: Error accessing the streams map")
func newStreamsMapLegacy(newStream func(protocol.StreamID) streamI, pers protocol.Perspective) streamManager { func newStreamsMapLegacy(newStream func(protocol.StreamID) streamI, maxStreams int, pers protocol.Perspective) streamManager {
// add some tolerance to the maximum incoming streams value // add some tolerance to the maximum incoming streams value
maxStreams := uint32(protocol.MaxIncomingStreams)
maxIncomingStreams := utils.MaxUint32( maxIncomingStreams := utils.MaxUint32(
maxStreams+protocol.MaxStreamsMinimumIncrement, uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement,
uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)), uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)),
) )
sm := streamsMapLegacy{ sm := streamsMapLegacy{
...@@ -131,7 +130,10 @@ func (m *streamsMapLegacy) openRemoteStream(id protocol.StreamID) (streamI, erro ...@@ -131,7 +130,10 @@ func (m *streamsMapLegacy) openRemoteStream(id protocol.StreamID) (streamI, erro
if m.numIncomingStreams >= m.maxIncomingStreams { if m.numIncomingStreams >= m.maxIncomingStreams {
return nil, qerr.TooManyOpenStreams return nil, qerr.TooManyOpenStreams
} }
if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer { // maxNewStreamIDDelta is the maximum difference between and a newly opened Stream and the highest StreamID that a client has ever opened
// note that the number of streams is half this value, since the client can only open streams with open StreamID
maxStreamIDDelta := protocol.StreamID(4 * m.maxIncomingStreams)
if id+maxStreamIDDelta < m.highestStreamOpenedByPeer {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer))
} }
...@@ -185,6 +187,14 @@ func (m *streamsMapLegacy) OpenStreamSync() (Stream, error) { ...@@ -185,6 +187,14 @@ func (m *streamsMapLegacy) OpenStreamSync() (Stream, error) {
} }
} }
func (m *streamsMapLegacy) OpenUniStream() (SendStream, error) {
return nil, errors.New("gQUIC doesn't support unidirectional streams")
}
func (m *streamsMapLegacy) OpenUniStreamSync() (SendStream, error) {
return nil, errors.New("gQUIC doesn't support unidirectional streams")
}
// AcceptStream returns the next stream opened by the peer // AcceptStream returns the next stream opened by the peer
// it blocks until a new stream is opened // it blocks until a new stream is opened
func (m *streamsMapLegacy) AcceptStream() (Stream, error) { func (m *streamsMapLegacy) AcceptStream() (Stream, error) {
...@@ -206,6 +216,10 @@ func (m *streamsMapLegacy) AcceptStream() (Stream, error) { ...@@ -206,6 +216,10 @@ func (m *streamsMapLegacy) AcceptStream() (Stream, error) {
return str, nil return str, nil
} }
func (m *streamsMapLegacy) AcceptUniStream() (ReceiveStream, error) {
return nil, errors.New("gQUIC doesn't support unidirectional streams")
}
func (m *streamsMapLegacy) DeleteStream(id protocol.StreamID) error { func (m *streamsMapLegacy) DeleteStream(id protocol.StreamID) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
......
...@@ -85,10 +85,11 @@ func (m *outgoingBidiStreamsMap) openStreamImpl() (streamI, error) { ...@@ -85,10 +85,11 @@ func (m *outgoingBidiStreamsMap) openStreamImpl() (streamI, error) {
} }
func (m *outgoingBidiStreamsMap) GetStream(id protocol.StreamID) (streamI, error) { func (m *outgoingBidiStreamsMap) GetStream(id protocol.StreamID) (streamI, error) {
m.mutex.RLock()
if id >= m.nextStream { if id >= m.nextStream {
m.mutex.RUnlock()
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))
} }
m.mutex.RLock()
s := m.streams[id] s := m.streams[id]
m.mutex.RUnlock() m.mutex.RUnlock()
return s, nil return s, nil
......
...@@ -86,10 +86,11 @@ func (m *outgoingItemsMap) openStreamImpl() (item, error) { ...@@ -86,10 +86,11 @@ func (m *outgoingItemsMap) openStreamImpl() (item, error) {
} }
func (m *outgoingItemsMap) GetStream(id protocol.StreamID) (item, error) { func (m *outgoingItemsMap) GetStream(id protocol.StreamID) (item, error) {
m.mutex.RLock()
if id >= m.nextStream { if id >= m.nextStream {
m.mutex.RUnlock()
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))
} }
m.mutex.RLock()
s := m.streams[id] s := m.streams[id]
m.mutex.RUnlock() m.mutex.RUnlock()
return s, nil return s, nil
......
...@@ -85,10 +85,11 @@ func (m *outgoingUniStreamsMap) openStreamImpl() (sendStreamI, error) { ...@@ -85,10 +85,11 @@ func (m *outgoingUniStreamsMap) openStreamImpl() (sendStreamI, error) {
} }
func (m *outgoingUniStreamsMap) GetStream(id protocol.StreamID) (sendStreamI, error) { func (m *outgoingUniStreamsMap) GetStream(id protocol.StreamID) (sendStreamI, error) {
m.mutex.RLock()
if id >= m.nextStream { if id >= m.nextStream {
m.mutex.RUnlock()
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))
} }
m.mutex.RLock()
s := m.streams[id] s := m.streams[id]
m.mutex.RUnlock() m.mutex.RUnlock()
return s, nil return s, nil
......
...@@ -201,6 +201,8 @@ func (s State) String() string { ...@@ -201,6 +201,8 @@ func (s State) String() string {
return "Client WAIT_CV" return "Client WAIT_CV"
case StateClientWaitFinished: case StateClientWaitFinished:
return "Client WAIT_FINISHED" return "Client WAIT_FINISHED"
case StateClientWaitCertCR:
return "Client WAIT_CERT_CR"
case StateClientConnected: case StateClientConnected:
return "Client CONNECTED" return "Client CONNECTED"
case StateServerStart: case StateServerStart:
......
...@@ -265,7 +265,7 @@ type Conn struct { ...@@ -265,7 +265,7 @@ type Conn struct {
EarlyData []byte EarlyData []byte
state StateConnected state stateConnected
hState HandshakeState hState HandshakeState
handshakeMutex sync.Mutex handshakeMutex sync.Mutex
handshakeAlert Alert handshakeAlert Alert
...@@ -345,7 +345,7 @@ func (c *Conn) consumeRecord() error { ...@@ -345,7 +345,7 @@ func (c *Conn) consumeRecord() error {
} }
var connected bool var connected bool
c.state, connected = state.(StateConnected) c.state, connected = state.(stateConnected)
if !connected { if !connected {
logf(logTypeHandshake, "Disconnected after state transition: %v", alert) logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
c.sendAlert(alert) c.sendAlert(alert)
...@@ -385,7 +385,7 @@ func (c *Conn) consumeRecord() error { ...@@ -385,7 +385,7 @@ func (c *Conn) consumeRecord() error {
// Read application data up to the size of buffer. Handshake and alert records // Read application data up to the size of buffer. Handshake and alert records
// are consumed by the Conn object directly. // are consumed by the Conn object directly.
func (c *Conn) Read(buffer []byte) (int, error) { func (c *Conn) Read(buffer []byte) (int, error) {
if _, connected := c.hState.(StateConnected); !connected && c.config.NonBlocking { if _, connected := c.hState.(stateConnected); !connected {
return 0, errors.New("Read called before the handshake completed") return 0, errors.New("Read called before the handshake completed")
} }
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer)) logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
...@@ -661,7 +661,7 @@ func (c *Conn) HandshakeSetup() Alert { ...@@ -661,7 +661,7 @@ func (c *Conn) HandshakeSetup() Alert {
} }
if c.isClient { if c.isClient {
state, actions, alert = ClientStateStart{Config: c.config, Opts: opts, hsCtx: c.hsCtx}.Next(nil) state, actions, alert = clientStateStart{Config: c.config, Opts: opts, hsCtx: c.hsCtx}.Next(nil)
if alert != AlertNoAlert { if alert != AlertNoAlert {
logf(logTypeHandshake, "Error initializing client state: %v", alert) logf(logTypeHandshake, "Error initializing client state: %v", alert)
return alert return alert
...@@ -688,7 +688,7 @@ func (c *Conn) HandshakeSetup() Alert { ...@@ -688,7 +688,7 @@ func (c *Conn) HandshakeSetup() Alert {
return AlertInternalError return AlertInternalError
} }
} }
state = ServerStateStart{Config: c.config, conn: c, hsCtx: c.hsCtx} state = serverStateStart{Config: c.config, conn: c, hsCtx: c.hsCtx}
} }
c.hState = state c.hState = state
...@@ -751,7 +751,7 @@ func (c *Conn) Handshake() Alert { ...@@ -751,7 +751,7 @@ func (c *Conn) Handshake() Alert {
logf(logTypeHandshake, "(Re-)entering handshake, state=%v", c.hState) logf(logTypeHandshake, "(Re-)entering handshake, state=%v", c.hState)
state := c.hState state := c.hState
_, connected := state.(StateConnected) _, connected := state.(stateConnected)
hmr := &handshakeMessageReaderImpl{hsCtx: &c.hsCtx} hmr := &handshakeMessageReaderImpl{hsCtx: &c.hsCtx}
for !connected { for !connected {
...@@ -784,9 +784,9 @@ func (c *Conn) Handshake() Alert { ...@@ -784,9 +784,9 @@ func (c *Conn) Handshake() Alert {
c.hState = state c.hState = state
logf(logTypeHandshake, "state is now %s", c.GetHsState()) logf(logTypeHandshake, "state is now %s", c.GetHsState())
_, connected = state.(StateConnected) _, connected = state.(stateConnected)
if connected { if connected {
c.state = state.(StateConnected) c.state = state.(stateConnected)
c.handshakeComplete = true c.handshakeComplete = true
} }
...@@ -852,7 +852,7 @@ func (c *Conn) GetHsState() State { ...@@ -852,7 +852,7 @@ func (c *Conn) GetHsState() State {
} }
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
_, connected := c.hState.(StateConnected) _, connected := c.hState.(stateConnected)
if !connected { if !connected {
return nil, fmt.Errorf("Cannot compute exporter when state is not connected") return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
} }
......
...@@ -11,9 +11,11 @@ import ( ...@@ -11,9 +11,11 @@ import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
"fmt" "fmt"
"math/big" "math/big"
"time"
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
...@@ -616,3 +618,50 @@ func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet { ...@@ -616,3 +618,50 @@ func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet {
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen), iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
} }
} }
func MakeNewSelfSignedCert(name string, alg SignatureScheme) (crypto.Signer, *x509.Certificate, error) {
priv, err := newSigningKey(alg)
if err != nil {
return nil, nil, err
}
cert, err := newSelfSigned(name, alg, priv)
if err != nil {
return nil, nil, err
}
return priv, cert, nil
}
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
sigAlg, ok := x509AlgMap[alg]
if !ok {
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
}
if len(name) == 0 {
return nil, fmt.Errorf("tls.selfsigned: No name provided")
}
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
if err != nil {
return nil, err
}
template := &x509.Certificate{
SerialNumber: serial,
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(0, 0, 1),
SignatureAlgorithm: sigAlg,
Subject: pkix.Name{CommonName: name},
DNSNames: []string{name},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
if err != nil {
return nil, err
}
// It is safe to ignore the error here because we're parsing known-good data
cert, _ := x509.ParseCertificate(der)
return cert, nil
}
...@@ -81,8 +81,8 @@ func (hc *HandshakeContext) SetVersion(version uint16) { ...@@ -81,8 +81,8 @@ func (hc *HandshakeContext) SetVersion(version uint16) {
} }
} }
// StateConnected is symmetric between client and server // stateConnected is symmetric between client and server
type StateConnected struct { type stateConnected struct {
Params ConnectionParameters Params ConnectionParameters
hsCtx HandshakeContext hsCtx HandshakeContext
isClient bool isClient bool
...@@ -95,16 +95,16 @@ type StateConnected struct { ...@@ -95,16 +95,16 @@ type StateConnected struct {
verifiedChains [][]*x509.Certificate verifiedChains [][]*x509.Certificate
} }
var _ HandshakeState = &StateConnected{} var _ HandshakeState = &stateConnected{}
func (state StateConnected) State() State { func (state stateConnected) State() State {
if state.isClient { if state.isClient {
return StateClientConnected return StateClientConnected
} }
return StateServerConnected return StateServerConnected
} }
func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) { func (state *stateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) {
var trafficKeys keySet var trafficKeys keySet
if state.isClient { if state.isClient {
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret, state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
...@@ -130,7 +130,7 @@ func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAct ...@@ -130,7 +130,7 @@ func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAct
return toSend, AlertNoAlert return toSend, AlertNoAlert
} }
func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) { func (state *stateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) {
tkt, err := NewSessionTicket(length, lifetime) tkt, err := NewSessionTicket(length, lifetime)
if err != nil { if err != nil {
logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err) logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err)
...@@ -172,11 +172,11 @@ func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLif ...@@ -172,11 +172,11 @@ func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLif
} }
// Next does nothing for this state. // Next does nothing for this state.
func (state StateConnected) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { func (state stateConnected) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
return state, nil, AlertNoAlert return state, nil, AlertNoAlert
} }
func (state StateConnected) ProcessMessage(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { func (state stateConnected) ProcessMessage(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil { if hm == nil {
logf(logTypeHandshake, "[StateConnected] Unexpected message") logf(logTypeHandshake, "[StateConnected] Unexpected message")
return nil, nil, AlertUnexpectedMessage return nil, nil, AlertUnexpectedMessage
......
...@@ -145,7 +145,7 @@ ...@@ -145,7 +145,7 @@
"importpath": "github.com/lucas-clemente/quic-go", "importpath": "github.com/lucas-clemente/quic-go",
"repository": "https://github.com/lucas-clemente/quic-go", "repository": "https://github.com/lucas-clemente/quic-go",
"vcs": "git", "vcs": "git",
"revision": "d71850eb2ff581620f2f5742b558a97de22c13f6", "revision": "9fa739409e6edddbbd47c8031cb7bb3d1a209cc8",
"branch": "master", "branch": "master",
"notests": true "notests": true
}, },
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment