kodit 0.2.8__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (37) hide show
  1. kodit/_version.py +2 -2
  2. kodit/app.py +36 -1
  3. kodit/application/factories/__init__.py +1 -0
  4. kodit/application/factories/code_indexing_factory.py +119 -0
  5. kodit/application/services/{indexing_application_service.py → code_indexing_application_service.py} +159 -198
  6. kodit/cli.py +214 -62
  7. kodit/config.py +40 -3
  8. kodit/domain/entities.py +7 -5
  9. kodit/domain/repositories.py +33 -0
  10. kodit/domain/services/bm25_service.py +14 -17
  11. kodit/domain/services/embedding_service.py +10 -14
  12. kodit/domain/services/snippet_service.py +198 -0
  13. kodit/domain/value_objects.py +301 -21
  14. kodit/infrastructure/bm25/local_bm25_repository.py +20 -12
  15. kodit/infrastructure/bm25/vectorchord_bm25_repository.py +31 -11
  16. kodit/infrastructure/cloning/metadata.py +1 -0
  17. kodit/infrastructure/embedding/embedding_providers/hash_embedding_provider.py +14 -25
  18. kodit/infrastructure/embedding/local_vector_search_repository.py +26 -38
  19. kodit/infrastructure/embedding/vectorchord_vector_search_repository.py +50 -35
  20. kodit/infrastructure/enrichment/enrichment_factory.py +1 -1
  21. kodit/infrastructure/indexing/auto_indexing_service.py +84 -0
  22. kodit/infrastructure/indexing/indexing_factory.py +8 -91
  23. kodit/infrastructure/indexing/snippet_domain_service_factory.py +37 -0
  24. kodit/infrastructure/snippet_extraction/languages/java.scm +12 -0
  25. kodit/infrastructure/snippet_extraction/snippet_extraction_factory.py +3 -31
  26. kodit/infrastructure/sqlalchemy/embedding_repository.py +14 -3
  27. kodit/infrastructure/sqlalchemy/snippet_repository.py +174 -2
  28. kodit/mcp.py +61 -49
  29. {kodit-0.2.8.dist-info → kodit-0.3.0.dist-info}/METADATA +1 -1
  30. {kodit-0.2.8.dist-info → kodit-0.3.0.dist-info}/RECORD +33 -31
  31. kodit/application/commands/__init__.py +0 -1
  32. kodit/application/commands/snippet_commands.py +0 -22
  33. kodit/application/services/snippet_application_service.py +0 -149
  34. kodit/infrastructure/enrichment/legacy_enrichment_models.py +0 -42
  35. {kodit-0.2.8.dist-info → kodit-0.3.0.dist-info}/WHEEL +0 -0
  36. {kodit-0.2.8.dist-info → kodit-0.3.0.dist-info}/entry_points.txt +0 -0
  37. {kodit-0.2.8.dist-info → kodit-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -7,10 +7,10 @@ from kodit.domain.entities import EmbeddingType
7
7
  from kodit.domain.value_objects import (
8
8
  EmbeddingRequest,
9
9
  EmbeddingResponse,
10
+ IndexRequest,
10
11
  IndexResult,
11
- VectorIndexRequest,
12
- VectorSearchQueryRequest,
13
- VectorSearchResult,
12
+ SearchRequest,
13
+ SearchResult,
14
14
  )
15
15
 
16
16
 
@@ -29,14 +29,12 @@ class VectorSearchRepository(ABC):
29
29
 
30
30
  @abstractmethod
31
31
  def index_documents(
32
- self, request: VectorIndexRequest
32
+ self, request: IndexRequest
33
33
  ) -> AsyncGenerator[list[IndexResult], None]:
34
34
  """Index documents for vector search."""
35
35
 
36
36
  @abstractmethod
37
- async def search(
38
- self, request: VectorSearchQueryRequest
39
- ) -> Sequence[VectorSearchResult]:
37
+ async def search(self, request: SearchRequest) -> Sequence[SearchResult]:
40
38
  """Search documents using vector similarity."""
41
39
 
42
40
  @abstractmethod
@@ -65,7 +63,7 @@ class EmbeddingDomainService:
65
63
  self.vector_search_repository = vector_search_repository
66
64
 
67
65
  async def index_documents(
68
- self, request: VectorIndexRequest
66
+ self, request: IndexRequest
69
67
  ) -> AsyncGenerator[list[IndexResult], None]:
70
68
  """Index documents using domain business rules.
71
69
 
@@ -94,15 +92,13 @@ class EmbeddingDomainService:
94
92
  return
95
93
 
96
94
  # Domain logic: create new request with validated documents
97
- validated_request = VectorIndexRequest(documents=valid_documents)
95
+ validated_request = IndexRequest(documents=valid_documents)
98
96
  async for result in self.vector_search_repository.index_documents(
99
97
  validated_request
100
98
  ):
101
99
  yield result
102
100
 
103
- async def search(
104
- self, request: VectorSearchQueryRequest
105
- ) -> Sequence[VectorSearchResult]:
101
+ async def search(self, request: SearchRequest) -> Sequence[SearchResult]:
106
102
  """Search documents using domain business rules.
107
103
 
108
104
  Args:
@@ -124,8 +120,8 @@ class EmbeddingDomainService:
124
120
 
125
121
  # Domain logic: normalize query
126
122
  normalized_query = request.query.strip()
127
- normalized_request = VectorSearchQueryRequest(
128
- query=normalized_query, top_k=request.top_k
123
+ normalized_request = SearchRequest(
124
+ query=normalized_query, top_k=request.top_k, snippet_ids=request.snippet_ids
129
125
  )
130
126
 
131
127
  return await self.vector_search_repository.search(normalized_request)
@@ -0,0 +1,198 @@
1
+ """Domain service for snippet operations."""
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import structlog
7
+
8
+ from kodit.domain.entities import Snippet
9
+ from kodit.domain.enums import SnippetExtractionStrategy
10
+ from kodit.domain.interfaces import ProgressCallback
11
+ from kodit.domain.repositories import FileRepository, SnippetRepository
12
+ from kodit.domain.services.snippet_extraction_service import (
13
+ SnippetExtractionDomainService,
14
+ )
15
+ from kodit.domain.value_objects import (
16
+ MultiSearchRequest,
17
+ SnippetExtractionRequest,
18
+ SnippetListItem,
19
+ )
20
+ from kodit.reporting import Reporter
21
+
22
+
23
+ class SnippetDomainService:
24
+ """Domain service for snippet-related operations.
25
+
26
+ This service consolidates snippet operations that were previously
27
+ spread between application services. It handles:
28
+ - Snippet extraction from files
29
+ - Snippet persistence
30
+ - Snippet querying and filtering
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ snippet_extraction_service: SnippetExtractionDomainService,
36
+ snippet_repository: SnippetRepository,
37
+ file_repository: FileRepository,
38
+ ) -> None:
39
+ """Initialize the snippet domain service.
40
+
41
+ Args:
42
+ snippet_extraction_service: Service for extracting snippets from files
43
+ snippet_repository: Repository for snippet persistence
44
+ file_repository: Repository for file operations
45
+
46
+ """
47
+ self.snippet_extraction_service = snippet_extraction_service
48
+ self.snippet_repository = snippet_repository
49
+ self.file_repository = file_repository
50
+ self.log = structlog.get_logger(__name__)
51
+
52
+ async def extract_and_create_snippets(
53
+ self,
54
+ index_id: int,
55
+ strategy: SnippetExtractionStrategy,
56
+ progress_callback: ProgressCallback | None = None,
57
+ ) -> list[Snippet]:
58
+ """Extract snippets from all files in an index and persist them.
59
+
60
+ This method combines the extraction and persistence logic that was
61
+ previously split between domain and application services.
62
+
63
+ Args:
64
+ index_id: The ID of the index to create snippets for
65
+ strategy: The extraction strategy to use
66
+ progress_callback: Optional callback for progress reporting
67
+
68
+ Returns:
69
+ List of created Snippet entities with IDs assigned
70
+
71
+ """
72
+ files = await self.file_repository.get_files_for_index(index_id)
73
+ created_snippets = []
74
+
75
+ reporter = Reporter(self.log, progress_callback)
76
+ await reporter.start(
77
+ "create_snippets", len(files), "Creating snippets from files..."
78
+ )
79
+
80
+ for i, file in enumerate(files, 1):
81
+ if not self._should_process_file(file):
82
+ continue
83
+
84
+ try:
85
+ # Extract snippets from file
86
+ request = SnippetExtractionRequest(Path(file.cloned_path), strategy)
87
+ result = await self.snippet_extraction_service.extract_snippets(request)
88
+
89
+ # Create and persist snippet entities
90
+ for snippet_content in result.snippets:
91
+ snippet = Snippet(
92
+ file_id=file.id,
93
+ index_id=index_id,
94
+ content=snippet_content,
95
+ )
96
+ saved_snippet = await self.snippet_repository.save(snippet)
97
+ created_snippets.append(saved_snippet)
98
+
99
+ except (OSError, ValueError) as e:
100
+ self.log.debug(
101
+ "Skipping file",
102
+ file=file.cloned_path,
103
+ error=str(e),
104
+ )
105
+ continue
106
+
107
+ await reporter.step(
108
+ "create_snippets",
109
+ current=i,
110
+ total=len(files),
111
+ message=f"Processing {file.cloned_path}...",
112
+ )
113
+
114
+ await reporter.done("create_snippets")
115
+ return created_snippets
116
+
117
+ async def get_snippets_for_index(self, index_id: int) -> list[Snippet]:
118
+ """Get all snippets for a specific index.
119
+
120
+ Args:
121
+ index_id: The ID of the index
122
+
123
+ Returns:
124
+ List of Snippet entities for the index
125
+
126
+ """
127
+ # This delegates to the repository but provides a domain-level interface
128
+ return list(await self.snippet_repository.get_by_index(index_id))
129
+
130
+ async def update_snippet_content(self, snippet_id: int, content: str) -> None:
131
+ """Update the content of an existing snippet.
132
+
133
+ Args:
134
+ snippet_id: The ID of the snippet to update
135
+ content: The new content for the snippet
136
+
137
+ """
138
+ # Get the snippet first to ensure it exists
139
+ snippet = await self.snippet_repository.get(snippet_id)
140
+ if not snippet:
141
+ msg = f"Snippet not found: {snippet_id}"
142
+ raise ValueError(msg)
143
+
144
+ # Update the content
145
+ snippet.content = content
146
+ await self.snippet_repository.save(snippet)
147
+
148
+ async def delete_snippets_for_index(self, index_id: int) -> None:
149
+ """Delete all snippets for a specific index.
150
+
151
+ Args:
152
+ index_id: The ID of the index
153
+
154
+ """
155
+ await self.snippet_repository.delete_by_index(index_id)
156
+
157
+ async def search_snippets(
158
+ self, request: MultiSearchRequest
159
+ ) -> list[SnippetListItem]:
160
+ """Search snippets with filters.
161
+
162
+ Args:
163
+ request: The search request containing filters
164
+
165
+ Returns:
166
+ List of matching snippet items
167
+
168
+ """
169
+ return list(await self.snippet_repository.search(request))
170
+
171
+ async def list_snippets(
172
+ self, file_path: str | None = None, source_uri: str | None = None
173
+ ) -> list[SnippetListItem]:
174
+ """List snippets with optional filtering.
175
+
176
+ Args:
177
+ file_path: Optional file path to filter by
178
+ source_uri: Optional source URI to filter by
179
+
180
+ Returns:
181
+ List of snippet items matching the criteria
182
+
183
+ """
184
+ return list(await self.snippet_repository.list_snippets(file_path, source_uri))
185
+
186
+ def _should_process_file(self, file: Any) -> bool:
187
+ """Check if a file should be processed for snippet extraction.
188
+
189
+ Args:
190
+ file: The file to check
191
+
192
+ Returns:
193
+ True if the file should be processed
194
+
195
+ """
196
+ # Skip unsupported file types
197
+ mime_blacklist = ["unknown/unknown"]
198
+ return file.mime_type not in mime_blacklist
@@ -4,7 +4,12 @@ from dataclasses import dataclass
4
4
  from datetime import datetime
5
5
  from enum import Enum
6
6
  from pathlib import Path
7
+ from typing import Any, ClassVar
7
8
 
9
+ from sqlalchemy import JSON, DateTime, Integer, Text
10
+ from sqlalchemy.orm import Mapped, mapped_column
11
+
12
+ from kodit.domain.entities import Base
8
13
  from kodit.domain.enums import SnippetExtractionStrategy
9
14
 
10
15
 
@@ -40,6 +45,14 @@ class Document:
40
45
  text: str
41
46
 
42
47
 
48
+ @dataclass
49
+ class DocumentSearchResult:
50
+ """Generic document search result model."""
51
+
52
+ snippet_id: int
53
+ score: float
54
+
55
+
43
56
  @dataclass
44
57
  class SearchResult:
45
58
  """Generic search result model."""
@@ -56,12 +69,12 @@ class IndexRequest:
56
69
 
57
70
 
58
71
  @dataclass
59
- class SimpleSearchRequest:
72
+ class SearchRequest:
60
73
  """Generic search request (single query string)."""
61
74
 
62
75
  query: str
63
76
  top_k: int = 10
64
- search_type: SearchType = SearchType.BM25
77
+ snippet_ids: list[int] | None = None
65
78
 
66
79
 
67
80
  @dataclass
@@ -78,17 +91,75 @@ class IndexResult:
78
91
  snippet_id: int
79
92
 
80
93
 
81
- # Legacy aliases for backward compatibility
82
- BM25Document = Document
83
- BM25SearchResult = SearchResult
84
- BM25IndexRequest = IndexRequest
85
- BM25SearchRequest = SimpleSearchRequest
86
- BM25DeleteRequest = DeleteRequest
87
-
88
- VectorSearchRequest = Document
89
- VectorSearchResult = SearchResult
90
- VectorIndexRequest = IndexRequest
91
- VectorSearchQueryRequest = SimpleSearchRequest
94
+ @dataclass(frozen=True)
95
+ class SnippetSearchFilters:
96
+ """Value object for filtering snippet search results."""
97
+
98
+ language: str | None = None
99
+ author: str | None = None
100
+ created_after: datetime | None = None
101
+ created_before: datetime | None = None
102
+ source_repo: str | None = None
103
+
104
+ @classmethod
105
+ def from_cli_params(
106
+ cls,
107
+ language: str | None = None,
108
+ author: str | None = None,
109
+ created_after: str | None = None,
110
+ created_before: str | None = None,
111
+ source_repo: str | None = None,
112
+ ) -> "SnippetSearchFilters | None":
113
+ """Create SnippetSearchFilters from CLI parameters.
114
+
115
+ Args:
116
+ language: Programming language filter (e.g., python, go, javascript)
117
+ author: Author name filter
118
+ created_after: Date string in YYYY-MM-DD format for filtering snippets
119
+ created after
120
+ created_before: Date string in YYYY-MM-DD format for filtering snippets
121
+ created before
122
+ source_repo: Source repository filter (e.g., github.com/example/repo)
123
+
124
+ Returns:
125
+ SnippetSearchFilters instance if any filters are provided, None otherwise
126
+
127
+ Raises:
128
+ ValueError: If date strings are in invalid format
129
+
130
+ """
131
+ # Only create filters if at least one parameter is provided
132
+ if not any([language, author, created_after, created_before, source_repo]):
133
+ return None
134
+
135
+ # Parse date strings if provided
136
+ parsed_created_after = None
137
+ if created_after:
138
+ try:
139
+ parsed_created_after = datetime.fromisoformat(created_after)
140
+ except ValueError as e:
141
+ raise ValueError(
142
+ f"Invalid date format for created_after: {created_after}. "
143
+ "Expected ISO 8601 format (YYYY-MM-DD)"
144
+ ) from e
145
+
146
+ parsed_created_before = None
147
+ if created_before:
148
+ try:
149
+ parsed_created_before = datetime.fromisoformat(created_before)
150
+ except ValueError as e:
151
+ raise ValueError(
152
+ f"Invalid date format for created_before: {created_before}. "
153
+ "Expected ISO 8601 format (YYYY-MM-DD)"
154
+ ) from e
155
+
156
+ return cls(
157
+ language=language,
158
+ author=author,
159
+ created_after=parsed_created_after,
160
+ created_before=parsed_created_before,
161
+ source_repo=source_repo,
162
+ )
92
163
 
93
164
 
94
165
  @dataclass
@@ -99,6 +170,7 @@ class MultiSearchRequest:
99
170
  text_query: str | None = None
100
171
  code_query: str | None = None
101
172
  keywords: list[str] | None = None
173
+ filters: SnippetSearchFilters | None = None
102
174
 
103
175
 
104
176
  @dataclass
@@ -196,14 +268,6 @@ class EnrichmentIndexRequest:
196
268
  requests: list[EnrichmentRequest]
197
269
 
198
270
 
199
- @dataclass
200
- class EnrichmentSearchRequest:
201
- """Domain model for enrichment search request."""
202
-
203
- query: str
204
- top_k: int = 10
205
-
206
-
207
271
  @dataclass
208
272
  class IndexView:
209
273
  """Domain model for index information."""
@@ -213,3 +277,219 @@ class IndexView:
213
277
  num_snippets: int
214
278
  updated_at: datetime | None = None
215
279
  source: str | None = None
280
+
281
+
282
+ @dataclass
283
+ class SnippetListItem:
284
+ """Domain model for snippet list item with file information."""
285
+
286
+ id: int
287
+ file_path: str
288
+ content: str
289
+ source_uri: str
290
+
291
+
292
+ class LanguageMapping:
293
+ """Value object for language-to-extension mappings.
294
+
295
+ This encapsulates the domain knowledge of programming languages and their
296
+ associated file extensions. It provides bidirectional mapping capabilities
297
+ and is designed to be immutable and reusable across the application.
298
+ """
299
+
300
+ # Comprehensive mapping of language names to their file extensions
301
+ _LANGUAGE_TO_EXTENSIONS: ClassVar[dict[str, list[str]]] = {
302
+ "python": ["py", "pyw", "pyx", "pxd"],
303
+ "go": ["go"],
304
+ "javascript": ["js", "jsx", "mjs"],
305
+ "typescript": ["ts", "tsx"],
306
+ "java": ["java"],
307
+ "csharp": ["cs"],
308
+ "cpp": ["cpp", "cc", "cxx", "hpp"],
309
+ "c": ["c", "h"],
310
+ "rust": ["rs"],
311
+ "php": ["php"],
312
+ "ruby": ["rb"],
313
+ "swift": ["swift"],
314
+ "kotlin": ["kt", "kts"],
315
+ "scala": ["scala"],
316
+ "r": ["r", "R"],
317
+ "matlab": ["m"],
318
+ "perl": ["pl", "pm"],
319
+ "bash": ["sh", "bash"],
320
+ "powershell": ["ps1"],
321
+ "sql": ["sql"],
322
+ "html": ["html", "htm"],
323
+ "css": ["css", "scss", "sass"],
324
+ "yaml": ["yml", "yaml"],
325
+ "json": ["json"],
326
+ "xml": ["xml"],
327
+ "markdown": ["md", "markdown"],
328
+ }
329
+
330
+ @classmethod
331
+ def get_extensions_for_language(cls, language: str) -> list[str]:
332
+ """Get file extensions for a given language.
333
+
334
+ Args:
335
+ language: The programming language name (case-insensitive)
336
+
337
+ Returns:
338
+ List of file extensions (without leading dots) for the language
339
+
340
+ Raises:
341
+ ValueError: If the language is not supported
342
+
343
+ """
344
+ language_lower = language.lower()
345
+ extensions = cls._LANGUAGE_TO_EXTENSIONS.get(language_lower)
346
+
347
+ if extensions is None:
348
+ raise ValueError(f"Unsupported language: {language}")
349
+
350
+ return extensions.copy() # Return a copy to prevent modification
351
+
352
+ @classmethod
353
+ def get_language_for_extension(cls, extension: str) -> str:
354
+ """Get language for a given file extension.
355
+
356
+ Args:
357
+ extension: The file extension (with or without leading dot)
358
+
359
+ Returns:
360
+ The programming language name
361
+
362
+ Raises:
363
+ ValueError: If the extension is not supported
364
+
365
+ """
366
+ # Remove leading dot if present
367
+ ext_clean = extension.removeprefix(".").lower()
368
+
369
+ # Search through all languages to find matching extension
370
+ for language, extensions in cls._LANGUAGE_TO_EXTENSIONS.items():
371
+ if ext_clean in extensions:
372
+ return language
373
+
374
+ raise ValueError(f"Unsupported file extension: {extension}")
375
+
376
+ @classmethod
377
+ def get_extension_to_language_map(cls) -> dict[str, str]:
378
+ """Get a mapping from file extensions to language names.
379
+
380
+ Returns:
381
+ Dictionary mapping file extensions (without leading dots) to language names
382
+
383
+ """
384
+ extension_map = {}
385
+ for language, extensions in cls._LANGUAGE_TO_EXTENSIONS.items():
386
+ for extension in extensions:
387
+ extension_map[extension] = language
388
+ return extension_map
389
+
390
+ @classmethod
391
+ def get_supported_languages(cls) -> list[str]:
392
+ """Get list of all supported programming languages.
393
+
394
+ Returns:
395
+ List of supported language names
396
+
397
+ """
398
+ return list(cls._LANGUAGE_TO_EXTENSIONS.keys())
399
+
400
+ @classmethod
401
+ def get_supported_extensions(cls) -> list[str]:
402
+ """Get list of all supported file extensions.
403
+
404
+ Returns:
405
+ List of supported file extensions (without leading dots)
406
+
407
+ """
408
+ extensions = []
409
+ for ext_list in cls._LANGUAGE_TO_EXTENSIONS.values():
410
+ extensions.extend(ext_list)
411
+ return extensions
412
+
413
+ @classmethod
414
+ def is_supported_language(cls, language: str) -> bool:
415
+ """Check if a language is supported.
416
+
417
+ Args:
418
+ language: The programming language name (case-insensitive)
419
+
420
+ Returns:
421
+ True if the language is supported, False otherwise
422
+
423
+ """
424
+ return language.lower() in cls._LANGUAGE_TO_EXTENSIONS
425
+
426
+ @classmethod
427
+ def is_supported_extension(cls, extension: str) -> bool:
428
+ """Check if a file extension is supported.
429
+
430
+ Args:
431
+ extension: The file extension (with or without leading dot)
432
+
433
+ Returns:
434
+ True if the extension is supported, False otherwise
435
+
436
+ """
437
+ try:
438
+ cls.get_language_for_extension(extension)
439
+ except ValueError:
440
+ return False
441
+ return True
442
+
443
+ @classmethod
444
+ def get_extensions_with_fallback(cls, language: str) -> list[str]:
445
+ """Get file extensions for a language, falling back to passed language name.
446
+
447
+ Args:
448
+ language: The programming language name (case-insensitive)
449
+
450
+ Returns:
451
+ List of file extensions (without leading dots) for the language, or
452
+ [language.lower()] if not found.
453
+
454
+ """
455
+ language_lower = language.lower()
456
+ if cls.is_supported_language(language_lower):
457
+ return cls.get_extensions_for_language(language_lower)
458
+ return [language_lower]
459
+
460
+
461
+ # Database models for value objects
462
+ class BM25DocumentModel(Base):
463
+ """BM25 document model."""
464
+
465
+ __tablename__ = "bm25_documents"
466
+
467
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
468
+ content: Mapped[str] = mapped_column(Text, nullable=False)
469
+ document_metadata: Mapped[dict[str, Any] | None] = mapped_column(
470
+ JSON, nullable=True
471
+ )
472
+ created_at: Mapped[datetime] = mapped_column(
473
+ DateTime(timezone=True), nullable=False
474
+ )
475
+ updated_at: Mapped[datetime] = mapped_column(
476
+ DateTime(timezone=True), nullable=False
477
+ )
478
+
479
+
480
+ class VectorDocumentModel(Base):
481
+ """Vector document model."""
482
+
483
+ __tablename__ = "vector_documents"
484
+
485
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
486
+ content: Mapped[str] = mapped_column(Text, nullable=False)
487
+ document_metadata: Mapped[dict[str, Any] | None] = mapped_column(
488
+ JSON, nullable=True
489
+ )
490
+ created_at: Mapped[datetime] = mapped_column(
491
+ DateTime(timezone=True), nullable=False
492
+ )
493
+ updated_at: Mapped[datetime] = mapped_column(
494
+ DateTime(timezone=True), nullable=False
495
+ )
@@ -12,10 +12,10 @@ import structlog
12
12
 
13
13
  from kodit.domain.services.bm25_service import BM25Repository
14
14
  from kodit.domain.value_objects import (
15
- BM25DeleteRequest,
16
- BM25IndexRequest,
17
- BM25SearchRequest,
18
- BM25SearchResult,
15
+ DeleteRequest,
16
+ IndexRequest,
17
+ SearchRequest,
18
+ SearchResult,
19
19
  )
20
20
 
21
21
  if TYPE_CHECKING:
@@ -68,7 +68,7 @@ class LocalBM25Repository(BM25Repository):
68
68
  show_progress=True,
69
69
  )
70
70
 
71
- async def index_documents(self, request: BM25IndexRequest) -> None:
71
+ async def index_documents(self, request: IndexRequest) -> None:
72
72
  """Index documents for BM25 search."""
73
73
  self.log.debug("Indexing corpus")
74
74
  if not request.documents:
@@ -84,7 +84,7 @@ class LocalBM25Repository(BM25Repository):
84
84
  async with aiofiles.open(self.index_path / SNIPPET_IDS_FILE, "w") as f:
85
85
  await f.write(json.dumps(self.snippet_ids))
86
86
 
87
- async def search(self, request: BM25SearchRequest) -> list[BM25SearchResult]:
87
+ async def search(self, request: SearchRequest) -> list[SearchResult]:
88
88
  """Search documents using BM25."""
89
89
  if request.top_k == 0:
90
90
  self.log.warning("Top k is 0, returning empty list")
@@ -117,13 +117,21 @@ class LocalBM25Repository(BM25Repository):
117
117
  k=top_k,
118
118
  )
119
119
  self.log.debug("Raw results", results=results, scores=scores)
120
- return [
121
- BM25SearchResult(snippet_id=int(result), score=float(score))
122
- for result, score in zip(results[0], scores[0], strict=False)
123
- if score > 0.0
124
- ]
125
120
 
126
- async def delete_documents(self, request: BM25DeleteRequest) -> None:
121
+ # Filter results by snippet_ids if provided
122
+ filtered_results = []
123
+ for result, score in zip(results[0], scores[0], strict=False):
124
+ snippet_id = int(result)
125
+ if score > 0.0 and (
126
+ request.snippet_ids is None or snippet_id in request.snippet_ids
127
+ ):
128
+ filtered_results.append(
129
+ SearchResult(snippet_id=snippet_id, score=float(score))
130
+ )
131
+
132
+ return filtered_results
133
+
134
+ async def delete_documents(self, request: DeleteRequest) -> None:
127
135
  """Delete documents from the index."""
128
136
  # request parameter is unused as deletion is not supported
129
137
  # ruff: noqa: ARG002