Refactor DB and RMQ handlers: add connection pool lifecycle, remove TLS for RMQ, switch to lifespan events in FastAPI, ensure safe startup/shutdown.
This commit is contained in:
parent
a211bd3156
commit
482a90ae7b
17
src/db.py
17
src/db.py
@ -5,7 +5,6 @@ import os
|
||||
import time
|
||||
import sys
|
||||
|
||||
|
||||
db_username = return_credentials("/etc/secrets/db_username")
|
||||
db_password = return_credentials("/etc/secrets/db_password")
|
||||
db_host = os.getenv("BACKEND_API_INTERNAL_DB_HOST","localhost")
|
||||
@ -24,17 +23,20 @@ MYSQL_CONFIG = {
|
||||
_pool_lock = threading.Lock()
|
||||
_connection_pool = None
|
||||
|
||||
|
||||
def create_connection_pool():
|
||||
global _connection_pool
|
||||
for attempt in range(1, MAX_RETRIES+1):
|
||||
try:
|
||||
print(f"[MySQL] Attempt {attempt} to connect...")
|
||||
_connection_pool = mysql.connector.pooling.MySQLConnectionPool(
|
||||
pool = mysql.connector.pooling.MySQLConnectionPool(
|
||||
pool_name="mypool",
|
||||
pool_size=5,
|
||||
pool_reset_session=True,
|
||||
**MYSQL_CONFIG
|
||||
)
|
||||
with _pool_lock:
|
||||
_connection_pool = pool
|
||||
print("[MySQL] Connection pool created successfully.")
|
||||
return
|
||||
except mysql.connector.Error as e:
|
||||
@ -44,6 +46,15 @@ def create_connection_pool():
|
||||
print(f"[MySQL] Failed to connect after {MAX_RETRIES} attempts — exiting.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def close_connection_pool():
|
||||
global _connection_pool
|
||||
with _pool_lock:
|
||||
if _connection_pool:
|
||||
_connection_pool = None
|
||||
print("[MySQL] Connection pool closed.")
|
||||
|
||||
|
||||
def get_connection_pool():
|
||||
global _connection_pool
|
||||
with _pool_lock:
|
||||
@ -51,6 +62,7 @@ def get_connection_pool():
|
||||
create_connection_pool()
|
||||
return _connection_pool
|
||||
|
||||
|
||||
def get_db():
|
||||
pool = get_connection_pool()
|
||||
conn = pool.get_connection()
|
||||
@ -58,4 +70,3 @@ def get_db():
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
32
src/main.py
32
src/main.py
@ -1,15 +1,16 @@
|
||||
from fastapi import FastAPI, Depends, HTTPException
|
||||
from fastapi import FastAPI, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security.api_key import APIKeyHeader
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from typing import Dict
|
||||
from pydantic import BaseModel
|
||||
from validator import verify_api_key
|
||||
from db import get_db
|
||||
from db import get_db, create_connection_pool, close_connection_pool
|
||||
from logger_handler import setup_logger
|
||||
from rabbitmq_handler import send_message_to_rmq
|
||||
from rabbitmq_handler import send_message_to_rmq, create_connection, close_connection
|
||||
import uvicorn
|
||||
from uvicorn_logging_config import LOGGING_CONFIG
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
@ -19,10 +20,29 @@ class Notification(BaseModel):
|
||||
receipent_user_id : int
|
||||
message : Dict
|
||||
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Starting application...")
|
||||
|
||||
logger.info("Creating MySQL connection pool...")
|
||||
create_connection_pool()
|
||||
logger.info("Connecting to RabbitMQ...")
|
||||
app.state.rmq_connection = create_connection()
|
||||
|
||||
yield
|
||||
logger.info("Closing RabbitMQ connection...")
|
||||
close_connection(app.state.rmq_connection)
|
||||
logger.info("Closing MySQL connection pool...")
|
||||
close_connection_pool()
|
||||
|
||||
|
||||
api = FastAPI(
|
||||
title="Internal Notifier API",
|
||||
description="API to forward messages to RabbitMQ",
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
def verify_api_key_dependency_internal(db=Depends(get_db), api_key: str = Depends(api_key_header_internal)) -> str:
|
||||
@ -49,15 +69,15 @@ async def custom_http_exception_handler(request,exc):
|
||||
@api.post("/internal/receive-notifications")
|
||||
def receive_notifications(
|
||||
notification_data: Notification,
|
||||
request: Request,
|
||||
db = Depends(get_db),
|
||||
program_name: str = Depends(verify_api_key_dependency_internal)
|
||||
):
|
||||
logger.info(f"Received notifcation data from {program_name} for RMQ")
|
||||
send_message_to_rmq(notification_data.user_id,notification_data.message)
|
||||
send_message_to_rmq(request.app.state.rmq_connection,notification_data.receipent_user_id,notification_data.message)
|
||||
logger.info("Successfully delivered message to RMQ")
|
||||
return {"status": "queued"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:api",
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import pika
|
||||
from typing import Dict
|
||||
from secret_handler import return_credentials
|
||||
import ssl
|
||||
import json
|
||||
import time
|
||||
import sys
|
||||
@ -16,15 +15,12 @@ rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE","app_notifications"
|
||||
MAX_RETRIES = 5
|
||||
RETRY_DELAY = 5
|
||||
|
||||
def send_message_to_rmq(user_id: int, message: Dict):
|
||||
|
||||
def create_connection():
|
||||
credentials = pika.PlainCredentials(username=rmq_username, password=rmq_password)
|
||||
context = ssl.create_default_context()
|
||||
context.check_hostname = False
|
||||
ssl_options = pika.SSLOptions(context)
|
||||
conn_params = pika.ConnectionParameters(
|
||||
host=rmq_host,
|
||||
port=5671,
|
||||
ssl_options=ssl_options,
|
||||
port=5672,
|
||||
credentials=credentials,
|
||||
virtual_host=rmq_vhost
|
||||
)
|
||||
@ -32,22 +28,8 @@ def send_message_to_rmq(user_id: int, message: Dict):
|
||||
for attempt in range(1, MAX_RETRIES + 1):
|
||||
try:
|
||||
connection = pika.BlockingConnection(conn_params)
|
||||
channel = connection.channel()
|
||||
channel.exchange_declare(exchange=rmq_exchange, exchange_type="topic", durable=True)
|
||||
channel.confirm_delivery()
|
||||
channel.basic_publish(
|
||||
exchange=rmq_exchange,
|
||||
routing_key=f"notify.user.{user_id}",
|
||||
body=json.dumps(message),
|
||||
properties=pika.BasicProperties(
|
||||
content_type="application/json",
|
||||
delivery_mode=2
|
||||
),
|
||||
mandatory=True
|
||||
)
|
||||
connection.close()
|
||||
return
|
||||
|
||||
print("[RMQ] Connection established.")
|
||||
return connection
|
||||
except Exception as e:
|
||||
print(f"[RMQ] Attempt {attempt} failed: {e}")
|
||||
if attempt < MAX_RETRIES:
|
||||
@ -56,5 +38,34 @@ def send_message_to_rmq(user_id: int, message: Dict):
|
||||
print("[RMQ] Failed to connect after maximum retries — exiting.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def send_message_to_rmq(connection, user_id: int, message: Dict):
|
||||
if not connection or not connection.is_open:
|
||||
raise RuntimeError("RabbitMQ connection is not open")
|
||||
|
||||
channel = connection.channel()
|
||||
channel.exchange_declare(exchange=rmq_exchange, exchange_type="topic", durable=True)
|
||||
channel.confirm_delivery()
|
||||
|
||||
channel.basic_publish(
|
||||
exchange=rmq_exchange,
|
||||
routing_key=f"notify.user.{user_id}",
|
||||
body=json.dumps(message),
|
||||
properties=pika.BasicProperties(
|
||||
content_type="application/json",
|
||||
delivery_mode=2
|
||||
),
|
||||
mandatory=True
|
||||
)
|
||||
|
||||
|
||||
def close_connection(connection):
|
||||
if connection and connection.is_open:
|
||||
connection.close()
|
||||
print("[RMQ] Connection closed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
send_message_to_rmq(1, {"type": "notification", "content": "Vault TLS cert reloaded successfully."})
|
||||
conn = create_connection()
|
||||
send_message_to_rmq(conn, 1, {"type": "notification", "content": "Vault TLS cert reloaded successfully."})
|
||||
close_connection(conn)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user