experimaestro 2.0.0a8__py3-none-any.whl → 2.0.0b8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of experimaestro might be problematic. Click here for more details.

Files changed (122) hide show
  1. experimaestro/__init__.py +10 -11
  2. experimaestro/annotations.py +167 -206
  3. experimaestro/cli/__init__.py +278 -7
  4. experimaestro/cli/filter.py +42 -74
  5. experimaestro/cli/jobs.py +157 -106
  6. experimaestro/cli/refactor.py +249 -0
  7. experimaestro/click.py +0 -1
  8. experimaestro/commandline.py +19 -3
  9. experimaestro/connectors/__init__.py +20 -1
  10. experimaestro/connectors/local.py +12 -0
  11. experimaestro/core/arguments.py +182 -46
  12. experimaestro/core/identifier.py +107 -6
  13. experimaestro/core/objects/__init__.py +6 -0
  14. experimaestro/core/objects/config.py +542 -25
  15. experimaestro/core/objects/config_walk.py +20 -0
  16. experimaestro/core/serialization.py +91 -34
  17. experimaestro/core/subparameters.py +164 -0
  18. experimaestro/core/types.py +175 -38
  19. experimaestro/exceptions.py +26 -0
  20. experimaestro/experiments/cli.py +111 -25
  21. experimaestro/generators.py +50 -9
  22. experimaestro/huggingface.py +3 -1
  23. experimaestro/launcherfinder/parser.py +29 -0
  24. experimaestro/launchers/__init__.py +26 -1
  25. experimaestro/launchers/direct.py +12 -0
  26. experimaestro/launchers/slurm/base.py +154 -2
  27. experimaestro/mkdocs/metaloader.py +0 -1
  28. experimaestro/mypy.py +452 -7
  29. experimaestro/notifications.py +63 -13
  30. experimaestro/progress.py +0 -2
  31. experimaestro/rpyc.py +0 -1
  32. experimaestro/run.py +19 -6
  33. experimaestro/scheduler/base.py +510 -125
  34. experimaestro/scheduler/dependencies.py +43 -28
  35. experimaestro/scheduler/dynamic_outputs.py +259 -130
  36. experimaestro/scheduler/experiment.py +256 -31
  37. experimaestro/scheduler/interfaces.py +501 -0
  38. experimaestro/scheduler/jobs.py +216 -206
  39. experimaestro/scheduler/remote/__init__.py +31 -0
  40. experimaestro/scheduler/remote/client.py +874 -0
  41. experimaestro/scheduler/remote/protocol.py +467 -0
  42. experimaestro/scheduler/remote/server.py +423 -0
  43. experimaestro/scheduler/remote/sync.py +144 -0
  44. experimaestro/scheduler/services.py +323 -23
  45. experimaestro/scheduler/state_db.py +437 -0
  46. experimaestro/scheduler/state_provider.py +2766 -0
  47. experimaestro/scheduler/state_sync.py +891 -0
  48. experimaestro/scheduler/workspace.py +52 -10
  49. experimaestro/scriptbuilder.py +7 -0
  50. experimaestro/server/__init__.py +147 -57
  51. experimaestro/server/data/index.css +0 -125
  52. experimaestro/server/data/index.css.map +1 -1
  53. experimaestro/server/data/index.js +194 -58
  54. experimaestro/server/data/index.js.map +1 -1
  55. experimaestro/settings.py +44 -5
  56. experimaestro/sphinx/__init__.py +3 -3
  57. experimaestro/taskglobals.py +20 -0
  58. experimaestro/tests/conftest.py +80 -0
  59. experimaestro/tests/core/test_generics.py +2 -2
  60. experimaestro/tests/identifier_stability.json +45 -0
  61. experimaestro/tests/launchers/bin/sacct +6 -2
  62. experimaestro/tests/launchers/bin/sbatch +4 -2
  63. experimaestro/tests/launchers/test_slurm.py +80 -0
  64. experimaestro/tests/tasks/test_dynamic.py +231 -0
  65. experimaestro/tests/test_cli_jobs.py +615 -0
  66. experimaestro/tests/test_deprecated.py +630 -0
  67. experimaestro/tests/test_environment.py +200 -0
  68. experimaestro/tests/test_file_progress_integration.py +1 -1
  69. experimaestro/tests/test_forward.py +3 -3
  70. experimaestro/tests/test_identifier.py +372 -41
  71. experimaestro/tests/test_identifier_stability.py +458 -0
  72. experimaestro/tests/test_instance.py +3 -3
  73. experimaestro/tests/test_multitoken.py +442 -0
  74. experimaestro/tests/test_mypy.py +433 -0
  75. experimaestro/tests/test_objects.py +312 -5
  76. experimaestro/tests/test_outputs.py +2 -2
  77. experimaestro/tests/test_param.py +8 -12
  78. experimaestro/tests/test_partial_paths.py +231 -0
  79. experimaestro/tests/test_progress.py +0 -48
  80. experimaestro/tests/test_remote_state.py +671 -0
  81. experimaestro/tests/test_resumable_task.py +480 -0
  82. experimaestro/tests/test_serializers.py +141 -1
  83. experimaestro/tests/test_state_db.py +434 -0
  84. experimaestro/tests/test_subparameters.py +160 -0
  85. experimaestro/tests/test_tags.py +136 -0
  86. experimaestro/tests/test_tasks.py +107 -121
  87. experimaestro/tests/test_token_locking.py +252 -0
  88. experimaestro/tests/test_tokens.py +17 -13
  89. experimaestro/tests/test_types.py +123 -1
  90. experimaestro/tests/test_workspace_triggers.py +158 -0
  91. experimaestro/tests/token_reschedule.py +4 -2
  92. experimaestro/tests/utils.py +2 -2
  93. experimaestro/tokens.py +154 -57
  94. experimaestro/tools/diff.py +1 -1
  95. experimaestro/tui/__init__.py +8 -0
  96. experimaestro/tui/app.py +2395 -0
  97. experimaestro/tui/app.tcss +353 -0
  98. experimaestro/tui/log_viewer.py +228 -0
  99. experimaestro/utils/__init__.py +23 -0
  100. experimaestro/utils/environment.py +148 -0
  101. experimaestro/utils/git.py +129 -0
  102. experimaestro/utils/resources.py +1 -1
  103. experimaestro/version.py +34 -0
  104. {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b8.dist-info}/METADATA +68 -38
  105. experimaestro-2.0.0b8.dist-info/RECORD +187 -0
  106. {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b8.dist-info}/WHEEL +1 -1
  107. experimaestro-2.0.0b8.dist-info/entry_points.txt +16 -0
  108. experimaestro/compat.py +0 -6
  109. experimaestro/core/objects.pyi +0 -221
  110. experimaestro/server/data/0c35d18bf06992036b69.woff2 +0 -0
  111. experimaestro/server/data/219aa9140e099e6c72ed.woff2 +0 -0
  112. experimaestro/server/data/3a4004a46a653d4b2166.woff +0 -0
  113. experimaestro/server/data/3baa5b8f3469222b822d.woff +0 -0
  114. experimaestro/server/data/4d73cb90e394b34b7670.woff +0 -0
  115. experimaestro/server/data/4ef4218c522f1eb6b5b1.woff2 +0 -0
  116. experimaestro/server/data/5d681e2edae8c60630db.woff +0 -0
  117. experimaestro/server/data/6f420cf17cc0d7676fad.woff2 +0 -0
  118. experimaestro/server/data/c380809fd3677d7d6903.woff2 +0 -0
  119. experimaestro/server/data/f882956fd323fd322f31.woff +0 -0
  120. experimaestro-2.0.0a8.dist-info/RECORD +0 -166
  121. experimaestro-2.0.0a8.dist-info/entry_points.txt +0 -17
  122. {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b8.dist-info}/licenses/LICENSE +0 -0
@@ -3,8 +3,20 @@ from . import Launcher
3
3
 
4
4
 
5
5
  class DirectLauncher(Launcher):
6
+ """Launcher that runs tasks directly as local processes.
7
+
8
+ This is the default launcher that executes tasks on the local machine
9
+ without any job scheduler. Tasks are run as Python subprocesses.
10
+
11
+ :param connector: The connector to use (defaults to LocalConnector)
12
+ """
13
+
6
14
  def scriptbuilder(self):
7
15
  return PythonScriptBuilder()
8
16
 
17
+ def launcher_info_code(self) -> str:
18
+ """Returns empty string as local launcher has no time limits."""
19
+ return ""
20
+
9
21
  def __str__(self):
10
22
  return f"DirectLauncher({self.connector})"
@@ -7,6 +7,7 @@ from typing import (
7
7
  List,
8
8
  Optional,
9
9
  Tuple,
10
+ TYPE_CHECKING,
10
11
  get_type_hints,
11
12
  )
12
13
  from experimaestro.connectors.local import LocalConnector
@@ -20,7 +21,7 @@ from experimaestro.launcherfinder.registry import (
20
21
  from experimaestro.utils import ThreadingCondition
21
22
  from experimaestro.tests.connectors.utils import OutputCaptureHandler
22
23
  from experimaestro.utils.asyncio import asyncThreadcheck
23
- from experimaestro.compat import cached_property
24
+ from functools import cached_property
24
25
  from experimaestro.launchers import Launcher
25
26
  from experimaestro.scriptbuilder import PythonScriptBuilder
26
27
  from experimaestro.connectors import (
@@ -32,8 +33,131 @@ from experimaestro.connectors import (
32
33
  RedirectType,
33
34
  )
34
35
 
36
+ if TYPE_CHECKING:
37
+ from experimaestro.scheduler.jobs import JobState
38
+
35
39
  logger = logging.getLogger("xpm.slurm")
36
40
 
41
+ # Cached job end time (absolute timestamp).
42
+ # Only used when a task is running within a SLURM job.
43
+ _slurm_job_end_time: Optional[float] = None
44
+
45
+
46
+ class SlurmLauncherInformation:
47
+ """Launcher information for SLURM jobs, used during task execution."""
48
+
49
+ def __init__(self, binpath: str = "/usr/bin"):
50
+ self.binpath = Path(binpath)
51
+
52
+ def remaining_time(self) -> Optional[float]:
53
+ """Returns the remaining time in seconds before the SLURM job times out.
54
+
55
+ Uses the SLURM_JOB_ID environment variable to query squeue for the
56
+ remaining time. The job end time is cached on first call.
57
+
58
+ Returns:
59
+ The remaining time in seconds, or None if no time limit.
60
+ """
61
+ import os
62
+ import time
63
+
64
+ global _slurm_job_end_time
65
+
66
+ # Use cached end time if available
67
+ if _slurm_job_end_time is not None:
68
+ remaining = _slurm_job_end_time - time.time()
69
+ return max(0.0, remaining)
70
+
71
+ # Query SLURM for remaining time and compute end time
72
+ job_id = os.environ.get("SLURM_JOB_ID")
73
+ if not job_id:
74
+ logger.debug("No SLURM_JOB_ID in environment, cannot get remaining time")
75
+ return None
76
+
77
+ remaining_seconds = self._query_remaining_time(job_id)
78
+ if remaining_seconds is None:
79
+ return None
80
+
81
+ # Cache the absolute end time
82
+ _slurm_job_end_time = time.time() + remaining_seconds
83
+ return remaining_seconds
84
+
85
+ def _query_remaining_time(self, job_id: str) -> Optional[float]:
86
+ """Query SLURM for remaining time of a job."""
87
+ import subprocess
88
+
89
+ try:
90
+ result = subprocess.run(
91
+ [
92
+ f"{self.binpath}/squeue",
93
+ "--job",
94
+ job_id,
95
+ "--format=%L",
96
+ "--noheader",
97
+ ],
98
+ capture_output=True,
99
+ text=True,
100
+ timeout=30,
101
+ )
102
+
103
+ if result.returncode != 0:
104
+ logger.warning(
105
+ "squeue returned error code %d: %s",
106
+ result.returncode,
107
+ result.stderr,
108
+ )
109
+ return None
110
+
111
+ time_str = result.stdout.strip()
112
+ if not time_str or time_str == "UNLIMITED":
113
+ return None
114
+
115
+ return self._parse_slurm_time(time_str)
116
+ except subprocess.TimeoutExpired:
117
+ logger.warning("Timeout querying squeue for remaining time")
118
+ return None
119
+ except Exception as e:
120
+ logger.warning("Error querying SLURM remaining time: %s", e)
121
+ return None
122
+
123
+ @staticmethod
124
+ def _parse_slurm_time(time_str: str) -> Optional[float]:
125
+ """Parse SLURM time format to seconds.
126
+
127
+ SLURM time format can be:
128
+ - D-HH:MM:SS (days-hours:minutes:seconds)
129
+ - HH:MM:SS (hours:minutes:seconds)
130
+ - MM:SS (minutes:seconds)
131
+ - SS (seconds)
132
+
133
+ Returns:
134
+ Time in seconds, or None if parsing fails
135
+ """
136
+ try:
137
+ days = 0
138
+ if "-" in time_str:
139
+ days_str, time_str = time_str.split("-", 1)
140
+ days = int(days_str)
141
+
142
+ parts = time_str.split(":")
143
+ if len(parts) == 3:
144
+ hours, minutes, seconds = int(parts[0]), int(parts[1]), int(parts[2])
145
+ elif len(parts) == 2:
146
+ hours = 0
147
+ minutes, seconds = int(parts[0]), int(parts[1])
148
+ elif len(parts) == 1:
149
+ hours = 0
150
+ minutes = 0
151
+ seconds = int(parts[0])
152
+ else:
153
+ logger.warning("Could not parse SLURM time: %s", time_str)
154
+ return None
155
+
156
+ return float(days * 86400 + hours * 3600 + minutes * 60 + seconds)
157
+ except (ValueError, IndexError) as e:
158
+ logger.warning("Could not parse SLURM time '%s': %s", time_str, e)
159
+ return None
160
+
37
161
 
38
162
  class SlurmJobState:
39
163
  start: str
@@ -176,14 +300,34 @@ class BatchSlurmProcess(Process):
176
300
  def __init__(self, launcher: "SlurmLauncher", jobid: str):
177
301
  self.launcher = launcher
178
302
  self.jobid = jobid
303
+ self._last_state: Optional[SlurmJobState] = None
179
304
 
180
305
  def wait(self):
181
306
  with SlurmProcessWatcher.get(self.launcher) as watcher:
182
307
  while True:
183
308
  state = watcher.getjob(self.jobid)
184
309
  if state and state.finished():
310
+ self._last_state = state
185
311
  return 0 if state.slurm_state == "COMPLETED" else 1
186
312
 
313
+ def get_job_state(self, code: int) -> "JobState":
314
+ """Convert SLURM exit code to JobState, detecting timeouts"""
315
+ from experimaestro.scheduler.jobs import (
316
+ JobState,
317
+ JobStateError,
318
+ JobFailureStatus,
319
+ )
320
+
321
+ if code == 0:
322
+ return JobState.DONE
323
+
324
+ # Check if this was a SLURM timeout
325
+ if self._last_state and self._last_state.slurm_state == "TIMEOUT":
326
+ logger.info("SLURM job %s timed out", self.jobid)
327
+ return JobStateError(JobFailureStatus.TIMEOUT)
328
+
329
+ return JobState.ERROR
330
+
187
331
  async def aio_state(self, timeout: float | None = None) -> ProcessState:
188
332
  def check():
189
333
  with SlurmProcessWatcher.get(self.launcher) as watcher:
@@ -432,7 +576,7 @@ class SlurmLauncher(Launcher):
432
576
  def scriptbuilder(self):
433
577
  """Returns the script builder
434
578
 
435
- We assume *nix, but should be changed to PythonScriptBuilder when working
579
+ We assume Unix, but should be changed to PythonScriptBuilder when working
436
580
  """
437
581
  return SlurmScriptBuilder(self)
438
582
 
@@ -442,6 +586,14 @@ class SlurmLauncher(Launcher):
442
586
  By default, returns the associated connector builder"""
443
587
  return SlurmProcessBuilder(self)
444
588
 
589
+ def launcher_info_code(self) -> str:
590
+ """Returns Python code to set up launcher info during task execution."""
591
+ return (
592
+ " from experimaestro.launchers.slurm import SlurmLauncherInformation\n"
593
+ " from experimaestro import taskglobals\n"
594
+ f' taskglobals.Env.instance().launcher_info = SlurmLauncherInformation(binpath="{self.binpath}")\n'
595
+ )
596
+
445
597
 
446
598
  class SlurmScriptBuilder(PythonScriptBuilder):
447
599
  def __init__(self, launcher: SlurmLauncher, pythonpath=None):
@@ -2,7 +2,6 @@
2
2
  # when building documentation
3
3
 
4
4
  import sys
5
- import re
6
5
  import importlib.abc
7
6
  import importlib.machinery
8
7
 
experimaestro/mypy.py CHANGED
@@ -1,15 +1,460 @@
1
- from mypy.plugin import Plugin
1
+ """Mypy plugin for experimaestro.
2
2
 
3
+ This plugin provides type hints support for experimaestro's Config system,
4
+ particularly for the Config.C pattern and proper parameter type inference.
3
5
 
4
- class ExperimaestroPlugin(Plugin):
5
- """Just do nothing for now"""
6
+ The plugin handles:
7
+ - Config.C, Config.XPMConfig, Config.XPMValue class properties
8
+ - Adding __init__ with proper Param field signatures
9
+ - Adding ConfigMixin to the class hierarchy for method access
10
+ - Handling task_outputs return type for submit()
6
11
 
7
- def get_class_decorator_hook(self, tada):
8
- pass
12
+ Usage in mypy.ini or pyproject.toml:
13
+ [mypy]
14
+ plugins = experimaestro.mypy
15
+
16
+ Or in pyproject.toml:
17
+ [tool.mypy]
18
+ plugins = ["experimaestro.mypy"]
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from typing import Callable, List, Optional
24
+
25
+ from mypy.nodes import (
26
+ TypeInfo,
27
+ Var,
28
+ Argument,
29
+ ARG_NAMED_OPT,
30
+ ARG_NAMED,
31
+ )
32
+ from mypy.plugin import Plugin, ClassDefContext
33
+ from mypy.plugins.common import add_attribute_to_class, add_method_to_class
34
+ from mypy.types import (
35
+ Instance,
36
+ TypeType,
37
+ NoneType,
38
+ )
39
+ from mypy.mro import calculate_mro, MroError
40
+
41
+ # Full names of Config and its subclasses that need C/XPMConfig attributes
42
+ CONFIG_FULLNAMES = {
43
+ "experimaestro.core.objects.config.Config",
44
+ "experimaestro.core.objects.config.LightweightTask",
45
+ "experimaestro.core.objects.config.Task",
46
+ "experimaestro.core.objects.config.ResumableTask",
47
+ "experimaestro.Config",
48
+ "experimaestro.Task",
49
+ "experimaestro.LightweightTask",
50
+ "experimaestro.ResumableTask",
51
+ "experimaestro.core.objects.Config",
52
+ "experimaestro.core.objects.Task",
53
+ "experimaestro.core.objects.LightweightTask",
54
+ "experimaestro.core.objects.ResumableTask",
55
+ }
56
+
57
+ # ConfigMixin full name for method inheritance
58
+ CONFIGMIXIN_FULLNAME = "experimaestro.core.objects.config.ConfigMixin"
59
+
60
+ # Full names for Param annotations (required by default)
61
+ PARAM_FULLNAMES = {
62
+ "experimaestro.core.arguments.Param",
63
+ "experimaestro.Param",
64
+ }
65
+
66
+ # Full names for Meta/Option annotations (always optional, ignored in identifier)
67
+ META_FULLNAMES = {
68
+ "experimaestro.core.arguments.Meta",
69
+ "experimaestro.Meta",
70
+ "experimaestro.core.arguments.Option",
71
+ "experimaestro.Option",
72
+ }
73
+
74
+ # Full names for Constant annotations (excluded from __init__)
75
+ CONSTANT_FULLNAMES = {
76
+ "experimaestro.core.arguments.Constant",
77
+ "experimaestro.Constant",
78
+ }
79
+
80
+
81
+ def is_config_subclass(info: TypeInfo) -> bool:
82
+ """Check if a TypeInfo represents a Config subclass.
83
+
84
+ Args:
85
+ info: The TypeInfo to check
86
+
87
+ Returns:
88
+ True if the type is Config or a subclass of Config
89
+ """
90
+ if info.fullname in CONFIG_FULLNAMES:
91
+ return True
92
+ for base in info.mro:
93
+ if base.fullname in CONFIG_FULLNAMES:
94
+ return True
95
+ return False
96
+
97
+
98
+ # Fields to skip when building __init__ signature
99
+ SKIP_FIELDS = {
100
+ "C",
101
+ "XPMConfig",
102
+ "XPMValue",
103
+ "__xpm__",
104
+ "__xpmtype__",
105
+ "__xpmid__",
106
+ "_deprecated_from",
107
+ }
108
+
109
+
110
+ def _is_config_class(base: TypeInfo) -> bool:
111
+ """Check if a TypeInfo is a Config subclass.
112
+
113
+ Returns True for user-defined Config subclasses.
114
+ """
115
+ for mro_base in base.mro:
116
+ if mro_base.fullname in CONFIG_FULLNAMES:
117
+ return True
118
+ return False
119
+
120
+
121
+ def _get_annotation_type_str(name: str, base: TypeInfo) -> Optional[str]:
122
+ """Get the type annotation string for a field.
123
+
124
+ Tries multiple sources to find the original annotation:
125
+ 1. The AST unanalyzed_type (preserves original)
126
+ 2. The variable's type string
127
+ """
128
+ # Check the AST first to get unanalyzed types
129
+ if base.defn is not None:
130
+ for stmt in base.defn.defs.body:
131
+ from mypy.nodes import AssignmentStmt
132
+
133
+ if isinstance(stmt, AssignmentStmt):
134
+ for lvalue in stmt.lvalues:
135
+ from mypy.nodes import NameExpr
136
+
137
+ if isinstance(lvalue, NameExpr) and lvalue.name == name:
138
+ # Try unanalyzed_type first (preserves the original annotation)
139
+ if stmt.unanalyzed_type is not None:
140
+ return str(stmt.unanalyzed_type)
141
+ # Fall back to analyzed type
142
+ if stmt.type is not None:
143
+ return str(stmt.type)
144
+
145
+ # Fall back to checking the symbol's type
146
+ if name in base.names:
147
+ sym = base.names[name]
148
+ if sym.node is not None and isinstance(sym.node, Var):
149
+ var = sym.node
150
+ if var.type is not None:
151
+ return str(var.type)
152
+
153
+ return None
154
+
155
+
156
+ def _is_constant_field(name: str, base: TypeInfo) -> bool:
157
+ """Check if a field is declared as Constant[T].
158
+
159
+ Constant fields should be excluded from __init__.
160
+ """
161
+ type_str = _get_annotation_type_str(name, base)
162
+ if type_str is None:
163
+ return False
164
+
165
+ # Normalize type string - remove optional markers (?)
166
+ # mypy represents types like "Constant?[str?]"
167
+ type_lower = type_str.lower().replace("?", "")
168
+
169
+ # Check for Constant annotation in the type string
170
+ if "constant[" in type_lower:
171
+ return True
172
+ for fullname in CONSTANT_FULLNAMES:
173
+ if fullname.lower() in type_lower:
174
+ return True
175
+ return False
176
+
177
+
178
+ def _is_meta_field(name: str, base: TypeInfo) -> bool:
179
+ """Check if a field is declared as Meta[T] or Option[T].
180
+
181
+ Meta fields should always be optional in __init__.
182
+ """
183
+ type_str = _get_annotation_type_str(name, base)
184
+ if type_str is None:
185
+ return False
186
+
187
+ # Normalize type string - remove optional markers (?)
188
+ # mypy represents types like "Meta?[Path?]"
189
+ type_lower = type_str.lower().replace("?", "")
190
+
191
+ # Check for Meta/Option annotation in the type string
192
+ if "meta[" in type_lower or "option[" in type_lower:
193
+ return True
194
+ for fullname in META_FULLNAMES:
195
+ if fullname.lower() in type_lower:
196
+ return True
197
+ return False
198
+
199
+
200
+ def _get_param_fields(info: TypeInfo) -> List[tuple]:
201
+ """Extract Param and Meta fields from a class and its bases.
202
+
203
+ Returns list of (name, type, has_default) tuples.
204
+
205
+ Only includes fields from Config subclasses to avoid picking up
206
+ attributes from other base classes like nn.Module.
207
+ Excludes Constant fields which should not be in __init__.
208
+ """
209
+ fields = []
210
+ seen = set()
211
+
212
+ # Walk MRO to get inherited fields (in reverse to get proper order)
213
+ for base in reversed(info.mro):
214
+ if base.fullname == "builtins.object":
215
+ continue
216
+ if base.fullname in CONFIG_FULLNAMES:
217
+ # Skip Config/Task base classes - we only want user-defined fields
218
+ continue
219
+ if base.fullname == CONFIGMIXIN_FULLNAME:
220
+ # Skip ConfigMixin - it has methods, not params
221
+ continue
222
+
223
+ # Only include fields from Config subclasses
224
+ # This skips bases like nn.Module that don't inherit from Config
225
+ if not _is_config_class(base):
226
+ continue
227
+
228
+ for name, sym in base.names.items():
229
+ if name in seen or name in SKIP_FIELDS:
230
+ continue
231
+ if sym.node is None or not isinstance(sym.node, Var):
232
+ continue
233
+
234
+ var = sym.node
235
+ if var.type is None:
236
+ continue
237
+
238
+ # Skip private/dunder fields
239
+ if name.startswith("_"):
240
+ continue
9
241
 
10
- def get_customize_class_mro_hook(self, tada):
242
+ # Skip Constant fields - they should not be in __init__
243
+ if _is_constant_field(name, base):
244
+ continue
245
+
246
+ # Meta fields are always optional
247
+ # Param fields are optional only if they have a default
248
+ is_meta = _is_meta_field(name, base)
249
+ has_default = var.has_explicit_value or is_meta
250
+
251
+ seen.add(name)
252
+ fields.append((name, var.type, has_default))
253
+
254
+ return fields
255
+
256
+
257
+ def _add_init_method(ctx: ClassDefContext) -> None:
258
+ """Add an __init__ method with proper Param field signatures."""
259
+ info = ctx.cls.info
260
+
261
+ # Get all Param fields from this class and bases
262
+ fields = _get_param_fields(info)
263
+
264
+ # Build __init__ arguments
265
+ args = []
266
+ for name, field_type, has_default in fields:
267
+ # All experimaestro params are keyword-only
268
+ # Fields with defaults are optional
269
+ kind = ARG_NAMED_OPT if has_default else ARG_NAMED
270
+
271
+ # Create argument
272
+ arg = Argument(
273
+ variable=Var(name, field_type),
274
+ type_annotation=field_type,
275
+ initializer=None,
276
+ kind=kind,
277
+ )
278
+ args.append(arg)
279
+
280
+ # Add the __init__ method if we have any args
281
+ if args:
282
+ add_method_to_class(
283
+ ctx.api,
284
+ ctx.cls,
285
+ "__init__",
286
+ args,
287
+ NoneType(),
288
+ )
289
+
290
+
291
+ def _get_task_outputs_return_type(info: TypeInfo) -> Optional[Instance]:
292
+ """Check if the class has a task_outputs method and return its return type.
293
+
294
+ If the class defines task_outputs, submit() should return that type instead
295
+ of Self.
296
+ """
297
+ # Look for task_outputs method in the class
298
+ if "task_outputs" in info.names:
299
+ sym = info.names["task_outputs"]
300
+ if sym.node is not None:
301
+ # Try to get the return type from the method signature
302
+ from mypy.nodes import FuncDef
303
+
304
+ if isinstance(sym.node, FuncDef):
305
+ ret_type = sym.node.type
306
+ if ret_type is not None:
307
+ from mypy.types import CallableType
308
+
309
+ if isinstance(ret_type, CallableType):
310
+ return ret_type.ret_type
311
+ return None
312
+
313
+
314
+ def _add_configmixin_to_bases(ctx: ClassDefContext) -> None:
315
+ """Add ConfigMixin to the class bases if not already present.
316
+
317
+ This allows mypy to see all ConfigMixin methods on Config subclasses.
318
+ """
319
+ info = ctx.cls.info
320
+
321
+ # Check if ConfigMixin is already in the MRO
322
+ for base in info.mro:
323
+ if base.fullname == CONFIGMIXIN_FULLNAME:
324
+ return # Already has ConfigMixin
325
+
326
+ # Try to look up ConfigMixin
327
+ try:
328
+ configmixin_sym = ctx.api.lookup_fully_qualified_or_none(CONFIGMIXIN_FULLNAME)
329
+ if configmixin_sym is None or not isinstance(configmixin_sym.node, TypeInfo):
330
+ return
331
+
332
+ configmixin_info = configmixin_sym.node
333
+ configmixin_instance = Instance(configmixin_info, [])
334
+
335
+ # Add ConfigMixin to bases if not already present
336
+ configmixin_in_bases = any(
337
+ isinstance(b, Instance) and b.type.fullname == CONFIGMIXIN_FULLNAME
338
+ for b in info.bases
339
+ )
340
+ if not configmixin_in_bases:
341
+ info.bases.append(configmixin_instance)
342
+
343
+ # Recalculate MRO
344
+ try:
345
+ calculate_mro(info)
346
+ except MroError:
347
+ # If MRO calculation fails, remove the base we added
348
+ info.bases.pop()
349
+ except Exception:
350
+ # If lookup fails, continue without adding ConfigMixin
11
351
  pass
12
352
 
13
353
 
14
- def plugin(version: str):
354
+ def _add_submit_method(ctx: ClassDefContext) -> None:
355
+ """Add submit() method that returns Self (or task_outputs return type).
356
+
357
+ The actual submit() signature from ConfigMixin:
358
+ def submit(self, *, workspace=None, launcher=None, run_mode=None,
359
+ init_tasks=[], max_retries=None)
360
+ """
361
+ info = ctx.cls.info
362
+
363
+ # Check if the class has task_outputs
364
+ task_outputs_type = _get_task_outputs_return_type(info)
365
+
366
+ # submit() returns task_outputs return type if defined, otherwise Self
367
+ if task_outputs_type is not None:
368
+ return_type = task_outputs_type
369
+ else:
370
+ return_type = Instance(info, [])
371
+
372
+ # Build submit() arguments - all optional kwargs
373
+ from mypy.types import AnyType, TypeOfAny
374
+
375
+ any_type = AnyType(TypeOfAny.explicit)
376
+ submit_args = []
377
+ for arg_name in ("workspace", "launcher", "run_mode", "init_tasks", "max_retries"):
378
+ arg = Argument(
379
+ variable=Var(arg_name, any_type),
380
+ type_annotation=any_type,
381
+ initializer=None,
382
+ kind=ARG_NAMED_OPT,
383
+ )
384
+ submit_args.append(arg)
385
+
386
+ add_method_to_class(
387
+ ctx.api,
388
+ ctx.cls,
389
+ "submit",
390
+ submit_args,
391
+ return_type,
392
+ )
393
+
394
+
395
+ def _process_config_class(ctx: ClassDefContext) -> None:
396
+ """Process a Config subclass to add type hints.
397
+
398
+ This adds:
399
+ - ConfigMixin to the class hierarchy for method access
400
+ - C, XPMConfig, XPMValue as class attributes returning Type[Self]
401
+ - An __init__ method with proper Param field signatures
402
+ - A submit() method that returns Self (or task_outputs return type)
403
+ """
404
+ info = ctx.cls.info
405
+
406
+ # Add ConfigMixin to bases for method access (tag, instance, etc.)
407
+ _add_configmixin_to_bases(ctx)
408
+
409
+ # Create Type[Self] for this class
410
+ class_type = Instance(info, [])
411
+ type_type = TypeType(class_type)
412
+
413
+ # Add C, XPMConfig, XPMValue as class attributes returning the class type
414
+ for attr_name in ("C", "XPMConfig", "XPMValue"):
415
+ if attr_name not in info.names:
416
+ add_attribute_to_class(
417
+ ctx.api,
418
+ ctx.cls,
419
+ attr_name,
420
+ type_type,
421
+ )
422
+
423
+ # Add __init__ with proper field signatures
424
+ _add_init_method(ctx)
425
+
426
+ # Add submit() method that returns Self (or task_outputs type)
427
+ _add_submit_method(ctx)
428
+
429
+
430
+ class ExperimaestroPlugin(Plugin):
431
+ """Mypy plugin for experimaestro type hints.
432
+
433
+ This plugin handles:
434
+ - Converting @classproperty decorated methods to proper class attributes
435
+ - Type inference for Config.C and Config.XPMConfig patterns
436
+ - Adding __init__ methods with proper Param field signatures
437
+ """
438
+
439
+ def get_base_class_hook(
440
+ self, fullname: str
441
+ ) -> Callable[[ClassDefContext], None] | None:
442
+ """Hook called when a class inherits from Config.
443
+
444
+ This allows us to process classproperty attributes and add __init__.
445
+ """
446
+ if fullname in CONFIG_FULLNAMES:
447
+ return _process_config_class
448
+ return None
449
+
450
+
451
+ def plugin(_version: str):
452
+ """Entry point for mypy plugin.
453
+
454
+ Args:
455
+ _version: The mypy version string (unused but required by mypy API)
456
+
457
+ Returns:
458
+ The ExperimaestroPlugin class
459
+ """
15
460
  return ExperimaestroPlugin