mirascope 2.0.0a1__py3-none-any.whl → 2.0.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mirascope/__init__.py +2 -2
- mirascope/api/__init__.py +6 -0
- mirascope/api/_generated/README.md +207 -0
- mirascope/api/_generated/__init__.py +85 -0
- mirascope/api/_generated/client.py +155 -0
- mirascope/api/_generated/core/__init__.py +52 -0
- mirascope/api/_generated/core/api_error.py +23 -0
- mirascope/api/_generated/core/client_wrapper.py +58 -0
- mirascope/api/_generated/core/datetime_utils.py +30 -0
- mirascope/api/_generated/core/file.py +70 -0
- mirascope/api/_generated/core/force_multipart.py +16 -0
- mirascope/api/_generated/core/http_client.py +619 -0
- mirascope/api/_generated/core/http_response.py +55 -0
- mirascope/api/_generated/core/jsonable_encoder.py +102 -0
- mirascope/api/_generated/core/pydantic_utilities.py +310 -0
- mirascope/api/_generated/core/query_encoder.py +60 -0
- mirascope/api/_generated/core/remove_none_from_dict.py +11 -0
- mirascope/api/_generated/core/request_options.py +35 -0
- mirascope/api/_generated/core/serialization.py +282 -0
- mirascope/api/_generated/docs/__init__.py +4 -0
- mirascope/api/_generated/docs/client.py +95 -0
- mirascope/api/_generated/docs/raw_client.py +132 -0
- mirascope/api/_generated/environment.py +9 -0
- mirascope/api/_generated/errors/__init__.py +7 -0
- mirascope/api/_generated/errors/bad_request_error.py +15 -0
- mirascope/api/_generated/health/__init__.py +7 -0
- mirascope/api/_generated/health/client.py +96 -0
- mirascope/api/_generated/health/raw_client.py +129 -0
- mirascope/api/_generated/health/types/__init__.py +8 -0
- mirascope/api/_generated/health/types/health_check_response.py +24 -0
- mirascope/api/_generated/health/types/health_check_response_status.py +5 -0
- mirascope/api/_generated/reference.md +167 -0
- mirascope/api/_generated/traces/__init__.py +55 -0
- mirascope/api/_generated/traces/client.py +162 -0
- mirascope/api/_generated/traces/raw_client.py +168 -0
- mirascope/api/_generated/traces/types/__init__.py +95 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item.py +36 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource.py +31 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item.py +25 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value.py +54 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_array_value.py +23 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_kvlist_value.py +28 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_kvlist_value_values_item.py +24 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item.py +35 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope.py +35 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item.py +27 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value.py +54 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_array_value.py +23 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_kvlist_value.py +28 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_kvlist_value_values_item.py +24 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item.py +60 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item.py +29 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value.py +54 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_array_value.py +23 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_kvlist_value.py +28 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_kvlist_value_values_item.py +24 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_status.py +24 -0
- mirascope/api/_generated/traces/types/traces_create_response.py +27 -0
- mirascope/api/_generated/traces/types/traces_create_response_partial_success.py +28 -0
- mirascope/api/_generated/types/__init__.py +21 -0
- mirascope/api/_generated/types/http_api_decode_error.py +31 -0
- mirascope/api/_generated/types/http_api_decode_error_tag.py +5 -0
- mirascope/api/_generated/types/issue.py +44 -0
- mirascope/api/_generated/types/issue_tag.py +17 -0
- mirascope/api/_generated/types/property_key.py +7 -0
- mirascope/api/_generated/types/property_key_tag.py +29 -0
- mirascope/api/_generated/types/property_key_tag_tag.py +5 -0
- mirascope/api/client.py +255 -0
- mirascope/api/settings.py +81 -0
- mirascope/llm/__init__.py +41 -11
- mirascope/llm/calls/calls.py +81 -57
- mirascope/llm/calls/decorator.py +121 -115
- mirascope/llm/content/__init__.py +3 -2
- mirascope/llm/context/_utils.py +19 -6
- mirascope/llm/exceptions.py +30 -16
- mirascope/llm/formatting/_utils.py +9 -5
- mirascope/llm/formatting/format.py +2 -2
- mirascope/llm/formatting/from_call_args.py +2 -2
- mirascope/llm/messages/message.py +13 -5
- mirascope/llm/models/__init__.py +2 -2
- mirascope/llm/models/models.py +189 -81
- mirascope/llm/prompts/__init__.py +13 -12
- mirascope/llm/prompts/_utils.py +27 -24
- mirascope/llm/prompts/decorator.py +133 -204
- mirascope/llm/prompts/prompts.py +424 -0
- mirascope/llm/prompts/protocols.py +25 -59
- mirascope/llm/providers/__init__.py +38 -0
- mirascope/llm/{clients → providers}/_missing_import_stubs.py +8 -6
- mirascope/llm/providers/anthropic/__init__.py +24 -0
- mirascope/llm/{clients → providers}/anthropic/_utils/decode.py +5 -4
- mirascope/llm/{clients → providers}/anthropic/_utils/encode.py +31 -10
- mirascope/llm/providers/anthropic/model_id.py +40 -0
- mirascope/llm/{clients/anthropic/clients.py → providers/anthropic/provider.py} +33 -418
- mirascope/llm/{clients → providers}/base/__init__.py +3 -3
- mirascope/llm/{clients → providers}/base/_utils.py +10 -7
- mirascope/llm/{clients/base/client.py → providers/base/base_provider.py} +255 -126
- mirascope/llm/providers/google/__init__.py +21 -0
- mirascope/llm/{clients → providers}/google/_utils/decode.py +6 -4
- mirascope/llm/{clients → providers}/google/_utils/encode.py +30 -24
- mirascope/llm/providers/google/model_id.py +28 -0
- mirascope/llm/providers/google/provider.py +438 -0
- mirascope/llm/providers/load_provider.py +48 -0
- mirascope/llm/providers/mlx/__init__.py +24 -0
- mirascope/llm/providers/mlx/_utils.py +107 -0
- mirascope/llm/providers/mlx/encoding/__init__.py +8 -0
- mirascope/llm/providers/mlx/encoding/base.py +69 -0
- mirascope/llm/providers/mlx/encoding/transformers.py +131 -0
- mirascope/llm/providers/mlx/mlx.py +237 -0
- mirascope/llm/providers/mlx/model_id.py +17 -0
- mirascope/llm/providers/mlx/provider.py +411 -0
- mirascope/llm/providers/model_id.py +16 -0
- mirascope/llm/providers/openai/__init__.py +6 -0
- mirascope/llm/providers/openai/completions/__init__.py +20 -0
- mirascope/llm/{clients/openai/responses → providers/openai/completions}/_utils/__init__.py +2 -0
- mirascope/llm/{clients → providers}/openai/completions/_utils/decode.py +5 -3
- mirascope/llm/{clients → providers}/openai/completions/_utils/encode.py +33 -23
- mirascope/llm/providers/openai/completions/provider.py +456 -0
- mirascope/llm/providers/openai/model_id.py +31 -0
- mirascope/llm/providers/openai/model_info.py +246 -0
- mirascope/llm/providers/openai/provider.py +386 -0
- mirascope/llm/providers/openai/responses/__init__.py +21 -0
- mirascope/llm/{clients → providers}/openai/responses/_utils/decode.py +5 -3
- mirascope/llm/{clients → providers}/openai/responses/_utils/encode.py +28 -17
- mirascope/llm/providers/openai/responses/provider.py +470 -0
- mirascope/llm/{clients → providers}/openai/shared/_utils.py +7 -3
- mirascope/llm/providers/provider_id.py +13 -0
- mirascope/llm/providers/provider_registry.py +167 -0
- mirascope/llm/responses/base_response.py +10 -5
- mirascope/llm/responses/base_stream_response.py +10 -5
- mirascope/llm/responses/response.py +24 -13
- mirascope/llm/responses/root_response.py +7 -12
- mirascope/llm/responses/stream_response.py +35 -23
- mirascope/llm/tools/__init__.py +9 -2
- mirascope/llm/tools/_utils.py +12 -3
- mirascope/llm/tools/decorator.py +10 -10
- mirascope/llm/tools/protocols.py +4 -4
- mirascope/llm/tools/tool_schema.py +44 -9
- mirascope/llm/tools/tools.py +12 -11
- mirascope/ops/__init__.py +156 -0
- mirascope/ops/_internal/__init__.py +5 -0
- mirascope/ops/_internal/closure.py +1118 -0
- mirascope/ops/_internal/configuration.py +126 -0
- mirascope/ops/_internal/context.py +76 -0
- mirascope/ops/_internal/exporters/__init__.py +26 -0
- mirascope/ops/_internal/exporters/exporters.py +342 -0
- mirascope/ops/_internal/exporters/processors.py +104 -0
- mirascope/ops/_internal/exporters/types.py +165 -0
- mirascope/ops/_internal/exporters/utils.py +29 -0
- mirascope/ops/_internal/instrumentation/__init__.py +8 -0
- mirascope/ops/_internal/instrumentation/llm/__init__.py +8 -0
- mirascope/ops/_internal/instrumentation/llm/encode.py +238 -0
- mirascope/ops/_internal/instrumentation/llm/gen_ai_types/__init__.py +38 -0
- mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_input_messages.py +31 -0
- mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_output_messages.py +38 -0
- mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_system_instructions.py +18 -0
- mirascope/ops/_internal/instrumentation/llm/gen_ai_types/shared.py +100 -0
- mirascope/ops/_internal/instrumentation/llm/llm.py +1288 -0
- mirascope/ops/_internal/propagation.py +198 -0
- mirascope/ops/_internal/protocols.py +51 -0
- mirascope/ops/_internal/session.py +139 -0
- mirascope/ops/_internal/spans.py +232 -0
- mirascope/ops/_internal/traced_calls.py +371 -0
- mirascope/ops/_internal/traced_functions.py +394 -0
- mirascope/ops/_internal/tracing.py +276 -0
- mirascope/ops/_internal/types.py +13 -0
- mirascope/ops/_internal/utils.py +75 -0
- mirascope/ops/_internal/versioned_calls.py +512 -0
- mirascope/ops/_internal/versioned_functions.py +346 -0
- mirascope/ops/_internal/versioning.py +303 -0
- mirascope/ops/exceptions.py +21 -0
- {mirascope-2.0.0a1.dist-info → mirascope-2.0.0a3.dist-info}/METADATA +77 -1
- mirascope-2.0.0a3.dist-info/RECORD +206 -0
- {mirascope-2.0.0a1.dist-info → mirascope-2.0.0a3.dist-info}/WHEEL +1 -1
- mirascope/graphs/__init__.py +0 -22
- mirascope/graphs/finite_state_machine.py +0 -625
- mirascope/llm/agents/__init__.py +0 -15
- mirascope/llm/agents/agent.py +0 -97
- mirascope/llm/agents/agent_template.py +0 -45
- mirascope/llm/agents/decorator.py +0 -176
- mirascope/llm/calls/base_call.py +0 -33
- mirascope/llm/clients/__init__.py +0 -34
- mirascope/llm/clients/anthropic/__init__.py +0 -25
- mirascope/llm/clients/anthropic/model_ids.py +0 -8
- mirascope/llm/clients/google/__init__.py +0 -20
- mirascope/llm/clients/google/clients.py +0 -853
- mirascope/llm/clients/google/model_ids.py +0 -15
- mirascope/llm/clients/openai/__init__.py +0 -25
- mirascope/llm/clients/openai/completions/__init__.py +0 -28
- mirascope/llm/clients/openai/completions/_utils/model_features.py +0 -81
- mirascope/llm/clients/openai/completions/clients.py +0 -833
- mirascope/llm/clients/openai/completions/model_ids.py +0 -8
- mirascope/llm/clients/openai/responses/__init__.py +0 -26
- mirascope/llm/clients/openai/responses/_utils/model_features.py +0 -87
- mirascope/llm/clients/openai/responses/clients.py +0 -832
- mirascope/llm/clients/openai/responses/model_ids.py +0 -8
- mirascope/llm/clients/providers.py +0 -175
- mirascope-2.0.0a1.dist-info/RECORD +0 -102
- /mirascope/llm/{clients → providers}/anthropic/_utils/__init__.py +0 -0
- /mirascope/llm/{clients → providers}/base/kwargs.py +0 -0
- /mirascope/llm/{clients → providers}/base/params.py +0 -0
- /mirascope/llm/{clients → providers}/google/_utils/__init__.py +0 -0
- /mirascope/llm/{clients → providers}/google/message.py +0 -0
- /mirascope/llm/{clients/openai/completions → providers/openai/responses}/_utils/__init__.py +0 -0
- /mirascope/llm/{clients → providers}/openai/shared/__init__.py +0 -0
- {mirascope-2.0.0a1.dist-info → mirascope-2.0.0a3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1118 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
import gc
|
|
5
|
+
import hashlib
|
|
6
|
+
import importlib.metadata
|
|
7
|
+
import importlib.util
|
|
8
|
+
import inspect
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
import site
|
|
12
|
+
import subprocess
|
|
13
|
+
import sys
|
|
14
|
+
import tempfile
|
|
15
|
+
from collections.abc import Callable
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from functools import cached_property, lru_cache
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from textwrap import dedent
|
|
20
|
+
from types import ModuleType
|
|
21
|
+
from typing import Annotated, Any, TypedDict, TypeVar, cast, get_args, get_origin
|
|
22
|
+
|
|
23
|
+
import libcst as cst
|
|
24
|
+
import libcst.matchers as m
|
|
25
|
+
from libcst import MaybeSentinel
|
|
26
|
+
from packaging.markers import default_environment
|
|
27
|
+
from packaging.requirements import Requirement
|
|
28
|
+
|
|
29
|
+
from ..exceptions import ClosureComputationError
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
_BaseCompoundStatementT = TypeVar(
|
|
34
|
+
"_BaseCompoundStatementT", bound=cst.BaseCompoundStatement
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _is_third_party(module: ModuleType, site_packages: set[str]) -> bool:
|
|
39
|
+
"""Returns True if the module is a third-party or standard library module."""
|
|
40
|
+
module_file = getattr(module, "__file__", None)
|
|
41
|
+
return (
|
|
42
|
+
module.__name__ == "mirascope"
|
|
43
|
+
or module.__name__.startswith("mirascope.")
|
|
44
|
+
or module.__name__ in sys.stdlib_module_names
|
|
45
|
+
or module_file is None
|
|
46
|
+
or any(
|
|
47
|
+
str(Path(module_file).resolve()).startswith(site_pkg)
|
|
48
|
+
for site_pkg in site_packages
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class _RemoveDocstringTransformer(cst.CSTTransformer):
|
|
54
|
+
"""CST transformer to remove docstrings from functions and classes."""
|
|
55
|
+
|
|
56
|
+
def __init__(self, exclude_fn_body: bool) -> None:
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.exclude_fn_body = exclude_fn_body
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def _remove_first_docstring(
|
|
62
|
+
node: _BaseCompoundStatementT,
|
|
63
|
+
) -> _BaseCompoundStatementT:
|
|
64
|
+
"""Returns the node with the first docstring removed from its body."""
|
|
65
|
+
body = node.body
|
|
66
|
+
stmts = list(body.body)
|
|
67
|
+
if stmts:
|
|
68
|
+
first_stmt = stmts[0]
|
|
69
|
+
if m.matches(
|
|
70
|
+
first_stmt, m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())])
|
|
71
|
+
):
|
|
72
|
+
stmts.pop(0)
|
|
73
|
+
|
|
74
|
+
if not stmts:
|
|
75
|
+
stmts = [
|
|
76
|
+
cst.Expr(
|
|
77
|
+
value=cst.Ellipsis(
|
|
78
|
+
lpar=[],
|
|
79
|
+
rpar=[],
|
|
80
|
+
),
|
|
81
|
+
semicolon=MaybeSentinel.DEFAULT,
|
|
82
|
+
)
|
|
83
|
+
]
|
|
84
|
+
if m.matches(node.body, m.IndentedBlock()):
|
|
85
|
+
return node.with_changes(body=stmts[0])
|
|
86
|
+
new_body = body.with_changes(body=stmts)
|
|
87
|
+
return node.with_changes(body=new_body)
|
|
88
|
+
|
|
89
|
+
def leave_FunctionDef(
|
|
90
|
+
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
|
|
91
|
+
) -> cst.FunctionDef:
|
|
92
|
+
"""Returns the function definition with docstring removed or body replaced with ellipsis."""
|
|
93
|
+
if self.exclude_fn_body:
|
|
94
|
+
stmts = cst.Expr(
|
|
95
|
+
value=cst.Ellipsis(
|
|
96
|
+
lpar=[],
|
|
97
|
+
rpar=[],
|
|
98
|
+
),
|
|
99
|
+
semicolon=MaybeSentinel.DEFAULT,
|
|
100
|
+
)
|
|
101
|
+
return updated_node.with_changes(body=stmts)
|
|
102
|
+
|
|
103
|
+
return self._remove_first_docstring(updated_node)
|
|
104
|
+
|
|
105
|
+
def leave_ClassDef(
|
|
106
|
+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
|
|
107
|
+
) -> cst.ClassDef:
|
|
108
|
+
"""Returns the class definition with docstring removed or body replaced with pass."""
|
|
109
|
+
if self.exclude_fn_body:
|
|
110
|
+
pass_stmt = cst.SimpleStatementLine([cst.Pass()])
|
|
111
|
+
new_body = updated_node.body.with_changes(body=[pass_stmt])
|
|
112
|
+
return updated_node.with_changes(body=new_body)
|
|
113
|
+
|
|
114
|
+
return self._remove_first_docstring(updated_node)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _clean_source_code(
|
|
118
|
+
fn: Callable[..., Any] | type,
|
|
119
|
+
*,
|
|
120
|
+
exclude_fn_body: bool = False,
|
|
121
|
+
) -> str:
|
|
122
|
+
"""Returns the source code of a function or class with docstrings optionally removed."""
|
|
123
|
+
source = dedent(inspect.getsource(fn))
|
|
124
|
+
docstr_flag = os.getenv("MIRASCOPE_VERSIONING_INCLUDE_DOCSTRINGS", "false").lower()
|
|
125
|
+
if docstr_flag in ("1", "true", "yes"):
|
|
126
|
+
return source.rstrip()
|
|
127
|
+
module = cst.parse_module(source)
|
|
128
|
+
|
|
129
|
+
transformer = _RemoveDocstringTransformer(exclude_fn_body=exclude_fn_body)
|
|
130
|
+
new_module = module.visit(transformer)
|
|
131
|
+
|
|
132
|
+
code = new_module.code
|
|
133
|
+
code = code.rstrip()
|
|
134
|
+
|
|
135
|
+
return code
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@dataclass(frozen=True)
|
|
139
|
+
class _AttributePath:
|
|
140
|
+
"""Represents a parsed attribute access path like 'module.class.method'."""
|
|
141
|
+
|
|
142
|
+
components: list[str]
|
|
143
|
+
"""Ordered list from base to final attribute (e.g., ['module', 'class', 'method'])."""
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def base_name(self) -> str:
|
|
147
|
+
"""Returns the base module or object name."""
|
|
148
|
+
return self.components[0] if self.components else ""
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def last_attribute(self) -> str:
|
|
152
|
+
"""Returns the last attribute in the chain."""
|
|
153
|
+
return self.components[-1] if self.components else ""
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def full_path(self) -> str:
|
|
157
|
+
"""Returns the complete dotted path."""
|
|
158
|
+
return ".".join(self.components)
|
|
159
|
+
|
|
160
|
+
@classmethod
|
|
161
|
+
def from_ast_node(cls, node: ast.AST) -> _AttributePath | None:
|
|
162
|
+
"""Creates an `_AttributePath` from an AST node.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
node: An AST node (typically ast.Name, ast.Attribute, or ast.Call).
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
`_AttributePath` with parsed components, or None if parsing fails.
|
|
169
|
+
"""
|
|
170
|
+
components = []
|
|
171
|
+
current = node
|
|
172
|
+
|
|
173
|
+
while True:
|
|
174
|
+
if isinstance(current, ast.Attribute):
|
|
175
|
+
components.append(current.attr)
|
|
176
|
+
current = current.value
|
|
177
|
+
elif isinstance(current, ast.Call):
|
|
178
|
+
current = current.func
|
|
179
|
+
elif isinstance(current, ast.Name):
|
|
180
|
+
components.append(current.id)
|
|
181
|
+
break
|
|
182
|
+
else:
|
|
183
|
+
break
|
|
184
|
+
|
|
185
|
+
components.reverse()
|
|
186
|
+
return cls(components=components)
|
|
187
|
+
|
|
188
|
+
def __bool__(self) -> bool:
|
|
189
|
+
"""Returns True if components exist."""
|
|
190
|
+
return bool(self.components)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class _NameCollector(ast.NodeVisitor):
|
|
194
|
+
"""AST visitor that collects all names used in a piece of code."""
|
|
195
|
+
|
|
196
|
+
def __init__(self) -> None:
|
|
197
|
+
self.used_names: list[str] = []
|
|
198
|
+
|
|
199
|
+
def visit_Name(self, node: ast.Name) -> None:
|
|
200
|
+
"""Collects name nodes."""
|
|
201
|
+
self.used_names.append(node.id)
|
|
202
|
+
|
|
203
|
+
def visit_Call(self, node: ast.Call) -> None:
|
|
204
|
+
"""Collects function names from call nodes."""
|
|
205
|
+
if isinstance(node.func, ast.Name):
|
|
206
|
+
self.used_names.append(node.func.id)
|
|
207
|
+
self.generic_visit(node)
|
|
208
|
+
|
|
209
|
+
def visit_Attribute(self, node: ast.Attribute) -> None:
|
|
210
|
+
"""Collects attribute access chains."""
|
|
211
|
+
attr_path = _AttributePath.from_ast_node(node)
|
|
212
|
+
if attr_path:
|
|
213
|
+
self.used_names.append(attr_path.full_path)
|
|
214
|
+
self.used_names.append(attr_path.base_name)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class _ImportCollector(ast.NodeVisitor):
|
|
218
|
+
"""AST visitor that collects import statements based on used names."""
|
|
219
|
+
|
|
220
|
+
def __init__(self, used_names: list[str], site_packages: set[str]) -> None:
|
|
221
|
+
self.imports: set[str] = set()
|
|
222
|
+
self.user_defined_imports: set[str] = set()
|
|
223
|
+
self.used_names = used_names
|
|
224
|
+
self.site_packages = site_packages
|
|
225
|
+
self.alias_map: dict[str, str] = {}
|
|
226
|
+
|
|
227
|
+
def _is_used_import(self, import_name: str) -> bool:
|
|
228
|
+
"""Returns whether an import with the given name is used in the code."""
|
|
229
|
+
return import_name in self.used_names or any(
|
|
230
|
+
u.startswith(f"{import_name}.") for u in self.used_names
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def _add_import(
|
|
234
|
+
self, import_statement: str, is_third_party: bool, alias: str | None = None
|
|
235
|
+
) -> None:
|
|
236
|
+
"""Adds import statement to appropriate collection."""
|
|
237
|
+
if alias:
|
|
238
|
+
self.alias_map[alias] = import_statement
|
|
239
|
+
|
|
240
|
+
if is_third_party:
|
|
241
|
+
self.imports.add(import_statement)
|
|
242
|
+
else:
|
|
243
|
+
self.user_defined_imports.add(import_statement)
|
|
244
|
+
|
|
245
|
+
def visit_Import(self, node: ast.Import) -> None:
|
|
246
|
+
"""Collects import statements."""
|
|
247
|
+
for name in node.names:
|
|
248
|
+
full_module_name = name.name
|
|
249
|
+
base_module_name = name.name.split(".")[0]
|
|
250
|
+
try:
|
|
251
|
+
module = __import__(base_module_name)
|
|
252
|
+
except ImportError:
|
|
253
|
+
module = None
|
|
254
|
+
import_name = name.asname or base_module_name
|
|
255
|
+
|
|
256
|
+
if not self._is_used_import(import_name):
|
|
257
|
+
continue
|
|
258
|
+
|
|
259
|
+
is_third_party = (
|
|
260
|
+
_is_third_party(module, self.site_packages) if module else False
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
if alias := name.asname:
|
|
264
|
+
import_statement = f"import {full_module_name} as {alias}"
|
|
265
|
+
else:
|
|
266
|
+
import_statement = f"import {full_module_name}"
|
|
267
|
+
|
|
268
|
+
self._add_import(import_statement, is_third_party, name.asname)
|
|
269
|
+
|
|
270
|
+
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
|
271
|
+
"""Collects from-import statements."""
|
|
272
|
+
if not (module := node.module):
|
|
273
|
+
return
|
|
274
|
+
|
|
275
|
+
try:
|
|
276
|
+
is_third_party = _is_third_party(
|
|
277
|
+
__import__(module.split(".")[0]), self.site_packages
|
|
278
|
+
)
|
|
279
|
+
except ImportError:
|
|
280
|
+
module = "." * node.level + module
|
|
281
|
+
is_third_party = False
|
|
282
|
+
|
|
283
|
+
for name in node.names:
|
|
284
|
+
import_name = name.asname or name.name
|
|
285
|
+
|
|
286
|
+
if not self._is_used_import(import_name):
|
|
287
|
+
continue
|
|
288
|
+
|
|
289
|
+
if alias := name.asname:
|
|
290
|
+
import_statement = f"from {module} import {name.name} as {alias}"
|
|
291
|
+
else:
|
|
292
|
+
import_statement = f"from {module} import {name.name}"
|
|
293
|
+
|
|
294
|
+
self._add_import(import_statement, is_third_party, name.asname)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class _LocalAssignmentCollector(ast.NodeVisitor):
|
|
298
|
+
"""AST visitor that collects local variable assignments."""
|
|
299
|
+
|
|
300
|
+
def __init__(self) -> None:
|
|
301
|
+
self.assignments: set[str] = set()
|
|
302
|
+
|
|
303
|
+
def visit_Assign(self, node: ast.Assign) -> None:
|
|
304
|
+
"""Collects variable names from assignment statements."""
|
|
305
|
+
if isinstance(node.targets[0], ast.Name):
|
|
306
|
+
self.assignments.add(node.targets[0].id)
|
|
307
|
+
self.generic_visit(node)
|
|
308
|
+
|
|
309
|
+
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
|
|
310
|
+
"""Collects variable names from annotated assignment statements."""
|
|
311
|
+
if isinstance(node.target, ast.Name):
|
|
312
|
+
self.assignments.add(node.target.id)
|
|
313
|
+
self.generic_visit(node)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class _GlobalAssignmentCollector(ast.NodeVisitor):
|
|
317
|
+
"""AST visitor that collects global assignments used in function."""
|
|
318
|
+
|
|
319
|
+
def __init__(self, used_names: list[str], source: str) -> None:
|
|
320
|
+
self.used_names = used_names
|
|
321
|
+
self.source = source
|
|
322
|
+
self.assignments: list[str] = []
|
|
323
|
+
self.current_function = None
|
|
324
|
+
self.current_class = None
|
|
325
|
+
|
|
326
|
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
|
327
|
+
"""Tracks function scope while visiting."""
|
|
328
|
+
old_function = self.current_function
|
|
329
|
+
self.current_function = node
|
|
330
|
+
self.generic_visit(node)
|
|
331
|
+
self.current_function = old_function
|
|
332
|
+
|
|
333
|
+
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
|
334
|
+
"""Tracks class scope while visiting."""
|
|
335
|
+
old_class = self.current_class
|
|
336
|
+
self.current_class = node
|
|
337
|
+
self.generic_visit(node)
|
|
338
|
+
self.current_class = old_class
|
|
339
|
+
|
|
340
|
+
def visit_Assign(self, node: ast.Assign) -> None:
|
|
341
|
+
"""Collects global assignment statements."""
|
|
342
|
+
if self.current_function is not None or self.current_class is not None:
|
|
343
|
+
return
|
|
344
|
+
for target in node.targets:
|
|
345
|
+
if isinstance(target, ast.Name) and target.id in self.used_names:
|
|
346
|
+
code = ast.get_source_segment(self.source, node)
|
|
347
|
+
if code is not None:
|
|
348
|
+
self.assignments.append(code)
|
|
349
|
+
|
|
350
|
+
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
|
|
351
|
+
"""Collects global annotated assignment statements."""
|
|
352
|
+
if self.current_function is not None or self.current_class is not None:
|
|
353
|
+
return
|
|
354
|
+
if isinstance(node.target, ast.Name) and node.target.id in self.used_names:
|
|
355
|
+
code = ast.get_source_segment(self.source, node)
|
|
356
|
+
if code is not None:
|
|
357
|
+
self.assignments.append(code)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _collect_parameter_names(tree: ast.Module) -> set[str]:
|
|
361
|
+
"""Returns set of all parameter names from functions in the AST."""
|
|
362
|
+
params = set()
|
|
363
|
+
for node in ast.walk(tree):
|
|
364
|
+
if isinstance(node, ast.FunctionDef):
|
|
365
|
+
for arg in node.args.args:
|
|
366
|
+
params.add(arg.arg)
|
|
367
|
+
for arg in node.args.kwonlyargs:
|
|
368
|
+
params.add(arg.arg)
|
|
369
|
+
if node.args.vararg:
|
|
370
|
+
params.add(node.args.vararg.arg)
|
|
371
|
+
if node.args.kwarg:
|
|
372
|
+
params.add(node.args.kwarg.arg)
|
|
373
|
+
return params
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def _extract_types(annotation: Any) -> set[type]: # noqa: ANN401
|
|
377
|
+
"""Returns set of types found in a type annotation."""
|
|
378
|
+
types_found: set[type] = set()
|
|
379
|
+
origin = get_origin(annotation)
|
|
380
|
+
|
|
381
|
+
if origin is not None:
|
|
382
|
+
if origin is Annotated:
|
|
383
|
+
base_annotation, *_ = get_args(annotation)
|
|
384
|
+
types_found |= _extract_types(base_annotation)
|
|
385
|
+
else:
|
|
386
|
+
for arg in get_args(annotation):
|
|
387
|
+
types_found |= _extract_types(arg)
|
|
388
|
+
elif isinstance(annotation, type) and not _is_stdlib_or_builtin(annotation):
|
|
389
|
+
types_found.add(annotation)
|
|
390
|
+
return types_found
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def _is_stdlib_or_builtin(obj: Any) -> bool: # noqa: ANN401
|
|
394
|
+
"""Returns True if object is from standard library or builtins."""
|
|
395
|
+
if not hasattr(obj, "__module__"):
|
|
396
|
+
return False
|
|
397
|
+
|
|
398
|
+
module_name = obj.__module__
|
|
399
|
+
if not module_name:
|
|
400
|
+
return False
|
|
401
|
+
|
|
402
|
+
return (
|
|
403
|
+
module_name in sys.stdlib_module_names
|
|
404
|
+
or module_name.startswith("collections.")
|
|
405
|
+
or module_name.startswith("typing.")
|
|
406
|
+
or module_name in {"abc", "typing", "builtins", "_collections_abc"}
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
class _DefinitionCollector(ast.NodeVisitor):
|
|
411
|
+
"""AST visitor that collects function and class definitions referenced in code."""
|
|
412
|
+
|
|
413
|
+
def __init__(
|
|
414
|
+
self, module: ModuleType, used_names: list[str], site_packages: set[str]
|
|
415
|
+
) -> None:
|
|
416
|
+
self.module = module
|
|
417
|
+
self.used_names = used_names
|
|
418
|
+
self.site_packages = site_packages
|
|
419
|
+
self.definitions_to_include: list[Callable[..., Any] | type] = []
|
|
420
|
+
self.definitions_to_analyze: list[Callable[..., Any] | type] = []
|
|
421
|
+
self.imports: set[str] = set()
|
|
422
|
+
|
|
423
|
+
def visit_Name(self, node: ast.Name) -> None:
|
|
424
|
+
"""Collects named references to callable definitions."""
|
|
425
|
+
if node.id in self.used_names:
|
|
426
|
+
candidate = getattr(self.module, node.id, None)
|
|
427
|
+
if callable(candidate) and not _is_stdlib_or_builtin(candidate):
|
|
428
|
+
self.definitions_to_include.append(candidate)
|
|
429
|
+
self.generic_visit(node)
|
|
430
|
+
|
|
431
|
+
def _process_decorator(self, decorator_node: ast.AST) -> None:
|
|
432
|
+
"""Processes a decorator node to extract its definition."""
|
|
433
|
+
if isinstance(decorator_node, ast.Name):
|
|
434
|
+
if decorator_func := getattr(self.module, decorator_node.id, None):
|
|
435
|
+
self.definitions_to_include.append(decorator_func)
|
|
436
|
+
elif isinstance(decorator_node, ast.Attribute):
|
|
437
|
+
attr_path = _AttributePath.from_ast_node(decorator_node)
|
|
438
|
+
if attr_path:
|
|
439
|
+
base_module = getattr(self.module, attr_path.base_name, None)
|
|
440
|
+
if (
|
|
441
|
+
attr_path.full_path in self.used_names
|
|
442
|
+
and base_module
|
|
443
|
+
and (
|
|
444
|
+
definition := getattr(
|
|
445
|
+
base_module, attr_path.last_attribute, None
|
|
446
|
+
)
|
|
447
|
+
)
|
|
448
|
+
):
|
|
449
|
+
self.definitions_to_include.append(definition)
|
|
450
|
+
|
|
451
|
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
|
452
|
+
"""Collects function definitions and their decorators."""
|
|
453
|
+
for decorator_node in node.decorator_list:
|
|
454
|
+
self._process_decorator(decorator_node)
|
|
455
|
+
|
|
456
|
+
nested_func = getattr(self.module, node.name, None)
|
|
457
|
+
if nested_func:
|
|
458
|
+
self.definitions_to_analyze.append(nested_func)
|
|
459
|
+
|
|
460
|
+
self.generic_visit(node)
|
|
461
|
+
|
|
462
|
+
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
|
463
|
+
"""Collects class definitions and their type annotations."""
|
|
464
|
+
if class_def := getattr(self.module, node.name, None):
|
|
465
|
+
self.definitions_to_analyze.append(class_def)
|
|
466
|
+
if hasattr(class_def, "__annotations__"):
|
|
467
|
+
for ann in class_def.__annotations__.values():
|
|
468
|
+
for candidate in _extract_types(ann):
|
|
469
|
+
if (
|
|
470
|
+
isinstance(candidate, type)
|
|
471
|
+
and candidate.__module__ == class_def.__module__
|
|
472
|
+
and candidate.__module__ != "builtins"
|
|
473
|
+
) and candidate not in self.definitions_to_include:
|
|
474
|
+
self.definitions_to_include.append(candidate)
|
|
475
|
+
for item in node.body:
|
|
476
|
+
if isinstance(item, ast.FunctionDef) and (
|
|
477
|
+
definition := getattr(class_def, item.name, None)
|
|
478
|
+
):
|
|
479
|
+
self.definitions_to_analyze.append(definition)
|
|
480
|
+
self.generic_visit(node)
|
|
481
|
+
|
|
482
|
+
def _process_name_or_attribute(self, node: ast.AST) -> None:
|
|
483
|
+
"""Processes name or attribute nodes to find definitions."""
|
|
484
|
+
if isinstance(node, ast.Name):
|
|
485
|
+
if (
|
|
486
|
+
(obj := getattr(self.module, node.id, None))
|
|
487
|
+
and hasattr(obj, "__name__")
|
|
488
|
+
and not _is_stdlib_or_builtin(obj)
|
|
489
|
+
):
|
|
490
|
+
self.definitions_to_include.append(obj)
|
|
491
|
+
elif isinstance(node, ast.Attribute):
|
|
492
|
+
attr_path = _AttributePath.from_ast_node(node)
|
|
493
|
+
if not attr_path or attr_path.full_path not in self.used_names:
|
|
494
|
+
return
|
|
495
|
+
|
|
496
|
+
base_module = getattr(self.module, attr_path.base_name, None)
|
|
497
|
+
if (
|
|
498
|
+
base_module
|
|
499
|
+
and isinstance(base_module, ModuleType)
|
|
500
|
+
and _is_third_party(base_module, self.site_packages)
|
|
501
|
+
):
|
|
502
|
+
return
|
|
503
|
+
|
|
504
|
+
obj = self.module
|
|
505
|
+
for component in attr_path.components:
|
|
506
|
+
obj = getattr(obj, component, None)
|
|
507
|
+
if obj is None:
|
|
508
|
+
break
|
|
509
|
+
|
|
510
|
+
if (
|
|
511
|
+
obj
|
|
512
|
+
and hasattr(obj, "__name__")
|
|
513
|
+
and not _is_stdlib_or_builtin(obj)
|
|
514
|
+
and not isinstance(obj, ModuleType)
|
|
515
|
+
):
|
|
516
|
+
self.definitions_to_include.append(obj)
|
|
517
|
+
|
|
518
|
+
def visit_Call(self, node: ast.Call) -> None:
|
|
519
|
+
"""Collects definitions referenced in function calls."""
|
|
520
|
+
self._process_name_or_attribute(node.func)
|
|
521
|
+
for arg in node.args:
|
|
522
|
+
self._process_name_or_attribute(arg)
|
|
523
|
+
for keyword in node.keywords:
|
|
524
|
+
self._process_name_or_attribute(keyword.value)
|
|
525
|
+
self.generic_visit(node)
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
class _QualifiedNameRewriter(cst.CSTTransformer):
|
|
529
|
+
"""CST transformer that rewrites qualified names to simple names for local definitions."""
|
|
530
|
+
|
|
531
|
+
def __init__(self, local_names: set[str], user_defined_imports: set[str]) -> None:
|
|
532
|
+
super().__init__()
|
|
533
|
+
self.local_names: set[str] = local_names
|
|
534
|
+
self.alias_mapping = {}
|
|
535
|
+
for import_stmt in user_defined_imports:
|
|
536
|
+
if not import_stmt.startswith("from "):
|
|
537
|
+
continue
|
|
538
|
+
parts = import_stmt.split(" ")
|
|
539
|
+
if len(parts) >= 4 and "as" in parts:
|
|
540
|
+
original_name = parts[parts.index("import") + 1]
|
|
541
|
+
alias = parts[parts.index("as") + 1]
|
|
542
|
+
self.alias_mapping[alias] = original_name
|
|
543
|
+
|
|
544
|
+
def _gather_attribute_chain(self, node: cst.Attribute | cst.Name) -> list[str]:
|
|
545
|
+
"""Returns the chain of attribute names from an attribute node."""
|
|
546
|
+
names = []
|
|
547
|
+
current = node
|
|
548
|
+
|
|
549
|
+
while isinstance(current, cst.Attribute):
|
|
550
|
+
names.append(current.attr.value)
|
|
551
|
+
current = current.value
|
|
552
|
+
|
|
553
|
+
if isinstance(current, cst.Name):
|
|
554
|
+
names.append(current.value)
|
|
555
|
+
|
|
556
|
+
return list(reversed(names))
|
|
557
|
+
|
|
558
|
+
def leave_Attribute(
|
|
559
|
+
self, original_node: cst.Attribute, updated_node: cst.Attribute
|
|
560
|
+
) -> cst.Name | cst.Attribute:
|
|
561
|
+
"""Returns simplified name if attribute refers to local definition."""
|
|
562
|
+
names = self._gather_attribute_chain(updated_node)
|
|
563
|
+
if names and names[-1] in self.local_names:
|
|
564
|
+
return cst.Name(value=names[-1])
|
|
565
|
+
|
|
566
|
+
return updated_node
|
|
567
|
+
|
|
568
|
+
def leave_Name(self, original_node: cst.Name, updated_node: cst.Name) -> cst.Name:
|
|
569
|
+
"""Returns de-aliased name if it was imported with an alias."""
|
|
570
|
+
if updated_node.value in self.alias_mapping:
|
|
571
|
+
return cst.Name(
|
|
572
|
+
value=self.alias_mapping[updated_node.value],
|
|
573
|
+
lpar=updated_node.lpar,
|
|
574
|
+
rpar=updated_node.rpar,
|
|
575
|
+
)
|
|
576
|
+
return updated_node
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
def _get_class_from_unbound_method(method: Callable[..., Any]) -> type | None:
|
|
580
|
+
"""Returns the class that contains the given unbound method."""
|
|
581
|
+
qualname = method.__qualname__
|
|
582
|
+
parts = qualname.split(".")
|
|
583
|
+
class_qualname = ".".join(parts[:-1])
|
|
584
|
+
|
|
585
|
+
for obj in gc.get_objects():
|
|
586
|
+
try:
|
|
587
|
+
object_is_type = isinstance(obj, type)
|
|
588
|
+
except Exception:
|
|
589
|
+
continue
|
|
590
|
+
if object_is_type and getattr(obj, "__qualname__", None) == class_qualname:
|
|
591
|
+
return obj
|
|
592
|
+
return None
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def _clean_source_from_string(source: str, exclude_fn_body: bool = False) -> str:
|
|
596
|
+
"""Returns cleaned source code string with optional docstring removal."""
|
|
597
|
+
source = dedent(source)
|
|
598
|
+
module = cst.parse_module(source)
|
|
599
|
+
transformer = _RemoveDocstringTransformer(exclude_fn_body=exclude_fn_body)
|
|
600
|
+
new_module = module.visit(transformer)
|
|
601
|
+
return new_module.code.rstrip()
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
def _get_class_source_from_method(method: Callable[..., Any]) -> str:
|
|
605
|
+
"""Get the source code of the class containing the given method.
|
|
606
|
+
|
|
607
|
+
Args:
|
|
608
|
+
method: The method to get the containing class source from.
|
|
609
|
+
|
|
610
|
+
Returns:
|
|
611
|
+
The cleaned source code of the containing class.
|
|
612
|
+
|
|
613
|
+
Raises:
|
|
614
|
+
ValueError: If the class cannot be determined from the method.
|
|
615
|
+
"""
|
|
616
|
+
cls = _get_class_from_unbound_method(method)
|
|
617
|
+
if cls is None:
|
|
618
|
+
raise ValueError("Cannot determine class from method via gc")
|
|
619
|
+
source = inspect.getsource(cls)
|
|
620
|
+
return _clean_source_from_string(source)
|
|
621
|
+
|
|
622
|
+
|
|
623
|
+
class _DependencyCollector:
|
|
624
|
+
"""Collects dependencies, imports, and source code for function closure."""
|
|
625
|
+
|
|
626
|
+
def __init__(self) -> None:
|
|
627
|
+
self.imports: set[str] = set()
|
|
628
|
+
self.fn_internal_imports: set[str] = set()
|
|
629
|
+
self.user_defined_imports: set[str] = set()
|
|
630
|
+
self.assignments: list[str] = []
|
|
631
|
+
self.source_code: list[str] = []
|
|
632
|
+
self.visited_functions: set[str] = set()
|
|
633
|
+
self.site_packages: set[str] = {
|
|
634
|
+
str(Path(p).resolve()) for p in site.getsitepackages()
|
|
635
|
+
}
|
|
636
|
+
self._last_import_collector: _ImportCollector | None = None
|
|
637
|
+
|
|
638
|
+
def _collect_assignments_and_imports(
|
|
639
|
+
self,
|
|
640
|
+
fn_tree: ast.Module,
|
|
641
|
+
module_tree: ast.Module,
|
|
642
|
+
used_names: list[str],
|
|
643
|
+
module_source: str,
|
|
644
|
+
) -> None:
|
|
645
|
+
"""Collects global assignments and their required imports."""
|
|
646
|
+
local_assignment_collector = _LocalAssignmentCollector()
|
|
647
|
+
local_assignment_collector.visit(fn_tree)
|
|
648
|
+
local_assignments = local_assignment_collector.assignments
|
|
649
|
+
|
|
650
|
+
parameter_names = _collect_parameter_names(fn_tree)
|
|
651
|
+
|
|
652
|
+
global_assignment_collector = _GlobalAssignmentCollector(
|
|
653
|
+
used_names, module_source
|
|
654
|
+
)
|
|
655
|
+
global_assignment_collector.visit(module_tree)
|
|
656
|
+
|
|
657
|
+
for global_assignment in global_assignment_collector.assignments:
|
|
658
|
+
tree = ast.parse(global_assignment)
|
|
659
|
+
stmt = cast(ast.Assign | ast.AnnAssign, tree.body[0])
|
|
660
|
+
if isinstance(stmt, ast.Assign):
|
|
661
|
+
var_name = cast(ast.Name, stmt.targets[0]).id
|
|
662
|
+
else:
|
|
663
|
+
var_name = cast(ast.Name, stmt.target).id
|
|
664
|
+
|
|
665
|
+
if var_name in parameter_names:
|
|
666
|
+
continue
|
|
667
|
+
|
|
668
|
+
if var_name not in used_names or var_name in local_assignments:
|
|
669
|
+
continue
|
|
670
|
+
|
|
671
|
+
self.assignments.append(global_assignment)
|
|
672
|
+
|
|
673
|
+
name_collector = _NameCollector()
|
|
674
|
+
name_collector.visit(tree)
|
|
675
|
+
import_collector = _ImportCollector(
|
|
676
|
+
name_collector.used_names, self.site_packages
|
|
677
|
+
)
|
|
678
|
+
import_collector.visit(module_tree)
|
|
679
|
+
self.imports.update(import_collector.imports)
|
|
680
|
+
self.user_defined_imports.update(import_collector.user_defined_imports)
|
|
681
|
+
|
|
682
|
+
@staticmethod
|
|
683
|
+
def _extract_definition(
|
|
684
|
+
definition: Callable[..., Any] | type | property,
|
|
685
|
+
) -> Callable[..., Any] | type | None:
|
|
686
|
+
"""Returns the actual definition from decorators and properties."""
|
|
687
|
+
if isinstance(definition, property):
|
|
688
|
+
return definition.fget
|
|
689
|
+
|
|
690
|
+
if isinstance(definition, cached_property) or (
|
|
691
|
+
hasattr(definition, "func")
|
|
692
|
+
and getattr(definition, "__name__", None) is None
|
|
693
|
+
):
|
|
694
|
+
# For Python 3.13+
|
|
695
|
+
return definition.func # pyright: ignore[reportFunctionMemberAccess] # pragma: no cover
|
|
696
|
+
|
|
697
|
+
return definition
|
|
698
|
+
|
|
699
|
+
def _get_source_code(self, definition: Callable[..., Any] | type) -> str | None:
|
|
700
|
+
"""Returns the source code for a definition."""
|
|
701
|
+
if definition.__qualname__ in self.visited_functions:
|
|
702
|
+
return None
|
|
703
|
+
self.visited_functions.add(definition.__qualname__)
|
|
704
|
+
|
|
705
|
+
if "." in definition.__qualname__ and inspect.getmodule(definition) is not None:
|
|
706
|
+
try:
|
|
707
|
+
return _get_class_source_from_method(definition)
|
|
708
|
+
except ValueError:
|
|
709
|
+
return _clean_source_code(definition)
|
|
710
|
+
|
|
711
|
+
return _clean_source_code(definition)
|
|
712
|
+
|
|
713
|
+
def _process_imports(
|
|
714
|
+
self,
|
|
715
|
+
module_tree: ast.Module,
|
|
716
|
+
used_names: list[str],
|
|
717
|
+
source: str,
|
|
718
|
+
) -> None:
|
|
719
|
+
"""Process and categorize imports."""
|
|
720
|
+
import_collector = _ImportCollector(used_names, self.site_packages)
|
|
721
|
+
import_collector.visit(module_tree)
|
|
722
|
+
|
|
723
|
+
new_imports = {
|
|
724
|
+
import_stmt
|
|
725
|
+
for import_stmt in import_collector.imports
|
|
726
|
+
if import_stmt not in source
|
|
727
|
+
}
|
|
728
|
+
|
|
729
|
+
self.imports.update(new_imports)
|
|
730
|
+
self.fn_internal_imports.update(import_collector.imports - new_imports)
|
|
731
|
+
self.user_defined_imports.update(import_collector.user_defined_imports)
|
|
732
|
+
|
|
733
|
+
def _process_definitions(
|
|
734
|
+
self, fn_tree: ast.Module, module: ModuleType, used_names: list[str]
|
|
735
|
+
) -> None:
|
|
736
|
+
"""Process nested definitions and dependencies."""
|
|
737
|
+
definition_collector = _DefinitionCollector(
|
|
738
|
+
module, used_names, self.site_packages
|
|
739
|
+
)
|
|
740
|
+
definition_collector.visit(fn_tree)
|
|
741
|
+
|
|
742
|
+
for definition in definition_collector.definitions_to_include:
|
|
743
|
+
self._collect_imports_and_source_code(definition, True)
|
|
744
|
+
|
|
745
|
+
for definition in definition_collector.definitions_to_analyze:
|
|
746
|
+
self._collect_imports_and_source_code(definition, False)
|
|
747
|
+
|
|
748
|
+
def _collect_imports_and_source_code(
|
|
749
|
+
self,
|
|
750
|
+
definition: Callable[..., Any] | type | property,
|
|
751
|
+
include_source: bool,
|
|
752
|
+
) -> None:
|
|
753
|
+
"""Collects imports and optionally source code for a definition."""
|
|
754
|
+
try:
|
|
755
|
+
if _is_stdlib_or_builtin(definition) or isinstance(definition, ModuleType):
|
|
756
|
+
return
|
|
757
|
+
|
|
758
|
+
# property(fget=None) is not reachable via current code paths but kept as guard
|
|
759
|
+
if (
|
|
760
|
+
isinstance(definition, property) and definition.fget is None
|
|
761
|
+
): # pragma: no cover
|
|
762
|
+
return
|
|
763
|
+
|
|
764
|
+
extracted_definition = _DependencyCollector._extract_definition(definition)
|
|
765
|
+
# Same guard as above; kept for defensive coding
|
|
766
|
+
if extracted_definition is None: # pragma: no cover
|
|
767
|
+
return
|
|
768
|
+
|
|
769
|
+
source = self._get_source_code(extracted_definition)
|
|
770
|
+
if source is None:
|
|
771
|
+
return
|
|
772
|
+
|
|
773
|
+
module = inspect.getmodule(extracted_definition)
|
|
774
|
+
if not module or _is_third_party(module, self.site_packages):
|
|
775
|
+
return
|
|
776
|
+
|
|
777
|
+
module_source = inspect.getsource(module)
|
|
778
|
+
module_tree = ast.parse(module_source)
|
|
779
|
+
fn_tree = ast.parse(source)
|
|
780
|
+
|
|
781
|
+
name_collector = _NameCollector()
|
|
782
|
+
name_collector.visit(fn_tree)
|
|
783
|
+
used_names = list(dict.fromkeys(name_collector.used_names))
|
|
784
|
+
|
|
785
|
+
self._process_imports(module_tree, used_names, source)
|
|
786
|
+
|
|
787
|
+
if include_source:
|
|
788
|
+
for import_stmt in self.user_defined_imports:
|
|
789
|
+
source = source.replace(import_stmt, "")
|
|
790
|
+
self.source_code.insert(0, source)
|
|
791
|
+
|
|
792
|
+
self._collect_assignments_and_imports(
|
|
793
|
+
fn_tree, module_tree, used_names, module_source
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
self._process_definitions(fn_tree, module, used_names)
|
|
797
|
+
|
|
798
|
+
except (OSError, TypeError) as e:
|
|
799
|
+
logger.debug(f"Failed to collect imports for {definition}: {e}")
|
|
800
|
+
|
|
801
|
+
@staticmethod
|
|
802
|
+
def _collect_required_dependencies(imports: set[str]) -> dict[str, dict[str, Any]]:
|
|
803
|
+
"""Returns package dependencies required by the import statements."""
|
|
804
|
+
stdlib_modules = set(sys.stdlib_module_names)
|
|
805
|
+
installed_packages = {
|
|
806
|
+
dist.name: dist for dist in importlib.metadata.distributions()
|
|
807
|
+
}
|
|
808
|
+
import_to_dist = importlib.metadata.packages_distributions()
|
|
809
|
+
|
|
810
|
+
dependencies = {}
|
|
811
|
+
imported_dists = {}
|
|
812
|
+
imported_roots = set()
|
|
813
|
+
|
|
814
|
+
for import_stmt in imports:
|
|
815
|
+
parts = import_stmt.strip().split()
|
|
816
|
+
root_module = parts[1].split(".")[0]
|
|
817
|
+
if root_module in stdlib_modules:
|
|
818
|
+
continue
|
|
819
|
+
|
|
820
|
+
imported_roots.add(root_module)
|
|
821
|
+
|
|
822
|
+
dist_names = import_to_dist.get(root_module, [root_module])
|
|
823
|
+
for dist_name in dist_names:
|
|
824
|
+
if dist_name not in installed_packages:
|
|
825
|
+
continue
|
|
826
|
+
|
|
827
|
+
dist = installed_packages[dist_name]
|
|
828
|
+
imported_dists.setdefault(dist_name, dist)
|
|
829
|
+
if dist_name not in dependencies:
|
|
830
|
+
dependencies[dist_name] = {
|
|
831
|
+
"version": dist.version,
|
|
832
|
+
"extras": None,
|
|
833
|
+
}
|
|
834
|
+
break
|
|
835
|
+
|
|
836
|
+
if not imported_dists:
|
|
837
|
+
return {}
|
|
838
|
+
|
|
839
|
+
dist_to_modules = {}
|
|
840
|
+
for module_name, dist_names in import_to_dist.items():
|
|
841
|
+
for dist_name in dist_names:
|
|
842
|
+
dist_to_modules.setdefault(dist_name, set()).add(module_name)
|
|
843
|
+
|
|
844
|
+
base_env = cast(dict[str, str], default_environment().copy())
|
|
845
|
+
base_env["extra"] = ""
|
|
846
|
+
extra_env_cache = {}
|
|
847
|
+
|
|
848
|
+
def _env_for_extra(extra: str) -> dict[str, str]:
|
|
849
|
+
if extra not in extra_env_cache:
|
|
850
|
+
env = cast(dict[str, str], default_environment().copy())
|
|
851
|
+
env["extra"] = extra
|
|
852
|
+
extra_env_cache[extra] = env
|
|
853
|
+
return extra_env_cache[extra]
|
|
854
|
+
|
|
855
|
+
base_requirements = {}
|
|
856
|
+
extra_requirements = {}
|
|
857
|
+
|
|
858
|
+
for dist_name, dist in imported_dists.items():
|
|
859
|
+
base_reqs = set()
|
|
860
|
+
extras_map = {
|
|
861
|
+
extra: set() for extra in dist.metadata.get_all("Provides-Extra", [])
|
|
862
|
+
}
|
|
863
|
+
requirements = dist.requires or []
|
|
864
|
+
for requirement_str in requirements:
|
|
865
|
+
req = Requirement(requirement_str)
|
|
866
|
+
marker = req.marker
|
|
867
|
+
if marker is None or marker.evaluate(base_env):
|
|
868
|
+
base_reqs.add(req.name)
|
|
869
|
+
continue
|
|
870
|
+
|
|
871
|
+
for extra in extras_map:
|
|
872
|
+
if marker.evaluate(_env_for_extra(extra)):
|
|
873
|
+
extras_map[extra].add(req.name)
|
|
874
|
+
|
|
875
|
+
base_requirements[dist_name] = base_reqs
|
|
876
|
+
extra_requirements[dist_name] = extras_map
|
|
877
|
+
|
|
878
|
+
provided_requirements = set()
|
|
879
|
+
for reqs in base_requirements.values():
|
|
880
|
+
provided_requirements.update(reqs)
|
|
881
|
+
provided_requirements.update(imported_dists.keys())
|
|
882
|
+
|
|
883
|
+
for dist_name in sorted(imported_dists):
|
|
884
|
+
extras_to_keep = []
|
|
885
|
+
apply_usage_gate = not dist_name.startswith("mirascope")
|
|
886
|
+
for extra, deps in extra_requirements[dist_name].items():
|
|
887
|
+
if not deps:
|
|
888
|
+
continue
|
|
889
|
+
|
|
890
|
+
if apply_usage_gate and not any(
|
|
891
|
+
dist_to_modules.get(dep, set()) & imported_roots for dep in deps
|
|
892
|
+
):
|
|
893
|
+
continue
|
|
894
|
+
|
|
895
|
+
missing = [dep for dep in deps if dep not in provided_requirements]
|
|
896
|
+
if missing:
|
|
897
|
+
extras_to_keep.append(extra)
|
|
898
|
+
provided_requirements.update(deps)
|
|
899
|
+
|
|
900
|
+
dependencies[dist_name]["extras"] = extras_to_keep or None
|
|
901
|
+
|
|
902
|
+
return dependencies
|
|
903
|
+
|
|
904
|
+
@classmethod
|
|
905
|
+
def _map_child_to_parent(
|
|
906
|
+
cls,
|
|
907
|
+
child_to_parent: dict[ast.AST, ast.AST | None],
|
|
908
|
+
node: ast.AST,
|
|
909
|
+
parent: ast.AST | None = None,
|
|
910
|
+
) -> None:
|
|
911
|
+
"""Maps each AST node to its parent node."""
|
|
912
|
+
child_to_parent[node] = parent
|
|
913
|
+
for _, value in ast.iter_fields(node):
|
|
914
|
+
if isinstance(value, list):
|
|
915
|
+
for child in value:
|
|
916
|
+
if isinstance(child, ast.AST):
|
|
917
|
+
cls._map_child_to_parent(child_to_parent, child, node)
|
|
918
|
+
elif isinstance(value, ast.AST):
|
|
919
|
+
cls._map_child_to_parent(child_to_parent, value, node)
|
|
920
|
+
|
|
921
|
+
def _extract_local_names(self, code_blocks: list[str]) -> set[str]:
|
|
922
|
+
"""Extracts names of locally defined functions and classes."""
|
|
923
|
+
local_names = set()
|
|
924
|
+
|
|
925
|
+
for code in code_blocks:
|
|
926
|
+
tree = ast.parse(code)
|
|
927
|
+
child_to_parent = {}
|
|
928
|
+
self._map_child_to_parent(child_to_parent, tree)
|
|
929
|
+
|
|
930
|
+
for node in ast.walk(tree):
|
|
931
|
+
if isinstance(node, ast.FunctionDef | ast.ClassDef):
|
|
932
|
+
parent = child_to_parent.get(node)
|
|
933
|
+
if isinstance(parent, ast.Module):
|
|
934
|
+
local_names.add(node.name)
|
|
935
|
+
|
|
936
|
+
return local_names
|
|
937
|
+
|
|
938
|
+
@staticmethod
|
|
939
|
+
def _rewrite_code_blocks(
|
|
940
|
+
code_blocks: list[str], rewriter: _QualifiedNameRewriter
|
|
941
|
+
) -> list[str]:
|
|
942
|
+
"""Rewrites code blocks with simplified names."""
|
|
943
|
+
rewritten = []
|
|
944
|
+
for code in code_blocks:
|
|
945
|
+
tree = cst.parse_module(code)
|
|
946
|
+
new_tree = tree.visit(rewriter)
|
|
947
|
+
rewritten.append(new_tree.code)
|
|
948
|
+
return rewritten
|
|
949
|
+
|
|
950
|
+
def collect(
|
|
951
|
+
self, fn: Callable[..., Any]
|
|
952
|
+
) -> tuple[list[str], list[str], list[str], dict[str, dict[str, Any]]]:
|
|
953
|
+
"""Collects all components needed for function closure.
|
|
954
|
+
|
|
955
|
+
Args:
|
|
956
|
+
fn: The function to collect closure information for.
|
|
957
|
+
|
|
958
|
+
Returns:
|
|
959
|
+
A tuple containing:
|
|
960
|
+
- List of import statements
|
|
961
|
+
- List of assignment statements
|
|
962
|
+
- List of source code blocks
|
|
963
|
+
- Dictionary of required dependencies
|
|
964
|
+
"""
|
|
965
|
+
self._collect_imports_and_source_code(fn, True)
|
|
966
|
+
|
|
967
|
+
local_names = self._extract_local_names(self.source_code + self.assignments)
|
|
968
|
+
rewriter = _QualifiedNameRewriter(local_names, self.user_defined_imports)
|
|
969
|
+
|
|
970
|
+
assignments = self._rewrite_code_blocks(self.assignments, rewriter)
|
|
971
|
+
source_code = self._rewrite_code_blocks(self.source_code, rewriter)
|
|
972
|
+
|
|
973
|
+
required_dependencies = _DependencyCollector._collect_required_dependencies(
|
|
974
|
+
self.imports | self.fn_internal_imports
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
return (
|
|
978
|
+
list(self.imports),
|
|
979
|
+
list(dict.fromkeys(assignments)),
|
|
980
|
+
source_code,
|
|
981
|
+
required_dependencies,
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
def _run_ruff(code: str) -> str:
|
|
986
|
+
"""Returns formatted code using ruff formatter."""
|
|
987
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp_file:
|
|
988
|
+
tmp_file.write(code)
|
|
989
|
+
tmp_path = Path(tmp_file.name)
|
|
990
|
+
|
|
991
|
+
try:
|
|
992
|
+
proc = subprocess.run(
|
|
993
|
+
["ruff", "check", "--isolated", "--select=I001", "--fix", str(tmp_path)],
|
|
994
|
+
capture_output=True,
|
|
995
|
+
text=True,
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
if proc.returncode not in (0, 1):
|
|
999
|
+
raise subprocess.CalledProcessError(
|
|
1000
|
+
proc.returncode, proc.args, output=proc.stdout, stderr=proc.stderr
|
|
1001
|
+
)
|
|
1002
|
+
|
|
1003
|
+
subprocess.run(
|
|
1004
|
+
["ruff", "format", "--isolated", "--line-length=88", str(tmp_path)],
|
|
1005
|
+
check=True,
|
|
1006
|
+
capture_output=True,
|
|
1007
|
+
text=True,
|
|
1008
|
+
)
|
|
1009
|
+
processed_code = tmp_path.read_text()
|
|
1010
|
+
return processed_code
|
|
1011
|
+
finally:
|
|
1012
|
+
tmp_path.unlink()
|
|
1013
|
+
|
|
1014
|
+
|
|
1015
|
+
def get_qualified_name(fn: Callable[..., Any]) -> str:
|
|
1016
|
+
"""Return the simplified qualified name of a function.
|
|
1017
|
+
|
|
1018
|
+
If the function is defined locally, return the name after '<locals>.';
|
|
1019
|
+
otherwise, return the last non-empty part after splitting by '.'.
|
|
1020
|
+
|
|
1021
|
+
Args:
|
|
1022
|
+
fn: The function to get the qualified name from.
|
|
1023
|
+
|
|
1024
|
+
Returns:
|
|
1025
|
+
The simplified qualified name of the function.
|
|
1026
|
+
"""
|
|
1027
|
+
qualified_name = fn.__qualname__
|
|
1028
|
+
if "<locals>." in qualified_name:
|
|
1029
|
+
return qualified_name.split("<locals>.")[-1]
|
|
1030
|
+
else:
|
|
1031
|
+
parts = [part for part in qualified_name.split(".") if part]
|
|
1032
|
+
return parts[-1] if parts else qualified_name
|
|
1033
|
+
|
|
1034
|
+
|
|
1035
|
+
class DependencyInfo(TypedDict):
|
|
1036
|
+
"""Represents the dependency information for a closure."""
|
|
1037
|
+
|
|
1038
|
+
version: str
|
|
1039
|
+
"""The version of the dependency."""
|
|
1040
|
+
|
|
1041
|
+
extras: list[str] | None
|
|
1042
|
+
"""The extras required for the dependency."""
|
|
1043
|
+
|
|
1044
|
+
|
|
1045
|
+
@dataclass(frozen=True, kw_only=True)
|
|
1046
|
+
class Closure:
|
|
1047
|
+
"""Represents the closure of a function."""
|
|
1048
|
+
|
|
1049
|
+
name: str
|
|
1050
|
+
"""The name of the function."""
|
|
1051
|
+
|
|
1052
|
+
signature: str
|
|
1053
|
+
"""The signature of the function."""
|
|
1054
|
+
|
|
1055
|
+
docstring: str | None
|
|
1056
|
+
"""The docstring of the function."""
|
|
1057
|
+
|
|
1058
|
+
code: str
|
|
1059
|
+
"""The code of the function."""
|
|
1060
|
+
|
|
1061
|
+
hash: str
|
|
1062
|
+
"""The hash of the closure."""
|
|
1063
|
+
|
|
1064
|
+
signature_hash: str
|
|
1065
|
+
"""The hash of the function signature (determines major version X)."""
|
|
1066
|
+
|
|
1067
|
+
dependencies: dict[str, DependencyInfo]
|
|
1068
|
+
"""The dependencies of the closure."""
|
|
1069
|
+
|
|
1070
|
+
@classmethod
|
|
1071
|
+
@lru_cache(maxsize=128)
|
|
1072
|
+
def from_fn(cls, fn: Callable[..., Any]) -> Closure:
|
|
1073
|
+
"""Create a closure from a function.
|
|
1074
|
+
|
|
1075
|
+
Args:
|
|
1076
|
+
fn: The function to analyze
|
|
1077
|
+
|
|
1078
|
+
Returns:
|
|
1079
|
+
Closure: The closure of the function.
|
|
1080
|
+
|
|
1081
|
+
Raises:
|
|
1082
|
+
ClosureComputationError: if the closure cannot be computed properly.
|
|
1083
|
+
"""
|
|
1084
|
+
collector = _DependencyCollector()
|
|
1085
|
+
imports, assignments, source_code, dependencies = collector.collect(fn)
|
|
1086
|
+
code = "{imports}\n\n{assignments}\n\n{source_code}".format(
|
|
1087
|
+
imports="\n".join(imports),
|
|
1088
|
+
assignments="\n".join(assignments),
|
|
1089
|
+
source_code="\n\n".join(source_code),
|
|
1090
|
+
)
|
|
1091
|
+
qualified_name = get_qualified_name(fn)
|
|
1092
|
+
try:
|
|
1093
|
+
formatted_code = _run_ruff(code)
|
|
1094
|
+
except (subprocess.CalledProcessError, FileNotFoundError, OSError):
|
|
1095
|
+
raise ClosureComputationError(qualified_name=qualified_name)
|
|
1096
|
+
hash_value = hashlib.sha256(formatted_code.encode("utf-8")).hexdigest()
|
|
1097
|
+
|
|
1098
|
+
signature = _run_ruff(_clean_source_code(fn, exclude_fn_body=True)).strip()
|
|
1099
|
+
signature_hash = hashlib.sha256(signature.encode("utf-8")).hexdigest()
|
|
1100
|
+
|
|
1101
|
+
return cls(
|
|
1102
|
+
name=qualified_name,
|
|
1103
|
+
docstring=inspect.getdoc(fn),
|
|
1104
|
+
signature=signature,
|
|
1105
|
+
code=formatted_code,
|
|
1106
|
+
hash=hash_value,
|
|
1107
|
+
signature_hash=signature_hash,
|
|
1108
|
+
dependencies={
|
|
1109
|
+
name: DependencyInfo(
|
|
1110
|
+
version=dep_info["version"],
|
|
1111
|
+
extras=dep_info.get("extras"),
|
|
1112
|
+
)
|
|
1113
|
+
for name, dep_info in dependencies.items()
|
|
1114
|
+
},
|
|
1115
|
+
)
|
|
1116
|
+
|
|
1117
|
+
|
|
1118
|
+
__all__ = ["Closure", "DependencyInfo"]
|