Refactor the models and add a mapper

- A new mapper allows to create database repositories from the SourceGraph data
This commit is contained in:
Vladyslav Fedoriuk 2023-08-02 23:09:30 +02:00
parent 9d0e6e606d
commit 6484b85956
12 changed files with 188 additions and 100 deletions

View File

@ -9,7 +9,9 @@ from pytest_mock import MockerFixture
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
from app.database import Dependency, Repo, async_session_maker from app.database import Dependency, Repo, async_session_maker
from app.factories import RepoCreateDataFactory from app.factories import DependencyCreateDataFactory
from app.scraper.factories import SourceGraphRepoDataFactory
from app.scraper.models import SourceGraphRepoData
@pytest.fixture(autouse=True, scope="session") @pytest.fixture(autouse=True, scope="session")
@ -73,19 +75,27 @@ async def test_db_session(
@pytest.fixture() @pytest.fixture()
async def some_repos( async def some_repos(
test_db_session: AsyncSession, repo_create_data_factory: RepoCreateDataFactory test_db_session: AsyncSession,
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
dependency_create_data_factory: DependencyCreateDataFactory,
) -> list[Repo]: ) -> list[Repo]:
"""Create some repos.""" """Create some repos."""
repo_create_data = repo_create_data_factory.batch(10) source_graph_repos_data: list[
assert repo_create_data == IsList(length=10) SourceGraphRepoData
] = source_graph_repo_data_factory.batch(10)
assert source_graph_repos_data == IsList(length=10)
repos = [ repos = [
Repo( Repo(
url=str(repo.url), url=str(source_graph_repo_data.repo_url),
description=source_graph_repo_data.description,
stars=source_graph_repo_data.stars,
source_graph_repo_id=source_graph_repo_data.repo_id,
dependencies=[ dependencies=[
Dependency(name=dependency.name) for dependency in repo.dependencies Dependency(**dependency_create_data.model_dump())
for dependency_create_data in dependency_create_data_factory.batch(5)
], ],
) )
for repo in repo_create_data for source_graph_repo_data in source_graph_repos_data
] ]
test_db_session.add_all(repos) test_db_session.add_all(repos)
await test_db_session.flush() await test_db_session.flush()

View File

@ -19,7 +19,7 @@ from collections.abc import AsyncGenerator
from pathlib import PurePath from pathlib import PurePath
from typing import Final from typing import Final
from sqlalchemy import ForeignKey, String from sqlalchemy import BigInteger, ForeignKey, String, Text
from sqlalchemy.ext.asyncio import ( from sqlalchemy.ext.asyncio import (
AsyncAttrs, AsyncAttrs,
AsyncEngine, AsyncEngine,
@ -58,6 +58,11 @@ class Repo(Base):
__tablename__ = "repo" __tablename__ = "repo"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
url: Mapped[str] = mapped_column(nullable=False, unique=True) url: Mapped[str] = mapped_column(nullable=False, unique=True)
description: Mapped[str] = mapped_column(Text, nullable=False)
stars: Mapped[int] = mapped_column(BigInteger, nullable=False)
source_graph_repo_id: Mapped[int | None] = mapped_column(
BigInteger, nullable=True, unique=True
)
dependencies: Mapped[list["Dependency"]] = relationship( dependencies: Mapped[list["Dependency"]] = relationship(
"Dependency", secondary="repo_dependency", back_populates="repos" "Dependency", secondary="repo_dependency", back_populates="repos"
) )

View File

@ -2,7 +2,7 @@
from polyfactory.factories.pydantic_factory import ModelFactory from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.pytest_plugin import register_fixture from polyfactory.pytest_plugin import register_fixture
from app.models import DependencyCreateData, RepoCreateData from app.models import DependencyCreateData
@register_fixture @register_fixture
@ -10,14 +10,3 @@ class DependencyCreateDataFactory(ModelFactory[DependencyCreateData]):
"""Factory for creating DependencyCreateData.""" """Factory for creating DependencyCreateData."""
__model__ = DependencyCreateData __model__ = DependencyCreateData
@register_fixture
class RepoCreateDataFactory(ModelFactory[RepoCreateData]):
"""Factory for creating RepoCreateData."""
__model__ = RepoCreateData
__randomize_collection_length__ = True
__min_collection_length__ = 2
__max_collection_length__ = 5

View File

@ -1,7 +1,7 @@
"""Module contains the models for the application.""" """Module contains the models for the application."""
from typing import NewType from typing import NewType
from pydantic import AnyUrl, BaseModel, Field from pydantic import BaseModel
RepoId = NewType("RepoId", int) RepoId = NewType("RepoId", int)
DependencyId = NewType("DependencyId", int) DependencyId = NewType("DependencyId", int)
@ -11,25 +11,3 @@ class DependencyCreateData(BaseModel):
"""A dependency of a repository.""" """A dependency of a repository."""
name: str name: str
class RepoCreateData(BaseModel):
"""A repository that is being tracked."""
url: AnyUrl
dependencies: list[DependencyCreateData] = Field(default_factory=list)
class DependencyDetails(BaseModel):
"""A single dependency."""
id: DependencyId
name: str
class RepoDetails(BaseModel):
"""A repository that is being tracked."""
id: RepoId
url: AnyUrl
dependencies: list[DependencyDetails] = Field(default_factory=list)

View File

@ -1,55 +1,25 @@
"""The client for the SourceGraph API.""" """The client for the SourceGraph API."""
import asyncio import asyncio
import datetime
from collections.abc import AsyncGenerator, Mapping from collections.abc import AsyncGenerator, Mapping
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import timedelta from datetime import timedelta
from typing import Any, AnyStr, Final, Literal, NewType, Self from typing import Any, AnyStr, Final, Self
from urllib.parse import quote from urllib.parse import quote
import httpx import httpx
import stamina import stamina
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
from pydantic import ( from pydantic import (
BaseModel,
Field,
HttpUrl, HttpUrl,
NonNegativeInt,
TypeAdapter,
computed_field,
) )
from app.scraper.models import SourceGraphRepoDataListAdapter
#: The URL of the SourceGraph SSE API. #: The URL of the SourceGraph SSE API.
SOURCE_GRAPH_STREAM_API_URL: Final[ SOURCE_GRAPH_STREAM_API_URL: Final[
HttpUrl HttpUrl
] = "https://sourcegraph.com/.api/search/stream" ] = "https://sourcegraph.com/.api/search/stream"
#: The ID of a repository from the SourceGraph API.
SourceGraphRepoId = NewType("SourceGraphRepoId", int)
class SourceGraphRepoData(BaseModel):
"""The data of a repository."""
type: Literal["repo"]
repo_id: SourceGraphRepoId = Field(..., alias="repositoryID")
repo_handle: str = Field(..., alias="repository")
stars: NonNegativeInt = Field(..., alias="repoStars")
last_fetched_at: datetime.datetime = Field(..., alias="repoLastFetched")
description: str = Field(default="")
@computed_field
@property
def repo_url(self: Self) -> HttpUrl:
"""The URL of the repository."""
return TypeAdapter(HttpUrl).validate_python(f"https://{self.repo_handle}")
#: The type adapter for the SourceGraphRepoData.
SourceGraphRepoDataAdapter = TypeAdapter(SourceGraphRepoData)
#: The type adapter for the SourceGraphRepoData list.
SourceGraphRepoDataListAdapter = TypeAdapter(list[SourceGraphRepoData])
#: The query parameters for the SourceGraph SSE API. #: The query parameters for the SourceGraph SSE API.
FASTAPI_REPOS_QUERY_PARAMS: Final[Mapping[str, str]] = { FASTAPI_REPOS_QUERY_PARAMS: Final[Mapping[str, str]] = {

12
app/scraper/factories.py Normal file
View File

@ -0,0 +1,12 @@
"""Factories for creating test data."""
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.pytest_plugin import register_fixture
from app.scraper.models import SourceGraphRepoData
@register_fixture
class SourceGraphRepoDataFactory(ModelFactory[SourceGraphRepoData]):
"""Factory for creating RepoCreateData."""
__model__ = SourceGraphRepoData

28
app/scraper/mapper.py Normal file
View File

@ -0,0 +1,28 @@
"""Mapper for scraper."""
from collections.abc import Sequence
import sqlalchemy.sql
from sqlalchemy.ext.asyncio import AsyncSession
from app import database
from app.scraper.models import SourceGraphRepoData
async def create_repos_from_source_graph_repos_data(
session: AsyncSession, source_graph_repo_data: Sequence[SourceGraphRepoData]
) -> Sequence[database.Repo]:
"""Create repos from source graph repos data."""
return (
await session.scalars(
sqlalchemy.sql.insert(database.Repo).returning(database.Repo),
[
{
"url": str(repo_data.repo_url),
"description": repo_data.description,
"stars": repo_data.stars,
"source_graph_repo_id": repo_data.repo_id,
}
for repo_data in source_graph_repo_data
],
)
).all()

39
app/scraper/models.py Normal file
View File

@ -0,0 +1,39 @@
"""The models for the scraper."""
import datetime
from typing import Literal, NewType, Self
from pydantic import (
BaseModel,
Field,
HttpUrl,
NonNegativeInt,
TypeAdapter,
computed_field,
)
#: The ID of a repository from the SourceGraph API.
SourceGraphRepoId = NewType("SourceGraphRepoId", int)
class SourceGraphRepoData(BaseModel):
"""The data of a repository."""
type: Literal["repo"]
repo_id: SourceGraphRepoId = Field(..., alias="repositoryID")
repo_handle: str = Field(..., alias="repository")
stars: NonNegativeInt = Field(..., alias="repoStars")
last_fetched_at: datetime.datetime = Field(..., alias="repoLastFetched")
description: str = Field(default="")
@computed_field
@property
def repo_url(self: Self) -> HttpUrl:
"""The URL of the repository."""
return TypeAdapter(HttpUrl).validate_python(f"https://{self.repo_handle}")
#: The type adapter for the SourceGraphRepoData.
SourceGraphRepoDataAdapter = TypeAdapter(SourceGraphRepoData)
#: The type adapter for the SourceGraphRepoData list.
SourceGraphRepoDataListAdapter = TypeAdapter(list[SourceGraphRepoData])

View File

@ -3,7 +3,7 @@ import pytest
from dirty_equals import HasLen, IsDatetime, IsInstance, IsPositiveInt from dirty_equals import HasLen, IsDatetime, IsInstance, IsPositiveInt
from pydantic import Json, TypeAdapter from pydantic import Json, TypeAdapter
from app.scraper.client import SourceGraphRepoData from app.scraper.models import SourceGraphRepoData
@pytest.fixture() @pytest.fixture()

View File

@ -0,0 +1,27 @@
"""The tests for the scraper mapper."""
import pytest
from dirty_equals import IsInstance, IsList
from sqlalchemy.ext.asyncio import AsyncSession
from app import database
from app.scraper.factories import SourceGraphRepoDataFactory
from app.scraper.mapper import create_repos_from_source_graph_repos_data
from app.scraper.models import SourceGraphRepoData
pytestmark = pytest.mark.anyio
async def test_create_repos_from_source_graph_repos_data(
test_db_session: AsyncSession,
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
) -> None:
"""Test creating repos from source graph repos data."""
source_graph_repo_data: list[
SourceGraphRepoData
] = source_graph_repo_data_factory.batch(5)
repos = await create_repos_from_source_graph_repos_data(
test_db_session, source_graph_repo_data
)
assert repos == IsList(length=5)
assert all(repo == IsInstance[database.Repo] for repo in repos)
assert all(repo.id is not None for repo in repos)

View File

@ -6,49 +6,75 @@ from dirty_equals import IsList
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app import database from app import database
from app.factories import RepoCreateDataFactory from app.factories import DependencyCreateDataFactory
from app.models import DependencyCreateData
from app.scraper.factories import SourceGraphRepoDataFactory
from app.scraper.models import SourceGraphRepoData
pytestmark = pytest.mark.anyio pytestmark = pytest.mark.anyio
def _assert_repo_properties(
repo: database.Repo, source_graph_repo_data: SourceGraphRepoData
) -> bool:
"""Assert that the repo has the expected properties."""
assert repo.id is not None
assert repo.url == str(source_graph_repo_data.repo_url)
assert repo.description == source_graph_repo_data.description
assert repo.stars == source_graph_repo_data.stars
assert repo.source_graph_repo_id == source_graph_repo_data.repo_id
return True
async def test_create_repo_no_dependencies( async def test_create_repo_no_dependencies(
test_db_session: AsyncSession, repo_create_data_factory: RepoCreateDataFactory test_db_session: AsyncSession,
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
) -> None: ) -> None:
"""Test creating a repo.""" """Test creating a repo."""
repo_create_data = repo_create_data_factory.build() source_graph_repo_data: SourceGraphRepoData = source_graph_repo_data_factory.build()
repo = database.Repo(url=str(repo_create_data.url)) repo = database.Repo(
url=str(source_graph_repo_data.repo_url),
description=source_graph_repo_data.description,
stars=source_graph_repo_data.stars,
source_graph_repo_id=source_graph_repo_data.repo_id,
)
test_db_session.add(repo) test_db_session.add(repo)
await test_db_session.flush() await test_db_session.flush()
await test_db_session.refresh(repo) await test_db_session.refresh(repo)
assert repo.id is not None _assert_repo_properties(repo, source_graph_repo_data)
assert repo.url == str(repo_create_data.url)
assert (await repo.awaitable_attrs.dependencies) == IsList(length=0) assert (await repo.awaitable_attrs.dependencies) == IsList(length=0)
async def test_create_repo_with_dependencies( async def test_create_repo_with_dependencies(
test_db_session: AsyncSession, repo_create_data_factory: RepoCreateDataFactory test_db_session: AsyncSession,
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
dependency_create_data_factory: DependencyCreateDataFactory,
) -> None: ) -> None:
"""Test creating a repo with dependencies.""" """Test creating a repo with dependencies."""
repo_create_data = repo_create_data_factory.build() source_graph_repo_data: SourceGraphRepoData = source_graph_repo_data_factory.build()
dependencies_create_data: list[
DependencyCreateData
] = dependency_create_data_factory.batch(5)
repo = database.Repo( repo = database.Repo(
url=str(repo_create_data.url), url=str(source_graph_repo_data.repo_url),
description=source_graph_repo_data.description,
stars=source_graph_repo_data.stars,
source_graph_repo_id=source_graph_repo_data.repo_id,
dependencies=[ dependencies=[
database.Dependency(name=dependency.name) database.Dependency(**dependency_create_data.model_dump())
for dependency in repo_create_data.dependencies for dependency_create_data in dependencies_create_data
], ],
) )
test_db_session.add(repo) test_db_session.add(repo)
await test_db_session.flush() await test_db_session.flush()
await test_db_session.refresh(repo) await test_db_session.refresh(repo)
assert repo.id is not None _assert_repo_properties(repo, source_graph_repo_data)
assert repo.url == str(repo_create_data.url)
repo_dependencies = await repo.awaitable_attrs.dependencies repo_dependencies = await repo.awaitable_attrs.dependencies
assert 2 <= len(repo_dependencies) <= 5 assert repo_dependencies == IsList(length=5)
assert repo_dependencies == IsList(length=len(repo_create_data.dependencies))
assert all( assert all(
repo_dependency.name == dependency.name repo_dependency.name == dependency.name
for repo_dependency, dependency in zip( for repo_dependency, dependency in zip(
repo_dependencies, repo_create_data.dependencies, strict=True repo_dependencies, dependencies_create_data, strict=True
) )
) )
@ -66,12 +92,12 @@ async def test_list_repositories(
repos_from_db = repos_from_db_result.scalars().unique().all() repos_from_db = repos_from_db_result.scalars().unique().all()
assert repos_from_db == IsList(length=10) assert repos_from_db == IsList(length=10)
assert all( assert all(
repo.url == str(repo_create_data.url) repo.id == repo_data.id
and all( and all(
repo_dependency.name == dependency.name repo_dependency.name == dependency.name
for repo_dependency, dependency in zip( for repo_dependency, dependency in zip(
repo.dependencies, repo_create_data.dependencies, strict=True repo.dependencies, repo_data.dependencies, strict=True
) )
) )
for repo, repo_create_data in zip(repos_from_db, some_repos, strict=True) for repo, repo_data in zip(repos_from_db, some_repos, strict=True)
) )

View File

@ -1,15 +1,15 @@
"""Add Repo, Dependency, and RepoDependency tables """Create Repo, Dependency and RepoDependency tables
Revision ID: d8fc955c639b Revision ID: 0232d84a5aea
Revises: Revises:
Create Date: 2023-07-28 23:41:00.169286 Create Date: 2023-08-02 22:14:12.910175
""" """
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = "d8fc955c639b" revision = "0232d84a5aea"
down_revision = None down_revision = None
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@ -28,7 +28,11 @@ def upgrade() -> None:
"repo", "repo",
sa.Column("id", sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column("url", sa.String(), nullable=False), sa.Column("url", sa.String(), nullable=False),
sa.Column("description", sa.Text(), nullable=False),
sa.Column("stars", sa.BigInteger(), nullable=False),
sa.Column("source_graph_repo_id", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("source_graph_repo_id"),
sa.UniqueConstraint("url"), sa.UniqueConstraint("url"),
) )
op.create_table( op.create_table(