Commit 0e7bf0b3 authored by Juliusz Chroboczek's avatar Juliusz Chroboczek

Carry group around when pushing connections.

This avoids a race condition if the group changes before the connections
are pushed.
parent b134bfcf
......@@ -87,7 +87,11 @@ func (client *Client) Kick(id, user, message string) error {
return err
}
func (client *Client) PushConn(id string, up conn.Up, tracks []conn.UpTrack, label string) error {
func (client *Client) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, label string) error {
if client.group != g {
return nil
}
client.mu.Lock()
defer client.mu.Unlock()
......
......@@ -97,7 +97,7 @@ type Client interface {
Challengeable
SetPermissions(ClientPermissions)
OverridePermissions(*Group) bool
PushConn(id string, conn conn.Up, tracks []conn.UpTrack, label string) error
PushConn(g *Group, id string, conn conn.Up, tracks []conn.UpTrack, label string) error
PushClient(id, username string, add bool) error
}
......
......@@ -450,7 +450,7 @@ func newUpConn(c group.Client, id string) (*rtpUpConnection, error) {
if complete {
clients := c.Group().GetClients(c)
for _, cc := range clients {
cc.PushConn(up.id, up, tracks, up.label)
cc.PushConn(c.Group(), up.id, up, tracks, up.label)
}
go rtcpUpSender(up)
}
......
......@@ -258,7 +258,7 @@ func delUpConn(c *webClient, id string) bool {
if g != nil {
go func(clients []group.Client) {
for _, c := range clients {
err := c.PushConn(conn.id, nil, nil, "")
err := c.PushConn(g, conn.id, nil, nil, "")
if err != nil {
log.Printf("PushConn: %v", err)
}
......@@ -582,21 +582,16 @@ func (c *webClient) setRequested(requested map[string]uint32) error {
c.requested = requested
go pushConns(c)
go pushConns(c, c.group)
return nil
}
func pushConns(c group.Client) {
group := c.Group()
if group == nil {
log.Printf("Pushing connections to unjoined client")
return
}
clients := group.GetClients(c)
func pushConns(c group.Client, g *group.Group) {
clients := g.GetClients(c)
for _, cc := range clients {
ccc, ok := cc.(*webClient)
if ok {
ccc.action(pushConnsAction{c})
ccc.action(pushConnsAction{g, c})
}
}
}
......@@ -638,8 +633,8 @@ func addDownConnTracks(c *webClient, remote conn.Up, tracks []conn.UpTrack) (*rt
return down, nil
}
func (c *webClient) PushConn(id string, up conn.Up, tracks []conn.UpTrack, label string) error {
err := c.action(pushConnAction{id, up, tracks})
func (c *webClient) PushConn(g *group.Group, id string, up conn.Up, tracks []conn.UpTrack, label string) error {
err := c.action(pushConnAction{g, id, up, tracks})
if err != nil {
return err
}
......@@ -709,6 +704,7 @@ func StartClient(conn *websocket.Conn) error {
}
type pushConnAction struct {
group *group.Group
id string
conn conn.Up
tracks []conn.UpTrack
......@@ -720,7 +716,8 @@ type addLabelAction struct {
}
type pushConnsAction struct {
c group.Client
group *group.Group
client group.Client
}
type connectionFailedAction struct {
......@@ -736,24 +733,10 @@ type kickAction struct {
}
func clientLoop(c *webClient, ws *websocket.Conn) error {
defer func() {
if c.group != nil {
group.DelClient(c)
c.group = nil
}
}()
read := make(chan interface{}, 1)
go clientReader(ws, read, c.done)
defer func() {
c.setRequested(map[string]uint32{})
if c.up != nil {
for id := range c.up {
delUpConn(c, id)
}
}
}()
defer leaveGroup(c)
readTime := time.Now()
......@@ -779,6 +762,10 @@ func clientLoop(c *webClient, ws *websocket.Conn) error {
case a := <-c.actionCh:
switch a := a.(type) {
case pushConnAction:
g := c.group
if g == nil || a.group != g {
return nil
}
if a.conn == nil {
found := delDownConn(c, a.id)
if found {
......@@ -821,6 +808,10 @@ func clientLoop(c *webClient, ws *websocket.Conn) error {
Value: &label,
})
case pushConnsAction:
g := c.group
if g == nil || a.group != g {
return nil
}
for _, u := range c.up {
if !u.complete() {
continue
......@@ -831,8 +822,8 @@ func clientLoop(c *webClient, ws *websocket.Conn) error {
ts[i] = t
}
go func() {
err := a.c.PushConn(
u.id, u, ts, u.label,
err := a.client.PushConn(
g, u.id, u, ts, u.label,
)
if err != nil {
log.Printf(
......@@ -855,6 +846,7 @@ func clientLoop(c *webClient, ws *websocket.Conn) error {
tracks[i] = t.remote
}
go c.PushConn(
c.group,
down.remote.Id(), down.remote,
tracks, down.remote.Label(),
)
......@@ -935,6 +927,24 @@ func failUpConnection(c *webClient, id string, message string) error {
return nil
}
func leaveGroup(c *webClient) {
if c.group == nil {
return
}
c.setRequested(map[string]uint32{})
if c.up != nil {
for id := range c.up {
delUpConn(c, id)
}
}
group.DelClient(c)
c.permissions = group.ClientPermissions{}
c.group = nil
}
func failDownConnection(c *webClient, id string, message string) error {
if id != "" {
err := c.write(clientMessage{
......@@ -1009,8 +1019,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
if c.group == nil || c.group.Name() != m.Group {
return group.ProtocolError("you are not joined")
}
c.group = nil
c.permissions = group.ClientPermissions{}
leaveGroup(c)
perms := c.permissions
return c.write(clientMessage{
Type: "joined",
......@@ -1245,7 +1254,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
disk.Close()
return c.error(err)
}
go pushConns(disk)
go pushConns(disk, c.group)
case "unrecord":
if !c.permissions.Record {
return c.error(group.UserError("not authorised"))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment