Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/voice_agents/basic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self) -> None:
"with that in mind keep your responses concise and to the point."
"do not use emojis, asterisks, markdown, or other special characters in your responses."
"You are curious and friendly, and have a sense of humor."
"you will speak english to the user",
"You will speak english to the user over voice.",
tools=[EndCallTool()],
)

Expand Down Expand Up @@ -90,7 +90,7 @@ async def entrypoint(ctx: JobContext) -> None:
stt=inference.STT("deepgram/nova-3", language="multi"),
# A Large Language Model (LLM) is your agent's brain, processing user input and generating a response
# See all available models at https://docs.livekit.io/agents/models/llm/
llm=inference.LLM("openai/gpt-4.1-mini"),
llm=inference.LLM("google/gemini-3.5-flash"),
# Text-to-speech (TTS) is your agent's voice, turning the LLM's text into speech that the user can hear
# See all available models as well as voice selections at https://docs.livekit.io/agents/models/tts/
tts=inference.TTS("cartesia/sonic-3", voice="9626c31c-bec5-4cca-baa8-f8ba9e84c8bc"),
Expand All @@ -116,6 +116,11 @@ async def entrypoint(ctx: JobContext) -> None:
"filter_markdown",
text_transforms.replace({"LiveKit": "<<ˈ|l|aɪ|v>> <<ˈ|k|ɪ|t>>"}),
],
# automatically detect keyterms and apply them to the STT per user turn
keyterm_options={
"terms": ["LiveKit"],
"detection": {"enabled": True, "turn_interval": 1},
},

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For conversations where some context/user information is available before the call (e.g from a patient/customer profile loaded when starting), should we allow extracting keyterms from such context first?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps keep that on the developer side, they can pass the context like address, user name via keyterm_options={"terms": [...]} once the profile loads.

)

@session.on("metrics_collected")
Expand Down
3 changes: 3 additions & 0 deletions livekit-agents/livekit/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
AMDPredictionEvent,
)
from .voice.background_audio import AudioConfig, BackgroundAudioPlayer, BuiltinAudioClip, PlayHandle
from .voice.keyterms import KeytermDetectionOptions, KeytermOptions
from .voice.room_io import RoomInputOptions, RoomIO, RoomOutputOptions
from .voice.run_result import (
AgentHandoffEvent,
Expand Down Expand Up @@ -240,6 +241,8 @@ def __getattr__(name: str) -> typing.Any:
"InterruptionOptions",
"PreemptiveGenerationOptions",
"UserTurnLimitOptions",
"KeytermOptions",
"KeytermDetectionOptions",
"UserTurnExceededEvent",
]

Expand Down
30 changes: 30 additions & 0 deletions livekit-agents/livekit/agents/inference/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,24 @@ def _diarization_enabled(extra_kwargs: dict[str, Any] | None) -> bool:
return False


def _keyterms_extra_for_model(model: NotGivenOr[str], keyterms: list[str]) -> dict[str, Any] | None:
"""Map a provider-agnostic keyterms list onto the active provider's extra_kwargs key.

Returns None when the model does not support keyterm prompting. Called with an empty
list, it doubles as a capability check (non-None ⇒ supported). Keep every provider's
keyterm key here so capability inference and update_keyterms can't diverge.
"""
if not (is_given(model) and isinstance(model, str)):
return None
if model.startswith("deepgram/"):
return {"keyterm": list(keyterms)}
if model.startswith("assemblyai/"):
return {"keyterms_prompt": list(keyterms)}
if model.startswith("speechmatics/"):
return {"additional_vocab": [{"content": term} for term in keyterms]}
return None


STTLanguages = Literal["multi", "en", "de", "es", "fr", "ja", "pt", "zh", "hi"]


Expand Down Expand Up @@ -517,6 +535,7 @@ def __init__(
diarization=diarization_enabled,
aligned_transcript="word",
offline_recognize=False,
keyterms=_keyterms_extra_for_model(model, []) is not None,
),
)

Expand Down Expand Up @@ -634,6 +653,10 @@ def update_options(

self._opts.model = model
self._vad = _resolve_vad_for_model(model, self._vad)
self._capabilities = replace(
self._capabilities,
keyterms=_keyterms_extra_for_model(self._opts.model, []) is not None,
)
if is_given(language):
self._opts.language = LanguageCode(language)
if is_given(extra):
Expand All @@ -646,6 +669,13 @@ def update_options(
for stream in self._streams:
stream.update_options(model=model, language=language, extra=extra)

def update_keyterms(self, keyterms: list[str]) -> None:
extra = _keyterms_extra_for_model(self._opts.model, keyterms)
if extra is None:
super().update_keyterms(keyterms) # warn-and-skip for unsupported models
return
self.update_options(extra=extra)

def _sanitize_options(
self, *, language: NotGivenOr[STTLanguages | str] = NOT_GIVEN
) -> STTOptions:
Expand Down
6 changes: 6 additions & 0 deletions livekit-agents/livekit/agents/stt/fallback_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
interim_results=all(t.capabilities.interim_results for t in stt),
diarization=all(t.capabilities.diarization for t in stt),
aligned_transcript=aligned_transcript,
keyterms=any(t.capabilities.keyterms for t in stt),
)
)

Expand Down Expand Up @@ -113,6 +114,11 @@ def model(self) -> str:
def provider(self) -> str:
return "livekit"

def update_keyterms(self, keyterms: list[str]) -> None:
# forward to every underlying STT; unsupported ones warn-and-skip internally
for stt_instance in self._stt_instances:
stt_instance.update_keyterms(keyterms)

async def _try_recognize(
self,
*,
Expand Down
4 changes: 4 additions & 0 deletions livekit-agents/livekit/agents/stt/stream_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, *, stt: STT, vad: VAD) -> None:
streaming=True,
interim_results=False,
diarization=False, # diarization requires streaming STT
keyterms=stt.capabilities.keyterms,
)
)
self._vad = vad
Expand All @@ -42,6 +43,9 @@ def model(self) -> str:
def provider(self) -> str:
return self._stt.provider

def update_keyterms(self, keyterms: list[str]) -> None:
self._stt.update_keyterms(keyterms)

async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
Expand Down
14 changes: 14 additions & 0 deletions livekit-agents/livekit/agents/stt/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class STTCapabilities:
aligned_transcript: Literal["word", "chunk", False] = False
offline_recognize: bool = True
"""Whether the STT supports batch recognition via recognize() method"""
keyterms: bool = False
"""Whether the STT supports keyterm prompting via update_keyterms()"""


class STTError(BaseModel):
Expand All @@ -152,6 +154,7 @@ def __init__(self, *, capabilities: STTCapabilities) -> None:
self._capabilities = capabilities
self._label = f"{type(self).__module__}.{type(self).__name__}"
self._recognize_metrics_needed = True
self._keyterms_unsupported_warned = False

@property
def label(self) -> str:
Expand Down Expand Up @@ -264,6 +267,17 @@ def _emit_error(self, api_error: Exception, recoverable: bool) -> None:
),
)

def update_keyterms(self, keyterms: list[str]) -> None:
"""Set the keyterms used to bias recognition toward specific words/phrases."""
if not self._capabilities.keyterms:
if not self._keyterms_unsupported_warned:
self._keyterms_unsupported_warned = True
logger.warning(
"keyterms are not supported by this STT, ignoring update_keyterms()",
extra={"stt": self._label},
)
return

