Commit 208f023d authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Make upConnections generic.

parent 0f96f944
...@@ -7,338 +7,38 @@ package main ...@@ -7,338 +7,38 @@ package main
import ( import (
"errors" "errors"
"sync"
"sync/atomic"
"sfu/estimator"
"sfu/jitter"
"sfu/packetcache"
"sfu/rtptime"
"github.com/pion/rtp" "github.com/pion/rtp"
"github.com/pion/webrtc/v2" "github.com/pion/webrtc/v2"
) )
type localTrackAction struct {
add bool
track downTrack
}
type upTrack struct {
track *webrtc.Track
label string
rate *estimator.Estimator
cache *packetcache.Cache
jitter *jitter.Estimator
maxBitrate uint64
lastPLI uint64
lastFIR uint64
firSeqno uint32
localCh chan localTrackAction // signals that local has changed
writerDone chan struct{} // closed when the loop dies
mu sync.Mutex
local []downTrack
srTime uint64
srNTPTime uint64
srRTPTime uint32
}
func (up *upTrack) notifyLocal(add bool, track downTrack) {
select {
case up.localCh <- localTrackAction{add, track}:
case <-up.writerDone:
}
}
func (up *upTrack) addLocal(local downTrack) error {
up.mu.Lock()
for _, t := range up.local {
if t == local {
up.mu.Unlock()
return nil
}
}
up.local = append(up.local, local)
up.mu.Unlock()
up.notifyLocal(true, local)
return nil
}
func (up *upTrack) delLocal(local downTrack) bool {
up.mu.Lock()
for i, l := range up.local {
if l == local {
up.local = append(up.local[:i], up.local[i+1:]...)
up.mu.Unlock()
up.notifyLocal(false, l)
return true
}
}
up.mu.Unlock()
return false
}
func (up *upTrack) getLocal() []downTrack {
up.mu.Lock()
defer up.mu.Unlock()
local := make([]downTrack, len(up.local))
copy(local, up.local)
return local
}
func (up *upTrack) hasRtcpFb(tpe, parameter string) bool {
for _, fb := range up.track.Codec().RTCPFeedback {
if fb.Type == tpe && fb.Parameter == parameter {
return true
}
}
return false
}
type iceConnection interface {
addICECandidate(candidate *webrtc.ICECandidateInit) error
flushICECandidates() error
}
type upConnection struct {
id string
label string
pc *webrtc.PeerConnection
labels map[string]string
iceCandidates []*webrtc.ICECandidateInit
mu sync.Mutex
closed bool
tracks []*upTrack
local []downConnection
}
var ErrConnectionClosed = errors.New("connection is closed") var ErrConnectionClosed = errors.New("connection is closed")
var ErrKeyframeNeeded = errors.New("keyframe needed")
func (up *upConnection) getTracks() []*upTrack { type upConnection interface {
up.mu.Lock() addLocal(downConnection) error
defer up.mu.Unlock() delLocal(downConnection) bool
tracks := make([]*upTrack, len(up.tracks)) Id() string
copy(tracks, up.tracks) Label() string
return tracks
}
func (up *upConnection) addLocal(local downConnection) error {
up.mu.Lock()
defer up.mu.Unlock()
if up.closed {
return ErrConnectionClosed
}
for _, t := range up.local {
if t == local {
return nil
}
}
up.local = append(up.local, local)
return nil
}
func (up *upConnection) delLocal(local downConnection) bool {
up.mu.Lock()
defer up.mu.Unlock()
for i, l := range up.local {
if l == local {
up.local = append(up.local[:i], up.local[i+1:]...)
return true
}
}
return false
}
func (up *upConnection) getLocal() []downConnection {
up.mu.Lock()
defer up.mu.Unlock()
local := make([]downConnection, len(up.local))
copy(local, up.local)
return local
}
func (up *upConnection) Close() error {
up.mu.Lock()
defer up.mu.Unlock()
go func(local []downConnection) {
for _, l := range local {
l.Close()
}
}(up.local)
up.local = nil
up.closed = true
return up.pc.Close()
}
func (up *upConnection) addICECandidate(candidate *webrtc.ICECandidateInit) error {
if up.pc.RemoteDescription() != nil {
return up.pc.AddICECandidate(*candidate)
}
up.iceCandidates = append(up.iceCandidates, candidate)
return nil
}
func flushICECandidates(pc *webrtc.PeerConnection, candidates []*webrtc.ICECandidateInit) error {
if pc.RemoteDescription() == nil {
return errors.New("flushICECandidates called in bad state")
}
var err error
for _, candidate := range candidates {
err2 := pc.AddICECandidate(*candidate)
if err == nil {
err = err2
}
}
return err
}
func (up *upConnection) flushICECandidates() error {
err := flushICECandidates(up.pc, up.iceCandidates)
up.iceCandidates = nil
return err
}
func getUpMid(pc *webrtc.PeerConnection, track *webrtc.Track) string {
for _, t := range pc.GetTransceivers() {
if t.Receiver() != nil && t.Receiver().Track() == track {
return t.Mid()
}
}
return ""
}
// called locked
func (up *upConnection) complete() bool {
for mid, _ := range up.labels {
found := false
for _, t := range up.tracks {
m := getUpMid(up.pc, t.track)
if m == mid {
found = true
break
}
}
if !found {
return false
}
}
return true
}
type bitrate struct {
bitrate uint64
jiffies uint64
}
const receiverReportTimeout = 8 * rtptime.JiffiesPerSec
func (br *bitrate) Set(bitrate uint64, now uint64) {
// this is racy -- a reader might read the
// data between the two writes. This shouldn't
// matter, we'll recover at the next sample.
atomic.StoreUint64(&br.bitrate, bitrate)
atomic.StoreUint64(&br.jiffies, now)
}
func (br *bitrate) Get(now uint64) uint64 {
ts := atomic.LoadUint64(&br.jiffies)
if now < ts || now-ts > receiverReportTimeout {
return ^uint64(0)
}
return atomic.LoadUint64(&br.bitrate)
}
type receiverStats struct {
loss uint32
jitter uint32
jiffies uint64
} }
func (s *receiverStats) Set(loss uint8, jitter uint32, now uint64) { type upTrack interface {
atomic.StoreUint32(&s.loss, uint32(loss)) addLocal(downTrack) error
atomic.StoreUint32(&s.jitter, jitter) delLocal(downTrack) bool
atomic.StoreUint64(&s.jiffies, now) Label() string
Codec() *webrtc.RTPCodec
// get a recent packet. Returns 0 if the packet is not in cache.
getRTP(seqno uint16, result []byte) uint16
// returns the last timestamp, if possible
getTimestamp() (uint32, bool)
} }
func (s *receiverStats) Get(now uint64) (uint8, uint32) { type downConnection interface {
ts := atomic.LoadUint64(&s.jiffies) Close() error
if now < ts || now > ts+receiverReportTimeout {
return 0, 0
}
return uint8(atomic.LoadUint32(&s.loss)), atomic.LoadUint32(&s.jitter)
} }
var ErrKeyframeNeeded = errors.New("keyframe needed")
type downTrack interface { type downTrack interface {
WriteRTP(packat *rtp.Packet) error WriteRTP(packat *rtp.Packet) error
Accumulate(bytes uint32) Accumulate(bytes uint32)
GetMaxBitrate(now uint64) uint64 GetMaxBitrate(now uint64) uint64
} }
type rtpDownTrack struct {
track *webrtc.Track
remote *upTrack
maxLossBitrate *bitrate
maxREMBBitrate *bitrate
rate *estimator.Estimator
stats *receiverStats
srTime uint64
srNTPTime uint64
rtt uint64
}
func (down *rtpDownTrack) WriteRTP(packet *rtp.Packet) error {
return down.track.WriteRTP(packet)
}
func (down *rtpDownTrack) Accumulate(bytes uint32) {
down.rate.Accumulate(bytes)
}
func (down *rtpDownTrack) GetMaxBitrate(now uint64) uint64 {
br1 := down.maxLossBitrate.Get(now)
br2 := down.maxREMBBitrate.Get(now)
if br1 < br2 {
return br1
}
return br2
}
type downConnection interface {
Close() error
}
type rtpDownConnection struct {
id string
client *webClient
pc *webrtc.PeerConnection
remote *upConnection
tracks []*rtpDownTrack
iceCandidates []*webrtc.ICECandidateInit
}
func (down *rtpDownConnection) Close() error {
return down.client.action(delConnAction{down.id})
}
func (down *rtpDownConnection) addICECandidate(candidate *webrtc.ICECandidateInit) error {
if down.pc.RemoteDescription() != nil {
return down.pc.AddICECandidate(*candidate)
}
down.iceCandidates = append(down.iceCandidates, candidate)
return nil
}
func (down *rtpDownConnection) flushICECandidates() error {
err := flushICECandidates(down.pc, down.iceCandidates)
down.iceCandidates = nil
return err
}
...@@ -24,15 +24,15 @@ type diskClient struct { ...@@ -24,15 +24,15 @@ type diskClient struct {
closed bool closed bool
} }
func (client *diskClient) getGroup() *group { func (client *diskClient) Group() *group {
return client.group return client.group
} }
func (client *diskClient) getId() string { func (client *diskClient) Id() string {
return client.id return client.id
} }
func (client *diskClient) getUsername() string { func (client *diskClient) Username() string {
return "RECORDING" return "RECORDING"
} }
...@@ -52,7 +52,7 @@ func (client *diskClient) Close() error { ...@@ -52,7 +52,7 @@ func (client *diskClient) Close() error {
return nil return nil
} }
func (client *diskClient) pushConn(conn *upConnection, tracks []*upTrack, label string) error { func (client *diskClient) pushConn(conn upConnection, tracks []upTrack, label string) error {
client.mu.Lock() client.mu.Lock()
defer client.mu.Unlock() defer client.mu.Unlock()
...@@ -75,15 +75,13 @@ func (client *diskClient) pushConn(conn *upConnection, tracks []*upTrack, label ...@@ -75,15 +75,13 @@ func (client *diskClient) pushConn(conn *upConnection, tracks []*upTrack, label
return nil return nil
} }
var _ client = &diskClient{}
type diskConn struct { type diskConn struct {
directory string directory string
label string label string
mu sync.Mutex mu sync.Mutex
file *os.File file *os.File
remote *upConnection remote upConnection
tracks []*diskTrack tracks []*diskTrack
width, height uint32 width, height uint32
} }
...@@ -154,7 +152,7 @@ func openDiskFile(directory, label string) (*os.File, error) { ...@@ -154,7 +152,7 @@ func openDiskFile(directory, label string) (*os.File, error) {
} }
type diskTrack struct { type diskTrack struct {
remote *upTrack remote upTrack
conn *diskConn conn *diskConn
writer webm.BlockWriteCloser writer webm.BlockWriteCloser
...@@ -162,7 +160,7 @@ type diskTrack struct { ...@@ -162,7 +160,7 @@ type diskTrack struct {
timestamp uint32 timestamp uint32
} }
func newDiskConn(directory, label string, up *upConnection, remoteTracks []*upTrack) (*diskConn, error) { func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrack) (*diskConn, error) {
conn := diskConn{ conn := diskConn{
directory: directory, directory: directory,
label: label, label: label,
...@@ -172,7 +170,7 @@ func newDiskConn(directory, label string, up *upConnection, remoteTracks []*upTr ...@@ -172,7 +170,7 @@ func newDiskConn(directory, label string, up *upConnection, remoteTracks []*upTr
video := false video := false
for _, remote := range remoteTracks { for _, remote := range remoteTracks {
var builder *samplebuilder.SampleBuilder var builder *samplebuilder.SampleBuilder
switch remote.track.Codec().Name { switch remote.Codec().Name {
case webrtc.Opus: case webrtc.Opus:
builder = samplebuilder.New(16, &codecs.OpusPacket{}) builder = samplebuilder.New(16, &codecs.OpusPacket{})
case webrtc.VP8: case webrtc.VP8:
...@@ -245,7 +243,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { ...@@ -245,7 +243,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error {
keyframe := true keyframe := true
switch t.remote.track.Codec().Name { switch t.remote.Codec().Name {
case webrtc.VP8: case webrtc.VP8:
if len(sample.Data) < 1 { if len(sample.Data) < 1 {
return nil return nil
...@@ -265,7 +263,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { ...@@ -265,7 +263,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error {
return nil return nil
} }
tm := t.timestamp / (t.remote.track.Codec().ClockRate / 1000) tm := t.timestamp / (t.remote.Codec().ClockRate / 1000)
_, err := t.writer.Write(keyframe, int64(tm), sample.Data) _, err := t.writer.Write(keyframe, int64(tm), sample.Data)
if err != nil { if err != nil {
return err return err
...@@ -275,7 +273,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { ...@@ -275,7 +273,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error {
// called locked // called locked
func (t *diskTrack) initWriter(data []byte) error { func (t *diskTrack) initWriter(data []byte) error {
switch t.remote.track.Codec().Name { switch t.remote.Codec().Name {
case webrtc.VP8: case webrtc.VP8:
if len(data) < 10 { if len(data) < 10 {
return nil return nil
...@@ -300,9 +298,9 @@ func (conn *diskConn) initWriter(width, height uint32) error { ...@@ -300,9 +298,9 @@ func (conn *diskConn) initWriter(width, height uint32) error {
} }
var entries []webm.TrackEntry var entries []webm.TrackEntry
for i, t := range conn.tracks { for i, t := range conn.tracks {
codec := t.remote.track.Codec() codec := t.remote.Codec()
var entry webm.TrackEntry var entry webm.TrackEntry
switch t.remote.track.Codec().Name { switch t.remote.Codec().Name {
case webrtc.Opus: case webrtc.Opus:
entry = webm.TrackEntry{ entry = webm.TrackEntry{
Name: "Audio", Name: "Audio",
......
...@@ -22,10 +22,10 @@ import ( ...@@ -22,10 +22,10 @@ import (
) )
type client interface { type client interface {
getGroup() *group Group() *group
getId() string Id() string
getUsername() string Username() string
pushConn(conn *upConnection, tracks []*upTrack, label string) error pushConn(conn upConnection, tracks []upTrack, label string) error
pushClient(id, username string, add bool) error pushClient(id, username string, add bool) error
} }
...@@ -58,8 +58,8 @@ type delConnAction struct { ...@@ -58,8 +58,8 @@ type delConnAction struct {
} }
type addConnAction struct { type addConnAction struct {
conn *upConnection conn upConnection
tracks []*upTrack tracks []upTrack
} }
type addLabelAction struct { type addLabelAction struct {
...@@ -230,20 +230,20 @@ func addClient(name string, c client, user, pass string) (*group, error) { ...@@ -230,20 +230,20 @@ func addClient(name string, c client, user, pass string) (*group, error) {
return nil, userError("too many users") return nil, userError("too many users")
} }
} }
if g.clients[c.getId()] != nil { if g.clients[c.Id()] != nil {
return nil, protocolError("duplicate client id") return nil, protocolError("duplicate client id")
} }
g.clients[c.getId()] = c g.clients[c.Id()] = c
go func(clients []client) { go func(clients []client) {
c.pushClient(c.getId(), c.getUsername(), true) c.pushClient(c.Id(), c.Username(), true)
for _, cc := range clients { for _, cc := range clients {
err := c.pushClient(cc.getId(), cc.getUsername(), true) err := c.pushClient(cc.Id(), cc.Username(), true)
if err == ErrClientDead { if err == ErrClientDead {
return return
} }
cc.pushClient(c.getId(), c.getUsername(), true) cc.pushClient(c.Id(), c.Username(), true)
} }
}(g.getClientsUnlocked(c)) }(g.getClientsUnlocked(c))
...@@ -251,19 +251,19 @@ func addClient(name string, c client, user, pass string) (*group, error) { ...@@ -251,19 +251,19 @@ func addClient(name string, c client, user, pass string) (*group, error) {
} }
func delClient(c client) { func delClient(c client) {
g := c.getGroup() g := c.Group()
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
if g.clients[c.getId()] != c { if g.clients[c.Id()] != c {
log.Printf("Deleting unknown client") log.Printf("Deleting unknown client")
return return
} }
delete(g.clients, c.getId()) delete(g.clients, c.Id())
go func(clients []client) { go func(clients []client) {
for _, cc := range clients { for _, cc := range clients {
cc.pushClient(c.getId(), c.getUsername(), false) cc.pushClient(c.Id(), c.Username(), false)
} }
}(g.getClientsUnlocked(nil)) }(g.getClientsUnlocked(nil))
} }
......
// Copyright (c) 2020 by Juliusz Chroboczek.
// This is not open source software. Copy it, and I'll break into your
// house and tell your three year-old that Santa doesn't exist.
package main
import (
"errors"
"io"
"log"
"math/bits"
"sync"
"sync/atomic"
"time"
"sfu/estimator"
"sfu/jitter"
"sfu/packetcache"
"sfu/rtptime"
"github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/pion/webrtc/v2"
)
type bitrate struct {
bitrate uint64
jiffies uint64
}
func (br *bitrate) Set(bitrate uint64, now uint64) {
atomic.StoreUint64(&br.bitrate, bitrate)
atomic.StoreUint64(&br.jiffies, now)
}
func (br *bitrate) Get(now uint64) uint64 {
ts := atomic.LoadUint64(&br.jiffies)
if now < ts || now-ts > receiverReportTimeout {
return ^uint64(0)
}
return atomic.LoadUint64(&br.bitrate)
}
type receiverStats struct {
loss uint32
jitter uint32
jiffies uint64
}
func (s *receiverStats) Set(loss uint8, jitter uint32, now uint64) {
atomic.StoreUint32(&s.loss, uint32(loss))
atomic.StoreUint32(&s.jitter, jitter)
atomic.StoreUint64(&s.jiffies, now)
}
func (s *receiverStats) Get(now uint64) (uint8, uint32) {
ts := atomic.LoadUint64(&s.jiffies)
if now < ts || now > ts+receiverReportTimeout {
return 0, 0
}
return uint8(atomic.LoadUint32(&s.loss)), atomic.LoadUint32(&s.jitter)
}
const receiverReportTimeout = 8 * rtptime.JiffiesPerSec
type iceConnection interface {
addICECandidate(candidate *webrtc.ICECandidateInit) error
flushICECandidates() error
}
type rtpDownTrack struct {
track *webrtc.Track
remote upTrack
maxLossBitrate *bitrate
maxREMBBitrate *bitrate
rate *estimator.Estimator
stats *receiverStats
srTime uint64
srNTPTime uint64
rtt uint64
}
func (down *rtpDownTrack) WriteRTP(packet *rtp.Packet) error {
return down.track.WriteRTP(packet)
}
func (down *rtpDownTrack) Accumulate(bytes uint32) {
down.rate.Accumulate(bytes)
}
func (down *rtpDownTrack) GetMaxBitrate(now uint64) uint64 {
br1 := down.maxLossBitrate.Get(now)
br2 := down.maxREMBBitrate.Get(now)
if br1 < br2 {
return br1
}
return br2
}
type rtpDownConnection struct {
id string
pc *webrtc.PeerConnection
remote upConnection
tracks []*rtpDownTrack
iceCandidates []*webrtc.ICECandidateInit
close func() error
}
func newDownConn(id string, remote upConnection) (*rtpDownConnection, error) {
pc, err := groups.api.NewPeerConnection(iceConfiguration())
if err != nil {
return nil, err
}
pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) {
log.Printf("Got track on downstream connection")
})
conn := &rtpDownConnection{
id: id,
pc: pc,
remote: remote,
}
return conn, nil
}
func (down *rtpDownConnection) Close() error {
return down.close()
}
func (down *rtpDownConnection) addICECandidate(candidate *webrtc.ICECandidateInit) error {
if down.pc.RemoteDescription() != nil {
return down.pc.AddICECandidate(*candidate)
}
down.iceCandidates = append(down.iceCandidates, candidate)
return nil
}
func flushICECandidates(pc *webrtc.PeerConnection, candidates []*webrtc.ICECandidateInit) error {
if pc.RemoteDescription() == nil {
return errors.New("flushICECandidates called in bad state")
}
var err error
for _, candidate := range candidates {
err2 := pc.AddICECandidate(*candidate)
if err == nil {
err = err2
}
}
return err
}
func (down *rtpDownConnection) flushICECandidates() error {
err := flushICECandidates(down.pc, down.iceCandidates)
down.iceCandidates = nil
return err
}
type rtpUpTrack struct {
track *webrtc.Track
label string
rate *estimator.Estimator
cache *packetcache.Cache
jitter *jitter.Estimator
maxBitrate uint64
lastPLI uint64
lastFIR uint64
firSeqno uint32
localCh chan localTrackAction
writerDone chan struct{}
mu sync.Mutex
local []downTrack
srTime uint64
srNTPTime uint64
srRTPTime uint32
}
type localTrackAction struct {
add bool
track downTrack
}
func (up *rtpUpTrack) notifyLocal(add bool, track downTrack) {
select {
case up.localCh <- localTrackAction{add, track}:
case <-up.writerDone:
}
}
func (up *rtpUpTrack) addLocal(local downTrack) error {
up.mu.Lock()
for _, t := range up.local {
if t == local {
up.mu.Unlock()
return nil
}
}
up.local = append(up.local, local)
up.mu.Unlock()
up.notifyLocal(true, local)
return nil
}
func (up *rtpUpTrack) delLocal(local downTrack) bool {
up.mu.Lock()
for i, l := range up.local {
if l == local {
up.local = append(up.local[:i], up.local[i+1:]...)
up.mu.Unlock()
up.notifyLocal(false, l)
return true
}
}
up.mu.Unlock()
return false
}
func (up *rtpUpTrack) getLocal() []downTrack {
up.mu.Lock()
defer up.mu.Unlock()
local := make([]downTrack, len(up.local))
copy(local, up.local)
return local
}
func (up *rtpUpTrack) getRTP(seqno uint16, result []byte) uint16 {
return up.cache.Get(seqno, result)
}
func (up *rtpUpTrack) getTimestamp() (uint32, bool) {
buf := make([]byte, packetcache.BufSize)
l := up.cache.GetLast(buf)
if l == 0 {
return 0, false
}
var packet rtp.Packet
err := packet.Unmarshal(buf)
if err != nil {
return 0, false
}
return packet.Timestamp, true
}
func (up *rtpUpTrack) Label() string {
return up.label
}
func (up *rtpUpTrack) Codec() *webrtc.RTPCodec {
return up.track.Codec()
}
func (up *rtpUpTrack) hasRtcpFb(tpe, parameter string) bool {
for _, fb := range up.track.Codec().RTCPFeedback {
if fb.Type == tpe && fb.Parameter == parameter {
return true
}
}
return false
}
type rtpUpConnection struct {
id string
label string
pc *webrtc.PeerConnection
labels map[string]string
iceCandidates []*webrtc.ICECandidateInit
mu sync.Mutex
closed bool
tracks []*rtpUpTrack
local []downConnection
}
func (up *rtpUpConnection) getTracks() []*rtpUpTrack {
up.mu.Lock()
defer up.mu.Unlock()
tracks := make([]*rtpUpTrack, len(up.tracks))
copy(tracks, up.tracks)
return tracks
}
func (up *rtpUpConnection) Id() string {
return up.id
}
func (up *rtpUpConnection) Label() string {
return up.label
}
func (up *rtpUpConnection) addLocal(local downConnection) error {
up.mu.Lock()
defer up.mu.Unlock()
if up.closed {
return ErrConnectionClosed
}
for _, t := range up.local {
if t == local {
return nil
}
}
up.local = append(up.local, local)
return nil
}
func (up *rtpUpConnection) delLocal(local downConnection) bool {
up.mu.Lock()
defer up.mu.Unlock()
for i, l := range up.local {
if l == local {
up.local = append(up.local[:i], up.local[i+1:]...)
return true
}
}
return false
}
func (up *rtpUpConnection) getLocal() []downConnection {
up.mu.Lock()
defer up.mu.Unlock()
local := make([]downConnection, len(up.local))
copy(local, up.local)
return local
}
func (up *rtpUpConnection) Close() error {
up.mu.Lock()
defer up.mu.Unlock()
go func(local []downConnection) {
for _, l := range local {
l.Close()
}
}(up.local)
up.local = nil
up.closed = true
return up.pc.Close()
}
func (up *rtpUpConnection) addICECandidate(candidate *webrtc.ICECandidateInit) error {
if up.pc.RemoteDescription() != nil {
return up.pc.AddICECandidate(*candidate)
}
up.iceCandidates = append(up.iceCandidates, candidate)
return nil
}
func (up *rtpUpConnection) flushICECandidates() error {
err := flushICECandidates(up.pc, up.iceCandidates)
up.iceCandidates = nil
return err
}
func getTrackMid(pc *webrtc.PeerConnection, track *webrtc.Track) string {
for _, t := range pc.GetTransceivers() {
if t.Receiver() != nil && t.Receiver().Track() == track {
return t.Mid()
}
}
return ""
}
// called locked
func (up *rtpUpConnection) complete() bool {
for mid, _ := range up.labels {
found := false
for _, t := range up.tracks {
m := getTrackMid(up.pc, t.track)
if m == mid {
found = true
break
}
}
if !found {
return false
}
}
return true
}
func newUpConn(c client, id string) (*rtpUpConnection, error) {
pc, err := groups.api.NewPeerConnection(iceConfiguration())
if err != nil {
return nil, err
}
_, err = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio,
webrtc.RtpTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionRecvonly,
},
)
if err != nil {
pc.Close()
return nil, err
}
_, err = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo,
webrtc.RtpTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionRecvonly,
},
)
if err != nil {
pc.Close()
return nil, err
}
conn := &rtpUpConnection{id: id, pc: pc}
pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) {
conn.mu.Lock()
defer conn.mu.Unlock()
mid := getTrackMid(pc, remote)
if mid == "" {
log.Printf("Couldn't get track's mid")
return
}
label, ok := conn.labels[mid]
if !ok {
log.Printf("Couldn't get track's label")
isvideo := remote.Kind() == webrtc.RTPCodecTypeVideo
if isvideo {
label = "video"
} else {
label = "audio"
}
}
track := &rtpUpTrack{
track: remote,
label: label,
cache: packetcache.New(32),
rate: estimator.New(time.Second),
jitter: jitter.New(remote.Codec().ClockRate),
maxBitrate: ^uint64(0),
localCh: make(chan localTrackAction, 2),
writerDone: make(chan struct{}),
}
conn.tracks = append(conn.tracks, track)
if remote.Kind() == webrtc.RTPCodecTypeVideo {
atomic.AddUint32(&c.Group().videoCount, 1)
}
go readLoop(conn, track)
go rtcpUpListener(conn, track, receiver)
if conn.complete() {
tracks := make([]upTrack, len(conn.tracks))
for i, t := range conn.tracks {
tracks[i] = t
}
clients := c.Group().getClients(c)
for _, cc := range clients {
cc.pushConn(conn, tracks, conn.label)
}
go rtcpUpSender(conn)
}
})
return conn, nil
}
type packetIndex struct {
seqno uint16
index uint16
}
func readLoop(conn *rtpUpConnection, track *rtpUpTrack) {
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
ch := make(chan packetIndex, 32)
defer close(ch)
go writeLoop(conn, track, ch)
buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet
drop := 0
for {
bytes, err := track.track.Read(buf)
if err != nil {
if err != io.EOF {
log.Printf("%v", err)
}
break
}
track.rate.Accumulate(uint32(bytes))
err = packet.Unmarshal(buf[:bytes])
if err != nil {
log.Printf("%v", err)
continue
}
track.jitter.Accumulate(packet.Timestamp)
first, index :=
track.cache.Store(packet.SequenceNumber, buf[:bytes])
if packet.SequenceNumber-first > 24 {
found, first, bitmap := track.cache.BitmapGet()
if found {
err := conn.sendNACK(track, first, bitmap)
if err != nil {
log.Printf("%v", err)
}
}
}
if drop > 0 {
if packet.Marker {
// last packet in frame
drop = 0
} else {
drop--
}
continue
}
select {
case ch <- packetIndex{packet.SequenceNumber, index}:
default:
if isvideo {
// the writer is congested. Drop until
// the end of the frame.
if isvideo && !packet.Marker {
drop = 7
}
}
}
}
}
func writeLoop(conn *rtpUpConnection, track *rtpUpTrack, ch <-chan packetIndex) {
defer close(track.writerDone)
buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet
local := make([]downTrack, 0)
firSent := false
for {
select {
case action := <-track.localCh:
if action.add {
local = append(local, action.track)
firSent = false
} else {
found := false
for i, t := range local {
if t == action.track {
local = append(local[:i], local[i+1:]...)
found = true
break
}
}
if !found {
log.Printf("Deleting unknown track!")
}
}
case pi, ok := <-ch:
if !ok {
return
}
bytes := track.cache.GetAt(pi.seqno, pi.index, buf)
if bytes == 0 {
continue
}
err := packet.Unmarshal(buf[:bytes])
if err != nil {
log.Printf("%v", err)
continue
}
kfNeeded := false
for _, l := range local {
err := l.WriteRTP(&packet)
if err != nil {
if err == ErrKeyframeNeeded {
kfNeeded = true
} else if err != io.ErrClosedPipe {
log.Printf("WriteRTP: %v", err)
}
continue
}
l.Accumulate(uint32(bytes))
}
if kfNeeded {
err := conn.sendFIR(track, !firSent)
if err == ErrUnsupportedFeedback {
err := conn.sendPLI(track)
if err != nil &&
err != ErrUnsupportedFeedback {
log.Printf("sendPLI: %v", err)
}
} else if err != nil {
log.Printf("sendFIR: %v", err)
}
firSent = true
}
}
}
}
var ErrUnsupportedFeedback = errors.New("unsupported feedback type")
var ErrRateLimited = errors.New("rate limited")
func (up *rtpUpConnection) sendPLI(track *rtpUpTrack) error {
if !track.hasRtcpFb("nack", "pli") {
return ErrUnsupportedFeedback
}
last := atomic.LoadUint64(&track.lastPLI)
now := rtptime.Jiffies()
if now >= last && now-last < rtptime.JiffiesPerSec/5 {
return ErrRateLimited
}
atomic.StoreUint64(&track.lastPLI, now)
return sendPLI(up.pc, track.track.SSRC())
}
func sendPLI(pc *webrtc.PeerConnection, ssrc uint32) error {
return pc.WriteRTCP([]rtcp.Packet{
&rtcp.PictureLossIndication{MediaSSRC: ssrc},
})
}
func (up *rtpUpConnection) sendFIR(track *rtpUpTrack, increment bool) error {
// we need to reliably increment the seqno, even if we are going
// to drop the packet due to rate limiting.
var seqno uint8
if increment {
seqno = uint8(atomic.AddUint32(&track.firSeqno, 1) & 0xFF)
} else {
seqno = uint8(atomic.LoadUint32(&track.firSeqno) & 0xFF)
}
if !track.hasRtcpFb("ccm", "fir") {
return ErrUnsupportedFeedback
}
last := atomic.LoadUint64(&track.lastFIR)
now := rtptime.Jiffies()
if now >= last && now-last < rtptime.JiffiesPerSec/5 {
return ErrRateLimited
}
atomic.StoreUint64(&track.lastFIR, now)
return sendFIR(up.pc, track.track.SSRC(), seqno)
}
func sendFIR(pc *webrtc.PeerConnection, ssrc uint32, seqno uint8) error {
return pc.WriteRTCP([]rtcp.Packet{
&rtcp.FullIntraRequest{
FIR: []rtcp.FIREntry{
rtcp.FIREntry{
SSRC: ssrc,
SequenceNumber: seqno,
},
},
},
})
}
func sendREMB(pc *webrtc.PeerConnection, ssrc uint32, bitrate uint64) error {
return pc.WriteRTCP([]rtcp.Packet{
&rtcp.ReceiverEstimatedMaximumBitrate{
Bitrate: bitrate,
SSRCs: []uint32{ssrc},
},
})
}
func (up *rtpUpConnection) sendNACK(track *rtpUpTrack, first uint16, bitmap uint16) error {
if !track.hasRtcpFb("nack", "") {
return nil
}
err := sendNACK(up.pc, track.track.SSRC(), first, bitmap)
if err == nil {
track.cache.Expect(1 + bits.OnesCount16(bitmap))
}
return err
}
func sendNACK(pc *webrtc.PeerConnection, ssrc uint32, first uint16, bitmap uint16) error {
packet := rtcp.Packet(
&rtcp.TransportLayerNack{
MediaSSRC: ssrc,
Nacks: []rtcp.NackPair{
rtcp.NackPair{
first,
rtcp.PacketBitmap(bitmap),
},
},
},
)
return pc.WriteRTCP([]rtcp.Packet{packet})
}
func sendRecovery(p *rtcp.TransportLayerNack, track *rtpDownTrack) {
var packet rtp.Packet
buf := make([]byte, packetcache.BufSize)
for _, nack := range p.Nacks {
for _, seqno := range nack.PacketList() {
l := track.remote.getRTP(seqno, buf)
if l == 0 {
continue
}
err := packet.Unmarshal(buf[:l])
if err != nil {
continue
}
err = track.track.WriteRTP(&packet)
if err != nil {
log.Printf("WriteRTP: %v", err)
continue
}
track.rate.Accumulate(uint32(l))
}
}
}
func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPReceiver) {
for {
firstSR := false
ps, err := r.ReadRTCP()
if err != nil {
if err != io.EOF {
log.Printf("ReadRTCP: %v", err)
}
return
}
now := rtptime.Jiffies()
for _, p := range ps {
switch p := p.(type) {
case *rtcp.SenderReport:
track.mu.Lock()
if track.srTime == 0 {
firstSR = true
}
track.srTime = now
track.srNTPTime = p.NTPTime
track.srRTPTime = p.RTPTime
track.mu.Unlock()
case *rtcp.SourceDescription:
}
}
if firstSR {
// this is the first SR we got for at least one track,
// quickly propagate the time offsets downstream
local := conn.getLocal()
for _, l := range local {
l, ok := l.(*rtpDownConnection)
if ok {
err := sendSR(l)
if err != nil {
log.Printf("sendSR: %v", err)
}
}
}
}
}
}
func sendRR(conn *rtpUpConnection) error {
conn.mu.Lock()
defer conn.mu.Unlock()
if len(conn.tracks) == 0 {
return nil
}
now := rtptime.Jiffies()
reports := make([]rtcp.ReceptionReport, 0, len(conn.tracks))
for _, t := range conn.tracks {
expected, lost, totalLost, eseqno := t.cache.GetStats(true)
if expected == 0 {
expected = 1
}
if lost >= expected {
lost = expected - 1
}
t.mu.Lock()
srTime := t.srTime
srNTPTime := t.srNTPTime
t.mu.Unlock()
var delay uint64
if srTime != 0 {
delay = (now - srTime) /
(rtptime.JiffiesPerSec / 0x10000)
}
reports = append(reports, rtcp.ReceptionReport{
SSRC: t.track.SSRC(),
FractionLost: uint8((lost * 256) / expected),
TotalLost: totalLost,
LastSequenceNumber: eseqno,
Jitter: t.jitter.Jitter(),
LastSenderReport: uint32(srNTPTime >> 16),
Delay: uint32(delay),
})
}
return conn.pc.WriteRTCP([]rtcp.Packet{
&rtcp.ReceiverReport{
Reports: reports,
},
})
}
func rtcpUpSender(conn *rtpUpConnection) {
for {
time.Sleep(time.Second)
err := sendRR(conn)
if err != nil {
if err == io.EOF || err == io.ErrClosedPipe {
return
}
log.Printf("sendRR: %v", err)
}
}
}
func sendSR(conn *rtpDownConnection) error {
// since this is only called after all tracks have been created,
// there is no need for locking.
packets := make([]rtcp.Packet, 0, len(conn.tracks))
now := time.Now()
nowNTP := rtptime.TimeToNTP(now)
jiffies := rtptime.TimeToJiffies(now)
for _, t := range conn.tracks {
clockrate := t.track.Codec().ClockRate
remote := t.remote
var nowRTP uint32
switch r := remote.(type) {
case *rtpUpTrack:
r.mu.Lock()
lastTime := r.srTime
srNTPTime := r.srNTPTime
srRTPTime := r.srRTPTime
r.mu.Unlock()
if lastTime == 0 {
// we never got a remote SR, skip this track
continue
}
if srNTPTime != 0 {
srTime := rtptime.NTPToTime(srNTPTime)
d := now.Sub(srTime)
if d > 0 && d < time.Hour {
delay := rtptime.FromDuration(
d, clockrate,
)
nowRTP = srRTPTime + uint32(delay)
}
}
default:
ts, ok := remote.getTimestamp()
if !ok {
continue
}
nowRTP = ts
}
p, b := t.rate.Totals()
packets = append(packets,
&rtcp.SenderReport{
SSRC: t.track.SSRC(),
NTPTime: nowNTP,
RTPTime: nowRTP,
PacketCount: p,
OctetCount: b,
})
atomic.StoreUint64(&t.srTime, jiffies)
atomic.StoreUint64(&t.srNTPTime, nowNTP)
}
if len(packets) == 0 {
return nil
}
return conn.pc.WriteRTCP(packets)
}
func rtcpDownSender(conn *rtpDownConnection) {
for {
time.Sleep(time.Second)
err := sendSR(conn)
if err != nil {
if err == io.EOF || err == io.ErrClosedPipe {
return
}
log.Printf("sendSR: %v", err)
}
}
}
const (
minLossRate = 9600
initLossRate = 512 * 1000
maxLossRate = 1 << 30
)
func (track *rtpDownTrack) updateRate(loss uint8, now uint64) {
rate := track.maxLossBitrate.Get(now)
if rate < minLossRate || rate > maxLossRate {
// no recent feedback, reset
rate = initLossRate
}
if loss < 5 {
// if our actual rate is low, then we're not probing the
// bottleneck
r, _ := track.rate.Estimate()
actual := 8 * uint64(r)
if actual >= (rate*7)/8 {
// loss < 0.02, multiply by 1.05
rate = rate * 269 / 256
if rate > maxLossRate {
rate = maxLossRate
}
}
} else if loss > 25 {
// loss > 0.1, multiply by (1 - loss/2)
rate = rate * (512 - uint64(loss)) / 512
if rate < minLossRate {
rate = minLossRate
}
}
// update unconditionally, to set the timestamp
track.maxLossBitrate.Set(rate, now)
}
func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RTPSender) {
var gotFir bool
lastFirSeqno := uint8(0)
for {
ps, err := s.ReadRTCP()
if err != nil {
if err != io.EOF {
log.Printf("ReadRTCP: %v", err)
}
return
}
jiffies := rtptime.Jiffies()
for _, p := range ps {
switch p := p.(type) {
case *rtcp.PictureLossIndication:
remote, ok := conn.remote.(*rtpUpConnection)
if !ok {
continue
}
rt, ok := track.remote.(*rtpUpTrack)
if !ok {
continue
}
err := remote.sendPLI(rt)
if err != nil {
log.Printf("sendPLI: %v", err)
}
case *rtcp.FullIntraRequest:
found := false
var seqno uint8
for _, entry := range p.FIR {
if entry.SSRC == track.track.SSRC() {
found = true
seqno = entry.SequenceNumber
break
}
}
if !found {
log.Printf("Misdirected FIR")
continue
}
increment := true
if gotFir {
increment = seqno != lastFirSeqno
}
gotFir = true
lastFirSeqno = seqno
remote, ok := conn.remote.(*rtpUpConnection)
if !ok {
continue
}
rt, ok := track.remote.(*rtpUpTrack)
if !ok {
continue
}
err := remote.sendFIR(rt, increment)
if err == ErrUnsupportedFeedback {
err := remote.sendPLI(rt)
if err != nil {
log.Printf("sendPLI: %v", err)
}
} else if err != nil {
log.Printf("sendFIR: %v", err)
}
case *rtcp.ReceiverEstimatedMaximumBitrate:
track.maxREMBBitrate.Set(p.Bitrate, jiffies)
case *rtcp.ReceiverReport:
for _, r := range p.Reports {
if r.SSRC == track.track.SSRC() {
handleReport(track, r, jiffies)
}
}
case *rtcp.SenderReport:
for _, r := range p.Reports {
if r.SSRC == track.track.SSRC() {
handleReport(track, r, jiffies)
}
}
case *rtcp.TransportLayerNack:
maxBitrate := track.GetMaxBitrate(jiffies)
bitrate, _ := track.rate.Estimate()
if uint64(bitrate)*7/8 < maxBitrate {
sendRecovery(p, track)
}
}
}
}
}
func handleReport(track *rtpDownTrack, report rtcp.ReceptionReport, jiffies uint64) {
track.stats.Set(report.FractionLost, report.Jitter, jiffies)
track.updateRate(report.FractionLost, jiffies)
if report.LastSenderReport != 0 {
jiffies := rtptime.Jiffies()
srTime := atomic.LoadUint64(&track.srTime)
if jiffies < srTime || jiffies-srTime > 8*rtptime.JiffiesPerSec {
return
}
srNTPTime := atomic.LoadUint64(&track.srNTPTime)
if report.LastSenderReport == uint32(srNTPTime>>16) {
delay := uint64(report.Delay) *
(rtptime.JiffiesPerSec / 0x10000)
if delay > jiffies-srTime {
return
}
rtt := (jiffies - srTime) - delay
oldrtt := atomic.LoadUint64(&track.rtt)
newrtt := rtt
if oldrtt > 0 {
newrtt = (3*oldrtt + rtt) / 4
}
atomic.StoreUint64(&track.rtt, newrtt)
}
}
}
func updateUpTrack(track *rtpUpTrack, maxVideoRate uint64) uint64 {
now := rtptime.Jiffies()
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
clockrate := track.track.Codec().ClockRate
minrate := uint64(minAudioRate)
rate := ^uint64(0)
if isvideo {
minrate = minVideoRate
rate = maxVideoRate
if rate < minrate {
rate = minrate
}
}
local := track.getLocal()
var maxrto uint64
for _, l := range local {
bitrate := l.GetMaxBitrate(now)
if bitrate == ^uint64(0) {
continue
}
if bitrate <= minrate {
rate = minrate
break
}
if rate > bitrate {
rate = bitrate
}
ll, ok := l.(*rtpDownTrack)
if ok {
_, j := ll.stats.Get(now)
jitter := uint64(j) *
(rtptime.JiffiesPerSec /
uint64(clockrate))
rtt := atomic.LoadUint64(&ll.rtt)
rto := rtt + 4*jitter
if rto > maxrto {
maxrto = rto
}
}
}
track.maxBitrate = rate
_, r := track.rate.Estimate()
packets := int((uint64(r) * maxrto * 4) / rtptime.JiffiesPerSec)
if packets < 32 {
packets = 32
}
if packets > 256 {
packets = 256
}
track.cache.ResizeCond(packets)
return rate
}
...@@ -8,10 +8,8 @@ package main ...@@ -8,10 +8,8 @@ package main
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"io"
"log" "log"
"math" "math"
"math/bits"
"os" "os"
"strings" "strings"
"sync" "sync"
...@@ -19,13 +17,8 @@ import ( ...@@ -19,13 +17,8 @@ import (
"time" "time"
"sfu/estimator" "sfu/estimator"
"sfu/jitter"
"sfu/packetcache"
"sfu/rtptime"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/pion/webrtc/v2" "github.com/pion/webrtc/v2"
) )
...@@ -104,18 +97,18 @@ type webClient struct { ...@@ -104,18 +97,18 @@ type webClient struct {
mu sync.Mutex mu sync.Mutex
down map[string]*rtpDownConnection down map[string]*rtpDownConnection
up map[string]*upConnection up map[string]*rtpUpConnection
} }
func (c *webClient) getGroup() *group { func (c *webClient) Group() *group {
return c.group return c.group
} }
func (c *webClient) getId() string { func (c *webClient) Id() string {
return c.id return c.id
} }
func (c *webClient) getUsername() string { func (c *webClient) Username() string {
return c.username return c.username
} }
...@@ -196,135 +189,37 @@ type closeMessage struct { ...@@ -196,135 +189,37 @@ type closeMessage struct {
data []byte data []byte
} }
func startClient(conn *websocket.Conn) (err error) { func getUpConn(c *webClient, id string) *rtpUpConnection {
var m clientMessage
err = conn.SetReadDeadline(time.Now().Add(15 * time.Second))
if err != nil {
conn.Close()
return
}
err = conn.ReadJSON(&m)
if err != nil {
conn.Close()
return
}
err = conn.SetReadDeadline(time.Time{})
if err != nil {
conn.Close()
return
}
if m.Type != "handshake" {
conn.Close()
return
}
if strings.ContainsRune(m.Username, ' ') {
err = userError("don't put spaces in your username")
return
}
c := &webClient{
id: m.Id,
username: m.Username,
actionCh: make(chan interface{}, 10),
done: make(chan struct{}),
}
defer close(c.done)
c.writeCh = make(chan interface{}, 25)
defer func() {
if isWSNormalError(err) {
err = nil
} else {
m, e := errorToWSCloseMessage(err)
if m != "" {
c.write(clientMessage{
Type: "error",
Value: m,
})
}
select {
case c.writeCh <- closeMessage{e}:
case <-c.writerDone:
}
}
close(c.writeCh)
c.writeCh = nil
}()
c.writerDone = make(chan struct{})
go clientWriter(conn, c.writeCh, c.writerDone)
g, err := addClient(m.Group, c, m.Username, m.Password)
if err != nil {
return
}
c.group = g
defer delClient(c)
return clientLoop(c, conn)
}
func getUpConn(c *webClient, id string) *upConnection {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.up == nil { if c.up == nil {
return nil return nil
} }
conn := c.up[id] return c.up[id]
if conn == nil {
return nil
}
return conn
} }
func getUpConns(c *webClient) []*upConnection { func getUpConns(c *webClient) []*rtpUpConnection {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
up := make([]*upConnection, 0, len(c.up)) up := make([]*rtpUpConnection, 0, len(c.up))
for _, u := range c.up { for _, u := range c.up {
up = append(up, u) up = append(up, u)
} }
return up return up
} }
func addUpConn(c *webClient, id string) (*upConnection, error) { func addUpConn(c *webClient, id string) (*rtpUpConnection, error) {
pc, err := groups.api.NewPeerConnection(iceConfiguration()) conn, err := newUpConn(c, id)
if err != nil {
return nil, err
}
_, err = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio,
webrtc.RtpTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionRecvonly,
},
)
if err != nil { if err != nil {
pc.Close()
return nil, err return nil, err
} }
_, err = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo,
webrtc.RtpTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionRecvonly,
},
)
if err != nil {
pc.Close()
return nil, err
}
conn := &upConnection{id: id, pc: pc}
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.up == nil { if c.up == nil {
c.up = make(map[string]*upConnection) c.up = make(map[string]*rtpUpConnection)
} }
if c.up[id] != nil || (c.down != nil && c.down[id] != nil) { if c.up[id] != nil || (c.down != nil && c.down[id] != nil) {
conn.pc.Close() conn.pc.Close()
...@@ -332,386 +227,13 @@ func addUpConn(c *webClient, id string) (*upConnection, error) { ...@@ -332,386 +227,13 @@ func addUpConn(c *webClient, id string) (*upConnection, error) {
} }
c.up[id] = conn c.up[id] = conn
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { conn.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
sendICE(c, id, candidate) sendICE(c, id, candidate)
}) })
pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) {
conn.mu.Lock()
defer conn.mu.Unlock()
mid := getUpMid(pc, remote)
if mid == "" {
log.Printf("Couldn't get track's mid")
return
}
label, ok := conn.labels[mid]
if !ok {
log.Printf("Couldn't get track's label")
isvideo := remote.Kind() == webrtc.RTPCodecTypeVideo
if isvideo {
label = "video"
} else {
label = "audio"
}
}
track := &upTrack{
track: remote,
label: label,
cache: packetcache.New(32),
rate: estimator.New(time.Second),
jitter: jitter.New(remote.Codec().ClockRate),
maxBitrate: ^uint64(0),
localCh: make(chan localTrackAction, 2),
writerDone: make(chan struct{}),
}
conn.tracks = append(conn.tracks, track)
if remote.Kind() == webrtc.RTPCodecTypeVideo {
atomic.AddUint32(&c.group.videoCount, 1)
}
go readLoop(conn, track)
go rtcpUpListener(conn, track, receiver)
if conn.complete() {
// cannot call getTracks, we're locked
tracks := make([]*upTrack, len(conn.tracks))
copy(tracks, conn.tracks)
clients := c.group.getClients(c)
for _, cc := range clients {
cc.pushConn(conn, tracks, conn.label)
}
go rtcpUpSender(conn)
}
})
return conn, nil return conn, nil
} }
type packetIndex struct {
seqno uint16
index uint16
}
func readLoop(conn *upConnection, track *upTrack) {
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
ch := make(chan packetIndex, 32)
defer close(ch)
go writeLoop(conn, track, ch)
buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet
drop := 0
for {
bytes, err := track.track.Read(buf)
if err != nil {
if err != io.EOF {
log.Printf("%v", err)
}
break
}
track.rate.Accumulate(uint32(bytes))
err = packet.Unmarshal(buf[:bytes])
if err != nil {
log.Printf("%v", err)
continue
}
track.jitter.Accumulate(packet.Timestamp)
first, index :=
track.cache.Store(packet.SequenceNumber, buf[:bytes])
if packet.SequenceNumber-first > 24 {
found, first, bitmap := track.cache.BitmapGet()
if found {
err := conn.sendNACK(track, first, bitmap)
if err != nil {
log.Printf("%v", err)
}
}
}
if drop > 0 {
if packet.Marker {
// last packet in frame
drop = 0
} else {
drop--
}
continue
}
select {
case ch <- packetIndex{packet.SequenceNumber, index}:
default:
if isvideo {
// the writer is congested. Drop until
// the end of the frame.
if isvideo && !packet.Marker {
drop = 7
}
}
}
}
}
func writeLoop(conn *upConnection, track *upTrack, ch <-chan packetIndex) {
defer close(track.writerDone)
buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet
local := make([]downTrack, 0)
firSent := false
for {
select {
case action := <-track.localCh:
if action.add {
local = append(local, action.track)
firSent = false
} else {
found := false
for i, t := range local {
if t == action.track {
local = append(local[:i], local[i+1:]...)
found = true
break
}
}
if !found {
log.Printf("Deleting unknown track!")
}
}
case pi, ok := <-ch:
if !ok {
return
}
bytes := track.cache.GetAt(pi.seqno, pi.index, buf)
if bytes == 0 {
continue
}
err := packet.Unmarshal(buf[:bytes])
if err != nil {
log.Printf("%v", err)
continue
}
kfNeeded := false
for _, l := range local {
err := l.WriteRTP(&packet)
if err != nil {
if err == ErrKeyframeNeeded {
kfNeeded = true
} else if err != io.ErrClosedPipe {
log.Printf("WriteRTP: %v", err)
}
continue
}
l.Accumulate(uint32(bytes))
}
if kfNeeded {
err := conn.sendFIR(track, !firSent)
if err == ErrUnsupportedFeedback {
err := conn.sendPLI(track)
if err != nil &&
err != ErrUnsupportedFeedback {
log.Printf("sendPLI: %v", err)
}
} else if err != nil {
log.Printf("sendFIR: %v", err)
}
firSent = true
}
}
}
}
func rtcpUpListener(conn *upConnection, track *upTrack, r *webrtc.RTPReceiver) {
for {
firstSR := false
ps, err := r.ReadRTCP()
if err != nil {
if err != io.EOF {
log.Printf("ReadRTCP: %v", err)
}
return
}
now := rtptime.Jiffies()
for _, p := range ps {
switch p := p.(type) {
case *rtcp.SenderReport:
track.mu.Lock()
if track.srTime == 0 {
firstSR = true
}
track.srTime = now
track.srNTPTime = p.NTPTime
track.srRTPTime = p.RTPTime
track.mu.Unlock()
case *rtcp.SourceDescription:
}
}
if firstSR {
// this is the first SR we got for at least one track,
// quickly propagate the time offsets downstream
local := conn.getLocal()
for _, l := range local {
l, ok := l.(*rtpDownConnection)
if ok {
err := sendSR(l)
if err != nil {
log.Printf("sendSR: %v", err)
}
}
}
}
}
}
func sendRR(conn *upConnection) error {
conn.mu.Lock()
defer conn.mu.Unlock()
if len(conn.tracks) == 0 {
return nil
}
now := rtptime.Jiffies()
reports := make([]rtcp.ReceptionReport, 0, len(conn.tracks))
for _, t := range conn.tracks {
expected, lost, totalLost, eseqno := t.cache.GetStats(true)
if expected == 0 {
expected = 1
}
if lost >= expected {
lost = expected - 1
}
t.mu.Lock()
srTime := t.srTime
srNTPTime := t.srNTPTime
t.mu.Unlock()
var delay uint64
if srTime != 0 {
delay = (now - srTime) /
(rtptime.JiffiesPerSec / 0x10000)
}
reports = append(reports, rtcp.ReceptionReport{
SSRC: t.track.SSRC(),
FractionLost: uint8((lost * 256) / expected),
TotalLost: totalLost,
LastSequenceNumber: eseqno,
Jitter: t.jitter.Jitter(),
LastSenderReport: uint32(srNTPTime >> 16),
Delay: uint32(delay),
})
}
return conn.pc.WriteRTCP([]rtcp.Packet{
&rtcp.ReceiverReport{
Reports: reports,
},
})
}
func rtcpUpSender(conn *upConnection) {
for {
time.Sleep(time.Second)
err := sendRR(conn)
if err != nil {
if err == io.EOF || err == io.ErrClosedPipe {
return
}
log.Printf("sendRR: %v", err)
}
}
}
func sendSR(conn *rtpDownConnection) error {
// since this is only called after all tracks have been created,
// there is no need for locking.
packets := make([]rtcp.Packet, 0, len(conn.tracks))
now := time.Now()
nowNTP := rtptime.TimeToNTP(now)
jiffies := rtptime.TimeToJiffies(now)
for _, t := range conn.tracks {
clockrate := t.track.Codec().ClockRate
remote := t.remote
remote.mu.Lock()
lastTime := remote.srTime
srNTPTime := remote.srNTPTime
srRTPTime := remote.srRTPTime
remote.mu.Unlock()
if lastTime == 0 {
// we never got a remote SR, skip this track
continue
}
nowRTP := srRTPTime
if srNTPTime != 0 {
srTime := rtptime.NTPToTime(srNTPTime)
delay := now.Sub(srTime)
if delay > 0 && delay < time.Hour {
d := rtptime.FromDuration(delay, clockrate)
nowRTP = srRTPTime + uint32(d)
}
}
p, b := t.rate.Totals()
packets = append(packets,
&rtcp.SenderReport{
SSRC: t.track.SSRC(),
NTPTime: nowNTP,
RTPTime: nowRTP,
PacketCount: p,
OctetCount: b,
})
atomic.StoreUint64(&t.srTime, jiffies)
atomic.StoreUint64(&t.srNTPTime, nowNTP)
}
if len(packets) == 0 {
return nil
}
return conn.pc.WriteRTCP(packets)
}
func rtcpDownSender(conn *rtpDownConnection) {
for {
time.Sleep(time.Second)
err := sendSR(conn)
if err != nil {
if err == io.EOF || err == io.ErrClosedPipe {
return
}
log.Printf("sendSR: %v", err)
}
}
}
func delUpConn(c *webClient, id string) bool { func delUpConn(c *webClient, id string) bool {
c.mu.Lock() c.mu.Lock()
if c.up == nil { if c.up == nil {
...@@ -769,37 +291,28 @@ func getConn(c *webClient, id string) iceConnection { ...@@ -769,37 +291,28 @@ func getConn(c *webClient, id string) iceConnection {
return nil return nil
} }
func addDownConn(c *webClient, id string, remote *upConnection) (*rtpDownConnection, error) { func addDownConn(c *webClient, id string, remote upConnection) (*rtpDownConnection, error) {
pc, err := groups.api.NewPeerConnection(iceConfiguration()) conn, err := newDownConn(id, remote)
if err != nil { if err != nil {
return nil, err return nil, err
} }
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { c.mu.Lock()
sendICE(c, id, candidate) defer c.mu.Unlock()
})
pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) {
log.Printf("Got track on downstream connection")
})
if c.down == nil { if c.down == nil {
c.down = make(map[string]*rtpDownConnection) c.down = make(map[string]*rtpDownConnection)
} }
conn := &rtpDownConnection{
id: id,
client: c,
pc: pc,
remote: remote,
}
c.mu.Lock()
defer c.mu.Unlock()
if c.down[id] != nil || (c.up != nil && c.up[id] != nil) { if c.down[id] != nil || (c.up != nil && c.up[id] != nil) {
conn.pc.Close() conn.pc.Close()
return nil, errors.New("Adding duplicate connection") return nil, errors.New("Adding duplicate connection")
} }
conn.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
sendICE(c, id, candidate)
})
err = remote.addLocal(conn) err = remote.addLocal(conn)
if err != nil { if err != nil {
conn.pc.Close() conn.pc.Close()
...@@ -807,6 +320,10 @@ func addDownConn(c *webClient, id string, remote *upConnection) (*rtpDownConnect ...@@ -807,6 +320,10 @@ func addDownConn(c *webClient, id string, remote *upConnection) (*rtpDownConnect
} }
c.down[id] = conn c.down[id] = conn
conn.close = func() error {
return c.action(delConnAction{conn.id})
}
return conn, nil return conn, nil
} }
...@@ -833,13 +350,21 @@ func delDownConn(c *webClient, id string) bool { ...@@ -833,13 +350,21 @@ func delDownConn(c *webClient, id string) bool {
return true return true
} }
func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack *upTrack, remoteConn *upConnection) (*webrtc.RTPSender, error) { func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack upTrack, remoteConn upConnection) (*webrtc.RTPSender, error) {
local, err := conn.pc.NewTrack( var pt uint8
remoteTrack.track.PayloadType(), var ssrc uint32
remoteTrack.track.SSRC(), var id, label string
remoteTrack.track.ID(), switch rt := remoteTrack.(type) {
remoteTrack.track.Label(), case *rtpUpTrack:
) pt = rt.track.PayloadType()
ssrc = rt.track.SSRC()
id = rt.track.ID()
label = rt.track.Label()
default:
return nil, errors.New("not implemented yet")
}
local, err := conn.pc.NewTrack(pt, ssrc, id, label)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -864,319 +389,6 @@ func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack *upTrack, r ...@@ -864,319 +389,6 @@ func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack *upTrack, r
return s, nil return s, nil
} }
const (
minLossRate = 9600
initLossRate = 512 * 1000
maxLossRate = 1 << 30
)
func (track *rtpDownTrack) updateRate(loss uint8, now uint64) {
rate := track.maxLossBitrate.Get(now)
if rate < minLossRate || rate > maxLossRate {
// no recent feedback, reset
rate = initLossRate
}
if loss < 5 {
// if our actual rate is low, then we're not probing the
// bottleneck
r, _ := track.rate.Estimate()
actual := 8 * uint64(r)
if actual >= (rate*7)/8 {
// loss < 0.02, multiply by 1.05
rate = rate * 269 / 256
if rate > maxLossRate {
rate = maxLossRate
}
}
} else if loss > 25 {
// loss > 0.1, multiply by (1 - loss/2)
rate = rate * (512 - uint64(loss)) / 512
if rate < minLossRate {
rate = minLossRate
}
}
// update unconditionally, to set the timestamp
track.maxLossBitrate.Set(rate, now)
}
func rtcpDownListener(conn *rtpDownConnection, track *rtpDownTrack, s *webrtc.RTPSender) {
var gotFir bool
lastFirSeqno := uint8(0)
for {
ps, err := s.ReadRTCP()
if err != nil {
if err != io.EOF {
log.Printf("ReadRTCP: %v", err)
}
return
}
jiffies := rtptime.Jiffies()
for _, p := range ps {
switch p := p.(type) {
case *rtcp.PictureLossIndication:
err := conn.remote.sendPLI(track.remote)
if err != nil {
log.Printf("sendPLI: %v", err)
}
case *rtcp.FullIntraRequest:
found := false
var seqno uint8
for _, entry := range p.FIR {
if entry.SSRC == track.track.SSRC() {
found = true
seqno = entry.SequenceNumber
break
}
}
if !found {
log.Printf("Misdirected FIR")
continue
}
increment := true
if gotFir {
increment = seqno != lastFirSeqno
}
gotFir = true
lastFirSeqno = seqno
err := conn.remote.sendFIR(
track.remote, increment,
)
if err == ErrUnsupportedFeedback {
err := conn.remote.sendPLI(track.remote)
if err != nil {
log.Printf("sendPLI: %v", err)
}
} else if err != nil {
log.Printf("sendFIR: %v", err)
}
case *rtcp.ReceiverEstimatedMaximumBitrate:
track.maxREMBBitrate.Set(p.Bitrate, jiffies)
case *rtcp.ReceiverReport:
for _, r := range p.Reports {
if r.SSRC == track.track.SSRC() {
handleReport(track, r, jiffies)
}
}
case *rtcp.SenderReport:
for _, r := range p.Reports {
if r.SSRC == track.track.SSRC() {
handleReport(track, r, jiffies)
}
}
case *rtcp.TransportLayerNack:
maxBitrate := track.GetMaxBitrate(jiffies)
bitrate, _ := track.rate.Estimate()
if uint64(bitrate)*7/8 < maxBitrate {
sendRecovery(p, track)
}
}
}
}
}
func handleReport(track *rtpDownTrack, report rtcp.ReceptionReport, jiffies uint64) {
track.stats.Set(report.FractionLost, report.Jitter, jiffies)
track.updateRate(report.FractionLost, jiffies)
if report.LastSenderReport != 0 {
jiffies := rtptime.Jiffies()
srTime := atomic.LoadUint64(&track.srTime)
if jiffies < srTime || jiffies-srTime > 8*rtptime.JiffiesPerSec {
return
}
srNTPTime := atomic.LoadUint64(&track.srNTPTime)
if report.LastSenderReport == uint32(srNTPTime>>16) {
delay := uint64(report.Delay) *
(rtptime.JiffiesPerSec / 0x10000)
if delay > jiffies-srTime {
return
}
rtt := (jiffies - srTime) - delay
oldrtt := atomic.LoadUint64(&track.rtt)
newrtt := rtt
if oldrtt > 0 {
newrtt = (3*oldrtt + rtt) / 4
}
atomic.StoreUint64(&track.rtt, newrtt)
}
}
}
func updateUpTrack(track *upTrack, maxVideoRate uint64) uint64 {
now := rtptime.Jiffies()
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
clockrate := track.track.Codec().ClockRate
minrate := uint64(minAudioRate)
rate := ^uint64(0)
if isvideo {
minrate = minVideoRate
rate = maxVideoRate
if rate < minrate {
rate = minrate
}
}
local := track.getLocal()
var maxrto uint64
for _, l := range local {
bitrate := l.GetMaxBitrate(now)
if bitrate == ^uint64(0) {
continue
}
if bitrate <= minrate {
rate = minrate
break
}
if rate > bitrate {
rate = bitrate
}
ll, ok := l.(*rtpDownTrack)
if ok {
_, j := ll.stats.Get(now)
jitter := uint64(j) *
(rtptime.JiffiesPerSec /
uint64(clockrate))
rtt := atomic.LoadUint64(&ll.rtt)
rto := rtt + 4*jitter
if rto > maxrto {
maxrto = rto
}
}
}
track.maxBitrate = rate
_, r := track.rate.Estimate()
packets := int((uint64(r) * maxrto * 4) / rtptime.JiffiesPerSec)
if packets < 32 {
packets = 32
}
if packets > 256 {
packets = 256
}
track.cache.ResizeCond(packets)
return rate
}
var ErrUnsupportedFeedback = errors.New("unsupported feedback type")
var ErrRateLimited = errors.New("rate limited")
func (up *upConnection) sendPLI(track *upTrack) error {
if !track.hasRtcpFb("nack", "pli") {
return ErrUnsupportedFeedback
}
last := atomic.LoadUint64(&track.lastPLI)
now := rtptime.Jiffies()
if now >= last && now-last < rtptime.JiffiesPerSec/5 {
return ErrRateLimited
}
atomic.StoreUint64(&track.lastPLI, now)
return sendPLI(up.pc, track.track.SSRC())
}
func sendPLI(pc *webrtc.PeerConnection, ssrc uint32) error {
return pc.WriteRTCP([]rtcp.Packet{
&rtcp.PictureLossIndication{MediaSSRC: ssrc},
})
}
func (up *upConnection) sendFIR(track *upTrack, increment bool) error {
// we need to reliably increment the seqno, even if we are going
// to drop the packet due to rate limiting.
var seqno uint8
if increment {
seqno = uint8(atomic.AddUint32(&track.firSeqno, 1) & 0xFF)
} else {
seqno = uint8(atomic.LoadUint32(&track.firSeqno) & 0xFF)
}
if !track.hasRtcpFb("ccm", "fir") {
return ErrUnsupportedFeedback
}
last := atomic.LoadUint64(&track.lastFIR)
now := rtptime.Jiffies()
if now >= last && now-last < rtptime.JiffiesPerSec/5 {
return ErrRateLimited
}
atomic.StoreUint64(&track.lastFIR, now)
return sendFIR(up.pc, track.track.SSRC(), seqno)
}
func sendFIR(pc *webrtc.PeerConnection, ssrc uint32, seqno uint8) error {
return pc.WriteRTCP([]rtcp.Packet{
&rtcp.FullIntraRequest{
FIR: []rtcp.FIREntry{
rtcp.FIREntry{
SSRC: ssrc,
SequenceNumber: seqno,
},
},
},
})
}
func sendREMB(pc *webrtc.PeerConnection, ssrc uint32, bitrate uint64) error {
return pc.WriteRTCP([]rtcp.Packet{
&rtcp.ReceiverEstimatedMaximumBitrate{
Bitrate: bitrate,
SSRCs: []uint32{ssrc},
},
})
}
func (up *upConnection) sendNACK(track *upTrack, first uint16, bitmap uint16) error {
if !track.hasRtcpFb("nack", "") {
return nil
}
err := sendNACK(up.pc, track.track.SSRC(), first, bitmap)
if err == nil {
track.cache.Expect(1 + bits.OnesCount16(bitmap))
}
return err
}
func sendNACK(pc *webrtc.PeerConnection, ssrc uint32, first uint16, bitmap uint16) error {
packet := rtcp.Packet(
&rtcp.TransportLayerNack{
MediaSSRC: ssrc,
Nacks: []rtcp.NackPair{
rtcp.NackPair{
first,
rtcp.PacketBitmap(bitmap),
},
},
},
)
return pc.WriteRTCP([]rtcp.Packet{packet})
}
func sendRecovery(p *rtcp.TransportLayerNack, track *rtpDownTrack) {
var packet rtp.Packet
buf := make([]byte, packetcache.BufSize)
for _, nack := range p.Nacks {
for _, seqno := range nack.PacketList() {
l := track.remote.cache.Get(seqno, buf)
if l == 0 {
continue
}
err := packet.Unmarshal(buf[:l])
if err != nil {
continue
}
err = track.track.WriteRTP(&packet)
if err != nil {
log.Printf("WriteRTP: %v", err)
continue
}
track.rate.Accumulate(uint32(l))
}
}
}
func negotiate(c *webClient, down *rtpDownConnection) error { func negotiate(c *webClient, down *rtpDownConnection) error {
offer, err := down.pc.CreateOffer(nil) offer, err := down.pc.CreateOffer(nil)
if err != nil { if err != nil {
...@@ -1200,7 +412,7 @@ func negotiate(c *webClient, down *rtpDownConnection) error { ...@@ -1200,7 +412,7 @@ func negotiate(c *webClient, down *rtpDownConnection) error {
for _, tr := range down.tracks { for _, tr := range down.tracks {
if tr.track == track { if tr.track == track {
labels[t.Mid()] = tr.remote.label labels[t.Mid()] = tr.remote.Label()
} }
} }
} }
...@@ -1313,7 +525,7 @@ func (c *webClient) setRequested(requested map[string]uint32) error { ...@@ -1313,7 +525,7 @@ func (c *webClient) setRequested(requested map[string]uint32) error {
} }
func pushConns(c client) { func pushConns(c client) {
clients := c.getGroup().getClients(c) clients := c.Group().getClients(c)
for _, cc := range clients { for _, cc := range clients {
ccc, ok := cc.(*webClient) ccc, ok := cc.(*webClient)
if ok { if ok {
...@@ -1326,10 +538,10 @@ func (c *webClient) isRequested(label string) bool { ...@@ -1326,10 +538,10 @@ func (c *webClient) isRequested(label string) bool {
return c.requested[label] != 0 return c.requested[label] != 0
} }
func addDownConnTracks(c *webClient, remote *upConnection, tracks []*upTrack) (*rtpDownConnection, error) { func addDownConnTracks(c *webClient, remote upConnection, tracks []upTrack) (*rtpDownConnection, error) {
requested := false requested := false
for _, t := range tracks { for _, t := range tracks {
if c.isRequested(t.label) { if c.isRequested(t.Label()) {
requested = true requested = true
break break
} }
...@@ -1338,13 +550,13 @@ func addDownConnTracks(c *webClient, remote *upConnection, tracks []*upTrack) (* ...@@ -1338,13 +550,13 @@ func addDownConnTracks(c *webClient, remote *upConnection, tracks []*upTrack) (*
return nil, nil return nil, nil
} }
down, err := addDownConn(c, remote.id, remote) down, err := addDownConn(c, remote.Id(), remote)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, t := range tracks { for _, t := range tracks {
if !c.isRequested(t.label) { if !c.isRequested(t.Label()) {
continue continue
} }
_, err = addDownTrack(c, down, t, remote) _, err = addDownTrack(c, down, t, remote)
...@@ -1359,13 +571,13 @@ func addDownConnTracks(c *webClient, remote *upConnection, tracks []*upTrack) (* ...@@ -1359,13 +571,13 @@ func addDownConnTracks(c *webClient, remote *upConnection, tracks []*upTrack) (*
return down, nil return down, nil
} }
func (c *webClient) pushConn(conn *upConnection, tracks []*upTrack, label string) error { func (c *webClient) pushConn(conn upConnection, tracks []upTrack, label string) error {
err := c.action(addConnAction{conn, tracks}) err := c.action(addConnAction{conn, tracks})
if err != nil { if err != nil {
return err return err
} }
if label != "" { if label != "" {
err := c.action(addLabelAction{conn.id, conn.label}) err := c.action(addLabelAction{conn.Id(), conn.Label()})
if err != nil { if err != nil {
return err return err
} }
...@@ -1373,6 +585,78 @@ func (c *webClient) pushConn(conn *upConnection, tracks []*upTrack, label string ...@@ -1373,6 +585,78 @@ func (c *webClient) pushConn(conn *upConnection, tracks []*upTrack, label string
return nil return nil
} }
func startClient(conn *websocket.Conn) (err error) {
var m clientMessage
err = conn.SetReadDeadline(time.Now().Add(15 * time.Second))
if err != nil {
conn.Close()
return
}
err = conn.ReadJSON(&m)
if err != nil {
conn.Close()
return
}
err = conn.SetReadDeadline(time.Time{})
if err != nil {
conn.Close()
return
}
if m.Type != "handshake" {
conn.Close()
return
}
if strings.ContainsRune(m.Username, ' ') {
err = userError("don't put spaces in your username")
return
}
c := &webClient{
id: m.Id,
username: m.Username,
actionCh: make(chan interface{}, 10),
done: make(chan struct{}),
}
defer close(c.done)
c.writeCh = make(chan interface{}, 25)
defer func() {
if isWSNormalError(err) {
err = nil
} else {
m, e := errorToWSCloseMessage(err)
if m != "" {
c.write(clientMessage{
Type: "error",
Value: m,
})
}
select {
case c.writeCh <- closeMessage{e}:
case <-c.writerDone:
}
}
close(c.writeCh)
c.writeCh = nil
}()
c.writerDone = make(chan struct{})
go clientWriter(conn, c.writeCh, c.writerDone)
g, err := addClient(m.Group, c, m.Username, m.Password)
if err != nil {
return
}
c.group = g
defer delClient(c)
return clientLoop(c, conn)
}
func clientLoop(c *webClient, conn *websocket.Conn) error { func clientLoop(c *webClient, conn *websocket.Conn) error {
read := make(chan interface{}, 1) read := make(chan interface{}, 1)
go clientReader(conn, read, c.done) go clientReader(conn, read, c.done)
...@@ -1469,7 +753,11 @@ func clientLoop(c *webClient, conn *websocket.Conn) error { ...@@ -1469,7 +753,11 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
case pushConnsAction: case pushConnsAction:
for _, u := range c.up { for _, u := range c.up {
tracks := u.getTracks() tracks := u.getTracks()
go a.c.pushConn(u, tracks, u.label) ts := make([]upTrack, len(tracks))
for i, t := range tracks {
ts[i] = t
}
go a.c.pushConn(u, ts, u.label)
} }
case connectionFailedAction: case connectionFailedAction:
found := delUpConn(c, a.id) found := delUpConn(c, a.id)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment