kodit 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (118) hide show
  1. kodit/_version.py +2 -2
  2. kodit/application/__init__.py +1 -0
  3. kodit/application/commands/__init__.py +1 -0
  4. kodit/application/commands/snippet_commands.py +22 -0
  5. kodit/application/services/__init__.py +1 -0
  6. kodit/application/services/indexing_application_service.py +387 -0
  7. kodit/application/services/snippet_application_service.py +149 -0
  8. kodit/cli.py +118 -82
  9. kodit/database.py +0 -22
  10. kodit/domain/__init__.py +1 -0
  11. kodit/{source/source_models.py → domain/entities.py} +88 -19
  12. kodit/domain/enums.py +9 -0
  13. kodit/domain/errors.py +5 -0
  14. kodit/domain/interfaces.py +27 -0
  15. kodit/domain/repositories.py +95 -0
  16. kodit/domain/services/__init__.py +1 -0
  17. kodit/domain/services/bm25_service.py +124 -0
  18. kodit/domain/services/embedding_service.py +155 -0
  19. kodit/domain/services/enrichment_service.py +48 -0
  20. kodit/domain/services/ignore_service.py +45 -0
  21. kodit/domain/services/indexing_service.py +203 -0
  22. kodit/domain/services/snippet_extraction_service.py +89 -0
  23. kodit/domain/services/source_service.py +85 -0
  24. kodit/domain/value_objects.py +215 -0
  25. kodit/infrastructure/__init__.py +1 -0
  26. kodit/infrastructure/bm25/__init__.py +1 -0
  27. kodit/infrastructure/bm25/bm25_factory.py +28 -0
  28. kodit/{bm25/local_bm25.py → infrastructure/bm25/local_bm25_repository.py} +33 -22
  29. kodit/{bm25/vectorchord_bm25.py → infrastructure/bm25/vectorchord_bm25_repository.py} +40 -35
  30. kodit/infrastructure/cloning/__init__.py +1 -0
  31. kodit/infrastructure/cloning/folder/__init__.py +1 -0
  32. kodit/infrastructure/cloning/folder/factory.py +128 -0
  33. kodit/infrastructure/cloning/folder/working_copy.py +38 -0
  34. kodit/infrastructure/cloning/git/__init__.py +1 -0
  35. kodit/infrastructure/cloning/git/factory.py +147 -0
  36. kodit/infrastructure/cloning/git/working_copy.py +32 -0
  37. kodit/infrastructure/cloning/metadata.py +127 -0
  38. kodit/infrastructure/embedding/__init__.py +1 -0
  39. kodit/infrastructure/embedding/embedding_factory.py +87 -0
  40. kodit/infrastructure/embedding/embedding_providers/__init__.py +1 -0
  41. kodit/infrastructure/embedding/embedding_providers/batching.py +93 -0
  42. kodit/infrastructure/embedding/embedding_providers/hash_embedding_provider.py +79 -0
  43. kodit/infrastructure/embedding/embedding_providers/local_embedding_provider.py +129 -0
  44. kodit/infrastructure/embedding/embedding_providers/openai_embedding_provider.py +113 -0
  45. kodit/infrastructure/embedding/local_vector_search_repository.py +114 -0
  46. kodit/{embedding/vectorchord_vector_search_service.py → infrastructure/embedding/vectorchord_vector_search_repository.py} +65 -46
  47. kodit/infrastructure/enrichment/__init__.py +1 -0
  48. kodit/{enrichment → infrastructure/enrichment}/enrichment_factory.py +28 -12
  49. kodit/infrastructure/enrichment/legacy_enrichment_models.py +42 -0
  50. kodit/{enrichment/enrichment_provider → infrastructure/enrichment}/local_enrichment_provider.py +38 -26
  51. kodit/infrastructure/enrichment/null_enrichment_provider.py +25 -0
  52. kodit/infrastructure/enrichment/openai_enrichment_provider.py +89 -0
  53. kodit/infrastructure/git/__init__.py +1 -0
  54. kodit/{source/git.py → infrastructure/git/git_utils.py} +10 -2
  55. kodit/infrastructure/ignore/__init__.py +1 -0
  56. kodit/{source/ignore.py → infrastructure/ignore/ignore_pattern_provider.py} +23 -6
  57. kodit/infrastructure/indexing/__init__.py +1 -0
  58. kodit/infrastructure/indexing/fusion_service.py +55 -0
  59. kodit/infrastructure/indexing/index_repository.py +291 -0
  60. kodit/infrastructure/indexing/indexing_factory.py +113 -0
  61. kodit/infrastructure/snippet_extraction/__init__.py +1 -0
  62. kodit/infrastructure/snippet_extraction/language_detection_service.py +39 -0
  63. kodit/infrastructure/snippet_extraction/snippet_extraction_factory.py +95 -0
  64. kodit/infrastructure/snippet_extraction/snippet_query_provider.py +45 -0
  65. kodit/{snippets/method_snippets.py → infrastructure/snippet_extraction/tree_sitter_snippet_extractor.py} +123 -61
  66. kodit/infrastructure/sqlalchemy/__init__.py +1 -0
  67. kodit/{embedding → infrastructure/sqlalchemy}/embedding_repository.py +40 -26
  68. kodit/infrastructure/sqlalchemy/file_repository.py +78 -0
  69. kodit/infrastructure/sqlalchemy/repository.py +133 -0
  70. kodit/infrastructure/sqlalchemy/snippet_repository.py +79 -0
  71. kodit/infrastructure/ui/__init__.py +1 -0
  72. kodit/infrastructure/ui/progress.py +127 -0
  73. kodit/{util → infrastructure/ui}/spinner.py +19 -4
  74. kodit/mcp.py +51 -28
  75. kodit/migrations/env.py +1 -4
  76. kodit/reporting.py +78 -0
  77. {kodit-0.2.4.dist-info → kodit-0.2.6.dist-info}/METADATA +1 -1
  78. kodit-0.2.6.dist-info/RECORD +100 -0
  79. kodit/bm25/__init__.py +0 -1
  80. kodit/bm25/keyword_search_factory.py +0 -17
  81. kodit/bm25/keyword_search_service.py +0 -34
  82. kodit/embedding/__init__.py +0 -1
  83. kodit/embedding/embedding_factory.py +0 -69
  84. kodit/embedding/embedding_models.py +0 -28
  85. kodit/embedding/embedding_provider/__init__.py +0 -1
  86. kodit/embedding/embedding_provider/embedding_provider.py +0 -92
  87. kodit/embedding/embedding_provider/hash_embedding_provider.py +0 -86
  88. kodit/embedding/embedding_provider/local_embedding_provider.py +0 -96
  89. kodit/embedding/embedding_provider/openai_embedding_provider.py +0 -73
  90. kodit/embedding/local_vector_search_service.py +0 -87
  91. kodit/embedding/vector_search_service.py +0 -55
  92. kodit/enrichment/__init__.py +0 -1
  93. kodit/enrichment/enrichment_provider/__init__.py +0 -1
  94. kodit/enrichment/enrichment_provider/enrichment_provider.py +0 -36
  95. kodit/enrichment/enrichment_provider/openai_enrichment_provider.py +0 -79
  96. kodit/enrichment/enrichment_service.py +0 -45
  97. kodit/indexing/__init__.py +0 -1
  98. kodit/indexing/fusion.py +0 -67
  99. kodit/indexing/indexing_models.py +0 -43
  100. kodit/indexing/indexing_repository.py +0 -216
  101. kodit/indexing/indexing_service.py +0 -344
  102. kodit/snippets/__init__.py +0 -1
  103. kodit/snippets/languages/__init__.py +0 -53
  104. kodit/snippets/snippets.py +0 -50
  105. kodit/source/__init__.py +0 -1
  106. kodit/source/source_factories.py +0 -356
  107. kodit/source/source_repository.py +0 -169
  108. kodit/source/source_service.py +0 -150
  109. kodit/util/__init__.py +0 -1
  110. kodit-0.2.4.dist-info/RECORD +0 -71
  111. /kodit/{snippets → infrastructure/snippet_extraction}/languages/csharp.scm +0 -0
  112. /kodit/{snippets → infrastructure/snippet_extraction}/languages/go.scm +0 -0
  113. /kodit/{snippets → infrastructure/snippet_extraction}/languages/javascript.scm +0 -0
  114. /kodit/{snippets → infrastructure/snippet_extraction}/languages/python.scm +0 -0
  115. /kodit/{snippets → infrastructure/snippet_extraction}/languages/typescript.scm +0 -0
  116. {kodit-0.2.4.dist-info → kodit-0.2.6.dist-info}/WHEEL +0 -0
  117. {kodit-0.2.4.dist-info → kodit-0.2.6.dist-info}/entry_points.txt +0 -0
  118. {kodit-0.2.4.dist-info → kodit-0.2.6.dist-info}/licenses/LICENSE +0 -0
@@ -1,17 +1,132 @@
1
- """Extract method snippets from source code."""
1
+ """Infrastructure implementation using tree-sitter for method extraction."""
2
+
3
+ from pathlib import Path
4
+ from typing import cast
2
5
 
3
6
  from tree_sitter import Node, Query
4
7
  from tree_sitter_language_pack import SupportedLanguage, get_language, get_parser
5
8
 
9
+ from kodit.domain.services.snippet_extraction_service import SnippetExtractor
10
+ from kodit.infrastructure.snippet_extraction.snippet_query_provider import (
11
+ SnippetQueryProvider,
12
+ )
13
+
14
+
15
+ class TreeSitterSnippetExtractor(SnippetExtractor):
16
+ """Infrastructure implementation using tree-sitter for method extraction."""
17
+
18
+ def __init__(self, query_provider: SnippetQueryProvider) -> None:
19
+ """Initialize the tree-sitter snippet extractor.
20
+
21
+ Args:
22
+ query_provider: Provider for snippet queries
23
+
24
+ """
25
+ self.query_provider = query_provider
26
+
27
+ async def extract(self, file_path: Path, language: str) -> list[str]:
28
+ """Extract snippets using tree-sitter parsing.
29
+
30
+ Args:
31
+ file_path: Path to the file to extract snippets from
32
+ language: The programming language of the file
33
+
34
+ Returns:
35
+ List of extracted code snippets
36
+
37
+ Raises:
38
+ ValueError: If the file cannot be read or language is not supported
39
+
40
+ """
41
+ try:
42
+ # Get the query for the language
43
+ query = await self.query_provider.get_query(language)
44
+ except FileNotFoundError as e:
45
+ raise ValueError(f"Unsupported language: {file_path}") from e
46
+
47
+ # Get parser and language for tree-sitter
48
+ try:
49
+ tree_sitter_language = get_language(cast("SupportedLanguage", language))
50
+ parser = get_parser(cast("SupportedLanguage", language))
51
+ except Exception as e:
52
+ raise ValueError(f"Unsupported language: {file_path}") from e
53
+
54
+ # Create query object
55
+ query_obj = Query(tree_sitter_language, query)
56
+
57
+ # Read file content
58
+ try:
59
+ file_bytes = file_path.read_bytes()
60
+ except Exception as e:
61
+ raise ValueError(f"Failed to read file: {file_path}") from e
6
62
 
