47 lines
1.2 KiB
Python
47 lines
1.2 KiB
Python
|
|
from types import TracebackType
|
||
|
|
from typing import Optional, Type
|
||
|
|
|
||
|
|
from asyncpg import Connection
|
||
|
|
from asyncpg.pool import Pool
|
||
|
|
|
||
|
|
|
||
|
|
class FakeAsyncPGPool:
|
||
|
|
def __init__(self, pool: Pool) -> None:
|
||
|
|
self._pool = pool
|
||
|
|
self._conn = None
|
||
|
|
self._tx = None
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
async def create_pool(cls, pool: Pool) -> "FakeAsyncPGPool":
|
||
|
|
pool = cls(pool)
|
||
|
|
conn = await pool._pool.acquire()
|
||
|
|
tx = conn.transaction()
|
||
|
|
await tx.start()
|
||
|
|
pool._conn = conn
|
||
|
|
pool._tx = tx
|
||
|
|
return pool
|
||
|
|
|
||
|
|
async def close(self) -> None:
|
||
|
|
await self._tx.rollback()
|
||
|
|
await self._pool.release(self._conn)
|
||
|
|
await self._pool.close()
|
||
|
|
|
||
|
|
def acquire(self, *, timeout: Optional[float] = None) -> "FakePoolAcquireContent":
|
||
|
|
return FakePoolAcquireContent(self)
|
||
|
|
|
||
|
|
|
||
|
|
class FakePoolAcquireContent:
|
||
|
|
def __init__(self, pool: FakeAsyncPGPool) -> None:
|
||
|
|
self._pool = pool
|
||
|
|
|
||
|
|
async def __aenter__(self) -> Connection:
|
||
|
|
return self._pool._conn
|
||
|
|
|
||
|
|
async def __aexit__(
|
||
|
|
self,
|
||
|
|
exc_type: Optional[Type[Exception]],
|
||
|
|
exc: Optional[Exception],
|
||
|
|
tb: Optional[TracebackType],
|
||
|
|
) -> None:
|
||
|
|
pass
|