Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions livekit-agents/livekit/agents/voice/amd/detector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import os
from types import TracebackType
from typing import TYPE_CHECKING, Literal, TypedDict

Expand All @@ -16,6 +17,7 @@
from ...telemetry import trace_types, tracer
from ...types import NOT_GIVEN, NotGivenOr
from ...utils import EventEmitter, aio, is_given
from ...utils.misc import is_cloud
from ...utils.participant import wait_for_track_publication
from .classifier import (
AMD_PROMPT,
Expand All @@ -35,7 +37,7 @@
from ..agent_session import AgentSession

EVALUATED_LLM_MODELS: set[str] = {
"google/gemini-3.1-flash-lite-preview",
"google/gemini-3.1-flash-lite",
"google/gemini-3-flash-preview",
"openai/gpt-4.1",
"openai/gpt-5.2",
Expand Down Expand Up @@ -103,18 +105,24 @@ class AMD(EventEmitter[Literal["amd_prediction"]]):
session: The :class:`AgentSession` to wire AMD to.
llm: LLM used for greeting classification. Accepts an :class:`LLM`
instance or an inference model string (e.g.
``"openai/gpt-4.1-mini"``). Omit to fall back to the session's
own LLM.
``"openai/gpt-4.1-mini"``). When omitted, AMD auto-selects:
if LiveKit inference credentials are available in the environment
it uses ``"google/gemini-3.1-flash-lite"`` via the
inference gateway; otherwise it falls back to the session's own
LLM.
interrupt_on_machine: If ``True`` (default), interrupt any pending
agent speech immediately when a machine is detected.
ivr_detection: If ``True`` (default), automatically start IVR
navigation when a ``machine-ivr`` result is returned.
participant_identity: If set, only this participant's audio track
subscription triggers the detection timers. If omitted, the first
remote audio track wins.
stt: STT used for transcript generation. Required when the session
uses no STT (e.g. a realtime model). Omit to reuse the session's
STT transcripts.
stt: STT used for transcript generation. Accepts an :class:`STT`
instance or an inference model string (e.g.
``"cartesia/ink-whisper"``). When omitted, AMD auto-selects:
if LiveKit inference credentials are available it uses
``"cartesia/ink-whisper"`` via the inference gateway; otherwise
it reuses the session's existing STT transcripts.
suppress_compatibility_warning: If ``True``, do not log a warning when
the resolved STT or LLM is not among the bundled AMD-tested model
strings. Has no effect on classification behavior.
Expand All @@ -123,19 +131,36 @@ class AMD(EventEmitter[Literal["amd_prediction"]]):
omitted, library defaults apply.
"""

_DEFAULT_LLM_MODEL: str = "google/gemini-3.1-flash-lite"
_DEFAULT_STT_MODEL: str = "cartesia/ink-whisper"

def __init__(
self,
session: AgentSession,
*,
llm: NotGivenOr[LLM | LLMModels | str] = "google/gemini-3.1-flash-lite-preview",
stt: NotGivenOr[STT | str] = "cartesia/ink-whisper",
llm: NotGivenOr[LLM | LLMModels | str] = NOT_GIVEN,
stt: NotGivenOr[STT | str] = NOT_GIVEN,
interrupt_on_machine: bool = True,
ivr_detection: bool = True,
participant_identity: NotGivenOr[str] = NOT_GIVEN,
suppress_compatibility_warning: bool = False,
detection_options: NotGivenOr[DetectionOptions] = NOT_GIVEN,
) -> None:
super().__init__()

if not is_given(llm) or not is_given(stt):
api_key = os.getenv("LIVEKIT_INFERENCE_API_KEY") or os.getenv("LIVEKIT_API_KEY")
api_secret = os.getenv("LIVEKIT_INFERENCE_API_SECRET") or os.getenv(
"LIVEKIT_API_SECRET"
)
auto_select = (
is_cloud(os.getenv("LIVEKIT_URL", "")) and bool(api_key) and bool(api_secret)
)
if not is_given(llm):
llm = self._DEFAULT_LLM_MODEL if auto_select else NOT_GIVEN
if not is_given(stt):
stt = self._DEFAULT_STT_MODEL if auto_select else NOT_GIVEN

self._llm_config: NotGivenOr[LLM | LLMModels | str] = llm
self._session: AgentSession = session
self._interrupt_on_machine = interrupt_on_machine
Expand All @@ -156,11 +181,12 @@ def __init__(
)

if not self._suppress_compatibility_warning:
_warn_if_not_evaluated(
self._stt.model if is_given(self._stt) else None,
EVALUATED_STT_MODELS,
model_kind="stt",
)
if is_given(self._stt):
_warn_if_not_evaluated(
self._stt.model,
EVALUATED_STT_MODELS,
model_kind="stt",
)

self._stt_task: asyncio.Task[None] | None = None
self._audio_ch: aio.Chan[rtc.AudioFrame] | None = None
Expand Down
Loading