diff --git a/CLAUDE.md b/CLAUDE.md index 65f22d4..5156803 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -49,7 +49,7 @@ Endpoints inject repositories with `FromDI(Repository)` from `modern_di_fastapi` - `app/models.py` — `BigIntAuditBase` from `advanced_alchemy` (auto `id`, `created_at`, `updated_at`). The module aliases `orm_registry.metadata` onto `orm.DeclarativeBase.metadata` so Alembic autogenerate sees both. New models go here. - `app/repositories.py` — Subclass `SQLAlchemyAsyncRepositoryService[Model]` with a nested `BaseRepository(SQLAlchemyAsyncRepository[Model])`. Routes use the service methods (`list`, `get_one_or_none`, `create`, `update`, `create_many`, `upsert_many`). -- `app/resources/db.py` — `CustomAsyncSession.close()` does `expunge_all()` instead of closing when bound to an `AsyncConnection`. This is what enables the test rollback pattern below — do not "fix" it. +- `app/resources/db.py` — `create_session` passes `join_transaction_mode="create_savepoint"`. This is inert in production (the session binds to an engine) but enables the test rollback pattern below: when a test binds the session to a connection already in a transaction, the session owns its own savepoint so the outer transaction survives commits — do not "fix" it. - `migrations/env.py` swaps the asyncpg driver for the sync `postgresql` driver and uses `app.models.METADATA` as `target_metadata`. ### Settings @@ -61,7 +61,7 @@ Endpoints inject repositories with `FromDI(Repository)` from `modern_di_fastapi` `tests/conftest.py` provides the test isolation pattern — read it before adding fixtures: - `app` fixture builds a fresh app via `LifespanManager`. -- `db_session` opens a connection, begins a transaction, begins a nested savepoint, and **overrides `Dependencies.database_engine`** with the connection itself. The nested savepoint is rolled back at teardown so each test starts clean. This is why `CustomAsyncSession.close` must `expunge_all` rather than close — closing would commit the outer transaction. +- `db_session` opens a connection, begins a transaction, and **overrides `Dependencies.database_engine`** with the connection itself. Each session built against that connection uses `join_transaction_mode="create_savepoint"`, so `auto_commit` releases the session's own savepoint while the outer transaction is rolled back at teardown — each test starts clean. - `set_async_session_in_base_sqlalchemy_factory` wires `db_session` into `SQLAlchemyFactory.__async_session__` so `polyfactory` factories in `tests/factories.py` (`DeckModelFactory`, `CardModelFactory`) persist via the rolled-back session. Test modules that use these factories opt in with `pytestmark = [pytest.mark.usefixtures("set_async_session_in_base_sqlalchemy_factory")]`. `pytest.ini_options` sets `asyncio_mode = "auto"` — async tests do not need `@pytest.mark.asyncio`. Coverage runs by default (`--cov=. --cov-report term-missing`). diff --git a/app/resources/db.py b/app/resources/db.py index c2723d5..99944bc 100644 --- a/app/resources/db.py +++ b/app/resources/db.py @@ -25,16 +25,16 @@ async def close_sa_engine(engine: sa.AsyncEngine) -> None: await engine.dispose() -class CustomAsyncSession(sa.AsyncSession): - async def close(self) -> None: - if isinstance(self.bind, sa.AsyncConnection): - return self.expunge_all() - - return await super().close() - - def create_session(engine: sa.AsyncEngine) -> sa.AsyncSession: - return CustomAsyncSession(engine, expire_on_commit=False, autoflush=False) + # join_transaction_mode is inert in production (the session binds to an engine); when tests bind + # the session to a connection already in a transaction, it makes the session own a savepoint so + # the outer transaction survives commits and the per-test rollback stays clean. + return sa.AsyncSession( + engine, + expire_on_commit=False, + autoflush=False, + join_transaction_mode="create_savepoint", + ) async def close_session(session: sa.AsyncSession) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index f788758..314e500 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -47,11 +47,15 @@ async def db_session(di_container: modern_di.Container) -> typing.AsyncIterator[ engine = create_sa_engine() connection = await engine.connect() transaction = await connection.begin() - await connection.begin_nested() di_container.override(ioc.Dependencies.database_engine, connection) try: - yield AsyncSession(connection, expire_on_commit=False, autoflush=False) + yield AsyncSession( + connection, + expire_on_commit=False, + autoflush=False, + join_transaction_mode="create_savepoint", + ) finally: if connection.in_transaction(): await transaction.rollback()