All checks were successful
Build & Publish to GHCR / build (push) Successful in 38s
196 lines
5.6 KiB
Python
196 lines
5.6 KiB
Python
from fastapi import FastAPI, Depends, HTTPException, Request
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.security.api_key import APIKeyHeader
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
from typing import Optional,List
|
|
from pydantic import BaseModel
|
|
from validator import is_valid_platform,is_valid_token,verify_api_key
|
|
from secret_handler import encrypt_token
|
|
from db import get_db, create_connection_pool, close_connection_pool
|
|
from logger_handler import setup_logger
|
|
import uuid
|
|
from hashlib import sha256
|
|
import uvicorn
|
|
from uvicorn_logging_config import LOGGING_CONFIG
|
|
from contextlib import asynccontextmanager
|
|
from metrics_server import REQUEST_COUNTER
|
|
import asyncio
|
|
|
|
|
|
logger = setup_logger(__name__)
|
|
|
|
api_key_header = APIKeyHeader(name="X-API-Key")
|
|
|
|
|
|
def hash_token(token: str) -> str:
|
|
return sha256(token.encode()).hexdigest()
|
|
|
|
class TokenRequest(BaseModel):
|
|
user_id : int
|
|
token : str
|
|
platform : str
|
|
app_ver : str
|
|
locale : Optional[str] = None
|
|
topics : Optional[List[str]] = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
logger.info("Starting application...")
|
|
|
|
logger.info("Creating MySQL connection pool...")
|
|
create_connection_pool()
|
|
|
|
yield
|
|
logger.info("Closing MySQL connection pool...")
|
|
close_connection_pool()
|
|
|
|
|
|
api = FastAPI(
|
|
title="Device Token Management",
|
|
description="API for requesting tokens",
|
|
version="1.0.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
|
|
@api.middleware("http")
|
|
async def prometheus_middleware(request: Request, call_next):
|
|
status = 500
|
|
try:
|
|
response = await call_next(request)
|
|
status = response.status_code
|
|
except Exception:
|
|
raise
|
|
finally:
|
|
REQUEST_COUNTER.labels(request.method, request.url.path, status).inc()
|
|
return response
|
|
|
|
@api.get("/health", tags=["Health"])
|
|
def return_health(db=Depends(get_db)):
|
|
try:
|
|
cursor = db.cursor()
|
|
cursor.execute("SELECT 1")
|
|
cursor.fetchone()
|
|
db_status = "ok"
|
|
except Exception as e:
|
|
logger.error(f"Health check DB failed: {e}")
|
|
db_status = "error"
|
|
|
|
return JSONResponse(
|
|
status_code=200 if db_status == "ok" else 500,
|
|
content={"status": db_status, "message": "Service is running"}
|
|
)
|
|
|
|
|
|
def verify_api_key_dependency(db=Depends(get_db), api_key: str = Depends(api_key_header)) -> int:
|
|
cursor = db.cursor()
|
|
cursor.execute("SELECT user_id, api_key FROM users WHERE status = 'active'")
|
|
for user_id, hashed_key in cursor.fetchall():
|
|
if verify_api_key(api_key=api_key, hashed=hashed_key):
|
|
return user_id
|
|
raise HTTPException(status_code=403, detail="Unauthorized here")
|
|
|
|
@api.exception_handler(StarletteHTTPException)
|
|
async def custom_http_exception_handler(request,exc):
|
|
if exc.status_code == 404:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"detail": "Unauthorized"}
|
|
)
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content={"detail": exc.detail}
|
|
)
|
|
|
|
@api.post("/register_token")
|
|
def register_token(
|
|
request_data: TokenRequest,
|
|
db = Depends(get_db),
|
|
user_id: int = Depends(verify_api_key_dependency)
|
|
):
|
|
logger.info(f"Registering token for user_id={user_id}, platform={request_data.platform}")
|
|
if not is_valid_platform(request_data.platform) or not is_valid_token(request_data.token):
|
|
raise HTTPException(status_code=403,detail="Unathorized")
|
|
|
|
secure_token = encrypt_token(request_data.token)
|
|
hashed_token = hash_token(request_data.token)
|
|
|
|
try:
|
|
cursor = db.cursor()
|
|
cursor.execute(
|
|
"SELECT * FROM device_tokens WHERE user_id=%s AND hashed_token=%s",
|
|
(user_id,hashed_token))
|
|
existing = cursor.fetchone()
|
|
|
|
if existing:
|
|
cursor.execute("""
|
|
UPDATE device_tokens
|
|
SET platform=%s, app_ver=%s, locale=%s, topics=%s, last_seen_at=NOW()
|
|
WHERE user_id=%s AND hashed_token=%s
|
|
""", (request_data.platform,
|
|
request_data.app_ver,
|
|
request_data.locale,
|
|
request_data.topics,
|
|
user_id,
|
|
hashed_token
|
|
))
|
|
else:
|
|
token_id = str(uuid.uuid4())
|
|
logger.info(f"Creating new entry user_id={user_id}, token_id={token_id}")
|
|
cursor.execute("""
|
|
INSERT INTO device_tokens
|
|
(token_id, user_id, platform, token, hashed_token, status, app_ver, locale, topics, created_at)
|
|
VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,NOW())
|
|
""", (token_id,
|
|
user_id,
|
|
request_data.platform,
|
|
secure_token,
|
|
hashed_token,
|
|
'active',
|
|
request_data.app_ver,
|
|
request_data.locale,
|
|
request_data.topics
|
|
))
|
|
|
|
db.commit()
|
|
logger.info(f"Success: Registering token for user_id={user_id}, platform={request_data.platform}")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
return {"status":"registered"}
|
|
|
|
|
|
@api.post("/unregister-token")
|
|
def unregister_token(
|
|
request_data: TokenRequest,
|
|
db = Depends(get_db),
|
|
user_id: int = Depends(verify_api_key_dependency)
|
|
):
|
|
logger.info(f"Unregistering token for user_id={user_id}, platform={request_data.platform}")
|
|
hashed_token = hash_token(request_data.token)
|
|
try:
|
|
cursor = db.cursor()
|
|
cursor.execute("""
|
|
UPDATE device_tokens
|
|
SET status=%s, last_seen_at=NOW()
|
|
WHERE hashed_token=%s
|
|
""", ('expired', hashed_token))
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
logger.info(f"Success: Unregistering token for user_id={user_id}, platform={request_data.platform}")
|
|
return {"status":"unregistered"}
|
|
|
|
async def start_servers():
|
|
config_main = uvicorn.Config("main:api", host="0.0.0.0", port=8100, log_config=LOGGING_CONFIG, log_level="info")
|
|
config_metrics = uvicorn.Config("metrics_server:metrics_api", host="0.0.0.0", port=9000, log_level="info")
|
|
|
|
server_main = uvicorn.Server(config_main)
|
|
server_metrics = uvicorn.Server(config_metrics)
|
|
|
|
await asyncio.gather(server_main.serve(), server_metrics.serve())
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(start_servers())
|