fix(gateway): preserve WhatsApp pairing approvals across JID/LID alias flips

This commit is contained in:
QuenVix
2026-05-23 10:25:35 +03:00
committed by Teknium
parent 3127a41cb1
commit 52a368fa72
2 changed files with 136 additions and 10 deletions

View File

@@ -28,6 +28,10 @@ import time
from pathlib import Path
from typing import Optional
from gateway.whatsapp_identity import (
expand_whatsapp_aliases,
normalize_whatsapp_identifier,
)
from hermes_constants import get_hermes_dir
from utils import atomic_replace
@@ -110,12 +114,40 @@ class PairingStore:
def _save_json(self, path: Path, data: dict) -> None:
_secure_write(path, json.dumps(data, indent=2, ensure_ascii=False))
def _normalize_user_id(self, platform: str, user_id: str) -> str:
"""Normalize platform-specific user IDs before persisting them."""
raw_user_id = str(user_id or "").strip()
if platform == "whatsapp":
return normalize_whatsapp_identifier(raw_user_id) or raw_user_id
return raw_user_id
def _user_id_aliases(self, platform: str, user_id: str) -> set[str]:
"""Return all known equivalent user IDs for auth/rate-limit checks."""
raw_user_id = str(user_id or "").strip()
if not raw_user_id:
return set()
aliases = {raw_user_id, self._normalize_user_id(platform, raw_user_id)}
if platform == "whatsapp":
aliases.update(expand_whatsapp_aliases(raw_user_id))
aliases.discard("")
return aliases
def _user_ids_match(self, platform: str, left: str, right: str) -> bool:
"""Return True when two user IDs represent the same principal."""
left_aliases = self._user_id_aliases(platform, left)
right_aliases = self._user_id_aliases(platform, right)
return bool(left_aliases and right_aliases and (left_aliases & right_aliases))
# ----- Approved users -----
def is_approved(self, platform: str, user_id: str) -> bool:
"""Check if a user is approved (paired) on a platform."""
approved = self._load_json(self._approved_path(platform))
return user_id in approved
for approved_user_id in approved:
if self._user_ids_match(platform, approved_user_id, user_id):
return True
return False
def list_approved(self, platform: str = None) -> list:
"""List approved users, optionally filtered by platform."""
@@ -130,7 +162,16 @@ class PairingStore:
def _approve_user(self, platform: str, user_id: str, user_name: str = "") -> None:
"""Add a user to the approved list. Must be called under self._lock."""
approved = self._load_json(self._approved_path(platform))
approved[user_id] = {
normalized_user_id = self._normalize_user_id(platform, user_id)
duplicate_ids = [
approved_user_id
for approved_user_id in approved
if self._user_ids_match(platform, approved_user_id, normalized_user_id)
]
for approved_user_id in duplicate_ids:
del approved[approved_user_id]
approved[normalized_user_id] = {
"user_name": user_name,
"approved_at": time.time(),
}
@@ -141,8 +182,14 @@ class PairingStore:
path = self._approved_path(platform)
with self._lock:
approved = self._load_json(path)
if user_id in approved:
del approved[user_id]
matching_ids = [
approved_user_id
for approved_user_id in approved
if self._user_ids_match(platform, approved_user_id, user_id)
]
if matching_ids:
for approved_user_id in matching_ids:
del approved[approved_user_id]
self._save_json(path, approved)
return True
return False
@@ -170,6 +217,7 @@ class PairingStore:
"""
with self._lock:
self._cleanup_expired(platform)
normalized_user_id = self._normalize_user_id(platform, user_id)
# Check lockout
if self._is_locked_out(platform):
@@ -198,7 +246,7 @@ class PairingStore:
pending[entry_id] = {
"hash": code_hash,
"salt": salt.hex(),
"user_id": user_id,
"user_id": normalized_user_id,
"user_name": user_name,
"created_at": time.time(),
}
@@ -326,15 +374,20 @@ class PairingStore:
def _is_rate_limited(self, platform: str, user_id: str) -> bool:
"""Check if a user has requested a code too recently."""
limits = self._load_json(self._rate_limit_path())
key = f"{platform}:{user_id}"
last_request = limits.get(key, 0)
return (time.time() - last_request) < RATE_LIMIT_SECONDS
for alias in self._user_id_aliases(platform, user_id):
key = f"{platform}:{alias}"
last_request = limits.get(key, 0)
if (time.time() - last_request) < RATE_LIMIT_SECONDS:
return True
return False
def _record_rate_limit(self, platform: str, user_id: str) -> None:
"""Record the time of a pairing request for rate limiting."""
limits = self._load_json(self._rate_limit_path())
key = f"{platform}:{user_id}"
limits[key] = time.time()
now = time.time()
for alias in self._user_id_aliases(platform, user_id):
key = f"{platform}:{alias}"
limits[key] = now
self._save_json(self._rate_limit_path(), limits)
def _is_locked_out(self, platform: str) -> bool: