46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
# app/api/dependencies/database.py
|
||
from typing import AsyncIterator, Callable, Type
|
||
|
||
from asyncpg import Connection, Pool
|
||
from fastapi import Depends
|
||
from starlette.requests import Request
|
||
|
||
from app.db.repositories.base import BaseRepository
|
||
|
||
|
||
def _get_db_pool(request: Request) -> Pool:
|
||
"""
|
||
从 app.state.pool 取得连接池;若未初始化给出清晰报错。
|
||
"""
|
||
pool = getattr(request.app.state, "pool", None)
|
||
if pool is None:
|
||
raise RuntimeError("Database pool not initialized on app.state.pool")
|
||
return pool
|
||
|
||
|
||
async def _get_connection_from_pool(
|
||
pool: Pool = Depends(_get_db_pool),
|
||
) -> AsyncIterator[Connection]:
|
||
"""
|
||
私有实现:从连接池借出一个连接,使用后自动归还。
|
||
"""
|
||
async with pool.acquire() as conn:
|
||
yield conn
|
||
|
||
|
||
# ✅ 公共别名:供路由里直接使用 Depends(get_connection)
|
||
get_connection = _get_connection_from_pool
|
||
|
||
|
||
def get_repository(
|
||
repo_type: Type[BaseRepository],
|
||
) -> Callable[[Connection], BaseRepository]:
|
||
"""
|
||
兼容旧用法:Depends(get_repository(UserRepo))
|
||
内部依赖 get_connection,因此两种写法都能共存。
|
||
"""
|
||
def _get_repo(conn: Connection = Depends(get_connection)) -> BaseRepository:
|
||
return repo_type(conn)
|
||
|
||
return _get_repo
|