From f4134323a91ab39a3e65c2432e8ef2d0ae150dd0 Mon Sep 17 00:00:00 2001 From: Neppkun Date: Tue, 2 Jun 2026 21:43:40 +0300 Subject: [PATCH 01/12] Add voice support (send + receive + DAVE/MLS E2EE) Implements joining/leaving/moving voice channels, sending audio (Opus passthrough + PCM encode), receiving audio (per-user sinks), and DAVE/MLS end-to-end encryption, faithful to discord.http's pure-asyncio, stdlib-first design. - Phase 0: merge VoiceState/PartialVoiceState into channel.py; delete voice.py - Phase 1: op-4 sender, VOICE_SERVER_UPDATE parser + routing, voice-client registry on Client, PartialChannel.connect() - voice/ subpackage: voice WS (v8) + heartbeat, UDP/IP-discovery, RTP framing, aead_aes256_gcm_rtpsize via cryptography (no PyNaCl), ctypes libopus binding (optional), Ogg/Opus parser, async AudioSource/player + FFmpeg sources, receiver + sinks, DAVE/MLS via the optional davey extra - Phase 7: ExponentialBackoff, voice-close-code reconnect/resume, opt-in resume_voice persistence, default insta-leave on shard reset - Phase 8: public exports, voice example, pyproject voice package + extra Transport encryption needs no extra deps. libopus (PCM encode/decode) and FFmpeg (transcoding) are optional external runtimes; davey (DAVE E2EE) is the sole optional Python dep, gated behind discord.http[voice]. Co-Authored-By: Claude Opus 4.8 --- discord_http/channel.py | 265 +++++++++++- discord_http/client.py | 160 ++++++- discord_http/gateway/cache.py | 3 +- discord_http/gateway/parser.py | 37 +- discord_http/gateway/shard.py | 61 +++ discord_http/guild.py | 2 +- discord_http/utils.py | 55 +++ discord_http/voice.py | 209 --------- discord_http/voice/__init__.py | 36 ++ discord_http/voice/client.py | 347 +++++++++++++++ discord_http/voice/connection.py | 619 ++++++++++++++++++++++++++ discord_http/voice/dave.py | 395 +++++++++++++++++ discord_http/voice/encryptor.py | 82 ++++ discord_http/voice/enums.py | 36 ++ discord_http/voice/gateway_udp.py | 162 +++++++ discord_http/voice/oggparse.py | 205 +++++++++ discord_http/voice/opus.py | 598 ++++++++++++++++++++++++++ discord_http/voice/player.py | 692 ++++++++++++++++++++++++++++++ discord_http/voice/receiver.py | 257 +++++++++++ discord_http/voice/sinks.py | 176 ++++++++ discord_http/voice/socket.py | 419 ++++++++++++++++++ examples/voice_example.py | 104 +++++ pyproject.toml | 7 +- tests/test_voice_encryptor.py | 59 +++ tests/test_voice_oggparse.py | 136 ++++++ 25 files changed, 4895 insertions(+), 227 deletions(-) delete mode 100644 discord_http/voice.py create mode 100644 discord_http/voice/__init__.py create mode 100644 discord_http/voice/client.py create mode 100644 discord_http/voice/connection.py create mode 100644 discord_http/voice/dave.py create mode 100644 discord_http/voice/encryptor.py create mode 100644 discord_http/voice/enums.py create mode 100644 discord_http/voice/gateway_udp.py create mode 100644 discord_http/voice/oggparse.py create mode 100644 discord_http/voice/opus.py create mode 100644 discord_http/voice/player.py create mode 100644 discord_http/voice/receiver.py create mode 100644 discord_http/voice/sinks.py create mode 100644 discord_http/voice/socket.py create mode 100644 examples/voice_example.py create mode 100644 tests/test_voice_encryptor.py create mode 100644 tests/test_voice_oggparse.py diff --git a/discord_http/channel.py b/discord_http/channel.py index f26151e..b528461 100644 --- a/discord_http/channel.py +++ b/discord_http/channel.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Callable, Generator from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Self, overload, Literal +from typing import TYPE_CHECKING, Any, Self, overload, Literal from . import utils from .embeds import Embed @@ -32,6 +32,7 @@ from .member import ThreadMember from .message import PartialMessage, Message, Poll from .user import PartialUser, User + from .voice.client import VoiceClient MISSING = utils.MISSING @@ -48,6 +49,7 @@ "NewsThread", "PartialChannel", "PartialThread", + "PartialVoiceState", "PrivateThread", "PublicThread", "StageChannel", @@ -56,6 +58,7 @@ "Thread", "VoiceChannel", "VoiceRegion", + "VoiceState", ) @@ -478,6 +481,74 @@ async def create_invite( data=r.response ) + async def connect( + self, + *, + timeout: float = 30.0, + reconnect: bool = True, + self_deaf: bool = False, + self_mute: bool = False + ) -> "VoiceClient": + """ + Connect to this voice channel. + + Parameters + ---------- + timeout: + How long to wait, in seconds, for the voice handshake to complete. + reconnect: + Whether to automatically reconnect if the voice connection drops. + self_deaf: + Whether the bot should be self-deafened. + self_mute: + Whether the bot should be self-muted. + + Returns + ------- + The voice client for the connection. + + Raises + ------ + `TypeError` + If the channel is not a voice or stage channel. + `ValueError` + If the channel is not associated with a guild. + `NotImplementedError` + If the gateway is not available. + `RuntimeError` + If the bot is already connected to a voice channel in this guild. + """ + if self.type not in ( + ChannelType.unknown, + ChannelType.guild_voice, + ChannelType.guild_stage_voice + ): + raise TypeError("Cannot connect to a non-voice channel") + + if not self.guild_id: + raise ValueError("Cannot connect to a voice channel without a guild") + + client = self._state.bot + + if not client.gateway: + raise NotImplementedError("gateway is not available") + + if client._get_voice_client(self.guild_id) is not None: + raise RuntimeError("Already connected to a voice channel in this guild") + + from .voice.client import VoiceClient + + vc = VoiceClient(client, self) + client._add_voice_client(self.guild_id, vc) + await vc.connect( + timeout=timeout, + reconnect=reconnect, + self_deaf=self_deaf, + self_mute=self_mute + ) + + return vc + async def send( self, content: str | None = MISSING, @@ -2732,3 +2803,195 @@ async def create_stage_instance( guild=self.guild ) return self._stage_instance + + +class PartialVoiceState(PartialBase): + """ Represents a partial voice state object. """ + + __slots__ = ( + "_state", + "channel_id", + "guild_id", + ) + + def __init__( + self, + *, + state: "DiscordAPI", + id: int, # noqa: A002 + channel_id: int | None = None, + guild_id: int | None = None, + ): + self._state = state + + self.id: int = int(id) + """ The ID of the user this voice state belongs to. """ + + self.channel_id: int | None = channel_id + """ The ID of the voice channel this user is in, if any. """ + + self.guild_id: int | None = guild_id + """ The ID of the guild this voice state is in, if any. """ + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return "PartialVoiceState" + + async def fetch(self) -> "VoiceState": + """ + Fetches the voice state of the member. + + Returns + ------- + The voice state of the member + + Raises + ------ + `NotFound` + - If the member is not in the guild + - If the member is not in a voice channel + """ + if not self.guild_id: + raise ValueError("Cannot fetch voice state without guild_id") + + r = await self._state.query( + "GET", + f"/guilds/{self.guild_id}/voice-states/{self.id}" + ) + + guild = self._state.cache.get_guild(self.guild_id) + channel = None + if self.channel_id is not None: + channel = self._state.cache.get_channel(self.guild_id, self.channel_id) + + return VoiceState( + state=self._state, + data=r.response, + guild=guild, + channel=channel + ) + + async def edit( + self, + *, + suppress: bool = MISSING, + ) -> None: + """ + Updates the voice state of the member. + + Parameters + ---------- + suppress: + Whether to suppress the user + """ + if not self.guild_id: + raise ValueError("Cannot update voice state without guild_id") + + data: dict[str, Any] = {} + + if suppress is not MISSING: + data["suppress"] = bool(suppress) + + await self._state.query( + "PATCH", + f"/guilds/{self.guild_id}/voice-states/{int(self.id)}", + json=data, + res_method="text" + ) + + +class VoiceState(PartialVoiceState): + """ Represents a voice state object. """ + + __slots__ = ( + "channel", + "deaf", + "guild", + "member", + "mute", + "request_to_speak_timestamp", + "self_deaf", + "self_mute", + "self_stream", + "self_video", + "session_id", + "suppress", + "user", + ) + + def __init__( + self, + *, + state: "DiscordAPI", + data: dict, + guild: "PartialGuild | None", + channel: "BaseChannel | PartialChannel | None" + ): + from .user import PartialUser + + super().__init__( + state=state, + id=int(data["user_id"]), + guild_id=utils.get_int(data, "guild_id"), + channel_id=utils.get_int(data, "channel_id") + ) + + self.session_id: str = data["session_id"] + """ The session ID of the voice state. """ + + self.user: PartialUser = PartialUser(state=state, id=int(data["user_id"])) + """ The user this voice state belongs to. """ + + self.member: "Member | None" = None + """ The member this voice state belongs to, if any. """ + + self.channel: "BaseChannel | PartialChannel | None" = channel + """ The voice channel this user is in, if any. """ + + self.guild: "PartialGuild | None" = guild + """ The guild this voice state is in, if any. """ + + self.deaf: bool = data["deaf"] + """ Whether the user is deafened by the server. """ + + self.mute: bool = data["mute"] + """ Whether the user is muted by the server. """ + + self.self_deaf: bool = data["self_deaf"] + """ Whether the user is deafened by themselves. """ + + self.self_mute: bool = data["self_mute"] + """ Whether the user is muted by themselves. """ + + self.self_stream: bool = data.get("self_stream", False) + """ Whether the user is streaming. """ + + self.self_video: bool = data["self_video"] + """ Whether the user is using video. """ + + self.suppress: bool = data["suppress"] + """ Whether the user is suppressed by the server. """ + + self.request_to_speak_timestamp: datetime | None = None + """ The timestamp when the user requested to speak, if any. """ + + self._from_data(data) + + def __repr__(self) -> str: + return f"" + + def _from_data(self, data: dict) -> None: + if data.get("member") and self.guild: + from .member import Member + self.member = Member( + state=self._state, + guild=self.guild, + data=data["member"] + ) + + if data.get("request_to_speak_timestamp"): + self.request_to_speak_timestamp = utils.parse_time( + data["request_to_speak_timestamp"] + ) diff --git a/discord_http/client.py b/discord_http/client.py index baad1a2..44bf3c1 100644 --- a/discord_http/client.py +++ b/discord_http/client.py @@ -13,7 +13,7 @@ from . import utils, __version__ from .automod import PartialAutoModRule, AutoModRule from .backend import DiscordHTTP -from .channel import PartialChannel, BaseChannel +from .channel import PartialChannel, BaseChannel, PartialVoiceState, VoiceState from .commands import Command, Interaction, Listener, Cog, SubGroup from .context import Context from .emoji import PartialEmoji, Emoji @@ -34,13 +34,13 @@ from .sticker import PartialSticker, Sticker from .user import User, PartialUser, Application from .view import InteractionStorage -from .voice import PartialVoiceState, VoiceState from .webhook import PartialWebhook, Webhook if TYPE_CHECKING: from .gateway.client import GatewayClient from .gateway.flags import GatewayCacheFlags, Intents from .gateway.object import PlayingStatus + from .voice.client import VoiceClient _log = logging.getLogger(__name__) @@ -102,6 +102,9 @@ class Client: Whether to disable the default GET path or not, if not provided, it will use `False`. The default GET path only provides information about the bot and when it was last rebooted. Usually a great tool to just validate that your bot is online. + resume_voice: bool + Whether to remember and revive voice clients across shard reboots, if not provided, it will use `False`. + Requires the `guild_voice_states` intent; otherwise voice clients are torn down on shard reset. """ def __init__( self, @@ -126,7 +129,8 @@ def __init__( intents: "Intents | None" = None, logging_level: int = logging.INFO, disable_default_get_path: bool = False, - debug_events: bool = False + debug_events: bool = False, + resume_voice: bool = False ): if application_id is not None: _log.warning( @@ -149,6 +153,7 @@ def __init__( self.logging_level: int = logging_level self.debug_events: bool = debug_events self.enable_gateway: bool = enable_gateway + self.resume_voice: bool = resume_voice self.playing_status: "PlayingStatus | None" = playing_status self.guild_ready_timeout: float = guild_ready_timeout self.chunk_guilds_on_startup: bool = chunk_guilds_on_startup @@ -204,6 +209,7 @@ def __init__( self._after_invoke: tuple[Callable, bool] | None = None self._waiting_listeners: dict[str, list[tuple[asyncio.Future, Callable]]] = {} self._background_tasks: set[asyncio.Task] = set() + self._voice_clients: dict[int, "VoiceClient"] = {} utils.setup_logger(level=self.logging_level) @@ -216,6 +222,149 @@ def _cleanup_task(self, task: asyncio.Task) -> None: except Exception: pass + def _get_voice_client(self, guild_id: int) -> "VoiceClient | None": + """ + Get the voice client for a guild, if one is registered. + + Parameters + ---------- + guild_id: + The guild to get the voice client for. + + Returns + ------- + The voice client, or ``None`` if none is registered. + """ + return self._voice_clients.get(guild_id) + + def _add_voice_client(self, guild_id: int, voice_client: "VoiceClient") -> None: + """ + Register a voice client for a guild. + + Parameters + ---------- + guild_id: + The guild to register the voice client for. + voice_client: + The voice client to register. + """ + self._voice_clients[guild_id] = voice_client + + def _remove_voice_client(self, guild_id: int) -> None: + """ + Remove the voice client for a guild, if one is registered. + + Parameters + ---------- + guild_id: + The guild to remove the voice client for. + """ + self._voice_clients.pop(guild_id, None) + + def _voice_clients_for_shard(self, shard_id: int) -> "list[VoiceClient]": + """ + Return the registered voice clients whose guilds belong to a shard. + + Parameters + ---------- + shard_id: + The shard to enumerate voice clients for. + + Returns + ------- + The voice clients belonging to the shard's guilds. + """ + if not self.gateway: + return [] + + return [ + vc for guild_id, vc in list(self._voice_clients.items()) + if self.get_shard_by_guild_id(guild_id) == shard_id + ] + + def _has_voice_states_intent(self, shard_id: int) -> bool: + """ + Whether the ``guild_voice_states`` intent is enabled for a shard. + + Parameters + ---------- + shard_id: + The shard to check the intents of. + + Returns + ------- + ``True`` if the intent is available, otherwise ``False``. + """ + from .gateway.flags import Intents + + if not self.gateway: + return False + + shard = self.gateway.get_shard(shard_id) + if shard is None: + return False + + return Intents.guild_voice_states in shard.intents + + async def _revive_voice_clients(self, shard_id: int) -> None: + """ + Re-establish remembered voice clients after a shard READY/RESUMED. + + This is a no-op unless :attr:`resume_voice` is enabled, the + ``guild_voice_states`` intent is available, and there are voice clients + belonging to the shard. Each remembered client re-issues op4 and runs a + fresh handshake, resuming playback where possible. + + Parameters + ---------- + shard_id: + The shard that just became ready or resumed. + """ + if not self.resume_voice: + return + + voice_clients = self._voice_clients_for_shard(shard_id) + if not voice_clients: + return + + if not self._has_voice_states_intent(shard_id): + _log.warning( + "resume_voice is enabled but the guild_voice_states intent is missing; " + "tearing down voice clients for shard %s", + shard_id + ) + await self._teardown_voice_clients_for_shard(shard_id) + return + + for vc in voice_clients: + if vc.is_connected(): + # Still alive (e.g. a RESUMED where nothing was torn down). + continue + try: + await vc.connect( + self_deaf=vc.connection._self_deaf, + self_mute=vc.connection._self_mute, + ) + except Exception as exc: + _log.warning("Failed to revive voice client for guild %s", vc.guild_id, exc_info=exc) + await vc._cleanup() + + async def _teardown_voice_clients_for_shard(self, shard_id: int) -> None: + """ + Tear down every voice client belonging to a shard's guilds. + + Stops playback, closes the websocket and UDP transport, and removes the + client from the registry. Used when a shard is reset or killed and the + connections should not survive. + + Parameters + ---------- + shard_id: + The shard whose voice clients should be torn down. + """ + for vc in self._voice_clients_for_shard(shard_id): + await vc._cleanup() + async def _cooldown_cleanup_loop(self) -> None: """ Periodically sweeps expired cooldown buckets that accumulate between invocations. """ while True: @@ -514,6 +663,11 @@ def user(self) -> User: return self.application.bot + @property + def voice_clients(self) -> list["VoiceClient"]: + """ Returns a list of all the voice clients the bot is connected to. """ + return list(self._voice_clients.values()) + @property def guilds(self) -> list[Guild | PartialGuild]: """ diff --git a/discord_http/gateway/cache.py b/discord_http/gateway/cache.py index ed5d8a8..09b0773 100644 --- a/discord_http/gateway/cache.py +++ b/discord_http/gateway/cache.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING -from ..channel import BaseChannel -from ..voice import VoiceState, PartialVoiceState +from ..channel import BaseChannel, PartialVoiceState, VoiceState from .flags import GatewayCacheFlags diff --git a/discord_http/gateway/parser.py b/discord_http/gateway/parser.py index bc6e46d..6d8586a 100644 --- a/discord_http/gateway/parser.py +++ b/discord_http/gateway/parser.py @@ -7,7 +7,10 @@ from .. import utils from ..audit import AuditLogEntry from ..automod import AutoModRule -from ..channel import BaseChannel, PartialChannel, StageInstance, PartialThread +from ..channel import ( + BaseChannel, PartialChannel, PartialThread, + PartialVoiceState, StageInstance, VoiceState +) from ..emoji import Emoji, EmojiParser from ..entitlements import Entitlements from ..enums import ChannelType @@ -20,7 +23,6 @@ from ..soundboard import PartialSoundboardSound, SoundboardSound from ..sticker import Sticker from ..user import User, PartialUser -from ..voice import VoiceState, PartialVoiceState from .enums import PollVoteActionType from .flags import GatewayCacheFlags @@ -1336,16 +1338,24 @@ def invite_delete(self, data: dict) -> tuple[PartialInvite]: ), ) - """ - This is just a placeholder for now. - I am unsure if I ever will handle voice communication with discord.http/gateway - Let this be a reminder for myself in later time + def voice_server_update(self, data: dict) -> tuple[dict]: + """ + Voice server update event. - - AlexFlipnote, 9. October 2024 + Parameters + ---------- + data: + Data received from the event. - def voice_channel_effect_send(self, data: dict) -> tuple[None]: - return (None,) - """ + Returns + ------- + The raw voice server update payload. + """ + vc = self.bot._get_voice_client(int(data["guild_id"])) + if vc is not None: + self.bot.loop.create_task(vc.on_voice_server_update(data)) + + return (data,) def voice_state_update(self, data: dict) -> tuple[ VoiceState | PartialVoiceState | None, @@ -1385,6 +1395,13 @@ def voice_state_update(self, data: dict) -> tuple[ ) self.bot.cache.update_voice_state(vs) + + bot_user = self.bot.application.bot if self.bot.application else None + if bot_user is not None and int(data["user_id"]) == bot_user.id: + vc = self.bot._get_voice_client(int(data["guild_id"])) + if vc is not None: + self.bot.loop.create_task(vc.on_voice_state_update(data)) + return (before_vs, vs) def typing_start(self, data: dict) -> tuple[TypingStartEvent]: diff --git a/discord_http/gateway/shard.py b/discord_http/gateway/shard.py index 1f68ba0..cf46fe0 100644 --- a/discord_http/gateway/shard.py +++ b/discord_http/gateway/shard.py @@ -370,7 +370,31 @@ def _reset_buffer(self) -> None: self._buffer = bytearray() self._zlib = zlib.decompressobj() + def _revive_voice_clients(self) -> None: + """ Schedule revival of remembered voice clients after READY/RESUMED. """ + if not self.bot._voice_clients: + return + task = asyncio.create_task( + self.bot._revive_voice_clients(self.shard_id), + name=f"discord.http/gateway/shard-{self.shard_id}/revive_voice" + ) + self.bot._background_tasks.add(task) + task.add_done_callback(self.bot._cleanup_task) + + def _teardown_voice_clients(self) -> None: + """ Insta-leave any voice clients on this shard unless persistence is on. """ + if self.bot.resume_voice or not self.bot._voice_clients: + return + task = asyncio.create_task( + self.bot._teardown_voice_clients_for_shard(self.shard_id), + name=f"discord.http/gateway/shard-{self.shard_id}/teardown_voice" + ) + self.bot._background_tasks.add(task) + task.add_done_callback(self.bot._cleanup_task) + def _reset_instance(self) -> None: + self._teardown_voice_clients() + self._reset_buffer() self.status.reset() @@ -569,7 +593,11 @@ async def received_message(self, raw_msg: str | bytes) -> None: name=f"discord.http/gateway/shard-{self.shard_id}/delay_ready" ) + self._revive_voice_clients() + case "RESUMED": + self._revive_voice_clients() + if self.bot.has_any_dispatch("shard_resumed"): self.bot.dispatch( "shard_resumed", @@ -1084,6 +1112,39 @@ async def change_presence(self, status: PlayingStatus) -> None: "d": status.to_dict() }) + async def change_voice_state( + self, + *, + guild_id: int, + channel_id: int | None, + self_mute: bool = False, + self_deaf: bool = False + ) -> None: + """ + Changes the voice state of the shard for the specified guild. + + Parameters + ---------- + guild_id: + The guild to change the voice state in. + channel_id: + The voice channel to connect to, or ``None`` to disconnect. + self_mute: + Whether the bot is self-muted. + self_deaf: + Whether the bot is self-deafened. + """ + _log.debug(f"Changing voice state in Shard {self.shard_id} for guild {guild_id} to channel {channel_id}") + await self.send_message({ + "op": int(PayloadType.voice_state), + "d": { + "guild_id": str(guild_id), + "channel_id": str(channel_id) if channel_id is not None else None, + "self_mute": bool(self_mute), + "self_deaf": bool(self_deaf) + } + }) + def payload(self, op: PayloadType) -> dict: """ Returns a payload for the websocket. diff --git a/discord_http/guild.py b/discord_http/guild.py index 794fdbb..668875e 100644 --- a/discord_http/guild.py +++ b/discord_http/guild.py @@ -29,7 +29,7 @@ from .message import Message from .soundboard import SoundboardSound, PartialSoundboardSound from .sticker import Sticker, PartialSticker -from .voice import VoiceState, PartialVoiceState +from .channel import VoiceState, PartialVoiceState if TYPE_CHECKING: from .audit import AuditLogEntry diff --git a/discord_http/utils.py b/discord_http/utils.py index 0f1579a..b85cbae 100644 --- a/discord_http/utils.py +++ b/discord_http/utils.py @@ -3,6 +3,7 @@ import logging import orjson import posixpath +import random import re import struct import sys @@ -294,6 +295,60 @@ def to_dict(self) -> dict[str, float]: } +class ExponentialBackoff: + """ + A small helper that produces exponentially increasing delays. + + Each call to :meth:`delay` returns ``base * 2 ** exp`` (capped at + ``max_delay``), incrementing an internal exponent so successive calls grow + geometrically. Optional jitter spreads retries out to avoid thundering-herd + reconnect storms. Calling :meth:`reset` returns the backoff to its initial + state, e.g. after a successful reconnect. + + Parameters + ---------- + base: + The base delay, in seconds, used as the multiplier for the exponent. + max_delay: + The maximum delay, in seconds, that any single call may return. + jitter: + Whether to apply random jitter to each returned delay. + """ + + def __init__(self, base: float = 1.0, *, max_delay: float = 60.0, jitter: bool = True): + self.base: float = base + """ The base delay, in seconds. """ + + self.max_delay: float = max_delay + """ The maximum delay, in seconds, returned by :meth:`delay`. """ + + self.jitter: bool = jitter + """ Whether random jitter is applied to each delay. """ + + self._exp: int = 0 + + def delay(self) -> float: + """ + Return the next backoff delay and advance the internal exponent. + + Returns + ------- + The next delay, in seconds, capped at :attr:`max_delay` and + optionally jittered. + """ + self._exp += 1 + value = min(self.base * (2 ** (self._exp - 1)), self.max_delay) + + if self.jitter: + value *= random.uniform(0.5, 1.0) + + return value + + def reset(self) -> None: + """ Reset the internal exponent so the next delay starts from the base. """ + self._exp = 0 + + def format_small_unit(seconds: float | timedelta) -> str: """ Helper to scale sub-second values to the appropriate unit. diff --git a/discord_http/voice.py b/discord_http/voice.py deleted file mode 100644 index 733a947..0000000 --- a/discord_http/voice.py +++ /dev/null @@ -1,209 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING, Any - -from . import utils -from .object import PartialBase -from .user import PartialUser - -MISSING = utils.MISSING - -if TYPE_CHECKING: - from .channel import BaseChannel, PartialChannel - from .guild import PartialGuild - from .http import DiscordAPI - from .member import Member - -__all__ = ( - "PartialVoiceState", - "VoiceState", -) - - -class PartialVoiceState(PartialBase): - """ Represents a partial voice state object. """ - - __slots__ = ( - "_state", - "channel_id", - "guild_id", - ) - - def __init__( - self, - *, - state: "DiscordAPI", - id: int, # noqa: A002 - channel_id: int | None = None, - guild_id: int | None = None, - ): - self._state = state - - self.id: int = int(id) - """ The ID of the user this voice state belongs to. """ - - self.channel_id: int | None = channel_id - """ The ID of the voice channel this user is in, if any. """ - - self.guild_id: int | None = guild_id - """ The ID of the guild this voice state is in, if any. """ - - def __repr__(self) -> str: - return f"" - - def __str__(self) -> str: - return "PartialVoiceState" - - async def fetch(self) -> "VoiceState": - """ - Fetches the voice state of the member. - - Returns - ------- - The voice state of the member - - Raises - ------ - `NotFound` - - If the member is not in the guild - - If the member is not in a voice channel - """ - if not self.guild_id: - raise ValueError("Cannot fetch voice state without guild_id") - - r = await self._state.query( - "GET", - f"/guilds/{self.guild_id}/voice-states/{self.id}" - ) - - guild = self._state.cache.get_guild(self.guild_id) - channel = None - if self.channel_id is not None: - channel = self._state.cache.get_channel(self.guild_id, self.channel_id) - - return VoiceState( - state=self._state, - data=r.response, - guild=guild, - channel=channel - ) - - async def edit( - self, - *, - suppress: bool = MISSING, - ) -> None: - """ - Updates the voice state of the member. - - Parameters - ---------- - suppress: - Whether to suppress the user - """ - if not self.guild_id: - raise ValueError("Cannot update voice state without guild_id") - - data: dict[str, Any] = {} - - if suppress is not MISSING: - data["suppress"] = bool(suppress) - - await self._state.query( - "PATCH", - f"/guilds/{self.guild_id}/voice-states/{int(self.id)}", - json=data, - res_method="text" - ) - - -class VoiceState(PartialVoiceState): - """ Represents a voice state object. """ - - __slots__ = ( - "channel", - "deaf", - "guild", - "member", - "mute", - "request_to_speak_timestamp", - "self_deaf", - "self_mute", - "self_stream", - "self_video", - "session_id", - "suppress", - "user", - ) - - def __init__( - self, - *, - state: "DiscordAPI", - data: dict, - guild: "PartialGuild | None", - channel: "BaseChannel | PartialChannel | None" - ): - super().__init__( - state=state, - id=int(data["user_id"]), - guild_id=utils.get_int(data, "guild_id"), - channel_id=utils.get_int(data, "channel_id") - ) - - self.session_id: str = data["session_id"] - """ The session ID of the voice state. """ - - self.user: PartialUser = PartialUser(state=state, id=int(data["user_id"])) - """ The user this voice state belongs to. """ - - self.member: "Member | None" = None - """ The member this voice state belongs to, if any. """ - - self.channel: "BaseChannel | PartialChannel | None" = channel - """ The voice channel this user is in, if any. """ - - self.guild: "PartialGuild | None" = guild - """ The guild this voice state is in, if any. """ - - self.deaf: bool = data["deaf"] - """ Whether the user is deafened by the server. """ - - self.mute: bool = data["mute"] - """ Whether the user is muted by the server. """ - - self.self_deaf: bool = data["self_deaf"] - """ Whether the user is deafened by themselves. """ - - self.self_mute: bool = data["self_mute"] - """ Whether the user is muted by themselves. """ - - self.self_stream: bool = data.get("self_stream", False) - """ Whether the user is streaming. """ - - self.self_video: bool = data["self_video"] - """ Whether the user is using video. """ - - self.suppress: bool = data["suppress"] - """ Whether the user is suppressed by the server. """ - - self.request_to_speak_timestamp: datetime | None = None - """ The timestamp when the user requested to speak, if any. """ - - self._from_data(data) - - def __repr__(self) -> str: - return f"" - - def _from_data(self, data: dict) -> None: - if data.get("member") and self.guild: - from .member import Member - self.member = Member( - state=self._state, - guild=self.guild, - data=data["member"] - ) - - if data.get("request_to_speak_timestamp"): - self.request_to_speak_timestamp = utils.parse_time( - data["request_to_speak_timestamp"] - ) diff --git a/discord_http/voice/__init__.py b/discord_http/voice/__init__.py new file mode 100644 index 0000000..ec6d7c5 --- /dev/null +++ b/discord_http/voice/__init__.py @@ -0,0 +1,36 @@ +# ruff: noqa: F403, F405 +from . import opus +from .client import * +from .connection import * +from .dave import has_dave, max_protocol_version +from .enums import SUPPORTED_MODES, VoiceOp +from .opus import OPUS_SILENCE, OpusError, OpusNotLoaded, is_loaded, load_opus +from .player import * +from .receiver import * +from .sinks import * + +__all__ = ( + "OPUS_SILENCE", + "SUPPORTED_MODES", + "AudioPlayer", + "AudioSink", + "AudioSource", + "CallbackSink", + "FFmpegOpusAudio", + "FFmpegPCMAudio", + "OpusError", + "OpusNotLoaded", + "PCMAudio", + "PCMVolumeTransformer", + "VoiceClient", + "VoiceConnection", + "VoiceData", + "VoiceOp", + "VoiceReceiver", + "WaveSink", + "has_dave", + "is_loaded", + "load_opus", + "max_protocol_version", + "opus", +) diff --git a/discord_http/voice/client.py b/discord_http/voice/client.py new file mode 100644 index 0000000..5935571 --- /dev/null +++ b/discord_http/voice/client.py @@ -0,0 +1,347 @@ +import asyncio +import logging +import struct + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from .connection import VoiceConnection + +if TYPE_CHECKING: + from ..channel import PartialChannel + from ..client import Client + from .opus import Encoder + from .player import AudioPlayer + from .receiver import VoiceReceiver + from .sinks import AudioSink + +__all__ = ("VoiceClient",) + +_log = logging.getLogger(__name__) + + +class VoiceClient: + """ + The public handle for an active voice connection in a guild. + + Wraps the lower-level :class:`VoiceConnection` and exposes playback, + receiving, and RTP transmission helpers. + """ + + def __init__(self, client: "Client", channel: "PartialChannel"): + self.client: "Client" = client + """ The bot client that owns this voice client. """ + + self.bot: "Client" = client + """ Alias of :attr:`client`. """ + + self.channel: "PartialChannel" = channel + """ The voice channel this client is connected to. """ + + if channel.guild_id is None: + raise ValueError("Cannot create a voice client for a channel without a guild") + + self.guild_id: int = channel.guild_id + """ The ID of the guild this voice client is in. """ + + self.connection: VoiceConnection = VoiceConnection(self) + """ The underlying voice connection state machine. """ + + self._player: "AudioPlayer | None" = None + self._receiver: "VoiceReceiver | None" = None + self._encoder: "Encoder | None" = None + + @property + def loop(self) -> asyncio.AbstractEventLoop: + """ The event loop the client runs on. """ + return self.client.loop + + @property + def user_id(self) -> int: + """ The ID of the bot user. """ + return self.client.user.id + + @property + def ssrc(self) -> int | None: + """ The SSRC assigned to this connection. """ + return self.connection.ssrc + + @property + def secret_key(self) -> bytes | None: + """ The transport secret key, once known. """ + return self.connection.secret_key + + @property + def latency(self) -> float: + """ The latency of the most recent voice heartbeat, in seconds. """ + return self.connection.latency + + @property + def average_latency(self) -> float: + """ The average latency of recent voice heartbeats, in seconds. """ + return self.connection.average_latency + + @property + def voice_privacy_code(self) -> str | None: + """ The DAVE voice privacy code, if available. """ + return self.connection.voice_privacy_code + + @property + def endpoint(self) -> str | None: + """ The voice server endpoint host. """ + return self.connection.endpoint + + @property + def session_id(self) -> str | None: + """ The voice session ID. """ + return self.connection.session_id + + def is_connected(self) -> bool: + """ Whether the voice connection is established. """ + return self.connection.is_connected() + + async def connect( + self, + *, + timeout: float = 30.0, + reconnect: bool = True, + self_deaf: bool = False, + self_mute: bool = False + ) -> None: + """ + Connect to the voice channel. + + Parameters + ---------- + timeout: + The maximum time to wait for the handshake, in seconds. + reconnect: + Whether to attempt reconnection on failure. + self_deaf: + Whether to join self-deafened. + self_mute: + Whether to join self-muted. + """ + await self.connection.connect( + timeout=timeout, + reconnect=reconnect, + self_deaf=self_deaf, + self_mute=self_mute, + ) + + async def disconnect(self, *, force: bool = True) -> None: + """ + Disconnect from the voice channel and clean up. + + Parameters + ---------- + force: + Whether to force the disconnect even on error. + """ + if self._player is not None: + self._player.stop() + self._player = None + + if self._receiver is not None: + self._receiver.stop() + self._receiver = None + + if self._encoder is not None: + self._encoder.cleanup() + self._encoder = None + + await self.connection.disconnect(force=force) + self.client._remove_voice_client(self.guild_id) + + async def _cleanup(self) -> None: + """ + Tear down the voice client locally without relying on the gateway. + + Stops the player and receiver, closes the websocket and UDP transport, + and removes the client from the registry. Safe to call when the owning + shard has been reset or killed and op4 can no longer be sent. + """ + if self._player is not None: + self._player.stop() + self._player = None + + if self._receiver is not None: + self._receiver.stop() + self._receiver = None + + if self._encoder is not None: + self._encoder.cleanup() + self._encoder = None + + await self.connection.close_transport() + self.client._remove_voice_client(self.guild_id) + + async def move_to(self, channel: "PartialChannel") -> None: + """ + Move to a different voice channel. + + Parameters + ---------- + channel: + The channel to move to. + """ + await self.connection.move_to(channel) + self.channel = channel + + async def on_voice_state_update(self, data: dict) -> None: + """ + Forward a VOICE_STATE_UPDATE to the connection. + + Parameters + ---------- + data: + The raw voice state update payload. + """ + self.connection.on_voice_state_update(data) + + async def on_voice_server_update(self, data: dict) -> None: + """ + Forward a VOICE_SERVER_UPDATE to the connection. + + Parameters + ---------- + data: + The raw voice server update payload. + """ + self.connection.on_voice_server_update(data) + + async def speak(self, speaking: bool = True) -> None: + """ + Send the SPEAKING frame to the voice gateway. + + Parameters + ---------- + speaking: + Whether the bot is speaking. + """ + if self.connection.socket is None or self.connection.ssrc is None: + return + await self.connection.socket.send_speaking(1 if speaking else 0, ssrc=self.connection.ssrc) + + def _get_encoder(self) -> "Encoder": + """ + Return the cached Opus encoder, creating it on first use. + + Returns + ------- + The Opus encoder for this client. + """ + if self._encoder is None: + from .opus import Encoder + + self._encoder = Encoder() + return self._encoder + + def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: + """ + Frame, encrypt, and transmit a single audio packet over UDP. + + Parameters + ---------- + data: + The audio payload: PCM when ``encode`` is ``True`` else a raw Opus packet. + encode: + Whether ``data`` is PCM that must be Opus-encoded first. + """ + connection = self.connection + if connection.encryptor is None or connection.transport is None or connection.ssrc is None: + return + + opus = self._get_encoder().encode(data) if encode else data + + # DAVE end-to-end encryption applies to the Opus payload before RTP framing. + if connection.can_encrypt(): + opus = connection.dave_encrypt_opus(opus) + + connection.sequence = (connection.sequence + 1) % (2 ** 16) + + header = bytearray(12) + struct.pack_into(">BBHII", header, 0, 0x80, 0x78, connection.sequence, connection.timestamp, connection.ssrc) + + connection.timestamp = (connection.timestamp + 960) % (2 ** 32) + + packet = connection.encryptor.encrypt(bytes(header), opus) + + try: + connection.transport.sendto(packet) + except OSError: + _log.debug("Failed to send audio packet for guild %s", self.guild_id) + + def play( + self, + audio: Any, # noqa: ANN401 + *, + after: Callable[[Exception | None], object] | None = None + ) -> None: + """ + Play an audio source over the connection. + + Parameters + ---------- + audio: + The audio source, path, bytes, or stream to play. + after: + A callback invoked with any error once playback finishes. + """ + from .player import AudioPlayer, _resolve_source + + if self._player is not None: + self._player.stop() + + source = _resolve_source(audio) + player = AudioPlayer(source, self, after=after) + self._player = player + player.start() + + def pause(self) -> None: + """ Pause the current playback. """ + if self._player is not None: + self._player.pause() + + def resume(self) -> None: + """ Resume paused playback. """ + if self._player is not None: + self._player.resume() + + def stop(self) -> None: + """ Stop the current playback. """ + if self._player is not None: + self._player.stop() + self._player = None + + def is_playing(self) -> bool: + """ Whether audio is currently playing. """ + return self._player is not None and self._player.is_playing() + + def is_paused(self) -> bool: + """ Whether playback is currently paused. """ + return self._player is not None and self._player.is_paused() + + def listen(self, sink: "AudioSink") -> None: + """ + Start receiving voice into the given sink. + + Parameters + ---------- + sink: + The audio sink to write received audio into. + """ + from .receiver import VoiceReceiver + + if self._receiver is None: + self._receiver = VoiceReceiver(self) + self._receiver.start(sink) + + def stop_listening(self) -> None: + """ Stop receiving voice. """ + if self._receiver is not None: + self._receiver.stop() + + def is_listening(self) -> bool: + """ Whether the client is currently receiving voice. """ + return self._receiver is not None and self._receiver.is_listening() diff --git a/discord_http/voice/connection.py b/discord_http/voice/connection.py new file mode 100644 index 0000000..7bf5d45 --- /dev/null +++ b/discord_http/voice/connection.py @@ -0,0 +1,619 @@ +import asyncio +import logging + +from typing import TYPE_CHECKING, Any + +from ..utils import ExponentialBackoff +from .encryptor import Encryptor +from .enums import SUPPORTED_MODES +from .gateway_udp import VoiceUDPProtocol, create_udp +from .socket import VoiceCloseCode, VoiceSocket + +if TYPE_CHECKING: + from ..channel import PartialChannel + from .client import VoiceClient + from .dave import DaveManager + +__all__ = ("VoiceConnection",) + +_log = logging.getLogger(__name__) + +#: The maximum number of full-reconnect attempts before giving up. +MAX_RECONNECT_ATTEMPTS = 5 + + +class VoiceConnection: + """ + The transport and control-plane state machine for a voice connection. + + Coordinates the gateway voice-state/voice-server handshake, the voice + websocket, UDP IP discovery, and the encryption handshake to reach a + fully connected state. + """ + + def __init__(self, voice_client: "VoiceClient"): + self.voice_client: "VoiceClient" = voice_client + """ The voice client that owns this connection. """ + + self.guild_id: int = voice_client.guild_id + """ The ID of the guild this connection is for. """ + + self.channel_id: int | None = voice_client.channel.id + """ The ID of the voice channel currently targeted. """ + + self.user_id: int = voice_client.user_id + """ The ID of the bot user. """ + + self.socket: VoiceSocket | None = None + """ The voice websocket, if open. """ + + self.udp: VoiceUDPProtocol | None = None + """ The UDP protocol, if connected. """ + + self.transport: asyncio.DatagramTransport | None = None + """ The UDP datagram transport, if connected. """ + + self.encryptor: Encryptor | None = None + """ The transport encryptor, once the secret key is known. """ + + self.token: str | None = None + """ The voice connection token from the voice server update. """ + + self.endpoint: str | None = None + """ The voice server endpoint host, without scheme or port. """ + + self.session_id: str | None = None + """ The voice session ID from the voice state update. """ + + self.server_id: int | None = None + """ The server (guild) ID from the voice server update. """ + + self.ssrc: int | None = None + """ The synchronisation source identifier assigned by the gateway. """ + + self.secret_key: bytes | None = None + """ The secret key used for transport encryption. """ + + self.endpoint_ip: str | None = None + """ The discovered external IP address. """ + + self.endpoint_port: int | None = None + """ The discovered external UDP port. """ + + self.mode: str | None = None + """ The negotiated encryption mode. """ + + self.sequence: int = 0 + """ The RTP sequence counter. """ + + self.timestamp: int = 0 + """ The RTP timestamp counter. """ + + self.dave_session: "DaveManager | None" = None + """ The DAVE/MLS session manager, if a protocol version was negotiated. """ + + self.dave_protocol_version: int = 0 + """ The negotiated DAVE protocol version (0 if not in use). """ + + self.pending_transitions: dict[int, Any] = {} + """ Pending DAVE protocol transitions, keyed by transition ID. """ + + self._state_event: asyncio.Event = asyncio.Event() + self._server_event: asyncio.Event = asyncio.Event() + self._ready_event: asyncio.Event = asyncio.Event() + self._connected_event: asyncio.Event = asyncio.Event() + + self._reconnect: bool = True + self._self_mute: bool = False + self._self_deaf: bool = False + self._closing: bool = False + self._reconnect_task: asyncio.Task | None = None + self._backoff: ExponentialBackoff = ExponentialBackoff(base=1.0, max_delay=30.0) + + @property + def latency(self) -> float: + """ The latency of the most recent voice heartbeat, in seconds. """ + if self.socket is None: + return float("inf") + return self.socket.latency + + @property + def average_latency(self) -> float: + """ The average latency of recent voice heartbeats, in seconds. """ + if self.socket is None: + return float("inf") + return self.socket.average_latency + + @property + def voice_privacy_code(self) -> str | None: + """ The DAVE voice privacy code, if a DAVE session is active. """ + if self.dave_session is None: + return None + return self.dave_session.voice_privacy_code + + def is_connected(self) -> bool: + """ Whether the connection has completed its handshake. """ + return self._connected_event.is_set() + + async def connect( + self, + *, + timeout: float = 30.0, + reconnect: bool = True, + self_deaf: bool = False, + self_mute: bool = False + ) -> None: + """ + Establish the full voice connection. + + Parameters + ---------- + timeout: + The maximum time to wait for the handshake, in seconds. + reconnect: + Whether to attempt reconnection on failure. + self_deaf: + Whether to join self-deafened. + self_mute: + Whether to join self-muted. + + Raises + ------ + RuntimeError + If no shard can be resolved for the guild. + TimeoutError + If the handshake does not complete within ``timeout``. + """ + client = self.voice_client.client + + shard_id = client.get_shard_by_guild_id(self.guild_id) + if shard_id is None: + raise RuntimeError(f"Could not resolve a shard for guild {self.guild_id}") + + shard = client.gateway.get_shard(shard_id) if client.gateway else None + if shard is None: + raise RuntimeError(f"Could not resolve shard {shard_id} for guild {self.guild_id}") + + self._reconnect = reconnect + self._self_mute = self_mute + self._self_deaf = self_deaf + self._closing = False + self._backoff.reset() + + self._state_event.clear() + self._server_event.clear() + self._ready_event.clear() + self._connected_event.clear() + + await shard.change_voice_state( + guild_id=self.guild_id, + channel_id=self.channel_id, + self_mute=self_mute, + self_deaf=self_deaf, + ) + + await asyncio.wait_for(self._wait_for_handshake(), timeout) + + async def _wait_for_handshake(self) -> None: + """ Wait for the gateway handshake then drive the voice socket to connected. """ + await self._state_event.wait() + await self._server_event.wait() + + self.socket = VoiceSocket(self) + await self.socket.connect() + + await self._connected_event.wait() + + def _on_socket_closed(self, close_code: int | None) -> None: + """ + React to the voice websocket closing by deciding whether to reconnect. + + Parameters + ---------- + close_code: + The websocket close code, if one was reported. + """ + if self._closing: + return + + if self._reconnect_task is not None and not self._reconnect_task.done(): + return + + self._reconnect_task = asyncio.create_task( + self._handle_close(close_code), + name=f"discord.http/voice/connection-{self.guild_id}/reconnect" + ) + + async def _handle_close(self, close_code: int | None) -> None: + """ + Drive reconnect/resume logic based on the voice gateway close code. + + Parameters + ---------- + close_code: + The websocket close code, if one was reported. + """ + if close_code in (VoiceCloseCode.disconnected, VoiceCloseCode.call_terminated): + _log.info("Voice connection for guild %s disconnected (code %s); tearing down", self.guild_id, close_code) + await self._teardown_and_remove() + return + + if close_code == VoiceCloseCode.rate_limited: + _log.warning("Voice connection for guild %s was rate limited (code %s); not reconnecting", self.guild_id, close_code) + await self._teardown_and_remove() + return + + if close_code == VoiceCloseCode.voice_server_crashed: + _log.info("Voice server for guild %s crashed (code %s); resuming", self.guild_id, close_code) + await self._resume() + return + + if close_code in (VoiceCloseCode.normal, VoiceCloseCode.going_away): + _log.debug("Voice connection for guild %s closed cleanly (code %s)", self.guild_id, close_code) + await self._teardown_and_remove() + return + + if not self._reconnect: + _log.debug("Voice connection for guild %s closed (code %s); reconnect disabled", self.guild_id, close_code) + await self._teardown_and_remove() + return + + await self._full_reconnect(close_code) + + async def _resume(self) -> None: + """ Re-open the voice websocket and RESUME (op 7) the existing session. """ + if self.socket is not None: + self.socket._request_close() + await self.socket.close() + + try: + self._connected_event.clear() + self.socket = VoiceSocket(self) + await self.socket.connect(resume=True) + except Exception as exc: + _log.warning("Voice resume for guild %s failed; falling back to full reconnect", self.guild_id, exc_info=exc) + await self._full_reconnect(VoiceCloseCode.voice_server_crashed) + + async def _full_reconnect(self, close_code: int | None) -> None: + """ + Re-issue op4 and run a fresh handshake, retrying with exponential backoff. + + Parameters + ---------- + close_code: + The close code that triggered the reconnect, for logging. + """ + if self.socket is not None: + self.socket._request_close() + await self.socket.close() + self.socket = None + + if self.transport is not None: + self.transport.close() + self.transport = None + self.udp = None + + self._backoff.reset() + + for attempt in range(1, MAX_RECONNECT_ATTEMPTS + 1): + if self._closing: + return + + delay = self._backoff.delay() + _log.info( + "Reconnecting voice for guild %s (close code %s), attempt %s/%s in %.2fs", + self.guild_id, close_code, attempt, MAX_RECONNECT_ATTEMPTS, delay + ) + await asyncio.sleep(delay) + + try: + await self.connect( + reconnect=self._reconnect, + self_deaf=self._self_deaf, + self_mute=self._self_mute, + ) + except Exception as exc: + _log.warning("Voice reconnect attempt %s for guild %s failed", attempt, self.guild_id, exc_info=exc) + continue + else: + _log.info("Voice connection for guild %s reconnected", self.guild_id) + return + + _log.error("Voice connection for guild %s could not reconnect after %s attempts; tearing down", self.guild_id, MAX_RECONNECT_ATTEMPTS) + await self._teardown_and_remove() + + async def _teardown_and_remove(self) -> None: + """ Tear down the connection and remove the voice client from the registry. """ + try: + await self.voice_client.disconnect(force=True) + except Exception as exc: + _log.debug("Error during voice teardown for guild %s", self.guild_id, exc_info=exc) + + def on_voice_state_update(self, data: dict) -> None: + """ + Handle a VOICE_STATE_UPDATE for the bot. + + Parameters + ---------- + data: + The raw voice state update payload. + """ + self.session_id = data.get("session_id") or self.session_id + + channel_id = data.get("channel_id") + self.channel_id = int(channel_id) if channel_id is not None else None + + if self.session_id is not None: + self._state_event.set() + + def on_voice_server_update(self, data: dict) -> None: + """ + Handle a VOICE_SERVER_UPDATE for the guild. + + Parameters + ---------- + data: + The raw voice server update payload. + """ + self.token = data.get("token") + + endpoint = data.get("endpoint") + if endpoint: + endpoint = endpoint.removeprefix("wss://").removeprefix("ws://") + endpoint = endpoint.split("/", 1)[0] + endpoint = endpoint.rsplit(":", 1)[0] + self.endpoint = endpoint + + server_id = data.get("guild_id") or data.get("server_id") + self.server_id = int(server_id) if server_id is not None else None + + if self.token is not None and self.endpoint is not None: + self._server_event.set() + + async def on_ready(self, data: dict) -> None: + """ + Handle the voice READY (op 2): set up UDP and select the protocol. + + Parameters + ---------- + data: + The READY payload containing ssrc, ip, port and modes. + """ + self.ssrc = int(data["ssrc"]) + ip = data["ip"] + port = int(data["port"]) + modes = data.get("modes", []) + + self.mode = next((m for m in SUPPORTED_MODES if m in modes), SUPPORTED_MODES[0]) + + self.transport, self.udp = await create_udp(self, ip, port) + + discovered_ip, discovered_port = await self.udp.discover_ip(self.ssrc) + self.endpoint_ip = discovered_ip + self.endpoint_port = discovered_port + + self._ready_event.set() + + if self.socket is not None: + await self.socket.send_select_protocol(discovered_ip, discovered_port, self.mode) + + async def on_session_description(self, data: dict) -> None: + """ + Handle the SESSION_DESCRIPTION (op 4): build the encryptor. + + Parameters + ---------- + data: + The session description payload with the secret key and mode. + """ + secret_key = bytes(data["secret_key"]) + self.secret_key = secret_key + self.mode = data.get("mode", self.mode) + self.encryptor = Encryptor(secret_key) + + dave_version = int(data.get("dave_protocol_version", 0) or 0) + self.dave_protocol_version = dave_version + if dave_version > 0: + await self.reinit_dave_session() + + self._connected_event.set() + + async def on_speaking(self, data: dict) -> None: + """ + Handle a SPEAKING (op 5) frame from another user. + + Parameters + ---------- + data: + The speaking payload with ssrc, user_id and speaking flags. + """ + receiver = self.voice_client._receiver + if receiver is None: + return + + ssrc = data.get("ssrc") + user_id = data.get("user_id") + if ssrc is not None and user_id is not None: + receiver.add_ssrc(int(ssrc), int(user_id)) + + async def on_resumed(self, data: dict) -> None: # noqa: ARG002 + """ + Handle the RESUMED (op 9) frame. + + Parameters + ---------- + data: + The resumed payload. + """ + _log.debug("Voice connection for guild %s resumed", self.guild_id) + + async def on_dave_binary(self, opcode: int, payload: bytes) -> None: + """ + Handle an inbound binary DAVE frame. + + Parameters + ---------- + opcode: + The voice opcode of the binary frame. + payload: + The binary payload following the opcode. + """ + from .dave import has_dave + + if not has_dave: + _log.warning("Received DAVE binary op %s but the davey library is not available", opcode) + return + + if self.dave_session is None: + await self.reinit_dave_session() + + if self.dave_session is not None: + await self.dave_session.handle_binary(opcode, payload) + + async def reinit_dave_session(self) -> None: + """ Create or reset the DAVE session for the negotiated protocol version. """ + from .dave import DaveManager, has_dave + + if not has_dave: + if self.dave_protocol_version > 0: + raise RuntimeError( + "Discord negotiated a DAVE protocol version but the davey library is not installed" + ) + return + + if self.dave_session is None: + self.dave_session = DaveManager(self) + + await self.dave_session.reinit(self.dave_protocol_version) + + def can_encrypt(self) -> bool: + """ Whether a DAVE session is ready to encrypt Opus payloads. """ + return self.dave_session is not None and self.dave_session.ready + + def dave_encrypt_opus(self, opus: bytes) -> bytes: + """ + Encrypt an Opus payload through the DAVE session, if active. + + Parameters + ---------- + opus: + The Opus payload to encrypt. + + Returns + ------- + The DAVE-encrypted Opus payload, or the input unchanged when inactive. + """ + if self.dave_session is None or not self.dave_session.ready: + return opus + return self.dave_session.encrypt_opus(opus) + + def dave_decrypt_opus(self, user_id: int, opus: bytes) -> bytes: + """ + Decrypt an Opus payload through the DAVE session, if active. + + Parameters + ---------- + user_id: + The user ID the payload was received from. + opus: + The Opus payload to decrypt. + + Returns + ------- + The decrypted Opus payload, or the input unchanged when inactive. + """ + if self.dave_session is None or not self.dave_session.ready: + return opus + return self.dave_session.decrypt_opus(user_id, opus) + + async def disconnect(self, *, force: bool = True) -> None: + """ + Tear down the voice connection. + + Parameters + ---------- + force: + Whether to force the disconnect even if the gateway update fails. + """ + self._closing = True + + if self._reconnect_task is not None and self._reconnect_task is not asyncio.current_task(): + self._reconnect_task.cancel() + self._reconnect_task = None + + if self.socket is not None: + self.socket._request_close() + + client = self.voice_client.client + try: + shard_id = client.get_shard_by_guild_id(self.guild_id) + shard = client.gateway.get_shard(shard_id) if (client.gateway and shard_id is not None) else None + if shard is not None: + await shard.change_voice_state(guild_id=self.guild_id, channel_id=None) + except Exception as exc: + if not force: + raise + _log.debug("Failed to send voice disconnect for guild %s", self.guild_id, exc_info=exc) + + if self.socket is not None: + await self.socket.close() + self.socket = None + + if self.transport is not None: + self.transport.close() + self.transport = None + + self.udp = None + self.encryptor = None + self.secret_key = None + self.ssrc = None + self._connected_event.clear() + self._ready_event.clear() + + async def close_transport(self) -> None: + """ + Close the websocket and UDP transport without notifying the gateway. + + Used when the owning shard has been reset or killed, so op4 can no + longer be sent. Cancels any pending reconnect and clears local state. + """ + self._closing = True + + if self._reconnect_task is not None and self._reconnect_task is not asyncio.current_task(): + self._reconnect_task.cancel() + self._reconnect_task = None + + if self.socket is not None: + self.socket._request_close() + await self.socket.close() + self.socket = None + + if self.transport is not None: + self.transport.close() + self.transport = None + + self.udp = None + self.encryptor = None + self.secret_key = None + self.ssrc = None + self._connected_event.clear() + self._ready_event.clear() + + async def move_to(self, channel: "PartialChannel") -> None: + """ + Move the connection to a different voice channel. + + Parameters + ---------- + channel: + The channel to move to. + """ + client = self.voice_client.client + + shard_id = client.get_shard_by_guild_id(self.guild_id) + shard = client.gateway.get_shard(shard_id) if (client.gateway and shard_id is not None) else None + if shard is None: + raise RuntimeError(f"Could not resolve a shard for guild {self.guild_id}") + + self.channel_id = channel.id + await shard.change_voice_state(guild_id=self.guild_id, channel_id=channel.id) diff --git a/discord_http/voice/dave.py b/discord_http/voice/dave.py new file mode 100644 index 0000000..1e68b9e --- /dev/null +++ b/discord_http/voice/dave.py @@ -0,0 +1,395 @@ +import logging + +from typing import TYPE_CHECKING, Any + +try: + import davey # type: ignore[import-not-found] + has_dave = True +except ImportError: + davey = None + has_dave = False + +from .enums import VoiceOp + +if TYPE_CHECKING: + from .connection import VoiceConnection + +# The ``davey`` session object is dynamically typed (the package may be absent and its +# exact API is not statically known here), so it is treated as ``Any`` to keep static +# analysis clean without scattering per-access suppressions. +_Session = Any + +__all__ = ( + "DaveManager", + "has_dave", + "max_protocol_version", +) + +_log = logging.getLogger(__name__) + + +def max_protocol_version() -> int: + """ + The maximum DAVE protocol version this installation can negotiate. + + Returns + ------- + The ``davey.DAVE_PROTOCOL_VERSION`` when the optional ``davey`` package is + installed, otherwise ``0`` (signalling that only transport encryption is available). + """ + if has_dave and davey is not None: + return int(davey.DAVE_PROTOCOL_VERSION) + return 0 + + +class DaveManager: + """ + Manages DAVE (Discord Audio/Video End-to-end Encryption) for a voice connection. + + DAVE layers end-to-end encryption on top of Discord's transport encryption using + the MLS (Messaging Layer Security) protocol, provided by the optional ``davey`` + package. When ``davey`` is not installed, this manager degrades gracefully: every + method becomes a safe no-op and audio is passed through unchanged so that the base + library keeps working with transport encryption only. + + The manager wraps a single ``davey.DaveSession`` and is driven by the voice + connection and socket, which forward the negotiated protocol version and the binary + MLS opcodes (21-31) received from the voice gateway. Outbound MLS messages are sent + back through ``connection.socket.send_binary(opcode, payload)``. + + Parameters + ---------- + connection: + The voice connection this manager operates on. Used to read ``user_id`` and + ``channel_id``, to send binary MLS operations, and to learn the negotiated version. + """ + + __slots__ = ( + "_connection", + "_pending_transition", + "_session", + "_version", + ) + + def __init__(self, connection: "VoiceConnection"): + self._connection: "VoiceConnection" = connection + self._session: _Session | None = None + self._version: int = 0 + self._pending_transition: tuple[int, int] | None = None + + @property + def ready(self) -> bool: + """ Whether the underlying MLS session has completed its handshake. """ + if self._session is not None: + try: + return bool(self._session.ready) + except AttributeError: + return False + return False + + @property + def voice_privacy_code(self) -> str | None: + """ The privacy code users can compare out-of-band to verify the E2EE session, if any. """ + if self._session is not None: + try: + code = self._session.voice_privacy_code + except AttributeError: + return None + return str(code) if code is not None else None + return None + + def can_encrypt(self) -> bool: + """ Whether end-to-end encryption is currently active and usable. """ + return self._version > 0 and self._session is not None and self.ready + + async def reinit(self, version: int) -> None: + """ + Create or reinitialise the MLS session for a given protocol version. + + When ``version`` is greater than ``0`` but ``davey`` is not installed, a clear, + actionable :class:`RuntimeError` is raised telling the user how to install the + optional dependency. When a session is created, its serialized MLS key package is + sent to the voice gateway. + + Parameters + ---------- + version: + The negotiated DAVE protocol version. ``0`` disables end-to-end encryption. + + Raises + ------ + RuntimeError + If a non-zero DAVE version is requested but ``davey`` is not installed. + """ + self._version = version + + if version <= 0: + self._session = None + return + + if not has_dave or davey is None: + raise RuntimeError( + "Discord negotiated DAVE end-to-end encryption " + f"(protocol version {version}), but the optional 'davey' package is not " + 'installed. Install it with: pip install "discord.http[voice]"' + ) + + try: + self._session = davey.DaveSession( + version, + self._connection.user_id, + self._connection.channel_id, + ) + except Exception as exc: + _log.warning("Failed to initialise DAVE session: %s", exc) + self._session = None + return + + await self._send_key_package() + + async def _send_key_package(self) -> None: + """ Serialize and send our MLS key package to the voice gateway. """ + if self._session is None: + return + + try: + key_package = self._session.get_serialized_key_package() + except AttributeError: + return + + await self._connection.socket.send_binary( + int(VoiceOp.dave_mls_key_package), key_package + ) + + def set_passthrough_mode(self, enabled: bool) -> None: + """ + Toggle passthrough mode on the MLS session. + + While in passthrough mode the session does not transform media, allowing audio to + flow during transitions where some participants are not yet on the new epoch. + + Parameters + ---------- + enabled: + ``True`` to pass media through unchanged, ``False`` to resume encryption. + """ + if self._session is not None: + try: + self._session.set_passthrough_mode(enabled) + except AttributeError: + pass + + def encrypt_opus(self, opus: bytes) -> bytes: + """ + End-to-end encrypt an outbound Opus frame. + + Parameters + ---------- + opus: + The plaintext Opus frame. + + Returns + ------- + The encrypted frame, or the input unchanged when E2EE is not active. + """ + if not self.can_encrypt() or self._session is None: + return opus + try: + return bytes(self._session.encrypt_opus(opus)) + except AttributeError: + return opus + + def decrypt_opus(self, user_id: int, opus: bytes) -> bytes: + """ + End-to-end decrypt an inbound Opus frame from a given user. + + Parameters + ---------- + user_id: + The user the frame originated from. + opus: + The received Opus frame. + + Returns + ------- + The decrypted frame, or the input unchanged when E2EE is not active. + """ + if not self.can_encrypt() or self._session is None: + return opus + try: + return bytes(self._session.decrypt_opus(user_id, opus)) + except AttributeError: + return opus + + async def handle_binary(self, opcode: int, payload: bytes) -> None: + """ + Dispatch a binary DAVE/MLS operation (opcodes 21-31) received from the gateway. + + Parameters + ---------- + opcode: + The voice opcode, expected to be one of the DAVE ops 21-31. + payload: + The raw binary payload following the opcode. + """ + if opcode == VoiceOp.dave_prepare_transition: + await self._handle_prepare_transition(payload) + elif opcode == VoiceOp.dave_execute_transition: + await self._handle_execute_transition(payload) + elif opcode == VoiceOp.dave_prepare_epoch: + await self._handle_prepare_epoch(payload) + elif opcode == VoiceOp.dave_mls_external_sender: + self._handle_external_sender(payload) + elif opcode == VoiceOp.dave_mls_proposals: + await self._handle_proposals(payload) + elif opcode == VoiceOp.dave_mls_announce_commit_transition: + await self._handle_commit(payload) + elif opcode == VoiceOp.dave_mls_welcome: + await self._handle_welcome(payload) + else: + _log.debug("Unhandled DAVE binary opcode %s", opcode) + + async def _handle_prepare_transition(self, payload: bytes) -> None: + """ Handle PREPARE_TRANSITION (21): record the pending transition and acknowledge. """ + transition_id, version = self._parse_transition(payload) + self._pending_transition = (transition_id, version) + + if transition_id == 0: + await self._execute_transition(transition_id, version) + else: + await self._connection.socket.send_binary( + int(VoiceOp.dave_transition_ready), self._encode_transition_id(transition_id) + ) + + async def _handle_execute_transition(self, payload: bytes) -> None: + """ Handle EXECUTE_TRANSITION (22): apply the pending version and passthrough state. """ + transition_id, _ = self._parse_transition(payload) + + if self._pending_transition is not None: + pending_id, version = self._pending_transition + if pending_id == transition_id: + await self._execute_transition(transition_id, version) + return + + _log.debug("Received EXECUTE_TRANSITION for unknown transition %s", transition_id) + + async def _execute_transition(self, transition_id: int, version: int) -> None: + """ Apply a transition: switch protocol version and update passthrough mode. """ + self._version = version + self.set_passthrough_mode(version == 0) + self._pending_transition = None + _log.debug("Executed DAVE transition %s to version %s", transition_id, version) + + async def _handle_prepare_epoch(self, payload: bytes) -> None: + """ Handle PREPARE_EPOCH (24): reinitialise the session for a new MLS epoch. """ + _epoch, version = self._parse_transition(payload) + await self.reinit(version) + + def _handle_external_sender(self, payload: bytes) -> None: + """ Handle MLS_EXTERNAL_SENDER (25): register the gateway's external sender. """ + if self._session is not None: + try: + self._session.set_external_sender(payload) + except AttributeError: + pass + + async def _handle_proposals(self, payload: bytes) -> None: + """ Handle MLS_PROPOSALS (27): process proposals and forward any commit/welcome. """ + if self._session is None: + return + + try: + result = self._session.process_proposals(payload) + except AttributeError: + return + except Exception as exc: + _log.warning("Failed to process MLS proposals: %s", exc) + await self._recover_from_invalid_commit() + return + + if result is None: + return + + commit_welcome = self._extract_commit_welcome(result) + if commit_welcome is not None: + await self._connection.socket.send_binary( + int(VoiceOp.dave_mls_commit_welcome), commit_welcome + ) + + async def _handle_commit(self, payload: bytes) -> None: + """ Handle MLS_ANNOUNCE_COMMIT_TRANSITION (29): apply the announced commit. """ + if self._session is None: + return + + try: + self._session.process_commit(payload) + except AttributeError: + return + except Exception as exc: + _log.warning("Failed to process MLS commit: %s", exc) + await self._recover_from_invalid_commit() + + async def _handle_welcome(self, payload: bytes) -> None: + """ Handle MLS_WELCOME (30): join the group from the received welcome message. """ + if self._session is None: + return + + try: + self._session.process_welcome(payload) + except AttributeError: + return + except Exception as exc: + _log.warning("Failed to process MLS welcome: %s", exc) + await self._recover_from_invalid_commit() + + async def _recover_from_invalid_commit(self) -> None: + """ Notify the gateway of an invalid commit/welcome and reinitialise the session. """ + await self._connection.socket.send_binary( + int(VoiceOp.dave_mls_invalid_commit_welcome), b"" + ) + await self.reinit(self._version) + + @staticmethod + def _extract_commit_welcome(result: _Session) -> bytes | None: + """ + Extract serialized commit/welcome bytes from a ``davey.CommitWelcome`` result. + + The exact ``davey`` API may differ; this tolerates a serialize method, raw bytes, + or a ``None`` result. + + Parameters + ---------- + result: + The (dynamically typed) value returned by ``davey``'s proposal processing. + + Returns + ------- + The serialized commit/welcome bytes, or ``None`` when there is nothing to send. + """ + if result is None: + return None + if isinstance(result, (bytes, bytearray)): + return bytes(result) + if hasattr(result, "serialize"): + try: + return bytes(result.serialize()) + except Exception: + return None + return None + + @staticmethod + def _parse_transition(payload: bytes) -> tuple[int, int]: + """ + Parse a transition payload into ``(transition_id, version)``. + + Transition payloads carry a 2-byte big-endian transition id optionally followed by + a 1-byte protocol version. Missing fields default to ``0``. + """ + transition_id = int.from_bytes(payload[:2], "big") if len(payload) >= 2 else 0 + version = payload[2] if len(payload) >= 3 else 0 + return transition_id, version + + @staticmethod + def _encode_transition_id(transition_id: int) -> bytes: + """ Encode a transition id as a 2-byte big-endian payload for TRANSITION_READY. """ + return transition_id.to_bytes(2, "big") diff --git a/discord_http/voice/encryptor.py b/discord_http/voice/encryptor.py new file mode 100644 index 0000000..8fc8353 --- /dev/null +++ b/discord_http/voice/encryptor.py @@ -0,0 +1,82 @@ +import struct + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +__all__ = ( + "Encryptor", +) + + +class Encryptor: + """ + Handles Discord voice transport encryption. + + Implements the ``aead_aes256_gcm_rtpsize`` mode, which encrypts the RTP + payload with AES-256-GCM while authenticating the unencrypted RTP header + as additional authenticated data (AAD). + """ + + __slots__ = ( + "_aead", + "_nonce", + ) + + def __init__(self, secret_key: bytes): + self._aead = AESGCM(bytes(secret_key)) + self._nonce = 0 + + @property + def mode(self) -> str: + """ The encryption mode implemented by this encryptor. """ + return "aead_aes256_gcm_rtpsize" + + def encrypt(self, header: bytes, plaintext: bytes) -> bytes: + """ + Encrypt an RTP payload. + + Parameters + ---------- + header: + The unencrypted RTP header, used as the additional authenticated data. + plaintext: + The payload to encrypt, usually an Opus frame. + + Returns + ------- + The packet, consisting of the header, the ciphertext, and the 4-byte big-endian nonce counter. + """ + nonce = self._nonce + nonce_bytes = struct.pack(">I", nonce) + b"\x00" * 8 + ciphertext = self._aead.encrypt(nonce_bytes, plaintext, header) + + self._nonce = (self._nonce + 1) & 0xFFFFFFFF + + return header + ciphertext + struct.pack(">I", nonce) + + def decrypt(self, packet: bytes) -> bytes: + """ + Decrypt a received RTP packet. + + Parameters + ---------- + packet: + The full received packet, including the header, ciphertext, and trailing nonce counter. + + Returns + ------- + The decrypted payload, usually an Opus frame. + """ + nonce_bytes = packet[-4:] + b"\x00" * 8 + + offset = 12 + csrc_count = packet[0] & 0x0F + offset += csrc_count * 4 + + if packet[0] & 0x10: + length = struct.unpack(">H", packet[offset + 2:offset + 4])[0] + offset += 4 + length * 4 + + header = packet[:offset] + ciphertext = packet[offset:-4] + + return self._aead.decrypt(nonce_bytes, ciphertext, header) diff --git a/discord_http/voice/enums.py b/discord_http/voice/enums.py new file mode 100644 index 0000000..5baafec --- /dev/null +++ b/discord_http/voice/enums.py @@ -0,0 +1,36 @@ +from ..enums import BaseEnum + +__all__ = ( + "SUPPORTED_MODES", + "VoiceOp", +) + +SUPPORTED_MODES: tuple[str, ...] = ("aead_aes256_gcm_rtpsize",) + + +class VoiceOp(BaseEnum): + """ Represents the opcode type of a voice gateway payload. """ + identify = 0 + select_protocol = 1 + ready = 2 + heartbeat = 3 + session_description = 4 + speaking = 5 + heartbeat_ack = 6 + resume = 7 + hello = 8 + resumed = 9 + clients_connect = 11 + client_connect = 12 + client_disconnect = 13 + dave_prepare_transition = 21 + dave_execute_transition = 22 + dave_transition_ready = 23 + dave_prepare_epoch = 24 + dave_mls_external_sender = 25 + dave_mls_key_package = 26 + dave_mls_proposals = 27 + dave_mls_commit_welcome = 28 + dave_mls_announce_commit_transition = 29 + dave_mls_welcome = 30 + dave_mls_invalid_commit_welcome = 31 diff --git a/discord_http/voice/gateway_udp.py b/discord_http/voice/gateway_udp.py new file mode 100644 index 0000000..3e9dd17 --- /dev/null +++ b/discord_http/voice/gateway_udp.py @@ -0,0 +1,162 @@ +import asyncio +import logging +import struct + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .connection import VoiceConnection + +__all__ = ("VoiceUDPProtocol",) + +_log = logging.getLogger(__name__) + + +class VoiceUDPProtocol(asyncio.DatagramProtocol): + """ + The UDP transport protocol for Discord voice. + + Handles IP discovery, routes inbound RTP packets to the receiver, and + drops RTCP control traffic. + """ + + def __init__(self, connection: "VoiceConnection"): + self.connection: "VoiceConnection" = connection + """ The voice connection that owns this protocol. """ + + self.transport: asyncio.DatagramTransport | None = None + """ The UDP datagram transport, if connected. """ + + self._discovery_future: asyncio.Future[bytes] | None = None + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """ + Store the transport once the endpoint is created. + + Parameters + ---------- + transport: + The datagram transport for this protocol. + """ + self.transport = transport # type: ignore[assignment] + + def error_received(self, exc: Exception) -> None: + """ + Log a transport-level error. + + Parameters + ---------- + exc: + The exception reported by the transport. + """ + _log.warning("Voice UDP error for guild %s: %s", self.connection.guild_id, exc) + + def connection_lost(self, exc: Exception | None) -> None: + """ + Handle the transport being closed. + + Parameters + ---------- + exc: + The exception that caused the loss, if any. + """ + if exc is not None: + _log.debug("Voice UDP connection lost for guild %s", self.connection.guild_id, exc_info=exc) + self.transport = None + + def datagram_received(self, data: bytes, addr: tuple) -> None: # noqa: ARG002 + """ + Route an inbound datagram to the right consumer. + + Parameters + ---------- + data: + The raw datagram payload. + addr: + The source address of the datagram. + """ + if len(data) < 2: + return + + # IP discovery response (type 0x0002 in the second byte). + if data[1] == 0x02 and self._discovery_future is not None and not self._discovery_future.done(): + self._discovery_future.set_result(data) + return + + # Drop RTCP control packets (payload types 200-204). + payload_type = data[1] & 0x7F + if 200 <= payload_type <= 204: + return + + # Otherwise treat it as RTP and hand it to the receiver, if any. + receiver = self.connection.voice_client._receiver + if receiver is not None: + receiver.unpack(data) + + async def discover_ip(self, ssrc: int) -> tuple[str, int]: + """ + Perform IP discovery to learn this client's external address. + + Parameters + ---------- + ssrc: + The SSRC assigned by the voice gateway. + + Returns + ------- + The externally visible IP address and UDP port. + + Raises + ------ + RuntimeError + If the transport is not available. + """ + if self.transport is None: + raise RuntimeError("UDP transport is not available for IP discovery") + + loop = asyncio.get_running_loop() + self._discovery_future = loop.create_future() + + request = struct.pack(">HHI", 0x1, 70, ssrc) + b"\x00" * 66 + self.transport.sendto(request) + + try: + data = await asyncio.wait_for(self._discovery_future, timeout=20.0) + finally: + self._discovery_future = None + + # External IP is a null-terminated string starting at offset 8. + ip_end = data.index(0, 8) + ip = data[8:ip_end].decode("ascii") + port = struct.unpack_from(">H", data, len(data) - 2)[0] + + return ip, port + + +async def create_udp( + connection: "VoiceConnection", + ip: str, + port: int +) -> tuple[asyncio.DatagramTransport, VoiceUDPProtocol]: + """ + Create a connected UDP datagram endpoint for voice. + + Parameters + ---------- + connection: + The voice connection that owns the new protocol. + ip: + The voice server IP address to connect to. + port: + The voice server UDP port to connect to. + + Returns + ------- + The datagram transport and its protocol. + """ + loop = asyncio.get_running_loop() + transport, protocol = await loop.create_datagram_endpoint( + lambda: VoiceUDPProtocol(connection), + remote_addr=(ip, port), + ) + return transport, protocol diff --git a/discord_http/voice/oggparse.py b/discord_http/voice/oggparse.py new file mode 100644 index 0000000..82473d7 --- /dev/null +++ b/discord_http/voice/oggparse.py @@ -0,0 +1,205 @@ +import struct + +from collections.abc import Iterator +from typing import IO + +__all__ = ( + "OggPage", + "OggStream", +) + +# 4-byte capture pattern that begins every Ogg page. +_OGG_MAGIC = b"OggS" + +# Fixed header layout that follows the 4-byte capture pattern, little-endian: +# x - version (1 byte, ignored, must be 0) +# B - header_type (1 byte) +# Q - granule_position (8 bytes, signed treated as unsigned here) +# I - bitstream_serial_number (4 bytes) +# I - page_sequence_number (4 bytes) +# I - CRC checksum (4 bytes) +# B - page_segments (1 byte, number of segments N) +# This is exactly 23 bytes, mirroring discord.py's well-known approach. +_HEADER_STRUCT = struct.Struct(" None: + header = stream.read(_HEADER_STRUCT.size) + if len(header) < _HEADER_STRUCT.size: + raise ValueError("Incomplete Ogg page header") + + ( + self.header_type, + self.granule_position, + self.bitstream_serial_number, + self.page_sequence_number, + self.crc_checksum, + page_segments, + ) = _HEADER_STRUCT.unpack(header) + + self.segtable = stream.read(page_segments) + if len(self.segtable) < page_segments: + raise ValueError("Incomplete Ogg page segment table") + + body_length = sum(self.segtable) + self.data = stream.read(body_length) + if len(self.data) < body_length: + raise ValueError("Incomplete Ogg page body") + + def iter_packets(self) -> Iterator[tuple[bytes, bool]]: + """ + Yield the packet chunks contained in this single page. + + Each yielded tuple is ``(packet_bytes, complete)`` where ``complete`` is + ``True`` when the accumulated chunk terminates a packet within this page + (the lacing value was ``0-254``) and ``False`` when the packet continues + into the next page (the final lacing value was exactly ``255``). + + Yields + ------ + tuple[bytes, bool] + A chunk of packet data and whether it completes the packet. + """ + offset = 0 + partial = bytearray() + + for lacing in self.segtable: + chunk = self.data[offset:offset + lacing] + offset += lacing + partial += chunk + + if lacing < 255: + yield bytes(partial), True + partial = bytearray() + + # A trailing run of 255s means the packet spills into the next page. + if partial: + yield bytes(partial), False + + +class OggStream: + """ + A reader that extracts raw Opus packets from an Ogg/Opus byte stream. + + The stream is scanned for ``b"OggS"`` capture patterns; each page is parsed + into an :class:`OggPage` and packets are reassembled across page boundaries + according to the Ogg lacing rules (a trailing lacing value of ``255`` + continues a packet into the following page). + + Notes + ----- + The first two packets of a standard Ogg/Opus stream are the ``OpusHead`` and + ``OpusTags`` metadata headers. They are yielded as-is; consumers that only + want audio frames should skip any packet starting with ``b"OpusHead"`` or + ``b"OpusTags"``. They are never silently dropped. + """ + + __slots__ = ("stream",) + + def __init__(self, stream: IO[bytes]) -> None: + self.stream = stream + + def _find_next_page(self) -> bool: + """ + Advance the stream to just after the next ``b"OggS"`` capture pattern. + + Returns + ------- + bool + ``True`` if a capture pattern was found, ``False`` at end of stream. + """ + head = self.stream.read(4) + if head == _OGG_MAGIC: + return True + + # Slide a 4-byte window forward one byte at a time until the magic is + # found or the stream is exhausted. + while True: + byte = self.stream.read(1) + if not byte: + return False + + head = head[1:] + byte + if head == _OGG_MAGIC: + return True + + def iter_pages(self) -> Iterator[OggPage]: + """ + Yield each :class:`OggPage` found in the stream, in order. + + Yields + ------ + OggPage + The next parsed page. + """ + while self._find_next_page(): + yield OggPage(self.stream) + + def iter_packets(self) -> Iterator[bytes]: + """ + Yield fully reassembled Opus packets across page boundaries. + + Packets split by ``255`` lacing values, both within a page and across + pages, are concatenated before being yielded. + + Yields + ------ + bytes + A complete Opus packet, including the ``OpusHead``/``OpusTags`` + header packets at the start of the stream. + """ + partial = bytearray() + + for page in self.iter_pages(): + for chunk, complete in page.iter_packets(): + partial += chunk + if complete: + yield bytes(partial) + partial = bytearray() + + # A stream that ends on a 255-run is malformed, but flush whatever we + # accumulated rather than silently discarding trailing data. + if partial: + yield bytes(partial) diff --git a/discord_http/voice/opus.py b/discord_http/voice/opus.py new file mode 100644 index 0000000..d66f782 --- /dev/null +++ b/discord_http/voice/opus.py @@ -0,0 +1,598 @@ +import ctypes +import ctypes.util +import logging + +__all__ = ( + "OPUS_APPLICATION_AUDIO", + "OPUS_APPLICATION_LOWDELAY", + "OPUS_APPLICATION_VOIP", + "OPUS_SILENCE", + "SAMPLES_PER_FRAME", + "SAMPLE_RATE", + "Decoder", + "Encoder", + "OpusError", + "OpusNotLoaded", + "is_loaded", + "load_opus", +) + +_log = logging.getLogger(__name__) + + +# Audio constants, fixed for Discord voice (48kHz stereo, 20ms frames). +SAMPLE_RATE = 48000 +""" The sample rate Discord expects, in Hz. """ + +CHANNELS = 2 +""" The number of audio channels Discord expects (stereo). """ + +FRAME_LENGTH = 20 +""" The length of a single audio frame, in milliseconds. """ + +SAMPLES_PER_FRAME = SAMPLE_RATE // 1000 * FRAME_LENGTH +""" The number of samples per channel in a single 20ms frame (960). """ + +SAMPLE_SIZE = 2 +""" The size of a single sample, in bytes (signed 16-bit). """ + +FRAME_SIZE = SAMPLES_PER_FRAME * CHANNELS * SAMPLE_SIZE +""" The size of a decoded 20ms PCM frame, in bytes (s16le, stereo). """ + +OPUS_SILENCE = b"\xf8\xff\xfe" +""" The magic Opus frame that encodes silence. """ + + +# Opus application types. +OPUS_APPLICATION_VOIP = 2048 +OPUS_APPLICATION_AUDIO = 2049 +OPUS_APPLICATION_LOWDELAY = 2051 + +# Opus CTL request constants. +OPUS_SET_BITRATE_REQUEST = 4002 +OPUS_SET_BANDWIDTH_REQUEST = 4008 +OPUS_SET_INBAND_FEC_REQUEST = 4012 +OPUS_SET_PACKET_LOSS_PERC_REQUEST = 4014 +OPUS_SET_SIGNAL_REQUEST = 4024 + +# Opus value constants for CTL requests. +OPUS_AUTO = -1000 +OPUS_SIGNAL_VOICE = 3001 +OPUS_SIGNAL_MUSIC = 3002 +OPUS_BANDWIDTH_FULLBAND = 1105 + +# Opus error codes (used when raising OpusError). +OPUS_OK = 0 + +# Named aliases that the public set_* helpers accept. +_BANDWIDTHS: dict[str, int] = { + "auto": OPUS_AUTO, + "fullband": OPUS_BANDWIDTH_FULLBAND, +} + +_SIGNALS: dict[str, int] = { + "auto": OPUS_AUTO, + "voice": OPUS_SIGNAL_VOICE, + "music": OPUS_SIGNAL_MUSIC, +} + + +class OpusError(Exception): + """ Raised when libopus returns an error code. """ + + +class OpusNotLoaded(Exception): # noqa: N818 + """ Raised when an Opus operation is attempted but libopus is not available. """ + + +# Opaque handle types. libopus only ever hands these back as pointers. +EncoderStruct = ctypes.c_void_p +DecoderStruct = ctypes.c_void_p + +# Module-level loader state. ``_lib`` is the loaded CDLL or None, and +# ``_loaded`` is a sentinel so a failed/empty lazy search is not repeated. +_lib: ctypes.CDLL | None = None +_loaded: bool = False + + +def _configure_lib(lib: ctypes.CDLL) -> None: + """ + Configure the ``argtypes`` and ``restype`` of every function we bind. + + Parameters + ---------- + lib: + The freshly loaded libopus shared library. + """ + lib.opus_strerror.argtypes = [ctypes.c_int] + lib.opus_strerror.restype = ctypes.c_char_p + + lib.opus_encoder_get_size.argtypes = [ctypes.c_int] + lib.opus_encoder_get_size.restype = ctypes.c_int + + lib.opus_encoder_create.argtypes = [ + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.POINTER(ctypes.c_int), + ] + lib.opus_encoder_create.restype = EncoderStruct + + # ``opus_encoder_ctl`` is variadic; arguments are supplied per-call. + lib.opus_encoder_ctl.restype = ctypes.c_int + + lib.opus_encode.argtypes = [ + EncoderStruct, + ctypes.POINTER(ctypes.c_int16), + ctypes.c_int, + ctypes.POINTER(ctypes.c_ubyte), + ctypes.c_int32, + ] + lib.opus_encode.restype = ctypes.c_int32 + + lib.opus_encoder_destroy.argtypes = [EncoderStruct] + lib.opus_encoder_destroy.restype = None + + lib.opus_decoder_get_size.argtypes = [ctypes.c_int] + lib.opus_decoder_get_size.restype = ctypes.c_int + + lib.opus_decoder_create.argtypes = [ + ctypes.c_int, + ctypes.c_int, + ctypes.POINTER(ctypes.c_int), + ] + lib.opus_decoder_create.restype = DecoderStruct + + # ``opus_decoder_ctl`` is variadic; arguments are supplied per-call. + lib.opus_decoder_ctl.restype = ctypes.c_int + + lib.opus_decode.argtypes = [ + DecoderStruct, + ctypes.POINTER(ctypes.c_ubyte), + ctypes.c_int32, + ctypes.POINTER(ctypes.c_int16), + ctypes.c_int, + ctypes.c_int, + ] + lib.opus_decode.restype = ctypes.c_int + + lib.opus_decoder_destroy.argtypes = [DecoderStruct] + lib.opus_decoder_destroy.restype = None + + lib.opus_packet_get_nb_frames.argtypes = [ctypes.POINTER(ctypes.c_ubyte), ctypes.c_int] + lib.opus_packet_get_nb_frames.restype = ctypes.c_int + + lib.opus_packet_get_samples_per_frame.argtypes = [ctypes.POINTER(ctypes.c_ubyte), ctypes.c_int] + lib.opus_packet_get_samples_per_frame.restype = ctypes.c_int + + +def load_opus(name: str | None = None) -> None: + """ + Load libopus and configure all of its bindings. + + Parameters + ---------- + name: + An explicit path or library name to load. When omitted, the system + library is located with :func:`ctypes.util.find_library`. + + Raises + ------ + OpusError + If an explicit ``name`` was given but the library could not be loaded. + """ + global _lib, _loaded # noqa: PLW0603 + + _loaded = True + + location = name + if location is None: + location = ctypes.util.find_library("opus") + + if location is None: + # No library on the system; this is not fatal. The passthrough/E2EE + # paths can still operate without libopus present. + _lib = None + _log.debug("libopus could not be located; Opus encode/decode is unavailable") + return + + try: + lib = ctypes.CDLL(location) + _configure_lib(lib) + except (OSError, AttributeError) as exc: + _lib = None + if name is not None: + # The caller explicitly asked for this library, so surface failure. + raise OpusError(f"Could not load libopus from {location!r}") from exc + _log.warning("Found libopus at %r but failed to load it: %s", location, exc) + return + + _lib = lib + _log.debug("Successfully loaded libopus from %r", location) + + +def is_loaded() -> bool: + """ + Whether libopus is currently loaded. + + A lazy :func:`load_opus` is attempted once if loading has not yet been + attempted. A missing system library results in ``False`` rather than an + exception. + + Returns + ------- + ``True`` if libopus is loaded and ready, ``False`` otherwise. + """ + if not _loaded: + load_opus() + + return _lib is not None + + +def _get_lib() -> ctypes.CDLL: + """ + Return the loaded libopus library, loading it lazily if needed. + + Returns + ------- + The loaded libopus shared library. + + Raises + ------ + OpusNotLoaded + If libopus could not be found or loaded. + """ + if not _loaded: + load_opus() + + if _lib is None: + raise OpusNotLoaded("libopus is not loaded; install the Opus shared library to use voice encode/decode") + + return _lib + + +def _strerror(code: int) -> str: + """ + Resolve a libopus error code into a human-readable string. + + Parameters + ---------- + code: + The negative error code returned by a libopus call. + + Returns + ------- + The decoded error string. + """ + lib = _lib + if lib is None: + return f"error code {code}" + + message: bytes | None = lib.opus_strerror(code) + if message is None: + return f"error code {code}" + + return message.decode("utf-8", "replace") + + +def _as_ubyte_ptr(data: bytes) -> "ctypes._Pointer[ctypes.c_ubyte]": + """ + Copy ``data`` into a ctypes buffer and return a ``c_ubyte`` pointer to it. + + The caller must keep a reference to ``data`` alive only for the duration of + the libopus call; libopus does not retain the pointer. + + Parameters + ---------- + data: + The bytes to expose to libopus. + + Returns + ------- + A pointer to the start of a mutable copy of ``data``. + """ + buffer = ctypes.create_string_buffer(bytes(data), len(data)) + return ctypes.cast(buffer, ctypes.POINTER(ctypes.c_ubyte)) + + +def _as_int16_ptr(data: bytes) -> "ctypes._Pointer[ctypes.c_int16]": + """ + Copy ``data`` into a ctypes buffer and return a ``c_int16`` pointer to it. + + Parameters + ---------- + data: + The raw PCM bytes to expose to libopus. + + Returns + ------- + A pointer to the start of a mutable copy of ``data``. + """ + buffer = ctypes.create_string_buffer(bytes(data), len(data)) + return ctypes.cast(buffer, ctypes.POINTER(ctypes.c_int16)) + + +def _check(code: int) -> int: + """ + Raise :class:`OpusError` if ``code`` indicates a libopus failure. + + Parameters + ---------- + code: + The integer return value of a libopus call. + + Returns + ------- + The original ``code`` when it is non-negative. + + Raises + ------ + OpusError + If ``code`` is negative. + """ + if code < OPUS_OK: + raise OpusError(_strerror(code)) + + return code + + +class Encoder: + """ A libopus encoder configured for Discord voice (48kHz, stereo). """ + + def __init__(self, application: int = OPUS_APPLICATION_AUDIO): + # Set first so __del__ is safe even if _get_lib() raises below. + self._state: int = 0 + self._lib: ctypes.CDLL = _get_lib() + self.application = application + """ The Opus application type the encoder was created with. """ + + error = ctypes.c_int() + state: int = self._lib.opus_encoder_create( + ctypes.c_int(SAMPLE_RATE), + ctypes.c_int(CHANNELS), + ctypes.c_int(application), + ctypes.byref(error), + ) + + _check(error.value) + self._state = state + + def __del__(self) -> None: + self.cleanup() + + def _ctl(self, request: int, value: int) -> int: + """ + Issue a CTL request to the encoder. + + Parameters + ---------- + request: + The CTL request constant. + value: + The integer value to set. + + Returns + ------- + The libopus return code. + """ + return _check(self._lib.opus_encoder_ctl(self._state, ctypes.c_int(request), ctypes.c_int(value))) + + def set_bitrate(self, kbps: int) -> None: + """ + Set the target bitrate. + + Parameters + ---------- + kbps: + The target bitrate in kilobits per second. + """ + clamped = min(512, max(16, kbps)) + self._ctl(OPUS_SET_BITRATE_REQUEST, clamped * 1024) + + def set_fec(self, enabled: bool) -> None: + """ + Enable or disable in-band forward error correction. + + Parameters + ---------- + enabled: + Whether FEC should be enabled. + """ + self._ctl(OPUS_SET_INBAND_FEC_REQUEST, 1 if enabled else 0) + + def set_expected_packet_loss_percent(self, pct: float) -> None: + """ + Set the expected packet-loss percentage used to tune FEC. + + Parameters + ---------- + pct: + The expected packet loss, as a fraction between 0 and 1. + """ + value = min(100, max(0, int(pct * 100))) + self._ctl(OPUS_SET_PACKET_LOSS_PERC_REQUEST, value) + + def set_bandwidth(self, name: str) -> None: + """ + Set the encoder bandwidth. + + Parameters + ---------- + name: + The bandwidth name, one of ``auto`` or ``fullband``. + + Raises + ------ + KeyError + If ``name`` is not a recognised bandwidth. + """ + self._ctl(OPUS_SET_BANDWIDTH_REQUEST, _BANDWIDTHS[name]) + + def set_signal_type(self, name: str) -> None: + """ + Set the signal type hint. + + Parameters + ---------- + name: + The signal type name, one of ``auto``, ``voice`` or ``music``. + + Raises + ------ + KeyError + If ``name`` is not a recognised signal type. + """ + self._ctl(OPUS_SET_SIGNAL_REQUEST, _SIGNALS[name]) + + def encode(self, pcm: bytes, frame_size: int = SAMPLES_PER_FRAME) -> bytes: + """ + Encode a single frame of PCM audio into an Opus packet. + + Parameters + ---------- + pcm: + The raw signed 16-bit little-endian stereo PCM data. + frame_size: + The number of samples per channel in the frame. + + Returns + ------- + The encoded Opus packet. + + Raises + ------ + OpusError + If libopus fails to encode the frame. + """ + max_data_bytes = len(pcm) + pcm_ptr = _as_int16_ptr(pcm) + output = (ctypes.c_ubyte * max_data_bytes)() + + result: int = self._lib.opus_encode( + self._state, + pcm_ptr, + ctypes.c_int(frame_size), + ctypes.cast(output, ctypes.POINTER(ctypes.c_ubyte)), + ctypes.c_int32(max_data_bytes), + ) + + _check(result) + + return bytes(output[:result]) + + def cleanup(self) -> None: + """ Free the underlying libopus encoder. """ + if self._state: + self._lib.opus_encoder_destroy(self._state) + self._state = 0 + + +class Decoder: + """ A libopus decoder configured for Discord voice (48kHz, stereo). """ + + def __init__(self): + # Set first so __del__ is safe even if _get_lib() raises below. + self._state: int = 0 + self._lib: ctypes.CDLL = _get_lib() + + error = ctypes.c_int() + state: int = self._lib.opus_decoder_create( + ctypes.c_int(SAMPLE_RATE), + ctypes.c_int(CHANNELS), + ctypes.byref(error), + ) + + _check(error.value) + self._state = state + + def __del__(self) -> None: + self.cleanup() + + @staticmethod + def packet_get_nb_frames(data: bytes) -> int: + """ + Return the number of frames in an Opus packet. + + Parameters + ---------- + data: + The Opus packet to inspect. + + Returns + ------- + The number of frames the packet contains. + """ + lib = _get_lib() + data_ptr = _as_ubyte_ptr(data) + return _check(lib.opus_packet_get_nb_frames(data_ptr, ctypes.c_int(len(data)))) + + @staticmethod + def packet_get_samples_per_frame(data: bytes) -> int: + """ + Return the number of samples per frame for an Opus packet. + + Parameters + ---------- + data: + The Opus packet to inspect. + + Returns + ------- + The number of samples per channel in each frame. + """ + lib = _get_lib() + data_ptr = _as_ubyte_ptr(data) + return _check(lib.opus_packet_get_samples_per_frame(data_ptr, ctypes.c_int(SAMPLE_RATE))) + + def decode(self, data: bytes | None, *, fec: bool = False) -> bytes: + """ + Decode an Opus packet into PCM audio. + + Parameters + ---------- + data: + The Opus packet to decode, or ``None`` to perform packet-loss + concealment for a single 20ms frame. + fec: + Whether to decode using forward error correction. + + Returns + ------- + The decoded signed 16-bit little-endian stereo PCM data. + + Raises + ------ + OpusError + If libopus fails to decode the packet. + """ + if data is None: + frame_size = SAMPLES_PER_FRAME + data_ptr: "ctypes._Pointer[ctypes.c_ubyte] | None" = None + data_len = 0 + else: + frames = self.packet_get_nb_frames(data) + samples_per_frame = self.packet_get_samples_per_frame(data) + frame_size = frames * samples_per_frame + data_ptr = _as_ubyte_ptr(data) + data_len = len(data) + + pcm = (ctypes.c_int16 * (frame_size * CHANNELS))() + + result: int = self._lib.opus_decode( + self._state, + data_ptr, + ctypes.c_int32(data_len), + ctypes.cast(pcm, ctypes.POINTER(ctypes.c_int16)), + ctypes.c_int(frame_size), + ctypes.c_int(1 if fec else 0), + ) + + _check(result) + + return bytes(bytearray(pcm)[: result * CHANNELS * SAMPLE_SIZE]) + + def cleanup(self) -> None: + """ Free the underlying libopus decoder. """ + if self._state: + self._lib.opus_decoder_destroy(self._state) + self._state = 0 diff --git a/discord_http/voice/player.py b/discord_http/voice/player.py new file mode 100644 index 0000000..94b274d --- /dev/null +++ b/discord_http/voice/player.py @@ -0,0 +1,692 @@ +import abc +import asyncio +import io +import logging +import os +import shutil + +from array import array +from collections.abc import AsyncIterable, Callable +from typing import TYPE_CHECKING + +from .oggparse import OggPage +from .opus import OPUS_SILENCE + +if TYPE_CHECKING: + from .client import VoiceClient + +__all__ = ( + "AudioPlayer", + "AudioSource", + "FFmpegOpusAudio", + "FFmpegPCMAudio", + "PCMAudio", + "PCMVolumeTransformer", +) + +_log = logging.getLogger(__name__) + + +# The size of a single 20ms PCM frame, in bytes (48kHz, stereo, s16le). +FRAME_SIZE = 3840 + +# The number of bytes pulled from ffmpeg stdout per read when parsing Ogg/Opus. +_OGG_READ_CHUNK = 8192 + +# The 4-byte capture pattern that begins every Ogg page. +_OGG_MAGIC = b"OggS" + +# The fixed-size Ogg page header that follows the capture pattern (see oggparse). +_OGG_HEADER_SIZE = 23 + +# The signed 16-bit value range, used when clamping scaled PCM samples. +_INT16_MIN = -32768 +_INT16_MAX = 32767 + + +class AudioSource(abc.ABC): + """ + An abstract base class for an audio source. + + A source yields raw audio one 20ms frame at a time via :meth:`read`. When + :meth:`is_opus` returns ``True`` each frame is a complete Opus packet ready + to be sent on the wire; otherwise each frame is 3840 bytes of signed 16-bit + little-endian PCM (48kHz, stereo) which the player encodes before sending. + """ + + @abc.abstractmethod + async def read(self) -> bytes: + """ + Read the next 20ms frame of audio. + + Returns + ------- + bytes + A single Opus packet when :meth:`is_opus` is ``True``, otherwise + exactly 3840 bytes of s16le PCM. Empty bytes signal end of stream. + """ + raise NotImplementedError + + def is_opus(self) -> bool: + """ + Whether :meth:`read` yields pre-encoded Opus packets. + + Returns + ------- + bool + ``True`` if frames are Opus packets, ``False`` if they are PCM. + """ + return False + + def cleanup(self) -> None: + """ Release any resources held by the source. """ + return + + +class PCMAudio(AudioSource): + """ + An audio source that reads raw PCM frames from a binary stream. + + The stream must contain signed 16-bit little-endian PCM at 48kHz in stereo. + + Parameters + ---------- + stream: + A readable binary stream of s16le PCM data. + """ + + def __init__(self, stream: io.IOBase) -> None: + self.stream = stream + + async def read(self) -> bytes: + """ Read one 3840-byte PCM frame, or empty bytes at end of stream. """ + ret = self.stream.read(FRAME_SIZE) + if len(ret) != FRAME_SIZE: + return b"" + return ret + + +class PCMVolumeTransformer(AudioSource): + """ + A wrapper that scales the volume of a PCM (non-Opus) audio source. + + Samples are scaled with the stdlib :mod:`array` module because ``audioop`` + is removed in Python 3.13. Each signed 16-bit sample is multiplied by the + volume factor and clamped back into the int16 range. + + Parameters + ---------- + original: + The PCM source to wrap. Its :meth:`AudioSource.is_opus` must be ``False``. + volume: + The initial volume multiplier, where ``1.0`` is unchanged. + + Raises + ------ + TypeError + If ``original`` is not an :class:`AudioSource`. + ValueError + If ``original`` yields Opus packets rather than PCM. + """ + + def __init__(self, original: AudioSource, volume: float = 1.0) -> None: + if not isinstance(original, AudioSource): + raise TypeError(f"Expected AudioSource, got {type(original).__name__}") + if original.is_opus(): + raise ValueError("PCMVolumeTransformer only supports non-Opus sources") + + self.original = original + self._volume = max(volume, 0.0) + + @property + def volume(self) -> float: + """ The volume multiplier, where ``1.0`` is unchanged. """ + return self._volume + + @volume.setter + def volume(self, value: float) -> None: + self._volume = max(value, 0.0) + + async def read(self) -> bytes: + """ Read one frame from the wrapped source with volume applied. """ + data = await self.original.read() + if not data: + return b"" + + samples = array("h") + samples.frombytes(data) + for i, sample in enumerate(samples): + scaled = int(sample * self._volume) + samples[i] = min(max(scaled, _INT16_MIN), _INT16_MAX) + + return samples.tobytes() + + def cleanup(self) -> None: + """ Clean up the wrapped source. """ + self.original.cleanup() + + +class _FFmpegAudio(AudioSource): + """ + A base audio source backed by an ``ffmpeg`` subprocess. + + The subprocess is launched lazily on the first :meth:`read` via + :func:`asyncio.create_subprocess_exec`, so no blocking thread is ever used. + When ``pipe`` is set the ``source`` is streamed into ffmpeg's stdin: a small + pump task copies bytes from a readable stream (or an async iterable) into + stdin and closes it at end of input. + + Parameters + ---------- + source: + A path/URL (when ``pipe`` is ``False``) or a readable stream / async + iterable of bytes (when ``pipe`` is ``True``). + args: + The ffmpeg output argument list following the ``-i`` input specifier. + before_args: + The ffmpeg input argument list placed before the ``-i`` input specifier. + executable: + The ffmpeg executable name or path. + pipe: + Whether ``source`` is piped to ffmpeg stdin rather than read as input. + + Raises + ------ + FileNotFoundError + If the ffmpeg executable cannot be located on ``PATH``. + """ + + def __init__( + self, + source: str | io.IOBase | AsyncIterable[bytes], + *, + args: list[str], + before_args: list[str] | None = None, + executable: str = "ffmpeg", + pipe: bool = False, + ) -> None: + if shutil.which(executable) is None: + raise FileNotFoundError(f"ffmpeg executable {executable!r} was not found on PATH") + + self._source = source + self._executable = executable + self._pipe = pipe + self._args = args + self._before_args = before_args or [] + + self._process: asyncio.subprocess.Process | None = None + self._stdin_task: asyncio.Task[None] | None = None + self._stdout: asyncio.StreamReader | None = None + + async def _spawn(self) -> None: + """ Launch the ffmpeg subprocess and start the stdin pump if piping. """ + stdin = asyncio.subprocess.PIPE if self._pipe else asyncio.subprocess.DEVNULL + input_arg = "pipe:0" if self._pipe else self._source + + if not isinstance(input_arg, str): + # Non-pipe sources must be a path or URL string. + raise TypeError(f"Expected str source for non-piped ffmpeg, got {type(input_arg).__name__}") + + self._process = await asyncio.create_subprocess_exec( + self._executable, + *self._before_args, + "-i", + input_arg, + *self._args, + stdin=stdin, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.DEVNULL, + ) + self._stdout = self._process.stdout + + if self._pipe: + self._stdin_task = asyncio.create_task(self._pump_stdin()) + + async def _pump_stdin(self) -> None: + """ Copy the piped source into ffmpeg stdin, then close it. """ + process = self._process + if process is None or process.stdin is None: + return + + stdin = process.stdin + try: + if isinstance(self._source, AsyncIterable): + async for chunk in self._source: + stdin.write(chunk) + await stdin.drain() + elif isinstance(self._source, io.IOBase): + while True: + chunk = self._source.read(_OGG_READ_CHUNK) + if not chunk: + break + stdin.write(chunk) + await stdin.drain() + except (BrokenPipeError, ConnectionResetError): + # ffmpeg may exit early (e.g. on stop); nothing more to feed. + pass + finally: + try: + stdin.close() + except (BrokenPipeError, ConnectionResetError): + pass + + def cleanup(self) -> None: + """ Terminate the ffmpeg subprocess and cancel the stdin pump. """ + if self._stdin_task is not None and not self._stdin_task.done(): + self._stdin_task.cancel() + self._stdin_task = None + + process = self._process + if process is not None and process.returncode is None: + try: + process.kill() + except ProcessLookupError: + pass + + self._process = None + self._stdout = None + + +class FFmpegPCMAudio(_FFmpegAudio): + """ + An audio source that transcodes input to PCM with ``ffmpeg``. + + ffmpeg decodes the input to signed 16-bit little-endian PCM at 48kHz in + stereo, and :meth:`read` returns one 3840-byte frame per call via + :meth:`asyncio.StreamReader.readexactly`. The final, possibly short frame is + returned once before end of stream is signalled with empty bytes. + + Because the output is PCM, libopus is required to encode it before sending. + + Parameters + ---------- + source: + A path/URL, or (when ``pipe`` is ``True``) a readable stream / async + iterable of bytes fed to ffmpeg stdin. + before_options: + Extra ffmpeg arguments placed before ``-i`` (e.g. seek options). + options: + Extra ffmpeg output arguments placed after the format flags. + pipe: + Whether ``source`` is piped to ffmpeg stdin. + executable: + The ffmpeg executable name or path. + """ + + def __init__( + self, + source: str | io.IOBase | AsyncIterable[bytes], + *, + before_options: str | None = None, + options: str | None = None, + pipe: bool = False, + executable: str = "ffmpeg", + ) -> None: + before_args = before_options.split() if before_options is not None else None + + args = ["-f", "s16le", "-ar", "48000", "-ac", "2", "-loglevel", "warning"] + if options is not None: + args.extend(options.split()) + args.append("pipe:1") + + super().__init__(source, args=args, before_args=before_args, executable=executable, pipe=pipe) + + async def read(self) -> bytes: + """ Read one 3840-byte PCM frame from ffmpeg, or empty bytes at EOF. """ + if self._stdout is None: + await self._spawn() + + assert self._stdout is not None + try: + return await self._stdout.readexactly(FRAME_SIZE) + except asyncio.IncompleteReadError as exc: + # End of stream: return the trailing partial frame, then b"". + return bytes(exc.partial) + + +class FFmpegOpusAudio(_FFmpegAudio): + """ + An audio source that encodes input to Opus with ``ffmpeg``. + + ffmpeg produces an Ogg/Opus stream on stdout and this class extracts raw + Opus packets from it, so no libopus binding is needed on the Python side. + + Approach for async Ogg parsing + ------------------------------- + :class:`~discord_http.voice.oggparse.OggPage` parses synchronously from a + file-like ``read`` and raises ``ValueError`` if a page is truncated. To keep + the event loop unblocked we never hand it ffmpeg's pipe directly. Instead + bytes are accumulated in a :class:`bytearray` and parsed one *complete page + at a time*: + + 1. Chunks are read from ffmpeg stdout with ``await stdout.read(n)`` and + appended to the buffer. + 2. The buffer is scanned for the ``b"OggS"`` capture pattern. Using only the + fixed header and segment table, the full byte length of the page is + computed up front; parsing is deferred until that many bytes are present, + so :class:`OggPage` never sees a truncated page. Each fully-parsed page is + removed from the front of the buffer. + 3. Packets are reassembled across pages (a trailing ``255`` lacing value + continues a packet into the next page), with the ``OpusHead``/``OpusTags`` + header packets skipped. + + More bytes are read and the pass retried until at least one packet is ready. + Empty bytes are returned once stdout is exhausted and the buffer drained. + + Parameters + ---------- + source: + A path/URL, or (when ``pipe`` is ``True``) a readable stream / async + iterable of bytes fed to ffmpeg stdin. + bitrate: + The target Opus bitrate in kilobits per second. + before_options: + Extra ffmpeg arguments placed before ``-i`` (e.g. seek options). + options: + Extra ffmpeg output arguments placed after the format flags. + pipe: + Whether ``source`` is piped to ffmpeg stdin. + executable: + The ffmpeg executable name or path. + """ + + def __init__( + self, + source: str | io.IOBase | AsyncIterable[bytes], + *, + bitrate: int = 128, + before_options: str | None = None, + options: str | None = None, + pipe: bool = False, + executable: str = "ffmpeg", + ) -> None: + before_args = before_options.split() if before_options is not None else None + + args = [ + "-c:a", "libopus", + "-f", "opus", + "-ar", "48000", + "-ac", "2", + "-b:a", f"{bitrate}k", + "-loglevel", "warning", + ] + if options is not None: + args.extend(options.split()) + args.append("pipe:1") + + super().__init__(source, args=args, before_args=before_args, executable=executable, pipe=pipe) + + self._buffer = bytearray() + self._partial = bytearray() + self._packets: list[bytes] = [] + self._eof = False + + async def _fill_buffer(self) -> bool: + """ + Read one chunk from ffmpeg stdout into the buffer. + + Returns + ------- + bool + ``True`` if data was read, ``False`` at end of stdout. + """ + assert self._stdout is not None + chunk = await self._stdout.read(_OGG_READ_CHUNK) + if not chunk: + self._eof = True + return False + + self._buffer.extend(chunk) + return True + + def _drain_buffer(self) -> None: + """ Parse every complete Ogg page currently held in the buffer. """ + while True: + index = self._buffer.find(_OGG_MAGIC) + if index < 0: + break + + # An Ogg page is: magic (4) + fixed header (23) + segment table + # (page_segments) + body (sum of lacing values). Bail until the full + # page has arrived so OggPage never reads a truncated header/body. + header_end = index + 4 + _OGG_HEADER_SIZE + if len(self._buffer) < header_end: + break + + page_segments = self._buffer[header_end - 1] + table_end = header_end + page_segments + if len(self._buffer) < table_end: + break + + body_size = sum(self._buffer[header_end:table_end]) + page_end = table_end + body_size + if len(self._buffer) < page_end: + break + + # The slice starts just after the magic, matching what OggPage expects. + page = OggPage(io.BytesIO(self._buffer[index + 4:page_end])) + for chunk, complete in page.iter_packets(): + self._partial.extend(chunk) + if complete: + packet = bytes(self._partial) + self._partial.clear() + if not packet.startswith((b"OpusHead", b"OpusTags")): + self._packets.append(packet) + + del self._buffer[:page_end] + + async def read(self) -> bytes: + """ Read one Opus packet from ffmpeg's Ogg stream, or empty bytes at EOF. """ + if self._stdout is None: + await self._spawn() + + while not self._packets: + self._drain_buffer() + if self._packets: + break + if self._eof: + return b"" + await self._fill_buffer() + + return self._packets.pop(0) + + def is_opus(self) -> bool: + """ Whether frames are Opus packets (always ``True`` for this source). """ + return True + + +class AudioPlayer: + """ + A paced player that streams an :class:`AudioSource` to a voice client. + + The player runs as an :class:`asyncio.Task`. Frames are read and sent every + :attr:`DELAY` seconds with drift correction: each frame's deadline is + computed from a fixed start time so transient delays do not accumulate. + + Pause, resume and stop are controlled by :class:`asyncio.Event` flags, and + the source can be hot-swapped mid-playback via :meth:`set_source`. When the + stream ends, five :data:`~discord_http.voice.opus.OPUS_SILENCE` frames are + sent, speaking is disabled, and the optional ``after`` callback is invoked. + + Parameters + ---------- + source: + The audio source to play. + voice_client: + The voice client used to speak and send packets. + after: + An optional callable invoked with an exception (or ``None`` on success) + once playback ends. It is called synchronously and may not be a + coroutine; exceptions raised by it are logged and swallowed. + """ + + DELAY: float = 0.02 + + def __init__( + self, + source: AudioSource, + voice_client: "VoiceClient", + *, + after: Callable[[Exception | None], object] | None = None, + ) -> None: + self.source = source + self.voice_client = voice_client + self.after = after + + self._loop = voice_client.loop + self._task: asyncio.Task[None] | None = None + self._resumed = asyncio.Event() + self._resumed.set() + self._end = asyncio.Event() + self._error: Exception | None = None + + def start(self) -> None: + """ Schedule the playback task on the voice client's event loop. """ + if self._task is not None: + raise RuntimeError("Player has already been started") + self._task = self._loop.create_task(self._run()) + + async def _run(self) -> None: + """ Drive the playback loop with drift-corrected pacing. """ + try: + await self.voice_client.speak(True) + + start = self._loop.time() + count = 0 + + while not self._end.is_set(): + if not self._resumed.is_set(): + await self._resumed.wait() + # Re-anchor pacing after a pause so we do not burst frames. + start = self._loop.time() + count = 0 + + data = await self.source.read() + if not data: + break + + self.voice_client.send_audio_packet(data, encode=not self.source.is_opus()) + + count += 1 + deadline = start + count * self.DELAY + await asyncio.sleep(max(0.0, deadline - self._loop.time())) + except Exception as exc: + self._error = exc + finally: + await self._cleanup() + + async def _cleanup(self) -> None: + """ Flush silence, stop speaking, clean up and invoke ``after``. """ + try: + for _ in range(5): + self.voice_client.send_audio_packet(OPUS_SILENCE, encode=False) + except Exception: + _log.exception("Failed to send trailing silence frames") + + try: + await self.voice_client.speak(False) + except Exception: + _log.exception("Failed to disable speaking") + + self.source.cleanup() + + if self.after is not None: + try: + self.after(self._error) + except Exception: + _log.exception("Error calling the after callback") + elif self._error is not None: + _log.exception("Exception in audio player", exc_info=self._error) + + def stop(self) -> None: + """ Stop playback as soon as possible and resume any paused loop. """ + self._end.set() + self._resumed.set() + if self._task is not None: + self._task.cancel() + + def pause(self) -> None: + """ Pause playback, halting reads until :meth:`resume` is called. """ + self._resumed.clear() + + def resume(self) -> None: + """ Resume playback after a :meth:`pause`. """ + self._resumed.set() + + def is_playing(self) -> bool: + """ + Whether audio is currently playing. + + Returns + ------- + bool + ``True`` if the loop is running and not paused. + """ + return not self._end.is_set() and self._resumed.is_set() + + def is_paused(self) -> bool: + """ + Whether playback is paused. + + Returns + ------- + bool + ``True`` if the loop is running but paused. + """ + return not self._end.is_set() and not self._resumed.is_set() + + def set_source(self, source: AudioSource) -> None: + """ + Hot-swap the audio source without interrupting the player task. + + Parameters + ---------- + source: + The new audio source to read from. + """ + self.pause() + self.source.cleanup() + self.source = source + self.resume() + + +def _resolve_source(audio: object) -> AudioSource: + """ + Coerce arbitrary audio input into an :class:`AudioSource`. + + Parameters + ---------- + audio: + One of: an :class:`AudioSource` (returned as-is); a ``str`` or + :class:`os.PathLike` path/URL; raw ``bytes``/``bytearray``/``memoryview``; + a readable :class:`io.IOBase` stream; or an :class:`~collections.abc.AsyncIterable` + of ``bytes``. The latter four are decoded by ffmpeg into Opus. + + Returns + ------- + AudioSource + A source ready to be played. + + Raises + ------ + TypeError + If ``audio`` is not a supported type. + FileNotFoundError + If ffmpeg is required but not found on ``PATH``. + """ + if isinstance(audio, AudioSource): + return audio + + if isinstance(audio, (str, os.PathLike)): + return FFmpegOpusAudio(os.fspath(audio)) + + if isinstance(audio, (bytes, bytearray, memoryview)): + return FFmpegOpusAudio(io.BytesIO(bytes(audio)), pipe=True) + + if isinstance(audio, io.IOBase): + return FFmpegOpusAudio(audio, pipe=True) + + if isinstance(audio, AsyncIterable): + return FFmpegOpusAudio(audio, pipe=True) + + raise TypeError(f"Unsupported audio source type: {type(audio).__name__}") diff --git a/discord_http/voice/receiver.py b/discord_http/voice/receiver.py new file mode 100644 index 0000000..4f8660d --- /dev/null +++ b/discord_http/voice/receiver.py @@ -0,0 +1,257 @@ +import logging +import struct + +from typing import TYPE_CHECKING + +from . import opus +from .opus import Decoder +from .sinks import VoiceData + +if TYPE_CHECKING: + from .client import VoiceClient + from .sinks import AudioSink + +__all__ = ( + "VoiceReceiver", +) + +_log = logging.getLogger(__name__) + + +# The fixed-length portion of an RTP header is 12 bytes; the SSRC is the final +# 32-bit big-endian field, occupying bytes 8..12. +_RTP_HEADER_LENGTH = 12 +_SSRC_OFFSET = 8 + +# The RTP timestamp is a 32-bit big-endian field occupying bytes 4..8. +_TIMESTAMP_OFFSET = 4 + +# The RTP sequence number is a 16-bit big-endian field occupying bytes 2..4. +_SEQUENCE_OFFSET = 2 + + +class VoiceReceiver: + """ Consumes incoming RTP voice packets and dispatches audio to an :class:`AudioSink`. """ + + def __init__(self, voice_client: "VoiceClient") -> None: + """ + Create a receiver bound to a voice client. + + Parameters + ---------- + voice_client: + The voice client this receiver belongs to, used to reach the + connection (encryptor, DAVE hooks) and event loop. + """ + self.voice_client = voice_client + """ The voice client this receiver belongs to. """ + + self.sink: "AudioSink | None" = None + """ The sink currently receiving audio, or ``None`` when not listening. """ + + self._ssrc_map: dict[int, int] = {} + """ Maps an RTP SSRC to the user ID it belongs to. """ + + # Lazily-created per-SSRC Opus decoders. Only populated when the active + # sink wants PCM, since Opus passthrough never needs to decode. + self._decoders: dict[int, Decoder] = {} + + # Per-SSRC last seen RTP sequence number, used for lightweight + # packet-loss concealment when decoding to PCM. + self._last_seq: dict[int, int] = {} + + # Whether the missing-libopus warning has already been emitted, so the + # synchronous UDP callback does not spam the log on every packet. + self._warned_no_opus = False + + def start(self, sink: "AudioSink") -> None: + """ + Begin listening, dispatching received audio to ``sink``. + + Parameters + ---------- + sink: + The sink to receive decoded PCM or raw Opus audio. + """ + self.sink = sink + + def stop(self) -> None: + """ Stop listening and release any per-SSRC decoders and the sink. """ + sink = self.sink + self.sink = None + + for decoder in self._decoders.values(): + decoder.cleanup() + + self._decoders.clear() + self._last_seq.clear() + + if sink is not None: + try: + sink.cleanup() + except Exception: + _log.exception("Error while cleaning up audio sink") + + def is_listening(self) -> bool: + """ + Whether a sink is currently attached. + + Returns + ------- + ``True`` if listening, ``False`` otherwise. + """ + return self.sink is not None + + def add_ssrc(self, ssrc: int, user_id: int) -> None: + """ + Associate an RTP SSRC with a user ID. + + Parameters + ---------- + ssrc: + The RTP synchronisation source identifier. + user_id: + The user ID that owns the SSRC. + """ + self._ssrc_map[ssrc] = user_id + + def remove_user(self, user_id: int) -> None: + """ + Remove every SSRC mapping and decoder belonging to a user. + + Parameters + ---------- + user_id: + The user ID to forget. + """ + stale = [ssrc for ssrc, uid in self._ssrc_map.items() if uid == user_id] + + for ssrc in stale: + del self._ssrc_map[ssrc] + self._last_seq.pop(ssrc, None) + + decoder = self._decoders.pop(ssrc, None) + if decoder is not None: + decoder.cleanup() + + def _get_decoder(self, ssrc: int) -> Decoder: + """ + Return the per-SSRC decoder, creating it on first use. + + Parameters + ---------- + ssrc: + The RTP synchronisation source identifier to decode for. + + Returns + ------- + The decoder dedicated to ``ssrc``. + """ + decoder = self._decoders.get(ssrc) + if decoder is None: + decoder = Decoder() + self._decoders[ssrc] = decoder + + return decoder + + def unpack(self, packet: bytes) -> None: + """ + Decrypt, decode and dispatch a single received RTP packet. + + This is called synchronously from the UDP datagram callback, so it must + never raise: every failure is logged and swallowed. + + Parameters + ---------- + packet: + The raw RTP packet as received from the voice UDP socket. + """ + sink = self.sink + if sink is None: + return + + if len(packet) < _RTP_HEADER_LENGTH: + return + + ssrc = struct.unpack_from(">I", packet, _SSRC_OFFSET)[0] + timestamp = struct.unpack_from(">I", packet, _TIMESTAMP_OFFSET)[0] + sequence = struct.unpack_from(">H", packet, _SEQUENCE_OFFSET)[0] + user_id = self._ssrc_map.get(ssrc) + + connection = self.voice_client.connection + + encryptor = connection.encryptor + if encryptor is None: + # No session key yet (or already torn down); nothing decryptable. + return + + try: + payload = encryptor.decrypt(packet) + except Exception: + _log.exception("Failed to transport-decrypt incoming voice packet") + return + + # DAVE end-to-end decryption, applied only when a session is active and + # we know who the sender is. Tolerant: skip silently when inactive. + if user_id is not None and connection.can_encrypt(): + try: + payload = connection.dave_decrypt_opus(user_id, payload) + except Exception: + _log.exception("Failed to DAVE-decrypt incoming voice packet") + return + + if sink.wants_opus(): + data = VoiceData(user=user_id, pcm=None, opus=payload, timestamp=timestamp, ssrc=ssrc) + else: + pcm = self._decode_pcm(ssrc, sequence, payload) + if pcm is None: + return + data = VoiceData(user=user_id, pcm=pcm, opus=None, timestamp=timestamp, ssrc=ssrc) + + try: + sink.write(user_id, data) + except Exception: + _log.exception("Error in audio sink while writing received voice data") + + def _decode_pcm(self, ssrc: int, sequence: int, payload: bytes) -> bytes | None: + """ + Decode an Opus payload to PCM, applying lightweight packet-loss concealment. + + Parameters + ---------- + ssrc: + The RTP synchronisation source identifier of the sender. + sequence: + The RTP sequence number of the packet, used to detect gaps. + payload: + The Opus payload to decode. + + Returns + ------- + The decoded signed 16-bit little-endian stereo PCM, or ``None`` when + libopus is unavailable and the packet must be dropped. + """ + if not opus.is_loaded(): + # PCM was requested but libopus is missing. Log once and drop rather + # than raising, so the synchronous UDP callback never crashes. + if not self._warned_no_opus: + self._warned_no_opus = True + _log.warning("libopus is not loaded; dropping received voice (PCM decoding unavailable)") + return None + + try: + decoder = self._get_decoder(ssrc) + + # Detect a sequence gap and conceal a single lost frame before + # decoding the packet we actually received. RTP sequence numbers are + # 16-bit and wrap, so compare modulo 2**16. + last = self._last_seq.get(ssrc) + if last is not None and (sequence - last) & 0xFFFF > 1: + decoder.decode(None) + + self._last_seq[ssrc] = sequence + + return decoder.decode(payload) + except Exception: + _log.exception("Failed to Opus-decode received voice packet") + return None diff --git a/discord_http/voice/sinks.py b/discord_http/voice/sinks.py new file mode 100644 index 0000000..c9b5efa --- /dev/null +++ b/discord_http/voice/sinks.py @@ -0,0 +1,176 @@ +import abc +import io +import logging +import os +import wave + +from collections.abc import Callable +from dataclasses import dataclass + +__all__ = ( + "AudioSink", + "CallbackSink", + "VoiceData", + "WaveSink", +) + +_log = logging.getLogger(__name__) + + +@dataclass(slots=True) +class VoiceData: + """ Represents a single chunk of received voice audio for one speaker. """ + + user: int | None + """ The user ID this audio belongs to, or ``None`` if unknown. """ + + pcm: bytes | None + """ The decoded 48kHz 16-bit stereo PCM payload, if available. """ + + opus: bytes | None + """ The raw Opus payload, if available. """ + + timestamp: int + """ The RTP timestamp of the packet this data came from. """ + + ssrc: int + """ The RTP SSRC of the sender this data came from. """ + + +class AudioSink(abc.ABC): + """ Abstract base class for consumers of received voice audio. """ + + def wants_opus(self) -> bool: + """ + Whether this sink wants raw Opus payloads instead of decoded PCM. + + When ``False`` (the default) the receiver decodes packets to PCM + before handing them to :meth:`write`. + + Returns + ------- + ``True`` if the sink consumes Opus, ``False`` for PCM + """ + return False + + @abc.abstractmethod + def write(self, user: int | None, data: VoiceData) -> None: + """ + Consume a single chunk of received voice audio. + + Parameters + ---------- + user: + The user ID the audio belongs to, or ``None`` if unknown + data: + The voice data container holding the PCM and/or Opus payload + """ + raise NotImplementedError + + def cleanup(self) -> None: + """ Finalize the sink, flushing and releasing any held resources. """ + return + + +class CallbackSink(AudioSink): + """ Audio sink that forwards every received chunk to a callback. """ + + def __init__( + self, + callback: Callable[[int | None, VoiceData], object], + *, + opus: bool = False + ) -> None: + """ + Create a sink that forwards received audio to a callback. + + Parameters + ---------- + callback: + The callable invoked as ``callback(user, data)`` for each chunk + opus: + Whether to request raw Opus payloads instead of decoded PCM + """ + self.callback = callback + self.opus = opus + + def wants_opus(self) -> bool: + """ + Whether this sink wants raw Opus payloads instead of decoded PCM. + + Returns + ------- + The value of the ``opus`` flag passed at construction + """ + return self.opus + + def write(self, user: int | None, data: VoiceData) -> None: + """ + Forward the received audio chunk to the callback. + + Parameters + ---------- + user: + The user ID the audio belongs to, or ``None`` if unknown + data: + The voice data container holding the PCM and/or Opus payload + """ + self.callback(user, data) + + +class WaveSink(AudioSink): + """ Audio sink that writes received PCM to a single 48kHz 16-bit stereo WAV file. """ + + def __init__(self, destination: str | os.PathLike | io.IOBase) -> None: + """ + Create a sink that writes received PCM to a WAV file. + + Parameters + ---------- + destination: + A file path or writable binary stream to receive the WAV data + """ + self.destination = destination + self._file: wave.Wave_write | None = None + + def wants_opus(self) -> bool: + """ + Whether this sink wants raw Opus payloads instead of decoded PCM. + + Returns + ------- + Always ``False`` as the WAV file stores PCM + """ + return False + + def _ensure_open(self) -> wave.Wave_write: + """ Open the wave file lazily, configuring it for 48kHz 16-bit stereo. """ + file = self._file + if file is None: + file = wave.open(self.destination, "wb") # type: ignore[arg-type] # noqa: SIM115 + file.setnchannels(2) + file.setsampwidth(2) + file.setframerate(48000) + self._file = file + return file + + def write(self, user: int | None, data: VoiceData) -> None: # noqa: ARG002 + """ + Append the chunk's PCM payload to the WAV file. + + Parameters + ---------- + user: + The user ID the audio belongs to, or ``None`` if unknown (unused) + data: + The voice data container holding the PCM payload + """ + if data.pcm is None: + return + self._ensure_open().writeframes(data.pcm) + + def cleanup(self) -> None: + """ Finalize the WAV file, writing headers and closing the stream. """ + if self._file is not None: + self._file.close() + self._file = None diff --git a/discord_http/voice/socket.py b/discord_http/voice/socket.py new file mode 100644 index 0000000..eb8b638 --- /dev/null +++ b/discord_http/voice/socket.py @@ -0,0 +1,419 @@ +import asyncio +import logging +import struct +import time + +import orjson + +from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType +from collections import deque +from collections.abc import Coroutine +from typing import TYPE_CHECKING, Any + +from enum import IntEnum + +from .enums import VoiceOp + +if TYPE_CHECKING: + from .connection import VoiceConnection + +__all__ = ("VoiceCloseCode", "VoiceSocket") + +_log = logging.getLogger(__name__) + + +class VoiceCloseCode(IntEnum): + """ The voice gateway websocket close codes that govern reconnect behaviour. """ + + normal = 1000 + going_away = 1001 + disconnected = 4014 + voice_server_crashed = 4015 + unknown_encryption_mode = 4016 + bad_request = 4020 + rate_limited = 4021 + call_terminated = 4022 + + +class VoiceSocket: + """ + The voice gateway websocket connection. + + Handles the voice gateway protocol (version 8): the initial handshake, + heartbeating, latency tracking, and dispatching of both JSON and binary + (DAVE) frames to the owning :class:`VoiceConnection`. + """ + + def __init__(self, connection: "VoiceConnection"): + self.connection: "VoiceConnection" = connection + """ The voice connection that owns this socket. """ + + self.ws: ClientWebSocketResponse | None = None + """ The underlying websocket connection, if open. """ + + self.seq_ack: int = -1 + """ The last sequence number received from the voice gateway. """ + + self._closing: bool = False + self._resuming: bool = False + self._own_session: bool = False + self._session: ClientSession | None = None + + self._heartbeat_interval: float = 0.0 + self._heartbeat_task: asyncio.Task | None = None + self._receive_task: asyncio.Task | None = None + + self._out_seq: int = 0 + self._last_send: float = 0.0 + self._latencies: deque[float] = deque(maxlen=20) + + @property + def latency(self) -> float: + """ The latency of the most recent heartbeat, in seconds, or ``inf`` if unknown. """ + if not self._latencies: + return float("inf") + return self._latencies[-1] + + @property + def average_latency(self) -> float: + """ The average latency over the last few heartbeats, in seconds, or ``inf`` if unknown. """ + if not self._latencies: + return float("inf") + return sum(self._latencies) / len(self._latencies) + + def _get_session(self) -> ClientSession: + """ + Return a usable aiohttp session, reusing the bot's if reachable. + + Returns + ------- + The shared HTTP session, or a freshly created one owned by this socket. + """ + try: + session = self.connection.voice_client.client.state.http.session + except AttributeError: + session = None + + if session is not None: + return session + + self._own_session = True + self._session = ClientSession() + return self._session + + async def connect(self, *, resume: bool = False) -> None: + """ + Open the voice websocket and start the receive loop. + + Parameters + ---------- + resume: + Whether to RESUME (op 7) an existing session rather than IDENTIFY (op 0). + """ + self._closing = False + self._resuming = resume + session = self._get_session() + endpoint = self.connection.endpoint + self.ws = await session.ws_connect(f"wss://{endpoint}/?v=8") + + self._receive_task = asyncio.create_task( + self._receive_loop(), + name=f"discord.http/voice/socket-{self.connection.guild_id}/receive" + ) + + async def _receive_loop(self) -> None: + """ Continuously receive frames and dispatch them; never blocks on handlers. """ + ws = self.ws + if ws is None: + return + + close_code: int | None = None + + try: + async for msg in ws: + if msg.type is WSMsgType.TEXT: + self._dispatch_text(msg.data) + + elif msg.type is WSMsgType.BINARY: + self._dispatch_binary(msg.data) + + elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): + _log.debug("Voice socket for guild %s received close frame", self.connection.guild_id) + break + + elif msg.type is WSMsgType.ERROR: + _log.warning("Voice socket for guild %s received error: %s", self.connection.guild_id, msg.data) + break + + except asyncio.CancelledError: + raise + + except Exception as exc: + _log.debug("Voice socket for guild %s receive loop ended", self.connection.guild_id, exc_info=exc) + + close_code = ws.close_code + + if not self._closing: + self.connection._on_socket_closed(close_code) + + def _request_close(self) -> None: + """ Mark the socket as intentionally closing so the receive loop suppresses reconnect. """ + self._closing = True + + def _dispatch_text(self, raw: str | bytes) -> None: + """ + Parse and dispatch a text frame by its voice opcode. + + Parameters + ---------- + raw: + The raw JSON text frame received from the voice gateway. + """ + payload: dict = orjson.loads(raw) + + seq = payload.get("seq") + if seq is not None: + self.seq_ack = seq + + op = payload.get("op") + data: dict = payload.get("d") or {} + + try: + voice_op = VoiceOp(op) + except ValueError: + _log.debug("Voice socket for guild %s received unknown op %s", self.connection.guild_id, op) + return + + match voice_op: + case VoiceOp.hello: + self._heartbeat_interval = float(data["heartbeat_interval"]) / 1000 + self._start_heartbeat() + if self._resuming: + self._schedule(self.send_resume()) + else: + self._schedule(self.send_identify()) + + case VoiceOp.ready: + self._schedule(self.connection.on_ready(data)) + + case VoiceOp.session_description: + self._schedule(self.connection.on_session_description(data)) + + case VoiceOp.speaking: + self._schedule(self.connection.on_speaking(data)) + + case VoiceOp.heartbeat_ack: + if self._last_send: + self._latencies.append(time.perf_counter() - self._last_send) + + case VoiceOp.resumed: + self._schedule(self.connection.on_resumed(data)) + + case _: + _log.debug("Voice socket for guild %s received unhandled op %s", self.connection.guild_id, voice_op) + + def _dispatch_binary(self, raw: bytes) -> None: + """ + Parse and dispatch a binary DAVE frame. + + Parameters + ---------- + raw: + The raw binary frame: ``seq(2B >H) + opcode(1B) + payload``. + """ + if len(raw) < 3: + return + + seq, opcode = struct.unpack_from(">HB", raw, 0) + payload = raw[3:] + + self.seq_ack = seq + + self._schedule(self.connection.on_dave_binary(opcode, payload)) + + def _schedule(self, coro: Coroutine[Any, Any, Any]) -> None: + """ + Schedule a coroutine as a task so the receive loop never blocks. + + Parameters + ---------- + coro: + The coroutine to run independently of the receive loop. + """ + asyncio.create_task( # noqa: RUF006 + self._guard(coro), + name=f"discord.http/voice/socket-{self.connection.guild_id}/dispatch" + ) + + async def _guard(self, coro: Coroutine[Any, Any, Any]) -> None: + """ + Run a scheduled coroutine, logging any exception it raises. + + Parameters + ---------- + coro: + The coroutine to await. + """ + try: + await coro + except Exception as exc: + _log.error("Error in voice socket handler for guild %s", self.connection.guild_id, exc_info=exc) + + def _start_heartbeat(self) -> None: + """ (Re)start the heartbeat task using the negotiated interval. """ + if self._heartbeat_task is not None and not self._heartbeat_task.done(): + self._heartbeat_task.cancel() + + self._heartbeat_task = asyncio.create_task( + self._heartbeat_loop(), + name=f"discord.http/voice/socket-{self.connection.guild_id}/heartbeat" + ) + + async def _heartbeat_loop(self) -> None: + """ Send a heartbeat (op 3) every interval until cancelled. """ + try: + while True: + await self._send_heartbeat() + await asyncio.sleep(self._heartbeat_interval) + except asyncio.CancelledError: + pass + except Exception as exc: + _log.debug("Voice heartbeat for guild %s stopped", self.connection.guild_id, exc_info=exc) + + async def _send_heartbeat(self) -> None: + """ Send a single heartbeat frame, recording the send time for latency. """ + self._last_send = time.perf_counter() + nonce = int(time.time() * 1000) + await self._send_json({ + "op": int(VoiceOp.heartbeat), + "d": { + "t": nonce, + "seq_ack": self.seq_ack, + } + }) + + async def _send_json(self, payload: dict) -> None: + """ + Send a JSON frame over the websocket. + + Parameters + ---------- + payload: + The payload to serialise and send. + """ + if self.ws is None or self.ws.closed: + return + await self.ws.send_bytes(orjson.dumps(payload)) + + async def send_identify(self) -> None: + """ Send the IDENTIFY (op 0) frame, advertising DAVE support. """ + from .dave import max_protocol_version + + await self._send_json({ + "op": int(VoiceOp.identify), + "d": { + "server_id": str(self.connection.guild_id), + "user_id": str(self.connection.user_id), + "session_id": self.connection.session_id, + "token": self.connection.token, + "max_dave_protocol_version": max_protocol_version(), + } + }) + + async def send_select_protocol(self, ip: str, port: int, mode: str) -> None: + """ + Send the SELECT_PROTOCOL (op 1) frame after IP discovery. + + Parameters + ---------- + ip: + The externally discovered IP address. + port: + The externally discovered UDP port. + mode: + The negotiated encryption mode. + """ + await self._send_json({ + "op": int(VoiceOp.select_protocol), + "d": { + "protocol": "udp", + "data": { + "address": ip, + "port": port, + "mode": mode, + } + } + }) + + async def send_speaking(self, speaking: int, *, ssrc: int, delay: int = 0) -> None: + """ + Send the SPEAKING (op 5) frame. + + Parameters + ---------- + speaking: + The speaking bitflag (1 to indicate microphone audio). + ssrc: + The SSRC of the connection. + delay: + The voice delay, in milliseconds. + """ + await self._send_json({ + "op": int(VoiceOp.speaking), + "d": { + "speaking": int(speaking), + "delay": int(delay), + "ssrc": int(ssrc), + } + }) + + async def send_resume(self) -> None: + """ Send the RESUME (op 7) frame to resume an interrupted session. """ + await self._send_json({ + "op": int(VoiceOp.resume), + "d": { + "server_id": str(self.connection.guild_id), + "session_id": self.connection.session_id, + "token": self.connection.token, + "seq_ack": self.seq_ack, + } + }) + + async def send_binary(self, opcode: int, payload: bytes) -> None: + """ + Send a binary DAVE frame. + + Parameters + ---------- + opcode: + The voice opcode for the binary frame. + payload: + The binary payload to send after the opcode. + """ + if self.ws is None or self.ws.closed: + return + + self._out_seq = (self._out_seq + 1) & 0xFFFF + frame = struct.pack(">H", 0) + bytes([opcode & 0xFF]) + payload + await self.ws.send_bytes(frame) + + async def close(self) -> None: + """ Cancel the background tasks and close the websocket. """ + self._closing = True + + if self._heartbeat_task is not None: + self._heartbeat_task.cancel() + self._heartbeat_task = None + + if self._receive_task is not None: + self._receive_task.cancel() + self._receive_task = None + + if self.ws is not None and not self.ws.closed: + await self.ws.close() + self.ws = None + + if self._own_session and self._session is not None: + await self._session.close() + self._session = None + self._own_session = False diff --git a/examples/voice_example.py b/examples/voice_example.py new file mode 100644 index 0000000..c9efbe9 --- /dev/null +++ b/examples/voice_example.py @@ -0,0 +1,104 @@ +import asyncio + +from discord_http import BaseChannel, Client, Context, VoiceClient, WaveSink +from discord_http.gateway import Intents + +# Voice requires a gateway connection (to send the voice-state update) and the +# guild_voice_states intent (so the bot receives its own voice server/state updates). +# +# Codec notes: +# * Passing an ``.mp3``/``.opus`` file plays through ffmpeg -> Ogg/Opus and needs +# ONLY ffmpeg installed (no libopus) -- the audio is sent as opus passthrough. +# * PCM encode/decode (raw PCM sources, volume transforms, or receiving/decoding +# other users' audio) additionally needs libopus loaded (``discord_http.voice.load_opus``). +# * DAVE end-to-end encryption (MLS) is optional and needs: pip install "discord.http[voice]" +client = Client( + token="BOT_TOKEN", + enable_gateway=True, + intents=( + Intents.guild_messages | + Intents.guild_voice_states + ) +) + + +@client.command() +async def join(ctx: Context): + """ Join the caller's voice channel and play a song """ + if ctx.guild is None or ctx.author is None: + return ctx.response.send_message("This command can only be used in a guild.") + + # Resolve a voice channel to connect to (here a hard-coded id for brevity). + channel = await client.fetch_channel(1234567890, guild_id=ctx.guild.id) + if not isinstance(channel, BaseChannel): + return ctx.response.send_message("Could not find that channel.") + + vc: VoiceClient = await channel.connect() + + # Play a local file (mp3 -> opus passthrough, ffmpeg only). + await vc.play("song.mp3") + + return ctx.response.send_message(f"Now playing, latency: {vc.latency:.1f}ms") + + +@client.command() +async def pause(ctx: Context): + """ Pause / resume the current track """ + vc = client._get_voice_client(ctx.guild.id) if ctx.guild else None + if vc is None: + return ctx.response.send_message("Not connected.") + + if vc.is_paused(): + vc.resume() + return ctx.response.send_message("Resumed.") + + vc.pause() + return ctx.response.send_message("Paused.") + + +@client.command() +async def leave(ctx: Context): + """ Stop playback and disconnect """ + vc = client._get_voice_client(ctx.guild.id) if ctx.guild else None + if vc is None: + return ctx.response.send_message("Not connected.") + + vc.stop() + await vc.disconnect() + return ctx.response.send_message("Disconnected.") + + +async def voice_demo(channel: BaseChannel, move_to: BaseChannel) -> None: + """ + A standalone walkthrough of the voice API. + + Parameters + ---------- + channel: + The voice channel to connect to first. + move_to: + A second voice channel to move into mid-session. + """ + vc: VoiceClient = await channel.connect() + + # Playback controls. + await vc.play("song.mp3") + vc.pause() + vc.resume() + + # Hop to another channel without disconnecting. + await vc.move_to(move_to) + + # Receiving: write everyone's audio into a single WAV file. + # (decoding opus -> PCM for the WAV needs libopus loaded.) + vc.listen(WaveSink("out.wav")) + await asyncio.sleep(10) + vc.stop_listening() + + print(f"voice latency: {vc.latency:.1f}ms (avg {vc.average_latency:.1f}ms)") + + vc.stop() + await vc.disconnect() + + +client.start(host="127.0.0.1", port=8080) diff --git a/pyproject.toml b/pyproject.toml index 20a7564..4f10cef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,11 +38,15 @@ docs = [ "sphinx>=8.2.3", "sphinx-autodoc-typehints>=3.2.0", ] +voice = [ + "davey>=0.1.0", +] [tool.setuptools] packages = [ "discord_http", "discord_http.gateway", + "discord_http.voice", ] [tool.setuptools.dynamic] @@ -59,6 +63,7 @@ output-format = "concise" include = [ "discord_http/*.py", "discord_http/gateway/*.py", + "discord_http/voice/*.py", ] exclude = [ @@ -143,7 +148,6 @@ ignore = [ "FURB103", # write-whole-file # Pylint Warnings - "PLW0717", # Try clause contains too many statements "PLW2901", # Overwrite loop control variable # Type hints @@ -179,6 +183,7 @@ pythonVersion = "3.11" include = [ "discord_http", "discord_http.gateway", + "discord_http.voice", ] exclude = [ diff --git a/tests/test_voice_encryptor.py b/tests/test_voice_encryptor.py new file mode 100644 index 0000000..d648f4a --- /dev/null +++ b/tests/test_voice_encryptor.py @@ -0,0 +1,59 @@ +import os +import struct +import unittest + +from discord_http.voice.encryptor import Encryptor + + +class TestVoiceEncryptor(unittest.TestCase): + def test_mode(self) -> None: + key = b"\x00" * 32 + enc = Encryptor(key) + self.assertEqual(enc.mode, "aead_aes256_gcm_rtpsize") + + def test_roundtrip_basic_header(self) -> None: + key = os.urandom(32) + header = struct.pack(">BBHII", 0x80, 0x78, 1, 2, 3) + plaintext = b"opus-frame-data" + + sender = Encryptor(key) + packet = sender.encrypt(header, plaintext) + + self.assertEqual(packet[:12], header) + self.assertEqual(packet[-4:], struct.pack(">I", 0)) + + receiver = Encryptor(key) + self.assertEqual(receiver.decrypt(packet), plaintext) + + def test_roundtrip_with_extension(self) -> None: + key = os.urandom(32) + + # base header with the extension bit (0x10) set on byte0 + base = struct.pack(">BBHII", 0x90, 0x78, 5, 6, 7) + # one-byte RTP extension: 0xBE 0xDE profile, length = 1 word (4 bytes) + extension = b"\xbe\xde" + struct.pack(">H", 1) + b"\x01\x02\x03\x04" + header = base + extension + plaintext = b"another-opus-frame" + + sender = Encryptor(key) + packet = sender.encrypt(header, plaintext) + + self.assertEqual(packet[:len(header)], header) + + receiver = Encryptor(key) + self.assertEqual(receiver.decrypt(packet), plaintext) + + def test_nonce_increments(self) -> None: + key = os.urandom(32) + header = struct.pack(">BBHII", 0x80, 0x78, 1, 2, 3) + + sender = Encryptor(key) + first = sender.encrypt(header, b"a") + second = sender.encrypt(header, b"a") + + self.assertEqual(first[-4:], struct.pack(">I", 0)) + self.assertEqual(second[-4:], struct.pack(">I", 1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_voice_oggparse.py b/tests/test_voice_oggparse.py new file mode 100644 index 0000000..59c853a --- /dev/null +++ b/tests/test_voice_oggparse.py @@ -0,0 +1,136 @@ +import io +import shutil +import struct +import subprocess + +import pytest + +from discord_http.voice.oggparse import OggPage, OggStream + + +def _build_page( + body: bytes, + segtable: bytes, + *, + header_type: int = 0, + granule_position: int = 0, + serial: int = 1, + sequence: int = 0, + crc: int = 0, +) -> bytes: + """Build a single valid Ogg page from a body and a hand-crafted segment table.""" + assert sum(segtable) == len(body), "segment table must sum to body length" + header = struct.pack( + "<4sBBQIIIB", + b"OggS", + 0, # version + header_type, + granule_position, + serial, + sequence, + crc, + len(segtable), + ) + return header + segtable + body + + +def test_single_page_single_packet() -> None: + body = b"hello opus" + page_bytes = _build_page(body, bytes([len(body)])) + + stream = OggStream(io.BytesIO(page_bytes)) + packets = list(stream.iter_packets()) + + assert packets == [body] + + +def test_page_header_fields_parsed() -> None: + body = b"\x00\x01\x02\x03" + page_bytes = _build_page( + body, + bytes([len(body)]), + header_type=0x02, + granule_position=12345, + sequence=7, + ) + + # Skip the 4-byte magic, then parse the page directly. + buffer = io.BytesIO(page_bytes) + assert buffer.read(4) == b"OggS" + page = OggPage(buffer) + + assert page.header_type == 0x02 + assert page.granule_position == 12345 + assert page.page_sequence_number == 7 + assert page.segtable == bytes([len(body)]) + assert page.data == body + + +def test_multiple_packets_in_one_page() -> None: + packet_a = b"first" + packet_b = b"second-packet" + body = packet_a + packet_b + segtable = bytes([len(packet_a), len(packet_b)]) + + stream = OggStream(io.BytesIO(_build_page(body, segtable))) + assert list(stream.iter_packets()) == [packet_a, packet_b] + + +def test_packet_spanning_segments_via_255_lacing() -> None: + # A packet exactly 255 bytes long needs a 255 lacing + a 0 lacing terminator. + body = b"x" * 255 + segtable = bytes([255, 0]) + + stream = OggStream(io.BytesIO(_build_page(body, segtable))) + assert list(stream.iter_packets()) == [body] + + +def test_packet_spanning_pages() -> None: + # First page ends mid-packet (trailing 255 lacing), second page continues it. + head = b"a" * 255 + tail = b"bcd" + page_one = _build_page(head, bytes([255]), sequence=0) + page_two = _build_page(tail, bytes([len(tail)]), header_type=0x01, sequence=1) + + stream = OggStream(io.BytesIO(page_one + page_two)) + assert list(stream.iter_packets()) == [head + tail] + + +def test_scans_past_leading_garbage() -> None: + body = b"payload" + page_bytes = _build_page(body, bytes([len(body)])) + + stream = OggStream(io.BytesIO(b"garbage-before-magic" + page_bytes)) + assert list(stream.iter_packets()) == [body] + + +def test_ffmpeg_generated_opus_stream() -> None: + ffmpeg = shutil.which("ffmpeg") + if ffmpeg is None: + pytest.skip("ffmpeg not available on PATH") + + result = subprocess.run( # noqa: S603 + [ + ffmpeg, + "-f", "lavfi", + "-i", "sine=frequency=440:duration=1", + "-c:a", "libopus", + "-f", "ogg", + "-", + ], + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + check=True, + ) + data = result.stdout + assert data, "ffmpeg produced no output" + + packets = list(OggStream(io.BytesIO(data)).iter_packets()) + + assert len(packets) > 2 + assert packets[0].startswith(b"OpusHead") + assert packets[1].startswith(b"OpusTags") + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-q"])) From 7a4ef5a6e1767d8ffaf04d91fc988df81fba42d6 Mon Sep 17 00:00:00 2001 From: Neppkun Date: Tue, 2 Jun 2026 21:49:01 +0300 Subject: [PATCH 02/12] Restore PLW0717 to ruff ignore list `PLW0717` ("Try clause contains too many statements") is a valid, active rule in the canonical lint toolchain (`make lint` -> uv-resolved ruff 0.15.x) and was deliberately ignored in the original pyproject.toml. An earlier commit removed it after a misdiagnosis from a system-installed ruff build that didn't recognize the selector; that re-exposed 10 PLW0717 violations (7 in pre-existing files). Restoring the maintainer's ignore entry makes `make lint` pass clean again. Co-Authored-By: Claude Opus 4.8 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 4f10cef..14757b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,6 +148,7 @@ ignore = [ "FURB103", # write-whole-file # Pylint Warnings + "PLW0717", # Try clause contains too many statements "PLW2901", # Overwrite loop control variable # Type hints From d0350db4cc2c6e1f87585dd8090f70be9a4f4e44 Mon Sep 17 00:00:00 2001 From: Neppkun Date: Tue, 2 Jun 2026 22:55:01 +0300 Subject: [PATCH 03/12] Address PR review feedback on voice support Implements AlexFlipnote's review comments on #29: - enums: rename VoiceOp -> VoiceOpType to match the library's *Type naming - errors: move OpusError/OpusNotLoaded into the root errors.py - opus: drop the module-level `global` loader state in favour of an encapsulated loader object - socket: VoiceCloseCode now subclasses BaseEnum; reuse the bot's shared aiohttp session instead of the deep try/except + self-owned session - gateway_udp: narrow the transport with isinstance instead of a type: ignore, keeping full type safety - connection: parse the voice endpoint with the library's URL helper, import has_dave/DaveManager at module level, and read the reconnect attempt cap from the new Client.voice_reconnect_attempts variable - dave: early-return guard in set_passthrough_mode and a match/case in handle_binary - client (voice): move_to accepts a channel or an int id; play uses a strict AudioSourceInput type alias instead of Any - gateway: voice clients are driven from shard special handlers, not from asyncio tasks created inside the parser; the parser methods are pure again - gateway/client: decouple voice from the intent socket lifecycle (drop the READY/RESUMED revive + reset teardown and the resume_voice plumbing), since the voice websocket manages its own reconnect - voice package: drop __all__ from __init__ and rely on each module's __all__ - example: target the caller's current voice channel instead of hard-coded ids - logging: use f-strings throughout the voice package to match the codebase - tests: port the ogg parser tests to unittest (the project's runner) All checks green: ruff clean, pyright 0 errors, 106 unittest tests pass. Co-Authored-By: Claude Opus 4.8 --- discord_http/client.py | 118 ++---------------- discord_http/errors.py | 10 ++ discord_http/gateway/parser.py | 10 -- discord_http/gateway/shard.py | 56 ++++----- discord_http/voice/__init__.py | 34 +----- discord_http/voice/client.py | 18 +-- discord_http/voice/connection.py | 54 ++++----- discord_http/voice/dave.py | 67 +++++----- discord_http/voice/enums.py | 4 +- discord_http/voice/gateway_udp.py | 9 +- discord_http/voice/opus.py | 59 +++++---- discord_http/voice/player.py | 5 + discord_http/voice/socket.py | 82 ++++++------- examples/voice_example.py | 38 ++++-- tests/test_voice_oggparse.py | 195 +++++++++++++++--------------- 15 files changed, 326 insertions(+), 433 deletions(-) diff --git a/discord_http/client.py b/discord_http/client.py index 44bf3c1..1d2e5a7 100644 --- a/discord_http/client.py +++ b/discord_http/client.py @@ -102,9 +102,9 @@ class Client: Whether to disable the default GET path or not, if not provided, it will use `False`. The default GET path only provides information about the bot and when it was last rebooted. Usually a great tool to just validate that your bot is online. - resume_voice: bool - Whether to remember and revive voice clients across shard reboots, if not provided, it will use `False`. - Requires the `guild_voice_states` intent; otherwise voice clients are torn down on shard reset. + voice_reconnect_attempts: int + How many times a voice connection will try to fully reconnect after an + unexpected close before giving up, if not provided, it will use `5`. """ def __init__( self, @@ -130,7 +130,7 @@ def __init__( logging_level: int = logging.INFO, disable_default_get_path: bool = False, debug_events: bool = False, - resume_voice: bool = False + voice_reconnect_attempts: int = 5 ): if application_id is not None: _log.warning( @@ -153,7 +153,11 @@ def __init__( self.logging_level: int = logging_level self.debug_events: bool = debug_events self.enable_gateway: bool = enable_gateway - self.resume_voice: bool = resume_voice + self.voice_reconnect_attempts: int = voice_reconnect_attempts + """ + How many times a voice connection will try to fully reconnect after an + unexpected close before giving up. + """ self.playing_status: "PlayingStatus | None" = playing_status self.guild_ready_timeout: float = guild_ready_timeout self.chunk_guilds_on_startup: bool = chunk_guilds_on_startup @@ -261,110 +265,6 @@ def _remove_voice_client(self, guild_id: int) -> None: """ self._voice_clients.pop(guild_id, None) - def _voice_clients_for_shard(self, shard_id: int) -> "list[VoiceClient]": - """ - Return the registered voice clients whose guilds belong to a shard. - - Parameters - ---------- - shard_id: - The shard to enumerate voice clients for. - - Returns - ------- - The voice clients belonging to the shard's guilds. - """ - if not self.gateway: - return [] - - return [ - vc for guild_id, vc in list(self._voice_clients.items()) - if self.get_shard_by_guild_id(guild_id) == shard_id - ] - - def _has_voice_states_intent(self, shard_id: int) -> bool: - """ - Whether the ``guild_voice_states`` intent is enabled for a shard. - - Parameters - ---------- - shard_id: - The shard to check the intents of. - - Returns - ------- - ``True`` if the intent is available, otherwise ``False``. - """ - from .gateway.flags import Intents - - if not self.gateway: - return False - - shard = self.gateway.get_shard(shard_id) - if shard is None: - return False - - return Intents.guild_voice_states in shard.intents - - async def _revive_voice_clients(self, shard_id: int) -> None: - """ - Re-establish remembered voice clients after a shard READY/RESUMED. - - This is a no-op unless :attr:`resume_voice` is enabled, the - ``guild_voice_states`` intent is available, and there are voice clients - belonging to the shard. Each remembered client re-issues op4 and runs a - fresh handshake, resuming playback where possible. - - Parameters - ---------- - shard_id: - The shard that just became ready or resumed. - """ - if not self.resume_voice: - return - - voice_clients = self._voice_clients_for_shard(shard_id) - if not voice_clients: - return - - if not self._has_voice_states_intent(shard_id): - _log.warning( - "resume_voice is enabled but the guild_voice_states intent is missing; " - "tearing down voice clients for shard %s", - shard_id - ) - await self._teardown_voice_clients_for_shard(shard_id) - return - - for vc in voice_clients: - if vc.is_connected(): - # Still alive (e.g. a RESUMED where nothing was torn down). - continue - try: - await vc.connect( - self_deaf=vc.connection._self_deaf, - self_mute=vc.connection._self_mute, - ) - except Exception as exc: - _log.warning("Failed to revive voice client for guild %s", vc.guild_id, exc_info=exc) - await vc._cleanup() - - async def _teardown_voice_clients_for_shard(self, shard_id: int) -> None: - """ - Tear down every voice client belonging to a shard's guilds. - - Stops playback, closes the websocket and UDP transport, and removes the - client from the registry. Used when a shard is reset or killed and the - connections should not survive. - - Parameters - ---------- - shard_id: - The shard whose voice clients should be torn down. - """ - for vc in self._voice_clients_for_shard(shard_id): - await vc._cleanup() - async def _cooldown_cleanup_loop(self) -> None: """ Periodically sweeps expired cooldown buckets that accumulate between invocations. """ while True: diff --git a/discord_http/errors.py b/discord_http/errors.py index dcf469e..409c82d 100644 --- a/discord_http/errors.py +++ b/discord_http/errors.py @@ -20,6 +20,8 @@ "HTTPException", "InvalidMember", "NotFound", + "OpusError", + "OpusNotLoaded", "Ratelimited", "UserMissingPermissions", ) @@ -29,6 +31,14 @@ class DiscordException(Exception): # noqa: N818 """ Base exception for discord_http. """ +class OpusError(DiscordException): + """ Raised when libopus returns an error code. """ + + +class OpusNotLoaded(DiscordException): + """ Raised when an Opus operation is attempted but libopus is not available. """ + + class CheckFailed(DiscordException): """ Raised whenever a check fails. """ diff --git a/discord_http/gateway/parser.py b/discord_http/gateway/parser.py index 6d8586a..f216006 100644 --- a/discord_http/gateway/parser.py +++ b/discord_http/gateway/parser.py @@ -1351,10 +1351,6 @@ def voice_server_update(self, data: dict) -> tuple[dict]: ------- The raw voice server update payload. """ - vc = self.bot._get_voice_client(int(data["guild_id"])) - if vc is not None: - self.bot.loop.create_task(vc.on_voice_server_update(data)) - return (data,) def voice_state_update(self, data: dict) -> tuple[ @@ -1396,12 +1392,6 @@ def voice_state_update(self, data: dict) -> tuple[ self.bot.cache.update_voice_state(vs) - bot_user = self.bot.application.bot if self.bot.application else None - if bot_user is not None and int(data["user_id"]) == bot_user.id: - vc = self.bot._get_voice_client(int(data["guild_id"])) - if vc is not None: - self.bot.loop.create_task(vc.on_voice_state_update(data)) - return (before_vs, vs) def typing_start(self, data: dict) -> tuple[TypingStartEvent]: diff --git a/discord_http/gateway/shard.py b/discord_http/gateway/shard.py index cf46fe0..173c19b 100644 --- a/discord_http/gateway/shard.py +++ b/discord_http/gateway/shard.py @@ -352,6 +352,8 @@ def __init__( "GUILD_CREATE": (self._parse_guild_create, True), "GUILD_DELETE": (self._parse_guild_delete, False), "GUILD_MEMBERS_CHUNK": (self._parse_guild_members_chunk, True), + "VOICE_STATE_UPDATE": (self._parse_voice_state_update, False), + "VOICE_SERVER_UPDATE": (self._parse_voice_server_update, False), } @property @@ -370,31 +372,7 @@ def _reset_buffer(self) -> None: self._buffer = bytearray() self._zlib = zlib.decompressobj() - def _revive_voice_clients(self) -> None: - """ Schedule revival of remembered voice clients after READY/RESUMED. """ - if not self.bot._voice_clients: - return - task = asyncio.create_task( - self.bot._revive_voice_clients(self.shard_id), - name=f"discord.http/gateway/shard-{self.shard_id}/revive_voice" - ) - self.bot._background_tasks.add(task) - task.add_done_callback(self.bot._cleanup_task) - - def _teardown_voice_clients(self) -> None: - """ Insta-leave any voice clients on this shard unless persistence is on. """ - if self.bot.resume_voice or not self.bot._voice_clients: - return - task = asyncio.create_task( - self.bot._teardown_voice_clients_for_shard(self.shard_id), - name=f"discord.http/gateway/shard-{self.shard_id}/teardown_voice" - ) - self.bot._background_tasks.add(task) - task.add_done_callback(self.bot._cleanup_task) - def _reset_instance(self) -> None: - self._teardown_voice_clients() - self._reset_buffer() self.status.reset() @@ -593,11 +571,7 @@ async def received_message(self, raw_msg: str | bytes) -> None: name=f"discord.http/gateway/shard-{self.shard_id}/delay_ready" ) - self._revive_voice_clients() - case "RESUMED": - self._revive_voice_clients() - if self.bot.has_any_dispatch("shard_resumed"): self.bot.dispatch( "shard_resumed", @@ -1084,6 +1058,32 @@ def _parse_guild_delete(self, data: dict) -> None: self._send_dispatch(event_name, guild) + def _parse_voice_state_update(self, data: dict) -> None: + payload = self.parser.voice_state_update(data) + + bot_user = self.bot.application.bot if self.bot.application else None + if ( + bot_user is not None + and data.get("guild_id") is not None + and int(data["user_id"]) == bot_user.id + ): + vc = self.bot._get_voice_client(int(data["guild_id"])) + if vc is not None: + vc.on_voice_state_update(data) + + if self.bot.has_any_dispatch("voice_state_update"): + self._send_dispatch("voice_state_update", *payload) + + def _parse_voice_server_update(self, data: dict) -> None: + (payload,) = self.parser.voice_server_update(data) + + vc = self.bot._get_voice_client(int(data["guild_id"])) + if vc is not None: + vc.on_voice_server_update(data) + + if self.bot.has_any_dispatch("voice_server_update"): + self._send_dispatch("voice_server_update", payload) + async def _parse_guild_members_chunk(self, data: dict) -> None: result = self.parser.guild_members_chunk(data) diff --git a/discord_http/voice/__init__.py b/discord_http/voice/__init__.py index ec6d7c5..62d40a7 100644 --- a/discord_http/voice/__init__.py +++ b/discord_http/voice/__init__.py @@ -1,36 +1,10 @@ -# ruff: noqa: F403, F405 +# ruff: noqa: F401, F403 from . import opus from .client import * from .connection import * -from .dave import has_dave, max_protocol_version -from .enums import SUPPORTED_MODES, VoiceOp -from .opus import OPUS_SILENCE, OpusError, OpusNotLoaded, is_loaded, load_opus +from .dave import * +from .enums import * +from .opus import * from .player import * from .receiver import * from .sinks import * - -__all__ = ( - "OPUS_SILENCE", - "SUPPORTED_MODES", - "AudioPlayer", - "AudioSink", - "AudioSource", - "CallbackSink", - "FFmpegOpusAudio", - "FFmpegPCMAudio", - "OpusError", - "OpusNotLoaded", - "PCMAudio", - "PCMVolumeTransformer", - "VoiceClient", - "VoiceConnection", - "VoiceData", - "VoiceOp", - "VoiceReceiver", - "WaveSink", - "has_dave", - "is_loaded", - "load_opus", - "max_protocol_version", - "opus", -) diff --git a/discord_http/voice/client.py b/discord_http/voice/client.py index 5935571..b68261e 100644 --- a/discord_http/voice/client.py +++ b/discord_http/voice/client.py @@ -3,7 +3,7 @@ import struct from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from .connection import VoiceConnection @@ -11,7 +11,7 @@ from ..channel import PartialChannel from ..client import Client from .opus import Encoder - from .player import AudioPlayer + from .player import AudioPlayer, AudioSourceInput from .receiver import VoiceReceiver from .sinks import AudioSink @@ -176,19 +176,21 @@ async def _cleanup(self) -> None: await self.connection.close_transport() self.client._remove_voice_client(self.guild_id) - async def move_to(self, channel: "PartialChannel") -> None: + async def move_to(self, channel: "PartialChannel | int") -> None: """ Move to a different voice channel. Parameters ---------- channel: - The channel to move to. + The channel to move to, either a channel object or its ID. """ + if isinstance(channel, int): + channel = self.client.get_partial_channel(channel, guild_id=self.guild_id) await self.connection.move_to(channel) self.channel = channel - async def on_voice_state_update(self, data: dict) -> None: + def on_voice_state_update(self, data: dict) -> None: """ Forward a VOICE_STATE_UPDATE to the connection. @@ -199,7 +201,7 @@ async def on_voice_state_update(self, data: dict) -> None: """ self.connection.on_voice_state_update(data) - async def on_voice_server_update(self, data: dict) -> None: + def on_voice_server_update(self, data: dict) -> None: """ Forward a VOICE_SERVER_UPDATE to the connection. @@ -270,11 +272,11 @@ def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: try: connection.transport.sendto(packet) except OSError: - _log.debug("Failed to send audio packet for guild %s", self.guild_id) + _log.debug(f"Failed to send audio packet for guild {self.guild_id}") def play( self, - audio: Any, # noqa: ANN401 + audio: "AudioSourceInput", *, after: Callable[[Exception | None], object] | None = None ) -> None: diff --git a/discord_http/voice/connection.py b/discord_http/voice/connection.py index 7bf5d45..14ce8c6 100644 --- a/discord_http/voice/connection.py +++ b/discord_http/voice/connection.py @@ -3,7 +3,8 @@ from typing import TYPE_CHECKING, Any -from ..utils import ExponentialBackoff +from ..utils import URL, ExponentialBackoff +from .dave import DaveManager, has_dave from .encryptor import Encryptor from .enums import SUPPORTED_MODES from .gateway_udp import VoiceUDPProtocol, create_udp @@ -12,15 +13,11 @@ if TYPE_CHECKING: from ..channel import PartialChannel from .client import VoiceClient - from .dave import DaveManager __all__ = ("VoiceConnection",) _log = logging.getLogger(__name__) -#: The maximum number of full-reconnect attempts before giving up. -MAX_RECONNECT_ATTEMPTS = 5 - class VoiceConnection: """ @@ -234,27 +231,27 @@ async def _handle_close(self, close_code: int | None) -> None: The websocket close code, if one was reported. """ if close_code in (VoiceCloseCode.disconnected, VoiceCloseCode.call_terminated): - _log.info("Voice connection for guild %s disconnected (code %s); tearing down", self.guild_id, close_code) + _log.info(f"Voice connection for guild {self.guild_id} disconnected (code {close_code}); tearing down") await self._teardown_and_remove() return if close_code == VoiceCloseCode.rate_limited: - _log.warning("Voice connection for guild %s was rate limited (code %s); not reconnecting", self.guild_id, close_code) + _log.warning(f"Voice connection for guild {self.guild_id} was rate limited (code {close_code}); not reconnecting") await self._teardown_and_remove() return if close_code == VoiceCloseCode.voice_server_crashed: - _log.info("Voice server for guild %s crashed (code %s); resuming", self.guild_id, close_code) + _log.info(f"Voice server for guild {self.guild_id} crashed (code {close_code}); resuming") await self._resume() return if close_code in (VoiceCloseCode.normal, VoiceCloseCode.going_away): - _log.debug("Voice connection for guild %s closed cleanly (code %s)", self.guild_id, close_code) + _log.debug(f"Voice connection for guild {self.guild_id} closed cleanly (code {close_code})") await self._teardown_and_remove() return if not self._reconnect: - _log.debug("Voice connection for guild %s closed (code %s); reconnect disabled", self.guild_id, close_code) + _log.debug(f"Voice connection for guild {self.guild_id} closed (code {close_code}); reconnect disabled") await self._teardown_and_remove() return @@ -271,8 +268,8 @@ async def _resume(self) -> None: self.socket = VoiceSocket(self) await self.socket.connect(resume=True) except Exception as exc: - _log.warning("Voice resume for guild %s failed; falling back to full reconnect", self.guild_id, exc_info=exc) - await self._full_reconnect(VoiceCloseCode.voice_server_crashed) + _log.warning(f"Voice resume for guild {self.guild_id} failed; falling back to full reconnect", exc_info=exc) + await self._full_reconnect(int(VoiceCloseCode.voice_server_crashed)) async def _full_reconnect(self, close_code: int | None) -> None: """ @@ -295,14 +292,16 @@ async def _full_reconnect(self, close_code: int | None) -> None: self._backoff.reset() - for attempt in range(1, MAX_RECONNECT_ATTEMPTS + 1): + max_attempts = self.voice_client.client.voice_reconnect_attempts + + for attempt in range(1, max_attempts + 1): if self._closing: return delay = self._backoff.delay() _log.info( - "Reconnecting voice for guild %s (close code %s), attempt %s/%s in %.2fs", - self.guild_id, close_code, attempt, MAX_RECONNECT_ATTEMPTS, delay + f"Reconnecting voice for guild {self.guild_id} (close code {close_code}), " + f"attempt {attempt}/{max_attempts} in {delay:.2f}s" ) await asyncio.sleep(delay) @@ -313,13 +312,13 @@ async def _full_reconnect(self, close_code: int | None) -> None: self_mute=self._self_mute, ) except Exception as exc: - _log.warning("Voice reconnect attempt %s for guild %s failed", attempt, self.guild_id, exc_info=exc) + _log.warning(f"Voice reconnect attempt {attempt} for guild {self.guild_id} failed", exc_info=exc) continue else: - _log.info("Voice connection for guild %s reconnected", self.guild_id) + _log.info(f"Voice connection for guild {self.guild_id} reconnected") return - _log.error("Voice connection for guild %s could not reconnect after %s attempts; tearing down", self.guild_id, MAX_RECONNECT_ATTEMPTS) + _log.error(f"Voice connection for guild {self.guild_id} could not reconnect after {max_attempts} attempts; tearing down") await self._teardown_and_remove() async def _teardown_and_remove(self) -> None: @@ -327,7 +326,7 @@ async def _teardown_and_remove(self) -> None: try: await self.voice_client.disconnect(force=True) except Exception as exc: - _log.debug("Error during voice teardown for guild %s", self.guild_id, exc_info=exc) + _log.debug(f"Error during voice teardown for guild {self.guild_id}", exc_info=exc) def on_voice_state_update(self, data: dict) -> None: """ @@ -359,10 +358,9 @@ def on_voice_server_update(self, data: dict) -> None: endpoint = data.get("endpoint") if endpoint: - endpoint = endpoint.removeprefix("wss://").removeprefix("ws://") - endpoint = endpoint.split("/", 1)[0] - endpoint = endpoint.rsplit(":", 1)[0] - self.endpoint = endpoint + # Discord sends the endpoint without a scheme (e.g. "host.discord.media:443"); + # the URL helper cleanly extracts the host without the port for us. + self.endpoint = URL(f"wss://{endpoint}").host or endpoint server_id = data.get("guild_id") or data.get("server_id") self.server_id = int(server_id) if server_id is not None else None @@ -445,7 +443,7 @@ async def on_resumed(self, data: dict) -> None: # noqa: ARG002 data: The resumed payload. """ - _log.debug("Voice connection for guild %s resumed", self.guild_id) + _log.debug(f"Voice connection for guild {self.guild_id} resumed") async def on_dave_binary(self, opcode: int, payload: bytes) -> None: """ @@ -458,10 +456,8 @@ async def on_dave_binary(self, opcode: int, payload: bytes) -> None: payload: The binary payload following the opcode. """ - from .dave import has_dave - if not has_dave: - _log.warning("Received DAVE binary op %s but the davey library is not available", opcode) + _log.warning(f"Received DAVE binary op {opcode} but the davey library is not available") return if self.dave_session is None: @@ -472,8 +468,6 @@ async def on_dave_binary(self, opcode: int, payload: bytes) -> None: async def reinit_dave_session(self) -> None: """ Create or reset the DAVE session for the negotiated protocol version. """ - from .dave import DaveManager, has_dave - if not has_dave: if self.dave_protocol_version > 0: raise RuntimeError( @@ -553,7 +547,7 @@ async def disconnect(self, *, force: bool = True) -> None: except Exception as exc: if not force: raise - _log.debug("Failed to send voice disconnect for guild %s", self.guild_id, exc_info=exc) + _log.debug(f"Failed to send voice disconnect for guild {self.guild_id}", exc_info=exc) if self.socket is not None: await self.socket.close() diff --git a/discord_http/voice/dave.py b/discord_http/voice/dave.py index 1e68b9e..a5aba62 100644 --- a/discord_http/voice/dave.py +++ b/discord_http/voice/dave.py @@ -9,7 +9,7 @@ davey = None has_dave = False -from .enums import VoiceOp +from .enums import VoiceOpType if TYPE_CHECKING: from .connection import VoiceConnection @@ -141,7 +141,7 @@ async def reinit(self, version: int) -> None: self._connection.channel_id, ) except Exception as exc: - _log.warning("Failed to initialise DAVE session: %s", exc) + _log.warning(f"Failed to initialise DAVE session: {exc}") self._session = None return @@ -158,7 +158,7 @@ async def _send_key_package(self) -> None: return await self._connection.socket.send_binary( - int(VoiceOp.dave_mls_key_package), key_package + int(VoiceOpType.dave_mls_key_package), key_package ) def set_passthrough_mode(self, enabled: bool) -> None: @@ -173,11 +173,13 @@ def set_passthrough_mode(self, enabled: bool) -> None: enabled: ``True`` to pass media through unchanged, ``False`` to resume encryption. """ - if self._session is not None: - try: - self._session.set_passthrough_mode(enabled) - except AttributeError: - pass + if self._session is None: + return + + try: + self._session.set_passthrough_mode(enabled) + except AttributeError: + pass def encrypt_opus(self, opus: bytes) -> bytes: """ @@ -232,22 +234,23 @@ async def handle_binary(self, opcode: int, payload: bytes) -> None: payload: The raw binary payload following the opcode. """ - if opcode == VoiceOp.dave_prepare_transition: - await self._handle_prepare_transition(payload) - elif opcode == VoiceOp.dave_execute_transition: - await self._handle_execute_transition(payload) - elif opcode == VoiceOp.dave_prepare_epoch: - await self._handle_prepare_epoch(payload) - elif opcode == VoiceOp.dave_mls_external_sender: - self._handle_external_sender(payload) - elif opcode == VoiceOp.dave_mls_proposals: - await self._handle_proposals(payload) - elif opcode == VoiceOp.dave_mls_announce_commit_transition: - await self._handle_commit(payload) - elif opcode == VoiceOp.dave_mls_welcome: - await self._handle_welcome(payload) - else: - _log.debug("Unhandled DAVE binary opcode %s", opcode) + match opcode: + case VoiceOpType.dave_prepare_transition: + await self._handle_prepare_transition(payload) + case VoiceOpType.dave_execute_transition: + await self._handle_execute_transition(payload) + case VoiceOpType.dave_prepare_epoch: + await self._handle_prepare_epoch(payload) + case VoiceOpType.dave_mls_external_sender: + self._handle_external_sender(payload) + case VoiceOpType.dave_mls_proposals: + await self._handle_proposals(payload) + case VoiceOpType.dave_mls_announce_commit_transition: + await self._handle_commit(payload) + case VoiceOpType.dave_mls_welcome: + await self._handle_welcome(payload) + case _: + _log.debug(f"Unhandled DAVE binary opcode {opcode}") async def _handle_prepare_transition(self, payload: bytes) -> None: """ Handle PREPARE_TRANSITION (21): record the pending transition and acknowledge. """ @@ -258,7 +261,7 @@ async def _handle_prepare_transition(self, payload: bytes) -> None: await self._execute_transition(transition_id, version) else: await self._connection.socket.send_binary( - int(VoiceOp.dave_transition_ready), self._encode_transition_id(transition_id) + int(VoiceOpType.dave_transition_ready), self._encode_transition_id(transition_id) ) async def _handle_execute_transition(self, payload: bytes) -> None: @@ -271,14 +274,14 @@ async def _handle_execute_transition(self, payload: bytes) -> None: await self._execute_transition(transition_id, version) return - _log.debug("Received EXECUTE_TRANSITION for unknown transition %s", transition_id) + _log.debug(f"Received EXECUTE_TRANSITION for unknown transition {transition_id}") async def _execute_transition(self, transition_id: int, version: int) -> None: """ Apply a transition: switch protocol version and update passthrough mode. """ self._version = version self.set_passthrough_mode(version == 0) self._pending_transition = None - _log.debug("Executed DAVE transition %s to version %s", transition_id, version) + _log.debug(f"Executed DAVE transition {transition_id} to version {version}") async def _handle_prepare_epoch(self, payload: bytes) -> None: """ Handle PREPARE_EPOCH (24): reinitialise the session for a new MLS epoch. """ @@ -303,7 +306,7 @@ async def _handle_proposals(self, payload: bytes) -> None: except AttributeError: return except Exception as exc: - _log.warning("Failed to process MLS proposals: %s", exc) + _log.warning(f"Failed to process MLS proposals: {exc}") await self._recover_from_invalid_commit() return @@ -313,7 +316,7 @@ async def _handle_proposals(self, payload: bytes) -> None: commit_welcome = self._extract_commit_welcome(result) if commit_welcome is not None: await self._connection.socket.send_binary( - int(VoiceOp.dave_mls_commit_welcome), commit_welcome + int(VoiceOpType.dave_mls_commit_welcome), commit_welcome ) async def _handle_commit(self, payload: bytes) -> None: @@ -326,7 +329,7 @@ async def _handle_commit(self, payload: bytes) -> None: except AttributeError: return except Exception as exc: - _log.warning("Failed to process MLS commit: %s", exc) + _log.warning(f"Failed to process MLS commit: {exc}") await self._recover_from_invalid_commit() async def _handle_welcome(self, payload: bytes) -> None: @@ -339,13 +342,13 @@ async def _handle_welcome(self, payload: bytes) -> None: except AttributeError: return except Exception as exc: - _log.warning("Failed to process MLS welcome: %s", exc) + _log.warning(f"Failed to process MLS welcome: {exc}") await self._recover_from_invalid_commit() async def _recover_from_invalid_commit(self) -> None: """ Notify the gateway of an invalid commit/welcome and reinitialise the session. """ await self._connection.socket.send_binary( - int(VoiceOp.dave_mls_invalid_commit_welcome), b"" + int(VoiceOpType.dave_mls_invalid_commit_welcome), b"" ) await self.reinit(self._version) diff --git a/discord_http/voice/enums.py b/discord_http/voice/enums.py index 5baafec..25f9445 100644 --- a/discord_http/voice/enums.py +++ b/discord_http/voice/enums.py @@ -2,13 +2,13 @@ __all__ = ( "SUPPORTED_MODES", - "VoiceOp", + "VoiceOpType", ) SUPPORTED_MODES: tuple[str, ...] = ("aead_aes256_gcm_rtpsize",) -class VoiceOp(BaseEnum): +class VoiceOpType(BaseEnum): """ Represents the opcode type of a voice gateway payload. """ identify = 0 select_protocol = 1 diff --git a/discord_http/voice/gateway_udp.py b/discord_http/voice/gateway_udp.py index 3e9dd17..7ce2d16 100644 --- a/discord_http/voice/gateway_udp.py +++ b/discord_http/voice/gateway_udp.py @@ -38,7 +38,10 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: transport: The datagram transport for this protocol. """ - self.transport = transport # type: ignore[assignment] + # A DatagramProtocol is only ever driven by a DatagramTransport; the + # isinstance check narrows the BaseTransport type without a cast. + if isinstance(transport, asyncio.DatagramTransport): + self.transport = transport def error_received(self, exc: Exception) -> None: """ @@ -49,7 +52,7 @@ def error_received(self, exc: Exception) -> None: exc: The exception reported by the transport. """ - _log.warning("Voice UDP error for guild %s: %s", self.connection.guild_id, exc) + _log.warning(f"Voice UDP error for guild {self.connection.guild_id}: {exc}") def connection_lost(self, exc: Exception | None) -> None: """ @@ -61,7 +64,7 @@ def connection_lost(self, exc: Exception | None) -> None: The exception that caused the loss, if any. """ if exc is not None: - _log.debug("Voice UDP connection lost for guild %s", self.connection.guild_id, exc_info=exc) + _log.debug(f"Voice UDP connection lost for guild {self.connection.guild_id}", exc_info=exc) self.transport = None def datagram_received(self, data: bytes, addr: tuple) -> None: # noqa: ARG002 diff --git a/discord_http/voice/opus.py b/discord_http/voice/opus.py index d66f782..d04991a 100644 --- a/discord_http/voice/opus.py +++ b/discord_http/voice/opus.py @@ -2,6 +2,8 @@ import ctypes.util import logging +from ..errors import OpusError, OpusNotLoaded + __all__ = ( "OPUS_APPLICATION_AUDIO", "OPUS_APPLICATION_LOWDELAY", @@ -77,22 +79,31 @@ } -class OpusError(Exception): - """ Raised when libopus returns an error code. """ +# Opaque handle types. libopus only ever hands these back as pointers. +EncoderStruct = ctypes.c_void_p +DecoderStruct = ctypes.c_void_p -class OpusNotLoaded(Exception): # noqa: N818 - """ Raised when an Opus operation is attempted but libopus is not available. """ +class _OpusLoader: + """ + Holds the lazily-loaded libopus handle and its load state. + Encapsulating the state on an instance avoids module-level ``global`` + statements: :func:`load_opus` mutates the attributes of this singleton + instead of rebinding module globals. + """ -# Opaque handle types. libopus only ever hands these back as pointers. -EncoderStruct = ctypes.c_void_p -DecoderStruct = ctypes.c_void_p + __slots__ = ("attempted", "lib") + + def __init__(self) -> None: + self.lib: ctypes.CDLL | None = None + """ The loaded libopus shared library, or ``None`` if unavailable. """ -# Module-level loader state. ``_lib`` is the loaded CDLL or None, and -# ``_loaded`` is a sentinel so a failed/empty lazy search is not repeated. -_lib: ctypes.CDLL | None = None -_loaded: bool = False + self.attempted: bool = False + """ Whether a load has been attempted (so it is not retried endlessly). """ + + +_loader = _OpusLoader() def _configure_lib(lib: ctypes.CDLL) -> None: @@ -181,9 +192,7 @@ def load_opus(name: str | None = None) -> None: OpusError If an explicit ``name`` was given but the library could not be loaded. """ - global _lib, _loaded # noqa: PLW0603 - - _loaded = True + _loader.attempted = True location = name if location is None: @@ -192,7 +201,7 @@ def load_opus(name: str | None = None) -> None: if location is None: # No library on the system; this is not fatal. The passthrough/E2EE # paths can still operate without libopus present. - _lib = None + _loader.lib = None _log.debug("libopus could not be located; Opus encode/decode is unavailable") return @@ -200,15 +209,15 @@ def load_opus(name: str | None = None) -> None: lib = ctypes.CDLL(location) _configure_lib(lib) except (OSError, AttributeError) as exc: - _lib = None + _loader.lib = None if name is not None: # The caller explicitly asked for this library, so surface failure. raise OpusError(f"Could not load libopus from {location!r}") from exc - _log.warning("Found libopus at %r but failed to load it: %s", location, exc) + _log.warning(f"Found libopus at {location!r} but failed to load it: {exc}") return - _lib = lib - _log.debug("Successfully loaded libopus from %r", location) + _loader.lib = lib + _log.debug(f"Successfully loaded libopus from {location!r}") def is_loaded() -> bool: @@ -223,10 +232,10 @@ def is_loaded() -> bool: ------- ``True`` if libopus is loaded and ready, ``False`` otherwise. """ - if not _loaded: + if not _loader.attempted: load_opus() - return _lib is not None + return _loader.lib is not None def _get_lib() -> ctypes.CDLL: @@ -242,13 +251,13 @@ def _get_lib() -> ctypes.CDLL: OpusNotLoaded If libopus could not be found or loaded. """ - if not _loaded: + if not _loader.attempted: load_opus() - if _lib is None: + if _loader.lib is None: raise OpusNotLoaded("libopus is not loaded; install the Opus shared library to use voice encode/decode") - return _lib + return _loader.lib def _strerror(code: int) -> str: @@ -264,7 +273,7 @@ def _strerror(code: int) -> str: ------- The decoded error string. """ - lib = _lib + lib = _loader.lib if lib is None: return f"error code {code}" diff --git a/discord_http/voice/player.py b/discord_http/voice/player.py index 94b274d..71cc17d 100644 --- a/discord_http/voice/player.py +++ b/discord_http/voice/player.py @@ -18,6 +18,7 @@ __all__ = ( "AudioPlayer", "AudioSource", + "AudioSourceInput", "FFmpegOpusAudio", "FFmpegPCMAudio", "PCMAudio", @@ -650,6 +651,10 @@ def set_source(self, source: AudioSource) -> None: self.resume() +AudioSourceInput = AudioSource | str | os.PathLike | bytes | bytearray | memoryview | io.IOBase | AsyncIterable[bytes] +""" The set of inputs accepted as audio sources by :meth:`VoiceClient.play`. """ + + def _resolve_source(audio: object) -> AudioSource: """ Coerce arbitrary audio input into an :class:`AudioSource`. diff --git a/discord_http/voice/socket.py b/discord_http/voice/socket.py index eb8b638..1254540 100644 --- a/discord_http/voice/socket.py +++ b/discord_http/voice/socket.py @@ -10,9 +10,8 @@ from collections.abc import Coroutine from typing import TYPE_CHECKING, Any -from enum import IntEnum - -from .enums import VoiceOp +from ..enums import BaseEnum +from .enums import VoiceOpType if TYPE_CHECKING: from .connection import VoiceConnection @@ -22,7 +21,7 @@ _log = logging.getLogger(__name__) -class VoiceCloseCode(IntEnum): +class VoiceCloseCode(BaseEnum): """ The voice gateway websocket close codes that govern reconnect behaviour. """ normal = 1000 @@ -56,8 +55,6 @@ def __init__(self, connection: "VoiceConnection"): self._closing: bool = False self._resuming: bool = False - self._own_session: bool = False - self._session: ClientSession | None = None self._heartbeat_interval: float = 0.0 self._heartbeat_task: asyncio.Task | None = None @@ -81,25 +78,24 @@ def average_latency(self) -> float: return float("inf") return sum(self._latencies) / len(self._latencies) - def _get_session(self) -> ClientSession: + @property + def session(self) -> ClientSession: """ - Return a usable aiohttp session, reusing the bot's if reachable. + The shared aiohttp session from the bot's HTTP client. Returns ------- - The shared HTTP session, or a freshly created one owned by this socket. - """ - try: - session = self.connection.voice_client.client.state.http.session - except AttributeError: - session = None + The bot's HTTP session, reused for the voice websocket. - if session is not None: - return session - - self._own_session = True - self._session = ClientSession() - return self._session + Raises + ------ + RuntimeError + If the HTTP session is not available (the client is not running). + """ + session = self.connection.voice_client.client.state.http.session + if session is None: + raise RuntimeError("HTTP session is not available; the client must be running to open a voice socket") + return session async def connect(self, *, resume: bool = False) -> None: """ @@ -112,9 +108,8 @@ async def connect(self, *, resume: bool = False) -> None: """ self._closing = False self._resuming = resume - session = self._get_session() endpoint = self.connection.endpoint - self.ws = await session.ws_connect(f"wss://{endpoint}/?v=8") + self.ws = await self.session.ws_connect(f"wss://{endpoint}/?v=8") self._receive_task = asyncio.create_task( self._receive_loop(), @@ -138,18 +133,18 @@ async def _receive_loop(self) -> None: self._dispatch_binary(msg.data) elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): - _log.debug("Voice socket for guild %s received close frame", self.connection.guild_id) + _log.debug(f"Voice socket for guild {self.connection.guild_id} received close frame") break elif msg.type is WSMsgType.ERROR: - _log.warning("Voice socket for guild %s received error: %s", self.connection.guild_id, msg.data) + _log.warning(f"Voice socket for guild {self.connection.guild_id} received error: {msg.data}") break except asyncio.CancelledError: raise except Exception as exc: - _log.debug("Voice socket for guild %s receive loop ended", self.connection.guild_id, exc_info=exc) + _log.debug(f"Voice socket for guild {self.connection.guild_id} receive loop ended", exc_info=exc) close_code = ws.close_code @@ -179,13 +174,13 @@ def _dispatch_text(self, raw: str | bytes) -> None: data: dict = payload.get("d") or {} try: - voice_op = VoiceOp(op) + voice_op = VoiceOpType(op) except ValueError: - _log.debug("Voice socket for guild %s received unknown op %s", self.connection.guild_id, op) + _log.debug(f"Voice socket for guild {self.connection.guild_id} received unknown op {op}") return match voice_op: - case VoiceOp.hello: + case VoiceOpType.hello: self._heartbeat_interval = float(data["heartbeat_interval"]) / 1000 self._start_heartbeat() if self._resuming: @@ -193,24 +188,24 @@ def _dispatch_text(self, raw: str | bytes) -> None: else: self._schedule(self.send_identify()) - case VoiceOp.ready: + case VoiceOpType.ready: self._schedule(self.connection.on_ready(data)) - case VoiceOp.session_description: + case VoiceOpType.session_description: self._schedule(self.connection.on_session_description(data)) - case VoiceOp.speaking: + case VoiceOpType.speaking: self._schedule(self.connection.on_speaking(data)) - case VoiceOp.heartbeat_ack: + case VoiceOpType.heartbeat_ack: if self._last_send: self._latencies.append(time.perf_counter() - self._last_send) - case VoiceOp.resumed: + case VoiceOpType.resumed: self._schedule(self.connection.on_resumed(data)) case _: - _log.debug("Voice socket for guild %s received unhandled op %s", self.connection.guild_id, voice_op) + _log.debug(f"Voice socket for guild {self.connection.guild_id} received unhandled op {voice_op}") def _dispatch_binary(self, raw: bytes) -> None: """ @@ -257,7 +252,7 @@ async def _guard(self, coro: Coroutine[Any, Any, Any]) -> None: try: await coro except Exception as exc: - _log.error("Error in voice socket handler for guild %s", self.connection.guild_id, exc_info=exc) + _log.error(f"Error in voice socket handler for guild {self.connection.guild_id}", exc_info=exc) def _start_heartbeat(self) -> None: """ (Re)start the heartbeat task using the negotiated interval. """ @@ -278,14 +273,14 @@ async def _heartbeat_loop(self) -> None: except asyncio.CancelledError: pass except Exception as exc: - _log.debug("Voice heartbeat for guild %s stopped", self.connection.guild_id, exc_info=exc) + _log.debug(f"Voice heartbeat for guild {self.connection.guild_id} stopped", exc_info=exc) async def _send_heartbeat(self) -> None: """ Send a single heartbeat frame, recording the send time for latency. """ self._last_send = time.perf_counter() nonce = int(time.time() * 1000) await self._send_json({ - "op": int(VoiceOp.heartbeat), + "op": int(VoiceOpType.heartbeat), "d": { "t": nonce, "seq_ack": self.seq_ack, @@ -310,7 +305,7 @@ async def send_identify(self) -> None: from .dave import max_protocol_version await self._send_json({ - "op": int(VoiceOp.identify), + "op": int(VoiceOpType.identify), "d": { "server_id": str(self.connection.guild_id), "user_id": str(self.connection.user_id), @@ -334,7 +329,7 @@ async def send_select_protocol(self, ip: str, port: int, mode: str) -> None: The negotiated encryption mode. """ await self._send_json({ - "op": int(VoiceOp.select_protocol), + "op": int(VoiceOpType.select_protocol), "d": { "protocol": "udp", "data": { @@ -359,7 +354,7 @@ async def send_speaking(self, speaking: int, *, ssrc: int, delay: int = 0) -> No The voice delay, in milliseconds. """ await self._send_json({ - "op": int(VoiceOp.speaking), + "op": int(VoiceOpType.speaking), "d": { "speaking": int(speaking), "delay": int(delay), @@ -370,7 +365,7 @@ async def send_speaking(self, speaking: int, *, ssrc: int, delay: int = 0) -> No async def send_resume(self) -> None: """ Send the RESUME (op 7) frame to resume an interrupted session. """ await self._send_json({ - "op": int(VoiceOp.resume), + "op": int(VoiceOpType.resume), "d": { "server_id": str(self.connection.guild_id), "session_id": self.connection.session_id, @@ -412,8 +407,3 @@ async def close(self) -> None: if self.ws is not None and not self.ws.closed: await self.ws.close() self.ws = None - - if self._own_session and self._session is not None: - await self._session.close() - self._session = None - self._own_session = False diff --git a/examples/voice_example.py b/examples/voice_example.py index c9efbe9..ea76e6e 100644 --- a/examples/voice_example.py +++ b/examples/voice_example.py @@ -1,6 +1,6 @@ import asyncio -from discord_http import BaseChannel, Client, Context, VoiceClient, WaveSink +from discord_http import BaseChannel, Client, Context, PartialChannel, VoiceClient, WaveSink from discord_http.gateway import Intents # Voice requires a gateway connection (to send the voice-state update) and the @@ -22,21 +22,39 @@ ) +def caller_voice_channel(ctx: Context) -> "BaseChannel | PartialChannel | None": + """ + Resolve the voice channel the invoking member is currently sitting in. + + This reads the member's cached voice state (populated from the gateway via the + guild_voice_states intent) instead of relying on a hard-coded channel id, so the + bot always follows whoever ran the command. + """ + if ctx.guild is None or ctx.author is None: + return None + + voice_state = ctx.guild.get_member_voice_state(ctx.author.id) + if voice_state is None or voice_state.channel_id is None: + return None + + # ``VoiceState`` exposes the resolved ``channel`` directly; the partial variant only + # carries the id, so fall back to a partial channel that ``connect()`` can act on. + return getattr(voice_state, "channel", None) or client.get_partial_channel( + voice_state.channel_id, guild_id=ctx.guild.id + ) + + @client.command() async def join(ctx: Context): """ Join the caller's voice channel and play a song """ - if ctx.guild is None or ctx.author is None: - return ctx.response.send_message("This command can only be used in a guild.") - - # Resolve a voice channel to connect to (here a hard-coded id for brevity). - channel = await client.fetch_channel(1234567890, guild_id=ctx.guild.id) - if not isinstance(channel, BaseChannel): - return ctx.response.send_message("Could not find that channel.") + channel = caller_voice_channel(ctx) + if channel is None: + return ctx.response.send_message("Join a voice channel first, then try again.") vc: VoiceClient = await channel.connect() # Play a local file (mp3 -> opus passthrough, ffmpeg only). - await vc.play("song.mp3") + vc.play("song.mp3") return ctx.response.send_message(f"Now playing, latency: {vc.latency:.1f}ms") @@ -82,7 +100,7 @@ async def voice_demo(channel: BaseChannel, move_to: BaseChannel) -> None: vc: VoiceClient = await channel.connect() # Playback controls. - await vc.play("song.mp3") + vc.play("song.mp3") vc.pause() vc.resume() diff --git a/tests/test_voice_oggparse.py b/tests/test_voice_oggparse.py index 59c853a..f8033ea 100644 --- a/tests/test_voice_oggparse.py +++ b/tests/test_voice_oggparse.py @@ -2,8 +2,7 @@ import shutil import struct import subprocess - -import pytest +import unittest from discord_http.voice.oggparse import OggPage, OggStream @@ -19,7 +18,8 @@ def _build_page( crc: int = 0, ) -> bytes: """Build a single valid Ogg page from a body and a hand-crafted segment table.""" - assert sum(segtable) == len(body), "segment table must sum to body length" + if sum(segtable) != len(body): + raise ValueError("segment table must sum to body length") header = struct.pack( "<4sBBQIIIB", b"OggS", @@ -34,103 +34,98 @@ def _build_page( return header + segtable + body -def test_single_page_single_packet() -> None: - body = b"hello opus" - page_bytes = _build_page(body, bytes([len(body)])) - - stream = OggStream(io.BytesIO(page_bytes)) - packets = list(stream.iter_packets()) - - assert packets == [body] - - -def test_page_header_fields_parsed() -> None: - body = b"\x00\x01\x02\x03" - page_bytes = _build_page( - body, - bytes([len(body)]), - header_type=0x02, - granule_position=12345, - sequence=7, - ) - - # Skip the 4-byte magic, then parse the page directly. - buffer = io.BytesIO(page_bytes) - assert buffer.read(4) == b"OggS" - page = OggPage(buffer) - - assert page.header_type == 0x02 - assert page.granule_position == 12345 - assert page.page_sequence_number == 7 - assert page.segtable == bytes([len(body)]) - assert page.data == body - - -def test_multiple_packets_in_one_page() -> None: - packet_a = b"first" - packet_b = b"second-packet" - body = packet_a + packet_b - segtable = bytes([len(packet_a), len(packet_b)]) - - stream = OggStream(io.BytesIO(_build_page(body, segtable))) - assert list(stream.iter_packets()) == [packet_a, packet_b] - - -def test_packet_spanning_segments_via_255_lacing() -> None: - # A packet exactly 255 bytes long needs a 255 lacing + a 0 lacing terminator. - body = b"x" * 255 - segtable = bytes([255, 0]) - - stream = OggStream(io.BytesIO(_build_page(body, segtable))) - assert list(stream.iter_packets()) == [body] - - -def test_packet_spanning_pages() -> None: - # First page ends mid-packet (trailing 255 lacing), second page continues it. - head = b"a" * 255 - tail = b"bcd" - page_one = _build_page(head, bytes([255]), sequence=0) - page_two = _build_page(tail, bytes([len(tail)]), header_type=0x01, sequence=1) - - stream = OggStream(io.BytesIO(page_one + page_two)) - assert list(stream.iter_packets()) == [head + tail] - - -def test_scans_past_leading_garbage() -> None: - body = b"payload" - page_bytes = _build_page(body, bytes([len(body)])) - - stream = OggStream(io.BytesIO(b"garbage-before-magic" + page_bytes)) - assert list(stream.iter_packets()) == [body] - - -def test_ffmpeg_generated_opus_stream() -> None: - ffmpeg = shutil.which("ffmpeg") - if ffmpeg is None: - pytest.skip("ffmpeg not available on PATH") - - result = subprocess.run( # noqa: S603 - [ - ffmpeg, - "-f", "lavfi", - "-i", "sine=frequency=440:duration=1", - "-c:a", "libopus", - "-f", "ogg", - "-", - ], - stdout=subprocess.PIPE, - stderr=subprocess.DEVNULL, - check=True, - ) - data = result.stdout - assert data, "ffmpeg produced no output" - - packets = list(OggStream(io.BytesIO(data)).iter_packets()) - - assert len(packets) > 2 - assert packets[0].startswith(b"OpusHead") - assert packets[1].startswith(b"OpusTags") +class TestOggParse(unittest.TestCase): + def test_single_page_single_packet(self) -> None: + body = b"hello opus" + page_bytes = _build_page(body, bytes([len(body)])) + + stream = OggStream(io.BytesIO(page_bytes)) + packets = list(stream.iter_packets()) + + self.assertEqual(packets, [body]) + + def test_page_header_fields_parsed(self) -> None: + body = b"\x00\x01\x02\x03" + page_bytes = _build_page( + body, + bytes([len(body)]), + header_type=0x02, + granule_position=12345, + sequence=7, + ) + + # Skip the 4-byte magic, then parse the page directly. + buffer = io.BytesIO(page_bytes) + self.assertEqual(buffer.read(4), b"OggS") + page = OggPage(buffer) + + self.assertEqual(page.header_type, 0x02) + self.assertEqual(page.granule_position, 12345) + self.assertEqual(page.page_sequence_number, 7) + self.assertEqual(page.segtable, bytes([len(body)])) + self.assertEqual(page.data, body) + + def test_multiple_packets_in_one_page(self) -> None: + packet_a = b"first" + packet_b = b"second-packet" + body = packet_a + packet_b + segtable = bytes([len(packet_a), len(packet_b)]) + + stream = OggStream(io.BytesIO(_build_page(body, segtable))) + self.assertEqual(list(stream.iter_packets()), [packet_a, packet_b]) + + def test_packet_spanning_segments_via_255_lacing(self) -> None: + # A packet exactly 255 bytes long needs a 255 lacing + a 0 lacing terminator. + body = b"x" * 255 + segtable = bytes([255, 0]) + + stream = OggStream(io.BytesIO(_build_page(body, segtable))) + self.assertEqual(list(stream.iter_packets()), [body]) + + def test_packet_spanning_pages(self) -> None: + # First page ends mid-packet (trailing 255 lacing), second page continues it. + head = b"a" * 255 + tail = b"bcd" + page_one = _build_page(head, bytes([255]), sequence=0) + page_two = _build_page(tail, bytes([len(tail)]), header_type=0x01, sequence=1) + + stream = OggStream(io.BytesIO(page_one + page_two)) + self.assertEqual(list(stream.iter_packets()), [head + tail]) + + def test_scans_past_leading_garbage(self) -> None: + body = b"payload" + page_bytes = _build_page(body, bytes([len(body)])) + + stream = OggStream(io.BytesIO(b"garbage-before-magic" + page_bytes)) + self.assertEqual(list(stream.iter_packets()), [body]) + + def test_ffmpeg_generated_opus_stream(self) -> None: + ffmpeg = shutil.which("ffmpeg") + if ffmpeg is None: + self.skipTest("ffmpeg not available on PATH") + + result = subprocess.run( # noqa: S603 + [ + ffmpeg, + "-f", "lavfi", + "-i", "sine=frequency=440:duration=1", + "-c:a", "libopus", + "-f", "ogg", + "-", + ], + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + check=True, + ) + data = result.stdout + self.assertTrue(data, "ffmpeg produced no output") + + packets = list(OggStream(io.BytesIO(data)).iter_packets()) + + self.assertGreater(len(packets), 2) + self.assertTrue(packets[0].startswith(b"OpusHead")) + self.assertTrue(packets[1].startswith(b"OpusTags")) if __name__ == "__main__": - raise SystemExit(pytest.main([__file__, "-q"])) + unittest.main() From 346e3260a15ee6be1d59598677eaaca516d684cf Mon Sep 17 00:00:00 2001 From: Neppkun Date: Tue, 2 Jun 2026 23:02:59 +0300 Subject: [PATCH 04/12] Fix voice example: enable voice-state caching so member lookup works get_member_voice_state() only returns data when the library is caching voice states. The guild_voice_states intent makes Discord SEND the updates, but gateway_cache decides what is kept; with no gateway_cache the cache flags are None, update_voice_state() no-ops, and the bot always thinks the caller is not in a voice channel. Enable GatewayCacheFlags.guilds | GatewayCacheFlags.voice_states (voice states need a cached guild to hang on) and document why both the intent and the cache flag are required. Co-Authored-By: Claude Opus 4.8 --- examples/voice_example.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/voice_example.py b/examples/voice_example.py index ea76e6e..2dc352b 100644 --- a/examples/voice_example.py +++ b/examples/voice_example.py @@ -1,11 +1,17 @@ import asyncio from discord_http import BaseChannel, Client, Context, PartialChannel, VoiceClient, WaveSink -from discord_http.gateway import Intents +from discord_http.gateway import GatewayCacheFlags, Intents # Voice requires a gateway connection (to send the voice-state update) and the # guild_voice_states intent (so the bot receives its own voice server/state updates). # +# To follow whoever ran a command, the bot also needs to *cache* voice states: the +# guild_voice_states intent only makes Discord SEND the updates, while gateway_cache +# decides what the library keeps. Without GatewayCacheFlags.voice_states (and a guild +# cache to hang them on), get_member_voice_state() is always empty and the bot will +# think nobody is in a channel. +# # Codec notes: # * Passing an ``.mp3``/``.opus`` file plays through ffmpeg -> Ogg/Opus and needs # ONLY ffmpeg installed (no libopus) -- the audio is sent as opus passthrough. @@ -18,6 +24,10 @@ intents=( Intents.guild_messages | Intents.guild_voice_states + ), + gateway_cache=( + GatewayCacheFlags.guilds | + GatewayCacheFlags.voice_states ) ) From 05b11e68267b19a5379753961dfe1b45a138aa21 Mon Sep 17 00:00:00 2001 From: Neppkun Date: Tue, 2 Jun 2026 23:14:24 +0300 Subject: [PATCH 05/12] Fix voice example: use ctx.user and add the guilds intent Two reasons the bot thought the caller was not in voice: - ctx.author is only set for interactions tied to a message; for a slash command it is None. The invoking member is ctx.user. - Without Intents.guilds the bot never receives GUILD_CREATE, so no guild is cached: ctx.guild falls back to an empty stub and Cache.update_voice_state bails (get_guild returns None), so voice states are never stored. Add the guilds intent so the guild (and its voice states) actually get cached. Co-Authored-By: Claude Opus 4.8 --- examples/voice_example.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/voice_example.py b/examples/voice_example.py index 2dc352b..688787d 100644 --- a/examples/voice_example.py +++ b/examples/voice_example.py @@ -6,11 +6,14 @@ # Voice requires a gateway connection (to send the voice-state update) and the # guild_voice_states intent (so the bot receives its own voice server/state updates). # -# To follow whoever ran a command, the bot also needs to *cache* voice states: the -# guild_voice_states intent only makes Discord SEND the updates, while gateway_cache -# decides what the library keeps. Without GatewayCacheFlags.voice_states (and a guild -# cache to hang them on), get_member_voice_state() is always empty and the bot will -# think nobody is in a channel. +# To follow whoever ran a command, the bot must *cache* voice states, which needs: +# * Intents.guilds -> so GUILD_CREATE fires and guilds get cached (voice +# states are stored on the guild; without it ctx.guild +# is an empty stub and nothing is ever kept) +# * Intents.guild_voice_states -> so Discord actually SENDS voice state updates +# * GatewayCacheFlags.guilds | .voice_states -> so the library keeps both +# Miss any of these and get_member_voice_state() stays empty -> the bot thinks nobody +# is in a channel. # # Codec notes: # * Passing an ``.mp3``/``.opus`` file plays through ffmpeg -> Ogg/Opus and needs @@ -22,6 +25,7 @@ token="BOT_TOKEN", enable_gateway=True, intents=( + Intents.guilds | Intents.guild_messages | Intents.guild_voice_states ), @@ -40,10 +44,10 @@ def caller_voice_channel(ctx: Context) -> "BaseChannel | PartialChannel | None": guild_voice_states intent) instead of relying on a hard-coded channel id, so the bot always follows whoever ran the command. """ - if ctx.guild is None or ctx.author is None: + if ctx.guild is None or ctx.user is None: return None - voice_state = ctx.guild.get_member_voice_state(ctx.author.id) + voice_state = ctx.guild.get_member_voice_state(ctx.user.id) if voice_state is None or voice_state.channel_id is None: return None From dfb41a2670cad55720450220026c0a109af3b23a Mon Sep 17 00:00:00 2001 From: Neppkun Date: Wed, 3 Jun 2026 00:14:56 +0300 Subject: [PATCH 06/12] Fix voice connection handshake and DAVE framing The voice handshake never completed: connect() always timed out waiting on the voice gateway. Debugging live against a real channel surfaced five distinct protocol bugs (each revealed by the next close code once the prior was fixed: timeout -> 4003 -> 4006 -> 4017 -> 4005 -> connected). socket.py: - _send_json: send JSON control frames as TEXT (send_str) instead of BINARY (send_bytes). The voice gateway reserves binary frames for DAVE/E2EE opcodes, so IDENTIFY sent as binary was silently ignored -> handshake timeout. - HELLO handling: send IDENTIFY/RESUME *before* starting the heartbeat loop (via a single ordered _handle_hello coroutine). Previously the heartbeat task was created first and op 3 raced ahead of op 0 -> 4003 Not authenticated. - _heartbeat_loop: sleep one interval before the first beat so nothing is sent between IDENTIFY and READY (mirrors discord.py's keep-alive). - send_binary: frame outbound DAVE messages as opcode(1B)+payload with NO 2-byte sequence prefix. Inbound frames carry a seq prefix but outbound must not; the extra leading 0x00 made Discord read the frame as opcode 0 (IDENTIFY) -> 4005 Already authenticated. Switch to ws.receive() loop so Discord's close reason is logged (invaluable for diagnosis). - add send_transition_ready (op 23) as a JSON frame, per the protocol. connection.py: - on_voice_server_update: preserve the endpoint port. Discord assigns voice servers on non-443 ports (e.g. :2053) and the session/token are bound to that host:port; stripping it connected us to a different instance -> 4006 Session is no longer valid. Only strip a scheme if present. dave.py: - _handle_proposals: decode the operation_type(1B) prefix and pass (operation_type, proposals) to davey.process_proposals (it requires both args); concatenate commit + welcome for the reply. - _handle_commit / _handle_welcome: strip the transition_id(2B) prefix before passing the MLS blob to davey, and ack non-zero transitions via the JSON TRANSITION_READY (op 23) rather than a binary frame. Verified end to end against a live channel that requires DAVE: the bot joins, plays audio, stays connected, and negotiates E2EE (can_encrypt() == True). --- discord_http/voice/connection.py | 14 ++++-- discord_http/voice/dave.py | 82 ++++++++++++++++++++++++-------- discord_http/voice/socket.py | 80 ++++++++++++++++++++++++++----- 3 files changed, 139 insertions(+), 37 deletions(-) diff --git a/discord_http/voice/connection.py b/discord_http/voice/connection.py index 14ce8c6..594aaf5 100644 --- a/discord_http/voice/connection.py +++ b/discord_http/voice/connection.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any -from ..utils import URL, ExponentialBackoff +from ..utils import ExponentialBackoff from .dave import DaveManager, has_dave from .encryptor import Encryptor from .enums import SUPPORTED_MODES @@ -358,9 +358,15 @@ def on_voice_server_update(self, data: dict) -> None: endpoint = data.get("endpoint") if endpoint: - # Discord sends the endpoint without a scheme (e.g. "host.discord.media:443"); - # the URL helper cleanly extracts the host without the port for us. - self.endpoint = URL(f"wss://{endpoint}").host or endpoint + # Discord sends the endpoint as "host:port" without a scheme, and the + # port is NOT always 443 (e.g. "c-ams20-....discord.media:2053"). The + # token/session are bound to that specific host:port, so we MUST keep + # the port intact -- connecting to host:443 instead reaches a + # different voice server instance and Discord closes the socket with + # 4006 "Session is no longer valid". Only strip a scheme if present. + if endpoint.startswith("wss://"): + endpoint = endpoint[len("wss://"):] + self.endpoint = endpoint.rstrip("/") server_id = data.get("guild_id") or data.get("server_id") self.server_id = int(server_id) if server_id is not None else None diff --git a/discord_http/voice/dave.py b/discord_http/voice/dave.py index a5aba62..4b900ba 100644 --- a/discord_http/voice/dave.py +++ b/discord_http/voice/dave.py @@ -260,9 +260,7 @@ async def _handle_prepare_transition(self, payload: bytes) -> None: if transition_id == 0: await self._execute_transition(transition_id, version) else: - await self._connection.socket.send_binary( - int(VoiceOpType.dave_transition_ready), self._encode_transition_id(transition_id) - ) + await self._connection.socket.send_transition_ready(transition_id) async def _handle_execute_transition(self, payload: bytes) -> None: """ Handle EXECUTE_TRANSITION (22): apply the pending version and passthrough state. """ @@ -297,12 +295,28 @@ def _handle_external_sender(self, payload: bytes) -> None: pass async def _handle_proposals(self, payload: bytes) -> None: - """ Handle MLS_PROPOSALS (27): process proposals and forward any commit/welcome. """ - if self._session is None: + """ + Handle MLS_PROPOSALS (27): process proposals and forward any commit/welcome. + + The payload is ``operation_type(1B) + proposals``: the first byte selects + append (``0``) vs revoke, and the remainder is the serialized proposals. + """ + if self._session is None or davey is None: + return + + if len(payload) < 1: return + optype = payload[0] + proposals = payload[1:] + operation_type = ( + davey.ProposalsOperationType.append + if optype == 0 + else davey.ProposalsOperationType.revoke + ) + try: - result = self._session.process_proposals(payload) + result = self._session.process_proposals(operation_type, proposals) except AttributeError: return except Exception as exc: @@ -310,9 +324,6 @@ async def _handle_proposals(self, payload: bytes) -> None: await self._recover_from_invalid_commit() return - if result is None: - return - commit_welcome = self._extract_commit_welcome(result) if commit_welcome is not None: await self._connection.socket.send_binary( @@ -320,30 +331,54 @@ async def _handle_proposals(self, payload: bytes) -> None: ) async def _handle_commit(self, payload: bytes) -> None: - """ Handle MLS_ANNOUNCE_COMMIT_TRANSITION (29): apply the announced commit. """ + """ + Handle MLS_ANNOUNCE_COMMIT_TRANSITION (29): apply the announced commit. + + The payload is ``transition_id(2B big-endian) + commit``. + """ if self._session is None: return + transition_id = int.from_bytes(payload[:2], "big") if len(payload) >= 2 else 0 + commit = payload[2:] + try: - self._session.process_commit(payload) + self._session.process_commit(commit) except AttributeError: return except Exception as exc: _log.warning(f"Failed to process MLS commit: {exc}") await self._recover_from_invalid_commit() + return + + if transition_id != 0: + self._pending_transition = (transition_id, self._version) + await self._connection.socket.send_transition_ready(transition_id) async def _handle_welcome(self, payload: bytes) -> None: - """ Handle MLS_WELCOME (30): join the group from the received welcome message. """ + """ + Handle MLS_WELCOME (30): join the group from the received welcome message. + + The payload is ``transition_id(2B big-endian) + welcome``. + """ if self._session is None: return + transition_id = int.from_bytes(payload[:2], "big") if len(payload) >= 2 else 0 + welcome = payload[2:] + try: - self._session.process_welcome(payload) + self._session.process_welcome(welcome) except AttributeError: return except Exception as exc: _log.warning(f"Failed to process MLS welcome: {exc}") await self._recover_from_invalid_commit() + return + + if transition_id != 0: + self._pending_transition = (transition_id, self._version) + await self._connection.socket.send_transition_ready(transition_id) async def _recover_from_invalid_commit(self) -> None: """ Notify the gateway of an invalid commit/welcome and reinitialise the session. """ @@ -355,10 +390,12 @@ async def _recover_from_invalid_commit(self) -> None: @staticmethod def _extract_commit_welcome(result: _Session) -> bytes | None: """ - Extract serialized commit/welcome bytes from a ``davey.CommitWelcome`` result. + Extract the bytes to send for a ``davey.CommitWelcome`` result. - The exact ``davey`` API may differ; this tolerates a serialize method, raw bytes, - or a ``None`` result. + ``davey``'s ``process_proposals`` returns a ``CommitWelcome`` carrying a + ``commit`` and an optional ``welcome``. Discord expects them concatenated + as ``commit + welcome`` (commit alone when there is no welcome). This also + tolerates raw bytes or a ``None`` result. Parameters ---------- @@ -373,6 +410,14 @@ def _extract_commit_welcome(result: _Session) -> bytes | None: return None if isinstance(result, (bytes, bytearray)): return bytes(result) + + commit = getattr(result, "commit", None) + if commit is not None: + welcome = getattr(result, "welcome", None) + if welcome: + return bytes(commit) + bytes(welcome) + return bytes(commit) + if hasattr(result, "serialize"): try: return bytes(result.serialize()) @@ -391,8 +436,3 @@ def _parse_transition(payload: bytes) -> tuple[int, int]: transition_id = int.from_bytes(payload[:2], "big") if len(payload) >= 2 else 0 version = payload[2] if len(payload) >= 3 else 0 return transition_id, version - - @staticmethod - def _encode_transition_id(transition_id: int) -> bytes: - """ Encode a transition id as a 2-byte big-endian payload for TRANSITION_READY. """ - return transition_id.to_bytes(2, "big") diff --git a/discord_http/voice/socket.py b/discord_http/voice/socket.py index 1254540..6e84ff0 100644 --- a/discord_http/voice/socket.py +++ b/discord_http/voice/socket.py @@ -60,7 +60,6 @@ def __init__(self, connection: "VoiceConnection"): self._heartbeat_task: asyncio.Task | None = None self._receive_task: asyncio.Task | None = None - self._out_seq: int = 0 self._last_send: float = 0.0 self._latencies: deque[float] = deque(maxlen=20) @@ -125,7 +124,9 @@ async def _receive_loop(self) -> None: close_code: int | None = None try: - async for msg in ws: + while True: + msg = await ws.receive() + if msg.type is WSMsgType.TEXT: self._dispatch_text(msg.data) @@ -133,7 +134,13 @@ async def _receive_loop(self) -> None: self._dispatch_binary(msg.data) elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): - _log.debug(f"Voice socket for guild {self.connection.guild_id} received close frame") + # Discord's close reason (msg.extra) is invaluable for diagnosing + # voice failures (e.g. "E2EE/DAVE protocol required" for 4017), + # so surface it alongside the code. + _log.debug( + f"Voice socket for guild {self.connection.guild_id} received close frame " + f"(code={msg.data!r}, reason={msg.extra!r})" + ) break elif msg.type is WSMsgType.ERROR: @@ -182,11 +189,7 @@ def _dispatch_text(self, raw: str | bytes) -> None: match voice_op: case VoiceOpType.hello: self._heartbeat_interval = float(data["heartbeat_interval"]) / 1000 - self._start_heartbeat() - if self._resuming: - self._schedule(self.send_resume()) - else: - self._schedule(self.send_identify()) + self._schedule(self._handle_hello()) case VoiceOpType.ready: self._schedule(self.connection.on_ready(data)) @@ -254,6 +257,23 @@ async def _guard(self, coro: Coroutine[Any, Any, Any]) -> None: except Exception as exc: _log.error(f"Error in voice socket handler for guild {self.connection.guild_id}", exc_info=exc) + async def _handle_hello(self) -> None: + """ + React to HELLO (op 8): authenticate, then start heartbeating. + + IDENTIFY/RESUME MUST be the first payload sent on the voice gateway. + The heartbeat loop emits a heartbeat (op 3) immediately, so it can only + be started *after* authentication has been sent; otherwise Discord sees + a payload before IDENTIFY and closes the socket with code 4003 + ("Not authenticated"). + """ + if self._resuming: + await self.send_resume() + else: + await self.send_identify() + + self._start_heartbeat() + def _start_heartbeat(self) -> None: """ (Re)start the heartbeat task using the negotiated interval. """ if self._heartbeat_task is not None and not self._heartbeat_task.done(): @@ -268,8 +288,13 @@ async def _heartbeat_loop(self) -> None: """ Send a heartbeat (op 3) every interval until cancelled. """ try: while True: - await self._send_heartbeat() + # Sleep *before* the first beat: the voice gateway must complete + # the IDENTIFY -> READY -> SESSION_DESCRIPTION handshake without + # any heartbeat interleaved. Sending op 3 before READY makes + # Discord invalidate the session (close code 4006). This mirrors + # discord.py's voice keep-alive, which also waits one interval. await asyncio.sleep(self._heartbeat_interval) + await self._send_heartbeat() except asyncio.CancelledError: pass except Exception as exc: @@ -298,7 +323,12 @@ async def _send_json(self, payload: dict) -> None: """ if self.ws is None or self.ws.closed: return - await self.ws.send_bytes(orjson.dumps(payload)) + # JSON control frames MUST be sent as text frames: the voice gateway + # reserves binary frames for DAVE/E2EE opcodes (see ``send_binary`` and + # ``_dispatch_binary``). Sending JSON via ``send_bytes`` makes Discord + # treat IDENTIFY as a malformed DAVE frame, so it never replies with + # READY/SESSION_DESCRIPTION and the handshake times out. + await self.ws.send_str(orjson.dumps(payload).decode("utf-8")) async def send_identify(self) -> None: """ Send the IDENTIFY (op 0) frame, advertising DAVE support. """ @@ -374,10 +404,37 @@ async def send_resume(self) -> None: } }) + async def send_transition_ready(self, transition_id: int) -> None: + """ + Send the DAVE TRANSITION_READY (op 23) acknowledgement. + + This is a JSON control frame (not a binary DAVE frame): it carries the + ``transition_id`` as JSON, matching the voice gateway protocol. + + Parameters + ---------- + transition_id: + The id of the transition being acknowledged. + """ + await self._send_json({ + "op": int(VoiceOpType.dave_transition_ready), + "d": { + "transition_id": transition_id, + } + }) + async def send_binary(self, opcode: int, payload: bytes) -> None: """ Send a binary DAVE frame. + Outbound binary frames are framed as ``opcode(1B) + payload`` with NO + sequence prefix. This is asymmetric with *inbound* binary frames, which + Discord prefixes with a 2-byte sequence number (``seq(2B) + opcode(1B) + + payload``, handled in :meth:`_dispatch_binary`). Prefixing outbound + frames with the 2-byte sequence makes Discord read the leading ``0x00`` + byte as opcode 0 (IDENTIFY) and close the socket with 4005 + ("Already authenticated"). + Parameters ---------- opcode: @@ -388,8 +445,7 @@ async def send_binary(self, opcode: int, payload: bytes) -> None: if self.ws is None or self.ws.closed: return - self._out_seq = (self._out_seq + 1) & 0xFFFF - frame = struct.pack(">H", 0) + bytes([opcode & 0xFF]) + payload + frame = bytes([opcode & 0xFF]) + payload await self.ws.send_bytes(frame) async def close(self) -> None: From f65b08e27901d1a8735c31e61acada3652fdc1ba Mon Sep 17 00:00:00 2001 From: Neppkun Date: Wed, 3 Jun 2026 00:23:30 +0300 Subject: [PATCH 07/12] Add VoiceConnection.client property for readability Address review feedback on the long attribute chain `self.voice_client.client.voice_reconnect_attempts`. Introduce a `client` property on VoiceConnection and use it everywhere the owning bot client was reached via `self.voice_client.client`. --- discord_http/voice/connection.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/discord_http/voice/connection.py b/discord_http/voice/connection.py index 594aaf5..c70f84a 100644 --- a/discord_http/voice/connection.py +++ b/discord_http/voice/connection.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from ..channel import PartialChannel + from ..client import Client from .client import VoiceClient __all__ = ("VoiceConnection",) @@ -107,6 +108,11 @@ def __init__(self, voice_client: "VoiceClient"): self._reconnect_task: asyncio.Task | None = None self._backoff: ExponentialBackoff = ExponentialBackoff(base=1.0, max_delay=30.0) + @property + def client(self) -> "Client": + """ The bot client that owns this voice connection. """ + return self.voice_client.client + @property def latency(self) -> float: """ The latency of the most recent voice heartbeat, in seconds. """ @@ -161,7 +167,7 @@ async def connect( TimeoutError If the handshake does not complete within ``timeout``. """ - client = self.voice_client.client + client = self.client shard_id = client.get_shard_by_guild_id(self.guild_id) if shard_id is None: @@ -292,7 +298,7 @@ async def _full_reconnect(self, close_code: int | None) -> None: self._backoff.reset() - max_attempts = self.voice_client.client.voice_reconnect_attempts + max_attempts = self.client.voice_reconnect_attempts for attempt in range(1, max_attempts + 1): if self._closing: @@ -544,7 +550,7 @@ async def disconnect(self, *, force: bool = True) -> None: if self.socket is not None: self.socket._request_close() - client = self.voice_client.client + client = self.client try: shard_id = client.get_shard_by_guild_id(self.guild_id) shard = client.gateway.get_shard(shard_id) if (client.gateway and shard_id is not None) else None @@ -608,7 +614,7 @@ async def move_to(self, channel: "PartialChannel") -> None: channel: The channel to move to. """ - client = self.voice_client.client + client = self.client shard_id = client.get_shard_by_guild_id(self.guild_id) shard = client.gateway.get_shard(shard_id) if (client.gateway and shard_id is not None) else None From f5488e2a7222f6a17ee8e0316f7d58fe2028597d Mon Sep 17 00:00:00 2001 From: Neppkun Date: Thu, 4 Jun 2026 15:16:30 +0300 Subject: [PATCH 08/12] Fix voice rejoin, shutdown reconnect, ResourceWarning, and add public APIs Address review feedback on voice support: - Rejoin (DAVE WrongEpoch -> 4006): route DAVE control ops 21/22/24, which arrive as JSON text frames rather than binary, to the DAVE manager so the local MLS epoch stays in sync with the gateway. Add a configurable policy for the 4006 "session invalid" close (most common when a channel empties and Discord tears down the DAVE session): disconnect by default, or reconnect when connect(reconnect_on_session_invalid=True) is passed. - Shutdown reconnect: tear down active voice clients in GatewayClient.close() before killing shards, so the voice sockets don't try to reconnect on the 1006 close during shutdown. - ResourceWarning ("unclosed transport" / "I/O operation on closed pipe"): close the ffmpeg subprocess transport and reap the process in _FFmpegAudio.cleanup() so the Windows Proactor pipe transports are released deterministically instead of by the GC. - PartialVoiceState.channel: add a `channel` property to PartialVoiceState (and override it on VoiceState) so user_vc.channel.connect() works and is statically known for both types. - Public get_voice_client: rename Client._get_voice_client to the public Client.get_voice_client and update all callers and the example. --- discord_http/channel.py | 41 ++++++++++++++++++++--- discord_http/client.py | 2 +- discord_http/gateway/client.py | 11 +++++++ discord_http/gateway/shard.py | 4 +-- discord_http/voice/client.py | 6 ++++ discord_http/voice/connection.py | 45 +++++++++++++++++++++++++ discord_http/voice/dave.py | 56 ++++++++++++++++++-------------- discord_http/voice/player.py | 43 +++++++++++++++++++++--- discord_http/voice/socket.py | 13 ++++++++ examples/voice_example.py | 15 ++++----- 10 files changed, 193 insertions(+), 43 deletions(-) diff --git a/discord_http/channel.py b/discord_http/channel.py index b528461..1573250 100644 --- a/discord_http/channel.py +++ b/discord_http/channel.py @@ -486,6 +486,7 @@ async def connect( *, timeout: float = 30.0, reconnect: bool = True, + reconnect_on_session_invalid: bool = False, self_deaf: bool = False, self_mute: bool = False ) -> "VoiceClient": @@ -498,6 +499,11 @@ async def connect( How long to wait, in seconds, for the voice handshake to complete. reconnect: Whether to automatically reconnect if the voice connection drops. + reconnect_on_session_invalid: + What to do when Discord invalidates the voice session (close code + 4006), which most commonly happens when the channel empties out and + Discord tears down the DAVE/MLS session. By default (``False``) the + bot disconnects; set to ``True`` to attempt a full reconnect instead. self_deaf: Whether the bot should be self-deafened. self_mute: @@ -533,7 +539,7 @@ async def connect( if not client.gateway: raise NotImplementedError("gateway is not available") - if client._get_voice_client(self.guild_id) is not None: + if client.get_voice_client(self.guild_id) is not None: raise RuntimeError("Already connected to a voice channel in this guild") from .voice.client import VoiceClient @@ -543,6 +549,7 @@ async def connect( await vc.connect( timeout=timeout, reconnect=reconnect, + reconnect_on_session_invalid=reconnect_on_session_invalid, self_deaf=self_deaf, self_mute=self_mute ) @@ -2839,6 +2846,21 @@ def __repr__(self) -> str: def __str__(self) -> str: return "PartialVoiceState" + @property + def channel(self) -> "BaseChannel | PartialChannel | None": + """ + The voice channel this user is in, if any. + + Returns a `PartialChannel` built from `channel_id` so it can be used + directly (e.g. ``await voice_state.channel.connect()``) even when the + full channel object is not cached. + """ + if self.channel_id is None: + return None + return self._state.bot.get_partial_channel( + self.channel_id, guild_id=self.guild_id + ) + async def fetch(self) -> "VoiceState": """ Fetches the voice state of the member. @@ -2906,7 +2928,7 @@ class VoiceState(PartialVoiceState): """ Represents a voice state object. """ __slots__ = ( - "channel", + "_channel", "deaf", "guild", "member", @@ -2947,8 +2969,7 @@ def __init__( self.member: "Member | None" = None """ The member this voice state belongs to, if any. """ - self.channel: "BaseChannel | PartialChannel | None" = channel - """ The voice channel this user is in, if any. """ + self._channel: "BaseChannel | PartialChannel | None" = channel self.guild: "PartialGuild | None" = guild """ The guild this voice state is in, if any. """ @@ -2982,6 +3003,18 @@ def __init__( def __repr__(self) -> str: return f"" + @property + def channel(self) -> "BaseChannel | PartialChannel | None": + """ + The voice channel this user is in, if any. + + Prefers the resolved channel object (when cached); otherwise falls back + to a `PartialChannel` built from `channel_id` so it remains usable. + """ + if self._channel is not None: + return self._channel + return super().channel + def _from_data(self, data: dict) -> None: if data.get("member") and self.guild: from .member import Member diff --git a/discord_http/client.py b/discord_http/client.py index 1d2e5a7..d442259 100644 --- a/discord_http/client.py +++ b/discord_http/client.py @@ -226,7 +226,7 @@ def _cleanup_task(self, task: asyncio.Task) -> None: except Exception: pass - def _get_voice_client(self, guild_id: int) -> "VoiceClient | None": + def get_voice_client(self, guild_id: int) -> "VoiceClient | None": """ Get the voice client for a guild, if one is registered. diff --git a/discord_http/gateway/client.py b/discord_http/gateway/client.py index 0b59067..095355e 100644 --- a/discord_http/gateway/client.py +++ b/discord_http/gateway/client.py @@ -234,6 +234,17 @@ def start(self) -> None: async def close(self) -> None: """ Close the gateway client. """ + # Tear down any active voice connections first. Each voice websocket is a + # separate connection to Discord; once the gateway drops, Discord closes + # them abnormally (code 1006) and the voice socket would otherwise try to + # reconnect mid-shutdown. ``_cleanup`` closes them locally (no op4) and + # marks them as intentionally closing so no reconnect is scheduled. + for vc in list(self.bot._voice_clients.values()): + try: + await vc._cleanup() + except Exception as exc: + _log.debug("Error cleaning up voice client during shutdown", exc_info=exc) + to_close = [ asyncio.ensure_future(shard.close(kill=True)) for shard in self.__shards.values() diff --git a/discord_http/gateway/shard.py b/discord_http/gateway/shard.py index 173c19b..bc78857 100644 --- a/discord_http/gateway/shard.py +++ b/discord_http/gateway/shard.py @@ -1067,7 +1067,7 @@ def _parse_voice_state_update(self, data: dict) -> None: and data.get("guild_id") is not None and int(data["user_id"]) == bot_user.id ): - vc = self.bot._get_voice_client(int(data["guild_id"])) + vc = self.bot.get_voice_client(int(data["guild_id"])) if vc is not None: vc.on_voice_state_update(data) @@ -1077,7 +1077,7 @@ def _parse_voice_state_update(self, data: dict) -> None: def _parse_voice_server_update(self, data: dict) -> None: (payload,) = self.parser.voice_server_update(data) - vc = self.bot._get_voice_client(int(data["guild_id"])) + vc = self.bot.get_voice_client(int(data["guild_id"])) if vc is not None: vc.on_voice_server_update(data) diff --git a/discord_http/voice/client.py b/discord_http/voice/client.py index b68261e..34c4879 100644 --- a/discord_http/voice/client.py +++ b/discord_http/voice/client.py @@ -105,6 +105,7 @@ async def connect( *, timeout: float = 30.0, reconnect: bool = True, + reconnect_on_session_invalid: bool = False, self_deaf: bool = False, self_mute: bool = False ) -> None: @@ -117,6 +118,10 @@ async def connect( The maximum time to wait for the handshake, in seconds. reconnect: Whether to attempt reconnection on failure. + reconnect_on_session_invalid: + Whether to reconnect when Discord invalidates the session (close code + 4006), e.g. after the channel empties and the DAVE session is torn + down. Defaults to ``False`` (disconnect instead of reconnecting). self_deaf: Whether to join self-deafened. self_mute: @@ -125,6 +130,7 @@ async def connect( await self.connection.connect( timeout=timeout, reconnect=reconnect, + reconnect_on_session_invalid=reconnect_on_session_invalid, self_deaf=self_deaf, self_mute=self_mute, ) diff --git a/discord_http/voice/connection.py b/discord_http/voice/connection.py index c70f84a..0a144c0 100644 --- a/discord_http/voice/connection.py +++ b/discord_http/voice/connection.py @@ -102,6 +102,7 @@ def __init__(self, voice_client: "VoiceClient"): self._connected_event: asyncio.Event = asyncio.Event() self._reconnect: bool = True + self._reconnect_on_session_invalid: bool = False self._self_mute: bool = False self._self_deaf: bool = False self._closing: bool = False @@ -143,6 +144,7 @@ async def connect( *, timeout: float = 30.0, reconnect: bool = True, + reconnect_on_session_invalid: bool = False, self_deaf: bool = False, self_mute: bool = False ) -> None: @@ -155,6 +157,10 @@ async def connect( The maximum time to wait for the handshake, in seconds. reconnect: Whether to attempt reconnection on failure. + reconnect_on_session_invalid: + Whether to reconnect when Discord invalidates the session (close code + 4006), e.g. after the channel empties and the DAVE session is torn + down. Defaults to ``False`` (disconnect instead of reconnecting). self_deaf: Whether to join self-deafened. self_mute: @@ -178,6 +184,7 @@ async def connect( raise RuntimeError(f"Could not resolve shard {shard_id} for guild {self.guild_id}") self._reconnect = reconnect + self._reconnect_on_session_invalid = reconnect_on_session_invalid self._self_mute = self_mute self._self_deaf = self_deaf self._closing = False @@ -251,6 +258,19 @@ async def _handle_close(self, close_code: int | None) -> None: await self._resume() return + if close_code == VoiceCloseCode.session_invalid: + # Discord invalidates the session (4006) most commonly when the channel + # empties out and the DAVE/MLS session is torn down. The session is gone + # for good, so reconnecting only helps if the caller explicitly opted in; + # otherwise we disconnect rather than retry-and-time-out. + if self._reconnect and self._reconnect_on_session_invalid: + _log.info(f"Voice session for guild {self.guild_id} invalidated (code {close_code}); reconnecting") + await self._full_reconnect(close_code) + else: + _log.info(f"Voice session for guild {self.guild_id} invalidated (code {close_code}); disconnecting") + await self._teardown_and_remove() + return + if close_code in (VoiceCloseCode.normal, VoiceCloseCode.going_away): _log.debug(f"Voice connection for guild {self.guild_id} closed cleanly (code {close_code})") await self._teardown_and_remove() @@ -314,6 +334,7 @@ async def _full_reconnect(self, close_code: int | None) -> None: try: await self.connect( reconnect=self._reconnect, + reconnect_on_session_invalid=self._reconnect_on_session_invalid, self_deaf=self._self_deaf, self_mute=self._self_mute, ) @@ -478,6 +499,30 @@ async def on_dave_binary(self, opcode: int, payload: bytes) -> None: if self.dave_session is not None: await self.dave_session.handle_binary(opcode, payload) + async def on_dave_json(self, opcode: int, data: dict) -> None: + """ + Handle an inbound JSON DAVE control op (transition/epoch, ops 21/22/24). + + These ops arrive as JSON text frames rather than binary DAVE frames, so + they carry the decoded ``d`` payload instead of raw bytes. + + Parameters + ---------- + opcode: + The voice opcode of the control frame. + data: + The decoded JSON ``d`` payload. + """ + if not has_dave: + _log.warning(f"Received DAVE JSON op {opcode} but the davey library is not available") + return + + if self.dave_session is None: + await self.reinit_dave_session() + + if self.dave_session is not None: + await self.dave_session.handle_json(opcode, data) + async def reinit_dave_session(self) -> None: """ Create or reset the DAVE session for the negotiated protocol version. """ if not has_dave: diff --git a/discord_http/voice/dave.py b/discord_http/voice/dave.py index 4b900ba..4d7e3c6 100644 --- a/discord_http/voice/dave.py +++ b/discord_http/voice/dave.py @@ -235,12 +235,6 @@ async def handle_binary(self, opcode: int, payload: bytes) -> None: The raw binary payload following the opcode. """ match opcode: - case VoiceOpType.dave_prepare_transition: - await self._handle_prepare_transition(payload) - case VoiceOpType.dave_execute_transition: - await self._handle_execute_transition(payload) - case VoiceOpType.dave_prepare_epoch: - await self._handle_prepare_epoch(payload) case VoiceOpType.dave_mls_external_sender: self._handle_external_sender(payload) case VoiceOpType.dave_mls_proposals: @@ -252,9 +246,35 @@ async def handle_binary(self, opcode: int, payload: bytes) -> None: case _: _log.debug(f"Unhandled DAVE binary opcode {opcode}") - async def _handle_prepare_transition(self, payload: bytes) -> None: + async def handle_json(self, opcode: int, data: dict) -> None: + """ + Dispatch a JSON DAVE control op (21, 22, 24) received from the gateway. + + Unlike the MLS data ops (25-31), the transition/epoch control ops arrive as + regular JSON text frames rather than binary frames, so they are routed here + with the decoded ``d`` payload instead of raw bytes. + + Parameters + ---------- + opcode: + The voice opcode, expected to be one of 21, 22 or 24. + data: + The decoded JSON ``d`` payload of the frame. + """ + match opcode: + case VoiceOpType.dave_prepare_transition: + await self._handle_prepare_transition(data) + case VoiceOpType.dave_execute_transition: + await self._handle_execute_transition(data) + case VoiceOpType.dave_prepare_epoch: + await self._handle_prepare_epoch(data) + case _: + _log.debug(f"Unhandled DAVE JSON opcode {opcode}") + + async def _handle_prepare_transition(self, data: dict) -> None: """ Handle PREPARE_TRANSITION (21): record the pending transition and acknowledge. """ - transition_id, version = self._parse_transition(payload) + transition_id = int(data.get("transition_id", 0) or 0) + version = int(data.get("protocol_version", 0) or 0) self._pending_transition = (transition_id, version) if transition_id == 0: @@ -262,9 +282,9 @@ async def _handle_prepare_transition(self, payload: bytes) -> None: else: await self._connection.socket.send_transition_ready(transition_id) - async def _handle_execute_transition(self, payload: bytes) -> None: + async def _handle_execute_transition(self, data: dict) -> None: """ Handle EXECUTE_TRANSITION (22): apply the pending version and passthrough state. """ - transition_id, _ = self._parse_transition(payload) + transition_id = int(data.get("transition_id", 0) or 0) if self._pending_transition is not None: pending_id, version = self._pending_transition @@ -281,9 +301,9 @@ async def _execute_transition(self, transition_id: int, version: int) -> None: self._pending_transition = None _log.debug(f"Executed DAVE transition {transition_id} to version {version}") - async def _handle_prepare_epoch(self, payload: bytes) -> None: + async def _handle_prepare_epoch(self, data: dict) -> None: """ Handle PREPARE_EPOCH (24): reinitialise the session for a new MLS epoch. """ - _epoch, version = self._parse_transition(payload) + version = int(data.get("protocol_version", 0) or 0) await self.reinit(version) def _handle_external_sender(self, payload: bytes) -> None: @@ -424,15 +444,3 @@ def _extract_commit_welcome(result: _Session) -> bytes | None: except Exception: return None return None - - @staticmethod - def _parse_transition(payload: bytes) -> tuple[int, int]: - """ - Parse a transition payload into ``(transition_id, version)``. - - Transition payloads carry a 2-byte big-endian transition id optionally followed by - a 1-byte protocol version. Missing fields default to ``0``. - """ - transition_id = int.from_bytes(payload[:2], "big") if len(payload) >= 2 else 0 - version = payload[2] if len(payload) >= 3 else 0 - return transition_id, version diff --git a/discord_http/voice/player.py b/discord_http/voice/player.py index 71cc17d..1c24c6b 100644 --- a/discord_http/voice/player.py +++ b/discord_http/voice/player.py @@ -218,6 +218,7 @@ def __init__( self._process: asyncio.subprocess.Process | None = None self._stdin_task: asyncio.Task[None] | None = None self._stdout: asyncio.StreamReader | None = None + self._reap_task: asyncio.Task[None] | None = None async def _spawn(self) -> None: """ Launch the ffmpeg subprocess and start the stdin pump if piping. """ @@ -278,15 +279,49 @@ def cleanup(self) -> None: self._stdin_task = None process = self._process - if process is not None and process.returncode is None: + if process is not None: + if process.returncode is None: + try: + process.kill() + except ProcessLookupError: + pass + + # Close the subprocess transport so its stdio pipe transports are + # released deterministically. On the Windows Proactor event loop an + # unclosed pipe transport otherwise triggers a ResourceWarning + # ("unclosed transport" / "I/O operation on closed pipe") when it is + # finalized by the garbage collector. + transport = getattr(process, "_transport", None) + if transport is not None: + try: + transport.close() + except Exception: + pass + + # Reap the process so the OS releases it and the pipe transports + # finish closing. ``cleanup`` is synchronous, so schedule the wait on + # the running loop when there is one (there is none during a hard + # interpreter shutdown, where closing the transport above suffices). try: - process.kill() - except ProcessLookupError: - pass + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop is not None: + # Keep a reference so the fire-and-forget reaping task is not + # garbage-collected before it completes. + self._reap_task = loop.create_task(self._reap_process(process)) self._process = None self._stdout = None + @staticmethod + async def _reap_process(process: "asyncio.subprocess.Process") -> None: + """ Await the ffmpeg subprocess so it is fully reaped after termination. """ + try: + await process.wait() + except Exception: + pass + class FFmpegPCMAudio(_FFmpegAudio): """ diff --git a/discord_http/voice/socket.py b/discord_http/voice/socket.py index 6e84ff0..7577311 100644 --- a/discord_http/voice/socket.py +++ b/discord_http/voice/socket.py @@ -26,6 +26,7 @@ class VoiceCloseCode(BaseEnum): normal = 1000 going_away = 1001 + session_invalid = 4006 disconnected = 4014 voice_server_crashed = 4015 unknown_encryption_mode = 4016 @@ -207,6 +208,18 @@ def _dispatch_text(self, raw: str | bytes) -> None: case VoiceOpType.resumed: self._schedule(self.connection.on_resumed(data)) + case ( + VoiceOpType.dave_prepare_transition + | VoiceOpType.dave_execute_transition + | VoiceOpType.dave_prepare_epoch + ): + # The DAVE transition/epoch control ops arrive as JSON text frames + # (only the MLS data ops 25-31 are binary), so route them with the + # decoded payload. Handling these keeps the local MLS session in + # step with the gateway's epoch; ignoring them desyncs the epoch and + # leads to MLS WrongEpoch errors and a 4006 close on rejoin. + self._schedule(self.connection.on_dave_json(int(voice_op), data)) + case _: _log.debug(f"Voice socket for guild {self.connection.guild_id} received unhandled op {voice_op}") diff --git a/examples/voice_example.py b/examples/voice_example.py index 688787d..ebb138d 100644 --- a/examples/voice_example.py +++ b/examples/voice_example.py @@ -48,14 +48,13 @@ def caller_voice_channel(ctx: Context) -> "BaseChannel | PartialChannel | None": return None voice_state = ctx.guild.get_member_voice_state(ctx.user.id) - if voice_state is None or voice_state.channel_id is None: + if voice_state is None: return None - # ``VoiceState`` exposes the resolved ``channel`` directly; the partial variant only - # carries the id, so fall back to a partial channel that ``connect()`` can act on. - return getattr(voice_state, "channel", None) or client.get_partial_channel( - voice_state.channel_id, guild_id=ctx.guild.id - ) + # ``channel`` is available on both ``VoiceState`` and ``PartialVoiceState`` and + # always returns something ``connect()`` can act on (a partial channel when the + # full object is not cached), or ``None`` when the user is not in a channel. + return voice_state.channel @client.command() @@ -76,7 +75,7 @@ async def join(ctx: Context): @client.command() async def pause(ctx: Context): """ Pause / resume the current track """ - vc = client._get_voice_client(ctx.guild.id) if ctx.guild else None + vc = client.get_voice_client(ctx.guild.id) if ctx.guild else None if vc is None: return ctx.response.send_message("Not connected.") @@ -91,7 +90,7 @@ async def pause(ctx: Context): @client.command() async def leave(ctx: Context): """ Stop playback and disconnect """ - vc = client._get_voice_client(ctx.guild.id) if ctx.guild else None + vc = client.get_voice_client(ctx.guild.id) if ctx.guild else None if vc is None: return ctx.response.send_message("Not connected.") From 20cc1a28b0187df34386bb3c35bab3d1c73e4f3e Mon Sep 17 00:00:00 2001 From: Neppkun Date: Thu, 4 Jun 2026 15:30:07 +0300 Subject: [PATCH 09/12] Address review: validate voice_reconnect_attempts and preserve DAVE transition version - Client: raise ValueError when voice_reconnect_attempts is negative, so the voice retry logic can't receive an invalid count. - DaveManager: when handling a commit/welcome that drives a transition, preserve the negotiated protocol version recorded by DAVE_PREPARE_TRANSITION instead of overwriting it with the current version. A new _pending_transition_version() helper keeps the pending version when the transition_id matches and otherwise falls back to the current version (e.g. when the commit arrives first), so EXECUTE_TRANSITION applies the negotiated epoch. --- discord_http/client.py | 2 ++ discord_http/voice/dave.py | 24 ++++++++++++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/discord_http/client.py b/discord_http/client.py index d442259..a6d73fa 100644 --- a/discord_http/client.py +++ b/discord_http/client.py @@ -153,6 +153,8 @@ def __init__( self.logging_level: int = logging_level self.debug_events: bool = debug_events self.enable_gateway: bool = enable_gateway + if voice_reconnect_attempts < 0: + raise ValueError("voice_reconnect_attempts must be >= 0") self.voice_reconnect_attempts: int = voice_reconnect_attempts """ How many times a voice connection will try to fully reconnect after an diff --git a/discord_http/voice/dave.py b/discord_http/voice/dave.py index 4d7e3c6..b6f510a 100644 --- a/discord_http/voice/dave.py +++ b/discord_http/voice/dave.py @@ -372,7 +372,8 @@ async def _handle_commit(self, payload: bytes) -> None: return if transition_id != 0: - self._pending_transition = (transition_id, self._version) + version = self._pending_transition_version(transition_id) + self._pending_transition = (transition_id, version) await self._connection.socket.send_transition_ready(transition_id) async def _handle_welcome(self, payload: bytes) -> None: @@ -397,9 +398,28 @@ async def _handle_welcome(self, payload: bytes) -> None: return if transition_id != 0: - self._pending_transition = (transition_id, self._version) + version = self._pending_transition_version(transition_id) + self._pending_transition = (transition_id, version) await self._connection.socket.send_transition_ready(transition_id) + def _pending_transition_version(self, transition_id: int) -> int: + """ + Resolve the protocol version to record for a transition driven by an + incoming commit/welcome. + + When ``DAVE_PREPARE_TRANSITION`` already recorded the negotiated target + version for this ``transition_id``, preserve it so the later + ``EXECUTE_TRANSITION`` applies the negotiated epoch rather than the + current version. Otherwise (no matching pending transition, e.g. the + commit/welcome arrived first) fall back to the current version. + """ + if ( + self._pending_transition is not None + and self._pending_transition[0] == transition_id + ): + return self._pending_transition[1] + return self._version + async def _recover_from_invalid_commit(self) -> None: """ Notify the gateway of an invalid commit/welcome and reinitialise the session. """ await self._connection.socket.send_binary( From 40a928f6aa41142634c64bf54f860b91d135c622 Mon Sep 17 00:00:00 2001 From: Neppkun Date: Thu, 4 Jun 2026 18:26:38 +0300 Subject: [PATCH 10/12] Address CodeRabbit review findings and use URL parser for voice endpoint CodeRabbit findings: - channel.py: remove stale voice-client registry entry if connect() fails, so a failed/timed-out join no longer blocks all future joins. - channel.py: PartialVoiceState.fetch() now resolves channel from the fresh fetched payload's channel_id instead of the stale partial's id. - connection.py: exponential backoff now grows across reconnect attempts (connect() no longer resets backoff on internal reconnect calls). - connection.py: a successful RESUMED frame re-sets the connected event so is_connected() recovers after a server-crash resume. - connection.py: clear the stale MLS/DAVE session when a new SESSION_DESCRIPTION negotiates dave_protocol_version 0. - gateway_udp.py: detect RTCP control packets on the unmasked second byte so types 200-204 are actually dropped. - player.py: use shlex.split() for ffmpeg before_options/options to preserve shell quoting. - player.py: FFmpegPCMAudio.read() discards any trailing short PCM frame and signals EOF instead of forwarding a partial frame into libopus. - player.py: AudioPlayer marks itself finished on natural EOF so is_playing() reports correctly. - receiver.py: stop the previous listening session before swapping sinks. - examples/voice_example.py: convert latency (seconds) to ms before labeling. Maintainer feedback: - connection.py: parse the voice server endpoint via utils.URL while preserving the schemeless host:port semantics required to avoid 4006. --- discord_http/channel.py | 23 +++++++++------ discord_http/voice/connection.py | 49 ++++++++++++++++++++++++++----- discord_http/voice/gateway_udp.py | 7 +++-- discord_http/voice/player.py | 25 ++++++++++------ discord_http/voice/receiver.py | 5 ++++ examples/voice_example.py | 4 +-- 6 files changed, 83 insertions(+), 30 deletions(-) diff --git a/discord_http/channel.py b/discord_http/channel.py index 1573250..d39df48 100644 --- a/discord_http/channel.py +++ b/discord_http/channel.py @@ -546,13 +546,17 @@ async def connect( vc = VoiceClient(client, self) client._add_voice_client(self.guild_id, vc) - await vc.connect( - timeout=timeout, - reconnect=reconnect, - reconnect_on_session_invalid=reconnect_on_session_invalid, - self_deaf=self_deaf, - self_mute=self_mute - ) + try: + await vc.connect( + timeout=timeout, + reconnect=reconnect, + reconnect_on_session_invalid=reconnect_on_session_invalid, + self_deaf=self_deaf, + self_mute=self_mute + ) + except Exception: + client._remove_voice_client(self.guild_id) + raise return vc @@ -2885,8 +2889,9 @@ async def fetch(self) -> "VoiceState": guild = self._state.cache.get_guild(self.guild_id) channel = None - if self.channel_id is not None: - channel = self._state.cache.get_channel(self.guild_id, self.channel_id) + channel_id = utils.get_int(r.response, "channel_id") + if channel_id is not None: + channel = self._state.cache.get_channel(self.guild_id, channel_id) return VoiceState( state=self._state, diff --git a/discord_http/voice/connection.py b/discord_http/voice/connection.py index 0a144c0..c15b360 100644 --- a/discord_http/voice/connection.py +++ b/discord_http/voice/connection.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any -from ..utils import ExponentialBackoff +from ..utils import URL, ExponentialBackoff from .dave import DaveManager, has_dave from .encryptor import Encryptor from .enums import SUPPORTED_MODES @@ -146,7 +146,8 @@ async def connect( reconnect: bool = True, reconnect_on_session_invalid: bool = False, self_deaf: bool = False, - self_mute: bool = False + self_mute: bool = False, + _internal_reconnect: bool = False ) -> None: """ Establish the full voice connection. @@ -188,7 +189,14 @@ async def connect( self._self_mute = self_mute self._self_deaf = self_deaf self._closing = False - self._backoff.reset() + + # Only reset the exponential backoff on a genuinely fresh/initial + # connect. The reconnect loop in _full_reconnect() drives its own + # backoff and passes _internal_reconnect=True so that each retry's + # connect() call does NOT reset it -- otherwise the delay would never + # grow and every attempt would sleep the base delay. + if not _internal_reconnect: + self._backoff.reset() self._state_event.clear() self._server_event.clear() @@ -337,6 +345,8 @@ async def _full_reconnect(self, close_code: int | None) -> None: reconnect_on_session_invalid=self._reconnect_on_session_invalid, self_deaf=self._self_deaf, self_mute=self._self_mute, + # Preserve the growing backoff across reconnect attempts. + _internal_reconnect=True, ) except Exception as exc: _log.warning(f"Voice reconnect attempt {attempt} for guild {self.guild_id} failed", exc_info=exc) @@ -390,10 +400,23 @@ def on_voice_server_update(self, data: dict) -> None: # token/session are bound to that specific host:port, so we MUST keep # the port intact -- connecting to host:443 instead reaches a # different voice server instance and Discord closes the socket with - # 4006 "Session is no longer valid". Only strip a scheme if present. - if endpoint.startswith("wss://"): - endpoint = endpoint[len("wss://"):] - self.endpoint = endpoint.rstrip("/") + # 4006 "Session is no longer valid". + # + # ``urlparse`` (which ``utils.URL`` wraps) treats a bare "host:port" + # as scheme:path, so we normalise by stripping any existing scheme + # and prepending "wss://" before parsing, then reconstruct the + # schemeless host[:port] string that self.endpoint is meant to hold. + host_port = endpoint.rstrip("/") + for _scheme in ("wss://", "https://"): + if host_port.startswith(_scheme): + host_port = host_port[len(_scheme):] + break + + url = URL("wss://" + host_port) + if url.port is not None: + self.endpoint = f"{url.host}:{url.port}" + else: + self.endpoint = url.host or host_port server_id = data.get("guild_id") or data.get("server_id") self.server_id = int(server_id) if server_id is not None else None @@ -446,6 +469,14 @@ async def on_session_description(self, data: dict) -> None: self.dave_protocol_version = dave_version if dave_version > 0: await self.reinit_dave_session() + elif self.dave_session is not None: + # DAVE was negotiated off (version 0) but a prior connection left a + # session behind. Tear down the MLS session so can_encrypt() and the + # opus wrappers don't keep using stale E2EE state. DaveManager has no + # dedicated close/cleanup; reinit(0) clears its internal session, then + # we drop the reference. + await self.dave_session.reinit(0) + self.dave_session = None self._connected_event.set() @@ -477,6 +508,10 @@ async def on_resumed(self, data: dict) -> None: # noqa: ARG002 The resumed payload. """ _log.debug(f"Voice connection for guild {self.guild_id} resumed") + # A successful RESUMED frame means the connection is live again. _resume() + # cleared _connected_event when re-opening the socket, so re-set it here; + # otherwise is_connected() would stay False forever after a server crash. + self._connected_event.set() async def on_dave_binary(self, opcode: int, payload: bytes) -> None: """ diff --git a/discord_http/voice/gateway_udp.py b/discord_http/voice/gateway_udp.py index 7ce2d16..86bfa16 100644 --- a/discord_http/voice/gateway_udp.py +++ b/discord_http/voice/gateway_udp.py @@ -86,9 +86,10 @@ def datagram_received(self, data: bytes, addr: tuple) -> None: # noqa: ARG002 self._discovery_future.set_result(data) return - # Drop RTCP control packets (payload types 200-204). - payload_type = data[1] & 0x7F - if 200 <= payload_type <= 204: + # Drop RTCP control packets. For RTCP the second byte is the raw packet + # type (200-204); for RTP it is the marker bit plus a 7-bit payload type, + # so RTCP must be detected on the unmasked byte before any RTP masking. + if 200 <= data[1] <= 204: return # Otherwise treat it as RTP and hand it to the receiver, if any. diff --git a/discord_http/voice/player.py b/discord_http/voice/player.py index 1c24c6b..bebe3e2 100644 --- a/discord_http/voice/player.py +++ b/discord_http/voice/player.py @@ -3,6 +3,7 @@ import io import logging import os +import shlex import shutil from array import array @@ -329,8 +330,8 @@ class FFmpegPCMAudio(_FFmpegAudio): ffmpeg decodes the input to signed 16-bit little-endian PCM at 48kHz in stereo, and :meth:`read` returns one 3840-byte frame per call via - :meth:`asyncio.StreamReader.readexactly`. The final, possibly short frame is - returned once before end of stream is signalled with empty bytes. + :meth:`asyncio.StreamReader.readexactly`. Any trailing partial frame at the + end of stream is discarded and end of stream is signalled with empty bytes. Because the output is PCM, libopus is required to encode it before sending. @@ -358,11 +359,11 @@ def __init__( pipe: bool = False, executable: str = "ffmpeg", ) -> None: - before_args = before_options.split() if before_options is not None else None + before_args = shlex.split(before_options) if before_options is not None else None args = ["-f", "s16le", "-ar", "48000", "-ac", "2", "-loglevel", "warning"] if options is not None: - args.extend(options.split()) + args.extend(shlex.split(options)) args.append("pipe:1") super().__init__(source, args=args, before_args=before_args, executable=executable, pipe=pipe) @@ -375,9 +376,11 @@ async def read(self) -> bytes: assert self._stdout is not None try: return await self._stdout.readexactly(FRAME_SIZE) - except asyncio.IncompleteReadError as exc: - # End of stream: return the trailing partial frame, then b"". - return bytes(exc.partial) + except asyncio.IncompleteReadError: + # End of stream: discard any trailing partial frame and signal EOF. + # The contract requires exactly FRAME_SIZE bytes or empty bytes; a + # short PCM frame would be mis-encoded by libopus downstream. + return b"" class FFmpegOpusAudio(_FFmpegAudio): @@ -436,7 +439,7 @@ def __init__( pipe: bool = False, executable: str = "ffmpeg", ) -> None: - before_args = before_options.split() if before_options is not None else None + before_args = shlex.split(before_options) if before_options is not None else None args = [ "-c:a", "libopus", @@ -447,7 +450,7 @@ def __init__( "-loglevel", "warning", ] if options is not None: - args.extend(options.split()) + args.extend(shlex.split(options)) args.append("pipe:1") super().__init__(source, args=args, before_args=before_args, executable=executable, pipe=pipe) @@ -613,6 +616,10 @@ async def _run(self) -> None: async def _cleanup(self) -> None: """ Flush silence, stop speaking, clean up and invoke ``after``. """ + # Mark the player as finished so ``is_playing()``/``is_paused()`` report + # correctly after natural EOF. Idempotent: ``stop()`` may have set it. + self._end.set() + try: for _ in range(5): self.voice_client.send_audio_packet(OPUS_SILENCE, encode=False) diff --git a/discord_http/voice/receiver.py b/discord_http/voice/receiver.py index 4f8660d..d75c61c 100644 --- a/discord_http/voice/receiver.py +++ b/discord_http/voice/receiver.py @@ -73,6 +73,11 @@ def start(self, sink: "AudioSink") -> None: sink: The sink to receive decoded PCM or raw Opus audio. """ + # Tear down any in-progress session so its sink is cleaned up and stale + # per-SSRC decoder/sequence state does not leak into the new sink. + if self.sink is not None: + self.stop() + self.sink = sink def stop(self) -> None: diff --git a/examples/voice_example.py b/examples/voice_example.py index ebb138d..f53bc48 100644 --- a/examples/voice_example.py +++ b/examples/voice_example.py @@ -69,7 +69,7 @@ async def join(ctx: Context): # Play a local file (mp3 -> opus passthrough, ffmpeg only). vc.play("song.mp3") - return ctx.response.send_message(f"Now playing, latency: {vc.latency:.1f}ms") + return ctx.response.send_message(f"Now playing, latency: {vc.latency * 1000:.1f}ms") @client.command() @@ -126,7 +126,7 @@ async def voice_demo(channel: BaseChannel, move_to: BaseChannel) -> None: await asyncio.sleep(10) vc.stop_listening() - print(f"voice latency: {vc.latency:.1f}ms (avg {vc.average_latency:.1f}ms)") + print(f"voice latency: {vc.latency * 1000:.1f}ms (avg {vc.average_latency * 1000:.1f}ms)") vc.stop() await vc.disconnect() From a84d5225bae2d57a2de329f9a1c0cc721974fb12 Mon Sep 17 00:00:00 2001 From: Neppkun Date: Thu, 4 Jun 2026 20:07:38 +0300 Subject: [PATCH 11/12] Fix voice session-invalid recovery, reconnect backoff, and review findings This addresses a second round of review feedback plus a real reconnect bug observed in the field, across the voice stack. == reconnect_on_session_invalid (voice/connection.py) == Symptom: a human leaving+rejoining the bot's channel rebuilds the DAVE/MLS group and Discord closes the voice websocket with 4006 ("session no longer valid"). The old reconnect re-issued op4 change_voice_state with the SAME channel_id, which is a no-op for the voice server (the bot never left at the gateway level), so no fresh VOICE_SERVER_UPDATE ever arrived and every attempt timed out on _server_event.wait(). A subsequent play command timed out too. Fix is a hybrid that mirrors discord.py's verified patterns: - _soft_reconnect(): on 4006, first try a no-leave recovery -- close the old socket/UDP locally (WITHOUT sending op4), open a fresh VoiceSocket and re-IDENTIFY (op 0) with the still-valid session_id/token/endpoint. This is discord.py's _potential_reconnect approach and avoids any user-visible leave/rejoin blip when the credentials are still usable. - _full_reconnect(force_refresh=...): only if the soft attempt fails do we fall back to the leave/rejoin bounce. _force_voice_refresh() sends change_voice_state(channel_id=None), waits for the gateway leave ack (_left_event) so the leave+rejoin is not coalesced into a no-op, then connect() rejoins -- forcing Discord to allocate a fresh voice server. The bounce is gated behind force_refresh so unrelated reconnect paths (resume fallback, generic codes) do NOT cause a spurious leave/rejoin. - connect() now joins using the stable voice_client.channel.id instead of the mutable self.channel_id, because the leave echo (channel_id=None) clobbers self.channel_id via on_voice_state_update and would otherwise rejoin nothing. - _left_event is set/cleared in on_voice_state_update to signal the leave ack. Why hybrid rather than always-leave: discord.py treats 4006 as an unhandled code and always leaves+rejoins, but its 4014 path (_potential_reconnect) proves a no-leave re-IDENTIFY works when creds survive. Soft-first gives a no-blip recovery in the common case while keeping the robust leave/rejoin fallback. One soft attempt per 4006 event prevents a soft->4006->soft loop. == Review findings == channel.py (PartialChannel.connect): catch BaseException, not just Exception, around vc.connect(). asyncio.CancelledError is a BaseException on 3.11+, so a cancelled/timed-out connect previously skipped cleanup and left a stale _voice_clients entry that blocked all future connects in the guild. voice/client.py: removed the VoiceClient.client/.bot duplication, keeping `bot` to match the library-wide convention (self.bot is used 243x vs self.client 10x, all in voice; no external/test refs to a voice .client). Updated internal refs and the VoiceConnection.client property to read self.voice_client.bot. voice/player.py: - FFmpegOpusAudio._packets switched from list+pop(0) (O(n)) to collections.deque + popleft() for the FIFO packet queue. - AudioPlayer._run now logs a warning and re-raises if _cleanup() is interrupted by CancelledError, so the silently-skipped trailing silence/speaking-off is at least observable. Did not use asyncio.shield (risks hanging shutdown). voice/gateway_udp.py: IP-discovery detection strengthened from a single data[1]==0x02 check to data[0]==0x00 and data[1]==0x02 and len(data)>=74 (the fixed discovery response shape), so a coincidental RTP packet can't resolve the discovery future. voice/socket.py: - _schedule now retains strong references to dispatched tasks in self._dispatch_tasks (with a done-callback to discard), instead of relying on a fire-and-forget create_task that could be GC'd mid-flight; dropped the noqa. - Added VoiceCloseCode.dave_e2ee_required = 4017 for completeness/discoverability (referenced in a _receive_loop comment). voice/sinks.py: documented that WaveSink mixes all speakers into one stream and ignores the `user` arg (behavior unchanged) to prevent surprise. voice/connection.py (also): added a clarifying note in _full_reconnect that the backoff is reset once before the loop and grows across attempts; renamed a loop variable _scheme -> scheme (RUF052); reworked a DAVE docstring summary (D205). == Validation == ruff check --config pyproject.toml (ruff 0.15.15, matching uv.lock) -> clean pyright on all touched files -> 0 errors/warnings (the pre-existing dave.py:141 typing error on the baseline is unrelated and untouched) pytest -k "voice or channel" -> 11 passed --- discord_http/channel.py | 6 +- discord_http/voice/client.py | 15 ++- discord_http/voice/connection.py | 147 +++++++++++++++++++++++++++--- discord_http/voice/dave.py | 6 +- discord_http/voice/gateway_udp.py | 12 ++- discord_http/voice/player.py | 14 ++- discord_http/voice/sinks.py | 10 +- discord_http/voice/socket.py | 8 +- 8 files changed, 182 insertions(+), 36 deletions(-) diff --git a/discord_http/channel.py b/discord_http/channel.py index d39df48..1431439 100644 --- a/discord_http/channel.py +++ b/discord_http/channel.py @@ -554,7 +554,11 @@ async def connect( self_deaf=self_deaf, self_mute=self_mute ) - except Exception: + except BaseException: + # Catch BaseException, not just Exception: a timeout or a cancelled + # connect (asyncio.CancelledError is a BaseException on 3.11+) must + # still remove the half-registered voice client, otherwise the stale + # entry blocks future connect attempts in this guild. client._remove_voice_client(self.guild_id) raise diff --git a/discord_http/voice/client.py b/discord_http/voice/client.py index 34c4879..a25b143 100644 --- a/discord_http/voice/client.py +++ b/discord_http/voice/client.py @@ -29,11 +29,8 @@ class VoiceClient: """ def __init__(self, client: "Client", channel: "PartialChannel"): - self.client: "Client" = client - """ The bot client that owns this voice client. """ - self.bot: "Client" = client - """ Alias of :attr:`client`. """ + """ The bot client that owns this voice client. """ self.channel: "PartialChannel" = channel """ The voice channel this client is connected to. """ @@ -54,12 +51,12 @@ def __init__(self, client: "Client", channel: "PartialChannel"): @property def loop(self) -> asyncio.AbstractEventLoop: """ The event loop the client runs on. """ - return self.client.loop + return self.bot.loop @property def user_id(self) -> int: """ The ID of the bot user. """ - return self.client.user.id + return self.bot.user.id @property def ssrc(self) -> int | None: @@ -157,7 +154,7 @@ async def disconnect(self, *, force: bool = True) -> None: self._encoder = None await self.connection.disconnect(force=force) - self.client._remove_voice_client(self.guild_id) + self.bot._remove_voice_client(self.guild_id) async def _cleanup(self) -> None: """ @@ -180,7 +177,7 @@ async def _cleanup(self) -> None: self._encoder = None await self.connection.close_transport() - self.client._remove_voice_client(self.guild_id) + self.bot._remove_voice_client(self.guild_id) async def move_to(self, channel: "PartialChannel | int") -> None: """ @@ -192,7 +189,7 @@ async def move_to(self, channel: "PartialChannel | int") -> None: The channel to move to, either a channel object or its ID. """ if isinstance(channel, int): - channel = self.client.get_partial_channel(channel, guild_id=self.guild_id) + channel = self.bot.get_partial_channel(channel, guild_id=self.guild_id) await self.connection.move_to(channel) self.channel = channel diff --git a/discord_http/voice/connection.py b/discord_http/voice/connection.py index c15b360..77bba5b 100644 --- a/discord_http/voice/connection.py +++ b/discord_http/voice/connection.py @@ -100,6 +100,10 @@ def __init__(self, voice_client: "VoiceClient"): self._server_event: asyncio.Event = asyncio.Event() self._ready_event: asyncio.Event = asyncio.Event() self._connected_event: asyncio.Event = asyncio.Event() + # Set when the gateway acknowledges the bot leaving the channel + # (VOICE_STATE_UPDATE with channel_id=None); used by the reconnect + # bounce to confirm the leave before rejoining. + self._left_event: asyncio.Event = asyncio.Event() self._reconnect: bool = True self._reconnect_on_session_invalid: bool = False @@ -112,7 +116,7 @@ def __init__(self, voice_client: "VoiceClient"): @property def client(self) -> "Client": """ The bot client that owns this voice connection. """ - return self.voice_client.client + return self.voice_client.bot @property def latency(self) -> float: @@ -203,9 +207,13 @@ async def connect( self._ready_event.clear() self._connected_event.clear() + # Join the voice client's channel, which is the authoritative target. + # Do NOT use self.channel_id here: the reconnect bounce sends op4 with + # channel_id=None and the gateway echoes a VOICE_STATE_UPDATE that + # clears self.channel_id, so reading it could rejoin the wrong channel. await shard.change_voice_state( guild_id=self.guild_id, - channel_id=self.channel_id, + channel_id=self.voice_client.channel.id, self_mute=self_mute, self_deaf=self_deaf, ) @@ -267,16 +275,28 @@ async def _handle_close(self, close_code: int | None) -> None: return if close_code == VoiceCloseCode.session_invalid: - # Discord invalidates the session (4006) most commonly when the channel - # empties out and the DAVE/MLS session is torn down. The session is gone - # for good, so reconnecting only helps if the caller explicitly opted in; - # otherwise we disconnect rather than retry-and-time-out. - if self._reconnect and self._reconnect_on_session_invalid: - _log.info(f"Voice session for guild {self.guild_id} invalidated (code {close_code}); reconnecting") - await self._full_reconnect(close_code) - else: + # Discord invalidates the session (4006) most commonly when another + # member leaves/rejoins and the DAVE/MLS group is rebuilt. The bot is + # still in the channel, so reconnecting only helps if the caller opted + # in; otherwise we disconnect rather than retry-and-time-out. + if not (self._reconnect and self._reconnect_on_session_invalid): _log.info(f"Voice session for guild {self.guild_id} invalidated (code {close_code}); disconnecting") await self._teardown_and_remove() + return + + # Try a soft reconnect first: re-open the voice websocket and + # re-IDENTIFY with the existing token/session, WITHOUT leaving the + # channel. This avoids a visible disconnect/rejoin blip when the + # credentials are still usable. + _log.info(f"Voice session for guild {self.guild_id} invalidated (code {close_code}); attempting soft reconnect") + if await self._soft_reconnect(): + _log.info(f"Voice connection for guild {self.guild_id} recovered without rejoining") + return + + # The existing credentials are no longer usable; fall back to a full + # reconnect that drops and re-acquires the gateway voice state. + _log.info(f"Soft reconnect for guild {self.guild_id} failed; falling back to rejoin") + await self._full_reconnect(close_code, force_refresh=True) return if close_code in (VoiceCloseCode.normal, VoiceCloseCode.going_away): @@ -305,7 +325,52 @@ async def _resume(self) -> None: _log.warning(f"Voice resume for guild {self.guild_id} failed; falling back to full reconnect", exc_info=exc) await self._full_reconnect(int(VoiceCloseCode.voice_server_crashed)) - async def _full_reconnect(self, close_code: int | None) -> None: + async def _soft_reconnect(self, timeout: float = 10.0) -> bool: + """ + Re-open the voice websocket and re-IDENTIFY without leaving the channel. + + On a 4006 the gateway voice state (session id) and the voice server + token/endpoint are often still valid; the bot never left the channel. + Opening a fresh socket and sending IDENTIFY (op 0) with those existing + credentials can recover the session with no user-visible blip. Returns + ``True`` on a successful handshake, ``False`` if the credentials are + missing or the handshake does not complete in time (the caller then + falls back to a full leave/rejoin reconnect). + + Parameters + ---------- + timeout: + How long to wait for the re-handshake to complete, in seconds. + """ + if not (self.session_id and self.token and self.endpoint): + return False + + # Tear down the old socket/UDP locally WITHOUT sending op4, so the gateway + # keeps our voice state and we can reuse the session id and token. + if self.socket is not None: + self.socket._request_close() + await self.socket.close() + self.socket = None + + if self.transport is not None: + self.transport.close() + self.transport = None + self.udp = None + + self._ready_event.clear() + self._connected_event.clear() + + try: + self.socket = VoiceSocket(self) + await self.socket.connect() + await asyncio.wait_for(self._connected_event.wait(), timeout) + except Exception as exc: + _log.debug(f"Soft reconnect for guild {self.guild_id} did not complete", exc_info=exc) + return False + + return True + + async def _full_reconnect(self, close_code: int | None, *, force_refresh: bool = False) -> None: """ Re-issue op4 and run a fresh handshake, retrying with exponential backoff. @@ -313,6 +378,12 @@ async def _full_reconnect(self, close_code: int | None) -> None: ---------- close_code: The close code that triggered the reconnect, for logging. + force_refresh: + When ``True``, drop and re-acquire the gateway voice state (leave and + rejoin the channel) before each attempt. Needed for a 4006 where the + bot is still in the channel, so re-issuing op4 with the same channel + would be a no-op and yield no fresh VOICE_SERVER_UPDATE. Defaults to + ``False`` so unrelated reconnects do not cause a visible leave/rejoin. """ if self.socket is not None: self.socket._request_close() @@ -324,6 +395,9 @@ async def _full_reconnect(self, close_code: int | None) -> None: self.transport = None self.udp = None + # Each attempt below deliberately re-issues op4 (change_voice_state) via + # connect(..., _internal_reconnect=True). The backoff is reset once here + # and then grows across attempts (connect() skips resetting it). self._backoff.reset() max_attempts = self.client.voice_reconnect_attempts @@ -340,6 +414,10 @@ async def _full_reconnect(self, close_code: int | None) -> None: await asyncio.sleep(delay) try: + if force_refresh: + # Force Discord to allocate a fresh voice server before + # rejoining (only needed for the 4006 leave/rejoin fallback). + await self._force_voice_refresh() await self.connect( reconnect=self._reconnect, reconnect_on_session_invalid=self._reconnect_on_session_invalid, @@ -358,6 +436,41 @@ async def _full_reconnect(self, close_code: int | None) -> None: _log.error(f"Voice connection for guild {self.guild_id} could not reconnect after {max_attempts} attempts; tearing down") await self._teardown_and_remove() + async def _force_voice_refresh(self) -> None: + """ + Drop the gateway voice state so Discord re-allocates a fresh voice server. + + On a 4006 ("session no longer valid") the voice session is dead but the + bot is still in the channel at the gateway level. Re-issuing op4 with the + same ``channel_id`` is then a no-op for the voice server: no new + VOICE_SERVER_UPDATE arrives and the handshake times out waiting for it. + + Sending ``channel_id=None`` first makes the gateway drop the voice state, + so the subsequent rejoin in :meth:`connect` is a genuine state change and + yields a fresh token/endpoint. We wait for the gateway to acknowledge the + leave (via :attr:`_left_event`) so it does not coalesce the leave and the + immediate rejoin into a single no-op. + """ + client = self.client + shard_id = client.get_shard_by_guild_id(self.guild_id) + shard = client.gateway.get_shard(shard_id) if (client.gateway and shard_id is not None) else None + if shard is None: + return + + self._left_event.clear() + try: + await shard.change_voice_state(guild_id=self.guild_id, channel_id=None) + except Exception as exc: + _log.debug(f"Failed to send voice-state reset for guild {self.guild_id}", exc_info=exc) + return + + try: + await asyncio.wait_for(self._left_event.wait(), timeout=5.0) + except TimeoutError: + # No leave ack arrived (e.g. the bot was already out); a short settle + # still lets the gateway register the reset before we rejoin. + await asyncio.sleep(0.25) + async def _teardown_and_remove(self) -> None: """ Tear down the connection and remove the voice client from the registry. """ try: @@ -379,6 +492,12 @@ def on_voice_state_update(self, data: dict) -> None: channel_id = data.get("channel_id") self.channel_id = int(channel_id) if channel_id is not None else None + # Track leave/rejoin so the reconnect bounce can await the leave ack. + if self.channel_id is None: + self._left_event.set() + else: + self._left_event.clear() + if self.session_id is not None: self._state_event.set() @@ -407,9 +526,9 @@ def on_voice_server_update(self, data: dict) -> None: # and prepending "wss://" before parsing, then reconstruct the # schemeless host[:port] string that self.endpoint is meant to hold. host_port = endpoint.rstrip("/") - for _scheme in ("wss://", "https://"): - if host_port.startswith(_scheme): - host_port = host_port[len(_scheme):] + for scheme in ("wss://", "https://"): + if host_port.startswith(scheme): + host_port = host_port[len(scheme):] break url = URL("wss://" + host_port) diff --git a/discord_http/voice/dave.py b/discord_http/voice/dave.py index b6f510a..d9c0c39 100644 --- a/discord_http/voice/dave.py +++ b/discord_http/voice/dave.py @@ -404,10 +404,10 @@ async def _handle_welcome(self, payload: bytes) -> None: def _pending_transition_version(self, transition_id: int) -> int: """ - Resolve the protocol version to record for a transition driven by an - incoming commit/welcome. + Resolve the protocol version to record for a transition. - When ``DAVE_PREPARE_TRANSITION`` already recorded the negotiated target + The transition is driven by an incoming commit/welcome. When + ``DAVE_PREPARE_TRANSITION`` already recorded the negotiated target version for this ``transition_id``, preserve it so the later ``EXECUTE_TRANSITION`` applies the negotiated epoch rather than the current version. Otherwise (no matching pending transition, e.g. the diff --git a/discord_http/voice/gateway_udp.py b/discord_http/voice/gateway_udp.py index 86bfa16..28e497e 100644 --- a/discord_http/voice/gateway_udp.py +++ b/discord_http/voice/gateway_udp.py @@ -81,8 +81,16 @@ def datagram_received(self, data: bytes, addr: tuple) -> None: # noqa: ARG002 if len(data) < 2: return - # IP discovery response (type 0x0002 in the second byte). - if data[1] == 0x02 and self._discovery_future is not None and not self._discovery_future.done(): + # IP discovery response: 2-byte big-endian type 0x0002 (bytes 0-1) and a + # fixed 74-byte length. Checking all three avoids matching a coincidental + # RTP packet whose second byte happens to be 0x02. + if ( + data[0] == 0x00 + and data[1] == 0x02 + and len(data) >= 74 + and self._discovery_future is not None + and not self._discovery_future.done() + ): self._discovery_future.set_result(data) return diff --git a/discord_http/voice/player.py b/discord_http/voice/player.py index bebe3e2..b8e5911 100644 --- a/discord_http/voice/player.py +++ b/discord_http/voice/player.py @@ -7,6 +7,7 @@ import shutil from array import array +from collections import deque from collections.abc import AsyncIterable, Callable from typing import TYPE_CHECKING @@ -457,7 +458,7 @@ def __init__( self._buffer = bytearray() self._partial = bytearray() - self._packets: list[bytes] = [] + self._packets: deque[bytes] = deque() self._eof = False async def _fill_buffer(self) -> bool: @@ -527,7 +528,7 @@ async def read(self) -> bytes: return b"" await self._fill_buffer() - return self._packets.pop(0) + return self._packets.popleft() def is_opus(self) -> bool: """ Whether frames are Opus packets (always ``True`` for this source). """ @@ -612,7 +613,14 @@ async def _run(self) -> None: except Exception as exc: self._error = exc finally: - await self._cleanup() + try: + await self._cleanup() + except asyncio.CancelledError: + _log.warning( + f"Audio player cleanup for guild {self.voice_client.guild_id} " + "was interrupted; trailing silence and speaking-off may be skipped" + ) + raise async def _cleanup(self) -> None: """ Flush silence, stop speaking, clean up and invoke ``after``. """ diff --git a/discord_http/voice/sinks.py b/discord_http/voice/sinks.py index c9b5efa..052c07d 100644 --- a/discord_http/voice/sinks.py +++ b/discord_http/voice/sinks.py @@ -119,7 +119,12 @@ def write(self, user: int | None, data: VoiceData) -> None: class WaveSink(AudioSink): - """ Audio sink that writes received PCM to a single 48kHz 16-bit stereo WAV file. """ + """ + Audio sink that writes received PCM to a single 48kHz 16-bit stereo WAV file. + + All speakers are mixed into a single stream; the ``user`` argument to + :meth:`write` is ignored, so per-speaker separation is not preserved. + """ def __init__(self, destination: str | os.PathLike | io.IOBase) -> None: """ @@ -161,7 +166,8 @@ def write(self, user: int | None, data: VoiceData) -> None: # noqa: ARG002 Parameters ---------- user: - The user ID the audio belongs to, or ``None`` if unknown (unused) + The user ID the audio belongs to, or ``None`` if unknown (unused; + all speakers are mixed into the same WAV stream) data: The voice data container holding the PCM payload """ diff --git a/discord_http/voice/socket.py b/discord_http/voice/socket.py index 7577311..e09b232 100644 --- a/discord_http/voice/socket.py +++ b/discord_http/voice/socket.py @@ -30,6 +30,7 @@ class VoiceCloseCode(BaseEnum): disconnected = 4014 voice_server_crashed = 4015 unknown_encryption_mode = 4016 + dave_e2ee_required = 4017 bad_request = 4020 rate_limited = 4021 call_terminated = 4022 @@ -60,6 +61,7 @@ def __init__(self, connection: "VoiceConnection"): self._heartbeat_interval: float = 0.0 self._heartbeat_task: asyncio.Task | None = None self._receive_task: asyncio.Task | None = None + self._dispatch_tasks: set[asyncio.Task] = set() self._last_send: float = 0.0 self._latencies: deque[float] = deque(maxlen=20) @@ -92,7 +94,7 @@ def session(self) -> ClientSession: RuntimeError If the HTTP session is not available (the client is not running). """ - session = self.connection.voice_client.client.state.http.session + session = self.connection.voice_client.bot.state.http.session if session is None: raise RuntimeError("HTTP session is not available; the client must be running to open a voice socket") return session @@ -251,10 +253,12 @@ def _schedule(self, coro: Coroutine[Any, Any, Any]) -> None: coro: The coroutine to run independently of the receive loop. """ - asyncio.create_task( # noqa: RUF006 + task = asyncio.create_task( self._guard(coro), name=f"discord.http/voice/socket-{self.connection.guild_id}/dispatch" ) + self._dispatch_tasks.add(task) + task.add_done_callback(self._dispatch_tasks.discard) async def _guard(self, coro: Coroutine[Any, Any, Any]) -> None: """ From c0ecdacc7a09e251d2f98614d55636498404a7c8 Mon Sep 17 00:00:00 2001 From: Neppkun Date: Thu, 4 Jun 2026 20:31:38 +0300 Subject: [PATCH 12/12] Lower auto-handled voice log lines from INFO/WARN to DEBUG Routine, automatically-recovered voice events were logging at INFO/WARN and spamming the console during normal operation (e.g. a member rejoining triggers a 4006 + MLS rebuild that the library transparently recovers from). Demote those to DEBUG; keep INFO/WARN/ERROR only for conditions that are NOT auto-recovered and likely need operator attention. Demoted to DEBUG (transient / self-healing): - connection.py: disconnect+teardown (4014/4022), server-crash resume (4015), 4006 disconnect (when reconnect not opted in), soft-reconnect attempt/success/ fallback, resume-failed fallback, per-attempt "Reconnecting..." line, failed reconnect attempt (will retry), and reconnect success. - socket.py: websocket ERROR frame (the receive loop breaks and the connection auto-reconnects). - dave.py: "Failed to process MLS proposals/commit/welcome" -- each immediately calls _recover_from_invalid_commit() (notify gateway + reinit), so they are fully handled. Kept as-is (genuine, non-auto-recovered failures): - connection.py: rate-limited give-up (WARN), reconnect exhausted after N attempts + teardown (ERROR), DAVE op received without the davey library (WARN). - socket.py: unexpected exception in a dispatch handler (ERROR). - dave.py: DAVE session init failure -- degrades to no session (WARN). - player.py: cleanup interrupted by cancellation -- re-raised, not swallowed (WARN). No behavior change; logging levels only. --- discord_http/voice/connection.py | 20 ++++++++++---------- discord_http/voice/dave.py | 6 +++--- discord_http/voice/socket.py | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/discord_http/voice/connection.py b/discord_http/voice/connection.py index 77bba5b..e17ce77 100644 --- a/discord_http/voice/connection.py +++ b/discord_http/voice/connection.py @@ -260,7 +260,7 @@ async def _handle_close(self, close_code: int | None) -> None: The websocket close code, if one was reported. """ if close_code in (VoiceCloseCode.disconnected, VoiceCloseCode.call_terminated): - _log.info(f"Voice connection for guild {self.guild_id} disconnected (code {close_code}); tearing down") + _log.debug(f"Voice connection for guild {self.guild_id} disconnected (code {close_code}); tearing down") await self._teardown_and_remove() return @@ -270,7 +270,7 @@ async def _handle_close(self, close_code: int | None) -> None: return if close_code == VoiceCloseCode.voice_server_crashed: - _log.info(f"Voice server for guild {self.guild_id} crashed (code {close_code}); resuming") + _log.debug(f"Voice server for guild {self.guild_id} crashed (code {close_code}); resuming") await self._resume() return @@ -280,7 +280,7 @@ async def _handle_close(self, close_code: int | None) -> None: # still in the channel, so reconnecting only helps if the caller opted # in; otherwise we disconnect rather than retry-and-time-out. if not (self._reconnect and self._reconnect_on_session_invalid): - _log.info(f"Voice session for guild {self.guild_id} invalidated (code {close_code}); disconnecting") + _log.debug(f"Voice session for guild {self.guild_id} invalidated (code {close_code}); disconnecting") await self._teardown_and_remove() return @@ -288,14 +288,14 @@ async def _handle_close(self, close_code: int | None) -> None: # re-IDENTIFY with the existing token/session, WITHOUT leaving the # channel. This avoids a visible disconnect/rejoin blip when the # credentials are still usable. - _log.info(f"Voice session for guild {self.guild_id} invalidated (code {close_code}); attempting soft reconnect") + _log.debug(f"Voice session for guild {self.guild_id} invalidated (code {close_code}); attempting soft reconnect") if await self._soft_reconnect(): - _log.info(f"Voice connection for guild {self.guild_id} recovered without rejoining") + _log.debug(f"Voice connection for guild {self.guild_id} recovered without rejoining") return # The existing credentials are no longer usable; fall back to a full # reconnect that drops and re-acquires the gateway voice state. - _log.info(f"Soft reconnect for guild {self.guild_id} failed; falling back to rejoin") + _log.debug(f"Soft reconnect for guild {self.guild_id} failed; falling back to rejoin") await self._full_reconnect(close_code, force_refresh=True) return @@ -322,7 +322,7 @@ async def _resume(self) -> None: self.socket = VoiceSocket(self) await self.socket.connect(resume=True) except Exception as exc: - _log.warning(f"Voice resume for guild {self.guild_id} failed; falling back to full reconnect", exc_info=exc) + _log.debug(f"Voice resume for guild {self.guild_id} failed; falling back to full reconnect", exc_info=exc) await self._full_reconnect(int(VoiceCloseCode.voice_server_crashed)) async def _soft_reconnect(self, timeout: float = 10.0) -> bool: @@ -407,7 +407,7 @@ async def _full_reconnect(self, close_code: int | None, *, force_refresh: bool = return delay = self._backoff.delay() - _log.info( + _log.debug( f"Reconnecting voice for guild {self.guild_id} (close code {close_code}), " f"attempt {attempt}/{max_attempts} in {delay:.2f}s" ) @@ -427,10 +427,10 @@ async def _full_reconnect(self, close_code: int | None, *, force_refresh: bool = _internal_reconnect=True, ) except Exception as exc: - _log.warning(f"Voice reconnect attempt {attempt} for guild {self.guild_id} failed", exc_info=exc) + _log.debug(f"Voice reconnect attempt {attempt} for guild {self.guild_id} failed", exc_info=exc) continue else: - _log.info(f"Voice connection for guild {self.guild_id} reconnected") + _log.debug(f"Voice connection for guild {self.guild_id} reconnected") return _log.error(f"Voice connection for guild {self.guild_id} could not reconnect after {max_attempts} attempts; tearing down") diff --git a/discord_http/voice/dave.py b/discord_http/voice/dave.py index d9c0c39..072354c 100644 --- a/discord_http/voice/dave.py +++ b/discord_http/voice/dave.py @@ -340,7 +340,7 @@ async def _handle_proposals(self, payload: bytes) -> None: except AttributeError: return except Exception as exc: - _log.warning(f"Failed to process MLS proposals: {exc}") + _log.debug(f"Failed to process MLS proposals: {exc}") await self._recover_from_invalid_commit() return @@ -367,7 +367,7 @@ async def _handle_commit(self, payload: bytes) -> None: except AttributeError: return except Exception as exc: - _log.warning(f"Failed to process MLS commit: {exc}") + _log.debug(f"Failed to process MLS commit: {exc}") await self._recover_from_invalid_commit() return @@ -393,7 +393,7 @@ async def _handle_welcome(self, payload: bytes) -> None: except AttributeError: return except Exception as exc: - _log.warning(f"Failed to process MLS welcome: {exc}") + _log.debug(f"Failed to process MLS welcome: {exc}") await self._recover_from_invalid_commit() return diff --git a/discord_http/voice/socket.py b/discord_http/voice/socket.py index e09b232..4bdbc2d 100644 --- a/discord_http/voice/socket.py +++ b/discord_http/voice/socket.py @@ -147,7 +147,7 @@ async def _receive_loop(self) -> None: break elif msg.type is WSMsgType.ERROR: - _log.warning(f"Voice socket for guild {self.connection.guild_id} received error: {msg.data}") + _log.debug(f"Voice socket for guild {self.connection.guild_id} received error: {msg.data}") break except asyncio.CancelledError: