kodit 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (95) hide show
  1. kodit/_version.py +2 -2
  2. kodit/app.py +53 -23
  3. kodit/application/factories/reporting_factory.py +6 -2
  4. kodit/application/factories/server_factory.py +311 -0
  5. kodit/application/services/code_search_application_service.py +144 -0
  6. kodit/application/services/commit_indexing_application_service.py +543 -0
  7. kodit/application/services/indexing_worker_service.py +13 -44
  8. kodit/application/services/queue_service.py +24 -3
  9. kodit/application/services/reporting.py +0 -2
  10. kodit/application/services/sync_scheduler.py +15 -31
  11. kodit/cli.py +2 -753
  12. kodit/cli_utils.py +2 -9
  13. kodit/config.py +1 -94
  14. kodit/database.py +38 -1
  15. kodit/domain/{entities.py → entities/__init__.py} +50 -195
  16. kodit/domain/entities/git.py +190 -0
  17. kodit/domain/factories/__init__.py +1 -0
  18. kodit/domain/factories/git_repo_factory.py +76 -0
  19. kodit/domain/protocols.py +263 -64
  20. kodit/domain/services/bm25_service.py +5 -1
  21. kodit/domain/services/embedding_service.py +3 -0
  22. kodit/domain/services/git_repository_service.py +429 -0
  23. kodit/domain/services/git_service.py +300 -0
  24. kodit/domain/services/task_status_query_service.py +2 -2
  25. kodit/domain/value_objects.py +83 -114
  26. kodit/infrastructure/api/client/__init__.py +0 -2
  27. kodit/infrastructure/api/v1/__init__.py +0 -4
  28. kodit/infrastructure/api/v1/dependencies.py +92 -46
  29. kodit/infrastructure/api/v1/routers/__init__.py +0 -6
  30. kodit/infrastructure/api/v1/routers/commits.py +271 -0
  31. kodit/infrastructure/api/v1/routers/queue.py +2 -2
  32. kodit/infrastructure/api/v1/routers/repositories.py +282 -0
  33. kodit/infrastructure/api/v1/routers/search.py +31 -14
  34. kodit/infrastructure/api/v1/schemas/__init__.py +0 -24
  35. kodit/infrastructure/api/v1/schemas/commit.py +96 -0
  36. kodit/infrastructure/api/v1/schemas/context.py +2 -0
  37. kodit/infrastructure/api/v1/schemas/repository.py +128 -0
  38. kodit/infrastructure/api/v1/schemas/search.py +12 -9
  39. kodit/infrastructure/api/v1/schemas/snippet.py +58 -0
  40. kodit/infrastructure/api/v1/schemas/tag.py +31 -0
  41. kodit/infrastructure/api/v1/schemas/task_status.py +2 -0
  42. kodit/infrastructure/bm25/local_bm25_repository.py +16 -4
  43. kodit/infrastructure/bm25/vectorchord_bm25_repository.py +68 -52
  44. kodit/infrastructure/cloning/git/git_python_adaptor.py +467 -0
  45. kodit/infrastructure/cloning/git/working_copy.py +1 -1
  46. kodit/infrastructure/embedding/embedding_factory.py +3 -2
  47. kodit/infrastructure/embedding/local_vector_search_repository.py +1 -1
  48. kodit/infrastructure/embedding/vectorchord_vector_search_repository.py +111 -84
  49. kodit/infrastructure/enrichment/litellm_enrichment_provider.py +19 -26
  50. kodit/infrastructure/indexing/fusion_service.py +1 -1
  51. kodit/infrastructure/mappers/git_mapper.py +193 -0
  52. kodit/infrastructure/mappers/snippet_mapper.py +106 -0
  53. kodit/infrastructure/mappers/task_mapper.py +5 -44
  54. kodit/infrastructure/reporting/log_progress.py +8 -5
  55. kodit/infrastructure/reporting/telemetry_progress.py +21 -0
  56. kodit/infrastructure/slicing/slicer.py +32 -31
  57. kodit/infrastructure/sqlalchemy/embedding_repository.py +43 -23
  58. kodit/infrastructure/sqlalchemy/entities.py +394 -158
  59. kodit/infrastructure/sqlalchemy/git_branch_repository.py +263 -0
  60. kodit/infrastructure/sqlalchemy/git_commit_repository.py +337 -0
  61. kodit/infrastructure/sqlalchemy/git_repository.py +252 -0
  62. kodit/infrastructure/sqlalchemy/git_tag_repository.py +257 -0
  63. kodit/infrastructure/sqlalchemy/snippet_v2_repository.py +484 -0
  64. kodit/infrastructure/sqlalchemy/task_repository.py +29 -23
  65. kodit/infrastructure/sqlalchemy/task_status_repository.py +24 -12
  66. kodit/infrastructure/sqlalchemy/unit_of_work.py +10 -14
  67. kodit/mcp.py +12 -30
  68. kodit/migrations/env.py +1 -0
  69. kodit/migrations/versions/04b80f802e0c_foreign_key_review.py +100 -0
  70. kodit/migrations/versions/7f15f878c3a1_add_new_git_entities.py +690 -0
  71. kodit/migrations/versions/f9e5ef5e688f_add_git_commits_number.py +43 -0
  72. kodit/py.typed +0 -0
  73. kodit/utils/dump_openapi.py +7 -4
  74. kodit/utils/path_utils.py +29 -0
  75. {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/METADATA +3 -3
  76. kodit-0.5.0.dist-info/RECORD +137 -0
  77. kodit/application/factories/code_indexing_factory.py +0 -195
  78. kodit/application/services/auto_indexing_service.py +0 -99
  79. kodit/application/services/code_indexing_application_service.py +0 -410
  80. kodit/domain/services/index_query_service.py +0 -70
  81. kodit/domain/services/index_service.py +0 -269
  82. kodit/infrastructure/api/client/index_client.py +0 -57
  83. kodit/infrastructure/api/v1/routers/indexes.py +0 -164
  84. kodit/infrastructure/api/v1/schemas/index.py +0 -101
  85. kodit/infrastructure/bm25/bm25_factory.py +0 -28
  86. kodit/infrastructure/cloning/__init__.py +0 -1
  87. kodit/infrastructure/cloning/metadata.py +0 -98
  88. kodit/infrastructure/mappers/index_mapper.py +0 -345
  89. kodit/infrastructure/reporting/tdqm_progress.py +0 -38
  90. kodit/infrastructure/slicing/language_detection_service.py +0 -18
  91. kodit/infrastructure/sqlalchemy/index_repository.py +0 -646
  92. kodit-0.4.3.dist-info/RECORD +0 -125
  93. {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/WHEEL +0 -0
  94. {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/entry_points.txt +0 -0
  95. {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,43 +1,10 @@
1
1
  """Task mapper for the task queue."""
2
2
 
3
- from typing import ClassVar
4
-
5
3
  from kodit.domain.entities import Task
6
- from kodit.domain.value_objects import TaskType
4
+ from kodit.domain.value_objects import TaskOperation
7
5
  from kodit.infrastructure.sqlalchemy import entities as db_entities
8
6
 
9
7
 
10
- class TaskTypeMapper:
11
- """Maps between domain QueuedTaskType and SQLAlchemy TaskType."""
12
-
13
- # Map TaskType enum to QueuedTaskType
14
- TASK_TYPE_MAPPING: ClassVar[dict[db_entities.TaskType, TaskType]] = {
15
- db_entities.TaskType.INDEX_UPDATE: TaskType.INDEX_UPDATE,
16
- }
17
-
18
- @staticmethod
19
- def to_domain_type(task_type: db_entities.TaskType) -> TaskType:
20
- """Convert SQLAlchemy TaskType to domain QueuedTaskType."""
21
- if task_type not in TaskTypeMapper.TASK_TYPE_MAPPING:
22
- raise ValueError(f"Unknown task type: {task_type}")
23
- return TaskTypeMapper.TASK_TYPE_MAPPING[task_type]
24
-
25
- @staticmethod
26
- def from_domain_type(task_type: TaskType) -> db_entities.TaskType:
27
- """Convert domain QueuedTaskType to SQLAlchemy TaskType."""
28
- if task_type not in TaskTypeMapper.TASK_TYPE_MAPPING.values():
29
- raise ValueError(f"Unknown task type: {task_type}")
30
-
31
- # Find value in TASK_TYPE_MAPPING
32
- return next(
33
- (
34
- db_task_type
35
- for db_task_type, domain_task_type in TaskTypeMapper.TASK_TYPE_MAPPING.items() # noqa: E501
36
- if domain_task_type == task_type
37
- )
38
- )
39
-
40
-
41
8
  class TaskMapper:
42
9
  """Maps between domain QueuedTask and SQLAlchemy Task entities.
43
10
 
@@ -52,13 +19,12 @@ class TaskMapper:
52
19
  Since QueuedTask doesn't have status fields, we store processing
53
20
  state in the payload.
54
21
  """
55
- # Get the task type
56
- task_type = TaskTypeMapper.to_domain_type(record.type)
57
-
22
+ if record.type not in TaskOperation.__members__.values():
23
+ raise ValueError(f"Unknown operation: {record.type}")
58
24
  # The dedup_key becomes the id in the domain entity
59
25
  return Task(
60
26
  id=record.dedup_key, # Use dedup_key as the unique identifier
61
- type=task_type,
27
+ type=TaskOperation(record.type),
62
28
  priority=record.priority,
63
29
  payload=record.payload or {},
64
30
  created_at=record.created_at,
@@ -68,14 +34,9 @@ class TaskMapper:
68
34
  @staticmethod
69
35
  def from_domain_task(task: Task) -> db_entities.Task:
70
36
  """Convert domain QueuedTask to SQLAlchemy Task record."""
71
- if task.type not in TaskTypeMapper.TASK_TYPE_MAPPING.values():
72
- raise ValueError(f"Unknown task type: {task.type}")
73
-
74
- # Find value in TASK_TYPE_MAPPING
75
- task_type = TaskTypeMapper.from_domain_type(task.type)
76
37
  return db_entities.Task(
77
38
  dedup_key=task.id,
78
- type=task_type,
39
+ type=task.type.value,
79
40
  payload=task.payload,
80
41
  priority=task.priority,
81
42
  )
@@ -22,13 +22,16 @@ class LoggingReportingModule(ReportingModule):
22
22
  async def on_change(self, progress: TaskStatus) -> None:
23
23
  """On step changed."""
24
24
  current_time = datetime.now(UTC)
25
- time_since_last_log = current_time - self._last_log_time
26
25
  step = progress
27
26
 
28
- if (
29
- step.state != ReportingState.IN_PROGRESS
30
- or time_since_last_log >= self.config.log_time_interval
31
- ):
27
+ if step.state == ReportingState.FAILED:
28
+ self._log.exception(
29
+ step.operation,
30
+ state=step.state,
31
+ completion_percent=step.completion_percent,
32
+ error=step.error,
33
+ )
34
+ else:
32
35
  self._log.info(
33
36
  step.operation,
34
37
  state=step.state,
@@ -0,0 +1,21 @@
1
+ """Log progress using telemetry."""
2
+
3
+ import structlog
4
+
5
+ from kodit.domain.entities import TaskStatus
6
+ from kodit.domain.protocols import ReportingModule
7
+ from kodit.log import log_event
8
+
9
+
10
+ class TelemetryProgressReportingModule(ReportingModule):
11
+ """Database progress reporting module."""
12
+
13
+ def __init__(self) -> None:
14
+ """Initialize the logging reporting module."""
15
+ self._log = structlog.get_logger(__name__)
16
+
17
+ async def on_change(self, progress: TaskStatus) -> None:
18
+ """On step changed."""
19
+ log_event(
20
+ progress.operation,
21
+ )
@@ -14,7 +14,7 @@ import structlog
14
14
  from tree_sitter import Node, Parser, Tree
15
15
  from tree_sitter_language_pack import get_language
16
16
 
17
- from kodit.domain.entities import File, Snippet
17
+ from kodit.domain.entities.git import GitFile, SnippetV2
18
18
  from kodit.domain.value_objects import LanguageMapping
19
19
 
20
20
 
@@ -149,9 +149,9 @@ class Slicer:
149
149
  """Initialize an empty slicer."""
150
150
  self.log = structlog.get_logger(__name__)
151
151
 
152
- def extract_snippets( # noqa: C901
153
- self, files: list[File], language: str = "python"
154
- ) -> list[Snippet]:
152
+ def extract_snippets_from_git_files( # noqa: C901
153
+ self, files: list[GitFile], language: str = "python"
154
+ ) -> list[SnippetV2]:
155
155
  """Extract code snippets from a list of files.
156
156
 
157
157
  Args:
@@ -187,10 +187,10 @@ class Slicer:
187
187
  raise RuntimeError(f"Failed to load {language} parser: {e}") from e
188
188
 
189
189
  # Create mapping from Paths to File objects and extract paths
190
- path_to_file_map: dict[Path, File] = {}
190
+ path_to_file_map: dict[Path, GitFile] = {}
191
191
  file_paths: list[Path] = []
192
192
  for file in files:
193
- file_path = file.as_path()
193
+ file_path = Path(file.path)
194
194
 
195
195
  # Validate file matches language
196
196
  if not self._file_matches_language(file_path.suffix, language):
@@ -225,7 +225,7 @@ class Slicer:
225
225
  self._build_reverse_call_graph(state)
226
226
 
227
227
  # Extract snippets for all functions
228
- snippets = []
228
+ snippets: list[SnippetV2] = []
229
229
  for qualified_name in state.def_index:
230
230
  snippet_content = self._get_snippet(
231
231
  qualified_name,
@@ -234,7 +234,7 @@ class Slicer:
234
234
  {"max_depth": 2, "max_functions": 8},
235
235
  )
236
236
  if "not found" not in snippet_content:
237
- snippet = self._create_snippet_entity(
237
+ snippet = self._create_snippet_entity_from_git_files(
238
238
  qualified_name, snippet_content, language, state, path_to_file_map
239
239
  )
240
240
  snippets.append(snippet)
@@ -247,8 +247,8 @@ class Slicer:
247
247
  return False
248
248
 
249
249
  try:
250
- return (
251
- language == LanguageMapping.get_language_for_extension(file_extension)
250
+ return language == LanguageMapping.get_language_for_extension(
251
+ file_extension
252
252
  )
253
253
  except ValueError:
254
254
  # Extension not supported, so it doesn't match any language
@@ -614,7 +614,8 @@ class Slicer:
614
614
  if callers:
615
615
  snippet_lines.append("")
616
616
  snippet_lines.append("# === USAGE EXAMPLES ===")
617
- for caller in list(callers)[:2]: # Show up to 2 examples
617
+ # Show up to 2 examples, sorted for deterministic order
618
+ for caller in sorted(callers)[:2]:
618
619
  call_line = self._find_function_call_line(
619
620
  caller, function_name, state, file_contents
620
621
  )
@@ -625,37 +626,37 @@ class Slicer:
625
626
 
626
627
  return "\n".join(snippet_lines)
627
628
 
628
- def _create_snippet_entity(
629
+ def _create_snippet_entity_from_git_files(
629
630
  self,
630
631
  qualified_name: str,
631
632
  snippet_content: str,
632
633
  language: str,
633
634
  state: AnalyzerState,
634
- path_to_file_map: dict[Path, File],
635
- ) -> Snippet:
635
+ path_to_file_map: dict[Path, GitFile],
636
+ ) -> SnippetV2:
636
637
  """Create a Snippet domain entity from extracted content."""
637
638
  # Determine all files that this snippet derives from
638
- derives_from_files = self._find_source_files_for_snippet(
639
+ derives_from_files = self._find_source_files_for_snippet_from_git_files(
639
640
  qualified_name, snippet_content, state, path_to_file_map
640
641
  )
641
642
 
642
643
  # Create the snippet entity
643
- snippet = Snippet(derives_from=derives_from_files)
644
-
645
- # Add the original content
646
- snippet.add_original_content(snippet_content, language)
647
-
648
- return snippet
644
+ return SnippetV2(
645
+ derives_from=derives_from_files,
646
+ content=snippet_content,
647
+ extension=language,
648
+ sha=SnippetV2.compute_sha(snippet_content),
649
+ )
649
650
 
650
- def _find_source_files_for_snippet(
651
+ def _find_source_files_for_snippet_from_git_files(
651
652
  self,
652
653
  qualified_name: str,
653
654
  snippet_content: str,
654
655
  state: AnalyzerState,
655
- path_to_file_map: dict[Path, File],
656
- ) -> list[File]:
656
+ path_to_file_map: dict[Path, GitFile],
657
+ ) -> list[GitFile]:
657
658
  """Find all source files that a snippet derives from."""
658
- source_files: list[File] = []
659
+ source_files: list[GitFile] = []
659
660
  source_file_paths: set[Path] = set()
660
661
 
661
662
  # Add the primary function's file
@@ -835,7 +836,7 @@ class Slicer:
835
836
  # Add direct dependencies
836
837
  to_visit.extend(
837
838
  (callee, depth + 1)
838
- for callee in state.call_graph.get(current, set())
839
+ for callee in sorted(state.call_graph.get(current, set()))
839
840
  if callee not in visited and callee in state.def_index
840
841
  )
841
842
 
@@ -850,26 +851,26 @@ class Slicer:
850
851
  in_degree: dict[str, int] = defaultdict(int)
851
852
  graph: dict[str, set[str]] = defaultdict(set)
852
853
 
853
- for func in functions:
854
- for callee in state.call_graph.get(func, set()):
854
+ for func in sorted(functions):
855
+ for callee in sorted(state.call_graph.get(func, set())):
855
856
  if callee in functions:
856
857
  graph[func].add(callee)
857
858
  in_degree[callee] += 1
858
859
 
859
860
  # Find roots
860
- queue = [f for f in functions if in_degree[f] == 0]
861
+ queue = [f for f in sorted(functions) if in_degree[f] == 0]
861
862
  result = []
862
863
 
863
864
  while queue:
864
865
  current = queue.pop(0)
865
866
  result.append(current)
866
- for neighbor in graph[current]:
867
+ for neighbor in sorted(graph[current]):
867
868
  in_degree[neighbor] -= 1
868
869
  if in_degree[neighbor] == 0:
869
870
  queue.append(neighbor)
870
871
 
871
872
  # Add any remaining (cycles)
872
- for func in functions:
873
+ for func in sorted(functions):
873
874
  if func not in result:
874
875
  result.append(func)
875
876
 
@@ -14,59 +14,79 @@ def create_embedding_repository(
14
14
  session_factory: Callable[[], AsyncSession],
15
15
  ) -> "SqlAlchemyEmbeddingRepository":
16
16
  """Create an embedding repository."""
17
- uow = SqlAlchemyUnitOfWork(session_factory=session_factory)
18
- return SqlAlchemyEmbeddingRepository(uow)
17
+ return SqlAlchemyEmbeddingRepository(session_factory=session_factory)
19
18
 
20
19
 
21
20
  class SqlAlchemyEmbeddingRepository:
22
21
  """SQLAlchemy implementation of embedding repository."""
23
22
 
24
- def __init__(self, uow: SqlAlchemyUnitOfWork) -> None:
23
+ def __init__(self, session_factory: Callable[[], AsyncSession]) -> None:
25
24
  """Initialize the SQLAlchemy embedding repository."""
26
- self.uow = uow
25
+ self.session_factory = session_factory
27
26
 
28
27
  async def create_embedding(self, embedding: Embedding) -> None:
29
28
  """Create a new embedding record in the database."""
30
- async with self.uow:
31
- self.uow.session.add(embedding)
29
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
30
+ session.add(embedding)
32
31
 
33
32
  async def get_embedding_by_snippet_id_and_type(
34
33
  self, snippet_id: int, embedding_type: EmbeddingType
35
34
  ) -> Embedding | None:
36
35
  """Get an embedding by its snippet ID and type."""
37
- async with self.uow:
36
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
38
37
  query = select(Embedding).where(
39
38
  Embedding.snippet_id == snippet_id,
40
39
  Embedding.type == embedding_type,
41
40
  )
42
- result = await self.uow.session.execute(query)
41
+ result = await session.execute(query)
43
42
  return result.scalar_one_or_none()
44
43
 
45
44
  async def list_embeddings_by_type(
46
45
  self, embedding_type: EmbeddingType
47
46
  ) -> list[Embedding]:
48
47
  """List all embeddings of a given type."""
49
- async with self.uow:
48
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
50
49
  query = select(Embedding).where(Embedding.type == embedding_type)
51
- result = await self.uow.session.execute(query)
50
+ result = await session.execute(query)
52
51
  return list(result.scalars())
53
52
 
54
- async def delete_embeddings_by_snippet_id(self, snippet_id: int) -> None:
53
+ async def delete_embeddings_by_snippet_id(self, snippet_id: str) -> None:
55
54
  """Delete all embeddings for a snippet."""
56
- async with self.uow:
55
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
57
56
  query = select(Embedding).where(Embedding.snippet_id == snippet_id)
58
- result = await self.uow.session.execute(query)
57
+ result = await session.execute(query)
59
58
  embeddings = result.scalars().all()
60
59
  for embedding in embeddings:
61
- await self.uow.session.delete(embedding)
60
+ await session.delete(embedding)
61
+
62
+ async def list_embeddings_by_snippet_ids_and_type(
63
+ self, snippet_ids: list[str], embedding_type: EmbeddingType
64
+ ) -> list[Embedding]:
65
+ """Get all embeddings for the given snippet IDs."""
66
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
67
+ query = select(Embedding).where(
68
+ Embedding.snippet_id.in_(snippet_ids),
69
+ Embedding.type == embedding_type,
70
+ )
71
+ result = await session.execute(query)
72
+ return list(result.scalars())
73
+
74
+ async def get_embeddings_by_snippet_ids(
75
+ self, snippet_ids: list[str]
76
+ ) -> list[Embedding]:
77
+ """Get all embeddings for the given snippet IDs."""
78
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
79
+ query = select(Embedding).where(Embedding.snippet_id.in_(snippet_ids))
80
+ result = await session.execute(query)
81
+ return list(result.scalars())
62
82
 
63
83
  async def list_semantic_results(
64
84
  self,
65
85
  embedding_type: EmbeddingType,
66
86
  embedding: list[float],
67
87
  top_k: int = 10,
68
- snippet_ids: list[int] | None = None,
69
- ) -> list[tuple[int, float]]:
88
+ snippet_ids: list[str] | None = None,
89
+ ) -> list[tuple[str, float]]:
70
90
  """List semantic results using cosine similarity.
71
91
 
72
92
  This implementation fetches all embeddings of the given type and computes
@@ -97,8 +117,8 @@ class SqlAlchemyEmbeddingRepository:
97
117
  return self._get_top_k_results(similarities, embeddings, top_k)
98
118
 
99
119
  async def _list_embedding_values(
100
- self, embedding_type: EmbeddingType, snippet_ids: list[int] | None = None
101
- ) -> list[tuple[int, list[float]]]:
120
+ self, embedding_type: EmbeddingType, snippet_ids: list[str] | None = None
121
+ ) -> list[tuple[str, list[float]]]:
102
122
  """List all embeddings of a given type from the database.
103
123
 
104
124
  Args:
@@ -109,7 +129,7 @@ class SqlAlchemyEmbeddingRepository:
109
129
  List of (snippet_id, embedding) tuples
110
130
 
111
131
  """
112
- async with self.uow:
132
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
113
133
  query = select(Embedding.snippet_id, Embedding.embedding).where(
114
134
  Embedding.type == embedding_type
115
135
  )
@@ -118,11 +138,11 @@ class SqlAlchemyEmbeddingRepository:
118
138
  if snippet_ids is not None:
119
139
  query = query.where(Embedding.snippet_id.in_(snippet_ids))
120
140
 
121
- rows = await self.uow.session.execute(query)
141
+ rows = await session.execute(query)
122
142
  return [tuple(row) for row in rows.all()] # Convert Row objects to tuples
123
143
 
124
144
  def _prepare_vectors(
125
- self, embeddings: list[tuple[int, list[float]]], query_embedding: list[float]
145
+ self, embeddings: list[tuple[str, list[float]]], query_embedding: list[float]
126
146
  ) -> tuple[np.ndarray, np.ndarray]:
127
147
  """Convert embeddings to numpy arrays.
128
148
 
@@ -191,9 +211,9 @@ class SqlAlchemyEmbeddingRepository:
191
211
  def _get_top_k_results(
192
212
  self,
193
213
  similarities: np.ndarray,
194
- embeddings: list[tuple[int, list[float]]],
214
+ embeddings: list[tuple[str, list[float]]],
195
215
  top_k: int,
196
- ) -> list[tuple[int, float]]:
216
+ ) -> list[tuple[str, float]]:
197
217
  """Get top-k results by similarity score.
198
218
 
199
219
  Args: