Commit 0e777472 authored by Kim "BKC" Carlbäcker's avatar Kim "BKC" Carlbäcker Committed by Nick Thomas

Bug-fix for redis connection

parent a806da23
......@@ -105,12 +105,16 @@ SentinelMaster = "mymaster"
Optional fields are as follows:
```
[redis]
ReadTimeout = 1000
DB = 0
ReadTimeout = "1s"
KeepAlivePeriod = "5m"
MaxIdle = 1
MaxActive = 1
```
- `ReadTimeout` is how many milliseconds that a redis read-command can take. Defaults to `1000`
- `DB` is the Database to connect to. Defaults to `0`
- `ReadTimeout` is how long a redis read-command can take. Defaults to `1s`
- `KeepAlivePeriod` is how long the redis connection is to be kept alive without anything flowing through it. Defaults to `5m`
- `MaxIdle` is how many idle connections can be in the redis-pool at once. Defaults to 1
- `MaxActive` is how many connections the pool can keep. Defaults to 1
......
......@@ -17,12 +17,25 @@ func (u *TomlURL) UnmarshalText(text []byte) error {
return err
}
type TomlDuration struct {
time.Duration
}
func (d *TomlDuration) UnmarshalTest(text []byte) error {
temp, err := time.ParseDuration(string(text))
d.Duration = temp
return err
}
type RedisConfig struct {
URL TomlURL
Sentinel []TomlURL
SentinelMaster string
Password string
ReadTimeout *int
DB *int
ReadTimeout *TomlDuration
WriteTimeout *TomlDuration
KeepAlivePeriod *TomlDuration
MaxIdle *int
MaxActive *int
}
......
package redis
import (
"errors"
"fmt"
"log"
"strings"
......@@ -34,7 +33,7 @@ var (
totalMessages = prometheus.NewCounter(
prometheus.CounterOpts{
Name: "gitlab_workhorse_keywather_total_messages",
Help: "How many messages gitlab-workhorse has recieved in total on pubsub.",
Help: "How many messages gitlab-workhorse has received in total on pubsub.",
},
)
)
......@@ -58,13 +57,11 @@ type KeyChan struct {
Chan chan string
}
func processInner(conn redis.Conn) {
redisReconnectTimeout.Reset()
func processInner(conn redis.Conn) error {
defer conn.Close()
psc := redis.PubSubConn{Conn: conn}
if err := psc.Subscribe(keySubChannel); err != nil {
return
return err
}
defer psc.Unsubscribe(keySubChannel)
......@@ -72,18 +69,36 @@ func processInner(conn redis.Conn) {
switch v := psc.Receive().(type) {
case redis.Message:
totalMessages.Inc()
msg := strings.SplitN(string(v.Data), "=", 2)
dataStr := string(v.Data)
msg := strings.SplitN(dataStr, "=", 2)
if len(msg) != 2 {
helper.LogError(nil, errors.New("Redis subscribe error: got an invalid notification"))
helper.LogError(nil, fmt.Errorf("Redis receive error: got an invalid notification: %q", dataStr))
continue
}
key, value := msg[0], msg[1]
notifyChanWatchers(key, value)
case error:
helper.LogError(nil, fmt.Errorf("Redis subscribe error: %s", v))
return
helper.LogError(nil, fmt.Errorf("Redis receive error: %s", v))
// Intermittent error, return nil so that it doesn't wait before reconnect
return nil
}
}
}
func dialPubSub(dialer redisDialerFunc) (redis.Conn, error) {
conn, err := dialer()
if err != nil {
return nil, err
}
// Make sure Redis is actually connected
conn.Do("PING")
if err := conn.Err(); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
// Process redis subscriptions
......@@ -97,13 +112,19 @@ func Process(reconnect bool) {
for loop {
loop = reconnect
log.Println("Connecting to redis")
conn, err := redisDialFunc()
conn, err := dialPubSub(workerDialFunc)
if err != nil {
helper.LogError(nil, fmt.Errorf("Failed to connect to redis: %s", err))
time.Sleep(redisReconnectTimeout.Duration())
continue
}
processInner(conn)
redisReconnectTimeout.Reset()
if err = processInner(conn); err != nil {
helper.LogError(nil, fmt.Errorf("Failed to process redis-queue: %s", err))
continue
}
}
}
......
......@@ -103,7 +103,6 @@ func TestWatchKeyNoChange(t *testing.T) {
processMessages(1, "something")
wg.Wait()
}
func TestWatchKeyTimeout(t *testing.T) {
......
......@@ -3,6 +3,8 @@ package redis
import (
"errors"
"fmt"
"net"
"net/url"
"time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
......@@ -18,10 +20,26 @@ var (
)
const (
// Max Idle Connections in the pool.
defaultMaxIdle = 1
// Max Active Connections in the pool.
defaultMaxActive = 1
// Timeout for Read operations on the pool. 1 second is technically overkill,
// it's just for sanity.
defaultReadTimeout = 1 * time.Second
// Timeout for Write operations on the pool. 1 second is technically overkill,
// it's just for sanity.
defaultWriteTimeout = 1 * time.Second
// Timeout before killing Idle connections in the pool. 3 minutes seemed good.
// If you _actually_ hit this timeout often, you should consider turning of
// redis-support since it's not necessary at that point...
defaultIdleTimeout = 3 * time.Minute
// KeepAlivePeriod is to keep a TCP connection open for an extended period of
// time without being killed. This is used both in the pool, and in the
// worker-connection.
// See https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive for more
// information.
defaultKeepAlivePeriod = 5 * time.Minute
)
var (
......@@ -65,37 +83,91 @@ func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel {
}
}
var redisDialFunc func() (redis.Conn, error)
var poolDialFunc func() (redis.Conn, error)
var workerDialFunc func() (redis.Conn, error)
func dialOptionsBuilder(cfg *config.RedisConfig) []redis.DialOption {
func timeoutDialOptions(cfg *config.RedisConfig) []redis.DialOption {
readTimeout := defaultReadTimeout
writeTimeout := defaultWriteTimeout
if cfg != nil {
if cfg.ReadTimeout != nil {
readTimeout = time.Millisecond * time.Duration(*cfg.ReadTimeout)
readTimeout = cfg.ReadTimeout.Duration
}
if cfg.WriteTimeout != nil {
writeTimeout = cfg.WriteTimeout.Duration
}
}
return []redis.DialOption{
redis.DialReadTimeout(readTimeout),
redis.DialWriteTimeout(writeTimeout),
}
}
func dialOptionsBuilder(cfg *config.RedisConfig, setTimeouts bool) []redis.DialOption {
var dopts []redis.DialOption
if setTimeouts {
dopts = timeoutDialOptions(cfg)
}
if cfg == nil {
return dopts
}
dopts := []redis.DialOption{redis.DialReadTimeout(readTimeout)}
if cfg.Password != "" {
dopts = append(dopts, redis.DialPassword(cfg.Password))
}
if cfg.DB != nil {
dopts = append(dopts, redis.DialDatabase(*cfg.DB))
}
return dopts
}
// DefaultDialFunc should always used. Only exception is for unit-tests.
func DefaultDialFunc(cfg *config.RedisConfig) func() (redis.Conn, error) {
dopts := dialOptionsBuilder(cfg)
innerDial := func() (redis.Conn, error) {
return redis.Dial(cfg.URL.Scheme, cfg.URL.Host, dopts...)
func keepAliveDialer(timeout time.Duration) func(string, string) (net.Conn, error) {
return func(network, address string) (net.Conn, error) {
addr, err := net.ResolveTCPAddr(network, address)
if err != nil {
return nil, err
}
if sntnl != nil {
innerDial = func() (redis.Conn, error) {
tc, err := net.DialTCP(network, nil, addr)
if err != nil {
return nil, err
}
if err := tc.SetKeepAlive(true); err != nil {
return nil, err
}
if err := tc.SetKeepAlivePeriod(timeout); err != nil {
return nil, err
}
return tc, nil
}
}
type redisDialerFunc func() (redis.Conn, error)
func sentinelDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration) redisDialerFunc {
return func() (redis.Conn, error) {
address, err := sntnl.MasterAddr()
if err != nil {
return nil, err
}
dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod)))
return redis.Dial("tcp", address, dopts...)
}
}
func defaultDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration, url url.URL) redisDialerFunc {
return func() (redis.Conn, error) {
if url.Scheme == "unix" {
return redis.Dial(url.Scheme, url.Path, dopts...)
}
dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod)))
return redis.Dial(url.Scheme, url.Host, dopts...)
}
}
func countDialer(dialer redisDialerFunc) redisDialerFunc {
return func() (redis.Conn, error) {
c, err := innerDial()
c, err := dialer()
if err == nil {
totalConnections.Inc()
}
......@@ -103,8 +175,21 @@ func DefaultDialFunc(cfg *config.RedisConfig) func() (redis.Conn, error) {
}
}
// DefaultDialFunc should always used. Only exception is for unit-tests.
func DefaultDialFunc(cfg *config.RedisConfig, setReadTimeout bool) func() (redis.Conn, error) {
keepAlivePeriod := defaultKeepAlivePeriod
if cfg.KeepAlivePeriod != nil {
keepAlivePeriod = cfg.KeepAlivePeriod.Duration
}
dopts := dialOptionsBuilder(cfg, setReadTimeout)
if sntnl != nil {
return countDialer(sentinelDialer(dopts, keepAlivePeriod))
}
return countDialer(defaultDialer(dopts, keepAlivePeriod, cfg.URL.URL))
}
// Configure redis-connection
func Configure(cfg *config.RedisConfig, dialFunc func() (redis.Conn, error)) {
func Configure(cfg *config.RedisConfig, dialFunc func(*config.RedisConfig, bool) func() (redis.Conn, error)) {
if cfg == nil {
return
}
......@@ -117,12 +202,13 @@ func Configure(cfg *config.RedisConfig, dialFunc func() (redis.Conn, error)) {
maxActive = *cfg.MaxActive
}
sntnl = sentinelConn(cfg.SentinelMaster, cfg.Sentinel)
redisDialFunc = dialFunc
workerDialFunc = dialFunc(cfg, false)
poolDialFunc = dialFunc(cfg, true)
pool = &redis.Pool{
MaxIdle: maxIdle, // Keep at most X hot connections
MaxActive: maxActive, // Keep at most X live connections, 0 means unlimited
IdleTimeout: defaultIdleTimeout, // X time until an unused connection is closed
Dial: redisDialFunc,
Dial: poolDialFunc,
Wait: true,
}
if sntnl != nil {
......
......@@ -5,6 +5,7 @@ import (
"time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"github.com/garyburd/redigo/redis"
"github.com/rafaeljusto/redigomock"
......@@ -17,8 +18,10 @@ import (
func setupMockPool() (*redigomock.Conn, func()) {
conn := redigomock.NewConn()
cfg := &config.RedisConfig{URL: config.TomlURL{}}
Configure(cfg, func() (redis.Conn, error) {
Configure(cfg, func(_ *config.RedisConfig, _ bool) func() (redis.Conn, error) {
return func() (redis.Conn, error) {
return conn, nil
}
})
return conn, func() {
pool = nil
......@@ -33,7 +36,7 @@ func TestConfigureNoConfig(t *testing.T) {
func TestConfigureMinimalConfig(t *testing.T) {
cfg := &config.RedisConfig{URL: config.TomlURL{}, Password: ""}
Configure(cfg, DefaultDialFunc(cfg))
Configure(cfg, DefaultDialFunc)
if assert.NotNil(t, pool, "Pool should not be nil") {
assert.Equal(t, 1, pool.MaxIdle)
assert.Equal(t, 1, pool.MaxActive)
......@@ -43,7 +46,8 @@ func TestConfigureMinimalConfig(t *testing.T) {
}
func TestConfigureFullConfig(t *testing.T) {
i, a, r := 4, 10, 3
i, a := 4, 10
r := config.TomlDuration{Duration: 3}
cfg := &config.RedisConfig{
URL: config.TomlURL{},
Password: "",
......@@ -51,7 +55,7 @@ func TestConfigureFullConfig(t *testing.T) {
MaxActive: &a,
ReadTimeout: &r,
}
Configure(cfg, DefaultDialFunc(cfg))
Configure(cfg, DefaultDialFunc)
if assert.NotNil(t, pool, "Pool should not be nil") {
assert.Equal(t, i, pool.MaxIdle)
assert.Equal(t, a, pool.MaxActive)
......@@ -88,3 +92,51 @@ func TestGetStringFail(t *testing.T) {
_, err := GetString("foobar")
assert.Error(t, err, "Expected error when not connected to redis")
}
func TestSentinelConnNoSentinel(t *testing.T) {
s := sentinelConn("", []config.TomlURL{})
assert.Nil(t, s, "Sentinel without urls should return nil")
}
func TestSentinelConnTwoURLs(t *testing.T) {
urls := []string{"tcp://10.0.0.1:12345", "tcp://10.0.0.2:12345"}
var sentinelUrls []config.TomlURL
for _, url := range urls {
parsedURL := helper.URLMustParse(url)
sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL})
}
s := sentinelConn("foobar", sentinelUrls)
assert.Equal(t, len(urls), len(s.Addrs))
for i := range urls {
assert.Equal(t, urls[i], s.Addrs[i])
}
}
func TestDialOptionsBuildersPassword(t *testing.T) {
dopts := dialOptionsBuilder(&config.RedisConfig{Password: "foo"}, false)
assert.Equal(t, 1, len(dopts))
}
func TestDialOptionsBuildersSetTimeouts(t *testing.T) {
dopts := dialOptionsBuilder(nil, true)
assert.Equal(t, 2, len(dopts))
}
func TestDialOptionsBuildersSetTimeoutsConfig(t *testing.T) {
cfg := &config.RedisConfig{
ReadTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)},
WriteTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)},
}
dopts := dialOptionsBuilder(cfg, true)
assert.Equal(t, 2, len(dopts))
}
func TestDialOptionsBuildersSelectDB(t *testing.T) {
db := 3
dopts := dialOptionsBuilder(&config.RedisConfig{DB: &db}, false)
assert.Equal(t, 1, len(dopts))
}
......@@ -133,7 +133,7 @@ func main() {
cfg.Redis = cfgFromFile.Redis
redis.Configure(cfg.Redis, redis.DefaultDialFunc(cfg.Redis))
redis.Configure(cfg.Redis, redis.DefaultDialFunc)
go redis.Process(true)
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment