diff --git a/src/appose/python_worker.py b/src/appose/python_worker.py index d4c2429..12c1371 100644 --- a/src/appose/python_worker.py +++ b/src/appose/python_worker.py @@ -222,8 +222,15 @@ def _process_input(self) -> None: else: # Create a thread and save a reference to it, in case its script # kills the thread. This happens e.g. if it calls sys.exit. - task._thread = Thread(target=task._run, name=f"Appose-{uuid}") - task._thread.start() + # + # Assign task._thread only AFTER start() returns. Otherwise the + # janitor (_cleanup_threads) can observe task._thread set while + # the thread is not yet alive (the window between Thread() + # construction and start()) and spuriously fail the task with + # "thread death". See apposed/appose#15. + t = Thread(target=task._run, name=f"Appose-{uuid}") + t.start() + task._thread = t elif request_type == RequestType.CANCEL: task = self.tasks.get(uuid) diff --git a/tests/test_service.py b/tests/test_service.py index 5a4afcc..501fbce 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -7,6 +7,7 @@ from appose.service import ResponseType, TaskException, TaskStatus from tests.test_base import execute_and_assert, maybe_debug from pathlib import Path +import threading import time import os import re @@ -354,3 +355,39 @@ def test_task_result_null(): # result() should return None. assert task.result() is None + + +def test_thread_death_stress(): + """Floods the worker with many concurrent tiny tasks to surface the + spurious 'thread death' race (apposed/appose#15). No task here can + legitimately die, so any 'thread death' is the bug.""" + env = appose.system() + n_threads = 16 + n_tasks = 200 # per thread + errors = [] + err_lock = threading.Lock() + submit_lock = threading.Lock() # serialize stdin writes only + + with env.python() as service: + maybe_debug(service) + + def worker(): + for _ in range(n_tasks): + with submit_lock: + task = service.task("task.outputs['result'] = 1") + task.start() + try: + task.wait_for() + except Exception as e: + with err_lock: + errors.append(str(e)) + + threads = [threading.Thread(target=worker) for _ in range(n_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors, ( + f"{len(errors)}/{n_threads * n_tasks} tasks failed; sample: {errors[:5]}" + )