46 lines
1.3 KiB
Python
Raw Normal View History

2025-12-04 10:04:21 +08:00
# app/api/dependencies/database.py
from typing import AsyncIterator, Callable, Type
from asyncpg import Connection, Pool
from fastapi import Depends
from starlette.requests import Request
from app.db.repositories.base import BaseRepository
def _get_db_pool(request: Request) -> Pool:
"""
app.state.pool 取得连接池若未初始化给出清晰报错
"""
pool = getattr(request.app.state, "pool", None)
if pool is None:
raise RuntimeError("Database pool not initialized on app.state.pool")
return pool
async def _get_connection_from_pool(
pool: Pool = Depends(_get_db_pool),
) -> AsyncIterator[Connection]:
"""
私有实现从连接池借出一个连接使用后自动归还
"""
async with pool.acquire() as conn:
yield conn
# ✅ 公共别名:供路由里直接使用 Depends(get_connection)
get_connection = _get_connection_from_pool
def get_repository(
repo_type: Type[BaseRepository],
) -> Callable[[Connection], BaseRepository]:
"""
兼容旧用法Depends(get_repository(UserRepo))
内部依赖 get_connection因此两种写法都能共存
"""
def _get_repo(conn: Connection = Depends(get_connection)) -> BaseRepository:
return repo_type(conn)
return _get_repo