From f893837943918a76b6548d4a9ec5b4a792e51085 Mon Sep 17 00:00:00 2001 From: Jason Lunz Date: Sun, 31 May 2026 13:26:35 +0200 Subject: [PATCH 01/11] Replace the MemoryLimit constructor trio with MemoryWatch + options MemoryLimit, MemoryLimitWithObserver, and MemoryObserver were three positional-argument constructors expressing one concept: watch a stage's RSS, optionally enforce a limit, optionally log the peak. Collapse them into a single functional-options constructor: MemoryWatch(stage, eventHandler, WithMemoryLimit(n), WithPeakUsageLogging()) Calling MemoryWatch with neither option now reports an invalid-usage event and returns the stage unwrapped, mirroring the non-LimitableStage guard. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pipe/memorylimit.go | 236 +++++++++++---------------------- pipe/memorylimit_panic_test.go | 5 +- pipe/memorylimit_test.go | 24 +++- 3 files changed, 98 insertions(+), 167 deletions(-) diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go index df362c6..c610657 100644 --- a/pipe/memorylimit.go +++ b/pipe/memorylimit.go @@ -12,11 +12,11 @@ import ( const memoryPollInterval = time.Second // ErrMemoryLimitExceeded is the error that will be used to kill a -// process, if necessary, from MemoryLimit. +// process, if necessary, from a MemoryWatch with WithMemoryLimit. var ErrMemoryLimitExceeded = errors.New("memory limit exceeded") // LimitableStage is the superset of `Stage` that must be implemented -// by stages passed to MemoryLimit and MemoryObserver. +// by stages passed to MemoryWatch. type LimitableStage interface { Stage @@ -24,108 +24,83 @@ type LimitableStage interface { Kill(error) } -// MemoryLimit watches the memory usage of the stage and stops it if it -// exceeds the given limit. +// MemoryWatchOption configures a MemoryWatch stage. +type MemoryWatchOption func(*memoryWatchConfig) + +type memoryWatchConfig struct { + limit *uint64 // non-nil enables kill-at-limit + observe bool // log peak RSS when the stage exits +} + +// WithMemoryLimit makes MemoryWatch kill the stage when its RSS exceeds +// byteLimit. +func WithMemoryLimit(byteLimit uint64) MemoryWatchOption { + return func(c *memoryWatchConfig) { + c.limit = &byteLimit + } +} + +// WithPeakUsageLogging makes MemoryWatch log the peak RSS when the stage +// exits. +func WithPeakUsageLogging() MemoryWatchOption { + return func(c *memoryWatchConfig) { + c.observe = true + } +} + +// MemoryWatch watches the memory usage of the stage and reports via +// eventHandler. With WithMemoryLimit it kills the stage when the limit is +// exceeded; with WithPeakUsageLogging it logs the peak RSS when the stage +// exits. At least one of the two options is required. // // If the event handler panics while reporting the over-limit event, the // stage is still killed. A panic in any other event-handler call (an -// RSS-read error) is recovered via StartOptions.PanicHandler and the -// stage keeps running unmonitored; see StartOptions.PanicHandler. -func MemoryLimit(stage Stage, byteLimit uint64, eventHandler func(e *Event)) Stage { +// RSS-read error, or the peak-usage report) is recovered via +// StartOptions.PanicHandler and the stage keeps running unmonitored; see +// StartOptions.PanicHandler. +func MemoryWatch(stage Stage, eventHandler func(e *Event), opts ...MemoryWatchOption) Stage { + var cfg memoryWatchConfig + for _, opt := range opts { + opt(&cfg) + } limitableStage, ok := stage.(LimitableStage) if !ok { eventHandler(&Event{ Command: stage.Name(), - Msg: "invalid pipe.MemoryLimit usage", - Err: fmt.Errorf("invalid pipe.MemoryLimit usage"), + Msg: "invalid pipe.MemoryWatch usage", + Err: fmt.Errorf("invalid pipe.MemoryWatch usage"), }) return stage } - return &memoryWatchStage{ - nameSuffix: " with memory limit", - stage: limitableStage, - watch: killAtLimit(byteLimit, eventHandler), - } -} - -func killAtLimit(byteLimit uint64, eventHandler func(e *Event)) memoryWatchFunc { - return func(ctx context.Context, stage LimitableStage) { - var consecutiveErrors int - - t := time.NewTicker(memoryPollInterval) - defer t.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-t.C: - rss, err := stage.GetRSSAnon(ctx) - if err != nil && !errors.Is(err, errProcessInfoMissing) { - consecutiveErrors++ - if consecutiveErrors >= 2 { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "error getting RSS", - Err: err, - }) - } - continue - } - consecutiveErrors = 0 - if rss < byteLimit { - continue - } - func() { - // Guarantee the over-limit stage is killed even if - // the user's event handler panics. - defer stage.Kill(ErrMemoryLimitExceeded) - eventHandler(&Event{ - Command: stage.Name(), - Msg: "stage exceeded allowed memory use", - Err: fmt.Errorf("stage exceeded allowed memory use"), - Context: map[string]interface{}{ - "limit": byteLimit, - "used": rss, - }, - }) - }() - return - } - } - } -} - -// MemoryLimitWithObserver combines MemoryLimit and MemoryObserver in -// one goroutine. It watches the memory usage of the stage, stops it -// if it exceeds the given limit, and logs the peak memory usage when -// the stage exits. -// -// Its event-handler panic behavior matches MemoryLimit: the over-limit -// kill always happens, while a panic in the RSS-error or peak-usage -// handler is recovered via StartOptions.PanicHandler and the stage keeps -// running unmonitored. See StartOptions.PanicHandler. -func MemoryLimitWithObserver(stage Stage, byteLimit uint64, eventHandler func(e *Event)) Stage { - limitableStage, ok := stage.(LimitableStage) - if !ok { + if cfg.limit == nil && !cfg.observe { eventHandler(&Event{ Command: stage.Name(), - Msg: "invalid pipe.MemoryLimitWithObserver usage", - Err: fmt.Errorf("invalid pipe.MemoryLimitWithObserver usage"), + Msg: "invalid pipe.MemoryWatch usage", + Err: fmt.Errorf( + "pipe.MemoryWatch requires WithMemoryLimit and/or WithPeakUsageLogging", + ), }) return stage } + nameSuffix := "" + if cfg.limit != nil { + nameSuffix = " with memory limit" + } + return &memoryWatchStage{ - nameSuffix: " with memory limit", + nameSuffix: nameSuffix, stage: limitableStage, - watch: killAtLimitAndObserve(byteLimit, eventHandler), + watch: cfg.watchFunc(eventHandler), } } -func killAtLimitAndObserve(byteLimit uint64, eventHandler func(e *Event)) memoryWatchFunc { +func (c *memoryWatchConfig) watchFunc(eventHandler func(e *Event)) memoryWatchFunc { + limit := c.limit + observe := c.observe + return func(ctx context.Context, stage LimitableStage) { var ( maxRSS uint64 @@ -139,18 +114,22 @@ func killAtLimitAndObserve(byteLimit uint64, eventHandler func(e *Event)) memory for { select { case <-ctx.Done(): - eventHandler(&Event{ - Command: stage.Name(), - Msg: "peak memory usage", - Context: map[string]interface{}{ - "max_rss_bytes": maxRSS, - "samples": samples, - "errors": errCount, - }, - }) + if observe { + eventHandler(&Event{ + Command: stage.Name(), + Msg: "peak memory usage", + Context: map[string]interface{}{ + "max_rss_bytes": maxRSS, + "samples": samples, + "errors": errCount, + }, + }) + } return case <-t.C: if killed { + // After a kill we only remain in the loop to emit the + // peak-usage event at ctx.Done; stop sampling. continue } @@ -178,7 +157,7 @@ func killAtLimitAndObserve(byteLimit uint64, eventHandler func(e *Event)) memory maxRSS = rss } - if rss >= byteLimit { + if limit != nil && rss >= *limit { func() { // Guarantee the over-limit stage is killed even if // the user's event handler panics. @@ -188,82 +167,15 @@ func killAtLimitAndObserve(byteLimit uint64, eventHandler func(e *Event)) memory Msg: "stage exceeded allowed memory use", Err: fmt.Errorf("stage exceeded allowed memory use"), Context: map[string]interface{}{ - "limit": byteLimit, + "limit": *limit, "used": rss, }, }) }() - killed = true - } - } - } - } -} - -// MemoryObserver watches memory use of the stage and logs the maximum -// value when the stage exits. -func MemoryObserver(stage Stage, eventHandler func(e *Event)) Stage { - limitableStage, ok := stage.(LimitableStage) - if !ok { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "invalid pipe.MemoryObserver usage", - Err: fmt.Errorf("invalid pipe.MemoryObserver usage"), - }) - return stage - } - - return &memoryWatchStage{ - stage: limitableStage, - watch: logMaxRSS(eventHandler), - } -} - -func logMaxRSS(eventHandler func(e *Event)) memoryWatchFunc { - - return func(ctx context.Context, stage LimitableStage) { - var ( - maxRSS uint64 - samples, errors, consecutiveErrors int - ) - - t := time.NewTicker(memoryPollInterval) - defer t.Stop() - - for { - select { - case <-ctx.Done(): - eventHandler(&Event{ - Command: stage.Name(), - Msg: "peak memory usage", - Context: map[string]interface{}{ - "max_rss_bytes": maxRSS, - "samples": samples, - "errors": errors, - }, - }) - - return - case <-t.C: - rss, err := stage.GetRSSAnon(ctx) - if err != nil { - errors++ - consecutiveErrors++ - if consecutiveErrors == 2 { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "error getting RSS", - Err: err, - }) + if !observe { + return } - // don't log any more errors until we get rss successfully. - continue - } - - consecutiveErrors = 0 - samples++ - if rss > maxRSS { - maxRSS = rss + killed = true } } } diff --git a/pipe/memorylimit_panic_test.go b/pipe/memorylimit_panic_test.go index da9977e..a591084 100644 --- a/pipe/memorylimit_panic_test.go +++ b/pipe/memorylimit_panic_test.go @@ -147,9 +147,12 @@ func (s *killTrackingStage) Kill(error) { // would never be killed and Wait would hang. func TestMemoryLimitKillsEvenIfEventHandlerPanics(t *testing.T) { stage := newKillTrackingStage() + limit := uint64(1) ms := &memoryWatchStage{ stage: stage, - watch: killAtLimit(1, func(*Event) { panic(memWatchPanicSentinel) }), + watch: (&memoryWatchConfig{limit: &limit}).watchFunc( + func(*Event) { panic(memWatchPanicSentinel) }, + ), } opts := StartOptions{ PanicHandler: func(p any) error { return fmt.Errorf("recovered: %v", p) }, diff --git a/pipe/memorylimit_test.go b/pipe/memorylimit_test.go index 582421d..4e46945 100644 --- a/pipe/memorylimit_test.go +++ b/pipe/memorylimit_test.go @@ -61,7 +61,7 @@ func testMemoryObserver(t *testing.T, mbs int, stage pipe.Stage) int { logger := log.New(buf, "testMemoryObserver", log.Ldate|log.Ltime) p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdout(devNull)) - p.Add(pipe.MemoryObserver(stage, LogEventHandler(logger))) + p.Add(pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithPeakUsageLogging())) require.NoError(t, p.Start(ctx)) // Write some nonsense data to less, but don't close stdin until we want it @@ -168,7 +168,7 @@ func testMemoryLimitWithObserverBelowLimit(t *testing.T, mbs int, stage pipe.Sta logger := log.New(buf, "testMemoryLimitWithObserverBelowLimit", log.Ldate|log.Ltime) p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdout(devNull)) - p.Add(pipe.MemoryLimitWithObserver(stage, 100*1024*1024*1024, LogEventHandler(logger))) + p.Add(pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithMemoryLimit(100*1024*1024*1024), pipe.WithPeakUsageLogging())) require.NoError(t, p.Start(ctx)) var bytes [1_000_000]byte @@ -218,7 +218,7 @@ func testMemoryLimit(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (str return nil }, ), - pipe.MemoryLimit(stage, limit, LogEventHandler(logger)), + pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithMemoryLimit(limit)), ) require.NoError(t, p.Start(ctx)) @@ -252,7 +252,7 @@ func testMemoryLimitWithObserver(t *testing.T, mbs int, limit uint64, stage pipe return nil }, ), - pipe.MemoryLimitWithObserver(stage, limit, LogEventHandler(logger)), + pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithMemoryLimit(limit), pipe.WithPeakUsageLogging()), ) require.NoError(t, p.Start(ctx)) @@ -260,3 +260,19 @@ func testMemoryLimitWithObserver(t *testing.T, mbs int, limit uint64, stage pipe return buf.String(), err } + +// TestMemoryWatchRequiresAnOption verifies that MemoryWatch without +// WithMemoryLimit or WithPeakUsageLogging is rejected: it reports an +// invalid-usage event and returns the stage unwrapped (no watcher). +func TestMemoryWatchRequiresAnOption(t *testing.T) { + stage := pipe.Command("true") + + var events []*pipe.Event + got := pipe.MemoryWatch(stage, func(e *pipe.Event) { + events = append(events, e) + }) + + require.Same(t, stage, got, "expected the input stage returned unwrapped") + require.Len(t, events, 1) + require.Contains(t, events[0].Msg, "invalid pipe.MemoryWatch usage") +} From f05b5c6084ccf56f46dfb782fa2582428775e7b5 Mon Sep 17 00:00:00 2001 From: Michael Haggerty Date: Tue, 2 Jun 2026 10:06:08 +0200 Subject: [PATCH 02/11] memoryWatcher: new helper type For now it only has one method, which is a `memoryWatchFunc`. --- pipe/memorylimit.go | 155 ++++++++++++++++++++++++-------------------- 1 file changed, 83 insertions(+), 72 deletions(-) diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go index c610657..dae4990 100644 --- a/pipe/memorylimit.go +++ b/pipe/memorylimit.go @@ -98,85 +98,96 @@ func MemoryWatch(stage Stage, eventHandler func(e *Event), opts ...MemoryWatchOp } func (c *memoryWatchConfig) watchFunc(eventHandler func(e *Event)) memoryWatchFunc { - limit := c.limit - observe := c.observe - - return func(ctx context.Context, stage LimitableStage) { - var ( - maxRSS uint64 - samples, errCount, consecutiveErrors int - killed bool - ) - - t := time.NewTicker(memoryPollInterval) - defer t.Stop() - - for { - select { - case <-ctx.Done(): - if observe { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "peak memory usage", - Context: map[string]interface{}{ - "max_rss_bytes": maxRSS, - "samples": samples, - "errors": errCount, - }, - }) - } - return - case <-t.C: - if killed { - // After a kill we only remain in the loop to emit the - // peak-usage event at ctx.Done; stop sampling. - continue - } + mw := memoryWatcher{ + cfg: c, + eventHandler: eventHandler, + } - rss, err := stage.GetRSSAnon(ctx) - if err != nil { - if !errors.Is(err, errProcessInfoMissing) { - errCount++ - consecutiveErrors++ - if consecutiveErrors == 2 { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "error getting RSS", - Err: err, - }) - } - } else { - consecutiveErrors = 0 - } - continue - } + return mw.watch +} - consecutiveErrors = 0 - samples++ - if rss > maxRSS { - maxRSS = rss - } +type memoryWatcher struct { + cfg *memoryWatchConfig + eventHandler func(e *Event) + + maxRSS uint64 + samples int + errCount int + consecutiveErrors int + killed bool +} - if limit != nil && rss >= *limit { - func() { - // Guarantee the over-limit stage is killed even if - // the user's event handler panics. - defer stage.Kill(ErrMemoryLimitExceeded) - eventHandler(&Event{ +// watch is a `memoryWatchFunc` that watches the memory usage of the +// specified `stage`. +func (mw *memoryWatcher) watch(ctx context.Context, stage LimitableStage) { + t := time.NewTicker(memoryPollInterval) + defer t.Stop() + + for { + select { + case <-ctx.Done(): + if mw.cfg.observe { + mw.eventHandler(&Event{ + Command: stage.Name(), + Msg: "peak memory usage", + Context: map[string]interface{}{ + "max_rss_bytes": mw.maxRSS, + "samples": mw.samples, + "errors": mw.errCount, + }, + }) + } + return + case <-t.C: + if mw.killed { + // After a kill we only remain in the loop to emit the + // peak-usage event at ctx.Done; stop sampling. + continue + } + + rss, err := stage.GetRSSAnon(ctx) + if err != nil { + if !errors.Is(err, errProcessInfoMissing) { + mw.errCount++ + mw.consecutiveErrors++ + if mw.consecutiveErrors == 2 { + mw.eventHandler(&Event{ Command: stage.Name(), - Msg: "stage exceeded allowed memory use", - Err: fmt.Errorf("stage exceeded allowed memory use"), - Context: map[string]interface{}{ - "limit": *limit, - "used": rss, - }, + Msg: "error getting RSS", + Err: err, }) - }() - if !observe { - return } - killed = true + } else { + mw.consecutiveErrors = 0 + } + continue + } + + mw.consecutiveErrors = 0 + mw.samples++ + if rss > mw.maxRSS { + mw.maxRSS = rss + } + + if mw.cfg.limit != nil && rss >= *mw.cfg.limit { + func() { + // Guarantee the over-limit stage is killed even if + // the user's event handler panics. + defer stage.Kill(ErrMemoryLimitExceeded) + mw.eventHandler(&Event{ + Command: stage.Name(), + Msg: "stage exceeded allowed memory use", + Err: fmt.Errorf("stage exceeded allowed memory use"), + Context: map[string]interface{}{ + "limit": *mw.cfg.limit, + "used": rss, + }, + }) + }() + if !mw.cfg.observe { + return } + mw.killed = true } } } From 2cc87f43bd4ee7cde8c0055f96e47a2fcfa94591 Mon Sep 17 00:00:00 2001 From: Michael Haggerty Date: Tue, 2 Jun 2026 10:15:11 +0200 Subject: [PATCH 03/11] memoryWatcher: add some methods Extract the following methods from `memoryWatcher.watch()`: * `handleGetRSSError()` * `killStage()` * `reportPeakUsage()` --- pipe/memorylimit.go | 89 +++++++++++++++++++++++++++------------------ 1 file changed, 53 insertions(+), 36 deletions(-) diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go index dae4990..472cb01 100644 --- a/pipe/memorylimit.go +++ b/pipe/memorylimit.go @@ -127,15 +127,7 @@ func (mw *memoryWatcher) watch(ctx context.Context, stage LimitableStage) { select { case <-ctx.Done(): if mw.cfg.observe { - mw.eventHandler(&Event{ - Command: stage.Name(), - Msg: "peak memory usage", - Context: map[string]interface{}{ - "max_rss_bytes": mw.maxRSS, - "samples": mw.samples, - "errors": mw.errCount, - }, - }) + mw.reportPeakUsage(stage) } return case <-t.C: @@ -147,19 +139,7 @@ func (mw *memoryWatcher) watch(ctx context.Context, stage LimitableStage) { rss, err := stage.GetRSSAnon(ctx) if err != nil { - if !errors.Is(err, errProcessInfoMissing) { - mw.errCount++ - mw.consecutiveErrors++ - if mw.consecutiveErrors == 2 { - mw.eventHandler(&Event{ - Command: stage.Name(), - Msg: "error getting RSS", - Err: err, - }) - } - } else { - mw.consecutiveErrors = 0 - } + mw.handleGetRSSError(stage, err) continue } @@ -170,20 +150,8 @@ func (mw *memoryWatcher) watch(ctx context.Context, stage LimitableStage) { } if mw.cfg.limit != nil && rss >= *mw.cfg.limit { - func() { - // Guarantee the over-limit stage is killed even if - // the user's event handler panics. - defer stage.Kill(ErrMemoryLimitExceeded) - mw.eventHandler(&Event{ - Command: stage.Name(), - Msg: "stage exceeded allowed memory use", - Err: fmt.Errorf("stage exceeded allowed memory use"), - Context: map[string]interface{}{ - "limit": *mw.cfg.limit, - "used": rss, - }, - }) - }() + mw.killStage(stage, rss) + if !mw.cfg.observe { return } @@ -193,6 +161,55 @@ func (mw *memoryWatcher) watch(ctx context.Context, stage LimitableStage) { } } +// handleGetRSSError deals with error `err` that happened when trying +// to get `stage`'s memory usage. +func (mw *memoryWatcher) handleGetRSSError(stage LimitableStage, err error) { + if !errors.Is(err, errProcessInfoMissing) { + mw.errCount++ + mw.consecutiveErrors++ + if mw.consecutiveErrors == 2 { + mw.eventHandler(&Event{ + Command: stage.Name(), + Msg: "error getting RSS", + Err: err, + }) + } + } else { + mw.consecutiveErrors = 0 + } +} + +// killStage kills `stage` and reports and event saying what it did. +func (mw *memoryWatcher) killStage(stage LimitableStage, rss uint64) { + // Guarantee the over-limit stage is killed even if + // the user's event handler panics. + defer stage.Kill(ErrMemoryLimitExceeded) + + mw.eventHandler(&Event{ + Command: stage.Name(), + Msg: "stage exceeded allowed memory use", + Err: fmt.Errorf("stage exceeded allowed memory use"), + Context: map[string]any{ + "limit": *mw.cfg.limit, + "used": rss, + }, + }) +} + +// reportPeakUsage sends an event reporting the peak usage that has +// been seen for `stage`. +func (mw *memoryWatcher) reportPeakUsage(stage LimitableStage) { + mw.eventHandler(&Event{ + Command: stage.Name(), + Msg: "peak memory usage", + Context: map[string]any{ + "max_rss_bytes": mw.maxRSS, + "samples": mw.samples, + "errors": mw.errCount, + }, + }) +} + type memoryWatchStage struct { nameSuffix string stage LimitableStage From ebdc8f55031bd8849a82ccf03f62b8d727646c35 Mon Sep 17 00:00:00 2001 From: Michael Haggerty Date: Tue, 2 Jun 2026 10:48:09 +0200 Subject: [PATCH 04/11] memoryWatcher.watch(): simplify loop termination --- pipe/memorylimit.go | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go index 472cb01..420bee2 100644 --- a/pipe/memorylimit.go +++ b/pipe/memorylimit.go @@ -114,29 +114,19 @@ type memoryWatcher struct { samples int errCount int consecutiveErrors int - killed bool } // watch is a `memoryWatchFunc` that watches the memory usage of the // specified `stage`. func (mw *memoryWatcher) watch(ctx context.Context, stage LimitableStage) { t := time.NewTicker(memoryPollInterval) - defer t.Stop() +watchLoop: for { select { case <-ctx.Done(): - if mw.cfg.observe { - mw.reportPeakUsage(stage) - } - return + break watchLoop case <-t.C: - if mw.killed { - // After a kill we only remain in the loop to emit the - // peak-usage event at ctx.Done; stop sampling. - continue - } - rss, err := stage.GetRSSAnon(ctx) if err != nil { mw.handleGetRSSError(stage, err) @@ -152,13 +142,19 @@ func (mw *memoryWatcher) watch(ctx context.Context, stage LimitableStage) { if mw.cfg.limit != nil && rss >= *mw.cfg.limit { mw.killStage(stage, rss) - if !mw.cfg.observe { - return - } - mw.killed = true + // After a kill we wait for `ctx.Done()` and then emit + // the peak-usage event. + break watchLoop } } } + + t.Stop() + + if mw.cfg.observe { + <-ctx.Done() + mw.reportPeakUsage(stage) + } } // handleGetRSSError deals with error `err` that happened when trying From 4ab48c3c5599ae9a0228b68af13099e89f8f7eeb Mon Sep 17 00:00:00 2001 From: Michael Haggerty Date: Tue, 2 Jun 2026 10:58:48 +0200 Subject: [PATCH 05/11] memoryWatcher.update(): new method Extract method `memoryWatcher.update()` from `memoryWatcher.watch()`. --- pipe/memorylimit.go | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go index 420bee2..e184e43 100644 --- a/pipe/memorylimit.go +++ b/pipe/memorylimit.go @@ -127,23 +127,8 @@ watchLoop: case <-ctx.Done(): break watchLoop case <-t.C: - rss, err := stage.GetRSSAnon(ctx) - if err != nil { - mw.handleGetRSSError(stage, err) - continue - } - - mw.consecutiveErrors = 0 - mw.samples++ - if rss > mw.maxRSS { - mw.maxRSS = rss - } - - if mw.cfg.limit != nil && rss >= *mw.cfg.limit { - mw.killStage(stage, rss) - - // After a kill we wait for `ctx.Done()` and then emit - // the peak-usage event. + if mw.update(ctx, stage) { + // The stage was killed. break watchLoop } } @@ -157,6 +142,29 @@ watchLoop: } } +// update samples the current memory usage and updates internal stats. +// Return true if the stage was killed for exceeding the memory limit. +func (mw *memoryWatcher) update(ctx context.Context, stage LimitableStage) bool { + rss, err := stage.GetRSSAnon(ctx) + if err != nil { + mw.handleGetRSSError(stage, err) + return false + } + + mw.consecutiveErrors = 0 + mw.samples++ + if rss > mw.maxRSS { + mw.maxRSS = rss + } + + if mw.cfg.limit != nil && rss >= *mw.cfg.limit { + mw.killStage(stage, rss) + return true + } + + return false +} + // handleGetRSSError deals with error `err` that happened when trying // to get `stage`'s memory usage. func (mw *memoryWatcher) handleGetRSSError(stage LimitableStage, err error) { From 8a96683d5de3c0d92b8203627e66a07815179940 Mon Sep 17 00:00:00 2001 From: Michael Haggerty Date: Tue, 2 Jun 2026 11:20:35 +0200 Subject: [PATCH 06/11] memoryWatcher.stage: new field Build the stage into the `memoryWatcher`, so it doesn't need to be passed around as an extra argument. --- pipe/memorylimit.go | 42 +++++++++++++++++++--------------- pipe/memorylimit_panic_test.go | 4 ++-- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go index e184e43..b72c4cc 100644 --- a/pipe/memorylimit.go +++ b/pipe/memorylimit.go @@ -93,13 +93,16 @@ func MemoryWatch(stage Stage, eventHandler func(e *Event), opts ...MemoryWatchOp return &memoryWatchStage{ nameSuffix: nameSuffix, stage: limitableStage, - watch: cfg.watchFunc(eventHandler), + watch: cfg.watchFunc(limitableStage, eventHandler), } } -func (c *memoryWatchConfig) watchFunc(eventHandler func(e *Event)) memoryWatchFunc { +func (c *memoryWatchConfig) watchFunc( + stage LimitableStage, eventHandler func(e *Event), +) memoryWatchFunc { mw := memoryWatcher{ cfg: c, + stage: stage, eventHandler: eventHandler, } @@ -108,6 +111,7 @@ func (c *memoryWatchConfig) watchFunc(eventHandler func(e *Event)) memoryWatchFu type memoryWatcher struct { cfg *memoryWatchConfig + stage LimitableStage eventHandler func(e *Event) maxRSS uint64 @@ -118,7 +122,7 @@ type memoryWatcher struct { // watch is a `memoryWatchFunc` that watches the memory usage of the // specified `stage`. -func (mw *memoryWatcher) watch(ctx context.Context, stage LimitableStage) { +func (mw *memoryWatcher) watch(ctx context.Context) { t := time.NewTicker(memoryPollInterval) watchLoop: @@ -127,7 +131,7 @@ watchLoop: case <-ctx.Done(): break watchLoop case <-t.C: - if mw.update(ctx, stage) { + if mw.update(ctx) { // The stage was killed. break watchLoop } @@ -138,16 +142,16 @@ watchLoop: if mw.cfg.observe { <-ctx.Done() - mw.reportPeakUsage(stage) + mw.reportPeakUsage() } } // update samples the current memory usage and updates internal stats. // Return true if the stage was killed for exceeding the memory limit. -func (mw *memoryWatcher) update(ctx context.Context, stage LimitableStage) bool { - rss, err := stage.GetRSSAnon(ctx) +func (mw *memoryWatcher) update(ctx context.Context) bool { + rss, err := mw.stage.GetRSSAnon(ctx) if err != nil { - mw.handleGetRSSError(stage, err) + mw.handleGetRSSError(err) return false } @@ -158,7 +162,7 @@ func (mw *memoryWatcher) update(ctx context.Context, stage LimitableStage) bool } if mw.cfg.limit != nil && rss >= *mw.cfg.limit { - mw.killStage(stage, rss) + mw.killStage(rss) return true } @@ -167,13 +171,13 @@ func (mw *memoryWatcher) update(ctx context.Context, stage LimitableStage) bool // handleGetRSSError deals with error `err` that happened when trying // to get `stage`'s memory usage. -func (mw *memoryWatcher) handleGetRSSError(stage LimitableStage, err error) { +func (mw *memoryWatcher) handleGetRSSError(err error) { if !errors.Is(err, errProcessInfoMissing) { mw.errCount++ mw.consecutiveErrors++ if mw.consecutiveErrors == 2 { mw.eventHandler(&Event{ - Command: stage.Name(), + Command: mw.stage.Name(), Msg: "error getting RSS", Err: err, }) @@ -183,14 +187,14 @@ func (mw *memoryWatcher) handleGetRSSError(stage LimitableStage, err error) { } } -// killStage kills `stage` and reports and event saying what it did. -func (mw *memoryWatcher) killStage(stage LimitableStage, rss uint64) { +// killStage kills the stage and reports and event saying what it did. +func (mw *memoryWatcher) killStage(rss uint64) { // Guarantee the over-limit stage is killed even if // the user's event handler panics. - defer stage.Kill(ErrMemoryLimitExceeded) + defer mw.stage.Kill(ErrMemoryLimitExceeded) mw.eventHandler(&Event{ - Command: stage.Name(), + Command: mw.stage.Name(), Msg: "stage exceeded allowed memory use", Err: fmt.Errorf("stage exceeded allowed memory use"), Context: map[string]any{ @@ -202,9 +206,9 @@ func (mw *memoryWatcher) killStage(stage LimitableStage, rss uint64) { // reportPeakUsage sends an event reporting the peak usage that has // been seen for `stage`. -func (mw *memoryWatcher) reportPeakUsage(stage LimitableStage) { +func (mw *memoryWatcher) reportPeakUsage() { mw.eventHandler(&Event{ - Command: stage.Name(), + Command: mw.stage.Name(), Msg: "peak memory usage", Context: map[string]any{ "max_rss_bytes": mw.maxRSS, @@ -223,7 +227,7 @@ type memoryWatchStage struct { watchErr error } -type memoryWatchFunc func(context.Context, LimitableStage) +type memoryWatchFunc func(context.Context) var _ LimitableStage = (*memoryWatchStage)(nil) @@ -265,7 +269,7 @@ func (m *memoryWatchStage) monitor(ctx context.Context, panicHandler StagePanicH m.watchErr = panicHandler(p) } }() - m.watch(ctx, m.stage) + m.watch(ctx) }() } diff --git a/pipe/memorylimit_panic_test.go b/pipe/memorylimit_panic_test.go index a591084..1731882 100644 --- a/pipe/memorylimit_panic_test.go +++ b/pipe/memorylimit_panic_test.go @@ -33,7 +33,7 @@ func (fakeLimitableStage) Kill(error) {} func panickingWatchStage() *memoryWatchStage { return &memoryWatchStage{ stage: fakeLimitableStage{}, - watch: func(context.Context, LimitableStage) { panic(memWatchPanicSentinel) }, + watch: func(context.Context) { panic(memWatchPanicSentinel) }, } } @@ -151,7 +151,7 @@ func TestMemoryLimitKillsEvenIfEventHandlerPanics(t *testing.T) { ms := &memoryWatchStage{ stage: stage, watch: (&memoryWatchConfig{limit: &limit}).watchFunc( - func(*Event) { panic(memWatchPanicSentinel) }, + stage, func(*Event) { panic(memWatchPanicSentinel) }, ), } opts := StartOptions{ From d49ce65286ac2097b8e08adb08d05e4269df3f99 Mon Sep 17 00:00:00 2001 From: Michael Haggerty Date: Tue, 2 Jun 2026 12:20:32 +0200 Subject: [PATCH 07/11] memoryWatchConfig: subsume type into `memoryWatcher` --- pipe/memorylimit.go | 56 +++++++++++++--------------------- pipe/memorylimit_panic_test.go | 10 ++++-- 2 files changed, 29 insertions(+), 37 deletions(-) diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go index b72c4cc..c894b28 100644 --- a/pipe/memorylimit.go +++ b/pipe/memorylimit.go @@ -25,26 +25,21 @@ type LimitableStage interface { } // MemoryWatchOption configures a MemoryWatch stage. -type MemoryWatchOption func(*memoryWatchConfig) - -type memoryWatchConfig struct { - limit *uint64 // non-nil enables kill-at-limit - observe bool // log peak RSS when the stage exits -} +type MemoryWatchOption func(*memoryWatcher) // WithMemoryLimit makes MemoryWatch kill the stage when its RSS exceeds // byteLimit. func WithMemoryLimit(byteLimit uint64) MemoryWatchOption { - return func(c *memoryWatchConfig) { - c.limit = &byteLimit + return func(mw *memoryWatcher) { + mw.limit = &byteLimit } } // WithPeakUsageLogging makes MemoryWatch log the peak RSS when the stage // exits. func WithPeakUsageLogging() MemoryWatchOption { - return func(c *memoryWatchConfig) { - c.observe = true + return func(mw *memoryWatcher) { + mw.observe = true } } @@ -59,11 +54,6 @@ func WithPeakUsageLogging() MemoryWatchOption { // StartOptions.PanicHandler and the stage keeps running unmonitored; see // StartOptions.PanicHandler. func MemoryWatch(stage Stage, eventHandler func(e *Event), opts ...MemoryWatchOption) Stage { - var cfg memoryWatchConfig - for _, opt := range opts { - opt(&cfg) - } - limitableStage, ok := stage.(LimitableStage) if !ok { eventHandler(&Event{ @@ -74,7 +64,15 @@ func MemoryWatch(stage Stage, eventHandler func(e *Event), opts ...MemoryWatchOp return stage } - if cfg.limit == nil && !cfg.observe { + mw := memoryWatcher{ + stage: limitableStage, + eventHandler: eventHandler, + } + for _, opt := range opts { + opt(&mw) + } + + if mw.limit == nil && !mw.observe { eventHandler(&Event{ Command: stage.Name(), Msg: "invalid pipe.MemoryWatch usage", @@ -86,34 +84,24 @@ func MemoryWatch(stage Stage, eventHandler func(e *Event), opts ...MemoryWatchOp } nameSuffix := "" - if cfg.limit != nil { + if mw.limit != nil { nameSuffix = " with memory limit" } return &memoryWatchStage{ nameSuffix: nameSuffix, stage: limitableStage, - watch: cfg.watchFunc(limitableStage, eventHandler), + watch: mw.watch, } } -func (c *memoryWatchConfig) watchFunc( - stage LimitableStage, eventHandler func(e *Event), -) memoryWatchFunc { - mw := memoryWatcher{ - cfg: c, - stage: stage, - eventHandler: eventHandler, - } - - return mw.watch -} - type memoryWatcher struct { - cfg *memoryWatchConfig stage LimitableStage eventHandler func(e *Event) + limit *uint64 // non-nil enables kill-at-limit + observe bool // log peak RSS when the stage exits + maxRSS uint64 samples int errCount int @@ -140,7 +128,7 @@ watchLoop: t.Stop() - if mw.cfg.observe { + if mw.observe { <-ctx.Done() mw.reportPeakUsage() } @@ -161,7 +149,7 @@ func (mw *memoryWatcher) update(ctx context.Context) bool { mw.maxRSS = rss } - if mw.cfg.limit != nil && rss >= *mw.cfg.limit { + if mw.limit != nil && rss >= *mw.limit { mw.killStage(rss) return true } @@ -198,7 +186,7 @@ func (mw *memoryWatcher) killStage(rss uint64) { Msg: "stage exceeded allowed memory use", Err: fmt.Errorf("stage exceeded allowed memory use"), Context: map[string]any{ - "limit": *mw.cfg.limit, + "limit": *mw.limit, "used": rss, }, }) diff --git a/pipe/memorylimit_panic_test.go b/pipe/memorylimit_panic_test.go index 1731882..946d7b9 100644 --- a/pipe/memorylimit_panic_test.go +++ b/pipe/memorylimit_panic_test.go @@ -148,11 +148,15 @@ func (s *killTrackingStage) Kill(error) { func TestMemoryLimitKillsEvenIfEventHandlerPanics(t *testing.T) { stage := newKillTrackingStage() limit := uint64(1) + eventHandler := func(*Event) { panic(memWatchPanicSentinel) } + mw := memoryWatcher{ + stage: stage, + eventHandler: eventHandler, + limit: &limit, + } ms := &memoryWatchStage{ stage: stage, - watch: (&memoryWatchConfig{limit: &limit}).watchFunc( - stage, func(*Event) { panic(memWatchPanicSentinel) }, - ), + watch: mw.watch, } opts := StartOptions{ PanicHandler: func(p any) error { return fmt.Errorf("recovered: %v", p) }, From 0e8133e37ad5dc68c35007ae9a19159f09abf533 Mon Sep 17 00:00:00 2001 From: Michael Haggerty Date: Tue, 2 Jun 2026 12:33:41 +0200 Subject: [PATCH 08/11] memoryWatchStage.monitor(): only recover if there's a panic handler --- pipe/memorylimit.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go index c894b28..7277dfb 100644 --- a/pipe/memorylimit.go +++ b/pipe/memorylimit.go @@ -249,14 +249,15 @@ func (m *memoryWatchStage) monitor(ctx context.Context, panicHandler StagePanicH go func() { defer m.wg.Done() - defer func() { - if p := recover(); p != nil { - if panicHandler == nil { - panic(p) + + if panicHandler != nil { + defer func() { + if p := recover(); p != nil { + m.watchErr = panicHandler(p) } - m.watchErr = panicHandler(p) - } - }() + }() + } + m.watch(ctx) }() } From 939002cfe4284288fa11f4dc5f2a37ccb58f70aa Mon Sep 17 00:00:00 2001 From: Jason Lunz Date: Tue, 2 Jun 2026 21:59:59 +0200 Subject: [PATCH 09/11] memoryWatcher.watch(): defer Ticker.Stop() Don't leak a goroutine on panic/recover --- pipe/memorylimit.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go index 7277dfb..9e4febe 100644 --- a/pipe/memorylimit.go +++ b/pipe/memorylimit.go @@ -112,6 +112,7 @@ type memoryWatcher struct { // specified `stage`. func (mw *memoryWatcher) watch(ctx context.Context) { t := time.NewTicker(memoryPollInterval) + defer t.Stop() watchLoop: for { @@ -126,8 +127,6 @@ watchLoop: } } - t.Stop() - if mw.observe { <-ctx.Done() mw.reportPeakUsage() From 099bd875a438514c654c3be725260e5d51aa0f93 Mon Sep 17 00:00:00 2001 From: Jason Lunz Date: Tue, 2 Jun 2026 19:45:53 +0200 Subject: [PATCH 10/11] Test memory-watch panic handling via a booby-trapped GetRSSAnon Per mhagger's review suggestion on #51: rather than injecting a synthetic `watch` closure that panics, make `fakeLimitableStage.GetRSSAnon()` itself panic, and drive the test through the real `MemoryWatch` constructor and monitor goroutine. This exercises the genuine panic-propagation path. Co-authored-by: Michael Haggerty Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pipe/memorylimit_panic_test.go | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/pipe/memorylimit_panic_test.go b/pipe/memorylimit_panic_test.go index 946d7b9..ebc480f 100644 --- a/pipe/memorylimit_panic_test.go +++ b/pipe/memorylimit_panic_test.go @@ -14,10 +14,12 @@ import ( const memWatchPanicSentinel = "memwatch-panic-sentinel" const memWatchPanicChildEnv = "GO_PIPE_MEMWATCH_PANIC_CHILD" -// fakeLimitableStage is a minimal LimitableStage whose Wait returns -// immediately, letting a memoryWatchStage test exercise its watch -// goroutine in isolation. -type fakeLimitableStage struct{} +// fakeLimitableStage is a minimal LimitableStage whose `GetRSSAnon()` +// method panics, and whose `Wait()` method returns after that panic +// has been issued. +type fakeLimitableStage struct { + done chan struct{} +} func (fakeLimitableStage) Name() string { return "fake" } func (fakeLimitableStage) Preferences() StagePreferences { return StagePreferences{} } @@ -26,15 +28,23 @@ func (fakeLimitableStage) Start( ) error { return nil } -func (fakeLimitableStage) Wait() error { return nil } -func (fakeLimitableStage) GetRSSAnon(context.Context) (uint64, error) { return 0, nil } -func (fakeLimitableStage) Kill(error) {} +func (stage fakeLimitableStage) Wait() error { + <-stage.done + return nil +} + +func (stage fakeLimitableStage) GetRSSAnon(context.Context) (uint64, error) { + close(stage.done) + panic(memWatchPanicSentinel) +} + +func (fakeLimitableStage) Kill(error) {} -func panickingWatchStage() *memoryWatchStage { - return &memoryWatchStage{ - stage: fakeLimitableStage{}, - watch: func(context.Context) { panic(memWatchPanicSentinel) }, +func panickingWatchStage() Stage { + stage := fakeLimitableStage{ + done: make(chan struct{}), } + return MemoryWatch(stage, func(*Event) {}, WithMemoryLimit(1)) } // TestMemoryWatchStagePanicWithHandlerSurfaced verifies that a panic From 9967cc9f5c2bea1b42bd72aa1989539596093b88 Mon Sep 17 00:00:00 2001 From: Jason Lunz Date: Tue, 2 Jun 2026 19:52:59 +0200 Subject: [PATCH 11/11] Test no-handler panic propagation without a subprocess Per mhagger's review suggestion on #51: replace the re-exec'd subprocess test (and its child-process scaffolding) with a synchronous assert.PanicsWithValue. A panic in the monitor goroutine can't be observed from the test goroutine, so we drive the watch loop directly. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pipe/memorylimit_panic_test.go | 58 ++++++++++------------------------ 1 file changed, 17 insertions(+), 41 deletions(-) diff --git a/pipe/memorylimit_panic_test.go b/pipe/memorylimit_panic_test.go index ebc480f..07b058a 100644 --- a/pipe/memorylimit_panic_test.go +++ b/pipe/memorylimit_panic_test.go @@ -4,15 +4,14 @@ import ( "context" "fmt" "io" - "os" - "os/exec" "strings" "testing" "time" + + "github.com/stretchr/testify/assert" ) const memWatchPanicSentinel = "memwatch-panic-sentinel" -const memWatchPanicChildEnv = "GO_PIPE_MEMWATCH_PANIC_CHILD" // fakeLimitableStage is a minimal LimitableStage whose `GetRSSAnon()` // method panics, and whose `Wait()` method returns after that panic @@ -70,47 +69,24 @@ func TestMemoryWatchStagePanicWithHandlerSurfaced(t *testing.T) { } } -// TestMemoryWatchStagePanicWithoutHandlerPropagates verifies that when -// the memory-watch goroutine panics and no panic handler is installed, -// the panic propagates (crashing the process) rather than being -// silently swallowed. Because that would crash the test binary, the -// scenario runs in a re-exec'd subprocess. +// TestMemoryWatchStagePanicWithoutHandlerPropagates verifies that the +// memory-watch sampling path does not swallow a panic. The monitor +// goroutine only installs a recover when a handler is present (see +// memoryWatchStage.monitor), so we exercise update() directly and assert +// that it propagates the panic. update() is used rather than watch() so the +// assertion is synchronous and ticker-independent: a regression that stopped +// the panic would fail the test rather than hang on the ticker loop. func TestMemoryWatchStagePanicWithoutHandlerPropagates(t *testing.T) { - if os.Getenv(memWatchPanicChildEnv) == "1" { - runMemWatchPanicChild() - return - } - - cmd := exec.Command(os.Args[0], "-test.run=^TestMemoryWatchStagePanicWithoutHandlerPropagates$", "-test.v") //nolint:gosec // re-exec of this test binary with constant arguments. - cmd.Env = append(os.Environ(), memWatchPanicChildEnv+"=1") - out, err := cmd.CombinedOutput() - output := string(out) - - if err == nil { - t.Fatalf("expected subprocess to crash from a propagated panic, but it exited 0\noutput:\n%s", output) - } - if strings.Contains(output, "SURVIVED") { - t.Fatalf("panic was swallowed: Wait returned instead of propagating\noutput:\n%s", output) - } - if !strings.Contains(output, "panic:") || !strings.Contains(output, memWatchPanicSentinel) { - t.Fatalf("expected a propagated panic mentioning %q, got:\n%s", memWatchPanicSentinel, output) - } -} - -func runMemWatchPanicChild() { - ms := panickingWatchStage() - - if err := ms.Start(context.Background(), Env{}, nil, nil, StartOptions{}); err != nil { - os.Stdout.WriteString("SURVIVED: Start returned err=" + err.Error() + "\n") - os.Exit(0) + limit := uint64(1) + mw := memoryWatcher{ + stage: fakeLimitableStage{done: make(chan struct{})}, + eventHandler: func(*Event) {}, + limit: &limit, } - _ = ms.Wait() - - // Reaching this point at all indicates the panic was swallowed. - time.Sleep(2 * time.Second) - os.Stdout.WriteString("SURVIVED: Wait returned\n") - os.Exit(0) + assert.PanicsWithValue(t, memWatchPanicSentinel, func() { + mw.update(context.Background()) + }) } // killTrackingStage is a LimitableStage that reports an over-limit RSS