experimaestro 2.0.0b8__py3-none-any.whl → 2.0.0b17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of experimaestro might be problematic. Click here for more details.

Files changed (152) hide show
  1. experimaestro/__init__.py +12 -5
  2. experimaestro/cli/__init__.py +239 -126
  3. experimaestro/cli/filter.py +48 -23
  4. experimaestro/cli/jobs.py +253 -71
  5. experimaestro/cli/refactor.py +1 -2
  6. experimaestro/commandline.py +7 -4
  7. experimaestro/connectors/__init__.py +9 -1
  8. experimaestro/connectors/local.py +43 -3
  9. experimaestro/core/arguments.py +18 -18
  10. experimaestro/core/identifier.py +11 -11
  11. experimaestro/core/objects/config.py +96 -39
  12. experimaestro/core/objects/config_walk.py +3 -3
  13. experimaestro/core/{subparameters.py → partial.py} +16 -16
  14. experimaestro/core/partial_lock.py +394 -0
  15. experimaestro/core/types.py +12 -15
  16. experimaestro/dynamic.py +290 -0
  17. experimaestro/experiments/__init__.py +6 -2
  18. experimaestro/experiments/cli.py +217 -50
  19. experimaestro/experiments/configuration.py +24 -0
  20. experimaestro/generators.py +5 -5
  21. experimaestro/ipc.py +118 -1
  22. experimaestro/launcherfinder/__init__.py +2 -2
  23. experimaestro/launcherfinder/registry.py +6 -7
  24. experimaestro/launcherfinder/specs.py +2 -9
  25. experimaestro/launchers/slurm/__init__.py +2 -2
  26. experimaestro/launchers/slurm/base.py +62 -0
  27. experimaestro/locking.py +957 -1
  28. experimaestro/notifications.py +89 -201
  29. experimaestro/progress.py +63 -366
  30. experimaestro/rpyc.py +0 -2
  31. experimaestro/run.py +29 -2
  32. experimaestro/scheduler/__init__.py +8 -1
  33. experimaestro/scheduler/base.py +629 -53
  34. experimaestro/scheduler/dependencies.py +20 -16
  35. experimaestro/scheduler/experiment.py +732 -167
  36. experimaestro/scheduler/interfaces.py +316 -101
  37. experimaestro/scheduler/jobs.py +58 -20
  38. experimaestro/scheduler/remote/adaptive_sync.py +265 -0
  39. experimaestro/scheduler/remote/client.py +171 -117
  40. experimaestro/scheduler/remote/protocol.py +8 -193
  41. experimaestro/scheduler/remote/server.py +95 -71
  42. experimaestro/scheduler/services.py +53 -28
  43. experimaestro/scheduler/state_provider.py +663 -2430
  44. experimaestro/scheduler/state_status.py +1247 -0
  45. experimaestro/scheduler/transient.py +31 -0
  46. experimaestro/scheduler/workspace.py +1 -1
  47. experimaestro/scheduler/workspace_state_provider.py +1273 -0
  48. experimaestro/scriptbuilder.py +4 -4
  49. experimaestro/settings.py +36 -0
  50. experimaestro/tests/conftest.py +33 -5
  51. experimaestro/tests/connectors/bin/executable.py +1 -1
  52. experimaestro/tests/fixtures/pre_experiment/experiment_check_env.py +16 -0
  53. experimaestro/tests/fixtures/pre_experiment/experiment_check_mock.py +14 -0
  54. experimaestro/tests/fixtures/pre_experiment/experiment_simple.py +12 -0
  55. experimaestro/tests/fixtures/pre_experiment/pre_setup_env.py +5 -0
  56. experimaestro/tests/fixtures/pre_experiment/pre_setup_error.py +3 -0
  57. experimaestro/tests/fixtures/pre_experiment/pre_setup_mock.py +8 -0
  58. experimaestro/tests/launchers/bin/test.py +1 -0
  59. experimaestro/tests/launchers/test_slurm.py +9 -9
  60. experimaestro/tests/partial_reschedule.py +46 -0
  61. experimaestro/tests/restart.py +3 -3
  62. experimaestro/tests/restart_main.py +1 -0
  63. experimaestro/tests/scripts/notifyandwait.py +1 -0
  64. experimaestro/tests/task_partial.py +38 -0
  65. experimaestro/tests/task_tokens.py +2 -2
  66. experimaestro/tests/tasks/test_dynamic.py +6 -6
  67. experimaestro/tests/test_dependencies.py +3 -3
  68. experimaestro/tests/test_deprecated.py +15 -15
  69. experimaestro/tests/test_dynamic_locking.py +317 -0
  70. experimaestro/tests/test_environment.py +24 -14
  71. experimaestro/tests/test_experiment.py +171 -36
  72. experimaestro/tests/test_identifier.py +25 -25
  73. experimaestro/tests/test_identifier_stability.py +3 -5
  74. experimaestro/tests/test_multitoken.py +2 -4
  75. experimaestro/tests/{test_subparameters.py → test_partial.py} +25 -25
  76. experimaestro/tests/test_partial_paths.py +81 -138
  77. experimaestro/tests/test_pre_experiment.py +219 -0
  78. experimaestro/tests/test_progress.py +2 -8
  79. experimaestro/tests/test_remote_state.py +560 -99
  80. experimaestro/tests/test_stray_jobs.py +261 -0
  81. experimaestro/tests/test_tasks.py +1 -2
  82. experimaestro/tests/test_token_locking.py +52 -67
  83. experimaestro/tests/test_tokens.py +5 -6
  84. experimaestro/tests/test_transient.py +225 -0
  85. experimaestro/tests/test_workspace_state_provider.py +768 -0
  86. experimaestro/tests/token_reschedule.py +1 -3
  87. experimaestro/tests/utils.py +2 -7
  88. experimaestro/tokens.py +227 -372
  89. experimaestro/tools/diff.py +1 -0
  90. experimaestro/tools/documentation.py +4 -5
  91. experimaestro/tools/jobs.py +1 -2
  92. experimaestro/tui/app.py +438 -1966
  93. experimaestro/tui/app.tcss +162 -0
  94. experimaestro/tui/dialogs.py +172 -0
  95. experimaestro/tui/log_viewer.py +253 -3
  96. experimaestro/tui/messages.py +137 -0
  97. experimaestro/tui/utils.py +54 -0
  98. experimaestro/tui/widgets/__init__.py +23 -0
  99. experimaestro/tui/widgets/experiments.py +468 -0
  100. experimaestro/tui/widgets/global_services.py +238 -0
  101. experimaestro/tui/widgets/jobs.py +972 -0
  102. experimaestro/tui/widgets/log.py +156 -0
  103. experimaestro/tui/widgets/orphans.py +363 -0
  104. experimaestro/tui/widgets/runs.py +185 -0
  105. experimaestro/tui/widgets/services.py +314 -0
  106. experimaestro/tui/widgets/stray_jobs.py +528 -0
  107. experimaestro/utils/__init__.py +1 -1
  108. experimaestro/utils/environment.py +105 -22
  109. experimaestro/utils/fswatcher.py +124 -0
  110. experimaestro/utils/jobs.py +1 -2
  111. experimaestro/utils/jupyter.py +1 -2
  112. experimaestro/utils/logging.py +72 -0
  113. experimaestro/version.py +2 -2
  114. experimaestro/webui/__init__.py +9 -0
  115. experimaestro/webui/app.py +117 -0
  116. experimaestro/{server → webui}/data/index.css +66 -11
  117. experimaestro/webui/data/index.css.map +1 -0
  118. experimaestro/{server → webui}/data/index.js +82763 -87217
  119. experimaestro/webui/data/index.js.map +1 -0
  120. experimaestro/webui/routes/__init__.py +5 -0
  121. experimaestro/webui/routes/auth.py +53 -0
  122. experimaestro/webui/routes/proxy.py +117 -0
  123. experimaestro/webui/server.py +200 -0
  124. experimaestro/webui/state_bridge.py +152 -0
  125. experimaestro/webui/websocket.py +413 -0
  126. {experimaestro-2.0.0b8.dist-info → experimaestro-2.0.0b17.dist-info}/METADATA +5 -6
  127. experimaestro-2.0.0b17.dist-info/RECORD +219 -0
  128. experimaestro/cli/progress.py +0 -269
  129. experimaestro/scheduler/state.py +0 -75
  130. experimaestro/scheduler/state_db.py +0 -437
  131. experimaestro/scheduler/state_sync.py +0 -891
  132. experimaestro/server/__init__.py +0 -467
  133. experimaestro/server/data/index.css.map +0 -1
  134. experimaestro/server/data/index.js.map +0 -1
  135. experimaestro/tests/test_cli_jobs.py +0 -615
  136. experimaestro/tests/test_file_progress.py +0 -425
  137. experimaestro/tests/test_file_progress_integration.py +0 -477
  138. experimaestro/tests/test_state_db.py +0 -434
  139. experimaestro-2.0.0b8.dist-info/RECORD +0 -187
  140. /experimaestro/{server → webui}/data/1815e00441357e01619e.ttf +0 -0
  141. /experimaestro/{server → webui}/data/2463b90d9a316e4e5294.woff2 +0 -0
  142. /experimaestro/{server → webui}/data/2582b0e4bcf85eceead0.ttf +0 -0
  143. /experimaestro/{server → webui}/data/89999bdf5d835c012025.woff2 +0 -0
  144. /experimaestro/{server → webui}/data/914997e1bdfc990d0897.ttf +0 -0
  145. /experimaestro/{server → webui}/data/c210719e60948b211a12.woff2 +0 -0
  146. /experimaestro/{server → webui}/data/favicon.ico +0 -0
  147. /experimaestro/{server → webui}/data/index.html +0 -0
  148. /experimaestro/{server → webui}/data/login.html +0 -0
  149. /experimaestro/{server → webui}/data/manifest.json +0 -0
  150. {experimaestro-2.0.0b8.dist-info → experimaestro-2.0.0b17.dist-info}/WHEEL +0 -0
  151. {experimaestro-2.0.0b8.dist-info → experimaestro-2.0.0b17.dist-info}/entry_points.txt +0 -0
  152. {experimaestro-2.0.0b8.dist-info → experimaestro-2.0.0b17.dist-info}/licenses/LICENSE +0 -0
@@ -14,7 +14,7 @@ from experimaestro import (
14
14
  field,
15
15
  Task,
16
16
  LightweightTask,
17
- subparameters,
17
+ partial,
18
18
  param_group,
19
19
  )
20
20
  from experimaestro.core.objects import (
@@ -742,7 +742,7 @@ def test_identifier_field_default_vs_ignore_default():
742
742
  )
743
743
 
744
744
 
745
- # --- Test partial identifiers (subparameters) ---
745
+ # --- Test partial identifiers (partial) ---
746
746
 
747
747
 
748
748
  # Define parameter groups at module level
@@ -751,7 +751,7 @@ model_group = param_group("model")
751
751
 
752
752
 
753
753
  def get_partial_identifier(config, sp):
754
- """Helper to get partial identifier for a config and subparameters"""
754
+ """Helper to get partial identifier for a config and partial"""
755
755
  return config.__xpm__.get_partial_identifier(sp).all
756
756
 
757
757
 
@@ -759,7 +759,7 @@ def test_partial_identifier_excludes_grouped_params():
759
759
  """Test that partial identifier excludes parameters in excluded groups"""
760
760
 
761
761
  class ConfigWithGroups(Config):
762
- checkpoints = subparameters(exclude_groups=[iter_group])
762
+ checkpoints = partial(exclude_groups=[iter_group])
763
763
  max_iter: Param[int] = field(groups=[iter_group])
