kodit 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (117) hide show
  1. kodit/_version.py +2 -2
  2. kodit/application/__init__.py +1 -0
  3. kodit/application/commands/__init__.py +1 -0
  4. kodit/application/commands/snippet_commands.py +22 -0
  5. kodit/application/services/__init__.py +1 -0
  6. kodit/application/services/indexing_application_service.py +363 -0
  7. kodit/application/services/snippet_application_service.py +143 -0
  8. kodit/cli.py +105 -82
  9. kodit/database.py +0 -22
  10. kodit/domain/__init__.py +1 -0
  11. kodit/{source/source_models.py → domain/entities.py} +88 -19
  12. kodit/domain/enums.py +9 -0
  13. kodit/domain/interfaces.py +27 -0
  14. kodit/domain/repositories.py +95 -0
  15. kodit/domain/services/__init__.py +1 -0
  16. kodit/domain/services/bm25_service.py +124 -0
  17. kodit/domain/services/embedding_service.py +155 -0
  18. kodit/domain/services/enrichment_service.py +48 -0
  19. kodit/domain/services/ignore_service.py +45 -0
  20. kodit/domain/services/indexing_service.py +203 -0
  21. kodit/domain/services/snippet_extraction_service.py +89 -0
  22. kodit/domain/services/source_service.py +83 -0
  23. kodit/domain/value_objects.py +215 -0
  24. kodit/infrastructure/__init__.py +1 -0
  25. kodit/infrastructure/bm25/__init__.py +1 -0
  26. kodit/infrastructure/bm25/bm25_factory.py +28 -0
  27. kodit/{bm25/local_bm25.py → infrastructure/bm25/local_bm25_repository.py} +33 -22
  28. kodit/{bm25/vectorchord_bm25.py → infrastructure/bm25/vectorchord_bm25_repository.py} +40 -35
  29. kodit/infrastructure/cloning/__init__.py +1 -0
  30. kodit/infrastructure/cloning/folder/__init__.py +1 -0
  31. kodit/infrastructure/cloning/folder/factory.py +119 -0
  32. kodit/infrastructure/cloning/folder/working_copy.py +38 -0
  33. kodit/infrastructure/cloning/git/__init__.py +1 -0
  34. kodit/infrastructure/cloning/git/factory.py +133 -0
  35. kodit/infrastructure/cloning/git/working_copy.py +32 -0
  36. kodit/infrastructure/cloning/metadata.py +127 -0
  37. kodit/infrastructure/embedding/__init__.py +1 -0
  38. kodit/infrastructure/embedding/embedding_factory.py +87 -0
  39. kodit/infrastructure/embedding/embedding_providers/__init__.py +1 -0
  40. kodit/infrastructure/embedding/embedding_providers/batching.py +93 -0
  41. kodit/infrastructure/embedding/embedding_providers/hash_embedding_provider.py +79 -0
  42. kodit/infrastructure/embedding/embedding_providers/local_embedding_provider.py +129 -0
  43. kodit/infrastructure/embedding/embedding_providers/openai_embedding_provider.py +113 -0
  44. kodit/infrastructure/embedding/local_vector_search_repository.py +114 -0
  45. kodit/{embedding/vectorchord_vector_search_service.py → infrastructure/embedding/vectorchord_vector_search_repository.py} +65 -46
  46. kodit/infrastructure/enrichment/__init__.py +1 -0
  47. kodit/{enrichment → infrastructure/enrichment}/enrichment_factory.py +28 -12
  48. kodit/infrastructure/enrichment/legacy_enrichment_models.py +42 -0
  49. kodit/{enrichment/enrichment_provider → infrastructure/enrichment}/local_enrichment_provider.py +38 -26
  50. kodit/infrastructure/enrichment/null_enrichment_provider.py +25 -0
  51. kodit/infrastructure/enrichment/openai_enrichment_provider.py +89 -0
  52. kodit/infrastructure/git/__init__.py +1 -0
  53. kodit/{source/git.py → infrastructure/git/git_utils.py} +10 -2
  54. kodit/infrastructure/ignore/__init__.py +1 -0
  55. kodit/{source/ignore.py → infrastructure/ignore/ignore_pattern_provider.py} +23 -6
  56. kodit/infrastructure/indexing/__init__.py +1 -0
  57. kodit/infrastructure/indexing/fusion_service.py +55 -0
  58. kodit/infrastructure/indexing/index_repository.py +296 -0
  59. kodit/infrastructure/indexing/indexing_factory.py +111 -0
  60. kodit/infrastructure/snippet_extraction/__init__.py +1 -0
  61. kodit/infrastructure/snippet_extraction/language_detection_service.py +39 -0
  62. kodit/infrastructure/snippet_extraction/snippet_extraction_factory.py +95 -0
  63. kodit/infrastructure/snippet_extraction/snippet_query_provider.py +45 -0
  64. kodit/{snippets/method_snippets.py → infrastructure/snippet_extraction/tree_sitter_snippet_extractor.py} +123 -61
  65. kodit/infrastructure/sqlalchemy/__init__.py +1 -0
  66. kodit/{embedding → infrastructure/sqlalchemy}/embedding_repository.py +40 -24
  67. kodit/infrastructure/sqlalchemy/file_repository.py +73 -0
  68. kodit/infrastructure/sqlalchemy/repository.py +121 -0
  69. kodit/infrastructure/sqlalchemy/snippet_repository.py +75 -0
  70. kodit/infrastructure/ui/__init__.py +1 -0
  71. kodit/infrastructure/ui/progress.py +127 -0
  72. kodit/{util → infrastructure/ui}/spinner.py +19 -4
  73. kodit/mcp.py +50 -28
  74. kodit/migrations/env.py +1 -4
  75. kodit/reporting.py +78 -0
  76. {kodit-0.2.4.dist-info → kodit-0.2.5.dist-info}/METADATA +1 -1
  77. kodit-0.2.5.dist-info/RECORD +99 -0
  78. kodit/bm25/__init__.py +0 -1
  79. kodit/bm25/keyword_search_factory.py +0 -17
  80. kodit/bm25/keyword_search_service.py +0 -34
  81. kodit/embedding/__init__.py +0 -1
  82. kodit/embedding/embedding_factory.py +0 -69
  83. kodit/embedding/embedding_models.py +0 -28
  84. kodit/embedding/embedding_provider/__init__.py +0 -1
  85. kodit/embedding/embedding_provider/embedding_provider.py +0 -92
  86. kodit/embedding/embedding_provider/hash_embedding_provider.py +0 -86
  87. kodit/embedding/embedding_provider/local_embedding_provider.py +0 -96
  88. kodit/embedding/embedding_provider/openai_embedding_provider.py +0 -73
  89. kodit/embedding/local_vector_search_service.py +0 -87
  90. kodit/embedding/vector_search_service.py +0 -55
  91. kodit/enrichment/__init__.py +0 -1
  92. kodit/enrichment/enrichment_provider/__init__.py +0 -1
  93. kodit/enrichment/enrichment_provider/enrichment_provider.py +0 -36
  94. kodit/enrichment/enrichment_provider/openai_enrichment_provider.py +0 -79
  95. kodit/enrichment/enrichment_service.py +0 -45
  96. kodit/indexing/__init__.py +0 -1
  97. kodit/indexing/fusion.py +0 -67
  98. kodit/indexing/indexing_models.py +0 -43
  99. kodit/indexing/indexing_repository.py +0 -216
  100. kodit/indexing/indexing_service.py +0 -344
  101. kodit/snippets/__init__.py +0 -1
  102. kodit/snippets/languages/__init__.py +0 -53
  103. kodit/snippets/snippets.py +0 -50
  104. kodit/source/__init__.py +0 -1
  105. kodit/source/source_factories.py +0 -356
  106. kodit/source/source_repository.py +0 -169
  107. kodit/source/source_service.py +0 -150
  108. kodit/util/__init__.py +0 -1
  109. kodit-0.2.4.dist-info/RECORD +0 -71
  110. /kodit/{snippets → infrastructure/snippet_extraction}/languages/csharp.scm +0 -0
  111. /kodit/{snippets → infrastructure/snippet_extraction}/languages/go.scm +0 -0
  112. /kodit/{snippets → infrastructure/snippet_extraction}/languages/javascript.scm +0 -0
  113. /kodit/{snippets → infrastructure/snippet_extraction}/languages/python.scm +0 -0
  114. /kodit/{snippets → infrastructure/snippet_extraction}/languages/typescript.scm +0 -0
  115. {kodit-0.2.4.dist-info → kodit-0.2.5.dist-info}/WHEEL +0 -0
  116. {kodit-0.2.4.dist-info → kodit-0.2.5.dist-info}/entry_points.txt +0 -0
  117. {kodit-0.2.4.dist-info → kodit-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,53 +0,0 @@
1
- """Detect the language of a file."""
2
-
3
- from pathlib import Path
4
- from typing import cast
5
-
6
- from tree_sitter_language_pack import SupportedLanguage
7
-
8
- # Mapping of file extensions to programming languages
9
- LANGUAGE_MAP: dict[str, str] = {
10
- # JavaScript/TypeScript
11
- "js": "javascript",
12
- "jsx": "javascript",
13
- "ts": "typescript",
14
- "tsx": "typescript",
15
- # Python
16
- "py": "python",
17
- # Rust
18
- "rs": "rust",
19
- # Go
20
- "go": "go",
21
- # C/C++
22
- "cpp": "cpp",
23
- "hpp": "cpp",
24
- "c": "c",
25
- "h": "c",
26
- # C#
27
- "cs": "csharp",
28
- # Ruby
29
- "rb": "ruby",
30
- # Java
31
- "java": "java",
32
- # PHP
33
- "php": "php",
34
- # Swift
35
- "swift": "swift",
36
- # Kotlin
37
- "kt": "kotlin",
38
- }
39
-
40
-
41
- def detect_language(file_path: Path) -> SupportedLanguage:
42
- """Detect the language of a file."""
43
- suffix = file_path.suffix.removeprefix(".").lower()
44
- msg = f"Unsupported language for file suffix: {suffix}"
45
- lang = LANGUAGE_MAP.get(suffix)
46
- if lang is None:
47
- raise ValueError(msg)
48
-
49
- # Try to cast the language to a SupportedLanguage
50
- try:
51
- return cast("SupportedLanguage", lang)
52
- except Exception as e:
53
- raise ValueError(msg) from e
@@ -1,50 +0,0 @@
1
- """Generate snippets from a file."""
2
-
3
- from dataclasses import dataclass
4
- from pathlib import Path
5
-
6
- from kodit.snippets.languages import detect_language
7
- from kodit.snippets.method_snippets import MethodSnippets
8
-
9
-
10
- @dataclass
11
- class Snippet:
12
- """A snippet of code."""
13
-
14
- text: str
15
-
16
-
17
- class SnippetService:
18
- """Factory for generating snippets from a file.
19
-
20
- This is required because there's going to be multiple ways to generate snippets.
21
- """
22
-
23
- def __init__(self) -> None:
24
- """Initialize the snippet factory."""
25
- self.language_dir = Path(__file__).parent / "languages"
26
-
27
- def snippets_for_file(self, file_path: Path) -> list[Snippet]:
28
- """Generate snippets from a file."""
29
- language = detect_language(file_path)
30
-
31
- try:
32
- query_path = self.language_dir / f"{language}.scm"
33
- with query_path.open() as f:
34
- query = f.read()
35
- except Exception as e:
36
- msg = f"Unsupported language: {file_path}"
37
- raise ValueError(msg) from e
38
-
39
- method_analser = MethodSnippets(language, query)
40
-
41
- try:
42
- file_bytes = file_path.read_bytes()
43
- except Exception as e:
44
- msg = f"Failed to read file: {file_path}"
45
- raise ValueError(msg) from e
46
-
47
- method_snippets = method_analser.extract(file_bytes)
48
- all_snippets = [Snippet(text=snippet) for snippet in method_snippets]
49
- # Remove any snippets that are empty
50
- return [snippet for snippet in all_snippets if snippet.text.strip()]
kodit/source/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Sources package for managing code source repositories and local directories."""
@@ -1,356 +0,0 @@
1
- """Source factories for creating different types of sources.
2
-
3
- This module provides factory classes for creating sources, improving cohesion by
4
- separating the concerns of different source types.
5
- """
6
-
7
- import mimetypes
8
- import shutil
9
- import tempfile
10
- from abc import ABC, abstractmethod
11
- from datetime import UTC, datetime
12
- from hashlib import sha256
13
- from pathlib import Path
14
- from typing import Protocol
15
-
16
- import aiofiles
17
- import git
18
- import structlog
19
- from tqdm import tqdm
20
-
21
- from kodit.source.ignore import IgnorePatterns
22
- from kodit.source.source_models import (
23
- Author,
24
- AuthorFileMapping,
25
- File,
26
- Source,
27
- SourceType,
28
- )
29
- from kodit.source.source_repository import SourceRepository
30
-
31
-
32
- class WorkingCopyProvider(Protocol):
33
- """Protocol for providing working copies of sources."""
34
-
35
- async def prepare(self, uri: str) -> Path:
36
- """Prepare a working copy and return its path."""
37
- ...
38
-
39
-
40
- class FileMetadataExtractor(Protocol):
41
- """Protocol for extracting file metadata."""
42
-
43
- async def extract(self, path: Path, source: Source) -> File:
44
- """Extract metadata from a file."""
45
- ...
46
-
47
-
48
- class AuthorExtractor(Protocol):
49
- """Protocol for extracting author information."""
50
-
51
- async def extract(self, path: Path, source: Source) -> list[Author]:
52
- """Extract authors for a file."""
53
- ...
54
-
55
-
56
- class SourceFactory(ABC):
57
- """Abstract base class for source factories."""
58
-
59
- def __init__(
60
- self,
61
- working_copy: WorkingCopyProvider,
62
- metadata_extractor: FileMetadataExtractor,
63
- author_extractor: AuthorExtractor,
64
- repository: SourceRepository,
65
- ) -> None:
66
- """Initialize the source factory."""
67
- self.working_copy = working_copy
68
- self.metadata_extractor = metadata_extractor
69
- self.author_extractor = author_extractor
70
- self.repository = repository
71
- self.log = structlog.get_logger(__name__)
72
-
73
- @abstractmethod
74
- async def create(self, uri: str) -> Source:
75
- """Create a source from a URI."""
76
- ...
77
-
78
- async def _process_files(self, source: Source, files: list[Path]) -> None:
79
- """Process files for a source."""
80
- for path in tqdm(files, total=len(files), leave=False):
81
- if not path.is_file():
82
- continue
83
-
84
- # Extract file metadata
85
- file_record = await self.metadata_extractor.extract(path, source)
86
- await self.repository.create_file(file_record)
87
-
88
- # Extract authors
89
- authors = await self.author_extractor.extract(path, source)
90
- for author in authors:
91
- await self.repository.upsert_author_file_mapping(
92
- AuthorFileMapping(
93
- author_id=author.id,
94
- file_id=file_record.id,
95
- )
96
- )
97
-
98
-
99
- class GitSourceFactory(SourceFactory):
100
- """Factory for creating Git sources."""
101
-
102
- async def create(self, uri: str) -> Source:
103
- """Create a git source from a URI."""
104
- # Normalize the URI
105
- self.log.debug("Normalising git uri", uri=uri)
106
- with tempfile.TemporaryDirectory() as temp_dir:
107
- git.Repo.clone_from(uri, temp_dir)
108
- remote = git.Repo(temp_dir).remote()
109
- uri = remote.url
110
-
111
- # Check if source already exists
112
- self.log.debug("Checking if source already exists", uri=uri)
113
- source = await self.repository.get_source_by_uri(uri)
114
-
115
- if source:
116
- self.log.info("Source already exists, reusing...", source_id=source.id)
117
- return source
118
-
119
- # Prepare working copy
120
- clone_path = await self.working_copy.prepare(uri)
121
-
122
- # Create source record
123
- self.log.debug("Creating source", uri=uri, clone_path=str(clone_path))
124
- source = await self.repository.create_source(
125
- Source(
126
- uri=uri,
127
- cloned_path=str(clone_path),
128
- source_type=SourceType.GIT,
129
- )
130
- )
131
-
132
- # Get files to process using ignore patterns
133
- ignore_patterns = IgnorePatterns(clone_path)
134
- files = [
135
- f
136
- for f in clone_path.rglob("*")
137
- if f.is_file() and not ignore_patterns.should_ignore(f)
138
- ]
139
-
140
- # Process files
141
- self.log.info("Inspecting files", source_id=source.id, num_files=len(files))
142
- await self._process_files(source, files)
143
-
144
- return source
145
-
146
-
147
- class FolderSourceFactory(SourceFactory):
148
- """Factory for creating folder sources."""
149
-
150
- async def create(self, uri: str) -> Source:
151
- """Create a folder source from a path."""
152
- directory = Path(uri).expanduser().resolve()
153
-
154
- # Check if source already exists
155
- source = await self.repository.get_source_by_uri(directory.as_uri())
156
- if source:
157
- self.log.info("Source already exists, reusing...", source_id=source.id)
158
- return source
159
-
160
- # Validate directory exists
161
- if not directory.exists():
162
- msg = f"Folder does not exist: {directory}"
163
- raise ValueError(msg)
164
-
165
- # Prepare working copy
166
- clone_path = await self.working_copy.prepare(directory.as_uri())
167
-
168
- # Create source record
169
- source = await self.repository.create_source(
170
- Source(
171
- uri=directory.as_uri(),
172
- cloned_path=str(clone_path),
173
- source_type=SourceType.FOLDER,
174
- )
175
- )
176
-
177
- # Get all files to process
178
- files = [f for f in clone_path.rglob("*") if f.is_file()]
179
-
180
- # Process files
181
- await self._process_files(source, files)
182
-
183
- return source
184
-
185
-
186
- class GitWorkingCopyProvider:
187
- """Working copy provider for Git repositories."""
188
-
189
- def __init__(self, clone_dir: Path) -> None:
190
- """Initialize the provider."""
191
- self.clone_dir = clone_dir
192
- self.log = structlog.get_logger(__name__)
193
-
194
- async def prepare(self, uri: str) -> Path:
195
- """Prepare a Git working copy."""
196
- # Create a unique directory name for the clone
197
- clone_path = self.clone_dir / uri.replace("/", "_").replace(":", "_")
198
- clone_path.mkdir(parents=True, exist_ok=True)
199
-
200
- try:
201
- self.log.info("Cloning repository", uri=uri, clone_path=str(clone_path))
202
- git.Repo.clone_from(uri, clone_path)
203
- except git.GitCommandError as e:
204
- if "already exists and is not an empty directory" not in str(e):
205
- msg = f"Failed to clone repository: {e}"
206
- raise ValueError(msg) from e
207
- self.log.info("Repository already exists, reusing...", uri=uri)
208
-
209
- return clone_path
210
-
211
-
212
- class FolderWorkingCopyProvider:
213
- """Working copy provider for local folders."""
214
-
215
- def __init__(self, clone_dir: Path) -> None:
216
- """Initialize the provider."""
217
- self.clone_dir = clone_dir
218
-
219
- async def prepare(self, uri: str) -> Path:
220
- """Prepare a folder working copy."""
221
- # Handle file:// URIs
222
- if uri.startswith("file://"):
223
- from urllib.parse import urlparse
224
-
225
- parsed = urlparse(uri)
226
- directory = Path(parsed.path).expanduser().resolve()
227
- else:
228
- directory = Path(uri).expanduser().resolve()
229
-
230
- # Clone into a local directory
231
- clone_path = self.clone_dir / directory.as_posix().replace("/", "_")
232
- clone_path.mkdir(parents=True, exist_ok=True)
233
-
234
- # Copy all files recursively, preserving directory structure, ignoring
235
- # hidden files
236
- shutil.copytree(
237
- directory,
238
- clone_path,
239
- ignore=shutil.ignore_patterns(".*"),
240
- dirs_exist_ok=True,
241
- )
242
-
243
- return clone_path
244
-
245
-
246
- class BaseFileMetadataExtractor:
247
- """Base class for file metadata extraction with common functionality."""
248
-
249
- async def extract(self, path: Path, source: Source) -> File:
250
- """Extract metadata from a file."""
251
- # Get timestamps - to be implemented by subclasses
252
- created_at, updated_at = await self._get_timestamps(path, source)
253
-
254
- # Read file content and calculate metadata
255
- async with aiofiles.open(path, "rb") as f:
256
- content = await f.read()
257
- mime_type = mimetypes.guess_type(path)
258
- sha = sha256(content).hexdigest()
259
-
260
- return File(
261
- created_at=created_at,
262
- updated_at=updated_at,
263
- source_id=source.id,
264
- cloned_path=str(path),
265
- mime_type=mime_type[0]
266
- if mime_type and mime_type[0]
267
- else "application/octet-stream",
268
- uri=path.as_uri(),
269
- sha256=sha,
270
- size_bytes=len(content),
271
- )
272
-
273
- async def _get_timestamps(
274
- self, path: Path, source: Source
275
- ) -> tuple[datetime, datetime]:
276
- """Get creation and modification timestamps. To be implemented by subclasses."""
277
- raise NotImplementedError
278
-
279
-
280
- class GitFileMetadataExtractor(BaseFileMetadataExtractor):
281
- """Git-specific implementation for extracting file metadata."""
282
-
283
- async def _get_timestamps(
284
- self, path: Path, source: Source
285
- ) -> tuple[datetime, datetime]:
286
- """Get timestamps from Git history."""
287
- git_repo = git.Repo(source.cloned_path)
288
- commits = list(git_repo.iter_commits(paths=str(path), all=True))
289
-
290
- if commits:
291
- last_modified_at = commits[0].committed_datetime
292
- first_modified_at = commits[-1].committed_datetime
293
- return first_modified_at, last_modified_at
294
- # Fallback to current time if no commits found
295
- now = datetime.now(UTC)
296
- return now, now
297
-
298
-
299
- class FolderFileMetadataExtractor(BaseFileMetadataExtractor):
300
- """Folder-specific implementation for extracting file metadata."""
301
-
302
- async def _get_timestamps(
303
- self,
304
- path: Path,
305
- source: Source, # noqa: ARG002
306
- ) -> tuple[datetime, datetime]:
307
- """Get timestamps from file system."""
308
- stat = path.stat()
309
- file_created_at = datetime.fromtimestamp(stat.st_ctime, UTC)
310
- file_modified_at = datetime.fromtimestamp(stat.st_mtime, UTC)
311
- return file_created_at, file_modified_at
312
-
313
-
314
- class GitAuthorExtractor:
315
- """Author extractor for Git repositories."""
316
-
317
- def __init__(self, repository: SourceRepository) -> None:
318
- """Initialize the extractor."""
319
- self.repository = repository
320
-
321
- async def extract(self, path: Path, source: Source) -> list[Author]:
322
- """Extract authors from a Git file."""
323
- authors: list[Author] = []
324
- git_repo = git.Repo(source.cloned_path)
325
-
326
- try:
327
- # Get the file's blame
328
- blames = git_repo.blame("HEAD", str(path))
329
-
330
- # Extract the blame's authors
331
- actors = [
332
- commit.author
333
- for blame in blames or []
334
- for commit in blame
335
- if isinstance(commit, git.Commit)
336
- ]
337
-
338
- # Get or create the authors in the database
339
- for actor in actors:
340
- if actor.email:
341
- author = Author.from_actor(actor)
342
- author = await self.repository.upsert_author(author)
343
- authors.append(author)
344
- except git.GitCommandError:
345
- # Handle cases where file might not be tracked
346
- pass
347
-
348
- return authors
349
-
350
-
351
- class NoOpAuthorExtractor:
352
- """No-op author extractor for sources that don't have author information."""
353
-
354
- async def extract(self, path: Path, source: Source) -> list[Author]: # noqa: ARG002
355
- """Return empty list of authors."""
356
- return []
@@ -1,169 +0,0 @@
1
- """Source repository for database operations."""
2
-
3
- from sqlalchemy import func, select
4
- from sqlalchemy.ext.asyncio import AsyncSession
5
-
6
- from kodit.source.source_models import (
7
- Author,
8
- AuthorFileMapping,
9
- File,
10
- Source,
11
- SourceType,
12
- )
13
-
14
-
15
- class SourceRepository:
16
- """Repository for managing source database operations.
17
-
18
- This class provides methods for creating and retrieving source records from the
19
- database. It handles the low-level database operations and transaction management.
20
-
21
- Args:
22
- session: The SQLAlchemy async session to use for database operations.
23
-
24
- """
25
-
26
- def __init__(self, session: AsyncSession) -> None:
27
- """Initialize the source repository."""
28
- self.session = session
29
-
30
- async def create_source(self, source: Source) -> Source:
31
- """Add a new source to the database."""
32
- # Validate the source
33
- if source.type == SourceType.UNKNOWN:
34
- msg = "Source type is required"
35
- raise ValueError(msg)
36
-
37
- self.session.add(source)
38
- await self.session.commit()
39
- return source
40
-
41
- async def create_file(self, file: File) -> File:
42
- """Create a new file record in the database.
43
-
44
- This method creates a new File record and adds it to the session.
45
-
46
- """
47
- self.session.add(file)
48
- await self.session.commit()
49
- return file
50
-
51
- async def list_files_for_source(self, source_id: int) -> list[File]:
52
- """List all files for a source."""
53
- query = select(File).where(File.source_id == source_id)
54
- result = await self.session.execute(query)
55
- return list(result.scalars())
56
-
57
- async def num_files_for_source(self, source_id: int) -> int:
58
- """Get the number of files for a source.
59
-
60
- Args:
61
- source_id: The ID of the source to get the number of files for.
62
-
63
- Returns:
64
- The number of files for the source.
65
-
66
- """
67
- query = (
68
- select(func.count()).select_from(File).where(File.source_id == source_id)
69
- )
70
- result = await self.session.execute(query)
71
- return result.scalar_one()
72
-
73
- async def list_sources(self) -> list[Source]:
74
- """Retrieve all sources from the database.
75
-
76
- Returns:
77
- A list of Source instances.
78
-
79
- """
80
- query = select(Source).limit(10)
81
- result = await self.session.execute(query)
82
- return list(result.scalars())
83
-
84
- async def get_source_by_uri(self, uri: str) -> Source | None:
85
- """Get a source by its URI.
86
-
87
- Args:
88
- uri: The URI of the source to get.
89
-
90
- Returns:
91
- The source with the given URI, or None if it does not exist.
92
-
93
- """
94
- query = select(Source).where(Source.uri == uri)
95
- result = await self.session.execute(query)
96
- return result.scalar_one_or_none()
97
-
98
- async def get_source_by_id(self, source_id: int) -> Source | None:
99
- """Get a source by its ID.
100
-
101
- Args:
102
- source_id: The ID of the source to get.
103
-
104
- """
105
- query = select(Source).where(Source.id == source_id)
106
- result = await self.session.execute(query)
107
- return result.scalar_one_or_none()
108
-
109
- async def get_author_by_email(self, email: str) -> Author | None:
110
- """Get an author by email."""
111
- query = select(Author).where(Author.email == email)
112
- result = await self.session.execute(query)
113
- return result.scalar_one_or_none()
114
-
115
- async def upsert_author(self, author: Author) -> Author:
116
- """Create a new author or return existing one if email already exists.
117
-
118
- Args:
119
- author: The Author instance to upsert.
120
-
121
- Returns:
122
- The existing Author if one with the same email exists, otherwise the newly
123
- created Author.
124
-
125
- """
126
- # First check if author already exists with same name and email
127
- query = select(Author).where(
128
- Author.name == author.name, Author.email == author.email
129
- )
130
- result = await self.session.execute(query)
131
- existing_author = result.scalar_one_or_none()
132
-
133
- if existing_author:
134
- return existing_author
135
-
136
- # Author doesn't exist, create new one
137
- self.session.add(author)
138
- await self.session.commit()
139
- return author
140
-
141
- async def upsert_author_file_mapping(
142
- self, mapping: AuthorFileMapping
143
- ) -> AuthorFileMapping:
144
- """Create a new author file mapping or return existing one if already exists."""
145
- # First check if mapping already exists with same author_id and file_id
146
- query = select(AuthorFileMapping).where(
147
- AuthorFileMapping.author_id == mapping.author_id,
148
- AuthorFileMapping.file_id == mapping.file_id,
149
- )
150
- result = await self.session.execute(query)
151
- existing_mapping = result.scalar_one_or_none()
152
-
153
- if existing_mapping:
154
- return existing_mapping
155
-
156
- # Mapping doesn't exist, create new one
157
- self.session.add(mapping)
158
- await self.session.commit()
159
- return mapping
160
-
161
- async def list_files_for_author(self, author_id: int) -> list[File]:
162
- """List all files for an author."""
163
- query = (
164
- select(File)
165
- .join(AuthorFileMapping)
166
- .where(AuthorFileMapping.author_id == author_id)
167
- )
168
- result = await self.session.execute(query)
169
- return list(result.scalars())