diff --git a/src/db.py b/src/db.py index 42940c7..336001a 100644 --- a/src/db.py +++ b/src/db.py @@ -5,7 +5,6 @@ import os import time import sys - db_username = return_credentials("/etc/secrets/db_username") db_password = return_credentials("/etc/secrets/db_password") db_host = os.getenv("BACKEND_API_INTERNAL_DB_HOST","localhost") @@ -24,17 +23,20 @@ MYSQL_CONFIG = { _pool_lock = threading.Lock() _connection_pool = None + def create_connection_pool(): global _connection_pool for attempt in range(1, MAX_RETRIES+1): try: print(f"[MySQL] Attempt {attempt} to connect...") - _connection_pool = mysql.connector.pooling.MySQLConnectionPool( + pool = mysql.connector.pooling.MySQLConnectionPool( pool_name="mypool", pool_size=5, pool_reset_session=True, **MYSQL_CONFIG ) + with _pool_lock: + _connection_pool = pool print("[MySQL] Connection pool created successfully.") return except mysql.connector.Error as e: @@ -44,6 +46,15 @@ def create_connection_pool(): print(f"[MySQL] Failed to connect after {MAX_RETRIES} attempts — exiting.") sys.exit(1) + +def close_connection_pool(): + global _connection_pool + with _pool_lock: + if _connection_pool: + _connection_pool = None + print("[MySQL] Connection pool closed.") + + def get_connection_pool(): global _connection_pool with _pool_lock: @@ -51,6 +62,7 @@ def get_connection_pool(): create_connection_pool() return _connection_pool + def get_db(): pool = get_connection_pool() conn = pool.get_connection() @@ -58,4 +70,3 @@ def get_db(): yield conn finally: conn.close() - diff --git a/src/main.py b/src/main.py index d36c4f5..27a8767 100644 --- a/src/main.py +++ b/src/main.py @@ -1,15 +1,16 @@ -from fastapi import FastAPI, Depends, HTTPException +from fastapi import FastAPI, Depends, HTTPException, Request from fastapi.responses import JSONResponse from fastapi.security.api_key import APIKeyHeader from starlette.exceptions import HTTPException as StarletteHTTPException from typing import Dict from pydantic import BaseModel from validator import verify_api_key -from db import get_db +from db import get_db, create_connection_pool, close_connection_pool from logger_handler import setup_logger -from rabbitmq_handler import send_message_to_rmq +from rabbitmq_handler import send_message_to_rmq, create_connection, close_connection import uvicorn from uvicorn_logging_config import LOGGING_CONFIG +from contextlib import asynccontextmanager logger = setup_logger(__name__) @@ -19,10 +20,29 @@ class Notification(BaseModel): receipent_user_id : int message : Dict + + +@asynccontextmanager +async def lifespan(app: FastAPI): + logger.info("Starting application...") + + logger.info("Creating MySQL connection pool...") + create_connection_pool() + logger.info("Connecting to RabbitMQ...") + app.state.rmq_connection = create_connection() + + yield + logger.info("Closing RabbitMQ connection...") + close_connection(app.state.rmq_connection) + logger.info("Closing MySQL connection pool...") + close_connection_pool() + + api = FastAPI( title="Internal Notifier API", description="API to forward messages to RabbitMQ", - version="1.0.0" + version="1.0.0", + lifespan=lifespan ) def verify_api_key_dependency_internal(db=Depends(get_db), api_key: str = Depends(api_key_header_internal)) -> str: @@ -49,15 +69,15 @@ async def custom_http_exception_handler(request,exc): @api.post("/internal/receive-notifications") def receive_notifications( notification_data: Notification, + request: Request, db = Depends(get_db), program_name: str = Depends(verify_api_key_dependency_internal) ): logger.info(f"Received notifcation data from {program_name} for RMQ") - send_message_to_rmq(notification_data.user_id,notification_data.message) + send_message_to_rmq(request.app.state.rmq_connection,notification_data.receipent_user_id,notification_data.message) logger.info("Successfully delivered message to RMQ") return {"status": "queued"} - if __name__ == "__main__": uvicorn.run( "main:api", diff --git a/src/rabbitmq_handler.py b/src/rabbitmq_handler.py index 74beca1..c6ee28a 100644 --- a/src/rabbitmq_handler.py +++ b/src/rabbitmq_handler.py @@ -1,7 +1,6 @@ import pika from typing import Dict from secret_handler import return_credentials -import ssl import json import time import sys @@ -16,15 +15,12 @@ rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE","app_notifications" MAX_RETRIES = 5 RETRY_DELAY = 5 -def send_message_to_rmq(user_id: int, message: Dict): + +def create_connection(): credentials = pika.PlainCredentials(username=rmq_username, password=rmq_password) - context = ssl.create_default_context() - context.check_hostname = False - ssl_options = pika.SSLOptions(context) conn_params = pika.ConnectionParameters( host=rmq_host, - port=5671, - ssl_options=ssl_options, + port=5672, credentials=credentials, virtual_host=rmq_vhost ) @@ -32,22 +28,8 @@ def send_message_to_rmq(user_id: int, message: Dict): for attempt in range(1, MAX_RETRIES + 1): try: connection = pika.BlockingConnection(conn_params) - channel = connection.channel() - channel.exchange_declare(exchange=rmq_exchange, exchange_type="topic", durable=True) - channel.confirm_delivery() - channel.basic_publish( - exchange=rmq_exchange, - routing_key=f"notify.user.{user_id}", - body=json.dumps(message), - properties=pika.BasicProperties( - content_type="application/json", - delivery_mode=2 - ), - mandatory=True - ) - connection.close() - return - + print("[RMQ] Connection established.") + return connection except Exception as e: print(f"[RMQ] Attempt {attempt} failed: {e}") if attempt < MAX_RETRIES: @@ -56,5 +38,34 @@ def send_message_to_rmq(user_id: int, message: Dict): print("[RMQ] Failed to connect after maximum retries — exiting.") sys.exit(1) + +def send_message_to_rmq(connection, user_id: int, message: Dict): + if not connection or not connection.is_open: + raise RuntimeError("RabbitMQ connection is not open") + + channel = connection.channel() + channel.exchange_declare(exchange=rmq_exchange, exchange_type="topic", durable=True) + channel.confirm_delivery() + + channel.basic_publish( + exchange=rmq_exchange, + routing_key=f"notify.user.{user_id}", + body=json.dumps(message), + properties=pika.BasicProperties( + content_type="application/json", + delivery_mode=2 + ), + mandatory=True + ) + + +def close_connection(connection): + if connection and connection.is_open: + connection.close() + print("[RMQ] Connection closed.") + + if __name__ == "__main__": - send_message_to_rmq(1, {"type": "notification", "content": "Vault TLS cert reloaded successfully."}) \ No newline at end of file + conn = create_connection() + send_message_to_rmq(conn, 1, {"type": "notification", "content": "Vault TLS cert reloaded successfully."}) + close_connection(conn)