11from __future__ import annotations
22
33import asyncio
4+ import os
45from types import TracebackType
56from typing import TYPE_CHECKING , Literal , TypedDict
67
1617from ...telemetry import trace_types , tracer
1718from ...types import NOT_GIVEN , NotGivenOr
1819from ...utils import EventEmitter , aio , is_given
20+ from ...utils .misc import is_cloud
1921from ...utils .participant import wait_for_track_publication
2022from .classifier import (
2123 AMD_PROMPT ,
3537 from ..agent_session import AgentSession
3638
3739EVALUATED_LLM_MODELS : set [str ] = {
38- "google/gemini-3.1-flash-lite-preview " ,
40+ "google/gemini-3.1-flash-lite" ,
3941 "google/gemini-3-flash-preview" ,
4042 "openai/gpt-4.1" ,
4143 "openai/gpt-5.2" ,
@@ -103,18 +105,24 @@ class AMD(EventEmitter[Literal["amd_prediction"]]):
103105 session: The :class:`AgentSession` to wire AMD to.
104106 llm: LLM used for greeting classification. Accepts an :class:`LLM`
105107 instance or an inference model string (e.g.
106- ``"openai/gpt-4.1-mini"``). Omit to fall back to the session's
107- own LLM.
108+ ``"openai/gpt-4.1-mini"``). When omitted, AMD auto-selects:
109+ if LiveKit inference credentials are available in the environment
110+ it uses ``"google/gemini-3.1-flash-lite"`` via the
111+ inference gateway; otherwise it falls back to the session's own
112+ LLM.
108113 interrupt_on_machine: If ``True`` (default), interrupt any pending
109114 agent speech immediately when a machine is detected.
110115 ivr_detection: If ``True`` (default), automatically start IVR
111116 navigation when a ``machine-ivr`` result is returned.
112117 participant_identity: If set, only this participant's audio track
113118 subscription triggers the detection timers. If omitted, the first
114119 remote audio track wins.
115- stt: STT used for transcript generation. Required when the session
116- uses no STT (e.g. a realtime model). Omit to reuse the session's
117- STT transcripts.
120+ stt: STT used for transcript generation. Accepts an :class:`STT`
121+ instance or an inference model string (e.g.
122+ ``"cartesia/ink-whisper"``). When omitted, AMD auto-selects:
123+ if LiveKit inference credentials are available it uses
124+ ``"cartesia/ink-whisper"`` via the inference gateway; otherwise
125+ it reuses the session's existing STT transcripts.
118126 suppress_compatibility_warning: If ``True``, do not log a warning when
119127 the resolved STT or LLM is not among the bundled AMD-tested model
120128 strings. Has no effect on classification behavior.
@@ -123,19 +131,36 @@ class AMD(EventEmitter[Literal["amd_prediction"]]):
123131 omitted, library defaults apply.
124132 """
125133
134+ _DEFAULT_LLM_MODEL : str = "google/gemini-3.1-flash-lite"
135+ _DEFAULT_STT_MODEL : str = "cartesia/ink-whisper"
136+
126137 def __init__ (
127138 self ,
128139 session : AgentSession ,
129140 * ,
130- llm : NotGivenOr [LLM | LLMModels | str ] = "google/gemini-3.1-flash-lite-preview" ,
131- stt : NotGivenOr [STT | str ] = "cartesia/ink-whisper" ,
141+ llm : NotGivenOr [LLM | LLMModels | str ] = NOT_GIVEN ,
142+ stt : NotGivenOr [STT | str ] = NOT_GIVEN ,
132143 interrupt_on_machine : bool = True ,
133144 ivr_detection : bool = True ,
134145 participant_identity : NotGivenOr [str ] = NOT_GIVEN ,
135146 suppress_compatibility_warning : bool = False ,
136147 detection_options : NotGivenOr [DetectionOptions ] = NOT_GIVEN ,
137148 ) -> None :
138149 super ().__init__ ()
150+
151+ if not is_given (llm ) or not is_given (stt ):
152+ api_key = os .getenv ("LIVEKIT_INFERENCE_API_KEY" ) or os .getenv ("LIVEKIT_API_KEY" )
153+ api_secret = os .getenv ("LIVEKIT_INFERENCE_API_SECRET" ) or os .getenv (
154+ "LIVEKIT_API_SECRET"
155+ )
156+ auto_select = (
157+ is_cloud (os .getenv ("LIVEKIT_URL" , "" )) and bool (api_key ) and bool (api_secret )
158+ )
159+ if not is_given (llm ):
160+ llm = self ._DEFAULT_LLM_MODEL if auto_select else NOT_GIVEN
161+ if not is_given (stt ):
162+ stt = self ._DEFAULT_STT_MODEL if auto_select else NOT_GIVEN
163+
139164 self ._llm_config : NotGivenOr [LLM | LLMModels | str ] = llm
140165 self ._session : AgentSession = session
141166 self ._interrupt_on_machine = interrupt_on_machine
@@ -156,11 +181,12 @@ def __init__(
156181 )
157182
158183 if not self ._suppress_compatibility_warning :
159- _warn_if_not_evaluated (
160- self ._stt .model if is_given (self ._stt ) else None ,
161- EVALUATED_STT_MODELS ,
162- model_kind = "stt" ,
163- )
184+ if is_given (self ._stt ):
185+ _warn_if_not_evaluated (
186+ self ._stt .model ,
187+ EVALUATED_STT_MODELS ,
188+ model_kind = "stt" ,
189+ )
164190
165191 self ._stt_task : asyncio .Task [None ] | None = None
166192 self ._audio_ch : aio .Chan [rtc .AudioFrame ] | None = None
0 commit comments