def stream(
self,
*,
Expand Down
3 changes: 3 additions & 0 deletions livekit-agents/livekit/agents/voice/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
UserStateChangedEvent,
UserTurnExceededEvent,
)
from .keyterms import KeytermDetectionOptions, KeytermOptions
from .room_io import (
_ParticipantAudioOutput,
_ParticipantStreamTranscriptionOutput,
Expand Down Expand Up @@ -49,6 +50,8 @@
"FunctionToolsExecutedEvent",
"AgentFalseInterruptionEvent",
"UserTurnExceededEvent",
"KeytermOptions",
"KeytermDetectionOptions",
"TranscriptSynchronizer",
"io",
"room_io",
Expand Down
10 changes: 10 additions & 0 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,13 @@ async def _start_session(self, *, reuse_resources: _ReusableResources | None = N
else:
self._audio_recognition.start()

# bind the session's keyterm detector to this activity's STT and LLM
self._session._keyterm_detector.start(
self._session,
stt=self.stt if isinstance(self.stt, stt.STT) else None,
llm=self.llm if isinstance(self.llm, llm.LLM) else None,
)

@tracer.start_as_current_span("drain_agent_activity")
async def drain(
self, *, new_activity: AgentActivity | None = None
Expand Down Expand Up @@ -903,6 +910,8 @@ async def _pause_scheduling_task(
if self._scheduling_paused:
return

await self._session._keyterm_detector.aclose()

self._scheduling_paused = True
self._drain_blocked_tasks = blocked_tasks or []
self._wake_up_scheduling_task()
Expand Down Expand Up @@ -1037,6 +1046,7 @@ async def aclose(self) -> None:

self._closed = True
self._cancel_preemptive_generation()
await self._session._keyterm_detector.aclose()

# on_exit_task should be awaited in `drain`
self._on_exit_task = None
Expand Down
24 changes: 24 additions & 0 deletions livekit-agents/livekit/agents/voice/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
UserStateChangedEvent,
)
from .ivr import IVRActivity
from .keyterms import KeytermDetectionOptions, KeytermDetector, KeytermOptions, _resolve_detection
from .recorder_io import RecorderIO
from .remote_session import RoomSessionTransport, SessionHost
from .run_result import RunResult
Expand Down Expand Up @@ -141,6 +142,7 @@ class SessionConnectOptions:
@dataclass
class AgentSessionOptions:
turn_handling: TurnHandlingOptions
keyterm_detection: KeytermDetectionOptions
max_tool_steps: int
user_away_timeout: float | None
min_consecutive_speech_delay: float
Expand Down Expand Up @@ -228,6 +230,7 @@ def __init__(
llm: NotGivenOr[llm.LLM | llm.RealtimeModel | LLMModels | str] = NOT_GIVEN,
tts: NotGivenOr[tts.TTS | TTSModels | str] = NOT_GIVEN,
turn_handling: NotGivenOr[TurnHandlingOptions] = NOT_GIVEN,
keyterm_options: NotGivenOr[KeytermOptions] = NOT_GIVEN,
# Tool settings
tools: NotGivenOr[list[llm.Tool | llm.Toolset]] = NOT_GIVEN,
tool_handling: NotGivenOr[ToolHandlingOptions] = NOT_GIVEN,
Expand Down Expand Up @@ -285,6 +288,9 @@ def __init__(
providing external tools for the agent to use.
userdata (Userdata_T, optional): Arbitrary per-session user data.
turn_handling (TurnHandlingOptions, optional): Configuration for turn handling.
keyterm_options (KeytermOptions, optional): Keyterm prompting for the STT. Holds
user-defined ``terms`` and optional automatic ``detection`` config. Applies to
supported STTs; unsupported ones warn and ignore it.
max_endpointing_delay (float): Maximum time-in-seconds the agent
will wait before terminating the turn. Default ``3.0`` s.
max_tool_steps (int): Maximum consecutive tool calls per LLM turn.
Expand Down Expand Up @@ -367,6 +373,8 @@ def __init__(
user_turn_limit = _resolve_user_turn_limit(turn_handling.get("user_turn_limit"))
raw_turn_detection = turn_handling.get("turn_detection", None)

keyterm_opts: KeytermOptions = keyterm_options if is_given(keyterm_options) else {}

# This is the "global" chat_context, it holds the entire conversation history
self._chat_ctx = ChatContext.empty()
self._opts = AgentSessionOptions(
Expand All @@ -377,6 +385,7 @@ def __init__(
preemptive_generation=preemptive_gen,
user_turn_limit=user_turn_limit,
),
keyterm_detection=_resolve_detection(keyterm_opts.get("detection")),
max_tool_steps=max_tool_steps,
user_away_timeout=user_away_timeout,
min_consecutive_speech_delay=min_consecutive_speech_delay,
Expand Down Expand Up @@ -409,6 +418,11 @@ def __init__(
self._llm = llm or None
self._tts = tts or None

self._keyterm_detector = KeytermDetector(
user_keyterms=keyterm_opts.get("terms"),
options=self._opts.keyterm_detection,
)

self._turn_detection = raw_turn_detection
self._interruption_detection = interruption.get("mode", NOT_GIVEN)
self._mcp_servers = mcp_servers or None
Expand Down Expand Up @@ -554,6 +568,11 @@ def conn_options(self) -> SessionConnectOptions:
def history(self) -> llm.ChatContext:
return self._chat_ctx

@property
def keyterms(self) -> list[str]:
"""The effective keyterms (user-defined + auto-detected) currently applied to the STT."""
return self._keyterm_detector.keyterms

@property
def current_speech(self) -> SpeechHandle | None:
return self._activity.current_speech if self._activity is not None else None
Expand Down Expand Up @@ -1054,6 +1073,7 @@ def update_options(
*,
endpointing_opts: NotGivenOr[EndpointingOptions] = NOT_GIVEN,
turn_detection: NotGivenOr[TurnDetectionMode | None] = NOT_GIVEN,
keyterms: NotGivenOr[list[str]] = NOT_GIVEN,
# deprecated
min_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
max_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
Expand All @@ -1065,9 +1085,13 @@ def update_options(
endpointing_opts (NotGivenOr[EndpointingOptions], optional): Endpointing options.
turn_detection (NotGivenOr[TurnDetectionMode | None], optional): Strategy for deciding
when the user has finished speaking. ``None`` reverts to automatic selection.
keyterms (NotGivenOr[list[str]], optional): Replace the user-defined keyterms applied
to the STT. Auto-detected keyterms are left untouched.
min_endpointing_delay: Deprecated, use ``endpointing_opts`` instead.
max_endpointing_delay: Deprecated, use ``endpointing_opts`` instead.
"""
if is_given(keyterms):
self._keyterm_detector.set_user_keyterms(keyterms)
if is_given(min_endpointing_delay) or is_given(max_endpointing_delay):
logger.warning(
"min_endpointing_delay and max_endpointing_delay are deprecated, "
Expand Down
Loading