refactor(server): deduplicate `tls.Config` setup

This commit is contained in:
ThinkChaos 2024-03-29 19:27:26 -04:00
parent c389a4a0f4
commit c6e3de4ae0
No known key found for this signature in database
2 changed files with 33 additions and 25 deletions

View File

@ -50,7 +50,7 @@ type Server struct {
queryResolver resolver.ChainedResolver queryResolver resolver.ChainedResolver
cfg *config.Config cfg *config.Config
httpMux *chi.Mux httpMux *chi.Mux
cert tls.Certificate tlsCfg *tls.Config
} }
func logger() *logrus.Entry { func logger() *logrus.Entry {
@ -98,20 +98,38 @@ func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
return return
} }
func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
var cert tls.Certificate
cert, err := retrieveCertificate(cfg)
if err != nil {
return nil, fmt.Errorf("can't retrieve cert: %w", err)
}
// #nosec G402 // See TLSVersion.validate
res := &tls.Config{
MinVersion: uint16(cfg.MinTLSServeVer),
CipherSuites: tlsCipherSuites(),
Certificates: []tls.Certificate{cert},
}
return res, nil
}
// NewServer creates new server instance with passed config // NewServer creates new server instance with passed config
// //
//nolint:funlen //nolint:funlen
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) { func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
var cert tls.Certificate var tlsCfg *tls.Config
if len(cfg.Ports.HTTPS) > 0 || len(cfg.Ports.TLS) > 0 { if len(cfg.Ports.HTTPS) > 0 || len(cfg.Ports.TLS) > 0 {
cert, err = retrieveCertificate(cfg) tlsCfg, err = newTLSConfig(cfg)
if err != nil { if err != nil {
return nil, fmt.Errorf("can't retrieve cert: %w", err) return nil, err
} }
} }
dnsServers, err := createServers(cfg, cert) dnsServers, err := createServers(cfg, tlsCfg)
if err != nil { if err != nil {
return nil, fmt.Errorf("server creation failed: %w", err) return nil, fmt.Errorf("server creation failed: %w", err)
} }
@ -147,7 +165,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
cfg: cfg, cfg: cfg,
httpListeners: httpListeners, httpListeners: httpListeners,
httpsListeners: httpsListeners, httpsListeners: httpsListeners,
cert: cert, tlsCfg: tlsCfg,
} }
server.printConfiguration() server.printConfiguration()
@ -165,7 +183,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
return server, err return server, err
} }
func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, error) { func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error) {
var dnsServers []*dns.Server var dnsServers []*dns.Server
var err *multierror.Error var err *multierror.Error
@ -187,7 +205,7 @@ func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, err
addServers(createUDPServer, cfg.Ports.DNS), addServers(createUDPServer, cfg.Ports.DNS),
addServers(createTCPServer, cfg.Ports.DNS), addServers(createTCPServer, cfg.Ports.DNS),
addServers(func(address string) (*dns.Server, error) { addServers(func(address string) (*dns.Server, error) {
return createTLSServer(cfg, address, cert) return createTLSServer(address, tlsCfg)
}, cfg.Ports.TLS)) }, cfg.Ports.TLS))
return dnsServers, err.ErrorOrNil() return dnsServers, err.ErrorOrNil()
@ -222,17 +240,12 @@ func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener,
return listeners, nil return listeners, nil
} }
func createTLSServer(cfg *config.Config, address string, cert tls.Certificate) (*dns.Server, error) { func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) {
return &dns.Server{ return &dns.Server{
Addr: address, Addr: address,
Net: "tcp-tls", Net: "tcp-tls",
//nolint:gosec TLSConfig: tlsCfg,
TLSConfig: &tls.Config{ Handler: dns.NewServeMux(),
Certificates: []tls.Certificate{cert},
MinVersion: uint16(cfg.MinTLSServeVer),
CipherSuites: tlsCipherSuites(),
},
Handler: dns.NewServeMux(),
NotifyStartedFunc: func() { NotifyStartedFunc: func() {
logger().Infof("TLS server is up and running on address %s", address) logger().Infof("TLS server is up and running on address %s", address)
}, },
@ -508,12 +521,7 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
ReadTimeout: readTimeout, ReadTimeout: readTimeout,
ReadHeaderTimeout: readHeaderTimeout, ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout, WriteTimeout: writeTimeout,
//nolint:gosec TLSConfig: s.tlsCfg,
TLSConfig: &tls.Config{
MinVersion: uint16(s.cfg.MinTLSServeVer),
CipherSuites: tlsCipherSuites(),
Certificates: []tls.Certificate{s.cert},
},
} }
if err := server.ServeTLS(listener, "", ""); err != nil { if err := server.ServeTLS(listener, "", ""); err != nil {

View File

@ -745,7 +745,7 @@ var _ = Describe("Running DNS server", func() {
} }
sut, err := NewServer(ctx, &cfg) sut, err := NewServer(ctx, &cfg)
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(sut.cert.Certificate).ShouldNot(BeNil()) Expect(sut.tlsCfg.Certificates).ShouldNot(BeEmpty())
}) })
}) })
}) })