322 lines
11 KiB
Python
322 lines
11 KiB
Python
|
|
# app/api/routes/authentication.py
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
from typing import Optional, Any, TYPE_CHECKING
|
|||
|
|
from datetime import datetime, timedelta
|
|||
|
|
|
|||
|
|
from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
|
|||
|
|
from starlette.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED
|
|||
|
|
|
|||
|
|
from app.api.dependencies.database import get_repository
|
|||
|
|
from app.core.config import get_app_settings
|
|||
|
|
from app.core.settings.app import AppSettings
|
|||
|
|
from app.db.errors import EntityDoesNotExist
|
|||
|
|
from app.db.repositories.users import UsersRepository
|
|||
|
|
|
|||
|
|
# 条件导入:运行期可能没有 email_codes 仓库
|
|||
|
|
try:
|
|||
|
|
from app.db.repositories.email_codes import EmailCodesRepository # type: ignore
|
|||
|
|
HAS_EMAIL_CODES_REPO = True
|
|||
|
|
except Exception: # pragma: no cover
|
|||
|
|
EmailCodesRepository = None # type: ignore
|
|||
|
|
HAS_EMAIL_CODES_REPO = False
|
|||
|
|
|
|||
|
|
# 仅用于类型检查(让 Pylance/pyright 认识名字,但运行期不导入)
|
|||
|
|
if TYPE_CHECKING: # pragma: no cover
|
|||
|
|
from app.db.repositories.email_codes import EmailCodesRepository as _EmailCodesRepositoryT # noqa: F401
|
|||
|
|
|
|||
|
|
from app.models.schemas.users import (
|
|||
|
|
UserInLogin,
|
|||
|
|
UserInResponse,
|
|||
|
|
UserWithToken,
|
|||
|
|
RegisterWithEmailIn,
|
|||
|
|
)
|
|||
|
|
from app.models.schemas.email_code import EmailCodeSendIn, EmailCodeSendOut
|
|||
|
|
from app.resources import strings
|
|||
|
|
from app.services import jwt
|
|||
|
|
from app.services.mailer import send_email
|
|||
|
|
from app.services.authentication import (
|
|||
|
|
check_email_is_taken,
|
|||
|
|
assert_passwords_match,
|
|||
|
|
make_unique_username,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
router = APIRouter()
|
|||
|
|
|
|||
|
|
# ================= Cookie 工具(最小改造,无需新增文件) =================
|
|||
|
|
REFRESH_COOKIE_NAME = "refresh_token"
|
|||
|
|
|
|||
|
|
def set_refresh_cookie(resp: Response, token: str, *, max_age_days: int = 30) -> None:
|
|||
|
|
"""
|
|||
|
|
仅通过 HttpOnly Cookie 下发 refresh。
|
|||
|
|
- SameSite=Lax:避免跨站表单滥用
|
|||
|
|
- Secure=True:生产环境建议始终为 True;如本地纯 HTTP 开发可按需改为 False
|
|||
|
|
- Path 设为 /api/auth,缩小作用域
|
|||
|
|
"""
|
|||
|
|
resp.set_cookie(
|
|||
|
|
key=REFRESH_COOKIE_NAME,
|
|||
|
|
value=token,
|
|||
|
|
max_age=max_age_days * 24 * 3600,
|
|||
|
|
httponly=True,
|
|||
|
|
secure=True, # 如需在本地 http 调试,可改为 False
|
|||
|
|
samesite="lax",
|
|||
|
|
path="/api/auth",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def clear_refresh_cookie(resp: Response) -> None:
|
|||
|
|
resp.delete_cookie(
|
|||
|
|
key=REFRESH_COOKIE_NAME,
|
|||
|
|
path="/api/auth",
|
|||
|
|
httponly=True,
|
|||
|
|
secure=True,
|
|||
|
|
samesite="lax",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 为了兼容“可选的验证码仓库”,构造一个可交给 Depends 的工厂
|
|||
|
|
def _provide_optional_email_codes_repo():
|
|||
|
|
if HAS_EMAIL_CODES_REPO:
|
|||
|
|
return get_repository(EmailCodesRepository) # type: ignore[name-defined]
|
|||
|
|
|
|||
|
|
async def _none():
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
return _none
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========= 发送邮箱验证码 =========
|
|||
|
|
@router.post(
|
|||
|
|
"/email-code",
|
|||
|
|
response_model=EmailCodeSendOut,
|
|||
|
|
name="auth:email-code",
|
|||
|
|
)
|
|||
|
|
async def send_email_code(
|
|||
|
|
payload: EmailCodeSendIn = Body(...),
|
|||
|
|
settings: AppSettings = Depends(get_app_settings),
|
|||
|
|
email_codes_repo: Optional[Any] = Depends(_provide_optional_email_codes_repo()),
|
|||
|
|
) -> EmailCodeSendOut:
|
|||
|
|
"""
|
|||
|
|
发送邮箱验证码并写入数据库(若仓库存在)。
|
|||
|
|
"""
|
|||
|
|
# 1) 生成验证码(6 位数字)
|
|||
|
|
rnd = __import__("random").randint(0, 999999)
|
|||
|
|
code = f"{rnd:06d}"
|
|||
|
|
|
|||
|
|
# 2) 过期时间
|
|||
|
|
expires_at = datetime.utcnow() + timedelta(minutes=settings.email_code_expires_minutes)
|
|||
|
|
|
|||
|
|
# 3) 记录到数据库(可选)
|
|||
|
|
if email_codes_repo is not None:
|
|||
|
|
await email_codes_repo.create_code( # type: ignore[attr-defined]
|
|||
|
|
email=payload.email,
|
|||
|
|
code=code,
|
|||
|
|
scene=payload.scene,
|
|||
|
|
expires_at=expires_at,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 4) 发邮件
|
|||
|
|
subject = f"【AI平台】{payload.scene} 验证码:{code}"
|
|||
|
|
html = f"""
|
|||
|
|
<div style="font-family:Arial,Helvetica,sans-serif;font-size:14px;line-height:1.6">
|
|||
|
|
<p>您好!</p>
|
|||
|
|
<p>您正在进行 <b>{payload.scene}</b> 操作,本次验证码为:</p>
|
|||
|
|
<p style="font-size:22px;font-weight:700;letter-spacing:2px">{code}</p>
|
|||
|
|
<p>有效期:{settings.email_code_expires_minutes} 分钟;请勿泄露给他人。</p>
|
|||
|
|
</div>
|
|||
|
|
"""
|
|||
|
|
send_email(payload.email, subject, html)
|
|||
|
|
return EmailCodeSendOut(ok=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========= 登录 =========
|
|||
|
|
@router.post(
|
|||
|
|
"/login",
|
|||
|
|
response_model=UserInResponse,
|
|||
|
|
response_model_exclude_none=True,
|
|||
|
|
name="auth:login",
|
|||
|
|
)
|
|||
|
|
async def login(
|
|||
|
|
response: Response,
|
|||
|
|
user_login: UserInLogin = Body(..., embed=True, alias="user"),
|
|||
|
|
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
|||
|
|
settings: AppSettings = Depends(get_app_settings),
|
|||
|
|
) -> UserInResponse:
|
|||
|
|
"""邮箱 + 密码登录(签发 Access & Set-Cookie Refresh)"""
|
|||
|
|
wrong_login_error = HTTPException(
|
|||
|
|
status_code=HTTP_400_BAD_REQUEST,
|
|||
|
|
detail=strings.INCORRECT_LOGIN_INPUT,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
user = await users_repo.get_user_by_email(email=user_login.email)
|
|||
|
|
except EntityDoesNotExist as existence_error:
|
|||
|
|
raise wrong_login_error from existence_error
|
|||
|
|
|
|||
|
|
if not user.check_password(user_login.password):
|
|||
|
|
raise wrong_login_error
|
|||
|
|
|
|||
|
|
secret = str(settings.secret_key.get_secret_value())
|
|||
|
|
|
|||
|
|
# Access(15m) + Refresh(30d)
|
|||
|
|
access = jwt.create_access_token_for_user(user, secret)
|
|||
|
|
refresh = jwt.create_refresh_token_for_user(user, secret)
|
|||
|
|
|
|||
|
|
# 仅通过 HttpOnly Cookie 下发 refresh
|
|||
|
|
set_refresh_cookie(response, refresh, max_age_days=jwt.REFRESH_TOKEN_EXPIRE_DAYS)
|
|||
|
|
|
|||
|
|
return UserInResponse(
|
|||
|
|
user=UserWithToken(
|
|||
|
|
username=user.username,
|
|||
|
|
email=user.email,
|
|||
|
|
bio=user.bio,
|
|||
|
|
image=user.image,
|
|||
|
|
token=access, # 仍然在 body 返回 access,保持前端兼容
|
|||
|
|
email_verified=getattr(user, "email_verified", False),
|
|||
|
|
roles=getattr(user, "roles", []),
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========= 注册 =========
|
|||
|
|
@router.post(
|
|||
|
|
"",
|
|||
|
|
status_code=HTTP_201_CREATED,
|
|||
|
|
response_model=UserInResponse,
|
|||
|
|
response_model_exclude_none=True,
|
|||
|
|
name="auth:register",
|
|||
|
|
)
|
|||
|
|
async def register(
|
|||
|
|
response: Response,
|
|||
|
|
payload: RegisterWithEmailIn = Body(..., embed=True, alias="user"),
|
|||
|
|
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
|||
|
|
settings: AppSettings = Depends(get_app_settings),
|
|||
|
|
email_codes_repo: Optional[Any] = Depends(_provide_optional_email_codes_repo()),
|
|||
|
|
) -> UserInResponse:
|
|||
|
|
"""
|
|||
|
|
注册流程:
|
|||
|
|
1) 校验两次密码一致
|
|||
|
|
2) 校验邮箱未被占用
|
|||
|
|
3) 校验验证码(若存在验证码仓库)
|
|||
|
|
4) 生成唯一用户名
|
|||
|
|
5) 创建用户
|
|||
|
|
6) 如仓库提供 set_email_verified,则置为 True
|
|||
|
|
7) 签发 Access & Set-Cookie Refresh
|
|||
|
|
"""
|
|||
|
|
# 1) 两次密码一致
|
|||
|
|
try:
|
|||
|
|
assert_passwords_match(payload.password, payload.confirm_password)
|
|||
|
|
except ValueError:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=HTTP_400_BAD_REQUEST,
|
|||
|
|
detail="Passwords do not match",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 2) 邮箱是否占用
|
|||
|
|
if await check_email_is_taken(users_repo, payload.email):
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=HTTP_400_BAD_REQUEST,
|
|||
|
|
detail=strings.EMAIL_TAKEN,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 3) 校验验证码
|
|||
|
|
if email_codes_repo is not None:
|
|||
|
|
ok = await email_codes_repo.verify_and_consume( # type: ignore[attr-defined]
|
|||
|
|
email=payload.email,
|
|||
|
|
code=payload.code,
|
|||
|
|
scene="register",
|
|||
|
|
)
|
|||
|
|
if not ok:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=HTTP_400_BAD_REQUEST,
|
|||
|
|
detail="Invalid or expired verification code",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 4) 生成唯一用户名
|
|||
|
|
username = await make_unique_username(users_repo, payload.email)
|
|||
|
|
|
|||
|
|
# 5) 创建用户
|
|||
|
|
user = await users_repo.create_user(
|
|||
|
|
username=username,
|
|||
|
|
email=payload.email,
|
|||
|
|
password=payload.password,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 6) 若仓库支持置已验证,则更新并回读
|
|||
|
|
if hasattr(users_repo, "set_email_verified"):
|
|||
|
|
try:
|
|||
|
|
await users_repo.set_email_verified(email=payload.email, verified=True) # type: ignore[attr-defined]
|
|||
|
|
user = await users_repo.get_user_by_email(email=payload.email)
|
|||
|
|
except Exception:
|
|||
|
|
pass # 不阻塞主流程
|
|||
|
|
|
|||
|
|
# 7) 签发 Access & Refresh(并下发 Cookie)
|
|||
|
|
secret = str(settings.secret_key.get_secret_value())
|
|||
|
|
access = jwt.create_access_token_for_user(user, secret)
|
|||
|
|
refresh = jwt.create_refresh_token_for_user(user, secret)
|
|||
|
|
set_refresh_cookie(response, refresh, max_age_days=jwt.REFRESH_TOKEN_EXPIRE_DAYS)
|
|||
|
|
|
|||
|
|
return UserInResponse(
|
|||
|
|
user=UserWithToken(
|
|||
|
|
username=user.username,
|
|||
|
|
email=user.email,
|
|||
|
|
bio=user.bio,
|
|||
|
|
image=user.image,
|
|||
|
|
token=access,
|
|||
|
|
email_verified=getattr(user, "email_verified", True),
|
|||
|
|
roles=getattr(user, "roles", []),
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========= 刷新 Access(仅 Cookie 取 refresh)=========
|
|||
|
|
@router.post(
|
|||
|
|
"/refresh",
|
|||
|
|
name="auth:refresh",
|
|||
|
|
)
|
|||
|
|
async def refresh_access_token(
|
|||
|
|
request: Request,
|
|||
|
|
response: Response,
|
|||
|
|
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
|||
|
|
settings: AppSettings = Depends(get_app_settings),
|
|||
|
|
) -> dict:
|
|||
|
|
"""
|
|||
|
|
从 HttpOnly Cookie 读取 refresh,校验后签发新的 access,并重置 refresh Cookie。
|
|||
|
|
最小改造版本:refresh 不轮换(如需轮换/重放检测,请走“增表方案”)。
|
|||
|
|
"""
|
|||
|
|
refresh = request.cookies.get(REFRESH_COOKIE_NAME)
|
|||
|
|
if not refresh:
|
|||
|
|
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Missing refresh token")
|
|||
|
|
|
|||
|
|
secret = str(settings.secret_key.get_secret_value())
|
|||
|
|
try:
|
|||
|
|
username = jwt.get_username_from_token(refresh, secret, expected_subject=jwt.JWT_SUBJECT_REFRESH)
|
|||
|
|
except ValueError:
|
|||
|
|
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
|
|||
|
|
|
|||
|
|
# 取用户(优先按 username)
|
|||
|
|
try:
|
|||
|
|
# 大多数 RealWorld 模板都有该方法
|
|||
|
|
user = await users_repo.get_user_by_username(username=username) # type: ignore[attr-defined]
|
|||
|
|
except Exception:
|
|||
|
|
# 若没有 get_user_by_username,则退回按 email 查
|
|||
|
|
try:
|
|||
|
|
user = await users_repo.get_user_by_email(email=username)
|
|||
|
|
except Exception as e:
|
|||
|
|
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="User not found") from e
|
|||
|
|
|
|||
|
|
# 签发新 access;最小改造——同一个 refresh 继续使用(不轮换)
|
|||
|
|
access = jwt.create_access_token_for_user(user, secret)
|
|||
|
|
# 也可选择重置 refresh 的过期时间(同值覆盖),这里直接重设 Cookie:
|
|||
|
|
set_refresh_cookie(response, refresh, max_age_days=jwt.REFRESH_TOKEN_EXPIRE_DAYS)
|
|||
|
|
|
|||
|
|
return {"token": access, "expires_in": jwt.ACCESS_TOKEN_EXPIRE_MINUTES * 60}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========= 登出(清 Cookie;前端清本地 access)=========
|
|||
|
|
@router.post(
|
|||
|
|
"/logout",
|
|||
|
|
name="auth:logout",
|
|||
|
|
)
|
|||
|
|
async def logout(response: Response) -> dict:
|
|||
|
|
clear_refresh_cookie(response)
|
|||
|
|
return {"ok": True}
|