mirror of https://github.com/0xERR0R/blocky.git
352 lines
7.6 KiB
Go
352 lines
7.6 KiB
Go
package redis
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/0xERR0R/blocky/config"
|
|
"github.com/0xERR0R/blocky/log"
|
|
"github.com/0xERR0R/blocky/model"
|
|
"github.com/0xERR0R/blocky/util"
|
|
"github.com/go-redis/redis/v8"
|
|
"github.com/google/uuid"
|
|
"github.com/miekg/dns"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
SyncChannelName = "blocky_sync"
|
|
CacheStorePrefix = "blocky:cache:"
|
|
chanCap = 1000
|
|
cacheReason = "EXTERNAL_CACHE"
|
|
defaultCacheTime = 1 * time.Second
|
|
messageTypeCache = 0
|
|
messageTypeEnable = 1
|
|
)
|
|
|
|
// sendBuffer message
|
|
type bufferMessage struct {
|
|
Key string
|
|
Message *dns.Msg
|
|
}
|
|
|
|
// redis pubsub message
|
|
type redisMessage struct {
|
|
Key string `json:"k,omitempty"`
|
|
Type int `json:"t"`
|
|
Message []byte `json:"m"`
|
|
Client []byte `json:"c"`
|
|
}
|
|
|
|
// CacheChannel message
|
|
type CacheMessage struct {
|
|
Key string
|
|
Response *model.Response
|
|
}
|
|
|
|
type EnabledMessage struct {
|
|
State bool `json:"s"`
|
|
Duration time.Duration `json:"d,omitempty"`
|
|
Groups []string `json:"g,omitempty"`
|
|
}
|
|
|
|
// Client for redis communication
|
|
type Client struct {
|
|
config *config.Redis
|
|
client *redis.Client
|
|
l *logrus.Entry
|
|
id []byte
|
|
sendBuffer chan *bufferMessage
|
|
CacheChannel chan *CacheMessage
|
|
EnabledChannel chan *EnabledMessage
|
|
}
|
|
|
|
// New creates a new redis client
|
|
func New(ctx context.Context, cfg *config.Redis) (*Client, error) {
|
|
// disable redis if no address is provided
|
|
if cfg == nil || len(cfg.Address) == 0 {
|
|
return nil, nil //nolint:nilnil
|
|
}
|
|
|
|
var baseClient *redis.Client
|
|
if len(cfg.SentinelAddresses) > 0 {
|
|
baseClient = redis.NewFailoverClient(&redis.FailoverOptions{
|
|
MasterName: cfg.Address,
|
|
SentinelUsername: cfg.Username,
|
|
SentinelPassword: cfg.SentinelPassword,
|
|
SentinelAddrs: cfg.SentinelAddresses,
|
|
Username: cfg.Username,
|
|
Password: cfg.Password,
|
|
DB: cfg.Database,
|
|
MaxRetries: cfg.ConnectionAttempts,
|
|
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
|
|
})
|
|
} else {
|
|
baseClient = redis.NewClient(&redis.Options{
|
|
Addr: cfg.Address,
|
|
Username: cfg.Username,
|
|
Password: cfg.Password,
|
|
DB: cfg.Database,
|
|
MaxRetries: cfg.ConnectionAttempts,
|
|
MaxRetryBackoff: cfg.ConnectionCooldown.ToDuration(),
|
|
})
|
|
}
|
|
|
|
rdb := baseClient.WithContext(ctx)
|
|
|
|
_, err := rdb.Ping(ctx).Result()
|
|
if err == nil {
|
|
var id []byte
|
|
|
|
id, err = uuid.New().MarshalBinary()
|
|
if err == nil {
|
|
// construct client
|
|
res := &Client{
|
|
config: cfg,
|
|
client: rdb,
|
|
l: log.PrefixedLog("redis"),
|
|
id: id,
|
|
sendBuffer: make(chan *bufferMessage, chanCap),
|
|
CacheChannel: make(chan *CacheMessage, chanCap),
|
|
EnabledChannel: make(chan *EnabledMessage, chanCap),
|
|
}
|
|
|
|
// start channel handling go routine
|
|
err = res.startup(ctx)
|
|
|
|
return res, err
|
|
}
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
// PublishCache publish cache to redis async
|
|
func (c *Client) PublishCache(key string, message *dns.Msg) {
|
|
if len(key) > 0 && message != nil {
|
|
c.sendBuffer <- &bufferMessage{
|
|
Key: key,
|
|
Message: message,
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) PublishEnabled(ctx context.Context, state *EnabledMessage) {
|
|
binState, sErr := json.Marshal(state)
|
|
if sErr == nil {
|
|
binMsg, mErr := json.Marshal(redisMessage{
|
|
Type: messageTypeEnable,
|
|
Message: binState,
|
|
Client: c.id,
|
|
})
|
|
|
|
if mErr == nil {
|
|
c.client.Publish(ctx, SyncChannelName, binMsg)
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetRedisCache reads the redis cache and publish it to the channel
|
|
func (c *Client) GetRedisCache(ctx context.Context) {
|
|
c.l.Debug("GetRedisCache")
|
|
|
|
go func() {
|
|
iter := c.client.Scan(ctx, 0, prefixKey("*"), 0).Iterator()
|
|
if err := iter.Err(); err != nil {
|
|
c.l.Error("GetRedisCache ", err)
|
|
|
|
return
|
|
}
|
|
|
|
for iter.Next(ctx) {
|
|
response, err := c.getResponse(ctx, iter.Val())
|
|
if err == nil {
|
|
if response != nil {
|
|
if !util.CtxSend(ctx, c.CacheChannel, response) {
|
|
return
|
|
}
|
|
}
|
|
} else {
|
|
c.l.Error("GetRedisCache ", err)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// startup starts a new goroutine for subscription and translation
|
|
func (c *Client) startup(ctx context.Context) error {
|
|
ps := c.client.Subscribe(ctx, SyncChannelName)
|
|
|
|
_, err := ps.Receive(ctx)
|
|
if err == nil {
|
|
go func() {
|
|
for {
|
|
select {
|
|
// received message from subscription
|
|
case msg := <-ps.Channel():
|
|
c.l.Debug("Received message: ", msg)
|
|
|
|
if msg != nil && len(msg.Payload) > 0 {
|
|
// message is not empty
|
|
c.processReceivedMessage(ctx, msg)
|
|
}
|
|
// publish message from buffer
|
|
case s := <-c.sendBuffer:
|
|
c.publishMessageFromBuffer(ctx, s)
|
|
// context is done
|
|
case <-ctx.Done():
|
|
c.client.Close()
|
|
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (c *Client) publishMessageFromBuffer(ctx context.Context, s *bufferMessage) {
|
|
origRes := s.Message
|
|
origRes.Compress = true
|
|
binRes, pErr := origRes.Pack()
|
|
|
|
if pErr == nil {
|
|
binMsg, mErr := json.Marshal(redisMessage{
|
|
Key: s.Key,
|
|
Type: messageTypeCache,
|
|
Message: binRes,
|
|
Client: c.id,
|
|
})
|
|
|
|
if mErr == nil {
|
|
c.client.Publish(ctx, SyncChannelName, binMsg)
|
|
}
|
|
|
|
c.client.Set(ctx,
|
|
prefixKey(s.Key),
|
|
binRes,
|
|
c.getTTL(origRes))
|
|
}
|
|
}
|
|
|
|
func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) {
|
|
var rm redisMessage
|
|
|
|
if err := json.Unmarshal([]byte(msg.Payload), &rm); err != nil {
|
|
c.l.Error("Processing error: ", err)
|
|
|
|
return
|
|
}
|
|
|
|
// message was sent from a different blocky instance
|
|
if !bytes.Equal(rm.Client, c.id) {
|
|
switch rm.Type {
|
|
case messageTypeCache:
|
|
var cm *CacheMessage
|
|
|
|
cm, err := convertMessage(&rm, 0)
|
|
if err != nil {
|
|
c.l.Error("Processing CacheMessage error: ", err)
|
|
|
|
return
|
|
}
|
|
|
|
util.CtxSend(ctx, c.CacheChannel, cm)
|
|
case messageTypeEnable:
|
|
var msg EnabledMessage
|
|
|
|
if err := json.Unmarshal(rm.Message, &msg); err != nil {
|
|
c.l.Error("Processing EnabledMessage error: ", err)
|
|
|
|
return
|
|
}
|
|
|
|
util.CtxSend(ctx, c.EnabledChannel, &msg)
|
|
default:
|
|
c.l.Warn("Unknown message type: ", rm.Type)
|
|
}
|
|
}
|
|
}
|
|
|
|
// getResponse returns model.Response for a key
|
|
func (c *Client) getResponse(ctx context.Context, key string) (*CacheMessage, error) {
|
|
resp, err := c.client.Get(ctx, key).Result()
|
|
if err == nil {
|
|
var ttl time.Duration
|
|
ttl, err = c.client.TTL(ctx, key).Result()
|
|
|
|
if err == nil {
|
|
var result *CacheMessage
|
|
|
|
result, err = convertMessage(&redisMessage{
|
|
Key: cleanKey(key),
|
|
Message: []byte(resp),
|
|
}, ttl)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("conversion error: %w", err)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
// convertMessage converts redisMessage to CacheMessage
|
|
func convertMessage(message *redisMessage, ttl time.Duration) (*CacheMessage, error) {
|
|
msg := dns.Msg{}
|
|
|
|
err := msg.Unpack(message.Message)
|
|
if err == nil {
|
|
if ttl > 0 {
|
|
for _, a := range msg.Answer {
|
|
a.Header().Ttl = uint32(ttl.Seconds())
|
|
}
|
|
}
|
|
|
|
res := &CacheMessage{
|
|
Key: message.Key,
|
|
Response: &model.Response{
|
|
RType: model.ResponseTypeCACHED,
|
|
Reason: cacheReason,
|
|
Res: &msg,
|
|
},
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
// getTTL of dns message or return defaultCacheTime if 0
|
|
func (c *Client) getTTL(dns *dns.Msg) time.Duration {
|
|
ttl := uint32(math.MaxInt32)
|
|
for _, a := range dns.Answer {
|
|
ttl = min(ttl, a.Header().Ttl)
|
|
}
|
|
|
|
if ttl == 0 {
|
|
return defaultCacheTime
|
|
}
|
|
|
|
return time.Duration(ttl) * time.Second
|
|
}
|
|
|
|
// prefixKey with CacheStorePrefix
|
|
func prefixKey(key string) string {
|
|
return fmt.Sprintf("%s%s", CacheStorePrefix, key)
|
|
}
|
|
|
|
// cleanKey trims CacheStorePrefix prefix
|
|
func cleanKey(key string) string {
|
|
return strings.TrimPrefix(key, CacheStorePrefix)
|
|
}
|