diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 273251385..8cb4f7c0e 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -1025,7 +1025,20 @@ class BasePlatformAdapter(ABC): self._post_delivery_callbacks: Dict[str, Any] = {} self._expected_cancelled_tasks: set[asyncio.Task] = set() self._busy_session_handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]] = None - # Chats where auto-TTS on voice input is disabled (set by /voice off) + # Auto-TTS on voice input: ``_auto_tts_default`` is the global default + # (``voice.auto_tts`` in config.yaml, pushed by GatewayRunner on connect). + # Per-chat overrides live in two sets populated from ``_voice_mode``: + # - ``_auto_tts_enabled_chats``: chat explicitly opted in via ``/voice on`` + # or ``/voice tts`` (mode is ``voice_only`` or ``all``). Fires even when + # the global default is False. + # - ``_auto_tts_disabled_chats``: chat explicitly opted out via + # ``/voice off`` (mode is ``off``). Suppresses auto-TTS even when the + # global default is True. + # The gate in _process_message() is: + # fire if chat in _auto_tts_enabled_chats + # OR (_auto_tts_default and chat not in _auto_tts_disabled_chats) + self._auto_tts_default: bool = False + self._auto_tts_enabled_chats: set = set() self._auto_tts_disabled_chats: set = set() # Chats where typing indicator is paused (e.g. during approval waits). # _keep_typing skips send_typing when the chat_id is in this set. @@ -1047,6 +1060,21 @@ class BasePlatformAdapter(ABC): def fatal_error_retryable(self) -> bool: return self._fatal_error_retryable + def _should_auto_tts_for_chat(self, chat_id: str) -> bool: + """Whether auto-TTS on voice input should fire for ``chat_id``. + + Decision layers (Issue #16007): + 1. Explicit ``/voice on`` or ``/voice tts`` → always fire (even if + ``voice.auto_tts`` is False). + 2. Explicit ``/voice off`` → never fire. + 3. Fall back to the global ``voice.auto_tts`` config default. + """ + if chat_id in self._auto_tts_enabled_chats: + return True + if chat_id in self._auto_tts_disabled_chats: + return False + return bool(self._auto_tts_default) + def set_fatal_error_handler(self, handler: Callable[["BasePlatformAdapter"], Awaitable[None] | None]) -> None: self._fatal_error_handler = handler @@ -2214,12 +2242,14 @@ class BasePlatformAdapter(ABC): logger.info("[%s] extract_local_files found %d file(s) in response", self.name, len(local_files)) # Auto-TTS: if voice message, generate audio FIRST (before sending text) - # Skipped when the chat has voice mode disabled (/voice off) + # Gated via ``_should_auto_tts_for_chat``: fires when the chat has + # an explicit ``/voice on|tts`` opt-in OR when ``voice.auto_tts`` is + # True globally and no ``/voice off`` has been issued. _tts_path = None - if (event.message_type == MessageType.VOICE + if (self._should_auto_tts_for_chat(event.source.chat_id) + and event.message_type == MessageType.VOICE and text_content - and not media_files - and event.source.chat_id not in self._auto_tts_disabled_chats): + and not media_files): try: from tools.tts_tool import text_to_speech_tool, check_tts_requirements if check_tts_requirements(): diff --git a/gateway/run.py b/gateway/run.py index f1aafcdf3..497d9241c 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -881,23 +881,74 @@ class GatewayRunner: return if disabled: disabled_chats.add(chat_id) + # ``/voice off`` also clears any explicit enable — it's a hard override. + enabled_chats = getattr(adapter, "_auto_tts_enabled_chats", None) + if isinstance(enabled_chats, set): + enabled_chats.discard(chat_id) else: disabled_chats.discard(chat_id) - def _sync_voice_mode_state_to_adapter(self, adapter) -> None: - """Restore persisted /voice off state into a live platform adapter.""" - disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None) - if not isinstance(disabled_chats, set): + def _set_adapter_auto_tts_enabled(self, adapter, chat_id: str, enabled: bool) -> None: + """Update an adapter's per-chat auto-TTS opt-in set if present. + + Used for ``/voice on``/``/voice tts`` where the user explicitly wants + auto-TTS even when ``voice.auto_tts`` is False globally. + """ + enabled_chats = getattr(adapter, "_auto_tts_enabled_chats", None) + if not isinstance(enabled_chats, set): return + if enabled: + enabled_chats.add(chat_id) + # An explicit opt-in clears any stale /voice off for this chat. + disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None) + if isinstance(disabled_chats, set): + disabled_chats.discard(chat_id) + else: + enabled_chats.discard(chat_id) + + def _sync_voice_mode_state_to_adapter(self, adapter) -> None: + """Restore persisted /voice state into a live platform adapter. + + Populates three fields from config + ``self._voice_mode``: + - ``_auto_tts_default``: global default from ``voice.auto_tts`` + - ``_auto_tts_enabled_chats``: chats with mode ``voice_only``/``all`` + - ``_auto_tts_disabled_chats``: chats with mode ``off`` + """ platform = getattr(adapter, "platform", None) if not isinstance(platform, Platform): return - disabled_chats.clear() + + disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None) + enabled_chats = getattr(adapter, "_auto_tts_enabled_chats", None) + if not isinstance(disabled_chats, set) and not isinstance(enabled_chats, set): + return + + # Push the global voice.auto_tts default (config.yaml) onto the adapter. + # Lazy import to avoid adding a module-level dep from gateway → hermes_cli. + try: + from hermes_cli.config import load_config as _load_full_config + _full_cfg = _load_full_config() + _auto_tts_default = bool( + (_full_cfg.get("voice") or {}).get("auto_tts", False) + ) + except Exception: + _auto_tts_default = False + if hasattr(adapter, "_auto_tts_default"): + adapter._auto_tts_default = _auto_tts_default + prefix = f"{platform.value}:" - disabled_chats.update( - key[len(prefix):] for key, mode in self._voice_mode.items() - if mode == "off" and key.startswith(prefix) - ) + if isinstance(disabled_chats, set): + disabled_chats.clear() + disabled_chats.update( + key[len(prefix):] for key, mode in self._voice_mode.items() + if mode == "off" and key.startswith(prefix) + ) + if isinstance(enabled_chats, set): + enabled_chats.clear() + enabled_chats.update( + key[len(prefix):] for key, mode in self._voice_mode.items() + if mode in ("voice_only", "all") and key.startswith(prefix) + ) async def _safe_adapter_disconnect(self, adapter, platform) -> None: """Call adapter.disconnect() defensively, swallowing any error. @@ -5977,7 +6028,7 @@ class GatewayRunner: self._voice_mode[voice_key] = "voice_only" self._save_voice_modes() if adapter: - self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False) + self._set_adapter_auto_tts_enabled(adapter, chat_id, enabled=True) return ( "Voice mode enabled.\n" "I'll reply with voice when you send voice messages.\n" @@ -5993,7 +6044,7 @@ class GatewayRunner: self._voice_mode[voice_key] = "all" self._save_voice_modes() if adapter: - self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False) + self._set_adapter_auto_tts_enabled(adapter, chat_id, enabled=True) return ( "Auto-TTS enabled.\n" "All replies will include a voice message." @@ -6032,7 +6083,7 @@ class GatewayRunner: self._voice_mode[voice_key] = "voice_only" self._save_voice_modes() if adapter: - self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False) + self._set_adapter_auto_tts_enabled(adapter, chat_id, enabled=True) return "Voice mode enabled." else: self._voice_mode[voice_key] = "off" @@ -6083,7 +6134,7 @@ class GatewayRunner: adapter._voice_sources[guild_id] = event.source.to_dict() self._voice_mode[self._voice_key(event.source.platform, event.source.chat_id)] = "all" self._save_voice_modes() - self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=False) + self._set_adapter_auto_tts_enabled(adapter, event.source.chat_id, enabled=True) return ( f"Joined voice channel **{voice_channel.name}**.\n" f"I'll speak my replies and listen to you. Use /voice leave to disconnect." diff --git a/tests/gateway/test_voice_command.py b/tests/gateway/test_voice_command.py index ed36b976e..2e9c54608 100644 --- a/tests/gateway/test_voice_command.py +++ b/tests/gateway/test_voice_command.py @@ -177,6 +177,53 @@ class TestHandleVoiceCommand: assert adapter._auto_tts_disabled_chats == {"123"} + def test_sync_populates_enabled_chats_from_voice_modes(self, runner): + """Issue #16007: sync also restores per-chat /voice on|tts opt-ins. + + The adapter's ``_auto_tts_enabled_chats`` must mirror chats whose + persisted voice_mode is ``voice_only`` or ``all`` — without this, + ``/voice on`` was relying on a "not in disabled set" default that + silently enabled auto-TTS for every chat. + """ + from gateway.config import Platform + runner._voice_mode = { + "telegram:off_chat": "off", + "telegram:on_chat": "voice_only", + "telegram:tts_chat": "all", + "slack:999": "voice_only", # wrong platform, must be ignored + } + adapter = SimpleNamespace( + _auto_tts_default=False, + _auto_tts_disabled_chats=set(), + _auto_tts_enabled_chats=set(), + platform=Platform.TELEGRAM, + ) + + runner._sync_voice_mode_state_to_adapter(adapter) + + assert adapter._auto_tts_disabled_chats == {"off_chat"} + assert adapter._auto_tts_enabled_chats == {"on_chat", "tts_chat"} + + def test_sync_pushes_config_default_onto_adapter(self, runner, monkeypatch): + """Issue #16007: ``voice.auto_tts`` must propagate to ``_auto_tts_default``.""" + from gateway.config import Platform + + fake_cfg = {"voice": {"auto_tts": True}} + monkeypatch.setattr( + "hermes_cli.config.load_config", + lambda: fake_cfg, + ) + adapter = SimpleNamespace( + _auto_tts_default=False, + _auto_tts_disabled_chats=set(), + _auto_tts_enabled_chats=set(), + platform=Platform.TELEGRAM, + ) + + runner._sync_voice_mode_state_to_adapter(adapter) + + assert adapter._auto_tts_default is True + def test_restart_restores_voice_off_state(self, runner, tmp_path): from gateway.config import Platform runner._VOICE_MODE_PATH.write_text(json.dumps({"telegram:123": "off"})) @@ -2706,3 +2753,56 @@ class TestUDPKeepalive: mock_conn.send_packet.assert_called_with(b'\xf8\xff\xfe') finally: DiscordAdapter._KEEPALIVE_INTERVAL = original_interval + + +# ===================================================================== +# BasePlatformAdapter._should_auto_tts_for_chat — gate for auto-TTS +# on voice input. Regression test for Issue #16007. +# ===================================================================== + +class TestShouldAutoTtsForChat: + """Three-layer gate: per-chat enable > per-chat disable > config default.""" + + def _make_adapter(self, *, default: bool, enabled=(), disabled=()): + """Build a bare adapter with only the attrs the gate reads.""" + adapter = SimpleNamespace( + _auto_tts_default=default, + _auto_tts_enabled_chats=set(enabled), + _auto_tts_disabled_chats=set(disabled), + ) + # Bind the unbound method — _should_auto_tts_for_chat only reads the + # three attrs above via ``self.``, so an unbound call works. + from gateway.platforms.base import BasePlatformAdapter + return BasePlatformAdapter._should_auto_tts_for_chat, adapter + + def test_default_false_no_override_suppresses(self): + """Issue #16007: voice.auto_tts=False and no per-chat state → no TTS.""" + fn, adapter = self._make_adapter(default=False) + assert fn(adapter, "chat1") is False + + def test_default_true_no_override_fires(self): + fn, adapter = self._make_adapter(default=True) + assert fn(adapter, "chat1") is True + + def test_explicit_enable_overrides_false_default(self): + """``/voice on`` with config auto_tts=False still fires.""" + fn, adapter = self._make_adapter(default=False, enabled={"chat1"}) + assert fn(adapter, "chat1") is True + + def test_explicit_disable_overrides_true_default(self): + """``/voice off`` with config auto_tts=True still suppresses.""" + fn, adapter = self._make_adapter(default=True, disabled={"chat1"}) + assert fn(adapter, "chat1") is False + + def test_enabled_wins_over_disabled(self): + """An explicit enable beats an explicit disable (enable takes priority).""" + fn, adapter = self._make_adapter( + default=False, enabled={"chat1"}, disabled={"chat1"} + ) + assert fn(adapter, "chat1") is True + + def test_per_chat_isolation(self): + """Enable for chat1 doesn't leak to chat2.""" + fn, adapter = self._make_adapter(default=False, enabled={"chat1"}) + assert fn(adapter, "chat1") is True + assert fn(adapter, "chat2") is False