fixed generics

This commit is contained in:
Kwitsch 2024-04-17 17:12:00 +00:00
parent 5eb351f23c
commit 0f6c847529
1 changed files with 16 additions and 45 deletions

View File

@ -2,57 +2,28 @@ package util
import (
"math"
"strconv"
"sync/atomic"
"time"
"github.com/miekg/dns"
)
// ttlInput is the input type for TTL values and consists of the following types:
// int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, string, time.Duration
type ttlInput interface {
int | int8 | int16 | int32 | int64 | uint | uint8 | uint32 | uint64 | string | time.Duration
// TTLInput is the input type for TTL values and consists of the following underlying types:
// int, uint, uint32, int64
type TTLInput interface {
~int | ~uint | ~uint32 | ~int64
}
// ToTTL converts the input to a TTL of seconds as uint32.
func ToTTL[T ttlInput](input T) uint32 {
// If the input is of underlying type time.Duration, the value is converted to seconds.
// If the input is negative, the TTL is set to 0.
// If the input is greater than the maximum value of uint32, the TTL is set to math.MaxUint32.
func ToTTL[T TTLInput](input T) uint32 {
// use int64 as the intermediate type
res := int64(0)
res := int64(input)
switch typedInput := any(input).(type) {
case string:
if seconds, err := strconv.Atoi(typedInput); err == nil {
res = int64(seconds)
} else {
if duration, err := time.ParseDuration(typedInput); err == nil {
res = int64(duration.Seconds())
}
}
case time.Duration:
res = int64(typedInput.Seconds())
case int:
res = int64(typedInput)
case int8:
res = int64(typedInput)
case int16:
res = int64(typedInput)
case int32:
res = int64(typedInput)
case int64:
res = typedInput
case uint:
res = int64(typedInput)
case uint8:
res = int64(typedInput)
case uint16:
res = int64(typedInput)
case uint32:
res = int64(typedInput)
case uint64:
res = int64(typedInput)
default:
panic("invalid TTL value input type")
// check if the input is of underlying type time.Duration
if durType, ok := any(input).(interface{ Seconds() float64 }); ok {
res = int64(durType.Seconds())
}
// check if the value is negative or greater than the maximum value of uint32
@ -70,7 +41,7 @@ func ToTTL[T ttlInput](input T) uint32 {
// SetAnswerMinTTL sets the TTL of all answers in the message that are less than the specified minimum TTL to
// the minimum TTL.
func SetAnswerMinTTL[T ttlInput](msg *dns.Msg, min T) {
func SetAnswerMinTTL[T TTLInput](msg *dns.Msg, min T) {
minTTL := ToTTL(min)
for _, answer := range msg.Answer {
if atomic.LoadUint32(&answer.Header().Ttl) < minTTL {
@ -81,7 +52,7 @@ func SetAnswerMinTTL[T ttlInput](msg *dns.Msg, min T) {
// SetAnswerMaxTTL sets the TTL of all answers in the message that are greater than the specified maximum TTL
// to the maximum TTL.
func SetAnswerMaxTTL[T ttlInput](msg *dns.Msg, max T) {
func SetAnswerMaxTTL[T TTLInput](msg *dns.Msg, max T) {
maxTTL := ToTTL(max)
for _, answer := range msg.Answer {
if atomic.LoadUint32(&answer.Header().Ttl) > maxTTL && maxTTL != 0 {
@ -92,7 +63,7 @@ func SetAnswerMaxTTL[T ttlInput](msg *dns.Msg, max T) {
// SetAnswerMinMaxTTL sets the TTL of all answers in the message that are less than the specified minimum TTL
// to the minimum TTL and the TTL of all answers that are greater than the specified maximum TTL to the maximum TTL.
func SetAnswerMinMaxTTL[T ttlInput](msg *dns.Msg, min, max T) {
func SetAnswerMinMaxTTL[T TTLInput, TT TTLInput](msg *dns.Msg, min T, max TT) {
minTTL := ToTTL(min)
maxTTL := ToTTL(max)
@ -124,7 +95,7 @@ func GetAnswerMinTTL(msg *dns.Msg) uint32 {
// AdjustAnswerTTL adjusts the TTL of all answers in the message by the difference between the lowest TTL
// and the answer's TTL plus the specified adjustment.
func AdjustAnswerTTL[T ttlInput](msg *dns.Msg, adjustment T) {
func AdjustAnswerTTL[T TTLInput](msg *dns.Msg, adjustment T) {
minTTL := GetAnswerMinTTL(msg)
adjustmentTTL := ToTTL(adjustment)