kodit 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (70) hide show
  1. kodit/_version.py +2 -2
  2. kodit/app.py +10 -12
  3. kodit/application/factories/server_factory.py +78 -11
  4. kodit/application/services/commit_indexing_application_service.py +188 -31
  5. kodit/application/services/enrichment_query_service.py +95 -0
  6. kodit/config.py +3 -3
  7. kodit/domain/enrichments/__init__.py +1 -0
  8. kodit/domain/enrichments/architecture/__init__.py +1 -0
  9. kodit/domain/enrichments/architecture/architecture.py +20 -0
  10. kodit/domain/enrichments/architecture/physical/__init__.py +1 -0
  11. kodit/domain/enrichments/architecture/physical/discovery_notes.py +14 -0
  12. kodit/domain/enrichments/architecture/physical/formatter.py +11 -0
  13. kodit/domain/enrichments/architecture/physical/physical.py +17 -0
  14. kodit/domain/enrichments/development/__init__.py +1 -0
  15. kodit/domain/enrichments/development/development.py +18 -0
  16. kodit/domain/enrichments/development/snippet/__init__.py +1 -0
  17. kodit/domain/enrichments/development/snippet/snippet.py +21 -0
  18. kodit/domain/enrichments/enricher.py +17 -0
  19. kodit/domain/enrichments/enrichment.py +39 -0
  20. kodit/domain/enrichments/request.py +12 -0
  21. kodit/domain/enrichments/response.py +11 -0
  22. kodit/domain/enrichments/usage/__init__.py +1 -0
  23. kodit/domain/enrichments/usage/api_docs.py +19 -0
  24. kodit/domain/enrichments/usage/usage.py +18 -0
  25. kodit/domain/protocols.py +7 -6
  26. kodit/domain/services/enrichment_service.py +9 -30
  27. kodit/domain/services/physical_architecture_service.py +182 -0
  28. kodit/domain/tracking/__init__.py +1 -0
  29. kodit/domain/tracking/resolution_service.py +81 -0
  30. kodit/domain/tracking/trackable.py +21 -0
  31. kodit/domain/value_objects.py +6 -23
  32. kodit/infrastructure/api/v1/dependencies.py +15 -0
  33. kodit/infrastructure/api/v1/routers/commits.py +81 -0
  34. kodit/infrastructure/api/v1/routers/repositories.py +99 -0
  35. kodit/infrastructure/api/v1/schemas/enrichment.py +29 -0
  36. kodit/infrastructure/cloning/git/git_python_adaptor.py +71 -4
  37. kodit/infrastructure/enricher/__init__.py +1 -0
  38. kodit/infrastructure/enricher/enricher_factory.py +53 -0
  39. kodit/infrastructure/{enrichment/litellm_enrichment_provider.py → enricher/litellm_enricher.py} +20 -33
  40. kodit/infrastructure/{enrichment/local_enrichment_provider.py → enricher/local_enricher.py} +19 -24
  41. kodit/infrastructure/enricher/null_enricher.py +36 -0
  42. kodit/infrastructure/mappers/enrichment_mapper.py +83 -0
  43. kodit/infrastructure/mappers/snippet_mapper.py +20 -22
  44. kodit/infrastructure/physical_architecture/__init__.py +1 -0
  45. kodit/infrastructure/physical_architecture/detectors/__init__.py +1 -0
  46. kodit/infrastructure/physical_architecture/detectors/docker_compose_detector.py +336 -0
  47. kodit/infrastructure/physical_architecture/formatters/__init__.py +1 -0
  48. kodit/infrastructure/physical_architecture/formatters/narrative_formatter.py +149 -0
  49. kodit/infrastructure/slicing/api_doc_extractor.py +836 -0
  50. kodit/infrastructure/slicing/ast_analyzer.py +1128 -0
  51. kodit/infrastructure/slicing/slicer.py +56 -391
  52. kodit/infrastructure/sqlalchemy/enrichment_v2_repository.py +118 -0
  53. kodit/infrastructure/sqlalchemy/entities.py +46 -38
  54. kodit/infrastructure/sqlalchemy/git_branch_repository.py +22 -11
  55. kodit/infrastructure/sqlalchemy/git_commit_repository.py +23 -14
  56. kodit/infrastructure/sqlalchemy/git_repository.py +27 -17
  57. kodit/infrastructure/sqlalchemy/git_tag_repository.py +22 -11
  58. kodit/infrastructure/sqlalchemy/snippet_v2_repository.py +101 -106
  59. kodit/migrations/versions/19f8c7faf8b9_add_generic_enrichment_type.py +260 -0
  60. kodit/utils/dump_config.py +361 -0
  61. kodit/utils/dump_openapi.py +5 -6
  62. {kodit-0.5.0.dist-info → kodit-0.5.2.dist-info}/METADATA +1 -1
  63. {kodit-0.5.0.dist-info → kodit-0.5.2.dist-info}/RECORD +67 -32
  64. kodit/infrastructure/enrichment/__init__.py +0 -1
  65. kodit/infrastructure/enrichment/enrichment_factory.py +0 -52
  66. kodit/infrastructure/enrichment/null_enrichment_provider.py +0 -19
  67. /kodit/infrastructure/{enrichment → enricher}/utils.py +0 -0
  68. {kodit-0.5.0.dist-info → kodit-0.5.2.dist-info}/WHEEL +0 -0
  69. {kodit-0.5.0.dist-info → kodit-0.5.2.dist-info}/entry_points.txt +0 -0
  70. {kodit-0.5.0.dist-info → kodit-0.5.2.dist-info}/licenses/LICENSE +0 -0
@@ -8,14 +8,19 @@ from collections import defaultdict
8
8
  from collections.abc import Generator
9
9
  from dataclasses import dataclass, field
10
10
  from pathlib import Path
11
- from typing import Any, ClassVar
11
+ from typing import Any
12
12
 
13
13
  import structlog
14
14
  from tree_sitter import Node, Parser, Tree
15
- from tree_sitter_language_pack import get_language
16
15
 
17
16
  from kodit.domain.entities.git import GitFile, SnippetV2
18
17
  from kodit.domain.value_objects import LanguageMapping
18
+ from kodit.infrastructure.slicing.ast_analyzer import (
19
+ ASTAnalyzer,
20
+ FunctionDefinition,
21
+ LanguageConfig,
22
+ ParsedFile,
23
+ )
19
24
 
20
25
 
21
26
  @dataclass
@@ -43,105 +48,6 @@ class AnalyzerState:
43
48
  )
44
49
 
45
50
 
46
- class LanguageConfig:
47
- """Language-specific configuration."""
48
-
49
- CONFIGS: ClassVar[dict[str, dict[str, Any]]] = {
50
- "python": {
51
- "function_nodes": ["function_definition"],
52
- "method_nodes": [],
53
- "call_node": "call",
54
- "import_nodes": ["import_statement", "import_from_statement"],
55
- "extension": ".py",
56
- "name_field": None, # Use identifier child
57
- },
58
- "java": {
59
- "function_nodes": ["method_declaration"],
60
- "method_nodes": [],
61
- "call_node": "method_invocation",
62
- "import_nodes": ["import_declaration"],
63
- "extension": ".java",
64
- "name_field": None,
65
- },
66
- "c": {
67
- "function_nodes": ["function_definition"],
68
- "method_nodes": [],
69
- "call_node": "call_expression",
70
- "import_nodes": ["preproc_include"],
71
- "extension": ".c",
72
- "name_field": "declarator",
73
- },
74
- "cpp": {
75
- "function_nodes": ["function_definition"],
76
- "method_nodes": [],
77
- "call_node": "call_expression",
78
- "import_nodes": ["preproc_include", "using_declaration"],
79
- "extension": ".cpp",
80
- "name_field": "declarator",
81
- },
82
- "rust": {
83
- "function_nodes": ["function_item"],
84
- "method_nodes": [],
85
- "call_node": "call_expression",
86
- "import_nodes": ["use_declaration", "extern_crate_declaration"],
87
- "extension": ".rs",
88
- "name_field": "name",
89
- },
90
- "go": {
91
- "function_nodes": ["function_declaration"],
92
- "method_nodes": ["method_declaration"],
93
- "call_node": "call_expression",
94
- "import_nodes": ["import_declaration"],
95
- "extension": ".go",
96
- "name_field": None,
97
- },
98
- "javascript": {
99
- "function_nodes": [
100
- "function_declaration",
101
- "function_expression",
102
- "arrow_function",
103
- ],
104
- "method_nodes": [],
105
- "call_node": "call_expression",
106
- "import_nodes": ["import_statement", "import_declaration"],
107
- "extension": ".js",
108
- "name_field": None,
109
- },
110
- "csharp": {
111
- "function_nodes": ["method_declaration"],
112
- "method_nodes": ["constructor_declaration"],
113
- "call_node": "invocation_expression",
114
- "import_nodes": ["using_directive"],
115
- "extension": ".cs",
116
- "name_field": None,
117
- },
118
- "html": {
119
- "function_nodes": ["script_element", "style_element"],
120
- "method_nodes": ["element"], # Elements with id/class attributes
121
- "call_node": "attribute",
122
- "import_nodes": ["script_element", "element"], # script and link elements
123
- "extension": ".html",
124
- "name_field": None,
125
- },
126
- "css": {
127
- "function_nodes": ["rule_set", "keyframes_statement"],
128
- "method_nodes": ["media_statement"],
129
- "call_node": "call_expression",
130
- "import_nodes": ["import_statement"],
131
- "extension": ".css",
132
- "name_field": None,
133
- },
134
- }
135
-
136
- # Aliases
137
- CONFIGS["c++"] = CONFIGS["cpp"]
138
- CONFIGS["typescript"] = CONFIGS["javascript"]
139
- CONFIGS["ts"] = CONFIGS["javascript"]
140
- CONFIGS["js"] = CONFIGS["javascript"]
141
- CONFIGS["c#"] = CONFIGS["csharp"]
142
- CONFIGS["cs"] = CONFIGS["csharp"]
143
-
144
-
145
51
  class Slicer:
