Commit 7126394e authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Fix locking in group.go.

Also export some fields as thread-safe methods.
parent 938d231b
...@@ -13,7 +13,6 @@ import ( ...@@ -13,7 +13,6 @@ import (
"sort" "sort"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
...@@ -31,16 +30,47 @@ const ( ...@@ -31,16 +30,47 @@ const (
) )
type group struct { type group struct {
name string name string
dead bool
description *groupDescription
locked uint32
mu sync.Mutex mu sync.Mutex
description *groupDescription
// indicates that the group no longer exists, but it still has clients
dead bool
locked bool
clients map[string]client clients map[string]client
history []chatHistoryEntry history []chatHistoryEntry
} }
func (g *group) Locked() bool {
g.mu.Lock()
defer g.mu.Unlock()
return g.locked
}
func (g *group) SetLocked(locked bool) {
g.mu.Lock()
defer g.mu.Unlock()
g.locked = locked
}
func (g *group) Public() bool {
g.mu.Lock()
defer g.mu.Unlock()
return g.description.Public
}
func (g *group) Redirect() string {
g.mu.Lock()
defer g.mu.Unlock()
return g.description.Redirect
}
func (g *group) AllowRecording() bool {
g.mu.Lock()
defer g.mu.Unlock()
return g.description.AllowRecording
}
var groups struct { var groups struct {
mu sync.Mutex mu sync.Mutex
groups map[string]*group groups map[string]*group
...@@ -94,9 +124,19 @@ func addGroup(name string, desc *groupDescription) (*group, error) { ...@@ -94,9 +124,19 @@ func addGroup(name string, desc *groupDescription) (*group, error) {
clients: make(map[string]client), clients: make(map[string]client),
} }
groups.groups[name] = g groups.groups[name] = g
} else if desc != nil { return g, nil
}
g.mu.Lock()
defer g.mu.Unlock()
if desc != nil {
g.description = desc g.description = desc
} else if g.dead || time.Since(g.description.loadTime) > 5*time.Second { g.dead = false
return g, nil
}
if g.dead || time.Since(g.description.loadTime) > 5*time.Second {
changed, err := descriptionChanged(name, g.description) changed, err := descriptionChanged(name, g.description)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
...@@ -176,6 +216,9 @@ func addClient(name string, c client) (*group, error) { ...@@ -176,6 +216,9 @@ func addClient(name string, c client) (*group, error) {
return nil, err return nil, err
} }
g.mu.Lock()
defer g.mu.Unlock()
perms, err := getPermission(g.description, c.Credentials()) perms, err := getPermission(g.description, c.Credentials())
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -183,13 +226,10 @@ func addClient(name string, c client) (*group, error) { ...@@ -183,13 +226,10 @@ func addClient(name string, c client) (*group, error) {
c.SetPermissions(perms) c.SetPermissions(perms)
if !perms.Op && atomic.LoadUint32(&g.locked) != 0 { if !perms.Op && g.locked {
return nil, userError("group is locked") return nil, userError("group is locked")
} }
g.mu.Lock()
defer g.mu.Unlock()
if !perms.Op && g.description.MaxClients > 0 { if !perms.Op && g.description.MaxClients > 0 {
if len(g.clients) >= g.description.MaxClients { if len(g.clients) >= g.description.MaxClients {
return nil, userError("too many users") return nil, userError("too many users")
...@@ -423,8 +463,8 @@ type publicGroup struct { ...@@ -423,8 +463,8 @@ type publicGroup struct {
func getPublicGroups() []publicGroup { func getPublicGroups() []publicGroup {
gs := make([]publicGroup, 0) gs := make([]publicGroup, 0)
rangeGroups(func (g *group) bool { rangeGroups(func(g *group) bool {
if g.description.Public { if g.Public() {
gs = append(gs, publicGroup{ gs = append(gs, publicGroup{
Name: g.name, Name: g.name,
ClientCount: len(g.clients), ClientCount: len(g.clients),
......
...@@ -12,7 +12,6 @@ import ( ...@@ -12,7 +12,6 @@ import (
"os" "os"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"sfu/estimator" "sfu/estimator"
...@@ -713,10 +712,10 @@ func startClient(conn *websocket.Conn) (err error) { ...@@ -713,10 +712,10 @@ func startClient(conn *websocket.Conn) (err error) {
} }
return return
} }
if g.description.Redirect != "" { if redirect := g.Redirect(); redirect != "" {
// We normally redirect at the HTTP level, but the group // We normally redirect at the HTTP level, but the group
// description could have been edited in the meantime. // description could have been edited in the meantime.
err = userError("group is now at " + g.description.Redirect) err = userError("group is now at " + redirect)
return return
} }
c.group = g c.group = g
...@@ -940,10 +939,7 @@ func failConnection(c *webClient, id string, message string) error { ...@@ -940,10 +939,7 @@ func failConnection(c *webClient, id string, message string) error {
} }
func setPermissions(g *group, id string, perm string) error { func setPermissions(g *group, id string, perm string) error {
g.mu.Lock() client := g.getClient(id)
defer g.mu.Unlock()
client := g.getClientUnlocked(id)
if client == nil { if client == nil {
return userError("no such user") return userError("no such user")
} }
...@@ -956,7 +952,7 @@ func setPermissions(g *group, id string, perm string) error { ...@@ -956,7 +952,7 @@ func setPermissions(g *group, id string, perm string) error {
switch perm { switch perm {
case "op": case "op":
c.permissions.Op = true c.permissions.Op = true
if g.description.AllowRecording { if g.AllowRecording() {
c.permissions.Record = true c.permissions.Record = true
} }
case "unop": case "unop":
...@@ -1071,11 +1067,7 @@ func handleClientMessage(c *webClient, m clientMessage) error { ...@@ -1071,11 +1067,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
if !c.permissions.Op { if !c.permissions.Op {
return c.error(userError("not authorised")) return c.error(userError("not authorised"))
} }
var locked uint32 c.group.SetLocked(m.Kind == "lock")
if m.Kind == "lock" {
locked = 1
}
atomic.StoreUint32(&c.group.locked, locked)
case "record": case "record":
if !c.permissions.Record { if !c.permissions.Record {
return c.error(userError("not authorised")) return c.error(userError("not authorised"))
......
...@@ -151,8 +151,8 @@ func groupHandler(w http.ResponseWriter, r *http.Request) { ...@@ -151,8 +151,8 @@ func groupHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
if g.description.Redirect != "" { if redirect := g.Redirect(); redirect != "" {
http.Redirect(w, r, g.description.Redirect, http.Redirect(w, r, redirect,
http.StatusPermanentRedirect) http.StatusPermanentRedirect)
return return
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment