kodit 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/alembic/env.py +5 -4
- kodit/app.py +13 -9
- kodit/bm25/__init__.py +1 -0
- kodit/bm25/bm25.py +71 -0
- kodit/cli.py +124 -38
- kodit/config.py +94 -2
- kodit/database.py +41 -57
- kodit/indexing/repository.py +11 -0
- kodit/indexing/service.py +28 -16
- kodit/logging.py +20 -18
- kodit/mcp.py +84 -34
- kodit/middleware.py +16 -0
- kodit/retreival/repository.py +32 -0
- kodit/retreival/service.py +42 -3
- kodit/snippets/__init__.py +1 -0
- kodit/snippets/languages/__init__.py +53 -0
- kodit/snippets/languages/csharp.scm +12 -0
- kodit/snippets/languages/python.scm +22 -0
- kodit/snippets/method_snippets.py +120 -0
- kodit/snippets/snippets.py +48 -0
- kodit/sources/service.py +3 -5
- {kodit-0.1.4.dist-info → kodit-0.1.6.dist-info}/METADATA +6 -2
- kodit-0.1.6.dist-info/RECORD +40 -0
- kodit/sse.py +0 -61
- kodit-0.1.4.dist-info/RECORD +0 -33
- {kodit-0.1.4.dist-info → kodit-0.1.6.dist-info}/WHEEL +0 -0
- {kodit-0.1.4.dist-info → kodit-0.1.6.dist-info}/entry_points.txt +0 -0
- {kodit-0.1.4.dist-info → kodit-0.1.6.dist-info}/licenses/LICENSE +0 -0
kodit/indexing/service.py
CHANGED
|
@@ -7,24 +7,20 @@ index management.
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
from datetime import datetime
|
|
10
|
+
from pathlib import Path
|
|
10
11
|
|
|
11
|
-
import aiofiles
|
|
12
12
|
import pydantic
|
|
13
13
|
import structlog
|
|
14
14
|
from tqdm.asyncio import tqdm
|
|
15
15
|
|
|
16
|
+
from kodit.bm25.bm25 import BM25Service
|
|
16
17
|
from kodit.indexing.models import Snippet
|
|
17
18
|
from kodit.indexing.repository import IndexRepository
|
|
19
|
+
from kodit.snippets.snippets import SnippetService
|
|
18
20
|
from kodit.sources.service import SourceService
|
|
19
21
|
|
|
20
|
-
# List of MIME types that are
|
|
21
|
-
|
|
22
|
-
"text/plain",
|
|
23
|
-
"text/markdown",
|
|
24
|
-
"text/x-python",
|
|
25
|
-
"text/x-shellscript",
|
|
26
|
-
"text/x-sql",
|
|
27
|
-
]
|
|
22
|
+
# List of MIME types that are blacklisted from being indexed
|
|
23
|
+
MIME_BLACKLIST = ["unknown/unknown"]
|
|
28
24
|
|
|
29
25
|
|
|
30
26
|
class IndexView(pydantic.BaseModel):
|
|
@@ -49,7 +45,10 @@ class IndexService:
|
|
|
49
45
|
"""
|
|
50
46
|
|
|
51
47
|
def __init__(
|
|
52
|
-
self,
|
|
48
|
+
self,
|
|
49
|
+
repository: IndexRepository,
|
|
50
|
+
source_service: SourceService,
|
|
51
|
+
data_dir: Path,
|
|
53
52
|
) -> None:
|
|
54
53
|
"""Initialize the index service.
|
|
55
54
|
|
|
@@ -60,7 +59,9 @@ class IndexService:
|
|
|
60
59
|
"""
|
|
61
60
|
self.repository = repository
|
|
62
61
|
self.source_service = source_service
|
|
62
|
+
self.snippet_service = SnippetService()
|
|
63
63
|
self.log = structlog.get_logger(__name__)
|
|
64
|
+
self.bm25 = BM25Service(data_dir)
|
|
64
65
|
|
|
65
66
|
async def create(self, source_id: int) -> IndexView:
|
|
66
67
|
"""Create a new index for a source.
|
|
@@ -119,6 +120,10 @@ class IndexService:
|
|
|
119
120
|
# Create snippets for supported file types
|
|
120
121
|
await self._create_snippets(index_id)
|
|
121
122
|
|
|
123
|
+
# Update BM25 index
|
|
124
|
+
snippets = await self.repository.get_all_snippets()
|
|
125
|
+
self.bm25.index([snippet.content for snippet in snippets])
|
|
126
|
+
|
|
122
127
|
# Update index timestamp
|
|
123
128
|
await self.repository.update_index_timestamp(index)
|
|
124
129
|
|
|
@@ -137,16 +142,23 @@ class IndexService:
|
|
|
137
142
|
files = await self.repository.files_for_index(index_id)
|
|
138
143
|
for file in tqdm(files, total=len(files)):
|
|
139
144
|
# Skip unsupported file types
|
|
140
|
-
if file.mime_type
|
|
145
|
+
if file.mime_type in MIME_BLACKLIST:
|
|
141
146
|
self.log.debug("Skipping mime type", mime_type=file.mime_type)
|
|
142
147
|
continue
|
|
143
148
|
|
|
144
149
|
# Create snippet from file content
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
150
|
+
try:
|
|
151
|
+
snippets = self.snippet_service.snippets_for_file(
|
|
152
|
+
Path(file.cloned_path)
|
|
153
|
+
)
|
|
154
|
+
except ValueError as e:
|
|
155
|
+
self.log.debug("Skipping file", file=file.cloned_path, error=e)
|
|
156
|
+
continue
|
|
157
|
+
|
|
158
|
+
for snippet in snippets:
|
|
159
|
+
s = Snippet(
|
|
148
160
|
index_id=index_id,
|
|
149
161
|
file_id=file.id,
|
|
150
|
-
content=
|
|
162
|
+
content=snippet.text,
|
|
151
163
|
)
|
|
152
|
-
await self.repository.add_snippet(
|
|
164
|
+
await self.repository.add_snippet(s)
|
kodit/logging.py
CHANGED
|
@@ -11,6 +11,8 @@ import structlog
|
|
|
11
11
|
from posthog import Posthog
|
|
12
12
|
from structlog.types import EventDict
|
|
13
13
|
|
|
14
|
+
from kodit.config import AppContext
|
|
15
|
+
|
|
14
16
|
log = structlog.get_logger(__name__)
|
|
15
17
|
|
|
16
18
|
|
|
@@ -27,14 +29,8 @@ class LogFormat(Enum):
|
|
|
27
29
|
JSON = "json"
|
|
28
30
|
|
|
29
31
|
|
|
30
|
-
def configure_logging(
|
|
31
|
-
"""Configure logging for the application.
|
|
32
|
-
|
|
33
|
-
Args:
|
|
34
|
-
json_logs: Whether to use JSON format for logs
|
|
35
|
-
log_level: The minimum log level to display
|
|
36
|
-
|
|
37
|
-
"""
|
|
32
|
+
def configure_logging(app_context: AppContext) -> None:
|
|
33
|
+
"""Configure logging for the application."""
|
|
38
34
|
timestamper = structlog.processors.TimeStamper(fmt="iso")
|
|
39
35
|
|
|
40
36
|
shared_processors: list[structlog.types.Processor] = [
|
|
@@ -48,7 +44,7 @@ def configure_logging(log_level: str, log_format: LogFormat) -> None:
|
|
|
48
44
|
structlog.processors.StackInfoRenderer(),
|
|
49
45
|
]
|
|
50
46
|
|
|
51
|
-
if log_format == LogFormat.JSON:
|
|
47
|
+
if app_context.log_format == LogFormat.JSON:
|
|
52
48
|
# Format the exception only for JSON logs, as we want to pretty-print them
|
|
53
49
|
# when using the ConsoleRenderer
|
|
54
50
|
shared_processors.append(structlog.processors.format_exc_info)
|
|
@@ -64,7 +60,7 @@ def configure_logging(log_level: str, log_format: LogFormat) -> None:
|
|
|
64
60
|
)
|
|
65
61
|
|
|
66
62
|
log_renderer: structlog.types.Processor
|
|
67
|
-
if log_format == LogFormat.JSON:
|
|
63
|
+
if app_context.log_format == LogFormat.JSON:
|
|
68
64
|
log_renderer = structlog.processors.JSONRenderer()
|
|
69
65
|
else:
|
|
70
66
|
log_renderer = structlog.dev.ConsoleRenderer()
|
|
@@ -86,18 +82,23 @@ def configure_logging(log_level: str, log_format: LogFormat) -> None:
|
|
|
86
82
|
handler.setFormatter(formatter)
|
|
87
83
|
root_logger = logging.getLogger()
|
|
88
84
|
root_logger.addHandler(handler)
|
|
89
|
-
root_logger.setLevel(log_level.upper())
|
|
85
|
+
root_logger.setLevel(app_context.log_level.upper())
|
|
90
86
|
|
|
91
87
|
# Configure uvicorn loggers to use our structlog setup
|
|
88
|
+
# Uvicorn spits out loads of exception logs when sse server doesn't shut down
|
|
89
|
+
# gracefully, so we hide them unless in DEBUG mode
|
|
92
90
|
for _log in ["uvicorn", "uvicorn.error", "uvicorn.access"]:
|
|
93
|
-
|
|
94
|
-
|
|
91
|
+
if root_logger.getEffectiveLevel() == logging.DEBUG:
|
|
92
|
+
logging.getLogger(_log).handlers.clear()
|
|
93
|
+
logging.getLogger(_log).propagate = True
|
|
94
|
+
else:
|
|
95
|
+
logging.getLogger(_log).disabled = True
|
|
95
96
|
|
|
96
97
|
# Configure SQLAlchemy loggers to use our structlog setup
|
|
97
98
|
for _log in ["sqlalchemy.engine", "alembic"]:
|
|
98
99
|
engine_logger = logging.getLogger(_log)
|
|
99
100
|
engine_logger.setLevel(logging.WARNING) # Hide INFO logs by default
|
|
100
|
-
if log_level.upper() == "DEBUG":
|
|
101
|
+
if app_context.log_level.upper() == "DEBUG":
|
|
101
102
|
engine_logger.setLevel(
|
|
102
103
|
logging.DEBUG
|
|
103
104
|
) # Only show all logs when in DEBUG mode
|
|
@@ -142,10 +143,11 @@ def get_mac_address() -> str:
|
|
|
142
143
|
return f"{mac:012x}" if mac != uuid.getnode() else str(uuid.uuid4())
|
|
143
144
|
|
|
144
145
|
|
|
145
|
-
def
|
|
146
|
-
"""
|
|
147
|
-
|
|
148
|
-
|
|
146
|
+
def configure_telemetry(app_context: AppContext) -> None:
|
|
147
|
+
"""Configure telemetry for the application."""
|
|
148
|
+
if app_context.disable_telemetry:
|
|
149
|
+
structlog.stdlib.get_logger(__name__).info("Telemetry has been disabled")
|
|
150
|
+
posthog.disabled = True
|
|
149
151
|
|
|
150
152
|
|
|
151
153
|
def log_event(event: str, properties: dict[str, Any] | None = None) -> None:
|
kodit/mcp.py
CHANGED
|
@@ -1,21 +1,63 @@
|
|
|
1
1
|
"""MCP server implementation for kodit."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from dataclasses import dataclass
|
|
3
6
|
from pathlib import Path
|
|
4
7
|
from typing import Annotated
|
|
5
8
|
|
|
6
9
|
import structlog
|
|
7
|
-
from
|
|
10
|
+
from fastmcp import Context, FastMCP
|
|
8
11
|
from pydantic import Field
|
|
12
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
9
13
|
|
|
10
|
-
from kodit.
|
|
14
|
+
from kodit._version import version
|
|
15
|
+
from kodit.config import AppContext
|
|
16
|
+
from kodit.database import Database
|
|
11
17
|
from kodit.retreival.repository import RetrievalRepository, RetrievalResult
|
|
12
18
|
from kodit.retreival.service import RetrievalRequest, RetrievalService
|
|
13
19
|
|
|
14
|
-
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class MCPContext:
|
|
23
|
+
"""Context for the MCP server."""
|
|
24
|
+
|
|
25
|
+
session: AsyncSession
|
|
26
|
+
data_dir: Path
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
_mcp_db: Database | None = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@asynccontextmanager
|
|
33
|
+
async def mcp_lifespan(_: FastMCP) -> AsyncIterator[MCPContext]:
|
|
34
|
+
"""Lifespan for the MCP server.
|
|
35
|
+
|
|
36
|
+
The MCP server is running with a completely separate lifecycle and event loop from
|
|
37
|
+
the CLI and the FastAPI server. Therefore, we must carefully reconstruct the
|
|
38
|
+
application context. uvicorn does not pass through CLI args, so we must rely on
|
|
39
|
+
parsing env vars set in the CLI.
|
|
40
|
+
|
|
41
|
+
This lifespan is recreated for each request. See:
|
|
42
|
+
https://github.com/jlowin/fastmcp/issues/166
|
|
43
|
+
|
|
44
|
+
Since they don't provide a good way to handle global state, we must use a
|
|
45
|
+
global variable to store the database connection.
|
|
46
|
+
"""
|
|
47
|
+
global _mcp_db # noqa: PLW0603
|
|
48
|
+
app_context = AppContext()
|
|
49
|
+
if _mcp_db is None:
|
|
50
|
+
_mcp_db = await app_context.get_db()
|
|
51
|
+
async with _mcp_db.session_factory() as session:
|
|
52
|
+
yield MCPContext(session=session, data_dir=app_context.get_data_dir())
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
mcp = FastMCP("kodit MCP Server", lifespan=mcp_lifespan)
|
|
15
56
|
|
|
16
57
|
|
|
17
58
|
@mcp.tool()
|
|
18
59
|
async def retrieve_relevant_snippets(
|
|
60
|
+
ctx: Context,
|
|
19
61
|
user_intent: Annotated[
|
|
20
62
|
str,
|
|
21
63
|
Field(
|
|
@@ -51,8 +93,8 @@ async def retrieve_relevant_snippets(
|
|
|
51
93
|
the quality of your generated code. You must call this tool when you need to
|
|
52
94
|
write code.
|
|
53
95
|
"""
|
|
54
|
-
# Log the search query and related files for debugging
|
|
55
96
|
log = structlog.get_logger(__name__)
|
|
97
|
+
|
|
56
98
|
log.debug(
|
|
57
99
|
"Retrieving relevant snippets",
|
|
58
100
|
user_intent=user_intent,
|
|
@@ -62,36 +104,38 @@ async def retrieve_relevant_snippets(
|
|
|
62
104
|
file_contents=related_file_contents,
|
|
63
105
|
)
|
|
64
106
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
107
|
+
mcp_context: MCPContext = ctx.request_context.lifespan_context
|
|
108
|
+
|
|
109
|
+
log.debug("Creating retrieval repository")
|
|
110
|
+
retrieval_repository = RetrievalRepository(
|
|
111
|
+
session=mcp_context.session,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
log.debug("Creating retrieval service")
|
|
115
|
+
retrieval_service = RetrievalService(
|
|
116
|
+
repository=retrieval_repository,
|
|
117
|
+
data_dir=mcp_context.data_dir,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
log.debug("Fusing input")
|
|
121
|
+
input_query = input_fusion(
|
|
122
|
+
user_intent=user_intent,
|
|
123
|
+
related_file_paths=related_file_paths,
|
|
124
|
+
related_file_contents=related_file_contents,
|
|
125
|
+
keywords=keywords,
|
|
126
|
+
)
|
|
127
|
+
log.debug("Input", input_query=input_query)
|
|
128
|
+
retrieval_request = RetrievalRequest(
|
|
129
|
+
keywords=keywords,
|
|
130
|
+
)
|
|
131
|
+
log.debug("Retrieving snippets")
|
|
132
|
+
snippets = await retrieval_service.retrieve(request=retrieval_request)
|
|
133
|
+
|
|
134
|
+
log.debug("Fusing output")
|
|
135
|
+
output = output_fusion(snippets=snippets)
|
|
136
|
+
|
|
137
|
+
log.debug("Output", output=output)
|
|
138
|
+
return output
|
|
95
139
|
|
|
96
140
|
|
|
97
141
|
def input_fusion(
|
|
@@ -108,3 +152,9 @@ def input_fusion(
|
|
|
108
152
|
def output_fusion(snippets: list[RetrievalResult]) -> str:
|
|
109
153
|
"""Fuse the snippets into a single output."""
|
|
110
154
|
return "\n\n".join(f"{snippet.uri}\n{snippet.content}" for snippet in snippets)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@mcp.tool()
|
|
158
|
+
async def get_version() -> str:
|
|
159
|
+
"""Get the version of the kodit project."""
|
|
160
|
+
return version
|
kodit/middleware.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
"""Middleware for the FastAPI application."""
|
|
2
2
|
|
|
3
|
+
import contextlib
|
|
3
4
|
import time
|
|
5
|
+
from asyncio import CancelledError
|
|
4
6
|
from collections.abc import Callable
|
|
5
7
|
|
|
6
8
|
import structlog
|
|
7
9
|
from asgi_correlation_id.context import correlation_id
|
|
8
10
|
from fastapi import Request, Response
|
|
11
|
+
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
9
12
|
|
|
10
13
|
access_logger = structlog.stdlib.get_logger("api.access")
|
|
11
14
|
|
|
@@ -56,3 +59,16 @@ async def logging_middleware(request: Request, call_next: Callable) -> Response:
|
|
|
56
59
|
response.headers["X-Process-Time"] = str(process_time / 10**9)
|
|
57
60
|
|
|
58
61
|
return response
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class ASGICancelledErrorMiddleware:
|
|
65
|
+
"""ASGI middleware to handle CancelledError at the ASGI level."""
|
|
66
|
+
|
|
67
|
+
def __init__(self, app: ASGIApp) -> None:
|
|
68
|
+
"""Initialize the middleware."""
|
|
69
|
+
self.app = app
|
|
70
|
+
|
|
71
|
+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
72
|
+
"""Handle the ASGI request and catch CancelledError."""
|
|
73
|
+
with contextlib.suppress(CancelledError):
|
|
74
|
+
await self.app(scope, receive, send)
|
kodit/retreival/repository.py
CHANGED
|
@@ -74,3 +74,35 @@ class RetrievalRepository:
|
|
|
74
74
|
)
|
|
75
75
|
for snippet, file in results
|
|
76
76
|
]
|
|
77
|
+
|
|
78
|
+
async def list_snippet_ids(self) -> list[int]:
|
|
79
|
+
"""List all snippet IDs.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
A list of all snippets.
|
|
83
|
+
|
|
84
|
+
"""
|
|
85
|
+
query = select(Snippet.id)
|
|
86
|
+
rows = await self.session.execute(query)
|
|
87
|
+
return list(rows.scalars().all())
|
|
88
|
+
|
|
89
|
+
async def list_snippets_by_ids(self, ids: list[int]) -> list[RetrievalResult]:
|
|
90
|
+
"""List snippets by IDs.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
A list of snippets.
|
|
94
|
+
|
|
95
|
+
"""
|
|
96
|
+
query = (
|
|
97
|
+
select(Snippet, File)
|
|
98
|
+
.where(Snippet.id.in_(ids))
|
|
99
|
+
.join(File, Snippet.file_id == File.id)
|
|
100
|
+
)
|
|
101
|
+
rows = await self.session.execute(query)
|
|
102
|
+
return [
|
|
103
|
+
RetrievalResult(
|
|
104
|
+
uri=file.uri,
|
|
105
|
+
content=snippet.content,
|
|
106
|
+
)
|
|
107
|
+
for snippet, file in rows.all()
|
|
108
|
+
]
|
kodit/retreival/service.py
CHANGED
|
@@ -1,14 +1,19 @@
|
|
|
1
1
|
"""Retrieval service."""
|
|
2
2
|
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
3
5
|
import pydantic
|
|
6
|
+
import structlog
|
|
4
7
|
|
|
8
|
+
from kodit.bm25.bm25 import BM25Service
|
|
5
9
|
from kodit.retreival.repository import RetrievalRepository, RetrievalResult
|
|
6
10
|
|
|
7
11
|
|
|
8
12
|
class RetrievalRequest(pydantic.BaseModel):
|
|
9
13
|
"""Request for a retrieval."""
|
|
10
14
|
|
|
11
|
-
|
|
15
|
+
keywords: list[str]
|
|
16
|
+
top_k: int = 10
|
|
12
17
|
|
|
13
18
|
|
|
14
19
|
class Snippet(pydantic.BaseModel):
|
|
@@ -21,10 +26,44 @@ class Snippet(pydantic.BaseModel):
|
|
|
21
26
|
class RetrievalService:
|
|
22
27
|
"""Service for retrieving relevant data."""
|
|
23
28
|
|
|
24
|
-
def __init__(self, repository: RetrievalRepository) -> None:
|
|
29
|
+
def __init__(self, repository: RetrievalRepository, data_dir: Path) -> None:
|
|
25
30
|
"""Initialize the retrieval service."""
|
|
26
31
|
self.repository = repository
|
|
32
|
+
self.log = structlog.get_logger(__name__)
|
|
33
|
+
self.bm25 = BM25Service(data_dir)
|
|
34
|
+
|
|
35
|
+
async def _load_bm25_index(self) -> None:
|
|
36
|
+
"""Load the BM25 index."""
|
|
27
37
|
|
|
28
38
|
async def retrieve(self, request: RetrievalRequest) -> list[RetrievalResult]:
|
|
29
39
|
"""Retrieve relevant data."""
|
|
30
|
-
|
|
40
|
+
snippet_ids = await self.repository.list_snippet_ids()
|
|
41
|
+
|
|
42
|
+
# Gather results for each keyword
|
|
43
|
+
result_ids: list[tuple[int, float]] = []
|
|
44
|
+
for keyword in request.keywords:
|
|
45
|
+
results = self.bm25.retrieve(snippet_ids, keyword, request.top_k)
|
|
46
|
+
result_ids.extend(results)
|
|
47
|
+
|
|
48
|
+
if len(result_ids) == 0:
|
|
49
|
+
return []
|
|
50
|
+
|
|
51
|
+
# Sort results by score
|
|
52
|
+
result_ids.sort(key=lambda x: x[1], reverse=True)
|
|
53
|
+
|
|
54
|
+
self.log.debug(
|
|
55
|
+
"Retrieval results",
|
|
56
|
+
total_results=len(result_ids),
|
|
57
|
+
max_score=result_ids[0][1],
|
|
58
|
+
min_score=result_ids[-1][1],
|
|
59
|
+
median_score=result_ids[len(result_ids) // 2][1],
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Don't return zero score results
|
|
63
|
+
result_ids = [x for x in result_ids if x[1] > 0]
|
|
64
|
+
|
|
65
|
+
# Build final list of doc ids up to top_k
|
|
66
|
+
final_doc_ids = [x[0] for x in result_ids[: request.top_k]]
|
|
67
|
+
|
|
68
|
+
# Get snippets from database
|
|
69
|
+
return await self.repository.list_snippets_by_ids(final_doc_ids)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Extract method snippets from source code."""
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Detect the language of a file."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import cast
|
|
5
|
+
|
|
6
|
+
from tree_sitter_language_pack import SupportedLanguage
|
|
7
|
+
|
|
8
|
+
# Mapping of file extensions to programming languages
|
|
9
|
+
LANGUAGE_MAP: dict[str, str] = {
|
|
10
|
+
# JavaScript/TypeScript
|
|
11
|
+
"js": "javascript",
|
|
12
|
+
"jsx": "javascript",
|
|
13
|
+
"ts": "typescript",
|
|
14
|
+
"tsx": "typescript",
|
|
15
|
+
# Python
|
|
16
|
+
"py": "python",
|
|
17
|
+
# Rust
|
|
18
|
+
"rs": "rust",
|
|
19
|
+
# Go
|
|
20
|
+
"go": "go",
|
|
21
|
+
# C/C++
|
|
22
|
+
"cpp": "cpp",
|
|
23
|
+
"hpp": "cpp",
|
|
24
|
+
"c": "c",
|
|
25
|
+
"h": "c",
|
|
26
|
+
# C#
|
|
27
|
+
"cs": "csharp",
|
|
28
|
+
# Ruby
|
|
29
|
+
"rb": "ruby",
|
|
30
|
+
# Java
|
|
31
|
+
"java": "java",
|
|
32
|
+
# PHP
|
|
33
|
+
"php": "php",
|
|
34
|
+
# Swift
|
|
35
|
+
"swift": "swift",
|
|
36
|
+
# Kotlin
|
|
37
|
+
"kt": "kotlin",
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def detect_language(file_path: Path) -> SupportedLanguage:
|
|
42
|
+
"""Detect the language of a file."""
|
|
43
|
+
suffix = file_path.suffix.removeprefix(".").lower()
|
|
44
|
+
msg = f"Unsupported language for file suffix: {suffix}"
|
|
45
|
+
lang = LANGUAGE_MAP.get(suffix)
|
|
46
|
+
if lang is None:
|
|
47
|
+
raise ValueError(msg)
|
|
48
|
+
|
|
49
|
+
# Try to cast the language to a SupportedLanguage
|
|
50
|
+
try:
|
|
51
|
+
return cast("SupportedLanguage", lang)
|
|
52
|
+
except Exception as e:
|
|
53
|
+
raise ValueError(msg) from e
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
(function_definition
|
|
2
|
+
name: (identifier) @function.name
|
|
3
|
+
body: (block) @function.body
|
|
4
|
+
) @function.def
|
|
5
|
+
|
|
6
|
+
(class_definition
|
|
7
|
+
name: (identifier) @class.name
|
|
8
|
+
) @class.def
|
|
9
|
+
|
|
10
|
+
(import_statement
|
|
11
|
+
name: (dotted_name (identifier) @import.name))
|
|
12
|
+
|
|
13
|
+
(import_from_statement
|
|
14
|
+
module_name: (dotted_name (identifier) @import.from))
|
|
15
|
+
|
|
16
|
+
(identifier) @ident
|
|
17
|
+
|
|
18
|
+
(assignment
|
|
19
|
+
left: (identifier) @assignment.lhs)
|
|
20
|
+
|
|
21
|
+
(parameters
|
|
22
|
+
(identifier) @param.name)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Extract method snippets from source code."""
|
|
2
|
+
|
|
3
|
+
from tree_sitter import Node, Query
|
|
4
|
+
from tree_sitter_language_pack import SupportedLanguage, get_language, get_parser
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MethodSnippets:
|
|
8
|
+
"""Extract method snippets from source code."""
|
|
9
|
+
|
|
10
|
+
def __init__(self, language: SupportedLanguage, query: str) -> None:
|
|
11
|
+
"""Initialize the MethodSnippets class."""
|
|
12
|
+
self.language = get_language(language)
|
|
13
|
+
self.parser = get_parser(language)
|
|
14
|
+
self.query = Query(self.language, query)
|
|
15
|
+
|
|
16
|
+
def _get_leaf_functions(
|
|
17
|
+
self, captures_by_name: dict[str, list[Node]]
|
|
18
|
+
) -> list[Node]:
|
|
19
|
+
"""Return all leaf functions in the AST."""
|
|
20
|
+
return [
|
|
21
|
+
node
|
|
22
|
+
for node in captures_by_name.get("function.body", [])
|
|
23
|
+
if self._is_leaf_function(captures_by_name, node)
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
def _is_leaf_function(
|
|
27
|
+
self, captures_by_name: dict[str, list[Node]], node: Node
|
|
28
|
+
) -> bool:
|
|
29
|
+
"""Return True if the node is a leaf function."""
|
|
30
|
+
for other in captures_by_name.get("function.body", []):
|
|
31
|
+
if other == node: # Skip self
|
|
32
|
+
continue
|
|
33
|
+
# if other is inside node, it's not a leaf function
|
|
34
|
+
if other.start_byte >= node.start_byte and other.end_byte <= node.end_byte:
|
|
35
|
+
return False
|
|
36
|
+
return True
|
|
37
|
+
|
|
38
|
+
def _get_imports(self, captures_by_name: dict[str, list[Node]]) -> list[Node]:
|
|
39
|
+
"""Return all imports in the AST."""
|
|
40
|
+
return captures_by_name.get("import.name", []) + captures_by_name.get(
|
|
41
|
+
"import.from", []
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def _classes_and_functions(
|
|
45
|
+
self, captures_by_name: dict[str, list[Node]]
|
|
46
|
+
) -> list[int]:
|
|
47
|
+
"""Return all classes and functions in the AST."""
|
|
48
|
+
return [
|
|
49
|
+
node.id
|
|
50
|
+
for node in {
|
|
51
|
+
*captures_by_name.get("function.def", []),
|
|
52
|
+
*captures_by_name.get("class.def", []),
|
|
53
|
+
}
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
def _get_ancestors(
|
|
57
|
+
self, captures_by_name: dict[str, list[Node]], node: Node
|
|
58
|
+
) -> list[Node]:
|
|
59
|
+
"""Return all ancestors of the node."""
|
|
60
|
+
valid_ancestors = self._classes_and_functions(captures_by_name)
|
|
61
|
+
ancestors = []
|
|
62
|
+
parent = node.parent
|
|
63
|
+
while parent:
|
|
64
|
+
if parent.id in valid_ancestors:
|
|
65
|
+
ancestors.append(parent)
|
|
66
|
+
parent = parent.parent
|
|
67
|
+
return ancestors
|
|
68
|
+
|
|
69
|
+
def extract(self, source_code: bytes) -> list[str]:
|
|
70
|
+
"""Extract method snippets from source code."""
|
|
71
|
+
tree = self.parser.parse(source_code)
|
|
72
|
+
|
|
73
|
+
captures_by_name = self.query.captures(tree.root_node)
|
|
74
|
+
|
|
75
|
+
lines = source_code.decode().splitlines()
|
|
76
|
+
|
|
77
|
+
# Find all leaf functions
|
|
78
|
+
leaf_functions = self._get_leaf_functions(captures_by_name)
|
|
79
|
+
|
|
80
|
+
# Find all imports
|
|
81
|
+
imports = self._get_imports(captures_by_name)
|
|
82
|
+
|
|
83
|
+
results = []
|
|
84
|
+
|
|
85
|
+
# For each leaf function, find all lines this function is dependent on
|
|
86
|
+
for func_node in leaf_functions:
|
|
87
|
+
all_lines_to_keep = set()
|
|
88
|
+
|
|
89
|
+
ancestors = self._get_ancestors(captures_by_name, func_node)
|
|
90
|
+
|
|
91
|
+
# Add self to keep
|
|
92
|
+
all_lines_to_keep.update(
|
|
93
|
+
range(func_node.start_point[0], func_node.end_point[0] + 1)
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Add imports to keep
|
|
97
|
+
for import_node in imports:
|
|
98
|
+
all_lines_to_keep.update(
|
|
99
|
+
range(import_node.start_point[0], import_node.end_point[0] + 1)
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Add ancestors to keep
|
|
103
|
+
for node in ancestors:
|
|
104
|
+
# Get the first line of the node for now
|
|
105
|
+
start = node.start_point[0]
|
|
106
|
+
end = node.start_point[0]
|
|
107
|
+
all_lines_to_keep.update(range(start, end + 1))
|
|
108
|
+
|
|
109
|
+
pseudo_code = []
|
|
110
|
+
for i, line in enumerate(lines):
|
|
111
|
+
if i in all_lines_to_keep:
|
|
112
|
+
pseudo_code.append(line)
|
|
113
|
+
|
|
114
|
+
results.append("\n".join(pseudo_code))
|
|
115
|
+
|
|
116
|
+
# If there are no results, then return the entire file
|
|
117
|
+
if not results:
|
|
118
|
+
return [source_code.decode()]
|
|
119
|
+
|
|
120
|
+
return results
|