blocky/cmd/blocking_test.go

176 lines
5.4 KiB
Go

package cmd
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"github.com/sirupsen/logrus/hooks/test"
"github.com/0xERR0R/blocky/api"
"github.com/0xERR0R/blocky/log"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Blocking command", func() {
var (
ts *httptest.Server
mockFn func(w http.ResponseWriter, _ *http.Request)
loggerHook *test.Hook
)
JustBeforeEach(func() {
ts = testHTTPAPIServer(mockFn)
})
JustAfterEach(func() {
ts.Close()
})
BeforeEach(func() {
mockFn = func(w http.ResponseWriter, _ *http.Request) {}
loggerHook = test.NewGlobal()
log.Log().AddHook(loggerHook)
})
AfterEach(func() {
loggerHook.Reset()
})
Describe("enable blocking", func() {
When("Enable blocking is called via REST", func() {
It("should enable the blocking status", func() {
Expect(disableBlocking(newBlockingCommand(), []string{})).Should(Succeed())
Expect(loggerHook.LastEntry().Message).Should(Equal("OK"))
})
})
When("Wrong url is used", func() {
It("Should end with error", func() {
apiPort = 0
err := enableBlocking(newBlockingCommand(), []string{})
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("connection refused"))
})
})
When("Server returns internal error", func() {
BeforeEach(func() {
mockFn = func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}
})
It("Should end with error", func() {
err := enableBlocking(newBlockingCommand(), []string{})
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("500 Internal Server Error"))
})
})
})
Describe("disable blocking", func() {
When("disable blocking is called via REST", func() {
It("should enable the blocking status", func() {
Expect(disableBlocking(newBlockingCommand(), []string{})).Should(Succeed())
Expect(loggerHook.LastEntry().Message).Should(Equal("OK"))
})
})
When("Wrong url is used", func() {
It("Should end with error", func() {
apiPort = 0
err := disableBlocking(newBlockingCommand(), []string{})
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("connection refused"))
})
})
When("Server returns internal error", func() {
BeforeEach(func() {
mockFn = func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}
})
It("Should end with error", func() {
err := disableBlocking(newBlockingCommand(), []string{})
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("500 Internal Server Error"))
})
})
})
Describe("status blocking", func() {
When("status blocking is called via REST and blocking is enabled", func() {
BeforeEach(func() {
mockFn = func(w http.ResponseWriter, _ *http.Request) {
w.Header().Add("Content-Type", "application/json")
i := 5
response, err := json.Marshal(api.ApiBlockingStatus{
Enabled: true,
AutoEnableInSec: &i,
})
Expect(err).Should(Succeed())
_, err = w.Write(response)
Expect(err).Should(Succeed())
}
})
It("should query the blocking status", func() {
Expect(statusBlocking(newBlockingCommand(), []string{})).Should(Succeed())
Expect(loggerHook.LastEntry().Message).Should(Equal("blocking enabled"))
})
})
When("status blocking is called via REST and blocking is disabled", func() {
var autoEnable int
diabledGroups := []string{"abc"}
BeforeEach(func() {
mockFn = func(w http.ResponseWriter, _ *http.Request) {
w.Header().Add("Content-Type", "application/json")
response, err := json.Marshal(api.ApiBlockingStatus{
Enabled: false,
AutoEnableInSec: &autoEnable,
DisabledGroups: &diabledGroups,
})
Expect(err).Should(Succeed())
_, err = w.Write(response)
Expect(err).Should(Succeed())
}
})
It("should show the blocking status with time", func() {
autoEnable = 5
Expect(statusBlocking(newBlockingCommand(), []string{})).Should(Succeed())
Expect(loggerHook.LastEntry().Message).Should(Equal("blocking disabled for groups: 'abc', for 5 seconds"))
})
It("should show the blocking status", func() {
autoEnable = 0
Expect(statusBlocking(newBlockingCommand(), []string{})).Should(Succeed())
Expect(loggerHook.LastEntry().Message).Should(Equal("blocking disabled for groups: abc"))
})
})
When("Wrong url is used", func() {
It("Should end with error", func() {
apiPort = 0
err := statusBlocking(newBlockingCommand(), []string{})
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("connection refused"))
})
})
When("Server returns internal error", func() {
BeforeEach(func() {
mockFn = func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}
})
It("Should end with error", func() {
err := statusBlocking(newBlockingCommand(), []string{})
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("500 Internal Server Error"))
})
})
})
})
func testHTTPAPIServer(fn func(w http.ResponseWriter, _ *http.Request)) *httptest.Server {
ts := httptest.NewServer(http.HandlerFunc(fn))
u, _ := url.Parse(ts.URL)
apiHost = u.Hostname()
port, _ := strconv.Atoi(u.Port())
apiPort = uint16(port)
return ts
}