from fastapi import FastAPI, Depends, HTTPException, Request, Response from fastapi.responses import JSONResponse from fastapi.security.api_key import APIKeyHeader from starlette.exceptions import HTTPException as StarletteHTTPException from typing import Dict from pydantic import BaseModel from validator import verify_api_key from db import get_db, create_connection_pool, close_connection_pool from logger_handler import setup_logger from rabbitmq_handler import send_message_to_rmq, create_connection, close_connection import uvicorn from uvicorn_logging_config import LOGGING_CONFIG from contextlib import asynccontextmanager from metrics_server import REQUEST_COUNTER import asyncio logger = setup_logger(__name__) api_key_header_internal = APIKeyHeader(name="X-API-Key-Internal") class Notification(BaseModel): receipent_user_id : int message : Dict @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Starting application...") logger.info("Creating MySQL connection pool...") create_connection_pool() logger.info("Connecting to RabbitMQ...") app.state.rmq_connection = create_connection() yield logger.info("Closing RabbitMQ connection...") close_connection(app.state.rmq_connection) logger.info("Closing MySQL connection pool...") close_connection_pool() def get_rmq_connection(app: FastAPI): connection = getattr(app.state, "rmq_connection", None) if not connection or not connection.is_open: app.state.rmq_connection = create_connection() return app.state.rmq_connection api = FastAPI( title="Internal Notifier API", description="API to forward messages to RabbitMQ", version="1.0.0", lifespan=lifespan ) @api.middleware("http") async def prometheus_middleware(request: Request, call_next): status = 500 try: response = await call_next(request) status = response.status_code except Exception: raise finally: REQUEST_COUNTER.labels(request.method, request.url.path, status).inc() return response def verify_api_key_dependency_internal(db=Depends(get_db), api_key: str = Depends(api_key_header_internal)) -> str: cursor = db.cursor() cursor.execute("SELECT program_name, api_key FROM internal_api_keys WHERE status = 'active'") for program_name, hashed_key in cursor.fetchall(): if verify_api_key(api_key=api_key, hashed=hashed_key): return program_name raise HTTPException(status_code=403, detail="Unauthorized") @api.exception_handler(StarletteHTTPException) async def custom_http_exception_handler(request,exc): if exc.status_code == 404: return JSONResponse( status_code=403, content={"detail": "Unauthorized"} ) return JSONResponse( status_code=exc.status_code, content={"detail": exc.detail} ) @api.get("/health", tags=["Health"]) def return_health(request:Request, db=Depends(get_db)): try: cursor = db.cursor() cursor.execute("SELECT 1") cursor.fetchone() db_status = "ok" except Exception as e: logger.error(f"Health check DB failed: {e}") db_status = "error" try: rmq_conn = getattr(request.app.state, "rmq_connection", None) if not rmq_conn or not rmq_conn.is_open: logger.error("Health check RMQ failed: connection closed or missing") rmq_status = "error" except Exception as e: logger.error(f"Health check RMQ failed: {e}") rmq_status = "error" overall_status = "ok" if db_status == "ok" and rmq_status == "ok" else "error" status_code = 200 if overall_status == "ok" else 500 return JSONResponse( status_code=status_code, content={"status": overall_status, "components": { "database": db_status, "rabbitmq": rmq_status }, "message": "Service is running" if overall_status == "ok" else "One or more checks failed"} ) @api.post("/internal/receive-notifications") def receive_notifications( notification_data: Notification, request: Request, db = Depends(get_db), program_name: str = Depends(verify_api_key_dependency_internal) ): rmq_connection = get_rmq_connection(request.app) logger.info(f"Received notifcation data from {program_name} for RMQ") send_message_to_rmq(rmq_connection,notification_data.receipent_user_id,notification_data.message) logger.info("Successfully delivered message to RMQ") return {"status": "queued"} async def start_servers(): config_main = uvicorn.Config("main:api", host="0.0.0.0", port=8101, log_config=LOGGING_CONFIG, log_level="info") config_metrics = uvicorn.Config("metrics_server:metrics_api", host="0.0.0.0", port=9000, log_level="info") server_main = uvicorn.Server(config_main) server_metrics = uvicorn.Server(config_metrics) await asyncio.gather(server_main.serve(), server_metrics.serve()) if __name__ == "__main__": asyncio.run(start_servers())