146
52
  """Slicer that extracts code snippets from files."""
147
53
 
@@ -149,7 +55,7 @@ class Slicer:
149
55
  """Initialize an empty slicer."""
150
56
  self.log = structlog.get_logger(__name__)
151
57
 
152
- def extract_snippets_from_git_files( # noqa: C901
58
+ def extract_snippets_from_git_files(
153
59
  self, files: list[GitFile], language: str = "python"
154
60
  ) -> list[SnippetV2]:
155
61
  """Extract code snippets from a list of files.
@@ -171,24 +77,15 @@ class Slicer:
171
77
 
172
78
  language = language.lower()
173
79
 
174
- # Get language configuration
175
- if language not in LanguageConfig.CONFIGS:
176
- self.log.debug("Skipping", language=language)
177
- return []
178
-
179
- config = LanguageConfig.CONFIGS[language]
180
-
181
- # Initialize tree-sitter
182
- tree_sitter_name = self._get_tree_sitter_language_name(language)
80
+ # Initialize ASTAnalyzer
183
81
  try:
184
- ts_language = get_language(tree_sitter_name) # type: ignore[arg-type]
185
- parser = Parser(ts_language)
186
- except Exception as e:
187
- raise RuntimeError(f"Failed to load {language} parser: {e}") from e
82
+ analyzer = ASTAnalyzer(language)
83
+ except ValueError:
84
+ self.log.debug("Skipping unsupported language", language=language)
85
+ return []
188
86
 
189
- # Create mapping from Paths to File objects and extract paths
87
+ # Validate files
190
88
  path_to_file_map: dict[Path, GitFile] = {}
191
- file_paths: list[Path] = []
192
89
  for file in files:
193
90
  file_path = Path(file.path)
194
91
 
@@ -201,30 +98,26 @@ class Slicer:
201
98
  raise FileNotFoundError(f"File not found: {file_path}")
202
99
 
203
100
  path_to_file_map[file_path] = file
204
- file_paths.append(file_path)
205
101
 
206
- # Initialize state
207
- state = AnalyzerState(parser=parser)
208
- state.files = file_paths
209
- file_contents: dict[Path, str] = {}
102
+ # Parse files and extract definitions using ASTAnalyzer
103
+ parsed_files = analyzer.parse_files(files)
104
+ if not parsed_files:
105
+ return []
210
106
 
211
- # Parse all files
212
- for file_path in file_paths:
213
- try:
214
- with file_path.open("rb") as f:
215
- source_code = f.read()
216
- tree = state.parser.parse(source_code)
217
- state.asts[file_path] = tree
218
- except OSError:
219
- # Skip files that can't be parsed
220
- continue
107
+ functions, _, _ = analyzer.extract_definitions(
108
+ parsed_files, include_private=True
109
+ )
110
+
111
+ # Build state from ASTAnalyzer results
112
+ state = self._build_state_from_ast_analyzer(parsed_files, functions)
113
+ config = LanguageConfig.CONFIGS[language]
221
114
 
222
- # Build indexes
223
- self._build_definition_and_import_indexes(state, config, language)
115
+ # Build call graph and snippets (Slicer-specific logic)
224
116
  self._build_call_graph(state, config)
225
117
  self._build_reverse_call_graph(state)
226
118
 
227
119
  # Extract snippets for all functions
120
+ file_contents: dict[Path, str] = {}
228
121
  snippets: list[SnippetV2] = []
229
122
  for qualified_name in state.def_index:
230
123
  snippet_content = self._get_snippet(
@@ -254,55 +147,35 @@ class Slicer:
254
147
  # Extension not supported, so it doesn't match any language
255
148
  return False
256
149
 
257
- def _get_tree_sitter_language_name(self, language: str) -> str:
258
- """Map user language names to tree-sitter language names."""
259
- mapping = {
260
- "c++": "cpp",
261
- "c": "c",
262
- "cpp": "cpp",
263
- "java": "java",
264
- "rust": "rust",
265
- "python": "python",
266
- "go": "go",
267
- "javascript": "javascript",
268
- "typescript": "typescript",
269
- "js": "javascript",
270
- "ts": "typescript",
271
- "csharp": "csharp",
272
- "c#": "csharp",
273
- "cs": "csharp",
274
- "html": "html",
275
- "css": "css",
276
- }
277
- return mapping.get(language, language)
278
-
279
- def _build_definition_and_import_indexes(
280
- self, state: AnalyzerState, config: dict[str, Any], language: str
281
- ) -> None:
282
- """Build definition and import indexes."""
283
- for file_path, tree in state.asts.items():
284
- # Build definition index
285
- for node in self._walk_tree(tree.root_node):
286
- if self._is_function_definition(node, config):
287
- qualified_name = self._qualify_name(
288
- node, file_path, config, language
289
- )
290
- if qualified_name:
291
- span = (node.start_byte, node.end_byte)
292
- state.def_index[qualified_name] = FunctionInfo(
293
- file=file_path,
294
- node=node,
295
- span=span,
296
- qualified_name=qualified_name,
297
- )
298
-
299
- # Build import map
300
- file_imports = {}
301
- for node in self._walk_tree(tree.root_node):
302
- if self._is_import_statement(node, config):
303
- imports = self._extract_imports(node)
304
- file_imports.update(imports)
305
- state.imports[file_path] = file_imports
150
+ def _build_state_from_ast_analyzer(
151
+ self,
152
+ parsed_files: list["ParsedFile"],
153
+ functions: list["FunctionDefinition"],
154
+ ) -> AnalyzerState:
155
+ """Build AnalyzerState from ASTAnalyzer results."""
156
+ # Create a dummy parser (not used for new parsing)
157
+ from tree_sitter_language_pack import get_language
158
+
159
+ ts_language = get_language("python")
160
+ parser = Parser(ts_language)
161
+
162
+ state = AnalyzerState(parser=parser)
163
+
164
+ # Populate files and ASTs from ParsedFile objects
165
+ for parsed in parsed_files:
166
+ state.files.append(parsed.path)
167
+ state.asts[parsed.path] = parsed.tree
168
+
169
+ # Populate def_index from FunctionDefinition objects
170
+ for func_def in functions:
171
+ state.def_index[func_def.qualified_name] = FunctionInfo(
172
+ file=func_def.file,
173
+ node=func_def.node,
174
+ span=func_def.span,
175
+ qualified_name=func_def.qualified_name,
176
+ )
177
+
178
+ return state
306
179
 
307
180
  def _build_call_graph(self, state: AnalyzerState, config: dict[str, Any]) -> None:
308
181
  """Build call graph from function definitions."""
@@ -338,214 +211,6 @@ class Slicer:
338
211
  # Add children to queue
339
212
  queue.extend(current.children)
340
213
 
341
- def _is_function_definition(self, node: Node, config: dict[str, Any]) -> bool:
342
- """Check if node is a function definition."""
343
- return node.type in (config["function_nodes"] + config["method_nodes"])
344
-
345
- def _is_import_statement(self, node: Node, config: dict[str, Any]) -> bool:
346
- """Check if node is an import statement."""
347
- return node.type in config["import_nodes"]
348
-
349
- def _extract_function_name(
350
- self, node: Node, config: dict[str, Any], language: str
351
- ) -> str | None:
352
- """Extract function name from a function definition node."""
353
- if language == "html":
354
- return self._extract_html_element_name(node)
355
- if language == "css":
356
- return self._extract_css_rule_name(node)
357
- if language == "go" and node.type == "method_declaration":
358
- return self._extract_go_method_name(node)
359
- if language in ["c", "cpp"] and config["name_field"]:
360
- return self._extract_c_cpp_function_name(node, config)
361
- if language == "rust" and config["name_field"]:
362
- return self._extract_rust_function_name(node, config)
363
- return self._extract_default_function_name(node)
364
-
365
- def _extract_go_method_name(self, node: Node) -> str | None:
366
- """Extract method name from Go method declaration."""
367
- for child in node.children:
368
- if child.type == "field_identifier" and child.text is not None:
369
- return child.text.decode("utf-8")
370
- return None
371
-
372
- def _extract_c_cpp_function_name(
373
- self, node: Node, config: dict[str, Any]
374
- ) -> str | None:
375
- """Extract function name from C/C++ function definition."""
376
- declarator = node.child_by_field_name(config["name_field"])
377
- if not declarator:
378
- return None
379
-
380
- if declarator.type == "function_declarator":
381
- for child in declarator.children:
382
- if child.type == "identifier" and child.text is not None:
383
- return child.text.decode("utf-8")
384
- elif declarator.type == "identifier" and declarator.text is not None:
385
- return declarator.text.decode("utf-8")
386
- return None
387
-
388
- def _extract_rust_function_name(
389
- self, node: Node, config: dict[str, Any]
390
- ) -> str | None:
391
- """Extract function name from Rust function definition."""
392
- name_node = node.child_by_field_name(config["name_field"])
393
- if name_node and name_node.type == "identifier" and name_node.text is not None:
394
- return name_node.text.decode("utf-8")
395
- return None
396
-
397
- def _extract_html_element_name(self, node: Node) -> str | None:
398
- """Extract meaningful name from HTML element."""
399
- if node.type == "script_element":
400
- return "script"
401
- if node.type == "style_element":
402
- return "style"
403
- if node.type == "element":
404
- return self._extract_html_element_info(node)
405
- return None
406
-
407
- def _extract_html_element_info(self, node: Node) -> str | None:
408
- """Extract element info with ID or class."""
409
- for child in node.children:
410
- if child.type == "start_tag":
411
- tag_name = self._get_tag_name(child)
412
- element_id = self._get_element_id(child)
413
- class_name = self._get_element_class(child)
414
-
415
- if element_id:
416
- return f"{tag_name or 'element'}#{element_id}"
417
- if class_name:
418
- return f"{tag_name or 'element'}.{class_name}"
419
- if tag_name:
420
- return tag_name
421
- return None
422
-
423
- def _get_tag_name(self, start_tag: Node) -> str | None:
424
- """Get tag name from start_tag node."""
425
- for child in start_tag.children:
426
- if child.type == "tag_name" and child.text:
427
- try:
428
- return child.text.decode("utf-8")
429
- except UnicodeDecodeError:
430
- return None
431
- return None
432
-
433
- def _get_element_id(self, start_tag: Node) -> str | None:
434
- """Get element ID from start_tag node."""
435
- return self._get_attribute_value(start_tag, "id")
436
-
437
- def _get_element_class(self, start_tag: Node) -> str | None:
438
- """Get first class name from start_tag node."""
439
- class_value = self._get_attribute_value(start_tag, "class")
440
- return class_value.split()[0] if class_value else None
441
-
442
- def _get_attribute_value(self, start_tag: Node, attr_name: str) -> str | None:
443
- """Get attribute value from start_tag node."""
444
- for child in start_tag.children:
445
- if child.type == "attribute":
446
- name = self._get_attr_name(child)
447
- if name == attr_name:
448
- return self._get_attr_value(child)
449
- return None
450
-
451
- def _get_attr_name(self, attr_node: Node) -> str | None:
452
- """Get attribute name."""
453
- for child in attr_node.children:
454
- if child.type == "attribute_name" and child.text:
455
- try:
456
- return child.text.decode("utf-8")
457
- except UnicodeDecodeError:
458
- return None
459
- return None
460
-
461
- def _get_attr_value(self, attr_node: Node) -> str | None:
462
- """Get attribute value."""
463
- for child in attr_node.children:
464
- if child.type == "quoted_attribute_value":
465
- for val_child in child.children:
466
- if val_child.type == "attribute_value" and val_child.text:
467
- try:
468
- return val_child.text.decode("utf-8")
469
- except UnicodeDecodeError:
470
- return None
471
- return None
472
-
473
- def _extract_css_rule_name(self, node: Node) -> str | None:
474
- """Extract meaningful name from CSS rule."""
475
- if node.type == "rule_set":
476
- return self._extract_css_selector(node)
477
- if node.type == "keyframes_statement":
478
- return self._extract_keyframes_name(node)
479
- if node.type == "media_statement":
480
- return "@media"
481
- return None
482
-
483
- def _extract_css_selector(self, rule_node: Node) -> str | None:
484
- """Extract CSS selector from rule_set."""
485
- for child in rule_node.children:
486
- if child.type == "selectors":
487
- selector_parts = []
488
- for selector_child in child.children:
489
- part = self._get_selector_part(selector_child)
490
- if part:
491
- selector_parts.append(part)
492
- if selector_parts:
493
- return "".join(selector_parts[:2]) # First couple selectors
494
- return None
495
-
496
- def _get_selector_part(self, selector_node: Node) -> str | None:
497
- """Get a single selector part."""
498
- if selector_node.type == "class_selector":
499
- return self._extract_class_selector(selector_node)
500
- if selector_node.type == "id_selector":
501
- return self._extract_id_selector(selector_node)
502
- if selector_node.type == "type_selector" and selector_node.text:
503
- return selector_node.text.decode("utf-8")
504
- return None
505
-
506
- def _extract_class_selector(self, node: Node) -> str | None:
507
- """Extract class selector name."""
508
- for child in node.children:
509
- if child.type == "class_name":
510
- for name_child in child.children:
511
- if name_child.type == "identifier" and name_child.text:
512
- return f".{name_child.text.decode('utf-8')}"
513
- return None
514
-
515
- def _extract_id_selector(self, node: Node) -> str | None:
516
- """Extract ID selector name."""
517
- for child in node.children:
518
- if child.type == "id_name":
519
- for name_child in child.children:
520
- if name_child.type == "identifier" and name_child.text:
521
- return f"#{name_child.text.decode('utf-8')}"
522
- return None
523
-
524
- def _extract_keyframes_name(self, node: Node) -> str | None:
525
- """Extract keyframes animation name."""
526
- for child in node.children:
527
- if child.type == "keyframes_name" and child.text:
528
- return f"@keyframes-{child.text.decode('utf-8')}"
529
- return None
530
-
531
- def _extract_default_function_name(self, node: Node) -> str | None:
532
- """Extract function name using default identifier search."""
533
- for child in node.children:
534
- if child.type == "identifier" and child.text is not None:
535
- return child.text.decode("utf-8")
536
- return None
537
-
538
- def _qualify_name(
539
- self, node: Node, file_path: Path, config: dict[str, Any], language: str
540
- ) -> str | None:
541
- """Create qualified name for a function node."""
542
- function_name = self._extract_function_name(node, config, language)
543
- if not function_name:
544
- return None
545
-
546
- module_name = file_path.stem
547
- return f"{module_name}.{function_name}"
548
-
549
214
  def _get_file_content(self, file_path: Path, file_contents: dict[Path, str]) -> str:
550
215
  """Get cached file content."""
551
216
  if file_path not in file_contents:
@@ -0,0 +1,118 @@
1
+ """EnrichmentV2 repository."""
2
+
3
+ from collections.abc import Callable, Sequence
4
+
5
+ import structlog
6
+ from sqlalchemy import delete, select
7
+ from sqlalchemy.ext.asyncio import AsyncSession
8
+
9
+ from kodit.domain.enrichments.enrichment import EnrichmentV2
10
+ from kodit.infrastructure.mappers.enrichment_mapper import EnrichmentMapper
11
+ from kodit.infrastructure.sqlalchemy import entities as db_entities
12
+ from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
13
+
14
+
15
+ class EnrichmentV2Repository:
16
+ """Repository for managing enrichments and their associations."""
17
+
18
+ def __init__(
19
+ self,
20
+ session_factory: Callable[[], AsyncSession],
21
+ ) -> None:
22
+ """Initialize the repository."""
23
+ self.session_factory = session_factory
24
+ self.mapper = EnrichmentMapper()
25
+ self.log = structlog.get_logger(__name__)
26
+
27
+ async def enrichments_for_entity_type(
28
+ self,
29
+ entity_type: str,
30
+ entity_ids: list[str],
31
+ ) -> list[EnrichmentV2]:
32
+ """Get all enrichments for multiple entities of the same type."""
33
+ if not entity_ids:
34
+ return []
35
+
36
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
37
+ stmt = (
38
+ select(
39
+ db_entities.EnrichmentV2,
40
+ db_entities.EnrichmentAssociation.entity_id,
41
+ )
42
+ .join(db_entities.EnrichmentAssociation)
43
+ .where(
44
+ db_entities.EnrichmentAssociation.entity_type == entity_type,
45
+ db_entities.EnrichmentAssociation.entity_id.in_(entity_ids),
46
+ )
47
+ )
48
+
49
+ result = await session.execute(stmt)
50
+ rows = result.all()
51
+
52
+ return [
53
+ self.mapper.to_domain(db_enrichment, entity_type, entity_id)
54
+ for db_enrichment, entity_id in rows
55
+ ]
56
+
57
+ async def bulk_save_enrichments(
58
+ self,
59
+ enrichments: Sequence[EnrichmentV2],
60
+ ) -> None:
61
+ """Bulk save enrichments with their associations."""
62
+ if not enrichments:
63
+ return
64
+
65
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
66
+ enrichment_records = []
67
+ for enrichment in enrichments:
68
+ db_enrichment = db_entities.EnrichmentV2(
69
+ type=enrichment.type,
70
+ subtype=enrichment.subtype,
71
+ content=enrichment.content,
72
+ )
73
+ session.add(db_enrichment)
74
+ enrichment_records.append((enrichment, db_enrichment))
75
+
76
+ await session.flush()
77
+
78
+ for enrichment, db_enrichment in enrichment_records:
79
+ db_association = db_entities.EnrichmentAssociation(
80
+ enrichment_id=db_enrichment.id,
81
+ entity_type=enrichment.entity_type_key(),
82
+ entity_id=enrichment.entity_id,
83
+ )
84
+ session.add(db_association)
85
+
86
+ async def bulk_delete_enrichments(
87
+ self,
88
+ entity_type: str,
89
+ entity_ids: list[str],
90
+ ) -> None:
91
+ """Bulk delete enrichments for multiple entities of the same type."""
92
+ if not entity_ids:
93
+ return
94
+
95
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
96
+ stmt = select(db_entities.EnrichmentAssociation.enrichment_id).where(
97
+ db_entities.EnrichmentAssociation.entity_type == entity_type,
98
+ db_entities.EnrichmentAssociation.entity_id.in_(entity_ids),
99
+ )
100
+ result = await session.execute(stmt)
101
+ enrichment_ids = result.scalars().all()
102
+
103
+ if enrichment_ids:
104
+ await session.execute(
105
+ delete(db_entities.EnrichmentV2).where(
106
+ db_entities.EnrichmentV2.id.in_(enrichment_ids)
107
+ )
108
+ )
109
+
110
+ async def delete_enrichment(self, enrichment_id: int) -> bool:
111
+ """Delete a specific enrichment by ID."""
112
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
113
+ result = await session.execute(
114
+ delete(db_entities.EnrichmentV2).where(
115
+ db_entities.EnrichmentV2.id == enrichment_id
116
+ )
117
+ )
118
+ return result.rowcount > 0