Refactor DB and RMQ handlers: add connection pool lifecycle, remove TLS for RMQ, switch to lifespan events in FastAPI, ensure safe startup/shutdown.
This commit is contained in:
parent
a211bd3156
commit
482a90ae7b
17
src/db.py
17
src/db.py
@ -5,7 +5,6 @@ import os
|
|||||||
import time
|
import time
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
db_username = return_credentials("/etc/secrets/db_username")
|
db_username = return_credentials("/etc/secrets/db_username")
|
||||||
db_password = return_credentials("/etc/secrets/db_password")
|
db_password = return_credentials("/etc/secrets/db_password")
|
||||||
db_host = os.getenv("BACKEND_API_INTERNAL_DB_HOST","localhost")
|
db_host = os.getenv("BACKEND_API_INTERNAL_DB_HOST","localhost")
|
||||||
@ -24,17 +23,20 @@ MYSQL_CONFIG = {
|
|||||||
_pool_lock = threading.Lock()
|
_pool_lock = threading.Lock()
|
||||||
_connection_pool = None
|
_connection_pool = None
|
||||||
|
|
||||||
|
|
||||||
def create_connection_pool():
|
def create_connection_pool():
|
||||||
global _connection_pool
|
global _connection_pool
|
||||||
for attempt in range(1, MAX_RETRIES+1):
|
for attempt in range(1, MAX_RETRIES+1):
|
||||||
try:
|
try:
|
||||||
print(f"[MySQL] Attempt {attempt} to connect...")
|
print(f"[MySQL] Attempt {attempt} to connect...")
|
||||||
_connection_pool = mysql.connector.pooling.MySQLConnectionPool(
|
pool = mysql.connector.pooling.MySQLConnectionPool(
|
||||||
pool_name="mypool",
|
pool_name="mypool",
|
||||||
pool_size=5,
|
pool_size=5,
|
||||||
pool_reset_session=True,
|
pool_reset_session=True,
|
||||||
**MYSQL_CONFIG
|
**MYSQL_CONFIG
|
||||||
)
|
)
|
||||||
|
with _pool_lock:
|
||||||
|
_connection_pool = pool
|
||||||
print("[MySQL] Connection pool created successfully.")
|
print("[MySQL] Connection pool created successfully.")
|
||||||
return
|
return
|
||||||
except mysql.connector.Error as e:
|
except mysql.connector.Error as e:
|
||||||
@ -44,6 +46,15 @@ def create_connection_pool():
|
|||||||
print(f"[MySQL] Failed to connect after {MAX_RETRIES} attempts — exiting.")
|
print(f"[MySQL] Failed to connect after {MAX_RETRIES} attempts — exiting.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def close_connection_pool():
|
||||||
|
global _connection_pool
|
||||||
|
with _pool_lock:
|
||||||
|
if _connection_pool:
|
||||||
|
_connection_pool = None
|
||||||
|
print("[MySQL] Connection pool closed.")
|
||||||
|
|
||||||
|
|
||||||
def get_connection_pool():
|
def get_connection_pool():
|
||||||
global _connection_pool
|
global _connection_pool
|
||||||
with _pool_lock:
|
with _pool_lock:
|
||||||
@ -51,6 +62,7 @@ def get_connection_pool():
|
|||||||
create_connection_pool()
|
create_connection_pool()
|
||||||
return _connection_pool
|
return _connection_pool
|
||||||
|
|
||||||
|
|
||||||
def get_db():
|
def get_db():
|
||||||
pool = get_connection_pool()
|
pool = get_connection_pool()
|
||||||
conn = pool.get_connection()
|
conn = pool.get_connection()
|
||||||
@ -58,4 +70,3 @@ def get_db():
|
|||||||
yield conn
|
yield conn
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|||||||
32
src/main.py
32
src/main.py
@ -1,15 +1,16 @@
|
|||||||
from fastapi import FastAPI, Depends, HTTPException
|
from fastapi import FastAPI, Depends, HTTPException, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.security.api_key import APIKeyHeader
|
from fastapi.security.api_key import APIKeyHeader
|
||||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from validator import verify_api_key
|
from validator import verify_api_key
|
||||||
from db import get_db
|
from db import get_db, create_connection_pool, close_connection_pool
|
||||||
from logger_handler import setup_logger
|
from logger_handler import setup_logger
|
||||||
from rabbitmq_handler import send_message_to_rmq
|
from rabbitmq_handler import send_message_to_rmq, create_connection, close_connection
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from uvicorn_logging_config import LOGGING_CONFIG
|
from uvicorn_logging_config import LOGGING_CONFIG
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
logger = setup_logger(__name__)
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
@ -19,10 +20,29 @@ class Notification(BaseModel):
|
|||||||
receipent_user_id : int
|
receipent_user_id : int
|
||||||
message : Dict
|
message : Dict
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
logger.info("Starting application...")
|
||||||
|
|
||||||
|
logger.info("Creating MySQL connection pool...")
|
||||||
|
create_connection_pool()
|
||||||
|
logger.info("Connecting to RabbitMQ...")
|
||||||
|
app.state.rmq_connection = create_connection()
|
||||||
|
|
||||||
|
yield
|
||||||
|
logger.info("Closing RabbitMQ connection...")
|
||||||
|
close_connection(app.state.rmq_connection)
|
||||||
|
logger.info("Closing MySQL connection pool...")
|
||||||
|
close_connection_pool()
|
||||||
|
|
||||||
|
|
||||||
api = FastAPI(
|
api = FastAPI(
|
||||||
title="Internal Notifier API",
|
title="Internal Notifier API",
|
||||||
description="API to forward messages to RabbitMQ",
|
description="API to forward messages to RabbitMQ",
|
||||||
version="1.0.0"
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
def verify_api_key_dependency_internal(db=Depends(get_db), api_key: str = Depends(api_key_header_internal)) -> str:
|
def verify_api_key_dependency_internal(db=Depends(get_db), api_key: str = Depends(api_key_header_internal)) -> str:
|
||||||
@ -49,15 +69,15 @@ async def custom_http_exception_handler(request,exc):
|
|||||||
@api.post("/internal/receive-notifications")
|
@api.post("/internal/receive-notifications")
|
||||||
def receive_notifications(
|
def receive_notifications(
|
||||||
notification_data: Notification,
|
notification_data: Notification,
|
||||||
|
request: Request,
|
||||||
db = Depends(get_db),
|
db = Depends(get_db),
|
||||||
program_name: str = Depends(verify_api_key_dependency_internal)
|
program_name: str = Depends(verify_api_key_dependency_internal)
|
||||||
):
|
):
|
||||||
logger.info(f"Received notifcation data from {program_name} for RMQ")
|
logger.info(f"Received notifcation data from {program_name} for RMQ")
|
||||||
send_message_to_rmq(notification_data.user_id,notification_data.message)
|
send_message_to_rmq(request.app.state.rmq_connection,notification_data.receipent_user_id,notification_data.message)
|
||||||
logger.info("Successfully delivered message to RMQ")
|
logger.info("Successfully delivered message to RMQ")
|
||||||
return {"status": "queued"}
|
return {"status": "queued"}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"main:api",
|
"main:api",
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import pika
|
import pika
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from secret_handler import return_credentials
|
from secret_handler import return_credentials
|
||||||
import ssl
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import sys
|
import sys
|
||||||
@ -16,15 +15,12 @@ rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE","app_notifications"
|
|||||||
MAX_RETRIES = 5
|
MAX_RETRIES = 5
|
||||||
RETRY_DELAY = 5
|
RETRY_DELAY = 5
|
||||||
|
|
||||||
def send_message_to_rmq(user_id: int, message: Dict):
|
|
||||||
|
def create_connection():
|
||||||
credentials = pika.PlainCredentials(username=rmq_username, password=rmq_password)
|
credentials = pika.PlainCredentials(username=rmq_username, password=rmq_password)
|
||||||
context = ssl.create_default_context()
|
|
||||||
context.check_hostname = False
|
|
||||||
ssl_options = pika.SSLOptions(context)
|
|
||||||
conn_params = pika.ConnectionParameters(
|
conn_params = pika.ConnectionParameters(
|
||||||
host=rmq_host,
|
host=rmq_host,
|
||||||
port=5671,
|
port=5672,
|
||||||
ssl_options=ssl_options,
|
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
virtual_host=rmq_vhost
|
virtual_host=rmq_vhost
|
||||||
)
|
)
|
||||||
@ -32,22 +28,8 @@ def send_message_to_rmq(user_id: int, message: Dict):
|
|||||||
for attempt in range(1, MAX_RETRIES + 1):
|
for attempt in range(1, MAX_RETRIES + 1):
|
||||||
try:
|
try:
|
||||||
connection = pika.BlockingConnection(conn_params)
|
connection = pika.BlockingConnection(conn_params)
|
||||||
channel = connection.channel()
|
print("[RMQ] Connection established.")
|
||||||
channel.exchange_declare(exchange=rmq_exchange, exchange_type="topic", durable=True)
|
return connection
|
||||||
channel.confirm_delivery()
|
|
||||||
channel.basic_publish(
|
|
||||||
exchange=rmq_exchange,
|
|
||||||
routing_key=f"notify.user.{user_id}",
|
|
||||||
body=json.dumps(message),
|
|
||||||
properties=pika.BasicProperties(
|
|
||||||
content_type="application/json",
|
|
||||||
delivery_mode=2
|
|
||||||
),
|
|
||||||
mandatory=True
|
|
||||||
)
|
|
||||||
connection.close()
|
|
||||||
return
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[RMQ] Attempt {attempt} failed: {e}")
|
print(f"[RMQ] Attempt {attempt} failed: {e}")
|
||||||
if attempt < MAX_RETRIES:
|
if attempt < MAX_RETRIES:
|
||||||
@ -56,5 +38,34 @@ def send_message_to_rmq(user_id: int, message: Dict):
|
|||||||
print("[RMQ] Failed to connect after maximum retries — exiting.")
|
print("[RMQ] Failed to connect after maximum retries — exiting.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def send_message_to_rmq(connection, user_id: int, message: Dict):
|
||||||
|
if not connection or not connection.is_open:
|
||||||
|
raise RuntimeError("RabbitMQ connection is not open")
|
||||||
|
|
||||||
|
channel = connection.channel()
|
||||||
|
channel.exchange_declare(exchange=rmq_exchange, exchange_type="topic", durable=True)
|
||||||
|
channel.confirm_delivery()
|
||||||
|
|
||||||
|
channel.basic_publish(
|
||||||
|
exchange=rmq_exchange,
|
||||||
|
routing_key=f"notify.user.{user_id}",
|
||||||
|
body=json.dumps(message),
|
||||||
|
properties=pika.BasicProperties(
|
||||||
|
content_type="application/json",
|
||||||
|
delivery_mode=2
|
||||||
|
),
|
||||||
|
mandatory=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def close_connection(connection):
|
||||||
|
if connection and connection.is_open:
|
||||||
|
connection.close()
|
||||||
|
print("[RMQ] Connection closed.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
send_message_to_rmq(1, {"type": "notification", "content": "Vault TLS cert reloaded successfully."})
|
conn = create_connection()
|
||||||
|
send_message_to_rmq(conn, 1, {"type": "notification", "content": "Vault TLS cert reloaded successfully."})
|
||||||
|
close_connection(conn)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user