7
- class MethodSnippets:
8
- """Extract method snippets from source code."""
63
+ # Parse and extract snippets
64
+ tree = parser.parse(file_bytes)
65
+ captures_by_name = query_obj.captures(tree.root_node)
66
+ lines = file_bytes.decode().splitlines()
9
67
 
10
- def __init__(self, language: SupportedLanguage, query: str) -> None:
11
- """Initialize the MethodSnippets class."""
12
- self.language = get_language(language)
13
- self.parser = get_parser(language)
14
- self.query = Query(self.language, query)
68
+ # Extract snippets using the existing logic
69
+ snippets = self._extract_snippets_from_captures(captures_by_name, lines)
70
+
71
+ # If there are no results, return the entire file
72
+ if not snippets:
73
+ return [file_bytes.decode()]
74
+
75
+ return snippets
76
+
77
+ def _extract_snippets_from_captures(
78
+ self, captures_by_name: dict[str, list[Node]], lines: list[str]
79
+ ) -> list[str]:
80
+ """Extract snippets from tree-sitter captures.
81
+
82
+ Args:
83
+ captures_by_name: Captures organized by name
84
+ lines: Lines of the source file
85
+
86
+ Returns:
87
+ List of extracted code snippets
88
+
89
+ """
90
+ # Find all leaf functions
91
+ leaf_functions = self._get_leaf_functions(captures_by_name)
92
+
93
+ # Find all imports
94
+ imports = self._get_imports(captures_by_name)
95
+
96
+ results = []
97
+
98
+ # For each leaf function, find all lines this function is dependent on
99
+ for func_node in leaf_functions:
100
+ all_lines_to_keep = set()
101
+
102
+ ancestors = self._get_ancestors(captures_by_name, func_node)
103
+
104
+ # Add self to keep
105
+ all_lines_to_keep.update(
106
+ range(func_node.start_point[0], func_node.end_point[0] + 1)
107
+ )
108
+
109
+ # Add imports to keep
110
+ for import_node in imports:
111
+ all_lines_to_keep.update(
112
+ range(import_node.start_point[0], import_node.end_point[0] + 1)
113
+ )
114
+
115
+ # Add ancestors to keep
116
+ for node in ancestors:
117
+ # Get the first line of the node for now
118
+ start = node.start_point[0]
119
+ end = node.start_point[0]
120
+ all_lines_to_keep.update(range(start, end + 1))
121
+
122
+ pseudo_code = []
123
+ for i, line in enumerate(lines):
124
+ if i in all_lines_to_keep:
125
+ pseudo_code.append(line)
126
+
127
+ results.append("\n".join(pseudo_code))
128
+
129
+ return results
15
130
 
