mirror of https://github.com/0xERR0R/blocky.git
refactor: add `:` prefix to ports during config unmarshaling
This commit is contained in:
parent
36d443728d
commit
4b37b404bf
|
@ -170,6 +170,13 @@ func (l *ListenConfig) UnmarshalText(data []byte) error {
|
|||
|
||||
*l = strings.Split(addresses, ",")
|
||||
|
||||
// Prefix all ports with :
|
||||
for i, addr := range *l {
|
||||
if !strings.ContainsRune(addr, ':') {
|
||||
(*l)[i] = ":" + addr
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -462,7 +462,7 @@ bootstrapDns:
|
|||
err := l.UnmarshalText([]byte("55,:56"))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(*l).Should(HaveLen(2))
|
||||
Expect(*l).Should(ContainElements("55", ":56"))
|
||||
Expect(*l).Should(ContainElements(":55", ":56"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
@ -958,7 +958,7 @@ bootstrapDns:
|
|||
})
|
||||
|
||||
func defaultTestFileConfig(config *Config) {
|
||||
Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"}))
|
||||
Expect(config.Ports.DNS).Should(Equal(ListenConfig{":55553", ":55554", "[::1]:55555"}))
|
||||
Expect(config.Upstreams.Init.Strategy).Should(Equal(InitStrategyFailOnError))
|
||||
Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky"))
|
||||
Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3))
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
|
@ -29,20 +30,27 @@ const (
|
|||
DS = dns.Type(dns.TypeDS)
|
||||
)
|
||||
|
||||
// GetIntPort returns an port for the current testing
|
||||
// GetIntPort returns a port for the current testing
|
||||
// process by adding the current ginkgo parallel process to
|
||||
// the base port and returning it as int
|
||||
// the base port and returning it as int.
|
||||
func GetIntPort(port int) int {
|
||||
return port + ginkgo.GinkgoParallelProcess()
|
||||
}
|
||||
|
||||
// GetStringPort returns an port for the current testing
|
||||
// GetStringPort returns a port for the current testing
|
||||
// process by adding the current ginkgo parallel process to
|
||||
// the base port and returning it as string
|
||||
// the base port and returning it as string.
|
||||
func GetStringPort(port int) string {
|
||||
return fmt.Sprintf("%d", GetIntPort(port))
|
||||
}
|
||||
|
||||
// GetHostPort returns a host:port string for the current testing
|
||||
// process by adding the current ginkgo parallel process to
|
||||
// the base port and returning it as string.
|
||||
func GetHostPort(host string, port int) string {
|
||||
return net.JoinHostPort(host, GetStringPort(port))
|
||||
}
|
||||
|
||||
// TempFile creates temp file with passed data
|
||||
func TempFile(data string) *os.File {
|
||||
f, err := os.CreateTemp("", "prefix")
|
||||
|
|
|
@ -69,14 +69,6 @@ func tlsCipherSuites() []uint16 {
|
|||
return tlsCipherSuites
|
||||
}
|
||||
|
||||
func getServerAddress(addr string) string {
|
||||
if !strings.Contains(addr, ":") {
|
||||
addr = fmt.Sprintf(":%s", addr)
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
type NewServerFunc func(address string) (*dns.Server, error)
|
||||
|
||||
func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
|
||||
|
@ -204,7 +196,7 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error
|
|||
|
||||
addServers := func(newServer NewServerFunc, addresses config.ListenConfig) error {
|
||||
for _, address := range addresses {
|
||||
server, err := newServer(getServerAddress(address))
|
||||
server, err := newServer(address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -245,7 +237,7 @@ func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listene
|
|||
listeners := make([]net.Listener, 0, len(addresses))
|
||||
|
||||
for _, address := range addresses {
|
||||
listener, err := net.Listen("tcp", getServerAddress(address))
|
||||
listener, err := net.Listen("tcp", address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -43,7 +44,7 @@ var (
|
|||
)
|
||||
|
||||
var _ = BeforeSuite(func() {
|
||||
baseURL = "http://localhost:" + GetStringPort(httpBasePort) + "/"
|
||||
baseURL = fmt.Sprintf("http://%s/", GetHostPort("localhost", httpBasePort))
|
||||
queryURL = baseURL + "dns-query"
|
||||
var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
|
@ -146,10 +147,10 @@ var _ = BeforeSuite(func() {
|
|||
},
|
||||
|
||||
Ports: config.Ports{
|
||||
DNS: config.ListenConfig{GetStringPort(dnsBasePort)},
|
||||
TLS: config.ListenConfig{GetStringPort(tlsBasePort)},
|
||||
HTTP: config.ListenConfig{GetStringPort(httpBasePort)},
|
||||
HTTPS: config.ListenConfig{GetStringPort(httpsBasePort)},
|
||||
DNS: config.ListenConfig{GetHostPort("", dnsBasePort)},
|
||||
TLS: config.ListenConfig{GetHostPort("", tlsBasePort)},
|
||||
HTTP: config.ListenConfig{GetHostPort("", httpBasePort)},
|
||||
HTTPS: config.ListenConfig{GetHostPort("", httpsBasePort)},
|
||||
},
|
||||
CertFile: certPem.Path,
|
||||
KeyFile: keyPem.Path,
|
||||
|
@ -633,7 +634,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
},
|
||||
Blocking: config.Blocking{BlockType: "zeroIp"},
|
||||
Ports: config.Ports{
|
||||
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
|
||||
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
|
||||
},
|
||||
})
|
||||
|
||||
|
@ -677,7 +678,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
},
|
||||
Blocking: config.Blocking{BlockType: "zeroIp"},
|
||||
Ports: config.Ports{
|
||||
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
|
||||
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
|
||||
},
|
||||
})
|
||||
|
||||
|
@ -751,7 +752,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
})
|
||||
|
||||
func requestServer(request *dns.Msg) *dns.Msg {
|
||||
conn, err := net.Dial("udp", ":"+GetStringPort(dnsBasePort))
|
||||
conn, err := net.Dial("udp", GetHostPort("", dnsBasePort))
|
||||
if err != nil {
|
||||
Log().Fatal("could not connect to server: ", err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue