Progress!
This commit is contained in:
parent
d193bc05f1
commit
b0e88a1dbc
2
.gitignore
vendored
2
.gitignore
vendored
@ -167,4 +167,4 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
test.py
|
||||
|
||||
@ -8,4 +8,4 @@ WORKDIR /app
|
||||
|
||||
COPY src/ /app/
|
||||
|
||||
ENTRYPOINT ["python","main.py"]
|
||||
ENTRYPOINT ["python","rabbitmq_handler.py"]
|
||||
|
||||
32
src/db.py
32
src/db.py
@ -1,3 +1,4 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import aiomysql
|
||||
import asyncio
|
||||
from secret_handler import return_credentials
|
||||
@ -51,32 +52,19 @@ class DBManager:
|
||||
logger.debug("[DB] Healthcheck OK")
|
||||
except Exception as e:
|
||||
logger.warning(f"[DB] Healthcheck failed: {e}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self):
|
||||
if not self._pool:
|
||||
raise RuntimeError("DB pool not initialized")
|
||||
return await self._pool.acquire()
|
||||
conn = await self._pool.acquire()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
self._pool.release(conn)
|
||||
|
||||
async def release(self, conn):
|
||||
if self._pool:
|
||||
self._pool.release(conn)
|
||||
|
||||
async def execute(self, query, *args, retries=3):
|
||||
for attempt in range(1, retries + 1):
|
||||
conn = await self.acquire()
|
||||
try:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, args)
|
||||
if cur.description:
|
||||
return await cur.fetchall()
|
||||
return None
|
||||
except aiomysql.OperationalError as e:
|
||||
logger.warning(f"[DB] Query failed (attempt {attempt}/{retries}): {e}")
|
||||
await asyncio.sleep(2 ** (attempt - 1))
|
||||
finally:
|
||||
await self.release(conn)
|
||||
raise RuntimeError("DB query failed after retries")
|
||||
|
||||
async def close(self):
|
||||
self._closing = True
|
||||
if self._health_task and not self._health_task.done():
|
||||
@ -90,4 +78,6 @@ class DBManager:
|
||||
await self._pool.wait_closed()
|
||||
logger.info("[DB] Connection pool closed")
|
||||
|
||||
db_manager = DBManager(host=db_host, port=3306, user=db_username, password=db_password, db=db_database)
|
||||
#db_manager = DBManager(host=db_host, port=3306, user=db_username, password=db_password, db=db_database)
|
||||
db_manager = DBManager(host=db_host, port=30006, user=db_username, password=db_password, db=db_database)
|
||||
|
||||
|
||||
@ -1,21 +1,22 @@
|
||||
import asyncio
|
||||
import aio_pika
|
||||
from aio_pika.exceptions import AMQPException
|
||||
from secret_handler import return_credentials
|
||||
from secret_handler import return_credentials, database_lookup, decrypt_token
|
||||
import os
|
||||
from logger_handler import setup_logger
|
||||
import json
|
||||
from db import db_manager
|
||||
from send_notification import send_notification
|
||||
import signal
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
rmq_username = return_credentials("/etc/secrets/rmq_username")
|
||||
rmq_password = return_credentials("/etc/secrets/rmq_password")
|
||||
rmq_host = os.getenv("BACKEND_API_INTERNAL_RMQ_HOST", "localhost")
|
||||
rmq_vhost = os.getenv("BACKEND_API_INTERNAL_RMQ_VHOST", "app_notifications")
|
||||
rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE", "app_notifications")
|
||||
rmq_host = os.getenv("BACKEND_PN_RMQ_HOST", "localhost:30672")
|
||||
rmq_vhost = os.getenv("BACKEND_PN_RMQ_VHOST", "app_notifications")
|
||||
rmq_exchange = os.getenv("BACKEND_PN_RMQ_EXCHANGE", "app_notifications")
|
||||
|
||||
RABBITMQ_URL = f"amqp://{rmq_username}:{rmq_password}@{rmq_host}/{rmq_vhost}"
|
||||
|
||||
@ -37,7 +38,7 @@ class RabbitMQConsumer:
|
||||
self.exchange = await self.channel.declare_exchange(
|
||||
self.exchange_name, aio_pika.ExchangeType.TOPIC, durable=True
|
||||
)
|
||||
self.queue = await self.channel.declare_queue("backend_push_notifications", durable=True)
|
||||
self.queue = await self.channel.declare_queue("notifications", durable=True)
|
||||
await self.queue.bind(self.exchange, routing_key="notify.user.*")
|
||||
logger.info("[Consumer] Connected, queue bound to notify.user.*")
|
||||
|
||||
@ -48,7 +49,18 @@ class RabbitMQConsumer:
|
||||
try:
|
||||
data = json.loads(message.body.decode())
|
||||
logger.info(f"[Consumer] Received: {data}")
|
||||
await send_notification(routing_key=message.routing_key,message=data,db_manager=self.db_manager)
|
||||
logger.info(message.routing_key)
|
||||
encrypted_tokens = await database_lookup(message.routing_key, db_manager)
|
||||
if not encrypted_tokens:
|
||||
logger.warning(f"No push tokens found for user {message.routing_key}")
|
||||
return
|
||||
|
||||
token_map = {row["uuid"]: row["token"].decode() for row in encrypted_tokens}
|
||||
for uuid, token in token_map.items():
|
||||
decrypted_token = decrypt_token(token)
|
||||
token_map[uuid] = decrypted_token
|
||||
await send_notification(message=data,push_tokens=token_map)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[Consumer] Bad message, discarding: {e}")
|
||||
await message.nack(requeue=False)
|
||||
@ -76,11 +88,17 @@ class RabbitMQConsumer:
|
||||
|
||||
|
||||
async def main():
|
||||
await db_manager.connect()
|
||||
consumer = RabbitMQConsumer(db_manager=db_manager)
|
||||
await consumer.connect()
|
||||
logger.info("Creating MySQL connection pool...")
|
||||
await db_manager.connect()
|
||||
await consumer.consume()
|
||||
stop_event = asyncio.Event()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
asyncio.get_running_loop().add_signal_handler(sig, stop_event.set)
|
||||
|
||||
await stop_event.wait()
|
||||
await consumer.close()
|
||||
await db_manager.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
@ -1,12 +1,50 @@
|
||||
from cryptography.fernet import Fernet
|
||||
import sys
|
||||
import asyncio
|
||||
from logger_handler import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
try:
|
||||
with open("/etc/secrets/encryption_key","rb") as file:
|
||||
encryption_key = file.read()
|
||||
except FileNotFoundError:
|
||||
logger.fatal("[Secret Handler] Encryption key not found")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.fatal(f"[Secret Handler] Failed to read encryption key: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
fernet = Fernet(encryption_key)
|
||||
|
||||
def encrypt_token(token:str)->str:
|
||||
return fernet.encrypt(token.encode()).decode()
|
||||
|
||||
def decrypt_token(token:str)->str:
|
||||
return fernet.decrypt(token.encode()).decode()
|
||||
|
||||
def return_credentials(path: str)->str:
|
||||
try:
|
||||
with open (path) as file:
|
||||
return file.read().strip()
|
||||
except FileNotFoundError:
|
||||
print(f"[FATAL] Secret file not found: {path}")
|
||||
logger.fatal(f"[Secret Handler] Secret file not found: {path}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"[FATAL] Failed to read secret file {path}: {e}")
|
||||
logger.fatal(f"[Secret Handler] Failed to read secret file {path}: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
async def database_lookup(routing_key: str, db_manager):
|
||||
try:
|
||||
user_id = int(routing_key.split('.')[-1])
|
||||
except ValueError:
|
||||
logger.error(f"[DB] Invalid user id supplied:{routing_key}")
|
||||
return []
|
||||
|
||||
async with db_manager.acquire() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute("SELECT token_id AS uuid,token FROM device_tokens WHERE user_id=%s",
|
||||
(user_id,))
|
||||
if cur.description:
|
||||
return await cur.fetchall()
|
||||
return []
|
||||
@ -2,27 +2,20 @@ import aiohttp
|
||||
import asyncio
|
||||
from logger_handler import setup_logger
|
||||
|
||||
API_ENDPOINT="https://exp.host/fakeUSer/api/v2/push/send"
|
||||
#API_ENDPOINT="https://exp.host/fakeUSer/api/v2/push/send"
|
||||
API_ENDPOINT="http://127.0.0.1:8000/honk"
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
async def send_notification(
|
||||
routing_key: str,
|
||||
message: dict,
|
||||
db_manager,
|
||||
push_tokens,
|
||||
max_retries: int = 5,
|
||||
timeout: int = 5,
|
||||
):
|
||||
push_tokens = await database_lookup(routing_key, db_manager)
|
||||
if not push_tokens:
|
||||
logger.warning(f"No push tokens found for user {routing_key}")
|
||||
return
|
||||
|
||||
results = {}
|
||||
for token, uuid in push_tokens:
|
||||
results[token] = await _send_to_token(token, uuid, message, max_retries, timeout)
|
||||
for uuid, token in push_tokens.items():
|
||||
results[uuid] = await _send_to_token(token, uuid, message, max_retries, timeout)
|
||||
|
||||
return results
|
||||
|
||||
@ -38,7 +31,7 @@ async def _send_to_token(token: str, uuid:str , message: dict, max_retries: int,
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=timeout
|
||||
) as response:
|
||||
await response.raise_for_status()
|
||||
response.raise_for_status()
|
||||
logger.info(f"Notification sent successfully to uuid {uuid}")
|
||||
return {"status": "ok"}
|
||||
|
||||
@ -65,19 +58,3 @@ def create_payload(push_token: str, message: dict) -> dict:
|
||||
"sound": "default",
|
||||
"priority": "high"
|
||||
}
|
||||
|
||||
async def database_lookup(routing_key: str, db_manager):
|
||||
try:
|
||||
user_id = int(routing_key.split('.')[-1])
|
||||
except ValueError:
|
||||
logger.error(f"[DB] Invalid user id supplied:{routing_key}")
|
||||
return []
|
||||
|
||||
async with db_manager.acquire() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute("SELECT tokend_id AS uuid,token FROM device_tokens WHERE user_id=%s",
|
||||
(user_id,))
|
||||
if cur.description:
|
||||
return await cur.fetchall()
|
||||
return []
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user