46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
|
|
# app/api/dependencies/database.py
|
|||
|
|
from typing import AsyncIterator, Callable, Type
|
|||
|
|
|
|||
|
|
from asyncpg import Connection, Pool
|
|||
|
|
from fastapi import Depends
|
|||
|
|
from starlette.requests import Request
|
|||
|
|
|
|||
|
|
from app.db.repositories.base import BaseRepository
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _get_db_pool(request: Request) -> Pool:
|
|||
|
|
"""
|
|||
|
|
从 app.state.pool 取得连接池;若未初始化给出清晰报错。
|
|||
|
|
"""
|
|||
|
|
pool = getattr(request.app.state, "pool", None)
|
|||
|
|
if pool is None:
|
|||
|
|
raise RuntimeError("Database pool not initialized on app.state.pool")
|
|||
|
|
return pool
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def _get_connection_from_pool(
|
|||
|
|
pool: Pool = Depends(_get_db_pool),
|
|||
|
|
) -> AsyncIterator[Connection]:
|
|||
|
|
"""
|
|||
|
|
私有实现:从连接池借出一个连接,使用后自动归还。
|
|||
|
|
"""
|
|||
|
|
async with pool.acquire() as conn:
|
|||
|
|
yield conn
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ✅ 公共别名:供路由里直接使用 Depends(get_connection)
|
|||
|
|
get_connection = _get_connection_from_pool
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_repository(
|
|||
|
|
repo_type: Type[BaseRepository],
|
|||
|
|
) -> Callable[[Connection], BaseRepository]:
|
|||
|
|
"""
|
|||
|
|
兼容旧用法:Depends(get_repository(UserRepo))
|
|||
|
|
内部依赖 get_connection,因此两种写法都能共存。
|
|||
|
|
"""
|
|||
|
|
def _get_repo(conn: Connection = Depends(get_connection)) -> BaseRepository:
|
|||
|
|
return repo_type(conn)
|
|||
|
|
|
|||
|
|
return _get_repo
|