mirror of
https://github.com/Kludex/awesome-fastapi-projects.git
synced 2025-05-15 13:47:05 +00:00
Refactor the models and add a mapper
- A new mapper allows to create database repositories from the SourceGraph data
This commit is contained in:
parent
9d0e6e606d
commit
6484b85956
@ -9,7 +9,9 @@ from pytest_mock import MockerFixture
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
|
||||
|
||||
from app.database import Dependency, Repo, async_session_maker
|
||||
from app.factories import RepoCreateDataFactory
|
||||
from app.factories import DependencyCreateDataFactory
|
||||
from app.scraper.factories import SourceGraphRepoDataFactory
|
||||
from app.scraper.models import SourceGraphRepoData
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="session")
|
||||
@ -73,19 +75,27 @@ async def test_db_session(
|
||||
|
||||
@pytest.fixture()
|
||||
async def some_repos(
|
||||
test_db_session: AsyncSession, repo_create_data_factory: RepoCreateDataFactory
|
||||
test_db_session: AsyncSession,
|
||||
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
|
||||
dependency_create_data_factory: DependencyCreateDataFactory,
|
||||
) -> list[Repo]:
|
||||
"""Create some repos."""
|
||||
repo_create_data = repo_create_data_factory.batch(10)
|
||||
assert repo_create_data == IsList(length=10)
|
||||
source_graph_repos_data: list[
|
||||
SourceGraphRepoData
|
||||
] = source_graph_repo_data_factory.batch(10)
|
||||
assert source_graph_repos_data == IsList(length=10)
|
||||
repos = [
|
||||
Repo(
|
||||
url=str(repo.url),
|
||||
url=str(source_graph_repo_data.repo_url),
|
||||
description=source_graph_repo_data.description,
|
||||
stars=source_graph_repo_data.stars,
|
||||
source_graph_repo_id=source_graph_repo_data.repo_id,
|
||||
dependencies=[
|
||||
Dependency(name=dependency.name) for dependency in repo.dependencies
|
||||
Dependency(**dependency_create_data.model_dump())
|
||||
for dependency_create_data in dependency_create_data_factory.batch(5)
|
||||
],
|
||||
)
|
||||
for repo in repo_create_data
|
||||
for source_graph_repo_data in source_graph_repos_data
|
||||
]
|
||||
test_db_session.add_all(repos)
|
||||
await test_db_session.flush()
|
||||
|
@ -19,7 +19,7 @@ from collections.abc import AsyncGenerator
|
||||
from pathlib import PurePath
|
||||
from typing import Final
|
||||
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy import BigInteger, ForeignKey, String, Text
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncAttrs,
|
||||
AsyncEngine,
|
||||
@ -58,6 +58,11 @@ class Repo(Base):
|
||||
__tablename__ = "repo"
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
url: Mapped[str] = mapped_column(nullable=False, unique=True)
|
||||
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
stars: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
source_graph_repo_id: Mapped[int | None] = mapped_column(
|
||||
BigInteger, nullable=True, unique=True
|
||||
)
|
||||
dependencies: Mapped[list["Dependency"]] = relationship(
|
||||
"Dependency", secondary="repo_dependency", back_populates="repos"
|
||||
)
|
||||
|
@ -2,7 +2,7 @@
|
||||
from polyfactory.factories.pydantic_factory import ModelFactory
|
||||
from polyfactory.pytest_plugin import register_fixture
|
||||
|
||||
from app.models import DependencyCreateData, RepoCreateData
|
||||
from app.models import DependencyCreateData
|
||||
|
||||
|
||||
@register_fixture
|
||||
@ -10,14 +10,3 @@ class DependencyCreateDataFactory(ModelFactory[DependencyCreateData]):
|
||||
"""Factory for creating DependencyCreateData."""
|
||||
|
||||
__model__ = DependencyCreateData
|
||||
|
||||
|
||||
@register_fixture
|
||||
class RepoCreateDataFactory(ModelFactory[RepoCreateData]):
|
||||
"""Factory for creating RepoCreateData."""
|
||||
|
||||
__model__ = RepoCreateData
|
||||
|
||||
__randomize_collection_length__ = True
|
||||
__min_collection_length__ = 2
|
||||
__max_collection_length__ = 5
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Module contains the models for the application."""
|
||||
from typing import NewType
|
||||
|
||||
from pydantic import AnyUrl, BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
RepoId = NewType("RepoId", int)
|
||||
DependencyId = NewType("DependencyId", int)
|
||||
@ -11,25 +11,3 @@ class DependencyCreateData(BaseModel):
|
||||
"""A dependency of a repository."""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class RepoCreateData(BaseModel):
|
||||
"""A repository that is being tracked."""
|
||||
|
||||
url: AnyUrl
|
||||
dependencies: list[DependencyCreateData] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DependencyDetails(BaseModel):
|
||||
"""A single dependency."""
|
||||
|
||||
id: DependencyId
|
||||
name: str
|
||||
|
||||
|
||||
class RepoDetails(BaseModel):
|
||||
"""A repository that is being tracked."""
|
||||
|
||||
id: RepoId
|
||||
url: AnyUrl
|
||||
dependencies: list[DependencyDetails] = Field(default_factory=list)
|
||||
|
@ -1,55 +1,25 @@
|
||||
"""The client for the SourceGraph API."""
|
||||
import asyncio
|
||||
import datetime
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta
|
||||
from typing import Any, AnyStr, Final, Literal, NewType, Self
|
||||
from typing import Any, AnyStr, Final, Self
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
import stamina
|
||||
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
HttpUrl,
|
||||
NonNegativeInt,
|
||||
TypeAdapter,
|
||||
computed_field,
|
||||
)
|
||||
|
||||
from app.scraper.models import SourceGraphRepoDataListAdapter
|
||||
|
||||
#: The URL of the SourceGraph SSE API.
|
||||
SOURCE_GRAPH_STREAM_API_URL: Final[
|
||||
HttpUrl
|
||||
] = "https://sourcegraph.com/.api/search/stream"
|
||||
|
||||
#: The ID of a repository from the SourceGraph API.
|
||||
SourceGraphRepoId = NewType("SourceGraphRepoId", int)
|
||||
|
||||
|
||||
class SourceGraphRepoData(BaseModel):
|
||||
"""The data of a repository."""
|
||||
|
||||
type: Literal["repo"]
|
||||
repo_id: SourceGraphRepoId = Field(..., alias="repositoryID")
|
||||
repo_handle: str = Field(..., alias="repository")
|
||||
stars: NonNegativeInt = Field(..., alias="repoStars")
|
||||
last_fetched_at: datetime.datetime = Field(..., alias="repoLastFetched")
|
||||
description: str = Field(default="")
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def repo_url(self: Self) -> HttpUrl:
|
||||
"""The URL of the repository."""
|
||||
return TypeAdapter(HttpUrl).validate_python(f"https://{self.repo_handle}")
|
||||
|
||||
|
||||
#: The type adapter for the SourceGraphRepoData.
|
||||
SourceGraphRepoDataAdapter = TypeAdapter(SourceGraphRepoData)
|
||||
|
||||
#: The type adapter for the SourceGraphRepoData list.
|
||||
SourceGraphRepoDataListAdapter = TypeAdapter(list[SourceGraphRepoData])
|
||||
|
||||
#: The query parameters for the SourceGraph SSE API.
|
||||
FASTAPI_REPOS_QUERY_PARAMS: Final[Mapping[str, str]] = {
|
||||
|
12
app/scraper/factories.py
Normal file
12
app/scraper/factories.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""Factories for creating test data."""
|
||||
from polyfactory.factories.pydantic_factory import ModelFactory
|
||||
from polyfactory.pytest_plugin import register_fixture
|
||||
|
||||
from app.scraper.models import SourceGraphRepoData
|
||||
|
||||
|
||||
@register_fixture
|
||||
class SourceGraphRepoDataFactory(ModelFactory[SourceGraphRepoData]):
|
||||
"""Factory for creating RepoCreateData."""
|
||||
|
||||
__model__ = SourceGraphRepoData
|
28
app/scraper/mapper.py
Normal file
28
app/scraper/mapper.py
Normal file
@ -0,0 +1,28 @@
|
||||
"""Mapper for scraper."""
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy.sql
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app import database
|
||||
from app.scraper.models import SourceGraphRepoData
|
||||
|
||||
|
||||
async def create_repos_from_source_graph_repos_data(
|
||||
session: AsyncSession, source_graph_repo_data: Sequence[SourceGraphRepoData]
|
||||
) -> Sequence[database.Repo]:
|
||||
"""Create repos from source graph repos data."""
|
||||
return (
|
||||
await session.scalars(
|
||||
sqlalchemy.sql.insert(database.Repo).returning(database.Repo),
|
||||
[
|
||||
{
|
||||
"url": str(repo_data.repo_url),
|
||||
"description": repo_data.description,
|
||||
"stars": repo_data.stars,
|
||||
"source_graph_repo_id": repo_data.repo_id,
|
||||
}
|
||||
for repo_data in source_graph_repo_data
|
||||
],
|
||||
)
|
||||
).all()
|
39
app/scraper/models.py
Normal file
39
app/scraper/models.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""The models for the scraper."""
|
||||
import datetime
|
||||
from typing import Literal, NewType, Self
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
HttpUrl,
|
||||
NonNegativeInt,
|
||||
TypeAdapter,
|
||||
computed_field,
|
||||
)
|
||||
|
||||
#: The ID of a repository from the SourceGraph API.
|
||||
SourceGraphRepoId = NewType("SourceGraphRepoId", int)
|
||||
|
||||
|
||||
class SourceGraphRepoData(BaseModel):
|
||||
"""The data of a repository."""
|
||||
|
||||
type: Literal["repo"]
|
||||
repo_id: SourceGraphRepoId = Field(..., alias="repositoryID")
|
||||
repo_handle: str = Field(..., alias="repository")
|
||||
stars: NonNegativeInt = Field(..., alias="repoStars")
|
||||
last_fetched_at: datetime.datetime = Field(..., alias="repoLastFetched")
|
||||
description: str = Field(default="")
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def repo_url(self: Self) -> HttpUrl:
|
||||
"""The URL of the repository."""
|
||||
return TypeAdapter(HttpUrl).validate_python(f"https://{self.repo_handle}")
|
||||
|
||||
|
||||
#: The type adapter for the SourceGraphRepoData.
|
||||
SourceGraphRepoDataAdapter = TypeAdapter(SourceGraphRepoData)
|
||||
|
||||
#: The type adapter for the SourceGraphRepoData list.
|
||||
SourceGraphRepoDataListAdapter = TypeAdapter(list[SourceGraphRepoData])
|
@ -3,7 +3,7 @@ import pytest
|
||||
from dirty_equals import HasLen, IsDatetime, IsInstance, IsPositiveInt
|
||||
from pydantic import Json, TypeAdapter
|
||||
|
||||
from app.scraper.client import SourceGraphRepoData
|
||||
from app.scraper.models import SourceGraphRepoData
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
27
app/scraper/tests/test_mapper.py
Normal file
27
app/scraper/tests/test_mapper.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""The tests for the scraper mapper."""
|
||||
import pytest
|
||||
from dirty_equals import IsInstance, IsList
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app import database
|
||||
from app.scraper.factories import SourceGraphRepoDataFactory
|
||||
from app.scraper.mapper import create_repos_from_source_graph_repos_data
|
||||
from app.scraper.models import SourceGraphRepoData
|
||||
|
||||
pytestmark = pytest.mark.anyio
|
||||
|
||||
|
||||
async def test_create_repos_from_source_graph_repos_data(
|
||||
test_db_session: AsyncSession,
|
||||
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
|
||||
) -> None:
|
||||
"""Test creating repos from source graph repos data."""
|
||||
source_graph_repo_data: list[
|
||||
SourceGraphRepoData
|
||||
] = source_graph_repo_data_factory.batch(5)
|
||||
repos = await create_repos_from_source_graph_repos_data(
|
||||
test_db_session, source_graph_repo_data
|
||||
)
|
||||
assert repos == IsList(length=5)
|
||||
assert all(repo == IsInstance[database.Repo] for repo in repos)
|
||||
assert all(repo.id is not None for repo in repos)
|
@ -6,49 +6,75 @@ from dirty_equals import IsList
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app import database
|
||||
from app.factories import RepoCreateDataFactory
|
||||
from app.factories import DependencyCreateDataFactory
|
||||
from app.models import DependencyCreateData
|
||||
from app.scraper.factories import SourceGraphRepoDataFactory
|
||||
from app.scraper.models import SourceGraphRepoData
|
||||
|
||||
pytestmark = pytest.mark.anyio
|
||||
|
||||
|
||||
def _assert_repo_properties(
|
||||
repo: database.Repo, source_graph_repo_data: SourceGraphRepoData
|
||||
) -> bool:
|
||||
"""Assert that the repo has the expected properties."""
|
||||
assert repo.id is not None
|
||||
assert repo.url == str(source_graph_repo_data.repo_url)
|
||||
assert repo.description == source_graph_repo_data.description
|
||||
assert repo.stars == source_graph_repo_data.stars
|
||||
assert repo.source_graph_repo_id == source_graph_repo_data.repo_id
|
||||
return True
|
||||
|
||||
|
||||
async def test_create_repo_no_dependencies(
|
||||
test_db_session: AsyncSession, repo_create_data_factory: RepoCreateDataFactory
|
||||
test_db_session: AsyncSession,
|
||||
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
|
||||
) -> None:
|
||||
"""Test creating a repo."""
|
||||
repo_create_data = repo_create_data_factory.build()
|
||||
repo = database.Repo(url=str(repo_create_data.url))
|
||||
source_graph_repo_data: SourceGraphRepoData = source_graph_repo_data_factory.build()
|
||||
repo = database.Repo(
|
||||
url=str(source_graph_repo_data.repo_url),
|
||||
description=source_graph_repo_data.description,
|
||||
stars=source_graph_repo_data.stars,
|
||||
source_graph_repo_id=source_graph_repo_data.repo_id,
|
||||
)
|
||||
test_db_session.add(repo)
|
||||
await test_db_session.flush()
|
||||
await test_db_session.refresh(repo)
|
||||
assert repo.id is not None
|
||||
assert repo.url == str(repo_create_data.url)
|
||||
_assert_repo_properties(repo, source_graph_repo_data)
|
||||
assert (await repo.awaitable_attrs.dependencies) == IsList(length=0)
|
||||
|
||||
|
||||
async def test_create_repo_with_dependencies(
|
||||
test_db_session: AsyncSession, repo_create_data_factory: RepoCreateDataFactory
|
||||
test_db_session: AsyncSession,
|
||||
source_graph_repo_data_factory: SourceGraphRepoDataFactory,
|
||||
dependency_create_data_factory: DependencyCreateDataFactory,
|
||||
) -> None:
|
||||
"""Test creating a repo with dependencies."""
|
||||
repo_create_data = repo_create_data_factory.build()
|
||||
source_graph_repo_data: SourceGraphRepoData = source_graph_repo_data_factory.build()
|
||||
dependencies_create_data: list[
|
||||
DependencyCreateData
|
||||
] = dependency_create_data_factory.batch(5)
|
||||
repo = database.Repo(
|
||||
url=str(repo_create_data.url),
|
||||
url=str(source_graph_repo_data.repo_url),
|
||||
description=source_graph_repo_data.description,
|
||||
stars=source_graph_repo_data.stars,
|
||||
source_graph_repo_id=source_graph_repo_data.repo_id,
|
||||
dependencies=[
|
||||
database.Dependency(name=dependency.name)
|
||||
for dependency in repo_create_data.dependencies
|
||||
database.Dependency(**dependency_create_data.model_dump())
|
||||
for dependency_create_data in dependencies_create_data
|
||||
],
|
||||
)
|
||||
test_db_session.add(repo)
|
||||
await test_db_session.flush()
|
||||
await test_db_session.refresh(repo)
|
||||
assert repo.id is not None
|
||||
assert repo.url == str(repo_create_data.url)
|
||||
_assert_repo_properties(repo, source_graph_repo_data)
|
||||
repo_dependencies = await repo.awaitable_attrs.dependencies
|
||||
assert 2 <= len(repo_dependencies) <= 5
|
||||
assert repo_dependencies == IsList(length=len(repo_create_data.dependencies))
|
||||
assert repo_dependencies == IsList(length=5)
|
||||
assert all(
|
||||
repo_dependency.name == dependency.name
|
||||
for repo_dependency, dependency in zip(
|
||||
repo_dependencies, repo_create_data.dependencies, strict=True
|
||||
repo_dependencies, dependencies_create_data, strict=True
|
||||
)
|
||||
)
|
||||
|
||||
@ -66,12 +92,12 @@ async def test_list_repositories(
|
||||
repos_from_db = repos_from_db_result.scalars().unique().all()
|
||||
assert repos_from_db == IsList(length=10)
|
||||
assert all(
|
||||
repo.url == str(repo_create_data.url)
|
||||
repo.id == repo_data.id
|
||||
and all(
|
||||
repo_dependency.name == dependency.name
|
||||
for repo_dependency, dependency in zip(
|
||||
repo.dependencies, repo_create_data.dependencies, strict=True
|
||||
repo.dependencies, repo_data.dependencies, strict=True
|
||||
)
|
||||
)
|
||||
for repo, repo_create_data in zip(repos_from_db, some_repos, strict=True)
|
||||
for repo, repo_data in zip(repos_from_db, some_repos, strict=True)
|
||||
)
|
||||
|
@ -1,15 +1,15 @@
|
||||
"""Add Repo, Dependency, and RepoDependency tables
|
||||
"""Create Repo, Dependency and RepoDependency tables
|
||||
|
||||
Revision ID: d8fc955c639b
|
||||
Revision ID: 0232d84a5aea
|
||||
Revises:
|
||||
Create Date: 2023-07-28 23:41:00.169286
|
||||
Create Date: 2023-08-02 22:14:12.910175
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d8fc955c639b"
|
||||
revision = "0232d84a5aea"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
@ -28,7 +28,11 @@ def upgrade() -> None:
|
||||
"repo",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("url", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=False),
|
||||
sa.Column("stars", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_graph_repo_id", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("source_graph_repo_id"),
|
||||
sa.UniqueConstraint("url"),
|
||||
)
|
||||
op.create_table(
|
Loading…
x
Reference in New Issue
Block a user