from fastapi import FastAPI, Depends, HTTPException, Request from fastapi.responses import JSONResponse from fastapi.security.api_key import APIKeyHeader from starlette.exceptions import HTTPException as StarletteHTTPException from typing import Optional,List from pydantic import BaseModel from validator import is_valid_platform,is_valid_token,verify_api_key from secret_handler import encrypt_token from db import get_db, create_connection_pool, close_connection_pool from logger_handler import setup_logger import uuid from hashlib import sha256 import uvicorn from uvicorn_logging_config import LOGGING_CONFIG from contextlib import asynccontextmanager from metrics_server import REQUEST_COUNTER import asyncio logger = setup_logger(__name__) api_key_header = APIKeyHeader(name="X-API-Key") def hash_token(token: str) -> str: return sha256(token.encode()).hexdigest() class TokenRequest(BaseModel): user_id : int token : str platform : str app_ver : str locale : Optional[str] = None topics : Optional[List[str]] = None @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Starting application...") logger.info("Creating MySQL connection pool...") create_connection_pool() yield logger.info("Closing MySQL connection pool...") close_connection_pool() api = FastAPI( title="Device Token Management", description="API for requesting tokens", version="1.0.0", lifespan=lifespan ) @api.middleware("http") async def prometheus_middleware(request: Request, call_next): status = 500 try: response = await call_next(request) status = response.status_code except Exception: raise finally: REQUEST_COUNTER.labels(request.method, request.url.path, status).inc() return response @api.get("/health", tags=["Health"]) def return_health(db=Depends(get_db)): try: cursor = db.cursor() cursor.execute("SELECT 1") cursor.fetchone() db_status = "ok" except Exception as e: logger.error(f"Health check DB failed: {e}") db_status = "error" return JSONResponse( status_code=200 if db_status == "ok" else 500, content={"status": db_status, "message": "Service is running"} ) def verify_api_key_dependency(db=Depends(get_db), api_key: str = Depends(api_key_header)) -> int: cursor = db.cursor() cursor.execute("SELECT user_id, api_key FROM users WHERE status = 'active'") for user_id, hashed_key in cursor.fetchall(): if verify_api_key(api_key=api_key, hashed=hashed_key): return user_id raise HTTPException(status_code=403, detail="Unauthorized here") @api.exception_handler(StarletteHTTPException) async def custom_http_exception_handler(request,exc): if exc.status_code == 404: return JSONResponse( status_code=401, content={"detail": "Unauthorized"} ) return JSONResponse( status_code=exc.status_code, content={"detail": exc.detail} ) @api.post("/register_token") def register_token( request_data: TokenRequest, db = Depends(get_db), user_id: int = Depends(verify_api_key_dependency) ): logger.info(f"Registering token for user_id={user_id}, platform={request_data.platform}") if not is_valid_platform(request_data.platform) or not is_valid_token(request_data.token): raise HTTPException(status_code=403,detail="Unathorized") secure_token = encrypt_token(request_data.token) hashed_token = hash_token(request_data.token) try: cursor = db.cursor() cursor.execute( "SELECT * FROM device_tokens WHERE user_id=%s AND hashed_token=%s", (user_id,hashed_token)) existing = cursor.fetchone() if existing: cursor.execute(""" UPDATE device_tokens SET platform=%s, app_ver=%s, locale=%s, topics=%s, last_seen_at=NOW() WHERE user_id=%s AND hashed_token=%s """, (request_data.platform, request_data.app_ver, request_data.locale, request_data.topics, user_id, hashed_token )) else: token_id = str(uuid.uuid4()) logger.info(f"Creating new entry user_id={user_id}, token_id={token_id}") cursor.execute(""" INSERT INTO device_tokens (token_id, user_id, platform, token, hashed_token, status, app_ver, locale, topics, created_at) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,NOW()) """, (token_id, user_id, request_data.platform, secure_token, hashed_token, 'active', request_data.app_ver, request_data.locale, request_data.topics )) db.commit() logger.info(f"Success: Registering token for user_id={user_id}, platform={request_data.platform}") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) return {"status":"registered"} @api.post("/unregister-token") def unregister_token( request_data: TokenRequest, db = Depends(get_db), user_id: int = Depends(verify_api_key_dependency) ): logger.info(f"Unregistering token for user_id={user_id}, platform={request_data.platform}") hashed_token = hash_token(request_data.token) try: cursor = db.cursor() cursor.execute(""" UPDATE device_tokens SET status=%s, last_seen_at=NOW() WHERE hashed_token=%s """, ('expired', hashed_token)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) logger.info(f"Success: Unregistering token for user_id={user_id}, platform={request_data.platform}") return {"status":"unregistered"} async def start_servers(): config_main = uvicorn.Config("main:api", host="0.0.0.0", port=8100, log_config=LOGGING_CONFIG, log_level="info") config_metrics = uvicorn.Config("metrics_server:metrics_api", host="0.0.0.0", port=9000, log_level="info") server_main = uvicorn.Server(config_main) server_metrics = uvicorn.Server(config_metrics) await asyncio.gather(server_main.serve(), server_metrics.serve()) if __name__ == "__main__": asyncio.run(start_servers())