kodit 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/application/__init__.py +1 -0
- kodit/application/commands/__init__.py +1 -0
- kodit/application/commands/snippet_commands.py +22 -0
- kodit/application/services/__init__.py +1 -0
- kodit/application/services/indexing_application_service.py +387 -0
- kodit/application/services/snippet_application_service.py +149 -0
- kodit/cli.py +118 -82
- kodit/database.py +0 -22
- kodit/domain/__init__.py +1 -0
- kodit/{source/source_models.py → domain/entities.py} +88 -19
- kodit/domain/enums.py +9 -0
- kodit/domain/errors.py +5 -0
- kodit/domain/interfaces.py +27 -0
- kodit/domain/repositories.py +95 -0
- kodit/domain/services/__init__.py +1 -0
- kodit/domain/services/bm25_service.py +124 -0
- kodit/domain/services/embedding_service.py +155 -0
- kodit/domain/services/enrichment_service.py +48 -0
- kodit/domain/services/ignore_service.py +45 -0
- kodit/domain/services/indexing_service.py +203 -0
- kodit/domain/services/snippet_extraction_service.py +89 -0
- kodit/domain/services/source_service.py +85 -0
- kodit/domain/value_objects.py +215 -0
- kodit/infrastructure/__init__.py +1 -0
- kodit/infrastructure/bm25/__init__.py +1 -0
- kodit/infrastructure/bm25/bm25_factory.py +28 -0
- kodit/{bm25/local_bm25.py → infrastructure/bm25/local_bm25_repository.py} +33 -22
- kodit/{bm25/vectorchord_bm25.py → infrastructure/bm25/vectorchord_bm25_repository.py} +40 -35
- kodit/infrastructure/cloning/__init__.py +1 -0
- kodit/infrastructure/cloning/folder/__init__.py +1 -0
- kodit/infrastructure/cloning/folder/factory.py +128 -0
- kodit/infrastructure/cloning/folder/working_copy.py +38 -0
- kodit/infrastructure/cloning/git/__init__.py +1 -0
- kodit/infrastructure/cloning/git/factory.py +147 -0
- kodit/infrastructure/cloning/git/working_copy.py +32 -0
- kodit/infrastructure/cloning/metadata.py +127 -0
- kodit/infrastructure/embedding/__init__.py +1 -0
- kodit/infrastructure/embedding/embedding_factory.py +87 -0
- kodit/infrastructure/embedding/embedding_providers/__init__.py +1 -0
- kodit/infrastructure/embedding/embedding_providers/batching.py +93 -0
- kodit/infrastructure/embedding/embedding_providers/hash_embedding_provider.py +79 -0
- kodit/infrastructure/embedding/embedding_providers/local_embedding_provider.py +129 -0
- kodit/infrastructure/embedding/embedding_providers/openai_embedding_provider.py +113 -0
- kodit/infrastructure/embedding/local_vector_search_repository.py +114 -0
- kodit/{embedding/vectorchord_vector_search_service.py → infrastructure/embedding/vectorchord_vector_search_repository.py} +65 -46
- kodit/infrastructure/enrichment/__init__.py +1 -0
- kodit/{enrichment → infrastructure/enrichment}/enrichment_factory.py +28 -12
- kodit/infrastructure/enrichment/legacy_enrichment_models.py +42 -0
- kodit/{enrichment/enrichment_provider → infrastructure/enrichment}/local_enrichment_provider.py +38 -26
- kodit/infrastructure/enrichment/null_enrichment_provider.py +25 -0
- kodit/infrastructure/enrichment/openai_enrichment_provider.py +89 -0
- kodit/infrastructure/git/__init__.py +1 -0
- kodit/{source/git.py → infrastructure/git/git_utils.py} +10 -2
- kodit/infrastructure/ignore/__init__.py +1 -0
- kodit/{source/ignore.py → infrastructure/ignore/ignore_pattern_provider.py} +23 -6
- kodit/infrastructure/indexing/__init__.py +1 -0
- kodit/infrastructure/indexing/fusion_service.py +55 -0
- kodit/infrastructure/indexing/index_repository.py +291 -0
- kodit/infrastructure/indexing/indexing_factory.py +113 -0
- kodit/infrastructure/snippet_extraction/__init__.py +1 -0
- kodit/infrastructure/snippet_extraction/language_detection_service.py +39 -0
- kodit/infrastructure/snippet_extraction/snippet_extraction_factory.py +95 -0
- kodit/infrastructure/snippet_extraction/snippet_query_provider.py +45 -0
- kodit/{snippets/method_snippets.py → infrastructure/snippet_extraction/tree_sitter_snippet_extractor.py} +123 -61
- kodit/infrastructure/sqlalchemy/__init__.py +1 -0
- kodit/{embedding → infrastructure/sqlalchemy}/embedding_repository.py +40 -26
- kodit/infrastructure/sqlalchemy/file_repository.py +78 -0
- kodit/infrastructure/sqlalchemy/repository.py +133 -0
- kodit/infrastructure/sqlalchemy/snippet_repository.py +79 -0
- kodit/infrastructure/ui/__init__.py +1 -0
- kodit/infrastructure/ui/progress.py +127 -0
- kodit/{util → infrastructure/ui}/spinner.py +19 -4
- kodit/mcp.py +51 -28
- kodit/migrations/env.py +1 -4
- kodit/reporting.py +78 -0
- {kodit-0.2.4.dist-info → kodit-0.2.6.dist-info}/METADATA +1 -1
- kodit-0.2.6.dist-info/RECORD +100 -0
- kodit/bm25/__init__.py +0 -1
- kodit/bm25/keyword_search_factory.py +0 -17
- kodit/bm25/keyword_search_service.py +0 -34
- kodit/embedding/__init__.py +0 -1
- kodit/embedding/embedding_factory.py +0 -69
- kodit/embedding/embedding_models.py +0 -28
- kodit/embedding/embedding_provider/__init__.py +0 -1
- kodit/embedding/embedding_provider/embedding_provider.py +0 -92
- kodit/embedding/embedding_provider/hash_embedding_provider.py +0 -86
- kodit/embedding/embedding_provider/local_embedding_provider.py +0 -96
- kodit/embedding/embedding_provider/openai_embedding_provider.py +0 -73
- kodit/embedding/local_vector_search_service.py +0 -87
- kodit/embedding/vector_search_service.py +0 -55
- kodit/enrichment/__init__.py +0 -1
- kodit/enrichment/enrichment_provider/__init__.py +0 -1
- kodit/enrichment/enrichment_provider/enrichment_provider.py +0 -36
- kodit/enrichment/enrichment_provider/openai_enrichment_provider.py +0 -79
- kodit/enrichment/enrichment_service.py +0 -45
- kodit/indexing/__init__.py +0 -1
- kodit/indexing/fusion.py +0 -67
- kodit/indexing/indexing_models.py +0 -43
- kodit/indexing/indexing_repository.py +0 -216
- kodit/indexing/indexing_service.py +0 -344
- kodit/snippets/__init__.py +0 -1
- kodit/snippets/languages/__init__.py +0 -53
- kodit/snippets/snippets.py +0 -50
- kodit/source/__init__.py +0 -1
- kodit/source/source_factories.py +0 -356
- kodit/source/source_repository.py +0 -169
- kodit/source/source_service.py +0 -150
- kodit/util/__init__.py +0 -1
- kodit-0.2.4.dist-info/RECORD +0 -71
- /kodit/{snippets → infrastructure/snippet_extraction}/languages/csharp.scm +0 -0
- /kodit/{snippets → infrastructure/snippet_extraction}/languages/go.scm +0 -0
- /kodit/{snippets → infrastructure/snippet_extraction}/languages/javascript.scm +0 -0
- /kodit/{snippets → infrastructure/snippet_extraction}/languages/python.scm +0 -0
- /kodit/{snippets → infrastructure/snippet_extraction}/languages/typescript.scm +0 -0
- {kodit-0.2.4.dist-info → kodit-0.2.6.dist-info}/WHEEL +0 -0
- {kodit-0.2.4.dist-info → kodit-0.2.6.dist-info}/entry_points.txt +0 -0
- {kodit-0.2.4.dist-info → kodit-0.2.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,17 +1,132 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Infrastructure implementation using tree-sitter for method extraction."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import cast
|
|
2
5
|
|
|
3
6
|
from tree_sitter import Node, Query
|
|
4
7
|
from tree_sitter_language_pack import SupportedLanguage, get_language, get_parser
|
|
5
8
|
|
|
9
|
+
from kodit.domain.services.snippet_extraction_service import SnippetExtractor
|
|
10
|
+
from kodit.infrastructure.snippet_extraction.snippet_query_provider import (
|
|
11
|
+
SnippetQueryProvider,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TreeSitterSnippetExtractor(SnippetExtractor):
|
|
16
|
+
"""Infrastructure implementation using tree-sitter for method extraction."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, query_provider: SnippetQueryProvider) -> None:
|
|
19
|
+
"""Initialize the tree-sitter snippet extractor.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
query_provider: Provider for snippet queries
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
self.query_provider = query_provider
|
|
26
|
+
|
|
27
|
+
async def extract(self, file_path: Path, language: str) -> list[str]:
|
|
28
|
+
"""Extract snippets using tree-sitter parsing.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
file_path: Path to the file to extract snippets from
|
|
32
|
+
language: The programming language of the file
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
List of extracted code snippets
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
ValueError: If the file cannot be read or language is not supported
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
try:
|
|
42
|
+
# Get the query for the language
|
|
43
|
+
query = await self.query_provider.get_query(language)
|
|
44
|
+
except FileNotFoundError as e:
|
|
45
|
+
raise ValueError(f"Unsupported language: {file_path}") from e
|
|
46
|
+
|
|
47
|
+
# Get parser and language for tree-sitter
|
|
48
|
+
try:
|
|
49
|
+
tree_sitter_language = get_language(cast("SupportedLanguage", language))
|
|
50
|
+
parser = get_parser(cast("SupportedLanguage", language))
|
|
51
|
+
except Exception as e:
|
|
52
|
+
raise ValueError(f"Unsupported language: {file_path}") from e
|
|
53
|
+
|
|
54
|
+
# Create query object
|
|
55
|
+
query_obj = Query(tree_sitter_language, query)
|
|
56
|
+
|
|
57
|
+
# Read file content
|
|
58
|
+
try:
|
|
59
|
+
file_bytes = file_path.read_bytes()
|
|
60
|
+
except Exception as e:
|
|
61
|
+
raise ValueError(f"Failed to read file: {file_path}") from e
|
|
6
62
|
|
|
7
|
-
|
|
8
|
-
|
|
63
|
+
# Parse and extract snippets
|
|
64
|
+
tree = parser.parse(file_bytes)
|
|
65
|
+
captures_by_name = query_obj.captures(tree.root_node)
|
|
66
|
+
lines = file_bytes.decode().splitlines()
|
|
9
67
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
68
|
+
# Extract snippets using the existing logic
|
|
69
|
+
snippets = self._extract_snippets_from_captures(captures_by_name, lines)
|
|
70
|
+
|
|
71
|
+
# If there are no results, return the entire file
|
|
72
|
+
if not snippets:
|
|
73
|
+
return [file_bytes.decode()]
|
|
74
|
+
|
|
75
|
+
return snippets
|
|
76
|
+
|
|
77
|
+
def _extract_snippets_from_captures(
|
|
78
|
+
self, captures_by_name: dict[str, list[Node]], lines: list[str]
|
|
79
|
+
) -> list[str]:
|
|
80
|
+
"""Extract snippets from tree-sitter captures.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
captures_by_name: Captures organized by name
|
|
84
|
+
lines: Lines of the source file
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
List of extracted code snippets
|
|
88
|
+
|
|
89
|
+
"""
|
|
90
|
+
# Find all leaf functions
|
|
91
|
+
leaf_functions = self._get_leaf_functions(captures_by_name)
|
|
92
|
+
|
|
93
|
+
# Find all imports
|
|
94
|
+
imports = self._get_imports(captures_by_name)
|
|
95
|
+
|
|
96
|
+
results = []
|
|
97
|
+
|
|
98
|
+
# For each leaf function, find all lines this function is dependent on
|
|
99
|
+
for func_node in leaf_functions:
|
|
100
|
+
all_lines_to_keep = set()
|
|
101
|
+
|
|
102
|
+
ancestors = self._get_ancestors(captures_by_name, func_node)
|
|
103
|
+
|
|
104
|
+
# Add self to keep
|
|
105
|
+
all_lines_to_keep.update(
|
|
106
|
+
range(func_node.start_point[0], func_node.end_point[0] + 1)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Add imports to keep
|
|
110
|
+
for import_node in imports:
|
|
111
|
+
all_lines_to_keep.update(
|
|
112
|
+
range(import_node.start_point[0], import_node.end_point[0] + 1)
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Add ancestors to keep
|
|
116
|
+
for node in ancestors:
|
|
117
|
+
# Get the first line of the node for now
|
|
118
|
+
start = node.start_point[0]
|
|
119
|
+
end = node.start_point[0]
|
|
120
|
+
all_lines_to_keep.update(range(start, end + 1))
|
|
121
|
+
|
|
122
|
+
pseudo_code = []
|
|
123
|
+
for i, line in enumerate(lines):
|
|
124
|
+
if i in all_lines_to_keep:
|
|
125
|
+
pseudo_code.append(line)
|
|
126
|
+
|
|
127
|
+
results.append("\n".join(pseudo_code))
|
|
128
|
+
|
|
129
|
+
return results
|
|
15
130
|
|
|
16
131
|
def _get_leaf_functions(
|
|
17
132
|
self, captures_by_name: dict[str, list[Node]]
|
|
@@ -65,56 +180,3 @@ class MethodSnippets:
|
|
|
65
180
|
ancestors.append(parent)
|
|
66
181
|
parent = parent.parent
|
|
67
182
|
return ancestors
|
|
68
|
-
|
|
69
|
-
def extract(self, source_code: bytes) -> list[str]:
|
|
70
|
-
"""Extract method snippets from source code."""
|
|
71
|
-
tree = self.parser.parse(source_code)
|
|
72
|
-
|
|
73
|
-
captures_by_name = self.query.captures(tree.root_node)
|
|
74
|
-
|
|
75
|
-
lines = source_code.decode().splitlines()
|
|
76
|
-
|
|
77
|
-
# Find all leaf functions
|
|
78
|
-
leaf_functions = self._get_leaf_functions(captures_by_name)
|
|
79
|
-
|
|
80
|
-
# Find all imports
|
|
81
|
-
imports = self._get_imports(captures_by_name)
|
|
82
|
-
|
|
83
|
-
results = []
|
|
84
|
-
|
|
85
|
-
# For each leaf function, find all lines this function is dependent on
|
|
86
|
-
for func_node in leaf_functions:
|
|
87
|
-
all_lines_to_keep = set()
|
|
88
|
-
|
|
89
|
-
ancestors = self._get_ancestors(captures_by_name, func_node)
|
|
90
|
-
|
|
91
|
-
# Add self to keep
|
|
92
|
-
all_lines_to_keep.update(
|
|
93
|
-
range(func_node.start_point[0], func_node.end_point[0] + 1)
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
# Add imports to keep
|
|
97
|
-
for import_node in imports:
|
|
98
|
-
all_lines_to_keep.update(
|
|
99
|
-
range(import_node.start_point[0], import_node.end_point[0] + 1)
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
# Add ancestors to keep
|
|
103
|
-
for node in ancestors:
|
|
104
|
-
# Get the first line of the node for now
|
|
105
|
-
start = node.start_point[0]
|
|
106
|
-
end = node.start_point[0]
|
|
107
|
-
all_lines_to_keep.update(range(start, end + 1))
|
|
108
|
-
|
|
109
|
-
pseudo_code = []
|
|
110
|
-
for i, line in enumerate(lines):
|
|
111
|
-
if i in all_lines_to_keep:
|
|
112
|
-
pseudo_code.append(line)
|
|
113
|
-
|
|
114
|
-
results.append("\n".join(pseudo_code))
|
|
115
|
-
|
|
116
|
-
# If there are no results, then return the entire file
|
|
117
|
-
if not results:
|
|
118
|
-
return [source_code.decode()]
|
|
119
|
-
|
|
120
|
-
return results
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""SQLAlchemy infrastructure."""
|
|
@@ -1,39 +1,35 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""SQLAlchemy implementation of embedding repository."""
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
from sqlalchemy import select
|
|
5
5
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
6
6
|
|
|
7
|
-
from kodit.
|
|
7
|
+
from kodit.domain.entities import Embedding, EmbeddingType
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
class
|
|
11
|
-
"""
|
|
10
|
+
class SqlAlchemyEmbeddingRepository:
|
|
11
|
+
"""SQLAlchemy implementation of embedding repository."""
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
session: The SQLAlchemy async session to use for database operations.
|
|
13
|
+
def __init__(self, session: AsyncSession) -> None:
|
|
14
|
+
"""Initialize the SQLAlchemy embedding repository.
|
|
18
15
|
|
|
19
|
-
|
|
16
|
+
Args:
|
|
17
|
+
session: The SQLAlchemy async session to use for database operations
|
|
20
18
|
|
|
21
|
-
|
|
22
|
-
"""Initialize the embedding repository."""
|
|
19
|
+
"""
|
|
23
20
|
self.session = session
|
|
24
21
|
|
|
25
22
|
async def create_embedding(self, embedding: Embedding) -> Embedding:
|
|
26
23
|
"""Create a new embedding record in the database.
|
|
27
24
|
|
|
28
25
|
Args:
|
|
29
|
-
embedding: The Embedding instance to create
|
|
26
|
+
embedding: The Embedding instance to create
|
|
30
27
|
|
|
31
28
|
Returns:
|
|
32
|
-
The created Embedding instance
|
|
29
|
+
The created Embedding instance
|
|
33
30
|
|
|
34
31
|
"""
|
|
35
32
|
self.session.add(embedding)
|
|
36
|
-
await self.session.commit()
|
|
37
33
|
return embedding
|
|
38
34
|
|
|
39
35
|
async def get_embedding_by_snippet_id_and_type(
|
|
@@ -42,11 +38,11 @@ class EmbeddingRepository:
|
|
|
42
38
|
"""Get an embedding by its snippet ID and type.
|
|
43
39
|
|
|
44
40
|
Args:
|
|
45
|
-
snippet_id: The ID of the snippet to get the embedding for
|
|
46
|
-
embedding_type: The type of embedding to get
|
|
41
|
+
snippet_id: The ID of the snippet to get the embedding for
|
|
42
|
+
embedding_type: The type of embedding to get
|
|
47
43
|
|
|
48
44
|
Returns:
|
|
49
|
-
The Embedding instance if found, None otherwise
|
|
45
|
+
The Embedding instance if found, None otherwise
|
|
50
46
|
|
|
51
47
|
"""
|
|
52
48
|
query = select(Embedding).where(
|
|
@@ -62,10 +58,10 @@ class EmbeddingRepository:
|
|
|
62
58
|
"""List all embeddings of a given type.
|
|
63
59
|
|
|
64
60
|
Args:
|
|
65
|
-
embedding_type: The type of embeddings to list
|
|
61
|
+
embedding_type: The type of embeddings to list
|
|
66
62
|
|
|
67
63
|
Returns:
|
|
68
|
-
A list of Embedding instances
|
|
64
|
+
A list of Embedding instances
|
|
69
65
|
|
|
70
66
|
"""
|
|
71
67
|
query = select(Embedding).where(Embedding.type == embedding_type)
|
|
@@ -76,7 +72,7 @@ class EmbeddingRepository:
|
|
|
76
72
|
"""Delete all embeddings for a snippet.
|
|
77
73
|
|
|
78
74
|
Args:
|
|
79
|
-
snippet_id: The ID of the snippet to delete embeddings for
|
|
75
|
+
snippet_id: The ID of the snippet to delete embeddings for
|
|
80
76
|
|
|
81
77
|
"""
|
|
82
78
|
query = select(Embedding).where(Embedding.snippet_id == snippet_id)
|
|
@@ -84,7 +80,6 @@ class EmbeddingRepository:
|
|
|
84
80
|
embeddings = result.scalars().all()
|
|
85
81
|
for embedding in embeddings:
|
|
86
82
|
await self.session.delete(embedding)
|
|
87
|
-
await self.session.commit()
|
|
88
83
|
|
|
89
84
|
async def list_semantic_results(
|
|
90
85
|
self, embedding_type: EmbeddingType, embedding: list[float], top_k: int = 10
|
|
@@ -181,7 +176,27 @@ class EmbeddingRepository:
|
|
|
181
176
|
"""
|
|
182
177
|
stored_norms = np.linalg.norm(stored_vecs, axis=1)
|
|
183
178
|
query_norm = np.linalg.norm(query_vec)
|
|
184
|
-
|
|
179
|
+
|
|
180
|
+
# Handle zero vectors to avoid division by zero
|
|
181
|
+
if query_norm == 0:
|
|
182
|
+
# If query vector is zero, return zeros for all similarities
|
|
183
|
+
return np.zeros(len(stored_vecs))
|
|
184
|
+
|
|
185
|
+
# Handle stored vectors with zero norms
|
|
186
|
+
zero_stored_mask = stored_norms == 0
|
|
187
|
+
similarities = np.zeros(len(stored_vecs))
|
|
188
|
+
|
|
189
|
+
# Only compute similarities for non-zero stored vectors
|
|
190
|
+
non_zero_mask = ~zero_stored_mask
|
|
191
|
+
if np.any(non_zero_mask):
|
|
192
|
+
non_zero_stored_vecs = stored_vecs[non_zero_mask]
|
|
193
|
+
non_zero_stored_norms = stored_norms[non_zero_mask]
|
|
194
|
+
non_zero_similarities = np.dot(non_zero_stored_vecs, query_vec) / (
|
|
195
|
+
non_zero_stored_norms * query_norm
|
|
196
|
+
)
|
|
197
|
+
similarities[non_zero_mask] = non_zero_similarities
|
|
198
|
+
|
|
199
|
+
return similarities
|
|
185
200
|
|
|
186
201
|
def _get_top_k_results(
|
|
187
202
|
self,
|
|
@@ -200,7 +215,6 @@ class EmbeddingRepository:
|
|
|
200
215
|
List of (snippet_id, similarity_score) tuples
|
|
201
216
|
|
|
202
217
|
"""
|
|
218
|
+
# Get indices of top-k similarities
|
|
203
219
|
top_indices = np.argsort(similarities)[::-1][:top_k]
|
|
204
|
-
return [
|
|
205
|
-
(embeddings[i][0], float(similarities[i])) for i in top_indices
|
|
206
|
-
] # Use index 0 to get snippet_id
|
|
220
|
+
return [(embeddings[i][0], float(similarities[i])) for i in top_indices]
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""SQLAlchemy implementation of file repository."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
from sqlalchemy import select
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
|
+
|
|
8
|
+
from kodit.domain.entities import File, Index
|
|
9
|
+
from kodit.domain.repositories import FileRepository
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SqlAlchemyFileRepository(FileRepository):
|
|
13
|
+
"""SQLAlchemy implementation of file repository."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, session: AsyncSession) -> None:
|
|
16
|
+
"""Initialize the SQLAlchemy file repository.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
session: The SQLAlchemy async session to use for database operations
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
self.session = session
|
|
23
|
+
|
|
24
|
+
async def get(self, id: int) -> File | None: # noqa: A002
|
|
25
|
+
"""Get a file by ID."""
|
|
26
|
+
return await self.session.get(File, id)
|
|
27
|
+
|
|
28
|
+
async def save(self, entity: File) -> File:
|
|
29
|
+
"""Save entity."""
|
|
30
|
+
self.session.add(entity)
|
|
31
|
+
return entity
|
|
32
|
+
|
|
33
|
+
async def delete(self, id: int) -> None: # noqa: A002
|
|
34
|
+
"""Delete entity by ID."""
|
|
35
|
+
file = await self.get(id)
|
|
36
|
+
if file:
|
|
37
|
+
await self.session.delete(file)
|
|
38
|
+
|
|
39
|
+
async def list(self) -> Sequence[File]:
|
|
40
|
+
"""List all entities."""
|
|
41
|
+
return (await self.session.scalars(select(File))).all()
|
|
42
|
+
|
|
43
|
+
async def get_files_for_index(self, index_id: int) -> Sequence[File]:
|
|
44
|
+
"""Get all files for an index.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
index_id: The ID of the index to get files for
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
A list of File instances
|
|
51
|
+
|
|
52
|
+
"""
|
|
53
|
+
# Get the index first to find its source_id
|
|
54
|
+
index_query = select(Index).where(Index.id == index_id)
|
|
55
|
+
index_result = await self.session.execute(index_query)
|
|
56
|
+
index = index_result.scalar_one_or_none()
|
|
57
|
+
|
|
58
|
+
if not index:
|
|
59
|
+
return []
|
|
60
|
+
|
|
61
|
+
# Get all files for the source
|
|
62
|
+
query = select(File).where(File.source_id == index.source_id)
|
|
63
|
+
result = await self.session.execute(query)
|
|
64
|
+
return list(result.scalars())
|
|
65
|
+
|
|
66
|
+
async def get_by_id(self, file_id: int) -> File | None:
|
|
67
|
+
"""Get a file by ID.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
file_id: The ID of the file to retrieve
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
The File instance if found, None otherwise
|
|
74
|
+
|
|
75
|
+
"""
|
|
76
|
+
query = select(File).where(File.id == file_id)
|
|
77
|
+
result = await self.session.execute(query)
|
|
78
|
+
return result.scalar_one_or_none()
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""SQLAlchemy repository."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import cast
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import select
|
|
7
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
8
|
+
|
|
9
|
+
from kodit.domain.entities import Author, AuthorFileMapping, File, Source, SourceType
|
|
10
|
+
from kodit.domain.repositories import AuthorRepository, SourceRepository
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SqlAlchemySourceRepository(SourceRepository):
|
|
14
|
+
"""SQLAlchemy source repository."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, session: AsyncSession) -> None:
|
|
17
|
+
"""Initialize the repository."""
|
|
18
|
+
self._session = session
|
|
19
|
+
|
|
20
|
+
async def get(self, id: int) -> Source | None: # noqa: A002
|
|
21
|
+
"""Get a source by ID."""
|
|
22
|
+
return await self._session.get(Source, id)
|
|
23
|
+
|
|
24
|
+
async def save(self, entity: Source) -> Source:
|
|
25
|
+
"""Save entity."""
|
|
26
|
+
self._session.add(entity)
|
|
27
|
+
return entity
|
|
28
|
+
|
|
29
|
+
async def delete(self, id: int) -> None: # noqa: A002
|
|
30
|
+
"""Delete entity by ID."""
|
|
31
|
+
source = await self.get(id)
|
|
32
|
+
if source:
|
|
33
|
+
await self._session.delete(source)
|
|
34
|
+
|
|
35
|
+
async def list(self) -> Sequence[Source]:
|
|
36
|
+
"""List all entities."""
|
|
37
|
+
stmt = select(Source)
|
|
38
|
+
return (await self._session.scalars(stmt)).all()
|
|
39
|
+
|
|
40
|
+
async def get_by_uri(self, uri: str) -> Source | None:
|
|
41
|
+
"""Get a source by URI."""
|
|
42
|
+
stmt = select(Source).where(Source.uri == uri)
|
|
43
|
+
return cast("Source | None", await self._session.scalar(stmt))
|
|
44
|
+
|
|
45
|
+
async def list_by_type(
|
|
46
|
+
self, source_type: SourceType | None = None
|
|
47
|
+
) -> Sequence[Source]:
|
|
48
|
+
"""List sources by type."""
|
|
49
|
+
stmt = select(Source)
|
|
50
|
+
if source_type is not None:
|
|
51
|
+
stmt = stmt.where(Source.type == source_type)
|
|
52
|
+
return (await self._session.scalars(stmt)).all()
|
|
53
|
+
|
|
54
|
+
async def create_file(self, file: File) -> File:
|
|
55
|
+
"""Create a new file record."""
|
|
56
|
+
self._session.add(file)
|
|
57
|
+
return file
|
|
58
|
+
|
|
59
|
+
async def upsert_author(self, author: Author) -> Author:
|
|
60
|
+
"""Create a new author or return existing one if email already exists."""
|
|
61
|
+
# First check if author already exists with same name and email
|
|
62
|
+
stmt = select(Author).where(
|
|
63
|
+
Author.name == author.name, Author.email == author.email
|
|
64
|
+
)
|
|
65
|
+
existing_author = cast("Author | None", await self._session.scalar(stmt))
|
|
66
|
+
|
|
67
|
+
if existing_author:
|
|
68
|
+
return existing_author
|
|
69
|
+
|
|
70
|
+
# Author doesn't exist, create new one
|
|
71
|
+
self._session.add(author)
|
|
72
|
+
return author
|
|
73
|
+
|
|
74
|
+
async def upsert_author_file_mapping(
|
|
75
|
+
self, mapping: AuthorFileMapping
|
|
76
|
+
) -> AuthorFileMapping:
|
|
77
|
+
"""Create a new author file mapping or return existing one if already exists."""
|
|
78
|
+
# First check if mapping already exists with same author_id and file_id
|
|
79
|
+
stmt = select(AuthorFileMapping).where(
|
|
80
|
+
AuthorFileMapping.author_id == mapping.author_id,
|
|
81
|
+
AuthorFileMapping.file_id == mapping.file_id,
|
|
82
|
+
)
|
|
83
|
+
existing_mapping = cast(
|
|
84
|
+
"AuthorFileMapping | None", await self._session.scalar(stmt)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if existing_mapping:
|
|
88
|
+
return existing_mapping
|
|
89
|
+
|
|
90
|
+
# Mapping doesn't exist, create new one
|
|
91
|
+
self._session.add(mapping)
|
|
92
|
+
return mapping
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class SqlAlchemyAuthorRepository(AuthorRepository):
|
|
96
|
+
"""SQLAlchemy author repository."""
|
|
97
|
+
|
|
98
|
+
def __init__(self, session: AsyncSession) -> None:
|
|
99
|
+
"""Initialize the repository."""
|
|
100
|
+
self._session = session
|
|
101
|
+
|
|
102
|
+
async def get(self, id: int) -> Author | None: # noqa: A002
|
|
103
|
+
"""Get an author by ID."""
|
|
104
|
+
return await self._session.get(Author, id)
|
|
105
|
+
|
|
106
|
+
async def save(self, entity: Author) -> Author:
|
|
107
|
+
"""Save entity."""
|
|
108
|
+
self._session.add(entity)
|
|
109
|
+
return entity
|
|
110
|
+
|
|
111
|
+
async def delete(self, id: int) -> None: # noqa: A002
|
|
112
|
+
"""Delete entity by ID."""
|
|
113
|
+
author = await self.get(id)
|
|
114
|
+
if author:
|
|
115
|
+
await self._session.delete(author)
|
|
116
|
+
|
|
117
|
+
async def list(self) -> Sequence[Author]:
|
|
118
|
+
"""List authors."""
|
|
119
|
+
return (await self._session.scalars(select(Author))).all()
|
|
120
|
+
|
|
121
|
+
async def get_by_name(self, name: str) -> Author | None:
|
|
122
|
+
"""Get an author by name."""
|
|
123
|
+
return cast(
|
|
124
|
+
"Author | None",
|
|
125
|
+
await self._session.scalar(select(Author).where(Author.name == name)),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
async def get_by_email(self, email: str) -> Author | None:
|
|
129
|
+
"""Get an author by email."""
|
|
130
|
+
return cast(
|
|
131
|
+
"Author | None",
|
|
132
|
+
await self._session.scalar(select(Author).where(Author.email == email)),
|
|
133
|
+
)
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""SQLAlchemy implementation of snippet repository."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
from sqlalchemy import delete, select
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
|
+
|
|
8
|
+
from kodit.domain.entities import Snippet
|
|
9
|
+
from kodit.domain.repositories import SnippetRepository
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SqlAlchemySnippetRepository(SnippetRepository):
|
|
13
|
+
"""SQLAlchemy implementation of snippet repository."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, session: AsyncSession) -> None:
|
|
16
|
+
"""Initialize the SQLAlchemy snippet repository.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
session: The SQLAlchemy async session to use for database operations
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
self.session = session
|
|
23
|
+
|
|
24
|
+
async def get(self, id: int) -> Snippet | None: # noqa: A002
|
|
25
|
+
"""Get a snippet by ID."""
|
|
26
|
+
return await self.session.get(Snippet, id)
|
|
27
|
+
|
|
28
|
+
async def save(self, entity: Snippet) -> Snippet:
|
|
29
|
+
"""Save entity."""
|
|
30
|
+
self.session.add(entity)
|
|
31
|
+
return entity
|
|
32
|
+
|
|
33
|
+
async def delete(self, id: int) -> None: # noqa: A002
|
|
34
|
+
"""Delete entity by ID."""
|
|
35
|
+
snippet = await self.get(id)
|
|
36
|
+
if snippet:
|
|
37
|
+
await self.session.delete(snippet)
|
|
38
|
+
|
|
39
|
+
async def list(self) -> Sequence[Snippet]:
|
|
40
|
+
"""List all entities."""
|
|
41
|
+
return (await self.session.scalars(select(Snippet))).all()
|
|
42
|
+
|
|
43
|
+
async def get_by_id(self, snippet_id: int) -> Snippet | None:
|
|
44
|
+
"""Get a snippet by ID.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
snippet_id: The ID of the snippet to retrieve
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
The Snippet instance if found, None otherwise
|
|
51
|
+
|
|
52
|
+
"""
|
|
53
|
+
query = select(Snippet).where(Snippet.id == snippet_id)
|
|
54
|
+
result = await self.session.execute(query)
|
|
55
|
+
return result.scalar_one_or_none()
|
|
56
|
+
|
|
57
|
+
async def get_by_index(self, index_id: int) -> Sequence[Snippet]:
|
|
58
|
+
"""Get all snippets for an index.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
index_id: The ID of the index to get snippets for
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
A list of Snippet instances
|
|
65
|
+
|
|
66
|
+
"""
|
|
67
|
+
query = select(Snippet).where(Snippet.index_id == index_id)
|
|
68
|
+
result = await self.session.execute(query)
|
|
69
|
+
return list(result.scalars())
|
|
70
|
+
|
|
71
|
+
async def delete_by_index(self, index_id: int) -> None:
|
|
72
|
+
"""Delete all snippets for an index.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
index_id: The ID of the index to delete snippets for
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
query = delete(Snippet).where(Snippet.index_id == index_id)
|
|
79
|
+
await self.session.execute(query)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""UI infrastructure module."""
|