diff --git a/checks.go b/checks.go index 3510b6e..cae02f5 100644 --- a/checks.go +++ b/checks.go @@ -35,32 +35,32 @@ func GetHTTP(addr string, timeout time.Duration) (CheckFunc, error) { timeout = DefaultTimeout } - return func(ctx context.Context) (Status, error) { + return func(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) if err != nil { - return StatusUnknown, fmt.Errorf("failed to create http request: %w", err) + return fmt.Errorf("failed to create http request: %w", err) } resp, err := http.DefaultClient.Do(req) if err != nil { - return StatusDown, fmt.Errorf("do request failed: %w", err) + return fmt.Errorf("do request failed: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return StatusDown, fmt.Errorf("err reading response body: %w", err) + return fmt.Errorf("err reading response body: %w", err) } if resp.StatusCode != http.StatusOK { - return StatusDown, fmt.Errorf("got HTTP response code %d, body: %s", resp.StatusCode, string(body)) + return fmt.Errorf("got HTTP response code %d, body: %s", resp.StatusCode, string(body)) } - return StatusOK, nil + return nil }, nil } @@ -83,27 +83,27 @@ func HeadHTTP(addr string, timeout time.Duration) (CheckFunc, error) { timeout = DefaultTimeout } - return func(ctx context.Context) (Status, error) { + return func(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodHead, u.String(), nil) if err != nil { - return StatusUnknown, fmt.Errorf("failed to create http request: %w", err) + return fmt.Errorf("failed to create http request: %w", err) } resp, err := http.DefaultClient.Do(req) if err != nil { - return StatusDown, fmt.Errorf("do request failed: %w", err) + return fmt.Errorf("do request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return StatusDown, fmt.Errorf("got HTTP response code %d", resp.StatusCode) + return fmt.Errorf("got HTTP response code %d", resp.StatusCode) } - return StatusOK, nil + return nil }, nil } @@ -121,7 +121,7 @@ func DialTCP(addr string, timeout time.Duration) (CheckFunc, error) { timeout = DefaultTimeout } - return func(ctx context.Context) (Status, error) { + return func(ctx context.Context) error { deadline := time.Now().Add(timeout) if t, ok := ctx.Deadline(); ok && t.Before(deadline) { deadline = t @@ -129,11 +129,11 @@ func DialTCP(addr string, timeout time.Duration) (CheckFunc, error) { conn, err := net.DialTimeout("tcp", addr, time.Until(deadline)) if err != nil { - return StatusDown, fmt.Errorf("error dialing: %w", err) + return fmt.Errorf("error dialing: %w", err) } defer conn.Close() - return StatusOK, nil + return nil }, nil } diff --git a/checks_test.go b/checks_test.go index a2417df..f113e77 100644 --- a/checks_test.go +++ b/checks_test.go @@ -11,6 +11,7 @@ import ( ) func TestGetHTTPSuccess(t *testing.T) { + t.Parallel() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) })) @@ -22,30 +23,28 @@ func TestGetHTTPSuccess(t *testing.T) { t.Fatal(err) } - status, err := fn(t.Context()) + err = fn(t.Context()) if err != nil { t.Fatal(err) } - - if status != watchdog.StatusOK { - t.Fatalf("incorrect status %s, expected %s", status, watchdog.StatusOK) - } } func TestGetHTTPError(t *testing.T) { + t.Parallel() addr := "https://127.0.0.1:42014" fn, err := watchdog.GetHTTP(addr, time.Second) if err != nil { t.Fatal(err) } - status, err := fn(t.Context()) - if status != watchdog.StatusDown || err == nil { + err = fn(t.Context()) + if err == nil { t.Errorf("incorrect status for unavalable host") } } func TestHeadHTTPSuccess(t *testing.T) { + t.Parallel() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) @@ -57,30 +56,28 @@ func TestHeadHTTPSuccess(t *testing.T) { t.Fatal(err) } - status, err := fn(t.Context()) + err = fn(t.Context()) if err != nil { t.Fatal(err) } - - if status != watchdog.StatusOK { - t.Fatalf("incorrect status %s, expected %s", status, watchdog.StatusOK) - } } func TestHeadHTTPError(t *testing.T) { + t.Parallel() addr := "https://127.0.0.1:42014" fn, err := watchdog.HeadHTTP(addr, time.Second) if err != nil { t.Fatal(err) } - status, err := fn(t.Context()) - if status != watchdog.StatusDown || err == nil { + err = fn(t.Context()) + if err == nil { t.Errorf("incorrect status for unavalable host") } } func TestDialTCPSuccess(t *testing.T) { + t.Parallel() addr := "127.0.0.1:42013" lis, err := net.Listen("tcp", addr) if err != nil { @@ -101,18 +98,14 @@ func TestDialTCPSuccess(t *testing.T) { t.Fatal(err) } - status, err := fn(t.Context()) + err = fn(t.Context()) if err != nil { t.Fatal(err) } - - if status != watchdog.StatusOK { - t.Error("incorrect status for available addr") - } - } func TestDialTCPError(t *testing.T) { + t.Parallel() addr := "127.0.0.1:65535" // check for non-existent addr fn, err := watchdog.DialTCP(addr, time.Second) @@ -120,8 +113,8 @@ func TestDialTCPError(t *testing.T) { t.Fatal(err) } - status, err := fn(t.Context()) - if (status != watchdog.StatusDown) || (err == nil) { - t.Errorf("incorrect status %s, expected %s", status, watchdog.StatusDown) + err = fn(t.Context()) + if err == nil { + t.Errorf("incorrect status") } } diff --git a/entity.go b/entity.go index cad809f..b8f3d75 100644 --- a/entity.go +++ b/entity.go @@ -5,30 +5,11 @@ import ( "time" ) -type Status uint8 - -const ( - StatusUnknown Status = iota - StatusOK - StatusDown -) - -func (s Status) String() string { - switch s { - case StatusOK: - return "OK" - case StatusDown: - return "DOWN" - default: - return "UNKNOWN" - } -} - // CheckFunc is a function that does the actual work. // This package provides a number of check functions but // any function matching the signature may be provided. // It must obey context dealine and cancellation. -type CheckFunc func(context.Context) (Status, error) +type CheckFunc func(context.Context) error // Check represents a check that must be run by Watchdog. type Check struct { @@ -38,7 +19,6 @@ type Check struct { } type CheckResult struct { - Name string // identifier of check - Status Status // status as retuned by CheckFunc - Error error // error returned by CheckFunc + Name string // identifier of check + Error error // error returned by CheckFunc } diff --git a/go.mod b/go.mod index bf9d428..c02d419 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,5 @@ module code.uint32.ru/tiny/watchdog -go 1.24.4 +go 1.24 require golang.org/x/sync v0.15.0 diff --git a/watchdog.go b/watchdog.go index d3da178..345b6fd 100644 --- a/watchdog.go +++ b/watchdog.go @@ -16,17 +16,16 @@ var ( // Watchdog keeps checks to run either periodically // or on demand. type Watchdog struct { - checks []*wdCheck + checks map[string]*wdCheck mu sync.Mutex - monitoring bool // is monitoring currently in progress - events chan CheckResult // output channel limiter chan struct{} // TODO: use proper limiter here - timeout time.Duration + timeout time.Duration // timeout for checks to complete - running int + monitoring bool // is monitoring currently in progress + running int // number of active checks monitored } type wdCheck struct { @@ -37,29 +36,26 @@ type wdCheck struct { // New creates instance of Watchdog with // provided checks. func New(checks ...Check) *Watchdog { - ch := make([]*wdCheck, len(checks)) + w := Watchdog{ + checks: make(map[string]*wdCheck), + } - for i := range checks { - ch[i] = &wdCheck{ - check: checks[i], + for _, c := range checks { + nc := &wdCheck{ + check: c, } + + w.checks[c.Name] = nc } - w := &Watchdog{ - checks: ch, - } - - return w + return &w } func (w *Watchdog) ListChecks() []Check { w.mu.Lock() defer w.mu.Unlock() - out := make([]Check, len(w.checks)) - for i := range w.checks { - out[i] = w.checks[i].check - } + out := w.copyChecks() return out } @@ -81,20 +77,31 @@ func (w *Watchdog) SetTimeout(d time.Duration) { // AddChecks adds checks to the group. // If monitoring is in progress then monitoring it started for the newly added // check as well. -// Check may have duplicate Name fields but note that RemoveChecks removes checks -// by their Name fields. +// Check may have not have duplicate Name fields. New check with the same +// hame overwrites the previous one. func (w *Watchdog) AddChecks(checks ...Check) { w.mu.Lock() defer w.mu.Unlock() - for i := range checks { + if w.checks == nil { + w.checks = make(map[string]*wdCheck) + } + + for _, c := range checks { nc := &wdCheck{ - check: checks[i], + check: c, } - w.checks = append(w.checks, nc) + + old, haveOld := w.checks[c.Name] + + w.checks[c.Name] = nc if w.monitoring { w.startMonitoring(nc) + + if haveOld { + w.stopMonitoring(old) + } } } } @@ -104,19 +111,18 @@ func (w *Watchdog) RemoveChecks(names ...string) { w.mu.Lock() defer w.mu.Unlock() - remaining := make([]*wdCheck, 0, len(w.checks)-len(names)) - for _, c := range w.checks { - if slices.Contains(names, c.check.Name) { - if w.monitoring { - w.stopMonitoring(c) - } + for _, name := range names { + c, ok := w.checks[name] + if !ok { continue } - remaining = append(remaining, c) - } + if w.monitoring { + w.stopMonitoring(c) + } - w.checks = remaining + delete(w.checks, name) + } } // Start starts monitoring. @@ -151,12 +157,10 @@ func (w *Watchdog) Start(concurrency int) (<-chan CheckResult, error) { w.events = make(chan CheckResult, concurrency) w.limiter = make(chan struct{}, concurrency) - for i := range w.checks { - w.startMonitoring(w.checks[i]) + for _, c := range w.checks { + w.startMonitoring(c) } - w.monitoring = true - return w.events, nil } @@ -170,8 +174,8 @@ func (w *Watchdog) Stop() error { return ErrNotRunning } - for i := range w.checks { - w.stopMonitoring(w.checks[i]) + for _, c := range w.checks { + w.stopMonitoring(c) } return nil @@ -182,13 +186,14 @@ func (w *Watchdog) Stop() error { // Otherwise at most concurrency checks will be allowed to run simultaneously. func (w *Watchdog) RunImmediately(ctx context.Context, concurrency int) ([]CheckResult, error) { w.mu.Lock() + if len(w.checks) == 0 { w.mu.Unlock() return nil, ErrNotConfigured } cp := w.copyChecks() - w.mu.Unlock() // release + w.mu.Unlock() if concurrency == 0 { concurrency = len(cp) @@ -209,43 +214,26 @@ func (w *Watchdog) RunImmediately(ctx context.Context, concurrency int) ([]Check } func (w *Watchdog) copyChecks() []Check { - cp := make([]Check, len(w.checks)) - for i := range w.checks { - cp[i] = w.checks[i].check + cp := make([]Check, 0, len(w.checks)) + for _, v := range w.checks { + cp = append(cp, v.check) } return cp } func (w *Watchdog) startMonitoring(wdc *wdCheck) { - wdc.stop = make(chan struct{}) c := wdc.check - // this method is called only with - // w.mu locked + if !w.monitoring { + w.monitoring = true + } + w.running++ go func() { - defer func() { - w.mu.Lock() - defer w.mu.Unlock() - - w.running-- - if w.running == 0 { - // last goroutine to exit will also - // close the output chan - close(w.events) - w.monitoring = false - } - }() - - state := CheckResult{ - // if first run return anything - // other that OK, we'll report it - // if first run is OK, then we do not need to report - Status: StatusOK, - } + var curr error = nil ticker := time.Tick(wdc.check.Interval) @@ -255,21 +243,22 @@ func (w *Watchdog) startMonitoring(wdc *wdCheck) { ctx, cancel := context.WithTimeout(context.Background(), w.timeout) defer cancel() - status, err := c.Check(ctx) + err := c.Check(ctx) <-w.limiter - s := CheckResult{ - Name: c.Name, - Status: status, - Error: err, + r := CheckResult{ + Name: c.Name, + Error: err, } - if s.Status != state.Status || s.Error != nil { - w.events <- s + if (err != nil && curr == nil) || + (curr != nil && err == nil) { + // status changed, let's report + w.events <- r } - state = s + curr = err select { case <-ticker: @@ -284,50 +273,52 @@ func (w *Watchdog) startMonitoring(wdc *wdCheck) { func (w *Watchdog) stopMonitoring(wdc *wdCheck) { close(wdc.stop) + w.running-- + + if w.running == 0 { + w.monitoring = false + close(w.events) + } } func runChecksConcurrently(ctx context.Context, ch []Check, concurrency int) []CheckResult { - statuses := make([]CheckResult, 0, len(ch)) - m := sync.Mutex{} // for append operations - sema := make(chan struct{}, concurrency) // semaphore to limit concurrency - done := make(chan struct{}, len(ch)) - - count := len(ch) + done := make(chan CheckResult, len(ch)) + wg := new(sync.WaitGroup) + wg.Add(len(ch)) for _, e := range ch { - sema <- struct{}{} // acquire - go func() error { + go func() { + sema <- struct{}{} // acquire defer func() { - <-sema - done <- struct{}{} + <-sema // release + wg.Done() }() // relying on assumption that CheckFunc obeys context // cancellation - status, err := e.Check(ctx) + err := e.Check(ctx) r := CheckResult{ - Name: e.Name, - Status: status, - Error: err, + Name: e.Name, + Error: err, } - m.Lock() - defer m.Unlock() - statuses = append(statuses, r) - - return nil + done <- r }() } - // wait for all to finish - for range done { - count-- - if count == 0 { - close(done) - } + go func() { + wg.Wait() + close(done) + }() + + results := make([]CheckResult, 0, len(ch)) + + // collect results + for r := range done { + results = append(results, r) } - return statuses + return results } diff --git a/watchdog_test.go b/watchdog_test.go index 297d9f7..ba24417 100644 --- a/watchdog_test.go +++ b/watchdog_test.go @@ -3,6 +3,8 @@ package watchdog_test import ( "context" "errors" + "net/http" + "net/http/httptest" "reflect" "testing" "time" @@ -12,20 +14,19 @@ import ( type mockChecker struct { name string - status watchdog.Status err error called bool } -func (m *mockChecker) Func(ctx context.Context) (watchdog.Status, error) { +func (m *mockChecker) Func(ctx context.Context) error { m.called = true time.Sleep(time.Millisecond * 10) if err := ctx.Err(); err != nil { - return watchdog.StatusUnknown, err + return err } - return m.status, m.err + return m.err } func (m *mockChecker) HasBeenCalled() bool { @@ -40,8 +41,8 @@ func (m *mockChecker) Check() watchdog.Check { } } -func newMockChecker(name string, s watchdog.Status, err error) *mockChecker { - return &mockChecker{name: name, status: s, err: err} +func newMockChecker(name string, err error) *mockChecker { + return &mockChecker{name: name, err: err} } func TestCreateWith_new(t *testing.T) { @@ -59,8 +60,8 @@ func TestNew(t *testing.T) { t.Errorf("expected len = 0") } - m1 := newMockChecker("mock", watchdog.StatusOK, nil) - m2 := newMockChecker("mock2", watchdog.StatusOK, nil) + m1 := newMockChecker("mock", nil) + m2 := newMockChecker("mock2", nil) w = watchdog.New() w.AddChecks(m1.Check(), m2.Check()) @@ -96,8 +97,8 @@ func TestRunImmediately(t *testing.T) { t.Errorf("expected zero len slice for empty instance, got %d", len(out)) } - m1 := newMockChecker("mock", watchdog.StatusOK, nil) - m2 := newMockChecker("mock2", watchdog.StatusOK, nil) + m1 := newMockChecker("mock", nil) + m2 := newMockChecker("mock2", nil) w = watchdog.New(m1.Check(), m2.Check()) out, _ = w.RunImmediately(t.Context(), 0) @@ -119,8 +120,8 @@ func TestStartStop(t *testing.T) { t.Error("Start doen't error on empty checks slice") } - m1 := newMockChecker("mock", watchdog.StatusOK, nil) - m2 := newMockChecker("mock2", watchdog.StatusOK, nil) + m1 := newMockChecker("mock", nil) + m2 := newMockChecker("mock2", nil) w.AddChecks(m1.Check(), m2.Check()) @@ -167,7 +168,7 @@ func TestStartStop(t *testing.T) { func TestSetTimeout(t *testing.T) { w := new(watchdog.Watchdog) w.SetTimeout(time.Millisecond) - m1 := newMockChecker("mock", watchdog.StatusOK, nil) + m1 := newMockChecker("mock", nil) w.AddChecks(m1.Check()) out, _ := w.Start(0) @@ -175,8 +176,67 @@ func TestSetTimeout(t *testing.T) { w.Stop() res := <-out - if !(res.Status == watchdog.StatusUnknown) || !errors.Is(res.Error, context.DeadlineExceeded) { - t.Logf("got status: %s, err: %v", res.Status, res.Error) + if !errors.Is(res.Error, context.DeadlineExceeded) { + t.Logf("got err: %v", res.Error) t.Fatal("incorrect status for timed out op") } } + +func TestDuplicates(t *testing.T) { + gethandler := func() http.HandlerFunc { + count := 0 + return func(w http.ResponseWriter, r *http.Request) { + count++ + if count >= 10 { + // we, re up again + w.Write([]byte("OK")) + return + } + + http.Error(w, "down", http.StatusInternalServerError) + } + } + + srv := httptest.NewServer(gethandler()) + + addr := srv.URL + + fn, err := watchdog.GetHTTP(addr, time.Second) + if err != nil { + t.Fatal(err) + } + + w := new(watchdog.Watchdog) + + check := watchdog.Check{ + Name: "test", + Interval: time.Millisecond * 10, + Check: fn, + } + + w.AddChecks(check) + w.SetTimeout(time.Millisecond * 10) + + out, err := w.Start(1000) + if err != nil { + t.Error("Start returns error", err) + } + + go func() { + time.Sleep(time.Second) + w.Stop() + }() + + count := 0 + + for range out { + count++ + } + + if count != 2 { + // must only return one event on initial failure and + // one event whet endpoint becomes available + t.Error("incorrect result count received from chan") + t.Log("received count is", count) + } +}