refactor: add `:` prefix to ports during config unmarshaling

This commit is contained in:
ThinkChaos 2024-03-30 18:02:27 -04:00
parent 36d443728d
commit 4b37b404bf
No known key found for this signature in database
5 changed files with 32 additions and 24 deletions

View File

@ -170,6 +170,13 @@ func (l *ListenConfig) UnmarshalText(data []byte) error {
*l = strings.Split(addresses, ",") *l = strings.Split(addresses, ",")
// Prefix all ports with :
for i, addr := range *l {
if !strings.ContainsRune(addr, ':') {
(*l)[i] = ":" + addr
}
}
return nil return nil
} }

View File

@ -462,7 +462,7 @@ bootstrapDns:
err := l.UnmarshalText([]byte("55,:56")) err := l.UnmarshalText([]byte("55,:56"))
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(*l).Should(HaveLen(2)) Expect(*l).Should(HaveLen(2))
Expect(*l).Should(ContainElements("55", ":56")) Expect(*l).Should(ContainElements(":55", ":56"))
}) })
}) })
}) })
@ -958,7 +958,7 @@ bootstrapDns:
}) })
func defaultTestFileConfig(config *Config) { func defaultTestFileConfig(config *Config) {
Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"})) Expect(config.Ports.DNS).Should(Equal(ListenConfig{":55553", ":55554", "[::1]:55555"}))
Expect(config.Upstreams.Init.Strategy).Should(Equal(InitStrategyFailOnError)) Expect(config.Upstreams.Init.Strategy).Should(Equal(InitStrategyFailOnError))
Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky")) Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky"))
Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3)) Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3))

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@ -29,20 +30,27 @@ const (
DS = dns.Type(dns.TypeDS) DS = dns.Type(dns.TypeDS)
) )
// GetIntPort returns an port for the current testing // GetIntPort returns a port for the current testing
// process by adding the current ginkgo parallel process to // process by adding the current ginkgo parallel process to
// the base port and returning it as int // the base port and returning it as int.
func GetIntPort(port int) int { func GetIntPort(port int) int {
return port + ginkgo.GinkgoParallelProcess() return port + ginkgo.GinkgoParallelProcess()
} }
// GetStringPort returns an port for the current testing // GetStringPort returns a port for the current testing
// process by adding the current ginkgo parallel process to // process by adding the current ginkgo parallel process to
// the base port and returning it as string // the base port and returning it as string.
func GetStringPort(port int) string { func GetStringPort(port int) string {
return fmt.Sprintf("%d", GetIntPort(port)) return fmt.Sprintf("%d", GetIntPort(port))
} }
// GetHostPort returns a host:port string for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as string.
func GetHostPort(host string, port int) string {
return net.JoinHostPort(host, GetStringPort(port))
}
// TempFile creates temp file with passed data // TempFile creates temp file with passed data
func TempFile(data string) *os.File { func TempFile(data string) *os.File {
f, err := os.CreateTemp("", "prefix") f, err := os.CreateTemp("", "prefix")

View File

@ -69,14 +69,6 @@ func tlsCipherSuites() []uint16 {
return tlsCipherSuites return tlsCipherSuites
} }
func getServerAddress(addr string) string {
if !strings.Contains(addr, ":") {
addr = fmt.Sprintf(":%s", addr)
}
return addr
}
type NewServerFunc func(address string) (*dns.Server, error) type NewServerFunc func(address string) (*dns.Server, error)
func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) { func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
@ -204,7 +196,7 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error
addServers := func(newServer NewServerFunc, addresses config.ListenConfig) error { addServers := func(newServer NewServerFunc, addresses config.ListenConfig) error {
for _, address := range addresses { for _, address := range addresses {
server, err := newServer(getServerAddress(address)) server, err := newServer(address)
if err != nil { if err != nil {
return err return err
} }
@ -245,7 +237,7 @@ func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listene
listeners := make([]net.Listener, 0, len(addresses)) listeners := make([]net.Listener, 0, len(addresses))
for _, address := range addresses { for _, address := range addresses {
listener, err := net.Listen("tcp", getServerAddress(address)) listener, err := net.Listen("tcp", address)
if err != nil { if err != nil {
return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err) return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err)
} }

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/base64" "encoding/base64"
"fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -43,7 +44,7 @@ var (
) )
var _ = BeforeSuite(func() { var _ = BeforeSuite(func() {
baseURL = "http://localhost:" + GetStringPort(httpBasePort) + "/" baseURL = fmt.Sprintf("http://%s/", GetHostPort("localhost", httpBasePort))
queryURL = baseURL + "dns-query" queryURL = baseURL + "dns-query"
var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream
ctx, cancelFn := context.WithCancel(context.Background()) ctx, cancelFn := context.WithCancel(context.Background())
@ -146,10 +147,10 @@ var _ = BeforeSuite(func() {
}, },
Ports: config.Ports{ Ports: config.Ports{
DNS: config.ListenConfig{GetStringPort(dnsBasePort)}, DNS: config.ListenConfig{GetHostPort("", dnsBasePort)},
TLS: config.ListenConfig{GetStringPort(tlsBasePort)}, TLS: config.ListenConfig{GetHostPort("", tlsBasePort)},
HTTP: config.ListenConfig{GetStringPort(httpBasePort)}, HTTP: config.ListenConfig{GetHostPort("", httpBasePort)},
HTTPS: config.ListenConfig{GetStringPort(httpsBasePort)}, HTTPS: config.ListenConfig{GetHostPort("", httpsBasePort)},
}, },
CertFile: certPem.Path, CertFile: certPem.Path,
KeyFile: keyPem.Path, KeyFile: keyPem.Path,
@ -633,7 +634,7 @@ var _ = Describe("Running DNS server", func() {
}, },
Blocking: config.Blocking{BlockType: "zeroIp"}, Blocking: config.Blocking{BlockType: "zeroIp"},
Ports: config.Ports{ Ports: config.Ports{
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)}, DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
}, },
}) })
@ -677,7 +678,7 @@ var _ = Describe("Running DNS server", func() {
}, },
Blocking: config.Blocking{BlockType: "zeroIp"}, Blocking: config.Blocking{BlockType: "zeroIp"},
Ports: config.Ports{ Ports: config.Ports{
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)}, DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
}, },
}) })
@ -751,7 +752,7 @@ var _ = Describe("Running DNS server", func() {
}) })
func requestServer(request *dns.Msg) *dns.Msg { func requestServer(request *dns.Msg) *dns.Msg {
conn, err := net.Dial("udp", ":"+GetStringPort(dnsBasePort)) conn, err := net.Dial("udp", GetHostPort("", dnsBasePort))
if err != nil { if err != nil {
Log().Fatal("could not connect to server: ", err) Log().Fatal("could not connect to server: ", err)
} }