Merge pull request 'Improved RabbitMQ handling' (#5) from feature/rabbitmq-improved-connection into main
All checks were successful
Build & Publish to GHCR / build (push) Successful in 1m5s

Reviewed-on: #5
This commit is contained in:
florian 2025-10-12 21:51:31 +02:00
commit 2be35613e2
5 changed files with 104 additions and 120 deletions

View File

@ -1,15 +1,19 @@
aio-pika==9.5.7
aiormq==6.9.0
annotated-types==0.7.0 annotated-types==0.7.0
anyio==4.11.0 anyio==4.11.0
argon2-cffi==25.1.0 argon2-cffi==25.1.0
argon2-cffi-bindings==25.1.0 argon2-cffi-bindings==25.1.0
cffi==2.0.0 cffi==2.0.0
click==8.3.0 click==8.3.0
fastapi==0.118.0 fastapi==0.119.0
h11==0.16.0 h11==0.16.0
idna==3.10 idna==3.11
multidict==6.7.0
mysql-connector-python==9.4.0 mysql-connector-python==9.4.0
pika==1.3.2 pamqp==3.3.0
prometheus_client==0.23.1 prometheus_client==0.23.1
propcache==0.4.1
pycparser==2.23 pycparser==2.23
pydantic==2.12.0 pydantic==2.12.0
pydantic_core==2.41.1 pydantic_core==2.41.1
@ -18,3 +22,4 @@ starlette==0.48.0
typing-inspection==0.4.2 typing-inspection==0.4.2
typing_extensions==4.15.0 typing_extensions==4.15.0
uvicorn==0.37.0 uvicorn==0.37.0
yarl==1.22.0

View File

@ -1,4 +1,3 @@
import mysql.connector
from mysql.connector import pooling, Error from mysql.connector import pooling, Error
import threading import threading
from secret_handler import return_credentials from secret_handler import return_credentials
@ -41,7 +40,7 @@ def create_connection_pool():
try: try:
logger.info(f"[MySQL] Attempt {attempt} to connect...") logger.info(f"[MySQL] Attempt {attempt} to connect...")
pool = pooling.MySQLConnectionPool( pool = pooling.MySQLConnectionPool(
pool_name="royalroadPool", pool_name="ApiInternalPool",
pool_size=5, pool_size=5,
pool_reset_session=True, pool_reset_session=True,
**MYSQL_CONFIG **MYSQL_CONFIG
@ -103,7 +102,6 @@ def _pool_healthcheck():
conn = pool.get_connection() conn = pool.get_connection()
conn.ping(reconnect=True, attempts=3, delay=1) conn.ping(reconnect=True, attempts=3, delay=1)
conn.close() conn.close()
logger.debug("[MySQL] Pool healthcheck OK.")
except Error as e: except Error as e:
logger.warning(f"[MySQL] Pool healthcheck failed: {e}") logger.warning(f"[MySQL] Pool healthcheck failed: {e}")
create_connection_pool() create_connection_pool()

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel
from validator import verify_api_key from validator import verify_api_key
from db import get_db, create_connection_pool, close_connection_pool, start_healthcheck_thread from db import get_db, create_connection_pool, close_connection_pool, start_healthcheck_thread
from logger_handler import setup_logger from logger_handler import setup_logger
from rabbitmq_handler import send_message_to_rmq, create_connection, close_connection from rabbitmq_handler import RabbitMQProducer
import uvicorn import uvicorn
from uvicorn_logging_config import LOGGING_CONFIG from uvicorn_logging_config import LOGGING_CONFIG
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -15,7 +15,7 @@ from metrics_server import REQUEST_COUNTER
import asyncio import asyncio
logger = setup_logger(__name__) logger = setup_logger(__name__)
producer = RabbitMQProducer()
api_key_header_internal = APIKeyHeader(name="X-API-Key-Internal") api_key_header_internal = APIKeyHeader(name="X-API-Key-Internal")
class Notification(BaseModel): class Notification(BaseModel):
@ -30,24 +30,21 @@ async def lifespan(app: FastAPI):
logger.info("Creating MySQL connection pool...") logger.info("Creating MySQL connection pool...")
create_connection_pool() create_connection_pool()
logger.info("Connecting to RabbitMQ...")
app.state.rmq_connection = create_connection()
start_healthcheck_thread() start_healthcheck_thread()
logger.info("MySQL healthcheck thread started.") logger.info("MySQL healthcheck thread started.")
logger.info("Starting RabbitMQ producer...")
await producer.connect()
app.state.rmq_producer = producer
logger.info("[FastAPI] RabbitMQ producer initialized.")
yield yield
logger.info("Closing RabbitMQ connection...") logger.info("Closing RabbitMQ producer...")
close_connection(app.state.rmq_connection) await producer.close()
logger.info("Closing MySQL connection pool...") logger.info("Closing MySQL connection pool...")
close_connection_pool() close_connection_pool()
def get_rmq_connection(app: FastAPI):
connection = getattr(app.state, "rmq_connection", None)
if not connection or not connection.is_open:
app.state.rmq_connection = create_connection()
return app.state.rmq_connection
api = FastAPI( api = FastAPI(
title="Internal Notifier API", title="Internal Notifier API",
description="API to forward messages to RabbitMQ", description="API to forward messages to RabbitMQ",
@ -87,54 +84,24 @@ async def custom_http_exception_handler(request,exc):
content={"detail": exc.detail} content={"detail": exc.detail}
) )
@api.get("/health", tags=["Health"])
def return_health(request:Request, db=Depends(get_db)):
try:
cursor = db.cursor()
cursor.execute("SELECT 1")
cursor.fetchone()
db_status = "ok"
except Exception as e:
logger.error(f"Health check DB failed: {e}")
db_status = "error"
try:
rmq_conn = getattr(request.app.state, "rmq_connection", None)
if not rmq_conn or not rmq_conn.is_open:
logger.error("Health check RMQ failed: connection closed or missing")
rmq_status = "error"
except Exception as e:
logger.error(f"Health check RMQ failed: {e}")
rmq_status = "error"
overall_status = "ok" if db_status == "ok" and rmq_status == "ok" else "error"
status_code = 200 if overall_status == "ok" else 500
return JSONResponse(
status_code=status_code,
content={"status": overall_status,
"components": {
"database": db_status,
"rabbitmq": rmq_status
},
"message": "Service is running" if overall_status == "ok" else "One or more checks failed"}
)
@api.post("/internal/receive-notifications") @api.post("/internal/receive-notifications")
def receive_notifications( async def receive_notifications(
notification_data: Notification, notification_data: Notification,
request: Request, request: Request,
db = Depends(get_db), db = Depends(get_db),
program_name: str = Depends(verify_api_key_dependency_internal) program_name: str = Depends(verify_api_key_dependency_internal)
): ):
rmq_connection = get_rmq_connection(request.app)
logger.info(f"Received notifcation data from {program_name} for RMQ") logger.info(f"Received notifcation data from {program_name} for RMQ")
send_message_to_rmq(rmq_connection,notification_data.receipent_user_id,notification_data.message) await request.app.state.rmq_producer.publish(
logger.info("Successfully delivered message to RMQ") notification_data.receipent_user_id,
notification_data.message
)
logger.info("Successfully queued for delivery to RMQ")
return {"status": "queued"} return {"status": "queued"}
async def start_servers(): async def start_servers():
config_main = uvicorn.Config("main:api", host="0.0.0.0", port=8101, log_config=LOGGING_CONFIG, log_level="info") config_main = uvicorn.Config("main:api", host="0.0.0.0", port=8101, log_level="info")
config_metrics = uvicorn.Config("metrics_server:metrics_api", host="0.0.0.0", port=9000, log_level="info") config_metrics = uvicorn.Config("metrics_server:metrics_api", host="0.0.0.0", port=9000, log_level="info")
server_main = uvicorn.Server(config_main) server_main = uvicorn.Server(config_main)

View File

@ -1,72 +1,86 @@
import pika import asyncio
from typing import Dict import aio_pika
from secret_handler import return_credentials from secret_handler import return_credentials
import json
import time
import sys
import os import os
from logger_handler import setup_logger
import json
logger = setup_logger(__name__)
rmq_username = return_credentials("/etc/secrets/rmq_username") rmq_username = return_credentials("/etc/secrets/rmq_username")
rmq_password = return_credentials("/etc/secrets/rmq_password") rmq_password = return_credentials("/etc/secrets/rmq_password")
rmq_host = os.getenv("BACKEND_API_INTERNAL_RMQ_HOST","localhost") rmq_host = os.getenv("BACKEND_API_INTERNAL_RMQ_HOST", "localhost")
rmq_vhost = os.getenv("BACKEND_API_INTERNAL_RMQ_VHOST","app_notifications") rmq_vhost = os.getenv("BACKEND_API_INTERNAL_RMQ_VHOST", "app_notifications")
rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE","app_notifications") rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE", "app_notifications")
MAX_RETRIES = 5 RABBITMQ_URL = f"amqp://{rmq_username}:{rmq_password}@{rmq_host}/{rmq_vhost}"
RETRY_DELAY = 5
def create_connection(): class RabbitMQProducer:
credentials = pika.PlainCredentials(username=rmq_username, password=rmq_password) def __init__(self, url=RABBITMQ_URL, exchange_name=rmq_exchange):
conn_params = pika.ConnectionParameters( self.url = url
host=rmq_host, self.exchange_name = exchange_name
port=5672, self.connection: aio_pika.RobustConnection | None = None
credentials=credentials, self.channel: aio_pika.RobustChannel | None = None
virtual_host=rmq_vhost self.exchange: aio_pika.Exchange | None = None
) self._queue: asyncio.Queue[tuple[int, dict]] = asyncio.Queue()
self._flush_task: asyncio.Task | None = None
self._closing = False
self._ready = asyncio.Event()
for attempt in range(1, MAX_RETRIES + 1): async def connect(self):
try: self.connection = await aio_pika.connect_robust(self.url)
connection = pika.BlockingConnection(conn_params) self.channel = await self.connection.channel(publisher_confirms=True)
print("[RMQ] Connection established.") self.exchange = await self.channel.declare_exchange(
return connection self.exchange_name, aio_pika.ExchangeType.TOPIC, durable=True
except Exception as e: )
print(f"[RMQ] Attempt {attempt} failed: {e}") logger.info(f"[aio-pika] Connected and exchange '{self.exchange_name}' ready.")
if attempt < MAX_RETRIES: self._ready.set()
time.sleep(RETRY_DELAY) self._flush_task = asyncio.create_task(self._flush_queue_loop())
else:
print("[RMQ] Failed to connect after maximum retries — exiting.") async def publish(self, routing_key: int, message: dict):
sys.exit(1) await self._queue.put((routing_key, message))
logger.debug(f"[RabbitMQ] Queued message for {routing_key}")
async def _flush_queue_loop(self):
await self._ready.wait()
logger.debug(f"here")
while True:
routing_key, message = await self._queue.get()
try:
await self.exchange.publish(
aio_pika.Message(
body=json.dumps(message).encode(),
content_type="application/json",
delivery_mode=aio_pika.DeliveryMode.PERSISTENT
),
routing_key=f"notify.user.{routing_key}",
mandatory=True
)
logger.debug(f"[aio-pika] Published message to notify.user.{routing_key}")
except Exception as e:
logger.warning(f"[aio-pika] Publish failed, requeuing: {e}")
await asyncio.sleep(2)
await self._queue.put((routing_key, message))
finally:
self._queue.task_done()
async def close(self):
self._closing = True
if self._flush_task and not self._flush_task.done():
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
if self.connection:
await self.connection.close()
logger.info("[aio-pika] Connection closed.")
def send_message_to_rmq(connection, user_id: int, message: Dict): async def main():
if not connection or not connection.is_open: producer = RabbitMQProducer()
print("[RMQ] Connection lost, reconnecting...") await producer.connect()
connection = create_connection() await producer.publish("1", {"type": "notification", "content": "Test message"})
await producer.close()
channel = connection.channel()
channel.exchange_declare(exchange=rmq_exchange, exchange_type="topic", durable=True)
channel.confirm_delivery()
channel.basic_publish(
exchange=rmq_exchange,
routing_key=f"notify.user.{user_id}",
body=json.dumps(message),
properties=pika.BasicProperties(
content_type="application/json",
delivery_mode=2
),
mandatory=True
)
def close_connection(connection):
if connection and connection.is_open:
connection.close()
print("[RMQ] Connection closed.")
if __name__ == "__main__":
conn = create_connection()
send_message_to_rmq(conn, 1, {"type": "notification", "content": "Vault TLS cert reloaded successfully."})
close_connection(conn)

View File

@ -12,9 +12,9 @@ def verify_api_key(api_key: str, hashed: str) -> bool:
return False return False
if __name__=="__main__": if __name__=="__main__":
plain_key = "super-secret-api-key" plain_key = "password"
#hashed_key = hash_api_key(plain_key) hashed_key = hash_api_key(plain_key)
hashed_key = '$argon2id$v=19$m=65536,t=3,p=4$vqU+MRafVW1b8AtF+zHb0w$p1J4Gyb0jhlVtKgYyjTITxfU97YaayeS3s3qFFP5sVM' #hashed_key = '$argon2id$v=19$m=65536,t=3,p=4$vqU+MRafVW1b8AtF+zHb0w$p1J4Gyb0jhlVtKgYyjTITxfU97YaayeS3s3qFFP5sVM'
print("Hashed API Key:", hashed_key) print("Hashed API Key:", hashed_key)
print("Verification:", verify_api_key(plain_key, hashed_key)) print("Verification:", verify_api_key(plain_key, hashed_key))