Commit a50e9c67 authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Buffer last keyframe.

parent bbd5ce0c
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
) )
const BufSize = 1500 const BufSize = 1500
const maxKeyframe = 1024
type entry struct { type entry struct {
seqno uint16 seqno uint16
...@@ -24,6 +25,9 @@ type Cache struct { ...@@ -24,6 +25,9 @@ type Cache struct {
// bitmap // bitmap
first uint16 first uint16
bitmap uint32 bitmap uint32
// buffered keyframe
kfTimestamp uint32
kfEntries []entry
// packet cache // packet cache
tail uint16 tail uint16
entries []entry entries []entry
...@@ -75,7 +79,7 @@ func (cache *Cache) set(seqno uint16) { ...@@ -75,7 +79,7 @@ func (cache *Cache) set(seqno uint16) {
} }
// Store a packet, setting bitmap at the same time // Store a packet, setting bitmap at the same time
func (cache *Cache) Store(seqno uint16, buf []byte) (uint16, uint16) { func (cache *Cache) Store(seqno uint16, timestamp uint32, keyframe bool, buf []byte) (uint16, uint16) {
cache.mu.Lock() cache.mu.Lock()
defer cache.mu.Unlock() defer cache.mu.Unlock()
...@@ -97,9 +101,39 @@ func (cache *Cache) Store(seqno uint16, buf []byte) (uint16, uint16) { ...@@ -97,9 +101,39 @@ func (cache *Cache) Store(seqno uint16, buf []byte) (uint16, uint16) {
} }
} }
} }
cache.set(seqno) cache.set(seqno)
doit := false
if keyframe {
if cache.kfTimestamp != timestamp {
cache.kfTimestamp = timestamp
cache.kfEntries = cache.kfEntries[:0]
}
doit = true
} else if len(cache.kfEntries) > 0 {
doit = cache.kfTimestamp == timestamp
}
if doit {
i := 0
for i < len(cache.kfEntries) {
if cache.kfEntries[i].seqno >= seqno {
break
}
i++
}
if i >= len(cache.kfEntries) || cache.kfEntries[i].seqno != seqno {
if len(cache.kfEntries) >= maxKeyframe {
cache.kfEntries = cache.kfEntries[:maxKeyframe-1]
}
cache.kfEntries = append(cache.kfEntries, entry{})
copy(cache.kfEntries[i+1:], cache.kfEntries[i:])
}
cache.kfEntries[i].seqno = seqno
cache.kfEntries[i].length = uint16(len(buf))
copy(cache.kfEntries[i].buf[:], buf)
}
i := cache.tail i := cache.tail
cache.entries[i].seqno = seqno cache.entries[i].seqno = seqno
copy(cache.entries[i].buf[:], buf) copy(cache.entries[i].buf[:], buf)
...@@ -118,20 +152,33 @@ func (cache *Cache) Expect(n int) { ...@@ -118,20 +152,33 @@ func (cache *Cache) Expect(n int) {
cache.expected += uint32(n) cache.expected += uint32(n)
} }
func get(seqno uint16, entries []entry, result []byte) uint16 {
for i := range entries {
if entries[i].length == 0 || entries[i].seqno != seqno {
continue
}
return uint16(copy(
result[:entries[i].length],
entries[i].buf[:]),
)
}
return 0
}
func (cache *Cache) Get(seqno uint16, result []byte) uint16 { func (cache *Cache) Get(seqno uint16, result []byte) uint16 {
cache.mu.Lock() cache.mu.Lock()
defer cache.mu.Unlock() defer cache.mu.Unlock()
for i := range cache.entries { n := get(seqno, cache.kfEntries, result)
if cache.entries[i].length == 0 || if n > 0 {
cache.entries[i].seqno != seqno { return n
continue
} }
return uint16(copy(
result[:cache.entries[i].length], n = get(seqno, cache.entries, result)
cache.entries[i].buf[:]), if n > 0 {
) return n
} }
return 0 return 0
} }
...@@ -151,6 +198,17 @@ func (cache *Cache) GetAt(seqno uint16, index uint16, result []byte) uint16 { ...@@ -151,6 +198,17 @@ func (cache *Cache) GetAt(seqno uint16, index uint16, result []byte) uint16 {
) )
} }
func (cache *Cache) Keyframe() (uint32, []uint16) {
cache.mu.Lock()
defer cache.mu.Unlock()
seqnos := make([]uint16, len(cache.kfEntries))
for i := range cache.kfEntries {
seqnos[i] = cache.kfEntries[i].seqno
}
return cache.kfTimestamp, seqnos
}
func (cache *Cache) resize(capacity int) { func (cache *Cache) resize(capacity int) {
if len(cache.entries) == capacity { if len(cache.entries) == capacity {
return return
......
...@@ -20,8 +20,8 @@ func TestCache(t *testing.T) { ...@@ -20,8 +20,8 @@ func TestCache(t *testing.T) {
buf1 := randomBuf() buf1 := randomBuf()
buf2 := randomBuf() buf2 := randomBuf()
cache := New(16) cache := New(16)
_, i1 := cache.Store(13, buf1) _, i1 := cache.Store(13, 0, false, buf1)
_, i2 := cache.Store(17, buf2) _, i2 := cache.Store(17, 0, false, buf2)
buf := make([]byte, BufSize) buf := make([]byte, BufSize)
...@@ -62,7 +62,7 @@ func TestCacheOverflow(t *testing.T) { ...@@ -62,7 +62,7 @@ func TestCacheOverflow(t *testing.T) {
cache := New(16) cache := New(16)
for i := 0; i < 32; i++ { for i := 0; i < 32; i++ {
cache.Store(uint16(i), []byte{uint8(i)}) cache.Store(uint16(i), 0, false, []byte{uint8(i)})
} }
for i := 0; i < 32; i++ { for i := 0; i < 32; i++ {
...@@ -84,7 +84,7 @@ func TestCacheGrow(t *testing.T) { ...@@ -84,7 +84,7 @@ func TestCacheGrow(t *testing.T) {
cache := New(16) cache := New(16)
for i := 0; i < 24; i++ { for i := 0; i < 24; i++ {
cache.Store(uint16(i), []byte{uint8(i)}) cache.Store(uint16(i), 0, false, []byte{uint8(i)})
} }
cache.Resize(32) cache.Resize(32)
...@@ -107,7 +107,7 @@ func TestCacheShrink(t *testing.T) { ...@@ -107,7 +107,7 @@ func TestCacheShrink(t *testing.T) {
cache := New(16) cache := New(16)
for i := 0; i < 24; i++ { for i := 0; i < 24; i++ {
cache.Store(uint16(i), []byte{uint8(i)}) cache.Store(uint16(i), 0, false, []byte{uint8(i)})
} }
cache.Resize(12) cache.Resize(12)
...@@ -150,6 +150,65 @@ func TestCacheGrowCond(t *testing.T) { ...@@ -150,6 +150,65 @@ func TestCacheGrowCond(t *testing.T) {
} }
} }
func TestKeyframe(t *testing.T) {
cache := New(16)
packet := make([]byte, 1)
buf := make([]byte, BufSize)
cache.Store(7, 57, true, packet)
cache.Store(8, 57, true, packet)
ts, kf := cache.Keyframe()
if ts != 57 || len(kf) != 2 {
t.Errorf("Got %v %v, expected %v %v", ts, len(kf), 57, 2)
}
for _, i := range kf {
l := cache.Get(i, buf)
if int(l) != len(packet) {
t.Errorf("Couldn't get %v", i)
}
}
for i := 0; i < 32; i++ {
cache.Store(uint16(9 + i), uint32(58 + i), false, packet)
}
ts, kf = cache.Keyframe()
if ts != 57 || len(kf) != 2 {
t.Errorf("Got %v %v, expected %v %v", ts, len(kf), 57, 2)
}
for _, i := range kf {
l := cache.Get(i, buf)
if int(l) != len(packet) {
t.Errorf("Couldn't get %v", i)
}
}
}
func TestKeyframeUnsorted(t *testing.T) {
cache := New(16)
packet := make([]byte, 1)
cache.Store(7, 57, true, packet)
cache.Store(9, 57, true, packet)
cache.Store(8, 57, true, packet)
cache.Store(10, 57, true, packet)
cache.Store(6, 57, true, packet)
cache.Store(8, 57, true, packet)
_, kf := cache.Keyframe()
if len(kf) != 5 {
t.Errorf("Got length %v, expected 5", len(kf))
}
for i, v := range kf {
if v != uint16(i + 6) {
t.Errorf("Position %v, expected %v, got %v\n",
i, i + 6, v)
}
}
}
func TestBitmap(t *testing.T) { func TestBitmap(t *testing.T) {
value := uint64(0xcdd58f1e035379c0) value := uint64(0xcdd58f1e035379c0)
packet := make([]byte, 1) packet := make([]byte, 1)
...@@ -159,7 +218,7 @@ func TestBitmap(t *testing.T) { ...@@ -159,7 +218,7 @@ func TestBitmap(t *testing.T) {
var first uint16 var first uint16
for i := 0; i < 64; i++ { for i := 0; i < 64; i++ {
if (value & (1 << i)) != 0 { if (value & (1 << i)) != 0 {
first, _ = cache.Store(uint16(42+i), packet) first, _ = cache.Store(uint16(42+i), 0, false, packet)
} }
} }
...@@ -175,13 +234,13 @@ func TestBitmapWrap(t *testing.T) { ...@@ -175,13 +234,13 @@ func TestBitmapWrap(t *testing.T) {
cache := New(16) cache := New(16)
cache.Store(0x7000, packet) cache.Store(0x7000, 0, false, packet)
cache.Store(0xA000, packet) cache.Store(0xA000, 0, false, packet)
var first uint16 var first uint16
for i := 0; i < 64; i++ { for i := 0; i < 64; i++ {
if (value & (1 << i)) != 0 { if (value & (1 << i)) != 0 {
first, _ = cache.Store(uint16(42+i), packet) first, _ = cache.Store(uint16(42+i), 0, false, packet)
} }
} }
...@@ -199,7 +258,7 @@ func TestBitmapGet(t *testing.T) { ...@@ -199,7 +258,7 @@ func TestBitmapGet(t *testing.T) {
for i := 0; i < 64; i++ { for i := 0; i < 64; i++ {
if (value & (1 << i)) != 0 { if (value & (1 << i)) != 0 {
cache.Store(uint16(42+i), packet) cache.Store(uint16(42+i), 0, false, packet)
} }
} }
...@@ -241,7 +300,7 @@ func TestBitmapPacket(t *testing.T) { ...@@ -241,7 +300,7 @@ func TestBitmapPacket(t *testing.T) {
for i := 0; i < 64; i++ { for i := 0; i < 64; i++ {
if (value & (1 << i)) != 0 { if (value & (1 << i)) != 0 {
cache.Store(uint16(42+i), packet) cache.Store(uint16(42+i), 0, false, packet)
} }
} }
...@@ -299,7 +358,7 @@ func BenchmarkCachePutGet(b *testing.B) { ...@@ -299,7 +358,7 @@ func BenchmarkCachePutGet(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
seqno := uint16(i) seqno := uint16(i)
cache.Store(seqno, buf) cache.Store(seqno, 0, false, buf)
for _, ch := range chans { for _, ch := range chans {
ch <- seqno ch <- seqno
} }
...@@ -350,7 +409,7 @@ func BenchmarkCachePutGetAt(b *testing.B) { ...@@ -350,7 +409,7 @@ func BenchmarkCachePutGetAt(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
seqno := uint16(i) seqno := uint16(i)
_, index := cache.Store(seqno, buf) _, index := cache.Store(seqno, 0, false, buf)
for _, ch := range chans { for _, ch := range chans {
ch <- is{index, seqno} ch <- is{index, seqno}
} }
......
...@@ -5,12 +5,24 @@ import ( ...@@ -5,12 +5,24 @@ import (
"log" "log"
"github.com/pion/rtp" "github.com/pion/rtp"
"github.com/pion/rtp/codecs"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
"sfu/packetcache" "sfu/packetcache"
"sfu/rtptime" "sfu/rtptime"
) )
func isVP8Keyframe(packet *rtp.Packet) bool {
var vp8 codecs.VP8Packet
_, err := vp8.Unmarshal(packet.Payload)
if err != nil {
return false
}
return vp8.S != 0 && vp8.PID == 0 &&
len(vp8.Payload) > 0 && (vp8.Payload[0]&0x1) == 0
}
func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { func readLoop(conn *rtpUpConnection, track *rtpUpTrack) {
writers := rtpWriterPool{conn: conn, track: track} writers := rtpWriterPool{conn: conn, track: track}
defer func() { defer func() {
...@@ -19,6 +31,7 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { ...@@ -19,6 +31,7 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) {
}() }()
isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo isvideo := track.track.Kind() == webrtc.RTPCodecTypeVideo
codec := track.track.Codec().Name
buf := make([]byte, packetcache.BufSize) buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet var packet rtp.Packet
for { for {
...@@ -39,8 +52,14 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { ...@@ -39,8 +52,14 @@ func readLoop(conn *rtpUpConnection, track *rtpUpTrack) {
track.jitter.Accumulate(packet.Timestamp) track.jitter.Accumulate(packet.Timestamp)
first, index := kf := false
track.cache.Store(packet.SequenceNumber, buf[:bytes]) if isvideo && codec == webrtc.VP8 {
kf = isVP8Keyframe(&packet)
}
first, index := track.cache.Store(
packet.SequenceNumber, packet.Timestamp, kf, buf[:bytes],
)
if packet.SequenceNumber-first > 24 { if packet.SequenceNumber-first > 24 {
found, first, bitmap := track.cache.BitmapGet() found, first, bitmap := track.cache.BitmapGet()
if found { if found {
......
...@@ -138,7 +138,7 @@ func (wp *rtpWriterPool) write(seqno uint16, index uint16, delay uint32, isvideo ...@@ -138,7 +138,7 @@ func (wp *rtpWriterPool) write(seqno uint16, index uint16, delay uint32, isvideo
continue continue
} }
// audio, try again with a delay // audio, try again with a delay
d := delay/uint32(2*len(wp.writers)) d := delay / uint32(2*len(wp.writers))
timer := time.NewTimer(rtptime.ToDuration( timer := time.NewTimer(rtptime.ToDuration(
uint64(d), rtptime.JiffiesPerSec, uint64(d), rtptime.JiffiesPerSec,
)) ))
...@@ -208,6 +208,31 @@ func (writer *rtpWriter) add(track conn.DownTrack, add bool, max int) error { ...@@ -208,6 +208,31 @@ func (writer *rtpWriter) add(track conn.DownTrack, add bool, max int) error {
} }
} }
func sendKeyframe(track conn.DownTrack, cache *packetcache.Cache) {
_, kf := cache.Keyframe()
if len(kf) == 0 {
return
}
buf := make([]byte, packetcache.BufSize)
var packet rtp.Packet
for _, seqno := range kf {
bytes := cache.Get(seqno, buf)
if(bytes == 0) {
return
}
err := packet.Unmarshal(buf[:bytes])
if err != nil {
return
}
err = track.WriteRTP(&packet)
if err != nil && err != conn.ErrKeyframeNeeded {
return
}
track.Accumulate(uint32(bytes))
}
}
// rtpWriterLoop is the main loop of an rtpWriter. // rtpWriterLoop is the main loop of an rtpWriter.
func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
defer close(writer.done) defer close(writer.done)
...@@ -245,6 +270,7 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { ...@@ -245,6 +270,7 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
if cname != "" { if cname != "" {
action.track.SetCname(cname) action.track.SetCname(cname)
} }
go sendKeyframe(action.track, track.cache)
} else { } else {
found := false found := false
for i, t := range local { for i, t := range local {
...@@ -286,9 +312,10 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { ...@@ -286,9 +312,10 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) {
if err != nil { if err != nil {
if err == conn.ErrKeyframeNeeded { if err == conn.ErrKeyframeNeeded {
kfNeeded = true kfNeeded = true
} } else {
continue continue
} }
}
l.Accumulate(uint32(bytes)) l.Accumulate(uint32(bytes))
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment