Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/mcp/client/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ async def stdin_writer():
except ProcessLookupError: # pragma: no cover
# Process already exited, which is fine
pass

if process.stdout: # pragma: no branch
try:
await process.stdout.aclose()
except Exception: # pragma: no cover
pass
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
Expand Down
7 changes: 5 additions & 2 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.abc import TaskGroup, TaskStatus
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
from pydantic import ValidationError

Expand Down Expand Up @@ -437,10 +437,13 @@ async def post_writer(
write_stream: ContextSendStream[SessionMessage],
start_get_stream: Callable[[], None],
tg: TaskGroup,
*,
task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED,
) -> None:
"""Handle writing requests to the server."""
try:
async with write_stream_reader, read_stream_writer, write_stream:
task_status.started(None)

async def _handle_message(session_message: SessionMessage) -> None:
message = session_message.message
Expand Down Expand Up @@ -570,7 +573,7 @@ async def streamable_http_client(
def start_get_stream() -> None:
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)

tg.start_soon(
await tg.start(
transport.post_writer,
client,
write_stream_reader,
Expand Down
40 changes: 40 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,46 @@ async def test_streamable_http_client_basic_connection(basic_app: Starlette) ->
assert result.server_info.name == SERVER_NAME


@pytest.mark.anyio
async def test_streamable_http_client_no_race_on_consecutive_requests(basic_app: Starlette) -> None:
"""The first request after initialize can run repeatedly without racing startup."""
for iteration in range(10): # pragma: no branch
async with (
make_client(basic_app) as http_client,
streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream),
ClientSession(read_stream, write_stream) as session,
):
await session.initialize()

tools = await session.list_tools()
assert len(tools.tools) == 8, f"Iteration {iteration}: expected 8 tools, got {len(tools.tools)}"
assert tools.tools[0].name == "test_tool"

tools2 = await session.list_tools()
assert len(tools2.tools) == 8

resource = await session.read_resource(uri="foobar://test-iteration")
assert len(resource.contents) == 1


@pytest.mark.anyio
async def test_streamable_http_client_rapid_request_sequence(basic_app: Starlette) -> None:
"""A rapid sequence of requests reuses the initialized stream reliably."""
async with (
make_client(basic_app) as http_client,
streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream),
ClientSession(read_stream, write_stream) as session,
):
await session.initialize()

for i in range(20):
tools = await session.list_tools()
assert len(tools.tools) == 8, f"Request {i}: expected 8 tools, got {len(tools.tools)}"

resource = await session.read_resource(uri="foobar://final-test")
assert len(resource.contents) == 1


@pytest.mark.anyio
async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession) -> None:
"""A resource read round-trips its arguments and the handler's content."""
Expand Down
Loading