fix: duration checks to take into account values can be negative

Replace `IsZero` with `IsAboveZero` to help us avoid this mistake again.
This commit is contained in:
ThinkChaos 2023-04-19 21:14:29 -04:00
parent cfc3699ab5
commit f22e310501
7 changed files with 29 additions and 18 deletions

View File

@ -20,7 +20,7 @@ type CachingConfig struct {
// IsEnabled implements `config.Configurable`.
func (c *CachingConfig) IsEnabled() bool {
return c.MaxCachingTime > 0
return c.MaxCachingTime.IsAboveZero()
}
// LogConfig implements `config.Configurable`.
@ -42,7 +42,7 @@ func (c *CachingConfig) LogConfig(logger *logrus.Entry) {
func (c *CachingConfig) EnablePrefetch() {
const day = Duration(24 * time.Hour)
if c.MaxCachingTime.IsZero() {
if !c.IsEnabled() {
// make sure resolver gets enabled
c.MaxCachingTime = day
}

View File

@ -298,7 +298,7 @@ func (c *SourceLoadingConfig) LogConfig(logger *logrus.Entry) {
logger.Debugf("maxErrorsPerSource = %d", c.MaxErrorsPerSource)
logger.Debugf("strategy = %s", c.Strategy)
if c.RefreshPeriod > 0 {
if c.RefreshPeriod.IsAboveZero() {
logger.Infof("refresh = every %s", c.RefreshPeriod)
} else {
logger.Debug("refresh = disabled")

View File

@ -785,8 +785,8 @@ func defaultTestFileConfig() {
Expect(config.Filtering.QueryTypes).Should(HaveLen(2))
Expect(config.FqdnOnly.Enable).Should(BeTrue())
Expect(config.Caching.MaxCachingTime.IsZero()).Should(BeTrue())
Expect(config.Caching.MinCachingTime.IsZero()).Should(BeTrue())
Expect(config.Caching.MaxCachingTime).Should(BeZero())
Expect(config.Caching.MinCachingTime).Should(BeZero())
Expect(config.DoHUserAgent).Should(Equal("testBlocky"))
Expect(config.MinTLSServeVer).Should(Equal("1.3"))

View File

@ -14,8 +14,8 @@ func (c Duration) ToDuration() time.Duration {
return time.Duration(c)
}
func (c Duration) IsZero() bool {
return c.ToDuration() == 0
func (c Duration) IsAboveZero() bool {
return c.ToDuration() > 0
}
func (c Duration) Seconds() float64 {

View File

@ -31,14 +31,25 @@ var _ = Describe("Duration", func() {
})
})
Describe("IsZero", func() {
It("should be true for zero", func() {
Expect(d.IsZero()).Should(BeTrue())
Expect(Duration(0).IsZero()).Should(BeTrue())
Describe("IsAboveZero", func() {
It("should be false for zero", func() {
Expect(d.IsAboveZero()).Should(BeFalse())
Expect(Duration(0).IsAboveZero()).Should(BeFalse())
})
It("should be false for non-zero", func() {
Expect(Duration(time.Second).IsZero()).Should(BeFalse())
It("should be false for negative", func() {
Expect(Duration(-1).IsAboveZero()).Should(BeFalse())
})
It("should be true for positive", func() {
Expect(Duration(1).IsAboveZero()).Should(BeTrue())
})
})
Describe("SecondsU32", func() {
It("should return the seconds", func() {
Expect(Duration(time.Minute).SecondsU32()).Should(Equal(uint32(60)))
Expect(Duration(time.Hour).SecondsU32()).Should(Equal(uint32(3600)))
})
})
})

View File

@ -78,7 +78,7 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
cachingCfg := cfg.Caching
cachingCfg.EnablePrefetch()
if cachingCfg.MinCachingTime.IsZero() {
if !cachingCfg.MinCachingTime.IsAboveZero() {
// Set a min time in case the user didn't to avoid prefetching too often
cachingCfg.MinCachingTime = config.Duration(time.Hour)
}
@ -116,7 +116,7 @@ func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
ctx := context.Background()
timeout := cfg.UpstreamTimeout
if timeout.IsZero() {
if timeout.IsAboveZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout.ToDuration())

View File

@ -217,7 +217,7 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response,
// put value into cache
r.resultCache.Put(cacheKey, &cacheValue{response.Res, prefetch}, r.adjustTTLs(response.Res.Answer))
} else if response.Res.Rcode == dns.RcodeNameError {
if r.cfg.CacheTimeNegative > 0 {
if r.cfg.CacheTimeNegative.IsAboveZero() {
// put negative cache if result code is NXDOMAIN
r.resultCache.Put(cacheKey, &cacheValue{response.Res, prefetch}, r.cfg.CacheTimeNegative.ToDuration())
}
@ -244,13 +244,13 @@ func (r *CachingResolver) adjustTTLs(answer []dns.RR) (maxTTL time.Duration) {
for _, a := range answer {
// if TTL < mitTTL -> adjust the value, set minTTL
if r.cfg.MinCachingTime > 0 {
if r.cfg.MinCachingTime.IsAboveZero() {
if atomic.LoadUint32(&a.Header().Ttl) < r.cfg.MinCachingTime.SecondsU32() {
atomic.StoreUint32(&a.Header().Ttl, r.cfg.MinCachingTime.SecondsU32())
}
}
if r.cfg.MaxCachingTime > 0 {
if r.cfg.MaxCachingTime.IsAboveZero() {
if atomic.LoadUint32(&a.Header().Ttl) > r.cfg.MaxCachingTime.SecondsU32() {
atomic.StoreUint32(&a.Header().Ttl, r.cfg.MaxCachingTime.SecondsU32())
}