kodit 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/application/factories/__init__.py +1 -0
- kodit/application/factories/code_indexing_factory.py +119 -0
- kodit/application/services/{indexing_application_service.py → code_indexing_application_service.py} +159 -198
- kodit/cli.py +199 -62
- kodit/domain/entities.py +7 -5
- kodit/domain/repositories.py +33 -0
- kodit/domain/services/bm25_service.py +14 -17
- kodit/domain/services/embedding_service.py +10 -14
- kodit/domain/services/snippet_service.py +198 -0
- kodit/domain/value_objects.py +301 -21
- kodit/infrastructure/bm25/local_bm25_repository.py +20 -12
- kodit/infrastructure/bm25/vectorchord_bm25_repository.py +31 -11
- kodit/infrastructure/cloning/git/working_copy.py +5 -2
- kodit/infrastructure/cloning/metadata.py +1 -0
- kodit/infrastructure/embedding/embedding_providers/hash_embedding_provider.py +14 -25
- kodit/infrastructure/embedding/local_vector_search_repository.py +26 -38
- kodit/infrastructure/embedding/vectorchord_vector_search_repository.py +50 -35
- kodit/infrastructure/enrichment/enrichment_factory.py +1 -1
- kodit/infrastructure/indexing/indexing_factory.py +8 -91
- kodit/infrastructure/indexing/snippet_domain_service_factory.py +37 -0
- kodit/infrastructure/snippet_extraction/languages/java.scm +12 -0
- kodit/infrastructure/snippet_extraction/snippet_extraction_factory.py +3 -31
- kodit/infrastructure/sqlalchemy/embedding_repository.py +14 -3
- kodit/infrastructure/sqlalchemy/snippet_repository.py +174 -2
- kodit/mcp.py +61 -49
- {kodit-0.2.7.dist-info → kodit-0.2.9.dist-info}/METADATA +1 -1
- {kodit-0.2.7.dist-info → kodit-0.2.9.dist-info}/RECORD +31 -30
- kodit/application/commands/__init__.py +0 -1
- kodit/application/commands/snippet_commands.py +0 -22
- kodit/application/services/snippet_application_service.py +0 -149
- kodit/infrastructure/enrichment/legacy_enrichment_models.py +0 -42
- {kodit-0.2.7.dist-info → kodit-0.2.9.dist-info}/WHEEL +0 -0
- {kodit-0.2.7.dist-info → kodit-0.2.9.dist-info}/entry_points.txt +0 -0
- {kodit-0.2.7.dist-info → kodit-0.2.9.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""Domain service for snippet operations."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import structlog
|
|
7
|
+
|
|
8
|
+
from kodit.domain.entities import Snippet
|
|
9
|
+
from kodit.domain.enums import SnippetExtractionStrategy
|
|
10
|
+
from kodit.domain.interfaces import ProgressCallback
|
|
11
|
+
from kodit.domain.repositories import FileRepository, SnippetRepository
|
|
12
|
+
from kodit.domain.services.snippet_extraction_service import (
|
|
13
|
+
SnippetExtractionDomainService,
|
|
14
|
+
)
|
|
15
|
+
from kodit.domain.value_objects import (
|
|
16
|
+
MultiSearchRequest,
|
|
17
|
+
SnippetExtractionRequest,
|
|
18
|
+
SnippetListItem,
|
|
19
|
+
)
|
|
20
|
+
from kodit.reporting import Reporter
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SnippetDomainService:
|
|
24
|
+
"""Domain service for snippet-related operations.
|
|
25
|
+
|
|
26
|
+
This service consolidates snippet operations that were previously
|
|
27
|
+
spread between application services. It handles:
|
|
28
|
+
- Snippet extraction from files
|
|
29
|
+
- Snippet persistence
|
|
30
|
+
- Snippet querying and filtering
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
snippet_extraction_service: SnippetExtractionDomainService,
|
|
36
|
+
snippet_repository: SnippetRepository,
|
|
37
|
+
file_repository: FileRepository,
|
|
38
|
+
) -> None:
|
|
39
|
+
"""Initialize the snippet domain service.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
snippet_extraction_service: Service for extracting snippets from files
|
|
43
|
+
snippet_repository: Repository for snippet persistence
|
|
44
|
+
file_repository: Repository for file operations
|
|
45
|
+
|
|
46
|
+
"""
|
|
47
|
+
self.snippet_extraction_service = snippet_extraction_service
|
|
48
|
+
self.snippet_repository = snippet_repository
|
|
49
|
+
self.file_repository = file_repository
|
|
50
|
+
self.log = structlog.get_logger(__name__)
|
|
51
|
+
|
|
52
|
+
async def extract_and_create_snippets(
|
|
53
|
+
self,
|
|
54
|
+
index_id: int,
|
|
55
|
+
strategy: SnippetExtractionStrategy,
|
|
56
|
+
progress_callback: ProgressCallback | None = None,
|
|
57
|
+
) -> list[Snippet]:
|
|
58
|
+
"""Extract snippets from all files in an index and persist them.
|
|
59
|
+
|
|
60
|
+
This method combines the extraction and persistence logic that was
|
|
61
|
+
previously split between domain and application services.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
index_id: The ID of the index to create snippets for
|
|
65
|
+
strategy: The extraction strategy to use
|
|
66
|
+
progress_callback: Optional callback for progress reporting
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
List of created Snippet entities with IDs assigned
|
|
70
|
+
|
|
71
|
+
"""
|
|
72
|
+
files = await self.file_repository.get_files_for_index(index_id)
|
|
73
|
+
created_snippets = []
|
|
74
|
+
|
|
75
|
+
reporter = Reporter(self.log, progress_callback)
|
|
76
|
+
await reporter.start(
|
|
77
|
+
"create_snippets", len(files), "Creating snippets from files..."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
for i, file in enumerate(files, 1):
|
|
81
|
+
if not self._should_process_file(file):
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
try:
|
|
85
|
+
# Extract snippets from file
|
|
86
|
+
request = SnippetExtractionRequest(Path(file.cloned_path), strategy)
|
|
87
|
+
result = await self.snippet_extraction_service.extract_snippets(request)
|
|
88
|
+
|
|
89
|
+
# Create and persist snippet entities
|
|
90
|
+
for snippet_content in result.snippets:
|
|
91
|
+
snippet = Snippet(
|
|
92
|
+
file_id=file.id,
|
|
93
|
+
index_id=index_id,
|
|
94
|
+
content=snippet_content,
|
|
95
|
+
)
|
|
96
|
+
saved_snippet = await self.snippet_repository.save(snippet)
|
|
97
|
+
created_snippets.append(saved_snippet)
|
|
98
|
+
|
|
99
|
+
except (OSError, ValueError) as e:
|
|
100
|
+
self.log.debug(
|
|
101
|
+
"Skipping file",
|
|
102
|
+
file=file.cloned_path,
|
|
103
|
+
error=str(e),
|
|
104
|
+
)
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
await reporter.step(
|
|
108
|
+
"create_snippets",
|
|
109
|
+
current=i,
|
|
110
|
+
total=len(files),
|
|
111
|
+
message=f"Processing {file.cloned_path}...",
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
await reporter.done("create_snippets")
|
|
115
|
+
return created_snippets
|
|
116
|
+
|
|
117
|
+
async def get_snippets_for_index(self, index_id: int) -> list[Snippet]:
|
|
118
|
+
"""Get all snippets for a specific index.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
index_id: The ID of the index
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
List of Snippet entities for the index
|
|
125
|
+
|
|
126
|
+
"""
|
|
127
|
+
# This delegates to the repository but provides a domain-level interface
|
|
128
|
+
return list(await self.snippet_repository.get_by_index(index_id))
|
|
129
|
+
|
|
130
|
+
async def update_snippet_content(self, snippet_id: int, content: str) -> None:
|
|
131
|
+
"""Update the content of an existing snippet.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
snippet_id: The ID of the snippet to update
|
|
135
|
+
content: The new content for the snippet
|
|
136
|
+
|
|
137
|
+
"""
|
|
138
|
+
# Get the snippet first to ensure it exists
|
|
139
|
+
snippet = await self.snippet_repository.get(snippet_id)
|
|
140
|
+
if not snippet:
|
|
141
|
+
msg = f"Snippet not found: {snippet_id}"
|
|
142
|
+
raise ValueError(msg)
|
|
143
|
+
|
|
144
|
+
# Update the content
|
|
145
|
+
snippet.content = content
|
|
146
|
+
await self.snippet_repository.save(snippet)
|
|
147
|
+
|
|
148
|
+
async def delete_snippets_for_index(self, index_id: int) -> None:
|
|
149
|
+
"""Delete all snippets for a specific index.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
index_id: The ID of the index
|
|
153
|
+
|
|
154
|
+
"""
|
|
155
|
+
await self.snippet_repository.delete_by_index(index_id)
|
|
156
|
+
|
|
157
|
+
async def search_snippets(
|
|
158
|
+
self, request: MultiSearchRequest
|
|
159
|
+
) -> list[SnippetListItem]:
|
|
160
|
+
"""Search snippets with filters.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
request: The search request containing filters
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
List of matching snippet items
|
|
167
|
+
|
|
168
|
+
"""
|
|
169
|
+
return list(await self.snippet_repository.search(request))
|
|
170
|
+
|
|
171
|
+
async def list_snippets(
|
|
172
|
+
self, file_path: str | None = None, source_uri: str | None = None
|
|
173
|
+
) -> list[SnippetListItem]:
|
|
174
|
+
"""List snippets with optional filtering.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
file_path: Optional file path to filter by
|
|
178
|
+
source_uri: Optional source URI to filter by
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
List of snippet items matching the criteria
|
|
182
|
+
|
|
183
|
+
"""
|
|
184
|
+
return list(await self.snippet_repository.list_snippets(file_path, source_uri))
|
|
185
|
+
|
|
186
|
+
def _should_process_file(self, file: Any) -> bool:
|
|
187
|
+
"""Check if a file should be processed for snippet extraction.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
file: The file to check
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
True if the file should be processed
|
|
194
|
+
|
|
195
|
+
"""
|
|
196
|
+
# Skip unsupported file types
|
|
197
|
+
mime_blacklist = ["unknown/unknown"]
|
|
198
|
+
return file.mime_type not in mime_blacklist
|
kodit/domain/value_objects.py
CHANGED
|
@@ -4,7 +4,12 @@ from dataclasses import dataclass
|
|
|
4
4
|
from datetime import datetime
|
|
5
5
|
from enum import Enum
|
|
6
6
|
from pathlib import Path
|
|
7
|
+
from typing import Any, ClassVar
|
|
7
8
|
|
|
9
|
+
from sqlalchemy import JSON, DateTime, Integer, Text
|
|
10
|
+
from sqlalchemy.orm import Mapped, mapped_column
|
|
11
|
+
|
|
12
|
+
from kodit.domain.entities import Base
|
|
8
13
|
from kodit.domain.enums import SnippetExtractionStrategy
|
|
9
14
|
|
|
10
15
|
|
|
@@ -40,6 +45,14 @@ class Document:
|
|
|
40
45
|
text: str
|
|
41
46
|
|
|
42
47
|
|
|
48
|
+
@dataclass
|
|
49
|
+
class DocumentSearchResult:
|
|
50
|
+
"""Generic document search result model."""
|
|
51
|
+
|
|
52
|
+
snippet_id: int
|
|
53
|
+
score: float
|
|
54
|
+
|
|
55
|
+
|
|
43
56
|
@dataclass
|
|
44
57
|
class SearchResult:
|
|
45
58
|
"""Generic search result model."""
|
|
@@ -56,12 +69,12 @@ class IndexRequest:
|
|
|
56
69
|
|
|
57
70
|
|
|
58
71
|
@dataclass
|
|
59
|
-
class
|
|
72
|
+
class SearchRequest:
|
|
60
73
|
"""Generic search request (single query string)."""
|
|
61
74
|
|
|
62
75
|
query: str
|
|
63
76
|
top_k: int = 10
|
|
64
|
-
|
|
77
|
+
snippet_ids: list[int] | None = None
|
|
65
78
|
|
|
66
79
|
|
|
67
80
|
@dataclass
|
|
@@ -78,17 +91,75 @@ class IndexResult:
|
|
|
78
91
|
snippet_id: int
|
|
79
92
|
|
|
80
93
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
94
|
+
@dataclass(frozen=True)
|
|
95
|
+
class SnippetSearchFilters:
|
|
96
|
+
"""Value object for filtering snippet search results."""
|
|
97
|
+
|
|
98
|
+
language: str | None = None
|
|
99
|
+
author: str | None = None
|
|
100
|
+
created_after: datetime | None = None
|
|
101
|
+
created_before: datetime | None = None
|
|
102
|
+
source_repo: str | None = None
|
|
103
|
+
|
|
104
|
+
@classmethod
|
|
105
|
+
def from_cli_params(
|
|
106
|
+
cls,
|
|
107
|
+
language: str | None = None,
|
|
108
|
+
author: str | None = None,
|
|
109
|
+
created_after: str | None = None,
|
|
110
|
+
created_before: str | None = None,
|
|
111
|
+
source_repo: str | None = None,
|
|
112
|
+
) -> "SnippetSearchFilters | None":
|
|
113
|
+
"""Create SnippetSearchFilters from CLI parameters.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
language: Programming language filter (e.g., python, go, javascript)
|
|
117
|
+
author: Author name filter
|
|
118
|
+
created_after: Date string in YYYY-MM-DD format for filtering snippets
|
|
119
|
+
created after
|
|
120
|
+
created_before: Date string in YYYY-MM-DD format for filtering snippets
|
|
121
|
+
created before
|
|
122
|
+
source_repo: Source repository filter (e.g., github.com/example/repo)
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
SnippetSearchFilters instance if any filters are provided, None otherwise
|
|
126
|
+
|
|
127
|
+
Raises:
|
|
128
|
+
ValueError: If date strings are in invalid format
|
|
129
|
+
|
|
130
|
+
"""
|
|
131
|
+
# Only create filters if at least one parameter is provided
|
|
132
|
+
if not any([language, author, created_after, created_before, source_repo]):
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
# Parse date strings if provided
|
|
136
|
+
parsed_created_after = None
|
|
137
|
+
if created_after:
|
|
138
|
+
try:
|
|
139
|
+
parsed_created_after = datetime.fromisoformat(created_after)
|
|
140
|
+
except ValueError as e:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
f"Invalid date format for created_after: {created_after}. "
|
|
143
|
+
"Expected ISO 8601 format (YYYY-MM-DD)"
|
|
144
|
+
) from e
|
|
145
|
+
|
|
146
|
+
parsed_created_before = None
|
|
147
|
+
if created_before:
|
|
148
|
+
try:
|
|
149
|
+
parsed_created_before = datetime.fromisoformat(created_before)
|
|
150
|
+
except ValueError as e:
|
|
151
|
+
raise ValueError(
|
|
152
|
+
f"Invalid date format for created_before: {created_before}. "
|
|
153
|
+
"Expected ISO 8601 format (YYYY-MM-DD)"
|
|
154
|
+
) from e
|
|
155
|
+
|
|
156
|
+
return cls(
|
|
157
|
+
language=language,
|
|
158
|
+
author=author,
|
|
159
|
+
created_after=parsed_created_after,
|
|
160
|
+
created_before=parsed_created_before,
|
|
161
|
+
source_repo=source_repo,
|
|
162
|
+
)
|
|
92
163
|
|
|
93
164
|
|
|
94
165
|
@dataclass
|
|
@@ -99,6 +170,7 @@ class MultiSearchRequest:
|
|
|
99
170
|
text_query: str | None = None
|
|
100
171
|
code_query: str | None = None
|
|
101
172
|
keywords: list[str] | None = None
|
|
173
|
+
filters: SnippetSearchFilters | None = None
|
|
102
174
|
|
|
103
175
|
|
|
104
176
|
@dataclass
|
|
@@ -196,14 +268,6 @@ class EnrichmentIndexRequest:
|
|
|
196
268
|
requests: list[EnrichmentRequest]
|
|
197
269
|
|
|
198
270
|
|
|
199
|
-
@dataclass
|
|
200
|
-
class EnrichmentSearchRequest:
|
|
201
|
-
"""Domain model for enrichment search request."""
|
|
202
|
-
|
|
203
|
-
query: str
|
|
204
|
-
top_k: int = 10
|
|
205
|
-
|
|
206
|
-
|
|
207
271
|
@dataclass
|
|
208
272
|
class IndexView:
|
|
209
273
|
"""Domain model for index information."""
|
|
@@ -213,3 +277,219 @@ class IndexView:
|
|
|
213
277
|
num_snippets: int
|
|
214
278
|
updated_at: datetime | None = None
|
|
215
279
|
source: str | None = None
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@dataclass
|
|
283
|
+
class SnippetListItem:
|
|
284
|
+
"""Domain model for snippet list item with file information."""
|
|
285
|
+
|
|
286
|
+
id: int
|
|
287
|
+
file_path: str
|
|
288
|
+
content: str
|
|
289
|
+
source_uri: str
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class LanguageMapping:
|
|
293
|
+
"""Value object for language-to-extension mappings.
|
|
294
|
+
|
|
295
|
+
This encapsulates the domain knowledge of programming languages and their
|
|
296
|
+
associated file extensions. It provides bidirectional mapping capabilities
|
|
297
|
+
and is designed to be immutable and reusable across the application.
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
# Comprehensive mapping of language names to their file extensions
|
|
301
|
+
_LANGUAGE_TO_EXTENSIONS: ClassVar[dict[str, list[str]]] = {
|
|
302
|
+
"python": ["py", "pyw", "pyx", "pxd"],
|
|
303
|
+
"go": ["go"],
|
|
304
|
+
"javascript": ["js", "jsx", "mjs"],
|
|
305
|
+
"typescript": ["ts", "tsx"],
|
|
306
|
+
"java": ["java"],
|
|
307
|
+
"csharp": ["cs"],
|
|
308
|
+
"cpp": ["cpp", "cc", "cxx", "hpp"],
|
|
309
|
+
"c": ["c", "h"],
|
|
310
|
+
"rust": ["rs"],
|
|
311
|
+
"php": ["php"],
|
|
312
|
+
"ruby": ["rb"],
|
|
313
|
+
"swift": ["swift"],
|
|
314
|
+
"kotlin": ["kt", "kts"],
|
|
315
|
+
"scala": ["scala"],
|
|
316
|
+
"r": ["r", "R"],
|
|
317
|
+
"matlab": ["m"],
|
|
318
|
+
"perl": ["pl", "pm"],
|
|
319
|
+
"bash": ["sh", "bash"],
|
|
320
|
+
"powershell": ["ps1"],
|
|
321
|
+
"sql": ["sql"],
|
|
322
|
+
"html": ["html", "htm"],
|
|
323
|
+
"css": ["css", "scss", "sass"],
|
|
324
|
+
"yaml": ["yml", "yaml"],
|
|
325
|
+
"json": ["json"],
|
|
326
|
+
"xml": ["xml"],
|
|
327
|
+
"markdown": ["md", "markdown"],
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
@classmethod
|
|
331
|
+
def get_extensions_for_language(cls, language: str) -> list[str]:
|
|
332
|
+
"""Get file extensions for a given language.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
language: The programming language name (case-insensitive)
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
List of file extensions (without leading dots) for the language
|
|
339
|
+
|
|
340
|
+
Raises:
|
|
341
|
+
ValueError: If the language is not supported
|
|
342
|
+
|
|
343
|
+
"""
|
|
344
|
+
language_lower = language.lower()
|
|
345
|
+
extensions = cls._LANGUAGE_TO_EXTENSIONS.get(language_lower)
|
|
346
|
+
|
|
347
|
+
if extensions is None:
|
|
348
|
+
raise ValueError(f"Unsupported language: {language}")
|
|
349
|
+
|
|
350
|
+
return extensions.copy() # Return a copy to prevent modification
|
|
351
|
+
|
|
352
|
+
@classmethod
|
|
353
|
+
def get_language_for_extension(cls, extension: str) -> str:
|
|
354
|
+
"""Get language for a given file extension.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
extension: The file extension (with or without leading dot)
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
The programming language name
|
|
361
|
+
|
|
362
|
+
Raises:
|
|
363
|
+
ValueError: If the extension is not supported
|
|
364
|
+
|
|
365
|
+
"""
|
|
366
|
+
# Remove leading dot if present
|
|
367
|
+
ext_clean = extension.removeprefix(".").lower()
|
|
368
|
+
|
|
369
|
+
# Search through all languages to find matching extension
|
|
370
|
+
for language, extensions in cls._LANGUAGE_TO_EXTENSIONS.items():
|
|
371
|
+
if ext_clean in extensions:
|
|
372
|
+
return language
|
|
373
|
+
|
|
374
|
+
raise ValueError(f"Unsupported file extension: {extension}")
|
|
375
|
+
|
|
376
|
+
@classmethod
|
|
377
|
+
def get_extension_to_language_map(cls) -> dict[str, str]:
|
|
378
|
+
"""Get a mapping from file extensions to language names.
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
Dictionary mapping file extensions (without leading dots) to language names
|
|
382
|
+
|
|
383
|
+
"""
|
|
384
|
+
extension_map = {}
|
|
385
|
+
for language, extensions in cls._LANGUAGE_TO_EXTENSIONS.items():
|
|
386
|
+
for extension in extensions:
|
|
387
|
+
extension_map[extension] = language
|
|
388
|
+
return extension_map
|
|
389
|
+
|
|
390
|
+
@classmethod
|
|
391
|
+
def get_supported_languages(cls) -> list[str]:
|
|
392
|
+
"""Get list of all supported programming languages.
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
List of supported language names
|
|
396
|
+
|
|
397
|
+
"""
|
|
398
|
+
return list(cls._LANGUAGE_TO_EXTENSIONS.keys())
|
|
399
|
+
|
|
400
|
+
@classmethod
|
|
401
|
+
def get_supported_extensions(cls) -> list[str]:
|
|
402
|
+
"""Get list of all supported file extensions.
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
List of supported file extensions (without leading dots)
|
|
406
|
+
|
|
407
|
+
"""
|
|
408
|
+
extensions = []
|
|
409
|
+
for ext_list in cls._LANGUAGE_TO_EXTENSIONS.values():
|
|
410
|
+
extensions.extend(ext_list)
|
|
411
|
+
return extensions
|
|
412
|
+
|
|
413
|
+
@classmethod
|
|
414
|
+
def is_supported_language(cls, language: str) -> bool:
|
|
415
|
+
"""Check if a language is supported.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
language: The programming language name (case-insensitive)
|
|
419
|
+
|
|
420
|
+
Returns:
|
|
421
|
+
True if the language is supported, False otherwise
|
|
422
|
+
|
|
423
|
+
"""
|
|
424
|
+
return language.lower() in cls._LANGUAGE_TO_EXTENSIONS
|
|
425
|
+
|
|
426
|
+
@classmethod
|
|
427
|
+
def is_supported_extension(cls, extension: str) -> bool:
|
|
428
|
+
"""Check if a file extension is supported.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
extension: The file extension (with or without leading dot)
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
True if the extension is supported, False otherwise
|
|
435
|
+
|
|
436
|
+
"""
|
|
437
|
+
try:
|
|
438
|
+
cls.get_language_for_extension(extension)
|
|
439
|
+
except ValueError:
|
|
440
|
+
return False
|
|
441
|
+
return True
|
|
442
|
+
|
|
443
|
+
@classmethod
|
|
444
|
+
def get_extensions_with_fallback(cls, language: str) -> list[str]:
|
|
445
|
+
"""Get file extensions for a language, falling back to passed language name.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
language: The programming language name (case-insensitive)
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
List of file extensions (without leading dots) for the language, or
|
|
452
|
+
[language.lower()] if not found.
|
|
453
|
+
|
|
454
|
+
"""
|
|
455
|
+
language_lower = language.lower()
|
|
456
|
+
if cls.is_supported_language(language_lower):
|
|
457
|
+
return cls.get_extensions_for_language(language_lower)
|
|
458
|
+
return [language_lower]
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
# Database models for value objects
|
|
462
|
+
class BM25DocumentModel(Base):
|
|
463
|
+
"""BM25 document model."""
|
|
464
|
+
|
|
465
|
+
__tablename__ = "bm25_documents"
|
|
466
|
+
|
|
467
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
|
468
|
+
content: Mapped[str] = mapped_column(Text, nullable=False)
|
|
469
|
+
document_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
|
470
|
+
JSON, nullable=True
|
|
471
|
+
)
|
|
472
|
+
created_at: Mapped[datetime] = mapped_column(
|
|
473
|
+
DateTime(timezone=True), nullable=False
|
|
474
|
+
)
|
|
475
|
+
updated_at: Mapped[datetime] = mapped_column(
|
|
476
|
+
DateTime(timezone=True), nullable=False
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
class VectorDocumentModel(Base):
|
|
481
|
+
"""Vector document model."""
|
|
482
|
+
|
|
483
|
+
__tablename__ = "vector_documents"
|
|
484
|
+
|
|
485
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
|
486
|
+
content: Mapped[str] = mapped_column(Text, nullable=False)
|
|
487
|
+
document_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
|
488
|
+
JSON, nullable=True
|
|
489
|
+
)
|
|
490
|
+
created_at: Mapped[datetime] = mapped_column(
|
|
491
|
+
DateTime(timezone=True), nullable=False
|
|
492
|
+
)
|
|
493
|
+
updated_at: Mapped[datetime] = mapped_column(
|
|
494
|
+
DateTime(timezone=True), nullable=False
|
|
495
|
+
)
|
|
@@ -12,10 +12,10 @@ import structlog
|
|
|
12
12
|
|
|
13
13
|
from kodit.domain.services.bm25_service import BM25Repository
|
|
14
14
|
from kodit.domain.value_objects import (
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
15
|
+
DeleteRequest,
|
|
16
|
+
IndexRequest,
|
|
17
|
+
SearchRequest,
|
|
18
|
+
SearchResult,
|
|
19
19
|
)
|
|
20
20
|
|
|
21
21
|
if TYPE_CHECKING:
|
|
@@ -68,7 +68,7 @@ class LocalBM25Repository(BM25Repository):
|
|
|
68
68
|
show_progress=True,
|
|
69
69
|
)
|
|
70
70
|
|
|
71
|
-
async def index_documents(self, request:
|
|
71
|
+
async def index_documents(self, request: IndexRequest) -> None:
|
|
72
72
|
"""Index documents for BM25 search."""
|
|
73
73
|
self.log.debug("Indexing corpus")
|
|
74
74
|
if not request.documents:
|
|
@@ -84,7 +84,7 @@ class LocalBM25Repository(BM25Repository):
|
|
|
84
84
|
async with aiofiles.open(self.index_path / SNIPPET_IDS_FILE, "w") as f:
|
|
85
85
|
await f.write(json.dumps(self.snippet_ids))
|
|
86
86
|
|
|
87
|
-
async def search(self, request:
|
|
87
|
+
async def search(self, request: SearchRequest) -> list[SearchResult]:
|
|
88
88
|
"""Search documents using BM25."""
|
|
89
89
|
if request.top_k == 0:
|
|
90
90
|
self.log.warning("Top k is 0, returning empty list")
|
|
@@ -117,13 +117,21 @@ class LocalBM25Repository(BM25Repository):
|
|
|
117
117
|
k=top_k,
|
|
118
118
|
)
|
|
119
119
|
self.log.debug("Raw results", results=results, scores=scores)
|
|
120
|
-
return [
|
|
121
|
-
BM25SearchResult(snippet_id=int(result), score=float(score))
|
|
122
|
-
for result, score in zip(results[0], scores[0], strict=False)
|
|
123
|
-
if score > 0.0
|
|
124
|
-
]
|
|
125
120
|
|
|
126
|
-
|
|
121
|
+
# Filter results by snippet_ids if provided
|
|
122
|
+
filtered_results = []
|
|
123
|
+
for result, score in zip(results[0], scores[0], strict=False):
|
|
124
|
+
snippet_id = int(result)
|
|
125
|
+
if score > 0.0 and (
|
|
126
|
+
request.snippet_ids is None or snippet_id in request.snippet_ids
|
|
127
|
+
):
|
|
128
|
+
filtered_results.append(
|
|
129
|
+
SearchResult(snippet_id=snippet_id, score=float(score))
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
return filtered_results
|
|
133
|
+
|
|
134
|
+
async def delete_documents(self, request: DeleteRequest) -> None:
|
|
127
135
|
"""Delete documents from the index."""
|
|
128
136
|
# request parameter is unused as deletion is not supported
|
|
129
137
|
# ruff: noqa: ARG002
|