From 6484b85956f03afeb98a2971eed9e1a61b57318f Mon Sep 17 00:00:00 2001 From: Vladyslav Fedoriuk Date: Wed, 2 Aug 2023 23:09:30 +0200 Subject: [PATCH] Refactor the models and add a mapper - A new mapper allows to create database repositories from the SourceGraph data --- app/conftest.py | 24 +++++-- app/database.py | 7 +- app/factories.py | 13 +--- app/models.py | 24 +------ app/scraper/client.py | 36 +---------- app/scraper/factories.py | 12 ++++ app/scraper/mapper.py | 28 ++++++++ app/scraper/models.py | 39 +++++++++++ app/scraper/tests/test_client.py | 2 +- app/scraper/tests/test_mapper.py | 27 ++++++++ app/tests/test_database.py | 64 +++++++++++++------ ...32d84a5aea_create_repo_dependency_and_.py} | 12 ++-- 12 files changed, 188 insertions(+), 100 deletions(-) create mode 100644 app/scraper/factories.py create mode 100644 app/scraper/mapper.py create mode 100644 app/scraper/models.py create mode 100644 app/scraper/tests/test_mapper.py rename migrations/versions/{d8fc955c639b_add_repo_dependency_and_repodependency_.py => 0232d84a5aea_create_repo_dependency_and_.py} (78%) diff --git a/app/conftest.py b/app/conftest.py index a0d415a..4d64287 100644 --- a/app/conftest.py +++ b/app/conftest.py @@ -9,7 +9,9 @@ from pytest_mock import MockerFixture from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from app.database import Dependency, Repo, async_session_maker -from app.factories import RepoCreateDataFactory +from app.factories import DependencyCreateDataFactory +from app.scraper.factories import SourceGraphRepoDataFactory +from app.scraper.models import SourceGraphRepoData @pytest.fixture(autouse=True, scope="session") @@ -73,19 +75,27 @@ async def test_db_session( @pytest.fixture() async def some_repos( - test_db_session: AsyncSession, repo_create_data_factory: RepoCreateDataFactory + test_db_session: AsyncSession, + source_graph_repo_data_factory: SourceGraphRepoDataFactory, + dependency_create_data_factory: DependencyCreateDataFactory, ) -> list[Repo]: """Create some repos.""" - repo_create_data = repo_create_data_factory.batch(10) - assert repo_create_data == IsList(length=10) + source_graph_repos_data: list[ + SourceGraphRepoData + ] = source_graph_repo_data_factory.batch(10) + assert source_graph_repos_data == IsList(length=10) repos = [ Repo( - url=str(repo.url), + url=str(source_graph_repo_data.repo_url), + description=source_graph_repo_data.description, + stars=source_graph_repo_data.stars, + source_graph_repo_id=source_graph_repo_data.repo_id, dependencies=[ - Dependency(name=dependency.name) for dependency in repo.dependencies + Dependency(**dependency_create_data.model_dump()) + for dependency_create_data in dependency_create_data_factory.batch(5) ], ) - for repo in repo_create_data + for source_graph_repo_data in source_graph_repos_data ] test_db_session.add_all(repos) await test_db_session.flush() diff --git a/app/database.py b/app/database.py index d94709f..0127fb0 100644 --- a/app/database.py +++ b/app/database.py @@ -19,7 +19,7 @@ from collections.abc import AsyncGenerator from pathlib import PurePath from typing import Final -from sqlalchemy import ForeignKey, String +from sqlalchemy import BigInteger, ForeignKey, String, Text from sqlalchemy.ext.asyncio import ( AsyncAttrs, AsyncEngine, @@ -58,6 +58,11 @@ class Repo(Base): __tablename__ = "repo" id: Mapped[int] = mapped_column(primary_key=True) url: Mapped[str] = mapped_column(nullable=False, unique=True) + description: Mapped[str] = mapped_column(Text, nullable=False) + stars: Mapped[int] = mapped_column(BigInteger, nullable=False) + source_graph_repo_id: Mapped[int | None] = mapped_column( + BigInteger, nullable=True, unique=True + ) dependencies: Mapped[list["Dependency"]] = relationship( "Dependency", secondary="repo_dependency", back_populates="repos" ) diff --git a/app/factories.py b/app/factories.py index 696b16d..9a8b48f 100644 --- a/app/factories.py +++ b/app/factories.py @@ -2,7 +2,7 @@ from polyfactory.factories.pydantic_factory import ModelFactory from polyfactory.pytest_plugin import register_fixture -from app.models import DependencyCreateData, RepoCreateData +from app.models import DependencyCreateData @register_fixture @@ -10,14 +10,3 @@ class DependencyCreateDataFactory(ModelFactory[DependencyCreateData]): """Factory for creating DependencyCreateData.""" __model__ = DependencyCreateData - - -@register_fixture -class RepoCreateDataFactory(ModelFactory[RepoCreateData]): - """Factory for creating RepoCreateData.""" - - __model__ = RepoCreateData - - __randomize_collection_length__ = True - __min_collection_length__ = 2 - __max_collection_length__ = 5 diff --git a/app/models.py b/app/models.py index 1d6b6e4..c7db2a0 100644 --- a/app/models.py +++ b/app/models.py @@ -1,7 +1,7 @@ """Module contains the models for the application.""" from typing import NewType -from pydantic import AnyUrl, BaseModel, Field +from pydantic import BaseModel RepoId = NewType("RepoId", int) DependencyId = NewType("DependencyId", int) @@ -11,25 +11,3 @@ class DependencyCreateData(BaseModel): """A dependency of a repository.""" name: str - - -class RepoCreateData(BaseModel): - """A repository that is being tracked.""" - - url: AnyUrl - dependencies: list[DependencyCreateData] = Field(default_factory=list) - - -class DependencyDetails(BaseModel): - """A single dependency.""" - - id: DependencyId - name: str - - -class RepoDetails(BaseModel): - """A repository that is being tracked.""" - - id: RepoId - url: AnyUrl - dependencies: list[DependencyDetails] = Field(default_factory=list) diff --git a/app/scraper/client.py b/app/scraper/client.py index c9a1f65..2c9d6dc 100644 --- a/app/scraper/client.py +++ b/app/scraper/client.py @@ -1,55 +1,25 @@ """The client for the SourceGraph API.""" import asyncio -import datetime from collections.abc import AsyncGenerator, Mapping from contextlib import asynccontextmanager from datetime import timedelta -from typing import Any, AnyStr, Final, Literal, NewType, Self +from typing import Any, AnyStr, Final, Self from urllib.parse import quote import httpx import stamina from httpx_sse import EventSource, ServerSentEvent, aconnect_sse from pydantic import ( - BaseModel, - Field, HttpUrl, - NonNegativeInt, - TypeAdapter, - computed_field, ) +from app.scraper.models import SourceGraphRepoDataListAdapter + #: The URL of the SourceGraph SSE API. SOURCE_GRAPH_STREAM_API_URL: Final[ HttpUrl ] = "https://sourcegraph.com/.api/search/stream" -#: The ID of a repository from the SourceGraph API. -SourceGraphRepoId = NewType("SourceGraphRepoId", int) - - -class SourceGraphRepoData(BaseModel): - """The data of a repository.""" - - type: Literal["repo"] - repo_id: SourceGraphRepoId = Field(..., alias="repositoryID") - repo_handle: str = Field(..., alias="repository") - stars: NonNegativeInt = Field(..., alias="repoStars") - last_fetched_at: datetime.datetime = Field(..., alias="repoLastFetched") - description: str = Field(default="") - - @computed_field - @property - def repo_url(self: Self) -> HttpUrl: - """The URL of the repository.""" - return TypeAdapter(HttpUrl).validate_python(f"https://{self.repo_handle}") - - -#: The type adapter for the SourceGraphRepoData. -SourceGraphRepoDataAdapter = TypeAdapter(SourceGraphRepoData) - -#: The type adapter for the SourceGraphRepoData list. -SourceGraphRepoDataListAdapter = TypeAdapter(list[SourceGraphRepoData]) #: The query parameters for the SourceGraph SSE API. FASTAPI_REPOS_QUERY_PARAMS: Final[Mapping[str, str]] = { diff --git a/app/scraper/factories.py b/app/scraper/factories.py new file mode 100644 index 0000000..2c6b211 --- /dev/null +++ b/app/scraper/factories.py @@ -0,0 +1,12 @@ +"""Factories for creating test data.""" +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.pytest_plugin import register_fixture + +from app.scraper.models import SourceGraphRepoData + + +@register_fixture +class SourceGraphRepoDataFactory(ModelFactory[SourceGraphRepoData]): + """Factory for creating RepoCreateData.""" + + __model__ = SourceGraphRepoData diff --git a/app/scraper/mapper.py b/app/scraper/mapper.py new file mode 100644 index 0000000..7add08d --- /dev/null +++ b/app/scraper/mapper.py @@ -0,0 +1,28 @@ +"""Mapper for scraper.""" +from collections.abc import Sequence + +import sqlalchemy.sql +from sqlalchemy.ext.asyncio import AsyncSession + +from app import database +from app.scraper.models import SourceGraphRepoData + + +async def create_repos_from_source_graph_repos_data( + session: AsyncSession, source_graph_repo_data: Sequence[SourceGraphRepoData] +) -> Sequence[database.Repo]: + """Create repos from source graph repos data.""" + return ( + await session.scalars( + sqlalchemy.sql.insert(database.Repo).returning(database.Repo), + [ + { + "url": str(repo_data.repo_url), + "description": repo_data.description, + "stars": repo_data.stars, + "source_graph_repo_id": repo_data.repo_id, + } + for repo_data in source_graph_repo_data + ], + ) + ).all() diff --git a/app/scraper/models.py b/app/scraper/models.py new file mode 100644 index 0000000..e38735e --- /dev/null +++ b/app/scraper/models.py @@ -0,0 +1,39 @@ +"""The models for the scraper.""" +import datetime +from typing import Literal, NewType, Self + +from pydantic import ( + BaseModel, + Field, + HttpUrl, + NonNegativeInt, + TypeAdapter, + computed_field, +) + +#: The ID of a repository from the SourceGraph API. +SourceGraphRepoId = NewType("SourceGraphRepoId", int) + + +class SourceGraphRepoData(BaseModel): + """The data of a repository.""" + + type: Literal["repo"] + repo_id: SourceGraphRepoId = Field(..., alias="repositoryID") + repo_handle: str = Field(..., alias="repository") + stars: NonNegativeInt = Field(..., alias="repoStars") + last_fetched_at: datetime.datetime = Field(..., alias="repoLastFetched") + description: str = Field(default="") + + @computed_field + @property + def repo_url(self: Self) -> HttpUrl: + """The URL of the repository.""" + return TypeAdapter(HttpUrl).validate_python(f"https://{self.repo_handle}") + + +#: The type adapter for the SourceGraphRepoData. +SourceGraphRepoDataAdapter = TypeAdapter(SourceGraphRepoData) + +#: The type adapter for the SourceGraphRepoData list. +SourceGraphRepoDataListAdapter = TypeAdapter(list[SourceGraphRepoData]) diff --git a/app/scraper/tests/test_client.py b/app/scraper/tests/test_client.py index f3972d9..3869693 100644 --- a/app/scraper/tests/test_client.py +++ b/app/scraper/tests/test_client.py @@ -3,7 +3,7 @@ import pytest from dirty_equals import HasLen, IsDatetime, IsInstance, IsPositiveInt from pydantic import Json, TypeAdapter -from app.scraper.client import SourceGraphRepoData +from app.scraper.models import SourceGraphRepoData @pytest.fixture() diff --git a/app/scraper/tests/test_mapper.py b/app/scraper/tests/test_mapper.py new file mode 100644 index 0000000..efac398 --- /dev/null +++ b/app/scraper/tests/test_mapper.py @@ -0,0 +1,27 @@ +"""The tests for the scraper mapper.""" +import pytest +from dirty_equals import IsInstance, IsList +from sqlalchemy.ext.asyncio import AsyncSession + +from app import database +from app.scraper.factories import SourceGraphRepoDataFactory +from app.scraper.mapper import create_repos_from_source_graph_repos_data +from app.scraper.models import SourceGraphRepoData + +pytestmark = pytest.mark.anyio + + +async def test_create_repos_from_source_graph_repos_data( + test_db_session: AsyncSession, + source_graph_repo_data_factory: SourceGraphRepoDataFactory, +) -> None: + """Test creating repos from source graph repos data.""" + source_graph_repo_data: list[ + SourceGraphRepoData + ] = source_graph_repo_data_factory.batch(5) + repos = await create_repos_from_source_graph_repos_data( + test_db_session, source_graph_repo_data + ) + assert repos == IsList(length=5) + assert all(repo == IsInstance[database.Repo] for repo in repos) + assert all(repo.id is not None for repo in repos) diff --git a/app/tests/test_database.py b/app/tests/test_database.py index e3819fc..3f8c36e 100644 --- a/app/tests/test_database.py +++ b/app/tests/test_database.py @@ -6,49 +6,75 @@ from dirty_equals import IsList from sqlalchemy.ext.asyncio import AsyncSession from app import database -from app.factories import RepoCreateDataFactory +from app.factories import DependencyCreateDataFactory +from app.models import DependencyCreateData +from app.scraper.factories import SourceGraphRepoDataFactory +from app.scraper.models import SourceGraphRepoData pytestmark = pytest.mark.anyio +def _assert_repo_properties( + repo: database.Repo, source_graph_repo_data: SourceGraphRepoData +) -> bool: + """Assert that the repo has the expected properties.""" + assert repo.id is not None + assert repo.url == str(source_graph_repo_data.repo_url) + assert repo.description == source_graph_repo_data.description + assert repo.stars == source_graph_repo_data.stars + assert repo.source_graph_repo_id == source_graph_repo_data.repo_id + return True + + async def test_create_repo_no_dependencies( - test_db_session: AsyncSession, repo_create_data_factory: RepoCreateDataFactory + test_db_session: AsyncSession, + source_graph_repo_data_factory: SourceGraphRepoDataFactory, ) -> None: """Test creating a repo.""" - repo_create_data = repo_create_data_factory.build() - repo = database.Repo(url=str(repo_create_data.url)) + source_graph_repo_data: SourceGraphRepoData = source_graph_repo_data_factory.build() + repo = database.Repo( + url=str(source_graph_repo_data.repo_url), + description=source_graph_repo_data.description, + stars=source_graph_repo_data.stars, + source_graph_repo_id=source_graph_repo_data.repo_id, + ) test_db_session.add(repo) await test_db_session.flush() await test_db_session.refresh(repo) - assert repo.id is not None - assert repo.url == str(repo_create_data.url) + _assert_repo_properties(repo, source_graph_repo_data) assert (await repo.awaitable_attrs.dependencies) == IsList(length=0) async def test_create_repo_with_dependencies( - test_db_session: AsyncSession, repo_create_data_factory: RepoCreateDataFactory + test_db_session: AsyncSession, + source_graph_repo_data_factory: SourceGraphRepoDataFactory, + dependency_create_data_factory: DependencyCreateDataFactory, ) -> None: """Test creating a repo with dependencies.""" - repo_create_data = repo_create_data_factory.build() + source_graph_repo_data: SourceGraphRepoData = source_graph_repo_data_factory.build() + dependencies_create_data: list[ + DependencyCreateData + ] = dependency_create_data_factory.batch(5) repo = database.Repo( - url=str(repo_create_data.url), + url=str(source_graph_repo_data.repo_url), + description=source_graph_repo_data.description, + stars=source_graph_repo_data.stars, + source_graph_repo_id=source_graph_repo_data.repo_id, dependencies=[ - database.Dependency(name=dependency.name) - for dependency in repo_create_data.dependencies + database.Dependency(**dependency_create_data.model_dump()) + for dependency_create_data in dependencies_create_data ], ) test_db_session.add(repo) await test_db_session.flush() await test_db_session.refresh(repo) - assert repo.id is not None - assert repo.url == str(repo_create_data.url) + _assert_repo_properties(repo, source_graph_repo_data) repo_dependencies = await repo.awaitable_attrs.dependencies - assert 2 <= len(repo_dependencies) <= 5 - assert repo_dependencies == IsList(length=len(repo_create_data.dependencies)) + assert repo_dependencies == IsList(length=5) assert all( repo_dependency.name == dependency.name for repo_dependency, dependency in zip( - repo_dependencies, repo_create_data.dependencies, strict=True + repo_dependencies, dependencies_create_data, strict=True ) ) @@ -66,12 +92,12 @@ async def test_list_repositories( repos_from_db = repos_from_db_result.scalars().unique().all() assert repos_from_db == IsList(length=10) assert all( - repo.url == str(repo_create_data.url) + repo.id == repo_data.id and all( repo_dependency.name == dependency.name for repo_dependency, dependency in zip( - repo.dependencies, repo_create_data.dependencies, strict=True + repo.dependencies, repo_data.dependencies, strict=True ) ) - for repo, repo_create_data in zip(repos_from_db, some_repos, strict=True) + for repo, repo_data in zip(repos_from_db, some_repos, strict=True) ) diff --git a/migrations/versions/d8fc955c639b_add_repo_dependency_and_repodependency_.py b/migrations/versions/0232d84a5aea_create_repo_dependency_and_.py similarity index 78% rename from migrations/versions/d8fc955c639b_add_repo_dependency_and_repodependency_.py rename to migrations/versions/0232d84a5aea_create_repo_dependency_and_.py index a83e310..8e8a4c1 100644 --- a/migrations/versions/d8fc955c639b_add_repo_dependency_and_repodependency_.py +++ b/migrations/versions/0232d84a5aea_create_repo_dependency_and_.py @@ -1,15 +1,15 @@ -"""Add Repo, Dependency, and RepoDependency tables +"""Create Repo, Dependency and RepoDependency tables -Revision ID: d8fc955c639b +Revision ID: 0232d84a5aea Revises: -Create Date: 2023-07-28 23:41:00.169286 +Create Date: 2023-08-02 22:14:12.910175 """ import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = "d8fc955c639b" +revision = "0232d84a5aea" down_revision = None branch_labels = None depends_on = None @@ -28,7 +28,11 @@ def upgrade() -> None: "repo", sa.Column("id", sa.Integer(), nullable=False), sa.Column("url", sa.String(), nullable=False), + sa.Column("description", sa.Text(), nullable=False), + sa.Column("stars", sa.BigInteger(), nullable=False), + sa.Column("source_graph_repo_id", sa.BigInteger(), nullable=True), sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("source_graph_repo_id"), sa.UniqueConstraint("url"), ) op.create_table(