144 lines
4.6 KiB
Python

from fastapi import FastAPI, Depends, HTTPException, Request, Response
from fastapi.responses import JSONResponse
from fastapi.security.api_key import APIKeyHeader
from starlette.exceptions import HTTPException as StarletteHTTPException
from typing import Dict
from pydantic import BaseModel
from validator import verify_api_key
from db import get_db, create_connection_pool, close_connection_pool
from logger_handler import setup_logger
from rabbitmq_handler import send_message_to_rmq, create_connection, close_connection
import uvicorn
from uvicorn_logging_config import LOGGING_CONFIG
from contextlib import asynccontextmanager
from metrics_server import REQUEST_COUNTER
import asyncio
logger = setup_logger(__name__)
api_key_header_internal = APIKeyHeader(name="X-API-Key-Internal")
class Notification(BaseModel):
receipent_user_id : int
message : Dict
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting application...")
logger.info("Creating MySQL connection pool...")
create_connection_pool()
logger.info("Connecting to RabbitMQ...")
app.state.rmq_connection = create_connection()
yield
logger.info("Closing RabbitMQ connection...")
close_connection(app.state.rmq_connection)
logger.info("Closing MySQL connection pool...")
close_connection_pool()
def get_rmq_connection(app: FastAPI):
connection = getattr(app.state, "rmq_connection", None)
if not connection or not connection.is_open:
app.state.rmq_connection = create_connection()
return app.state.rmq_connection
api = FastAPI(
title="Internal Notifier API",
description="API to forward messages to RabbitMQ",
version="1.0.0",
lifespan=lifespan
)
@api.middleware("http")
async def prometheus_middleware(request: Request, call_next):
status = 500
try:
response = await call_next(request)
status = response.status_code
except Exception:
raise
finally:
REQUEST_COUNTER.labels(request.method, request.url.path, status).inc()
return response
def verify_api_key_dependency_internal(db=Depends(get_db), api_key: str = Depends(api_key_header_internal)) -> str:
cursor = db.cursor()
cursor.execute("SELECT program_name, api_key FROM internal_api_keys WHERE status = 'active'")
for program_name, hashed_key in cursor.fetchall():
if verify_api_key(api_key=api_key, hashed=hashed_key):
return program_name
raise HTTPException(status_code=403, detail="Unauthorized")
@api.exception_handler(StarletteHTTPException)
async def custom_http_exception_handler(request,exc):
if exc.status_code == 404:
return JSONResponse(
status_code=403,
content={"detail": "Unauthorized"}
)
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.detail}
)
@api.get("/health", tags=["Health"])
def return_health(request:Request, db=Depends(get_db)):
try:
cursor = db.cursor()
cursor.execute("SELECT 1")
cursor.fetchone()
db_status = "ok"
except Exception as e:
logger.error(f"Health check DB failed: {e}")
db_status = "error"
try:
rmq_conn = getattr(request.app.state, "rmq_connection", None)
if not rmq_conn or not rmq_conn.is_open:
logger.error("Health check RMQ failed: connection closed or missing")
rmq_status = "error"
except Exception as e:
logger.error(f"Health check RMQ failed: {e}")
rmq_status = "error"
overall_status = "ok" if db_status == "ok" and rmq_status == "ok" else "error"
status_code = 200 if overall_status == "ok" else 500
return JSONResponse(
status_code=status_code,
content={"status": overall_status,
"components": {
"database": db_status,
"rabbitmq": rmq_status
},
"message": "Service is running" if overall_status == "ok" else "One or more checks failed"}
)
@api.post("/internal/receive-notifications")
def receive_notifications(
notification_data: Notification,
request: Request,
db = Depends(get_db),
program_name: str = Depends(verify_api_key_dependency_internal)
):
rmq_connection = get_rmq_connection(request.app)
logger.info(f"Received notifcation data from {program_name} for RMQ")
send_message_to_rmq(rmq_connection,notification_data.receipent_user_id,notification_data.message)
logger.info("Successfully delivered message to RMQ")
return {"status": "queued"}
async def start_servers():
config_main = uvicorn.Config("main:api", host="0.0.0.0", port=8101, log_config=LOGGING_CONFIG, log_level="info")
config_metrics = uvicorn.Config("metrics_server:metrics_api", host="0.0.0.0", port=9000, log_level="info")
server_main = uvicorn.Server(config_main)
server_metrics = uvicorn.Server(config_metrics)
await asyncio.gather(server_main.serve(), server_metrics.serve())
if __name__ == "__main__":
asyncio.run(start_servers())