kodit 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/app.py +6 -0
- kodit/bm25/local_bm25.py +8 -0
- kodit/bm25/vectorchord_bm25.py +4 -1
- kodit/cli.py +8 -2
- kodit/config.py +14 -24
- kodit/embedding/embedding_factory.py +25 -6
- kodit/embedding/embedding_provider/embedding_provider.py +2 -2
- kodit/embedding/embedding_provider/openai_embedding_provider.py +3 -1
- kodit/embedding/local_vector_search_service.py +4 -0
- kodit/embedding/vectorchord_vector_search_service.py +10 -2
- kodit/enrichment/enrichment_factory.py +26 -7
- kodit/enrichment/enrichment_provider/local_enrichment_provider.py +4 -0
- kodit/enrichment/enrichment_provider/openai_enrichment_provider.py +5 -1
- kodit/indexing/indexing_service.py +28 -3
- kodit/log.py +126 -24
- kodit/migrations/versions/9e53ea8bb3b0_add_authors.py +103 -0
- kodit/source/git.py +16 -0
- kodit/source/ignore.py +53 -0
- kodit/source/source_factories.py +356 -0
- kodit/source/source_models.py +52 -2
- kodit/source/source_repository.py +80 -16
- kodit/source/source_service.py +45 -155
- {kodit-0.2.1.dist-info → kodit-0.2.3.dist-info}/METADATA +4 -2
- {kodit-0.2.1.dist-info → kodit-0.2.3.dist-info}/RECORD +28 -24
- {kodit-0.2.1.dist-info → kodit-0.2.3.dist-info}/WHEEL +0 -0
- {kodit-0.2.1.dist-info → kodit-0.2.3.dist-info}/entry_points.txt +0 -0
- {kodit-0.2.1.dist-info → kodit-0.2.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
"""Source factories for creating different types of sources.
|
|
2
|
+
|
|
3
|
+
This module provides factory classes for creating sources, improving cohesion by
|
|
4
|
+
separating the concerns of different source types.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import mimetypes
|
|
8
|
+
import shutil
|
|
9
|
+
import tempfile
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from datetime import UTC, datetime
|
|
12
|
+
from hashlib import sha256
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Protocol
|
|
15
|
+
|
|
16
|
+
import aiofiles
|
|
17
|
+
import git
|
|
18
|
+
import structlog
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
|
|
21
|
+
from kodit.source.ignore import IgnorePatterns
|
|
22
|
+
from kodit.source.source_models import (
|
|
23
|
+
Author,
|
|
24
|
+
AuthorFileMapping,
|
|
25
|
+
File,
|
|
26
|
+
Source,
|
|
27
|
+
SourceType,
|
|
28
|
+
)
|
|
29
|
+
from kodit.source.source_repository import SourceRepository
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class WorkingCopyProvider(Protocol):
|
|
33
|
+
"""Protocol for providing working copies of sources."""
|
|
34
|
+
|
|
35
|
+
async def prepare(self, uri: str) -> Path:
|
|
36
|
+
"""Prepare a working copy and return its path."""
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class FileMetadataExtractor(Protocol):
|
|
41
|
+
"""Protocol for extracting file metadata."""
|
|
42
|
+
|
|
43
|
+
async def extract(self, path: Path, source: Source) -> File:
|
|
44
|
+
"""Extract metadata from a file."""
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class AuthorExtractor(Protocol):
|
|
49
|
+
"""Protocol for extracting author information."""
|
|
50
|
+
|
|
51
|
+
async def extract(self, path: Path, source: Source) -> list[Author]:
|
|
52
|
+
"""Extract authors for a file."""
|
|
53
|
+
...
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class SourceFactory(ABC):
|
|
57
|
+
"""Abstract base class for source factories."""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
working_copy: WorkingCopyProvider,
|
|
62
|
+
metadata_extractor: FileMetadataExtractor,
|
|
63
|
+
author_extractor: AuthorExtractor,
|
|
64
|
+
repository: SourceRepository,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Initialize the source factory."""
|
|
67
|
+
self.working_copy = working_copy
|
|
68
|
+
self.metadata_extractor = metadata_extractor
|
|
69
|
+
self.author_extractor = author_extractor
|
|
70
|
+
self.repository = repository
|
|
71
|
+
self.log = structlog.get_logger(__name__)
|
|
72
|
+
|
|
73
|
+
@abstractmethod
|
|
74
|
+
async def create(self, uri: str) -> Source:
|
|
75
|
+
"""Create a source from a URI."""
|
|
76
|
+
...
|
|
77
|
+
|
|
78
|
+
async def _process_files(self, source: Source, files: list[Path]) -> None:
|
|
79
|
+
"""Process files for a source."""
|
|
80
|
+
for path in tqdm(files, total=len(files), leave=False):
|
|
81
|
+
if not path.is_file():
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
# Extract file metadata
|
|
85
|
+
file_record = await self.metadata_extractor.extract(path, source)
|
|
86
|
+
await self.repository.create_file(file_record)
|
|
87
|
+
|
|
88
|
+
# Extract authors
|
|
89
|
+
authors = await self.author_extractor.extract(path, source)
|
|
90
|
+
for author in authors:
|
|
91
|
+
await self.repository.upsert_author_file_mapping(
|
|
92
|
+
AuthorFileMapping(
|
|
93
|
+
author_id=author.id,
|
|
94
|
+
file_id=file_record.id,
|
|
95
|
+
)
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class GitSourceFactory(SourceFactory):
|
|
100
|
+
"""Factory for creating Git sources."""
|
|
101
|
+
|
|
102
|
+
async def create(self, uri: str) -> Source:
|
|
103
|
+
"""Create a git source from a URI."""
|
|
104
|
+
# Normalize the URI
|
|
105
|
+
self.log.debug("Normalising git uri", uri=uri)
|
|
106
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
107
|
+
git.Repo.clone_from(uri, temp_dir)
|
|
108
|
+
remote = git.Repo(temp_dir).remote()
|
|
109
|
+
uri = remote.url
|
|
110
|
+
|
|
111
|
+
# Check if source already exists
|
|
112
|
+
self.log.debug("Checking if source already exists", uri=uri)
|
|
113
|
+
source = await self.repository.get_source_by_uri(uri)
|
|
114
|
+
|
|
115
|
+
if source:
|
|
116
|
+
self.log.info("Source already exists, reusing...", source_id=source.id)
|
|
117
|
+
return source
|
|
118
|
+
|
|
119
|
+
# Prepare working copy
|
|
120
|
+
clone_path = await self.working_copy.prepare(uri)
|
|
121
|
+
|
|
122
|
+
# Create source record
|
|
123
|
+
self.log.debug("Creating source", uri=uri, clone_path=str(clone_path))
|
|
124
|
+
source = await self.repository.create_source(
|
|
125
|
+
Source(
|
|
126
|
+
uri=uri,
|
|
127
|
+
cloned_path=str(clone_path),
|
|
128
|
+
source_type=SourceType.GIT,
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Get files to process using ignore patterns
|
|
133
|
+
ignore_patterns = IgnorePatterns(clone_path)
|
|
134
|
+
files = [
|
|
135
|
+
f
|
|
136
|
+
for f in clone_path.rglob("*")
|
|
137
|
+
if f.is_file() and not ignore_patterns.should_ignore(f)
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
# Process files
|
|
141
|
+
self.log.info("Inspecting files", source_id=source.id, num_files=len(files))
|
|
142
|
+
await self._process_files(source, files)
|
|
143
|
+
|
|
144
|
+
return source
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class FolderSourceFactory(SourceFactory):
|
|
148
|
+
"""Factory for creating folder sources."""
|
|
149
|
+
|
|
150
|
+
async def create(self, uri: str) -> Source:
|
|
151
|
+
"""Create a folder source from a path."""
|
|
152
|
+
directory = Path(uri).expanduser().resolve()
|
|
153
|
+
|
|
154
|
+
# Check if source already exists
|
|
155
|
+
source = await self.repository.get_source_by_uri(directory.as_uri())
|
|
156
|
+
if source:
|
|
157
|
+
self.log.info("Source already exists, reusing...", source_id=source.id)
|
|
158
|
+
return source
|
|
159
|
+
|
|
160
|
+
# Validate directory exists
|
|
161
|
+
if not directory.exists():
|
|
162
|
+
msg = f"Folder does not exist: {directory}"
|
|
163
|
+
raise ValueError(msg)
|
|
164
|
+
|
|
165
|
+
# Prepare working copy
|
|
166
|
+
clone_path = await self.working_copy.prepare(directory.as_uri())
|
|
167
|
+
|
|
168
|
+
# Create source record
|
|
169
|
+
source = await self.repository.create_source(
|
|
170
|
+
Source(
|
|
171
|
+
uri=directory.as_uri(),
|
|
172
|
+
cloned_path=str(clone_path),
|
|
173
|
+
source_type=SourceType.FOLDER,
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Get all files to process
|
|
178
|
+
files = [f for f in clone_path.rglob("*") if f.is_file()]
|
|
179
|
+
|
|
180
|
+
# Process files
|
|
181
|
+
await self._process_files(source, files)
|
|
182
|
+
|
|
183
|
+
return source
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class GitWorkingCopyProvider:
|
|
187
|
+
"""Working copy provider for Git repositories."""
|
|
188
|
+
|
|
189
|
+
def __init__(self, clone_dir: Path) -> None:
|
|
190
|
+
"""Initialize the provider."""
|
|
191
|
+
self.clone_dir = clone_dir
|
|
192
|
+
self.log = structlog.get_logger(__name__)
|
|
193
|
+
|
|
194
|
+
async def prepare(self, uri: str) -> Path:
|
|
195
|
+
"""Prepare a Git working copy."""
|
|
196
|
+
# Create a unique directory name for the clone
|
|
197
|
+
clone_path = self.clone_dir / uri.replace("/", "_").replace(":", "_")
|
|
198
|
+
clone_path.mkdir(parents=True, exist_ok=True)
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
self.log.info("Cloning repository", uri=uri, clone_path=str(clone_path))
|
|
202
|
+
git.Repo.clone_from(uri, clone_path)
|
|
203
|
+
except git.GitCommandError as e:
|
|
204
|
+
if "already exists and is not an empty directory" not in str(e):
|
|
205
|
+
msg = f"Failed to clone repository: {e}"
|
|
206
|
+
raise ValueError(msg) from e
|
|
207
|
+
self.log.info("Repository already exists, reusing...", uri=uri)
|
|
208
|
+
|
|
209
|
+
return clone_path
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class FolderWorkingCopyProvider:
|
|
213
|
+
"""Working copy provider for local folders."""
|
|
214
|
+
|
|
215
|
+
def __init__(self, clone_dir: Path) -> None:
|
|
216
|
+
"""Initialize the provider."""
|
|
217
|
+
self.clone_dir = clone_dir
|
|
218
|
+
|
|
219
|
+
async def prepare(self, uri: str) -> Path:
|
|
220
|
+
"""Prepare a folder working copy."""
|
|
221
|
+
# Handle file:// URIs
|
|
222
|
+
if uri.startswith("file://"):
|
|
223
|
+
from urllib.parse import urlparse
|
|
224
|
+
|
|
225
|
+
parsed = urlparse(uri)
|
|
226
|
+
directory = Path(parsed.path).expanduser().resolve()
|
|
227
|
+
else:
|
|
228
|
+
directory = Path(uri).expanduser().resolve()
|
|
229
|
+
|
|
230
|
+
# Clone into a local directory
|
|
231
|
+
clone_path = self.clone_dir / directory.as_posix().replace("/", "_")
|
|
232
|
+
clone_path.mkdir(parents=True, exist_ok=True)
|
|
233
|
+
|
|
234
|
+
# Copy all files recursively, preserving directory structure, ignoring
|
|
235
|
+
# hidden files
|
|
236
|
+
shutil.copytree(
|
|
237
|
+
directory,
|
|
238
|
+
clone_path,
|
|
239
|
+
ignore=shutil.ignore_patterns(".*"),
|
|
240
|
+
dirs_exist_ok=True,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
return clone_path
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class BaseFileMetadataExtractor:
|
|
247
|
+
"""Base class for file metadata extraction with common functionality."""
|
|
248
|
+
|
|
249
|
+
async def extract(self, path: Path, source: Source) -> File:
|
|
250
|
+
"""Extract metadata from a file."""
|
|
251
|
+
# Get timestamps - to be implemented by subclasses
|
|
252
|
+
created_at, updated_at = await self._get_timestamps(path, source)
|
|
253
|
+
|
|
254
|
+
# Read file content and calculate metadata
|
|
255
|
+
async with aiofiles.open(path, "rb") as f:
|
|
256
|
+
content = await f.read()
|
|
257
|
+
mime_type = mimetypes.guess_type(path)
|
|
258
|
+
sha = sha256(content).hexdigest()
|
|
259
|
+
|
|
260
|
+
return File(
|
|
261
|
+
created_at=created_at,
|
|
262
|
+
updated_at=updated_at,
|
|
263
|
+
source_id=source.id,
|
|
264
|
+
cloned_path=str(path),
|
|
265
|
+
mime_type=mime_type[0]
|
|
266
|
+
if mime_type and mime_type[0]
|
|
267
|
+
else "application/octet-stream",
|
|
268
|
+
uri=path.as_uri(),
|
|
269
|
+
sha256=sha,
|
|
270
|
+
size_bytes=len(content),
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
async def _get_timestamps(
|
|
274
|
+
self, path: Path, source: Source
|
|
275
|
+
) -> tuple[datetime, datetime]:
|
|
276
|
+
"""Get creation and modification timestamps. To be implemented by subclasses."""
|
|
277
|
+
raise NotImplementedError
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class GitFileMetadataExtractor(BaseFileMetadataExtractor):
|
|
281
|
+
"""Git-specific implementation for extracting file metadata."""
|
|
282
|
+
|
|
283
|
+
async def _get_timestamps(
|
|
284
|
+
self, path: Path, source: Source
|
|
285
|
+
) -> tuple[datetime, datetime]:
|
|
286
|
+
"""Get timestamps from Git history."""
|
|
287
|
+
git_repo = git.Repo(source.cloned_path)
|
|
288
|
+
commits = list(git_repo.iter_commits(paths=str(path), all=True))
|
|
289
|
+
|
|
290
|
+
if commits:
|
|
291
|
+
last_modified_at = commits[0].committed_datetime
|
|
292
|
+
first_modified_at = commits[-1].committed_datetime
|
|
293
|
+
return first_modified_at, last_modified_at
|
|
294
|
+
# Fallback to current time if no commits found
|
|
295
|
+
now = datetime.now(UTC)
|
|
296
|
+
return now, now
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class FolderFileMetadataExtractor(BaseFileMetadataExtractor):
|
|
300
|
+
"""Folder-specific implementation for extracting file metadata."""
|
|
301
|
+
|
|
302
|
+
async def _get_timestamps(
|
|
303
|
+
self,
|
|
304
|
+
path: Path,
|
|
305
|
+
source: Source, # noqa: ARG002
|
|
306
|
+
) -> tuple[datetime, datetime]:
|
|
307
|
+
"""Get timestamps from file system."""
|
|
308
|
+
stat = path.stat()
|
|
309
|
+
file_created_at = datetime.fromtimestamp(stat.st_ctime, UTC)
|
|
310
|
+
file_modified_at = datetime.fromtimestamp(stat.st_mtime, UTC)
|
|
311
|
+
return file_created_at, file_modified_at
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class GitAuthorExtractor:
|
|
315
|
+
"""Author extractor for Git repositories."""
|
|
316
|
+
|
|
317
|
+
def __init__(self, repository: SourceRepository) -> None:
|
|
318
|
+
"""Initialize the extractor."""
|
|
319
|
+
self.repository = repository
|
|
320
|
+
|
|
321
|
+
async def extract(self, path: Path, source: Source) -> list[Author]:
|
|
322
|
+
"""Extract authors from a Git file."""
|
|
323
|
+
authors: list[Author] = []
|
|
324
|
+
git_repo = git.Repo(source.cloned_path)
|
|
325
|
+
|
|
326
|
+
try:
|
|
327
|
+
# Get the file's blame
|
|
328
|
+
blames = git_repo.blame("HEAD", str(path))
|
|
329
|
+
|
|
330
|
+
# Extract the blame's authors
|
|
331
|
+
actors = [
|
|
332
|
+
commit.author
|
|
333
|
+
for blame in blames or []
|
|
334
|
+
for commit in blame
|
|
335
|
+
if isinstance(commit, git.Commit)
|
|
336
|
+
]
|
|
337
|
+
|
|
338
|
+
# Get or create the authors in the database
|
|
339
|
+
for actor in actors:
|
|
340
|
+
if actor.email:
|
|
341
|
+
author = Author.from_actor(actor)
|
|
342
|
+
author = await self.repository.upsert_author(author)
|
|
343
|
+
authors.append(author)
|
|
344
|
+
except git.GitCommandError:
|
|
345
|
+
# Handle cases where file might not be tracked
|
|
346
|
+
pass
|
|
347
|
+
|
|
348
|
+
return authors
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class NoOpAuthorExtractor:
|
|
352
|
+
"""No-op author extractor for sources that don't have author information."""
|
|
353
|
+
|
|
354
|
+
async def extract(self, path: Path, source: Source) -> list[Author]: # noqa: ARG002
|
|
355
|
+
"""Return empty list of authors."""
|
|
356
|
+
return []
|
kodit/source/source_models.py
CHANGED
|
@@ -5,7 +5,11 @@ It includes models for tracking different types of sources (git repositories and
|
|
|
5
5
|
folders) and their relationships.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
|
|
8
|
+
import datetime
|
|
9
|
+
from enum import Enum as EnumType
|
|
10
|
+
|
|
11
|
+
from git import Actor
|
|
12
|
+
from sqlalchemy import Enum, ForeignKey, Integer, String, UniqueConstraint
|
|
9
13
|
from sqlalchemy.orm import Mapped, mapped_column
|
|
10
14
|
|
|
11
15
|
from kodit.database import Base, CommonMixin
|
|
@@ -14,6 +18,14 @@ from kodit.database import Base, CommonMixin
|
|
|
14
18
|
__all__ = ["File", "Source"]
|
|
15
19
|
|
|
16
20
|
|
|
21
|
+
class SourceType(EnumType):
|
|
22
|
+
"""The type of source."""
|
|
23
|
+
|
|
24
|
+
UNKNOWN = 0
|
|
25
|
+
FOLDER = 1
|
|
26
|
+
GIT = 2
|
|
27
|
+
|
|
28
|
+
|
|
17
29
|
class Source(Base, CommonMixin):
|
|
18
30
|
"""Base model for tracking code sources.
|
|
19
31
|
|
|
@@ -32,12 +44,45 @@ class Source(Base, CommonMixin):
|
|
|
32
44
|
__tablename__ = "sources"
|
|
33
45
|
uri: Mapped[str] = mapped_column(String(1024), index=True, unique=True)
|
|
34
46
|
cloned_path: Mapped[str] = mapped_column(String(1024), index=True)
|
|
47
|
+
type: Mapped[SourceType] = mapped_column(
|
|
48
|
+
Enum(SourceType), default=SourceType.UNKNOWN, index=True
|
|
49
|
+
)
|
|
35
50
|
|
|
36
|
-
def __init__(self, uri: str, cloned_path: str) -> None:
|
|
51
|
+
def __init__(self, uri: str, cloned_path: str, source_type: SourceType) -> None:
|
|
37
52
|
"""Initialize a new Source instance for typing purposes."""
|
|
38
53
|
super().__init__()
|
|
39
54
|
self.uri = uri
|
|
40
55
|
self.cloned_path = cloned_path
|
|
56
|
+
self.type = source_type
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class Author(Base, CommonMixin):
|
|
60
|
+
"""Author model."""
|
|
61
|
+
|
|
62
|
+
__tablename__ = "authors"
|
|
63
|
+
|
|
64
|
+
__table_args__ = (UniqueConstraint("name", "email", name="uix_author"),)
|
|
65
|
+
|
|
66
|
+
name: Mapped[str] = mapped_column(String(255), index=True)
|
|
67
|
+
email: Mapped[str] = mapped_column(String(255), index=True)
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def from_actor(actor: Actor) -> "Author":
|
|
71
|
+
"""Create an Author from an Actor."""
|
|
72
|
+
return Author(name=actor.name, email=actor.email)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class AuthorFileMapping(Base, CommonMixin):
|
|
76
|
+
"""Author file mapping model."""
|
|
77
|
+
|
|
78
|
+
__tablename__ = "author_file_mappings"
|
|
79
|
+
|
|
80
|
+
__table_args__ = (
|
|
81
|
+
UniqueConstraint("author_id", "file_id", name="uix_author_file_mapping"),
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
author_id: Mapped[int] = mapped_column(ForeignKey("authors.id"), index=True)
|
|
85
|
+
file_id: Mapped[int] = mapped_column(ForeignKey("files.id"), index=True)
|
|
41
86
|
|
|
42
87
|
|
|
43
88
|
class File(Base, CommonMixin):
|
|
@@ -51,9 +96,12 @@ class File(Base, CommonMixin):
|
|
|
51
96
|
cloned_path: Mapped[str] = mapped_column(String(1024), index=True)
|
|
52
97
|
sha256: Mapped[str] = mapped_column(String(64), default="", index=True)
|
|
53
98
|
size_bytes: Mapped[int] = mapped_column(Integer, default=0)
|
|
99
|
+
extension: Mapped[str] = mapped_column(String(255), default="", index=True)
|
|
54
100
|
|
|
55
101
|
def __init__( # noqa: PLR0913
|
|
56
102
|
self,
|
|
103
|
+
created_at: datetime.datetime,
|
|
104
|
+
updated_at: datetime.datetime,
|
|
57
105
|
source_id: int,
|
|
58
106
|
cloned_path: str,
|
|
59
107
|
mime_type: str = "",
|
|
@@ -63,6 +111,8 @@ class File(Base, CommonMixin):
|
|
|
63
111
|
) -> None:
|
|
64
112
|
"""Initialize a new File instance for typing purposes."""
|
|
65
113
|
super().__init__()
|
|
114
|
+
self.created_at = created_at
|
|
115
|
+
self.updated_at = updated_at
|
|
66
116
|
self.source_id = source_id
|
|
67
117
|
self.cloned_path = cloned_path
|
|
68
118
|
self.mime_type = mime_type
|
|
@@ -3,7 +3,13 @@
|
|
|
3
3
|
from sqlalchemy import func, select
|
|
4
4
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
5
5
|
|
|
6
|
-
from kodit.source.source_models import
|
|
6
|
+
from kodit.source.source_models import (
|
|
7
|
+
Author,
|
|
8
|
+
AuthorFileMapping,
|
|
9
|
+
File,
|
|
10
|
+
Source,
|
|
11
|
+
SourceType,
|
|
12
|
+
)
|
|
7
13
|
|
|
8
14
|
|
|
9
15
|
class SourceRepository:
|
|
@@ -22,22 +28,12 @@ class SourceRepository:
|
|
|
22
28
|
self.session = session
|
|
23
29
|
|
|
24
30
|
async def create_source(self, source: Source) -> Source:
|
|
25
|
-
"""
|
|
31
|
+
"""Add a new source to the database."""
|
|
32
|
+
# Validate the source
|
|
33
|
+
if source.type == SourceType.UNKNOWN:
|
|
34
|
+
msg = "Source type is required"
|
|
35
|
+
raise ValueError(msg)
|
|
26
36
|
|
|
27
|
-
This method creates both a Source record and a linked FolderSource record
|
|
28
|
-
in a single transaction.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
path: The absolute path of the folder to create a source for.
|
|
32
|
-
|
|
33
|
-
Returns:
|
|
34
|
-
The created Source model instance.
|
|
35
|
-
|
|
36
|
-
Note:
|
|
37
|
-
This method commits the transaction to ensure the source.id is available
|
|
38
|
-
for creating the linked FolderSource record.
|
|
39
|
-
|
|
40
|
-
"""
|
|
41
37
|
self.session.add(source)
|
|
42
38
|
await self.session.commit()
|
|
43
39
|
return source
|
|
@@ -52,6 +48,12 @@ class SourceRepository:
|
|
|
52
48
|
await self.session.commit()
|
|
53
49
|
return file
|
|
54
50
|
|
|
51
|
+
async def list_files_for_source(self, source_id: int) -> list[File]:
|
|
52
|
+
"""List all files for a source."""
|
|
53
|
+
query = select(File).where(File.source_id == source_id)
|
|
54
|
+
result = await self.session.execute(query)
|
|
55
|
+
return list(result.scalars())
|
|
56
|
+
|
|
55
57
|
async def num_files_for_source(self, source_id: int) -> int:
|
|
56
58
|
"""Get the number of files for a source.
|
|
57
59
|
|
|
@@ -103,3 +105,65 @@ class SourceRepository:
|
|
|
103
105
|
query = select(Source).where(Source.id == source_id)
|
|
104
106
|
result = await self.session.execute(query)
|
|
105
107
|
return result.scalar_one_or_none()
|
|
108
|
+
|
|
109
|
+
async def get_author_by_email(self, email: str) -> Author | None:
|
|
110
|
+
"""Get an author by email."""
|
|
111
|
+
query = select(Author).where(Author.email == email)
|
|
112
|
+
result = await self.session.execute(query)
|
|
113
|
+
return result.scalar_one_or_none()
|
|
114
|
+
|
|
115
|
+
async def upsert_author(self, author: Author) -> Author:
|
|
116
|
+
"""Create a new author or return existing one if email already exists.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
author: The Author instance to upsert.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
The existing Author if one with the same email exists, otherwise the newly
|
|
123
|
+
created Author.
|
|
124
|
+
|
|
125
|
+
"""
|
|
126
|
+
# First check if author already exists with same name and email
|
|
127
|
+
query = select(Author).where(
|
|
128
|
+
Author.name == author.name, Author.email == author.email
|
|
129
|
+
)
|
|
130
|
+
result = await self.session.execute(query)
|
|
131
|
+
existing_author = result.scalar_one_or_none()
|
|
132
|
+
|
|
133
|
+
if existing_author:
|
|
134
|
+
return existing_author
|
|
135
|
+
|
|
136
|
+
# Author doesn't exist, create new one
|
|
137
|
+
self.session.add(author)
|
|
138
|
+
await self.session.commit()
|
|
139
|
+
return author
|
|
140
|
+
|
|
141
|
+
async def upsert_author_file_mapping(
|
|
142
|
+
self, mapping: AuthorFileMapping
|
|
143
|
+
) -> AuthorFileMapping:
|
|
144
|
+
"""Create a new author file mapping or return existing one if already exists."""
|
|
145
|
+
# First check if mapping already exists with same author_id and file_id
|
|
146
|
+
query = select(AuthorFileMapping).where(
|
|
147
|
+
AuthorFileMapping.author_id == mapping.author_id,
|
|
148
|
+
AuthorFileMapping.file_id == mapping.file_id,
|
|
149
|
+
)
|
|
150
|
+
result = await self.session.execute(query)
|
|
151
|
+
existing_mapping = result.scalar_one_or_none()
|
|
152
|
+
|
|
153
|
+
if existing_mapping:
|
|
154
|
+
return existing_mapping
|
|
155
|
+
|
|
156
|
+
# Mapping doesn't exist, create new one
|
|
157
|
+
self.session.add(mapping)
|
|
158
|
+
await self.session.commit()
|
|
159
|
+
return mapping
|
|
160
|
+
|
|
161
|
+
async def list_files_for_author(self, author_id: int) -> list[File]:
|
|
162
|
+
"""List all files for an author."""
|
|
163
|
+
query = (
|
|
164
|
+
select(File)
|
|
165
|
+
.join(AuthorFileMapping)
|
|
166
|
+
.where(AuthorFileMapping.author_id == author_id)
|
|
167
|
+
)
|
|
168
|
+
result = await self.session.execute(query)
|
|
169
|
+
return list(result.scalars())
|