experimaestro 1.11.1__py3-none-any.whl → 2.0.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of experimaestro might be problematic. Click here for more details.
- experimaestro/__init__.py +10 -11
- experimaestro/annotations.py +167 -206
- experimaestro/cli/__init__.py +140 -16
- experimaestro/cli/filter.py +42 -74
- experimaestro/cli/jobs.py +157 -106
- experimaestro/cli/progress.py +269 -0
- experimaestro/cli/refactor.py +249 -0
- experimaestro/click.py +0 -1
- experimaestro/commandline.py +19 -3
- experimaestro/connectors/__init__.py +22 -3
- experimaestro/connectors/local.py +12 -0
- experimaestro/core/arguments.py +192 -37
- experimaestro/core/identifier.py +127 -12
- experimaestro/core/objects/__init__.py +6 -0
- experimaestro/core/objects/config.py +702 -285
- experimaestro/core/objects/config_walk.py +24 -6
- experimaestro/core/serialization.py +91 -34
- experimaestro/core/serializers.py +1 -8
- experimaestro/core/subparameters.py +164 -0
- experimaestro/core/types.py +198 -83
- experimaestro/exceptions.py +26 -0
- experimaestro/experiments/cli.py +107 -25
- experimaestro/generators.py +50 -9
- experimaestro/huggingface.py +3 -1
- experimaestro/launcherfinder/parser.py +29 -0
- experimaestro/launcherfinder/registry.py +3 -3
- experimaestro/launchers/__init__.py +26 -1
- experimaestro/launchers/direct.py +12 -0
- experimaestro/launchers/slurm/base.py +154 -2
- experimaestro/mkdocs/base.py +6 -8
- experimaestro/mkdocs/metaloader.py +0 -1
- experimaestro/mypy.py +452 -7
- experimaestro/notifications.py +75 -16
- experimaestro/progress.py +404 -0
- experimaestro/rpyc.py +0 -1
- experimaestro/run.py +19 -6
- experimaestro/scheduler/__init__.py +18 -1
- experimaestro/scheduler/base.py +504 -959
- experimaestro/scheduler/dependencies.py +43 -28
- experimaestro/scheduler/dynamic_outputs.py +259 -130
- experimaestro/scheduler/experiment.py +582 -0
- experimaestro/scheduler/interfaces.py +474 -0
- experimaestro/scheduler/jobs.py +485 -0
- experimaestro/scheduler/services.py +186 -12
- experimaestro/scheduler/signal_handler.py +32 -0
- experimaestro/scheduler/state.py +1 -1
- experimaestro/scheduler/state_db.py +388 -0
- experimaestro/scheduler/state_provider.py +2345 -0
- experimaestro/scheduler/state_sync.py +834 -0
- experimaestro/scheduler/workspace.py +52 -10
- experimaestro/scriptbuilder.py +7 -0
- experimaestro/server/__init__.py +153 -32
- experimaestro/server/data/index.css +0 -125
- experimaestro/server/data/index.css.map +1 -1
- experimaestro/server/data/index.js +194 -58
- experimaestro/server/data/index.js.map +1 -1
- experimaestro/settings.py +47 -6
- experimaestro/sphinx/__init__.py +3 -3
- experimaestro/taskglobals.py +20 -0
- experimaestro/tests/conftest.py +80 -0
- experimaestro/tests/core/test_generics.py +2 -2
- experimaestro/tests/identifier_stability.json +45 -0
- experimaestro/tests/launchers/bin/sacct +6 -2
- experimaestro/tests/launchers/bin/sbatch +4 -2
- experimaestro/tests/launchers/common.py +2 -2
- experimaestro/tests/launchers/test_slurm.py +80 -0
- experimaestro/tests/restart.py +1 -1
- experimaestro/tests/tasks/all.py +7 -0
- experimaestro/tests/tasks/test_dynamic.py +231 -0
- experimaestro/tests/test_checkers.py +2 -2
- experimaestro/tests/test_cli_jobs.py +615 -0
- experimaestro/tests/test_dependencies.py +11 -17
- experimaestro/tests/test_deprecated.py +630 -0
- experimaestro/tests/test_environment.py +200 -0
- experimaestro/tests/test_experiment.py +3 -3
- experimaestro/tests/test_file_progress.py +425 -0
- experimaestro/tests/test_file_progress_integration.py +477 -0
- experimaestro/tests/test_forward.py +3 -3
- experimaestro/tests/test_generators.py +93 -0
- experimaestro/tests/test_identifier.py +520 -169
- experimaestro/tests/test_identifier_stability.py +458 -0
- experimaestro/tests/test_instance.py +16 -21
- experimaestro/tests/test_multitoken.py +442 -0
- experimaestro/tests/test_mypy.py +433 -0
- experimaestro/tests/test_objects.py +314 -30
- experimaestro/tests/test_outputs.py +8 -8
- experimaestro/tests/test_param.py +22 -26
- experimaestro/tests/test_partial_paths.py +231 -0
- experimaestro/tests/test_progress.py +2 -50
- experimaestro/tests/test_resumable_task.py +480 -0
- experimaestro/tests/test_serializers.py +141 -60
- experimaestro/tests/test_state_db.py +434 -0
- experimaestro/tests/test_subparameters.py +160 -0
- experimaestro/tests/test_tags.py +151 -15
- experimaestro/tests/test_tasks.py +137 -160
- experimaestro/tests/test_token_locking.py +252 -0
- experimaestro/tests/test_tokens.py +25 -19
- experimaestro/tests/test_types.py +133 -11
- experimaestro/tests/test_validation.py +19 -19
- experimaestro/tests/test_workspace_triggers.py +158 -0
- experimaestro/tests/token_reschedule.py +5 -3
- experimaestro/tests/utils.py +2 -2
- experimaestro/tokens.py +154 -57
- experimaestro/tools/diff.py +8 -1
- experimaestro/tui/__init__.py +8 -0
- experimaestro/tui/app.py +2303 -0
- experimaestro/tui/app.tcss +353 -0
- experimaestro/tui/log_viewer.py +228 -0
- experimaestro/typingutils.py +11 -2
- experimaestro/utils/__init__.py +23 -0
- experimaestro/utils/environment.py +148 -0
- experimaestro/utils/git.py +129 -0
- experimaestro/utils/resources.py +1 -1
- experimaestro/version.py +34 -0
- {experimaestro-1.11.1.dist-info → experimaestro-2.0.0b4.dist-info}/METADATA +70 -39
- experimaestro-2.0.0b4.dist-info/RECORD +181 -0
- {experimaestro-1.11.1.dist-info → experimaestro-2.0.0b4.dist-info}/WHEEL +1 -1
- experimaestro-2.0.0b4.dist-info/entry_points.txt +16 -0
- experimaestro/compat.py +0 -6
- experimaestro/core/objects.pyi +0 -225
- experimaestro/server/data/0c35d18bf06992036b69.woff2 +0 -0
- experimaestro/server/data/219aa9140e099e6c72ed.woff2 +0 -0
- experimaestro/server/data/3a4004a46a653d4b2166.woff +0 -0
- experimaestro/server/data/3baa5b8f3469222b822d.woff +0 -0
- experimaestro/server/data/4d73cb90e394b34b7670.woff +0 -0
- experimaestro/server/data/4ef4218c522f1eb6b5b1.woff2 +0 -0
- experimaestro/server/data/5d681e2edae8c60630db.woff +0 -0
- experimaestro/server/data/6f420cf17cc0d7676fad.woff2 +0 -0
- experimaestro/server/data/c380809fd3677d7d6903.woff2 +0 -0
- experimaestro/server/data/f882956fd323fd322f31.woff +0 -0
- experimaestro-1.11.1.dist-info/RECORD +0 -158
- experimaestro-1.11.1.dist-info/entry_points.txt +0 -17
- {experimaestro-1.11.1.dist-info → experimaestro-2.0.0b4.dist-info/licenses}/LICENSE +0 -0
experimaestro/mypy.py
CHANGED
|
@@ -1,15 +1,460 @@
|
|
|
1
|
-
|
|
1
|
+
"""Mypy plugin for experimaestro.
|
|
2
2
|
|
|
3
|
+
This plugin provides type hints support for experimaestro's Config system,
|
|
4
|
+
particularly for the Config.C pattern and proper parameter type inference.
|
|
3
5
|
|
|
4
|
-
|
|
5
|
-
|
|
6
|
+
The plugin handles:
|
|
7
|
+
- Config.C, Config.XPMConfig, Config.XPMValue class properties
|
|
8
|
+
- Adding __init__ with proper Param field signatures
|
|
9
|
+
- Adding ConfigMixin to the class hierarchy for method access
|
|
10
|
+
- Handling task_outputs return type for submit()
|
|
6
11
|
|
|
7
|
-
|
|
8
|
-
|
|
12
|
+
Usage in mypy.ini or pyproject.toml:
|
|
13
|
+
[mypy]
|
|
14
|
+
plugins = experimaestro.mypy
|
|
15
|
+
|
|
16
|
+
Or in pyproject.toml:
|
|
17
|
+
[tool.mypy]
|
|
18
|
+
plugins = ["experimaestro.mypy"]
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
from typing import Callable, List, Optional
|
|
24
|
+
|
|
25
|
+
from mypy.nodes import (
|
|
26
|
+
TypeInfo,
|
|
27
|
+
Var,
|
|
28
|
+
Argument,
|
|
29
|
+
ARG_NAMED_OPT,
|
|
30
|
+
ARG_NAMED,
|
|
31
|
+
)
|
|
32
|
+
from mypy.plugin import Plugin, ClassDefContext
|
|
33
|
+
from mypy.plugins.common import add_attribute_to_class, add_method_to_class
|
|
34
|
+
from mypy.types import (
|
|
35
|
+
Instance,
|
|
36
|
+
TypeType,
|
|
37
|
+
NoneType,
|
|
38
|
+
)
|
|
39
|
+
from mypy.mro import calculate_mro, MroError
|
|
40
|
+
|
|
41
|
+
# Full names of Config and its subclasses that need C/XPMConfig attributes
|
|
42
|
+
CONFIG_FULLNAMES = {
|
|
43
|
+
"experimaestro.core.objects.config.Config",
|
|
44
|
+
"experimaestro.core.objects.config.LightweightTask",
|
|
45
|
+
"experimaestro.core.objects.config.Task",
|
|
46
|
+
"experimaestro.core.objects.config.ResumableTask",
|
|
47
|
+
"experimaestro.Config",
|
|
48
|
+
"experimaestro.Task",
|
|
49
|
+
"experimaestro.LightweightTask",
|
|
50
|
+
"experimaestro.ResumableTask",
|
|
51
|
+
"experimaestro.core.objects.Config",
|
|
52
|
+
"experimaestro.core.objects.Task",
|
|
53
|
+
"experimaestro.core.objects.LightweightTask",
|
|
54
|
+
"experimaestro.core.objects.ResumableTask",
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
# ConfigMixin full name for method inheritance
|
|
58
|
+
CONFIGMIXIN_FULLNAME = "experimaestro.core.objects.config.ConfigMixin"
|
|
59
|
+
|
|
60
|
+
# Full names for Param annotations (required by default)
|
|
61
|
+
PARAM_FULLNAMES = {
|
|
62
|
+
"experimaestro.core.arguments.Param",
|
|
63
|
+
"experimaestro.Param",
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
# Full names for Meta/Option annotations (always optional, ignored in identifier)
|
|
67
|
+
META_FULLNAMES = {
|
|
68
|
+
"experimaestro.core.arguments.Meta",
|
|
69
|
+
"experimaestro.Meta",
|
|
70
|
+
"experimaestro.core.arguments.Option",
|
|
71
|
+
"experimaestro.Option",
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
# Full names for Constant annotations (excluded from __init__)
|
|
75
|
+
CONSTANT_FULLNAMES = {
|
|
76
|
+
"experimaestro.core.arguments.Constant",
|
|
77
|
+
"experimaestro.Constant",
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def is_config_subclass(info: TypeInfo) -> bool:
|
|
82
|
+
"""Check if a TypeInfo represents a Config subclass.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
info: The TypeInfo to check
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
True if the type is Config or a subclass of Config
|
|
89
|
+
"""
|
|
90
|
+
if info.fullname in CONFIG_FULLNAMES:
|
|
91
|
+
return True
|
|
92
|
+
for base in info.mro:
|
|
93
|
+
if base.fullname in CONFIG_FULLNAMES:
|
|
94
|
+
return True
|
|
95
|
+
return False
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# Fields to skip when building __init__ signature
|
|
99
|
+
SKIP_FIELDS = {
|
|
100
|
+
"C",
|
|
101
|
+
"XPMConfig",
|
|
102
|
+
"XPMValue",
|
|
103
|
+
"__xpm__",
|
|
104
|
+
"__xpmtype__",
|
|
105
|
+
"__xpmid__",
|
|
106
|
+
"_deprecated_from",
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _is_config_class(base: TypeInfo) -> bool:
|
|
111
|
+
"""Check if a TypeInfo is a Config subclass.
|
|
112
|
+
|
|
113
|
+
Returns True for user-defined Config subclasses.
|
|
114
|
+
"""
|
|
115
|
+
for mro_base in base.mro:
|
|
116
|
+
if mro_base.fullname in CONFIG_FULLNAMES:
|
|
117
|
+
return True
|
|
118
|
+
return False
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _get_annotation_type_str(name: str, base: TypeInfo) -> Optional[str]:
|
|
122
|
+
"""Get the type annotation string for a field.
|
|
123
|
+
|
|
124
|
+
Tries multiple sources to find the original annotation:
|
|
125
|
+
1. The AST unanalyzed_type (preserves original)
|
|
126
|
+
2. The variable's type string
|
|
127
|
+
"""
|
|
128
|
+
# Check the AST first to get unanalyzed types
|
|
129
|
+
if base.defn is not None:
|
|
130
|
+
for stmt in base.defn.defs.body:
|
|
131
|
+
from mypy.nodes import AssignmentStmt
|
|
132
|
+
|
|
133
|
+
if isinstance(stmt, AssignmentStmt):
|
|
134
|
+
for lvalue in stmt.lvalues:
|
|
135
|
+
from mypy.nodes import NameExpr
|
|
136
|
+
|
|
137
|
+
if isinstance(lvalue, NameExpr) and lvalue.name == name:
|
|
138
|
+
# Try unanalyzed_type first (preserves the original annotation)
|
|
139
|
+
if stmt.unanalyzed_type is not None:
|
|
140
|
+
return str(stmt.unanalyzed_type)
|
|
141
|
+
# Fall back to analyzed type
|
|
142
|
+
if stmt.type is not None:
|
|
143
|
+
return str(stmt.type)
|
|
144
|
+
|
|
145
|
+
# Fall back to checking the symbol's type
|
|
146
|
+
if name in base.names:
|
|
147
|
+
sym = base.names[name]
|
|
148
|
+
if sym.node is not None and isinstance(sym.node, Var):
|
|
149
|
+
var = sym.node
|
|
150
|
+
if var.type is not None:
|
|
151
|
+
return str(var.type)
|
|
152
|
+
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _is_constant_field(name: str, base: TypeInfo) -> bool:
|
|
157
|
+
"""Check if a field is declared as Constant[T].
|
|
158
|
+
|
|
159
|
+
Constant fields should be excluded from __init__.
|
|
160
|
+
"""
|
|
161
|
+
type_str = _get_annotation_type_str(name, base)
|
|
162
|
+
if type_str is None:
|
|
163
|
+
return False
|
|
164
|
+
|
|
165
|
+
# Normalize type string - remove optional markers (?)
|
|
166
|
+
# mypy represents types like "Constant?[str?]"
|
|
167
|
+
type_lower = type_str.lower().replace("?", "")
|
|
168
|
+
|
|
169
|
+
# Check for Constant annotation in the type string
|
|
170
|
+
if "constant[" in type_lower:
|
|
171
|
+
return True
|
|
172
|
+
for fullname in CONSTANT_FULLNAMES:
|
|
173
|
+
if fullname.lower() in type_lower:
|
|
174
|
+
return True
|
|
175
|
+
return False
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _is_meta_field(name: str, base: TypeInfo) -> bool:
|
|
179
|
+
"""Check if a field is declared as Meta[T] or Option[T].
|
|
180
|
+
|
|
181
|
+
Meta fields should always be optional in __init__.
|
|
182
|
+
"""
|
|
183
|
+
type_str = _get_annotation_type_str(name, base)
|
|
184
|
+
if type_str is None:
|
|
185
|
+
return False
|
|
186
|
+
|
|
187
|
+
# Normalize type string - remove optional markers (?)
|
|
188
|
+
# mypy represents types like "Meta?[Path?]"
|
|
189
|
+
type_lower = type_str.lower().replace("?", "")
|
|
190
|
+
|
|
191
|
+
# Check for Meta/Option annotation in the type string
|
|
192
|
+
if "meta[" in type_lower or "option[" in type_lower:
|
|
193
|
+
return True
|
|
194
|
+
for fullname in META_FULLNAMES:
|
|
195
|
+
if fullname.lower() in type_lower:
|
|
196
|
+
return True
|
|
197
|
+
return False
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _get_param_fields(info: TypeInfo) -> List[tuple]:
|
|
201
|
+
"""Extract Param and Meta fields from a class and its bases.
|
|
202
|
+
|
|
203
|
+
Returns list of (name, type, has_default) tuples.
|
|
204
|
+
|
|
205
|
+
Only includes fields from Config subclasses to avoid picking up
|
|
206
|
+
attributes from other base classes like nn.Module.
|
|
207
|
+
Excludes Constant fields which should not be in __init__.
|
|
208
|
+
"""
|
|
209
|
+
fields = []
|
|
210
|
+
seen = set()
|
|
211
|
+
|
|
212
|
+
# Walk MRO to get inherited fields (in reverse to get proper order)
|
|
213
|
+
for base in reversed(info.mro):
|
|
214
|
+
if base.fullname == "builtins.object":
|
|
215
|
+
continue
|
|
216
|
+
if base.fullname in CONFIG_FULLNAMES:
|
|
217
|
+
# Skip Config/Task base classes - we only want user-defined fields
|
|
218
|
+
continue
|
|
219
|
+
if base.fullname == CONFIGMIXIN_FULLNAME:
|
|
220
|
+
# Skip ConfigMixin - it has methods, not params
|
|
221
|
+
continue
|
|
222
|
+
|
|
223
|
+
# Only include fields from Config subclasses
|
|
224
|
+
# This skips bases like nn.Module that don't inherit from Config
|
|
225
|
+
if not _is_config_class(base):
|
|
226
|
+
continue
|
|
227
|
+
|
|
228
|
+
for name, sym in base.names.items():
|
|
229
|
+
if name in seen or name in SKIP_FIELDS:
|
|
230
|
+
continue
|
|
231
|
+
if sym.node is None or not isinstance(sym.node, Var):
|
|
232
|
+
continue
|
|
233
|
+
|
|
234
|
+
var = sym.node
|
|
235
|
+
if var.type is None:
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
# Skip private/dunder fields
|
|
239
|
+
if name.startswith("_"):
|
|
240
|
+
continue
|
|
9
241
|
|
|
10
|
-
|
|
242
|
+
# Skip Constant fields - they should not be in __init__
|
|
243
|
+
if _is_constant_field(name, base):
|
|
244
|
+
continue
|
|
245
|
+
|
|
246
|
+
# Meta fields are always optional
|
|
247
|
+
# Param fields are optional only if they have a default
|
|
248
|
+
is_meta = _is_meta_field(name, base)
|
|
249
|
+
has_default = var.has_explicit_value or is_meta
|
|
250
|
+
|
|
251
|
+
seen.add(name)
|
|
252
|
+
fields.append((name, var.type, has_default))
|
|
253
|
+
|
|
254
|
+
return fields
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _add_init_method(ctx: ClassDefContext) -> None:
|
|
258
|
+
"""Add an __init__ method with proper Param field signatures."""
|
|
259
|
+
info = ctx.cls.info
|
|
260
|
+
|
|
261
|
+
# Get all Param fields from this class and bases
|
|
262
|
+
fields = _get_param_fields(info)
|
|
263
|
+
|
|
264
|
+
# Build __init__ arguments
|
|
265
|
+
args = []
|
|
266
|
+
for name, field_type, has_default in fields:
|
|
267
|
+
# All experimaestro params are keyword-only
|
|
268
|
+
# Fields with defaults are optional
|
|
269
|
+
kind = ARG_NAMED_OPT if has_default else ARG_NAMED
|
|
270
|
+
|
|
271
|
+
# Create argument
|
|
272
|
+
arg = Argument(
|
|
273
|
+
variable=Var(name, field_type),
|
|
274
|
+
type_annotation=field_type,
|
|
275
|
+
initializer=None,
|
|
276
|
+
kind=kind,
|
|
277
|
+
)
|
|
278
|
+
args.append(arg)
|
|
279
|
+
|
|
280
|
+
# Add the __init__ method if we have any args
|
|
281
|
+
if args:
|
|
282
|
+
add_method_to_class(
|
|
283
|
+
ctx.api,
|
|
284
|
+
ctx.cls,
|
|
285
|
+
"__init__",
|
|
286
|
+
args,
|
|
287
|
+
NoneType(),
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def _get_task_outputs_return_type(info: TypeInfo) -> Optional[Instance]:
|
|
292
|
+
"""Check if the class has a task_outputs method and return its return type.
|
|
293
|
+
|
|
294
|
+
If the class defines task_outputs, submit() should return that type instead
|
|
295
|
+
of Self.
|
|
296
|
+
"""
|
|
297
|
+
# Look for task_outputs method in the class
|
|
298
|
+
if "task_outputs" in info.names:
|
|
299
|
+
sym = info.names["task_outputs"]
|
|
300
|
+
if sym.node is not None:
|
|
301
|
+
# Try to get the return type from the method signature
|
|
302
|
+
from mypy.nodes import FuncDef
|
|
303
|
+
|
|
304
|
+
if isinstance(sym.node, FuncDef):
|
|
305
|
+
ret_type = sym.node.type
|
|
306
|
+
if ret_type is not None:
|
|
307
|
+
from mypy.types import CallableType
|
|
308
|
+
|
|
309
|
+
if isinstance(ret_type, CallableType):
|
|
310
|
+
return ret_type.ret_type
|
|
311
|
+
return None
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _add_configmixin_to_bases(ctx: ClassDefContext) -> None:
|
|
315
|
+
"""Add ConfigMixin to the class bases if not already present.
|
|
316
|
+
|
|
317
|
+
This allows mypy to see all ConfigMixin methods on Config subclasses.
|
|
318
|
+
"""
|
|
319
|
+
info = ctx.cls.info
|
|
320
|
+
|
|
321
|
+
# Check if ConfigMixin is already in the MRO
|
|
322
|
+
for base in info.mro:
|
|
323
|
+
if base.fullname == CONFIGMIXIN_FULLNAME:
|
|
324
|
+
return # Already has ConfigMixin
|
|
325
|
+
|
|
326
|
+
# Try to look up ConfigMixin
|
|
327
|
+
try:
|
|
328
|
+
configmixin_sym = ctx.api.lookup_fully_qualified_or_none(CONFIGMIXIN_FULLNAME)
|
|
329
|
+
if configmixin_sym is None or not isinstance(configmixin_sym.node, TypeInfo):
|
|
330
|
+
return
|
|
331
|
+
|
|
332
|
+
configmixin_info = configmixin_sym.node
|
|
333
|
+
configmixin_instance = Instance(configmixin_info, [])
|
|
334
|
+
|
|
335
|
+
# Add ConfigMixin to bases if not already present
|
|
336
|
+
configmixin_in_bases = any(
|
|
337
|
+
isinstance(b, Instance) and b.type.fullname == CONFIGMIXIN_FULLNAME
|
|
338
|
+
for b in info.bases
|
|
339
|
+
)
|
|
340
|
+
if not configmixin_in_bases:
|
|
341
|
+
info.bases.append(configmixin_instance)
|
|
342
|
+
|
|
343
|
+
# Recalculate MRO
|
|
344
|
+
try:
|
|
345
|
+
calculate_mro(info)
|
|
346
|
+
except MroError:
|
|
347
|
+
# If MRO calculation fails, remove the base we added
|
|
348
|
+
info.bases.pop()
|
|
349
|
+
except Exception:
|
|
350
|
+
# If lookup fails, continue without adding ConfigMixin
|
|
11
351
|
pass
|
|
12
352
|
|
|
13
353
|
|
|
14
|
-
def
|
|
354
|
+
def _add_submit_method(ctx: ClassDefContext) -> None:
|
|
355
|
+
"""Add submit() method that returns Self (or task_outputs return type).
|
|
356
|
+
|
|
357
|
+
The actual submit() signature from ConfigMixin:
|
|
358
|
+
def submit(self, *, workspace=None, launcher=None, run_mode=None,
|
|
359
|
+
init_tasks=[], max_retries=None)
|
|
360
|
+
"""
|
|
361
|
+
info = ctx.cls.info
|
|
362
|
+
|
|
363
|
+
# Check if the class has task_outputs
|
|
364
|
+
task_outputs_type = _get_task_outputs_return_type(info)
|
|
365
|
+
|
|
366
|
+
# submit() returns task_outputs return type if defined, otherwise Self
|
|
367
|
+
if task_outputs_type is not None:
|
|
368
|
+
return_type = task_outputs_type
|
|
369
|
+
else:
|
|
370
|
+
return_type = Instance(info, [])
|
|
371
|
+
|
|
372
|
+
# Build submit() arguments - all optional kwargs
|
|
373
|
+
from mypy.types import AnyType, TypeOfAny
|
|
374
|
+
|
|
375
|
+
any_type = AnyType(TypeOfAny.explicit)
|
|
376
|
+
submit_args = []
|
|
377
|
+
for arg_name in ("workspace", "launcher", "run_mode", "init_tasks", "max_retries"):
|
|
378
|
+
arg = Argument(
|
|
379
|
+
variable=Var(arg_name, any_type),
|
|
380
|
+
type_annotation=any_type,
|
|
381
|
+
initializer=None,
|
|
382
|
+
kind=ARG_NAMED_OPT,
|
|
383
|
+
)
|
|
384
|
+
submit_args.append(arg)
|
|
385
|
+
|
|
386
|
+
add_method_to_class(
|
|
387
|
+
ctx.api,
|
|
388
|
+
ctx.cls,
|
|
389
|
+
"submit",
|
|
390
|
+
submit_args,
|
|
391
|
+
return_type,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def _process_config_class(ctx: ClassDefContext) -> None:
|
|
396
|
+
"""Process a Config subclass to add type hints.
|
|
397
|
+
|
|
398
|
+
This adds:
|
|
399
|
+
- ConfigMixin to the class hierarchy for method access
|
|
400
|
+
- C, XPMConfig, XPMValue as class attributes returning Type[Self]
|
|
401
|
+
- An __init__ method with proper Param field signatures
|
|
402
|
+
- A submit() method that returns Self (or task_outputs return type)
|
|
403
|
+
"""
|
|
404
|
+
info = ctx.cls.info
|
|
405
|
+
|
|
406
|
+
# Add ConfigMixin to bases for method access (tag, instance, etc.)
|
|
407
|
+
_add_configmixin_to_bases(ctx)
|
|
408
|
+
|
|
409
|
+
# Create Type[Self] for this class
|
|
410
|
+
class_type = Instance(info, [])
|
|
411
|
+
type_type = TypeType(class_type)
|
|
412
|
+
|
|
413
|
+
# Add C, XPMConfig, XPMValue as class attributes returning the class type
|
|
414
|
+
for attr_name in ("C", "XPMConfig", "XPMValue"):
|
|
415
|
+
if attr_name not in info.names:
|
|
416
|
+
add_attribute_to_class(
|
|
417
|
+
ctx.api,
|
|
418
|
+
ctx.cls,
|
|
419
|
+
attr_name,
|
|
420
|
+
type_type,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
# Add __init__ with proper field signatures
|
|
424
|
+
_add_init_method(ctx)
|
|
425
|
+
|
|
426
|
+
# Add submit() method that returns Self (or task_outputs type)
|
|
427
|
+
_add_submit_method(ctx)
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
class ExperimaestroPlugin(Plugin):
|
|
431
|
+
"""Mypy plugin for experimaestro type hints.
|
|
432
|
+
|
|
433
|
+
This plugin handles:
|
|
434
|
+
- Converting @classproperty decorated methods to proper class attributes
|
|
435
|
+
- Type inference for Config.C and Config.XPMConfig patterns
|
|
436
|
+
- Adding __init__ methods with proper Param field signatures
|
|
437
|
+
"""
|
|
438
|
+
|
|
439
|
+
def get_base_class_hook(
|
|
440
|
+
self, fullname: str
|
|
441
|
+
) -> Callable[[ClassDefContext], None] | None:
|
|
442
|
+
"""Hook called when a class inherits from Config.
|
|
443
|
+
|
|
444
|
+
This allows us to process classproperty attributes and add __init__.
|
|
445
|
+
"""
|
|
446
|
+
if fullname in CONFIG_FULLNAMES:
|
|
447
|
+
return _process_config_class
|
|
448
|
+
return None
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def plugin(_version: str):
|
|
452
|
+
"""Entry point for mypy plugin.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
_version: The mypy version string (unused but required by mypy API)
|
|
456
|
+
|
|
457
|
+
Returns:
|
|
458
|
+
The ExperimaestroPlugin class
|
|
459
|
+
"""
|
|
15
460
|
return ExperimaestroPlugin
|
experimaestro/notifications.py
CHANGED
|
@@ -12,6 +12,7 @@ from tqdm.auto import tqdm as std_tqdm
|
|
|
12
12
|
|
|
13
13
|
from .utils import logger
|
|
14
14
|
from experimaestro.taskglobals import Env as TaskEnv
|
|
15
|
+
from .progress import FileBasedProgressReporter
|
|
15
16
|
|
|
16
17
|
# --- Progress and other notifications
|
|
17
18
|
|
|
@@ -41,7 +42,7 @@ class LevelInformation:
|
|
|
41
42
|
return result
|
|
42
43
|
|
|
43
44
|
def __repr__(self) -> str:
|
|
44
|
-
return f"[{self.level}] {self.desc} {int(self.progress*1000)/10}%"
|
|
45
|
+
return f"[{self.level}] {self.desc} {int(self.progress * 1000) / 10}%"
|
|
45
46
|
|
|
46
47
|
|
|
47
48
|
class ListenerInformation:
|
|
@@ -79,10 +80,14 @@ class Reporter(threading.Thread):
|
|
|
79
80
|
self.progress_threshold = 0.01
|
|
80
81
|
self.cv = threading.Condition()
|
|
81
82
|
|
|
83
|
+
# File-based progress reporter
|
|
84
|
+
self.file_reporter = FileBasedProgressReporter(task_path=path)
|
|
85
|
+
|
|
82
86
|
def stop(self):
|
|
83
87
|
self.stopping = True
|
|
84
88
|
with self.cv:
|
|
85
|
-
self.cv.notifyAll()
|
|
89
|
+
# self.cv.notifyAll()
|
|
90
|
+
self.cv.notify_all()
|
|
86
91
|
|
|
87
92
|
@staticmethod
|
|
88
93
|
def isfatal_httperror(e: Exception, info: ListenerInformation) -> bool:
|
|
@@ -110,14 +115,27 @@ class Reporter(threading.Thread):
|
|
|
110
115
|
|
|
111
116
|
def check_urls(self):
|
|
112
117
|
"""Check whether we have new schedulers to notify"""
|
|
113
|
-
|
|
118
|
+
# Check if path exists (it might have been deleted during cleanup)
|
|
119
|
+
if not self.path.exists():
|
|
120
|
+
return
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
mtime = os.path.getmtime(self.path)
|
|
124
|
+
except (OSError, FileNotFoundError):
|
|
125
|
+
# Path was deleted while we were checking
|
|
126
|
+
return
|
|
127
|
+
|
|
114
128
|
if mtime > self.lastcheck:
|
|
115
129
|
for f in self.path.iterdir():
|
|
116
130
|
self.urls[f.name] = ListenerInformation(f.read_text().strip())
|
|
117
131
|
logger.info("Added new notification URL: %s", self.urls[f.name].url)
|
|
118
132
|
f.unlink()
|
|
119
133
|
|
|
120
|
-
|
|
134
|
+
try:
|
|
135
|
+
self.lastcheck = os.path.getmtime(self.path)
|
|
136
|
+
except (OSError, FileNotFoundError):
|
|
137
|
+
# Path was deleted during iteration
|
|
138
|
+
return
|
|
121
139
|
|
|
122
140
|
def run(self):
|
|
123
141
|
logger.info("Running notification thread")
|
|
@@ -186,7 +204,7 @@ class Reporter(threading.Thread):
|
|
|
186
204
|
try:
|
|
187
205
|
with urlopen(url) as _:
|
|
188
206
|
logger.debug(
|
|
189
|
-
"EOJ
|
|
207
|
+
"EOJ notification sent for %s",
|
|
190
208
|
baseurl,
|
|
191
209
|
)
|
|
192
210
|
except Exception:
|
|
@@ -194,6 +212,8 @@ class Reporter(threading.Thread):
|
|
|
194
212
|
"Could not report EOJ",
|
|
195
213
|
)
|
|
196
214
|
|
|
215
|
+
self.file_reporter.eoj()
|
|
216
|
+
|
|
197
217
|
def set_progress(
|
|
198
218
|
self, progress: float, level: int, desc: Optional[str], console=False
|
|
199
219
|
):
|
|
@@ -212,6 +232,8 @@ class Reporter(threading.Thread):
|
|
|
212
232
|
self.levels[level].desc = desc
|
|
213
233
|
self.levels[level].progress = progress
|
|
214
234
|
|
|
235
|
+
self.file_reporter.set_progress(progress, level, desc)
|
|
236
|
+
|
|
215
237
|
self.cv.notify_all()
|
|
216
238
|
|
|
217
239
|
INSTANCE: ClassVar[Optional["Reporter"]] = None
|
|
@@ -227,12 +249,21 @@ class Reporter(threading.Thread):
|
|
|
227
249
|
|
|
228
250
|
|
|
229
251
|
def progress(value: float, level=0, desc: Optional[str] = None, console=False):
|
|
230
|
-
"""
|
|
252
|
+
"""Report task progress to the experimaestro server.
|
|
231
253
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
254
|
+
Call this function from within a running task to report progress.
|
|
255
|
+
Progress is displayed in the web UI and TUI monitors.
|
|
256
|
+
|
|
257
|
+
Example::
|
|
258
|
+
|
|
259
|
+
for i, batch in enumerate(dataloader):
|
|
260
|
+
train(batch)
|
|
261
|
+
progress(i / len(dataloader), desc="Training")
|
|
262
|
+
|
|
263
|
+
:param value: Progress value between 0.0 and 1.0
|
|
264
|
+
:param level: Nesting level for nested progress bars (default: 0)
|
|
265
|
+
:param desc: Optional description of the current operation
|
|
266
|
+
:param console: If True, also print to console when no server is available
|
|
236
267
|
"""
|
|
237
268
|
if TaskEnv.instance().slave:
|
|
238
269
|
# Skip if in a slave process
|
|
@@ -246,8 +277,21 @@ def report_eoj():
|
|
|
246
277
|
|
|
247
278
|
|
|
248
279
|
class xpm_tqdm(std_tqdm):
|
|
249
|
-
"""
|
|
250
|
-
|
|
280
|
+
"""Experimaestro-aware tqdm progress bar.
|
|
281
|
+
|
|
282
|
+
A drop-in replacement for ``tqdm`` that automatically reports progress
|
|
283
|
+
to the experimaestro server. Use this instead of the standard ``tqdm``
|
|
284
|
+
in your task's ``execute()`` method.
|
|
285
|
+
|
|
286
|
+
Example::
|
|
287
|
+
|
|
288
|
+
from experimaestro import tqdm
|
|
289
|
+
|
|
290
|
+
class MyTask(Task):
|
|
291
|
+
def execute(self):
|
|
292
|
+
for batch in tqdm(dataloader, desc="Training"):
|
|
293
|
+
train(batch)
|
|
294
|
+
"""
|
|
251
295
|
|
|
252
296
|
def __init__(self, iterable=None, file=None, *args, **kwargs):
|
|
253
297
|
# Report progress bar
|
|
@@ -270,14 +314,29 @@ class xpm_tqdm(std_tqdm):
|
|
|
270
314
|
|
|
271
315
|
|
|
272
316
|
@overload
|
|
273
|
-
def tqdm(**kwargs) -> xpm_tqdm:
|
|
274
|
-
...
|
|
317
|
+
def tqdm(**kwargs) -> xpm_tqdm: ...
|
|
275
318
|
|
|
276
319
|
|
|
277
320
|
@overload
|
|
278
|
-
def tqdm(iterable: Optional[Iterator[T]] = None, **kwargs) -> Iterator[T]:
|
|
279
|
-
...
|
|
321
|
+
def tqdm(iterable: Optional[Iterator[T]] = None, **kwargs) -> Iterator[T]: ...
|
|
280
322
|
|
|
281
323
|
|
|
282
324
|
def tqdm(*args, **kwargs):
|
|
325
|
+
"""Create an experimaestro-aware progress bar.
|
|
326
|
+
|
|
327
|
+
A drop-in replacement for ``tqdm.tqdm`` that automatically reports progress
|
|
328
|
+
to the experimaestro server. Use this in task ``execute()`` methods.
|
|
329
|
+
|
|
330
|
+
Example::
|
|
331
|
+
|
|
332
|
+
from experimaestro import tqdm
|
|
333
|
+
|
|
334
|
+
for epoch in tqdm(range(100), desc="Epochs"):
|
|
335
|
+
for batch in tqdm(dataloader, desc="Batches"):
|
|
336
|
+
train(batch)
|
|
337
|
+
|
|
338
|
+
:param iterable: Iterable to wrap (optional)
|
|
339
|
+
:param kwargs: Additional arguments passed to tqdm
|
|
340
|
+
:return: A progress bar iterator
|
|
341
|
+
"""
|
|
283
342
|
return xpm_tqdm(*args, **kwargs) # type: ignore
|