mirror of https://github.com/0xERR0R/blocky.git
refactor(server): deduplicate `tls.Config` setup
This commit is contained in:
parent
c389a4a0f4
commit
c6e3de4ae0
|
@ -50,7 +50,7 @@ type Server struct {
|
|||
queryResolver resolver.ChainedResolver
|
||||
cfg *config.Config
|
||||
httpMux *chi.Mux
|
||||
cert tls.Certificate
|
||||
tlsCfg *tls.Config
|
||||
}
|
||||
|
||||
func logger() *logrus.Entry {
|
||||
|
@ -98,20 +98,38 @@ func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
|
||||
var cert tls.Certificate
|
||||
|
||||
cert, err := retrieveCertificate(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't retrieve cert: %w", err)
|
||||
}
|
||||
|
||||
// #nosec G402 // See TLSVersion.validate
|
||||
res := &tls.Config{
|
||||
MinVersion: uint16(cfg.MinTLSServeVer),
|
||||
CipherSuites: tlsCipherSuites(),
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// NewServer creates new server instance with passed config
|
||||
//
|
||||
//nolint:funlen
|
||||
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
|
||||
var cert tls.Certificate
|
||||
var tlsCfg *tls.Config
|
||||
|
||||
if len(cfg.Ports.HTTPS) > 0 || len(cfg.Ports.TLS) > 0 {
|
||||
cert, err = retrieveCertificate(cfg)
|
||||
tlsCfg, err = newTLSConfig(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't retrieve cert: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
dnsServers, err := createServers(cfg, cert)
|
||||
dnsServers, err := createServers(cfg, tlsCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("server creation failed: %w", err)
|
||||
}
|
||||
|
@ -147,7 +165,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
|
|||
cfg: cfg,
|
||||
httpListeners: httpListeners,
|
||||
httpsListeners: httpsListeners,
|
||||
cert: cert,
|
||||
tlsCfg: tlsCfg,
|
||||
}
|
||||
|
||||
server.printConfiguration()
|
||||
|
@ -165,7 +183,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
|
|||
return server, err
|
||||
}
|
||||
|
||||
func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, error) {
|
||||
func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error) {
|
||||
var dnsServers []*dns.Server
|
||||
|
||||
var err *multierror.Error
|
||||
|
@ -187,7 +205,7 @@ func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, err
|
|||
addServers(createUDPServer, cfg.Ports.DNS),
|
||||
addServers(createTCPServer, cfg.Ports.DNS),
|
||||
addServers(func(address string) (*dns.Server, error) {
|
||||
return createTLSServer(cfg, address, cert)
|
||||
return createTLSServer(address, tlsCfg)
|
||||
}, cfg.Ports.TLS))
|
||||
|
||||
return dnsServers, err.ErrorOrNil()
|
||||
|
@ -222,17 +240,12 @@ func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener,
|
|||
return listeners, nil
|
||||
}
|
||||
|
||||
func createTLSServer(cfg *config.Config, address string, cert tls.Certificate) (*dns.Server, error) {
|
||||
func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) {
|
||||
return &dns.Server{
|
||||
Addr: address,
|
||||
Net: "tcp-tls",
|
||||
//nolint:gosec
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: uint16(cfg.MinTLSServeVer),
|
||||
CipherSuites: tlsCipherSuites(),
|
||||
},
|
||||
Handler: dns.NewServeMux(),
|
||||
Addr: address,
|
||||
Net: "tcp-tls",
|
||||
TLSConfig: tlsCfg,
|
||||
Handler: dns.NewServeMux(),
|
||||
NotifyStartedFunc: func() {
|
||||
logger().Infof("TLS server is up and running on address %s", address)
|
||||
},
|
||||
|
@ -508,12 +521,7 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
|
|||
ReadTimeout: readTimeout,
|
||||
ReadHeaderTimeout: readHeaderTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
//nolint:gosec
|
||||
TLSConfig: &tls.Config{
|
||||
MinVersion: uint16(s.cfg.MinTLSServeVer),
|
||||
CipherSuites: tlsCipherSuites(),
|
||||
Certificates: []tls.Certificate{s.cert},
|
||||
},
|
||||
TLSConfig: s.tlsCfg,
|
||||
}
|
||||
|
||||
if err := server.ServeTLS(listener, "", ""); err != nil {
|
||||
|
|
|
@ -745,7 +745,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
}
|
||||
sut, err := NewServer(ctx, &cfg)
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(sut.cert.Certificate).ShouldNot(BeNil())
|
||||
Expect(sut.tlsCfg.Certificates).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue