Commit 461c78b0 authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Fix race condition in estimator.

parent b5f8ea0e
...@@ -3,64 +3,85 @@ ...@@ -3,64 +3,85 @@
package estimator package estimator
import ( import (
"sync/atomic" "sync"
"time" "time"
"github.com/jech/galene/rtptime" "github.com/jech/galene/rtptime"
) )
type Estimator struct { type Estimator struct {
interval uint64 interval uint64
mu sync.Mutex
time uint64 time uint64
bytes uint32 bytes uint32
packets uint32 packets uint32
totalBytes uint32 totalBytes uint64
totalPackets uint32 totalPackets uint64
rate uint32 rate uint32
packetRate uint32 packetRate uint32
} }
// New creates a new estimator that estimates rate over the last interval. // New creates a new estimator that estimates rate over the last interval.
func New(interval time.Duration) *Estimator { func New(interval time.Duration) *Estimator {
return new(rtptime.Now(rtptime.JiffiesPerSec), interval)
}
func new(now uint64, interval time.Duration) *Estimator {
return &Estimator{ return &Estimator{
interval: uint64( interval: uint64(
rtptime.FromDuration(interval, rtptime.JiffiesPerSec), rtptime.FromDuration(interval, rtptime.JiffiesPerSec),
), ),
time: rtptime.Now(rtptime.JiffiesPerSec), time: now,
} }
} }
// called locked
func (e *Estimator) swap(now uint64) { func (e *Estimator) swap(now uint64) {
tm := atomic.LoadUint64(&e.time) jiffies := now - e.time
jiffies := now - tm bytes := e.bytes
bytes := atomic.SwapUint32(&e.bytes, 0) e.bytes = 0
packets := atomic.SwapUint32(&e.packets, 0) packets := e.packets
atomic.AddUint32(&e.totalBytes, bytes) e.packets = 0
atomic.AddUint32(&e.totalPackets, packets) e.totalBytes += uint64(bytes)
e.totalPackets += uint64(packets)
var rate, packetRate uint32 var rate, packetRate uint32
if jiffies >= rtptime.JiffiesPerSec/1000 { if jiffies >= rtptime.JiffiesPerSec/1000 {
rate = uint32((uint64(bytes)*rtptime.JiffiesPerSec + jiffies/2) / jiffies) rate = uint32((uint64(bytes)*rtptime.JiffiesPerSec + jiffies/2) / jiffies)
packetRate = uint32((uint64(packets)*rtptime.JiffiesPerSec + jiffies/2) / jiffies) packetRate = uint32((uint64(packets)*rtptime.JiffiesPerSec + jiffies/2) / jiffies)
} }
atomic.StoreUint32(&e.rate, rate) e.rate = rate
atomic.StoreUint32(&e.packetRate, packetRate) e.packetRate = packetRate
atomic.StoreUint64(&e.time, now) e.time = now
} }
// Accumulate records one packet of size bytes // Accumulate records one packet of size bytes
func (e *Estimator) Accumulate(bytes uint32) { func (e *Estimator) Accumulate(bytes uint32) {
atomic.AddUint32(&e.bytes, bytes) e.mu.Lock()
atomic.AddUint32(&e.packets, 1) if e.bytes < ^uint32(0)-bytes {
e.bytes += bytes
}
if e.packets < ^uint32(0)-1 {
e.packets += 1
}
e.mu.Unlock()
} }
// called locked
func (e *Estimator) estimate(now uint64) (uint32, uint32) { func (e *Estimator) estimate(now uint64) (uint32, uint32) {
tm := atomic.LoadUint64(&e.time) if now < e.time {
if now < tm || now-tm > e.interval { // time went backwards
if e.time-now > e.interval {
e.time = now
e.packets = 0
e.bytes = 0
}
} else if now-e.time >= e.interval {
e.swap(now) e.swap(now)
} }
return atomic.LoadUint32(&e.rate), atomic.LoadUint32(&e.packetRate) return e.rate, e.packetRate
} }
// Estimate returns an estimate of the rate over the last interval. // Estimate returns an estimate of the rate over the last interval.
...@@ -68,12 +89,15 @@ func (e *Estimator) estimate(now uint64) (uint32, uint32) { ...@@ -68,12 +89,15 @@ func (e *Estimator) estimate(now uint64) (uint32, uint32) {
// passed to New. It returns the byte rate and the packet rate, in units // passed to New. It returns the byte rate and the packet rate, in units
// per second. // per second.
func (e *Estimator) Estimate() (uint32, uint32) { func (e *Estimator) Estimate() (uint32, uint32) {
e.mu.Lock()
defer e.mu.Unlock()
return e.estimate(rtptime.Now(rtptime.JiffiesPerSec)) return e.estimate(rtptime.Now(rtptime.JiffiesPerSec))
} }
// Totals returns the total number of bytes and packets accumulated. // Totals returns the total number of bytes and packets accumulated.
func (e *Estimator) Totals() (uint32, uint32) { func (e *Estimator) Totals() (uint64, uint64) {
b := atomic.LoadUint32(&e.totalBytes) + atomic.LoadUint32(&e.bytes) e.mu.Lock()
p := atomic.LoadUint32(&e.totalPackets) + atomic.LoadUint32(&e.packets) defer e.mu.Unlock()
return p, b return e.totalPackets + uint64(e.packets),
e.totalBytes + uint64(e.bytes)
} }
...@@ -3,13 +3,15 @@ package estimator ...@@ -3,13 +3,15 @@ package estimator
import ( import (
"testing" "testing"
"time" "time"
"sync"
"sync/atomic"
"github.com/jech/galene/rtptime" "github.com/jech/galene/rtptime"
) )
func TestEstimator(t *testing.T) { func TestEstimator(t *testing.T) {
now := rtptime.Jiffies() now := rtptime.Jiffies()
e := New(time.Second) e := new(now, time.Second)
e.estimate(now) e.estimate(now)
e.Accumulate(42) e.Accumulate(42)
...@@ -44,3 +46,97 @@ func TestEstimator(t *testing.T) { ...@@ -44,3 +46,97 @@ func TestEstimator(t *testing.T) {
} }
} }
func TestEstimatorMany(t *testing.T) {
now := rtptime.Jiffies()
e := new(now, time.Second)
for i := 0; i < 10000; i++ {
e.Accumulate(42)
now += rtptime.JiffiesPerSec / 1000
b, p := e.estimate(now)
if i >= 1000 {
if p != 1000 || b != p*42 {
t.Errorf("Got %v %v (%v), expected %v %v",
p, b, 1000, i, p*42,
)
}
}
}
}
func TestEstimatorParallel(t *testing.T) {
now := make([]uint64, 1)
now[0] = rtptime.Jiffies()
getNow := func() uint64 {
return atomic.LoadUint64(&now[0])
}
addNow := func(v uint64) {
atomic.AddUint64(&now[0], v)
}
e := new(getNow(), time.Second)
estimate := func() (uint32, uint32) {
e.mu.Lock()
defer e.mu.Unlock()
return e.estimate(getNow())
}
f := func(n int) {
for i := 0; i < 10000; i++ {
e.Accumulate(42)
addNow(rtptime.JiffiesPerSec / 1000)
b, p := estimate()
if i >= 1000 {
if b != p * 42 {
t.Errorf("%v: Got %v %v (%v), expected %v %v",
n, p, b, i, 1000, p*42,
)
}
}
}
}
var wg sync.WaitGroup
for i := 0; i < 16; i++ {
wg.Add(1)
go func(i int) {
f(i)
wg.Done()
}(i)
}
wg.Wait()
}
func BenchmarkEstimator(b *testing.B) {
e := New(time.Second)
e.Estimate()
time.Sleep(time.Millisecond)
e.Estimate()
b.ResetTimer()
for i := 0; i < 1000 * b.N; i++ {
e.Accumulate(100)
}
e.Estimate()
}
func BenchmarkEstimatorParallel(b *testing.B) {
e := New(time.Second)
e.Estimate()
time.Sleep(time.Millisecond)
e.Estimate()
b.ResetTimer()
b.RunParallel(func (pb *testing.PB) {
for pb.Next() {
for i := 0; i < 1000; i++ {
e.Accumulate(100)
}
}
})
e.Estimate()
}
...@@ -1070,8 +1070,8 @@ func sendSR(conn *rtpDownConnection) error { ...@@ -1070,8 +1070,8 @@ func sendSR(conn *rtpDownConnection) error {
SSRC: uint32(t.ssrc), SSRC: uint32(t.ssrc),
NTPTime: nowNTP, NTPTime: nowNTP,
RTPTime: nowRTP, RTPTime: nowRTP,
PacketCount: p, PacketCount: uint32(p),
OctetCount: b, OctetCount: uint32(b),
}) })
t.setSRTime(jiffies, nowNTP) t.setSRTime(jiffies, nowNTP)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment