Commit
This commit is contained in:
0
api/core/__init__.py
Normal file
0
api/core/__init__.py
Normal file
32
api/core/config.py
Normal file
32
api/core/config.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Configuration module for the API.
|
||||
All environment variables and constants are defined here.
|
||||
"""
|
||||
import os
|
||||
|
||||
# JWT Configuration
|
||||
SECRET_KEY = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production-please")
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
|
||||
|
||||
# Internal API Key (for bot communication)
|
||||
API_INTERNAL_KEY = os.getenv("API_INTERNAL_KEY", "change-this-internal-key")
|
||||
|
||||
# CORS Origins
|
||||
CORS_ORIGINS = [
|
||||
"https://staging.echoesoftheash.com",
|
||||
"http://localhost:3000",
|
||||
"http://localhost:5173"
|
||||
]
|
||||
|
||||
# Database Configuration (imported from database module)
|
||||
# DB settings are in database.py since they're tightly coupled with SQLAlchemy
|
||||
|
||||
# Image Directory
|
||||
from pathlib import Path
|
||||
IMAGES_DIR = Path(__file__).parent.parent.parent / "images"
|
||||
|
||||
# Game Constants
|
||||
MOVEMENT_COOLDOWN = 5 # seconds
|
||||
BASE_CARRYING_CAPACITY = 10.0 # kg
|
||||
BASE_VOLUME_CAPACITY = 10.0 # liters
|
||||
127
api/core/security.py
Normal file
127
api/core/security.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Security module for authentication and authorization.
|
||||
Handles JWT tokens, password hashing, and auth dependencies.
|
||||
"""
|
||||
import jwt
|
||||
import bcrypt
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
|
||||
from .config import SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES, API_INTERNAL_KEY
|
||||
from .. import database as db
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
def create_access_token(data: dict) -> str:
|
||||
"""Create a JWT access token"""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
"""Decode JWT token and return payload"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired"
|
||||
)
|
||||
except (jwt.InvalidTokenError, jwt.DecodeError, Exception):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials"
|
||||
)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using bcrypt"""
|
||||
return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
"""Verify a password against its hash"""
|
||||
return bcrypt.checkpw(password.encode('utf-8'), password_hash.encode('utf-8'))
|
||||
|
||||
|
||||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify JWT token and return current character (requires character selection).
|
||||
This is the main auth dependency for protected endpoints.
|
||||
"""
|
||||
try:
|
||||
token = credentials.credentials
|
||||
payload = decode_token(token)
|
||||
|
||||
# New system: account_id + character_id
|
||||
account_id = payload.get("account_id")
|
||||
if account_id is not None:
|
||||
character_id = payload.get("character_id")
|
||||
if character_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No character selected. Please select a character first."
|
||||
)
|
||||
|
||||
player = await db.get_player_by_id(character_id)
|
||||
if player is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Character not found"
|
||||
)
|
||||
|
||||
# Verify character belongs to account
|
||||
if player.get('account_id') != account_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Character does not belong to this account"
|
||||
)
|
||||
|
||||
return player
|
||||
|
||||
# Old system fallback: player_id (for backward compatibility)
|
||||
player_id = payload.get("player_id")
|
||||
if player_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid token: no player or character ID"
|
||||
)
|
||||
|
||||
player = await db.get_player_by_id(player_id)
|
||||
if player is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Player not found"
|
||||
)
|
||||
|
||||
return player
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired"
|
||||
)
|
||||
except (jwt.InvalidTokenError, jwt.DecodeError, Exception) as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials"
|
||||
)
|
||||
|
||||
|
||||
async def verify_internal_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""Verify internal API key for bot endpoints"""
|
||||
if credentials.credentials != API_INTERNAL_KEY:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid internal API key"
|
||||
)
|
||||
return True
|
||||
209
api/core/websockets.py
Normal file
209
api/core/websockets.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
WebSocket connection manager for real-time game updates.
|
||||
Handles WebSocket connections and Redis pub/sub for cross-worker communication.
|
||||
"""
|
||||
from typing import Dict, Optional, List
|
||||
from fastapi import WebSocket
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""
|
||||
Manages WebSocket connections for real-time game updates.
|
||||
Tracks active connections and provides methods for broadcasting messages.
|
||||
Uses Redis pub/sub for cross-worker communication.
|
||||
"""
|
||||
def __init__(self):
|
||||
# Maps player_id -> List of WebSocket connections (local to this worker only)
|
||||
self.active_connections: Dict[int, List[WebSocket]] = {}
|
||||
# Maps player_id -> username for debugging
|
||||
self.player_usernames: Dict[int, str] = {}
|
||||
# Redis manager instance (injected later)
|
||||
self.redis_manager = None
|
||||
|
||||
def set_redis_manager(self, redis_manager):
|
||||
"""Inject Redis manager after initialization."""
|
||||
self.redis_manager = redis_manager
|
||||
|
||||
async def connect(self, websocket: WebSocket, player_id: int, username: str):
|
||||
"""Accept a new WebSocket connection and track it."""
|
||||
await websocket.accept()
|
||||
|
||||
if player_id not in self.active_connections:
|
||||
self.active_connections[player_id] = []
|
||||
|
||||
self.active_connections[player_id].append(websocket)
|
||||
self.player_usernames[player_id] = username
|
||||
|
||||
# Subscribe to player's personal channel (only if first connection)
|
||||
if len(self.active_connections[player_id]) == 1 and self.redis_manager:
|
||||
await self.redis_manager.subscribe_to_channels([f"player:{player_id}"])
|
||||
await self.redis_manager.mark_player_connected(player_id)
|
||||
|
||||
logger.info(f"WebSocket connected: {username} (player_id={player_id}, worker={self.redis_manager.worker_id if self.redis_manager else 'N/A'})")
|
||||
|
||||
async def disconnect(self, player_id: int, websocket: WebSocket):
|
||||
"""Remove a WebSocket connection."""
|
||||
if player_id in self.active_connections:
|
||||
username = self.player_usernames.get(player_id, "unknown")
|
||||
|
||||
if websocket in self.active_connections[player_id]:
|
||||
self.active_connections[player_id].remove(websocket)
|
||||
|
||||
# If no more connections for this player, cleanup
|
||||
if not self.active_connections[player_id]:
|
||||
del self.active_connections[player_id]
|
||||
if player_id in self.player_usernames:
|
||||
del self.player_usernames[player_id]
|
||||
|
||||
# Unsubscribe from player's personal channel
|
||||
if self.redis_manager:
|
||||
await self.redis_manager.unsubscribe_from_channel(f"player:{player_id}")
|
||||
await self.redis_manager.mark_player_disconnected(player_id)
|
||||
|
||||
logger.info(f"All WebSockets disconnected: {username} (player_id={player_id})")
|
||||
else:
|
||||
logger.info(f"WebSocket disconnected: {username} (player_id={player_id}). Remaining connections: {len(self.active_connections[player_id])}")
|
||||
|
||||
async def send_personal_message(self, player_id: int, message: dict):
|
||||
"""Send a message to a specific player via Redis pub/sub."""
|
||||
if self.redis_manager:
|
||||
# Send locally first if player is connected to this worker
|
||||
if player_id in self.active_connections:
|
||||
await self._send_direct(player_id, message)
|
||||
else:
|
||||
# Publish to Redis (player might be on another worker)
|
||||
await self.redis_manager.publish_to_player(player_id, message)
|
||||
else:
|
||||
# Fallback to direct send (single worker mode)
|
||||
await self._send_direct(player_id, message)
|
||||
|
||||
async def _send_direct(self, player_id: int, message: dict):
|
||||
"""Directly send to local WebSocket connections."""
|
||||
if player_id in self.active_connections:
|
||||
connections = self.active_connections[player_id]
|
||||
disconnected_sockets = []
|
||||
|
||||
for websocket in connections:
|
||||
try:
|
||||
logger.debug(f"Sending {message.get('type')} to player {player_id}")
|
||||
await websocket.send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message to player {player_id}: {e}")
|
||||
disconnected_sockets.append(websocket)
|
||||
|
||||
# Cleanup failed sockets
|
||||
for ws in disconnected_sockets:
|
||||
await self.disconnect(player_id, ws)
|
||||
|
||||
async def broadcast(self, message: dict, exclude_player_id: Optional[int] = None):
|
||||
"""Broadcast a message to all connected players via Redis."""
|
||||
if self.redis_manager:
|
||||
await self.redis_manager.publish_global_broadcast(message)
|
||||
|
||||
# ALSO send to LOCAL connections immediately
|
||||
for player_id in list(self.active_connections.keys()):
|
||||
if player_id != exclude_player_id:
|
||||
await self._send_direct(player_id, message)
|
||||
else:
|
||||
# Fallback: direct broadcast to local connections
|
||||
for player_id in list(self.active_connections.keys()):
|
||||
if player_id != exclude_player_id:
|
||||
await self._send_direct(player_id, message)
|
||||
|
||||
async def send_to_location(self, location_id: str, message: dict, exclude_player_id: Optional[int] = None):
|
||||
"""Send a message to all players in a specific location via Redis pub/sub."""
|
||||
if self.redis_manager:
|
||||
# Use Redis pub/sub for cross-worker broadcast
|
||||
message_with_exclude = {
|
||||
**message,
|
||||
"exclude_player_id": exclude_player_id
|
||||
}
|
||||
await self.redis_manager.publish_to_location(location_id, message_with_exclude)
|
||||
|
||||
# ALSO send to LOCAL connections immediately (don't wait for Redis roundtrip)
|
||||
player_ids = await self.redis_manager.get_players_in_location(location_id)
|
||||
for player_id in player_ids:
|
||||
if player_id == exclude_player_id:
|
||||
continue
|
||||
if player_id in self.active_connections:
|
||||
await self._send_direct(player_id, message)
|
||||
else:
|
||||
# Fallback: Query DB and send directly (single worker mode)
|
||||
from .. import database as db
|
||||
players_in_location = await db.get_players_in_location(location_id)
|
||||
|
||||
active_players = [p for p in players_in_location if p['id'] in self.active_connections and p['id'] != exclude_player_id]
|
||||
if not active_players:
|
||||
return
|
||||
|
||||
logger.info(f"Broadcasting to location {location_id}: {message.get('type')} (excluding player {exclude_player_id})")
|
||||
|
||||
sent_count = 0
|
||||
for player in active_players:
|
||||
player_id = player['id']
|
||||
await self._send_direct(player_id, message)
|
||||
sent_count += 1
|
||||
|
||||
logger.info(f"Sent {message.get('type')} to {sent_count} players")
|
||||
|
||||
async def handle_redis_message(self, channel: str, data: dict):
|
||||
"""
|
||||
Handle incoming Redis pub/sub messages and route to local WebSocket connections.
|
||||
This method is called by RedisManager when a message arrives on a subscribed channel.
|
||||
"""
|
||||
try:
|
||||
# Extract message type and data
|
||||
message = {
|
||||
"type": data.get("type"),
|
||||
"data": data.get("data")
|
||||
}
|
||||
|
||||
# Determine routing based on channel type
|
||||
if channel.startswith("player:"):
|
||||
# Personal message to specific player
|
||||
player_id = int(channel.split(":")[1])
|
||||
if player_id in self.active_connections:
|
||||
await self._send_direct(player_id, message)
|
||||
|
||||
elif channel.startswith("location:"):
|
||||
# Broadcast to all players in location (only local connections)
|
||||
location_id = channel.split(":")[1]
|
||||
exclude_player_id = data.get("exclude_player_id")
|
||||
|
||||
# Get players from Redis location registry
|
||||
if self.redis_manager:
|
||||
player_ids = await self.redis_manager.get_players_in_location(location_id)
|
||||
|
||||
for player_id in player_ids:
|
||||
if player_id == exclude_player_id:
|
||||
continue
|
||||
|
||||
# Only send if this worker has the connection
|
||||
if player_id in self.active_connections:
|
||||
await self._send_direct(player_id, message)
|
||||
|
||||
elif channel == "game:broadcast":
|
||||
# Global broadcast to all local connections
|
||||
exclude_player_id = data.get("exclude_player_id")
|
||||
|
||||
for player_id in list(self.active_connections.keys()):
|
||||
if player_id != exclude_player_id:
|
||||
await self._send_direct(player_id, message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling Redis message on channel {channel}: {e}")
|
||||
|
||||
def has_players_in_location(self, location_id: str) -> bool:
|
||||
"""Check if there are any players with active connections in a specific location."""
|
||||
return len(self.active_connections) > 0
|
||||
|
||||
def get_connected_count(self) -> int:
|
||||
"""Get the number of active WebSocket connections."""
|
||||
return len(self.active_connections)
|
||||
|
||||
|
||||
# Global connection manager instance
|
||||
manager = ConnectionManager()
|
||||
Reference in New Issue
Block a user