764
764
  learning_rate: Param[float]
765
765
 
@@ -772,16 +772,16 @@ def test_partial_identifier_excludes_grouped_params():
772
772
  # Partial identifiers should be the same (max_iter is excluded)
773
773
  pid1 = get_partial_identifier(c1, ConfigWithGroups.checkpoints)
774
774
  pid2 = get_partial_identifier(c2, ConfigWithGroups.checkpoints)
775
- assert (
776
- pid1 == pid2
777
- ), "Partial identifiers should match when only excluded params differ"
775
+ assert pid1 == pid2, (
776
+ "Partial identifiers should match when only excluded params differ"
777
+ )
778
778
 
779
779
 
780
780
  def test_partial_identifier_includes_ungrouped_params():
781
781
  """Test that partial identifier includes parameters not in excluded groups"""
782
782
 
783
783
  class ConfigWithGroups(Config):
784
- checkpoints = subparameters(exclude_groups=[iter_group])
784
+ checkpoints = partial(exclude_groups=[iter_group])
785
785
  max_iter: Param[int] = field(groups=[iter_group])
786
786
  learning_rate: Param[float]
787
787
 
@@ -791,9 +791,9 @@ def test_partial_identifier_includes_ungrouped_params():
791
791
  # Partial identifiers should differ (learning_rate is not excluded)
792
792
  pid1 = get_partial_identifier(c1, ConfigWithGroups.checkpoints)
