diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go index df362c6..9e4febe 100644 --- a/pipe/memorylimit.go +++ b/pipe/memorylimit.go @@ -12,11 +12,11 @@ import ( const memoryPollInterval = time.Second // ErrMemoryLimitExceeded is the error that will be used to kill a -// process, if necessary, from MemoryLimit. +// process, if necessary, from a MemoryWatch with WithMemoryLimit. var ErrMemoryLimitExceeded = errors.New("memory limit exceeded") // LimitableStage is the superset of `Stage` that must be implemented -// by stages passed to MemoryLimit and MemoryObserver. +// by stages passed to MemoryWatch. type LimitableStage interface { Stage @@ -24,252 +24,187 @@ type LimitableStage interface { Kill(error) } -// MemoryLimit watches the memory usage of the stage and stops it if it -// exceeds the given limit. +// MemoryWatchOption configures a MemoryWatch stage. +type MemoryWatchOption func(*memoryWatcher) + +// WithMemoryLimit makes MemoryWatch kill the stage when its RSS exceeds +// byteLimit. +func WithMemoryLimit(byteLimit uint64) MemoryWatchOption { + return func(mw *memoryWatcher) { + mw.limit = &byteLimit + } +} + +// WithPeakUsageLogging makes MemoryWatch log the peak RSS when the stage +// exits. +func WithPeakUsageLogging() MemoryWatchOption { + return func(mw *memoryWatcher) { + mw.observe = true + } +} + +// MemoryWatch watches the memory usage of the stage and reports via +// eventHandler. With WithMemoryLimit it kills the stage when the limit is +// exceeded; with WithPeakUsageLogging it logs the peak RSS when the stage +// exits. At least one of the two options is required. // // If the event handler panics while reporting the over-limit event, the // stage is still killed. A panic in any other event-handler call (an -// RSS-read error) is recovered via StartOptions.PanicHandler and the -// stage keeps running unmonitored; see StartOptions.PanicHandler. -func MemoryLimit(stage Stage, byteLimit uint64, eventHandler func(e *Event)) Stage { - +// RSS-read error, or the peak-usage report) is recovered via +// StartOptions.PanicHandler and the stage keeps running unmonitored; see +// StartOptions.PanicHandler. +func MemoryWatch(stage Stage, eventHandler func(e *Event), opts ...MemoryWatchOption) Stage { limitableStage, ok := stage.(LimitableStage) if !ok { eventHandler(&Event{ Command: stage.Name(), - Msg: "invalid pipe.MemoryLimit usage", - Err: fmt.Errorf("invalid pipe.MemoryLimit usage"), + Msg: "invalid pipe.MemoryWatch usage", + Err: fmt.Errorf("invalid pipe.MemoryWatch usage"), }) return stage } - return &memoryWatchStage{ - nameSuffix: " with memory limit", - stage: limitableStage, - watch: killAtLimit(byteLimit, eventHandler), + mw := memoryWatcher{ + stage: limitableStage, + eventHandler: eventHandler, } -} - -func killAtLimit(byteLimit uint64, eventHandler func(e *Event)) memoryWatchFunc { - return func(ctx context.Context, stage LimitableStage) { - var consecutiveErrors int - - t := time.NewTicker(memoryPollInterval) - defer t.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-t.C: - rss, err := stage.GetRSSAnon(ctx) - if err != nil && !errors.Is(err, errProcessInfoMissing) { - consecutiveErrors++ - if consecutiveErrors >= 2 { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "error getting RSS", - Err: err, - }) - } - continue - } - consecutiveErrors = 0 - if rss < byteLimit { - continue - } - func() { - // Guarantee the over-limit stage is killed even if - // the user's event handler panics. - defer stage.Kill(ErrMemoryLimitExceeded) - eventHandler(&Event{ - Command: stage.Name(), - Msg: "stage exceeded allowed memory use", - Err: fmt.Errorf("stage exceeded allowed memory use"), - Context: map[string]interface{}{ - "limit": byteLimit, - "used": rss, - }, - }) - }() - return - } - } + for _, opt := range opts { + opt(&mw) } -} -// MemoryLimitWithObserver combines MemoryLimit and MemoryObserver in -// one goroutine. It watches the memory usage of the stage, stops it -// if it exceeds the given limit, and logs the peak memory usage when -// the stage exits. -// -// Its event-handler panic behavior matches MemoryLimit: the over-limit -// kill always happens, while a panic in the RSS-error or peak-usage -// handler is recovered via StartOptions.PanicHandler and the stage keeps -// running unmonitored. See StartOptions.PanicHandler. -func MemoryLimitWithObserver(stage Stage, byteLimit uint64, eventHandler func(e *Event)) Stage { - limitableStage, ok := stage.(LimitableStage) - if !ok { + if mw.limit == nil && !mw.observe { eventHandler(&Event{ Command: stage.Name(), - Msg: "invalid pipe.MemoryLimitWithObserver usage", - Err: fmt.Errorf("invalid pipe.MemoryLimitWithObserver usage"), + Msg: "invalid pipe.MemoryWatch usage", + Err: fmt.Errorf( + "pipe.MemoryWatch requires WithMemoryLimit and/or WithPeakUsageLogging", + ), }) return stage } + nameSuffix := "" + if mw.limit != nil { + nameSuffix = " with memory limit" + } + return &memoryWatchStage{ - nameSuffix: " with memory limit", + nameSuffix: nameSuffix, stage: limitableStage, - watch: killAtLimitAndObserve(byteLimit, eventHandler), + watch: mw.watch, } } -func killAtLimitAndObserve(byteLimit uint64, eventHandler func(e *Event)) memoryWatchFunc { - return func(ctx context.Context, stage LimitableStage) { - var ( - maxRSS uint64 - samples, errCount, consecutiveErrors int - killed bool - ) - - t := time.NewTicker(memoryPollInterval) - defer t.Stop() - - for { - select { - case <-ctx.Done(): - eventHandler(&Event{ - Command: stage.Name(), - Msg: "peak memory usage", - Context: map[string]interface{}{ - "max_rss_bytes": maxRSS, - "samples": samples, - "errors": errCount, - }, - }) - return - case <-t.C: - if killed { - continue - } +type memoryWatcher struct { + stage LimitableStage + eventHandler func(e *Event) - rss, err := stage.GetRSSAnon(ctx) - if err != nil { - if !errors.Is(err, errProcessInfoMissing) { - errCount++ - consecutiveErrors++ - if consecutiveErrors == 2 { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "error getting RSS", - Err: err, - }) - } - } else { - consecutiveErrors = 0 - } - continue - } + limit *uint64 // non-nil enables kill-at-limit + observe bool // log peak RSS when the stage exits - consecutiveErrors = 0 - samples++ - if rss > maxRSS { - maxRSS = rss - } + maxRSS uint64 + samples int + errCount int + consecutiveErrors int +} - if rss >= byteLimit { - func() { - // Guarantee the over-limit stage is killed even if - // the user's event handler panics. - defer stage.Kill(ErrMemoryLimitExceeded) - eventHandler(&Event{ - Command: stage.Name(), - Msg: "stage exceeded allowed memory use", - Err: fmt.Errorf("stage exceeded allowed memory use"), - Context: map[string]interface{}{ - "limit": byteLimit, - "used": rss, - }, - }) - }() - killed = true - } +// watch is a `memoryWatchFunc` that watches the memory usage of the +// specified `stage`. +func (mw *memoryWatcher) watch(ctx context.Context) { + t := time.NewTicker(memoryPollInterval) + defer t.Stop() + +watchLoop: + for { + select { + case <-ctx.Done(): + break watchLoop + case <-t.C: + if mw.update(ctx) { + // The stage was killed. + break watchLoop } } } + + if mw.observe { + <-ctx.Done() + mw.reportPeakUsage() + } } -// MemoryObserver watches memory use of the stage and logs the maximum -// value when the stage exits. -func MemoryObserver(stage Stage, eventHandler func(e *Event)) Stage { - limitableStage, ok := stage.(LimitableStage) - if !ok { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "invalid pipe.MemoryObserver usage", - Err: fmt.Errorf("invalid pipe.MemoryObserver usage"), - }) - return stage +// update samples the current memory usage and updates internal stats. +// Return true if the stage was killed for exceeding the memory limit. +func (mw *memoryWatcher) update(ctx context.Context) bool { + rss, err := mw.stage.GetRSSAnon(ctx) + if err != nil { + mw.handleGetRSSError(err) + return false } - return &memoryWatchStage{ - stage: limitableStage, - watch: logMaxRSS(eventHandler), + mw.consecutiveErrors = 0 + mw.samples++ + if rss > mw.maxRSS { + mw.maxRSS = rss } -} -func logMaxRSS(eventHandler func(e *Event)) memoryWatchFunc { - - return func(ctx context.Context, stage LimitableStage) { - var ( - maxRSS uint64 - samples, errors, consecutiveErrors int - ) - - t := time.NewTicker(memoryPollInterval) - defer t.Stop() - - for { - select { - case <-ctx.Done(): - eventHandler(&Event{ - Command: stage.Name(), - Msg: "peak memory usage", - Context: map[string]interface{}{ - "max_rss_bytes": maxRSS, - "samples": samples, - "errors": errors, - }, - }) - - return - case <-t.C: - rss, err := stage.GetRSSAnon(ctx) - if err != nil { - errors++ - consecutiveErrors++ - if consecutiveErrors == 2 { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "error getting RSS", - Err: err, - }) - } - // don't log any more errors until we get rss successfully. - continue - } + if mw.limit != nil && rss >= *mw.limit { + mw.killStage(rss) + return true + } - consecutiveErrors = 0 - samples++ - if rss > maxRSS { - maxRSS = rss - } - } + return false +} + +// handleGetRSSError deals with error `err` that happened when trying +// to get `stage`'s memory usage. +func (mw *memoryWatcher) handleGetRSSError(err error) { + if !errors.Is(err, errProcessInfoMissing) { + mw.errCount++ + mw.consecutiveErrors++ + if mw.consecutiveErrors == 2 { + mw.eventHandler(&Event{ + Command: mw.stage.Name(), + Msg: "error getting RSS", + Err: err, + }) } + } else { + mw.consecutiveErrors = 0 } } +// killStage kills the stage and reports and event saying what it did. +func (mw *memoryWatcher) killStage(rss uint64) { + // Guarantee the over-limit stage is killed even if + // the user's event handler panics. + defer mw.stage.Kill(ErrMemoryLimitExceeded) + + mw.eventHandler(&Event{ + Command: mw.stage.Name(), + Msg: "stage exceeded allowed memory use", + Err: fmt.Errorf("stage exceeded allowed memory use"), + Context: map[string]any{ + "limit": *mw.limit, + "used": rss, + }, + }) +} + +// reportPeakUsage sends an event reporting the peak usage that has +// been seen for `stage`. +func (mw *memoryWatcher) reportPeakUsage() { + mw.eventHandler(&Event{ + Command: mw.stage.Name(), + Msg: "peak memory usage", + Context: map[string]any{ + "max_rss_bytes": mw.maxRSS, + "samples": mw.samples, + "errors": mw.errCount, + }, + }) +} + type memoryWatchStage struct { nameSuffix string stage LimitableStage @@ -279,7 +214,7 @@ type memoryWatchStage struct { watchErr error } -type memoryWatchFunc func(context.Context, LimitableStage) +type memoryWatchFunc func(context.Context) var _ LimitableStage = (*memoryWatchStage)(nil) @@ -313,15 +248,16 @@ func (m *memoryWatchStage) monitor(ctx context.Context, panicHandler StagePanicH go func() { defer m.wg.Done() - defer func() { - if p := recover(); p != nil { - if panicHandler == nil { - panic(p) + + if panicHandler != nil { + defer func() { + if p := recover(); p != nil { + m.watchErr = panicHandler(p) } - m.watchErr = panicHandler(p) - } - }() - m.watch(ctx, m.stage) + }() + } + + m.watch(ctx) }() } diff --git a/pipe/memorylimit_panic_test.go b/pipe/memorylimit_panic_test.go index da9977e..07b058a 100644 --- a/pipe/memorylimit_panic_test.go +++ b/pipe/memorylimit_panic_test.go @@ -4,20 +4,21 @@ import ( "context" "fmt" "io" - "os" - "os/exec" "strings" "testing" "time" + + "github.com/stretchr/testify/assert" ) const memWatchPanicSentinel = "memwatch-panic-sentinel" -const memWatchPanicChildEnv = "GO_PIPE_MEMWATCH_PANIC_CHILD" -// fakeLimitableStage is a minimal LimitableStage whose Wait returns -// immediately, letting a memoryWatchStage test exercise its watch -// goroutine in isolation. -type fakeLimitableStage struct{} +// fakeLimitableStage is a minimal LimitableStage whose `GetRSSAnon()` +// method panics, and whose `Wait()` method returns after that panic +// has been issued. +type fakeLimitableStage struct { + done chan struct{} +} func (fakeLimitableStage) Name() string { return "fake" } func (fakeLimitableStage) Preferences() StagePreferences { return StagePreferences{} } @@ -26,15 +27,23 @@ func (fakeLimitableStage) Start( ) error { return nil } -func (fakeLimitableStage) Wait() error { return nil } -func (fakeLimitableStage) GetRSSAnon(context.Context) (uint64, error) { return 0, nil } -func (fakeLimitableStage) Kill(error) {} - -func panickingWatchStage() *memoryWatchStage { - return &memoryWatchStage{ - stage: fakeLimitableStage{}, - watch: func(context.Context, LimitableStage) { panic(memWatchPanicSentinel) }, +func (stage fakeLimitableStage) Wait() error { + <-stage.done + return nil +} + +func (stage fakeLimitableStage) GetRSSAnon(context.Context) (uint64, error) { + close(stage.done) + panic(memWatchPanicSentinel) +} + +func (fakeLimitableStage) Kill(error) {} + +func panickingWatchStage() Stage { + stage := fakeLimitableStage{ + done: make(chan struct{}), } + return MemoryWatch(stage, func(*Event) {}, WithMemoryLimit(1)) } // TestMemoryWatchStagePanicWithHandlerSurfaced verifies that a panic @@ -60,47 +69,24 @@ func TestMemoryWatchStagePanicWithHandlerSurfaced(t *testing.T) { } } -// TestMemoryWatchStagePanicWithoutHandlerPropagates verifies that when -// the memory-watch goroutine panics and no panic handler is installed, -// the panic propagates (crashing the process) rather than being -// silently swallowed. Because that would crash the test binary, the -// scenario runs in a re-exec'd subprocess. +// TestMemoryWatchStagePanicWithoutHandlerPropagates verifies that the +// memory-watch sampling path does not swallow a panic. The monitor +// goroutine only installs a recover when a handler is present (see +// memoryWatchStage.monitor), so we exercise update() directly and assert +// that it propagates the panic. update() is used rather than watch() so the +// assertion is synchronous and ticker-independent: a regression that stopped +// the panic would fail the test rather than hang on the ticker loop. func TestMemoryWatchStagePanicWithoutHandlerPropagates(t *testing.T) { - if os.Getenv(memWatchPanicChildEnv) == "1" { - runMemWatchPanicChild() - return + limit := uint64(1) + mw := memoryWatcher{ + stage: fakeLimitableStage{done: make(chan struct{})}, + eventHandler: func(*Event) {}, + limit: &limit, } - cmd := exec.Command(os.Args[0], "-test.run=^TestMemoryWatchStagePanicWithoutHandlerPropagates$", "-test.v") //nolint:gosec // re-exec of this test binary with constant arguments. - cmd.Env = append(os.Environ(), memWatchPanicChildEnv+"=1") - out, err := cmd.CombinedOutput() - output := string(out) - - if err == nil { - t.Fatalf("expected subprocess to crash from a propagated panic, but it exited 0\noutput:\n%s", output) - } - if strings.Contains(output, "SURVIVED") { - t.Fatalf("panic was swallowed: Wait returned instead of propagating\noutput:\n%s", output) - } - if !strings.Contains(output, "panic:") || !strings.Contains(output, memWatchPanicSentinel) { - t.Fatalf("expected a propagated panic mentioning %q, got:\n%s", memWatchPanicSentinel, output) - } -} - -func runMemWatchPanicChild() { - ms := panickingWatchStage() - - if err := ms.Start(context.Background(), Env{}, nil, nil, StartOptions{}); err != nil { - os.Stdout.WriteString("SURVIVED: Start returned err=" + err.Error() + "\n") - os.Exit(0) - } - - _ = ms.Wait() - - // Reaching this point at all indicates the panic was swallowed. - time.Sleep(2 * time.Second) - os.Stdout.WriteString("SURVIVED: Wait returned\n") - os.Exit(0) + assert.PanicsWithValue(t, memWatchPanicSentinel, func() { + mw.update(context.Background()) + }) } // killTrackingStage is a LimitableStage that reports an over-limit RSS @@ -147,9 +133,16 @@ func (s *killTrackingStage) Kill(error) { // would never be killed and Wait would hang. func TestMemoryLimitKillsEvenIfEventHandlerPanics(t *testing.T) { stage := newKillTrackingStage() + limit := uint64(1) + eventHandler := func(*Event) { panic(memWatchPanicSentinel) } + mw := memoryWatcher{ + stage: stage, + eventHandler: eventHandler, + limit: &limit, + } ms := &memoryWatchStage{ stage: stage, - watch: killAtLimit(1, func(*Event) { panic(memWatchPanicSentinel) }), + watch: mw.watch, } opts := StartOptions{ PanicHandler: func(p any) error { return fmt.Errorf("recovered: %v", p) }, diff --git a/pipe/memorylimit_test.go b/pipe/memorylimit_test.go index 582421d..4e46945 100644 --- a/pipe/memorylimit_test.go +++ b/pipe/memorylimit_test.go @@ -61,7 +61,7 @@ func testMemoryObserver(t *testing.T, mbs int, stage pipe.Stage) int { logger := log.New(buf, "testMemoryObserver", log.Ldate|log.Ltime) p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdout(devNull)) - p.Add(pipe.MemoryObserver(stage, LogEventHandler(logger))) + p.Add(pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithPeakUsageLogging())) require.NoError(t, p.Start(ctx)) // Write some nonsense data to less, but don't close stdin until we want it @@ -168,7 +168,7 @@ func testMemoryLimitWithObserverBelowLimit(t *testing.T, mbs int, stage pipe.Sta logger := log.New(buf, "testMemoryLimitWithObserverBelowLimit", log.Ldate|log.Ltime) p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdout(devNull)) - p.Add(pipe.MemoryLimitWithObserver(stage, 100*1024*1024*1024, LogEventHandler(logger))) + p.Add(pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithMemoryLimit(100*1024*1024*1024), pipe.WithPeakUsageLogging())) require.NoError(t, p.Start(ctx)) var bytes [1_000_000]byte @@ -218,7 +218,7 @@ func testMemoryLimit(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (str return nil }, ), - pipe.MemoryLimit(stage, limit, LogEventHandler(logger)), + pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithMemoryLimit(limit)), ) require.NoError(t, p.Start(ctx)) @@ -252,7 +252,7 @@ func testMemoryLimitWithObserver(t *testing.T, mbs int, limit uint64, stage pipe return nil }, ), - pipe.MemoryLimitWithObserver(stage, limit, LogEventHandler(logger)), + pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithMemoryLimit(limit), pipe.WithPeakUsageLogging()), ) require.NoError(t, p.Start(ctx)) @@ -260,3 +260,19 @@ func testMemoryLimitWithObserver(t *testing.T, mbs int, limit uint64, stage pipe return buf.String(), err } + +// TestMemoryWatchRequiresAnOption verifies that MemoryWatch without +// WithMemoryLimit or WithPeakUsageLogging is rejected: it reports an +// invalid-usage event and returns the stage unwrapped (no watcher). +func TestMemoryWatchRequiresAnOption(t *testing.T) { + stage := pipe.Command("true") + + var events []*pipe.Event + got := pipe.MemoryWatch(stage, func(e *pipe.Event) { + events = append(events, e) + }) + + require.Same(t, stage, got, "expected the input stage returned unwrapped") + require.Len(t, events, 1) + require.Contains(t, events[0].Msg, "invalid pipe.MemoryWatch usage") +}