# app/services/password_reset.py import os import hashlib import secrets from datetime import datetime, timedelta, timezone from fastapi import HTTPException, Request from asyncpg import Connection from app.db.repositories.users import UsersRepository from app.db.queries.queries import queries # aiosql 生成的 Queries 对象 from app.services import security # ✅ 使用项目原有 passlib 封装 from app.db.errors import EntityDoesNotExist # 用于兜底 try/except # 业务常量 RESET_SCENE = "reset" RESET_PURPOSE = "reset" CODE_TTL_MINUTES = 30 # 验证码有效期(分钟) # ===== 小工具 ===== def _sha256_hex(s: str) -> str: return hashlib.sha256(s.encode("utf-8")).hexdigest() def _email_html(code: str) -> str: return f"""

重置你的密码

你的验证码({CODE_TTL_MINUTES} 分钟内有效):

{code}

若非本人操作请忽略此邮件。

""" def _first_row(maybe_rows): """ aiosql + asyncpg 在 SELECT 时可能返回: - asyncpg.Record - list[Record] - dict-like 统一取“第一条/单条”。 """ if maybe_rows is None: return None if isinstance(maybe_rows, list): return maybe_rows[0] if maybe_rows else None return maybe_rows def _get_key(row, key: str): """ 兼容 asyncpg.Record / dict / tuple(list) 仅用于取 'id' 这种关键字段 """ if row is None: return None # dict-like / Record try: if key in row: return row[key] # type: ignore[index] except Exception: pass # 某些驱动可能支持 .get try: return row.get(key) # type: ignore[attr-defined] except Exception: pass # 最后尝试属性 return getattr(row, key, None) async def _get_user_by_email_optional(users_repo: UsersRepository, *, email: str): """ 安全获取用户: - 若仓库实现了 get_user_by_email_optional,直接用 - 否则回退到 get_user_by_email,并用 try/except 屏蔽不存在异常 返回 UserInDB 或 None """ # 新接口:优先调用 if hasattr(users_repo, "get_user_by_email_optional"): try: return await users_repo.get_user_by_email_optional(email=email) # type: ignore[attr-defined] except Exception: return None # 旧接口:try/except 防止抛出不存在 try: return await users_repo.get_user_by_email(email=email) except EntityDoesNotExist: return None # ===== 主流程 ===== async def send_reset_code_by_email( request: Request, conn: Connection, users_repo: UsersRepository, email: str, ) -> None: """ 若邮箱存在:生成 6 位验证码 -> 只存哈希 -> 发送邮件(或开发阶段打印) 若不存在:静默返回,防止枚举邮箱 """ user = await _get_user_by_email_optional(users_repo, email=email) if not user: return # 静默 # 6 位数字验证码(明文只用于发送/展示,数据库只存哈希) code = f"{secrets.randbelow(1_000_000):06d}" code_hash = _sha256_hex(code) expires_at = datetime.now(timezone.utc) + timedelta(minutes=CODE_TTL_MINUTES) request_ip = request.client.host if request.client else None user_agent = request.headers.get("user-agent", "") await queries.create_email_code( conn, email=email, scene=RESET_SCENE, purpose=RESET_PURPOSE, code_hash=code_hash, expires_at=expires_at, request_ip=request_ip, user_agent=user_agent, ) # === 发送邮件 === try: # 如果你已有统一邮件服务,可直接调用;没有则打印在开发日志 from app.services.mailer import send_email # 可选 # 你的 send_email 若是异步函数,这里 await;若是同步也能正常抛异常被捕获 maybe_coro = send_email( to_email=email, subject="重置密码验证码", html=_email_html(code), ) if hasattr(maybe_coro, "__await__"): await maybe_coro # 兼容 async 版本 except Exception: print(f"[DEV] reset code for {email}: {code} (expires in {CODE_TTL_MINUTES}m)") async def reset_password_with_code( conn: Connection, users_repo: UsersRepository, *, email: str, code: str, new_password: str, ) -> None: """ 校验验证码 -> 修改密码 -> 标记验证码已使用 -> 清理历史 """ code_hash = _sha256_hex(code.strip()) # 1) 校验验证码(只接受未使用且未过期) rec = await queries.get_valid_email_code( conn, email=email, scene=RESET_SCENE, purpose=RESET_PURPOSE, code_hash=code_hash, ) rec = _first_row(rec) if rec is None: raise HTTPException(status_code=400, detail="验证码无效或已过期") # 2) 查用户(安全获取,避免抛异常 & 防枚举) user = await _get_user_by_email_optional(users_repo, email=email) if user is None: # 与验证码错误同样提示,避免暴露邮箱存在性 raise HTTPException(status_code=400, detail="验证码无效或已过期") # 3) 生成新 salt / hash —— ✅ 使用项目原有 passlib 方案 # 关键点:和登录校验保持一致,对 (salt + plain_password) 做 passlib 哈希 new_salt = os.urandom(16).hex() new_hashed = security.get_password_hash(new_salt + new_password) # 4) 优先用 id 更新;若没有 id(历史坑),则回退用 email 更新 updated = None try: user_id = getattr(user, "id", None) if user_id: updated = await queries.update_user_password_by_id( conn, id=user_id, new_salt=new_salt, new_password=new_hashed, # ✅ passlib 生成的带前缀哈希 ) else: updated = await queries.update_user_password_by_email( conn, email=email, new_salt=new_salt, new_password=new_hashed, ) except Exception: # 极端情况下,id 更新失败也再补 email 更新,确保不中断 updated = await queries.update_user_password_by_email( conn, email=email, new_salt=new_salt, new_password=new_hashed, ) # aiosql 有时会返回 list,若是空 list 视为失败 if isinstance(updated, list) and not updated: raise HTTPException(status_code=500, detail="密码更新失败") # 5) 标记验证码已用 & 清理 rec_id = _get_key(rec, "id") if rec_id is not None: await queries.mark_email_code_used(conn, id=rec_id) else: print("[WARN] Could not resolve email_code.id to mark consumed.") await queries.delete_expired_email_codes(conn)