793
793
  pid2 = get_partial_identifier(c2, ConfigWithGroups.checkpoints)
794
- assert (
795
- pid1 != pid2
796
- ), "Partial identifiers should differ when non-excluded params differ"
794
+ assert pid1 != pid2, (
795
+ "Partial identifiers should differ when non-excluded params differ"
796
+ )
797
797
 
798
798
 
799
799
  def test_partial_identifier_matches_config_without_excluded():
@@ -801,7 +801,7 @@ def test_partial_identifier_matches_config_without_excluded():
801
801
 
802
802
  class ConfigWithIter(Config):
803
803
  __xpmid__ = "test.partial_identifier.config"
804
- checkpoints = subparameters(exclude_groups=[iter_group])
804
+ checkpoints = partial(exclude_groups=[iter_group])
805
805
  max_iter: Param[int] = field(groups=[iter_group])
806
806
  learning_rate: Param[float]
807
807
 
@@ -815,16 +815,16 @@ def test_partial_identifier_matches_config_without_excluded():
815
815
  # The partial identifier of c_with should match full identifier of c_without
816
816
  pid = get_partial_identifier(c_with, ConfigWithIter.checkpoints)
817
817
  full_id = getidentifier(c_without)
818
- assert (
819
- pid == full_id
820
- ), "Partial identifier should match config without excluded fields"
818
+ assert pid == full_id, (
819
+ "Partial identifier should match config without excluded fields"
820
+ )
821
821
 
822
822
 
823
823
  def test_partial_identifier_multiple_groups():
824
824
  """Test partial identifier with parameter in multiple groups"""
825
825
 
826
826
  class ConfigMultiGroup(Config):
827
- checkpoints = subparameters(exclude_groups=[iter_group])
827
+ checkpoints = partial(exclude_groups=[iter_group])
828
828
  # This parameter is in both groups - should be excluded if any group is excluded
829
829
  x: Param[int] = field(groups=[iter_group, model_group])
830
830
  y: Param[float]
@@ -835,9 +835,9 @@ def test_partial_identifier_multiple_groups():
835
835
  # Partial identifiers should be the same (x is in iter_group which is excluded)
836
836
  pid1 = get_partial_identifier(c1, ConfigMultiGroup.checkpoints)
837
837
  pid2 = get_partial_identifier(c2, ConfigMultiGroup.checkpoints)
838
- assert (
839
- pid1 == pid2
840
- ), "Partial identifiers should match when param is in any excluded group"
838
+ assert pid1 == pid2, (
839
+ "Partial identifiers should match when param is in any excluded group"
840
+ )
841
841
 
842
842
 
843
843
  def test_partial_identifier_include_overrides_exclude():
@@ -845,7 +845,7 @@ def test_partial_identifier_include_overrides_exclude():
845
845
 
846
846
  class ConfigIncludeOverride(Config):
847
847
  # iter_group is excluded but also included, so it should NOT be excluded
848
- partial = subparameters(
848
+ partial = partial(
849
849
  exclude_groups=[iter_group, model_group], include_groups=[iter_group]
850
850
  )
851
851
  x: Param[int] = field(groups=[iter_group])
@@ -873,7 +873,7 @@ def test_partial_identifier_exclude_all():
873
873
 
874
874
  class ConfigExcludeAll(Config):
875
875
  # Exclude all, but include model_group
876
- partial = subparameters(exclude_all=True, include_groups=[model_group])
876
+ partial = partial(exclude_all=True, include_groups=[model_group])
877
877
  x: Param[int] = field(groups=[iter_group])
878
878
  y: Param[int] = field(groups=[model_group])
879
879
  z: Param[float] # No group
@@ -895,16 +895,16 @@ def test_partial_identifier_exclude_all():
895
895
  assert pid1 != pid3, "y is included - should have different partial ID"
896
896
 
897
897
  # z is excluded (no group, exclude_all=True) - same partial ID
898
- assert (
899
- pid1 == pid4
900
- ), "z (no group) is excluded by exclude_all - should have same partial ID"
898
+ assert pid1 == pid4, (
899
+ "z (no group) is excluded by exclude_all - should have same partial ID"
900
+ )
901
901
 
902
902
 
903
903
  def test_partial_identifier_exclude_no_group():
904
904
  """Test exclude_no_group option"""
905
905
 
906
906
  class ConfigExcludeNoGroup(Config):
907
- partial = subparameters(exclude_no_group=True)
907
+ partial = partial(exclude_no_group=True)
908
908
  x: Param[int] = field(groups=[iter_group])
909
909
  y: Param[float] # No group
910
910
 
@@ -1,3 +1,4 @@
1
+ # ruff: noqa: T201 - This module uses print for CLI output when run as script
1
2
  # Tests for identifier stability across versions
2
3
 
3
4
  import json
@@ -395,8 +396,7 @@ def test_identifier_stability():
395
396
 
396
397
  if not reference_file.exists():
397
398
  raise FileNotFoundError(
398
- f"Reference file {reference_file} not found. "
399
- f"Run 'python {__file__}' to generate it."
399
+ f"Reference file {reference_file} not found. Run 'python {__file__}' to generate it."
400
400
  )
401
401
 
402
402
  # Load reference identifiers
@@ -417,9 +417,7 @@ def test_identifier_stability():
417
417
  )
418
418
  elif current_id != expected_id:
419
419
  mismatches.append(
420
- f" - {name}: MISMATCH\n"
421
- f" Expected: {expected_id}\n"
422
- f" Current: {current_id}"
420
+ f" - {name}: MISMATCH\n Expected: {expected_id}\n Current: {current_id}"
423
421
  )
424
422
 
425
423
  # Check for removed configurations
@@ -45,13 +45,11 @@ class TokenUsageTracker:
45
45
 
46
46
  if self.memory_used > self.memory_limit:
47
47
  self.violations.append(
48
- f"Memory limit exceeded: {self.memory_used} > {self.memory_limit} "
49
- f"(task {task_id})"
48
+ f"Memory limit exceeded: {self.memory_used} > {self.memory_limit} (task {task_id})"
50
49
  )
51
50
  if self.cpu_used > self.cpu_limit:
52
51
  self.violations.append(
53
- f"CPU limit exceeded: {self.cpu_used} > {self.cpu_limit} "
54
- f"(task {task_id})"
52
+ f"CPU limit exceeded: {self.cpu_used} > {self.cpu_limit} (task {task_id})"
55
53
  )
56
54
 
57
55
  logger.debug(
@@ -1,14 +1,14 @@
1
- """Tests for subparameters (partial identifier computation)"""
1
+ """Tests for partial (partial identifier computation)"""
2
2
 
3
3
  from experimaestro import (
4
4
  Config,
5
5
  Task,
6
6
  Param,
7
7
  field,
8
- subparameters,
8
+ partial,
9
9
  param_group,
10
10
  ParameterGroup,
11
- Subparameters,
11
+ Partial,
12
12
  )
13
13
 
14
14
 
@@ -17,8 +17,8 @@ iter_group = param_group("iter")
17
17
  model_group = param_group("model")
18
18
 
19
19
 
20
- class TestSubparametersBasic:
21
- """Test basic subparameters functionality"""
20
+ class TestPartialBasic:
21
+ """Test basic partial functionality"""
22
22
 
23
23
  def test_param_group_creation(self):
24
24
  """Test creating parameter groups"""
@@ -42,38 +42,38 @@ class TestSubparametersBasic:
42
42
  s2 = {group1, group3}
43
43
  assert len(s2) == 2
44
44
 
45
- def test_subparameters_creation(self):
46
- """Test creating subparameters"""
47
- sp = subparameters(exclude_groups=[iter_group])
48
- assert isinstance(sp, Subparameters)
45
+ def test_partials_creation(self):
46
+ """Test creating partial"""
47
+ sp = partial(exclude_groups=[iter_group])
48
+ assert isinstance(sp, Partial)
49
49
  assert iter_group in sp.exclude_groups
50
50
 
51
- def test_subparameters_is_excluded(self):
51
+ def test_partials_is_excluded(self):
52
52
  """Test is_excluded method"""
53
- sp = subparameters(exclude_groups=[iter_group])
53
+ sp = partial(exclude_groups=[iter_group])
54
54
  assert sp.is_excluded({iter_group}) is True
55
55
  assert sp.is_excluded({model_group}) is False
56
56
  assert sp.is_excluded(set()) is False
57
57
 
58
- def test_subparameters_include_overrides_exclude(self):
58
+ def test_partials_include_overrides_exclude(self):
59
59
  """Test that include_groups overrides exclude_groups"""
60
- sp = subparameters(
60
+ sp = partial(
61
61
  exclude_groups=[iter_group, model_group], include_groups=[iter_group]
62
62
  )
63
63
  # iter_group is in both, but include wins
64
64
  assert sp.is_excluded({iter_group}) is False
65
65
  assert sp.is_excluded({model_group}) is True
66
66
 
67
- def test_subparameters_exclude_all(self):
67
+ def test_partials_exclude_all(self):
68
68
  """Test exclude_all option"""
69
- sp = subparameters(exclude_all=True, include_groups=[model_group])
69
+ sp = partial(exclude_all=True, include_groups=[model_group])
70
70
  assert sp.is_excluded({iter_group}) is True
71
71
  assert sp.is_excluded({model_group}) is False
72
72
  assert sp.is_excluded(set()) is True
73
73
 
74
- def test_subparameters_exclude_no_group(self):
74
+ def test_partials_exclude_no_group(self):
75
75
  """Test exclude_no_group option"""
76
- sp = subparameters(exclude_no_group=True)
76
+ sp = partial(exclude_no_group=True)
77
77
  assert sp.is_excluded(set()) is True
78
78
  assert sp.is_excluded({iter_group}) is False
79
79
 
@@ -94,24 +94,24 @@ class TestPartialIdentifiers:
94
94
  assert iter_group in xpmtype.arguments["x"].groups
95
95
  assert len(xpmtype.arguments["y"].groups) == 0
96
96
 
97
- def test_subparameters_collected_in_objecttype(self):
98
- """Test that subparameters are collected in ObjectType"""
97
+ def test_partials_collected_in_objecttype(self):
98
+ """Test that partial are collected in ObjectType"""
99
99
 
100
100
  class MyTask(Task):
101
- checkpoints = subparameters(exclude_groups=[iter_group])
101
+ checkpoints = partial(exclude_groups=[iter_group])
102
102
  x: Param[int]
103
103
 
104
104
  xpmtype = MyTask.__getxpmtype__()
105
105
  xpmtype.__initialize__()
106
106
 
107
- assert "checkpoints" in xpmtype._subparameters
108
- assert xpmtype._subparameters["checkpoints"].name == "checkpoints"
107
+ assert "checkpoints" in xpmtype._partials
108
+ assert xpmtype._partials["checkpoints"].name == "checkpoints"
109
109
 
110
110
  def test_partial_identifier_same_when_excluded_differs(self):
111
111
  """Test that partial identifiers are the same when only excluded params differ"""
112
112
 
113
113
  class MyTask(Task):
114
- checkpoints = subparameters(exclude_groups=[iter_group])
114
+ checkpoints = partial(exclude_groups=[iter_group])
115
115
  max_iter: Param[int] = field(groups=[iter_group])
116
116
  learning_rate: Param[float]
117
117
 
@@ -130,7 +130,7 @@ class TestPartialIdentifiers:
130
130
  """Test that partial identifiers differ when included params differ"""
131
131
 
132
132
  class MyTask(Task):
133
- checkpoints = subparameters(exclude_groups=[iter_group])
133
+ checkpoints = partial(exclude_groups=[iter_group])
134
134
  max_iter: Param[int] = field(groups=[iter_group])
135
135
  learning_rate: Param[float]
136
136
 
@@ -146,7 +146,7 @@ class TestPartialIdentifiers:
146
146
  """Test partial identifiers with parameters in multiple groups"""
147
147
 
148
148
  class MyTask(Task):
149
- checkpoints = subparameters(exclude_groups=[iter_group])
149
+ checkpoints = partial(exclude_groups=[iter_group])
150
150
  # This parameter is in both groups
151
151
  x: Param[int] = field(groups=[iter_group, model_group])
152
152
  y: Param[float]
@@ -7,7 +7,7 @@ from experimaestro import (
7
7
  Meta,
8
8
  field,
9
9
  PathGenerator,
10
- subparameters,
10
+ partial,
11
11
  param_group,
12
12
  )
13
13
  from experimaestro.scheduler import JobState
@@ -20,10 +20,10 @@ iter_group = param_group("iter")
20
20
 
21
21
 
22
22
  class TaskWithPartial(Task):
23
- """Task that uses subparameters for partial paths"""
23
+ """Task that uses partial for partial paths"""
24
24
 
25
- # Define a subparameters set
26
- checkpoints = subparameters(exclude_groups=[iter_group])
25
+ # Define a partial set
26
+ checkpoints = partial(exclude_groups=[iter_group])
27
27
 
28
28
  # Parameter in iter_group - excluded from partial identifier
29
29
  max_iter: Param[int] = field(groups=[iter_group])
@@ -93,139 +93,82 @@ def test_partial_path_different_for_different_params():
93
93
  assert task1.checkpoint_path != task2.checkpoint_path
94
94
 
95
95
 
96
- def test_partial_registered_in_database():
97
- """Test that partials are registered in the database when jobs are submitted"""
98
- from experimaestro.scheduler.state_provider import WorkspaceStateProvider
99
- from experimaestro.scheduler.state_db import PartialModel, JobPartialModel
100
-
101
- with TemporaryDirectory(prefix="xpm", suffix="partial_db") as workdir:
102
- with TemporaryExperiment("partial_db", workdir=workdir, maxwait=30) as xp:
103
- task = TaskWithPartial.C(max_iter=100, learning_rate=0.1).submit()
104
-
105
- assert task.__xpm__.job.state == JobState.DONE
106
-
107
- # Get the state provider and check database
108
- # Note: Must use read_only=False since the experiment left a singleton
109
- # with read_only=False that hasn't been closed yet
110
- provider = WorkspaceStateProvider.get_instance(workdir, read_only=False)
111
-
112
- try:
113
- with provider.workspace_db.bind_ctx([PartialModel, JobPartialModel]):
114
- # Check that partial is registered
115
- partials = list(PartialModel.select())
116
- assert len(partials) == 1
117
- assert partials[0].subparameters_name == "checkpoints"
118
-
119
- # Check that job is linked to partial
120
- job_partials = list(JobPartialModel.select())
121
- assert len(job_partials) == 1
122
- assert job_partials[0].partial_id == partials[0].partial_id
123
- assert job_partials[0].experiment_id == xp.workdir.name
124
- finally:
125
- provider.close()
126
-
127
-
128
- def test_orphan_partial_cleanup():
129
- """Test that orphan partials are cleaned up when jobs are deleted"""
130
- from experimaestro.scheduler.state_provider import WorkspaceStateProvider
131
- from experimaestro.scheduler.state_db import PartialModel, JobPartialModel
132
-
133
- with TemporaryDirectory(prefix="xpm", suffix="partial_cleanup") as workdir:
134
- with TemporaryExperiment("partial_cleanup", workdir=workdir, maxwait=30) as xp:
135
- task = TaskWithPartial.C(max_iter=100, learning_rate=0.1).submit()
136
-
137
- assert task.__xpm__.job.state == JobState.DONE
138
- checkpoint_path = task.checkpoint_path
139
-
140
- # Verify partial path exists
141
- assert checkpoint_path.exists()
142
-
143
- # Get the state provider
144
- provider = WorkspaceStateProvider.get_instance(workdir, read_only=False)
145
-
146
- try:
147
- # Delete the job
148
- with provider.workspace_db.bind_ctx([PartialModel, JobPartialModel]):
149
- job_partials = list(JobPartialModel.select())
150
- assert len(job_partials) == 1
151
-
152
- # Delete job (this also removes job-partial link)
153
- provider.delete_job(
154
- task.__xpm__.job.identifier,
155
- xp.workdir.name,
156
- xp.run_id,
157
- )
158
-
159
- # Now the partial should be orphaned
160
- orphans = provider.get_orphan_partials()
161
- assert len(orphans) == 1
162
-
163
- # Cleanup orphan partials
164
- deleted = provider.cleanup_orphan_partials(perform=True)
165
- assert len(deleted) == 1
166
-
167
- # Verify partial directory is deleted
168
- assert not checkpoint_path.exists()
169
-
170
- # Verify partial is removed from database
171
- with provider.workspace_db.bind_ctx([PartialModel]):
172
- partials = list(PartialModel.select())
173
- assert len(partials) == 0
174
- finally:
175
- provider.close()
176
-
177
-
178
- def test_shared_partial_not_orphaned():
179
- """Test that partials shared by multiple jobs are not orphaned until all jobs deleted"""
180
- from experimaestro.scheduler.state_provider import WorkspaceStateProvider
181
-
182
- with TemporaryDirectory(prefix="xpm", suffix="partial_shared_cleanup") as workdir:
183
- with TemporaryExperiment(
184
- "partial_shared_cleanup", workdir=workdir, maxwait=30
185
- ) as xp:
186
- # Submit two tasks with same learning_rate (same partial)
187
- task1 = TaskWithPartial.C(max_iter=100, learning_rate=0.1).submit()
188
- task2 = TaskWithPartial.C(max_iter=200, learning_rate=0.1).submit()
189
-
190
- assert task1.__xpm__.job.state == JobState.DONE
191
- assert task2.__xpm__.job.state == JobState.DONE
192
-
193
- # They share the same partial path
194
- checkpoint_path = task1.checkpoint_path
195
- assert checkpoint_path == task2.checkpoint_path
196
- assert checkpoint_path.exists()
197
-
198
- provider = WorkspaceStateProvider.get_instance(workdir, read_only=False)
96
+ def test_partial_concurrent_processes():
97
+ """Test that two processes competing for the same partial are serialized.
98
+
99
+ Similar to test_token_reschedule but for partial locking:
100
+ - Two tasks with different x (excluded param) share the same partial
101
+ - They should run sequentially (one after the other)
102
+ """
103
+ import sys
104
+ import subprocess
105
+ import logging
106
+ import time
107
+ import pytest
108
+ from .utils import TemporaryDirectory, timeout, get_times_frompath
109
+
110
+ with TemporaryDirectory("partial_reschedule") as workdir:
111
+ lockingpath = workdir / "lockingpath"
112
+
113
+ command = [
114
+ sys.executable,
115
+ Path(__file__).parent / "partial_reschedule.py",
116
+ workdir,
117
+ ]
118
+
119
+ ready1 = workdir / "ready.1"
120
+ time1 = workdir / "time.1"
121
+ p1 = subprocess.Popen(
122
+ command + ["1", str(lockingpath), str(ready1), str(time1)]
123
+ )
124
+
125
+ ready2 = workdir / "ready.2"
126
+ time2 = workdir / "time.2"
127
+ p2 = subprocess.Popen(
128
+ command + ["2", str(lockingpath), str(ready2), str(time2)]
129
+ )
199
130
 
200
131
  try:
201
- # Delete first job
202
- provider.delete_job(
203
- task1.__xpm__.job.identifier,
204
- xp.workdir.name,
205
- xp.run_id,
206
- )
207
-
208
- # Partial should NOT be orphaned (still used by task2)
209
- orphans = provider.get_orphan_partials()
210
- assert len(orphans) == 0
211
-
212
- # Partial directory should still exist
213
- assert checkpoint_path.exists()
214
-
215
- # Delete second job
216
- provider.delete_job(
217
- task2.__xpm__.job.identifier,
218
- xp.workdir.name,
219
- xp.run_id,
220
- )
221
-
222
- # Now partial should be orphaned
223
- orphans = provider.get_orphan_partials()
224
- assert len(orphans) == 1
225
-
226
- # Cleanup
227
- deleted = provider.cleanup_orphan_partials(perform=True)
228
- assert len(deleted) == 1
229
- assert not checkpoint_path.exists()
230
- finally:
231
- provider.close()
132
+ with timeout(30):
133
+ logging.info("Waiting for both experiments to be ready")
134
+ # Wait that both processes are ready
135
+ while not ready1.is_file():
136
+ time.sleep(0.01)
137
+ while not ready2.is_file():
138
+ time.sleep(0.01)
139
+
140
+ # Create the locking path to allow tasks to finish
141
+ logging.info(
142
+ "Both processes are ready: allowing tasks to finish by writing in %s",
143
+ lockingpath,
144
+ )
145
+ lockingpath.write_text("Let's go")
146
+
147
+ # Waiting for the output
148
+ logging.info("Waiting for XP1 to finish (%s)", time1)
149
+ while not time1.is_file():
150
+ time.sleep(0.01)
151
+ logging.info("Experiment 1 finished")
152
+
153
+ logging.info("Waiting for XP2 to finish")
154
+ while not time2.is_file():
155
+ time.sleep(0.01)
156
+ logging.info("Experiment 2 finished")
157
+
158
+ time1_val = get_times_frompath(time1)
159
+ time2_val = get_times_frompath(time2)
160
+
161
+ logging.info("%s vs %s", time1_val, time2_val)
162
+ # One should have finished before the other started
163
+ # (they share the same partial, so only one can run at a time)
164
+ assert time1_val > time2_val or time2_val > time1_val
165
+ except TimeoutError:
166
+ p1.terminate()
167
+ p2.terminate()
168
+ pytest.fail("Timeout")
169
+
170
+ except Exception:
171
+ logging.warning("Other exception: killing processes (just in case)")
172
+ p1.terminate()
173
+ p2.terminate()
174
+ raise