Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions livekit-agents/livekit/agents/inference/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,8 @@ def __init__(
conn_options: APIConnectOptions,
) -> None:
super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
self._stt: STT = stt
self._opts = opts
self._session = stt._ensure_session()
self._request_id = str(utils.shortuuid("stt_request_"))

self._speaking = False
Expand Down Expand Up @@ -588,6 +588,7 @@ async def _send_session_update(self, msg: dict[str, Any]) -> None:
async def _run(self) -> None:
"""Main loop for streaming transcription."""
closing_ws = False
http_session = self._stt._ensure_session()

@utils.log_exceptions(logger=logger)
async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
Expand Down Expand Up @@ -632,7 +633,7 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
if closing_ws or self._session.closed:
if closing_ws or http_session.closed:
return
raise APIStatusError(
message="LiveKit Inference STT connection closed unexpectedly"
Expand Down Expand Up @@ -665,7 +666,7 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:

ws: aiohttp.ClientWebSocketResponse | None = None
try:
ws = await self._connect_ws()
ws = await self._connect_ws(http_session)
self._ws = ws
tasks = [
asyncio.create_task(send_task(ws)),
Expand All @@ -680,7 +681,9 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
if ws is not None:
await ws.close()

async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
async def _connect_ws(
self, http_session: aiohttp.ClientSession
) -> aiohttp.ClientWebSocketResponse:
"""Connect to the LiveKit Inference STT WebSocket."""
params: dict[str, Any] = {
"settings": {
Expand Down Expand Up @@ -718,7 +721,7 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
}
try:
ws = await asyncio.wait_for(
self._session.ws_connect(
http_session.ws_connect(
f"{base_url}/stt?model={self._opts.model}", headers=headers
),
self._conn_options.timeout,
Expand Down
36 changes: 34 additions & 2 deletions livekit-agents/livekit/agents/utils/http_context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import contextlib
import contextvars
from collections.abc import Callable
from collections.abc import AsyncIterator, Callable

import aiohttp

Expand Down Expand Up @@ -45,7 +46,9 @@ def http_session() -> aiohttp.ClientSession:
val = _ContextVar.get(None)
if val is None:
raise RuntimeError(
"Attempted to use an http session outside of a job context. This is probably because you are trying to use a plugin without using the agent worker api. You may need to create your own aiohttp.ClientSession, pass it into the plugin constructor as a kwarg, and manage its lifecycle." # noqa: E501
"Attempted to use an http session outside of a job context. This is probably because you are trying to use a plugin without using the agent worker api. " # noqa: E501
"If you're running plugins outside the agent worker (e.g. tests or scripts), wrap your code with `async with livekit.agents.utils.http_context.open(): ...`. " # noqa: E501
"Alternatively, create your own aiohttp.ClientSession, pass it into the plugin constructor as a kwarg, and manage its lifecycle." # noqa: E501
)

return val()
Expand All @@ -57,3 +60,32 @@ async def _close_http_ctx() -> None:
logger.debug("http_session(): closing the httpclient ctx")
await val().close()
_ContextVar.set(None)


@contextlib.asynccontextmanager
async def open() -> AsyncIterator[aiohttp.ClientSession]: # noqa: A001
"""Bind a process-local aiohttp.ClientSession to the current asyncio context.

Use this when running plugins outside a job worker (e.g. tests, scripts,
notebooks) so that ``http_session()`` returns a usable session inside the
``async with`` block. The session is closed and the context is reset on exit.

If an http session context is already bound (nested call, or already set up
by the worker), this is a no-op pass-through — the existing session is
yielded and left untouched on exit.

Example::

async with utils.http_context.open():
async with AgentSession() as session:
await session.start(MyAgent())
"""
if _ContextVar.get(None) is not None:
yield _ContextVar.get()() # type: ignore[misc]
return

factory = _new_session_ctx()
try:
yield factory()
finally:
await _close_http_ctx()
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def __init__(
conn_options: APIConnectOptions,
) -> None:
super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
self._stt: STT = stt
self._opts = opts
self._session = stt._ensure_session()
self._request_id = str(uuid.uuid4())
self._reconnect_event = asyncio.Event()
self._speaking = False
Expand All @@ -230,6 +230,7 @@ def update_options(
async def _run(self) -> None:
"""Main loop for streaming transcription."""
closing_ws = False
http_session = self._stt._ensure_session()

async def keepalive_task(ws: aiohttp.ClientWebSocketResponse) -> None:
try:
Expand Down Expand Up @@ -275,7 +276,7 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
if closing_ws or self._session.closed:
if closing_ws or http_session.closed:
return
raise APIStatusError(message="Cartesia STT connection closed unexpectedly")

Expand All @@ -292,7 +293,7 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:

while True:
try:
ws = await self._connect_ws()
ws = await self._connect_ws(http_session)
tasks = [
asyncio.create_task(send_task(ws)),
asyncio.create_task(recv_task(ws)),
Expand Down Expand Up @@ -323,7 +324,9 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
if ws is not None:
await ws.close()

async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
async def _connect_ws(
self, http_session: aiohttp.ClientSession
) -> aiohttp.ClientWebSocketResponse:
"""Connect to the Cartesia STT WebSocket."""
params = {
"model": self._opts.model,
Expand All @@ -342,7 +345,7 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:

try:
ws = await asyncio.wait_for(
self._session.ws_connect(
http_session.ws_connect(
ws_url,
headers={
"User-Agent": USER_AGENT,
Expand Down
3 changes: 2 additions & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ unit-tests:
tests/test_tool_proxy.py \
tests/test_endpointing.py \
tests/test_amd_classifier.py \
tests/test_session_host.py
tests/test_session_host.py \
tests/test_http_context_helper.py

# ============================================
# Development Workflows
Expand Down
94 changes: 94 additions & 0 deletions tests/test_http_context_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Tests for the public `utils.http_context.open()` helper and for the
inference STT error surface when called outside a job context.
"""

from __future__ import annotations

import asyncio

import aiohttp
import pytest

from livekit.agents import inference
from livekit.agents.utils import http_context


async def test_open_yields_working_session_and_closes_on_exit() -> None:
with pytest.raises(RuntimeError):
http_context.http_session()

async with http_context.open() as session:
assert isinstance(session, aiohttp.ClientSession)
assert not session.closed
# http_session() returns the same instance inside the block
assert http_context.http_session() is session

assert session.closed
with pytest.raises(RuntimeError):
http_context.http_session()


async def test_open_is_reentrant_inner_does_not_close_outer() -> None:
async with http_context.open() as outer:
async with http_context.open() as inner:
# nested open() reuses the outer session — does not create a new one
assert inner is outer

# outer session is untouched after inner exits
assert not outer.closed
assert http_context.http_session() is outer

assert outer.closed


async def test_open_isolated_per_task() -> None:
"""Each asyncio.Task gets its own http session context — they don't share."""
barrier = asyncio.Event()

async def worker() -> tuple[aiohttp.ClientSession, bool]:
async with http_context.open() as session:
await barrier.wait()
still_open = not session.closed
return session, still_open

task_a = asyncio.create_task(worker())
task_b = asyncio.create_task(worker())
await asyncio.sleep(0.01)
barrier.set()

(sess_a, a_open), (sess_b, b_open) = await asyncio.gather(task_a, task_b)
assert sess_a is not sess_b
assert a_open and b_open
assert sess_a.closed and sess_b.closed


async def test_http_session_error_message_points_to_helper() -> None:
with pytest.raises(RuntimeError) as exc_info:
http_context.http_session()
msg = str(exc_info.value)
assert "http_context.open()" in msg


async def test_inference_stt_surfaces_real_error_outside_ctx(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Regression: previously, calling `inference.STT().stream()` outside a job
context raised `AttributeError: 'SpeechStream' object has no attribute
'_session'` from a background task — masking the real "no http context" error.

After the fix, _ensure_session() runs inside _run(), so the actual
RuntimeError surfaces through the stream's main task.
"""
monkeypatch.setenv("LIVEKIT_API_KEY", "test-key")
monkeypatch.setenv("LIVEKIT_API_SECRET", "test-secret")

stt = inference.STT(model="cartesia/sonic-3")
stream = stt.stream()

# SpeechStream no longer eagerly grabs `_session` in __init__.
assert not hasattr(stream, "_session")

with pytest.raises(RuntimeError, match="http_context.open"):
await stream._task

await stream.aclose()