117 lines
3.9 KiB
Python
Raw Normal View History

2025-12-04 10:04:21 +08:00
from datetime import datetime, timedelta
from typing import Dict, Optional, Literal
import jwt
from pydantic import ValidationError
from app.models.domain.users import User
from app.models.schemas.jwt import JWTMeta, JWTUser
# === 配置 ===
ALGORITHM = "HS256"
# 统一区分两类 token 的 subject
JWT_SUBJECT_ACCESS = "access"
JWT_SUBJECT_REFRESH = "refresh"
# 有效期(按你的新方案)
ACCESS_TOKEN_EXPIRE_MINUTES = 15 # 15 分钟
REFRESH_TOKEN_EXPIRE_DAYS = 30 # 30 天
def _create_jwt_token(
*,
jwt_content: Dict[str, str],
secret_key: str,
expires_delta: timedelta,
subject: Literal["access", "refresh"],
) -> str:
"""
生成 JWT payload 中注入 exp / sub并用指定算法签名
jwt_content 通常来自 Pydantic 模型例如 JWTUser(username=...)
"""
to_encode = jwt_content.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update(JWTMeta(exp=expire, sub=subject).dict())
return jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
# ========== Access Token给前端放到 Authorization 里用) ==========
def create_access_token_for_user(user: User, secret_key: str) -> str:
"""
签发 Access Token有效期 15 分钟sub=access
"""
return _create_jwt_token(
jwt_content=JWTUser(username=user.username).dict(),
secret_key=secret_key,
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
subject=JWT_SUBJECT_ACCESS,
)
# ========== Refresh Token仅通过 HttpOnly Cookie 下发/使用) ==========
def create_refresh_token_for_user(user: User, secret_key: str) -> str:
"""
签发 Refresh Token有效期 30 sub=refresh
说明最小改造版本使用 JWT 作为 refresh若要更安全可改为随机串并服务端存哈希
"""
return _create_jwt_token(
jwt_content=JWTUser(username=user.username).dict(),
secret_key=secret_key,
expires_delta=timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS),
subject=JWT_SUBJECT_REFRESH,
)
# ========== 解码与校验工具 ==========
def _decode_token(token: str, secret_key: str) -> Dict:
"""
解码并返回原始 payload失败时抛 ValueError
"""
try:
return jwt.decode(token, secret_key, algorithms=[ALGORITHM])
except jwt.PyJWTError as decode_error:
raise ValueError("unable to decode JWT token") from decode_error
def get_username_from_token(
token: str,
secret_key: str,
expected_subject: Literal["access", "refresh"] = JWT_SUBJECT_ACCESS,
) -> str:
"""
解析 token 并返回用户名同时校验 sub 是否符合预期默认 access
- 用于受保护接口expected_subject='access'
- 用于刷新流程expected_subject='refresh'
"""
try:
payload = _decode_token(token, secret_key)
# 主动校验 sub避免把 refresh 当成 access 用
sub = payload.get("sub")
if sub != expected_subject:
raise ValueError(f"invalid token subject: expected '{expected_subject}', got '{sub}'")
# 用 Pydantic 做字段校验/提取
return JWTUser(**payload).username
except ValidationError as validation_error:
raise ValueError("malformed payload in token") from validation_error
# ========== 兼容旧用法的别名(如果你项目其他地方直接调用了它) ==========
def create_jwt_token(
*,
jwt_content: Dict[str, str],
secret_key: str,
expires_delta: timedelta,
) -> str:
"""
兼容旧签发函数默认当作 Access Token 使用sub=access
建议新代码直接使用 create_access_token_for_user / create_refresh_token_for_user
"""
return _create_jwt_token(
jwt_content=jwt_content,
secret_key=secret_key,
expires_delta=expires_delta,
subject=JWT_SUBJECT_ACCESS,
)