97 lines
3.2 KiB
Python
97 lines
3.2 KiB
Python
from datetime import timedelta
|
|
import secrets
|
|
from fastapi import Depends, HTTPException, Request, Response, status
|
|
from sqlalchemy.orm import Session as DBSession
|
|
|
|
from app.core.database import get_db
|
|
from app.models.session import Session as SessionModel
|
|
from app.models.user import User
|
|
|
|
SESSION_COOKIE_NAME = "bacchus_session"
|
|
CSRF_COOKIE_NAME = "bacchus_csrf"
|
|
CSRF_HEADER_NAME = "X-CSRF-Token"
|
|
SESSION_TTL = timedelta(hours=8)
|
|
|
|
def _new_token() -> str:
|
|
return secrets.token_urlsafe(32)
|
|
|
|
# ---------- Role-Normalisierung ----------
|
|
def _normalize_role(role_raw) -> str:
|
|
# Enum? -> value
|
|
if hasattr(role_raw, "value"):
|
|
role_raw = role_raw.value
|
|
role = str(role_raw or "user").strip()
|
|
if "." in role: # z.B. "UserRole.admin"
|
|
role = role.split(".")[-1]
|
|
return role.lower()
|
|
|
|
def issue_csrf_cookie(resp: Response) -> str:
|
|
token = _new_token()
|
|
resp.set_cookie(
|
|
key=CSRF_COOKIE_NAME,
|
|
value=token,
|
|
max_age=7200,
|
|
secure=False, # PROD: True
|
|
samesite="lax",
|
|
path="/",
|
|
)
|
|
return token
|
|
|
|
def clear_csrf_cookie(resp: Response) -> None:
|
|
resp.delete_cookie(key=CSRF_COOKIE_NAME, path="/", samesite="lax")
|
|
|
|
def verify_csrf(request: Request) -> None:
|
|
# CSRF nur für mutierende Methoden
|
|
if request.method in ("GET", "HEAD", "OPTIONS"):
|
|
return
|
|
cookie = request.cookies.get(CSRF_COOKIE_NAME)
|
|
header = request.headers.get(CSRF_HEADER_NAME)
|
|
if cookie != header:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="CSRF check failed")
|
|
|
|
def create_session(db: DBSession, user_id: int) -> str:
|
|
token = _new_token()
|
|
db.add(SessionModel(user_id=user_id, token=token))
|
|
db.commit()
|
|
return token
|
|
|
|
def set_session_cookie(resp: Response, token: str) -> None:
|
|
resp.set_cookie(
|
|
key=SESSION_COOKIE_NAME,
|
|
value=token,
|
|
httponly=True,
|
|
secure=False, # PROD: True
|
|
samesite="lax",
|
|
max_age=int(SESSION_TTL.total_seconds()),
|
|
path="/",
|
|
)
|
|
|
|
def clear_session_cookie(resp: Response) -> None:
|
|
resp.delete_cookie(key=SESSION_COOKIE_NAME, path="/", samesite="lax")
|
|
|
|
def get_current_user(request: Request, db: DBSession = Depends(get_db)) -> User:
|
|
token = request.cookies.get(SESSION_COOKIE_NAME)
|
|
if not token:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired session")
|
|
session = db.query(SessionModel).filter_by(token=token).first()
|
|
if not session:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired session")
|
|
return session.user
|
|
|
|
def requires_role(*roles: str):
|
|
roles_norm = tuple(_normalize_role(r) for r in roles)
|
|
def dep(user: User = Depends(get_current_user)):
|
|
user_role = _normalize_role(getattr(user, "role", None))
|
|
if roles_norm and user_role not in roles_norm:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient role")
|
|
return user
|
|
return dep
|
|
|
|
# ---- Hybrid-Gates ----
|
|
def requires_role_relaxed(*roles: str):
|
|
return requires_role(*roles)
|
|
|
|
def requires_role_mgmt(*roles: str):
|
|
# Später hier optional Session-Typ "management" erzwingen
|
|
return requires_role(*roles)
|