offwork 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- offwork/__init__.py +167 -0
- offwork/__main__.py +770 -0
- offwork/_venv.py +174 -0
- offwork/core/__init__.py +15 -0
- offwork/core/errors.py +83 -0
- offwork/core/models.py +174 -0
- offwork/core/pairing.py +389 -0
- offwork/core/progress.py +91 -0
- offwork/core/signing.py +91 -0
- offwork/core/task.py +520 -0
- offwork/core/token.py +184 -0
- offwork/core/version.py +10 -0
- offwork/graph/__init__.py +5 -0
- offwork/graph/analyzer.py +637 -0
- offwork/graph/decorator.py +87 -0
- offwork/graph/graph.py +995 -0
- offwork/graph/store.py +500 -0
- offwork/graph/tracing.py +429 -0
- offwork/py.typed +0 -0
- offwork/typing.py +48 -0
- offwork/worker/__init__.py +18 -0
- offwork/worker/backends/__init__.py +3 -0
- offwork/worker/backends/base.py +149 -0
- offwork/worker/backends/http.py +237 -0
- offwork/worker/backends/local.py +452 -0
- offwork/worker/backends/rabbitmq.py +410 -0
- offwork/worker/backends/redis.py +175 -0
- offwork/worker/deps.py +365 -0
- offwork/worker/remote.py +793 -0
- offwork/worker/result.py +276 -0
- offwork/worker/sandbox/Dockerfile +24 -0
- offwork/worker/sandbox/__init__.py +18 -0
- offwork/worker/sandbox/_protocol.py +50 -0
- offwork/worker/sandbox/docker.py +438 -0
- offwork/worker/sandbox/guest_agent.py +622 -0
- offwork/worker/schedule.py +26 -0
- offwork/worker/worker.py +263 -0
- offwork-0.4.0.dist-info/METADATA +143 -0
- offwork-0.4.0.dist-info/RECORD +42 -0
- offwork-0.4.0.dist-info/WHEEL +4 -0
- offwork-0.4.0.dist-info/entry_points.txt +3 -0
- offwork-0.4.0.dist-info/licenses/LICENSE +661 -0
offwork/core/version.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from importlib.metadata import PackageNotFoundError
|
|
2
|
+
from importlib.metadata import version as _pkg_version
|
|
3
|
+
|
|
4
|
+
_FALLBACK_VERSION = "0.4.0"
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
_VERSION: str = _pkg_version("offwork")
|
|
8
|
+
except PackageNotFoundError:
|
|
9
|
+
# Not installed as a package (e.g. running from source checkout).
|
|
10
|
+
_VERSION = _FALLBACK_VERSION
|
|
@@ -0,0 +1,637 @@
|
|
|
1
|
+
"""AST-based source capture, import extraction, and dependency detection."""
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
import inspect
|
|
5
|
+
import logging
|
|
6
|
+
import textwrap
|
|
7
|
+
import warnings
|
|
8
|
+
import importlib
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
|
|
12
|
+
from offwork.core.models import ImportInfo, FunctionNode
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _is_trace_decorator(node: ast.expr) -> bool:
|
|
18
|
+
"""Return True if *node* is a ``@trace`` or ``@trace(...)`` decorator."""
|
|
19
|
+
if isinstance(node, ast.Name) and node.id == "trace":
|
|
20
|
+
return True
|
|
21
|
+
if (
|
|
22
|
+
isinstance(node, ast.Call)
|
|
23
|
+
and isinstance(node.func, ast.Name)
|
|
24
|
+
and node.func.id == "trace"
|
|
25
|
+
):
|
|
26
|
+
return True
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_function_source(func: Callable[..., object]) -> str:
|
|
31
|
+
"""Get dedented source of func with @trace decorator lines stripped."""
|
|
32
|
+
source = textwrap.dedent(inspect.getsource(func))
|
|
33
|
+
tree = ast.parse(source)
|
|
34
|
+
func_def = tree.body[0]
|
|
35
|
+
if isinstance(func_def, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
36
|
+
lines_to_remove: set[int] = set()
|
|
37
|
+
for decorator in func_def.decorator_list:
|
|
38
|
+
if _is_trace_decorator(decorator):
|
|
39
|
+
for line_no in range(decorator.lineno, (decorator.end_lineno or decorator.lineno) + 1):
|
|
40
|
+
lines_to_remove.add(line_no)
|
|
41
|
+
if lines_to_remove:
|
|
42
|
+
src_lines = source.splitlines(keepends=True)
|
|
43
|
+
source = "".join(
|
|
44
|
+
line for i, line in enumerate(src_lines, 1)
|
|
45
|
+
if i not in lines_to_remove
|
|
46
|
+
)
|
|
47
|
+
logger.debug(
|
|
48
|
+
"Extracted source for %s (%d lines)",
|
|
49
|
+
func.__qualname__,
|
|
50
|
+
source.count("\n"),
|
|
51
|
+
)
|
|
52
|
+
return source
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_module_imports(func: Callable[..., object]) -> list[ImportInfo]:
|
|
56
|
+
"""Extract all top-level import bindings from the module where func is defined."""
|
|
57
|
+
source_file = inspect.getfile(func)
|
|
58
|
+
source_text = Path(source_file).read_text()
|
|
59
|
+
tree = ast.parse(source_text)
|
|
60
|
+
imports: list[ImportInfo] = []
|
|
61
|
+
|
|
62
|
+
for node in tree.body:
|
|
63
|
+
if isinstance(node, ast.Import):
|
|
64
|
+
imports.extend(_extract_import(node))
|
|
65
|
+
|
|
66
|
+
elif isinstance(node, ast.ImportFrom):
|
|
67
|
+
imports.extend(_extract_import_from(node))
|
|
68
|
+
|
|
69
|
+
elif isinstance(node, ast.With):
|
|
70
|
+
package = _parse_install_package_as(node)
|
|
71
|
+
if package is not None:
|
|
72
|
+
for child in node.body:
|
|
73
|
+
if isinstance(child, ast.Import):
|
|
74
|
+
imports.extend(_extract_import(child, package))
|
|
75
|
+
elif isinstance(child, ast.ImportFrom):
|
|
76
|
+
imports.extend(_extract_import_from(child, package))
|
|
77
|
+
continue
|
|
78
|
+
wo = _parse_worker_only_import(node)
|
|
79
|
+
if wo is not False:
|
|
80
|
+
wo_package = wo if isinstance(wo, str) else None
|
|
81
|
+
for child in node.body:
|
|
82
|
+
if isinstance(child, ast.Import):
|
|
83
|
+
imports.extend(_extract_import(child, wo_package, worker_only=True))
|
|
84
|
+
elif isinstance(child, ast.ImportFrom):
|
|
85
|
+
imports.extend(_extract_import_from(child, wo_package, worker_only=True))
|
|
86
|
+
|
|
87
|
+
logger.debug(
|
|
88
|
+
"Found %d import bindings in %s", len(imports), source_file
|
|
89
|
+
)
|
|
90
|
+
return imports
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def get_module_assignments(func: Callable[..., object]) -> dict[str, str]:
|
|
94
|
+
"""Extract top-level variable assignments from the module where func is defined.
|
|
95
|
+
|
|
96
|
+
Returns a dict mapping variable name to its assignment source code.
|
|
97
|
+
Skips dunder names, function/class definitions, and TYPE_CHECKING blocks.
|
|
98
|
+
"""
|
|
99
|
+
source_file = inspect.getfile(func)
|
|
100
|
+
source_text = Path(source_file).read_text()
|
|
101
|
+
tree = ast.parse(source_text)
|
|
102
|
+
assignments: dict[str, str] = {}
|
|
103
|
+
|
|
104
|
+
for node in tree.body:
|
|
105
|
+
# Simple assignment: x = ...
|
|
106
|
+
if isinstance(node, ast.Assign):
|
|
107
|
+
for target in node.targets:
|
|
108
|
+
if isinstance(target, ast.Name) and not target.id.startswith("__"):
|
|
109
|
+
assignments[target.id] = ast.get_source_segment(source_text, node) or ast.unparse(node)
|
|
110
|
+
|
|
111
|
+
# Annotated assignment: x: int = ...
|
|
112
|
+
elif isinstance(node, ast.AnnAssign) and node.value is not None:
|
|
113
|
+
if isinstance(node.target, ast.Name) and not node.target.id.startswith("__"):
|
|
114
|
+
assignments[node.target.id] = ast.get_source_segment(source_text, node) or ast.unparse(node)
|
|
115
|
+
|
|
116
|
+
logger.debug(
|
|
117
|
+
"Found %d module-level assignments in %s",
|
|
118
|
+
len(assignments),
|
|
119
|
+
source_file,
|
|
120
|
+
)
|
|
121
|
+
return assignments
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _is_install_package_as_call(expr: ast.expr) -> bool:
|
|
125
|
+
"""Match ``install_package_as`` or ``offwork.install_package_as``."""
|
|
126
|
+
if isinstance(expr, ast.Name):
|
|
127
|
+
return expr.id == "install_package_as"
|
|
128
|
+
if isinstance(expr, ast.Attribute):
|
|
129
|
+
return expr.attr == "install_package_as"
|
|
130
|
+
return False
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _parse_install_package_as(node: ast.With) -> str | None:
|
|
134
|
+
"""Return the package name if *node* is ``with install_package_as(...)``.
|
|
135
|
+
|
|
136
|
+
Accepts both the bare form (``install_package_as("foo")``) and the
|
|
137
|
+
attribute form (``offwork.install_package_as("foo")``).
|
|
138
|
+
"""
|
|
139
|
+
if len(node.items) != 1:
|
|
140
|
+
return None
|
|
141
|
+
ctx = node.items[0].context_expr
|
|
142
|
+
if not (
|
|
143
|
+
isinstance(ctx, ast.Call)
|
|
144
|
+
and _is_install_package_as_call(ctx.func)
|
|
145
|
+
and len(ctx.args) == 1
|
|
146
|
+
and isinstance(ctx.args[0], ast.Constant)
|
|
147
|
+
and isinstance(ctx.args[0].value, str)
|
|
148
|
+
):
|
|
149
|
+
return None
|
|
150
|
+
return ctx.args[0].value
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _is_worker_only_import_call(expr: ast.expr) -> bool:
|
|
154
|
+
"""Match ``worker_only_import`` or ``offwork.worker_only_import``."""
|
|
155
|
+
if isinstance(expr, ast.Name):
|
|
156
|
+
return expr.id == "worker_only_import"
|
|
157
|
+
if isinstance(expr, ast.Attribute):
|
|
158
|
+
return expr.attr == "worker_only_import"
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _parse_worker_only_import(node: ast.With) -> str | bool:
|
|
163
|
+
"""Detect ``with worker_only_import([package]):`` blocks.
|
|
164
|
+
|
|
165
|
+
Returns ``False`` if not a match, ``True`` if matched without a
|
|
166
|
+
package argument, or the package string if one was supplied.
|
|
167
|
+
"""
|
|
168
|
+
if len(node.items) != 1:
|
|
169
|
+
return False
|
|
170
|
+
ctx = node.items[0].context_expr
|
|
171
|
+
if not isinstance(ctx, ast.Call) or not _is_worker_only_import_call(ctx.func):
|
|
172
|
+
return False
|
|
173
|
+
if not ctx.args:
|
|
174
|
+
return True
|
|
175
|
+
if (
|
|
176
|
+
len(ctx.args) == 1
|
|
177
|
+
and isinstance(ctx.args[0], ast.Constant)
|
|
178
|
+
and isinstance(ctx.args[0].value, str)
|
|
179
|
+
):
|
|
180
|
+
return ctx.args[0].value
|
|
181
|
+
return True
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _extract_import(
|
|
185
|
+
node: ast.Import,
|
|
186
|
+
package: str | None = None,
|
|
187
|
+
worker_only: bool = False,
|
|
188
|
+
) -> list[ImportInfo]:
|
|
189
|
+
result: list[ImportInfo] = []
|
|
190
|
+
for alias in node.names:
|
|
191
|
+
bound = alias.asname or alias.name.split(".")[0]
|
|
192
|
+
stmt = ast.unparse(ast.Import(names=[alias]))
|
|
193
|
+
result.append(ImportInfo(
|
|
194
|
+
statement=stmt, bound_name=bound,
|
|
195
|
+
package=package, worker_only=worker_only,
|
|
196
|
+
))
|
|
197
|
+
return result
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _extract_import_from(
|
|
201
|
+
node: ast.ImportFrom,
|
|
202
|
+
package: str | None = None,
|
|
203
|
+
worker_only: bool = False,
|
|
204
|
+
) -> list[ImportInfo]:
|
|
205
|
+
result: list[ImportInfo] = []
|
|
206
|
+
for alias in node.names:
|
|
207
|
+
if alias.name == "*":
|
|
208
|
+
if node.module is None:
|
|
209
|
+
warnings.warn(
|
|
210
|
+
"Relative star import 'from . import *' "
|
|
211
|
+
"is not supported",
|
|
212
|
+
stacklevel=2,
|
|
213
|
+
)
|
|
214
|
+
continue
|
|
215
|
+
try:
|
|
216
|
+
star_mod = importlib.import_module(node.module)
|
|
217
|
+
exported: list[str]
|
|
218
|
+
if hasattr(star_mod, "__all__"):
|
|
219
|
+
exported = list(star_mod.__all__)
|
|
220
|
+
else:
|
|
221
|
+
exported = [
|
|
222
|
+
n for n in dir(star_mod) if not n.startswith("_")
|
|
223
|
+
]
|
|
224
|
+
logger.debug(
|
|
225
|
+
"Resolved 'from %s import *': %d names",
|
|
226
|
+
node.module,
|
|
227
|
+
len(exported),
|
|
228
|
+
)
|
|
229
|
+
for export_name in exported:
|
|
230
|
+
stmt = f"from {node.module} import {export_name}"
|
|
231
|
+
result.append(
|
|
232
|
+
ImportInfo(
|
|
233
|
+
statement=stmt, bound_name=export_name,
|
|
234
|
+
package=package, worker_only=worker_only,
|
|
235
|
+
)
|
|
236
|
+
)
|
|
237
|
+
except ImportError:
|
|
238
|
+
warnings.warn(
|
|
239
|
+
f"Cannot resolve 'from {node.module} import *': "
|
|
240
|
+
"module not importable",
|
|
241
|
+
stacklevel=2,
|
|
242
|
+
)
|
|
243
|
+
continue
|
|
244
|
+
bound = alias.asname or alias.name
|
|
245
|
+
stmt = ast.unparse(
|
|
246
|
+
ast.ImportFrom(
|
|
247
|
+
module=node.module, names=[alias], level=node.level
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
result.append(ImportInfo(
|
|
251
|
+
statement=stmt, bound_name=bound,
|
|
252
|
+
package=package, worker_only=worker_only,
|
|
253
|
+
))
|
|
254
|
+
return result
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def get_used_names(func_source: str) -> set[str]:
|
|
258
|
+
"""Collect all Name identifiers referenced in the given source code."""
|
|
259
|
+
tree = ast.parse(func_source)
|
|
260
|
+
return {node.id for node in ast.walk(tree) if isinstance(node, ast.Name)}
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def has_super_call(func_source: str) -> bool:
|
|
264
|
+
"""Return True if the function source contains a ``super()`` call."""
|
|
265
|
+
tree = ast.parse(func_source)
|
|
266
|
+
for node in ast.walk(tree):
|
|
267
|
+
if (
|
|
268
|
+
isinstance(node, ast.Call)
|
|
269
|
+
and isinstance(node.func, ast.Name)
|
|
270
|
+
and node.func.id == "super"
|
|
271
|
+
):
|
|
272
|
+
return True
|
|
273
|
+
return False
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def get_class_bases_from_source(
|
|
277
|
+
cls: type,
|
|
278
|
+
) -> tuple[list[str], dict[str, str]]:
|
|
279
|
+
"""Extract base class names and keyword arguments from the class definition.
|
|
280
|
+
|
|
281
|
+
Returns ``(bases, keywords)`` where *bases* is a list of base class names
|
|
282
|
+
(excluding ``object``) and *keywords* maps keyword names to their unparsed
|
|
283
|
+
AST values (e.g. ``{"metaclass": "ABCMeta"}``).
|
|
284
|
+
"""
|
|
285
|
+
try:
|
|
286
|
+
source = textwrap.dedent(inspect.getsource(cls))
|
|
287
|
+
except (OSError, TypeError):
|
|
288
|
+
return [], {}
|
|
289
|
+
try:
|
|
290
|
+
tree = ast.parse(source)
|
|
291
|
+
except SyntaxError:
|
|
292
|
+
return [], {}
|
|
293
|
+
for node in tree.body:
|
|
294
|
+
if isinstance(node, ast.ClassDef) and node.name == cls.__name__:
|
|
295
|
+
bases = [ast.unparse(b) for b in node.bases if ast.unparse(b) != "object"]
|
|
296
|
+
keywords: dict[str, str] = {}
|
|
297
|
+
for kw in node.keywords:
|
|
298
|
+
if kw.arg is not None:
|
|
299
|
+
keywords[kw.arg] = ast.unparse(kw.value)
|
|
300
|
+
return bases, keywords
|
|
301
|
+
return [], {}
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def get_class_attrs(cls: type) -> tuple[list[str], list[str]]:
|
|
305
|
+
"""Extract class-level attributes and decorators from the class source AST.
|
|
306
|
+
|
|
307
|
+
Returns ``(attrs, decorators)`` where *attrs* is a list of source code
|
|
308
|
+
strings for class body statements (assignments, annotated assignments,
|
|
309
|
+
docstrings) and *decorators* is a list of decorator source strings
|
|
310
|
+
(without the ``@`` prefix).
|
|
311
|
+
"""
|
|
312
|
+
try:
|
|
313
|
+
source = textwrap.dedent(inspect.getsource(cls))
|
|
314
|
+
except (OSError, TypeError):
|
|
315
|
+
return [], []
|
|
316
|
+
try:
|
|
317
|
+
tree = ast.parse(source)
|
|
318
|
+
except SyntaxError:
|
|
319
|
+
return [], []
|
|
320
|
+
for node in tree.body:
|
|
321
|
+
if not (isinstance(node, ast.ClassDef) and node.name == cls.__name__):
|
|
322
|
+
continue
|
|
323
|
+
attrs: list[str] = []
|
|
324
|
+
for child in node.body:
|
|
325
|
+
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
326
|
+
continue
|
|
327
|
+
segment = ast.get_source_segment(source, child)
|
|
328
|
+
if segment is not None:
|
|
329
|
+
attrs.append(textwrap.dedent(segment))
|
|
330
|
+
else:
|
|
331
|
+
attrs.append(ast.unparse(child))
|
|
332
|
+
decorators = [ast.unparse(d) for d in node.decorator_list]
|
|
333
|
+
return attrs, decorators
|
|
334
|
+
return [], []
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def find_bare_calls(func_source: str) -> set[str]:
|
|
338
|
+
"""Return all bare function call names from the source AST."""
|
|
339
|
+
tree = ast.parse(func_source)
|
|
340
|
+
return {
|
|
341
|
+
node.func.id
|
|
342
|
+
for node in ast.walk(tree)
|
|
343
|
+
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name)
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def find_self_calls(func_source: str) -> set[str]:
|
|
348
|
+
"""Return method names from ``self.method()`` / ``cls.method()`` calls."""
|
|
349
|
+
tree = ast.parse(func_source)
|
|
350
|
+
names: set[str] = set()
|
|
351
|
+
for node in ast.walk(tree):
|
|
352
|
+
if (
|
|
353
|
+
isinstance(node, ast.Call)
|
|
354
|
+
and isinstance(node.func, ast.Attribute)
|
|
355
|
+
and isinstance(node.func.value, ast.Name)
|
|
356
|
+
and node.func.value.id in ("self", "cls")
|
|
357
|
+
):
|
|
358
|
+
names.add(node.func.attr)
|
|
359
|
+
return names
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def filter_imports(
|
|
363
|
+
all_imports: list[ImportInfo], used_names: set[str]
|
|
364
|
+
) -> list[ImportInfo]:
|
|
365
|
+
"""Keep only imports whose bound_name appears in used_names."""
|
|
366
|
+
result = [imp for imp in all_imports if imp.bound_name in used_names]
|
|
367
|
+
logger.debug(
|
|
368
|
+
"Filtered imports: %d/%d retained", len(result), len(all_imports)
|
|
369
|
+
)
|
|
370
|
+
return result
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def hoist_closure_vars(source: str, closure_vars: dict[str, str]) -> str:
|
|
374
|
+
"""Add closure vars as keyword-only params with default values."""
|
|
375
|
+
if not closure_vars:
|
|
376
|
+
return source
|
|
377
|
+
tree = ast.parse(source)
|
|
378
|
+
func_def = tree.body[0]
|
|
379
|
+
assert isinstance(func_def, (ast.FunctionDef, ast.AsyncFunctionDef))
|
|
380
|
+
for name, repr_value in closure_vars.items():
|
|
381
|
+
func_def.args.kwonlyargs.append(ast.arg(arg=name))
|
|
382
|
+
func_def.args.kw_defaults.append(
|
|
383
|
+
ast.parse(repr_value, mode="eval").body
|
|
384
|
+
)
|
|
385
|
+
return ast.unparse(tree)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def hoist_closure_func_refs(
|
|
389
|
+
source: str,
|
|
390
|
+
closure_func_refs: dict[str, str],
|
|
391
|
+
nodes: dict[str, FunctionNode],
|
|
392
|
+
) -> str:
|
|
393
|
+
"""Add closure func refs as keyword-only params defaulting to the function name."""
|
|
394
|
+
if not closure_func_refs:
|
|
395
|
+
return source
|
|
396
|
+
tree = ast.parse(source)
|
|
397
|
+
func_def = tree.body[0]
|
|
398
|
+
assert isinstance(func_def, (ast.FunctionDef, ast.AsyncFunctionDef))
|
|
399
|
+
for var_name, qname in closure_func_refs.items():
|
|
400
|
+
func_name = nodes[qname].name if qname in nodes else qname.rsplit(".", 1)[-1]
|
|
401
|
+
func_def.args.kwonlyargs.append(ast.arg(arg=var_name))
|
|
402
|
+
func_def.args.kw_defaults.append(
|
|
403
|
+
ast.Name(id=func_name, ctx=ast.Load())
|
|
404
|
+
)
|
|
405
|
+
return ast.unparse(tree)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def _resolve_owner_class(qualname: str) -> str | None:
|
|
409
|
+
"""Extract the owning class name from a function's __qualname__, or None."""
|
|
410
|
+
parts = qualname.rsplit(".", 1)
|
|
411
|
+
if len(parts) == 1:
|
|
412
|
+
return None
|
|
413
|
+
prefix = parts[0]
|
|
414
|
+
if "<locals>" not in prefix:
|
|
415
|
+
return prefix
|
|
416
|
+
# For nested classes like "outer.<locals>.MyClass.__init__",
|
|
417
|
+
# extract the class name after the last "<locals>." segment.
|
|
418
|
+
after_locals = prefix.rsplit("<locals>.", 1)[-1]
|
|
419
|
+
# If there's still a class name (not empty, not another scope marker)
|
|
420
|
+
if after_locals and "<" not in after_locals:
|
|
421
|
+
return after_locals
|
|
422
|
+
return None
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def _extract_annotation_type_names(annotation: ast.expr) -> set[str]:
|
|
426
|
+
"""Extract potential class names from a type annotation AST node."""
|
|
427
|
+
names: set[str] = set()
|
|
428
|
+
for node in ast.walk(annotation):
|
|
429
|
+
if isinstance(node, ast.Name):
|
|
430
|
+
names.add(node.id)
|
|
431
|
+
elif isinstance(node, ast.Attribute):
|
|
432
|
+
names.add(node.attr)
|
|
433
|
+
return names
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def _resolve_root_var(node: ast.expr) -> str | None:
|
|
437
|
+
"""Walk up through subscripts/attributes to find the root variable name."""
|
|
438
|
+
while True:
|
|
439
|
+
if isinstance(node, ast.Name):
|
|
440
|
+
return node.id
|
|
441
|
+
if isinstance(node, (ast.Subscript, ast.Attribute)):
|
|
442
|
+
node = node.value
|
|
443
|
+
else:
|
|
444
|
+
return None
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def _prefer_same_module(
|
|
448
|
+
matches: list[FunctionNode], func_module: str
|
|
449
|
+
) -> FunctionNode:
|
|
450
|
+
"""Pick a node from the same module when possible, otherwise first match."""
|
|
451
|
+
same_module = [m for m in matches if m.module == func_module]
|
|
452
|
+
return same_module[0] if same_module else matches[0]
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def _classify_calls(
|
|
456
|
+
tree: ast.Module,
|
|
457
|
+
) -> tuple[set[str], set[str], set[tuple[str, str]]]:
|
|
458
|
+
"""Walk the AST and classify calls into bare, self/cls, and obj.method."""
|
|
459
|
+
bare_calls: set[str] = set()
|
|
460
|
+
self_calls: set[str] = set()
|
|
461
|
+
obj_method_calls: set[tuple[str, str]] = set()
|
|
462
|
+
|
|
463
|
+
for node in ast.walk(tree):
|
|
464
|
+
if not isinstance(node, ast.Call):
|
|
465
|
+
continue
|
|
466
|
+
if isinstance(node.func, ast.Name):
|
|
467
|
+
bare_calls.add(node.func.id)
|
|
468
|
+
elif isinstance(node.func, ast.Attribute):
|
|
469
|
+
root_var = _resolve_root_var(node.func.value)
|
|
470
|
+
if root_var is None:
|
|
471
|
+
continue
|
|
472
|
+
if root_var in ("self", "cls"):
|
|
473
|
+
self_calls.add(node.func.attr)
|
|
474
|
+
else:
|
|
475
|
+
obj_method_calls.add((root_var, node.func.attr))
|
|
476
|
+
|
|
477
|
+
return bare_calls, self_calls, obj_method_calls
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def _extract_param_types(tree: ast.Module) -> dict[str, set[str]]:
|
|
481
|
+
"""Extract parameter type annotation names from the first function def."""
|
|
482
|
+
for node in ast.walk(tree):
|
|
483
|
+
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
484
|
+
continue
|
|
485
|
+
param_types: dict[str, set[str]] = {}
|
|
486
|
+
for arg in node.args.args + node.args.kwonlyargs:
|
|
487
|
+
if arg.annotation is not None:
|
|
488
|
+
param_types[arg.arg] = _extract_annotation_type_names(arg.annotation)
|
|
489
|
+
if param_types:
|
|
490
|
+
logger.debug("Type annotations: %s", {k: sorted(v) for k, v in param_types.items()})
|
|
491
|
+
return param_types
|
|
492
|
+
return {}
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def _resolve_bare_calls(
|
|
496
|
+
called_names: set[str],
|
|
497
|
+
func_module: str,
|
|
498
|
+
registry: dict[str, FunctionNode],
|
|
499
|
+
) -> list[str]:
|
|
500
|
+
"""Match bare function calls (``helper()``) against the registry."""
|
|
501
|
+
deps: list[str] = []
|
|
502
|
+
for called in called_names:
|
|
503
|
+
matches = [node for node in registry.values() if node.name == called]
|
|
504
|
+
if matches:
|
|
505
|
+
chosen = _prefer_same_module(matches, func_module)
|
|
506
|
+
logger.debug("Bare call %s() -> %s", called, chosen.qualified_name)
|
|
507
|
+
deps.append(chosen.qualified_name)
|
|
508
|
+
continue
|
|
509
|
+
# Check for class constructor: ClassName() -> ClassName.__init__
|
|
510
|
+
init_matches = [
|
|
511
|
+
node for node in registry.values()
|
|
512
|
+
if node.name == "__init__" and node.owner_class == called
|
|
513
|
+
]
|
|
514
|
+
if init_matches:
|
|
515
|
+
chosen = _prefer_same_module(init_matches, func_module)
|
|
516
|
+
logger.debug("Constructor call %s() -> %s", called, chosen.qualified_name)
|
|
517
|
+
deps.append(chosen.qualified_name)
|
|
518
|
+
continue
|
|
519
|
+
# Class with no user-defined ``__init__`` (inherits ``object.__init__``):
|
|
520
|
+
# link to every registered method so the whole class block is emitted.
|
|
521
|
+
method_matches = [
|
|
522
|
+
node for node in registry.values() if node.owner_class == called
|
|
523
|
+
]
|
|
524
|
+
if method_matches:
|
|
525
|
+
for node in method_matches:
|
|
526
|
+
logger.debug("Bare class call %s() -> %s", called, node.qualified_name)
|
|
527
|
+
deps.append(node.qualified_name)
|
|
528
|
+
return deps
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
def _resolve_self_calls(
|
|
532
|
+
self_calls: set[str],
|
|
533
|
+
owner_class: str,
|
|
534
|
+
func_module: str,
|
|
535
|
+
registry: dict[str, FunctionNode],
|
|
536
|
+
existing_deps: list[str],
|
|
537
|
+
) -> list[str]:
|
|
538
|
+
"""Match ``self.method()`` / ``cls.method()`` calls against the registry."""
|
|
539
|
+
deps: list[str] = []
|
|
540
|
+
for method_name in self_calls:
|
|
541
|
+
matches = [
|
|
542
|
+
node for node in registry.values()
|
|
543
|
+
if node.name == method_name and node.owner_class == owner_class
|
|
544
|
+
]
|
|
545
|
+
if not matches:
|
|
546
|
+
continue
|
|
547
|
+
chosen = _prefer_same_module(matches, func_module)
|
|
548
|
+
if chosen.qualified_name not in existing_deps:
|
|
549
|
+
logger.debug("self/cls call .%s() -> %s", method_name, chosen.qualified_name)
|
|
550
|
+
deps.append(chosen.qualified_name)
|
|
551
|
+
return deps
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def _build_class_method_index(
|
|
555
|
+
registry: dict[str, FunctionNode],
|
|
556
|
+
) -> dict[str, dict[str, str]]:
|
|
557
|
+
"""Build lookup: simple class name -> {method_name -> qualified_name}."""
|
|
558
|
+
index: dict[str, dict[str, str]] = {}
|
|
559
|
+
for node in registry.values():
|
|
560
|
+
if node.owner_class is None:
|
|
561
|
+
continue
|
|
562
|
+
simple_class = node.owner_class.rsplit(".", 1)[-1]
|
|
563
|
+
index.setdefault(simple_class, {})[node.name] = node.qualified_name
|
|
564
|
+
return index
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def _resolve_obj_method_calls(
|
|
568
|
+
obj_method_calls: set[tuple[str, str]],
|
|
569
|
+
param_types: dict[str, set[str]],
|
|
570
|
+
registry: dict[str, FunctionNode],
|
|
571
|
+
existing_deps: list[str],
|
|
572
|
+
) -> list[str]:
|
|
573
|
+
"""Match ``obj.method()`` calls via type annotations or unambiguous match."""
|
|
574
|
+
class_methods = _build_class_method_index(registry)
|
|
575
|
+
deps: list[str] = []
|
|
576
|
+
|
|
577
|
+
for var_name, method_name in obj_method_calls:
|
|
578
|
+
resolved = _resolve_single_obj_method(
|
|
579
|
+
var_name, method_name, param_types, class_methods, registry,
|
|
580
|
+
)
|
|
581
|
+
if resolved is not None and resolved not in existing_deps:
|
|
582
|
+
deps.append(resolved)
|
|
583
|
+
|
|
584
|
+
return deps
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
def _resolve_single_obj_method(
|
|
588
|
+
var_name: str,
|
|
589
|
+
method_name: str,
|
|
590
|
+
param_types: dict[str, set[str]],
|
|
591
|
+
class_methods: dict[str, dict[str, str]],
|
|
592
|
+
registry: dict[str, FunctionNode],
|
|
593
|
+
) -> str | None:
|
|
594
|
+
"""Resolve a single ``obj.method()`` call to a qualified name."""
|
|
595
|
+
if var_name in param_types:
|
|
596
|
+
for type_name in param_types[var_name]:
|
|
597
|
+
methods = class_methods.get(type_name)
|
|
598
|
+
if methods and method_name in methods:
|
|
599
|
+
resolved = methods[method_name]
|
|
600
|
+
logger.debug("%s.%s() resolved via annotation -> %s", var_name, method_name, resolved)
|
|
601
|
+
return resolved
|
|
602
|
+
return None
|
|
603
|
+
|
|
604
|
+
candidates = [
|
|
605
|
+
node.qualified_name for node in registry.values()
|
|
606
|
+
if node.name == method_name and node.owner_class is not None
|
|
607
|
+
]
|
|
608
|
+
if len(candidates) == 1:
|
|
609
|
+
logger.debug("%s.%s() resolved via unambiguous match -> %s", var_name, method_name, candidates[0])
|
|
610
|
+
return candidates[0]
|
|
611
|
+
return None
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
def detect_traced_dependencies(
|
|
615
|
+
func_source: str,
|
|
616
|
+
func_module: str,
|
|
617
|
+
registry: dict[str, FunctionNode],
|
|
618
|
+
owner_class: str | None = None,
|
|
619
|
+
) -> list[str]:
|
|
620
|
+
"""Find qualified names of traced functions called in func_source."""
|
|
621
|
+
tree = ast.parse(func_source)
|
|
622
|
+
called_names, self_calls, obj_method_calls = _classify_calls(tree)
|
|
623
|
+
param_types = _extract_param_types(tree)
|
|
624
|
+
|
|
625
|
+
deps = _resolve_bare_calls(called_names, func_module, registry)
|
|
626
|
+
|
|
627
|
+
if owner_class and self_calls:
|
|
628
|
+
deps.extend(_resolve_self_calls(
|
|
629
|
+
self_calls, owner_class, func_module, registry, deps,
|
|
630
|
+
))
|
|
631
|
+
|
|
632
|
+
if obj_method_calls:
|
|
633
|
+
deps.extend(_resolve_obj_method_calls(
|
|
634
|
+
obj_method_calls, param_types, registry, deps,
|
|
635
|
+
))
|
|
636
|
+
|
|
637
|
+
return sorted(deps)
|