Files
echoes-of-the-ash/api/redis_manager.py
2025-11-27 16:27:01 +01:00

456 lines
18 KiB
Python

"""
Redis Manager for Echoes of the Ashes
Handles Redis pub/sub for cross-worker communication and caching for performance.
Key Features:
- Pub/Sub channels for location broadcasts and personal messages
- Player session caching (location, HP, stats)
- Location player registry (Set of character IDs per location)
- Inventory caching with aggressive invalidation
- Combat state caching
- Disconnected player tracking
"""
import asyncio
import json
import time
import uuid
from typing import Dict, List, Optional, Set, Any, Callable
import redis.asyncio as redis
from redis.asyncio.client import PubSub
class RedisManager:
"""Manages Redis connections, pub/sub, and caching."""
def __init__(self, redis_url: str = "redis://echoes_of_the_ashes_redis:6379"):
self.redis_url = redis_url
self.redis_client: Optional[redis.Redis] = None
self.pubsub: Optional[PubSub] = None
self.worker_id = str(uuid.uuid4())[:8] # Unique worker identifier
self.subscribed_channels: Set[str] = set()
self.message_handlers: Dict[str, Callable] = {}
self._listener_task: Optional[asyncio.Task] = None
async def connect(self):
"""Establish connection to Redis."""
self.redis_client = redis.from_url(
self.redis_url,
encoding="utf-8",
decode_responses=True,
max_connections=50
)
self.pubsub = self.redis_client.pubsub()
print(f"✅ Redis connected (Worker: {self.worker_id})")
async def disconnect(self):
"""Close Redis connection and cleanup."""
if self._listener_task:
self._listener_task.cancel()
try:
await self._listener_task
except asyncio.CancelledError:
pass
if self.pubsub:
await self.pubsub.unsubscribe()
await self.pubsub.close()
if self.redis_client:
await self.redis_client.close()
print(f"🔌 Redis disconnected (Worker: {self.worker_id})")
# ==================== PUB/SUB ====================
async def subscribe_to_channels(self, channels: List[str]):
"""Subscribe to multiple channels."""
if not self.pubsub:
raise RuntimeError("Redis pubsub not initialized")
for channel in channels:
if channel not in self.subscribed_channels:
await self.pubsub.subscribe(channel)
self.subscribed_channels.add(channel)
print(f"📡 Worker {self.worker_id} subscribed to {len(channels)} channels")
async def unsubscribe_from_channel(self, channel: str):
"""Unsubscribe from a specific channel."""
if self.pubsub and channel in self.subscribed_channels:
await self.pubsub.unsubscribe(channel)
self.subscribed_channels.discard(channel)
async def publish_to_channel(self, channel: str, message: Dict[str, Any]):
"""Publish a message to a Redis channel."""
if not self.redis_client:
raise RuntimeError("Redis client not initialized")
message_data = {
"worker_id": self.worker_id,
"timestamp": time.time(),
**message
}
await self.redis_client.publish(channel, json.dumps(message_data))
async def publish_to_location(self, location_id: str, message: Dict[str, Any]):
"""Publish a message to all players in a location."""
await self.publish_to_channel(f"location:{location_id}", message)
async def publish_to_player(self, character_id: int, message: Dict[str, Any]):
"""Publish a personal message to a specific player."""
await self.publish_to_channel(f"player:{character_id}", message)
async def publish_global_broadcast(self, message: Dict[str, Any]):
"""Publish a message to all connected players."""
await self.publish_to_channel("game:broadcast", message)
async def listen_for_messages(self, handler: Callable):
"""Listen for Redis pub/sub messages and route to handler.
Args:
handler: Async function that receives (channel, message_data)
"""
if not self.pubsub:
raise RuntimeError("Redis pubsub not initialized")
print(f"👂 Worker {self.worker_id} listening for Redis messages...")
async for message in self.pubsub.listen():
if message["type"] == "message":
channel = message["channel"]
try:
data = json.loads(message["data"])
# Don't process messages from this same worker (already handled locally)
if data.get("worker_id") == self.worker_id:
continue
# Route to handler
await handler(channel, data)
except json.JSONDecodeError:
print(f"⚠️ Invalid JSON in Redis message: {message['data']}")
except Exception as e:
print(f"❌ Error handling Redis message: {e}")
def start_listener(self, handler: Callable):
"""Start background task to listen for Redis messages."""
self._listener_task = asyncio.create_task(self.listen_for_messages(handler))
# ==================== PLAYER SESSIONS ====================
async def set_player_session(self, character_id: int, session_data: Dict[str, Any], ttl: int = 1800):
"""Cache player session data (30 min TTL by default).
Args:
character_id: Player's character ID
session_data: Dict with keys like 'location_id', 'hp', 'level', etc.
ttl: Time-to-live in seconds (default 30 minutes)
"""
key = f"player:{character_id}:session"
# Convert all values to strings for Redis hash
string_data = {k: str(v) for k, v in session_data.items()}
await self.redis_client.hset(key, mapping=string_data)
await self.redis_client.expire(key, ttl)
async def get_player_session(self, character_id: int) -> Optional[Dict[str, Any]]:
"""Retrieve cached player session data."""
key = f"player:{character_id}:session"
data = await self.redis_client.hgetall(key)
if not data:
return None
# Note: Values come back as strings, convert as needed
return data
async def update_player_session_field(self, character_id: int, field: str, value: Any):
"""Update a single field in player session (e.g., HP, location)."""
key = f"player:{character_id}:session"
await self.redis_client.hset(key, field, str(value))
# Refresh TTL
await self.redis_client.expire(key, 1800)
async def delete_player_session(self, character_id: int):
"""Delete player session from cache (force reload from DB)."""
key = f"player:{character_id}:session"
await self.redis_client.delete(key)
# ==================== LOCATION PLAYER REGISTRY ====================
async def add_player_to_location(self, character_id: int, location_id: str):
"""Add player to location's player set."""
key = f"location:{location_id}:players"
await self.redis_client.sadd(key, character_id)
async def remove_player_from_location(self, character_id: int, location_id: str):
"""Remove player from location's player set."""
key = f"location:{location_id}:players"
await self.redis_client.srem(key, character_id)
async def move_player_between_locations(self, character_id: int, from_location: str, to_location: str):
"""Atomically move player from one location to another."""
pipe = self.redis_client.pipeline()
pipe.srem(f"location:{from_location}:players", character_id)
pipe.sadd(f"location:{to_location}:players", character_id)
await pipe.execute()
async def get_players_in_location(self, location_id: str) -> List[int]:
"""Get list of all player IDs in a location."""
key = f"location:{location_id}:players"
members = await self.redis_client.smembers(key)
return [int(m) for m in members]
async def is_player_in_location(self, character_id: int, location_id: str) -> bool:
"""Check if player is in a specific location."""
key = f"location:{location_id}:players"
return await self.redis_client.sismember(key, character_id)
# ==================== INVENTORY CACHING ====================
async def cache_inventory(self, character_id: int, inventory_data: List[Dict], ttl: int = 600):
"""Cache player inventory (10 min TTL).
Args:
character_id: Player's character ID
inventory_data: List of inventory items
ttl: Time-to-live in seconds (default 10 minutes)
"""
key = f"player:{character_id}:inventory"
await self.redis_client.setex(key, ttl, json.dumps(inventory_data))
async def get_cached_inventory(self, character_id: int) -> Optional[List[Dict]]:
"""Retrieve cached inventory."""
key = f"player:{character_id}:inventory"
data = await self.redis_client.get(key)
if not data:
return None
return json.loads(data)
async def invalidate_inventory(self, character_id: int):
"""Delete inventory cache (force reload from DB)."""
key = f"player:{character_id}:inventory"
await self.redis_client.delete(key)
# ==================== COMBAT STATE CACHING ====================
async def cache_combat_state(self, character_id: int, combat_data: Dict[str, Any]):
"""Cache active combat state (no expiration, deleted when combat ends).
Args:
character_id: Player's character ID
combat_data: Combat state dict (npc_id, npc_hp, turn, etc.)
"""
key = f"player:{character_id}:combat"
# Convert to strings for hash
string_data = {k: str(v) for k, v in combat_data.items()}
await self.redis_client.hset(key, mapping=string_data)
async def get_combat_state(self, character_id: int) -> Optional[Dict[str, Any]]:
"""Retrieve cached combat state."""
key = f"player:{character_id}:combat"
data = await self.redis_client.hgetall(key)
if not data:
return None
return data
async def update_combat_field(self, character_id: int, field: str, value: Any):
"""Update single field in combat state (e.g., npc_hp, turn)."""
key = f"player:{character_id}:combat"
await self.redis_client.hset(key, field, str(value))
async def delete_combat_state(self, character_id: int):
"""Delete combat state (combat ended)."""
key = f"player:{character_id}:combat"
await self.redis_client.delete(key)
# ==================== DROPPED ITEMS ====================
async def add_dropped_item(self, location_id: str, item_data: Dict[str, Any], ttl: int = 3600):
"""Add a dropped item to location's list (1 hour TTL).
Args:
location_id: Location where item was dropped
item_data: Item details (item_id, unique_item_id, timestamp, etc.)
ttl: Time-to-live in seconds (default 1 hour)
"""
key = f"location:{location_id}:dropped_items"
# Use a list to store dropped items
await self.redis_client.rpush(key, json.dumps(item_data))
await self.redis_client.expire(key, ttl)
async def get_dropped_items(self, location_id: str) -> List[Dict[str, Any]]:
"""Get all dropped items in a location."""
key = f"location:{location_id}:dropped_items"
items = await self.redis_client.lrange(key, 0, -1)
return [json.loads(item) for item in items]
async def remove_dropped_item(self, location_id: str, item_data: Dict[str, Any]):
"""Remove a specific dropped item (when picked up)."""
key = f"location:{location_id}:dropped_items"
await self.redis_client.lrem(key, 1, json.dumps(item_data))
# ==================== WORKER REGISTRY ====================
async def register_worker(self):
"""Register this worker as active."""
await self.redis_client.sadd("active_workers", self.worker_id)
# Set heartbeat timestamp
await self.redis_client.hset(
f"worker:{self.worker_id}:heartbeat",
mapping={
"timestamp": str(time.time()),
"status": "online"
}
)
async def unregister_worker(self):
"""Unregister this worker."""
await self.redis_client.srem("active_workers", self.worker_id)
await self.redis_client.delete(f"worker:{self.worker_id}:heartbeat")
async def get_active_workers(self) -> List[str]:
"""Get list of all active worker IDs."""
members = await self.redis_client.smembers("active_workers")
return list(members)
async def update_heartbeat(self):
"""Update worker heartbeat timestamp."""
await self.redis_client.hset(
f"worker:{self.worker_id}:heartbeat",
"timestamp",
str(time.time())
)
# ==================== DISTRIBUTED LOCKS ====================
async def acquire_lock(self, lock_name: str, ttl: int = 60) -> bool:
"""Acquire a distributed lock for background tasks.
Args:
lock_name: Name of the lock (e.g., "spawn_task", "regen_task")
ttl: Lock expiration in seconds (default 60s)
Returns:
True if lock acquired, False if already held by another worker
"""
key = f"lock:{lock_name}"
# SET key value NX EX ttl (only set if not exists, with expiration)
result = await self.redis_client.set(
key,
self.worker_id,
nx=True,
ex=ttl
)
return result is not None
async def release_lock(self, lock_name: str):
"""Release a distributed lock."""
key = f"lock:{lock_name}"
# Only delete if this worker owns the lock
lock_owner = await self.redis_client.get(key)
if lock_owner == self.worker_id:
await self.redis_client.delete(key)
# ==================== DISCONNECTED PLAYERS ====================
async def mark_player_disconnected(self, character_id: int):
"""Mark player as disconnected (but keep in location registry)."""
session = await self.get_player_session(character_id)
if session:
await self.update_player_session_field(character_id, "websocket_connected", "false")
await self.update_player_session_field(character_id, "disconnect_time", str(time.time()))
async def mark_player_connected(self, character_id: int):
"""Mark player as connected."""
await self.update_player_session_field(character_id, "websocket_connected", "true")
# Remove disconnect time
key = f"player:{character_id}:session"
await self.redis_client.hdel(key, "disconnect_time")
async def is_player_connected(self, character_id: int) -> bool:
"""Check if player is currently connected via WebSocket."""
session = await self.get_player_session(character_id)
if not session:
return False
return session.get("websocket_connected") == "true"
async def get_disconnect_duration(self, character_id: int) -> Optional[float]:
"""Get how long player has been disconnected (in seconds)."""
session = await self.get_player_session(character_id)
if not session or session.get("websocket_connected") == "true":
return None
disconnect_time = session.get("disconnect_time")
if not disconnect_time:
return None
return time.time() - float(disconnect_time)
async def cleanup_disconnected_player(self, character_id: int):
"""Remove disconnected player from location registry (after timeout)."""
session = await self.get_player_session(character_id)
if session:
location_id = session.get("location_id")
if location_id:
await self.remove_player_from_location(character_id, location_id)
await self.delete_player_session(character_id)
# ==================== UTILITY ====================
async def ping(self) -> bool:
"""Test Redis connection."""
try:
await self.redis_client.ping()
return True
except Exception:
return False
async def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
info = await self.redis_client.info("stats")
return {
"total_commands_processed": info.get("total_commands_processed", 0),
"instantaneous_ops_per_sec": info.get("instantaneous_ops_per_sec", 0),
"keyspace_hits": info.get("keyspace_hits", 0),
"keyspace_misses": info.get("keyspace_misses", 0),
"connected_clients": info.get("connected_clients", 0),
}
# ==================== CONNECTED PLAYERS COUNTER ====================
async def increment_connected_player(self, player_id: int):
"""Increment connection count for a player."""
key = "connected_players_counts"
await self.redis_client.hincrby(key, str(player_id), 1)
async def decrement_connected_player(self, player_id: int):
"""Decrement connection count for a player. Remove if 0."""
key = "connected_players_counts"
count = await self.redis_client.hincrby(key, str(player_id), -1)
if count <= 0:
await self.redis_client.hdel(key, str(player_id))
async def get_connected_player_count(self) -> int:
"""Get total number of unique connected players."""
key = "connected_players_counts"
return await self.redis_client.hlen(key)
# Global instance
redis_manager = RedisManager()