64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Iterable, Optional
|
||
|
|
|
||
|
|
from fastapi import Depends, HTTPException, status
|
||
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||
|
|
from sqlalchemy import select
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
from sqlalchemy.orm import selectinload
|
||
|
|
|
||
|
|
from backend.core.security import decode_token
|
||
|
|
from backend.db.session import get_session
|
||
|
|
from backend.modules.users.models import RoleName, User
|
||
|
|
|
||
|
|
security = HTTPBearer(auto_error=False)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class AuthUser:
|
||
|
|
user: User
|
||
|
|
role_name: str
|
||
|
|
customer_id: Optional[int]
|
||
|
|
token: str
|
||
|
|
|
||
|
|
|
||
|
|
async def get_current_user(
|
||
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||
|
|
session: AsyncSession = Depends(get_session),
|
||
|
|
) -> AuthUser:
|
||
|
|
if credentials is None:
|
||
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
||
|
|
try:
|
||
|
|
payload = decode_token(credentials.credentials)
|
||
|
|
except ValueError:
|
||
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
||
|
|
user_id = payload.get("sub")
|
||
|
|
if not user_id:
|
||
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload")
|
||
|
|
|
||
|
|
user = await session.scalar(
|
||
|
|
select(User).where(User.id == int(user_id)).options(selectinload(User.role), selectinload(User.customer))
|
||
|
|
)
|
||
|
|
if not user:
|
||
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
||
|
|
if not user.is_active:
|
||
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User disabled")
|
||
|
|
|
||
|
|
role_name = payload.get("role") or (user.role.name if user.role else "")
|
||
|
|
return AuthUser(user=user, role_name=role_name, customer_id=user.customer_id, token=credentials.credentials)
|
||
|
|
|
||
|
|
|
||
|
|
def require_roles(roles: Iterable[RoleName]):
|
||
|
|
allowed = {r.value for r in roles}
|
||
|
|
|
||
|
|
async def dependency(auth_user: AuthUser = Depends(get_current_user)) -> AuthUser:
|
||
|
|
if auth_user.role_name not in allowed:
|
||
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient role")
|
||
|
|
return auth_user
|
||
|
|
|
||
|
|
return dependency
|
||
|
|
|
||
|
|
|
||
|
|
def require_admin(auth_user: AuthUser = Depends(require_roles([RoleName.ADMIN]))) -> AuthUser:
|
||
|
|
return auth_user
|