Refactor the models and add a mapper

- A new mapper allows to create database repositories from the SourceGraph data
This commit is contained in:
Vladyslav Fedoriuk 2023-08-02 23:09:30 +02:00
parent 9d0e6e606d
commit 6484b85956
12 changed files with 188 additions and 100 deletions

View File

@ -9,7 +9,9 @@ from pytest_mock import MockerFixture
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
from app.database import Dependency, Repo, async_session_maker
from app.factories import RepoCreateDataFactory
from app.factories import DependencyCreateDataFactory
from app.scraper.factories import SourceGraphRepoDataFactory
from app.scraper.models import SourceGraphRepoData
@pytest.fixture(autouse=True, scope="session")
@ -73,19 +75,27 @@ async def test_db_session(
@pytest.fixture()
async def some_repos(
test_db_session: AsyncSession, repo_create_data_factory: RepoCreateDataFactory
test_db_session: AsyncSession,
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
dependency_create_data_factory: DependencyCreateDataFactory,
) -> list[Repo]:
"""Create some repos."""
repo_create_data = repo_create_data_factory.batch(10)
assert repo_create_data == IsList(length=10)
source_graph_repos_data: list[
SourceGraphRepoData
] = source_graph_repo_data_factory.batch(10)
assert source_graph_repos_data == IsList(length=10)
repos = [
Repo(
url=str(repo.url),
url=str(source_graph_repo_data.repo_url),
description=source_graph_repo_data.description,
stars=source_graph_repo_data.stars,
source_graph_repo_id=source_graph_repo_data.repo_id,
dependencies=[
Dependency(name=dependency.name) for dependency in repo.dependencies
Dependency(**dependency_create_data.model_dump())
for dependency_create_data in dependency_create_data_factory.batch(5)
],
)
for repo in repo_create_data
for source_graph_repo_data in source_graph_repos_data
]
test_db_session.add_all(repos)
await test_db_session.flush()

View File

@ -19,7 +19,7 @@ from collections.abc import AsyncGenerator
from pathlib import PurePath
from typing import Final
from sqlalchemy import ForeignKey, String
from sqlalchemy import BigInteger, ForeignKey, String, Text
from sqlalchemy.ext.asyncio import (
AsyncAttrs,
AsyncEngine,
@ -58,6 +58,11 @@ class Repo(Base):
__tablename__ = "repo"
id: Mapped[int] = mapped_column(primary_key=True)
url: Mapped[str] = mapped_column(nullable=False, unique=True)
description: Mapped[str] = mapped_column(Text, nullable=False)
stars: Mapped[int] = mapped_column(BigInteger, nullable=False)
source_graph_repo_id: Mapped[int | None] = mapped_column(
BigInteger, nullable=True, unique=True
)
dependencies: Mapped[list["Dependency"]] = relationship(
"Dependency", secondary="repo_dependency", back_populates="repos"
)

View File

@ -2,7 +2,7 @@
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.pytest_plugin import register_fixture
from app.models import DependencyCreateData, RepoCreateData
from app.models import DependencyCreateData
@register_fixture
@ -10,14 +10,3 @@ class DependencyCreateDataFactory(ModelFactory[DependencyCreateData]):
"""Factory for creating DependencyCreateData."""
__model__ = DependencyCreateData
@register_fixture
class RepoCreateDataFactory(ModelFactory[RepoCreateData]):
"""Factory for creating RepoCreateData."""
__model__ = RepoCreateData
__randomize_collection_length__ = True
__min_collection_length__ = 2
__max_collection_length__ = 5

View File

@ -1,7 +1,7 @@
"""Module contains the models for the application."""
from typing import NewType
from pydantic import AnyUrl, BaseModel, Field
from pydantic import BaseModel
RepoId = NewType("RepoId", int)
DependencyId = NewType("DependencyId", int)
@ -11,25 +11,3 @@ class DependencyCreateData(BaseModel):
"""A dependency of a repository."""
name: str
class RepoCreateData(BaseModel):
"""A repository that is being tracked."""
url: AnyUrl
dependencies: list[DependencyCreateData] = Field(default_factory=list)
class DependencyDetails(BaseModel):
"""A single dependency."""
id: DependencyId
name: str
class RepoDetails(BaseModel):
"""A repository that is being tracked."""
id: RepoId
url: AnyUrl
dependencies: list[DependencyDetails] = Field(default_factory=list)

View File

@ -1,55 +1,25 @@
"""The client for the SourceGraph API."""
import asyncio
import datetime
from collections.abc import AsyncGenerator, Mapping
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import Any, AnyStr, Final, Literal, NewType, Self
from typing import Any, AnyStr, Final, Self
from urllib.parse import quote
import httpx
import stamina
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
from pydantic import (
BaseModel,
Field,
HttpUrl,
NonNegativeInt,
TypeAdapter,
computed_field,
)
from app.scraper.models import SourceGraphRepoDataListAdapter
#: The URL of the SourceGraph SSE API.
SOURCE_GRAPH_STREAM_API_URL: Final[
HttpUrl
] = "https://sourcegraph.com/.api/search/stream"
#: The ID of a repository from the SourceGraph API.
SourceGraphRepoId = NewType("SourceGraphRepoId", int)
class SourceGraphRepoData(BaseModel):
"""The data of a repository."""
type: Literal["repo"]
repo_id: SourceGraphRepoId = Field(..., alias="repositoryID")
repo_handle: str = Field(..., alias="repository")
stars: NonNegativeInt = Field(..., alias="repoStars")
last_fetched_at: datetime.datetime = Field(..., alias="repoLastFetched")
description: str = Field(default="")
@computed_field
@property
def repo_url(self: Self) -> HttpUrl:
"""The URL of the repository."""
return TypeAdapter(HttpUrl).validate_python(f"https://{self.repo_handle}")
#: The type adapter for the SourceGraphRepoData.
SourceGraphRepoDataAdapter = TypeAdapter(SourceGraphRepoData)
#: The type adapter for the SourceGraphRepoData list.
SourceGraphRepoDataListAdapter = TypeAdapter(list[SourceGraphRepoData])
#: The query parameters for the SourceGraph SSE API.
FASTAPI_REPOS_QUERY_PARAMS: Final[Mapping[str, str]] = {

12
app/scraper/factories.py Normal file
View File

@ -0,0 +1,12 @@
"""Factories for creating test data."""
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.pytest_plugin import register_fixture
from app.scraper.models import SourceGraphRepoData
@register_fixture
class SourceGraphRepoDataFactory(ModelFactory[SourceGraphRepoData]):
"""Factory for creating RepoCreateData."""
__model__ = SourceGraphRepoData

28
app/scraper/mapper.py Normal file
View File

@ -0,0 +1,28 @@
"""Mapper for scraper."""
from collections.abc import Sequence
import sqlalchemy.sql
from sqlalchemy.ext.asyncio import AsyncSession
from app import database
from app.scraper.models import SourceGraphRepoData
async def create_repos_from_source_graph_repos_data(
session: AsyncSession, source_graph_repo_data: Sequence[SourceGraphRepoData]
) -> Sequence[database.Repo]:
"""Create repos from source graph repos data."""
return (
await session.scalars(
sqlalchemy.sql.insert(database.Repo).returning(database.Repo),
[
{
"url": str(repo_data.repo_url),
"description": repo_data.description,
"stars": repo_data.stars,
"source_graph_repo_id": repo_data.repo_id,
}
for repo_data in source_graph_repo_data
],
)
).all()

39
app/scraper/models.py Normal file
View File

@ -0,0 +1,39 @@
"""The models for the scraper."""
import datetime
from typing import Literal, NewType, Self
from pydantic import (
BaseModel,
Field,
HttpUrl,
NonNegativeInt,
TypeAdapter,
computed_field,
)
#: The ID of a repository from the SourceGraph API.
SourceGraphRepoId = NewType("SourceGraphRepoId", int)
class SourceGraphRepoData(BaseModel):
"""The data of a repository."""
type: Literal["repo"]
repo_id: SourceGraphRepoId = Field(..., alias="repositoryID")
repo_handle: str = Field(..., alias="repository")
stars: NonNegativeInt = Field(..., alias="repoStars")
last_fetched_at: datetime.datetime = Field(..., alias="repoLastFetched")
description: str = Field(default="")
@computed_field
@property
def repo_url(self: Self) -> HttpUrl:
"""The URL of the repository."""
return TypeAdapter(HttpUrl).validate_python(f"https://{self.repo_handle}")
#: The type adapter for the SourceGraphRepoData.
SourceGraphRepoDataAdapter = TypeAdapter(SourceGraphRepoData)
#: The type adapter for the SourceGraphRepoData list.
SourceGraphRepoDataListAdapter = TypeAdapter(list[SourceGraphRepoData])

View File

@ -3,7 +3,7 @@ import pytest
from dirty_equals import HasLen, IsDatetime, IsInstance, IsPositiveInt
from pydantic import Json, TypeAdapter
from app.scraper.client import SourceGraphRepoData
from app.scraper.models import SourceGraphRepoData
@pytest.fixture()

View File

@ -0,0 +1,27 @@
"""The tests for the scraper mapper."""
import pytest
from dirty_equals import IsInstance, IsList
from sqlalchemy.ext.asyncio import AsyncSession
from app import database
from app.scraper.factories import SourceGraphRepoDataFactory
from app.scraper.mapper import create_repos_from_source_graph_repos_data
from app.scraper.models import SourceGraphRepoData
pytestmark = pytest.mark.anyio
async def test_create_repos_from_source_graph_repos_data(
test_db_session: AsyncSession,
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
) -> None:
"""Test creating repos from source graph repos data."""
source_graph_repo_data: list[
SourceGraphRepoData
] = source_graph_repo_data_factory.batch(5)
repos = await create_repos_from_source_graph_repos_data(
test_db_session, source_graph_repo_data
)
assert repos == IsList(length=5)
assert all(repo == IsInstance[database.Repo] for repo in repos)
assert all(repo.id is not None for repo in repos)

View File

@ -6,49 +6,75 @@ from dirty_equals import IsList
from sqlalchemy.ext.asyncio import AsyncSession
from app import database
from app.factories import RepoCreateDataFactory
from app.factories import DependencyCreateDataFactory
from app.models import DependencyCreateData
from app.scraper.factories import SourceGraphRepoDataFactory
from app.scraper.models import SourceGraphRepoData
pytestmark = pytest.mark.anyio
def _assert_repo_properties(
repo: database.Repo, source_graph_repo_data: SourceGraphRepoData
) -> bool:
"""Assert that the repo has the expected properties."""
assert repo.id is not None
assert repo.url == str(source_graph_repo_data.repo_url)
assert repo.description == source_graph_repo_data.description
assert repo.stars == source_graph_repo_data.stars
assert repo.source_graph_repo_id == source_graph_repo_data.repo_id
return True
async def test_create_repo_no_dependencies(
test_db_session: AsyncSession, repo_create_data_factory: RepoCreateDataFactory
test_db_session: AsyncSession,
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
) -> None:
"""Test creating a repo."""
repo_create_data = repo_create_data_factory.build()
repo = database.Repo(url=str(repo_create_data.url))
source_graph_repo_data: SourceGraphRepoData = source_graph_repo_data_factory.build()
repo = database.Repo(
url=str(source_graph_repo_data.repo_url),
description=source_graph_repo_data.description,
stars=source_graph_repo_data.stars,
source_graph_repo_id=source_graph_repo_data.repo_id,
)
test_db_session.add(repo)
await test_db_session.flush()
await test_db_session.refresh(repo)
assert repo.id is not None
assert repo.url == str(repo_create_data.url)
_assert_repo_properties(repo, source_graph_repo_data)
assert (await repo.awaitable_attrs.dependencies) == IsList(length=0)
async def test_create_repo_with_dependencies(
test_db_session: AsyncSession, repo_create_data_factory: RepoCreateDataFactory
test_db_session: AsyncSession,
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
dependency_create_data_factory: DependencyCreateDataFactory,
) -> None:
"""Test creating a repo with dependencies."""
repo_create_data = repo_create_data_factory.build()
source_graph_repo_data: SourceGraphRepoData = source_graph_repo_data_factory.build()
dependencies_create_data: list[
DependencyCreateData
] = dependency_create_data_factory.batch(5)
repo = database.Repo(
url=str(repo_create_data.url),
url=str(source_graph_repo_data.repo_url),
description=source_graph_repo_data.description,
stars=source_graph_repo_data.stars,
source_graph_repo_id=source_graph_repo_data.repo_id,
dependencies=[
database.Dependency(name=dependency.name)
for dependency in repo_create_data.dependencies
database.Dependency(**dependency_create_data.model_dump())
for dependency_create_data in dependencies_create_data
],
)
test_db_session.add(repo)
await test_db_session.flush()
await test_db_session.refresh(repo)
assert repo.id is not None
assert repo.url == str(repo_create_data.url)
_assert_repo_properties(repo, source_graph_repo_data)
repo_dependencies = await repo.awaitable_attrs.dependencies
assert 2 <= len(repo_dependencies) <= 5
assert repo_dependencies == IsList(length=len(repo_create_data.dependencies))
assert repo_dependencies == IsList(length=5)
assert all(
repo_dependency.name == dependency.name
for repo_dependency, dependency in zip(
repo_dependencies, repo_create_data.dependencies, strict=True
repo_dependencies, dependencies_create_data, strict=True
)
)
@ -66,12 +92,12 @@ async def test_list_repositories(
repos_from_db = repos_from_db_result.scalars().unique().all()
assert repos_from_db == IsList(length=10)
assert all(
repo.url == str(repo_create_data.url)
repo.id == repo_data.id
and all(
repo_dependency.name == dependency.name
for repo_dependency, dependency in zip(
repo.dependencies, repo_create_data.dependencies, strict=True
repo.dependencies, repo_data.dependencies, strict=True
)
)
for repo, repo_create_data in zip(repos_from_db, some_repos, strict=True)
for repo, repo_data in zip(repos_from_db, some_repos, strict=True)
)

View File

@ -1,15 +1,15 @@
"""Add Repo, Dependency, and RepoDependency tables
"""Create Repo, Dependency and RepoDependency tables
Revision ID: d8fc955c639b
Revision ID: 0232d84a5aea
Revises:
Create Date: 2023-07-28 23:41:00.169286
Create Date: 2023-08-02 22:14:12.910175
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "d8fc955c639b"
revision = "0232d84a5aea"
down_revision = None
branch_labels = None
depends_on = None
@ -28,7 +28,11 @@ def upgrade() -> None:
"repo",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("url", sa.String(), nullable=False),
sa.Column("description", sa.Text(), nullable=False),
sa.Column("stars", sa.BigInteger(), nullable=False),
sa.Column("source_graph_repo_id", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("source_graph_repo_id"),
sa.UniqueConstraint("url"),
)
op.create_table(