Deployment configuration
This commit is contained in:
62
src/db.py
Normal file
62
src/db.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import mysql.connector
|
||||
from mysql.connector import pooling
|
||||
import threading
|
||||
from hvac_handler import get_secret
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
|
||||
|
||||
db_username = get_secret("secret/api/db", "username")
|
||||
db_password = get_secret("secret/api/db", "password")
|
||||
db_host = os.getenv("BACKEND_API_DB_HOST","localhost")
|
||||
db_database = os.getenv("BACKEND_API_DB_DATABASE","app")
|
||||
|
||||
MAX_RETRIES = 5
|
||||
RETRY_DELAY = 5
|
||||
|
||||
MYSQL_CONFIG = {
|
||||
"host": db_host,
|
||||
"user": db_username,
|
||||
"password": db_password,
|
||||
"database": db_database
|
||||
}
|
||||
|
||||
_pool_lock = threading.Lock()
|
||||
_connection_pool = None
|
||||
|
||||
def create_connection_pool():
|
||||
global _connection_pool
|
||||
for attempt in range(1, MAX_RETRIES+1):
|
||||
try:
|
||||
print(f"[MySQL] Attempt {attempt} to connect...")
|
||||
_connection_pool = mysql.connector.pooling.MySQLConnectionPool(
|
||||
pool_name="mypool",
|
||||
pool_size=5,
|
||||
pool_reset_session=True,
|
||||
**MYSQL_CONFIG
|
||||
)
|
||||
print("[MySQL] Connection pool created successfully.")
|
||||
return
|
||||
except mysql.connector.Error as e:
|
||||
print(f"[MySQL] Attempt {attempt} failed: {e}")
|
||||
if attempt < MAX_RETRIES:
|
||||
time.sleep(RETRY_DELAY)
|
||||
print(f"[MySQL] Failed to connect after {MAX_RETRIES} attempts — exiting.")
|
||||
sys.exit(1)
|
||||
|
||||
def get_connection_pool():
|
||||
global _connection_pool
|
||||
with _pool_lock:
|
||||
if _connection_pool is None:
|
||||
create_connection_pool
|
||||
return _connection_pool
|
||||
|
||||
def get_db():
|
||||
pool = get_connection_pool()
|
||||
conn = pool.get_connection()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
43
src/hvac_handler.py
Normal file
43
src/hvac_handler.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import hvac
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
|
||||
HVAC_AGENT_URL = os.getenv("HVAC_AGENT_URL","http://vault-agent:8201")
|
||||
|
||||
MAX_RETRIES = 5
|
||||
BACKOFF = 5
|
||||
|
||||
def get_client():
|
||||
for attempt in range(1, MAX_RETRIES+1):
|
||||
try:
|
||||
client = hvac.Client(url=HVAC_AGENT_URL)
|
||||
if client.is_authenticated():
|
||||
return client
|
||||
raise Exception("Not authenticated")
|
||||
except Exception as e:
|
||||
print(f"Vault connection failed (attempt {attempt}/{MAX_RETRIES}): {e}")
|
||||
time.sleep(BACKOFF * attempt)
|
||||
print("Vault unreachable after retries. Exiting.")
|
||||
sys.exit(1)
|
||||
|
||||
client = get_client()
|
||||
|
||||
def get_secret(path:str, key:str):
|
||||
try:
|
||||
secret = client.secrets.kv.v2.read_secret_version(
|
||||
mount_point="kv",
|
||||
path=path
|
||||
)
|
||||
return secret["data"]["data"][key]
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch secret '{path}:{key}': {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def encrypt_token(token: str) -> str:
|
||||
response = client.secrets.transit.encrypt_data(
|
||||
name='push-tokens',
|
||||
plaintext=base64.b64encode(token.encode()).decode()
|
||||
)
|
||||
return response['data']['ciphertext']
|
||||
13
src/logger_handler.py
Normal file
13
src/logger_handler.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import logging
|
||||
|
||||
def setup_logger(name: str) -> logging.Logger:
|
||||
logger = logging.getLogger(name)
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
return logger
|
||||
148
src/main.py
Normal file
148
src/main.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from fastapi import FastAPI, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security.api_key import APIKeyHeader
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from typing import Optional,List
|
||||
from pydantic import BaseModel
|
||||
from validator import is_valid_platform,is_valid_token,verify_api_key
|
||||
from hvac_handler import encrypt_token
|
||||
from db import get_db
|
||||
from logger_handler import setup_logger
|
||||
import uuid
|
||||
from hashlib import sha256
|
||||
import uvicorn
|
||||
from uvicorn_logging_config import LOGGING_CONFIG
|
||||
|
||||
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key")
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
return sha256(token.encode()).hexdigest()
|
||||
|
||||
class TokenRequest(BaseModel):
|
||||
user_id : int
|
||||
token : str
|
||||
platform : str
|
||||
app_ver : str
|
||||
locale : Optional[str] = None
|
||||
topics : Optional[List[str]] = None
|
||||
|
||||
api = FastAPI(
|
||||
title="Device Token Management",
|
||||
description="API for requesting tokens",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
|
||||
def verify_api_key_dependency(db=Depends(get_db), api_key: str = Depends(api_key_header)) -> int:
|
||||
cursor = db.cursor()
|
||||
cursor.execute("SELECT user_id, api_key FROM users WHERE status = 'active'")
|
||||
for user_id, hashed_key in cursor.fetchall():
|
||||
if verify_api_key(api_key=api_key, hashed=hashed_key):
|
||||
return user_id
|
||||
raise HTTPException(status_code=403, detail="Unauthorized here")
|
||||
|
||||
@api.exception_handler(StarletteHTTPException)
|
||||
async def custom_http_exception_handler(request,exc):
|
||||
if exc.status_code == 404:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Unauthorized"}
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"detail": exc.detail}
|
||||
)
|
||||
|
||||
@api.post("/register_token")
|
||||
def register_token(
|
||||
request_data: TokenRequest,
|
||||
db = Depends(get_db),
|
||||
user_id: int = Depends(verify_api_key_dependency)
|
||||
):
|
||||
logger.info(f"Registering token for user_id={user_id}, platform={request_data.platform}")
|
||||
if not is_valid_platform(request_data.platform) or not is_valid_token(request_data.token):
|
||||
raise HTTPException(status_code=403,detail="Unathorized")
|
||||
|
||||
secure_token = encrypt_token(request_data.token)
|
||||
hashed_token = hash_token(request_data.token)
|
||||
|
||||
try:
|
||||
cursor = db.cursor()
|
||||
cursor.execute(
|
||||
"SELECT * FROM device_tokens WHERE user_id=%s AND hashed_token=%s",
|
||||
(user_id,hashed_token))
|
||||
existing = cursor.fetchone()
|
||||
|
||||
if existing:
|
||||
cursor.execute("""
|
||||
UPDATE device_tokens
|
||||
SET platform=%s, app_ver=%s, locale=%s, topics=%s, last_seen_at=NOW()
|
||||
WHERE user_id=%s AND hashed_token=%s
|
||||
""", (request_data.platform,
|
||||
request_data.app_ver,
|
||||
request_data.locale,
|
||||
request_data.topics,
|
||||
user_id,
|
||||
hashed_token
|
||||
))
|
||||
else:
|
||||
token_id = str(uuid.uuid4())
|
||||
logger.info(f"Creating new entry user_id={user_id}, token_id={token_id}")
|
||||
cursor.execute("""
|
||||
INSERT INTO device_tokens
|
||||
(token_id, user_id, platform, token, hashed_token, status, app_ver, locale, topics, created_at)
|
||||
VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,NOW())
|
||||
""", (token_id,
|
||||
user_id,
|
||||
request_data.platform,
|
||||
secure_token,
|
||||
hashed_token,
|
||||
'active',
|
||||
request_data.app_ver,
|
||||
request_data.locale,
|
||||
request_data.topics
|
||||
))
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Success: Registering token for user_id={user_id}, platform={request_data.platform}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return {"status":"registered"}
|
||||
|
||||
|
||||
@api.post("/unregister-token")
|
||||
def unregister_token(
|
||||
request_data: TokenRequest,
|
||||
db = Depends(get_db),
|
||||
user_id: int = Depends(verify_api_key_dependency)
|
||||
):
|
||||
logger.info(f"Unregistering token for user_id={user_id}, platform={request_data.platform}")
|
||||
hashed_token = hash_token(request_data.token)
|
||||
try:
|
||||
cursor = db.cursor()
|
||||
cursor.execute("""
|
||||
UPDATE device_tokens
|
||||
SET status=%s, last_seen_at=NOW()
|
||||
WHERE hashed_token=%s
|
||||
""", ('expired', hashed_token))
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
logger.info(f"Success: Unregistering token for user_id={user_id}, platform={request_data.platform}")
|
||||
return {"status":"unregistered"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:api",
|
||||
host="0.0.0.0",
|
||||
port=8100,
|
||||
log_config=LOGGING_CONFIG,
|
||||
log_level="info"
|
||||
)
|
||||
39
src/uvicorn_logging_config.py
Normal file
39
src/uvicorn_logging_config.py
Normal file
@@ -0,0 +1,39 @@
|
||||
LOGGING_CONFIG = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {
|
||||
"format": "%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
"datefmt": "%Y-%m-%d %H:%M:%S",
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"default": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "default",
|
||||
"stream": "ext://sys.stdout"
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
"": { # root logger
|
||||
"handlers": ["default"],
|
||||
"level": "INFO",
|
||||
"propagate": False
|
||||
},
|
||||
"uvicorn": {
|
||||
"handlers": ["default"],
|
||||
"level": "INFO",
|
||||
"propagate": False
|
||||
},
|
||||
"uvicorn.error": {
|
||||
"handlers": ["default"],
|
||||
"level": "INFO",
|
||||
"propagate": False
|
||||
},
|
||||
"uvicorn.access": {
|
||||
"handlers": ["default"],
|
||||
"level": "INFO",
|
||||
"propagate": False
|
||||
}
|
||||
}
|
||||
}
|
||||
33
src/validator.py
Normal file
33
src/validator.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from argon2 import PasswordHasher
|
||||
|
||||
def is_valid_platform(platform):
|
||||
if platform not in ["ios","android","web"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_valid_token(token): #Later check for specific Firebase tokens
|
||||
"""
|
||||
Correct length
|
||||
No malicious characters
|
||||
Freshness?
|
||||
"""
|
||||
return True
|
||||
|
||||
ph = PasswordHasher()
|
||||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
return ph.hash(api_key)
|
||||
|
||||
def verify_api_key(api_key: str, hashed: str) -> bool:
|
||||
try:
|
||||
return ph.verify(hashed, api_key)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if __name__=="__main__":
|
||||
plain_key = "super-secret-api-key"
|
||||
#hashed_key = hash_api_key(plain_key)
|
||||
hashed_key = '$argon2id$v=19$m=65536,t=3,p=4$vqU+MRafVW1b8AtF+zHb0w$p1J4Gyb0jhlVtKgYyjTITxfU97YaayeS3s3qFFP5sVM'
|
||||
|
||||
print("Hashed API Key:", hashed_key)
|
||||
print("Verification:", verify_api_key(plain_key, hashed_key))
|
||||
Reference in New Issue
Block a user