from datetime import datetime, timedelta from typing import Dict, Optional, Literal import jwt from pydantic import ValidationError from app.models.domain.users import User from app.models.schemas.jwt import JWTMeta, JWTUser # === 配置 === ALGORITHM = "HS256" # 统一区分两类 token 的 subject JWT_SUBJECT_ACCESS = "access" JWT_SUBJECT_REFRESH = "refresh" # 有效期(按你的新方案) ACCESS_TOKEN_EXPIRE_MINUTES = 15 # 15 分钟 REFRESH_TOKEN_EXPIRE_DAYS = 30 # 30 天 def _create_jwt_token( *, jwt_content: Dict[str, str], secret_key: str, expires_delta: timedelta, subject: Literal["access", "refresh"], ) -> str: """ 生成 JWT:在 payload 中注入 exp / sub,并用指定算法签名。 jwt_content 通常来自 Pydantic 模型(例如 JWTUser(username=...)) """ to_encode = jwt_content.copy() expire = datetime.utcnow() + expires_delta to_encode.update(JWTMeta(exp=expire, sub=subject).dict()) return jwt.encode(to_encode, secret_key, algorithm=ALGORITHM) # ========== Access Token(给前端放到 Authorization 里用) ========== def create_access_token_for_user(user: User, secret_key: str) -> str: """ 签发 Access Token(有效期 15 分钟;sub=access) """ return _create_jwt_token( jwt_content=JWTUser(username=user.username).dict(), secret_key=secret_key, expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES), subject=JWT_SUBJECT_ACCESS, ) # ========== Refresh Token(仅通过 HttpOnly Cookie 下发/使用) ========== def create_refresh_token_for_user(user: User, secret_key: str) -> str: """ 签发 Refresh Token(有效期 30 天;sub=refresh) 说明:最小改造版本使用 JWT 作为 refresh;若要更安全可改为随机串并服务端存哈希。 """ return _create_jwt_token( jwt_content=JWTUser(username=user.username).dict(), secret_key=secret_key, expires_delta=timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS), subject=JWT_SUBJECT_REFRESH, ) # ========== 解码与校验工具 ========== def _decode_token(token: str, secret_key: str) -> Dict: """ 解码并返回原始 payload;失败时抛 ValueError。 """ try: return jwt.decode(token, secret_key, algorithms=[ALGORITHM]) except jwt.PyJWTError as decode_error: raise ValueError("unable to decode JWT token") from decode_error def get_username_from_token( token: str, secret_key: str, expected_subject: Literal["access", "refresh"] = JWT_SUBJECT_ACCESS, ) -> str: """ 解析 token 并返回用户名;同时校验 sub 是否符合预期(默认 access)。 - 用于受保护接口:expected_subject='access' - 用于刷新流程:expected_subject='refresh' """ try: payload = _decode_token(token, secret_key) # 主动校验 sub,避免把 refresh 当成 access 用 sub = payload.get("sub") if sub != expected_subject: raise ValueError(f"invalid token subject: expected '{expected_subject}', got '{sub}'") # 用 Pydantic 做字段校验/提取 return JWTUser(**payload).username except ValidationError as validation_error: raise ValueError("malformed payload in token") from validation_error # ========== 兼容旧用法的别名(如果你项目其他地方直接调用了它) ========== def create_jwt_token( *, jwt_content: Dict[str, str], secret_key: str, expires_delta: timedelta, ) -> str: """ 兼容旧签发函数:默认当作 Access Token 使用(sub=access)。 建议新代码直接使用 create_access_token_for_user / create_refresh_token_for_user。 """ return _create_jwt_token( jwt_content=jwt_content, secret_key=secret_key, expires_delta=expires_delta, subject=JWT_SUBJECT_ACCESS, )