refactor: add `:` prefix to ports during config unmarshaling

This commit is contained in:
ThinkChaos 2024-03-30 18:02:27 -04:00
parent 36d443728d
commit 4b37b404bf
No known key found for this signature in database
5 changed files with 32 additions and 24 deletions

View File

@ -170,6 +170,13 @@ func (l *ListenConfig) UnmarshalText(data []byte) error {
*l = strings.Split(addresses, ",")
// Prefix all ports with :
for i, addr := range *l {
if !strings.ContainsRune(addr, ':') {
(*l)[i] = ":" + addr
}
}
return nil
}

View File

@ -462,7 +462,7 @@ bootstrapDns:
err := l.UnmarshalText([]byte("55,:56"))
Expect(err).Should(Succeed())
Expect(*l).Should(HaveLen(2))
Expect(*l).Should(ContainElements("55", ":56"))
Expect(*l).Should(ContainElements(":55", ":56"))
})
})
})
@ -958,7 +958,7 @@ bootstrapDns:
})
func defaultTestFileConfig(config *Config) {
Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"}))
Expect(config.Ports.DNS).Should(Equal(ListenConfig{":55553", ":55554", "[::1]:55555"}))
Expect(config.Upstreams.Init.Strategy).Should(Equal(InitStrategyFailOnError))
Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky"))
Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3))

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
@ -29,20 +30,27 @@ const (
DS = dns.Type(dns.TypeDS)
)
// GetIntPort returns an port for the current testing
// GetIntPort returns a port for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as int
// the base port and returning it as int.
func GetIntPort(port int) int {
return port + ginkgo.GinkgoParallelProcess()
}
// GetStringPort returns an port for the current testing
// GetStringPort returns a port for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as string
// the base port and returning it as string.
func GetStringPort(port int) string {
return fmt.Sprintf("%d", GetIntPort(port))
}
// GetHostPort returns a host:port string for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as string.
func GetHostPort(host string, port int) string {
return net.JoinHostPort(host, GetStringPort(port))
}
// TempFile creates temp file with passed data
func TempFile(data string) *os.File {
f, err := os.CreateTemp("", "prefix")

View File

@ -69,14 +69,6 @@ func tlsCipherSuites() []uint16 {
return tlsCipherSuites
}
func getServerAddress(addr string) string {
if !strings.Contains(addr, ":") {
addr = fmt.Sprintf(":%s", addr)
}
return addr
}
type NewServerFunc func(address string) (*dns.Server, error)
func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
@ -204,7 +196,7 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error
addServers := func(newServer NewServerFunc, addresses config.ListenConfig) error {
for _, address := range addresses {
server, err := newServer(getServerAddress(address))
server, err := newServer(address)
if err != nil {
return err
}
@ -245,7 +237,7 @@ func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listene
listeners := make([]net.Listener, 0, len(addresses))
for _, address := range addresses {
listener, err := net.Listen("tcp", getServerAddress(address))
listener, err := net.Listen("tcp", address)
if err != nil {
return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err)
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
@ -43,7 +44,7 @@ var (
)
var _ = BeforeSuite(func() {
baseURL = "http://localhost:" + GetStringPort(httpBasePort) + "/"
baseURL = fmt.Sprintf("http://%s/", GetHostPort("localhost", httpBasePort))
queryURL = baseURL + "dns-query"
var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream
ctx, cancelFn := context.WithCancel(context.Background())
@ -146,10 +147,10 @@ var _ = BeforeSuite(func() {
},
Ports: config.Ports{
DNS: config.ListenConfig{GetStringPort(dnsBasePort)},
TLS: config.ListenConfig{GetStringPort(tlsBasePort)},
HTTP: config.ListenConfig{GetStringPort(httpBasePort)},
HTTPS: config.ListenConfig{GetStringPort(httpsBasePort)},
DNS: config.ListenConfig{GetHostPort("", dnsBasePort)},
TLS: config.ListenConfig{GetHostPort("", tlsBasePort)},
HTTP: config.ListenConfig{GetHostPort("", httpBasePort)},
HTTPS: config.ListenConfig{GetHostPort("", httpsBasePort)},
},
CertFile: certPem.Path,
KeyFile: keyPem.Path,
@ -633,7 +634,7 @@ var _ = Describe("Running DNS server", func() {
},
Blocking: config.Blocking{BlockType: "zeroIp"},
Ports: config.Ports{
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
},
})
@ -677,7 +678,7 @@ var _ = Describe("Running DNS server", func() {
},
Blocking: config.Blocking{BlockType: "zeroIp"},
Ports: config.Ports{
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
},
})
@ -751,7 +752,7 @@ var _ = Describe("Running DNS server", func() {
})
func requestServer(request *dns.Msg) *dns.Msg {
conn, err := net.Dial("udp", ":"+GetStringPort(dnsBasePort))
conn, err := net.Dial("udp", GetHostPort("", dnsBasePort))
if err != nil {
Log().Fatal("could not connect to server: ", err)
}