mirror of https://github.com/0xERR0R/blocky.git
refactor(server): setup TLS listeners manually to remove `ServeTLS` use
This commit is contained in:
parent
c6e3de4ae0
commit
17b2a94a64
|
@ -50,7 +50,6 @@ type Server struct {
|
|||
queryResolver resolver.ChainedResolver
|
||||
cfg *config.Config
|
||||
httpMux *chi.Mux
|
||||
tlsCfg *tls.Config
|
||||
}
|
||||
|
||||
func logger() *logrus.Entry {
|
||||
|
@ -117,8 +116,6 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
|
|||
}
|
||||
|
||||
// NewServer creates new server instance with passed config
|
||||
//
|
||||
//nolint:funlen
|
||||
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
|
||||
var tlsCfg *tls.Config
|
||||
|
||||
|
@ -134,7 +131,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
|
|||
return nil, fmt.Errorf("server creation failed: %w", err)
|
||||
}
|
||||
|
||||
httpListeners, httpsListeners, err := createHTTPListeners(cfg)
|
||||
httpListeners, httpsListeners, err := createHTTPListeners(cfg, tlsCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -165,7 +162,6 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
|
|||
cfg: cfg,
|
||||
httpListeners: httpListeners,
|
||||
httpsListeners: httpsListeners,
|
||||
tlsCfg: tlsCfg,
|
||||
}
|
||||
|
||||
server.printConfiguration()
|
||||
|
@ -211,13 +207,15 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error
|
|||
return dnsServers, err.ErrorOrNil()
|
||||
}
|
||||
|
||||
func createHTTPListeners(cfg *config.Config) (httpListeners, httpsListeners []net.Listener, err error) {
|
||||
httpListeners, err = newListeners("http", cfg.Ports.HTTP)
|
||||
func createHTTPListeners(
|
||||
cfg *config.Config, tlsCfg *tls.Config,
|
||||
) (httpListeners, httpsListeners []net.Listener, err error) {
|
||||
httpListeners, err = newTCPListeners("http", cfg.Ports.HTTP)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
httpsListeners, err = newListeners("https", cfg.Ports.HTTPS)
|
||||
httpsListeners, err = newTLSListeners("https", cfg.Ports.HTTPS, tlsCfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -225,7 +223,7 @@ func createHTTPListeners(cfg *config.Config) (httpListeners, httpsListeners []ne
|
|||
return httpListeners, httpsListeners, nil
|
||||
}
|
||||
|
||||
func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) {
|
||||
func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) {
|
||||
listeners := make([]net.Listener, 0, len(addresses))
|
||||
|
||||
for _, address := range addresses {
|
||||
|
@ -240,6 +238,19 @@ func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener,
|
|||
return listeners, nil
|
||||
}
|
||||
|
||||
func newTLSListeners(proto string, addresses config.ListenConfig, tlsCfg *tls.Config) ([]net.Listener, error) {
|
||||
listeners, err := newTCPListeners(proto, addresses)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, inner := range listeners {
|
||||
listeners[i] = tls.NewListener(inner, tlsCfg)
|
||||
}
|
||||
|
||||
return listeners, nil
|
||||
}
|
||||
|
||||
func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) {
|
||||
return &dns.Server{
|
||||
Addr: address,
|
||||
|
@ -521,10 +532,9 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
|
|||
ReadTimeout: readTimeout,
|
||||
ReadHeaderTimeout: readHeaderTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
TLSConfig: s.tlsCfg,
|
||||
}
|
||||
|
||||
if err := server.ServeTLS(listener, "", ""); err != nil {
|
||||
if err := server.Serve(listener); err != nil {
|
||||
errCh <- fmt.Errorf("start https listener failed: %w", err)
|
||||
}
|
||||
}()
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -741,11 +740,12 @@ var _ = Describe("Running DNS server", func() {
|
|||
cfg.KeyFile = ""
|
||||
cfg.CertFile = ""
|
||||
cfg.Ports = config.Ports{
|
||||
HTTPS: []string{fmt.Sprintf(":%d", GetIntPort(httpsBasePort)+100)},
|
||||
HTTPS: []string{":0"},
|
||||
}
|
||||
sut, err := NewServer(ctx, &cfg)
|
||||
|
||||
sut, err := newTLSConfig(&cfg)
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(sut.tlsCfg.Certificates).ShouldNot(BeEmpty())
|
||||
Expect(sut.Certificates).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue