mirascope 2.0.0a1__py3-none-any.whl → 2.0.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. mirascope/__init__.py +2 -2
  2. mirascope/api/__init__.py +6 -0
  3. mirascope/api/_generated/README.md +207 -0
  4. mirascope/api/_generated/__init__.py +85 -0
  5. mirascope/api/_generated/client.py +155 -0
  6. mirascope/api/_generated/core/__init__.py +52 -0
  7. mirascope/api/_generated/core/api_error.py +23 -0
  8. mirascope/api/_generated/core/client_wrapper.py +58 -0
  9. mirascope/api/_generated/core/datetime_utils.py +30 -0
  10. mirascope/api/_generated/core/file.py +70 -0
  11. mirascope/api/_generated/core/force_multipart.py +16 -0
  12. mirascope/api/_generated/core/http_client.py +619 -0
  13. mirascope/api/_generated/core/http_response.py +55 -0
  14. mirascope/api/_generated/core/jsonable_encoder.py +102 -0
  15. mirascope/api/_generated/core/pydantic_utilities.py +310 -0
  16. mirascope/api/_generated/core/query_encoder.py +60 -0
  17. mirascope/api/_generated/core/remove_none_from_dict.py +11 -0
  18. mirascope/api/_generated/core/request_options.py +35 -0
  19. mirascope/api/_generated/core/serialization.py +282 -0
  20. mirascope/api/_generated/docs/__init__.py +4 -0
  21. mirascope/api/_generated/docs/client.py +95 -0
  22. mirascope/api/_generated/docs/raw_client.py +132 -0
  23. mirascope/api/_generated/environment.py +9 -0
  24. mirascope/api/_generated/errors/__init__.py +7 -0
  25. mirascope/api/_generated/errors/bad_request_error.py +15 -0
  26. mirascope/api/_generated/health/__init__.py +7 -0
  27. mirascope/api/_generated/health/client.py +96 -0
  28. mirascope/api/_generated/health/raw_client.py +129 -0
  29. mirascope/api/_generated/health/types/__init__.py +8 -0
  30. mirascope/api/_generated/health/types/health_check_response.py +24 -0
  31. mirascope/api/_generated/health/types/health_check_response_status.py +5 -0
  32. mirascope/api/_generated/reference.md +167 -0
  33. mirascope/api/_generated/traces/__init__.py +55 -0
  34. mirascope/api/_generated/traces/client.py +162 -0
  35. mirascope/api/_generated/traces/raw_client.py +168 -0
  36. mirascope/api/_generated/traces/types/__init__.py +95 -0
  37. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item.py +36 -0
  38. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource.py +31 -0
  39. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item.py +25 -0
  40. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value.py +54 -0
  41. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_array_value.py +23 -0
  42. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_kvlist_value.py +28 -0
  43. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_kvlist_value_values_item.py +24 -0
  44. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item.py +35 -0
  45. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope.py +35 -0
  46. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item.py +27 -0
  47. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value.py +54 -0
  48. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_array_value.py +23 -0
  49. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_kvlist_value.py +28 -0
  50. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_kvlist_value_values_item.py +24 -0
  51. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item.py +60 -0
  52. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item.py +29 -0
  53. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value.py +54 -0
  54. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_array_value.py +23 -0
  55. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_kvlist_value.py +28 -0
  56. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_kvlist_value_values_item.py +24 -0
  57. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_status.py +24 -0
  58. mirascope/api/_generated/traces/types/traces_create_response.py +27 -0
  59. mirascope/api/_generated/traces/types/traces_create_response_partial_success.py +28 -0
  60. mirascope/api/_generated/types/__init__.py +21 -0
  61. mirascope/api/_generated/types/http_api_decode_error.py +31 -0
  62. mirascope/api/_generated/types/http_api_decode_error_tag.py +5 -0
  63. mirascope/api/_generated/types/issue.py +44 -0
  64. mirascope/api/_generated/types/issue_tag.py +17 -0
  65. mirascope/api/_generated/types/property_key.py +7 -0
  66. mirascope/api/_generated/types/property_key_tag.py +29 -0
  67. mirascope/api/_generated/types/property_key_tag_tag.py +5 -0
  68. mirascope/api/client.py +255 -0
  69. mirascope/api/settings.py +81 -0
  70. mirascope/llm/__init__.py +41 -11
  71. mirascope/llm/calls/calls.py +81 -57
  72. mirascope/llm/calls/decorator.py +121 -115
  73. mirascope/llm/content/__init__.py +3 -2
  74. mirascope/llm/context/_utils.py +19 -6
  75. mirascope/llm/exceptions.py +30 -16
  76. mirascope/llm/formatting/_utils.py +9 -5
  77. mirascope/llm/formatting/format.py +2 -2
  78. mirascope/llm/formatting/from_call_args.py +2 -2
  79. mirascope/llm/messages/message.py +13 -5
  80. mirascope/llm/models/__init__.py +2 -2
  81. mirascope/llm/models/models.py +189 -81
  82. mirascope/llm/prompts/__init__.py +13 -12
  83. mirascope/llm/prompts/_utils.py +27 -24
  84. mirascope/llm/prompts/decorator.py +133 -204
  85. mirascope/llm/prompts/prompts.py +424 -0
  86. mirascope/llm/prompts/protocols.py +25 -59
  87. mirascope/llm/providers/__init__.py +38 -0
  88. mirascope/llm/{clients → providers}/_missing_import_stubs.py +8 -6
  89. mirascope/llm/providers/anthropic/__init__.py +24 -0
  90. mirascope/llm/{clients → providers}/anthropic/_utils/decode.py +5 -4
  91. mirascope/llm/{clients → providers}/anthropic/_utils/encode.py +31 -10
  92. mirascope/llm/providers/anthropic/model_id.py +40 -0
  93. mirascope/llm/{clients/anthropic/clients.py → providers/anthropic/provider.py} +33 -418
  94. mirascope/llm/{clients → providers}/base/__init__.py +3 -3
  95. mirascope/llm/{clients → providers}/base/_utils.py +10 -7
  96. mirascope/llm/{clients/base/client.py → providers/base/base_provider.py} +255 -126
  97. mirascope/llm/providers/google/__init__.py +21 -0
  98. mirascope/llm/{clients → providers}/google/_utils/decode.py +6 -4
  99. mirascope/llm/{clients → providers}/google/_utils/encode.py +30 -24
  100. mirascope/llm/providers/google/model_id.py +28 -0
  101. mirascope/llm/providers/google/provider.py +438 -0
  102. mirascope/llm/providers/load_provider.py +48 -0
  103. mirascope/llm/providers/mlx/__init__.py +24 -0
  104. mirascope/llm/providers/mlx/_utils.py +107 -0
  105. mirascope/llm/providers/mlx/encoding/__init__.py +8 -0
  106. mirascope/llm/providers/mlx/encoding/base.py +69 -0
  107. mirascope/llm/providers/mlx/encoding/transformers.py +131 -0
  108. mirascope/llm/providers/mlx/mlx.py +237 -0
  109. mirascope/llm/providers/mlx/model_id.py +17 -0
  110. mirascope/llm/providers/mlx/provider.py +411 -0
  111. mirascope/llm/providers/model_id.py +16 -0
  112. mirascope/llm/providers/openai/__init__.py +6 -0
  113. mirascope/llm/providers/openai/completions/__init__.py +20 -0
  114. mirascope/llm/{clients/openai/responses → providers/openai/completions}/_utils/__init__.py +2 -0
  115. mirascope/llm/{clients → providers}/openai/completions/_utils/decode.py +5 -3
  116. mirascope/llm/{clients → providers}/openai/completions/_utils/encode.py +33 -23
  117. mirascope/llm/providers/openai/completions/provider.py +456 -0
  118. mirascope/llm/providers/openai/model_id.py +31 -0
  119. mirascope/llm/providers/openai/model_info.py +246 -0
  120. mirascope/llm/providers/openai/provider.py +386 -0
  121. mirascope/llm/providers/openai/responses/__init__.py +21 -0
  122. mirascope/llm/{clients → providers}/openai/responses/_utils/decode.py +5 -3
  123. mirascope/llm/{clients → providers}/openai/responses/_utils/encode.py +28 -17
  124. mirascope/llm/providers/openai/responses/provider.py +470 -0
  125. mirascope/llm/{clients → providers}/openai/shared/_utils.py +7 -3
  126. mirascope/llm/providers/provider_id.py +13 -0
  127. mirascope/llm/providers/provider_registry.py +167 -0
  128. mirascope/llm/responses/base_response.py +10 -5
  129. mirascope/llm/responses/base_stream_response.py +10 -5
  130. mirascope/llm/responses/response.py +24 -13
  131. mirascope/llm/responses/root_response.py +7 -12
  132. mirascope/llm/responses/stream_response.py +35 -23
  133. mirascope/llm/tools/__init__.py +9 -2
  134. mirascope/llm/tools/_utils.py +12 -3
  135. mirascope/llm/tools/decorator.py +10 -10
  136. mirascope/llm/tools/protocols.py +4 -4
  137. mirascope/llm/tools/tool_schema.py +44 -9
  138. mirascope/llm/tools/tools.py +12 -11
  139. mirascope/ops/__init__.py +156 -0
  140. mirascope/ops/_internal/__init__.py +5 -0
  141. mirascope/ops/_internal/closure.py +1118 -0
  142. mirascope/ops/_internal/configuration.py +126 -0
  143. mirascope/ops/_internal/context.py +76 -0
  144. mirascope/ops/_internal/exporters/__init__.py +26 -0
  145. mirascope/ops/_internal/exporters/exporters.py +342 -0
  146. mirascope/ops/_internal/exporters/processors.py +104 -0
  147. mirascope/ops/_internal/exporters/types.py +165 -0
  148. mirascope/ops/_internal/exporters/utils.py +29 -0
  149. mirascope/ops/_internal/instrumentation/__init__.py +8 -0
  150. mirascope/ops/_internal/instrumentation/llm/__init__.py +8 -0
  151. mirascope/ops/_internal/instrumentation/llm/encode.py +238 -0
  152. mirascope/ops/_internal/instrumentation/llm/gen_ai_types/__init__.py +38 -0
  153. mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_input_messages.py +31 -0
  154. mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_output_messages.py +38 -0
  155. mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_system_instructions.py +18 -0
  156. mirascope/ops/_internal/instrumentation/llm/gen_ai_types/shared.py +100 -0
  157. mirascope/ops/_internal/instrumentation/llm/llm.py +1288 -0
  158. mirascope/ops/_internal/propagation.py +198 -0
  159. mirascope/ops/_internal/protocols.py +51 -0
  160. mirascope/ops/_internal/session.py +139 -0
  161. mirascope/ops/_internal/spans.py +232 -0
  162. mirascope/ops/_internal/traced_calls.py +371 -0
  163. mirascope/ops/_internal/traced_functions.py +394 -0
  164. mirascope/ops/_internal/tracing.py +276 -0
  165. mirascope/ops/_internal/types.py +13 -0
  166. mirascope/ops/_internal/utils.py +75 -0
  167. mirascope/ops/_internal/versioned_calls.py +512 -0
  168. mirascope/ops/_internal/versioned_functions.py +346 -0
  169. mirascope/ops/_internal/versioning.py +303 -0
  170. mirascope/ops/exceptions.py +21 -0
  171. {mirascope-2.0.0a1.dist-info → mirascope-2.0.0a3.dist-info}/METADATA +77 -1
  172. mirascope-2.0.0a3.dist-info/RECORD +206 -0
  173. {mirascope-2.0.0a1.dist-info → mirascope-2.0.0a3.dist-info}/WHEEL +1 -1
  174. mirascope/graphs/__init__.py +0 -22
  175. mirascope/graphs/finite_state_machine.py +0 -625
  176. mirascope/llm/agents/__init__.py +0 -15
  177. mirascope/llm/agents/agent.py +0 -97
  178. mirascope/llm/agents/agent_template.py +0 -45
  179. mirascope/llm/agents/decorator.py +0 -176
  180. mirascope/llm/calls/base_call.py +0 -33
  181. mirascope/llm/clients/__init__.py +0 -34
  182. mirascope/llm/clients/anthropic/__init__.py +0 -25
  183. mirascope/llm/clients/anthropic/model_ids.py +0 -8
  184. mirascope/llm/clients/google/__init__.py +0 -20
  185. mirascope/llm/clients/google/clients.py +0 -853
  186. mirascope/llm/clients/google/model_ids.py +0 -15
  187. mirascope/llm/clients/openai/__init__.py +0 -25
  188. mirascope/llm/clients/openai/completions/__init__.py +0 -28
  189. mirascope/llm/clients/openai/completions/_utils/model_features.py +0 -81
  190. mirascope/llm/clients/openai/completions/clients.py +0 -833
  191. mirascope/llm/clients/openai/completions/model_ids.py +0 -8
  192. mirascope/llm/clients/openai/responses/__init__.py +0 -26
  193. mirascope/llm/clients/openai/responses/_utils/model_features.py +0 -87
  194. mirascope/llm/clients/openai/responses/clients.py +0 -832
  195. mirascope/llm/clients/openai/responses/model_ids.py +0 -8
  196. mirascope/llm/clients/providers.py +0 -175
  197. mirascope-2.0.0a1.dist-info/RECORD +0 -102
  198. /mirascope/llm/{clients → providers}/anthropic/_utils/__init__.py +0 -0
  199. /mirascope/llm/{clients → providers}/base/kwargs.py +0 -0
  200. /mirascope/llm/{clients → providers}/base/params.py +0 -0
  201. /mirascope/llm/{clients → providers}/google/_utils/__init__.py +0 -0
  202. /mirascope/llm/{clients → providers}/google/message.py +0 -0
  203. /mirascope/llm/{clients/openai/completions → providers/openai/responses}/_utils/__init__.py +0 -0
  204. /mirascope/llm/{clients → providers}/openai/shared/__init__.py +0 -0
  205. {mirascope-2.0.0a1.dist-info → mirascope-2.0.0a3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1118 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import gc
5
+ import hashlib
6
+ import importlib.metadata
7
+ import importlib.util
8
+ import inspect
9
+ import logging
10
+ import os
11
+ import site
12
+ import subprocess
13
+ import sys
14
+ import tempfile
15
+ from collections.abc import Callable
16
+ from dataclasses import dataclass
17
+ from functools import cached_property, lru_cache
18
+ from pathlib import Path
19
+ from textwrap import dedent
20
+ from types import ModuleType
21
+ from typing import Annotated, Any, TypedDict, TypeVar, cast, get_args, get_origin
22
+
23
+ import libcst as cst
24
+ import libcst.matchers as m
25
+ from libcst import MaybeSentinel
26
+ from packaging.markers import default_environment
27
+ from packaging.requirements import Requirement
28
+
29
+ from ..exceptions import ClosureComputationError
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ _BaseCompoundStatementT = TypeVar(
34
+ "_BaseCompoundStatementT", bound=cst.BaseCompoundStatement
35
+ )
36
+
37
+
38
+ def _is_third_party(module: ModuleType, site_packages: set[str]) -> bool:
39
+ """Returns True if the module is a third-party or standard library module."""
40
+ module_file = getattr(module, "__file__", None)
41
+ return (
42
+ module.__name__ == "mirascope"
43
+ or module.__name__.startswith("mirascope.")
44
+ or module.__name__ in sys.stdlib_module_names
45
+ or module_file is None
46
+ or any(
47
+ str(Path(module_file).resolve()).startswith(site_pkg)
48
+ for site_pkg in site_packages
49
+ )
50
+ )
51
+
52
+
53
+ class _RemoveDocstringTransformer(cst.CSTTransformer):
54
+ """CST transformer to remove docstrings from functions and classes."""
55
+
56
+ def __init__(self, exclude_fn_body: bool) -> None:
57
+ super().__init__()
58
+ self.exclude_fn_body = exclude_fn_body
59
+
60
+ @staticmethod
61
+ def _remove_first_docstring(
62
+ node: _BaseCompoundStatementT,
63
+ ) -> _BaseCompoundStatementT:
64
+ """Returns the node with the first docstring removed from its body."""
65
+ body = node.body
66
+ stmts = list(body.body)
67
+ if stmts:
68
+ first_stmt = stmts[0]
69
+ if m.matches(
70
+ first_stmt, m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())])
71
+ ):
72
+ stmts.pop(0)
73
+
74
+ if not stmts:
75
+ stmts = [
76
+ cst.Expr(
77
+ value=cst.Ellipsis(
78
+ lpar=[],
79
+ rpar=[],
80
+ ),
81
+ semicolon=MaybeSentinel.DEFAULT,
82
+ )
83
+ ]
84
+ if m.matches(node.body, m.IndentedBlock()):
85
+ return node.with_changes(body=stmts[0])
86
+ new_body = body.with_changes(body=stmts)
87
+ return node.with_changes(body=new_body)
88
+
89
+ def leave_FunctionDef(
90
+ self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
91
+ ) -> cst.FunctionDef:
92
+ """Returns the function definition with docstring removed or body replaced with ellipsis."""
93
+ if self.exclude_fn_body:
94
+ stmts = cst.Expr(
95
+ value=cst.Ellipsis(
96
+ lpar=[],
97
+ rpar=[],
98
+ ),
99
+ semicolon=MaybeSentinel.DEFAULT,
100
+ )
101
+ return updated_node.with_changes(body=stmts)
102
+
103
+ return self._remove_first_docstring(updated_node)
104
+
105
+ def leave_ClassDef(
106
+ self, original_node: cst.ClassDef, updated_node: cst.ClassDef
107
+ ) -> cst.ClassDef:
108
+ """Returns the class definition with docstring removed or body replaced with pass."""
109
+ if self.exclude_fn_body:
110
+ pass_stmt = cst.SimpleStatementLine([cst.Pass()])
111
+ new_body = updated_node.body.with_changes(body=[pass_stmt])
112
+ return updated_node.with_changes(body=new_body)
113
+
114
+ return self._remove_first_docstring(updated_node)
115
+
116
+
117
+ def _clean_source_code(
118
+ fn: Callable[..., Any] | type,
119
+ *,
120
+ exclude_fn_body: bool = False,
121
+ ) -> str:
122
+ """Returns the source code of a function or class with docstrings optionally removed."""
123
+ source = dedent(inspect.getsource(fn))
124
+ docstr_flag = os.getenv("MIRASCOPE_VERSIONING_INCLUDE_DOCSTRINGS", "false").lower()
125
+ if docstr_flag in ("1", "true", "yes"):
126
+ return source.rstrip()
127
+ module = cst.parse_module(source)
128
+
129
+ transformer = _RemoveDocstringTransformer(exclude_fn_body=exclude_fn_body)
130
+ new_module = module.visit(transformer)
131
+
132
+ code = new_module.code
133
+ code = code.rstrip()
134
+
135
+ return code
136
+
137
+
138
+ @dataclass(frozen=True)
139
+ class _AttributePath:
140
+ """Represents a parsed attribute access path like 'module.class.method'."""
141
+
142
+ components: list[str]
143
+ """Ordered list from base to final attribute (e.g., ['module', 'class', 'method'])."""
144
+
145
+ @property
146
+ def base_name(self) -> str:
147
+ """Returns the base module or object name."""
148
+ return self.components[0] if self.components else ""
149
+
150
+ @property
151
+ def last_attribute(self) -> str:
152
+ """Returns the last attribute in the chain."""
153
+ return self.components[-1] if self.components else ""
154
+
155
+ @property
156
+ def full_path(self) -> str:
157
+ """Returns the complete dotted path."""
158
+ return ".".join(self.components)
159
+
160
+ @classmethod
161
+ def from_ast_node(cls, node: ast.AST) -> _AttributePath | None:
162
+ """Creates an `_AttributePath` from an AST node.
163
+
164
+ Args:
165
+ node: An AST node (typically ast.Name, ast.Attribute, or ast.Call).
166
+
167
+ Returns:
168
+ `_AttributePath` with parsed components, or None if parsing fails.
169
+ """
170
+ components = []
171
+ current = node
172
+
173
+ while True:
174
+ if isinstance(current, ast.Attribute):
175
+ components.append(current.attr)
176
+ current = current.value
177
+ elif isinstance(current, ast.Call):
178
+ current = current.func
179
+ elif isinstance(current, ast.Name):
180
+ components.append(current.id)
181
+ break
182
+ else:
183
+ break
184
+
185
+ components.reverse()
186
+ return cls(components=components)
187
+
188
+ def __bool__(self) -> bool:
189
+ """Returns True if components exist."""
190
+ return bool(self.components)
191
+
192
+
193
+ class _NameCollector(ast.NodeVisitor):
194
+ """AST visitor that collects all names used in a piece of code."""
195
+
196
+ def __init__(self) -> None:
197
+ self.used_names: list[str] = []
198
+
199
+ def visit_Name(self, node: ast.Name) -> None:
200
+ """Collects name nodes."""
201
+ self.used_names.append(node.id)
202
+
203
+ def visit_Call(self, node: ast.Call) -> None:
204
+ """Collects function names from call nodes."""
205
+ if isinstance(node.func, ast.Name):
206
+ self.used_names.append(node.func.id)
207
+ self.generic_visit(node)
208
+
209
+ def visit_Attribute(self, node: ast.Attribute) -> None:
210
+ """Collects attribute access chains."""
211
+ attr_path = _AttributePath.from_ast_node(node)
212
+ if attr_path:
213
+ self.used_names.append(attr_path.full_path)
214
+ self.used_names.append(attr_path.base_name)
215
+
216
+
217
+ class _ImportCollector(ast.NodeVisitor):
218
+ """AST visitor that collects import statements based on used names."""
219
+
220
+ def __init__(self, used_names: list[str], site_packages: set[str]) -> None:
221
+ self.imports: set[str] = set()
222
+ self.user_defined_imports: set[str] = set()
223
+ self.used_names = used_names
224
+ self.site_packages = site_packages
225
+ self.alias_map: dict[str, str] = {}
226
+
227
+ def _is_used_import(self, import_name: str) -> bool:
228
+ """Returns whether an import with the given name is used in the code."""
229
+ return import_name in self.used_names or any(
230
+ u.startswith(f"{import_name}.") for u in self.used_names
231
+ )
232
+
233
+ def _add_import(
234
+ self, import_statement: str, is_third_party: bool, alias: str | None = None
235
+ ) -> None:
236
+ """Adds import statement to appropriate collection."""
237
+ if alias:
238
+ self.alias_map[alias] = import_statement
239
+
240
+ if is_third_party:
241
+ self.imports.add(import_statement)
242
+ else:
243
+ self.user_defined_imports.add(import_statement)
244
+
245
+ def visit_Import(self, node: ast.Import) -> None:
246
+ """Collects import statements."""
247
+ for name in node.names:
248
+ full_module_name = name.name
249
+ base_module_name = name.name.split(".")[0]
250
+ try:
251
+ module = __import__(base_module_name)
252
+ except ImportError:
253
+ module = None
254
+ import_name = name.asname or base_module_name
255
+
256
+ if not self._is_used_import(import_name):
257
+ continue
258
+
259
+ is_third_party = (
260
+ _is_third_party(module, self.site_packages) if module else False
261
+ )
262
+
263
+ if alias := name.asname:
264
+ import_statement = f"import {full_module_name} as {alias}"
265
+ else:
266
+ import_statement = f"import {full_module_name}"
267
+
268
+ self._add_import(import_statement, is_third_party, name.asname)
269
+
270
+ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
271
+ """Collects from-import statements."""
272
+ if not (module := node.module):
273
+ return
274
+
275
+ try:
276
+ is_third_party = _is_third_party(
277
+ __import__(module.split(".")[0]), self.site_packages
278
+ )
279
+ except ImportError:
280
+ module = "." * node.level + module
281
+ is_third_party = False
282
+
283
+ for name in node.names:
284
+ import_name = name.asname or name.name
285
+
286
+ if not self._is_used_import(import_name):
287
+ continue
288
+
289
+ if alias := name.asname:
290
+ import_statement = f"from {module} import {name.name} as {alias}"
291
+ else:
292
+ import_statement = f"from {module} import {name.name}"
293
+
294
+ self._add_import(import_statement, is_third_party, name.asname)
295
+
296
+
297
+ class _LocalAssignmentCollector(ast.NodeVisitor):
298
+ """AST visitor that collects local variable assignments."""
299
+
300
+ def __init__(self) -> None:
301
+ self.assignments: set[str] = set()
302
+
303
+ def visit_Assign(self, node: ast.Assign) -> None:
304
+ """Collects variable names from assignment statements."""
305
+ if isinstance(node.targets[0], ast.Name):
306
+ self.assignments.add(node.targets[0].id)
307
+ self.generic_visit(node)
308
+
309
+ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
310
+ """Collects variable names from annotated assignment statements."""
311
+ if isinstance(node.target, ast.Name):
312
+ self.assignments.add(node.target.id)
313
+ self.generic_visit(node)
314
+
315
+
316
+ class _GlobalAssignmentCollector(ast.NodeVisitor):
317
+ """AST visitor that collects global assignments used in function."""
318
+
319
+ def __init__(self, used_names: list[str], source: str) -> None:
320
+ self.used_names = used_names
321
+ self.source = source
322
+ self.assignments: list[str] = []
323
+ self.current_function = None
324
+ self.current_class = None
325
+
326
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
327
+ """Tracks function scope while visiting."""
328
+ old_function = self.current_function
329
+ self.current_function = node
330
+ self.generic_visit(node)
331
+ self.current_function = old_function
332
+
333
+ def visit_ClassDef(self, node: ast.ClassDef) -> None:
334
+ """Tracks class scope while visiting."""
335
+ old_class = self.current_class
336
+ self.current_class = node
337
+ self.generic_visit(node)
338
+ self.current_class = old_class
339
+
340
+ def visit_Assign(self, node: ast.Assign) -> None:
341
+ """Collects global assignment statements."""
342
+ if self.current_function is not None or self.current_class is not None:
343
+ return
344
+ for target in node.targets:
345
+ if isinstance(target, ast.Name) and target.id in self.used_names:
346
+ code = ast.get_source_segment(self.source, node)
347
+ if code is not None:
348
+ self.assignments.append(code)
349
+
350
+ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
351
+ """Collects global annotated assignment statements."""
352
+ if self.current_function is not None or self.current_class is not None:
353
+ return
354
+ if isinstance(node.target, ast.Name) and node.target.id in self.used_names:
355
+ code = ast.get_source_segment(self.source, node)
356
+ if code is not None:
357
+ self.assignments.append(code)
358
+
359
+
360
+ def _collect_parameter_names(tree: ast.Module) -> set[str]:
361
+ """Returns set of all parameter names from functions in the AST."""
362
+ params = set()
363
+ for node in ast.walk(tree):
364
+ if isinstance(node, ast.FunctionDef):
365
+ for arg in node.args.args:
366
+ params.add(arg.arg)
367
+ for arg in node.args.kwonlyargs:
368
+ params.add(arg.arg)
369
+ if node.args.vararg:
370
+ params.add(node.args.vararg.arg)
371
+ if node.args.kwarg:
372
+ params.add(node.args.kwarg.arg)
373
+ return params
374
+
375
+
376
+ def _extract_types(annotation: Any) -> set[type]: # noqa: ANN401
377
+ """Returns set of types found in a type annotation."""
378
+ types_found: set[type] = set()
379
+ origin = get_origin(annotation)
380
+
381
+ if origin is not None:
382
+ if origin is Annotated:
383
+ base_annotation, *_ = get_args(annotation)
384
+ types_found |= _extract_types(base_annotation)
385
+ else:
386
+ for arg in get_args(annotation):
387
+ types_found |= _extract_types(arg)
388
+ elif isinstance(annotation, type) and not _is_stdlib_or_builtin(annotation):
389
+ types_found.add(annotation)
390
+ return types_found
391
+
392
+
393
+ def _is_stdlib_or_builtin(obj: Any) -> bool: # noqa: ANN401
394
+ """Returns True if object is from standard library or builtins."""
395
+ if not hasattr(obj, "__module__"):
396
+ return False
397
+
398
+ module_name = obj.__module__
399
+ if not module_name:
400
+ return False
401
+
402
+ return (
403
+ module_name in sys.stdlib_module_names
404
+ or module_name.startswith("collections.")
405
+ or module_name.startswith("typing.")
406
+ or module_name in {"abc", "typing", "builtins", "_collections_abc"}
407
+ )
408
+
409
+
410
+ class _DefinitionCollector(ast.NodeVisitor):
411
+ """AST visitor that collects function and class definitions referenced in code."""
412
+
413
+ def __init__(
414
+ self, module: ModuleType, used_names: list[str], site_packages: set[str]
415
+ ) -> None:
416
+ self.module = module
417
+ self.used_names = used_names
418
+ self.site_packages = site_packages
419
+ self.definitions_to_include: list[Callable[..., Any] | type] = []
420
+ self.definitions_to_analyze: list[Callable[..., Any] | type] = []
421
+ self.imports: set[str] = set()
422
+
423
+ def visit_Name(self, node: ast.Name) -> None:
424
+ """Collects named references to callable definitions."""
425
+ if node.id in self.used_names:
426
+ candidate = getattr(self.module, node.id, None)
427
+ if callable(candidate) and not _is_stdlib_or_builtin(candidate):
428
+ self.definitions_to_include.append(candidate)
429
+ self.generic_visit(node)
430
+
431
+ def _process_decorator(self, decorator_node: ast.AST) -> None:
432
+ """Processes a decorator node to extract its definition."""
433
+ if isinstance(decorator_node, ast.Name):
434
+ if decorator_func := getattr(self.module, decorator_node.id, None):
435
+ self.definitions_to_include.append(decorator_func)
436
+ elif isinstance(decorator_node, ast.Attribute):
437
+ attr_path = _AttributePath.from_ast_node(decorator_node)
438
+ if attr_path:
439
+ base_module = getattr(self.module, attr_path.base_name, None)
440
+ if (
441
+ attr_path.full_path in self.used_names
442
+ and base_module
443
+ and (
444
+ definition := getattr(
445
+ base_module, attr_path.last_attribute, None
446
+ )
447
+ )
448
+ ):
449
+ self.definitions_to_include.append(definition)
450
+
451
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
452
+ """Collects function definitions and their decorators."""
453
+ for decorator_node in node.decorator_list:
454
+ self._process_decorator(decorator_node)
455
+
456
+ nested_func = getattr(self.module, node.name, None)
457
+ if nested_func:
458
+ self.definitions_to_analyze.append(nested_func)
459
+
460
+ self.generic_visit(node)
461
+
462
+ def visit_ClassDef(self, node: ast.ClassDef) -> None:
463
+ """Collects class definitions and their type annotations."""
464
+ if class_def := getattr(self.module, node.name, None):
465
+ self.definitions_to_analyze.append(class_def)
466
+ if hasattr(class_def, "__annotations__"):
467
+ for ann in class_def.__annotations__.values():
468
+ for candidate in _extract_types(ann):
469
+ if (
470
+ isinstance(candidate, type)
471
+ and candidate.__module__ == class_def.__module__
472
+ and candidate.__module__ != "builtins"
473
+ ) and candidate not in self.definitions_to_include:
474
+ self.definitions_to_include.append(candidate)
475
+ for item in node.body:
476
+ if isinstance(item, ast.FunctionDef) and (
477
+ definition := getattr(class_def, item.name, None)
478
+ ):
479
+ self.definitions_to_analyze.append(definition)
480
+ self.generic_visit(node)
481
+
482
+ def _process_name_or_attribute(self, node: ast.AST) -> None:
483
+ """Processes name or attribute nodes to find definitions."""
484
+ if isinstance(node, ast.Name):
485
+ if (
486
+ (obj := getattr(self.module, node.id, None))
487
+ and hasattr(obj, "__name__")
488
+ and not _is_stdlib_or_builtin(obj)
489
+ ):
490
+ self.definitions_to_include.append(obj)
491
+ elif isinstance(node, ast.Attribute):
492
+ attr_path = _AttributePath.from_ast_node(node)
493
+ if not attr_path or attr_path.full_path not in self.used_names:
494
+ return
495
+
496
+ base_module = getattr(self.module, attr_path.base_name, None)
497
+ if (
498
+ base_module
499
+ and isinstance(base_module, ModuleType)
500
+ and _is_third_party(base_module, self.site_packages)
501
+ ):
502
+ return
503
+
504
+ obj = self.module
505
+ for component in attr_path.components:
506
+ obj = getattr(obj, component, None)
507
+ if obj is None:
508
+ break
509
+
510
+ if (
511
+ obj
512
+ and hasattr(obj, "__name__")
513
+ and not _is_stdlib_or_builtin(obj)
514
+ and not isinstance(obj, ModuleType)
515
+ ):
516
+ self.definitions_to_include.append(obj)
517
+
518
+ def visit_Call(self, node: ast.Call) -> None:
519
+ """Collects definitions referenced in function calls."""
520
+ self._process_name_or_attribute(node.func)
521
+ for arg in node.args:
522
+ self._process_name_or_attribute(arg)
523
+ for keyword in node.keywords:
524
+ self._process_name_or_attribute(keyword.value)
525
+ self.generic_visit(node)
526
+
527
+
528
+ class _QualifiedNameRewriter(cst.CSTTransformer):
529
+ """CST transformer that rewrites qualified names to simple names for local definitions."""
530
+
531
+ def __init__(self, local_names: set[str], user_defined_imports: set[str]) -> None:
532
+ super().__init__()
533
+ self.local_names: set[str] = local_names
534
+ self.alias_mapping = {}
535
+ for import_stmt in user_defined_imports:
536
+ if not import_stmt.startswith("from "):
537
+ continue
538
+ parts = import_stmt.split(" ")
539
+ if len(parts) >= 4 and "as" in parts:
540
+ original_name = parts[parts.index("import") + 1]
541
+ alias = parts[parts.index("as") + 1]
542
+ self.alias_mapping[alias] = original_name
543
+
544
+ def _gather_attribute_chain(self, node: cst.Attribute | cst.Name) -> list[str]:
545
+ """Returns the chain of attribute names from an attribute node."""
546
+ names = []
547
+ current = node
548
+
549
+ while isinstance(current, cst.Attribute):
550
+ names.append(current.attr.value)
551
+ current = current.value
552
+
553
+ if isinstance(current, cst.Name):
554
+ names.append(current.value)
555
+
556
+ return list(reversed(names))
557
+
558
+ def leave_Attribute(
559
+ self, original_node: cst.Attribute, updated_node: cst.Attribute
560
+ ) -> cst.Name | cst.Attribute:
561
+ """Returns simplified name if attribute refers to local definition."""
562
+ names = self._gather_attribute_chain(updated_node)
563
+ if names and names[-1] in self.local_names:
564
+ return cst.Name(value=names[-1])
565
+
566
+ return updated_node
567
+
568
+ def leave_Name(self, original_node: cst.Name, updated_node: cst.Name) -> cst.Name:
569
+ """Returns de-aliased name if it was imported with an alias."""
570
+ if updated_node.value in self.alias_mapping:
571
+ return cst.Name(
572
+ value=self.alias_mapping[updated_node.value],
573
+ lpar=updated_node.lpar,
574
+ rpar=updated_node.rpar,
575
+ )
576
+ return updated_node
577
+
578
+
579
+ def _get_class_from_unbound_method(method: Callable[..., Any]) -> type | None:
580
+ """Returns the class that contains the given unbound method."""
581
+ qualname = method.__qualname__
582
+ parts = qualname.split(".")
583
+ class_qualname = ".".join(parts[:-1])
584
+
585
+ for obj in gc.get_objects():
586
+ try:
587
+ object_is_type = isinstance(obj, type)
588
+ except Exception:
589
+ continue
590
+ if object_is_type and getattr(obj, "__qualname__", None) == class_qualname:
591
+ return obj
592
+ return None
593
+
594
+
595
+ def _clean_source_from_string(source: str, exclude_fn_body: bool = False) -> str:
596
+ """Returns cleaned source code string with optional docstring removal."""
597
+ source = dedent(source)
598
+ module = cst.parse_module(source)
599
+ transformer = _RemoveDocstringTransformer(exclude_fn_body=exclude_fn_body)
600
+ new_module = module.visit(transformer)
601
+ return new_module.code.rstrip()
602
+
603
+
604
+ def _get_class_source_from_method(method: Callable[..., Any]) -> str:
605
+ """Get the source code of the class containing the given method.
606
+
607
+ Args:
608
+ method: The method to get the containing class source from.
609
+
610
+ Returns:
611
+ The cleaned source code of the containing class.
612
+
613
+ Raises:
614
+ ValueError: If the class cannot be determined from the method.
615
+ """
616
+ cls = _get_class_from_unbound_method(method)
617
+ if cls is None:
618
+ raise ValueError("Cannot determine class from method via gc")
619
+ source = inspect.getsource(cls)
620
+ return _clean_source_from_string(source)
621
+
622
+
623
+ class _DependencyCollector:
624
+ """Collects dependencies, imports, and source code for function closure."""
625
+
626
+ def __init__(self) -> None:
627
+ self.imports: set[str] = set()
628
+ self.fn_internal_imports: set[str] = set()
629
+ self.user_defined_imports: set[str] = set()
630
+ self.assignments: list[str] = []
631
+ self.source_code: list[str] = []
632
+ self.visited_functions: set[str] = set()
633
+ self.site_packages: set[str] = {
634
+ str(Path(p).resolve()) for p in site.getsitepackages()
635
+ }
636
+ self._last_import_collector: _ImportCollector | None = None
637
+
638
+ def _collect_assignments_and_imports(
639
+ self,
640
+ fn_tree: ast.Module,
641
+ module_tree: ast.Module,
642
+ used_names: list[str],
643
+ module_source: str,
644
+ ) -> None:
645
+ """Collects global assignments and their required imports."""
646
+ local_assignment_collector = _LocalAssignmentCollector()
647
+ local_assignment_collector.visit(fn_tree)
648
+ local_assignments = local_assignment_collector.assignments
649
+
650
+ parameter_names = _collect_parameter_names(fn_tree)
651
+
652
+ global_assignment_collector = _GlobalAssignmentCollector(
653
+ used_names, module_source
654
+ )
655
+ global_assignment_collector.visit(module_tree)
656
+
657
+ for global_assignment in global_assignment_collector.assignments:
658
+ tree = ast.parse(global_assignment)
659
+ stmt = cast(ast.Assign | ast.AnnAssign, tree.body[0])
660
+ if isinstance(stmt, ast.Assign):
661
+ var_name = cast(ast.Name, stmt.targets[0]).id
662
+ else:
663
+ var_name = cast(ast.Name, stmt.target).id
664
+
665
+ if var_name in parameter_names:
666
+ continue
667
+
668
+ if var_name not in used_names or var_name in local_assignments:
669
+ continue
670
+
671
+ self.assignments.append(global_assignment)
672
+
673
+ name_collector = _NameCollector()
674
+ name_collector.visit(tree)
675
+ import_collector = _ImportCollector(
676
+ name_collector.used_names, self.site_packages
677
+ )
678
+ import_collector.visit(module_tree)
679
+ self.imports.update(import_collector.imports)
680
+ self.user_defined_imports.update(import_collector.user_defined_imports)
681
+
682
+ @staticmethod
683
+ def _extract_definition(
684
+ definition: Callable[..., Any] | type | property,
685
+ ) -> Callable[..., Any] | type | None:
686
+ """Returns the actual definition from decorators and properties."""
687
+ if isinstance(definition, property):
688
+ return definition.fget
689
+
690
+ if isinstance(definition, cached_property) or (
691
+ hasattr(definition, "func")
692
+ and getattr(definition, "__name__", None) is None
693
+ ):
694
+ # For Python 3.13+
695
+ return definition.func # pyright: ignore[reportFunctionMemberAccess] # pragma: no cover
696
+
697
+ return definition
698
+
699
+ def _get_source_code(self, definition: Callable[..., Any] | type) -> str | None:
700
+ """Returns the source code for a definition."""
701
+ if definition.__qualname__ in self.visited_functions:
702
+ return None
703
+ self.visited_functions.add(definition.__qualname__)
704
+
705
+ if "." in definition.__qualname__ and inspect.getmodule(definition) is not None:
706
+ try:
707
+ return _get_class_source_from_method(definition)
708
+ except ValueError:
709
+ return _clean_source_code(definition)
710
+
711
+ return _clean_source_code(definition)
712
+
713
+ def _process_imports(
714
+ self,
715
+ module_tree: ast.Module,
716
+ used_names: list[str],
717
+ source: str,
718
+ ) -> None:
719
+ """Process and categorize imports."""
720
+ import_collector = _ImportCollector(used_names, self.site_packages)
721
+ import_collector.visit(module_tree)
722
+
723
+ new_imports = {
724
+ import_stmt
725
+ for import_stmt in import_collector.imports
726
+ if import_stmt not in source
727
+ }
728
+
729
+ self.imports.update(new_imports)
730
+ self.fn_internal_imports.update(import_collector.imports - new_imports)
731
+ self.user_defined_imports.update(import_collector.user_defined_imports)
732
+
733
+ def _process_definitions(
734
+ self, fn_tree: ast.Module, module: ModuleType, used_names: list[str]
735
+ ) -> None:
736
+ """Process nested definitions and dependencies."""
737
+ definition_collector = _DefinitionCollector(
738
+ module, used_names, self.site_packages
739
+ )
740
+ definition_collector.visit(fn_tree)
741
+
742
+ for definition in definition_collector.definitions_to_include:
743
+ self._collect_imports_and_source_code(definition, True)
744
+
745
+ for definition in definition_collector.definitions_to_analyze:
746
+ self._collect_imports_and_source_code(definition, False)
747
+
748
+ def _collect_imports_and_source_code(
749
+ self,
750
+ definition: Callable[..., Any] | type | property,
751
+ include_source: bool,
752
+ ) -> None:
753
+ """Collects imports and optionally source code for a definition."""
754
+ try:
755
+ if _is_stdlib_or_builtin(definition) or isinstance(definition, ModuleType):
756
+ return
757
+
758
+ # property(fget=None) is not reachable via current code paths but kept as guard
759
+ if (
760
+ isinstance(definition, property) and definition.fget is None
761
+ ): # pragma: no cover
762
+ return
763
+
764
+ extracted_definition = _DependencyCollector._extract_definition(definition)
765
+ # Same guard as above; kept for defensive coding
766
+ if extracted_definition is None: # pragma: no cover
767
+ return
768
+
769
+ source = self._get_source_code(extracted_definition)
770
+ if source is None:
771
+ return
772
+
773
+ module = inspect.getmodule(extracted_definition)
774
+ if not module or _is_third_party(module, self.site_packages):
775
+ return
776
+
777
+ module_source = inspect.getsource(module)
778
+ module_tree = ast.parse(module_source)
779
+ fn_tree = ast.parse(source)
780
+
781
+ name_collector = _NameCollector()
782
+ name_collector.visit(fn_tree)
783
+ used_names = list(dict.fromkeys(name_collector.used_names))
784
+
785
+ self._process_imports(module_tree, used_names, source)
786
+
787
+ if include_source:
788
+ for import_stmt in self.user_defined_imports:
789
+ source = source.replace(import_stmt, "")
790
+ self.source_code.insert(0, source)
791
+
792
+ self._collect_assignments_and_imports(
793
+ fn_tree, module_tree, used_names, module_source
794
+ )
795
+
796
+ self._process_definitions(fn_tree, module, used_names)
797
+
798
+ except (OSError, TypeError) as e:
799
+ logger.debug(f"Failed to collect imports for {definition}: {e}")
800
+
801
+ @staticmethod
802
+ def _collect_required_dependencies(imports: set[str]) -> dict[str, dict[str, Any]]:
803
+ """Returns package dependencies required by the import statements."""
804
+ stdlib_modules = set(sys.stdlib_module_names)
805
+ installed_packages = {
806
+ dist.name: dist for dist in importlib.metadata.distributions()
807
+ }
808
+ import_to_dist = importlib.metadata.packages_distributions()
809
+
810
+ dependencies = {}
811
+ imported_dists = {}
812
+ imported_roots = set()
813
+
814
+ for import_stmt in imports:
815
+ parts = import_stmt.strip().split()
816
+ root_module = parts[1].split(".")[0]
817
+ if root_module in stdlib_modules:
818
+ continue
819
+
820
+ imported_roots.add(root_module)
821
+
822
+ dist_names = import_to_dist.get(root_module, [root_module])
823
+ for dist_name in dist_names:
824
+ if dist_name not in installed_packages:
825
+ continue
826
+
827
+ dist = installed_packages[dist_name]
828
+ imported_dists.setdefault(dist_name, dist)
829
+ if dist_name not in dependencies:
830
+ dependencies[dist_name] = {
831
+ "version": dist.version,
832
+ "extras": None,
833
+ }
834
+ break
835
+
836
+ if not imported_dists:
837
+ return {}
838
+
839
+ dist_to_modules = {}
840
+ for module_name, dist_names in import_to_dist.items():
841
+ for dist_name in dist_names:
842
+ dist_to_modules.setdefault(dist_name, set()).add(module_name)
843
+
844
+ base_env = cast(dict[str, str], default_environment().copy())
845
+ base_env["extra"] = ""
846
+ extra_env_cache = {}
847
+
848
+ def _env_for_extra(extra: str) -> dict[str, str]:
849
+ if extra not in extra_env_cache:
850
+ env = cast(dict[str, str], default_environment().copy())
851
+ env["extra"] = extra
852
+ extra_env_cache[extra] = env
853
+ return extra_env_cache[extra]
854
+
855
+ base_requirements = {}
856
+ extra_requirements = {}
857
+
858
+ for dist_name, dist in imported_dists.items():
859
+ base_reqs = set()
860
+ extras_map = {
861
+ extra: set() for extra in dist.metadata.get_all("Provides-Extra", [])
862
+ }
863
+ requirements = dist.requires or []
864
+ for requirement_str in requirements:
865
+ req = Requirement(requirement_str)
866
+ marker = req.marker
867
+ if marker is None or marker.evaluate(base_env):
868
+ base_reqs.add(req.name)
869
+ continue
870
+
871
+ for extra in extras_map:
872
+ if marker.evaluate(_env_for_extra(extra)):
873
+ extras_map[extra].add(req.name)
874
+
875
+ base_requirements[dist_name] = base_reqs
876
+ extra_requirements[dist_name] = extras_map
877
+
878
+ provided_requirements = set()
879
+ for reqs in base_requirements.values():
880
+ provided_requirements.update(reqs)
881
+ provided_requirements.update(imported_dists.keys())
882
+
883
+ for dist_name in sorted(imported_dists):
884
+ extras_to_keep = []
885
+ apply_usage_gate = not dist_name.startswith("mirascope")
886
+ for extra, deps in extra_requirements[dist_name].items():
887
+ if not deps:
888
+ continue
889
+
890
+ if apply_usage_gate and not any(
891
+ dist_to_modules.get(dep, set()) & imported_roots for dep in deps
892
+ ):
893
+ continue
894
+
895
+ missing = [dep for dep in deps if dep not in provided_requirements]
896
+ if missing:
897
+ extras_to_keep.append(extra)
898
+ provided_requirements.update(deps)
899
+
900
+ dependencies[dist_name]["extras"] = extras_to_keep or None
901
+
902
+ return dependencies
903
+
904
+ @classmethod
905
+ def _map_child_to_parent(
906
+ cls,
907
+ child_to_parent: dict[ast.AST, ast.AST | None],
908
+ node: ast.AST,
909
+ parent: ast.AST | None = None,
910
+ ) -> None:
911
+ """Maps each AST node to its parent node."""
912
+ child_to_parent[node] = parent
913
+ for _, value in ast.iter_fields(node):
914
+ if isinstance(value, list):
915
+ for child in value:
916
+ if isinstance(child, ast.AST):
917
+ cls._map_child_to_parent(child_to_parent, child, node)
918
+ elif isinstance(value, ast.AST):
919
+ cls._map_child_to_parent(child_to_parent, value, node)
920
+
921
+ def _extract_local_names(self, code_blocks: list[str]) -> set[str]:
922
+ """Extracts names of locally defined functions and classes."""
923
+ local_names = set()
924
+
925
+ for code in code_blocks:
926
+ tree = ast.parse(code)
927
+ child_to_parent = {}
928
+ self._map_child_to_parent(child_to_parent, tree)
929
+
930
+ for node in ast.walk(tree):
931
+ if isinstance(node, ast.FunctionDef | ast.ClassDef):
932
+ parent = child_to_parent.get(node)
933
+ if isinstance(parent, ast.Module):
934
+ local_names.add(node.name)
935
+
936
+ return local_names
937
+
938
+ @staticmethod
939
+ def _rewrite_code_blocks(
940
+ code_blocks: list[str], rewriter: _QualifiedNameRewriter
941
+ ) -> list[str]:
942
+ """Rewrites code blocks with simplified names."""
943
+ rewritten = []
944
+ for code in code_blocks:
945
+ tree = cst.parse_module(code)
946
+ new_tree = tree.visit(rewriter)
947
+ rewritten.append(new_tree.code)
948
+ return rewritten
949
+
950
+ def collect(
951
+ self, fn: Callable[..., Any]
952
+ ) -> tuple[list[str], list[str], list[str], dict[str, dict[str, Any]]]:
953
+ """Collects all components needed for function closure.
954
+
955
+ Args:
956
+ fn: The function to collect closure information for.
957
+
958
+ Returns:
959
+ A tuple containing:
960
+ - List of import statements
961
+ - List of assignment statements
962
+ - List of source code blocks
963
+ - Dictionary of required dependencies
964
+ """
965
+ self._collect_imports_and_source_code(fn, True)
966
+
967
+ local_names = self._extract_local_names(self.source_code + self.assignments)
968
+ rewriter = _QualifiedNameRewriter(local_names, self.user_defined_imports)
969
+
970
+ assignments = self._rewrite_code_blocks(self.assignments, rewriter)
971
+ source_code = self._rewrite_code_blocks(self.source_code, rewriter)
972
+
973
+ required_dependencies = _DependencyCollector._collect_required_dependencies(
974
+ self.imports | self.fn_internal_imports
975
+ )
976
+
977
+ return (
978
+ list(self.imports),
979
+ list(dict.fromkeys(assignments)),
980
+ source_code,
981
+ required_dependencies,
982
+ )
983
+
984
+
985
+ def _run_ruff(code: str) -> str:
986
+ """Returns formatted code using ruff formatter."""
987
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp_file:
988
+ tmp_file.write(code)
989
+ tmp_path = Path(tmp_file.name)
990
+
991
+ try:
992
+ proc = subprocess.run(
993
+ ["ruff", "check", "--isolated", "--select=I001", "--fix", str(tmp_path)],
994
+ capture_output=True,
995
+ text=True,
996
+ )
997
+
998
+ if proc.returncode not in (0, 1):
999
+ raise subprocess.CalledProcessError(
1000
+ proc.returncode, proc.args, output=proc.stdout, stderr=proc.stderr
1001
+ )
1002
+
1003
+ subprocess.run(
1004
+ ["ruff", "format", "--isolated", "--line-length=88", str(tmp_path)],
1005
+ check=True,
1006
+ capture_output=True,
1007
+ text=True,
1008
+ )
1009
+ processed_code = tmp_path.read_text()
1010
+ return processed_code
1011
+ finally:
1012
+ tmp_path.unlink()
1013
+
1014
+
1015
+ def get_qualified_name(fn: Callable[..., Any]) -> str:
1016
+ """Return the simplified qualified name of a function.
1017
+
1018
+ If the function is defined locally, return the name after '<locals>.';
1019
+ otherwise, return the last non-empty part after splitting by '.'.
1020
+
1021
+ Args:
1022
+ fn: The function to get the qualified name from.
1023
+
1024
+ Returns:
1025
+ The simplified qualified name of the function.
1026
+ """
1027
+ qualified_name = fn.__qualname__
1028
+ if "<locals>." in qualified_name:
1029
+ return qualified_name.split("<locals>.")[-1]
1030
+ else:
1031
+ parts = [part for part in qualified_name.split(".") if part]
1032
+ return parts[-1] if parts else qualified_name
1033
+
1034
+
1035
+ class DependencyInfo(TypedDict):
1036
+ """Represents the dependency information for a closure."""
1037
+
1038
+ version: str
1039
+ """The version of the dependency."""
1040
+
1041
+ extras: list[str] | None
1042
+ """The extras required for the dependency."""
1043
+
1044
+
1045
+ @dataclass(frozen=True, kw_only=True)
1046
+ class Closure:
1047
+ """Represents the closure of a function."""
1048
+
1049
+ name: str
1050
+ """The name of the function."""
1051
+
1052
+ signature: str
1053
+ """The signature of the function."""
1054
+
1055
+ docstring: str | None
1056
+ """The docstring of the function."""
1057
+
1058
+ code: str
1059
+ """The code of the function."""
1060
+
1061
+ hash: str
1062
+ """The hash of the closure."""
1063
+
1064
+ signature_hash: str
1065
+ """The hash of the function signature (determines major version X)."""
1066
+
1067
+ dependencies: dict[str, DependencyInfo]
1068
+ """The dependencies of the closure."""
1069
+
1070
+ @classmethod
1071
+ @lru_cache(maxsize=128)
1072
+ def from_fn(cls, fn: Callable[..., Any]) -> Closure:
1073
+ """Create a closure from a function.
1074
+
1075
+ Args:
1076
+ fn: The function to analyze
1077
+
1078
+ Returns:
1079
+ Closure: The closure of the function.
1080
+
1081
+ Raises:
1082
+ ClosureComputationError: if the closure cannot be computed properly.
1083
+ """
1084
+ collector = _DependencyCollector()
1085
+ imports, assignments, source_code, dependencies = collector.collect(fn)
1086
+ code = "{imports}\n\n{assignments}\n\n{source_code}".format(
1087
+ imports="\n".join(imports),
1088
+ assignments="\n".join(assignments),
1089
+ source_code="\n\n".join(source_code),
1090
+ )
1091
+ qualified_name = get_qualified_name(fn)
1092
+ try:
1093
+ formatted_code = _run_ruff(code)
1094
+ except (subprocess.CalledProcessError, FileNotFoundError, OSError):
1095
+ raise ClosureComputationError(qualified_name=qualified_name)
1096
+ hash_value = hashlib.sha256(formatted_code.encode("utf-8")).hexdigest()
1097
+
1098
+ signature = _run_ruff(_clean_source_code(fn, exclude_fn_body=True)).strip()
1099
+ signature_hash = hashlib.sha256(signature.encode("utf-8")).hexdigest()
1100
+
1101
+ return cls(
1102
+ name=qualified_name,
1103
+ docstring=inspect.getdoc(fn),
1104
+ signature=signature,
1105
+ code=formatted_code,
1106
+ hash=hash_value,
1107
+ signature_hash=signature_hash,
1108
+ dependencies={
1109
+ name: DependencyInfo(
1110
+ version=dep_info["version"],
1111
+ extras=dep_info.get("extras"),
1112
+ )
1113
+ for name, dep_info in dependencies.items()
1114
+ },
1115
+ )
1116
+
1117
+
1118
+ __all__ = ["Closure", "DependencyInfo"]