blocky/redis/redis.go

352 lines
7.6 KiB
Go

package redis
import (
"bytes"
"context"
"encoding/json"
"fmt"
"math"
"strings"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/go-redis/redis/v8"
"github.com/google/uuid"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
const (
SyncChannelName = "blocky_sync"
CacheStorePrefix = "blocky:cache:"
chanCap = 1000
cacheReason = "EXTERNAL_CACHE"
defaultCacheTime = 1 * time.Second
messageTypeCache = 0
messageTypeEnable = 1
)
// sendBuffer message
type bufferMessage struct {
Key string
Message *dns.Msg
}
// redis pubsub message
type redisMessage struct {
Key string `json:"k,omitempty"`
Type int `json:"t"`
Message []byte `json:"m"`
Client []byte `json:"c"`
}
// CacheChannel message
type CacheMessage struct {
Key string
Response *model.Response
}
type EnabledMessage struct {
State bool `json:"s"`
Duration time.Duration `json:"d,omitempty"`
Groups []string `json:"g,omitempty"`
}
// Client for redis communication
type Client struct {
config *config.Redis
client *redis.Client
l *logrus.Entry
id []byte
sendBuffer chan *bufferMessage
CacheChannel chan *CacheMessage
EnabledChannel chan *EnabledMessage
}
// New creates a new redis client
func New(ctx context.Context, cfg *config.Redis) (*Client, error) {
// disable redis if no address is provided
if cfg == nil || len(cfg.Address) == 0 {
return nil, nil //nolint:nilnil
}
var baseClient *redis.Client
if len(cfg.SentinelAddresses) > 0 {
baseClient = redis.NewFailoverClient(&redis.FailoverOptions{
MasterName: cfg.Address,
SentinelUsername: cfg.Username,
SentinelPassword: cfg.SentinelPassword,
SentinelAddrs: cfg.SentinelAddresses,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.Database,
MaxRetries: cfg.ConnectionAttempts,
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
})
} else {
baseClient = redis.NewClient(&redis.Options{
Addr: cfg.Address,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.Database,
MaxRetries: cfg.ConnectionAttempts,
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
})
}
rdb := baseClient.WithContext(ctx)
_, err := rdb.Ping(ctx).Result()
if err == nil {
var id []byte
id, err = uuid.New().MarshalBinary()
if err == nil {
// construct client
res := &Client{
config: cfg,
client: rdb,
l: log.PrefixedLog("redis"),
id: id,
sendBuffer: make(chan *bufferMessage, chanCap),
CacheChannel: make(chan *CacheMessage, chanCap),
EnabledChannel: make(chan *EnabledMessage, chanCap),
}
// start channel handling go routine
err = res.startup(ctx)
return res, err
}
}
return nil, err
}
// PublishCache publish cache to redis async
func (c *Client) PublishCache(key string, message *dns.Msg) {
if len(key) > 0 && message != nil {
c.sendBuffer <- &bufferMessage{
Key: key,
Message: message,
}
}
}
func (c *Client) PublishEnabled(ctx context.Context, state *EnabledMessage) {
binState, sErr := json.Marshal(state)
if sErr == nil {
binMsg, mErr := json.Marshal(redisMessage{
Type: messageTypeEnable,
Message: binState,
Client: c.id,
})
if mErr == nil {
c.client.Publish(ctx, SyncChannelName, binMsg)
}
}
}
// GetRedisCache reads the redis cache and publish it to the channel
func (c *Client) GetRedisCache(ctx context.Context) {
c.l.Debug("GetRedisCache")
go func() {
iter := c.client.Scan(ctx, 0, prefixKey("*"), 0).Iterator()
if err := iter.Err(); err != nil {
c.l.Error("GetRedisCache ", err)
return
}
for iter.Next(ctx) {
response, err := c.getResponse(ctx, iter.Val())
if err == nil {
if response != nil {
if !util.CtxSend(ctx, c.CacheChannel, response) {
return
}
}
} else {
c.l.Error("GetRedisCache ", err)
}
}
}()
}
// startup starts a new goroutine for subscription and translation
func (c *Client) startup(ctx context.Context) error {
ps := c.client.Subscribe(ctx, SyncChannelName)
_, err := ps.Receive(ctx)
if err == nil {
go func() {
for {
select {
// received message from subscription
case msg := <-ps.Channel():
c.l.Debug("Received message: ", msg)
if msg != nil && len(msg.Payload) > 0 {
// message is not empty
c.processReceivedMessage(ctx, msg)
}
// publish message from buffer
case s := <-c.sendBuffer:
c.publishMessageFromBuffer(ctx, s)
// context is done
case <-ctx.Done():
c.client.Close()
return
}
}
}()
}
return err
}
func (c *Client) publishMessageFromBuffer(ctx context.Context, s *bufferMessage) {
origRes := s.Message
origRes.Compress = true
binRes, pErr := origRes.Pack()
if pErr == nil {
binMsg, mErr := json.Marshal(redisMessage{
Key: s.Key,
Type: messageTypeCache,
Message: binRes,
Client: c.id,
})
if mErr == nil {
c.client.Publish(ctx, SyncChannelName, binMsg)
}
c.client.Set(ctx,
prefixKey(s.Key),
binRes,
c.getTTL(origRes))
}
}
func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) {
var rm redisMessage
if err := json.Unmarshal([]byte(msg.Payload), &rm); err != nil {
c.l.Error("Processing error: ", err)
return
}
// message was sent from a different blocky instance
if !bytes.Equal(rm.Client, c.id) {
switch rm.Type {
case messageTypeCache:
var cm *CacheMessage
cm, err := convertMessage(&rm, 0)
if err != nil {
c.l.Error("Processing CacheMessage error: ", err)
return
}
util.CtxSend(ctx, c.CacheChannel, cm)
case messageTypeEnable:
var msg EnabledMessage
if err := json.Unmarshal(rm.Message, &msg); err != nil {
c.l.Error("Processing EnabledMessage error: ", err)
return
}
util.CtxSend(ctx, c.EnabledChannel, &msg)
default:
c.l.Warn("Unknown message type: ", rm.Type)
}
}
}
// getResponse returns model.Response for a key
func (c *Client) getResponse(ctx context.Context, key string) (*CacheMessage, error) {
resp, err := c.client.Get(ctx, key).Result()
if err == nil {
var ttl time.Duration
ttl, err = c.client.TTL(ctx, key).Result()
if err == nil {
var result *CacheMessage
result, err = convertMessage(&redisMessage{
Key: cleanKey(key),
Message: []byte(resp),
}, ttl)
if err != nil {
return nil, fmt.Errorf("conversion error: %w", err)
}
return result, nil
}
}
return nil, err
}
// convertMessage converts redisMessage to CacheMessage
func convertMessage(message *redisMessage, ttl time.Duration) (*CacheMessage, error) {
msg := dns.Msg{}
err := msg.Unpack(message.Message)
if err == nil {
if ttl > 0 {
for _, a := range msg.Answer {
a.Header().Ttl = uint32(ttl.Seconds())
}
}
res := &CacheMessage{
Key: message.Key,
Response: &model.Response{
RType: model.ResponseTypeCACHED,
Reason: cacheReason,
Res: &msg,
},
}
return res, nil
}
return nil, err
}
// getTTL of dns message or return defaultCacheTime if 0
func (c *Client) getTTL(dns *dns.Msg) time.Duration {
ttl := uint32(math.MaxInt32)
for _, a := range dns.Answer {
ttl = min(ttl, a.Header().Ttl)
}
if ttl == 0 {
return defaultCacheTime
}
return time.Duration(ttl) * time.Second
}
// prefixKey with CacheStorePrefix
func prefixKey(key string) string {
return fmt.Sprintf("%s%s", CacheStorePrefix, key)
}
// cleanKey trims CacheStorePrefix prefix
func cleanKey(key string) string {
return strings.TrimPrefix(key, CacheStorePrefix)
}