Skip to content

Commit 913011b

Browse files
committed
properly handle nil input/output
1 parent fe0dde2 commit 913011b

3 files changed

Lines changed: 33 additions & 2 deletions

File tree

pipe/function.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"io"
7+
"strings"
78
)
89

910
// StageFunc is a function that can be used to power a `goStage`. It
@@ -66,11 +67,19 @@ func (s *goStage) Start(
6667
if stdin, ok := stdin.(readerNopCloser); ok {
6768
r = stdin.Reader
6869
}
70+
if r == nil {
71+
// treat nil as empty input.
72+
r = strings.NewReader("")
73+
}
6974

7075
var w io.Writer = stdout
7176
if stdout, ok := stdout.(writerNopCloser); ok {
7277
w = stdout.Writer
7378
}
79+
if w == nil {
80+
// treat nil output as /dev/null
81+
w = io.Discard
82+
}
7483

7584
go func() {
7685
defer close(s.done)

pipe/memorylimit_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,8 @@ func testMemoryLimit(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (str
210210
for i := 0; i < mbs; i++ {
211211
_, err := stdout.Write(bytes[:])
212212
if err != nil {
213-
require.ErrorIs(t, err, syscall.EPIPE)
213+
assert.ErrorIs(t, err, syscall.EPIPE)
214+
return nil
214215
}
215216
}
216217

@@ -244,7 +245,8 @@ func testMemoryLimitWithObserver(t *testing.T, mbs int, limit uint64, stage pipe
244245
for i := 0; i < mbs; i++ {
245246
_, err := stdout.Write(bytes[:])
246247
if err != nil {
247-
require.ErrorIs(t, err, syscall.EPIPE)
248+
assert.ErrorIs(t, err, syscall.EPIPE)
249+
return nil
248250
}
249251
}
250252
return nil

pipe/pipeline_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,26 @@ func TestPrintf(t *testing.T) {
798798
}
799799
}
800800

801+
func TestPrintlnNoOutput(t *testing.T) {
802+
t.Parallel()
803+
ctx := context.Background()
804+
p := pipe.New()
805+
p.Add(pipe.Println("Look Ma, no output!"))
806+
assert.NoError(t, p.Run(ctx))
807+
}
808+
809+
func TestFunctionNoInput(t *testing.T) {
810+
t.Parallel()
811+
ctx := context.Background()
812+
p := pipe.New()
813+
p.Add(pipe.Function("read-all", func(_ context.Context, _ pipe.Env, stdin io.Reader, _ io.Writer) error {
814+
n, err := io.Copy(io.Discard, stdin)
815+
assert.Equal(t, int64(0), n)
816+
return err
817+
}))
818+
assert.NoError(t, p.Run(ctx))
819+
}
820+
801821
func TestErrors(t *testing.T) {
802822
t.Parallel()
803823
ctx := context.Background()

0 commit comments

Comments
 (0)