From 6769a0aece9a37e9816b14e81c5593fbd06048e3 Mon Sep 17 00:00:00 2001 From: Sami Rusani Date: Sun, 12 Apr 2026 17:24:58 +0200 Subject: [PATCH] fix(matrix): add outbound mention payloads --- gateway/platforms/matrix.py | 115 ++++++++++++++++++++------- tests/gateway/test_matrix_mention.py | 77 ++++++++++++++++++ 2 files changed, 163 insertions(+), 29 deletions(-) diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index 7f719b525..06d222fe8 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -114,6 +114,9 @@ _CRYPTO_DB_PATH = _STORE_DIR / "crypto.db" # Grace period: ignore messages older than this many seconds before startup. _STARTUP_GRACE_SECONDS = 5 +_OUTBOUND_MENTION_RE = re.compile( + r"(? Dict[str, Any]: + """Build Matrix text content with HTML and outbound mention metadata.""" + msg_content: Dict[str, Any] = {"msgtype": msgtype, "body": text} + mention_user_ids = self._extract_outbound_mentions(text) + if mention_user_ids: + msg_content["m.mentions"] = {"user_ids": mention_user_ids} + + html_source = self._inject_outbound_mention_links(text) + html = self._markdown_to_html(html_source) + if html and html != text: + msg_content["format"] = "org.matrix.custom.html" + msg_content["formatted_body"] = html + + return msg_content + + def _extract_outbound_mentions(self, text: str) -> list[str]: + """Return unique Matrix user IDs mentioned in outbound text.""" + protected, _ = self._protect_outbound_mention_regions(text) + seen: Set[str] = set() + mentions: list[str] = [] + for match in _OUTBOUND_MENTION_RE.finditer(protected): + user_id = match.group(1) + if user_id not in seen: + seen.add(user_id) + mentions.append(user_id) + return mentions + + def _inject_outbound_mention_links(self, text: str) -> str: + """Wrap outbound Matrix mentions in markdown links outside code spans.""" + if not text: + return text + + protected, placeholders = self._protect_outbound_mention_regions(text) + + linked = _OUTBOUND_MENTION_RE.sub( + lambda match: f"[{match.group(1)}](https://matrix.to/#/{match.group(1)})", + protected, + ) + + for idx, original in enumerate(placeholders): + linked = linked.replace(f"\x00MENTION_PROTECTED{idx}\x00", original) + + return linked + + def _protect_outbound_mention_regions(self, text: str) -> tuple[str, list[str]]: + """Protect markdown regions where outbound mentions should stay literal.""" + placeholders: list[str] = [] + + def _protect(fragment: str) -> str: + idx = len(placeholders) + placeholders.append(fragment) + return f"\x00MENTION_PROTECTED{idx}\x00" + + protected = re.sub( + r"```[\s\S]*?```", + lambda match: _protect(match.group(0)), + text or "", + ) + protected = re.sub( + r"`[^`\n]+`", + lambda match: _protect(match.group(0)), + protected, + ) + protected = re.sub( + r"\[[^\]]+\]\([^)]+\)", + lambda match: _protect(match.group(0)), + protected, + ) + + return protected, placeholders + def _is_bot_mentioned( self, body: str, diff --git a/tests/gateway/test_matrix_mention.py b/tests/gateway/test_matrix_mention.py index 3809c33fc..ff4032505 100644 --- a/tests/gateway/test_matrix_mention.py +++ b/tests/gateway/test_matrix_mention.py @@ -173,6 +173,83 @@ class TestStripMention: assert result == "" +# --------------------------------------------------------------------------- +# Outbound mention payloads +# --------------------------------------------------------------------------- + + +class TestOutboundMentions: + def setup_method(self): + self.adapter = _make_adapter() + self.mock_client = MagicMock() + self.mock_client.send_message_event = AsyncMock(return_value="$evt1") + self.adapter._client = self.mock_client + + @staticmethod + def _sent_content(mock_client): + call_args = mock_client.send_message_event.call_args + return call_args.args[2] if len(call_args.args) > 2 else call_args.kwargs["content"] + + @pytest.mark.asyncio + async def test_send_adds_matrix_mentions_and_formatted_body(self): + result = await self.adapter.send( + "!room1:example.org", + "Hello @alice:example.org, please check this.", + ) + + assert result.success is True + content = self._sent_content(self.mock_client) + assert content["m.mentions"] == {"user_ids": ["@alice:example.org"]} + assert content["formatted_body"] == ( + 'Hello ' + "@alice:example.org, please check this." + ) + + @pytest.mark.asyncio + async def test_send_dedupes_mentions_and_ignores_code_spans(self): + await self.adapter.send( + "!room1:example.org", + "Ping @alice:example.org and @alice:example.org, not `@code:example.org`.", + ) + + content = self._sent_content(self.mock_client) + assert content["m.mentions"] == {"user_ids": ["@alice:example.org"]} + assert "@code:example.org" not in content["formatted_body"] + + @pytest.mark.asyncio + async def test_edit_message_preserves_mentions(self): + result = await self.adapter.edit_message( + "!room1:example.org", + "$original", + "Updated for @alice:example.org", + ) + + assert result.success is True + content = self._sent_content(self.mock_client) + assert content["m.mentions"] == {"user_ids": ["@alice:example.org"]} + assert content["m.new_content"]["m.mentions"] == {"user_ids": ["@alice:example.org"]} + assert content["m.new_content"]["formatted_body"] == ( + 'Updated for ' + "@alice:example.org" + ) + assert content["formatted_body"] == ( + '* Updated for ' + "@alice:example.org" + ) + + @pytest.mark.asyncio + async def test_send_notice_adds_mentions(self): + result = await self.adapter.send_notice( + "!room1:example.org", + "Heads up @alice:example.org", + ) + + assert result.success is True + content = self._sent_content(self.mock_client) + assert content["msgtype"] == "m.notice" + assert content["m.mentions"] == {"user_ids": ["@alice:example.org"]} + + # --------------------------------------------------------------------------- # Require-mention gating in _on_room_message # ---------------------------------------------------------------------------