Commit d9cf32ed authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Move connections to their own package.

parent 7148ea18
package main
import (
"sfu/conn"
)
type clientCredentials struct {
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
......@@ -16,7 +20,7 @@ type client interface {
Id() string
Credentials() clientCredentials
SetPermissions(clientPermissions)
pushConn(id string, conn upConnection, tracks []upTrack, label string) error
pushConn(id string, conn conn.Up, tracks []conn.UpTrack, label string) error
pushClient(id, username string, add bool) error
}
......
......@@ -3,7 +3,8 @@
// This is not open source software. Copy it, and I'll break into your
// house and tell your three year-old that Santa doesn't exist.
package main
// Package conn defines interfaces for connections and tracks.
package conn
import (
"errors"
......@@ -15,29 +16,33 @@ import (
var ErrConnectionClosed = errors.New("connection is closed")
var ErrKeyframeNeeded = errors.New("keyframe needed")
type upConnection interface {
addLocal(downConnection) error
delLocal(downConnection) bool
// Type Up represents a connection in the client to server direction.
type Up interface {
AddLocal(Down) error
DelLocal(Down) bool
Id() string
Label() string
}
type upTrack interface {
addLocal(downTrack) error
delLocal(downTrack) bool
// Type UpTrack represents a track in the client to server direction.
type UpTrack interface {
AddLocal(DownTrack) error
DelLocal(DownTrack) bool
Label() string
Codec() *webrtc.RTPCodec
// get a recent packet. Returns 0 if the packet is not in cache.
getRTP(seqno uint16, result []byte) uint16
GetRTP(seqno uint16, result []byte) uint16
}
type downConnection interface {
// Type Down represents a connection in the server to client direction.
type Down interface {
GetMaxBitrate(now uint64) uint64
}
type downTrack interface {
// Type DownTrack represents a track in the server to client direction.
type DownTrack interface {
WriteRTP(packat *rtp.Packet) error
Accumulate(bytes uint32)
setTimeOffset(ntp uint64, rtp uint32)
setCname(string)
SetTimeOffset(ntp uint64, rtp uint32)
SetCname(string)
}
......@@ -14,6 +14,8 @@ import (
"github.com/pion/rtp/codecs"
"github.com/pion/webrtc/v3"
"github.com/pion/webrtc/v3/pkg/media/samplebuilder"
"sfu/conn"
)
type diskClient struct {
......@@ -81,7 +83,7 @@ func (client *diskClient) kick(message string) error {
return err
}
func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrack, label string) error {
func (client *diskClient) pushConn(id string, up conn.Up, tracks []conn.UpTrack, label string) error {
client.mu.Lock()
defer client.mu.Unlock()
......@@ -95,7 +97,7 @@ func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrac
delete(client.down, id)
}
if conn == nil {
if up == nil {
return nil
}
......@@ -109,12 +111,12 @@ func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrac
client.down = make(map[string]*diskConn)
}
down, err := newDiskConn(directory, label, conn, tracks)
down, err := newDiskConn(directory, label, up, tracks)
if err != nil {
return err
}
client.down[conn.Id()] = down
client.down[up.Id()] = down
return nil
}
......@@ -125,7 +127,7 @@ type diskConn struct {
mu sync.Mutex
file *os.File
remote upConnection
remote conn.Up
tracks []*diskTrack
width, height uint32
}
......@@ -150,7 +152,7 @@ func (conn *diskConn) reopen() error {
}
func (conn *diskConn) Close() error {
conn.remote.delLocal(conn)
conn.remote.DelLocal(conn)
conn.mu.Lock()
tracks := make([]*diskTrack, 0, len(conn.tracks))
......@@ -164,7 +166,7 @@ func (conn *diskConn) Close() error {
conn.mu.Unlock()
for _, t := range tracks {
t.remote.delLocal(t)
t.remote.DelLocal(t)
}
return nil
}
......@@ -196,7 +198,7 @@ func openDiskFile(directory, label string) (*os.File, error) {
}
type diskTrack struct {
remote upTrack
remote conn.UpTrack
conn *diskConn
writer webm.BlockWriteCloser
......@@ -206,7 +208,7 @@ type diskTrack struct {
origin uint64
}
func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrack) (*diskConn, error) {
func newDiskConn(directory, label string, up conn.Up, remoteTracks []conn.UpTrack) (*diskConn, error) {
conn := diskConn{
directory: directory,
label: label,
......@@ -231,10 +233,10 @@ func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrac
conn: &conn,
}
conn.tracks = append(conn.tracks, track)
remote.addLocal(track)
remote.AddLocal(track)
}
err := up.addLocal(&conn)
err := up.AddLocal(&conn)
if err != nil {
return nil, err
}
......@@ -242,10 +244,10 @@ func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrac
return &conn, nil
}
func (t *diskTrack) setTimeOffset(ntp uint64, rtp uint32) {
func (t *diskTrack) SetTimeOffset(ntp uint64, rtp uint32) {
}
func (t *diskTrack) setCname(string) {
func (t *diskTrack) SetCname(string) {
}
func clonePacket(packet *rtp.Packet) *rtp.Packet {
......@@ -310,7 +312,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error {
if t.writer == nil {
if !keyframe {
return ErrKeyframeNeeded
return conn.ErrKeyframeNeeded
}
return nil
}
......
......@@ -14,14 +14,15 @@ import (
"sync/atomic"
"time"
"github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/pion/webrtc/v3"
"sfu/conn"
"sfu/estimator"
"sfu/jitter"
"sfu/packetcache"
"sfu/rtptime"
"github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/pion/webrtc/v3"
)
type bitrate struct {
......@@ -71,7 +72,7 @@ type iceConnection interface {
type rtpDownTrack struct {
track *webrtc.Track
remote upTrack
remote conn.UpTrack
maxBitrate *bitrate
rate *estimator.Estimator
stats *receiverStats
......@@ -91,25 +92,25 @@ func (down *rtpDownTrack) Accumulate(bytes uint32) {
down.rate.Accumulate(bytes)
}
func (down *rtpDownTrack) setTimeOffset(ntp uint64, rtp uint32) {
func (down *rtpDownTrack) SetTimeOffset(ntp uint64, rtp uint32) {
atomic.StoreUint64(&down.remoteNTPTime, ntp)
atomic.StoreUint32(&down.remoteRTPTime, rtp)
}
func (down *rtpDownTrack) setCname(cname string) {
func (down *rtpDownTrack) SetCname(cname string) {
down.cname.Store(cname)
}
type rtpDownConnection struct {
id string
pc *webrtc.PeerConnection
remote upConnection
remote conn.Up
tracks []*rtpDownTrack
maxREMBBitrate *bitrate
iceCandidates []*webrtc.ICECandidateInit
}
func newDownConn(c client, id string, remote upConnection) (*rtpDownConnection, error) {
func newDownConn(c client, id string, remote conn.Up) (*rtpDownConnection, error) {
pc, err := c.Group().API().NewPeerConnection(iceConfiguration())
if err != nil {
return nil, err
......@@ -193,7 +194,7 @@ type rtpUpTrack struct {
mu sync.Mutex
cname string
local []downTrack
local []conn.DownTrack
srTime uint64
srNTPTime uint64
srRTPTime uint32
......@@ -201,17 +202,17 @@ type rtpUpTrack struct {
type localTrackAction struct {
add bool
track downTrack
track conn.DownTrack
}
func (up *rtpUpTrack) notifyLocal(add bool, track downTrack) {
func (up *rtpUpTrack) notifyLocal(add bool, track conn.DownTrack) {
select {
case up.localCh <- localTrackAction{add, track}:
case <-up.readerDone:
}
}
func (up *rtpUpTrack) addLocal(local downTrack) error {
func (up *rtpUpTrack) AddLocal(local conn.DownTrack) error {
up.mu.Lock()
for _, t := range up.local {
if t == local {
......@@ -226,7 +227,7 @@ func (up *rtpUpTrack) addLocal(local downTrack) error {
return nil
}
func (up *rtpUpTrack) delLocal(local downTrack) bool {
func (up *rtpUpTrack) DelLocal(local conn.DownTrack) bool {
up.mu.Lock()
for i, l := range up.local {
if l == local {
......@@ -240,15 +241,15 @@ func (up *rtpUpTrack) delLocal(local downTrack) bool {
return false
}
func (up *rtpUpTrack) getLocal() []downTrack {
func (up *rtpUpTrack) getLocal() []conn.DownTrack {
up.mu.Lock()
defer up.mu.Unlock()
local := make([]downTrack, len(up.local))
local := make([]conn.DownTrack, len(up.local))
copy(local, up.local)
return local
}
func (up *rtpUpTrack) getRTP(seqno uint16, result []byte) uint16 {
func (up *rtpUpTrack) GetRTP(seqno uint16, result []byte) uint16 {
return up.cache.Get(seqno, result)
}
......@@ -278,7 +279,7 @@ type rtpUpConnection struct {
mu sync.Mutex
tracks []*rtpUpTrack
local []downConnection
local []conn.Down
}
func (up *rtpUpConnection) getTracks() []*rtpUpTrack {
......@@ -297,7 +298,7 @@ func (up *rtpUpConnection) Label() string {
return up.label
}
func (up *rtpUpConnection) addLocal(local downConnection) error {
func (up *rtpUpConnection) AddLocal(local conn.Down) error {
up.mu.Lock()
defer up.mu.Unlock()
for _, t := range up.local {
......@@ -309,7 +310,7 @@ func (up *rtpUpConnection) addLocal(local downConnection) error {
return nil
}
func (up *rtpUpConnection) delLocal(local downConnection) bool {
func (up *rtpUpConnection) DelLocal(local conn.Down) bool {
up.mu.Lock()
defer up.mu.Unlock()
for i, l := range up.local {
......@@ -321,10 +322,10 @@ func (up *rtpUpConnection) delLocal(local downConnection) bool {
return false
}
func (up *rtpUpConnection) getLocal() []downConnection {
func (up *rtpUpConnection) getLocal() []conn.Down {
up.mu.Lock()
defer up.mu.Unlock()
local := make([]downConnection, len(up.local))
local := make([]conn.Down, len(up.local))
copy(local, up.local)
return local
}
......@@ -396,10 +397,10 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) {
return nil, err
}
conn := &rtpUpConnection{id: id, pc: pc}
up := &rtpUpConnection{id: id, pc: pc}
pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) {
conn.mu.Lock()
up.mu.Lock()
mid := getTrackMid(pc, remote)
if mid == "" {
......@@ -407,7 +408,7 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) {
return
}
label, ok := conn.labels[mid]
label, ok := up.labels[mid]
if !ok {
log.Printf("Couldn't get track's label")
isvideo := remote.Kind() == webrtc.RTPCodecTypeVideo
......@@ -428,34 +429,34 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) {
readerDone: make(chan struct{}),
}
conn.tracks = append(conn.tracks, track)
up.tracks = append(up.tracks, track)
go readLoop(conn, track)
go readLoop(up, track)
go rtcpUpListener(conn, track, receiver)
go rtcpUpListener(up, track, receiver)
complete := conn.complete()
var tracks []upTrack
if(complete) {
tracks = make([]upTrack, len(conn.tracks))
for i, t := range conn.tracks {
complete := up.complete()
var tracks []conn.UpTrack
if complete {
tracks = make([]conn.UpTrack, len(up.tracks))
for i, t := range up.tracks {
tracks[i] = t
}
}
// pushConn might need to take the lock
conn.mu.Unlock()
up.mu.Unlock()
if complete {
clients := c.Group().getClients(c)
for _, cc := range clients {
cc.pushConn(conn.id, conn, tracks, conn.label)
cc.pushConn(up.id, up, tracks, up.label)
}
go rtcpUpSender(conn)
go rtcpUpSender(up)
}
})
return conn, nil
return up, nil
}
func readLoop(conn *rtpUpConnection, track *rtpUpTrack) {
......@@ -606,7 +607,7 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *rtpDownTrack) {
buf := make([]byte, packetcache.BufSize)
for _, nack := range p.Nacks {
for _, seqno := range nack.PacketList() {
l := track.remote.getRTP(seqno, buf)
l := track.remote.GetRTP(seqno, buf)
if l == 0 {
continue
}
......@@ -650,7 +651,7 @@ func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPRecei
track.srRTPTime = p.RTPTime
track.mu.Unlock()
for _, l := range local {
l.setTimeOffset(p.NTPTime, p.RTPTime)
l.SetTimeOffset(p.NTPTime, p.RTPTime)
}
case *rtcp.SourceDescription:
for _, c := range p.Chunks {
......@@ -665,7 +666,7 @@ func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPRecei
track.cname = i.Text
track.mu.Unlock()
for _, l := range local {
l.setCname(i.Text)
l.SetCname(i.Text)
}
}
}
......
......@@ -7,6 +7,7 @@ import (
"github.com/pion/rtp"
"sfu/conn"
"sfu/packetcache"
"sfu/rtptime"
)
......@@ -43,7 +44,7 @@ func sqrt(n int) int {
}
// add adds or removes a track from a writer pool
func (wp *rtpWriterPool) add(track downTrack, add bool) error {
func (wp *rtpWriterPool) add(track conn.DownTrack, add bool) error {
n := 4
if wp.count > 16 {
n = sqrt(wp.count)
......@@ -166,7 +167,7 @@ var ErrUnknownTrack = errors.New("unknown track")
type writerAction struct {
add bool
track downTrack
track conn.DownTrack
maxTracks int
ch chan error
}
......@@ -192,7 +193,7 @@ func newRtpWriter(conn *rtpUpConnection, track *rtpUpTrack) *rtpWriter {
}
// add adds or removes a track from a writer.
func (writer *rtpWriter) add(track downTrack, add bool, max int) error {
func (writer *rtpWriter) add(track conn.DownTrack, add bool, max int) error {
ch := make(chan error, 1)
select {
case writer.action <- writerAction{add, track, max, ch}:
......@@ -208,13 +209,13 @@ func (writer *rtpWriter) add(track downTrack, add bool, max int) error {
}
// rtpWriterLoop is the main loop of an rtpWriter.
func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack) {
func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
defer close(writer.done)
buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet
local := make([]downTrack, 0)
local := make([]conn.DownTrack, 0)
// reset whenever a new track is inserted
firSent := false
......@@ -239,10 +240,10 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack)
cname := track.cname
track.mu.Unlock()
if ntp != 0 {
action.track.setTimeOffset(ntp, rtp)
action.track.SetTimeOffset(ntp, rtp)
}
if cname != "" {
action.track.setCname(cname)
action.track.SetCname(cname)
}
} else {
found := false
......@@ -283,7 +284,7 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack)
for _, l := range local {
err := l.WriteRTP(&packet)
if err != nil {
if err == ErrKeyframeNeeded {
if err == conn.ErrKeyframeNeeded {
kfNeeded = true
}
continue
......@@ -292,9 +293,9 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack)
}
if kfNeeded {
err := conn.sendFIR(track, !firSent)
err := up.sendFIR(track, !firSent)
if err == ErrUnsupportedFeedback {
conn.sendPLI(track)
up.sendPLI(track)
}
firSent = true
}
......
......@@ -14,10 +14,11 @@ import (
"sync"
"time"
"sfu/estimator"
"github.com/gorilla/websocket"
"github.com/pion/webrtc/v3"
"sfu/conn"
"sfu/estimator"
)
var iceConf webrtc.Configuration
......@@ -300,7 +301,7 @@ func getConn(c *webClient, id string) iceConnection {
return nil
}
func addDownConn(c *webClient, id string, remote upConnection) (*rtpDownConnection, error) {
func addDownConn(c *webClient, id string, remote conn.Up) (*rtpDownConnection, error) {
conn, err := newDownConn(c, id, remote)
if err != nil {
return nil, err
......@@ -333,7 +334,7 @@ func addDownConn(c *webClient, id string, remote upConnection) (*rtpDownConnecti
}
})
err = remote.addLocal(conn)
err = remote.AddLocal(conn)
if err != nil {
conn.pc.Close()
return nil, err
......@@ -355,18 +356,18 @@ func delDownConn(c *webClient, id string) bool {
return false
}
conn.remote.delLocal(conn)
conn.remote.DelLocal(conn)
for _, track := range conn.tracks {
// we only insert the track after we get an answer, so
// ignore errors here.
track.remote.delLocal(track)
track.remote.DelLocal(track)
}
conn.pc.Close()
delete(c.down, id)
return true
}
func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack upTrack, remoteConn upConnection) (*webrtc.RTPSender, error) {
func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrack, remoteConn conn.Up) (*webrtc.RTPSender, error) {
var pt uint8
var ssrc uint32
var id, label string
......@@ -524,7 +525,7 @@ func gotAnswer(c *webClient, id string, answer webrtc.SessionDescription) error
}
for _, t := range down.tracks {
t.remote.addLocal(t)
t.remote.AddLocal(t)
}
return nil
}
......@@ -568,7 +569,7 @@ func (c *webClient) isRequested(label string) bool {
return c.requested[label] != 0
}
func addDownConnTracks(c *webClient, remote upConnection, tracks []upTrack) (*rtpDownConnection, error) {
func addDownConnTracks(c *webClient, remote conn.Up, tracks []conn.UpTrack) (*rtpDownConnection, error) {
requested := false
for _, t := range tracks {
if c.isRequested(t.Label()) {
......@@ -601,13 +602,13 @@ func addDownConnTracks(c *webClient, remote upConnection, tracks []upTrack) (*rt
return down, nil
}
func (c *webClient) pushConn(id string, conn upConnection, tracks []upTrack, label string) error {
err := c.action(pushConnAction{id, conn, tracks})
func (c *webClient) pushConn(id string, up conn.Up, tracks []conn.UpTrack, label string) error {
err := c.action(pushConnAction{id, up, tracks})
if err != nil {
return err
}
if conn != nil && label != "" {
err := c.action(addLabelAction{conn.Id(), conn.Label()})
if up != nil && label != "" {
err := c.action(addLabelAction{up.Id(), up.Label()})
if err != nil {
return err
}
......@@ -726,8 +727,8 @@ func startClient(conn *websocket.Conn) (err error) {
type pushConnAction struct {
id string
conn upConnection
tracks []upTrack
conn conn.Up
tracks []conn.UpTrack
}
type addLabelAction struct {
......@@ -749,9 +750,9 @@ type kickAction struct {
message string
}
func clientLoop(c *webClient, conn *websocket.Conn) error {
func clientLoop(c *webClient, ws *websocket.Conn) error {
read := make(chan interface{}, 1)
go clientReader(conn, read, c.done)
go clientReader(ws, read, c.done)
defer func() {
c.setRequested(map[string]uint32{})
......@@ -848,7 +849,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
case pushConnsAction:
for _, u := range c.up {
tracks := u.getTracks()
ts := make([]upTrack, len(tracks))
ts := make([]conn.UpTrack, len(tracks))
for i, t := range tracks {
ts[i] = t
}
......@@ -861,7 +862,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error {
return err
}
tracks := make(
[]upTrack, len(down.tracks),
[]conn.UpTrack, len(down.tracks),
)
for i, t := range down.tracks {
tracks[i] = t.remote
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment