diff --git a/.gitignore b/.gitignore index 0dbf2f2..a61f434 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,4 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +test.py diff --git a/src/db.py b/src/db.py index 37650ec..d007d28 100644 --- a/src/db.py +++ b/src/db.py @@ -24,17 +24,20 @@ MYSQL_CONFIG = { _pool_lock = threading.Lock() _connection_pool = None + def create_connection_pool(): global _connection_pool for attempt in range(1, MAX_RETRIES+1): try: print(f"[MySQL] Attempt {attempt} to connect...") - _connection_pool = mysql.connector.pooling.MySQLConnectionPool( + pool = mysql.connector.pooling.MySQLConnectionPool( pool_name="mypool", pool_size=5, pool_reset_session=True, **MYSQL_CONFIG ) + with _pool_lock: + _connection_pool = pool print("[MySQL] Connection pool created successfully.") return except mysql.connector.Error as e: @@ -44,6 +47,15 @@ def create_connection_pool(): print(f"[MySQL] Failed to connect after {MAX_RETRIES} attempts — exiting.") sys.exit(1) + +def close_connection_pool(): + global _connection_pool + with _pool_lock: + if _connection_pool: + _connection_pool = None + print("[MySQL] Connection pool closed.") + + def get_connection_pool(): global _connection_pool with _pool_lock: @@ -51,11 +63,18 @@ def get_connection_pool(): create_connection_pool() return _connection_pool + def get_db(): pool = get_connection_pool() - conn = pool.get_connection() + try: + conn = pool.get_connection() + if not conn.is_connected(): + conn.reconnect(attempts=MAX_RETRIES, delay=RETRY_DELAY) + except Exception: + create_connection_pool() + pool = get_connection_pool() + conn = pool.get_connection() try: yield conn finally: conn.close() - diff --git a/src/main.py b/src/main.py index 6dbb713..cebbe22 100644 --- a/src/main.py +++ b/src/main.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI, Depends, HTTPException +from fastapi import FastAPI, Depends, HTTPException, Request from fastapi.responses import JSONResponse from fastapi.security.api_key import APIKeyHeader from starlette.exceptions import HTTPException as StarletteHTTPException @@ -6,12 +6,13 @@ from typing import Optional,List from pydantic import BaseModel from validator import is_valid_platform,is_valid_token,verify_api_key from secret_handler import encrypt_token -from db import get_db +from db import get_db, create_connection_pool, close_connection_pool from logger_handler import setup_logger import uuid from hashlib import sha256 import uvicorn from uvicorn_logging_config import LOGGING_CONFIG +from contextlib import asynccontextmanager @@ -31,10 +32,24 @@ class TokenRequest(BaseModel): locale : Optional[str] = None topics : Optional[List[str]] = None + +@asynccontextmanager +async def lifespan(app: FastAPI): + logger.info("Starting application...") + + logger.info("Creating MySQL connection pool...") + create_connection_pool() + + yield + logger.info("Closing MySQL connection pool...") + close_connection_pool() + + api = FastAPI( title="Device Token Management", description="API for requesting tokens", - version="1.0.0" + version="1.0.0", + lifespan=lifespan )