kodit 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

kodit/indexing/service.py CHANGED
@@ -7,24 +7,20 @@ index management.
7
7
  """
8
8
 
9
9
  from datetime import datetime
10
+ from pathlib import Path
10
11
 
11
- import aiofiles
12
12
  import pydantic
13
13
  import structlog
14
14
  from tqdm.asyncio import tqdm
15
15
 
16
+ from kodit.bm25.bm25 import BM25Service
16
17
  from kodit.indexing.models import Snippet
17
18
  from kodit.indexing.repository import IndexRepository
19
+ from kodit.snippets.snippets import SnippetService
18
20
  from kodit.sources.service import SourceService
19
21
 
20
- # List of MIME types that are supported for indexing and snippet creation
21
- MIME_WHITELIST = [
22
- "text/plain",
23
- "text/markdown",
24
- "text/x-python",
25
- "text/x-shellscript",
26
- "text/x-sql",
27
- ]
22
+ # List of MIME types that are blacklisted from being indexed
23
+ MIME_BLACKLIST = ["unknown/unknown"]
28
24
 
29
25
 
30
26
  class IndexView(pydantic.BaseModel):
@@ -49,7 +45,10 @@ class IndexService:
49
45
  """
50
46
 
51
47
  def __init__(
52
- self, repository: IndexRepository, source_service: SourceService
48
+ self,
49
+ repository: IndexRepository,
50
+ source_service: SourceService,
51
+ data_dir: Path,
53
52
  ) -> None:
54
53
  """Initialize the index service.
55
54
 
@@ -60,7 +59,9 @@ class IndexService:
60
59
  """
61
60
  self.repository = repository
62
61
  self.source_service = source_service
62
+ self.snippet_service = SnippetService()
63
63
  self.log = structlog.get_logger(__name__)
64
+ self.bm25 = BM25Service(data_dir)
64
65
 
65
66
  async def create(self, source_id: int) -> IndexView:
