from datetime import timedelta import secrets from fastapi import Depends, HTTPException, Request, Response, status from sqlalchemy.orm import Session as DBSession from app.core.database import get_db from app.models.session import Session as SessionModel from app.models.user import User SESSION_COOKIE_NAME = "bacchus_session" CSRF_COOKIE_NAME = "bacchus_csrf" CSRF_HEADER_NAME = "X-CSRF-Token" SESSION_TTL = timedelta(hours=8) def _new_token() -> str: return secrets.token_urlsafe(32) # ---------- Role-Normalisierung ---------- def _normalize_role(role_raw) -> str: # Enum? -> value if hasattr(role_raw, "value"): role_raw = role_raw.value role = str(role_raw or "user").strip() if "." in role: # z.B. "UserRole.admin" role = role.split(".")[-1] return role.lower() def issue_csrf_cookie(resp: Response) -> str: token = _new_token() resp.set_cookie( key=CSRF_COOKIE_NAME, value=token, max_age=7200, secure=False, # PROD: True samesite="lax", path="/", ) return token def clear_csrf_cookie(resp: Response) -> None: resp.delete_cookie(key=CSRF_COOKIE_NAME, path="/", samesite="lax") def verify_csrf(request: Request) -> None: # CSRF nur für mutierende Methoden if request.method in ("GET", "HEAD", "OPTIONS"): return cookie = request.cookies.get(CSRF_COOKIE_NAME) header = request.headers.get(CSRF_HEADER_NAME) if cookie != header: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="CSRF check failed") def create_session(db: DBSession, user_id: int) -> str: token = _new_token() db.add(SessionModel(user_id=user_id, token=token)) db.commit() return token def set_session_cookie(resp: Response, token: str) -> None: resp.set_cookie( key=SESSION_COOKIE_NAME, value=token, httponly=True, secure=False, # PROD: True samesite="lax", max_age=int(SESSION_TTL.total_seconds()), path="/", ) def clear_session_cookie(resp: Response) -> None: resp.delete_cookie(key=SESSION_COOKIE_NAME, path="/", samesite="lax") def get_current_user(request: Request, db: DBSession = Depends(get_db)) -> User: token = request.cookies.get(SESSION_COOKIE_NAME) if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired session") session = db.query(SessionModel).filter_by(token=token).first() if not session: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired session") return session.user def requires_role(*roles: str): roles_norm = tuple(_normalize_role(r) for r in roles) def dep(user: User = Depends(get_current_user)): user_role = _normalize_role(getattr(user, "role", None)) if roles_norm and user_role not in roles_norm: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient role") return user return dep # ---- Hybrid-Gates ---- def requires_role_relaxed(*roles: str): return requires_role(*roles) def requires_role_mgmt(*roles: str): # Später hier optional Session-Typ "management" erzwingen return requires_role(*roles)