Refactor DB and RMQ handlers: add connection pool lifecycle, remove TLS for RMQ, switch to lifespan events in FastAPI, ensure safe startup/shutdown.

This commit is contained in:
florian 2025-10-08 17:41:03 +02:00
parent a211bd3156
commit 482a90ae7b
3 changed files with 75 additions and 33 deletions

View File

@ -5,7 +5,6 @@ import os
import time
import sys
db_username = return_credentials("/etc/secrets/db_username")
db_password = return_credentials("/etc/secrets/db_password")
db_host = os.getenv("BACKEND_API_INTERNAL_DB_HOST","localhost")
@ -24,17 +23,20 @@ MYSQL_CONFIG = {
_pool_lock = threading.Lock()
_connection_pool = None
def create_connection_pool():
global _connection_pool
for attempt in range(1, MAX_RETRIES+1):
try:
print(f"[MySQL] Attempt {attempt} to connect...")
_connection_pool = mysql.connector.pooling.MySQLConnectionPool(
pool = mysql.connector.pooling.MySQLConnectionPool(
pool_name="mypool",
pool_size=5,
pool_reset_session=True,
**MYSQL_CONFIG
)
with _pool_lock:
_connection_pool = pool
print("[MySQL] Connection pool created successfully.")
return
except mysql.connector.Error as e:
@ -44,6 +46,15 @@ def create_connection_pool():
print(f"[MySQL] Failed to connect after {MAX_RETRIES} attempts — exiting.")
sys.exit(1)
def close_connection_pool():
global _connection_pool
with _pool_lock:
if _connection_pool:
_connection_pool = None
print("[MySQL] Connection pool closed.")
def get_connection_pool():
global _connection_pool
with _pool_lock:
@ -51,6 +62,7 @@ def get_connection_pool():
create_connection_pool()
return _connection_pool
def get_db():
pool = get_connection_pool()
conn = pool.get_connection()
@ -58,4 +70,3 @@ def get_db():
yield conn
finally:
conn.close()

View File

@ -1,15 +1,16 @@
from fastapi import FastAPI, Depends, HTTPException
from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.security.api_key import APIKeyHeader
from starlette.exceptions import HTTPException as StarletteHTTPException
from typing import Dict
from pydantic import BaseModel
from validator import verify_api_key
from db import get_db
from db import get_db, create_connection_pool, close_connection_pool
from logger_handler import setup_logger
from rabbitmq_handler import send_message_to_rmq
from rabbitmq_handler import send_message_to_rmq, create_connection, close_connection
import uvicorn
from uvicorn_logging_config import LOGGING_CONFIG
from contextlib import asynccontextmanager
logger = setup_logger(__name__)
@ -19,10 +20,29 @@ class Notification(BaseModel):
receipent_user_id : int
message : Dict
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting application...")
logger.info("Creating MySQL connection pool...")
create_connection_pool()
logger.info("Connecting to RabbitMQ...")
app.state.rmq_connection = create_connection()
yield
logger.info("Closing RabbitMQ connection...")
close_connection(app.state.rmq_connection)
logger.info("Closing MySQL connection pool...")
close_connection_pool()
api = FastAPI(
title="Internal Notifier API",
description="API to forward messages to RabbitMQ",
version="1.0.0"
version="1.0.0",
lifespan=lifespan
)
def verify_api_key_dependency_internal(db=Depends(get_db), api_key: str = Depends(api_key_header_internal)) -> str:
@ -49,15 +69,15 @@ async def custom_http_exception_handler(request,exc):
@api.post("/internal/receive-notifications")
def receive_notifications(
notification_data: Notification,
request: Request,
db = Depends(get_db),
program_name: str = Depends(verify_api_key_dependency_internal)
):
logger.info(f"Received notifcation data from {program_name} for RMQ")
send_message_to_rmq(notification_data.user_id,notification_data.message)
send_message_to_rmq(request.app.state.rmq_connection,notification_data.receipent_user_id,notification_data.message)
logger.info("Successfully delivered message to RMQ")
return {"status": "queued"}
if __name__ == "__main__":
uvicorn.run(
"main:api",

View File

@ -1,7 +1,6 @@
import pika
from typing import Dict
from secret_handler import return_credentials
import ssl
import json
import time
import sys
@ -16,15 +15,12 @@ rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE","app_notifications"
MAX_RETRIES = 5
RETRY_DELAY = 5
def send_message_to_rmq(user_id: int, message: Dict):
def create_connection():
credentials = pika.PlainCredentials(username=rmq_username, password=rmq_password)
context = ssl.create_default_context()
context.check_hostname = False
ssl_options = pika.SSLOptions(context)
conn_params = pika.ConnectionParameters(
host=rmq_host,
port=5671,
ssl_options=ssl_options,
port=5672,
credentials=credentials,
virtual_host=rmq_vhost
)
@ -32,22 +28,8 @@ def send_message_to_rmq(user_id: int, message: Dict):
for attempt in range(1, MAX_RETRIES + 1):
try:
connection = pika.BlockingConnection(conn_params)
channel = connection.channel()
channel.exchange_declare(exchange=rmq_exchange, exchange_type="topic", durable=True)
channel.confirm_delivery()
channel.basic_publish(
exchange=rmq_exchange,
routing_key=f"notify.user.{user_id}",
body=json.dumps(message),
properties=pika.BasicProperties(
content_type="application/json",
delivery_mode=2
),
mandatory=True
)
connection.close()
return
print("[RMQ] Connection established.")
return connection
except Exception as e:
print(f"[RMQ] Attempt {attempt} failed: {e}")
if attempt < MAX_RETRIES:
@ -56,5 +38,34 @@ def send_message_to_rmq(user_id: int, message: Dict):
print("[RMQ] Failed to connect after maximum retries — exiting.")
sys.exit(1)
def send_message_to_rmq(connection, user_id: int, message: Dict):
if not connection or not connection.is_open:
raise RuntimeError("RabbitMQ connection is not open")
channel = connection.channel()
channel.exchange_declare(exchange=rmq_exchange, exchange_type="topic", durable=True)
channel.confirm_delivery()
channel.basic_publish(
exchange=rmq_exchange,
routing_key=f"notify.user.{user_id}",
body=json.dumps(message),
properties=pika.BasicProperties(
content_type="application/json",
delivery_mode=2
),
mandatory=True
)
def close_connection(connection):
if connection and connection.is_open:
connection.close()
print("[RMQ] Connection closed.")
if __name__ == "__main__":
send_message_to_rmq(1, {"type": "notification", "content": "Vault TLS cert reloaded successfully."})
conn = create_connection()
send_message_to_rmq(conn, 1, {"type": "notification", "content": "Vault TLS cert reloaded successfully."})
close_connection(conn)