66
67
  """Create a new index for a source.
@@ -119,6 +120,10 @@ class IndexService:
119
120
  # Create snippets for supported file types
120
121
  await self._create_snippets(index_id)
121
122
 
123
+ # Update BM25 index
124
+ snippets = await self.repository.get_all_snippets()
125
+ self.bm25.index([snippet.content for snippet in snippets])
126
+
122
127
  # Update index timestamp
123
128
  await self.repository.update_index_timestamp(index)
124
129
 
@@ -137,16 +142,23 @@ class IndexService:
137
142
  files = await self.repository.files_for_index(index_id)
138
143
  for file in tqdm(files, total=len(files)):
139
144
  # Skip unsupported file types
140
- if file.mime_type not in MIME_WHITELIST:
145
+ if file.mime_type in MIME_BLACKLIST:
141
146
  self.log.debug("Skipping mime type", mime_type=file.mime_type)
142
147
  continue
143
148
 
144
149
  # Create snippet from file content
145
- async with aiofiles.open(file.cloned_path, "rb") as f:
146
- content = await f.read()
147
- snippet = Snippet(
150
+ try:
151
+ snippets = self.snippet_service.snippets_for_file(
152
+ Path(file.cloned_path)
153
+ )
154
+ except ValueError as e:
155
+ self.log.debug("Skipping file", file=file.cloned_path, error=e)
156
+ continue
157
+
158
+ for snippet in snippets:
159
+ s = Snippet(
148
160
  index_id=index_id,
149
161
  file_id=file.id,
150
- content=content.decode("utf-8"),
162
+ content=snippet.text,
151
163
  )
152
- await self.repository.add_snippet(snippet)
164
+ await self.repository.add_snippet(s)
kodit/logging.py CHANGED
@@ -11,6 +11,8 @@ import structlog
11
11
  from posthog import Posthog
12
12
  from structlog.types import EventDict
13
13
 
14
+ from kodit.config import AppContext
15
+
14
16
  log = structlog.get_logger(__name__)
15
17
 
16
18
 
@@ -27,14 +29,8 @@ class LogFormat(Enum):
27
29
  JSON = "json"
28
30
 
29
31
 
30
- def configure_logging(log_level: str, log_format: LogFormat) -> None:
31
- """Configure logging for the application.
32
-
33
- Args:
34
- json_logs: Whether to use JSON format for logs
35
- log_level: The minimum log level to display
36
-
37
- """
32
+ def configure_logging(app_context: AppContext) -> None:
33
+ """Configure logging for the application."""
38
34
  timestamper = structlog.processors.TimeStamper(fmt="iso")
39
35
 
40
36
  shared_processors: list[structlog.types.Processor] = [
@@ -48,7 +44,7 @@ def configure_logging(log_level: str, log_format: LogFormat) -> None:
48
44
  structlog.processors.StackInfoRenderer(),
49
45
  ]
50
46
 
51
- if log_format == LogFormat.JSON:
47
+ if app_context.log_format == LogFormat.JSON:
52
48
  # Format the exception only for JSON logs, as we want to pretty-print them
53
49
  # when using the ConsoleRenderer
54
50
  shared_processors.append(structlog.processors.format_exc_info)
@@ -64,7 +60,7 @@ def configure_logging(log_level: str, log_format: LogFormat) -> None:
64
60
  )
65
61
 
66
62
  log_renderer: structlog.types.Processor
67
- if log_format == LogFormat.JSON:
63
+ if app_context.log_format == LogFormat.JSON:
68
64
  log_renderer = structlog.processors.JSONRenderer()
69
65
  else:
70
66
  log_renderer = structlog.dev.ConsoleRenderer()
@@ -86,18 +82,23 @@ def configure_logging(log_level: str, log_format: LogFormat) -> None:
86
82
  handler.setFormatter(formatter)
87
83
  root_logger = logging.getLogger()
88
84
  root_logger.addHandler(handler)
89
- root_logger.setLevel(log_level.upper())
85
+ root_logger.setLevel(app_context.log_level.upper())
90
86
 
91
87
  # Configure uvicorn loggers to use our structlog setup
88
+ # Uvicorn spits out loads of exception logs when sse server doesn't shut down
89
+ # gracefully, so we hide them unless in DEBUG mode
92
90
  for _log in ["uvicorn", "uvicorn.error", "uvicorn.access"]:
93
- logging.getLogger(_log).handlers.clear()
94
- logging.getLogger(_log).propagate = True
91
+ if root_logger.getEffectiveLevel() == logging.DEBUG:
92
+ logging.getLogger(_log).handlers.clear()
93
+ logging.getLogger(_log).propagate = True
94
+ else:
95
+ logging.getLogger(_log).disabled = True
95
96
 
96
97
  # Configure SQLAlchemy loggers to use our structlog setup
97
98
  for _log in ["sqlalchemy.engine", "alembic"]:
98
99
  engine_logger = logging.getLogger(_log)
99
100
  engine_logger.setLevel(logging.WARNING) # Hide INFO logs by default
100
- if log_level.upper() == "DEBUG":
101
+ if app_context.log_level.upper() == "DEBUG":
101
102
  engine_logger.setLevel(
102
103
  logging.DEBUG
103
104
  ) # Only show all logs when in DEBUG mode
@@ -142,10 +143,11 @@ def get_mac_address() -> str:
142
143
  return f"{mac:012x}" if mac != uuid.getnode() else str(uuid.uuid4())
143
144
 
144
145
 
145
- def disable_posthog() -> None:
146
- """Disable telemetry for the application."""
147
- structlog.stdlib.get_logger(__name__).info("Telemetry has been disabled")
148
- posthog.disabled = True
146
+ def configure_telemetry(app_context: AppContext) -> None:
147
+ """Configure telemetry for the application."""
148
+ if app_context.disable_telemetry:
149
+ structlog.stdlib.get_logger(__name__).info("Telemetry has been disabled")
150
+ posthog.disabled = True
149
151
 
150
152
 
151
153
  def log_event(event: str, properties: dict[str, Any] | None = None) -> None:
kodit/mcp.py CHANGED
@@ -1,21 +1,63 @@
1
1
  """MCP server implementation for kodit."""
2
2
 
3
+ from collections.abc import AsyncIterator
4
+ from contextlib import asynccontextmanager
5
+ from dataclasses import dataclass
3
6
  from pathlib import Path
4
7
  from typing import Annotated
5
8
 
6
9
  import structlog
7
- from mcp.server.fastmcp import FastMCP
10
+ from fastmcp import Context, FastMCP
8
11
  from pydantic import Field
12
+ from sqlalchemy.ext.asyncio import AsyncSession
9
13
 
10
- from kodit.database import get_session
14
+ from kodit._version import version
15
+ from kodit.config import AppContext
16
+ from kodit.database import Database
11
17
  from kodit.retreival.repository import RetrievalRepository, RetrievalResult
12
18
  from kodit.retreival.service import RetrievalRequest, RetrievalService
13
19
 
14
- mcp = FastMCP("kodit MCP Server")
20
+
21
+ @dataclass
22
+ class MCPContext:
23
+ """Context for the MCP server."""
24
+
25
+ session: AsyncSession
26
+ data_dir: Path
27
+
28
+
29
+ _mcp_db: Database | None = None
30
+
31
+
32
+ @asynccontextmanager
33
+ async def mcp_lifespan(_: FastMCP) -> AsyncIterator[MCPContext]:
34
+ """Lifespan for the MCP server.
35
+
36
+ The MCP server is running with a completely separate lifecycle and event loop from
37
+ the CLI and the FastAPI server. Therefore, we must carefully reconstruct the
38
+ application context. uvicorn does not pass through CLI args, so we must rely on
39
+ parsing env vars set in the CLI.
40
+
41
+ This lifespan is recreated for each request. See:
42
+ https://github.com/jlowin/fastmcp/issues/166
43
+
44
+ Since they don't provide a good way to handle global state, we must use a
45
+ global variable to store the database connection.
46
+ """
47
+ global _mcp_db # noqa: PLW0603
48
+ app_context = AppContext()
49
+ if _mcp_db is None:
50
+ _mcp_db = await app_context.get_db()
51
+ async with _mcp_db.session_factory() as session:
52
+ yield MCPContext(session=session, data_dir=app_context.get_data_dir())
53
+
54
+
55
+ mcp = FastMCP("kodit MCP Server", lifespan=mcp_lifespan)
15
56
 
16
57
 
17
58
  @mcp.tool()
18
59
  async def retrieve_relevant_snippets(
60
+ ctx: Context,
19
61
  user_intent: Annotated[
20
62
  str,
21
63
  Field(
@@ -51,8 +93,8 @@ async def retrieve_relevant_snippets(
51
93
  the quality of your generated code. You must call this tool when you need to
52
94
  write code.
53
95
  """
