From 873e7e7854c657349935d8a521d75899411328c7 Mon Sep 17 00:00:00 2001 From: Florian Date: Sun, 12 Oct 2025 16:54:08 +0200 Subject: [PATCH] Improved RabbitMQ handling - Switched to async because it offers easy automatic reconnect features on connection failure - Adjusted notification path to reflect that change - Added in memory message queue if RMQ is not reachable --- requirements.txt | 11 +++- src/db.py | 4 +- src/main.py | 69 ++++++--------------- src/rabbitmq_handler.py | 134 ++++++++++++++++++++++------------------ src/validator.py | 6 +- 5 files changed, 104 insertions(+), 120 deletions(-) diff --git a/requirements.txt b/requirements.txt index 177a42f..c5a5175 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,19 @@ +aio-pika==9.5.7 +aiormq==6.9.0 annotated-types==0.7.0 anyio==4.11.0 argon2-cffi==25.1.0 argon2-cffi-bindings==25.1.0 cffi==2.0.0 click==8.3.0 -fastapi==0.118.0 +fastapi==0.119.0 h11==0.16.0 -idna==3.10 +idna==3.11 +multidict==6.7.0 mysql-connector-python==9.4.0 -pika==1.3.2 +pamqp==3.3.0 prometheus_client==0.23.1 +propcache==0.4.1 pycparser==2.23 pydantic==2.12.0 pydantic_core==2.41.1 @@ -18,3 +22,4 @@ starlette==0.48.0 typing-inspection==0.4.2 typing_extensions==4.15.0 uvicorn==0.37.0 +yarl==1.22.0 diff --git a/src/db.py b/src/db.py index 8dada75..8f70f85 100644 --- a/src/db.py +++ b/src/db.py @@ -1,4 +1,3 @@ -import mysql.connector from mysql.connector import pooling, Error import threading from secret_handler import return_credentials @@ -41,7 +40,7 @@ def create_connection_pool(): try: logger.info(f"[MySQL] Attempt {attempt} to connect...") pool = pooling.MySQLConnectionPool( - pool_name="royalroadPool", + pool_name="ApiInternalPool", pool_size=5, pool_reset_session=True, **MYSQL_CONFIG @@ -103,7 +102,6 @@ def _pool_healthcheck(): conn = pool.get_connection() conn.ping(reconnect=True, attempts=3, delay=1) conn.close() - logger.debug("[MySQL] Pool healthcheck OK.") except Error as e: logger.warning(f"[MySQL] Pool healthcheck failed: {e}") create_connection_pool() diff --git a/src/main.py b/src/main.py index 3dede5d..635a060 100644 --- a/src/main.py +++ b/src/main.py @@ -7,7 +7,7 @@ from pydantic import BaseModel from validator import verify_api_key from db import get_db, create_connection_pool, close_connection_pool, start_healthcheck_thread from logger_handler import setup_logger -from rabbitmq_handler import send_message_to_rmq, create_connection, close_connection +from rabbitmq_handler import RabbitMQProducer import uvicorn from uvicorn_logging_config import LOGGING_CONFIG from contextlib import asynccontextmanager @@ -15,7 +15,7 @@ from metrics_server import REQUEST_COUNTER import asyncio logger = setup_logger(__name__) - +producer = RabbitMQProducer() api_key_header_internal = APIKeyHeader(name="X-API-Key-Internal") class Notification(BaseModel): @@ -30,24 +30,21 @@ async def lifespan(app: FastAPI): logger.info("Creating MySQL connection pool...") create_connection_pool() - logger.info("Connecting to RabbitMQ...") - app.state.rmq_connection = create_connection() - start_healthcheck_thread() logger.info("MySQL healthcheck thread started.") + logger.info("Starting RabbitMQ producer...") + await producer.connect() + app.state.rmq_producer = producer + logger.info("[FastAPI] RabbitMQ producer initialized.") + yield - logger.info("Closing RabbitMQ connection...") - close_connection(app.state.rmq_connection) + logger.info("Closing RabbitMQ producer...") + await producer.close() + logger.info("Closing MySQL connection pool...") close_connection_pool() -def get_rmq_connection(app: FastAPI): - connection = getattr(app.state, "rmq_connection", None) - if not connection or not connection.is_open: - app.state.rmq_connection = create_connection() - return app.state.rmq_connection - api = FastAPI( title="Internal Notifier API", description="API to forward messages to RabbitMQ", @@ -87,54 +84,24 @@ async def custom_http_exception_handler(request,exc): content={"detail": exc.detail} ) -@api.get("/health", tags=["Health"]) -def return_health(request:Request, db=Depends(get_db)): - try: - cursor = db.cursor() - cursor.execute("SELECT 1") - cursor.fetchone() - db_status = "ok" - except Exception as e: - logger.error(f"Health check DB failed: {e}") - db_status = "error" - - try: - rmq_conn = getattr(request.app.state, "rmq_connection", None) - if not rmq_conn or not rmq_conn.is_open: - logger.error("Health check RMQ failed: connection closed or missing") - rmq_status = "error" - except Exception as e: - logger.error(f"Health check RMQ failed: {e}") - rmq_status = "error" - - overall_status = "ok" if db_status == "ok" and rmq_status == "ok" else "error" - status_code = 200 if overall_status == "ok" else 500 - - return JSONResponse( - status_code=status_code, - content={"status": overall_status, - "components": { - "database": db_status, - "rabbitmq": rmq_status - }, - "message": "Service is running" if overall_status == "ok" else "One or more checks failed"} - ) - @api.post("/internal/receive-notifications") -def receive_notifications( +async def receive_notifications( notification_data: Notification, request: Request, db = Depends(get_db), program_name: str = Depends(verify_api_key_dependency_internal) ): - rmq_connection = get_rmq_connection(request.app) + logger.info(f"Received notifcation data from {program_name} for RMQ") - send_message_to_rmq(rmq_connection,notification_data.receipent_user_id,notification_data.message) - logger.info("Successfully delivered message to RMQ") + await request.app.state.rmq_producer.publish( + notification_data.receipent_user_id, + notification_data.message + ) + logger.info("Successfully queued for delivery to RMQ") return {"status": "queued"} async def start_servers(): - config_main = uvicorn.Config("main:api", host="0.0.0.0", port=8101, log_config=LOGGING_CONFIG, log_level="info") + config_main = uvicorn.Config("main:api", host="0.0.0.0", port=8101, log_level="info") config_metrics = uvicorn.Config("metrics_server:metrics_api", host="0.0.0.0", port=9000, log_level="info") server_main = uvicorn.Server(config_main) diff --git a/src/rabbitmq_handler.py b/src/rabbitmq_handler.py index 68ea583..1aa4495 100644 --- a/src/rabbitmq_handler.py +++ b/src/rabbitmq_handler.py @@ -1,72 +1,86 @@ -import pika -from typing import Dict +import asyncio +import aio_pika from secret_handler import return_credentials -import json -import time -import sys import os +from logger_handler import setup_logger +import json + +logger = setup_logger(__name__) rmq_username = return_credentials("/etc/secrets/rmq_username") rmq_password = return_credentials("/etc/secrets/rmq_password") -rmq_host = os.getenv("BACKEND_API_INTERNAL_RMQ_HOST","localhost") -rmq_vhost = os.getenv("BACKEND_API_INTERNAL_RMQ_VHOST","app_notifications") -rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE","app_notifications") +rmq_host = os.getenv("BACKEND_API_INTERNAL_RMQ_HOST", "localhost") +rmq_vhost = os.getenv("BACKEND_API_INTERNAL_RMQ_VHOST", "app_notifications") +rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE", "app_notifications") -MAX_RETRIES = 5 -RETRY_DELAY = 5 +RABBITMQ_URL = f"amqp://{rmq_username}:{rmq_password}@{rmq_host}/{rmq_vhost}" -def create_connection(): - credentials = pika.PlainCredentials(username=rmq_username, password=rmq_password) - conn_params = pika.ConnectionParameters( - host=rmq_host, - port=5672, - credentials=credentials, - virtual_host=rmq_vhost - ) +class RabbitMQProducer: + def __init__(self, url=RABBITMQ_URL, exchange_name=rmq_exchange): + self.url = url + self.exchange_name = exchange_name + self.connection: aio_pika.RobustConnection | None = None + self.channel: aio_pika.RobustChannel | None = None + self.exchange: aio_pika.Exchange | None = None + self._queue: asyncio.Queue[tuple[int, dict]] = asyncio.Queue() + self._flush_task: asyncio.Task | None = None + self._closing = False + self._ready = asyncio.Event() - for attempt in range(1, MAX_RETRIES + 1): - try: - connection = pika.BlockingConnection(conn_params) - print("[RMQ] Connection established.") - return connection - except Exception as e: - print(f"[RMQ] Attempt {attempt} failed: {e}") - if attempt < MAX_RETRIES: - time.sleep(RETRY_DELAY) - else: - print("[RMQ] Failed to connect after maximum retries — exiting.") - sys.exit(1) + async def connect(self): + self.connection = await aio_pika.connect_robust(self.url) + self.channel = await self.connection.channel(publisher_confirms=True) + self.exchange = await self.channel.declare_exchange( + self.exchange_name, aio_pika.ExchangeType.TOPIC, durable=True + ) + logger.info(f"[aio-pika] Connected and exchange '{self.exchange_name}' ready.") + self._ready.set() + self._flush_task = asyncio.create_task(self._flush_queue_loop()) + + async def publish(self, routing_key: int, message: dict): + await self._queue.put((routing_key, message)) + logger.debug(f"[RabbitMQ] Queued message for {routing_key}") + + async def _flush_queue_loop(self): + await self._ready.wait() + logger.debug(f"here") + while True: + routing_key, message = await self._queue.get() + try: + await self.exchange.publish( + aio_pika.Message( + body=json.dumps(message).encode(), + content_type="application/json", + delivery_mode=aio_pika.DeliveryMode.PERSISTENT + ), + routing_key=f"notify.user.{routing_key}", + mandatory=True + ) + logger.debug(f"[aio-pika] Published message to notify.user.{routing_key}") + except Exception as e: + logger.warning(f"[aio-pika] Publish failed, requeuing: {e}") + await asyncio.sleep(2) + await self._queue.put((routing_key, message)) + finally: + self._queue.task_done() + + async def close(self): + self._closing = True + if self._flush_task and not self._flush_task.done(): + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + if self.connection: + await self.connection.close() + logger.info("[aio-pika] Connection closed.") -def send_message_to_rmq(connection, user_id: int, message: Dict): - if not connection or not connection.is_open: - print("[RMQ] Connection lost, reconnecting...") - connection = create_connection() +async def main(): + producer = RabbitMQProducer() + await producer.connect() + await producer.publish("1", {"type": "notification", "content": "Test message"}) + await producer.close() - channel = connection.channel() - channel.exchange_declare(exchange=rmq_exchange, exchange_type="topic", durable=True) - channel.confirm_delivery() - - channel.basic_publish( - exchange=rmq_exchange, - routing_key=f"notify.user.{user_id}", - body=json.dumps(message), - properties=pika.BasicProperties( - content_type="application/json", - delivery_mode=2 - ), - mandatory=True - ) - - -def close_connection(connection): - if connection and connection.is_open: - connection.close() - print("[RMQ] Connection closed.") - - -if __name__ == "__main__": - conn = create_connection() - send_message_to_rmq(conn, 1, {"type": "notification", "content": "Vault TLS cert reloaded successfully."}) - close_connection(conn) diff --git a/src/validator.py b/src/validator.py index bb9a9cc..6dfca96 100644 --- a/src/validator.py +++ b/src/validator.py @@ -12,9 +12,9 @@ def verify_api_key(api_key: str, hashed: str) -> bool: return False if __name__=="__main__": - plain_key = "super-secret-api-key" - #hashed_key = hash_api_key(plain_key) - hashed_key = '$argon2id$v=19$m=65536,t=3,p=4$vqU+MRafVW1b8AtF+zHb0w$p1J4Gyb0jhlVtKgYyjTITxfU97YaayeS3s3qFFP5sVM' + plain_key = "password" + hashed_key = hash_api_key(plain_key) + #hashed_key = '$argon2id$v=19$m=65536,t=3,p=4$vqU+MRafVW1b8AtF+zHb0w$p1J4Gyb0jhlVtKgYyjTITxfU97YaayeS3s3qFFP5sVM' print("Hashed API Key:", hashed_key) print("Verification:", verify_api_key(plain_key, hashed_key))