kodit 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/alembic/env.py +0 -2
- kodit/app.py +8 -8
- kodit/bm25/__init__.py +1 -0
- kodit/bm25/bm25.py +71 -0
- kodit/cli.py +87 -33
- kodit/config.py +86 -2
- kodit/database.py +37 -56
- kodit/indexing/repository.py +11 -0
- kodit/indexing/service.py +26 -16
- kodit/logging.py +20 -18
- kodit/mcp.py +16 -4
- kodit/retreival/repository.py +32 -0
- kodit/retreival/service.py +41 -3
- kodit/snippets/__init__.py +1 -0
- kodit/snippets/languages/__init__.py +53 -0
- kodit/snippets/languages/csharp.scm +12 -0
- kodit/snippets/languages/python.scm +22 -0
- kodit/snippets/method_snippets.py +120 -0
- kodit/snippets/snippets.py +48 -0
- kodit/sources/service.py +3 -5
- {kodit-0.1.4.dist-info → kodit-0.1.5.dist-info}/METADATA +6 -2
- kodit-0.1.5.dist-info/RECORD +40 -0
- kodit/sse.py +0 -61
- kodit-0.1.4.dist-info/RECORD +0 -33
- {kodit-0.1.4.dist-info → kodit-0.1.5.dist-info}/WHEEL +0 -0
- {kodit-0.1.4.dist-info → kodit-0.1.5.dist-info}/entry_points.txt +0 -0
- {kodit-0.1.4.dist-info → kodit-0.1.5.dist-info}/licenses/LICENSE +0 -0
kodit/logging.py
CHANGED
|
@@ -11,6 +11,8 @@ import structlog
|
|
|
11
11
|
from posthog import Posthog
|
|
12
12
|
from structlog.types import EventDict
|
|
13
13
|
|
|
14
|
+
from kodit.config import Config
|
|
15
|
+
|
|
14
16
|
log = structlog.get_logger(__name__)
|
|
15
17
|
|
|
16
18
|
|
|
@@ -27,14 +29,8 @@ class LogFormat(Enum):
|
|
|
27
29
|
JSON = "json"
|
|
28
30
|
|
|
29
31
|
|
|
30
|
-
def configure_logging(
|
|
31
|
-
"""Configure logging for the application.
|
|
32
|
-
|
|
33
|
-
Args:
|
|
34
|
-
json_logs: Whether to use JSON format for logs
|
|
35
|
-
log_level: The minimum log level to display
|
|
36
|
-
|
|
37
|
-
"""
|
|
32
|
+
def configure_logging(config: Config) -> None:
|
|
33
|
+
"""Configure logging for the application."""
|
|
38
34
|
timestamper = structlog.processors.TimeStamper(fmt="iso")
|
|
39
35
|
|
|
40
36
|
shared_processors: list[structlog.types.Processor] = [
|
|
@@ -48,7 +44,7 @@ def configure_logging(log_level: str, log_format: LogFormat) -> None:
|
|
|
48
44
|
structlog.processors.StackInfoRenderer(),
|
|
49
45
|
]
|
|
50
46
|
|
|
51
|
-
if log_format == LogFormat.JSON:
|
|
47
|
+
if config.log_format == LogFormat.JSON:
|
|
52
48
|
# Format the exception only for JSON logs, as we want to pretty-print them
|
|
53
49
|
# when using the ConsoleRenderer
|
|
54
50
|
shared_processors.append(structlog.processors.format_exc_info)
|
|
@@ -64,7 +60,7 @@ def configure_logging(log_level: str, log_format: LogFormat) -> None:
|
|
|
64
60
|
)
|
|
65
61
|
|
|
66
62
|
log_renderer: structlog.types.Processor
|
|
67
|
-
if log_format == LogFormat.JSON:
|
|
63
|
+
if config.log_format == LogFormat.JSON:
|
|
68
64
|
log_renderer = structlog.processors.JSONRenderer()
|
|
69
65
|
else:
|
|
70
66
|
log_renderer = structlog.dev.ConsoleRenderer()
|
|
@@ -86,18 +82,23 @@ def configure_logging(log_level: str, log_format: LogFormat) -> None:
|
|
|
86
82
|
handler.setFormatter(formatter)
|
|
87
83
|
root_logger = logging.getLogger()
|
|
88
84
|
root_logger.addHandler(handler)
|
|
89
|
-
root_logger.setLevel(log_level.upper())
|
|
85
|
+
root_logger.setLevel(config.log_level.upper())
|
|
90
86
|
|
|
91
87
|
# Configure uvicorn loggers to use our structlog setup
|
|
88
|
+
# Uvicorn spits out loads of exception logs when sse server doesn't shut down
|
|
89
|
+
# gracefully, so we hide them unless in DEBUG mode
|
|
92
90
|
for _log in ["uvicorn", "uvicorn.error", "uvicorn.access"]:
|
|
93
|
-
|
|
94
|
-
|
|
91
|
+
if root_logger.getEffectiveLevel() == logging.DEBUG:
|
|
92
|
+
logging.getLogger(_log).handlers.clear()
|
|
93
|
+
logging.getLogger(_log).propagate = True
|
|
94
|
+
else:
|
|
95
|
+
logging.getLogger(_log).disabled = True
|
|
95
96
|
|
|
96
97
|
# Configure SQLAlchemy loggers to use our structlog setup
|
|
97
98
|
for _log in ["sqlalchemy.engine", "alembic"]:
|
|
98
99
|
engine_logger = logging.getLogger(_log)
|
|
99
100
|
engine_logger.setLevel(logging.WARNING) # Hide INFO logs by default
|
|
100
|
-
if log_level.upper() == "DEBUG":
|
|
101
|
+
if config.log_level.upper() == "DEBUG":
|
|
101
102
|
engine_logger.setLevel(
|
|
102
103
|
logging.DEBUG
|
|
103
104
|
) # Only show all logs when in DEBUG mode
|
|
@@ -142,10 +143,11 @@ def get_mac_address() -> str:
|
|
|
142
143
|
return f"{mac:012x}" if mac != uuid.getnode() else str(uuid.uuid4())
|
|
143
144
|
|
|
144
145
|
|
|
145
|
-
def
|
|
146
|
-
"""
|
|
147
|
-
|
|
148
|
-
|
|
146
|
+
def configure_telemetry(config: Config) -> None:
|
|
147
|
+
"""Configure telemetry for the application."""
|
|
148
|
+
if config.disable_telemetry:
|
|
149
|
+
structlog.stdlib.get_logger(__name__).info("Telemetry has been disabled")
|
|
150
|
+
posthog.disabled = True
|
|
149
151
|
|
|
150
152
|
|
|
151
153
|
def log_event(event: str, properties: dict[str, Any] | None = None) -> None:
|
kodit/mcp.py
CHANGED
|
@@ -4,10 +4,11 @@ from pathlib import Path
|
|
|
4
4
|
from typing import Annotated
|
|
5
5
|
|
|
6
6
|
import structlog
|
|
7
|
-
from
|
|
7
|
+
from fastmcp import FastMCP
|
|
8
8
|
from pydantic import Field
|
|
9
9
|
|
|
10
|
-
from kodit.
|
|
10
|
+
from kodit._version import version
|
|
11
|
+
from kodit.config import get_config
|
|
11
12
|
from kodit.retreival.repository import RetrievalRepository, RetrievalResult
|
|
12
13
|
from kodit.retreival.service import RetrievalRequest, RetrievalService
|
|
13
14
|
|
|
@@ -62,7 +63,11 @@ async def retrieve_relevant_snippets(
|
|
|
62
63
|
file_contents=related_file_contents,
|
|
63
64
|
)
|
|
64
65
|
|
|
65
|
-
|
|
66
|
+
# Must avoid running migrations because that runs in a separate event loop,
|
|
67
|
+
# mcp no-likey
|
|
68
|
+
config = get_config()
|
|
69
|
+
db = config.get_db(run_migrations=False)
|
|
70
|
+
async with db.get_session() as session:
|
|
66
71
|
log.debug("Creating retrieval repository")
|
|
67
72
|
retrieval_repository = RetrievalRepository(
|
|
68
73
|
session=session,
|
|
@@ -70,6 +75,7 @@ async def retrieve_relevant_snippets(
|
|
|
70
75
|
|
|
71
76
|
log.debug("Creating retrieval service")
|
|
72
77
|
retrieval_service = RetrievalService(
|
|
78
|
+
config=config,
|
|
73
79
|
repository=retrieval_repository,
|
|
74
80
|
)
|
|
75
81
|
|
|
@@ -82,7 +88,7 @@ async def retrieve_relevant_snippets(
|
|
|
82
88
|
)
|
|
83
89
|
log.debug("Input", input_query=input_query)
|
|
84
90
|
retrieval_request = RetrievalRequest(
|
|
85
|
-
|
|
91
|
+
keywords=keywords,
|
|
86
92
|
)
|
|
87
93
|
log.debug("Retrieving snippets")
|
|
88
94
|
snippets = await retrieval_service.retrieve(request=retrieval_request)
|
|
@@ -108,3 +114,9 @@ def input_fusion(
|
|
|
108
114
|
def output_fusion(snippets: list[RetrievalResult]) -> str:
|
|
109
115
|
"""Fuse the snippets into a single output."""
|
|
110
116
|
return "\n\n".join(f"{snippet.uri}\n{snippet.content}" for snippet in snippets)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@mcp.tool()
|
|
120
|
+
async def get_version() -> str:
|
|
121
|
+
"""Get the version of the kodit project."""
|
|
122
|
+
return version
|
kodit/retreival/repository.py
CHANGED
|
@@ -74,3 +74,35 @@ class RetrievalRepository:
|
|
|
74
74
|
)
|
|
75
75
|
for snippet, file in results
|
|
76
76
|
]
|
|
77
|
+
|
|
78
|
+
async def list_snippet_ids(self) -> list[int]:
|
|
79
|
+
"""List all snippet IDs.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
A list of all snippets.
|
|
83
|
+
|
|
84
|
+
"""
|
|
85
|
+
query = select(Snippet.id)
|
|
86
|
+
rows = await self.session.execute(query)
|
|
87
|
+
return list(rows.scalars().all())
|
|
88
|
+
|
|
89
|
+
async def list_snippets_by_ids(self, ids: list[int]) -> list[RetrievalResult]:
|
|
90
|
+
"""List snippets by IDs.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
A list of snippets.
|
|
94
|
+
|
|
95
|
+
"""
|
|
96
|
+
query = (
|
|
97
|
+
select(Snippet, File)
|
|
98
|
+
.where(Snippet.id.in_(ids))
|
|
99
|
+
.join(File, Snippet.file_id == File.id)
|
|
100
|
+
)
|
|
101
|
+
rows = await self.session.execute(query)
|
|
102
|
+
return [
|
|
103
|
+
RetrievalResult(
|
|
104
|
+
uri=file.uri,
|
|
105
|
+
content=snippet.content,
|
|
106
|
+
)
|
|
107
|
+
for snippet, file in rows.all()
|
|
108
|
+
]
|
kodit/retreival/service.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
|
1
1
|
"""Retrieval service."""
|
|
2
2
|
|
|
3
3
|
import pydantic
|
|
4
|
+
import structlog
|
|
4
5
|
|
|
6
|
+
from kodit.bm25.bm25 import BM25Service
|
|
7
|
+
from kodit.config import Config
|
|
5
8
|
from kodit.retreival.repository import RetrievalRepository, RetrievalResult
|
|
6
9
|
|
|
7
10
|
|
|
8
11
|
class RetrievalRequest(pydantic.BaseModel):
|
|
9
12
|
"""Request for a retrieval."""
|
|
10
13
|
|
|
11
|
-
|
|
14
|
+
keywords: list[str]
|
|
15
|
+
top_k: int = 10
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class Snippet(pydantic.BaseModel):
|
|
@@ -21,10 +25,44 @@ class Snippet(pydantic.BaseModel):
|
|
|
21
25
|
class RetrievalService:
|
|
22
26
|
"""Service for retrieving relevant data."""
|
|
23
27
|
|
|
24
|
-
def __init__(self, repository: RetrievalRepository) -> None:
|
|
28
|
+
def __init__(self, config: Config, repository: RetrievalRepository) -> None:
|
|
25
29
|
"""Initialize the retrieval service."""
|
|
26
30
|
self.repository = repository
|
|
31
|
+
self.log = structlog.get_logger(__name__)
|
|
32
|
+
self.bm25 = BM25Service(config)
|
|
33
|
+
|
|
34
|
+
async def _load_bm25_index(self) -> None:
|
|
35
|
+
"""Load the BM25 index."""
|
|
27
36
|
|
|
28
37
|
async def retrieve(self, request: RetrievalRequest) -> list[RetrievalResult]:
|
|
29
38
|
"""Retrieve relevant data."""
|
|
30
|
-
|
|
39
|
+
snippet_ids = await self.repository.list_snippet_ids()
|
|
40
|
+
|
|
41
|
+
# Gather results for each keyword
|
|
42
|
+
result_ids: list[tuple[int, float]] = []
|
|
43
|
+
for keyword in request.keywords:
|
|
44
|
+
results = self.bm25.retrieve(snippet_ids, keyword, request.top_k)
|
|
45
|
+
result_ids.extend(results)
|
|
46
|
+
|
|
47
|
+
if len(result_ids) == 0:
|
|
48
|
+
return []
|
|
49
|
+
|
|
50
|
+
# Sort results by score
|
|
51
|
+
result_ids.sort(key=lambda x: x[1], reverse=True)
|
|
52
|
+
|
|
53
|
+
self.log.debug(
|
|
54
|
+
"Retrieval results",
|
|
55
|
+
total_results=len(result_ids),
|
|
56
|
+
max_score=result_ids[0][1],
|
|
57
|
+
min_score=result_ids[-1][1],
|
|
58
|
+
median_score=result_ids[len(result_ids) // 2][1],
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Don't return zero score results
|
|
62
|
+
result_ids = [x for x in result_ids if x[1] > 0]
|
|
63
|
+
|
|
64
|
+
# Build final list of doc ids up to top_k
|
|
65
|
+
final_doc_ids = [x[0] for x in result_ids[: request.top_k]]
|
|
66
|
+
|
|
67
|
+
# Get snippets from database
|
|
68
|
+
return await self.repository.list_snippets_by_ids(final_doc_ids)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Extract method snippets from source code."""
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Detect the language of a file."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import cast
|
|
5
|
+
|
|
6
|
+
from tree_sitter_language_pack import SupportedLanguage
|
|
7
|
+
|
|
8
|
+
# Mapping of file extensions to programming languages
|
|
9
|
+
LANGUAGE_MAP: dict[str, str] = {
|
|
10
|
+
# JavaScript/TypeScript
|
|
11
|
+
"js": "javascript",
|
|
12
|
+
"jsx": "javascript",
|
|
13
|
+
"ts": "typescript",
|
|
14
|
+
"tsx": "typescript",
|
|
15
|
+
# Python
|
|
16
|
+
"py": "python",
|
|
17
|
+
# Rust
|
|
18
|
+
"rs": "rust",
|
|
19
|
+
# Go
|
|
20
|
+
"go": "go",
|
|
21
|
+
# C/C++
|
|
22
|
+
"cpp": "cpp",
|
|
23
|
+
"hpp": "cpp",
|
|
24
|
+
"c": "c",
|
|
25
|
+
"h": "c",
|
|
26
|
+
# C#
|
|
27
|
+
"cs": "csharp",
|
|
28
|
+
# Ruby
|
|
29
|
+
"rb": "ruby",
|
|
30
|
+
# Java
|
|
31
|
+
"java": "java",
|
|
32
|
+
# PHP
|
|
33
|
+
"php": "php",
|
|
34
|
+
# Swift
|
|
35
|
+
"swift": "swift",
|
|
36
|
+
# Kotlin
|
|
37
|
+
"kt": "kotlin",
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def detect_language(file_path: Path) -> SupportedLanguage:
|
|
42
|
+
"""Detect the language of a file."""
|
|
43
|
+
suffix = file_path.suffix.removeprefix(".").lower()
|
|
44
|
+
msg = f"Unsupported language for file suffix: {suffix}"
|
|
45
|
+
lang = LANGUAGE_MAP.get(suffix)
|
|
46
|
+
if lang is None:
|
|
47
|
+
raise ValueError(msg)
|
|
48
|
+
|
|
49
|
+
# Try to cast the language to a SupportedLanguage
|
|
50
|
+
try:
|
|
51
|
+
return cast("SupportedLanguage", lang)
|
|
52
|
+
except Exception as e:
|
|
53
|
+
raise ValueError(msg) from e
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
(function_definition
|
|
2
|
+
name: (identifier) @function.name
|
|
3
|
+
body: (block) @function.body
|
|
4
|
+
) @function.def
|
|
5
|
+
|
|
6
|
+
(class_definition
|
|
7
|
+
name: (identifier) @class.name
|
|
8
|
+
) @class.def
|
|
9
|
+
|
|
10
|
+
(import_statement
|
|
11
|
+
name: (dotted_name (identifier) @import.name))
|
|
12
|
+
|
|
13
|
+
(import_from_statement
|
|
14
|
+
module_name: (dotted_name (identifier) @import.from))
|
|
15
|
+
|
|
16
|
+
(identifier) @ident
|
|
17
|
+
|
|
18
|
+
(assignment
|
|
19
|
+
left: (identifier) @assignment.lhs)
|
|
20
|
+
|
|
21
|
+
(parameters
|
|
22
|
+
(identifier) @param.name)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Extract method snippets from source code."""
|
|
2
|
+
|
|
3
|
+
from tree_sitter import Node, Query
|
|
4
|
+
from tree_sitter_language_pack import SupportedLanguage, get_language, get_parser
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MethodSnippets:
|
|
8
|
+
"""Extract method snippets from source code."""
|
|
9
|
+
|
|
10
|
+
def __init__(self, language: SupportedLanguage, query: str) -> None:
|
|
11
|
+
"""Initialize the MethodSnippets class."""
|
|
12
|
+
self.language = get_language(language)
|
|
13
|
+
self.parser = get_parser(language)
|
|
14
|
+
self.query = Query(self.language, query)
|
|
15
|
+
|
|
16
|
+
def _get_leaf_functions(
|
|
17
|
+
self, captures_by_name: dict[str, list[Node]]
|
|
18
|
+
) -> list[Node]:
|
|
19
|
+
"""Return all leaf functions in the AST."""
|
|
20
|
+
return [
|
|
21
|
+
node
|
|
22
|
+
for node in captures_by_name.get("function.body", [])
|
|
23
|
+
if self._is_leaf_function(captures_by_name, node)
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
def _is_leaf_function(
|
|
27
|
+
self, captures_by_name: dict[str, list[Node]], node: Node
|
|
28
|
+
) -> bool:
|
|
29
|
+
"""Return True if the node is a leaf function."""
|
|
30
|
+
for other in captures_by_name.get("function.body", []):
|
|
31
|
+
if other == node: # Skip self
|
|
32
|
+
continue
|
|
33
|
+
# if other is inside node, it's not a leaf function
|
|
34
|
+
if other.start_byte >= node.start_byte and other.end_byte <= node.end_byte:
|
|
35
|
+
return False
|
|
36
|
+
return True
|
|
37
|
+
|
|
38
|
+
def _get_imports(self, captures_by_name: dict[str, list[Node]]) -> list[Node]:
|
|
39
|
+
"""Return all imports in the AST."""
|
|
40
|
+
return captures_by_name.get("import.name", []) + captures_by_name.get(
|
|
41
|
+
"import.from", []
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def _classes_and_functions(
|
|
45
|
+
self, captures_by_name: dict[str, list[Node]]
|
|
46
|
+
) -> list[int]:
|
|
47
|
+
"""Return all classes and functions in the AST."""
|
|
48
|
+
return [
|
|
49
|
+
node.id
|
|
50
|
+
for node in {
|
|
51
|
+
*captures_by_name.get("function.def", []),
|
|
52
|
+
*captures_by_name.get("class.def", []),
|
|
53
|
+
}
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
def _get_ancestors(
|
|
57
|
+
self, captures_by_name: dict[str, list[Node]], node: Node
|
|
58
|
+
) -> list[Node]:
|
|
59
|
+
"""Return all ancestors of the node."""
|
|
60
|
+
valid_ancestors = self._classes_and_functions(captures_by_name)
|
|
61
|
+
ancestors = []
|
|
62
|
+
parent = node.parent
|
|
63
|
+
while parent:
|
|
64
|
+
if parent.id in valid_ancestors:
|
|
65
|
+
ancestors.append(parent)
|
|
66
|
+
parent = parent.parent
|
|
67
|
+
return ancestors
|
|
68
|
+
|
|
69
|
+
def extract(self, source_code: bytes) -> list[str]:
|
|
70
|
+
"""Extract method snippets from source code."""
|
|
71
|
+
tree = self.parser.parse(source_code)
|
|
72
|
+
|
|
73
|
+
captures_by_name = self.query.captures(tree.root_node)
|
|
74
|
+
|
|
75
|
+
lines = source_code.decode().splitlines()
|
|
76
|
+
|
|
77
|
+
# Find all leaf functions
|
|
78
|
+
leaf_functions = self._get_leaf_functions(captures_by_name)
|
|
79
|
+
|
|
80
|
+
# Find all imports
|
|
81
|
+
imports = self._get_imports(captures_by_name)
|
|
82
|
+
|
|
83
|
+
results = []
|
|
84
|
+
|
|
85
|
+
# For each leaf function, find all lines this function is dependent on
|
|
86
|
+
for func_node in leaf_functions:
|
|
87
|
+
all_lines_to_keep = set()
|
|
88
|
+
|
|
89
|
+
ancestors = self._get_ancestors(captures_by_name, func_node)
|
|
90
|
+
|
|
91
|
+
# Add self to keep
|
|
92
|
+
all_lines_to_keep.update(
|
|
93
|
+
range(func_node.start_point[0], func_node.end_point[0] + 1)
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Add imports to keep
|
|
97
|
+
for import_node in imports:
|
|
98
|
+
all_lines_to_keep.update(
|
|
99
|
+
range(import_node.start_point[0], import_node.end_point[0] + 1)
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Add ancestors to keep
|
|
103
|
+
for node in ancestors:
|
|
104
|
+
# Get the first line of the node for now
|
|
105
|
+
start = node.start_point[0]
|
|
106
|
+
end = node.start_point[0]
|
|
107
|
+
all_lines_to_keep.update(range(start, end + 1))
|
|
108
|
+
|
|
109
|
+
pseudo_code = []
|
|
110
|
+
for i, line in enumerate(lines):
|
|
111
|
+
if i in all_lines_to_keep:
|
|
112
|
+
pseudo_code.append(line)
|
|
113
|
+
|
|
114
|
+
results.append("\n".join(pseudo_code))
|
|
115
|
+
|
|
116
|
+
# If there are no results, then return the entire file
|
|
117
|
+
if not results:
|
|
118
|
+
return [source_code.decode()]
|
|
119
|
+
|
|
120
|
+
return results
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Generate snippets from a file."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from kodit.snippets.languages import detect_language
|
|
7
|
+
from kodit.snippets.method_snippets import MethodSnippets
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class Snippet:
|
|
12
|
+
"""A snippet of code."""
|
|
13
|
+
|
|
14
|
+
text: str
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SnippetService:
|
|
18
|
+
"""Factory for generating snippets from a file.
|
|
19
|
+
|
|
20
|
+
This is required because there's going to be multiple ways to generate snippets.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self) -> None:
|
|
24
|
+
"""Initialize the snippet factory."""
|
|
25
|
+
self.language_dir = Path(__file__).parent / "languages"
|
|
26
|
+
|
|
27
|
+
def snippets_for_file(self, file_path: Path) -> list[Snippet]:
|
|
28
|
+
"""Generate snippets from a file."""
|
|
29
|
+
language = detect_language(file_path)
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
query_path = self.language_dir / f"{language}.scm"
|
|
33
|
+
with query_path.open() as f:
|
|
34
|
+
query = f.read()
|
|
35
|
+
except Exception as e:
|
|
36
|
+
msg = f"Unsupported language: {file_path}"
|
|
37
|
+
raise ValueError(msg) from e
|
|
38
|
+
|
|
39
|
+
method_analser = MethodSnippets(language, query)
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
file_bytes = file_path.read_bytes()
|
|
43
|
+
except Exception as e:
|
|
44
|
+
msg = f"Failed to read file: {file_path}"
|
|
45
|
+
raise ValueError(msg) from e
|
|
46
|
+
|
|
47
|
+
method_snippets = method_analser.extract(file_bytes)
|
|
48
|
+
return [Snippet(text=snippet) for snippet in method_snippets]
|
kodit/sources/service.py
CHANGED
|
@@ -18,12 +18,9 @@ import structlog
|
|
|
18
18
|
from tqdm import tqdm
|
|
19
19
|
from uritools import isuri, urisplit
|
|
20
20
|
|
|
21
|
-
from kodit.config import DATA_DIR
|
|
22
21
|
from kodit.sources.models import File, Source
|
|
23
22
|
from kodit.sources.repository import SourceRepository
|
|
24
23
|
|
|
25
|
-
CLONE_DIR = DATA_DIR / "clones"
|
|
26
|
-
|
|
27
24
|
|
|
28
25
|
class SourceView(pydantic.BaseModel):
|
|
29
26
|
"""View model for displaying source information.
|
|
@@ -53,13 +50,14 @@ class SourceService:
|
|
|
53
50
|
SourceRepository), and provides a clean API for source management.
|
|
54
51
|
"""
|
|
55
52
|
|
|
56
|
-
def __init__(self, repository: SourceRepository) -> None:
|
|
53
|
+
def __init__(self, clone_dir: Path, repository: SourceRepository) -> None:
|
|
57
54
|
"""Initialize the source service.
|
|
58
55
|
|
|
59
56
|
Args:
|
|
60
57
|
repository: The repository instance to use for database operations.
|
|
61
58
|
|
|
62
59
|
"""
|
|
60
|
+
self.clone_dir = clone_dir
|
|
63
61
|
self.repository = repository
|
|
64
62
|
self.log = structlog.get_logger(__name__)
|
|
65
63
|
|
|
@@ -129,7 +127,7 @@ class SourceService:
|
|
|
129
127
|
raise ValueError(msg)
|
|
130
128
|
|
|
131
129
|
# Clone into a local directory
|
|
132
|
-
clone_path =
|
|
130
|
+
clone_path = self.clone_dir / directory.as_posix().replace("/", "_")
|
|
133
131
|
clone_path.mkdir(parents=True, exist_ok=True)
|
|
134
132
|
|
|
135
133
|
# Copy all files recursively, preserving directory structure, ignoring hidden
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kodit
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.5
|
|
4
4
|
Summary: Code indexing for better AI code generation
|
|
5
5
|
Project-URL: Homepage, https://docs.helixml.tech/kodit/
|
|
6
6
|
Project-URL: Documentation, https://docs.helixml.tech/kodit/
|
|
@@ -22,18 +22,22 @@ Requires-Dist: aiosqlite>=0.20.0
|
|
|
22
22
|
Requires-Dist: alembic>=1.15.2
|
|
23
23
|
Requires-Dist: asgi-correlation-id>=4.3.4
|
|
24
24
|
Requires-Dist: better-exceptions>=0.3.3
|
|
25
|
+
Requires-Dist: bm25s[core]>=0.2.12
|
|
25
26
|
Requires-Dist: click>=8.1.8
|
|
26
27
|
Requires-Dist: colorama>=0.4.6
|
|
27
28
|
Requires-Dist: dotenv>=0.9.9
|
|
28
29
|
Requires-Dist: fastapi[standard]>=0.115.12
|
|
30
|
+
Requires-Dist: fastmcp>=2.3.3
|
|
29
31
|
Requires-Dist: httpx-retries>=0.3.2
|
|
30
32
|
Requires-Dist: httpx>=0.28.1
|
|
31
|
-
Requires-Dist: mcp>=1.6.0
|
|
32
33
|
Requires-Dist: posthog>=4.0.1
|
|
34
|
+
Requires-Dist: pydantic-settings>=2.9.1
|
|
33
35
|
Requires-Dist: pytable-formatter>=0.1.1
|
|
34
36
|
Requires-Dist: sqlalchemy[asyncio]>=2.0.40
|
|
35
37
|
Requires-Dist: structlog>=25.3.0
|
|
36
38
|
Requires-Dist: tdqm>=0.0.1
|
|
39
|
+
Requires-Dist: tree-sitter-language-pack>=0.7.3
|
|
40
|
+
Requires-Dist: tree-sitter>=0.24.0
|
|
37
41
|
Requires-Dist: uritools>=5.0.0
|
|
38
42
|
Description-Content-Type: text/markdown
|
|
39
43
|
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
kodit/.gitignore,sha256=ztkjgRwL9Uud1OEi36hGQeDGk3OLK1NfDEO8YqGYy8o,11
|
|
2
|
+
kodit/__init__.py,sha256=aEKHYninUq1yh6jaNfvJBYg-6fenpN132nJt1UU6Jxs,59
|
|
3
|
+
kodit/_version.py,sha256=Y4jy4bEMmwl_qNPCmiMFnlQ2ofMoqyG37hp8uwI3m10,511
|
|
4
|
+
kodit/app.py,sha256=TdPpCN4ucOElKHwDebfKgeVJ9xexdfpzpk6hnDH69vM,703
|
|
5
|
+
kodit/cli.py,sha256=CjmiRaJ-SdfCMYlVQGnxPSsoX5j3ix4fN3OLVc5EYkY,7473
|
|
6
|
+
kodit/config.py,sha256=18dhSYaE-ut2qXrBRKuCqLXeBCLEXw2y1Uw4lieMPwY,2682
|
|
7
|
+
kodit/database.py,sha256=NnAluOj_JHjnj5MeKuU9LApgSzik2kru1bQl-24vHkc,2272
|
|
8
|
+
kodit/logging.py,sha256=P1D9flYnvYxPw-DyOGyiv3y30x0gHPwdk6VJS29YHus,5269
|
|
9
|
+
kodit/mcp.py,sha256=O24O_GFzwwv5E-uBFoW_zZlSigeNSigaCj0s1xOmP8M,3855
|
|
10
|
+
kodit/middleware.py,sha256=NHLrqq20ZtPTE9esX9HD3z7EKi56_QTFxBlkdq0JDzQ,2138
|
|
11
|
+
kodit/alembic/README,sha256=ISVtAOvqvKk_5ThM5ioJE-lMkvf9IbknFUFVU_vPma4,58
|
|
12
|
+
kodit/alembic/__init__.py,sha256=lP5MuwlyWRMO6UcDWnQcQ3G-GYHcFb6rl9gYPHJ1sjo,40
|
|
13
|
+
kodit/alembic/env.py,sha256=IXhl7yvURSycs2v_3pd14Sr8_zGRfYXlQwWby1abfuk,2290
|
|
14
|
+
kodit/alembic/script.py.mako,sha256=zWziKtiwYKEWuwPV_HBNHwa9LCT45_bi01-uSNFaOOE,703
|
|
15
|
+
kodit/alembic/versions/85155663351e_initial.py,sha256=Cg7zlF871o9ShV5rQMQ1v7hRV7fI59veDY9cjtTrs-8,3306
|
|
16
|
+
kodit/alembic/versions/__init__.py,sha256=9-lHzptItTzq_fomdIRBegQNm4Znx6pVjwD4MiqRIdo,36
|
|
17
|
+
kodit/bm25/__init__.py,sha256=j8zyriNWhbwE5Lbybzg1hQAhANlU9mKHWw4beeUR6og,19
|
|
18
|
+
kodit/bm25/bm25.py,sha256=V0_byhV4kVnI3E-PBNsc4rBjQsDuZo1bt1uQKnywLS8,2283
|
|
19
|
+
kodit/indexing/__init__.py,sha256=cPyi2Iej3G1JFWlWr7X80_UrsMaTu5W5rBwgif1B3xo,75
|
|
20
|
+
kodit/indexing/models.py,sha256=sZIhGwvL4Dw0QTWFxrjfWctSLkAoDT6fv5DlGz8-Fr8,1258
|
|
21
|
+
kodit/indexing/repository.py,sha256=kvAlNfMSQYboF0TB1huw2qoBdLJ4UsEPiM7ZG-e6rrg,4300
|
|
22
|
+
kodit/indexing/service.py,sha256=N8QhrAvqhIHOgSlT9Jc786rjcVjMwiyiMTZr7mNA8D8,5431
|
|
23
|
+
kodit/retreival/__init__.py,sha256=33PhJU-3gtsqYq6A1UkaLNKbev_Zee9Lq6dYC59-CsA,69
|
|
24
|
+
kodit/retreival/repository.py,sha256=1lqGgJHsBmvMGMzEYa-hrdXg2q7rqtYPl1cvBb7jMRE,3119
|
|
25
|
+
kodit/retreival/service.py,sha256=g6iwM2FMxrL8WjtWnZdKdxKpfn6b0ThBmOdLWd7AKKQ,2011
|
|
26
|
+
kodit/snippets/__init__.py,sha256=-2coNoCRjTixU9KcP6alpmt7zqf37tCRWH3D7FPJ8dg,48
|
|
27
|
+
kodit/snippets/method_snippets.py,sha256=EVHhSNWahAC5nSXv9fWVFJY2yq25goHdCSCuENC07F8,4145
|
|
28
|
+
kodit/snippets/snippets.py,sha256=QumvhltWoxXw41SyKb-RbSvAr3m6V3lUy9n0AI8jcto,1409
|
|
29
|
+
kodit/snippets/languages/__init__.py,sha256=Bj5KKZSls2MQ8ZY1S_nHg447MgGZW-2WZM-oq6vjwwA,1187
|
|
30
|
+
kodit/snippets/languages/csharp.scm,sha256=gbBN4RiV1FBuTJF6orSnDFi8H9JwTw-d4piLJYsWUsc,222
|
|
31
|
+
kodit/snippets/languages/python.scm,sha256=ee85R9PBzwye3IMTE7-iVoKWd_ViU3EJISTyrFGrVeo,429
|
|
32
|
+
kodit/sources/__init__.py,sha256=1NTZyPdjThVQpZO1Mp1ColVsS7sqYanOVLqnoqV9Ipo,83
|
|
33
|
+
kodit/sources/models.py,sha256=xb42CaNDO1CUB8SIW-xXMrB6Ji8cFw-yeJ550xBEg9Q,2398
|
|
34
|
+
kodit/sources/repository.py,sha256=mGJrHWH6Uo8YABdoojHFbzaf_jW-2ywJpAHIa1gnc3U,3401
|
|
35
|
+
kodit/sources/service.py,sha256=cBCxnOQKwGNi2e13_3Vue8MylAaUxb9XG4IgM636la0,6712
|
|
36
|
+
kodit-0.1.5.dist-info/METADATA,sha256=N4fIBAIREHOujaDvEVj83fNvyeB8D7HaLXNcpeVdNJY,2181
|
|
37
|
+
kodit-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
38
|
+
kodit-0.1.5.dist-info/entry_points.txt,sha256=hoTn-1aKyTItjnY91fnO-rV5uaWQLQ-Vi7V5et2IbHY,40
|
|
39
|
+
kodit-0.1.5.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
40
|
+
kodit-0.1.5.dist-info/RECORD,,
|
kodit/sse.py
DELETED
|
@@ -1,61 +0,0 @@
|
|
|
1
|
-
"""Server-Sent Events (SSE) implementation for kodit."""
|
|
2
|
-
|
|
3
|
-
from collections.abc import Coroutine
|
|
4
|
-
from typing import Any
|
|
5
|
-
|
|
6
|
-
from fastapi import Request
|
|
7
|
-
from mcp.server.fastmcp import FastMCP
|
|
8
|
-
from mcp.server.session import ServerSession
|
|
9
|
-
from mcp.server.sse import SseServerTransport
|
|
10
|
-
from starlette.applications import Starlette
|
|
11
|
-
from starlette.routing import Mount, Route
|
|
12
|
-
|
|
13
|
-
####################################################################################
|
|
14
|
-
# Temporary monkeypatch which avoids crashing when a POST message is received
|
|
15
|
-
# before a connection has been initialized, e.g: after a deployment.
|
|
16
|
-
old__received_request = ServerSession._received_request # noqa: SLF001
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
async def _received_request(self: ServerSession, *args: Any, **kwargs: Any) -> None:
|
|
20
|
-
"""Handle a received request, catching RuntimeError to avoid crashes.
|
|
21
|
-
|
|
22
|
-
This is a temporary monkeypatch to avoid crashing when a POST message is
|
|
23
|
-
received before a connection has been initialized, e.g: after a deployment.
|
|
24
|
-
"""
|
|
25
|
-
try:
|
|
26
|
-
return await old__received_request(self, *args, **kwargs)
|
|
27
|
-
except RuntimeError:
|
|
28
|
-
pass
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
# pylint: disable-next=protected-access
|
|
32
|
-
ServerSession._received_request = _received_request # noqa: SLF001
|
|
33
|
-
####################################################################################
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def create_sse_server(mcp: FastMCP) -> Starlette:
|
|
37
|
-
"""Create a Starlette app that handles SSE connections and message handling."""
|
|
38
|
-
transport = SseServerTransport("/messages/")
|
|
39
|
-
|
|
40
|
-
# Define handler functions
|
|
41
|
-
async def handle_sse(request: Request) -> Coroutine[Any, Any, None]:
|
|
42
|
-
"""Handle SSE connections."""
|
|
43
|
-
async with transport.connect_sse(
|
|
44
|
-
request.scope,
|
|
45
|
-
request.receive,
|
|
46
|
-
request._send, # noqa: SLF001
|
|
47
|
-
) as streams:
|
|
48
|
-
await mcp._mcp_server.run( # noqa: SLF001
|
|
49
|
-
streams[0],
|
|
50
|
-
streams[1],
|
|
51
|
-
mcp._mcp_server.create_initialization_options(), # noqa: SLF001
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
# Create Starlette routes for SSE and message handling
|
|
55
|
-
routes = [
|
|
56
|
-
Route("/sse/", endpoint=handle_sse),
|
|
57
|
-
Mount("/messages/", app=transport.handle_post_message),
|
|
58
|
-
]
|
|
59
|
-
|
|
60
|
-
# Create a Starlette app
|
|
61
|
-
return Starlette(routes=routes)
|