Merge pull request 'Improved RabbitMQ handling' (#5) from feature/rabbitmq-improved-connection into main
All checks were successful
Build & Publish to GHCR / build (push) Successful in 1m5s
All checks were successful
Build & Publish to GHCR / build (push) Successful in 1m5s
Reviewed-on: #5
This commit is contained in:
commit
2be35613e2
@ -1,15 +1,19 @@
|
|||||||
|
aio-pika==9.5.7
|
||||||
|
aiormq==6.9.0
|
||||||
annotated-types==0.7.0
|
annotated-types==0.7.0
|
||||||
anyio==4.11.0
|
anyio==4.11.0
|
||||||
argon2-cffi==25.1.0
|
argon2-cffi==25.1.0
|
||||||
argon2-cffi-bindings==25.1.0
|
argon2-cffi-bindings==25.1.0
|
||||||
cffi==2.0.0
|
cffi==2.0.0
|
||||||
click==8.3.0
|
click==8.3.0
|
||||||
fastapi==0.118.0
|
fastapi==0.119.0
|
||||||
h11==0.16.0
|
h11==0.16.0
|
||||||
idna==3.10
|
idna==3.11
|
||||||
|
multidict==6.7.0
|
||||||
mysql-connector-python==9.4.0
|
mysql-connector-python==9.4.0
|
||||||
pika==1.3.2
|
pamqp==3.3.0
|
||||||
prometheus_client==0.23.1
|
prometheus_client==0.23.1
|
||||||
|
propcache==0.4.1
|
||||||
pycparser==2.23
|
pycparser==2.23
|
||||||
pydantic==2.12.0
|
pydantic==2.12.0
|
||||||
pydantic_core==2.41.1
|
pydantic_core==2.41.1
|
||||||
@ -18,3 +22,4 @@ starlette==0.48.0
|
|||||||
typing-inspection==0.4.2
|
typing-inspection==0.4.2
|
||||||
typing_extensions==4.15.0
|
typing_extensions==4.15.0
|
||||||
uvicorn==0.37.0
|
uvicorn==0.37.0
|
||||||
|
yarl==1.22.0
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
import mysql.connector
|
|
||||||
from mysql.connector import pooling, Error
|
from mysql.connector import pooling, Error
|
||||||
import threading
|
import threading
|
||||||
from secret_handler import return_credentials
|
from secret_handler import return_credentials
|
||||||
@ -41,7 +40,7 @@ def create_connection_pool():
|
|||||||
try:
|
try:
|
||||||
logger.info(f"[MySQL] Attempt {attempt} to connect...")
|
logger.info(f"[MySQL] Attempt {attempt} to connect...")
|
||||||
pool = pooling.MySQLConnectionPool(
|
pool = pooling.MySQLConnectionPool(
|
||||||
pool_name="royalroadPool",
|
pool_name="ApiInternalPool",
|
||||||
pool_size=5,
|
pool_size=5,
|
||||||
pool_reset_session=True,
|
pool_reset_session=True,
|
||||||
**MYSQL_CONFIG
|
**MYSQL_CONFIG
|
||||||
@ -103,7 +102,6 @@ def _pool_healthcheck():
|
|||||||
conn = pool.get_connection()
|
conn = pool.get_connection()
|
||||||
conn.ping(reconnect=True, attempts=3, delay=1)
|
conn.ping(reconnect=True, attempts=3, delay=1)
|
||||||
conn.close()
|
conn.close()
|
||||||
logger.debug("[MySQL] Pool healthcheck OK.")
|
|
||||||
except Error as e:
|
except Error as e:
|
||||||
logger.warning(f"[MySQL] Pool healthcheck failed: {e}")
|
logger.warning(f"[MySQL] Pool healthcheck failed: {e}")
|
||||||
create_connection_pool()
|
create_connection_pool()
|
||||||
|
|||||||
69
src/main.py
69
src/main.py
@ -7,7 +7,7 @@ from pydantic import BaseModel
|
|||||||
from validator import verify_api_key
|
from validator import verify_api_key
|
||||||
from db import get_db, create_connection_pool, close_connection_pool, start_healthcheck_thread
|
from db import get_db, create_connection_pool, close_connection_pool, start_healthcheck_thread
|
||||||
from logger_handler import setup_logger
|
from logger_handler import setup_logger
|
||||||
from rabbitmq_handler import send_message_to_rmq, create_connection, close_connection
|
from rabbitmq_handler import RabbitMQProducer
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from uvicorn_logging_config import LOGGING_CONFIG
|
from uvicorn_logging_config import LOGGING_CONFIG
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
@ -15,7 +15,7 @@ from metrics_server import REQUEST_COUNTER
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
logger = setup_logger(__name__)
|
logger = setup_logger(__name__)
|
||||||
|
producer = RabbitMQProducer()
|
||||||
api_key_header_internal = APIKeyHeader(name="X-API-Key-Internal")
|
api_key_header_internal = APIKeyHeader(name="X-API-Key-Internal")
|
||||||
|
|
||||||
class Notification(BaseModel):
|
class Notification(BaseModel):
|
||||||
@ -30,24 +30,21 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
logger.info("Creating MySQL connection pool...")
|
logger.info("Creating MySQL connection pool...")
|
||||||
create_connection_pool()
|
create_connection_pool()
|
||||||
logger.info("Connecting to RabbitMQ...")
|
|
||||||
app.state.rmq_connection = create_connection()
|
|
||||||
|
|
||||||
start_healthcheck_thread()
|
start_healthcheck_thread()
|
||||||
logger.info("MySQL healthcheck thread started.")
|
logger.info("MySQL healthcheck thread started.")
|
||||||
|
|
||||||
|
logger.info("Starting RabbitMQ producer...")
|
||||||
|
await producer.connect()
|
||||||
|
app.state.rmq_producer = producer
|
||||||
|
logger.info("[FastAPI] RabbitMQ producer initialized.")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
logger.info("Closing RabbitMQ connection...")
|
logger.info("Closing RabbitMQ producer...")
|
||||||
close_connection(app.state.rmq_connection)
|
await producer.close()
|
||||||
|
|
||||||
logger.info("Closing MySQL connection pool...")
|
logger.info("Closing MySQL connection pool...")
|
||||||
close_connection_pool()
|
close_connection_pool()
|
||||||
|
|
||||||
def get_rmq_connection(app: FastAPI):
|
|
||||||
connection = getattr(app.state, "rmq_connection", None)
|
|
||||||
if not connection or not connection.is_open:
|
|
||||||
app.state.rmq_connection = create_connection()
|
|
||||||
return app.state.rmq_connection
|
|
||||||
|
|
||||||
api = FastAPI(
|
api = FastAPI(
|
||||||
title="Internal Notifier API",
|
title="Internal Notifier API",
|
||||||
description="API to forward messages to RabbitMQ",
|
description="API to forward messages to RabbitMQ",
|
||||||
@ -87,54 +84,24 @@ async def custom_http_exception_handler(request,exc):
|
|||||||
content={"detail": exc.detail}
|
content={"detail": exc.detail}
|
||||||
)
|
)
|
||||||
|
|
||||||
@api.get("/health", tags=["Health"])
|
|
||||||
def return_health(request:Request, db=Depends(get_db)):
|
|
||||||
try:
|
|
||||||
cursor = db.cursor()
|
|
||||||
cursor.execute("SELECT 1")
|
|
||||||
cursor.fetchone()
|
|
||||||
db_status = "ok"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Health check DB failed: {e}")
|
|
||||||
db_status = "error"
|
|
||||||
|
|
||||||
try:
|
|
||||||
rmq_conn = getattr(request.app.state, "rmq_connection", None)
|
|
||||||
if not rmq_conn or not rmq_conn.is_open:
|
|
||||||
logger.error("Health check RMQ failed: connection closed or missing")
|
|
||||||
rmq_status = "error"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Health check RMQ failed: {e}")
|
|
||||||
rmq_status = "error"
|
|
||||||
|
|
||||||
overall_status = "ok" if db_status == "ok" and rmq_status == "ok" else "error"
|
|
||||||
status_code = 200 if overall_status == "ok" else 500
|
|
||||||
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status_code,
|
|
||||||
content={"status": overall_status,
|
|
||||||
"components": {
|
|
||||||
"database": db_status,
|
|
||||||
"rabbitmq": rmq_status
|
|
||||||
},
|
|
||||||
"message": "Service is running" if overall_status == "ok" else "One or more checks failed"}
|
|
||||||
)
|
|
||||||
|
|
||||||
@api.post("/internal/receive-notifications")
|
@api.post("/internal/receive-notifications")
|
||||||
def receive_notifications(
|
async def receive_notifications(
|
||||||
notification_data: Notification,
|
notification_data: Notification,
|
||||||
request: Request,
|
request: Request,
|
||||||
db = Depends(get_db),
|
db = Depends(get_db),
|
||||||
program_name: str = Depends(verify_api_key_dependency_internal)
|
program_name: str = Depends(verify_api_key_dependency_internal)
|
||||||
):
|
):
|
||||||
rmq_connection = get_rmq_connection(request.app)
|
|
||||||
logger.info(f"Received notifcation data from {program_name} for RMQ")
|
logger.info(f"Received notifcation data from {program_name} for RMQ")
|
||||||
send_message_to_rmq(rmq_connection,notification_data.receipent_user_id,notification_data.message)
|
await request.app.state.rmq_producer.publish(
|
||||||
logger.info("Successfully delivered message to RMQ")
|
notification_data.receipent_user_id,
|
||||||
|
notification_data.message
|
||||||
|
)
|
||||||
|
logger.info("Successfully queued for delivery to RMQ")
|
||||||
return {"status": "queued"}
|
return {"status": "queued"}
|
||||||
|
|
||||||
async def start_servers():
|
async def start_servers():
|
||||||
config_main = uvicorn.Config("main:api", host="0.0.0.0", port=8101, log_config=LOGGING_CONFIG, log_level="info")
|
config_main = uvicorn.Config("main:api", host="0.0.0.0", port=8101, log_level="info")
|
||||||
config_metrics = uvicorn.Config("metrics_server:metrics_api", host="0.0.0.0", port=9000, log_level="info")
|
config_metrics = uvicorn.Config("metrics_server:metrics_api", host="0.0.0.0", port=9000, log_level="info")
|
||||||
|
|
||||||
server_main = uvicorn.Server(config_main)
|
server_main = uvicorn.Server(config_main)
|
||||||
|
|||||||
@ -1,72 +1,86 @@
|
|||||||
import pika
|
import asyncio
|
||||||
from typing import Dict
|
import aio_pika
|
||||||
from secret_handler import return_credentials
|
from secret_handler import return_credentials
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
|
from logger_handler import setup_logger
|
||||||
|
import json
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
rmq_username = return_credentials("/etc/secrets/rmq_username")
|
rmq_username = return_credentials("/etc/secrets/rmq_username")
|
||||||
rmq_password = return_credentials("/etc/secrets/rmq_password")
|
rmq_password = return_credentials("/etc/secrets/rmq_password")
|
||||||
rmq_host = os.getenv("BACKEND_API_INTERNAL_RMQ_HOST","localhost")
|
rmq_host = os.getenv("BACKEND_API_INTERNAL_RMQ_HOST", "localhost")
|
||||||
rmq_vhost = os.getenv("BACKEND_API_INTERNAL_RMQ_VHOST","app_notifications")
|
rmq_vhost = os.getenv("BACKEND_API_INTERNAL_RMQ_VHOST", "app_notifications")
|
||||||
rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE","app_notifications")
|
rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE", "app_notifications")
|
||||||
|
|
||||||
MAX_RETRIES = 5
|
RABBITMQ_URL = f"amqp://{rmq_username}:{rmq_password}@{rmq_host}/{rmq_vhost}"
|
||||||
RETRY_DELAY = 5
|
|
||||||
|
|
||||||
|
|
||||||
def create_connection():
|
class RabbitMQProducer:
|
||||||
credentials = pika.PlainCredentials(username=rmq_username, password=rmq_password)
|
def __init__(self, url=RABBITMQ_URL, exchange_name=rmq_exchange):
|
||||||
conn_params = pika.ConnectionParameters(
|
self.url = url
|
||||||
host=rmq_host,
|
self.exchange_name = exchange_name
|
||||||
port=5672,
|
self.connection: aio_pika.RobustConnection | None = None
|
||||||
credentials=credentials,
|
self.channel: aio_pika.RobustChannel | None = None
|
||||||
virtual_host=rmq_vhost
|
self.exchange: aio_pika.Exchange | None = None
|
||||||
)
|
self._queue: asyncio.Queue[tuple[int, dict]] = asyncio.Queue()
|
||||||
|
self._flush_task: asyncio.Task | None = None
|
||||||
|
self._closing = False
|
||||||
|
self._ready = asyncio.Event()
|
||||||
|
|
||||||
for attempt in range(1, MAX_RETRIES + 1):
|
async def connect(self):
|
||||||
try:
|
self.connection = await aio_pika.connect_robust(self.url)
|
||||||
connection = pika.BlockingConnection(conn_params)
|
self.channel = await self.connection.channel(publisher_confirms=True)
|
||||||
print("[RMQ] Connection established.")
|
self.exchange = await self.channel.declare_exchange(
|
||||||
return connection
|
self.exchange_name, aio_pika.ExchangeType.TOPIC, durable=True
|
||||||
except Exception as e:
|
)
|
||||||
print(f"[RMQ] Attempt {attempt} failed: {e}")
|
logger.info(f"[aio-pika] Connected and exchange '{self.exchange_name}' ready.")
|
||||||
if attempt < MAX_RETRIES:
|
self._ready.set()
|
||||||
time.sleep(RETRY_DELAY)
|
self._flush_task = asyncio.create_task(self._flush_queue_loop())
|
||||||
else:
|
|
||||||
print("[RMQ] Failed to connect after maximum retries — exiting.")
|
async def publish(self, routing_key: int, message: dict):
|
||||||
sys.exit(1)
|
await self._queue.put((routing_key, message))
|
||||||
|
logger.debug(f"[RabbitMQ] Queued message for {routing_key}")
|
||||||
|
|
||||||
|
async def _flush_queue_loop(self):
|
||||||
|
await self._ready.wait()
|
||||||
|
logger.debug(f"here")
|
||||||
|
while True:
|
||||||
|
routing_key, message = await self._queue.get()
|
||||||
|
try:
|
||||||
|
await self.exchange.publish(
|
||||||
|
aio_pika.Message(
|
||||||
|
body=json.dumps(message).encode(),
|
||||||
|
content_type="application/json",
|
||||||
|
delivery_mode=aio_pika.DeliveryMode.PERSISTENT
|
||||||
|
),
|
||||||
|
routing_key=f"notify.user.{routing_key}",
|
||||||
|
mandatory=True
|
||||||
|
)
|
||||||
|
logger.debug(f"[aio-pika] Published message to notify.user.{routing_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[aio-pika] Publish failed, requeuing: {e}")
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
await self._queue.put((routing_key, message))
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
self._closing = True
|
||||||
|
if self._flush_task and not self._flush_task.done():
|
||||||
|
self._flush_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._flush_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
if self.connection:
|
||||||
|
await self.connection.close()
|
||||||
|
logger.info("[aio-pika] Connection closed.")
|
||||||
|
|
||||||
|
|
||||||
def send_message_to_rmq(connection, user_id: int, message: Dict):
|
async def main():
|
||||||
if not connection or not connection.is_open:
|
producer = RabbitMQProducer()
|
||||||
print("[RMQ] Connection lost, reconnecting...")
|
await producer.connect()
|
||||||
connection = create_connection()
|
await producer.publish("1", {"type": "notification", "content": "Test message"})
|
||||||
|
await producer.close()
|
||||||
|
|
||||||
channel = connection.channel()
|
|
||||||
channel.exchange_declare(exchange=rmq_exchange, exchange_type="topic", durable=True)
|
|
||||||
channel.confirm_delivery()
|
|
||||||
|
|
||||||
channel.basic_publish(
|
|
||||||
exchange=rmq_exchange,
|
|
||||||
routing_key=f"notify.user.{user_id}",
|
|
||||||
body=json.dumps(message),
|
|
||||||
properties=pika.BasicProperties(
|
|
||||||
content_type="application/json",
|
|
||||||
delivery_mode=2
|
|
||||||
),
|
|
||||||
mandatory=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def close_connection(connection):
|
|
||||||
if connection and connection.is_open:
|
|
||||||
connection.close()
|
|
||||||
print("[RMQ] Connection closed.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
conn = create_connection()
|
|
||||||
send_message_to_rmq(conn, 1, {"type": "notification", "content": "Vault TLS cert reloaded successfully."})
|
|
||||||
close_connection(conn)
|
|
||||||
|
|||||||
@ -12,9 +12,9 @@ def verify_api_key(api_key: str, hashed: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
plain_key = "super-secret-api-key"
|
plain_key = "password"
|
||||||
#hashed_key = hash_api_key(plain_key)
|
hashed_key = hash_api_key(plain_key)
|
||||||
hashed_key = '$argon2id$v=19$m=65536,t=3,p=4$vqU+MRafVW1b8AtF+zHb0w$p1J4Gyb0jhlVtKgYyjTITxfU97YaayeS3s3qFFP5sVM'
|
#hashed_key = '$argon2id$v=19$m=65536,t=3,p=4$vqU+MRafVW1b8AtF+zHb0w$p1J4Gyb0jhlVtKgYyjTITxfU97YaayeS3s3qFFP5sVM'
|
||||||
|
|
||||||
print("Hashed API Key:", hashed_key)
|
print("Hashed API Key:", hashed_key)
|
||||||
print("Verification:", verify_api_key(plain_key, hashed_key))
|
print("Verification:", verify_api_key(plain_key, hashed_key))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user