16
131
  def _get_leaf_functions(
17
132
  self, captures_by_name: dict[str, list[Node]]
@@ -65,56 +180,3 @@ class MethodSnippets:
65
180
  ancestors.append(parent)
66
181
  parent = parent.parent
67
182
  return ancestors
68
-
69
- def extract(self, source_code: bytes) -> list[str]:
70
- """Extract method snippets from source code."""
71
- tree = self.parser.parse(source_code)
72
-
73
- captures_by_name = self.query.captures(tree.root_node)
74
-
75
- lines = source_code.decode().splitlines()
76
-
77
- # Find all leaf functions
78
- leaf_functions = self._get_leaf_functions(captures_by_name)
79
-
80
- # Find all imports
81
- imports = self._get_imports(captures_by_name)
82
-
83
- results = []
84
-
85
- # For each leaf function, find all lines this function is dependent on
86
- for func_node in leaf_functions:
87
- all_lines_to_keep = set()
88
-
89
- ancestors = self._get_ancestors(captures_by_name, func_node)
90
-
91
- # Add self to keep
92
- all_lines_to_keep.update(
93
- range(func_node.start_point[0], func_node.end_point[0] + 1)
94
- )
95
-
96
- # Add imports to keep
97
- for import_node in imports:
98
- all_lines_to_keep.update(
99
- range(import_node.start_point[0], import_node.end_point[0] + 1)
100
- )
101
-
102
- # Add ancestors to keep
103
- for node in ancestors:
104
- # Get the first line of the node for now
105
- start = node.start_point[0]
106
- end = node.start_point[0]
107
- all_lines_to_keep.update(range(start, end + 1))
108
-
109
- pseudo_code = []
110
- for i, line in enumerate(lines):
111
- if i in all_lines_to_keep:
112
- pseudo_code.append(line)
113
-
114
- results.append("\n".join(pseudo_code))
115
-
116
- # If there are no results, then return the entire file
117
- if not results:
118
- return [source_code.decode()]
119
-
120
- return results
@@ -0,0 +1 @@
1
+ """SQLAlchemy infrastructure."""
@@ -1,39 +1,35 @@
1
- """Repository for managing embeddings."""
1
+ """SQLAlchemy implementation of embedding repository."""
2
2
 
3
3
  import numpy as np
4
4
  from sqlalchemy import select
5
5
  from sqlalchemy.ext.asyncio import AsyncSession
6
6
 
7
- from kodit.embedding.embedding_models import Embedding, EmbeddingType
7
+ from kodit.domain.entities import Embedding, EmbeddingType
8
8
 
9
9
 
10
- class EmbeddingRepository:
11
- """Repository for managing embeddings.
10
+ class SqlAlchemyEmbeddingRepository:
11
+ """SQLAlchemy implementation of embedding repository."""
12
12
 
13
- This class provides methods for creating and retrieving embeddings from the
14
- database. It handles the low-level database operations and transaction management.
15
-
16
- Args:
17
- session: The SQLAlchemy async session to use for database operations.
13
+ def __init__(self, session: AsyncSession) -> None:
14
+ """Initialize the SQLAlchemy embedding repository.
18
15
 
19
- """
16
+ Args:
17
+ session: The SQLAlchemy async session to use for database operations
20
18
 
21
- def __init__(self, session: AsyncSession) -> None:
22
- """Initialize the embedding repository."""
19
+ """
23
20
  self.session = session
24
21
 
25
22
  async def create_embedding(self, embedding: Embedding) -> Embedding:
26
23
  """Create a new embedding record in the database.
27
24
 
28
25
  Args:
29
- embedding: The Embedding instance to create.
26
+ embedding: The Embedding instance to create
30
27
 
31
28
  Returns:
32
- The created Embedding instance.
29
+ The created Embedding instance
33
30
 
34
31
  """
35
32
  self.session.add(embedding)
36
- await self.session.commit()
37
33
  return embedding
38
34
 
39
35
  async def get_embedding_by_snippet_id_and_type(
@@ -42,11 +38,11 @@ class EmbeddingRepository:
42
38
  """Get an embedding by its snippet ID and type.
43
39
 
44
40
  Args:
45
- snippet_id: The ID of the snippet to get the embedding for.
46
- embedding_type: The type of embedding to get.
41
+ snippet_id: The ID of the snippet to get the embedding for
42
+ embedding_type: The type of embedding to get
47
43
 
48
44
  Returns:
49
- The Embedding instance if found, None otherwise.
45
+ The Embedding instance if found, None otherwise
50
46
 
51
47
  """
52
48
  query = select(Embedding).where(
@@ -62,10 +58,10 @@ class EmbeddingRepository:
62
58
  """List all embeddings of a given type.
63
59
 
64
60
  Args:
65
- embedding_type: The type of embeddings to list.
61
+ embedding_type: The type of embeddings to list
66
62
 
67
63
  Returns:
68
- A list of Embedding instances.
64
+ A list of Embedding instances
69
65
 
70
66
  """
71
67
  query = select(Embedding).where(Embedding.type == embedding_type)
@@ -76,7 +72,7 @@ class EmbeddingRepository:
76
72
  """Delete all embeddings for a snippet.
77
73
 
78
74
  Args:
79
- snippet_id: The ID of the snippet to delete embeddings for.
75
+ snippet_id: The ID of the snippet to delete embeddings for
80
76
 
81
77
  """
82
78
  query = select(Embedding).where(Embedding.snippet_id == snippet_id)
@@ -84,7 +80,6 @@ class EmbeddingRepository:
84
80
  embeddings = result.scalars().all()
85
81
  for embedding in embeddings:
86
82
  await self.session.delete(embedding)
87
- await self.session.commit()
88
83
 
89
84
  async def list_semantic_results(
90
85
  self, embedding_type: EmbeddingType, embedding: list[float], top_k: int = 10
@@ -181,7 +176,27 @@ class EmbeddingRepository:
181
176
  """
182
177
  stored_norms = np.linalg.norm(stored_vecs, axis=1)
183
178
  query_norm = np.linalg.norm(query_vec)
184
- return np.dot(stored_vecs, query_vec) / (stored_norms * query_norm)
179
+
180
+ # Handle zero vectors to avoid division by zero
181
+ if query_norm == 0:
182
+ # If query vector is zero, return zeros for all similarities
183
+ return np.zeros(len(stored_vecs))
184
+
185
+ # Handle stored vectors with zero norms
186
+ zero_stored_mask = stored_norms == 0
187
+ similarities = np.zeros(len(stored_vecs))
188
+
189
+ # Only compute similarities for non-zero stored vectors
190
+ non_zero_mask = ~zero_stored_mask
191
+ if np.any(non_zero_mask):
192
+ non_zero_stored_vecs = stored_vecs[non_zero_mask]
193
+ non_zero_stored_norms = stored_norms[non_zero_mask]
194
+ non_zero_similarities = np.dot(non_zero_stored_vecs, query_vec) / (
195
+ non_zero_stored_norms * query_norm
196
+ )
197
+ similarities[non_zero_mask] = non_zero_similarities
198
+
199
+ return similarities
185
200
 
186
201
  def _get_top_k_results(
187
202
  self,
@@ -200,7 +215,6 @@ class EmbeddingRepository:
200
215
  List of (snippet_id, similarity_score) tuples
201
216
 
202
217
  """
218
+ # Get indices of top-k similarities
203
219
  top_indices = np.argsort(similarities)[::-1][:top_k]
204
- return [
205
- (embeddings[i][0], float(similarities[i])) for i in top_indices
206
- ] # Use index 0 to get snippet_id
220
+ return [(embeddings[i][0], float(similarities[i])) for i in top_indices]
@@ -0,0 +1,78 @@
1
+ """SQLAlchemy implementation of file repository."""
2
+
3
+ from collections.abc import Sequence
4
+
5
+ from sqlalchemy import select
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+
8
+ from kodit.domain.entities import File, Index
9
+ from kodit.domain.repositories import FileRepository
10
+
11
+
12
+ class SqlAlchemyFileRepository(FileRepository):
13
+ """SQLAlchemy implementation of file repository."""
14
+
15
+ def __init__(self, session: AsyncSession) -> None:
16
+ """Initialize the SQLAlchemy file repository.
17
+
18
+ Args:
19
+ session: The SQLAlchemy async session to use for database operations
20
+
21
+ """
22
+ self.session = session
23
+
24
+ async def get(self, id: int) -> File | None: # noqa: A002
25
+ """Get a file by ID."""
26
+ return await self.session.get(File, id)
27
+
28
+ async def save(self, entity: File) -> File:
29
+ """Save entity."""
30
+ self.session.add(entity)
31
+ return entity
32
+
33
+ async def delete(self, id: int) -> None: # noqa: A002
34
+ """Delete entity by ID."""
35
+ file = await self.get(id)
36
+ if file:
37
+ await self.session.delete(file)
38
+
39
+ async def list(self) -> Sequence[File]:
40
+ """List all entities."""
41
+ return (await self.session.scalars(select(File))).all()
42
+
43
+ async def get_files_for_index(self, index_id: int) -> Sequence[File]:
44
+ """Get all files for an index.
45
+
46
+ Args:
47
+ index_id: The ID of the index to get files for
48
+
49
+ Returns:
50
+ A list of File instances
51
+
52
+ """
53
+ # Get the index first to find its source_id
54
+ index_query = select(Index).where(Index.id == index_id)
55
+ index_result = await self.session.execute(index_query)
56
+ index = index_result.scalar_one_or_none()
57
+
58
+ if not index:
59
+ return []
60
+
61
+ # Get all files for the source
62
+ query = select(File).where(File.source_id == index.source_id)
63
+ result = await self.session.execute(query)
64
+ return list(result.scalars())
65
+
66
+ async def get_by_id(self, file_id: int) -> File | None:
67
+ """Get a file by ID.
68
+
69
+ Args:
70
+ file_id: The ID of the file to retrieve
71
+
72
+ Returns:
73
+ The File instance if found, None otherwise
74
+
75
+ """
76
+ query = select(File).where(File.id == file_id)
77
+ result = await self.session.execute(query)
78
+ return result.scalar_one_or_none()
@@ -0,0 +1,133 @@
1
+ """SQLAlchemy repository."""
2
+
3
+ from collections.abc import Sequence
4
+ from typing import cast
5
+
6
+ from sqlalchemy import select
7
+ from sqlalchemy.ext.asyncio import AsyncSession
8
+
9
+ from kodit.domain.entities import Author, AuthorFileMapping, File, Source, SourceType
10
+ from kodit.domain.repositories import AuthorRepository, SourceRepository
11
+
12
+
13
+ class SqlAlchemySourceRepository(SourceRepository):
14
+ """SQLAlchemy source repository."""
15
+
16
+ def __init__(self, session: AsyncSession) -> None:
17
+ """Initialize the repository."""
18
+ self._session = session
19
+
20
+ async def get(self, id: int) -> Source | None: # noqa: A002
21
+ """Get a source by ID."""
22
+ return await self._session.get(Source, id)
23
+
24
+ async def save(self, entity: Source) -> Source:
25
+ """Save entity."""
26
+ self._session.add(entity)
27
+ return entity
28
+
29
+ async def delete(self, id: int) -> None: # noqa: A002
30
+ """Delete entity by ID."""
31
+ source = await self.get(id)
32
+ if source:
33
+ await self._session.delete(source)
34
+
35
+ async def list(self) -> Sequence[Source]:
36
+ """List all entities."""
37
+ stmt = select(Source)
38
+ return (await self._session.scalars(stmt)).all()
39
+
40
+ async def get_by_uri(self, uri: str) -> Source | None:
41
+ """Get a source by URI."""
42
+ stmt = select(Source).where(Source.uri == uri)
43
+ return cast("Source | None", await self._session.scalar(stmt))
44
+
45
+ async def list_by_type(
46
+ self, source_type: SourceType | None = None
47
+ ) -> Sequence[Source]:
48
+ """List sources by type."""
49
+ stmt = select(Source)
50
+ if source_type is not None:
51
+ stmt = stmt.where(Source.type == source_type)
52
+ return (await self._session.scalars(stmt)).all()
53
+
54
+ async def create_file(self, file: File) -> File:
55
+ """Create a new file record."""
56
+ self._session.add(file)
57
+ return file
58
+
59
+ async def upsert_author(self, author: Author) -> Author:
60
+ """Create a new author or return existing one if email already exists."""
61
+ # First check if author already exists with same name and email
62
+ stmt = select(Author).where(
63
+ Author.name == author.name, Author.email == author.email
64
+ )
65
+ existing_author = cast("Author | None", await self._session.scalar(stmt))
66
+
67
+ if existing_author:
68
+ return existing_author
69
+
70
+ # Author doesn't exist, create new one
71
+ self._session.add(author)
72
+ return author
73
+
74
+ async def upsert_author_file_mapping(
75
+ self, mapping: AuthorFileMapping
76
+ ) -> AuthorFileMapping:
77
+ """Create a new author file mapping or return existing one if already exists."""
78
+ # First check if mapping already exists with same author_id and file_id
79
+ stmt = select(AuthorFileMapping).where(
80
+ AuthorFileMapping.author_id == mapping.author_id,
81
+ AuthorFileMapping.file_id == mapping.file_id,
82
+ )
83
+ existing_mapping = cast(
84
+ "AuthorFileMapping | None", await self._session.scalar(stmt)
85
+ )
86
+
87
+ if existing_mapping:
88
+ return existing_mapping
89
+
90
+ # Mapping doesn't exist, create new one
91
+ self._session.add(mapping)
92
+ return mapping
93
+
94
+
95
+ class SqlAlchemyAuthorRepository(AuthorRepository):
96
+ """SQLAlchemy author repository."""
97
+
98
+ def __init__(self, session: AsyncSession) -> None:
99
+ """Initialize the repository."""
100
+ self._session = session
101
+
102
+ async def get(self, id: int) -> Author | None: # noqa: A002
103
+ """Get an author by ID."""
104
+ return await self._session.get(Author, id)
105
+
106
+ async def save(self, entity: Author) -> Author:
107
+ """Save entity."""
108
+ self._session.add(entity)
109
+ return entity
110
+
111
+ async def delete(self, id: int) -> None: # noqa: A002
112
+ """Delete entity by ID."""
113
+ author = await self.get(id)
114
+ if author:
115
+ await self._session.delete(author)
116
+
117
+ async def list(self) -> Sequence[Author]:
118
+ """List authors."""
119
+ return (await self._session.scalars(select(Author))).all()
120
+
121
+ async def get_by_name(self, name: str) -> Author | None:
122
+ """Get an author by name."""
123
+ return cast(
124
+ "Author | None",
125
+ await self._session.scalar(select(Author).where(Author.name == name)),
126
+ )
127
+
128
+ async def get_by_email(self, email: str) -> Author | None:
129
+ """Get an author by email."""
130
+ return cast(
131
+ "Author | None",
132
+ await self._session.scalar(select(Author).where(Author.email == email)),
133
+ )
@@ -0,0 +1,79 @@
1
+ """SQLAlchemy implementation of snippet repository."""
2
+
3
+ from collections.abc import Sequence
4
+
5
+ from sqlalchemy import delete, select
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+
8
+ from kodit.domain.entities import Snippet
9
+ from kodit.domain.repositories import SnippetRepository
10
+
11
+
12
+ class SqlAlchemySnippetRepository(SnippetRepository):
13
+ """SQLAlchemy implementation of snippet repository."""
14
+
15
+ def __init__(self, session: AsyncSession) -> None:
16
+ """Initialize the SQLAlchemy snippet repository.
17
+
18
+ Args:
19
+ session: The SQLAlchemy async session to use for database operations
20
+
21
+ """
22
+ self.session = session
23
+
24
+ async def get(self, id: int) -> Snippet | None: # noqa: A002
25
+ """Get a snippet by ID."""
26
+ return await self.session.get(Snippet, id)
27
+
28
+ async def save(self, entity: Snippet) -> Snippet:
29
+ """Save entity."""
30
+ self.session.add(entity)
31
+ return entity
32
+
33
+ async def delete(self, id: int) -> None: # noqa: A002
34
+ """Delete entity by ID."""
35
+ snippet = await self.get(id)
36
+ if snippet:
37
+ await self.session.delete(snippet)
38
+
39
+ async def list(self) -> Sequence[Snippet]:
40
+ """List all entities."""
41
+ return (await self.session.scalars(select(Snippet))).all()
42
+
43
+ async def get_by_id(self, snippet_id: int) -> Snippet | None:
44
+ """Get a snippet by ID.
45
+
46
+ Args:
47
+ snippet_id: The ID of the snippet to retrieve
48
+
49
+ Returns:
50
+ The Snippet instance if found, None otherwise
51
+
52
+ """
53
+ query = select(Snippet).where(Snippet.id == snippet_id)
54
+ result = await self.session.execute(query)
55
+ return result.scalar_one_or_none()
56
+
57
+ async def get_by_index(self, index_id: int) -> Sequence[Snippet]:
58
+ """Get all snippets for an index.
59
+
60
+ Args:
61
+ index_id: The ID of the index to get snippets for
62
+
63
+ Returns:
64
+ A list of Snippet instances
65
+
66
+ """
67
+ query = select(Snippet).where(Snippet.index_id == index_id)
68
+ result = await self.session.execute(query)
69
+ return list(result.scalars())
70
+
71
+ async def delete_by_index(self, index_id: int) -> None:
72
+ """Delete all snippets for an index.
73
+
74
+ Args:
75
+ index_id: The ID of the index to delete snippets for
76
+
77
+ """
78
+ query = delete(Snippet).where(Snippet.index_id == index_id)
79
+ await self.session.execute(query)
@@ -0,0 +1 @@
1
+ """UI infrastructure module."""