From b0e88a1dbc1df516d1ad724a2227cf0588338fd2 Mon Sep 17 00:00:00 2001 From: Florian Date: Mon, 13 Oct 2025 17:48:28 +0200 Subject: [PATCH] Progress! --- .gitignore | 2 +- Dockerfile | 2 +- src/db.py | 32 +++++++++++------------------- src/rabbitmq_handler.py | 34 ++++++++++++++++++++++++-------- src/secret_handler.py | 42 ++++++++++++++++++++++++++++++++++++++-- src/send_notification.py | 35 ++++++--------------------------- 6 files changed, 85 insertions(+), 62 deletions(-) diff --git a/.gitignore b/.gitignore index 0dbf2f2..f587d29 100644 --- a/.gitignore +++ b/.gitignore @@ -167,4 +167,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ - +test.py diff --git a/Dockerfile b/Dockerfile index a54e052..d90c0bd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,4 +8,4 @@ WORKDIR /app COPY src/ /app/ -ENTRYPOINT ["python","main.py"] +ENTRYPOINT ["python","rabbitmq_handler.py"] diff --git a/src/db.py b/src/db.py index 8a22a68..f598269 100644 --- a/src/db.py +++ b/src/db.py @@ -1,3 +1,4 @@ +from contextlib import asynccontextmanager import aiomysql import asyncio from secret_handler import return_credentials @@ -51,32 +52,19 @@ class DBManager: logger.debug("[DB] Healthcheck OK") except Exception as e: logger.warning(f"[DB] Healthcheck failed: {e}") - + + @asynccontextmanager async def acquire(self): - if not self._pool: - raise RuntimeError("DB pool not initialized") - return await self._pool.acquire() + conn = await self._pool.acquire() + try: + yield conn + finally: + self._pool.release(conn) async def release(self, conn): if self._pool: self._pool.release(conn) - async def execute(self, query, *args, retries=3): - for attempt in range(1, retries + 1): - conn = await self.acquire() - try: - async with conn.cursor() as cur: - await cur.execute(query, args) - if cur.description: - return await cur.fetchall() - return None - except aiomysql.OperationalError as e: - logger.warning(f"[DB] Query failed (attempt {attempt}/{retries}): {e}") - await asyncio.sleep(2 ** (attempt - 1)) - finally: - await self.release(conn) - raise RuntimeError("DB query failed after retries") - async def close(self): self._closing = True if self._health_task and not self._health_task.done(): @@ -90,4 +78,6 @@ class DBManager: await self._pool.wait_closed() logger.info("[DB] Connection pool closed") -db_manager = DBManager(host=db_host, port=3306, user=db_username, password=db_password, db=db_database) \ No newline at end of file +#db_manager = DBManager(host=db_host, port=3306, user=db_username, password=db_password, db=db_database) +db_manager = DBManager(host=db_host, port=30006, user=db_username, password=db_password, db=db_database) + diff --git a/src/rabbitmq_handler.py b/src/rabbitmq_handler.py index 981d8cf..fda7064 100644 --- a/src/rabbitmq_handler.py +++ b/src/rabbitmq_handler.py @@ -1,21 +1,22 @@ import asyncio import aio_pika from aio_pika.exceptions import AMQPException -from secret_handler import return_credentials +from secret_handler import return_credentials, database_lookup, decrypt_token import os from logger_handler import setup_logger import json from db import db_manager from send_notification import send_notification +import signal logger = setup_logger(__name__) rmq_username = return_credentials("/etc/secrets/rmq_username") rmq_password = return_credentials("/etc/secrets/rmq_password") -rmq_host = os.getenv("BACKEND_API_INTERNAL_RMQ_HOST", "localhost") -rmq_vhost = os.getenv("BACKEND_API_INTERNAL_RMQ_VHOST", "app_notifications") -rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE", "app_notifications") +rmq_host = os.getenv("BACKEND_PN_RMQ_HOST", "localhost:30672") +rmq_vhost = os.getenv("BACKEND_PN_RMQ_VHOST", "app_notifications") +rmq_exchange = os.getenv("BACKEND_PN_RMQ_EXCHANGE", "app_notifications") RABBITMQ_URL = f"amqp://{rmq_username}:{rmq_password}@{rmq_host}/{rmq_vhost}" @@ -37,7 +38,7 @@ class RabbitMQConsumer: self.exchange = await self.channel.declare_exchange( self.exchange_name, aio_pika.ExchangeType.TOPIC, durable=True ) - self.queue = await self.channel.declare_queue("backend_push_notifications", durable=True) + self.queue = await self.channel.declare_queue("notifications", durable=True) await self.queue.bind(self.exchange, routing_key="notify.user.*") logger.info("[Consumer] Connected, queue bound to notify.user.*") @@ -48,7 +49,18 @@ class RabbitMQConsumer: try: data = json.loads(message.body.decode()) logger.info(f"[Consumer] Received: {data}") - await send_notification(routing_key=message.routing_key,message=data,db_manager=self.db_manager) + logger.info(message.routing_key) + encrypted_tokens = await database_lookup(message.routing_key, db_manager) + if not encrypted_tokens: + logger.warning(f"No push tokens found for user {message.routing_key}") + return + + token_map = {row["uuid"]: row["token"].decode() for row in encrypted_tokens} + for uuid, token in token_map.items(): + decrypted_token = decrypt_token(token) + token_map[uuid] = decrypted_token + await send_notification(message=data,push_tokens=token_map) + except json.JSONDecodeError as e: logger.error(f"[Consumer] Bad message, discarding: {e}") await message.nack(requeue=False) @@ -76,11 +88,17 @@ class RabbitMQConsumer: async def main(): + await db_manager.connect() consumer = RabbitMQConsumer(db_manager=db_manager) await consumer.connect() - logger.info("Creating MySQL connection pool...") - await db_manager.connect() + await consumer.consume() + stop_event = asyncio.Event() + for sig in (signal.SIGINT, signal.SIGTERM): + asyncio.get_running_loop().add_signal_handler(sig, stop_event.set) + await stop_event.wait() + await consumer.close() + await db_manager.close() if __name__ == "__main__": asyncio.run(main()) diff --git a/src/secret_handler.py b/src/secret_handler.py index 33d66a5..d9cd8f4 100644 --- a/src/secret_handler.py +++ b/src/secret_handler.py @@ -1,12 +1,50 @@ +from cryptography.fernet import Fernet import sys +import asyncio +from logger_handler import setup_logger + +logger = setup_logger(__name__) + +try: + with open("/etc/secrets/encryption_key","rb") as file: + encryption_key = file.read() +except FileNotFoundError: + logger.fatal("[Secret Handler] Encryption key not found") + sys.exit(1) +except Exception as e: + logger.fatal(f"[Secret Handler] Failed to read encryption key: {e}") + sys.exit(1) + +fernet = Fernet(encryption_key) + +def encrypt_token(token:str)->str: + return fernet.encrypt(token.encode()).decode() + +def decrypt_token(token:str)->str: + return fernet.decrypt(token.encode()).decode() def return_credentials(path: str)->str: try: with open (path) as file: return file.read().strip() except FileNotFoundError: - print(f"[FATAL] Secret file not found: {path}") + logger.fatal(f"[Secret Handler] Secret file not found: {path}") sys.exit(1) except Exception as e: - print(f"[FATAL] Failed to read secret file {path}: {e}") + logger.fatal(f"[Secret Handler] Failed to read secret file {path}: {e}") sys.exit(1) + +async def database_lookup(routing_key: str, db_manager): + try: + user_id = int(routing_key.split('.')[-1]) + except ValueError: + logger.error(f"[DB] Invalid user id supplied:{routing_key}") + return [] + + async with db_manager.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT token_id AS uuid,token FROM device_tokens WHERE user_id=%s", + (user_id,)) + if cur.description: + return await cur.fetchall() + return [] \ No newline at end of file diff --git a/src/send_notification.py b/src/send_notification.py index 90b63cc..27dd8d6 100644 --- a/src/send_notification.py +++ b/src/send_notification.py @@ -2,27 +2,20 @@ import aiohttp import asyncio from logger_handler import setup_logger -API_ENDPOINT="https://exp.host/fakeUSer/api/v2/push/send" +#API_ENDPOINT="https://exp.host/fakeUSer/api/v2/push/send" +API_ENDPOINT="http://127.0.0.1:8000/honk" logger = setup_logger(__name__) - - - async def send_notification( - routing_key: str, message: dict, - db_manager, + push_tokens, max_retries: int = 5, timeout: int = 5, ): - push_tokens = await database_lookup(routing_key, db_manager) - if not push_tokens: - logger.warning(f"No push tokens found for user {routing_key}") - return results = {} - for token, uuid in push_tokens: - results[token] = await _send_to_token(token, uuid, message, max_retries, timeout) + for uuid, token in push_tokens.items(): + results[uuid] = await _send_to_token(token, uuid, message, max_retries, timeout) return results @@ -38,7 +31,7 @@ async def _send_to_token(token: str, uuid:str , message: dict, max_retries: int, headers={"Content-Type": "application/json"}, timeout=timeout ) as response: - await response.raise_for_status() + response.raise_for_status() logger.info(f"Notification sent successfully to uuid {uuid}") return {"status": "ok"} @@ -65,19 +58,3 @@ def create_payload(push_token: str, message: dict) -> dict: "sound": "default", "priority": "high" } - -async def database_lookup(routing_key: str, db_manager): - try: - user_id = int(routing_key.split('.')[-1]) - except ValueError: - logger.error(f"[DB] Invalid user id supplied:{routing_key}") - return [] - - async with db_manager.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute("SELECT tokend_id AS uuid,token FROM device_tokens WHERE user_id=%s", - (user_id,)) - if cur.description: - return await cur.fetchall() - return [] -