Progress!
This commit is contained in:
parent
d193bc05f1
commit
b0e88a1dbc
2
.gitignore
vendored
2
.gitignore
vendored
@ -167,4 +167,4 @@ cython_debug/
|
|||||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
test.py
|
||||||
|
|||||||
@ -8,4 +8,4 @@ WORKDIR /app
|
|||||||
|
|
||||||
COPY src/ /app/
|
COPY src/ /app/
|
||||||
|
|
||||||
ENTRYPOINT ["python","main.py"]
|
ENTRYPOINT ["python","rabbitmq_handler.py"]
|
||||||
|
|||||||
30
src/db.py
30
src/db.py
@ -1,3 +1,4 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
import aiomysql
|
import aiomysql
|
||||||
import asyncio
|
import asyncio
|
||||||
from secret_handler import return_credentials
|
from secret_handler import return_credentials
|
||||||
@ -52,31 +53,18 @@ class DBManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[DB] Healthcheck failed: {e}")
|
logger.warning(f"[DB] Healthcheck failed: {e}")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
async def acquire(self):
|
async def acquire(self):
|
||||||
if not self._pool:
|
conn = await self._pool.acquire()
|
||||||
raise RuntimeError("DB pool not initialized")
|
try:
|
||||||
return await self._pool.acquire()
|
yield conn
|
||||||
|
finally:
|
||||||
|
self._pool.release(conn)
|
||||||
|
|
||||||
async def release(self, conn):
|
async def release(self, conn):
|
||||||
if self._pool:
|
if self._pool:
|
||||||
self._pool.release(conn)
|
self._pool.release(conn)
|
||||||
|
|
||||||
async def execute(self, query, *args, retries=3):
|
|
||||||
for attempt in range(1, retries + 1):
|
|
||||||
conn = await self.acquire()
|
|
||||||
try:
|
|
||||||
async with conn.cursor() as cur:
|
|
||||||
await cur.execute(query, args)
|
|
||||||
if cur.description:
|
|
||||||
return await cur.fetchall()
|
|
||||||
return None
|
|
||||||
except aiomysql.OperationalError as e:
|
|
||||||
logger.warning(f"[DB] Query failed (attempt {attempt}/{retries}): {e}")
|
|
||||||
await asyncio.sleep(2 ** (attempt - 1))
|
|
||||||
finally:
|
|
||||||
await self.release(conn)
|
|
||||||
raise RuntimeError("DB query failed after retries")
|
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
self._closing = True
|
self._closing = True
|
||||||
if self._health_task and not self._health_task.done():
|
if self._health_task and not self._health_task.done():
|
||||||
@ -90,4 +78,6 @@ class DBManager:
|
|||||||
await self._pool.wait_closed()
|
await self._pool.wait_closed()
|
||||||
logger.info("[DB] Connection pool closed")
|
logger.info("[DB] Connection pool closed")
|
||||||
|
|
||||||
db_manager = DBManager(host=db_host, port=3306, user=db_username, password=db_password, db=db_database)
|
#db_manager = DBManager(host=db_host, port=3306, user=db_username, password=db_password, db=db_database)
|
||||||
|
db_manager = DBManager(host=db_host, port=30006, user=db_username, password=db_password, db=db_database)
|
||||||
|
|
||||||
|
|||||||
@ -1,21 +1,22 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import aio_pika
|
import aio_pika
|
||||||
from aio_pika.exceptions import AMQPException
|
from aio_pika.exceptions import AMQPException
|
||||||
from secret_handler import return_credentials
|
from secret_handler import return_credentials, database_lookup, decrypt_token
|
||||||
import os
|
import os
|
||||||
from logger_handler import setup_logger
|
from logger_handler import setup_logger
|
||||||
import json
|
import json
|
||||||
from db import db_manager
|
from db import db_manager
|
||||||
from send_notification import send_notification
|
from send_notification import send_notification
|
||||||
|
import signal
|
||||||
|
|
||||||
logger = setup_logger(__name__)
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
rmq_username = return_credentials("/etc/secrets/rmq_username")
|
rmq_username = return_credentials("/etc/secrets/rmq_username")
|
||||||
rmq_password = return_credentials("/etc/secrets/rmq_password")
|
rmq_password = return_credentials("/etc/secrets/rmq_password")
|
||||||
rmq_host = os.getenv("BACKEND_API_INTERNAL_RMQ_HOST", "localhost")
|
rmq_host = os.getenv("BACKEND_PN_RMQ_HOST", "localhost:30672")
|
||||||
rmq_vhost = os.getenv("BACKEND_API_INTERNAL_RMQ_VHOST", "app_notifications")
|
rmq_vhost = os.getenv("BACKEND_PN_RMQ_VHOST", "app_notifications")
|
||||||
rmq_exchange = os.getenv("BACKEND_API_INTERNAL_RMQ_EXCHANGE", "app_notifications")
|
rmq_exchange = os.getenv("BACKEND_PN_RMQ_EXCHANGE", "app_notifications")
|
||||||
|
|
||||||
RABBITMQ_URL = f"amqp://{rmq_username}:{rmq_password}@{rmq_host}/{rmq_vhost}"
|
RABBITMQ_URL = f"amqp://{rmq_username}:{rmq_password}@{rmq_host}/{rmq_vhost}"
|
||||||
|
|
||||||
@ -37,7 +38,7 @@ class RabbitMQConsumer:
|
|||||||
self.exchange = await self.channel.declare_exchange(
|
self.exchange = await self.channel.declare_exchange(
|
||||||
self.exchange_name, aio_pika.ExchangeType.TOPIC, durable=True
|
self.exchange_name, aio_pika.ExchangeType.TOPIC, durable=True
|
||||||
)
|
)
|
||||||
self.queue = await self.channel.declare_queue("backend_push_notifications", durable=True)
|
self.queue = await self.channel.declare_queue("notifications", durable=True)
|
||||||
await self.queue.bind(self.exchange, routing_key="notify.user.*")
|
await self.queue.bind(self.exchange, routing_key="notify.user.*")
|
||||||
logger.info("[Consumer] Connected, queue bound to notify.user.*")
|
logger.info("[Consumer] Connected, queue bound to notify.user.*")
|
||||||
|
|
||||||
@ -48,7 +49,18 @@ class RabbitMQConsumer:
|
|||||||
try:
|
try:
|
||||||
data = json.loads(message.body.decode())
|
data = json.loads(message.body.decode())
|
||||||
logger.info(f"[Consumer] Received: {data}")
|
logger.info(f"[Consumer] Received: {data}")
|
||||||
await send_notification(routing_key=message.routing_key,message=data,db_manager=self.db_manager)
|
logger.info(message.routing_key)
|
||||||
|
encrypted_tokens = await database_lookup(message.routing_key, db_manager)
|
||||||
|
if not encrypted_tokens:
|
||||||
|
logger.warning(f"No push tokens found for user {message.routing_key}")
|
||||||
|
return
|
||||||
|
|
||||||
|
token_map = {row["uuid"]: row["token"].decode() for row in encrypted_tokens}
|
||||||
|
for uuid, token in token_map.items():
|
||||||
|
decrypted_token = decrypt_token(token)
|
||||||
|
token_map[uuid] = decrypted_token
|
||||||
|
await send_notification(message=data,push_tokens=token_map)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"[Consumer] Bad message, discarding: {e}")
|
logger.error(f"[Consumer] Bad message, discarding: {e}")
|
||||||
await message.nack(requeue=False)
|
await message.nack(requeue=False)
|
||||||
@ -76,11 +88,17 @@ class RabbitMQConsumer:
|
|||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
await db_manager.connect()
|
||||||
consumer = RabbitMQConsumer(db_manager=db_manager)
|
consumer = RabbitMQConsumer(db_manager=db_manager)
|
||||||
await consumer.connect()
|
await consumer.connect()
|
||||||
logger.info("Creating MySQL connection pool...")
|
await consumer.consume()
|
||||||
await db_manager.connect()
|
stop_event = asyncio.Event()
|
||||||
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||||
|
asyncio.get_running_loop().add_signal_handler(sig, stop_event.set)
|
||||||
|
|
||||||
|
await stop_event.wait()
|
||||||
|
await consumer.close()
|
||||||
|
await db_manager.close()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@ -1,12 +1,50 @@
|
|||||||
|
from cryptography.fernet import Fernet
|
||||||
import sys
|
import sys
|
||||||
|
import asyncio
|
||||||
|
from logger_handler import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open("/etc/secrets/encryption_key","rb") as file:
|
||||||
|
encryption_key = file.read()
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.fatal("[Secret Handler] Encryption key not found")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.fatal(f"[Secret Handler] Failed to read encryption key: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
fernet = Fernet(encryption_key)
|
||||||
|
|
||||||
|
def encrypt_token(token:str)->str:
|
||||||
|
return fernet.encrypt(token.encode()).decode()
|
||||||
|
|
||||||
|
def decrypt_token(token:str)->str:
|
||||||
|
return fernet.decrypt(token.encode()).decode()
|
||||||
|
|
||||||
def return_credentials(path: str)->str:
|
def return_credentials(path: str)->str:
|
||||||
try:
|
try:
|
||||||
with open (path) as file:
|
with open (path) as file:
|
||||||
return file.read().strip()
|
return file.read().strip()
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print(f"[FATAL] Secret file not found: {path}")
|
logger.fatal(f"[Secret Handler] Secret file not found: {path}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[FATAL] Failed to read secret file {path}: {e}")
|
logger.fatal(f"[Secret Handler] Failed to read secret file {path}: {e}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
async def database_lookup(routing_key: str, db_manager):
|
||||||
|
try:
|
||||||
|
user_id = int(routing_key.split('.')[-1])
|
||||||
|
except ValueError:
|
||||||
|
logger.error(f"[DB] Invalid user id supplied:{routing_key}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async with db_manager.acquire() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute("SELECT token_id AS uuid,token FROM device_tokens WHERE user_id=%s",
|
||||||
|
(user_id,))
|
||||||
|
if cur.description:
|
||||||
|
return await cur.fetchall()
|
||||||
|
return []
|
||||||
@ -2,27 +2,20 @@ import aiohttp
|
|||||||
import asyncio
|
import asyncio
|
||||||
from logger_handler import setup_logger
|
from logger_handler import setup_logger
|
||||||
|
|
||||||
API_ENDPOINT="https://exp.host/fakeUSer/api/v2/push/send"
|
#API_ENDPOINT="https://exp.host/fakeUSer/api/v2/push/send"
|
||||||
|
API_ENDPOINT="http://127.0.0.1:8000/honk"
|
||||||
logger = setup_logger(__name__)
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def send_notification(
|
async def send_notification(
|
||||||
routing_key: str,
|
|
||||||
message: dict,
|
message: dict,
|
||||||
db_manager,
|
push_tokens,
|
||||||
max_retries: int = 5,
|
max_retries: int = 5,
|
||||||
timeout: int = 5,
|
timeout: int = 5,
|
||||||
):
|
):
|
||||||
push_tokens = await database_lookup(routing_key, db_manager)
|
|
||||||
if not push_tokens:
|
|
||||||
logger.warning(f"No push tokens found for user {routing_key}")
|
|
||||||
return
|
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
for token, uuid in push_tokens:
|
for uuid, token in push_tokens.items():
|
||||||
results[token] = await _send_to_token(token, uuid, message, max_retries, timeout)
|
results[uuid] = await _send_to_token(token, uuid, message, max_retries, timeout)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -38,7 +31,7 @@ async def _send_to_token(token: str, uuid:str , message: dict, max_retries: int,
|
|||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
) as response:
|
) as response:
|
||||||
await response.raise_for_status()
|
response.raise_for_status()
|
||||||
logger.info(f"Notification sent successfully to uuid {uuid}")
|
logger.info(f"Notification sent successfully to uuid {uuid}")
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
@ -65,19 +58,3 @@ def create_payload(push_token: str, message: dict) -> dict:
|
|||||||
"sound": "default",
|
"sound": "default",
|
||||||
"priority": "high"
|
"priority": "high"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def database_lookup(routing_key: str, db_manager):
|
|
||||||
try:
|
|
||||||
user_id = int(routing_key.split('.')[-1])
|
|
||||||
except ValueError:
|
|
||||||
logger.error(f"[DB] Invalid user id supplied:{routing_key}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async with db_manager.acquire() as conn:
|
|
||||||
async with conn.cursor() as cur:
|
|
||||||
await cur.execute("SELECT tokend_id AS uuid,token FROM device_tokens WHERE user_id=%s",
|
|
||||||
(user_id,))
|
|
||||||
if cur.description:
|
|
||||||
return await cur.fetchall()
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user