kodit 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/app.py +53 -23
- kodit/application/factories/reporting_factory.py +6 -2
- kodit/application/factories/server_factory.py +311 -0
- kodit/application/services/code_search_application_service.py +144 -0
- kodit/application/services/commit_indexing_application_service.py +543 -0
- kodit/application/services/indexing_worker_service.py +13 -44
- kodit/application/services/queue_service.py +24 -3
- kodit/application/services/reporting.py +0 -2
- kodit/application/services/sync_scheduler.py +15 -31
- kodit/cli.py +2 -753
- kodit/cli_utils.py +2 -9
- kodit/config.py +1 -94
- kodit/database.py +38 -1
- kodit/domain/{entities.py → entities/__init__.py} +50 -195
- kodit/domain/entities/git.py +190 -0
- kodit/domain/factories/__init__.py +1 -0
- kodit/domain/factories/git_repo_factory.py +76 -0
- kodit/domain/protocols.py +263 -64
- kodit/domain/services/bm25_service.py +5 -1
- kodit/domain/services/embedding_service.py +3 -0
- kodit/domain/services/git_repository_service.py +429 -0
- kodit/domain/services/git_service.py +300 -0
- kodit/domain/services/task_status_query_service.py +2 -2
- kodit/domain/value_objects.py +83 -114
- kodit/infrastructure/api/client/__init__.py +0 -2
- kodit/infrastructure/api/v1/__init__.py +0 -4
- kodit/infrastructure/api/v1/dependencies.py +92 -46
- kodit/infrastructure/api/v1/routers/__init__.py +0 -6
- kodit/infrastructure/api/v1/routers/commits.py +271 -0
- kodit/infrastructure/api/v1/routers/queue.py +2 -2
- kodit/infrastructure/api/v1/routers/repositories.py +282 -0
- kodit/infrastructure/api/v1/routers/search.py +31 -14
- kodit/infrastructure/api/v1/schemas/__init__.py +0 -24
- kodit/infrastructure/api/v1/schemas/commit.py +96 -0
- kodit/infrastructure/api/v1/schemas/context.py +2 -0
- kodit/infrastructure/api/v1/schemas/repository.py +128 -0
- kodit/infrastructure/api/v1/schemas/search.py +12 -9
- kodit/infrastructure/api/v1/schemas/snippet.py +58 -0
- kodit/infrastructure/api/v1/schemas/tag.py +31 -0
- kodit/infrastructure/api/v1/schemas/task_status.py +2 -0
- kodit/infrastructure/bm25/local_bm25_repository.py +16 -4
- kodit/infrastructure/bm25/vectorchord_bm25_repository.py +68 -52
- kodit/infrastructure/cloning/git/git_python_adaptor.py +467 -0
- kodit/infrastructure/cloning/git/working_copy.py +1 -1
- kodit/infrastructure/embedding/embedding_factory.py +3 -2
- kodit/infrastructure/embedding/local_vector_search_repository.py +1 -1
- kodit/infrastructure/embedding/vectorchord_vector_search_repository.py +111 -84
- kodit/infrastructure/enrichment/litellm_enrichment_provider.py +19 -26
- kodit/infrastructure/indexing/fusion_service.py +1 -1
- kodit/infrastructure/mappers/git_mapper.py +193 -0
- kodit/infrastructure/mappers/snippet_mapper.py +106 -0
- kodit/infrastructure/mappers/task_mapper.py +5 -44
- kodit/infrastructure/reporting/log_progress.py +8 -5
- kodit/infrastructure/reporting/telemetry_progress.py +21 -0
- kodit/infrastructure/slicing/slicer.py +32 -31
- kodit/infrastructure/sqlalchemy/embedding_repository.py +43 -23
- kodit/infrastructure/sqlalchemy/entities.py +394 -158
- kodit/infrastructure/sqlalchemy/git_branch_repository.py +263 -0
- kodit/infrastructure/sqlalchemy/git_commit_repository.py +337 -0
- kodit/infrastructure/sqlalchemy/git_repository.py +252 -0
- kodit/infrastructure/sqlalchemy/git_tag_repository.py +257 -0
- kodit/infrastructure/sqlalchemy/snippet_v2_repository.py +484 -0
- kodit/infrastructure/sqlalchemy/task_repository.py +29 -23
- kodit/infrastructure/sqlalchemy/task_status_repository.py +24 -12
- kodit/infrastructure/sqlalchemy/unit_of_work.py +10 -14
- kodit/mcp.py +12 -30
- kodit/migrations/env.py +1 -0
- kodit/migrations/versions/04b80f802e0c_foreign_key_review.py +100 -0
- kodit/migrations/versions/7f15f878c3a1_add_new_git_entities.py +690 -0
- kodit/migrations/versions/f9e5ef5e688f_add_git_commits_number.py +43 -0
- kodit/py.typed +0 -0
- kodit/utils/dump_openapi.py +7 -4
- kodit/utils/path_utils.py +29 -0
- {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/METADATA +3 -3
- kodit-0.5.0.dist-info/RECORD +137 -0
- kodit/application/factories/code_indexing_factory.py +0 -195
- kodit/application/services/auto_indexing_service.py +0 -99
- kodit/application/services/code_indexing_application_service.py +0 -410
- kodit/domain/services/index_query_service.py +0 -70
- kodit/domain/services/index_service.py +0 -269
- kodit/infrastructure/api/client/index_client.py +0 -57
- kodit/infrastructure/api/v1/routers/indexes.py +0 -164
- kodit/infrastructure/api/v1/schemas/index.py +0 -101
- kodit/infrastructure/bm25/bm25_factory.py +0 -28
- kodit/infrastructure/cloning/__init__.py +0 -1
- kodit/infrastructure/cloning/metadata.py +0 -98
- kodit/infrastructure/mappers/index_mapper.py +0 -345
- kodit/infrastructure/reporting/tdqm_progress.py +0 -38
- kodit/infrastructure/slicing/language_detection_service.py +0 -18
- kodit/infrastructure/sqlalchemy/index_repository.py +0 -646
- kodit-0.4.3.dist-info/RECORD +0 -125
- {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/WHEEL +0 -0
- {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/entry_points.txt +0 -0
- {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,43 +1,10 @@
|
|
|
1
1
|
"""Task mapper for the task queue."""
|
|
2
2
|
|
|
3
|
-
from typing import ClassVar
|
|
4
|
-
|
|
5
3
|
from kodit.domain.entities import Task
|
|
6
|
-
from kodit.domain.value_objects import
|
|
4
|
+
from kodit.domain.value_objects import TaskOperation
|
|
7
5
|
from kodit.infrastructure.sqlalchemy import entities as db_entities
|
|
8
6
|
|
|
9
7
|
|
|
10
|
-
class TaskTypeMapper:
|
|
11
|
-
"""Maps between domain QueuedTaskType and SQLAlchemy TaskType."""
|
|
12
|
-
|
|
13
|
-
# Map TaskType enum to QueuedTaskType
|
|
14
|
-
TASK_TYPE_MAPPING: ClassVar[dict[db_entities.TaskType, TaskType]] = {
|
|
15
|
-
db_entities.TaskType.INDEX_UPDATE: TaskType.INDEX_UPDATE,
|
|
16
|
-
}
|
|
17
|
-
|
|
18
|
-
@staticmethod
|
|
19
|
-
def to_domain_type(task_type: db_entities.TaskType) -> TaskType:
|
|
20
|
-
"""Convert SQLAlchemy TaskType to domain QueuedTaskType."""
|
|
21
|
-
if task_type not in TaskTypeMapper.TASK_TYPE_MAPPING:
|
|
22
|
-
raise ValueError(f"Unknown task type: {task_type}")
|
|
23
|
-
return TaskTypeMapper.TASK_TYPE_MAPPING[task_type]
|
|
24
|
-
|
|
25
|
-
@staticmethod
|
|
26
|
-
def from_domain_type(task_type: TaskType) -> db_entities.TaskType:
|
|
27
|
-
"""Convert domain QueuedTaskType to SQLAlchemy TaskType."""
|
|
28
|
-
if task_type not in TaskTypeMapper.TASK_TYPE_MAPPING.values():
|
|
29
|
-
raise ValueError(f"Unknown task type: {task_type}")
|
|
30
|
-
|
|
31
|
-
# Find value in TASK_TYPE_MAPPING
|
|
32
|
-
return next(
|
|
33
|
-
(
|
|
34
|
-
db_task_type
|
|
35
|
-
for db_task_type, domain_task_type in TaskTypeMapper.TASK_TYPE_MAPPING.items() # noqa: E501
|
|
36
|
-
if domain_task_type == task_type
|
|
37
|
-
)
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
|
|
41
8
|
class TaskMapper:
|
|
42
9
|
"""Maps between domain QueuedTask and SQLAlchemy Task entities.
|
|
43
10
|
|
|
@@ -52,13 +19,12 @@ class TaskMapper:
|
|
|
52
19
|
Since QueuedTask doesn't have status fields, we store processing
|
|
53
20
|
state in the payload.
|
|
54
21
|
"""
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
22
|
+
if record.type not in TaskOperation.__members__.values():
|
|
23
|
+
raise ValueError(f"Unknown operation: {record.type}")
|
|
58
24
|
# The dedup_key becomes the id in the domain entity
|
|
59
25
|
return Task(
|
|
60
26
|
id=record.dedup_key, # Use dedup_key as the unique identifier
|
|
61
|
-
type=
|
|
27
|
+
type=TaskOperation(record.type),
|
|
62
28
|
priority=record.priority,
|
|
63
29
|
payload=record.payload or {},
|
|
64
30
|
created_at=record.created_at,
|
|
@@ -68,14 +34,9 @@ class TaskMapper:
|
|
|
68
34
|
@staticmethod
|
|
69
35
|
def from_domain_task(task: Task) -> db_entities.Task:
|
|
70
36
|
"""Convert domain QueuedTask to SQLAlchemy Task record."""
|
|
71
|
-
if task.type not in TaskTypeMapper.TASK_TYPE_MAPPING.values():
|
|
72
|
-
raise ValueError(f"Unknown task type: {task.type}")
|
|
73
|
-
|
|
74
|
-
# Find value in TASK_TYPE_MAPPING
|
|
75
|
-
task_type = TaskTypeMapper.from_domain_type(task.type)
|
|
76
37
|
return db_entities.Task(
|
|
77
38
|
dedup_key=task.id,
|
|
78
|
-
type=
|
|
39
|
+
type=task.type.value,
|
|
79
40
|
payload=task.payload,
|
|
80
41
|
priority=task.priority,
|
|
81
42
|
)
|
|
@@ -22,13 +22,16 @@ class LoggingReportingModule(ReportingModule):
|
|
|
22
22
|
async def on_change(self, progress: TaskStatus) -> None:
|
|
23
23
|
"""On step changed."""
|
|
24
24
|
current_time = datetime.now(UTC)
|
|
25
|
-
time_since_last_log = current_time - self._last_log_time
|
|
26
25
|
step = progress
|
|
27
26
|
|
|
28
|
-
if
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
27
|
+
if step.state == ReportingState.FAILED:
|
|
28
|
+
self._log.exception(
|
|
29
|
+
step.operation,
|
|
30
|
+
state=step.state,
|
|
31
|
+
completion_percent=step.completion_percent,
|
|
32
|
+
error=step.error,
|
|
33
|
+
)
|
|
34
|
+
else:
|
|
32
35
|
self._log.info(
|
|
33
36
|
step.operation,
|
|
34
37
|
state=step.state,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Log progress using telemetry."""
|
|
2
|
+
|
|
3
|
+
import structlog
|
|
4
|
+
|
|
5
|
+
from kodit.domain.entities import TaskStatus
|
|
6
|
+
from kodit.domain.protocols import ReportingModule
|
|
7
|
+
from kodit.log import log_event
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TelemetryProgressReportingModule(ReportingModule):
|
|
11
|
+
"""Database progress reporting module."""
|
|
12
|
+
|
|
13
|
+
def __init__(self) -> None:
|
|
14
|
+
"""Initialize the logging reporting module."""
|
|
15
|
+
self._log = structlog.get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
async def on_change(self, progress: TaskStatus) -> None:
|
|
18
|
+
"""On step changed."""
|
|
19
|
+
log_event(
|
|
20
|
+
progress.operation,
|
|
21
|
+
)
|
|
@@ -14,7 +14,7 @@ import structlog
|
|
|
14
14
|
from tree_sitter import Node, Parser, Tree
|
|
15
15
|
from tree_sitter_language_pack import get_language
|
|
16
16
|
|
|
17
|
-
from kodit.domain.entities import
|
|
17
|
+
from kodit.domain.entities.git import GitFile, SnippetV2
|
|
18
18
|
from kodit.domain.value_objects import LanguageMapping
|
|
19
19
|
|
|
20
20
|
|
|
@@ -149,9 +149,9 @@ class Slicer:
|
|
|
149
149
|
"""Initialize an empty slicer."""
|
|
150
150
|
self.log = structlog.get_logger(__name__)
|
|
151
151
|
|
|
152
|
-
def
|
|
153
|
-
self, files: list[
|
|
154
|
-
) -> list[
|
|
152
|
+
def extract_snippets_from_git_files( # noqa: C901
|
|
153
|
+
self, files: list[GitFile], language: str = "python"
|
|
154
|
+
) -> list[SnippetV2]:
|
|
155
155
|
"""Extract code snippets from a list of files.
|
|
156
156
|
|
|
157
157
|
Args:
|
|
@@ -187,10 +187,10 @@ class Slicer:
|
|
|
187
187
|
raise RuntimeError(f"Failed to load {language} parser: {e}") from e
|
|
188
188
|
|
|
189
189
|
# Create mapping from Paths to File objects and extract paths
|
|
190
|
-
path_to_file_map: dict[Path,
|
|
190
|
+
path_to_file_map: dict[Path, GitFile] = {}
|
|
191
191
|
file_paths: list[Path] = []
|
|
192
192
|
for file in files:
|
|
193
|
-
file_path = file.
|
|
193
|
+
file_path = Path(file.path)
|
|
194
194
|
|
|
195
195
|
# Validate file matches language
|
|
196
196
|
if not self._file_matches_language(file_path.suffix, language):
|
|
@@ -225,7 +225,7 @@ class Slicer:
|
|
|
225
225
|
self._build_reverse_call_graph(state)
|
|
226
226
|
|
|
227
227
|
# Extract snippets for all functions
|
|
228
|
-
snippets = []
|
|
228
|
+
snippets: list[SnippetV2] = []
|
|
229
229
|
for qualified_name in state.def_index:
|
|
230
230
|
snippet_content = self._get_snippet(
|
|
231
231
|
qualified_name,
|
|
@@ -234,7 +234,7 @@ class Slicer:
|
|
|
234
234
|
{"max_depth": 2, "max_functions": 8},
|
|
235
235
|
)
|
|
236
236
|
if "not found" not in snippet_content:
|
|
237
|
-
snippet = self.
|
|
237
|
+
snippet = self._create_snippet_entity_from_git_files(
|
|
238
238
|
qualified_name, snippet_content, language, state, path_to_file_map
|
|
239
239
|
)
|
|
240
240
|
snippets.append(snippet)
|
|
@@ -247,8 +247,8 @@ class Slicer:
|
|
|
247
247
|
return False
|
|
248
248
|
|
|
249
249
|
try:
|
|
250
|
-
return (
|
|
251
|
-
|
|
250
|
+
return language == LanguageMapping.get_language_for_extension(
|
|
251
|
+
file_extension
|
|
252
252
|
)
|
|
253
253
|
except ValueError:
|
|
254
254
|
# Extension not supported, so it doesn't match any language
|
|
@@ -614,7 +614,8 @@ class Slicer:
|
|
|
614
614
|
if callers:
|
|
615
615
|
snippet_lines.append("")
|
|
616
616
|
snippet_lines.append("# === USAGE EXAMPLES ===")
|
|
617
|
-
|
|
617
|
+
# Show up to 2 examples, sorted for deterministic order
|
|
618
|
+
for caller in sorted(callers)[:2]:
|
|
618
619
|
call_line = self._find_function_call_line(
|
|
619
620
|
caller, function_name, state, file_contents
|
|
620
621
|
)
|
|
@@ -625,37 +626,37 @@ class Slicer:
|
|
|
625
626
|
|
|
626
627
|
return "\n".join(snippet_lines)
|
|
627
628
|
|
|
628
|
-
def
|
|
629
|
+
def _create_snippet_entity_from_git_files(
|
|
629
630
|
self,
|
|
630
631
|
qualified_name: str,
|
|
631
632
|
snippet_content: str,
|
|
632
633
|
language: str,
|
|
633
634
|
state: AnalyzerState,
|
|
634
|
-
path_to_file_map: dict[Path,
|
|
635
|
-
) ->
|
|
635
|
+
path_to_file_map: dict[Path, GitFile],
|
|
636
|
+
) -> SnippetV2:
|
|
636
637
|
"""Create a Snippet domain entity from extracted content."""
|
|
637
638
|
# Determine all files that this snippet derives from
|
|
638
|
-
derives_from_files = self.
|
|
639
|
+
derives_from_files = self._find_source_files_for_snippet_from_git_files(
|
|
639
640
|
qualified_name, snippet_content, state, path_to_file_map
|
|
640
641
|
)
|
|
641
642
|
|
|
642
643
|
# Create the snippet entity
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
644
|
+
return SnippetV2(
|
|
645
|
+
derives_from=derives_from_files,
|
|
646
|
+
content=snippet_content,
|
|
647
|
+
extension=language,
|
|
648
|
+
sha=SnippetV2.compute_sha(snippet_content),
|
|
649
|
+
)
|
|
649
650
|
|
|
650
|
-
def
|
|
651
|
+
def _find_source_files_for_snippet_from_git_files(
|
|
651
652
|
self,
|
|
652
653
|
qualified_name: str,
|
|
653
654
|
snippet_content: str,
|
|
654
655
|
state: AnalyzerState,
|
|
655
|
-
path_to_file_map: dict[Path,
|
|
656
|
-
) -> list[
|
|
656
|
+
path_to_file_map: dict[Path, GitFile],
|
|
657
|
+
) -> list[GitFile]:
|
|
657
658
|
"""Find all source files that a snippet derives from."""
|
|
658
|
-
source_files: list[
|
|
659
|
+
source_files: list[GitFile] = []
|
|
659
660
|
source_file_paths: set[Path] = set()
|
|
660
661
|
|
|
661
662
|
# Add the primary function's file
|
|
@@ -835,7 +836,7 @@ class Slicer:
|
|
|
835
836
|
# Add direct dependencies
|
|
836
837
|
to_visit.extend(
|
|
837
838
|
(callee, depth + 1)
|
|
838
|
-
for callee in state.call_graph.get(current, set())
|
|
839
|
+
for callee in sorted(state.call_graph.get(current, set()))
|
|
839
840
|
if callee not in visited and callee in state.def_index
|
|
840
841
|
)
|
|
841
842
|
|
|
@@ -850,26 +851,26 @@ class Slicer:
|
|
|
850
851
|
in_degree: dict[str, int] = defaultdict(int)
|
|
851
852
|
graph: dict[str, set[str]] = defaultdict(set)
|
|
852
853
|
|
|
853
|
-
for func in functions:
|
|
854
|
-
for callee in state.call_graph.get(func, set()):
|
|
854
|
+
for func in sorted(functions):
|
|
855
|
+
for callee in sorted(state.call_graph.get(func, set())):
|
|
855
856
|
if callee in functions:
|
|
856
857
|
graph[func].add(callee)
|
|
857
858
|
in_degree[callee] += 1
|
|
858
859
|
|
|
859
860
|
# Find roots
|
|
860
|
-
queue = [f for f in functions if in_degree[f] == 0]
|
|
861
|
+
queue = [f for f in sorted(functions) if in_degree[f] == 0]
|
|
861
862
|
result = []
|
|
862
863
|
|
|
863
864
|
while queue:
|
|
864
865
|
current = queue.pop(0)
|
|
865
866
|
result.append(current)
|
|
866
|
-
for neighbor in graph[current]:
|
|
867
|
+
for neighbor in sorted(graph[current]):
|
|
867
868
|
in_degree[neighbor] -= 1
|
|
868
869
|
if in_degree[neighbor] == 0:
|
|
869
870
|
queue.append(neighbor)
|
|
870
871
|
|
|
871
872
|
# Add any remaining (cycles)
|
|
872
|
-
for func in functions:
|
|
873
|
+
for func in sorted(functions):
|
|
873
874
|
if func not in result:
|
|
874
875
|
result.append(func)
|
|
875
876
|
|
|
@@ -14,59 +14,79 @@ def create_embedding_repository(
|
|
|
14
14
|
session_factory: Callable[[], AsyncSession],
|
|
15
15
|
) -> "SqlAlchemyEmbeddingRepository":
|
|
16
16
|
"""Create an embedding repository."""
|
|
17
|
-
|
|
18
|
-
return SqlAlchemyEmbeddingRepository(uow)
|
|
17
|
+
return SqlAlchemyEmbeddingRepository(session_factory=session_factory)
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
class SqlAlchemyEmbeddingRepository:
|
|
22
21
|
"""SQLAlchemy implementation of embedding repository."""
|
|
23
22
|
|
|
24
|
-
def __init__(self,
|
|
23
|
+
def __init__(self, session_factory: Callable[[], AsyncSession]) -> None:
|
|
25
24
|
"""Initialize the SQLAlchemy embedding repository."""
|
|
26
|
-
self.
|
|
25
|
+
self.session_factory = session_factory
|
|
27
26
|
|
|
28
27
|
async def create_embedding(self, embedding: Embedding) -> None:
|
|
29
28
|
"""Create a new embedding record in the database."""
|
|
30
|
-
async with self.
|
|
31
|
-
|
|
29
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
30
|
+
session.add(embedding)
|
|
32
31
|
|
|
33
32
|
async def get_embedding_by_snippet_id_and_type(
|
|
34
33
|
self, snippet_id: int, embedding_type: EmbeddingType
|
|
35
34
|
) -> Embedding | None:
|
|
36
35
|
"""Get an embedding by its snippet ID and type."""
|
|
37
|
-
async with self.
|
|
36
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
38
37
|
query = select(Embedding).where(
|
|
39
38
|
Embedding.snippet_id == snippet_id,
|
|
40
39
|
Embedding.type == embedding_type,
|
|
41
40
|
)
|
|
42
|
-
result = await
|
|
41
|
+
result = await session.execute(query)
|
|
43
42
|
return result.scalar_one_or_none()
|
|
44
43
|
|
|
45
44
|
async def list_embeddings_by_type(
|
|
46
45
|
self, embedding_type: EmbeddingType
|
|
47
46
|
) -> list[Embedding]:
|
|
48
47
|
"""List all embeddings of a given type."""
|
|
49
|
-
async with self.
|
|
48
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
50
49
|
query = select(Embedding).where(Embedding.type == embedding_type)
|
|
51
|
-
result = await
|
|
50
|
+
result = await session.execute(query)
|
|
52
51
|
return list(result.scalars())
|
|
53
52
|
|
|
54
|
-
async def delete_embeddings_by_snippet_id(self, snippet_id:
|
|
53
|
+
async def delete_embeddings_by_snippet_id(self, snippet_id: str) -> None:
|
|
55
54
|
"""Delete all embeddings for a snippet."""
|
|
56
|
-
async with self.
|
|
55
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
57
56
|
query = select(Embedding).where(Embedding.snippet_id == snippet_id)
|
|
58
|
-
result = await
|
|
57
|
+
result = await session.execute(query)
|
|
59
58
|
embeddings = result.scalars().all()
|
|
60
59
|
for embedding in embeddings:
|
|
61
|
-
await
|
|
60
|
+
await session.delete(embedding)
|
|
61
|
+
|
|
62
|
+
async def list_embeddings_by_snippet_ids_and_type(
|
|
63
|
+
self, snippet_ids: list[str], embedding_type: EmbeddingType
|
|
64
|
+
) -> list[Embedding]:
|
|
65
|
+
"""Get all embeddings for the given snippet IDs."""
|
|
66
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
67
|
+
query = select(Embedding).where(
|
|
68
|
+
Embedding.snippet_id.in_(snippet_ids),
|
|
69
|
+
Embedding.type == embedding_type,
|
|
70
|
+
)
|
|
71
|
+
result = await session.execute(query)
|
|
72
|
+
return list(result.scalars())
|
|
73
|
+
|
|
74
|
+
async def get_embeddings_by_snippet_ids(
|
|
75
|
+
self, snippet_ids: list[str]
|
|
76
|
+
) -> list[Embedding]:
|
|
77
|
+
"""Get all embeddings for the given snippet IDs."""
|
|
78
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
79
|
+
query = select(Embedding).where(Embedding.snippet_id.in_(snippet_ids))
|
|
80
|
+
result = await session.execute(query)
|
|
81
|
+
return list(result.scalars())
|
|
62
82
|
|
|
63
83
|
async def list_semantic_results(
|
|
64
84
|
self,
|
|
65
85
|
embedding_type: EmbeddingType,
|
|
66
86
|
embedding: list[float],
|
|
67
87
|
top_k: int = 10,
|
|
68
|
-
snippet_ids: list[
|
|
69
|
-
) -> list[tuple[
|
|
88
|
+
snippet_ids: list[str] | None = None,
|
|
89
|
+
) -> list[tuple[str, float]]:
|
|
70
90
|
"""List semantic results using cosine similarity.
|
|
71
91
|
|
|
72
92
|
This implementation fetches all embeddings of the given type and computes
|
|
@@ -97,8 +117,8 @@ class SqlAlchemyEmbeddingRepository:
|
|
|
97
117
|
return self._get_top_k_results(similarities, embeddings, top_k)
|
|
98
118
|
|
|
99
119
|
async def _list_embedding_values(
|
|
100
|
-
self, embedding_type: EmbeddingType, snippet_ids: list[
|
|
101
|
-
) -> list[tuple[
|
|
120
|
+
self, embedding_type: EmbeddingType, snippet_ids: list[str] | None = None
|
|
121
|
+
) -> list[tuple[str, list[float]]]:
|
|
102
122
|
"""List all embeddings of a given type from the database.
|
|
103
123
|
|
|
104
124
|
Args:
|
|
@@ -109,7 +129,7 @@ class SqlAlchemyEmbeddingRepository:
|
|
|
109
129
|
List of (snippet_id, embedding) tuples
|
|
110
130
|
|
|
111
131
|
"""
|
|
112
|
-
async with self.
|
|
132
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
113
133
|
query = select(Embedding.snippet_id, Embedding.embedding).where(
|
|
114
134
|
Embedding.type == embedding_type
|
|
115
135
|
)
|
|
@@ -118,11 +138,11 @@ class SqlAlchemyEmbeddingRepository:
|
|
|
118
138
|
if snippet_ids is not None:
|
|
119
139
|
query = query.where(Embedding.snippet_id.in_(snippet_ids))
|
|
120
140
|
|
|
121
|
-
rows = await
|
|
141
|
+
rows = await session.execute(query)
|
|
122
142
|
return [tuple(row) for row in rows.all()] # Convert Row objects to tuples
|
|
123
143
|
|
|
124
144
|
def _prepare_vectors(
|
|
125
|
-
self, embeddings: list[tuple[
|
|
145
|
+
self, embeddings: list[tuple[str, list[float]]], query_embedding: list[float]
|
|
126
146
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
127
147
|
"""Convert embeddings to numpy arrays.
|
|
128
148
|
|
|
@@ -191,9 +211,9 @@ class SqlAlchemyEmbeddingRepository:
|
|
|
191
211
|
def _get_top_k_results(
|
|
192
212
|
self,
|
|
193
213
|
similarities: np.ndarray,
|
|
194
|
-
embeddings: list[tuple[
|
|
214
|
+
embeddings: list[tuple[str, list[float]]],
|
|
195
215
|
top_k: int,
|
|
196
|
-
) -> list[tuple[
|
|
216
|
+
) -> list[tuple[str, float]]:
|
|
197
217
|
"""Get top-k results by similarity score.
|
|
198
218
|
|
|
199
219
|
Args:
|