Improved RabbitMQ handling
- Switched to async because it offers easy automatic reconnect features on connection failure - Adjusted notification path to reflect that change - Added in memory message queue if RMQ is not reachable
This commit is contained in:
parent
c0fcdaeb4f
commit
873e7e7854
@ -1,15 +1,19 @@
|
||||
aio-pika==9.5.7
|
||||
aiormq==6.9.0
|
||||
annotated-types==0.7.0
|
||||
anyio==4.11.0
|
||||
argon2-cffi==25.1.0
|
||||
argon2-cffi-bindings==25.1.0
|
||||
cffi==2.0.0
|
||||
click==8.3.0
|
||||
fastapi==0.118.0
|
||||
fastapi==0.119.0
|
||||
h11==0.16.0
|
||||
idna==3.10
|
||||
idna==3.11
|
||||
multidict==6.7.0
|
||||
mysql-connector-python==9.4.0
|
||||
pika==1.3.2
|
||||
pamqp==3.3.0
|
||||
prometheus_client==0.23.1
|
||||
propcache==0.4.1
|
||||
pycparser==2.23
|
||||
pydantic==2.12.0
|
||||
pydantic_core==2.41.1
|
||||
@ -18,3 +22,4 @@ starlette==0.48.0
|
||||
typing-inspection==0.4.2
|
||||
typing_extensions==4.15.0
|
||||
uvicorn==0.37.0
|
||||
yarl==1.22.0
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import mysql.connector
|
||||
from mysql.connector import pooling, Error
|
||||
import threading
|
||||
from secret_handler import return_credentials
|
||||
@ -41,7 +40,7 @@ def create_connection_pool():
|
||||
try:
|
||||
logger.info(f"[MySQL] Attempt {attempt} to connect...")
|
||||
pool = pooling.MySQLConnectionPool(
|
||||
pool_name="royalroadPool",
|
||||
pool_name="ApiInternalPool",
|
||||
pool_size=5,
|
||||
pool_reset_session=True,
|
||||
**MYSQL_CONFIG
|
||||
@ -103,7 +102,6 @@ def _pool_healthcheck():
|
||||
conn = pool.get_connection()
|
||||
conn.ping(reconnect=True, attempts=3, delay=1)
|
||||
conn.close()
|
||||
logger.debug("[MySQL] Pool healthcheck OK.")
|
||||
except Error as e:
|
||||
logger.warning(f"[MySQL] Pool healthcheck failed: {e}")
|
||||
create_connection_pool()
|
||||
|
||||
69
src/main.py
69
src/main.py
@ -7,7 +7,7 @@ from pydantic import BaseModel
|
||||
from validator import verify_api_key
|
||||
from db import get_db, create_connection_pool, close_connection_pool, start_healthcheck_thread
|
||||
from logger_handler import setup_logger
|
||||
from rabbitmq_handler import send_message_to_rmq, create_connection, close_connection
|
||||
from rabbitmq_handler import RabbitMQProducer
|
||||
import uvicorn
|
||||
from uvicorn_logging_config import LOGGING_CONFIG
|
||||
from contextlib import asynccontextmanager
|
||||
@ -15,7 +15,7 @@ from metrics_server import REQUEST_COUNTER
|
||||
import asyncio
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
producer = RabbitMQProducer()
|
||||
api_key_header_internal = APIKeyHeader(name="X-API-Key-Internal")
|
||||
|
||||
class Notification(BaseModel):
|
||||
@ -30,24 +30,21 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
logger.info("Creating MySQL connection pool...")
|
||||
create_connection_pool()
|
||||
logger.info("Connecting to RabbitMQ...")
|
||||
app.state.rmq_connection = create_connection()
|
||||
|
||||
start_healthcheck_thread()
|
||||
logger.info("MySQL healthcheck thread started.")
|
||||
|
||||
logger.info("Starting RabbitMQ producer...")
|
||||
await producer.connect()
|
||||
app.state.rmq_producer = producer
|
||||
logger.info("[FastAPI] RabbitMQ producer initialized.")
|
||||
|
||||
yield
|
||||
logger.info("Closing RabbitMQ connection...")
|
||||
close_connection(app.state.rmq_connection)
|
||||
logger.info("Closing RabbitMQ producer...")
|
||||
await producer.close()
|
||||
|
||||
logger.info("Closing MySQL connection pool...")
|
||||
close_connection_pool()
|
||||
|
||||
def get_rmq_connection(app: FastAPI):
|
||||
connection = getattr(app.state, "rmq_connection", None)
|
||||
if not connection or not connection.is_open:
|
||||
app.state.rmq_connection = create_connection()
|
||||
return app.state.rmq_connection
|
||||
|
||||
api = FastAPI(
|
||||
title="Internal Notifier API",
|
||||
description="API to forward messages to RabbitMQ",
|
||||
@ -87,54 +84,24 @@ async def custom_http_exception_handler(request,exc):
|
||||
content={"detail": exc.detail}
|
||||
)
|
||||
|
||||
@api.get("/health", tags=["Health"])
|
||||
def return_health(request:Request, db=Depends(get_db)):
|
||||
try:
|
||||
cursor = db.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.fetchone()
|
||||
db_status = "ok"
|
||||
except Exception as e:
|
||||
logger.error(f"Health check DB failed: {e}")
|
||||
db_status = "error"
|
||||
|
||||
try:
|
||||
rmq_conn = getattr(request.app.state, "rmq_connection", None)
|
||||
if not rmq_conn or not rmq_conn.is_open:
|
||||
logger.error("Health check RMQ failed: connection closed or missing")
|
||||
rmq_status = "error"
|
||||
except Exception as e:
|
||||
logger.error(f"Health check RMQ failed: {e}")
|
||||
rmq_status = "error"
|
||||
|
||||
overall_status = "ok" if db_status == "ok" and rmq_status == "ok" else "error"
|
||||
status_code = 200 if overall_status == "ok" else 500
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={"status": overall_status,
|
||||
"components": {
|
||||
"database": db_status,
|
||||
"rabbitmq": rmq_status
|
||||
},
|
||||
"message": "Service is running" if overall_status == "ok" else "One or more checks failed"}
|
||||
)
|
||||
|
||||
@api.post("/internal/receive-notifications")
|
||||
def receive_notifications(
|
||||
async def receive_notifications(
|
||||
notification_data: Notification,
|
||||
request: Request,
|
||||
db = Depends(get_db),
|
||||
program_name: str = Depends(verify_api_key_dependency_internal)
|
||||
):
|
||||
rmq_connection = get_rmq_connection(request.app)
|
||||
|
||||
logger.info(f"Received notifcation data from {program_name} for RMQ")
|
||||
send_message_to_rmq(rmq_connection,notification_data.receipent_user_id,notification_data.message)
|
||||
logger.info("Successfully delivered message to RMQ")
|
||||
await request.app.state.rmq_producer.publish(
|
||||
notification_data.receipent_user_id,
|
||||
notification_data.message
|
||||
)
|
||||
logger.info("Successfully queued for delivery to RMQ")
|
||||
return {"status": "queued"}
|
||||
|
||||
async def start_servers():
|
||||
config_main = uvicorn.Config("main:api", host="0.0.0.0", port=8101, log_config=LOGGING_CONFIG, log_level="info")
|
||||
config_main = uvicorn.Config("main:api", host="0.0.0.0", port=8101, log_level="info")
|
||||
config_metrics = uvicorn.Config("metrics_server:metrics_api", host="0.0.0.0", port=9000, log_level="info")
|
||||
|
||||
server_main = uvicorn.Server(config_main)
|
||||
|
||||
@ -1,72 +1,86 @@
|
||||
import pika
|
||||
from typing import Dict
|
||||
import asyncio
|
||||
import aio_pika
|
||||
from secret_handler import return_credentials
|
||||
import json
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
from logger_handler import setup_logger
|
||||
import json
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
rmq_username = return_credentials("/etc/secrets/rmq_username")
|
||||
rmq_password = return_credentials("/etc/secrets/rmq_password")
|
||||
rmq_host = os.getenv("BACKEND_API_INTERNAL_RMQ_HOST","localhost")
|
||||
rmq_vhost = os.getenv("BACKEND_API_INTERNAL_RMQ_VHOST","app_notifications")
|
||||
rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE","app_notifications")
|
||||
rmq_host = os.getenv("BACKEND_API_INTERNAL_RMQ_HOST", "localhost")
|
||||
rmq_vhost = os.getenv("BACKEND_API_INTERNAL_RMQ_VHOST", "app_notifications")
|
||||
rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE", "app_notifications")
|
||||
|
||||
MAX_RETRIES = 5
|
||||
RETRY_DELAY = 5
|
||||
RABBITMQ_URL = f"amqp://{rmq_username}:{rmq_password}@{rmq_host}/{rmq_vhost}"
|
||||
|
||||
|
||||
def create_connection():
|
||||
credentials = pika.PlainCredentials(username=rmq_username, password=rmq_password)
|
||||
conn_params = pika.ConnectionParameters(
|
||||
host=rmq_host,
|
||||
port=5672,
|
||||
credentials=credentials,
|
||||
virtual_host=rmq_vhost
|
||||
)
|
||||
class RabbitMQProducer:
|
||||
def __init__(self, url=RABBITMQ_URL, exchange_name=rmq_exchange):
|
||||
self.url = url
|
||||
self.exchange_name = exchange_name
|
||||
self.connection: aio_pika.RobustConnection | None = None
|
||||
self.channel: aio_pika.RobustChannel | None = None
|
||||
self.exchange: aio_pika.Exchange | None = None
|
||||
self._queue: asyncio.Queue[tuple[int, dict]] = asyncio.Queue()
|
||||
self._flush_task: asyncio.Task | None = None
|
||||
self._closing = False
|
||||
self._ready = asyncio.Event()
|
||||
|
||||
for attempt in range(1, MAX_RETRIES + 1):
|
||||
try:
|
||||
connection = pika.BlockingConnection(conn_params)
|
||||
print("[RMQ] Connection established.")
|
||||
return connection
|
||||
except Exception as e:
|
||||
print(f"[RMQ] Attempt {attempt} failed: {e}")
|
||||
if attempt < MAX_RETRIES:
|
||||
time.sleep(RETRY_DELAY)
|
||||
else:
|
||||
print("[RMQ] Failed to connect after maximum retries — exiting.")
|
||||
sys.exit(1)
|
||||
async def connect(self):
|
||||
self.connection = await aio_pika.connect_robust(self.url)
|
||||
self.channel = await self.connection.channel(publisher_confirms=True)
|
||||
self.exchange = await self.channel.declare_exchange(
|
||||
self.exchange_name, aio_pika.ExchangeType.TOPIC, durable=True
|
||||
)
|
||||
logger.info(f"[aio-pika] Connected and exchange '{self.exchange_name}' ready.")
|
||||
self._ready.set()
|
||||
self._flush_task = asyncio.create_task(self._flush_queue_loop())
|
||||
|
||||
async def publish(self, routing_key: int, message: dict):
|
||||
await self._queue.put((routing_key, message))
|
||||
logger.debug(f"[RabbitMQ] Queued message for {routing_key}")
|
||||
|
||||
async def _flush_queue_loop(self):
|
||||
await self._ready.wait()
|
||||
logger.debug(f"here")
|
||||
while True:
|
||||
routing_key, message = await self._queue.get()
|
||||
try:
|
||||
await self.exchange.publish(
|
||||
aio_pika.Message(
|
||||
body=json.dumps(message).encode(),
|
||||
content_type="application/json",
|
||||
delivery_mode=aio_pika.DeliveryMode.PERSISTENT
|
||||
),
|
||||
routing_key=f"notify.user.{routing_key}",
|
||||
mandatory=True
|
||||
)
|
||||
logger.debug(f"[aio-pika] Published message to notify.user.{routing_key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[aio-pika] Publish failed, requeuing: {e}")
|
||||
await asyncio.sleep(2)
|
||||
await self._queue.put((routing_key, message))
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def close(self):
|
||||
self._closing = True
|
||||
if self._flush_task and not self._flush_task.done():
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
logger.info("[aio-pika] Connection closed.")
|
||||
|
||||
|
||||
def send_message_to_rmq(connection, user_id: int, message: Dict):
|
||||
if not connection or not connection.is_open:
|
||||
print("[RMQ] Connection lost, reconnecting...")
|
||||
connection = create_connection()
|
||||
async def main():
|
||||
producer = RabbitMQProducer()
|
||||
await producer.connect()
|
||||
await producer.publish("1", {"type": "notification", "content": "Test message"})
|
||||
await producer.close()
|
||||
|
||||
channel = connection.channel()
|
||||
channel.exchange_declare(exchange=rmq_exchange, exchange_type="topic", durable=True)
|
||||
channel.confirm_delivery()
|
||||
|
||||
channel.basic_publish(
|
||||
exchange=rmq_exchange,
|
||||
routing_key=f"notify.user.{user_id}",
|
||||
body=json.dumps(message),
|
||||
properties=pika.BasicProperties(
|
||||
content_type="application/json",
|
||||
delivery_mode=2
|
||||
),
|
||||
mandatory=True
|
||||
)
|
||||
|
||||
|
||||
def close_connection(connection):
|
||||
if connection and connection.is_open:
|
||||
connection.close()
|
||||
print("[RMQ] Connection closed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
conn = create_connection()
|
||||
send_message_to_rmq(conn, 1, {"type": "notification", "content": "Vault TLS cert reloaded successfully."})
|
||||
close_connection(conn)
|
||||
|
||||
@ -12,9 +12,9 @@ def verify_api_key(api_key: str, hashed: str) -> bool:
|
||||
return False
|
||||
|
||||
if __name__=="__main__":
|
||||
plain_key = "super-secret-api-key"
|
||||
#hashed_key = hash_api_key(plain_key)
|
||||
hashed_key = '$argon2id$v=19$m=65536,t=3,p=4$vqU+MRafVW1b8AtF+zHb0w$p1J4Gyb0jhlVtKgYyjTITxfU97YaayeS3s3qFFP5sVM'
|
||||
plain_key = "password"
|
||||
hashed_key = hash_api_key(plain_key)
|
||||
#hashed_key = '$argon2id$v=19$m=65536,t=3,p=4$vqU+MRafVW1b8AtF+zHb0w$p1J4Gyb0jhlVtKgYyjTITxfU97YaayeS3s3qFFP5sVM'
|
||||
|
||||
print("Hashed API Key:", hashed_key)
|
||||
print("Verification:", verify_api_key(plain_key, hashed_key))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user