diff --git a/README.md b/README.md index b9fb861..600d760 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ func main() { Processor: processor, } - if err := script.Run(ctx, conf); err != nil { + if _, err := script.Run(ctx, conf); err != nil { fmt.Println(err) } } @@ -130,8 +130,8 @@ func main() { Processor: p.Process, // Process implements script.Processor } - if err := script.Run(ctx, conf); err != nil { + if _, err := script.Run(ctx, conf); err != nil { fmt.Println(err) } } -``` \ No newline at end of file +``` diff --git a/chain_processor.go b/chain_processor.go new file mode 100644 index 0000000..07a84b1 --- /dev/null +++ b/chain_processor.go @@ -0,0 +1,23 @@ +package script + +import "context" + +// Chain chains provided Processors. +// When an error is returned by a Processor in chain, processing +// stops and the error is retuned without running further stages. +func Chain(processors ...Processor) Processor { + return func(ctx context.Context, in []string) ([]string, error) { + var err error + for _, p := range processors { + // not checking ctx expiry here, + // let the processor handle it + + in, err = p(ctx, in) + if err != nil { + return nil, err + } + } + + return in, nil + } +} diff --git a/chain_processor_test.go b/chain_processor_test.go new file mode 100644 index 0000000..671b7ab --- /dev/null +++ b/chain_processor_test.go @@ -0,0 +1,28 @@ +package script + +import ( + "context" + "slices" + "testing" +) + +func TestChain(t *testing.T) { + p := func(_ context.Context, in []string) ([]string, error) { + in[0] = in[0] + in[0] + return in, nil + } + + chain := Chain(p, p, p) + + in := []string{"a"} + want := []string{"aaaaaaaa"} + + res, err := chain(t.Context(), in) + if err != nil { + t.Fatal(err) + } + + if !slices.Equal(res, want) { + t.Fatalf("slices are not equal, have: %+v, want: %+v", res, want) + } +} diff --git a/runner.go b/runner.go index c0e88a7..48fae07 100644 --- a/runner.go +++ b/runner.go @@ -5,13 +5,14 @@ import ( "errors" "fmt" "io" + "sync/atomic" "golang.org/x/sync/errgroup" ) var ( EOF error = io.EOF - ErrNoProcessors = errors.New("no processors provided") + ErrNoProcessors = errors.New("script: no processors provided") ) var ( @@ -51,17 +52,22 @@ type RunConfig struct { Concurrency int } +type RunResult struct { + Read int // number of records read without error (offset count not included) + Processed int // number of records processed without error + Written int // number of records written to Writer without error +} + // Run starts the script described by r. // First Read is called offset times with output of Read being discarded. // Then limit Reads are made and processor is called for each portion -// of data. If limit is 0 then Run keep processing input until it receives +// of data. If limit is 0 then Run keeps processing input until it receives // EOF from Reader. // Run fails on any error including Reader error, Writer error and Processor error. -// If an error is encountered the writer operation will be attampted anyway so that -// the output is left in consistent state, recording what has been actually done -// by Processor. -func Run(ctx context.Context, r RunConfig) error { - if r.Concurrency == 0 { +// The returned RunResult is AWAYS VALID and indicates the actual progress of script. +// Returned error explains why Run failed. It may be either read, process or write error. +func Run(ctx context.Context, r RunConfig) (RunResult, error) { + if r.Concurrency <= 0 { r.Concurrency = 1 } @@ -70,27 +76,33 @@ func Run(ctx context.Context, r RunConfig) error { rdch := make(chan []string, r.Concurrency) wrch := make(chan []string, r.Concurrency) + var read, proc, written uint32 + // read input from Reader and forward to Processor grp.Go(func() error { - // closing chan for processor to complete operations + // closing chan to Processor defer close(rdch) for range r.Offset { _, err := r.Input.Read() if err != nil { - return fmt.Errorf("could not advance to required offset: %w", err) + return fmt.Errorf("script: could not advance to required offset (%d): %s", r.Offset, err) } } count := 0 + for { inp, err := r.Input.Read() - if err != nil && errors.Is(err, EOF) { + if errors.Is(err, EOF) { return nil } else if err != nil { - return err + return fmt.Errorf("script: read error: %s", err) } + // increment read count + read++ + select { case rdch <- inp: case <-ctx.Done(): @@ -101,7 +113,7 @@ func Run(ctx context.Context, r RunConfig) error { count++ - if count == r.Limit { // will never happen if limit set to 0 + if count == r.Limit { // will never happen if limit has been set to 0 return nil } } @@ -109,14 +121,23 @@ func Run(ctx context.Context, r RunConfig) error { // read output of Processor and write to Writer grp.Go(func() error { + defer func() { + for range wrch { + // NOP to drain channel + } + }() + // not paying attention to context here // because we must complete writes // this is run within group so that write // error would cancel group context for outp := range wrch { if err := r.Output.Write(outp); err != nil { - return err + return fmt.Errorf("script: write error: %s", err) } + + //increment write count + written++ } return nil @@ -124,27 +145,33 @@ func Run(ctx context.Context, r RunConfig) error { // run processing routines grp.Go(func() error { - // will close write chan once - // all workers are done + // closing chan to Writer defer close(wrch) + defer func() { + for range rdch { + // NOP to drain channel + } + }() - workergrp, innrctx := errgroup.WithContext(ctx) + workergrp := errgroup.Group{} for range r.Concurrency { workergrp.Go(func() error { - // not paying attention to context here - // because we must complete writes for inp := range rdch { - result, err := r.Processor(innrctx, inp) + result, err := r.Processor(ctx, inp) if err != nil { - return err + return fmt.Errorf("script: process error: %s", err) } - wrch <- result + // increment processed count + atomic.AddUint32(&proc, 1) - // if one of workers died or parent context expired - // we should die too - if err := innrctx.Err(); err != nil { + select { + case wrch <- result: + case <-ctx.Done(): + // this case is a must if writer fails + // otherwise we'd want to push process result + // to wrch return nil } } @@ -160,9 +187,11 @@ func Run(ctx context.Context, r RunConfig) error { return nil }) - if err := grp.Wait(); err != nil { - return err - } + err := grp.Wait() // if this is a context expiry then error is nil - return nil + return RunResult{ + Read: int(read), + Processed: int(proc), + Written: int(written), + }, err } diff --git a/runner_test.go b/runner_test.go index 3f2c53f..abb09e0 100644 --- a/runner_test.go +++ b/runner_test.go @@ -28,7 +28,8 @@ func TestBasicRun(t *testing.T) { Processor: echoProcessor, } - if err := script.Run(t.Context(), conf); err != nil { + res, err := script.Run(t.Context(), conf) + if err != nil { t.Fatal(err) } @@ -37,6 +38,10 @@ func TestBasicRun(t *testing.T) { if !reflect.DeepEqual(input, output) { t.Errorf("incorrect output, want: %v, got: %v", input, output) } + + if res.Read != 1 || res.Processed != 1 || res.Written != 1 { + t.Fatal("incorrect process result, want all fields to equal 1") + } } type infiniteReader struct{} @@ -60,7 +65,7 @@ func TestRunnerObeysContext(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond) defer cancel() - if err := script.Run(ctx, conf); err != nil { + if _, err := script.Run(ctx, conf); err != nil { t.Fatal(err) } }