117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
from datetime import datetime, timedelta
|
||
from typing import Dict, Optional, Literal
|
||
|
||
import jwt
|
||
from pydantic import ValidationError
|
||
|
||
from app.models.domain.users import User
|
||
from app.models.schemas.jwt import JWTMeta, JWTUser
|
||
|
||
# === 配置 ===
|
||
ALGORITHM = "HS256"
|
||
|
||
# 统一区分两类 token 的 subject
|
||
JWT_SUBJECT_ACCESS = "access"
|
||
JWT_SUBJECT_REFRESH = "refresh"
|
||
|
||
# 有效期(按你的新方案)
|
||
ACCESS_TOKEN_EXPIRE_MINUTES = 15 # 15 分钟
|
||
REFRESH_TOKEN_EXPIRE_DAYS = 30 # 30 天
|
||
|
||
|
||
def _create_jwt_token(
|
||
*,
|
||
jwt_content: Dict[str, str],
|
||
secret_key: str,
|
||
expires_delta: timedelta,
|
||
subject: Literal["access", "refresh"],
|
||
) -> str:
|
||
"""
|
||
生成 JWT:在 payload 中注入 exp / sub,并用指定算法签名。
|
||
jwt_content 通常来自 Pydantic 模型(例如 JWTUser(username=...))
|
||
"""
|
||
to_encode = jwt_content.copy()
|
||
expire = datetime.utcnow() + expires_delta
|
||
to_encode.update(JWTMeta(exp=expire, sub=subject).dict())
|
||
return jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
|
||
|
||
|
||
# ========== Access Token(给前端放到 Authorization 里用) ==========
|
||
def create_access_token_for_user(user: User, secret_key: str) -> str:
|
||
"""
|
||
签发 Access Token(有效期 15 分钟;sub=access)
|
||
"""
|
||
return _create_jwt_token(
|
||
jwt_content=JWTUser(username=user.username).dict(),
|
||
secret_key=secret_key,
|
||
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
|
||
subject=JWT_SUBJECT_ACCESS,
|
||
)
|
||
|
||
|
||
# ========== Refresh Token(仅通过 HttpOnly Cookie 下发/使用) ==========
|
||
def create_refresh_token_for_user(user: User, secret_key: str) -> str:
|
||
"""
|
||
签发 Refresh Token(有效期 30 天;sub=refresh)
|
||
说明:最小改造版本使用 JWT 作为 refresh;若要更安全可改为随机串并服务端存哈希。
|
||
"""
|
||
return _create_jwt_token(
|
||
jwt_content=JWTUser(username=user.username).dict(),
|
||
secret_key=secret_key,
|
||
expires_delta=timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS),
|
||
subject=JWT_SUBJECT_REFRESH,
|
||
)
|
||
|
||
|
||
# ========== 解码与校验工具 ==========
|
||
def _decode_token(token: str, secret_key: str) -> Dict:
|
||
"""
|
||
解码并返回原始 payload;失败时抛 ValueError。
|
||
"""
|
||
try:
|
||
return jwt.decode(token, secret_key, algorithms=[ALGORITHM])
|
||
except jwt.PyJWTError as decode_error:
|
||
raise ValueError("unable to decode JWT token") from decode_error
|
||
|
||
|
||
def get_username_from_token(
|
||
token: str,
|
||
secret_key: str,
|
||
expected_subject: Literal["access", "refresh"] = JWT_SUBJECT_ACCESS,
|
||
) -> str:
|
||
"""
|
||
解析 token 并返回用户名;同时校验 sub 是否符合预期(默认 access)。
|
||
- 用于受保护接口:expected_subject='access'
|
||
- 用于刷新流程:expected_subject='refresh'
|
||
"""
|
||
try:
|
||
payload = _decode_token(token, secret_key)
|
||
# 主动校验 sub,避免把 refresh 当成 access 用
|
||
sub = payload.get("sub")
|
||
if sub != expected_subject:
|
||
raise ValueError(f"invalid token subject: expected '{expected_subject}', got '{sub}'")
|
||
|
||
# 用 Pydantic 做字段校验/提取
|
||
return JWTUser(**payload).username
|
||
except ValidationError as validation_error:
|
||
raise ValueError("malformed payload in token") from validation_error
|
||
|
||
|
||
# ========== 兼容旧用法的别名(如果你项目其他地方直接调用了它) ==========
|
||
def create_jwt_token(
|
||
*,
|
||
jwt_content: Dict[str, str],
|
||
secret_key: str,
|
||
expires_delta: timedelta,
|
||
) -> str:
|
||
"""
|
||
兼容旧签发函数:默认当作 Access Token 使用(sub=access)。
|
||
建议新代码直接使用 create_access_token_for_user / create_refresh_token_for_user。
|
||
"""
|
||
return _create_jwt_token(
|
||
jwt_content=jwt_content,
|
||
secret_key=secret_key,
|
||
expires_delta=expires_delta,
|
||
subject=JWT_SUBJECT_ACCESS,
|
||
)
|