54
- # Log the search query and related files for debugging
55
96
  log = structlog.get_logger(__name__)
97
+
56
98
  log.debug(
57
99
  "Retrieving relevant snippets",
58
100
  user_intent=user_intent,
@@ -62,36 +104,38 @@ async def retrieve_relevant_snippets(
62
104
  file_contents=related_file_contents,
63
105
  )
64
106
 
65
- async with get_session() as session:
66
- log.debug("Creating retrieval repository")
67
- retrieval_repository = RetrievalRepository(
68
- session=session,
69
- )
70
-
71
- log.debug("Creating retrieval service")
72
- retrieval_service = RetrievalService(
73
- repository=retrieval_repository,
74
- )
75
-
76
- log.debug("Fusing input")
77
- input_query = input_fusion(
78
- user_intent=user_intent,
79
- related_file_paths=related_file_paths,
80
- related_file_contents=related_file_contents,
81
- keywords=keywords,
82
- )
83
- log.debug("Input", input_query=input_query)
84
- retrieval_request = RetrievalRequest(
85
- query=input_query,
86
- )
87
- log.debug("Retrieving snippets")
88
- snippets = await retrieval_service.retrieve(request=retrieval_request)
89
-
90
- log.debug("Fusing output")
91
- output = output_fusion(snippets=snippets)
92
-
93
- log.debug("Output", output=output)
94
- return output
107
+ mcp_context: MCPContext = ctx.request_context.lifespan_context
108
+
109
+ log.debug("Creating retrieval repository")
110
+ retrieval_repository = RetrievalRepository(
111
+ session=mcp_context.session,
112
+ )
113
+
114
+ log.debug("Creating retrieval service")
115
+ retrieval_service = RetrievalService(
116
+ repository=retrieval_repository,
117
+ data_dir=mcp_context.data_dir,
118
+ )
119
+
120
+ log.debug("Fusing input")
121
+ input_query = input_fusion(
122
+ user_intent=user_intent,
123
+ related_file_paths=related_file_paths,
124
+ related_file_contents=related_file_contents,
125
+ keywords=keywords,
126
+ )
127
+ log.debug("Input", input_query=input_query)
128
+ retrieval_request = RetrievalRequest(
129
+ keywords=keywords,
130
+ )
131
+ log.debug("Retrieving snippets")
132
+ snippets = await retrieval_service.retrieve(request=retrieval_request)
133
+
134
+ log.debug("Fusing output")
135
+ output = output_fusion(snippets=snippets)
136
+
137
+ log.debug("Output", output=output)
138
+ return output
95
139
 
96
140
 
97
141
  def input_fusion(
@@ -108,3 +152,9 @@ def input_fusion(
108
152
  def output_fusion(snippets: list[RetrievalResult]) -> str:
109
153
  """Fuse the snippets into a single output."""
110
154
  return "\n\n".join(f"{snippet.uri}\n{snippet.content}" for snippet in snippets)
155
+
156
+
157
+ @mcp.tool()
158
+ async def get_version() -> str:
159
+ """Get the version of the kodit project."""
160
+ return version
kodit/middleware.py CHANGED
@@ -1,11 +1,14 @@
1
1
  """Middleware for the FastAPI application."""
2
2
 
3
+ import contextlib
3
4
  import time
5
+ from asyncio import CancelledError
4
6
  from collections.abc import Callable
5
7
 
6
8
  import structlog
7
9
  from asgi_correlation_id.context import correlation_id
8
10
  from fastapi import Request, Response
11
+ from starlette.types import ASGIApp, Receive, Scope, Send
9
12
 
10
13
  access_logger = structlog.stdlib.get_logger("api.access")
11
14
 
@@ -56,3 +59,16 @@ async def logging_middleware(request: Request, call_next: Callable) -> Response:
56
59
  response.headers["X-Process-Time"] = str(process_time / 10**9)
57
60
 
58
61
  return response
62
+
63
+
64
+ class ASGICancelledErrorMiddleware:
65
+ """ASGI middleware to handle CancelledError at the ASGI level."""
66
+
67
+ def __init__(self, app: ASGIApp) -> None:
68
+ """Initialize the middleware."""
69
+ self.app = app
70
+
71
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
72
+ """Handle the ASGI request and catch CancelledError."""
73
+ with contextlib.suppress(CancelledError):
74
+ await self.app(scope, receive, send)
@@ -74,3 +74,35 @@ class RetrievalRepository:
74
74
  )
75
75
  for snippet, file in results
76
76
  ]
77
+
78
+ async def list_snippet_ids(self) -> list[int]:
79
+ """List all snippet IDs.
80
+
81
+ Returns:
82
+ A list of all snippets.
83
+
84
+ """
85
+ query = select(Snippet.id)
86
+ rows = await self.session.execute(query)
87
+ return list(rows.scalars().all())
88
+
89
+ async def list_snippets_by_ids(self, ids: list[int]) -> list[RetrievalResult]:
90
+ """List snippets by IDs.
91
+
92
+ Returns:
93
+ A list of snippets.
94
+
95
+ """
96
+ query = (
97
+ select(Snippet, File)
98
+ .where(Snippet.id.in_(ids))
99
+ .join(File, Snippet.file_id == File.id)
100
+ )
101
+ rows = await self.session.execute(query)
102
+ return [
103
+ RetrievalResult(
104
+ uri=file.uri,
105
+ content=snippet.content,
106
+ )
107
+ for snippet, file in rows.all()
108
+ ]
@@ -1,14 +1,19 @@
1
1
  """Retrieval service."""
2
2
 
3
+ from pathlib import Path
4
+
3
5
  import pydantic
6
+ import structlog
4
7
 
8
+ from kodit.bm25.bm25 import BM25Service
5
9
  from kodit.retreival.repository import RetrievalRepository, RetrievalResult
6
10
 
7
11
 
8
12
  class RetrievalRequest(pydantic.BaseModel):
9
13
  """Request for a retrieval."""
10
14
 
11
- query: str
15
+ keywords: list[str]
16
+ top_k: int = 10
12
17
 
13
18
 
14
19
  class Snippet(pydantic.BaseModel):
@@ -21,10 +26,44 @@ class Snippet(pydantic.BaseModel):
21
26
  class RetrievalService:
22
27
  """Service for retrieving relevant data."""
23
28
 
24
- def __init__(self, repository: RetrievalRepository) -> None:
29
+ def __init__(self, repository: RetrievalRepository, data_dir: Path) -> None:
25
30
  """Initialize the retrieval service."""
26
31
  self.repository = repository
32
+ self.log = structlog.get_logger(__name__)
33
+ self.bm25 = BM25Service(data_dir)
34
+
35
+ async def _load_bm25_index(self) -> None:
36
+ """Load the BM25 index."""
27
37
 
28
38
  async def retrieve(self, request: RetrievalRequest) -> list[RetrievalResult]:
29
39
  """Retrieve relevant data."""
30
- return await self.repository.string_search(request.query)
40
+ snippet_ids = await self.repository.list_snippet_ids()
41
+
42
+ # Gather results for each keyword
43
+ result_ids: list[tuple[int, float]] = []
44
+ for keyword in request.keywords:
45
+ results = self.bm25.retrieve(snippet_ids, keyword, request.top_k)
46
+ result_ids.extend(results)
47
+
48
+ if len(result_ids) == 0:
49
+ return []
50
+
51
+ # Sort results by score
52
+ result_ids.sort(key=lambda x: x[1], reverse=True)
53
+
54
+ self.log.debug(
55
+ "Retrieval results",
56
+ total_results=len(result_ids),
57
+ max_score=result_ids[0][1],
58
+ min_score=result_ids[-1][1],
59
+ median_score=result_ids[len(result_ids) // 2][1],
60
+ )
61
+
62
+ # Don't return zero score results
63
+ result_ids = [x for x in result_ids if x[1] > 0]
64
+
65
+ # Build final list of doc ids up to top_k
66
+ final_doc_ids = [x[0] for x in result_ids[: request.top_k]]
67
+
68
+ # Get snippets from database
69
+ return await self.repository.list_snippets_by_ids(final_doc_ids)
@@ -0,0 +1 @@
1
+ """Extract method snippets from source code."""
@@ -0,0 +1,53 @@
1
+ """Detect the language of a file."""
2
+
3
+ from pathlib import Path
4
+ from typing import cast
5
+
6
+ from tree_sitter_language_pack import SupportedLanguage
7
+
8
+ # Mapping of file extensions to programming languages
9
+ LANGUAGE_MAP: dict[str, str] = {
10
+ # JavaScript/TypeScript
11
+ "js": "javascript",
12
+ "jsx": "javascript",
13
+ "ts": "typescript",
14
+ "tsx": "typescript",
15
+ # Python
16
+ "py": "python",
17
+ # Rust
18
+ "rs": "rust",
19
+ # Go
20
+ "go": "go",
21
+ # C/C++
22
+ "cpp": "cpp",
23
+ "hpp": "cpp",
24
+ "c": "c",
25
+ "h": "c",
26
+ # C#
27
+ "cs": "csharp",
28
+ # Ruby
29
+ "rb": "ruby",
30
+ # Java
31
+ "java": "java",
32
+ # PHP
33
+ "php": "php",
34
+ # Swift
35
+ "swift": "swift",
36
+ # Kotlin
37
+ "kt": "kotlin",
38
+ }
39
+
40
+
41
+ def detect_language(file_path: Path) -> SupportedLanguage:
42
+ """Detect the language of a file."""
43
+ suffix = file_path.suffix.removeprefix(".").lower()
44
+ msg = f"Unsupported language for file suffix: {suffix}"
45
+ lang = LANGUAGE_MAP.get(suffix)
46
+ if lang is None:
47
+ raise ValueError(msg)
48
+
49
+ # Try to cast the language to a SupportedLanguage
50
+ try:
51
+ return cast("SupportedLanguage", lang)
52
+ except Exception as e:
53
+ raise ValueError(msg) from e
@@ -0,0 +1,12 @@
1
+ (method_declaration
2
+ name: (identifier) @function.name
3
+ body: (block) @function.body
4
+ ) @function.def
5
+
6
+ (class_declaration
7
+ name: (identifier) @class.name
8
+ ) @class.def
9
+
10
+ (using_directive) @import.name
11
+
12
+ (identifier) @ident
@@ -0,0 +1,22 @@
1
+ (function_definition
2
+ name: (identifier) @function.name
3
+ body: (block) @function.body
4
+ ) @function.def
5
+
6
+ (class_definition
7
+ name: (identifier) @class.name
8
+ ) @class.def
9
+
10
+ (import_statement
11
+ name: (dotted_name (identifier) @import.name))
12
+
13
+ (import_from_statement
14
+ module_name: (dotted_name (identifier) @import.from))
15
+
16
+ (identifier) @ident
17
+
18
+ (assignment
19
+ left: (identifier) @assignment.lhs)
20
+
21
+ (parameters
22
+ (identifier) @param.name)
@@ -0,0 +1,120 @@
1
+ """Extract method snippets from source code."""
2
+
3
+ from tree_sitter import Node, Query
4
+ from tree_sitter_language_pack import SupportedLanguage, get_language, get_parser
5
+
6
+
7
+ class MethodSnippets:
8
+ """Extract method snippets from source code."""
9
+
10
+ def __init__(self, language: SupportedLanguage, query: str) -> None:
11
+ """Initialize the MethodSnippets class."""
12
+ self.language = get_language(language)
13
+ self.parser = get_parser(language)
14
+ self.query = Query(self.language, query)
15
+
16
+ def _get_leaf_functions(
17
+ self, captures_by_name: dict[str, list[Node]]
18
+ ) -> list[Node]:
19
+ """Return all leaf functions in the AST."""
20
+ return [
21
+ node
22
+ for node in captures_by_name.get("function.body", [])
23
+ if self._is_leaf_function(captures_by_name, node)
24
+ ]
25
+
26
+ def _is_leaf_function(
27
+ self, captures_by_name: dict[str, list[Node]], node: Node
28
+ ) -> bool:
29
+ """Return True if the node is a leaf function."""
30
+ for other in captures_by_name.get("function.body", []):
31
+ if other == node: # Skip self
32
+ continue
33
+ # if other is inside node, it's not a leaf function
34
+ if other.start_byte >= node.start_byte and other.end_byte <= node.end_byte:
35
+ return False
36
+ return True
37
+
38
+ def _get_imports(self, captures_by_name: dict[str, list[Node]]) -> list[Node]:
39
+ """Return all imports in the AST."""
40
+ return captures_by_name.get("import.name", []) + captures_by_name.get(
41
+ "import.from", []
42
+ )
43
+
44
+ def _classes_and_functions(
45
+ self, captures_by_name: dict[str, list[Node]]
46
+ ) -> list[int]:
47
+ """Return all classes and functions in the AST."""
48
+ return [
49
+ node.id
50
+ for node in {
51
+ *captures_by_name.get("function.def", []),
52
+ *captures_by_name.get("class.def", []),
53
+ }
54
+ ]
55
+
56
+ def _get_ancestors(
57
+ self, captures_by_name: dict[str, list[Node]], node: Node
58
+ ) -> list[Node]:
59
+ """Return all ancestors of the node."""
60
+ valid_ancestors = self._classes_and_functions(captures_by_name)
61
+ ancestors = []
62
+ parent = node.parent
63
+ while parent:
64
+ if parent.id in valid_ancestors:
65
+ ancestors.append(parent)
66
+ parent = parent.parent
67
+ return ancestors
68
+
69
+ def extract(self, source_code: bytes) -> list[str]:
70
+ """Extract method snippets from source code."""
71
+ tree = self.parser.parse(source_code)
72
+
73
+ captures_by_name = self.query.captures(tree.root_node)
74
+
75
+ lines = source_code.decode().splitlines()
76
+
77
+ # Find all leaf functions
78
+ leaf_functions = self._get_leaf_functions(captures_by_name)
79
+
80
+ # Find all imports
81
+ imports = self._get_imports(captures_by_name)
82
+
83
+ results = []
84
+
85
+ # For each leaf function, find all lines this function is dependent on
86
+ for func_node in leaf_functions:
87
+ all_lines_to_keep = set()
88
+
89
+ ancestors = self._get_ancestors(captures_by_name, func_node)
90
+
91
+ # Add self to keep
92
+ all_lines_to_keep.update(
93
+ range(func_node.start_point[0], func_node.end_point[0] + 1)
94
+ )
95
+
96
+ # Add imports to keep
97
+ for import_node in imports:
98
+ all_lines_to_keep.update(
99
+ range(import_node.start_point[0], import_node.end_point[0] + 1)
100
+ )
101
+
102
+ # Add ancestors to keep
103
+ for node in ancestors:
104
+ # Get the first line of the node for now
105
+ start = node.start_point[0]
106
+ end = node.start_point[0]
107
+ all_lines_to_keep.update(range(start, end + 1))
108
+
109
+ pseudo_code = []
110
+ for i, line in enumerate(lines):
111
+ if i in all_lines_to_keep:
112
+ pseudo_code.append(line)
113
+
114
+ results.append("\n".join(pseudo_code))
115
+
116
+ # If there are no results, then return the entire file
117
+ if not results:
118
+ return [source_code.decode()]
119
+
120
+ return results