kodit 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (100) hide show
  1. kodit/_version.py +2 -2
  2. kodit/app.py +59 -24
  3. kodit/application/factories/reporting_factory.py +16 -7
  4. kodit/application/factories/server_factory.py +311 -0
  5. kodit/application/services/code_search_application_service.py +144 -0
  6. kodit/application/services/commit_indexing_application_service.py +543 -0
  7. kodit/application/services/indexing_worker_service.py +13 -46
  8. kodit/application/services/queue_service.py +24 -3
  9. kodit/application/services/reporting.py +70 -54
  10. kodit/application/services/sync_scheduler.py +15 -31
  11. kodit/cli.py +2 -763
  12. kodit/cli_utils.py +2 -9
  13. kodit/config.py +3 -96
  14. kodit/database.py +38 -1
  15. kodit/domain/entities/__init__.py +276 -0
  16. kodit/domain/entities/git.py +190 -0
  17. kodit/domain/factories/__init__.py +1 -0
  18. kodit/domain/factories/git_repo_factory.py +76 -0
  19. kodit/domain/protocols.py +270 -46
  20. kodit/domain/services/bm25_service.py +5 -1
  21. kodit/domain/services/embedding_service.py +3 -0
  22. kodit/domain/services/git_repository_service.py +429 -0
  23. kodit/domain/services/git_service.py +300 -0
  24. kodit/domain/services/task_status_query_service.py +19 -0
  25. kodit/domain/value_objects.py +113 -147
  26. kodit/infrastructure/api/client/__init__.py +0 -2
  27. kodit/infrastructure/api/v1/__init__.py +0 -4
  28. kodit/infrastructure/api/v1/dependencies.py +105 -44
  29. kodit/infrastructure/api/v1/routers/__init__.py +0 -6
  30. kodit/infrastructure/api/v1/routers/commits.py +271 -0
  31. kodit/infrastructure/api/v1/routers/queue.py +2 -2
  32. kodit/infrastructure/api/v1/routers/repositories.py +282 -0
  33. kodit/infrastructure/api/v1/routers/search.py +31 -14
  34. kodit/infrastructure/api/v1/schemas/__init__.py +0 -24
  35. kodit/infrastructure/api/v1/schemas/commit.py +96 -0
  36. kodit/infrastructure/api/v1/schemas/context.py +2 -0
  37. kodit/infrastructure/api/v1/schemas/repository.py +128 -0
  38. kodit/infrastructure/api/v1/schemas/search.py +12 -9
  39. kodit/infrastructure/api/v1/schemas/snippet.py +58 -0
  40. kodit/infrastructure/api/v1/schemas/tag.py +31 -0
  41. kodit/infrastructure/api/v1/schemas/task_status.py +41 -0
  42. kodit/infrastructure/bm25/local_bm25_repository.py +16 -4
  43. kodit/infrastructure/bm25/vectorchord_bm25_repository.py +68 -52
  44. kodit/infrastructure/cloning/git/git_python_adaptor.py +467 -0
  45. kodit/infrastructure/cloning/git/working_copy.py +10 -3
  46. kodit/infrastructure/embedding/embedding_factory.py +3 -2
  47. kodit/infrastructure/embedding/local_vector_search_repository.py +1 -1
  48. kodit/infrastructure/embedding/vectorchord_vector_search_repository.py +111 -84
  49. kodit/infrastructure/enrichment/litellm_enrichment_provider.py +19 -26
  50. kodit/infrastructure/enrichment/local_enrichment_provider.py +41 -30
  51. kodit/infrastructure/indexing/fusion_service.py +1 -1
  52. kodit/infrastructure/mappers/git_mapper.py +193 -0
  53. kodit/infrastructure/mappers/snippet_mapper.py +106 -0
  54. kodit/infrastructure/mappers/task_mapper.py +5 -44
  55. kodit/infrastructure/mappers/task_status_mapper.py +85 -0
  56. kodit/infrastructure/reporting/db_progress.py +23 -0
  57. kodit/infrastructure/reporting/log_progress.py +13 -38
  58. kodit/infrastructure/reporting/telemetry_progress.py +21 -0
  59. kodit/infrastructure/slicing/slicer.py +32 -31
  60. kodit/infrastructure/sqlalchemy/embedding_repository.py +43 -23
  61. kodit/infrastructure/sqlalchemy/entities.py +428 -131
  62. kodit/infrastructure/sqlalchemy/git_branch_repository.py +263 -0
  63. kodit/infrastructure/sqlalchemy/git_commit_repository.py +337 -0
  64. kodit/infrastructure/sqlalchemy/git_repository.py +252 -0
  65. kodit/infrastructure/sqlalchemy/git_tag_repository.py +257 -0
  66. kodit/infrastructure/sqlalchemy/snippet_v2_repository.py +484 -0
  67. kodit/infrastructure/sqlalchemy/task_repository.py +29 -23
  68. kodit/infrastructure/sqlalchemy/task_status_repository.py +91 -0
  69. kodit/infrastructure/sqlalchemy/unit_of_work.py +10 -14
  70. kodit/mcp.py +12 -26
  71. kodit/migrations/env.py +1 -1
  72. kodit/migrations/versions/04b80f802e0c_foreign_key_review.py +100 -0
  73. kodit/migrations/versions/7f15f878c3a1_add_new_git_entities.py +690 -0
  74. kodit/migrations/versions/b9cd1c3fd762_add_task_status.py +77 -0
  75. kodit/migrations/versions/f9e5ef5e688f_add_git_commits_number.py +43 -0
  76. kodit/py.typed +0 -0
  77. kodit/utils/dump_openapi.py +7 -4
  78. kodit/utils/path_utils.py +29 -0
  79. {kodit-0.4.2.dist-info → kodit-0.5.0.dist-info}/METADATA +3 -3
  80. kodit-0.5.0.dist-info/RECORD +137 -0
  81. kodit/application/factories/code_indexing_factory.py +0 -193
  82. kodit/application/services/auto_indexing_service.py +0 -103
  83. kodit/application/services/code_indexing_application_service.py +0 -393
  84. kodit/domain/entities.py +0 -323
  85. kodit/domain/services/index_query_service.py +0 -70
  86. kodit/domain/services/index_service.py +0 -267
  87. kodit/infrastructure/api/client/index_client.py +0 -57
  88. kodit/infrastructure/api/v1/routers/indexes.py +0 -119
  89. kodit/infrastructure/api/v1/schemas/index.py +0 -101
  90. kodit/infrastructure/bm25/bm25_factory.py +0 -28
  91. kodit/infrastructure/cloning/__init__.py +0 -1
  92. kodit/infrastructure/cloning/metadata.py +0 -98
  93. kodit/infrastructure/mappers/index_mapper.py +0 -345
  94. kodit/infrastructure/reporting/tdqm_progress.py +0 -73
  95. kodit/infrastructure/slicing/language_detection_service.py +0 -18
  96. kodit/infrastructure/sqlalchemy/index_repository.py +0 -646
  97. kodit-0.4.2.dist-info/RECORD +0 -119
  98. {kodit-0.4.2.dist-info → kodit-0.5.0.dist-info}/WHEEL +0 -0
  99. {kodit-0.4.2.dist-info → kodit-0.5.0.dist-info}/entry_points.txt +0 -0
  100. {kodit-0.4.2.dist-info → kodit-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,10 @@
1
1
  """VectorChord vector search repository implementation."""
2
2
 
3
- from collections.abc import AsyncGenerator
4
- from typing import Any, Literal
3
+ from collections.abc import AsyncGenerator, Callable
4
+ from typing import Literal
5
5
 
6
6
  import structlog
7
- from sqlalchemy import Result, TextClause, text
7
+ from sqlalchemy import text
8
8
  from sqlalchemy.ext.asyncio import AsyncSession
9
9
 
10
10
  from kodit.domain.services.embedding_service import (
@@ -19,6 +19,7 @@ from kodit.domain.value_objects import (
19
19
  SearchResult,
20
20
  )
21
21
  from kodit.infrastructure.sqlalchemy.entities import EmbeddingType
22
+ from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
22
23
 
23
24
  # SQL Queries
24
25
  CREATE_VCHORD_EXTENSION = """
@@ -72,6 +73,10 @@ CHECK_VCHORD_EMBEDDING_EXISTS = """
72
73
  SELECT EXISTS(SELECT 1 FROM {TABLE_NAME} WHERE snippet_id = :snippet_id)
73
74
  """
74
75
 
76
+ CHECK_VCHORD_EMBEDDING_EXISTS_MULTIPLE = """
77
+ SELECT snippet_id FROM {TABLE_NAME} WHERE snippet_id = ANY(:snippet_ids)
78
+ """
79
+
75
80
  TaskName = Literal["code", "text"]
76
81
 
77
82
 
@@ -80,8 +85,8 @@ class VectorChordVectorSearchRepository(VectorSearchRepository):
80
85
 
81
86
  def __init__(
82
87
  self,
88
+ session_factory: Callable[[], AsyncSession],
83
89
  task_name: TaskName,
84
- session: AsyncSession,
85
90
  embedding_provider: EmbeddingProvider,
86
91
  ) -> None:
87
92
  """Initialize the VectorChord vector search repository.
@@ -93,7 +98,7 @@ class VectorChordVectorSearchRepository(VectorSearchRepository):
93
98
 
94
99
  """
95
100
  self.embedding_provider = embedding_provider
96
- self._session = session
101
+ self.session_factory = session_factory
97
102
  self._initialized = False
98
103
  self.table_name = f"vectorchord_{task_name}_embeddings"
99
104
  self.index_name = f"{self.table_name}_idx"
@@ -111,12 +116,12 @@ class VectorChordVectorSearchRepository(VectorSearchRepository):
111
116
 
112
117
  async def _create_extensions(self) -> None:
113
118
  """Create the necessary extensions."""
114
- await self._session.execute(text(CREATE_VCHORD_EXTENSION))
115
- await self._commit()
119
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
120
+ await session.execute(text(CREATE_VCHORD_EXTENSION))
116
121
 
117
122
  async def _create_tables(self) -> None:
118
123
  """Create the necessary tables."""
119
- req = EmbeddingRequest(snippet_id=0, text="dimension")
124
+ req = EmbeddingRequest(snippet_id="0", text="dimension")
120
125
  vector_dim: list[float] | None = None
121
126
  async for batch in self.embedding_provider.embed([req]):
122
127
  if batch:
@@ -125,79 +130,85 @@ class VectorChordVectorSearchRepository(VectorSearchRepository):
125
130
  if vector_dim is None:
126
131
  msg = "Failed to obtain embedding dimension from provider"
127
132
  raise RuntimeError(msg)
128
- await self._session.execute(
129
- text(
130
- f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
131
- id SERIAL PRIMARY KEY,
132
- snippet_id INT NOT NULL UNIQUE,
133
- embedding VECTOR({len(vector_dim)}) NOT NULL
134
- );"""
133
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
134
+ await session.execute(
135
+ text(
136
+ f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
137
+ id SERIAL PRIMARY KEY,
138
+ snippet_id VARCHAR(255) NOT NULL UNIQUE,
139
+ embedding VECTOR({len(vector_dim)}) NOT NULL
140
+ );"""
141
+ )
135
142
  )
136
- )
137
- await self._session.execute(
138
- text(
139
- CREATE_VCHORD_INDEX.format(
140
- TABLE_NAME=self.table_name, INDEX_NAME=self.index_name
143
+ await session.execute(
144
+ text(
145
+ CREATE_VCHORD_INDEX.format(
146
+ TABLE_NAME=self.table_name, INDEX_NAME=self.index_name
147
+ )
141
148
  )
142
149
  )
143
- )
144
- result = await self._session.execute(
145
- text(CHECK_VCHORD_EMBEDDING_DIMENSION.format(TABLE_NAME=self.table_name))
146
- )
147
- vector_dim_from_db = result.scalar_one()
148
- if vector_dim_from_db != len(vector_dim):
149
- msg = (
150
- f"Embedding vector dimension does not match database, "
151
- f"please delete your index: {vector_dim_from_db} != {len(vector_dim)}"
150
+ result = await session.execute(
151
+ text(
152
+ CHECK_VCHORD_EMBEDDING_DIMENSION.format(TABLE_NAME=self.table_name)
153
+ )
152
154
  )
153
- raise ValueError(msg)
154
- await self._commit()
155
-
156
- async def _execute(
157
- self, query: TextClause, param_list: list[Any] | dict[str, Any] | None = None
158
- ) -> Result:
159
- """Execute a query."""
160
- if not self._initialized:
161
- await self._initialize()
162
- return await self._session.execute(query, param_list)
163
-
164
- async def _commit(self) -> None:
165
- """Commit the session."""
166
- await self._session.commit()
155
+ vector_dim_from_db = result.scalar_one()
156
+ if vector_dim_from_db != len(vector_dim):
157
+ msg = (
158
+ f"Embedding vector dimension does not match database, please "
159
+ f"delete your index: {vector_dim_from_db} != {len(vector_dim)}"
160
+ )
161
+ raise ValueError(msg)
167
162
 
168
163
  async def index_documents(
169
164
  self, request: IndexRequest
170
165
  ) -> AsyncGenerator[list[IndexResult], None]:
171
166
  """Index documents for vector search."""
167
+ if not self._initialized:
168
+ await self._initialize()
169
+
172
170
  if not request.documents:
173
171
  yield []
174
172
 
173
+ # Search for existing embeddings
174
+ existing_ids = await self._get_existing_ids(
175
+ [doc.snippet_id for doc in request.documents]
176
+ )
177
+ new_documents = [
178
+ doc for doc in request.documents if doc.snippet_id not in existing_ids
179
+ ]
180
+ if not new_documents:
181
+ self.log.info("No new documents to index")
182
+ return
183
+
175
184
  # Convert to embedding requests
176
- requests = [
185
+ embedding_requests = [
177
186
  EmbeddingRequest(snippet_id=doc.snippet_id, text=doc.text)
178
- for doc in request.documents
187
+ for doc in new_documents
179
188
  ]
180
189
 
181
- async for batch in self.embedding_provider.embed(requests):
182
- await self._execute(
183
- text(INSERT_QUERY.format(TABLE_NAME=self.table_name)),
184
- [
185
- {
186
- "snippet_id": result.snippet_id,
187
- "embedding": str(result.embedding),
188
- }
189
- for result in batch
190
- ],
191
- )
192
- await self._commit()
193
- yield [IndexResult(snippet_id=result.snippet_id) for result in batch]
190
+ async for batch in self.embedding_provider.embed(embedding_requests):
191
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
192
+ await session.execute(
193
+ text(INSERT_QUERY.format(TABLE_NAME=self.table_name)),
194
+ [
195
+ {
196
+ "snippet_id": result.snippet_id,
197
+ "embedding": str(result.embedding),
198
+ }
199
+ for result in batch
200
+ ],
201
+ )
202
+ yield [IndexResult(snippet_id=result.snippet_id) for result in batch]
194
203
 
195
204
  async def search(self, request: SearchRequest) -> list[SearchResult]:
196
205
  """Search documents using vector similarity."""
206
+ if not self._initialized:
207
+ await self._initialize()
197
208
  if not request.query or not request.query.strip():
198
209
  return []
199
210
 
200
- req = EmbeddingRequest(snippet_id=0, text=request.query)
211
+ req = EmbeddingRequest(snippet_id="0", text=request.query)
201
212
  embedding_vec: list[float] | None = None
202
213
  async for batch in self.embedding_provider.embed([req]):
203
214
  if batch:
@@ -207,39 +218,55 @@ class VectorChordVectorSearchRepository(VectorSearchRepository):
207
218
  if not embedding_vec:
208
219
  return []
209
220
 
210
- # Use filtered query if snippet_ids are provided
211
- if request.snippet_ids is not None:
212
- result = await self._execute(
213
- text(SEARCH_QUERY_WITH_FILTER.format(TABLE_NAME=self.table_name)),
214
- {
215
- "query": str(embedding_vec),
216
- "top_k": request.top_k,
217
- "snippet_ids": request.snippet_ids,
218
- },
219
- )
220
- else:
221
- result = await self._execute(
222
- text(SEARCH_QUERY.format(TABLE_NAME=self.table_name)),
223
- {"query": str(embedding_vec), "top_k": request.top_k},
224
- )
221
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
222
+ # Use filtered query if snippet_ids are provided
223
+ if request.snippet_ids is not None:
224
+ result = await session.execute(
225
+ text(SEARCH_QUERY_WITH_FILTER.format(TABLE_NAME=self.table_name)),
226
+ {
227
+ "query": str(embedding_vec),
228
+ "top_k": request.top_k,
229
+ "snippet_ids": request.snippet_ids,
230
+ },
231
+ )
232
+ else:
233
+ result = await session.execute(
234
+ text(SEARCH_QUERY.format(TABLE_NAME=self.table_name)),
235
+ {"query": str(embedding_vec), "top_k": request.top_k},
236
+ )
225
237
 
226
- rows = result.mappings().all()
238
+ rows = result.mappings().all()
227
239
 
228
- return [
229
- SearchResult(snippet_id=row["snippet_id"], score=row["score"])
230
- for row in rows
231
- ]
240
+ return [
241
+ SearchResult(snippet_id=row["snippet_id"], score=row["score"])
242
+ for row in rows
243
+ ]
232
244
 
233
245
  async def has_embedding(
234
246
  self, snippet_id: int, embedding_type: EmbeddingType
235
247
  ) -> bool:
236
248
  """Check if a snippet has an embedding."""
249
+ if not self._initialized:
250
+ await self._initialize()
237
251
  # For VectorChord, we check if the snippet exists in the table
238
252
  # Note: embedding_type is ignored since VectorChord uses separate
239
253
  # tables per task
240
254
  # ruff: noqa: ARG002
241
- result = await self._execute(
242
- text(CHECK_VCHORD_EMBEDDING_EXISTS.format(TABLE_NAME=self.table_name)),
243
- {"snippet_id": snippet_id},
244
- )
245
- return bool(result.scalar())
255
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
256
+ result = await session.execute(
257
+ text(CHECK_VCHORD_EMBEDDING_EXISTS.format(TABLE_NAME=self.table_name)),
258
+ {"snippet_id": snippet_id},
259
+ )
260
+ return bool(result.scalar())
261
+
262
+ async def _get_existing_ids(self, snippet_ids: list[str]) -> set[str]:
263
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
264
+ result = await session.execute(
265
+ text(
266
+ CHECK_VCHORD_EMBEDDING_EXISTS_MULTIPLE.format(
267
+ TABLE_NAME=self.table_name
268
+ )
269
+ ),
270
+ {"snippet_ids": snippet_ids},
271
+ )
272
+ return {row[0] for row in result.fetchall()}
@@ -128,32 +128,25 @@ class LiteLLMEnrichmentProvider(EnrichmentProvider):
128
128
  snippet_id=request.snippet_id,
129
129
  text="",
130
130
  )
131
- try:
132
- messages = [
133
- {
134
- "role": "system",
135
- "content": ENRICHMENT_SYSTEM_PROMPT,
136
- },
137
- {"role": "user", "content": request.text},
138
- ]
139
- response = await self._call_chat_completion(messages)
140
- content = (
141
- response.get("choices", [{}])[0]
142
- .get("message", {})
143
- .get("content", "")
144
- )
145
- # Remove thinking tags from the response
146
- cleaned_content = clean_thinking_tags(content or "")
147
- return EnrichmentResponse(
148
- snippet_id=request.snippet_id,
149
- text=cleaned_content,
150
- )
151
- except Exception as e:
152
- self.log.exception("Error enriching request", error=str(e))
153
- return EnrichmentResponse(
154
- snippet_id=request.snippet_id,
155
- text="",
156
- )
131
+ messages = [
132
+ {
133
+ "role": "system",
134
+ "content": ENRICHMENT_SYSTEM_PROMPT,
135
+ },
136
+ {"role": "user", "content": request.text},
137
+ ]
138
+ response = await self._call_chat_completion(messages)
139
+ content = (
140
+ response.get("choices", [{}])[0]
141
+ .get("message", {})
142
+ .get("content", "")
143
+ )
144
+ # Remove thinking tags from the response
145
+ cleaned_content = clean_thinking_tags(content or "")
146
+ return EnrichmentResponse(
147
+ snippet_id=request.snippet_id,
148
+ text=cleaned_content,
149
+ )
157
150
 
158
151
  # Create tasks for all requests
159
152
  tasks = [process_request(request) for request in requests]
@@ -1,7 +1,9 @@
1
1
  """Local enrichment provider implementation."""
2
2
 
3
+ import asyncio
3
4
  import os
4
5
  from collections.abc import AsyncGenerator
6
+ from typing import Any
5
7
 
6
8
  import structlog
7
9
  import tiktoken
@@ -60,23 +62,26 @@ class LocalEnrichmentProvider(EnrichmentProvider):
60
62
  self.log.warning("No valid requests for enrichment")
61
63
  return
62
64
 
63
- from transformers.models.auto.modeling_auto import (
64
- AutoModelForCausalLM,
65
- )
66
- from transformers.models.auto.tokenization_auto import AutoTokenizer
67
-
68
- if self.tokenizer is None:
69
- self.tokenizer = AutoTokenizer.from_pretrained(
70
- self.model_name, padding_side="left"
71
- )
72
- if self.model is None:
73
- os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
74
- self.model = AutoModelForCausalLM.from_pretrained(
75
- self.model_name,
76
- torch_dtype="auto",
77
- trust_remote_code=True,
78
- device_map="auto",
65
+ def _init_model() -> None:
66
+ from transformers.models.auto.modeling_auto import (
67
+ AutoModelForCausalLM,
79
68
  )
69
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
70
+
71
+ if self.tokenizer is None:
72
+ self.tokenizer = AutoTokenizer.from_pretrained(
73
+ self.model_name, padding_side="left"
74
+ )
75
+ if self.model is None:
76
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
77
+ self.model = AutoModelForCausalLM.from_pretrained(
78
+ self.model_name,
79
+ torch_dtype="auto",
80
+ trust_remote_code=True,
81
+ device_map="auto",
82
+ )
83
+
84
+ await asyncio.to_thread(_init_model)
80
85
 
81
86
  # Prepare prompts
82
87
  prompts = [
@@ -96,20 +101,26 @@ class LocalEnrichmentProvider(EnrichmentProvider):
96
101
  ]
97
102
 
98
103
  for prompt in prompts:
99
- model_inputs = self.tokenizer( # type: ignore[misc]
100
- prompt["text"],
101
- return_tensors="pt",
102
- padding=True,
103
- truncation=True,
104
- ).to(self.model.device) # type: ignore[attr-defined]
105
- generated_ids = self.model.generate( # type: ignore[attr-defined]
106
- **model_inputs, max_new_tokens=self.context_window
107
- )
108
- input_ids = model_inputs["input_ids"][0]
109
- output_ids = generated_ids[0][len(input_ids) :].tolist()
110
- content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip( # type: ignore[attr-defined]
111
- "\n"
112
- )
104
+
105
+ def process_prompt(prompt: dict[str, Any]) -> str:
106
+ model_inputs = self.tokenizer( # type: ignore[misc]
107
+ prompt["text"],
108
+ return_tensors="pt",
109
+ padding=True,
110
+ truncation=True,
111
+ ).to(self.model.device) # type: ignore[attr-defined]
112
+ generated_ids = self.model.generate( # type: ignore[attr-defined]
113
+ **model_inputs, max_new_tokens=self.context_window
114
+ )
115
+ input_ids = model_inputs["input_ids"][0]
116
+ output_ids = generated_ids[0][len(input_ids) :].tolist()
117
+ return self.tokenizer.decode( # type: ignore[attr-defined]
118
+ output_ids, skip_special_tokens=True
119
+ ).strip( # type: ignore[attr-defined]
120
+ "\n"
121
+ )
122
+
123
+ content = await asyncio.to_thread(process_prompt, prompt)
113
124
  # Remove thinking tags from the response
114
125
  cleaned_content = clean_thinking_tags(content)
115
126
  yield EnrichmentResponse(
@@ -2,7 +2,7 @@
2
2
 
3
3
  from collections import defaultdict
4
4
 
5
- from kodit.domain.services.index_query_service import FusionService
5
+ from kodit.domain.protocols import FusionService
6
6
  from kodit.domain.value_objects import FusionRequest, FusionResult
7
7
 
8
8
 
@@ -0,0 +1,193 @@
1
+ """Mapping between domain Git entities and SQLAlchemy entities."""
2
+
3
+ from collections import defaultdict
4
+ from pathlib import Path
5
+
6
+ from pydantic import AnyUrl
7
+
8
+ import kodit.domain.entities.git as domain_git_entities
9
+ from kodit.infrastructure.sqlalchemy import entities as db_entities
10
+
11
+
12
+ class GitMapper:
13
+ """Mapper for converting between domain Git entities and database entities."""
14
+
15
+ def to_domain_commits(
16
+ self,
17
+ db_commits: list[db_entities.GitCommit],
18
+ db_commit_files: list[db_entities.GitCommitFile],
19
+ ) -> list[domain_git_entities.GitCommit]:
20
+ """Convert SQLAlchemy GitCommit to domain GitCommit."""
21
+ commit_files_map = defaultdict(list)
22
+ for file in db_commit_files:
23
+ commit_files_map[file.commit_sha].append(file.blob_sha)
24
+
25
+ commit_domain_files_map = defaultdict(list)
26
+ for file in db_commit_files:
27
+ commit_domain_files_map[file.commit_sha].append(
28
+ domain_git_entities.GitFile(
29
+ created_at=file.created_at,
30
+ blob_sha=file.blob_sha,
31
+ path=file.path,
32
+ mime_type=file.mime_type,
33
+ size=file.size,
34
+ extension=file.extension,
35
+ )
36
+ )
37
+
38
+ domain_commits = []
39
+ for db_commit in db_commits:
40
+ domain_commit = domain_git_entities.GitCommit(
41
+ created_at=db_commit.created_at,
42
+ updated_at=db_commit.updated_at,
43
+ commit_sha=db_commit.commit_sha,
44
+ date=db_commit.date,
45
+ message=db_commit.message,
46
+ parent_commit_sha=db_commit.parent_commit_sha,
47
+ files=commit_domain_files_map[db_commit.commit_sha],
48
+ author=db_commit.author,
49
+ )
50
+ domain_commits.append(domain_commit)
51
+ return domain_commits
52
+
53
+ def to_domain_branches(
54
+ self,
55
+ db_branches: list[db_entities.GitBranch],
56
+ domain_commits: list[domain_git_entities.GitCommit],
57
+ ) -> list[domain_git_entities.GitBranch]:
58
+ """Convert SQLAlchemy GitBranch to domain GitBranch."""
59
+ commit_map = {commit.commit_sha: commit for commit in domain_commits}
60
+ domain_branches = []
61
+ for db_branch in db_branches:
62
+ if db_branch.head_commit_sha not in commit_map:
63
+ raise ValueError(
64
+ f"Commit {db_branch.head_commit_sha} for "
65
+ f"branch {db_branch.name} not found in commits: {commit_map.keys()}"
66
+ )
67
+ domain_branch = domain_git_entities.GitBranch(
68
+ repo_id=db_branch.repo_id,
69
+ name=db_branch.name,
70
+ created_at=db_branch.created_at,
71
+ updated_at=db_branch.updated_at,
72
+ head_commit=commit_map[db_branch.head_commit_sha],
73
+ )
74
+ domain_branches.append(domain_branch)
75
+ return domain_branches
76
+
77
+ def to_domain_tags(
78
+ self,
79
+ db_tags: list[db_entities.GitTag],
80
+ domain_commits: list[domain_git_entities.GitCommit],
81
+ ) -> list[domain_git_entities.GitTag]:
82
+ """Convert SQLAlchemy GitTag to domain GitTag."""
83
+ commit_map = {commit.commit_sha: commit for commit in domain_commits}
84
+ domain_tags = []
85
+ for db_tag in db_tags:
86
+ if db_tag.target_commit_sha not in commit_map:
87
+ raise ValueError(
88
+ f"Commit {db_tag.target_commit_sha} for tag {db_tag.name} not found"
89
+ )
90
+ domain_tag = domain_git_entities.GitTag(
91
+ created_at=db_tag.created_at,
92
+ updated_at=db_tag.updated_at,
93
+ repo_id=db_tag.repo_id,
94
+ name=db_tag.name,
95
+ target_commit=commit_map[db_tag.target_commit_sha],
96
+ )
97
+ domain_tags.append(domain_tag)
98
+ return domain_tags
99
+
100
+ def to_domain_tracking_branch(
101
+ self,
102
+ db_tracking_branch: db_entities.GitTrackingBranch | None,
103
+ db_tracking_branch_entity: db_entities.GitBranch | None,
104
+ domain_commits: list[domain_git_entities.GitCommit],
105
+ ) -> domain_git_entities.GitBranch | None:
106
+ """Convert SQLAlchemy GitTrackingBranch to domain GitBranch."""
107
+ if db_tracking_branch is None or db_tracking_branch_entity is None:
108
+ return None
109
+
110
+ commit_map = {commit.commit_sha: commit for commit in domain_commits}
111
+ if db_tracking_branch_entity.head_commit_sha not in commit_map:
112
+ raise ValueError(
113
+ f"Commit {db_tracking_branch_entity.head_commit_sha} for "
114
+ f"tracking branch {db_tracking_branch.name} not found"
115
+ )
116
+
117
+ return domain_git_entities.GitBranch(
118
+ repo_id=db_tracking_branch_entity.repo_id,
119
+ name=db_tracking_branch_entity.name,
120
+ created_at=db_tracking_branch_entity.created_at,
121
+ updated_at=db_tracking_branch_entity.updated_at,
122
+ head_commit=commit_map[db_tracking_branch_entity.head_commit_sha],
123
+ )
124
+
125
+ def to_domain_git_repo( # noqa: PLR0913
126
+ self,
127
+ db_repo: db_entities.GitRepo,
128
+ db_tracking_branch_entity: db_entities.GitBranch | None,
129
+ db_commits: list[db_entities.GitCommit],
130
+ db_tags: list[db_entities.GitTag],
131
+ db_commit_files: list[db_entities.GitCommitFile],
132
+ db_tracking_branch: db_entities.GitTrackingBranch | None,
133
+ ) -> domain_git_entities.GitRepo:
134
+ """Convert SQLAlchemy GitRepo to domain GitRepo."""
135
+ # Build commits needed for tags and tracking branch
136
+ domain_commits = self.to_domain_commits(
137
+ db_commits=db_commits, db_commit_files=db_commit_files
138
+ )
139
+ self.to_domain_tags(
140
+ db_tags=db_tags, domain_commits=domain_commits
141
+ )
142
+ tracking_branch = self.to_domain_tracking_branch(
143
+ db_tracking_branch=db_tracking_branch,
144
+ db_tracking_branch_entity=db_tracking_branch_entity,
145
+ domain_commits=domain_commits,
146
+ )
147
+
148
+ from kodit.domain.factories.git_repo_factory import GitRepoFactory
149
+
150
+ return GitRepoFactory.create_from_components(
151
+ repo_id=db_repo.id,
152
+ created_at=db_repo.created_at,
153
+ updated_at=db_repo.updated_at,
154
+ sanitized_remote_uri=AnyUrl(db_repo.sanitized_remote_uri),
155
+ remote_uri=AnyUrl(db_repo.remote_uri),
156
+ tracking_branch=tracking_branch,
157
+ cloned_path=Path(db_repo.cloned_path) if db_repo.cloned_path else None,
158
+ last_scanned_at=db_repo.last_scanned_at,
159
+ num_commits=db_repo.num_commits,
160
+ num_branches=db_repo.num_branches,
161
+ num_tags=db_repo.num_tags,
162
+ )
163
+
164
+ def to_domain_commit_index(
165
+ self,
166
+ db_commit_index: db_entities.CommitIndex,
167
+ snippets: list[domain_git_entities.SnippetV2],
168
+ ) -> domain_git_entities.CommitIndex:
169
+ """Convert SQLAlchemy CommitIndex to domain CommitIndex."""
170
+ return domain_git_entities.CommitIndex(
171
+ commit_sha=db_commit_index.commit_sha,
172
+ created_at=db_commit_index.created_at,
173
+ updated_at=db_commit_index.updated_at,
174
+ snippets=snippets,
175
+ status=domain_git_entities.IndexStatus(db_commit_index.status),
176
+ indexed_at=db_commit_index.indexed_at,
177
+ error_message=db_commit_index.error_message,
178
+ files_processed=db_commit_index.files_processed,
179
+ processing_time_seconds=float(db_commit_index.processing_time_seconds),
180
+ )
181
+
182
+ def from_domain_commit_index(
183
+ self, domain_commit_index: domain_git_entities.CommitIndex
184
+ ) -> db_entities.CommitIndex:
185
+ """Convert domain CommitIndex to SQLAlchemy CommitIndex."""
186
+ return db_entities.CommitIndex(
187
+ commit_sha=domain_commit_index.commit_sha,
188
+ status=domain_commit_index.status,
189
+ indexed_at=domain_commit_index.indexed_at,
190
+ error_message=domain_commit_index.error_message,
191
+ files_processed=domain_commit_index.files_processed,
192
+ processing_time_seconds=domain_commit_index.processing_time_seconds,
193
+ )