Fix GitHub Actions Worklows and refactor tests (#27)

* Refactor conftest

* Fix docstrings
This commit is contained in:
Vladyslav Fedoriuk 2023-11-18 20:27:38 +01:00 committed by GitHub
parent a2ed5df38a
commit 0610553651
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 129 additions and 61 deletions

View File

@ -10,13 +10,18 @@ on:
workflows:
- "Scraping the repositories from Source Graph"
- "Python App Quality and Testing"
branches: [main]
branches: [master]
types:
- completed
# Allows you to run this workflow manually from the Actions tab
# https://docs.github.com/en/actions/using-workflows/manually-running-a-workflow
# https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#workflow_dispatch
workflow_dispatch:
# Allows to run this workflow on push events to the master branch
# https://docs.github.com/en/actions/using-workflows/triggering-a-workflow#using-activity-types-and-filters-with-multiple-events
push:
branches: [master]
page_build:
# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages
permissions:

View File

@ -7,8 +7,13 @@ from typing import Literal
import pytest
import stamina
from dirty_equals import IsList
from pytest_mock import MockerFixture
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
from sqlalchemy.ext.asyncio import (
AsyncConnection,
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from app.database import Dependency, Repo
from app.factories import DependencyCreateDataFactory
@ -28,10 +33,38 @@ def _deactivate_retries() -> None:
stamina.set_active(False)
@pytest.fixture(autouse=True)
def _test_db(mocker: MockerFixture) -> None:
@pytest.fixture(scope="session")
def db_path() -> str:
"""Use the in-memory database for tests."""
mocker.patch("app.database.DB_PATH", "")
return "" # ":memory:"
@pytest.fixture(scope="session")
def db_connection_string(
db_path: str,
) -> str:
"""Provide the connection string for the in-memory database."""
return f"sqlite+aiosqlite:///{db_path}"
@pytest.fixture(scope="session", params=[{"echo": False}], ids=["echo=False"])
async def db_engine(
db_connection_string: str,
request: pytest.FixtureRequest,
) -> AsyncGenerator[AsyncEngine, None, None]:
"""Create the database engine."""
# echo=True enables logging of all SQL statements
# https://docs.sqlalchemy.org/en/20/core/engines.html#sqlalchemy.create_engine.params.echo
engine = create_async_engine(
db_connection_string,
**request.param, # type: ignore
)
try:
yield engine
finally:
# for AsyncEngine created in function scope, close and
# clean-up pooled connections
await engine.dispose()
@pytest.fixture(scope="session")
@ -49,35 +82,68 @@ def event_loop(
@pytest.fixture(scope="session")
async def test_db_connection() -> AsyncGenerator[AsyncConnection, None]:
"""Use the in-memory database for tests."""
from app.database import Base, engine
async def _database_objects(
db_engine: AsyncEngine,
) -> AsyncGenerator[None, None]:
"""Create the database objects (tables, etc.)."""
from app.database import Base
# Enters a transaction
# https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#sqlalchemy.ext.asyncio.AsyncConnection.begin
try:
async with engine.begin() as conn:
async with db_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
yield conn
yield
finally:
# for AsyncEngine created in function scope, close and
# clean-up pooled connections
await engine.dispose()
# Clean up after the testing session is over
async with db_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture(scope="session")
async def db_connection(
db_engine: AsyncEngine,
) -> AsyncGenerator[AsyncConnection, None]:
"""Create a database connection."""
# Return connection with no transaction
# https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#sqlalchemy.ext.asyncio.AsyncEngine.connect
async with db_engine.connect() as conn:
yield conn
@pytest.fixture()
async def test_db_session(
test_db_connection: AsyncConnection,
async def db_session(
db_engine: AsyncEngine,
_database_objects: None,
) -> AsyncGenerator[AsyncSession, None]:
"""Use the in-memory database for tests."""
from app.uow import async_session_uow
async with async_session_uow() as session:
"""Create a database session."""
# The `async_sessionmaker` function is used to create a Session factory
# https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#sqlalchemy.ext.asyncio.async_sessionmaker
async_session_factory = async_sessionmaker(
db_engine, expire_on_commit=False, autoflush=False, autocommit=False
)
async with async_session_factory() as session:
yield session
@pytest.fixture()
async def db_uow(
db_session: AsyncSession,
) -> AsyncGenerator[AsyncSession, None]:
"""Provide a transactional scope around a series of operations."""
# This context manager will start a transaction, and roll it back at the end
# https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#sqlalchemy.ext.asyncio.AsyncSessionTransaction
async with db_session.begin() as transaction:
try:
yield db_session
finally:
await transaction.rollback()
@pytest.fixture()
async def some_repos(
test_db_session: AsyncSession,
db_session: AsyncSession,
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
dependency_create_data_factory: DependencyCreateDataFactory,
) -> list[Repo]:
@ -99,7 +165,6 @@ async def some_repos(
)
for source_graph_repo_data in source_graph_repos_data
]
test_db_session.add_all(repos)
await test_db_session.flush()
await asyncio.gather(*[test_db_session.refresh(repo) for repo in repos])
db_session.add_all(repos)
await db_session.flush()
return repos

View File

@ -15,7 +15,6 @@ The module defines the following models:
The database is accessed asynchronously using SQLAlchemy's async API.
"""
from collections.abc import AsyncGenerator
from pathlib import PurePath
from typing import Final
@ -34,11 +33,13 @@ from sqlalchemy.orm import (
relationship,
)
DB_PATH: Final[PurePath] = PurePath(__file__).parent.parent / "db.sqlite3"
from app.types import RevisionHash, SourceGraphRepoId
SQLALCHEMY_DATABASE_URL: Final[str] = f"sqlite+aiosqlite:///{DB_PATH}"
_DB_PATH: Final[PurePath] = PurePath(__file__).parent.parent / "db.sqlite3"
engine: Final[AsyncEngine] = create_async_engine(SQLALCHEMY_DATABASE_URL)
_SQLALCHEMY_DATABASE_URL: Final[str] = f"sqlite+aiosqlite:///{_DB_PATH}"
engine: Final[AsyncEngine] = create_async_engine(_SQLALCHEMY_DATABASE_URL)
async_session_maker: Final[async_sessionmaker[AsyncSession]] = async_sessionmaker(
engine, expire_on_commit=False, autoflush=False, autocommit=False
@ -55,12 +56,6 @@ metadata = MetaData(
)
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
"""Get an async session."""
async with async_session_maker() as session:
yield session
Base = declarative_base(metadata=metadata, cls=AsyncAttrs)
@ -72,13 +67,13 @@ class Repo(Base):
url: Mapped[str] = mapped_column(nullable=False, unique=True)
description: Mapped[str] = mapped_column(Text, nullable=False)
stars: Mapped[int] = mapped_column(BigInteger, nullable=False)
source_graph_repo_id: Mapped[int | None] = mapped_column(
source_graph_repo_id: Mapped[SourceGraphRepoId | None] = mapped_column(
BigInteger, nullable=True, unique=True
)
dependencies: Mapped[list["Dependency"]] = relationship(
"Dependency", secondary="repo_dependency", back_populates="repos"
)
last_checked_revision: Mapped[str | None] = mapped_column(
last_checked_revision: Mapped[RevisionHash | None] = mapped_column(
String(255), nullable=True
)
__table_args__ = (UniqueConstraint("url", "source_graph_repo_id"),)

View File

@ -2,7 +2,7 @@
from pydantic import BaseModel, ConfigDict, NonNegativeInt
from app.types import DependencyId, RepoId, RevisionHash
from app.types import DependencyId, RepoId, RevisionHash, SourceGraphRepoId
class DependencyCreateData(BaseModel):
@ -33,6 +33,6 @@ class RepoDetail(BaseModel):
url: str
description: str
stars: NonNegativeInt
source_graph_repo_id: int
source_graph_repo_id: SourceGraphRepoId | None
dependencies: list[DependencyDetail]
last_checked_revision: RevisionHash | None

View File

@ -72,7 +72,7 @@ class AsyncSourceGraphSSEClient:
headers["Last-Event-ID"] = self._last_event_id
async with aconnect_sse(
client=self._aclient,
url=str(SOURCE_GRAPH_STREAM_API_URL),
url=SOURCE_GRAPH_STREAM_API_URL,
method="GET",
headers=headers,
**kwargs,

View File

@ -7,6 +7,6 @@ from app.source_graph.models import SourceGraphRepoData
@register_fixture
class SourceGraphRepoDataFactory(ModelFactory[SourceGraphRepoData]):
"""Factory for creating RepoCreateData."""
"""Factory for creating SourceGraphRepoData."""
__model__ = SourceGraphRepoData

View File

@ -1,6 +1,6 @@
"""The models for the Source Graph data."""
import datetime
from typing import Literal, NewType, Self
from typing import Literal, Self
from pydantic import (
BaseModel,
@ -11,8 +11,7 @@ from pydantic import (
computed_field,
)
#: The ID of a repository from the SourceGraph API.
SourceGraphRepoId = NewType("SourceGraphRepoId", int)
from app.types import SourceGraphRepoId
class SourceGraphRepoData(BaseModel):

View File

@ -14,7 +14,7 @@ pytestmark = pytest.mark.anyio
async def test_create_or_update_repos_from_source_graph_repos_data(
test_db_session: AsyncSession,
db_session: AsyncSession,
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
) -> None:
"""Test creating repos from source graph repos data."""
@ -22,7 +22,7 @@ async def test_create_or_update_repos_from_source_graph_repos_data(
SourceGraphRepoData
] = source_graph_repo_data_factory.batch(5)
repos = await create_or_update_repos_from_source_graph_repos_data(
test_db_session, source_graph_repo_data
db_session, source_graph_repo_data
)
assert repos == IsList(length=5)
assert all(repo == IsInstance[database.Repo] for repo in repos)
@ -31,12 +31,12 @@ async def test_create_or_update_repos_from_source_graph_repos_data(
async def test_create_or_update_repos_from_source_graph_repos_data_update(
some_repos: list[database.Repo],
test_db_session: AsyncSession,
db_session: AsyncSession,
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
) -> None:
"""Test updating repos from source graph repos data."""
assert (
await test_db_session.execute(
await db_session.execute(
sqlalchemy.select(sqlalchemy.func.count(database.Repo.id))
)
).scalar() == len(some_repos)
@ -53,13 +53,13 @@ async def test_create_or_update_repos_from_source_graph_repos_data_update(
for repo, repo_data in zip(some_repos, source_graph_repos_data, strict=True)
]
repos = await create_or_update_repos_from_source_graph_repos_data(
test_db_session, source_graph_repos_data
db_session, source_graph_repos_data
)
assert repos == IsList(length=len(some_repos))
assert all(repo == IsInstance[database.Repo] for repo in repos)
assert all(repo.id is not None for repo in repos)
assert (
await test_db_session.execute(
await db_session.execute(
sqlalchemy.select(sqlalchemy.func.count(database.Repo.id))
)
).scalar() == len(some_repos)

View File

@ -27,7 +27,7 @@ def _assert_repo_properties(
async def test_create_repo_no_dependencies(
test_db_session: AsyncSession,
db_session: AsyncSession,
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
) -> None:
"""Test creating a repo."""
@ -38,15 +38,15 @@ async def test_create_repo_no_dependencies(
stars=source_graph_repo_data.stars,
source_graph_repo_id=source_graph_repo_data.repo_id,
)
test_db_session.add(repo)
await test_db_session.flush()
await test_db_session.refresh(repo)
db_session.add(repo)
await db_session.flush()
await db_session.refresh(repo)
_assert_repo_properties(repo, source_graph_repo_data)
assert (await repo.awaitable_attrs.dependencies) == IsList(length=0)
async def test_create_repo_with_dependencies(
test_db_session: AsyncSession,
db_session: AsyncSession,
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
dependency_create_data_factory: DependencyCreateDataFactory,
) -> None:
@ -65,9 +65,8 @@ async def test_create_repo_with_dependencies(
for dependency_create_data in dependencies_create_data
],
)
test_db_session.add(repo)
await test_db_session.flush()
await test_db_session.refresh(repo)
db_session.add(repo)
await db_session.flush()
_assert_repo_properties(repo, source_graph_repo_data)
repo_dependencies = await repo.awaitable_attrs.dependencies
assert repo_dependencies == IsList(length=5)
@ -80,11 +79,11 @@ async def test_create_repo_with_dependencies(
async def test_list_repositories(
test_db_session: AsyncSession,
db_session: AsyncSession,
some_repos: list[database.Repo],
) -> None:
"""Test listing repositories."""
repos_from_db_result = await test_db_session.execute(
repos_from_db_result = await db_session.execute(
sa.select(database.Repo).options(
sqlalchemy.orm.joinedload(database.Repo.dependencies)
)

View File

@ -1,6 +1,11 @@
"""Type definitions for the application."""
from typing import NewType
#: The ID of a repository from the database.
RepoId = NewType("RepoId", int)
#: The ID of a repository from the SourceGraph API.
SourceGraphRepoId = NewType("SourceGraphRepoId", int)
#: The ID of a dependency from the database.
DependencyId = NewType("DependencyId", int)
#: The revision hash of a repository.
RevisionHash = NewType("RevisionHash", str)

View File

@ -20,8 +20,8 @@ async def async_session_uow() -> AsyncGenerator[AsyncSession, None]:
:return: a UoW instance
"""
async with async_session_maker() as session:
async with session.begin():
async with session.begin() as transaction:
try:
yield session
finally:
await session.rollback()
await transaction.rollback()