Refactor DB and RMQ handlers: add connection pool lifecycle, remove TLS for RMQ, switch to lifespan events in FastAPI, ensure safe startup/shutdown.

This commit is contained in:
florian 2025-10-08 17:41:03 +02:00
parent a211bd3156
commit 482a90ae7b
3 changed files with 75 additions and 33 deletions

View File

@ -5,7 +5,6 @@ import os
import time import time
import sys import sys
db_username = return_credentials("/etc/secrets/db_username") db_username = return_credentials("/etc/secrets/db_username")
db_password = return_credentials("/etc/secrets/db_password") db_password = return_credentials("/etc/secrets/db_password")
db_host = os.getenv("BACKEND_API_INTERNAL_DB_HOST","localhost") db_host = os.getenv("BACKEND_API_INTERNAL_DB_HOST","localhost")
@ -24,17 +23,20 @@ MYSQL_CONFIG = {
_pool_lock = threading.Lock() _pool_lock = threading.Lock()
_connection_pool = None _connection_pool = None
def create_connection_pool(): def create_connection_pool():
global _connection_pool global _connection_pool
for attempt in range(1, MAX_RETRIES+1): for attempt in range(1, MAX_RETRIES+1):
try: try:
print(f"[MySQL] Attempt {attempt} to connect...") print(f"[MySQL] Attempt {attempt} to connect...")
_connection_pool = mysql.connector.pooling.MySQLConnectionPool( pool = mysql.connector.pooling.MySQLConnectionPool(
pool_name="mypool", pool_name="mypool",
pool_size=5, pool_size=5,
pool_reset_session=True, pool_reset_session=True,
**MYSQL_CONFIG **MYSQL_CONFIG
) )
with _pool_lock:
_connection_pool = pool
print("[MySQL] Connection pool created successfully.") print("[MySQL] Connection pool created successfully.")
return return
except mysql.connector.Error as e: except mysql.connector.Error as e:
@ -44,6 +46,15 @@ def create_connection_pool():
print(f"[MySQL] Failed to connect after {MAX_RETRIES} attempts — exiting.") print(f"[MySQL] Failed to connect after {MAX_RETRIES} attempts — exiting.")
sys.exit(1) sys.exit(1)
def close_connection_pool():
global _connection_pool
with _pool_lock:
if _connection_pool:
_connection_pool = None
print("[MySQL] Connection pool closed.")
def get_connection_pool(): def get_connection_pool():
global _connection_pool global _connection_pool
with _pool_lock: with _pool_lock:
@ -51,6 +62,7 @@ def get_connection_pool():
create_connection_pool() create_connection_pool()
return _connection_pool return _connection_pool
def get_db(): def get_db():
pool = get_connection_pool() pool = get_connection_pool()
conn = pool.get_connection() conn = pool.get_connection()
@ -58,4 +70,3 @@ def get_db():
yield conn yield conn
finally: finally:
conn.close() conn.close()

View File

@ -1,15 +1,16 @@
from fastapi import FastAPI, Depends, HTTPException from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.security.api_key import APIKeyHeader from fastapi.security.api_key import APIKeyHeader
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from typing import Dict from typing import Dict
from pydantic import BaseModel from pydantic import BaseModel
from validator import verify_api_key from validator import verify_api_key
from db import get_db from db import get_db, create_connection_pool, close_connection_pool
from logger_handler import setup_logger from logger_handler import setup_logger
from rabbitmq_handler import send_message_to_rmq from rabbitmq_handler import send_message_to_rmq, create_connection, close_connection
import uvicorn import uvicorn
from uvicorn_logging_config import LOGGING_CONFIG from uvicorn_logging_config import LOGGING_CONFIG
from contextlib import asynccontextmanager
logger = setup_logger(__name__) logger = setup_logger(__name__)
@ -19,10 +20,29 @@ class Notification(BaseModel):
receipent_user_id : int receipent_user_id : int
message : Dict message : Dict
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting application...")
logger.info("Creating MySQL connection pool...")
create_connection_pool()
logger.info("Connecting to RabbitMQ...")
app.state.rmq_connection = create_connection()
yield
logger.info("Closing RabbitMQ connection...")
close_connection(app.state.rmq_connection)
logger.info("Closing MySQL connection pool...")
close_connection_pool()
api = FastAPI( api = FastAPI(
title="Internal Notifier API", title="Internal Notifier API",
description="API to forward messages to RabbitMQ", description="API to forward messages to RabbitMQ",
version="1.0.0" version="1.0.0",
lifespan=lifespan
) )
def verify_api_key_dependency_internal(db=Depends(get_db), api_key: str = Depends(api_key_header_internal)) -> str: def verify_api_key_dependency_internal(db=Depends(get_db), api_key: str = Depends(api_key_header_internal)) -> str:
@ -49,15 +69,15 @@ async def custom_http_exception_handler(request,exc):
@api.post("/internal/receive-notifications") @api.post("/internal/receive-notifications")
def receive_notifications( def receive_notifications(
notification_data: Notification, notification_data: Notification,
request: Request,
db = Depends(get_db), db = Depends(get_db),
program_name: str = Depends(verify_api_key_dependency_internal) program_name: str = Depends(verify_api_key_dependency_internal)
): ):
logger.info(f"Received notifcation data from {program_name} for RMQ") logger.info(f"Received notifcation data from {program_name} for RMQ")
send_message_to_rmq(notification_data.user_id,notification_data.message) send_message_to_rmq(request.app.state.rmq_connection,notification_data.receipent_user_id,notification_data.message)
logger.info("Successfully delivered message to RMQ") logger.info("Successfully delivered message to RMQ")
return {"status": "queued"} return {"status": "queued"}
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run( uvicorn.run(
"main:api", "main:api",

View File

@ -1,7 +1,6 @@
import pika import pika
from typing import Dict from typing import Dict
from secret_handler import return_credentials from secret_handler import return_credentials
import ssl
import json import json
import time import time
import sys import sys
@ -16,15 +15,12 @@ rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE","app_notifications"
MAX_RETRIES = 5 MAX_RETRIES = 5
RETRY_DELAY = 5 RETRY_DELAY = 5
def send_message_to_rmq(user_id: int, message: Dict):
def create_connection():
credentials = pika.PlainCredentials(username=rmq_username, password=rmq_password) credentials = pika.PlainCredentials(username=rmq_username, password=rmq_password)
context = ssl.create_default_context()
context.check_hostname = False
ssl_options = pika.SSLOptions(context)
conn_params = pika.ConnectionParameters( conn_params = pika.ConnectionParameters(
host=rmq_host, host=rmq_host,
port=5671, port=5672,
ssl_options=ssl_options,
credentials=credentials, credentials=credentials,
virtual_host=rmq_vhost virtual_host=rmq_vhost
) )
@ -32,9 +28,25 @@ def send_message_to_rmq(user_id: int, message: Dict):
for attempt in range(1, MAX_RETRIES + 1): for attempt in range(1, MAX_RETRIES + 1):
try: try:
connection = pika.BlockingConnection(conn_params) connection = pika.BlockingConnection(conn_params)
print("[RMQ] Connection established.")
return connection
except Exception as e:
print(f"[RMQ] Attempt {attempt} failed: {e}")
if attempt < MAX_RETRIES:
time.sleep(RETRY_DELAY)
else:
print("[RMQ] Failed to connect after maximum retries — exiting.")
sys.exit(1)
def send_message_to_rmq(connection, user_id: int, message: Dict):
if not connection or not connection.is_open:
raise RuntimeError("RabbitMQ connection is not open")
channel = connection.channel() channel = connection.channel()
channel.exchange_declare(exchange=rmq_exchange, exchange_type="topic", durable=True) channel.exchange_declare(exchange=rmq_exchange, exchange_type="topic", durable=True)
channel.confirm_delivery() channel.confirm_delivery()
channel.basic_publish( channel.basic_publish(
exchange=rmq_exchange, exchange=rmq_exchange,
routing_key=f"notify.user.{user_id}", routing_key=f"notify.user.{user_id}",
@ -45,16 +57,15 @@ def send_message_to_rmq(user_id: int, message: Dict):
), ),
mandatory=True mandatory=True
) )
connection.close()
return
except Exception as e:
print(f"[RMQ] Attempt {attempt} failed: {e}") def close_connection(connection):
if attempt < MAX_RETRIES: if connection and connection.is_open:
time.sleep(RETRY_DELAY) connection.close()
else: print("[RMQ] Connection closed.")
print("[RMQ] Failed to connect after maximum retries — exiting.")
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
send_message_to_rmq(1, {"type": "notification", "content": "Vault TLS cert reloaded successfully."}) conn = create_connection()
send_message_to_rmq(conn, 1, {"type": "notification", "content": "Vault TLS cert reloaded successfully."})
close_connection(conn)