diff --git a/chain_processor.go b/chain_processor.go new file mode 100644 index 0000000..07a84b1 --- /dev/null +++ b/chain_processor.go @@ -0,0 +1,23 @@ +package script + +import "context" + +// Chain chains provided Processors. +// When an error is returned by a Processor in chain, processing +// stops and the error is retuned without running further stages. +func Chain(processors ...Processor) Processor { + return func(ctx context.Context, in []string) ([]string, error) { + var err error + for _, p := range processors { + // not checking ctx expiry here, + // let the processor handle it + + in, err = p(ctx, in) + if err != nil { + return nil, err + } + } + + return in, nil + } +} diff --git a/chain_processor_test.go b/chain_processor_test.go new file mode 100644 index 0000000..671b7ab --- /dev/null +++ b/chain_processor_test.go @@ -0,0 +1,28 @@ +package script + +import ( + "context" + "slices" + "testing" +) + +func TestChain(t *testing.T) { + p := func(_ context.Context, in []string) ([]string, error) { + in[0] = in[0] + in[0] + return in, nil + } + + chain := Chain(p, p, p) + + in := []string{"a"} + want := []string{"aaaaaaaa"} + + res, err := chain(t.Context(), in) + if err != nil { + t.Fatal(err) + } + + if !slices.Equal(res, want) { + t.Fatalf("slices are not equal, have: %+v, want: %+v", res, want) + } +} diff --git a/runner.go b/runner.go index c0e88a7..b01c0c3 100644 --- a/runner.go +++ b/runner.go @@ -72,20 +72,20 @@ func Run(ctx context.Context, r RunConfig) error { // read input from Reader and forward to Processor grp.Go(func() error { - // closing chan for processor to complete operations + // closing chan to Processor defer close(rdch) for range r.Offset { _, err := r.Input.Read() if err != nil { - return fmt.Errorf("could not advance to required offset: %w", err) + return fmt.Errorf("could not advance to required offset (%d): %w", r.Offset, err) } } count := 0 for { inp, err := r.Input.Read() - if err != nil && errors.Is(err, EOF) { + if errors.Is(err, EOF) { return nil } else if err != nil { return err @@ -101,7 +101,7 @@ func Run(ctx context.Context, r RunConfig) error { count++ - if count == r.Limit { // will never happen if limit set to 0 + if count == r.Limit { // will never happen if limit has been set to 0 return nil } } @@ -124,27 +124,22 @@ func Run(ctx context.Context, r RunConfig) error { // run processing routines grp.Go(func() error { - // will close write chan once - // all workers are done + // closing chan to Writer defer close(wrch) - workergrp, innrctx := errgroup.WithContext(ctx) + workergrp := errgroup.Group{} for range r.Concurrency { workergrp.Go(func() error { - // not paying attention to context here - // because we must complete writes for inp := range rdch { - result, err := r.Processor(innrctx, inp) + result, err := r.Processor(ctx, inp) if err != nil { return err } - wrch <- result - - // if one of workers died or parent context expired - // we should die too - if err := innrctx.Err(); err != nil { + select { + case wrch <- result: + case <-ctx.Done(): return nil } }