Commit d9cf32ed authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Move connections to their own package.

parent 7148ea18
package main package main
import (
"sfu/conn"
)
type clientCredentials struct { type clientCredentials struct {
Username string `json:"username,omitempty"` Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"` Password string `json:"password,omitempty"`
...@@ -16,7 +20,7 @@ type client interface { ...@@ -16,7 +20,7 @@ type client interface {
Id() string Id() string
Credentials() clientCredentials Credentials() clientCredentials
SetPermissions(clientPermissions) SetPermissions(clientPermissions)
pushConn(id string, conn upConnection, tracks []upTrack, label string) error pushConn(id string, conn conn.Up, tracks []conn.UpTrack, label string) error
pushClient(id, username string, add bool) error pushClient(id, username string, add bool) error
} }
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
// This is not open source software. Copy it, and I'll break into your // This is not open source software. Copy it, and I'll break into your
// house and tell your three year-old that Santa doesn't exist. // house and tell your three year-old that Santa doesn't exist.
package main // Package conn defines interfaces for connections and tracks.
package conn
import ( import (
"errors" "errors"
...@@ -15,29 +16,33 @@ import ( ...@@ -15,29 +16,33 @@ import (
var ErrConnectionClosed = errors.New("connection is closed") var ErrConnectionClosed = errors.New("connection is closed")
var ErrKeyframeNeeded = errors.New("keyframe needed") var ErrKeyframeNeeded = errors.New("keyframe needed")
type upConnection interface { // Type Up represents a connection in the client to server direction.
addLocal(downConnection) error type Up interface {
delLocal(downConnection) bool AddLocal(Down) error
DelLocal(Down) bool
Id() string Id() string
Label() string Label() string
} }
type upTrack interface { // Type UpTrack represents a track in the client to server direction.
addLocal(downTrack) error type UpTrack interface {
delLocal(downTrack) bool AddLocal(DownTrack) error
DelLocal(DownTrack) bool
Label() string Label() string
Codec() *webrtc.RTPCodec Codec() *webrtc.RTPCodec
// get a recent packet. Returns 0 if the packet is not in cache. // get a recent packet. Returns 0 if the packet is not in cache.
getRTP(seqno uint16, result []byte) uint16 GetRTP(seqno uint16, result []byte) uint16
} }
type downConnection interface { // Type Down represents a connection in the server to client direction.
type Down interface {
GetMaxBitrate(now uint64) uint64 GetMaxBitrate(now uint64) uint64
} }
type downTrack interface { // Type DownTrack represents a track in the server to client direction.
type DownTrack interface {
WriteRTP(packat *rtp.Packet) error WriteRTP(packat *rtp.Packet) error
Accumulate(bytes uint32) Accumulate(bytes uint32)
setTimeOffset(ntp uint64, rtp uint32) SetTimeOffset(ntp uint64, rtp uint32)
setCname(string) SetCname(string)
} }
...@@ -14,6 +14,8 @@ import ( ...@@ -14,6 +14,8 @@ import (
"github.com/pion/rtp/codecs" "github.com/pion/rtp/codecs"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
"github.com/pion/webrtc/v3/pkg/media/samplebuilder" "github.com/pion/webrtc/v3/pkg/media/samplebuilder"
"sfu/conn"
) )
type diskClient struct { type diskClient struct {
...@@ -81,7 +83,7 @@ func (client *diskClient) kick(message string) error { ...@@ -81,7 +83,7 @@ func (client *diskClient) kick(message string) error {
return err return err
} }
func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrack, label string) error { func (client *diskClient) pushConn(id string, up conn.Up, tracks []conn.UpTrack, label string) error {
client.mu.Lock() client.mu.Lock()
defer client.mu.Unlock() defer client.mu.Unlock()
...@@ -95,7 +97,7 @@ func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrac ...@@ -95,7 +97,7 @@ func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrac
delete(client.down, id) delete(client.down, id)
} }
if conn == nil { if up == nil {
return nil return nil
} }
...@@ -109,12 +111,12 @@ func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrac ...@@ -109,12 +111,12 @@ func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrac
client.down = make(map[string]*diskConn) client.down = make(map[string]*diskConn)
} }
down, err := newDiskConn(directory, label, conn, tracks) down, err := newDiskConn(directory, label, up, tracks)
if err != nil { if err != nil {
return err return err
} }
client.down[conn.Id()] = down client.down[up.Id()] = down
return nil return nil
} }
...@@ -125,7 +127,7 @@ type diskConn struct { ...@@ -125,7 +127,7 @@ type diskConn struct {
mu sync.Mutex mu sync.Mutex
file *os.File file *os.File
remote upConnection remote conn.Up
tracks []*diskTrack tracks []*diskTrack
width, height uint32 width, height uint32
} }
...@@ -150,7 +152,7 @@ func (conn *diskConn) reopen() error { ...@@ -150,7 +152,7 @@ func (conn *diskConn) reopen() error {
} }
func (conn *diskConn) Close() error { func (conn *diskConn) Close() error {
conn.remote.delLocal(conn) conn.remote.DelLocal(conn)
conn.mu.Lock() conn.mu.Lock()
tracks := make([]*diskTrack, 0, len(conn.tracks)) tracks := make([]*diskTrack, 0, len(conn.tracks))
...@@ -164,7 +166,7 @@ func (conn *diskConn) Close() error { ...@@ -164,7 +166,7 @@ func (conn *diskConn) Close() error {
conn.mu.Unlock() conn.mu.Unlock()
for _, t := range tracks { for _, t := range tracks {
t.remote.delLocal(t) t.remote.DelLocal(t)
} }
return nil return nil
} }
...@@ -196,7 +198,7 @@ func openDiskFile(directory, label string) (*os.File, error) { ...@@ -196,7 +198,7 @@ func openDiskFile(directory, label string) (*os.File, error) {
} }
type diskTrack struct { type diskTrack struct {
remote upTrack remote conn.UpTrack
conn *diskConn conn *diskConn
writer webm.BlockWriteCloser writer webm.BlockWriteCloser
...@@ -206,7 +208,7 @@ type diskTrack struct { ...@@ -206,7 +208,7 @@ type diskTrack struct {
origin uint64 origin uint64
} }
func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrack) (*diskConn, error) { func newDiskConn(directory, label string, up conn.Up, remoteTracks []conn.UpTrack) (*diskConn, error) {
conn := diskConn{ conn := diskConn{
directory: directory, directory: directory,
label: label, label: label,
...@@ -231,10 +233,10 @@ func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrac ...@@ -231,10 +233,10 @@ func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrac
conn: &conn, conn: &conn,
} }
conn.tracks = append(conn.tracks, track) conn.tracks = append(conn.tracks, track)
remote.addLocal(track) remote.AddLocal(track)
} }
err := up.addLocal(&conn) err := up.AddLocal(&conn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -242,10 +244,10 @@ func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrac ...@@ -242,10 +244,10 @@ func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrac
return &conn, nil return &conn, nil
} }
func (t *diskTrack) setTimeOffset(ntp uint64, rtp uint32) { func (t *diskTrack) SetTimeOffset(ntp uint64, rtp uint32) {
} }
func (t *diskTrack) setCname(string) { func (t *diskTrack) SetCname(string) {
} }
func clonePacket(packet *rtp.Packet) *rtp.Packet { func clonePacket(packet *rtp.Packet) *rtp.Packet {
...@@ -310,7 +312,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { ...@@ -310,7 +312,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error {
if t.writer == nil { if t.writer == nil {
if !keyframe { if !keyframe {
return ErrKeyframeNeeded return conn.ErrKeyframeNeeded
} }
return nil return nil
} }
......
...@@ -14,14 +14,15 @@ import ( ...@@ -14,14 +14,15 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/pion/webrtc/v3"
"sfu/conn"
"sfu/estimator" "sfu/estimator"
"sfu/jitter" "sfu/jitter"
"sfu/packetcache" "sfu/packetcache"
"sfu/rtptime" "sfu/rtptime"
"github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/pion/webrtc/v3"
) )
type bitrate struct { type bitrate struct {
...@@ -71,7 +72,7 @@ type iceConnection interface { ...@@ -71,7 +72,7 @@ type iceConnection interface {
type rtpDownTrack struct { type rtpDownTrack struct {
track *webrtc.Track track *webrtc.Track
remote upTrack remote conn.UpTrack
maxBitrate *bitrate maxBitrate *bitrate
rate *estimator.Estimator rate *estimator.Estimator
stats *receiverStats stats *receiverStats
...@@ -91,25 +92,25 @@ func (down *rtpDownTrack) Accumulate(bytes uint32) { ...@@ -91,25 +92,25 @@ func (down *rtpDownTrack) Accumulate(bytes uint32) {
down.rate.Accumulate(bytes) down.rate.Accumulate(bytes)
} }
func (down *rtpDownTrack) setTimeOffset(ntp uint64, rtp uint32) { func (down *rtpDownTrack) SetTimeOffset(ntp uint64, rtp uint32) {
atomic.StoreUint64(&down.remoteNTPTime, ntp) atomic.StoreUint64(&down.remoteNTPTime, ntp)
atomic.StoreUint32(&down.remoteRTPTime, rtp) atomic.StoreUint32(&down.remoteRTPTime, rtp)
} }
func (down *rtpDownTrack) setCname(cname string) { func (down *rtpDownTrack) SetCname(cname string) {
down.cname.Store(cname) down.cname.Store(cname)
} }
type rtpDownConnection struct { type rtpDownConnection struct {
id string id string
pc *webrtc.PeerConnection pc *webrtc.PeerConnection
remote upConnection remote conn.Up
tracks []*rtpDownTrack tracks []*rtpDownTrack
maxREMBBitrate *bitrate maxREMBBitrate *bitrate
iceCandidates []*webrtc.ICECandidateInit iceCandidates []*webrtc.ICECandidateInit
} }
func newDownConn(c client, id string, remote upConnection) (*rtpDownConnection, error) { func newDownConn(c client, id string, remote conn.Up) (*rtpDownConnection, error) {
pc, err := c.Group().API().NewPeerConnection(iceConfiguration()) pc, err := c.Group().API().NewPeerConnection(iceConfiguration())
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -193,7 +194,7 @@ type rtpUpTrack struct { ...@@ -193,7 +194,7 @@ type rtpUpTrack struct {
mu sync.Mutex mu sync.Mutex
cname string cname string
local []downTrack local []conn.DownTrack
srTime uint64 srTime uint64
srNTPTime uint64 srNTPTime uint64
srRTPTime uint32 srRTPTime uint32
...@@ -201,17 +202,17 @@ type rtpUpTrack struct { ...@@ -201,17 +202,17 @@ type rtpUpTrack struct {
type localTrackAction struct { type localTrackAction struct {
add bool add bool
track downTrack track conn.DownTrack
} }
func (up *rtpUpTrack) notifyLocal(add bool, track downTrack) { func (up *rtpUpTrack) notifyLocal(add bool, track conn.DownTrack) {
select { select {
case up.localCh <- localTrackAction{add, track}: case up.localCh <- localTrackAction{add, track}:
case <-up.readerDone: case <-up.readerDone:
} }
} }
func (up *rtpUpTrack) addLocal(local downTrack) error { func (up *rtpUpTrack) AddLocal(local conn.DownTrack) error {
up.mu.Lock() up.mu.Lock()
for _, t := range up.local { for _, t := range up.local {
if t == local { if t == local {
...@@ -226,7 +227,7 @@ func (up *rtpUpTrack) addLocal(local downTrack) error { ...@@ -226,7 +227,7 @@ func (up *rtpUpTrack) addLocal(local downTrack) error {
return nil return nil
} }
func (up *rtpUpTrack) delLocal(local downTrack) bool { func (up *rtpUpTrack) DelLocal(local conn.DownTrack) bool {
up.mu.Lock() up.mu.Lock()
for i, l := range up.local { for i, l := range up.local {
if l == local { if l == local {
...@@ -240,15 +241,15 @@ func (up *rtpUpTrack) delLocal(local downTrack) bool { ...@@ -240,15 +241,15 @@ func (up *rtpUpTrack) delLocal(local downTrack) bool {
return false return false
} }
func (up *rtpUpTrack) getLocal() []downTrack { func (up *rtpUpTrack) getLocal() []conn.DownTrack {
up.mu.Lock() up.mu.Lock()
defer up.mu.Unlock() defer up.mu.Unlock()
local := make([]downTrack, len(up.local)) local := make([]conn.DownTrack, len(up.local))
copy(local, up.local) copy(local, up.local)
return local return local
} }
func (up *rtpUpTrack) getRTP(seqno uint16, result []byte) uint16 { func (up *rtpUpTrack) GetRTP(seqno uint16, result []byte) uint16 {
return up.cache.Get(seqno, result) return up.cache.Get(seqno, result)
} }
...@@ -278,7 +279,7 @@ type rtpUpConnection struct { ...@@ -278,7 +279,7 @@ type rtpUpConnection struct {
mu sync.Mutex mu sync.Mutex
tracks []*rtpUpTrack tracks []*rtpUpTrack
local []downConnection local []conn.Down
} }
func (up *rtpUpConnection) getTracks() []*rtpUpTrack { func (up *rtpUpConnection) getTracks() []*rtpUpTrack {
...@@ -297,7 +298,7 @@ func (up *rtpUpConnection) Label() string { ...@@ -297,7 +298,7 @@ func (up *rtpUpConnection) Label() string {
return up.label return up.label
} }
func (up *rtpUpConnection) addLocal(local downConnection) error { func (up *rtpUpConnection) AddLocal(local conn.Down) error {
up.mu.Lock() up.mu.Lock()
defer up.mu.Unlock() defer up.mu.Unlock()
for _, t := range up.local { for _, t := range up.local {
...@@ -309,7 +310,7 @@ func (up *rtpUpConnection) addLocal(local downConnection) error { ...@@ -309,7 +310,7 @@ func (up *rtpUpConnection) addLocal(local downConnection) error {
return nil return nil
} }
func (up *rtpUpConnection) delLocal(local downConnection) bool { func (up *rtpUpConnection) DelLocal(local conn.Down) bool {
up.mu.Lock() up.mu.Lock()
defer up.mu.Unlock() defer up.mu.Unlock()
for i, l := range up.local { for i, l := range up.local {
...@@ -321,10 +322,10 @@ func (up *rtpUpConnection) delLocal(local downConnection) bool { ...@@ -321,10 +322,10 @@ func (up *rtpUpConnection) delLocal(local downConnection) bool {
return false return false
} }
func (up *rtpUpConnection) getLocal() []downConnection { func (up *rtpUpConnection) getLocal() []conn.Down {
up.mu.Lock() up.mu.Lock()
defer up.mu.Unlock() defer up.mu.Unlock()
local := make([]downConnection, len(up.local)) local := make([]conn.Down, len(up.local))
copy(local, up.local) copy(local, up.local)
return local return local
} }
...@@ -396,10 +397,10 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) { ...@@ -396,10 +397,10 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) {
return nil, err return nil, err
} }
conn := &rtpUpConnection{id: id, pc: pc} up := &rtpUpConnection{id: id, pc: pc}
pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) { pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) {
conn.mu.Lock() up.mu.Lock()
mid := getTrackMid(pc, remote) mid := getTrackMid(pc, remote)
if mid == "" { if mid == "" {
...@@ -407,7 +408,7 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) { ...@@ -407,7 +408,7 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) {
return return
} }
label, ok := conn.labels[mid] label, ok := up.labels[mid]
if !ok { if !ok {
log.Printf("Couldn't get track's label") log.Printf("Couldn't get track's label")
isvideo := remote.Kind() == webrtc.RTPCodecTypeVideo isvideo := remote.Kind() == webrtc.RTPCodecTypeVideo
...@@ -428,34 +429,34 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) { ...@@ -428,34 +429,34 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) {
readerDone: make(chan struct{}), readerDone: make(chan struct{}),
} }
conn.tracks = append(conn.tracks, track) up.tracks = append(up.tracks, track)
go readLoop(conn, track) go readLoop(up, track)
go rtcpUpListener(conn, track, receiver) go rtcpUpListener(up, track, receiver)
complete := conn.complete() complete := up.complete()
var tracks []upTrack var tracks []conn.UpTrack
if(complete) { if complete {
tracks = make([]upTrack, len(conn.tracks)) tracks = make([]conn.UpTrack, len(up.tracks))
for i, t := range conn.tracks { for i, t := range up.tracks {
tracks[i] = t tracks[i] = t
} }
} }
// pushConn might need to take the lock // pushConn might need to take the lock
conn.mu.Unlock() up.mu.Unlock()
if complete { if complete {
clients := c.Group().getClients(c) clients := c.Group().getClients(c)
for _, cc := range clients { for _, cc := range clients {
cc.pushConn(conn.id, conn, tracks, conn.label) cc.pushConn(up.id, up, tracks, up.label)
} }
go rtcpUpSender(conn) go rtcpUpSender(up)
} }
}) })
return conn, nil return up, nil
} }
func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { func readLoop(conn *rtpUpConnection, track *rtpUpTrack) {
...@@ -606,7 +607,7 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *rtpDownTrack) { ...@@ -606,7 +607,7 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *rtpDownTrack) {
buf := make([]byte, packetcache.BufSize) buf := make([]byte, packetcache.BufSize)
for _, nack := range p.Nacks { for _, nack := range p.Nacks {
for _, seqno := range nack.PacketList() { for _, seqno := range nack.PacketList() {
l := track.remote.getRTP(seqno, buf) l := track.remote.GetRTP(seqno, buf)
if l == 0 { if l == 0 {
continue continue
} }
...@@ -650,7 +651,7 @@ func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPRecei ...@@ -650,7 +651,7 @@ func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPRecei
track.srRTPTime = p.RTPTime track.srRTPTime = p.RTPTime
track.mu.Unlock() track.mu.Unlock()
for _, l := range local { for _, l := range local {
l.setTimeOffset(p.NTPTime, p.RTPTime) l.SetTimeOffset(p.NTPTime, p.RTPTime)
} }
case *rtcp.SourceDescription: case *rtcp.SourceDescription:
for _, c := range p.Chunks { for _, c := range p.Chunks {
...@@ -665,7 +666,7 @@ func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPRecei ...@@ -665,7 +666,7 @@ func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPRecei
track.cname = i.Text track.cname = i.Text
track.mu.Unlock() track.mu.Unlock()
for _, l := range local { for _, l := range local {
l.setCname(i.Text) l.SetCname(i.Text)
} }
} }
} }
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"github.com/pion/rtp" "github.com/pion/rtp"
"sfu/conn"
"sfu/packetcache" "sfu/packetcache"
"sfu/rtptime" "sfu/rtptime"
) )
...@@ -43,7 +44,7 @@ func sqrt(n int) int { ...@@ -43,7 +44,7 @@ func sqrt(n int) int {
} }
// add adds or removes a track from a writer pool // add adds or removes a track from a writer pool
func (wp *rtpWriterPool) add(track downTrack, add bool) error { func (wp *rtpWriterPool) add(track conn.DownTrack, add bool) error {
n := 4 n := 4
if wp.count > 16 { if wp.count > 16 {
n = sqrt(wp.count) n = sqrt(wp.count)
...@@ -166,7 +167,7 @@ var ErrUnknownTrack = errors.New("unknown track") ...@@ -166,7 +167,7 @@ var ErrUnknownTrack = errors.New("unknown track")
type writerAction struct { type writerAction struct {
add bool add bool
track downTrack track conn.DownTrack
maxTracks int maxTracks int
ch chan error ch chan error
} }
...@@ -192,7 +193,7 @@ func newRtpWriter(conn *rtpUpConnection, track *rtpUpTrack) *rtpWriter { ...@@ -192,7 +193,7 @@ func newRtpWriter(conn *rtpUpConnection, track *rtpUpTrack) *rtpWriter {
} }
// add adds or removes a track from a writer. // add adds or removes a track from a writer.
func (writer *rtpWriter) add(track downTrack, add bool, max int) error { func (writer *rtpWriter) add(track conn.DownTrack, add bool, max int) error {
ch := make(chan error, 1) ch := make(chan error, 1)
select { select {
case writer.action <- writerAction{add, track, max, ch}: case writer.action <- writerAction{add, track, max, ch}:
...@@ -208,13 +209,13 @@ func (writer *rtpWriter) add(track downTrack, add bool, max int) error { ...@@ -208,13 +209,13 @@ func (writer *rtpWriter) add(track downTrack, add bool, max int) error {
} }
// rtpWriterLoop is the main loop of an rtpWriter. // rtpWriterLoop is the main loop of an rtpWriter.
func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack) { func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
defer close(writer.done) defer close(writer.done)
buf := make([]byte, packetcache.BufSize) buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet var packet rtp.Packet
local := make([]downTrack, 0) local := make([]conn.DownTrack, 0)
// reset whenever a new track is inserted // reset whenever a new track is inserted
firSent := false firSent := false
...@@ -239,10 +240,10 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack) ...@@ -239,10 +240,10 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack)
cname := track.cname cname := track.cname
track.mu.Unlock() track.mu.Unlock()
if ntp != 0 { if ntp != 0 {
action.track.setTimeOffset(ntp, rtp) action.track.SetTimeOffset(ntp, rtp)
} }
if cname != "" { if cname != "" {
action.track.setCname(cname) action.track.SetCname(cname)
} }
} else { } else {
found := false found := false
...@@ -283,7 +284,7 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack) ...@@ -283,7 +284,7 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack)
for _, l := range local { for _, l := range local {
err := l.WriteRTP(&packet) err := l.WriteRTP(&packet)
if err != nil { if err != nil {
if err == ErrKeyframeNeeded { if err == conn.ErrKeyframeNeeded {
kfNeeded = true kfNeeded = true
} }
continue continue
...@@ -292,9 +293,9 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack) ...@@ -292,9 +293,9 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack)
} }
if kfNeeded { if kfNeeded {
err := conn.sendFIR(track, !firSent) err := up.sendFIR(track, !firSent)
if err == ErrUnsupportedFeedback { if err == ErrUnsupportedFeedback {
conn.sendPLI(track) up.sendPLI(track)
} }
firSent = true firSent = true
} }
......
...@@ -14,10 +14,11 @@ import ( ...@@ -14,10 +14,11 @@ import (
"sync" "sync"
"time" "time"
"sfu/estimator"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
"sfu/conn"
"sfu/estimator"
) )
var iceConf webrtc.Configuration var iceConf webrtc.Configuration
...@@ -300,7 +301,7 @@ func getConn(c *webClient, id string) iceConnection { ...@@ -300,7 +301,7 @@ func getConn(c *webClient, id string) iceConnection {
return nil return nil
} }
func addDownConn(c *webClient, id string, remote upConnection) (*rtpDownConnection, error) { func addDownConn(c *webClient, id string, remote conn.Up) (*rtpDownConnection, error) {
conn, err := newDownConn(c, id, remote) conn, err := newDownConn(c, id, remote)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -333,7 +334,7 @@ func addDownConn(c *webClient, id string, remote upConnection) (*rtpDownConnecti ...@@ -333,7 +334,7 @@ func addDownConn(c *webClient, id string, remote upConnection) (*rtpDownConnecti
} }
}) })
err = remote.addLocal(conn) err = remote.AddLocal(conn)
if err != nil { if err != nil {
conn.pc.Close() conn.pc.Close()
return nil, err return nil, err
...@@ -355,18 +356,18 @@ func delDownConn(c *webClient, id string) bool { ...@@ -355,18 +356,18 @@ func delDownConn(c *webClient, id string) bool {
return false return false
} }
conn.remote.delLocal(conn) conn.remote.DelLocal(conn)
for _, track := range conn.tracks { for _, track := range conn.tracks {
// we only insert the track after we get an answer, so // we only insert the track after we get an answer, so
// ignore errors here. // ignore errors here.
track.remote.delLocal(track) track.remote.DelLocal(track)
} }
conn.pc.Close() conn.pc.Close()
delete(c.down, id) delete(c.down, id)
return true return true
} }
func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack upTrack, remoteConn upConnection) (*webrtc.RTPSender, error) { func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrack, remoteConn conn.Up) (*webrtc.RTPSender, error) {
var pt uint8 var pt uint8
var ssrc uint32 var ssrc uint32
var id, label string var id, label string
...@@ -524,7 +525,7 @@ func gotAnswer(c *webClient, id string, answer webrtc.SessionDescription) error ...@@ -524,7 +525,7 @@ func gotAnswer(c *webClient, id string, answer webrtc.SessionDescription) error
} }
for _, t := range down.tracks { for _, t := range down.tracks {
t.remote.addLocal(t) t.remote.AddLocal(t)
} }
return nil return nil
} }
...@@ -568,7 +569,7 @@ func (c *webClient) isRequested(label string) bool { ...@@ -568,7 +569,7 @@ func (c *webClient) isRequested(label string) bool {
return c.requested[label] != 0 return c.requested[label] != 0
} }
func addDownConnTracks(c *webClient, remote upConnection, tracks []upTrack) (*rtpDownConnection, error) { func addDownConnTracks(c *webClient, remote conn.Up, tracks []conn.UpTrack) (*rtpDownConnection, error) {
requested := false requested := false
for _, t := range tracks { for _, t := range tracks {
if c.isRequested(t.Label()) { if c.isRequested(t.Label()) {
...@@ -601,13 +602,13 @@ func addDownConnTracks(c *webClient, remote upConnection, tracks []upTrack) (*rt ...@@ -601,13 +602,13 @@ func addDownConnTracks(c *webClient, remote upConnection, tracks []upTrack) (*rt
return down, nil return down, nil
} }
func (c *webClient) pushConn(id string, conn upConnection, tracks []upTrack, label string) error { func (c *webClient) pushConn(id string, up conn.Up, tracks []conn.UpTrack, label string) error {
err := c.action(pushConnAction{id, conn, tracks}) err := c.action(pushConnAction{id, up, tracks})
if err != nil { if err != nil {
return err return err
} }
if conn != nil && label != "" { if up != nil && label != "" {
err := c.action(addLabelAction{conn.Id(), conn.Label()}) err := c.action(addLabelAction{up.Id(), up.Label()})
if err != nil { if err != nil {
return err return err
} }
...@@ -726,8 +727,8 @@ func startClient(conn *websocket.Conn) (err error) { ...@@ -726,8 +727,8 @@ func startClient(conn *websocket.Conn) (err error) {
type pushConnAction struct { type pushConnAction struct {
id string id string
conn upConnection conn conn.Up
tracks []upTrack tracks []conn.UpTrack
} }
type addLabelAction struct { type addLabelAction struct {
...@@ -749,9 +750,9 @@ type kickAction struct { ...@@ -749,9 +750,9 @@ type kickAction struct {
message string message string
} }
func clientLoop(c *webClient, conn *websocket.Conn) error { func clientLoop(c *webClient, ws *websocket.Conn) error {
read := make(chan interface{}, 1) read := make(chan interface{}, 1)
go clientReader(conn, read, c.done) go clientReader(ws, read, c.done)
defer func() { defer func() {
c.setRequested(map[string]uint32{}) c.setRequested(map[string]uint32{})
...@@ -848,7 +849,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error { ...@@ -848,7 +849,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
case pushConnsAction: case pushConnsAction:
for _, u := range c.up { for _, u := range c.up {
tracks := u.getTracks() tracks := u.getTracks()
ts := make([]upTrack, len(tracks)) ts := make([]conn.UpTrack, len(tracks))
for i, t := range tracks { for i, t := range tracks {
ts[i] = t ts[i] = t
} }
...@@ -861,7 +862,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error { ...@@ -861,7 +862,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
return err return err
} }
tracks := make( tracks := make(
[]upTrack, len(down.tracks), []conn.UpTrack, len(down.tracks),
) )
for i, t := range down.tracks { for i, t := range down.tracks {
tracks[i] = t.remote tracks[i] = t.remote
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment