"""The tests for the source graph mapper to the database objects.""" import pytest import sqlalchemy from dirty_equals import IsInstance, IsList from sqlalchemy.ext.asyncio import AsyncSession from app import database from app.source_graph.factories import SourceGraphRepoDataFactory from app.source_graph.mapper import create_or_update_repos_from_source_graph_repos_data from app.source_graph.models import SourceGraphRepoData pytestmark = pytest.mark.anyio async def test_create_or_update_repos_from_source_graph_repos_data( db_session: AsyncSession, source_graph_repo_data_factory: SourceGraphRepoDataFactory, ) -> None: """Test creating repos from source graph repos data.""" source_graph_repo_data: list[ SourceGraphRepoData ] = source_graph_repo_data_factory.batch(5) repos = await create_or_update_repos_from_source_graph_repos_data( db_session, source_graph_repo_data ) assert repos == IsList(length=5) assert all(repo == IsInstance[database.Repo] for repo in repos) assert all(repo.id is not None for repo in repos) async def test_create_or_update_repos_from_source_graph_repos_data_update( some_repos: list[database.Repo], db_session: AsyncSession, source_graph_repo_data_factory: SourceGraphRepoDataFactory, ) -> None: """Test updating repos from source graph repos data.""" assert ( await db_session.execute( sqlalchemy.select(sqlalchemy.func.count(database.Repo.id)) ) ).scalar() == len(some_repos) source_graph_repos_data: list[ SourceGraphRepoData ] = source_graph_repo_data_factory.batch(len(some_repos)) source_graph_repos_data = [ SourceGraphRepoData( **( repo_data.model_dump(by_alias=True) | {"repositoryID": repo.source_graph_repo_id} ) ) for repo, repo_data in zip(some_repos, source_graph_repos_data, strict=True) ] repos = await create_or_update_repos_from_source_graph_repos_data( db_session, source_graph_repos_data ) assert repos == IsList(length=len(some_repos)) assert all(repo == IsInstance[database.Repo] for repo in repos) assert all(repo.id is not None for repo in repos) assert ( await db_session.execute( sqlalchemy.select(sqlalchemy.func.count(database.Repo.id)) ) ).scalar() == len(some_repos)