Skip to content

Commit e2e0fe9

Browse files
authored
chore(amd): update default models and default behavior (#5713)
1 parent 3078382 commit e2e0fe9

1 file changed

Lines changed: 39 additions & 13 deletions

File tree

livekit-agents/livekit/agents/voice/amd/detector.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import os
45
from types import TracebackType
56
from typing import TYPE_CHECKING, Literal, TypedDict
67

@@ -16,6 +17,7 @@
1617
from ...telemetry import trace_types, tracer
1718
from ...types import NOT_GIVEN, NotGivenOr
1819
from ...utils import EventEmitter, aio, is_given
20+
from ...utils.misc import is_cloud
1921
from ...utils.participant import wait_for_track_publication
2022
from .classifier import (
2123
AMD_PROMPT,
@@ -35,7 +37,7 @@
3537
from ..agent_session import AgentSession
3638

3739
EVALUATED_LLM_MODELS: set[str] = {
38-
"google/gemini-3.1-flash-lite-preview",
40+
"google/gemini-3.1-flash-lite",
3941
"google/gemini-3-flash-preview",
4042
"openai/gpt-4.1",
4143
"openai/gpt-5.2",
@@ -103,18 +105,24 @@ class AMD(EventEmitter[Literal["amd_prediction"]]):
103105
session: The :class:`AgentSession` to wire AMD to.
104106
llm: LLM used for greeting classification. Accepts an :class:`LLM`
105107
instance or an inference model string (e.g.
106-
``"openai/gpt-4.1-mini"``). Omit to fall back to the session's
107-
own LLM.
108+
``"openai/gpt-4.1-mini"``). When omitted, AMD auto-selects:
109+
if LiveKit inference credentials are available in the environment
110+
it uses ``"google/gemini-3.1-flash-lite"`` via the
111+
inference gateway; otherwise it falls back to the session's own
112+
LLM.
108113
interrupt_on_machine: If ``True`` (default), interrupt any pending
109114
agent speech immediately when a machine is detected.
110115
ivr_detection: If ``True`` (default), automatically start IVR
111116
navigation when a ``machine-ivr`` result is returned.
112117
participant_identity: If set, only this participant's audio track
113118
subscription triggers the detection timers. If omitted, the first
114119
remote audio track wins.
115-
stt: STT used for transcript generation. Required when the session
116-
uses no STT (e.g. a realtime model). Omit to reuse the session's
117-
STT transcripts.
120+
stt: STT used for transcript generation. Accepts an :class:`STT`
121+
instance or an inference model string (e.g.
122+
``"cartesia/ink-whisper"``). When omitted, AMD auto-selects:
123+
if LiveKit inference credentials are available it uses
124+
``"cartesia/ink-whisper"`` via the inference gateway; otherwise
125+
it reuses the session's existing STT transcripts.
118126
suppress_compatibility_warning: If ``True``, do not log a warning when
119127
the resolved STT or LLM is not among the bundled AMD-tested model
120128
strings. Has no effect on classification behavior.
@@ -123,19 +131,36 @@ class AMD(EventEmitter[Literal["amd_prediction"]]):
123131
omitted, library defaults apply.
124132
"""
125133

134+
_DEFAULT_LLM_MODEL: str = "google/gemini-3.1-flash-lite"
135+
_DEFAULT_STT_MODEL: str = "cartesia/ink-whisper"
136+
126137
def __init__(
127138
self,
128139
session: AgentSession,
129140
*,
130-
llm: NotGivenOr[LLM | LLMModels | str] = "google/gemini-3.1-flash-lite-preview",
131-
stt: NotGivenOr[STT | str] = "cartesia/ink-whisper",
141+
llm: NotGivenOr[LLM | LLMModels | str] = NOT_GIVEN,
142+
stt: NotGivenOr[STT | str] = NOT_GIVEN,
132143
interrupt_on_machine: bool = True,
133144
ivr_detection: bool = True,
134145
participant_identity: NotGivenOr[str] = NOT_GIVEN,
135146
suppress_compatibility_warning: bool = False,
136147
detection_options: NotGivenOr[DetectionOptions] = NOT_GIVEN,
137148
) -> None:
138149
super().__init__()
150+
151+
if not is_given(llm) or not is_given(stt):
152+
api_key = os.getenv("LIVEKIT_INFERENCE_API_KEY") or os.getenv("LIVEKIT_API_KEY")
153+
api_secret = os.getenv("LIVEKIT_INFERENCE_API_SECRET") or os.getenv(
154+
"LIVEKIT_API_SECRET"
155+
)
156+
auto_select = (
157+
is_cloud(os.getenv("LIVEKIT_URL", "")) and bool(api_key) and bool(api_secret)
158+
)
159+
if not is_given(llm):
160+
llm = self._DEFAULT_LLM_MODEL if auto_select else NOT_GIVEN
161+
if not is_given(stt):
162+
stt = self._DEFAULT_STT_MODEL if auto_select else NOT_GIVEN
163+
139164
self._llm_config: NotGivenOr[LLM | LLMModels | str] = llm
140165
self._session: AgentSession = session
141166
self._interrupt_on_machine = interrupt_on_machine
@@ -156,11 +181,12 @@ def __init__(
156181
)
157182

158183
if not self._suppress_compatibility_warning:
159-
_warn_if_not_evaluated(
160-
self._stt.model if is_given(self._stt) else None,
161-
EVALUATED_STT_MODELS,
162-
model_kind="stt",
163-
)
184+
if is_given(self._stt):
185+
_warn_if_not_evaluated(
186+
self._stt.model,
187+
EVALUATED_STT_MODELS,
188+
model_kind="stt",
189+
)
164190

165191
self._stt_task: asyncio.Task[None] | None = None
166192
self._audio_ch: aio.Chan[rtc.AudioFrame] | None = None

0 commit comments

Comments
 (0)