2025-12-10 12:02:17 +08:00

168 lines
6.9 KiB
Python

from typing import List
from fastapi import HTTPException, status
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from backend.core.security import get_password_hash
from backend.modules.audit.models import AuditAction, AuditLog, AuditResourceType
from backend.modules.customers.models import Customer
from backend.modules.aws_accounts.models import AWSCredential, CustomerCredential
from backend.modules.users.models import Role, RoleName, User
async def list_users(session: AsyncSession, customer_filter: int | None = None) -> List[User]:
query = select(User).options(selectinload(User.role), selectinload(User.customer))
if customer_filter:
query = query.where(User.customer_id == customer_filter)
users = (await session.scalars(query)).all()
customer_ids = list({u.customer_id for u in users if u.customer_id})
cred_map: dict[int, list[str]] = {}
if customer_ids:
cred_rows = await session.execute(
select(CustomerCredential.customer_id, AWSCredential.name, AWSCredential.account_id)
.join(AWSCredential, AWSCredential.id == CustomerCredential.credential_id)
.where(CustomerCredential.customer_id.in_(customer_ids))
.where(CustomerCredential.is_allowed == 1)
)
tmp: dict[int, list[str]] = {}
for cid, name, account in cred_rows:
tmp.setdefault(cid, []).append(f"{name} ({account})")
cred_map = tmp
for user in users:
setattr(user, "role_name", user.role.name if user.role else None)
setattr(user, "customer_name", user.customer.name if user.customer else None)
setattr(user, "customer_credential_names", cred_map.get(user.customer_id, []))
return users
async def create_user(
session: AsyncSession,
username: str,
email: str | None,
password: str,
role_id: int,
customer_id: int | None,
actor: User,
) -> User:
role = await session.get(Role, role_id)
if not role:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Role not found")
if role.name == RoleName.ADMIN.value and actor.role.name != RoleName.ADMIN.value:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only ADMIN can create ADMIN")
customer = None
if role.name != RoleName.ADMIN.value and not customer_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="customer_id required")
if customer_id:
customer = await session.get(Customer, customer_id)
if not customer:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Customer not found")
user = User(
username=username,
email=email,
password_hash=get_password_hash(password),
role_id=role_id,
customer_id=customer_id,
)
session.add(user)
try:
await session.commit()
except IntegrityError as e:
await session.rollback()
msg = str(e.orig).lower()
if "uniq_username" in msg or "username" in msg:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在")
if "uniq_email" in msg or "email" in msg:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已存在")
raise
await session.refresh(user)
user.role = role
user.customer = customer
setattr(user, "role_name", role.name)
setattr(user, "customer_name", customer.name if customer_id else None)
session.add(
AuditLog(
user_id=actor.id,
customer_id=actor.customer_id,
action=AuditAction.USER_CREATE,
resource_type=AuditResourceType.USER,
resource_id=user.id,
description=f"Create user {username}",
)
)
await session.commit()
return user
async def update_user(session: AsyncSession, user_id: int, data: dict, actor: User) -> User:
user = await session.get(User, user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
if actor.role.name != RoleName.ADMIN.value and user.customer_id != actor.customer_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
if "role_id" in data:
role = await session.get(Role, data["role_id"])
if not role:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Role not found")
if role.name == RoleName.ADMIN.value and actor.role.name != RoleName.ADMIN.value:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only ADMIN can grant ADMIN")
if "customer_id" in data and actor.role.name != RoleName.ADMIN.value:
data["customer_id"] = actor.customer_id
if "password" in data and data["password"]:
user.password_hash = get_password_hash(data.pop("password"))
for field, value in data.items():
setattr(user, field, value)
try:
await session.commit()
except IntegrityError as e:
await session.rollback()
msg = str(e.orig).lower()
if "uniq_username" in msg or "username" in msg:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在")
if "uniq_email" in msg or "email" in msg:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已存在")
raise
await session.refresh(user)
role = await session.get(Role, user.role_id)
customer = await session.get(Customer, user.customer_id) if user.customer_id else None
user.role = role
user.customer = customer
setattr(user, "role_name", role.name if role else None)
setattr(user, "customer_name", customer.name if customer else None)
session.add(
AuditLog(
user_id=actor.id,
customer_id=user.customer_id,
action=AuditAction.USER_UPDATE,
resource_type=AuditResourceType.USER,
resource_id=user.id,
description=f"Update user {user.username}",
payload=data,
)
)
await session.commit()
return user
async def delete_user(session: AsyncSession, user_id: int, actor: User) -> None:
user = await session.get(User, user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
if actor.role.name != RoleName.ADMIN.value and user.customer_id != actor.customer_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
await session.delete(user)
session.add(
AuditLog(
user_id=actor.id,
customer_id=user.customer_id,
action=AuditAction.USER_DELETE,
resource_type=AuditResourceType.USER,
resource_id=user.id,
description=f"Delete user {user.username}",
)
)
await session.commit()