kodit 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/app.py +2 -0
- kodit/application/factories/server_factory.py +58 -32
- kodit/application/services/code_search_application_service.py +89 -12
- kodit/application/services/commit_indexing_application_service.py +527 -195
- kodit/application/services/enrichment_query_service.py +311 -43
- kodit/application/services/indexing_worker_service.py +1 -1
- kodit/application/services/queue_service.py +15 -10
- kodit/application/services/sync_scheduler.py +2 -1
- kodit/domain/enrichments/architecture/architecture.py +1 -1
- kodit/domain/enrichments/architecture/database_schema/__init__.py +1 -0
- kodit/domain/enrichments/architecture/database_schema/database_schema.py +17 -0
- kodit/domain/enrichments/architecture/physical/physical.py +1 -1
- kodit/domain/enrichments/development/development.py +1 -1
- kodit/domain/enrichments/development/snippet/snippet.py +12 -5
- kodit/domain/enrichments/enrichment.py +31 -4
- kodit/domain/enrichments/history/__init__.py +1 -0
- kodit/domain/enrichments/history/commit_description/__init__.py +1 -0
- kodit/domain/enrichments/history/commit_description/commit_description.py +17 -0
- kodit/domain/enrichments/history/history.py +18 -0
- kodit/domain/enrichments/usage/api_docs.py +1 -1
- kodit/domain/enrichments/usage/usage.py +1 -1
- kodit/domain/entities/git.py +30 -25
- kodit/domain/factories/git_repo_factory.py +20 -5
- kodit/domain/protocols.py +60 -125
- kodit/domain/services/embedding_service.py +14 -16
- kodit/domain/services/git_repository_service.py +60 -38
- kodit/domain/services/git_service.py +18 -11
- kodit/domain/tracking/resolution_service.py +6 -16
- kodit/domain/value_objects.py +6 -9
- kodit/infrastructure/api/v1/dependencies.py +12 -3
- kodit/infrastructure/api/v1/query_params.py +27 -0
- kodit/infrastructure/api/v1/routers/commits.py +91 -85
- kodit/infrastructure/api/v1/routers/repositories.py +53 -37
- kodit/infrastructure/api/v1/routers/search.py +1 -1
- kodit/infrastructure/api/v1/schemas/enrichment.py +14 -0
- kodit/infrastructure/api/v1/schemas/repository.py +1 -1
- kodit/infrastructure/cloning/git/git_python_adaptor.py +41 -0
- kodit/infrastructure/database_schema/__init__.py +1 -0
- kodit/infrastructure/database_schema/database_schema_detector.py +268 -0
- kodit/infrastructure/slicing/api_doc_extractor.py +0 -2
- kodit/infrastructure/sqlalchemy/embedding_repository.py +44 -34
- kodit/infrastructure/sqlalchemy/enrichment_association_repository.py +73 -0
- kodit/infrastructure/sqlalchemy/enrichment_v2_repository.py +145 -97
- kodit/infrastructure/sqlalchemy/entities.py +12 -116
- kodit/infrastructure/sqlalchemy/git_branch_repository.py +52 -244
- kodit/infrastructure/sqlalchemy/git_commit_repository.py +35 -324
- kodit/infrastructure/sqlalchemy/git_file_repository.py +70 -0
- kodit/infrastructure/sqlalchemy/git_repository.py +60 -230
- kodit/infrastructure/sqlalchemy/git_tag_repository.py +53 -240
- kodit/infrastructure/sqlalchemy/query.py +331 -0
- kodit/infrastructure/sqlalchemy/repository.py +203 -0
- kodit/infrastructure/sqlalchemy/task_repository.py +79 -58
- kodit/infrastructure/sqlalchemy/task_status_repository.py +45 -52
- kodit/migrations/versions/4b1a3b2c8fa5_refactor_git_tracking.py +190 -0
- {kodit-0.5.4.dist-info → kodit-0.5.6.dist-info}/METADATA +1 -1
- {kodit-0.5.4.dist-info → kodit-0.5.6.dist-info}/RECORD +60 -50
- kodit/infrastructure/mappers/enrichment_mapper.py +0 -83
- kodit/infrastructure/mappers/git_mapper.py +0 -193
- kodit/infrastructure/mappers/snippet_mapper.py +0 -104
- kodit/infrastructure/sqlalchemy/snippet_v2_repository.py +0 -479
- {kodit-0.5.4.dist-info → kodit-0.5.6.dist-info}/WHEEL +0 -0
- {kodit-0.5.4.dist-info → kodit-0.5.6.dist-info}/entry_points.txt +0 -0
- {kodit-0.5.4.dist-info → kodit-0.5.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""Database schema detector for discovering database schemas in a repository."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import ClassVar
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DatabaseSchemaDetector:
|
|
9
|
+
"""Detects database schemas from various sources in a repository."""
|
|
10
|
+
|
|
11
|
+
# File patterns to look for
|
|
12
|
+
MIGRATION_PATTERNS: ClassVar[list[str]] = [
|
|
13
|
+
"**/migrations/**/*.sql",
|
|
14
|
+
"**/migrations/**/*.py",
|
|
15
|
+
"**/migrate/**/*.sql",
|
|
16
|
+
"**/migrate/**/*.go",
|
|
17
|
+
"**/db/migrate/**/*.rb",
|
|
18
|
+
"**/alembic/versions/**/*.py",
|
|
19
|
+
"**/liquibase/**/*.xml",
|
|
20
|
+
"**/flyway/**/*.sql",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
SQL_FILE_PATTERNS: ClassVar[list[str]] = [
|
|
24
|
+
"**/*.sql",
|
|
25
|
+
"**/schema/**/*.sql",
|
|
26
|
+
"**/schemas/**/*.sql",
|
|
27
|
+
"**/database/**/*.sql",
|
|
28
|
+
"**/db/**/*.sql",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
ORM_MODEL_PATTERNS: ClassVar[list[str]] = [
|
|
32
|
+
"**/models/**/*.py", # SQLAlchemy, Django
|
|
33
|
+
"**/models/**/*.go", # GORM
|
|
34
|
+
"**/entities/**/*.py", # SQLAlchemy
|
|
35
|
+
"**/entities/**/*.ts", # TypeORM
|
|
36
|
+
"**/entities/**/*.js", # TypeORM/Sequelize
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
# Regex patterns for schema detection
|
|
40
|
+
CREATE_TABLE_PATTERN = re.compile(
|
|
41
|
+
r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?[`\"]?(\w+)[`\"]?",
|
|
42
|
+
re.IGNORECASE,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
SQLALCHEMY_MODEL_PATTERN = re.compile(
|
|
46
|
+
r"class\s+(\w+)\s*\([^)]*(?:Base|Model|db\.Model)[^)]*\):",
|
|
47
|
+
re.MULTILINE,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
GORM_MODEL_PATTERN = re.compile(
|
|
51
|
+
r"type\s+(\w+)\s+struct\s*{[^}]*gorm\.Model",
|
|
52
|
+
re.MULTILINE | re.DOTALL,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
TYPEORM_ENTITY_PATTERN = re.compile(
|
|
56
|
+
r"@Entity\([^)]*\)\s*(?:export\s+)?class\s+(\w+)",
|
|
57
|
+
re.MULTILINE,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
async def discover_schemas(self, repo_path: Path) -> str:
|
|
61
|
+
"""Discover database schemas and generate a structured report."""
|
|
62
|
+
findings: dict[str, set[str] | list[str] | list[dict] | None] = {
|
|
63
|
+
"tables": set(),
|
|
64
|
+
"migration_files": [],
|
|
65
|
+
"sql_files": [],
|
|
66
|
+
"orm_models": [],
|
|
67
|
+
"orm_type": None,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
# Detect migration files
|
|
71
|
+
await self._detect_migrations(repo_path, findings)
|
|
72
|
+
|
|
73
|
+
# Detect SQL schema files
|
|
74
|
+
await self._detect_sql_files(repo_path, findings)
|
|
75
|
+
|
|
76
|
+
# Detect ORM models
|
|
77
|
+
await self._detect_orm_models(repo_path, findings)
|
|
78
|
+
|
|
79
|
+
# Generate report
|
|
80
|
+
return self._generate_report(findings)
|
|
81
|
+
|
|
82
|
+
async def _detect_migrations(self, repo_path: Path, findings: dict) -> None:
|
|
83
|
+
"""Detect migration files."""
|
|
84
|
+
for pattern in self.MIGRATION_PATTERNS:
|
|
85
|
+
for file_path in repo_path.glob(pattern):
|
|
86
|
+
if file_path.is_file():
|
|
87
|
+
findings["migration_files"].append(str(file_path.relative_to(repo_path)))
|
|
88
|
+
# Try to extract table names from migrations
|
|
89
|
+
await self._extract_tables_from_file(file_path, findings)
|
|
90
|
+
|
|
91
|
+
async def _detect_sql_files(self, repo_path: Path, findings: dict) -> None:
|
|
92
|
+
"""Detect SQL schema files."""
|
|
93
|
+
migration_paths = set(findings["migration_files"])
|
|
94
|
+
|
|
95
|
+
for pattern in self.SQL_FILE_PATTERNS:
|
|
96
|
+
for file_path in repo_path.glob(pattern):
|
|
97
|
+
if file_path.is_file():
|
|
98
|
+
rel_path = str(file_path.relative_to(repo_path))
|
|
99
|
+
# Skip if already counted as migration
|
|
100
|
+
if rel_path not in migration_paths:
|
|
101
|
+
findings["sql_files"].append(rel_path)
|
|
102
|
+
await self._extract_tables_from_file(file_path, findings)
|
|
103
|
+
|
|
104
|
+
async def _detect_orm_models(self, repo_path: Path, findings: dict) -> None:
|
|
105
|
+
"""Detect ORM model files."""
|
|
106
|
+
for pattern in self.ORM_MODEL_PATTERNS:
|
|
107
|
+
for file_path in repo_path.glob(pattern):
|
|
108
|
+
if file_path.is_file():
|
|
109
|
+
rel_path = str(file_path.relative_to(repo_path))
|
|
110
|
+
models = await self._extract_orm_models(file_path)
|
|
111
|
+
if models:
|
|
112
|
+
findings["orm_models"].append({
|
|
113
|
+
"file": rel_path,
|
|
114
|
+
"models": models,
|
|
115
|
+
})
|
|
116
|
+
findings["tables"].update(models)
|
|
117
|
+
|
|
118
|
+
async def _extract_tables_from_file(self, file_path: Path, findings: dict) -> None:
|
|
119
|
+
"""Extract table names from SQL or migration files."""
|
|
120
|
+
try:
|
|
121
|
+
content = file_path.read_text(encoding="utf-8", errors="ignore")
|
|
122
|
+
|
|
123
|
+
# Look for CREATE TABLE statements
|
|
124
|
+
for match in self.CREATE_TABLE_PATTERN.finditer(content):
|
|
125
|
+
table_name = match.group(1)
|
|
126
|
+
findings["tables"].add(table_name)
|
|
127
|
+
|
|
128
|
+
except (OSError, UnicodeDecodeError):
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
async def _extract_orm_models(self, file_path: Path) -> list[str]:
|
|
132
|
+
"""Extract ORM model names from model files."""
|
|
133
|
+
models: list[str] = []
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
content = file_path.read_text(encoding="utf-8", errors="ignore")
|
|
137
|
+
suffix = file_path.suffix
|
|
138
|
+
|
|
139
|
+
if suffix == ".py":
|
|
140
|
+
# SQLAlchemy or Django models
|
|
141
|
+
models.extend(
|
|
142
|
+
match.group(1)
|
|
143
|
+
for match in self.SQLALCHEMY_MODEL_PATTERN.finditer(content)
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
elif suffix == ".go":
|
|
147
|
+
# GORM models
|
|
148
|
+
models.extend(
|
|
149
|
+
match.group(1)
|
|
150
|
+
for match in self.GORM_MODEL_PATTERN.finditer(content)
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
elif suffix in [".ts", ".js"]:
|
|
154
|
+
# TypeORM entities
|
|
155
|
+
models.extend(
|
|
156
|
+
match.group(1)
|
|
157
|
+
for match in self.TYPEORM_ENTITY_PATTERN.finditer(content)
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
except (OSError, UnicodeDecodeError):
|
|
161
|
+
pass
|
|
162
|
+
|
|
163
|
+
return models
|
|
164
|
+
|
|
165
|
+
def _generate_report(self, findings: dict) -> str: # noqa: PLR0915, C901, PLR0912
|
|
166
|
+
"""Generate a structured report of database schema findings."""
|
|
167
|
+
lines = []
|
|
168
|
+
|
|
169
|
+
# Summary
|
|
170
|
+
lines.append("# Database Schema Discovery Report")
|
|
171
|
+
lines.append("")
|
|
172
|
+
|
|
173
|
+
has_findings = (
|
|
174
|
+
findings["tables"]
|
|
175
|
+
or findings["migration_files"]
|
|
176
|
+
or findings["sql_files"]
|
|
177
|
+
or findings["orm_models"]
|
|
178
|
+
)
|
|
179
|
+
if not has_findings:
|
|
180
|
+
lines.append("No database schemas detected in this repository.")
|
|
181
|
+
return "\n".join(lines)
|
|
182
|
+
|
|
183
|
+
# Tables/Entities found
|
|
184
|
+
if findings["tables"]:
|
|
185
|
+
lines.append(f"## Detected Tables/Entities ({len(findings['tables'])})")
|
|
186
|
+
lines.append("")
|
|
187
|
+
lines.extend(f"- {table}" for table in sorted(findings["tables"]))
|
|
188
|
+
lines.append("")
|
|
189
|
+
|
|
190
|
+
# Migration files
|
|
191
|
+
if findings["migration_files"]:
|
|
192
|
+
lines.append(f"## Migration Files ({len(findings['migration_files'])})")
|
|
193
|
+
lines.append("")
|
|
194
|
+
lines.append(
|
|
195
|
+
"Database migrations detected, suggesting schema evolution over time:"
|
|
196
|
+
)
|
|
197
|
+
lines.extend(
|
|
198
|
+
f"- {mig_file}" for mig_file in findings["migration_files"][:10]
|
|
199
|
+
)
|
|
200
|
+
if len(findings["migration_files"]) > 10:
|
|
201
|
+
lines.append(f"- ... and {len(findings['migration_files']) - 10} more")
|
|
202
|
+
lines.append("")
|
|
203
|
+
|
|
204
|
+
# SQL files
|
|
205
|
+
if findings["sql_files"]:
|
|
206
|
+
lines.append(f"## SQL Schema Files ({len(findings['sql_files'])})")
|
|
207
|
+
lines.append("")
|
|
208
|
+
lines.extend(f"- {sql_file}" for sql_file in findings["sql_files"][:10])
|
|
209
|
+
if len(findings["sql_files"]) > 10:
|
|
210
|
+
lines.append(f"- ... and {len(findings['sql_files']) - 10} more")
|
|
211
|
+
lines.append("")
|
|
212
|
+
|
|
213
|
+
# ORM models
|
|
214
|
+
if findings["orm_models"]:
|
|
215
|
+
lines.append(f"## ORM Models ({len(findings['orm_models'])} files)")
|
|
216
|
+
lines.append("")
|
|
217
|
+
lines.append(
|
|
218
|
+
"ORM models detected, suggesting object-relational mapping:"
|
|
219
|
+
)
|
|
220
|
+
for orm_info in findings["orm_models"][:10]: # Limit to first 10
|
|
221
|
+
model_names = ", ".join(orm_info["models"][:5])
|
|
222
|
+
lines.append(f"- {orm_info['file']}: {model_names}")
|
|
223
|
+
if len(orm_info["models"]) > 5:
|
|
224
|
+
lines.append(f" (and {len(orm_info['models']) - 5} more models)")
|
|
225
|
+
if len(findings["orm_models"]) > 10:
|
|
226
|
+
lines.append(f"- ... and {len(findings['orm_models']) - 10} more files")
|
|
227
|
+
lines.append("")
|
|
228
|
+
|
|
229
|
+
# Inferred database type
|
|
230
|
+
lines.append("## Inferred Information")
|
|
231
|
+
lines.append("")
|
|
232
|
+
|
|
233
|
+
mig_files_str = str(findings.get("migration_files", []))
|
|
234
|
+
mig_files = findings.get("migration_files", [])
|
|
235
|
+
|
|
236
|
+
if "alembic" in mig_files_str:
|
|
237
|
+
lines.append("- Migration framework: Alembic (Python/SQLAlchemy)")
|
|
238
|
+
elif "django" in mig_files_str or any(
|
|
239
|
+
"migrations" in f and f.endswith(".py") for f in mig_files
|
|
240
|
+
):
|
|
241
|
+
lines.append("- Migration framework: Django Migrations")
|
|
242
|
+
elif any(".go" in f for f in mig_files):
|
|
243
|
+
lines.append(
|
|
244
|
+
"- Migration framework: Go-based migrations (golang-migrate)"
|
|
245
|
+
)
|
|
246
|
+
elif "flyway" in mig_files_str:
|
|
247
|
+
lines.append("- Migration framework: Flyway")
|
|
248
|
+
elif "liquibase" in mig_files_str:
|
|
249
|
+
lines.append("- Migration framework: Liquibase")
|
|
250
|
+
|
|
251
|
+
if findings["orm_models"]:
|
|
252
|
+
orm_models = findings["orm_models"]
|
|
253
|
+
py_models = sum(1 for m in orm_models if m["file"].endswith(".py"))
|
|
254
|
+
go_models = sum(1 for m in orm_models if m["file"].endswith(".go"))
|
|
255
|
+
ts_models = sum(
|
|
256
|
+
1 for m in orm_models if m["file"].endswith((".ts", ".js"))
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
if py_models > 0:
|
|
260
|
+
lines.append("- ORM: Python (likely SQLAlchemy or Django ORM)")
|
|
261
|
+
if go_models > 0:
|
|
262
|
+
lines.append("- ORM: Go (likely GORM)")
|
|
263
|
+
if ts_models > 0:
|
|
264
|
+
lines.append(
|
|
265
|
+
"- ORM: TypeScript/JavaScript (likely TypeORM or Sequelize)"
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
return "\n".join(lines)
|
|
@@ -39,7 +39,6 @@ class APIDocExtractor:
|
|
|
39
39
|
self,
|
|
40
40
|
files: list[GitFile],
|
|
41
41
|
language: str,
|
|
42
|
-
commit_sha: str,
|
|
43
42
|
include_private: bool = False, # noqa: FBT001, FBT002
|
|
44
43
|
) -> list[APIDocEnrichment]:
|
|
45
44
|
"""Extract API documentation enrichments from files.
|
|
@@ -93,7 +92,6 @@ class APIDocExtractor:
|
|
|
93
92
|
)
|
|
94
93
|
|
|
95
94
|
enrichment = APIDocEnrichment(
|
|
96
|
-
entity_id=commit_sha,
|
|
97
95
|
language=language,
|
|
98
96
|
content=markdown_content,
|
|
99
97
|
)
|
|
@@ -1,12 +1,15 @@
|
|
|
1
1
|
"""SQLAlchemy implementation of embedding repository."""
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
|
+
from typing import Any
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
from sqlalchemy import select
|
|
7
8
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
8
9
|
|
|
9
10
|
from kodit.infrastructure.sqlalchemy.entities import Embedding, EmbeddingType
|
|
11
|
+
from kodit.infrastructure.sqlalchemy.query import FilterOperator, QueryBuilder
|
|
12
|
+
from kodit.infrastructure.sqlalchemy.repository import SqlAlchemyRepository
|
|
10
13
|
from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
|
|
11
14
|
|
|
12
15
|
|
|
@@ -17,68 +20,75 @@ def create_embedding_repository(
|
|
|
17
20
|
return SqlAlchemyEmbeddingRepository(session_factory=session_factory)
|
|
18
21
|
|
|
19
22
|
|
|
20
|
-
class SqlAlchemyEmbeddingRepository:
|
|
23
|
+
class SqlAlchemyEmbeddingRepository(SqlAlchemyRepository[Embedding, Embedding]):
|
|
21
24
|
"""SQLAlchemy implementation of embedding repository."""
|
|
22
25
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
+
@property
|
|
27
|
+
def db_entity_type(self) -> type[Embedding]:
|
|
28
|
+
"""The SQLAlchemy model type."""
|
|
29
|
+
return Embedding
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def to_domain(db_entity: Embedding) -> Embedding:
|
|
33
|
+
"""Map database entity to domain entity."""
|
|
34
|
+
return db_entity
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def to_db(domain_entity: Embedding) -> Embedding:
|
|
38
|
+
"""Map domain entity to database entity."""
|
|
39
|
+
return domain_entity
|
|
40
|
+
|
|
41
|
+
def _get_id(self, entity: Embedding) -> Any:
|
|
42
|
+
"""Extract ID from domain entity."""
|
|
43
|
+
return entity.id
|
|
26
44
|
|
|
27
45
|
async def create_embedding(self, embedding: Embedding) -> None:
|
|
28
46
|
"""Create a new embedding record in the database."""
|
|
29
|
-
|
|
30
|
-
session.add(embedding)
|
|
47
|
+
await self.save(embedding)
|
|
31
48
|
|
|
32
49
|
async def get_embedding_by_snippet_id_and_type(
|
|
33
50
|
self, snippet_id: int, embedding_type: EmbeddingType
|
|
34
51
|
) -> Embedding | None:
|
|
35
52
|
"""Get an embedding by its snippet ID and type."""
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
53
|
+
query = (
|
|
54
|
+
QueryBuilder()
|
|
55
|
+
.filter("snippet_id", FilterOperator.EQ, snippet_id)
|
|
56
|
+
.filter("type", FilterOperator.EQ, embedding_type)
|
|
57
|
+
)
|
|
58
|
+
results = await self.find(query)
|
|
59
|
+
return results[0] if results else None
|
|
43
60
|
|
|
44
61
|
async def list_embeddings_by_type(
|
|
45
62
|
self, embedding_type: EmbeddingType
|
|
46
63
|
) -> list[Embedding]:
|
|
47
64
|
"""List all embeddings of a given type."""
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
result = await session.execute(query)
|
|
51
|
-
return list(result.scalars())
|
|
65
|
+
query = QueryBuilder().filter("type", FilterOperator.EQ, embedding_type)
|
|
66
|
+
return await self.find(query)
|
|
52
67
|
|
|
53
68
|
async def delete_embeddings_by_snippet_id(self, snippet_id: str) -> None:
|
|
54
69
|
"""Delete all embeddings for a snippet."""
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
for embedding in embeddings:
|
|
60
|
-
await session.delete(embedding)
|
|
70
|
+
query = QueryBuilder().filter("snippet_id", FilterOperator.EQ, snippet_id)
|
|
71
|
+
embeddings = await self.find(query)
|
|
72
|
+
for embedding in embeddings:
|
|
73
|
+
await self.delete(embedding)
|
|
61
74
|
|
|
62
75
|
async def list_embeddings_by_snippet_ids_and_type(
|
|
63
76
|
self, snippet_ids: list[str], embedding_type: EmbeddingType
|
|
64
77
|
) -> list[Embedding]:
|
|
65
78
|
"""Get all embeddings for the given snippet IDs."""
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
return list(result.scalars())
|
|
79
|
+
query = (
|
|
80
|
+
QueryBuilder()
|
|
81
|
+
.filter("snippet_id", FilterOperator.IN, snippet_ids)
|
|
82
|
+
.filter("type", FilterOperator.EQ, embedding_type)
|
|
83
|
+
)
|
|
84
|
+
return await self.find(query)
|
|
73
85
|
|
|
74
86
|
async def get_embeddings_by_snippet_ids(
|
|
75
87
|
self, snippet_ids: list[str]
|
|
76
88
|
) -> list[Embedding]:
|
|
77
89
|
"""Get all embeddings for the given snippet IDs."""
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
result = await session.execute(query)
|
|
81
|
-
return list(result.scalars())
|
|
90
|
+
query = QueryBuilder().filter("snippet_id", FilterOperator.IN, snippet_ids)
|
|
91
|
+
return await self.find(query)
|
|
82
92
|
|
|
83
93
|
async def list_semantic_results(
|
|
84
94
|
self,
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Enrichment association repository."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
|
+
|
|
8
|
+
from kodit.domain.enrichments.enrichment import (
|
|
9
|
+
EnrichmentAssociation,
|
|
10
|
+
)
|
|
11
|
+
from kodit.domain.protocols import EnrichmentAssociationRepository
|
|
12
|
+
from kodit.infrastructure.sqlalchemy import entities as db_entities
|
|
13
|
+
from kodit.infrastructure.sqlalchemy.repository import SqlAlchemyRepository
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def create_enrichment_association_repository(
|
|
17
|
+
session_factory: Callable[[], AsyncSession],
|
|
18
|
+
) -> EnrichmentAssociationRepository:
|
|
19
|
+
"""Create a enrichment association repository."""
|
|
20
|
+
return SQLAlchemyEnrichmentAssociationRepository(session_factory=session_factory)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SQLAlchemyEnrichmentAssociationRepository(
|
|
24
|
+
SqlAlchemyRepository[EnrichmentAssociation, db_entities.EnrichmentAssociation],
|
|
25
|
+
EnrichmentAssociationRepository,
|
|
26
|
+
):
|
|
27
|
+
"""Repository for managing enrichment associations."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, session_factory: Callable[[], AsyncSession]) -> None:
|
|
30
|
+
"""Initialize the repository."""
|
|
31
|
+
super().__init__(session_factory=session_factory)
|
|
32
|
+
self._log = structlog.get_logger(__name__)
|
|
33
|
+
|
|
34
|
+
def _get_id(self, entity: EnrichmentAssociation) -> int | None:
|
|
35
|
+
"""Get the ID of an enrichment association."""
|
|
36
|
+
return entity.id
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def db_entity_type(self) -> type[db_entities.EnrichmentAssociation]:
|
|
40
|
+
"""The SQLAlchemy model type."""
|
|
41
|
+
return db_entities.EnrichmentAssociation
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def to_domain(
|
|
45
|
+
db_entity: db_entities.EnrichmentAssociation,
|
|
46
|
+
) -> EnrichmentAssociation:
|
|
47
|
+
"""Map database entity to domain entity."""
|
|
48
|
+
return EnrichmentAssociation(
|
|
49
|
+
enrichment_id=db_entity.enrichment_id,
|
|
50
|
+
entity_type=db_entity.entity_type,
|
|
51
|
+
entity_id=db_entity.entity_id,
|
|
52
|
+
id=db_entity.id,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
@staticmethod
|
|
56
|
+
def to_db(
|
|
57
|
+
domain_entity: EnrichmentAssociation,
|
|
58
|
+
) -> db_entities.EnrichmentAssociation:
|
|
59
|
+
"""Map domain entity to database entity."""
|
|
60
|
+
from datetime import UTC, datetime
|
|
61
|
+
|
|
62
|
+
now = datetime.now(UTC)
|
|
63
|
+
db_entity = db_entities.EnrichmentAssociation(
|
|
64
|
+
enrichment_id=domain_entity.enrichment_id,
|
|
65
|
+
entity_type=domain_entity.entity_type,
|
|
66
|
+
entity_id=domain_entity.entity_id,
|
|
67
|
+
)
|
|
68
|
+
if domain_entity.id is not None:
|
|
69
|
+
db_entity.id = domain_entity.id
|
|
70
|
+
# Always set timestamps since domain entity doesn't track them
|
|
71
|
+
db_entity.created_at = now
|
|
72
|
+
db_entity.updated_at = now
|
|
73
|
+
return db_entity
|