experimaestro 2.0.0a8__py3-none-any.whl → 2.0.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of experimaestro might be problematic. Click here for more details.
- experimaestro/__init__.py +10 -11
- experimaestro/annotations.py +167 -206
- experimaestro/cli/__init__.py +130 -5
- experimaestro/cli/filter.py +42 -74
- experimaestro/cli/jobs.py +157 -106
- experimaestro/cli/refactor.py +249 -0
- experimaestro/click.py +0 -1
- experimaestro/commandline.py +19 -3
- experimaestro/connectors/__init__.py +20 -1
- experimaestro/connectors/local.py +12 -0
- experimaestro/core/arguments.py +182 -46
- experimaestro/core/identifier.py +107 -6
- experimaestro/core/objects/__init__.py +6 -0
- experimaestro/core/objects/config.py +542 -25
- experimaestro/core/objects/config_walk.py +20 -0
- experimaestro/core/serialization.py +91 -34
- experimaestro/core/subparameters.py +164 -0
- experimaestro/core/types.py +175 -38
- experimaestro/exceptions.py +26 -0
- experimaestro/experiments/cli.py +107 -25
- experimaestro/generators.py +50 -9
- experimaestro/huggingface.py +3 -1
- experimaestro/launcherfinder/parser.py +29 -0
- experimaestro/launchers/__init__.py +26 -1
- experimaestro/launchers/direct.py +12 -0
- experimaestro/launchers/slurm/base.py +154 -2
- experimaestro/mkdocs/metaloader.py +0 -1
- experimaestro/mypy.py +452 -7
- experimaestro/notifications.py +63 -13
- experimaestro/progress.py +0 -2
- experimaestro/rpyc.py +0 -1
- experimaestro/run.py +19 -6
- experimaestro/scheduler/base.py +489 -125
- experimaestro/scheduler/dependencies.py +43 -28
- experimaestro/scheduler/dynamic_outputs.py +259 -130
- experimaestro/scheduler/experiment.py +225 -30
- experimaestro/scheduler/interfaces.py +474 -0
- experimaestro/scheduler/jobs.py +216 -206
- experimaestro/scheduler/services.py +186 -12
- experimaestro/scheduler/state_db.py +388 -0
- experimaestro/scheduler/state_provider.py +2345 -0
- experimaestro/scheduler/state_sync.py +834 -0
- experimaestro/scheduler/workspace.py +52 -10
- experimaestro/scriptbuilder.py +7 -0
- experimaestro/server/__init__.py +147 -57
- experimaestro/server/data/index.css +0 -125
- experimaestro/server/data/index.css.map +1 -1
- experimaestro/server/data/index.js +194 -58
- experimaestro/server/data/index.js.map +1 -1
- experimaestro/settings.py +44 -5
- experimaestro/sphinx/__init__.py +3 -3
- experimaestro/taskglobals.py +20 -0
- experimaestro/tests/conftest.py +80 -0
- experimaestro/tests/core/test_generics.py +2 -2
- experimaestro/tests/identifier_stability.json +45 -0
- experimaestro/tests/launchers/bin/sacct +6 -2
- experimaestro/tests/launchers/bin/sbatch +4 -2
- experimaestro/tests/launchers/test_slurm.py +80 -0
- experimaestro/tests/tasks/test_dynamic.py +231 -0
- experimaestro/tests/test_cli_jobs.py +615 -0
- experimaestro/tests/test_deprecated.py +630 -0
- experimaestro/tests/test_environment.py +200 -0
- experimaestro/tests/test_file_progress_integration.py +1 -1
- experimaestro/tests/test_forward.py +3 -3
- experimaestro/tests/test_identifier.py +372 -41
- experimaestro/tests/test_identifier_stability.py +458 -0
- experimaestro/tests/test_instance.py +3 -3
- experimaestro/tests/test_multitoken.py +442 -0
- experimaestro/tests/test_mypy.py +433 -0
- experimaestro/tests/test_objects.py +312 -5
- experimaestro/tests/test_outputs.py +2 -2
- experimaestro/tests/test_param.py +8 -12
- experimaestro/tests/test_partial_paths.py +231 -0
- experimaestro/tests/test_progress.py +0 -48
- experimaestro/tests/test_resumable_task.py +480 -0
- experimaestro/tests/test_serializers.py +141 -1
- experimaestro/tests/test_state_db.py +434 -0
- experimaestro/tests/test_subparameters.py +160 -0
- experimaestro/tests/test_tags.py +136 -0
- experimaestro/tests/test_tasks.py +107 -121
- experimaestro/tests/test_token_locking.py +252 -0
- experimaestro/tests/test_tokens.py +17 -13
- experimaestro/tests/test_types.py +123 -1
- experimaestro/tests/test_workspace_triggers.py +158 -0
- experimaestro/tests/token_reschedule.py +4 -2
- experimaestro/tests/utils.py +2 -2
- experimaestro/tokens.py +154 -57
- experimaestro/tools/diff.py +1 -1
- experimaestro/tui/__init__.py +8 -0
- experimaestro/tui/app.py +2303 -0
- experimaestro/tui/app.tcss +353 -0
- experimaestro/tui/log_viewer.py +228 -0
- experimaestro/utils/__init__.py +23 -0
- experimaestro/utils/environment.py +148 -0
- experimaestro/utils/git.py +129 -0
- experimaestro/utils/resources.py +1 -1
- experimaestro/version.py +34 -0
- {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b4.dist-info}/METADATA +68 -38
- experimaestro-2.0.0b4.dist-info/RECORD +181 -0
- {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b4.dist-info}/WHEEL +1 -1
- experimaestro-2.0.0b4.dist-info/entry_points.txt +16 -0
- experimaestro/compat.py +0 -6
- experimaestro/core/objects.pyi +0 -221
- experimaestro/server/data/0c35d18bf06992036b69.woff2 +0 -0
- experimaestro/server/data/219aa9140e099e6c72ed.woff2 +0 -0
- experimaestro/server/data/3a4004a46a653d4b2166.woff +0 -0
- experimaestro/server/data/3baa5b8f3469222b822d.woff +0 -0
- experimaestro/server/data/4d73cb90e394b34b7670.woff +0 -0
- experimaestro/server/data/4ef4218c522f1eb6b5b1.woff2 +0 -0
- experimaestro/server/data/5d681e2edae8c60630db.woff +0 -0
- experimaestro/server/data/6f420cf17cc0d7676fad.woff2 +0 -0
- experimaestro/server/data/c380809fd3677d7d6903.woff2 +0 -0
- experimaestro/server/data/f882956fd323fd322f31.woff +0 -0
- experimaestro-2.0.0a8.dist-info/RECORD +0 -166
- experimaestro-2.0.0a8.dist-info/entry_points.txt +0 -17
- {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b4.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,8 +3,20 @@ from . import Launcher
|
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
class DirectLauncher(Launcher):
|
|
6
|
+
"""Launcher that runs tasks directly as local processes.
|
|
7
|
+
|
|
8
|
+
This is the default launcher that executes tasks on the local machine
|
|
9
|
+
without any job scheduler. Tasks are run as Python subprocesses.
|
|
10
|
+
|
|
11
|
+
:param connector: The connector to use (defaults to LocalConnector)
|
|
12
|
+
"""
|
|
13
|
+
|
|
6
14
|
def scriptbuilder(self):
|
|
7
15
|
return PythonScriptBuilder()
|
|
8
16
|
|
|
17
|
+
def launcher_info_code(self) -> str:
|
|
18
|
+
"""Returns empty string as local launcher has no time limits."""
|
|
19
|
+
return ""
|
|
20
|
+
|
|
9
21
|
def __str__(self):
|
|
10
22
|
return f"DirectLauncher({self.connector})"
|
|
@@ -7,6 +7,7 @@ from typing import (
|
|
|
7
7
|
List,
|
|
8
8
|
Optional,
|
|
9
9
|
Tuple,
|
|
10
|
+
TYPE_CHECKING,
|
|
10
11
|
get_type_hints,
|
|
11
12
|
)
|
|
12
13
|
from experimaestro.connectors.local import LocalConnector
|
|
@@ -20,7 +21,7 @@ from experimaestro.launcherfinder.registry import (
|
|
|
20
21
|
from experimaestro.utils import ThreadingCondition
|
|
21
22
|
from experimaestro.tests.connectors.utils import OutputCaptureHandler
|
|
22
23
|
from experimaestro.utils.asyncio import asyncThreadcheck
|
|
23
|
-
from
|
|
24
|
+
from functools import cached_property
|
|
24
25
|
from experimaestro.launchers import Launcher
|
|
25
26
|
from experimaestro.scriptbuilder import PythonScriptBuilder
|
|
26
27
|
from experimaestro.connectors import (
|
|
@@ -32,8 +33,131 @@ from experimaestro.connectors import (
|
|
|
32
33
|
RedirectType,
|
|
33
34
|
)
|
|
34
35
|
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from experimaestro.scheduler.jobs import JobState
|
|
38
|
+
|
|
35
39
|
logger = logging.getLogger("xpm.slurm")
|
|
36
40
|
|
|
41
|
+
# Cached job end time (absolute timestamp).
|
|
42
|
+
# Only used when a task is running within a SLURM job.
|
|
43
|
+
_slurm_job_end_time: Optional[float] = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class SlurmLauncherInformation:
|
|
47
|
+
"""Launcher information for SLURM jobs, used during task execution."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, binpath: str = "/usr/bin"):
|
|
50
|
+
self.binpath = Path(binpath)
|
|
51
|
+
|
|
52
|
+
def remaining_time(self) -> Optional[float]:
|
|
53
|
+
"""Returns the remaining time in seconds before the SLURM job times out.
|
|
54
|
+
|
|
55
|
+
Uses the SLURM_JOB_ID environment variable to query squeue for the
|
|
56
|
+
remaining time. The job end time is cached on first call.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
The remaining time in seconds, or None if no time limit.
|
|
60
|
+
"""
|
|
61
|
+
import os
|
|
62
|
+
import time
|
|
63
|
+
|
|
64
|
+
global _slurm_job_end_time
|
|
65
|
+
|
|
66
|
+
# Use cached end time if available
|
|
67
|
+
if _slurm_job_end_time is not None:
|
|
68
|
+
remaining = _slurm_job_end_time - time.time()
|
|
69
|
+
return max(0.0, remaining)
|
|
70
|
+
|
|
71
|
+
# Query SLURM for remaining time and compute end time
|
|
72
|
+
job_id = os.environ.get("SLURM_JOB_ID")
|
|
73
|
+
if not job_id:
|
|
74
|
+
logger.debug("No SLURM_JOB_ID in environment, cannot get remaining time")
|
|
75
|
+
return None
|
|
76
|
+
|
|
77
|
+
remaining_seconds = self._query_remaining_time(job_id)
|
|
78
|
+
if remaining_seconds is None:
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
# Cache the absolute end time
|
|
82
|
+
_slurm_job_end_time = time.time() + remaining_seconds
|
|
83
|
+
return remaining_seconds
|
|
84
|
+
|
|
85
|
+
def _query_remaining_time(self, job_id: str) -> Optional[float]:
|
|
86
|
+
"""Query SLURM for remaining time of a job."""
|
|
87
|
+
import subprocess
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
result = subprocess.run(
|
|
91
|
+
[
|
|
92
|
+
f"{self.binpath}/squeue",
|
|
93
|
+
"--job",
|
|
94
|
+
job_id,
|
|
95
|
+
"--format=%L",
|
|
96
|
+
"--noheader",
|
|
97
|
+
],
|
|
98
|
+
capture_output=True,
|
|
99
|
+
text=True,
|
|
100
|
+
timeout=30,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if result.returncode != 0:
|
|
104
|
+
logger.warning(
|
|
105
|
+
"squeue returned error code %d: %s",
|
|
106
|
+
result.returncode,
|
|
107
|
+
result.stderr,
|
|
108
|
+
)
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
time_str = result.stdout.strip()
|
|
112
|
+
if not time_str or time_str == "UNLIMITED":
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
return self._parse_slurm_time(time_str)
|
|
116
|
+
except subprocess.TimeoutExpired:
|
|
117
|
+
logger.warning("Timeout querying squeue for remaining time")
|
|
118
|
+
return None
|
|
119
|
+
except Exception as e:
|
|
120
|
+
logger.warning("Error querying SLURM remaining time: %s", e)
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def _parse_slurm_time(time_str: str) -> Optional[float]:
|
|
125
|
+
"""Parse SLURM time format to seconds.
|
|
126
|
+
|
|
127
|
+
SLURM time format can be:
|
|
128
|
+
- D-HH:MM:SS (days-hours:minutes:seconds)
|
|
129
|
+
- HH:MM:SS (hours:minutes:seconds)
|
|
130
|
+
- MM:SS (minutes:seconds)
|
|
131
|
+
- SS (seconds)
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Time in seconds, or None if parsing fails
|
|
135
|
+
"""
|
|
136
|
+
try:
|
|
137
|
+
days = 0
|
|
138
|
+
if "-" in time_str:
|
|
139
|
+
days_str, time_str = time_str.split("-", 1)
|
|
140
|
+
days = int(days_str)
|
|
141
|
+
|
|
142
|
+
parts = time_str.split(":")
|
|
143
|
+
if len(parts) == 3:
|
|
144
|
+
hours, minutes, seconds = int(parts[0]), int(parts[1]), int(parts[2])
|
|
145
|
+
elif len(parts) == 2:
|
|
146
|
+
hours = 0
|
|
147
|
+
minutes, seconds = int(parts[0]), int(parts[1])
|
|
148
|
+
elif len(parts) == 1:
|
|
149
|
+
hours = 0
|
|
150
|
+
minutes = 0
|
|
151
|
+
seconds = int(parts[0])
|
|
152
|
+
else:
|
|
153
|
+
logger.warning("Could not parse SLURM time: %s", time_str)
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
return float(days * 86400 + hours * 3600 + minutes * 60 + seconds)
|
|
157
|
+
except (ValueError, IndexError) as e:
|
|
158
|
+
logger.warning("Could not parse SLURM time '%s': %s", time_str, e)
|
|
159
|
+
return None
|
|
160
|
+
|
|
37
161
|
|
|
38
162
|
class SlurmJobState:
|
|
39
163
|
start: str
|
|
@@ -176,14 +300,34 @@ class BatchSlurmProcess(Process):
|
|
|
176
300
|
def __init__(self, launcher: "SlurmLauncher", jobid: str):
|
|
177
301
|
self.launcher = launcher
|
|
178
302
|
self.jobid = jobid
|
|
303
|
+
self._last_state: Optional[SlurmJobState] = None
|
|
179
304
|
|
|
180
305
|
def wait(self):
|
|
181
306
|
with SlurmProcessWatcher.get(self.launcher) as watcher:
|
|
182
307
|
while True:
|
|
183
308
|
state = watcher.getjob(self.jobid)
|
|
184
309
|
if state and state.finished():
|
|
310
|
+
self._last_state = state
|
|
185
311
|
return 0 if state.slurm_state == "COMPLETED" else 1
|
|
186
312
|
|
|
313
|
+
def get_job_state(self, code: int) -> "JobState":
|
|
314
|
+
"""Convert SLURM exit code to JobState, detecting timeouts"""
|
|
315
|
+
from experimaestro.scheduler.jobs import (
|
|
316
|
+
JobState,
|
|
317
|
+
JobStateError,
|
|
318
|
+
JobFailureStatus,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if code == 0:
|
|
322
|
+
return JobState.DONE
|
|
323
|
+
|
|
324
|
+
# Check if this was a SLURM timeout
|
|
325
|
+
if self._last_state and self._last_state.slurm_state == "TIMEOUT":
|
|
326
|
+
logger.info("SLURM job %s timed out", self.jobid)
|
|
327
|
+
return JobStateError(JobFailureStatus.TIMEOUT)
|
|
328
|
+
|
|
329
|
+
return JobState.ERROR
|
|
330
|
+
|
|
187
331
|
async def aio_state(self, timeout: float | None = None) -> ProcessState:
|
|
188
332
|
def check():
|
|
189
333
|
with SlurmProcessWatcher.get(self.launcher) as watcher:
|
|
@@ -432,7 +576,7 @@ class SlurmLauncher(Launcher):
|
|
|
432
576
|
def scriptbuilder(self):
|
|
433
577
|
"""Returns the script builder
|
|
434
578
|
|
|
435
|
-
We assume
|
|
579
|
+
We assume Unix, but should be changed to PythonScriptBuilder when working
|
|
436
580
|
"""
|
|
437
581
|
return SlurmScriptBuilder(self)
|
|
438
582
|
|
|
@@ -442,6 +586,14 @@ class SlurmLauncher(Launcher):
|
|
|
442
586
|
By default, returns the associated connector builder"""
|
|
443
587
|
return SlurmProcessBuilder(self)
|
|
444
588
|
|
|
589
|
+
def launcher_info_code(self) -> str:
|
|
590
|
+
"""Returns Python code to set up launcher info during task execution."""
|
|
591
|
+
return (
|
|
592
|
+
" from experimaestro.launchers.slurm import SlurmLauncherInformation\n"
|
|
593
|
+
" from experimaestro import taskglobals\n"
|
|
594
|
+
f' taskglobals.Env.instance().launcher_info = SlurmLauncherInformation(binpath="{self.binpath}")\n'
|
|
595
|
+
)
|
|
596
|
+
|
|
445
597
|
|
|
446
598
|
class SlurmScriptBuilder(PythonScriptBuilder):
|
|
447
599
|
def __init__(self, launcher: SlurmLauncher, pythonpath=None):
|
experimaestro/mypy.py
CHANGED
|
@@ -1,15 +1,460 @@
|
|
|
1
|
-
|
|
1
|
+
"""Mypy plugin for experimaestro.
|
|
2
2
|
|
|
3
|
+
This plugin provides type hints support for experimaestro's Config system,
|
|
4
|
+
particularly for the Config.C pattern and proper parameter type inference.
|
|
3
5
|
|
|
4
|
-
|
|
5
|
-
|
|
6
|
+
The plugin handles:
|
|
7
|
+
- Config.C, Config.XPMConfig, Config.XPMValue class properties
|
|
8
|
+
- Adding __init__ with proper Param field signatures
|
|
9
|
+
- Adding ConfigMixin to the class hierarchy for method access
|
|
10
|
+
- Handling task_outputs return type for submit()
|
|
6
11
|
|
|
7
|
-
|
|
8
|
-
|
|
12
|
+
Usage in mypy.ini or pyproject.toml:
|
|
13
|
+
[mypy]
|
|
14
|
+
plugins = experimaestro.mypy
|
|
15
|
+
|
|
16
|
+
Or in pyproject.toml:
|
|
17
|
+
[tool.mypy]
|
|
18
|
+
plugins = ["experimaestro.mypy"]
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
from typing import Callable, List, Optional
|
|
24
|
+
|
|
25
|
+
from mypy.nodes import (
|
|
26
|
+
TypeInfo,
|
|
27
|
+
Var,
|
|
28
|
+
Argument,
|
|
29
|
+
ARG_NAMED_OPT,
|
|
30
|
+
ARG_NAMED,
|
|
31
|
+
)
|
|
32
|
+
from mypy.plugin import Plugin, ClassDefContext
|
|
33
|
+
from mypy.plugins.common import add_attribute_to_class, add_method_to_class
|
|
34
|
+
from mypy.types import (
|
|
35
|
+
Instance,
|
|
36
|
+
TypeType,
|
|
37
|
+
NoneType,
|
|
38
|
+
)
|
|
39
|
+
from mypy.mro import calculate_mro, MroError
|
|
40
|
+
|
|
41
|
+
# Full names of Config and its subclasses that need C/XPMConfig attributes
|
|
42
|
+
CONFIG_FULLNAMES = {
|
|
43
|
+
"experimaestro.core.objects.config.Config",
|
|
44
|
+
"experimaestro.core.objects.config.LightweightTask",
|
|
45
|
+
"experimaestro.core.objects.config.Task",
|
|
46
|
+
"experimaestro.core.objects.config.ResumableTask",
|
|
47
|
+
"experimaestro.Config",
|
|
48
|
+
"experimaestro.Task",
|
|
49
|
+
"experimaestro.LightweightTask",
|
|
50
|
+
"experimaestro.ResumableTask",
|
|
51
|
+
"experimaestro.core.objects.Config",
|
|
52
|
+
"experimaestro.core.objects.Task",
|
|
53
|
+
"experimaestro.core.objects.LightweightTask",
|
|
54
|
+
"experimaestro.core.objects.ResumableTask",
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
# ConfigMixin full name for method inheritance
|
|
58
|
+
CONFIGMIXIN_FULLNAME = "experimaestro.core.objects.config.ConfigMixin"
|
|
59
|
+
|
|
60
|
+
# Full names for Param annotations (required by default)
|
|
61
|
+
PARAM_FULLNAMES = {
|
|
62
|
+
"experimaestro.core.arguments.Param",
|
|
63
|
+
"experimaestro.Param",
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
# Full names for Meta/Option annotations (always optional, ignored in identifier)
|
|
67
|
+
META_FULLNAMES = {
|
|
68
|
+
"experimaestro.core.arguments.Meta",
|
|
69
|
+
"experimaestro.Meta",
|
|
70
|
+
"experimaestro.core.arguments.Option",
|
|
71
|
+
"experimaestro.Option",
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
# Full names for Constant annotations (excluded from __init__)
|
|
75
|
+
CONSTANT_FULLNAMES = {
|
|
76
|
+
"experimaestro.core.arguments.Constant",
|
|
77
|
+
"experimaestro.Constant",
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def is_config_subclass(info: TypeInfo) -> bool:
|
|
82
|
+
"""Check if a TypeInfo represents a Config subclass.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
info: The TypeInfo to check
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
True if the type is Config or a subclass of Config
|
|
89
|
+
"""
|
|
90
|
+
if info.fullname in CONFIG_FULLNAMES:
|
|
91
|
+
return True
|
|
92
|
+
for base in info.mro:
|
|
93
|
+
if base.fullname in CONFIG_FULLNAMES:
|
|
94
|
+
return True
|
|
95
|
+
return False
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# Fields to skip when building __init__ signature
|
|
99
|
+
SKIP_FIELDS = {
|
|
100
|
+
"C",
|
|
101
|
+
"XPMConfig",
|
|
102
|
+
"XPMValue",
|
|
103
|
+
"__xpm__",
|
|
104
|
+
"__xpmtype__",
|
|
105
|
+
"__xpmid__",
|
|
106
|
+
"_deprecated_from",
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _is_config_class(base: TypeInfo) -> bool:
|
|
111
|
+
"""Check if a TypeInfo is a Config subclass.
|
|
112
|
+
|
|
113
|
+
Returns True for user-defined Config subclasses.
|
|
114
|
+
"""
|
|
115
|
+
for mro_base in base.mro:
|
|
116
|
+
if mro_base.fullname in CONFIG_FULLNAMES:
|
|
117
|
+
return True
|
|
118
|
+
return False
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _get_annotation_type_str(name: str, base: TypeInfo) -> Optional[str]:
|
|
122
|
+
"""Get the type annotation string for a field.
|
|
123
|
+
|
|
124
|
+
Tries multiple sources to find the original annotation:
|
|
125
|
+
1. The AST unanalyzed_type (preserves original)
|
|
126
|
+
2. The variable's type string
|
|
127
|
+
"""
|
|
128
|
+
# Check the AST first to get unanalyzed types
|
|
129
|
+
if base.defn is not None:
|
|
130
|
+
for stmt in base.defn.defs.body:
|
|
131
|
+
from mypy.nodes import AssignmentStmt
|
|
132
|
+
|
|
133
|
+
if isinstance(stmt, AssignmentStmt):
|
|
134
|
+
for lvalue in stmt.lvalues:
|
|
135
|
+
from mypy.nodes import NameExpr
|
|
136
|
+
|
|
137
|
+
if isinstance(lvalue, NameExpr) and lvalue.name == name:
|
|
138
|
+
# Try unanalyzed_type first (preserves the original annotation)
|
|
139
|
+
if stmt.unanalyzed_type is not None:
|
|
140
|
+
return str(stmt.unanalyzed_type)
|
|
141
|
+
# Fall back to analyzed type
|
|
142
|
+
if stmt.type is not None:
|
|
143
|
+
return str(stmt.type)
|
|
144
|
+
|
|
145
|
+
# Fall back to checking the symbol's type
|
|
146
|
+
if name in base.names:
|
|
147
|
+
sym = base.names[name]
|
|
148
|
+
if sym.node is not None and isinstance(sym.node, Var):
|
|
149
|
+
var = sym.node
|
|
150
|
+
if var.type is not None:
|
|
151
|
+
return str(var.type)
|
|
152
|
+
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _is_constant_field(name: str, base: TypeInfo) -> bool:
|
|
157
|
+
"""Check if a field is declared as Constant[T].
|
|
158
|
+
|
|
159
|
+
Constant fields should be excluded from __init__.
|
|
160
|
+
"""
|
|
161
|
+
type_str = _get_annotation_type_str(name, base)
|
|
162
|
+
if type_str is None:
|
|
163
|
+
return False
|
|
164
|
+
|
|
165
|
+
# Normalize type string - remove optional markers (?)
|
|
166
|
+
# mypy represents types like "Constant?[str?]"
|
|
167
|
+
type_lower = type_str.lower().replace("?", "")
|
|
168
|
+
|
|
169
|
+
# Check for Constant annotation in the type string
|
|
170
|
+
if "constant[" in type_lower:
|
|
171
|
+
return True
|
|
172
|
+
for fullname in CONSTANT_FULLNAMES:
|
|
173
|
+
if fullname.lower() in type_lower:
|
|
174
|
+
return True
|
|
175
|
+
return False
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _is_meta_field(name: str, base: TypeInfo) -> bool:
|
|
179
|
+
"""Check if a field is declared as Meta[T] or Option[T].
|
|
180
|
+
|
|
181
|
+
Meta fields should always be optional in __init__.
|
|
182
|
+
"""
|
|
183
|
+
type_str = _get_annotation_type_str(name, base)
|
|
184
|
+
if type_str is None:
|
|
185
|
+
return False
|
|
186
|
+
|
|
187
|
+
# Normalize type string - remove optional markers (?)
|
|
188
|
+
# mypy represents types like "Meta?[Path?]"
|
|
189
|
+
type_lower = type_str.lower().replace("?", "")
|
|
190
|
+
|
|
191
|
+
# Check for Meta/Option annotation in the type string
|
|
192
|
+
if "meta[" in type_lower or "option[" in type_lower:
|
|
193
|
+
return True
|
|
194
|
+
for fullname in META_FULLNAMES:
|
|
195
|
+
if fullname.lower() in type_lower:
|
|
196
|
+
return True
|
|
197
|
+
return False
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _get_param_fields(info: TypeInfo) -> List[tuple]:
|
|
201
|
+
"""Extract Param and Meta fields from a class and its bases.
|
|
202
|
+
|
|
203
|
+
Returns list of (name, type, has_default) tuples.
|
|
204
|
+
|
|
205
|
+
Only includes fields from Config subclasses to avoid picking up
|
|
206
|
+
attributes from other base classes like nn.Module.
|
|
207
|
+
Excludes Constant fields which should not be in __init__.
|
|
208
|
+
"""
|
|
209
|
+
fields = []
|
|
210
|
+
seen = set()
|
|
211
|
+
|
|
212
|
+
# Walk MRO to get inherited fields (in reverse to get proper order)
|
|
213
|
+
for base in reversed(info.mro):
|
|
214
|
+
if base.fullname == "builtins.object":
|
|
215
|
+
continue
|
|
216
|
+
if base.fullname in CONFIG_FULLNAMES:
|
|
217
|
+
# Skip Config/Task base classes - we only want user-defined fields
|
|
218
|
+
continue
|
|
219
|
+
if base.fullname == CONFIGMIXIN_FULLNAME:
|
|
220
|
+
# Skip ConfigMixin - it has methods, not params
|
|
221
|
+
continue
|
|
222
|
+
|
|
223
|
+
# Only include fields from Config subclasses
|
|
224
|
+
# This skips bases like nn.Module that don't inherit from Config
|
|
225
|
+
if not _is_config_class(base):
|
|
226
|
+
continue
|
|
227
|
+
|
|
228
|
+
for name, sym in base.names.items():
|
|
229
|
+
if name in seen or name in SKIP_FIELDS:
|
|
230
|
+
continue
|
|
231
|
+
if sym.node is None or not isinstance(sym.node, Var):
|
|
232
|
+
continue
|
|
233
|
+
|
|
234
|
+
var = sym.node
|
|
235
|
+
if var.type is None:
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
# Skip private/dunder fields
|
|
239
|
+
if name.startswith("_"):
|
|
240
|
+
continue
|
|
9
241
|
|
|
10
|
-
|
|
242
|
+
# Skip Constant fields - they should not be in __init__
|
|
243
|
+
if _is_constant_field(name, base):
|
|
244
|
+
continue
|
|
245
|
+
|
|
246
|
+
# Meta fields are always optional
|
|
247
|
+
# Param fields are optional only if they have a default
|
|
248
|
+
is_meta = _is_meta_field(name, base)
|
|
249
|
+
has_default = var.has_explicit_value or is_meta
|
|
250
|
+
|
|
251
|
+
seen.add(name)
|
|
252
|
+
fields.append((name, var.type, has_default))
|
|
253
|
+
|
|
254
|
+
return fields
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _add_init_method(ctx: ClassDefContext) -> None:
|
|
258
|
+
"""Add an __init__ method with proper Param field signatures."""
|
|
259
|
+
info = ctx.cls.info
|
|
260
|
+
|
|
261
|
+
# Get all Param fields from this class and bases
|
|
262
|
+
fields = _get_param_fields(info)
|
|
263
|
+
|
|
264
|
+
# Build __init__ arguments
|
|
265
|
+
args = []
|
|
266
|
+
for name, field_type, has_default in fields:
|
|
267
|
+
# All experimaestro params are keyword-only
|
|
268
|
+
# Fields with defaults are optional
|
|
269
|
+
kind = ARG_NAMED_OPT if has_default else ARG_NAMED
|
|
270
|
+
|
|
271
|
+
# Create argument
|
|
272
|
+
arg = Argument(
|
|
273
|
+
variable=Var(name, field_type),
|
|
274
|
+
type_annotation=field_type,
|
|
275
|
+
initializer=None,
|
|
276
|
+
kind=kind,
|
|
277
|
+
)
|
|
278
|
+
args.append(arg)
|
|
279
|
+
|
|
280
|
+
# Add the __init__ method if we have any args
|
|
281
|
+
if args:
|
|
282
|
+
add_method_to_class(
|
|
283
|
+
ctx.api,
|
|
284
|
+
ctx.cls,
|
|
285
|
+
"__init__",
|
|
286
|
+
args,
|
|
287
|
+
NoneType(),
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def _get_task_outputs_return_type(info: TypeInfo) -> Optional[Instance]:
|
|
292
|
+
"""Check if the class has a task_outputs method and return its return type.
|
|
293
|
+
|
|
294
|
+
If the class defines task_outputs, submit() should return that type instead
|
|
295
|
+
of Self.
|
|
296
|
+
"""
|
|
297
|
+
# Look for task_outputs method in the class
|
|
298
|
+
if "task_outputs" in info.names:
|
|
299
|
+
sym = info.names["task_outputs"]
|
|
300
|
+
if sym.node is not None:
|
|
301
|
+
# Try to get the return type from the method signature
|
|
302
|
+
from mypy.nodes import FuncDef
|
|
303
|
+
|
|
304
|
+
if isinstance(sym.node, FuncDef):
|
|
305
|
+
ret_type = sym.node.type
|
|
306
|
+
if ret_type is not None:
|
|
307
|
+
from mypy.types import CallableType
|
|
308
|
+
|
|
309
|
+
if isinstance(ret_type, CallableType):
|
|
310
|
+
return ret_type.ret_type
|
|
311
|
+
return None
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _add_configmixin_to_bases(ctx: ClassDefContext) -> None:
|
|
315
|
+
"""Add ConfigMixin to the class bases if not already present.
|
|
316
|
+
|
|
317
|
+
This allows mypy to see all ConfigMixin methods on Config subclasses.
|
|
318
|
+
"""
|
|
319
|
+
info = ctx.cls.info
|
|
320
|
+
|
|
321
|
+
# Check if ConfigMixin is already in the MRO
|
|
322
|
+
for base in info.mro:
|
|
323
|
+
if base.fullname == CONFIGMIXIN_FULLNAME:
|
|
324
|
+
return # Already has ConfigMixin
|
|
325
|
+
|
|
326
|
+
# Try to look up ConfigMixin
|
|
327
|
+
try:
|
|
328
|
+
configmixin_sym = ctx.api.lookup_fully_qualified_or_none(CONFIGMIXIN_FULLNAME)
|
|
329
|
+
if configmixin_sym is None or not isinstance(configmixin_sym.node, TypeInfo):
|
|
330
|
+
return
|
|
331
|
+
|
|
332
|
+
configmixin_info = configmixin_sym.node
|
|
333
|
+
configmixin_instance = Instance(configmixin_info, [])
|
|
334
|
+
|
|
335
|
+
# Add ConfigMixin to bases if not already present
|
|
336
|
+
configmixin_in_bases = any(
|
|
337
|
+
isinstance(b, Instance) and b.type.fullname == CONFIGMIXIN_FULLNAME
|
|
338
|
+
for b in info.bases
|
|
339
|
+
)
|
|
340
|
+
if not configmixin_in_bases:
|
|
341
|
+
info.bases.append(configmixin_instance)
|
|
342
|
+
|
|
343
|
+
# Recalculate MRO
|
|
344
|
+
try:
|
|
345
|
+
calculate_mro(info)
|
|
346
|
+
except MroError:
|
|
347
|
+
# If MRO calculation fails, remove the base we added
|
|
348
|
+
info.bases.pop()
|
|
349
|
+
except Exception:
|
|
350
|
+
# If lookup fails, continue without adding ConfigMixin
|
|
11
351
|
pass
|
|
12
352
|
|
|
13
353
|
|
|
14
|
-
def
|
|
354
|
+
def _add_submit_method(ctx: ClassDefContext) -> None:
|
|
355
|
+
"""Add submit() method that returns Self (or task_outputs return type).
|
|
356
|
+
|
|
357
|
+
The actual submit() signature from ConfigMixin:
|
|
358
|
+
def submit(self, *, workspace=None, launcher=None, run_mode=None,
|
|
359
|
+
init_tasks=[], max_retries=None)
|
|
360
|
+
"""
|
|
361
|
+
info = ctx.cls.info
|
|
362
|
+
|
|
363
|
+
# Check if the class has task_outputs
|
|
364
|
+
task_outputs_type = _get_task_outputs_return_type(info)
|
|
365
|
+
|
|
366
|
+
# submit() returns task_outputs return type if defined, otherwise Self
|
|
367
|
+
if task_outputs_type is not None:
|
|
368
|
+
return_type = task_outputs_type
|
|
369
|
+
else:
|
|
370
|
+
return_type = Instance(info, [])
|
|
371
|
+
|
|
372
|
+
# Build submit() arguments - all optional kwargs
|
|
373
|
+
from mypy.types import AnyType, TypeOfAny
|
|
374
|
+
|
|
375
|
+
any_type = AnyType(TypeOfAny.explicit)
|
|
376
|
+
submit_args = []
|
|
377
|
+
for arg_name in ("workspace", "launcher", "run_mode", "init_tasks", "max_retries"):
|
|
378
|
+
arg = Argument(
|
|
379
|
+
variable=Var(arg_name, any_type),
|
|
380
|
+
type_annotation=any_type,
|
|
381
|
+
initializer=None,
|
|
382
|
+
kind=ARG_NAMED_OPT,
|
|
383
|
+
)
|
|
384
|
+
submit_args.append(arg)
|
|
385
|
+
|
|
386
|
+
add_method_to_class(
|
|
387
|
+
ctx.api,
|
|
388
|
+
ctx.cls,
|
|
389
|
+
"submit",
|
|
390
|
+
submit_args,
|
|
391
|
+
return_type,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def _process_config_class(ctx: ClassDefContext) -> None:
|
|
396
|
+
"""Process a Config subclass to add type hints.
|
|
397
|
+
|
|
398
|
+
This adds:
|
|
399
|
+
- ConfigMixin to the class hierarchy for method access
|
|
400
|
+
- C, XPMConfig, XPMValue as class attributes returning Type[Self]
|
|
401
|
+
- An __init__ method with proper Param field signatures
|
|
402
|
+
- A submit() method that returns Self (or task_outputs return type)
|
|
403
|
+
"""
|
|
404
|
+
info = ctx.cls.info
|
|
405
|
+
|
|
406
|
+
# Add ConfigMixin to bases for method access (tag, instance, etc.)
|
|
407
|
+
_add_configmixin_to_bases(ctx)
|
|
408
|
+
|
|
409
|
+
# Create Type[Self] for this class
|
|
410
|
+
class_type = Instance(info, [])
|
|
411
|
+
type_type = TypeType(class_type)
|
|
412
|
+
|
|
413
|
+
# Add C, XPMConfig, XPMValue as class attributes returning the class type
|
|
414
|
+
for attr_name in ("C", "XPMConfig", "XPMValue"):
|
|
415
|
+
if attr_name not in info.names:
|
|
416
|
+
add_attribute_to_class(
|
|
417
|
+
ctx.api,
|
|
418
|
+
ctx.cls,
|
|
419
|
+
attr_name,
|
|
420
|
+
type_type,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
# Add __init__ with proper field signatures
|
|
424
|
+
_add_init_method(ctx)
|
|
425
|
+
|
|
426
|
+
# Add submit() method that returns Self (or task_outputs type)
|
|
427
|
+
_add_submit_method(ctx)
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
class ExperimaestroPlugin(Plugin):
|
|
431
|
+
"""Mypy plugin for experimaestro type hints.
|
|
432
|
+
|
|
433
|
+
This plugin handles:
|
|
434
|
+
- Converting @classproperty decorated methods to proper class attributes
|
|
435
|
+
- Type inference for Config.C and Config.XPMConfig patterns
|
|
436
|
+
- Adding __init__ methods with proper Param field signatures
|
|
437
|
+
"""
|
|
438
|
+
|
|
439
|
+
def get_base_class_hook(
|
|
440
|
+
self, fullname: str
|
|
441
|
+
) -> Callable[[ClassDefContext], None] | None:
|
|
442
|
+
"""Hook called when a class inherits from Config.
|
|
443
|
+
|
|
444
|
+
This allows us to process classproperty attributes and add __init__.
|
|
445
|
+
"""
|
|
446
|
+
if fullname in CONFIG_FULLNAMES:
|
|
447
|
+
return _process_config_class
|
|
448
|
+
return None
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def plugin(_version: str):
|
|
452
|
+
"""Entry point for mypy plugin.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
_version: The mypy version string (unused but required by mypy API)
|
|
456
|
+
|
|
457
|
+
Returns:
|
|
458
|
+
The ExperimaestroPlugin class
|
|
459
|
+
"""
|
|
15
460
|
return ExperimaestroPlugin
|