mirror of https://github.com/0xERR0R/blocky.git
refactor: add `:` prefix to ports during config unmarshaling
This commit is contained in:
parent
36d443728d
commit
4b37b404bf
|
@ -170,6 +170,13 @@ func (l *ListenConfig) UnmarshalText(data []byte) error {
|
||||||
|
|
||||||
*l = strings.Split(addresses, ",")
|
*l = strings.Split(addresses, ",")
|
||||||
|
|
||||||
|
// Prefix all ports with :
|
||||||
|
for i, addr := range *l {
|
||||||
|
if !strings.ContainsRune(addr, ':') {
|
||||||
|
(*l)[i] = ":" + addr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -462,7 +462,7 @@ bootstrapDns:
|
||||||
err := l.UnmarshalText([]byte("55,:56"))
|
err := l.UnmarshalText([]byte("55,:56"))
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(*l).Should(HaveLen(2))
|
Expect(*l).Should(HaveLen(2))
|
||||||
Expect(*l).Should(ContainElements("55", ":56"))
|
Expect(*l).Should(ContainElements(":55", ":56"))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -958,7 +958,7 @@ bootstrapDns:
|
||||||
})
|
})
|
||||||
|
|
||||||
func defaultTestFileConfig(config *Config) {
|
func defaultTestFileConfig(config *Config) {
|
||||||
Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"}))
|
Expect(config.Ports.DNS).Should(Equal(ListenConfig{":55553", ":55554", "[::1]:55555"}))
|
||||||
Expect(config.Upstreams.Init.Strategy).Should(Equal(InitStrategyFailOnError))
|
Expect(config.Upstreams.Init.Strategy).Should(Equal(InitStrategyFailOnError))
|
||||||
Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky"))
|
Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky"))
|
||||||
Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3))
|
Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3))
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
|
@ -29,20 +30,27 @@ const (
|
||||||
DS = dns.Type(dns.TypeDS)
|
DS = dns.Type(dns.TypeDS)
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetIntPort returns an port for the current testing
|
// GetIntPort returns a port for the current testing
|
||||||
// process by adding the current ginkgo parallel process to
|
// process by adding the current ginkgo parallel process to
|
||||||
// the base port and returning it as int
|
// the base port and returning it as int.
|
||||||
func GetIntPort(port int) int {
|
func GetIntPort(port int) int {
|
||||||
return port + ginkgo.GinkgoParallelProcess()
|
return port + ginkgo.GinkgoParallelProcess()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStringPort returns an port for the current testing
|
// GetStringPort returns a port for the current testing
|
||||||
// process by adding the current ginkgo parallel process to
|
// process by adding the current ginkgo parallel process to
|
||||||
// the base port and returning it as string
|
// the base port and returning it as string.
|
||||||
func GetStringPort(port int) string {
|
func GetStringPort(port int) string {
|
||||||
return fmt.Sprintf("%d", GetIntPort(port))
|
return fmt.Sprintf("%d", GetIntPort(port))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetHostPort returns a host:port string for the current testing
|
||||||
|
// process by adding the current ginkgo parallel process to
|
||||||
|
// the base port and returning it as string.
|
||||||
|
func GetHostPort(host string, port int) string {
|
||||||
|
return net.JoinHostPort(host, GetStringPort(port))
|
||||||
|
}
|
||||||
|
|
||||||
// TempFile creates temp file with passed data
|
// TempFile creates temp file with passed data
|
||||||
func TempFile(data string) *os.File {
|
func TempFile(data string) *os.File {
|
||||||
f, err := os.CreateTemp("", "prefix")
|
f, err := os.CreateTemp("", "prefix")
|
||||||
|
|
|
@ -69,14 +69,6 @@ func tlsCipherSuites() []uint16 {
|
||||||
return tlsCipherSuites
|
return tlsCipherSuites
|
||||||
}
|
}
|
||||||
|
|
||||||
func getServerAddress(addr string) string {
|
|
||||||
if !strings.Contains(addr, ":") {
|
|
||||||
addr = fmt.Sprintf(":%s", addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|
||||||
type NewServerFunc func(address string) (*dns.Server, error)
|
type NewServerFunc func(address string) (*dns.Server, error)
|
||||||
|
|
||||||
func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
|
func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
|
||||||
|
@ -204,7 +196,7 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error
|
||||||
|
|
||||||
addServers := func(newServer NewServerFunc, addresses config.ListenConfig) error {
|
addServers := func(newServer NewServerFunc, addresses config.ListenConfig) error {
|
||||||
for _, address := range addresses {
|
for _, address := range addresses {
|
||||||
server, err := newServer(getServerAddress(address))
|
server, err := newServer(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -245,7 +237,7 @@ func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listene
|
||||||
listeners := make([]net.Listener, 0, len(addresses))
|
listeners := make([]net.Listener, 0, len(addresses))
|
||||||
|
|
||||||
for _, address := range addresses {
|
for _, address := range addresses {
|
||||||
listener, err := net.Listen("tcp", getServerAddress(address))
|
listener, err := net.Listen("tcp", address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err)
|
return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -43,7 +44,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = BeforeSuite(func() {
|
var _ = BeforeSuite(func() {
|
||||||
baseURL = "http://localhost:" + GetStringPort(httpBasePort) + "/"
|
baseURL = fmt.Sprintf("http://%s/", GetHostPort("localhost", httpBasePort))
|
||||||
queryURL = baseURL + "dns-query"
|
queryURL = baseURL + "dns-query"
|
||||||
var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream
|
var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream
|
||||||
ctx, cancelFn := context.WithCancel(context.Background())
|
ctx, cancelFn := context.WithCancel(context.Background())
|
||||||
|
@ -146,10 +147,10 @@ var _ = BeforeSuite(func() {
|
||||||
},
|
},
|
||||||
|
|
||||||
Ports: config.Ports{
|
Ports: config.Ports{
|
||||||
DNS: config.ListenConfig{GetStringPort(dnsBasePort)},
|
DNS: config.ListenConfig{GetHostPort("", dnsBasePort)},
|
||||||
TLS: config.ListenConfig{GetStringPort(tlsBasePort)},
|
TLS: config.ListenConfig{GetHostPort("", tlsBasePort)},
|
||||||
HTTP: config.ListenConfig{GetStringPort(httpBasePort)},
|
HTTP: config.ListenConfig{GetHostPort("", httpBasePort)},
|
||||||
HTTPS: config.ListenConfig{GetStringPort(httpsBasePort)},
|
HTTPS: config.ListenConfig{GetHostPort("", httpsBasePort)},
|
||||||
},
|
},
|
||||||
CertFile: certPem.Path,
|
CertFile: certPem.Path,
|
||||||
KeyFile: keyPem.Path,
|
KeyFile: keyPem.Path,
|
||||||
|
@ -633,7 +634,7 @@ var _ = Describe("Running DNS server", func() {
|
||||||
},
|
},
|
||||||
Blocking: config.Blocking{BlockType: "zeroIp"},
|
Blocking: config.Blocking{BlockType: "zeroIp"},
|
||||||
Ports: config.Ports{
|
Ports: config.Ports{
|
||||||
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
|
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -677,7 +678,7 @@ var _ = Describe("Running DNS server", func() {
|
||||||
},
|
},
|
||||||
Blocking: config.Blocking{BlockType: "zeroIp"},
|
Blocking: config.Blocking{BlockType: "zeroIp"},
|
||||||
Ports: config.Ports{
|
Ports: config.Ports{
|
||||||
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
|
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -751,7 +752,7 @@ var _ = Describe("Running DNS server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
func requestServer(request *dns.Msg) *dns.Msg {
|
func requestServer(request *dns.Msg) *dns.Msg {
|
||||||
conn, err := net.Dial("udp", ":"+GetStringPort(dnsBasePort))
|
conn, err := net.Dial("udp", GetHostPort("", dnsBasePort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Log().Fatal("could not connect to server: ", err)
|
Log().Fatal("could not connect to server: ", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue