From 1d381420bff04819c6468cda686e8b525c064c1b Mon Sep 17 00:00:00 2001 From: Dmitry Fedotov Date: Tue, 23 Sep 2025 08:59:10 +0300 Subject: [PATCH] simplify API --- checks.go | 28 ++++++------- checks_test.go | 33 +++++---------- entity.go | 26 ++---------- watchdog.go | 102 ++++++++++++++--------------------------------- watchdog_test.go | 29 +++++++------- 5 files changed, 72 insertions(+), 146 deletions(-) diff --git a/checks.go b/checks.go index 3510b6e..cae02f5 100644 --- a/checks.go +++ b/checks.go @@ -35,32 +35,32 @@ func GetHTTP(addr string, timeout time.Duration) (CheckFunc, error) { timeout = DefaultTimeout } - return func(ctx context.Context) (Status, error) { + return func(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) if err != nil { - return StatusUnknown, fmt.Errorf("failed to create http request: %w", err) + return fmt.Errorf("failed to create http request: %w", err) } resp, err := http.DefaultClient.Do(req) if err != nil { - return StatusDown, fmt.Errorf("do request failed: %w", err) + return fmt.Errorf("do request failed: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return StatusDown, fmt.Errorf("err reading response body: %w", err) + return fmt.Errorf("err reading response body: %w", err) } if resp.StatusCode != http.StatusOK { - return StatusDown, fmt.Errorf("got HTTP response code %d, body: %s", resp.StatusCode, string(body)) + return fmt.Errorf("got HTTP response code %d, body: %s", resp.StatusCode, string(body)) } - return StatusOK, nil + return nil }, nil } @@ -83,27 +83,27 @@ func HeadHTTP(addr string, timeout time.Duration) (CheckFunc, error) { timeout = DefaultTimeout } - return func(ctx context.Context) (Status, error) { + return func(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodHead, u.String(), nil) if err != nil { - return StatusUnknown, fmt.Errorf("failed to create http request: %w", err) + return fmt.Errorf("failed to create http request: %w", err) } resp, err := http.DefaultClient.Do(req) if err != nil { - return StatusDown, fmt.Errorf("do request failed: %w", err) + return fmt.Errorf("do request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return StatusDown, fmt.Errorf("got HTTP response code %d", resp.StatusCode) + return fmt.Errorf("got HTTP response code %d", resp.StatusCode) } - return StatusOK, nil + return nil }, nil } @@ -121,7 +121,7 @@ func DialTCP(addr string, timeout time.Duration) (CheckFunc, error) { timeout = DefaultTimeout } - return func(ctx context.Context) (Status, error) { + return func(ctx context.Context) error { deadline := time.Now().Add(timeout) if t, ok := ctx.Deadline(); ok && t.Before(deadline) { deadline = t @@ -129,11 +129,11 @@ func DialTCP(addr string, timeout time.Duration) (CheckFunc, error) { conn, err := net.DialTimeout("tcp", addr, time.Until(deadline)) if err != nil { - return StatusDown, fmt.Errorf("error dialing: %w", err) + return fmt.Errorf("error dialing: %w", err) } defer conn.Close() - return StatusOK, nil + return nil }, nil } diff --git a/checks_test.go b/checks_test.go index a2417df..2c31c4c 100644 --- a/checks_test.go +++ b/checks_test.go @@ -22,14 +22,10 @@ func TestGetHTTPSuccess(t *testing.T) { t.Fatal(err) } - status, err := fn(t.Context()) + err = fn(t.Context()) if err != nil { t.Fatal(err) } - - if status != watchdog.StatusOK { - t.Fatalf("incorrect status %s, expected %s", status, watchdog.StatusOK) - } } func TestGetHTTPError(t *testing.T) { @@ -39,8 +35,8 @@ func TestGetHTTPError(t *testing.T) { t.Fatal(err) } - status, err := fn(t.Context()) - if status != watchdog.StatusDown || err == nil { + err = fn(t.Context()) + if err == nil { t.Errorf("incorrect status for unavalable host") } } @@ -57,14 +53,10 @@ func TestHeadHTTPSuccess(t *testing.T) { t.Fatal(err) } - status, err := fn(t.Context()) + err = fn(t.Context()) if err != nil { t.Fatal(err) } - - if status != watchdog.StatusOK { - t.Fatalf("incorrect status %s, expected %s", status, watchdog.StatusOK) - } } func TestHeadHTTPError(t *testing.T) { @@ -74,8 +66,8 @@ func TestHeadHTTPError(t *testing.T) { t.Fatal(err) } - status, err := fn(t.Context()) - if status != watchdog.StatusDown || err == nil { + err = fn(t.Context()) + if err == nil { t.Errorf("incorrect status for unavalable host") } } @@ -101,15 +93,10 @@ func TestDialTCPSuccess(t *testing.T) { t.Fatal(err) } - status, err := fn(t.Context()) + err = fn(t.Context()) if err != nil { t.Fatal(err) } - - if status != watchdog.StatusOK { - t.Error("incorrect status for available addr") - } - } func TestDialTCPError(t *testing.T) { @@ -120,8 +107,8 @@ func TestDialTCPError(t *testing.T) { t.Fatal(err) } - status, err := fn(t.Context()) - if (status != watchdog.StatusDown) || (err == nil) { - t.Errorf("incorrect status %s, expected %s", status, watchdog.StatusDown) + err = fn(t.Context()) + if err == nil { + t.Errorf("incorrect status") } } diff --git a/entity.go b/entity.go index cad809f..b8f3d75 100644 --- a/entity.go +++ b/entity.go @@ -5,30 +5,11 @@ import ( "time" ) -type Status uint8 - -const ( - StatusUnknown Status = iota - StatusOK - StatusDown -) - -func (s Status) String() string { - switch s { - case StatusOK: - return "OK" - case StatusDown: - return "DOWN" - default: - return "UNKNOWN" - } -} - // CheckFunc is a function that does the actual work. // This package provides a number of check functions but // any function matching the signature may be provided. // It must obey context dealine and cancellation. -type CheckFunc func(context.Context) (Status, error) +type CheckFunc func(context.Context) error // Check represents a check that must be run by Watchdog. type Check struct { @@ -38,7 +19,6 @@ type Check struct { } type CheckResult struct { - Name string // identifier of check - Status Status // status as retuned by CheckFunc - Error error // error returned by CheckFunc + Name string // identifier of check + Error error // error returned by CheckFunc } diff --git a/watchdog.go b/watchdog.go index ed26ad7..9273dd5 100644 --- a/watchdog.go +++ b/watchdog.go @@ -16,7 +16,7 @@ var ( // Watchdog keeps checks to run either periodically // or on demand. type Watchdog struct { - checks checksMap + checks map[string]*wdCheck mu sync.Mutex events chan CheckResult // output channel @@ -28,47 +28,6 @@ type Watchdog struct { running int // number of active checks monitored } -type checksMap struct { - m map[string]*wdCheck -} - -func (c *checksMap) build() { - if c.m == nil { - c.m = make(map[string]*wdCheck) - } -} - -func (c *checksMap) Map() map[string]*wdCheck { - c.build() - - return c.m -} - -func (c *checksMap) Set(key string, v *wdCheck) { - c.build() - - c.m[key] = v -} - -func (c *checksMap) Lookup(key string) (*wdCheck, bool) { - c.build() - - v, ok := c.m[key] - return v, ok -} - -func (c *checksMap) Delete(key string) { - c.build() - - delete(c.m, key) -} - -func (c *checksMap) Len() int { - c.build() - - return len(c.m) -} - type wdCheck struct { check Check stop chan struct{} @@ -77,13 +36,16 @@ type wdCheck struct { // New creates instance of Watchdog with // provided checks. func New(checks ...Check) *Watchdog { - w := Watchdog{} + w := Watchdog{ + checks: make(map[string]*wdCheck), + } + for _, c := range checks { nc := &wdCheck{ check: c, } - w.checks.Set(c.Name, nc) + w.checks[c.Name] = nc } return &w @@ -121,14 +83,18 @@ func (w *Watchdog) AddChecks(checks ...Check) { w.mu.Lock() defer w.mu.Unlock() + if w.checks == nil { + w.checks = make(map[string]*wdCheck) + } + for _, c := range checks { nc := &wdCheck{ check: c, } - old, haveOld := w.checks.Lookup(c.Name) + old, haveOld := w.checks[c.Name] - w.checks.Set(c.Name, nc) + w.checks[c.Name] = nc if w.monitoring { w.startMonitoring(nc) @@ -146,7 +112,7 @@ func (w *Watchdog) RemoveChecks(names ...string) { defer w.mu.Unlock() for _, name := range names { - c, ok := w.checks.Lookup(name) + c, ok := w.checks[name] if !ok { continue } @@ -155,7 +121,7 @@ func (w *Watchdog) RemoveChecks(names ...string) { w.stopMonitoring(c) } - w.checks.Delete(name) + delete(w.checks, name) } } @@ -172,7 +138,7 @@ func (w *Watchdog) Start(concurrency int) (<-chan CheckResult, error) { w.mu.Lock() defer w.mu.Unlock() - if w.checks.Len() == 0 { + if len(w.checks) == 0 { return nil, ErrNotConfigured } @@ -181,7 +147,7 @@ func (w *Watchdog) Start(concurrency int) (<-chan CheckResult, error) { } if concurrency == 0 { - concurrency = w.checks.Len() + concurrency = len(w.checks) } if w.timeout == 0 { @@ -191,7 +157,7 @@ func (w *Watchdog) Start(concurrency int) (<-chan CheckResult, error) { w.events = make(chan CheckResult, concurrency) w.limiter = make(chan struct{}, concurrency) - for _, c := range w.checks.Map() { + for _, c := range w.checks { w.startMonitoring(c) } @@ -208,7 +174,7 @@ func (w *Watchdog) Stop() error { return ErrNotRunning } - for _, c := range w.checks.Map() { + for _, c := range w.checks { w.stopMonitoring(c) } @@ -221,13 +187,13 @@ func (w *Watchdog) Stop() error { func (w *Watchdog) RunImmediately(ctx context.Context, concurrency int) ([]CheckResult, error) { w.mu.Lock() - if w.checks.Len() == 0 { + if len(w.checks) == 0 { w.mu.Unlock() return nil, ErrNotConfigured } cp := w.copyChecks() - w.mu.Unlock() // release + w.mu.Unlock() if concurrency == 0 { concurrency = len(cp) @@ -248,8 +214,8 @@ func (w *Watchdog) RunImmediately(ctx context.Context, concurrency int) ([]Check } func (w *Watchdog) copyChecks() []Check { - cp := make([]Check, 0, w.checks.Len()) - for _, v := range w.checks.Map() { + cp := make([]Check, 0, len(w.checks)) + for _, v := range w.checks { cp = append(cp, v.check) } @@ -267,11 +233,7 @@ func (w *Watchdog) startMonitoring(wdc *wdCheck) { w.running++ go func() { - state := CheckResult{ - // on first run return anything - // other that OK - Status: StatusOK, - } + var curr error = nil ticker := time.Tick(wdc.check.Interval) @@ -281,23 +243,22 @@ func (w *Watchdog) startMonitoring(wdc *wdCheck) { ctx, cancel := context.WithTimeout(context.Background(), w.timeout) defer cancel() - status, err := c.Check(ctx) + err := c.Check(ctx) <-w.limiter r := CheckResult{ - Name: c.Name, - Status: status, - Error: err, + Name: c.Name, + Error: err, } // if status changed or we've got an error // then report this - if r.Status != state.Status || r.Error != nil { + if !errors.Is(r.Error, curr) { w.events <- r } - state = r + curr = r.Error select { case <-ticker: @@ -336,12 +297,11 @@ func runChecksConcurrently(ctx context.Context, ch []Check, concurrency int) []C // relying on assumption that CheckFunc obeys context // cancellation - status, err := e.Check(ctx) + err := e.Check(ctx) r := CheckResult{ - Name: e.Name, - Status: status, - Error: err, + Name: e.Name, + Error: err, } done <- r diff --git a/watchdog_test.go b/watchdog_test.go index 297d9f7..a777612 100644 --- a/watchdog_test.go +++ b/watchdog_test.go @@ -12,20 +12,19 @@ import ( type mockChecker struct { name string - status watchdog.Status err error called bool } -func (m *mockChecker) Func(ctx context.Context) (watchdog.Status, error) { +func (m *mockChecker) Func(ctx context.Context) error { m.called = true time.Sleep(time.Millisecond * 10) if err := ctx.Err(); err != nil { - return watchdog.StatusUnknown, err + return err } - return m.status, m.err + return m.err } func (m *mockChecker) HasBeenCalled() bool { @@ -40,8 +39,8 @@ func (m *mockChecker) Check() watchdog.Check { } } -func newMockChecker(name string, s watchdog.Status, err error) *mockChecker { - return &mockChecker{name: name, status: s, err: err} +func newMockChecker(name string, err error) *mockChecker { + return &mockChecker{name: name, err: err} } func TestCreateWith_new(t *testing.T) { @@ -59,8 +58,8 @@ func TestNew(t *testing.T) { t.Errorf("expected len = 0") } - m1 := newMockChecker("mock", watchdog.StatusOK, nil) - m2 := newMockChecker("mock2", watchdog.StatusOK, nil) + m1 := newMockChecker("mock", nil) + m2 := newMockChecker("mock2", nil) w = watchdog.New() w.AddChecks(m1.Check(), m2.Check()) @@ -96,8 +95,8 @@ func TestRunImmediately(t *testing.T) { t.Errorf("expected zero len slice for empty instance, got %d", len(out)) } - m1 := newMockChecker("mock", watchdog.StatusOK, nil) - m2 := newMockChecker("mock2", watchdog.StatusOK, nil) + m1 := newMockChecker("mock", nil) + m2 := newMockChecker("mock2", nil) w = watchdog.New(m1.Check(), m2.Check()) out, _ = w.RunImmediately(t.Context(), 0) @@ -119,8 +118,8 @@ func TestStartStop(t *testing.T) { t.Error("Start doen't error on empty checks slice") } - m1 := newMockChecker("mock", watchdog.StatusOK, nil) - m2 := newMockChecker("mock2", watchdog.StatusOK, nil) + m1 := newMockChecker("mock", nil) + m2 := newMockChecker("mock2", nil) w.AddChecks(m1.Check(), m2.Check()) @@ -167,7 +166,7 @@ func TestStartStop(t *testing.T) { func TestSetTimeout(t *testing.T) { w := new(watchdog.Watchdog) w.SetTimeout(time.Millisecond) - m1 := newMockChecker("mock", watchdog.StatusOK, nil) + m1 := newMockChecker("mock", nil) w.AddChecks(m1.Check()) out, _ := w.Start(0) @@ -175,8 +174,8 @@ func TestSetTimeout(t *testing.T) { w.Stop() res := <-out - if !(res.Status == watchdog.StatusUnknown) || !errors.Is(res.Error, context.DeadlineExceeded) { - t.Logf("got status: %s, err: %v", res.Status, res.Error) + if !errors.Is(res.Error, context.DeadlineExceeded) { + t.Logf("got err: %v", res.Error) t.Fatal("incorrect status for timed out op") } }