yanex 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yanex/__init__.py +74 -0
- yanex/api.py +507 -0
- yanex/cli/__init__.py +3 -0
- yanex/cli/_utils.py +114 -0
- yanex/cli/commands/__init__.py +3 -0
- yanex/cli/commands/archive.py +177 -0
- yanex/cli/commands/compare.py +320 -0
- yanex/cli/commands/confirm.py +198 -0
- yanex/cli/commands/delete.py +203 -0
- yanex/cli/commands/list.py +243 -0
- yanex/cli/commands/run.py +625 -0
- yanex/cli/commands/show.py +560 -0
- yanex/cli/commands/unarchive.py +177 -0
- yanex/cli/commands/update.py +282 -0
- yanex/cli/filters/__init__.py +8 -0
- yanex/cli/filters/base.py +286 -0
- yanex/cli/filters/time_utils.py +178 -0
- yanex/cli/formatters/__init__.py +7 -0
- yanex/cli/formatters/console.py +325 -0
- yanex/cli/main.py +45 -0
- yanex/core/__init__.py +3 -0
- yanex/core/comparison.py +549 -0
- yanex/core/config.py +587 -0
- yanex/core/constants.py +16 -0
- yanex/core/environment.py +146 -0
- yanex/core/git_utils.py +153 -0
- yanex/core/manager.py +555 -0
- yanex/core/storage.py +682 -0
- yanex/ui/__init__.py +1 -0
- yanex/ui/compare_table.py +524 -0
- yanex/utils/__init__.py +3 -0
- yanex/utils/exceptions.py +70 -0
- yanex/utils/validation.py +165 -0
- yanex-0.1.0.dist-info/METADATA +251 -0
- yanex-0.1.0.dist-info/RECORD +39 -0
- yanex-0.1.0.dist-info/WHEEL +5 -0
- yanex-0.1.0.dist-info/entry_points.txt +2 -0
- yanex-0.1.0.dist-info/licenses/LICENSE +21 -0
- yanex-0.1.0.dist-info/top_level.txt +1 -0
yanex/core/config.py
ADDED
@@ -0,0 +1,587 @@
|
|
1
|
+
"""
|
2
|
+
Configuration management for experiments.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
import copy
|
8
|
+
import re
|
9
|
+
from pathlib import Path
|
10
|
+
from typing import Any, Optional, Union
|
11
|
+
|
12
|
+
import yaml
|
13
|
+
|
14
|
+
from ..utils.exceptions import ConfigError
|
15
|
+
from ..utils.validation import validate_config_data
|
16
|
+
|
17
|
+
|
18
|
+
def load_yaml_config(config_path: Path) -> dict[str, Any]:
|
19
|
+
"""
|
20
|
+
Load configuration from YAML file.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
config_path: Path to YAML configuration file
|
24
|
+
|
25
|
+
Returns:
|
26
|
+
Configuration dictionary
|
27
|
+
|
28
|
+
Raises:
|
29
|
+
ConfigError: If config file cannot be loaded or parsed
|
30
|
+
"""
|
31
|
+
if not config_path.exists():
|
32
|
+
raise ConfigError(f"Configuration file not found: {config_path}")
|
33
|
+
|
34
|
+
if not config_path.is_file():
|
35
|
+
raise ConfigError(f"Configuration path is not a file: {config_path}")
|
36
|
+
|
37
|
+
try:
|
38
|
+
with config_path.open("r", encoding="utf-8") as f:
|
39
|
+
config_data = yaml.safe_load(f) or {}
|
40
|
+
except yaml.YAMLError as e:
|
41
|
+
raise ConfigError(f"Failed to parse YAML config: {e}") from e
|
42
|
+
except Exception as e:
|
43
|
+
raise ConfigError(f"Failed to read config file: {e}") from e
|
44
|
+
|
45
|
+
if not isinstance(config_data, dict):
|
46
|
+
raise ConfigError(
|
47
|
+
f"Configuration must be a dictionary, got {type(config_data)}"
|
48
|
+
)
|
49
|
+
|
50
|
+
return validate_config_data(config_data)
|
51
|
+
|
52
|
+
|
53
|
+
def save_yaml_config(config_data: dict[str, Any], config_path: Path) -> None:
|
54
|
+
"""
|
55
|
+
Save configuration to YAML file.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
config_data: Configuration dictionary to save
|
59
|
+
config_path: Path where to save the configuration
|
60
|
+
|
61
|
+
Raises:
|
62
|
+
ConfigError: If config cannot be saved
|
63
|
+
"""
|
64
|
+
validate_config_data(config_data)
|
65
|
+
|
66
|
+
try:
|
67
|
+
# Ensure parent directory exists
|
68
|
+
config_path.parent.mkdir(parents=True, exist_ok=True)
|
69
|
+
|
70
|
+
with config_path.open("w", encoding="utf-8") as f:
|
71
|
+
yaml.safe_dump(
|
72
|
+
config_data, f, default_flow_style=False, sort_keys=True, indent=2
|
73
|
+
)
|
74
|
+
except Exception as e:
|
75
|
+
raise ConfigError(f"Failed to save config file: {e}") from e
|
76
|
+
|
77
|
+
|
78
|
+
def parse_param_overrides(param_strings: list[str]) -> dict[str, Any]:
|
79
|
+
"""
|
80
|
+
Parse parameter override strings from CLI.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
param_strings: List of "key=value" strings
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
Dictionary of parsed parameters
|
87
|
+
|
88
|
+
Raises:
|
89
|
+
ConfigError: If parameter format is invalid
|
90
|
+
"""
|
91
|
+
overrides = {}
|
92
|
+
|
93
|
+
for param_string in param_strings:
|
94
|
+
if "=" not in param_string:
|
95
|
+
raise ConfigError(
|
96
|
+
f"Invalid parameter format: {param_string}. Expected 'key=value'"
|
97
|
+
)
|
98
|
+
|
99
|
+
key, value_str = param_string.split("=", 1)
|
100
|
+
key = key.strip()
|
101
|
+
value_str = value_str.strip()
|
102
|
+
|
103
|
+
if not key:
|
104
|
+
raise ConfigError(f"Empty parameter key in: {param_string}")
|
105
|
+
|
106
|
+
# Try to parse value as different types
|
107
|
+
parsed_value = _parse_parameter_value(value_str)
|
108
|
+
|
109
|
+
# Support nested keys like "model.learning_rate=0.01"
|
110
|
+
_set_nested_key(overrides, key, parsed_value)
|
111
|
+
|
112
|
+
return overrides
|
113
|
+
|
114
|
+
|
115
|
+
def _parse_parameter_value(value_str: str) -> Any:
|
116
|
+
"""
|
117
|
+
Parse parameter value string to appropriate Python type.
|
118
|
+
|
119
|
+
Supports sweep syntax: range(), linspace(), logspace(), list()
|
120
|
+
|
121
|
+
Args:
|
122
|
+
value_str: String value to parse
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
Parsed value with appropriate type (including SweepParameter instances)
|
126
|
+
"""
|
127
|
+
value_str = value_str.strip()
|
128
|
+
|
129
|
+
# Handle empty string
|
130
|
+
if not value_str:
|
131
|
+
return ""
|
132
|
+
|
133
|
+
# Handle null/none
|
134
|
+
if value_str.lower() in ("null", "none", "~"):
|
135
|
+
return None
|
136
|
+
|
137
|
+
# Check for sweep syntax first
|
138
|
+
sweep_result = _parse_sweep_syntax(value_str)
|
139
|
+
if sweep_result is not None:
|
140
|
+
return sweep_result
|
141
|
+
|
142
|
+
# Try to parse as number first (before booleans)
|
143
|
+
try:
|
144
|
+
# Try integer first
|
145
|
+
if "." not in value_str and "e" not in value_str.lower():
|
146
|
+
return int(value_str)
|
147
|
+
else:
|
148
|
+
return float(value_str)
|
149
|
+
except ValueError:
|
150
|
+
pass
|
151
|
+
|
152
|
+
# Handle boolean values (after numbers, so "1" and "0" are treated as numbers)
|
153
|
+
if value_str.lower() in ("true", "yes", "on"):
|
154
|
+
return True
|
155
|
+
if value_str.lower() in ("false", "no", "off"):
|
156
|
+
return False
|
157
|
+
|
158
|
+
# Try to parse as JSON-like structures
|
159
|
+
if value_str.startswith("[") and value_str.endswith("]"):
|
160
|
+
try:
|
161
|
+
# Simple list parsing (comma-separated)
|
162
|
+
content = value_str[1:-1].strip()
|
163
|
+
if not content:
|
164
|
+
return []
|
165
|
+
items = [
|
166
|
+
_parse_parameter_value(item.strip()) for item in content.split(",")
|
167
|
+
]
|
168
|
+
return items
|
169
|
+
except Exception:
|
170
|
+
pass
|
171
|
+
|
172
|
+
# Return as string
|
173
|
+
return value_str
|
174
|
+
|
175
|
+
|
176
|
+
def _parse_sweep_syntax(value_str: str) -> Optional[SweepParameter]:
|
177
|
+
"""
|
178
|
+
Parse sweep syntax into SweepParameter objects.
|
179
|
+
|
180
|
+
Supported syntax:
|
181
|
+
- range(start, stop, step)
|
182
|
+
- linspace(start, stop, count)
|
183
|
+
- logspace(start, stop, count)
|
184
|
+
- list(item1, item2, ...)
|
185
|
+
|
186
|
+
Args:
|
187
|
+
value_str: String value to parse
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
SweepParameter instance or None if not sweep syntax
|
191
|
+
"""
|
192
|
+
# Regular expressions for sweep function parsing
|
193
|
+
range_pattern = r"range\(\s*([^,]+)\s*,\s*([^,]+)\s*,\s*([^,]+)\s*\)"
|
194
|
+
linspace_pattern = r"linspace\(\s*([^,]+)\s*,\s*([^,]+)\s*,\s*([^,]+)\s*\)"
|
195
|
+
logspace_pattern = r"logspace\(\s*([^,]+)\s*,\s*([^,]+)\s*,\s*([^,]+)\s*\)"
|
196
|
+
list_pattern = r"list\(\s*([^)]*)\s*\)"
|
197
|
+
|
198
|
+
# Try range() syntax
|
199
|
+
match = re.match(range_pattern, value_str)
|
200
|
+
if match:
|
201
|
+
try:
|
202
|
+
start = _parse_numeric_value(match.group(1))
|
203
|
+
stop = _parse_numeric_value(match.group(2))
|
204
|
+
step = _parse_numeric_value(match.group(3))
|
205
|
+
return RangeSweep(start, stop, step)
|
206
|
+
except Exception as e:
|
207
|
+
raise ConfigError(f"Invalid range() syntax: {value_str}. Error: {e}")
|
208
|
+
|
209
|
+
# Try linspace() syntax
|
210
|
+
match = re.match(linspace_pattern, value_str)
|
211
|
+
if match:
|
212
|
+
try:
|
213
|
+
start = _parse_numeric_value(match.group(1))
|
214
|
+
stop = _parse_numeric_value(match.group(2))
|
215
|
+
count = int(_parse_numeric_value(match.group(3)))
|
216
|
+
return LinspaceSweep(start, stop, count)
|
217
|
+
except Exception as e:
|
218
|
+
raise ConfigError(f"Invalid linspace() syntax: {value_str}. Error: {e}")
|
219
|
+
|
220
|
+
# Try logspace() syntax
|
221
|
+
match = re.match(logspace_pattern, value_str)
|
222
|
+
if match:
|
223
|
+
try:
|
224
|
+
start = _parse_numeric_value(match.group(1))
|
225
|
+
stop = _parse_numeric_value(match.group(2))
|
226
|
+
count = int(_parse_numeric_value(match.group(3)))
|
227
|
+
return LogspaceSweep(start, stop, count)
|
228
|
+
except Exception as e:
|
229
|
+
raise ConfigError(f"Invalid logspace() syntax: {value_str}. Error: {e}")
|
230
|
+
|
231
|
+
# Try list() syntax
|
232
|
+
match = re.match(list_pattern, value_str)
|
233
|
+
if match:
|
234
|
+
try:
|
235
|
+
content = match.group(1).strip()
|
236
|
+
if not content:
|
237
|
+
raise ConfigError("List sweep cannot be empty")
|
238
|
+
|
239
|
+
# Parse comma-separated items
|
240
|
+
items = []
|
241
|
+
for item_str in content.split(","):
|
242
|
+
item_str = item_str.strip()
|
243
|
+
if not item_str:
|
244
|
+
continue
|
245
|
+
# Parse each item as a regular parameter value (recursive, but won't match sweep syntax)
|
246
|
+
parsed_item = _parse_non_sweep_value(item_str)
|
247
|
+
items.append(parsed_item)
|
248
|
+
|
249
|
+
if not items:
|
250
|
+
raise ConfigError("List sweep cannot be empty")
|
251
|
+
|
252
|
+
return ListSweep(items)
|
253
|
+
except Exception as e:
|
254
|
+
raise ConfigError(f"Invalid list() syntax: {value_str}. Error: {e}")
|
255
|
+
|
256
|
+
return None
|
257
|
+
|
258
|
+
|
259
|
+
def _parse_numeric_value(value_str: str) -> Union[int, float]:
|
260
|
+
"""Parse a string as a numeric value (int or float)."""
|
261
|
+
value_str = value_str.strip()
|
262
|
+
|
263
|
+
try:
|
264
|
+
# Try integer first
|
265
|
+
if "." not in value_str and "e" not in value_str.lower():
|
266
|
+
return int(value_str)
|
267
|
+
else:
|
268
|
+
return float(value_str)
|
269
|
+
except ValueError:
|
270
|
+
raise ConfigError(f"Expected numeric value, got: {value_str}")
|
271
|
+
|
272
|
+
|
273
|
+
def _parse_non_sweep_value(value_str: str) -> Any:
|
274
|
+
"""Parse parameter value without sweep syntax detection."""
|
275
|
+
value_str = value_str.strip()
|
276
|
+
|
277
|
+
# Handle quoted strings
|
278
|
+
if (value_str.startswith('"') and value_str.endswith('"')) or (
|
279
|
+
value_str.startswith("'") and value_str.endswith("'")
|
280
|
+
):
|
281
|
+
return value_str[1:-1]
|
282
|
+
|
283
|
+
# Handle null/none
|
284
|
+
if value_str.lower() in ("null", "none", "~"):
|
285
|
+
return None
|
286
|
+
|
287
|
+
# Try to parse as number
|
288
|
+
try:
|
289
|
+
if "." not in value_str and "e" not in value_str.lower():
|
290
|
+
return int(value_str)
|
291
|
+
else:
|
292
|
+
return float(value_str)
|
293
|
+
except ValueError:
|
294
|
+
pass
|
295
|
+
|
296
|
+
# Handle boolean values
|
297
|
+
if value_str.lower() in ("true", "yes", "on"):
|
298
|
+
return True
|
299
|
+
if value_str.lower() in ("false", "no", "off"):
|
300
|
+
return False
|
301
|
+
|
302
|
+
# Return as string
|
303
|
+
return value_str
|
304
|
+
|
305
|
+
|
306
|
+
def _set_nested_key(config_dict: dict[str, Any], key: str, value: Any) -> None:
|
307
|
+
"""
|
308
|
+
Set nested key in configuration dictionary.
|
309
|
+
|
310
|
+
Args:
|
311
|
+
config_dict: Configuration dictionary to modify
|
312
|
+
key: Potentially nested key (e.g., "model.learning_rate")
|
313
|
+
value: Value to set
|
314
|
+
"""
|
315
|
+
keys = key.split(".")
|
316
|
+
current = config_dict
|
317
|
+
|
318
|
+
# Navigate to the nested location
|
319
|
+
for key_part in keys[:-1]:
|
320
|
+
if key_part not in current:
|
321
|
+
current[key_part] = {}
|
322
|
+
elif not isinstance(current[key_part], dict):
|
323
|
+
# If intermediate key exists but is not a dict, override it
|
324
|
+
current[key_part] = {}
|
325
|
+
current = current[key_part]
|
326
|
+
|
327
|
+
# Set the final value
|
328
|
+
current[keys[-1]] = value
|
329
|
+
|
330
|
+
|
331
|
+
def merge_configs(
|
332
|
+
base_config: dict[str, Any], override_config: dict[str, Any]
|
333
|
+
) -> dict[str, Any]:
|
334
|
+
"""
|
335
|
+
Merge two configuration dictionaries.
|
336
|
+
|
337
|
+
Args:
|
338
|
+
base_config: Base configuration dictionary
|
339
|
+
override_config: Override configuration dictionary
|
340
|
+
|
341
|
+
Returns:
|
342
|
+
Merged configuration dictionary
|
343
|
+
|
344
|
+
Note:
|
345
|
+
Override config takes precedence. Nested dictionaries are merged recursively.
|
346
|
+
"""
|
347
|
+
result = copy.deepcopy(base_config)
|
348
|
+
|
349
|
+
def merge_recursive(base: dict[str, Any], override: dict[str, Any]) -> None:
|
350
|
+
for key, value in override.items():
|
351
|
+
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
|
352
|
+
merge_recursive(base[key], value)
|
353
|
+
else:
|
354
|
+
base[key] = copy.deepcopy(value)
|
355
|
+
|
356
|
+
merge_recursive(result, override_config)
|
357
|
+
return result
|
358
|
+
|
359
|
+
|
360
|
+
def resolve_config(
|
361
|
+
config_path: Optional[Path] = None,
|
362
|
+
param_overrides: Optional[list[str]] = None,
|
363
|
+
default_config_name: str = "config.yaml",
|
364
|
+
) -> dict[str, Any]:
|
365
|
+
"""
|
366
|
+
Resolve final configuration from file and parameter overrides.
|
367
|
+
|
368
|
+
Args:
|
369
|
+
config_path: Path to configuration file
|
370
|
+
param_overrides: List of parameter override strings
|
371
|
+
default_config_name: Default config filename to look for
|
372
|
+
|
373
|
+
Returns:
|
374
|
+
Resolved configuration dictionary
|
375
|
+
|
376
|
+
Raises:
|
377
|
+
ConfigError: If configuration cannot be resolved
|
378
|
+
"""
|
379
|
+
# Start with empty config
|
380
|
+
config = {}
|
381
|
+
|
382
|
+
# Load from file if specified or if default exists
|
383
|
+
if config_path is None:
|
384
|
+
default_path = Path.cwd() / default_config_name
|
385
|
+
if default_path.exists():
|
386
|
+
config_path = default_path
|
387
|
+
|
388
|
+
if config_path is not None:
|
389
|
+
config = load_yaml_config(config_path)
|
390
|
+
|
391
|
+
# Apply parameter overrides
|
392
|
+
if param_overrides:
|
393
|
+
override_config = parse_param_overrides(param_overrides)
|
394
|
+
config = merge_configs(config, override_config)
|
395
|
+
|
396
|
+
return config
|
397
|
+
|
398
|
+
|
399
|
+
# Parameter Sweep Classes and Functions
|
400
|
+
|
401
|
+
|
402
|
+
class SweepParameter:
|
403
|
+
"""Base class for parameter sweep definitions."""
|
404
|
+
|
405
|
+
def generate_values(self) -> list[Any]:
|
406
|
+
"""Generate list of values for this sweep parameter."""
|
407
|
+
raise NotImplementedError
|
408
|
+
|
409
|
+
|
410
|
+
class RangeSweep(SweepParameter):
|
411
|
+
"""Range-based parameter sweep: range(start, stop, step)"""
|
412
|
+
|
413
|
+
def __init__(
|
414
|
+
self, start: Union[int, float], stop: Union[int, float], step: Union[int, float]
|
415
|
+
):
|
416
|
+
self.start = start
|
417
|
+
self.stop = stop
|
418
|
+
self.step = step
|
419
|
+
|
420
|
+
if step == 0:
|
421
|
+
raise ConfigError("Range step cannot be zero")
|
422
|
+
if (stop - start) * step < 0:
|
423
|
+
raise ConfigError("Range step direction doesn't match start/stop values")
|
424
|
+
|
425
|
+
def generate_values(self) -> list[Union[int, float]]:
|
426
|
+
"""Generate range values."""
|
427
|
+
values = []
|
428
|
+
current = self.start
|
429
|
+
|
430
|
+
if self.step > 0:
|
431
|
+
while current < self.stop:
|
432
|
+
values.append(current)
|
433
|
+
current += self.step
|
434
|
+
else:
|
435
|
+
while current > self.stop:
|
436
|
+
values.append(current)
|
437
|
+
current += self.step
|
438
|
+
|
439
|
+
return values
|
440
|
+
|
441
|
+
def __repr__(self) -> str:
|
442
|
+
return f"RangeSweep({self.start}, {self.stop}, {self.step})"
|
443
|
+
|
444
|
+
|
445
|
+
class LinspaceSweep(SweepParameter):
|
446
|
+
"""Linear space parameter sweep: linspace(start, stop, count)"""
|
447
|
+
|
448
|
+
def __init__(self, start: Union[int, float], stop: Union[int, float], count: int):
|
449
|
+
self.start = start
|
450
|
+
self.stop = stop
|
451
|
+
self.count = count
|
452
|
+
|
453
|
+
if count <= 0:
|
454
|
+
raise ConfigError("Linspace count must be positive")
|
455
|
+
|
456
|
+
def generate_values(self) -> list[float]:
|
457
|
+
"""Generate linearly spaced values."""
|
458
|
+
if self.count == 1:
|
459
|
+
return [float(self.start)]
|
460
|
+
|
461
|
+
step = (self.stop - self.start) / (self.count - 1)
|
462
|
+
return [self.start + i * step for i in range(self.count)]
|
463
|
+
|
464
|
+
def __repr__(self) -> str:
|
465
|
+
return f"LinspaceSweep({self.start}, {self.stop}, {self.count})"
|
466
|
+
|
467
|
+
|
468
|
+
class LogspaceSweep(SweepParameter):
|
469
|
+
"""Logarithmic space parameter sweep: logspace(start, stop, count)"""
|
470
|
+
|
471
|
+
def __init__(self, start: Union[int, float], stop: Union[int, float], count: int):
|
472
|
+
self.start = start
|
473
|
+
self.stop = stop
|
474
|
+
self.count = count
|
475
|
+
|
476
|
+
if count <= 0:
|
477
|
+
raise ConfigError("Logspace count must be positive")
|
478
|
+
|
479
|
+
def generate_values(self) -> list[float]:
|
480
|
+
"""Generate logarithmically spaced values."""
|
481
|
+
if self.count == 1:
|
482
|
+
return [10.0**self.start]
|
483
|
+
|
484
|
+
step = (self.stop - self.start) / (self.count - 1)
|
485
|
+
return [10.0 ** (self.start + i * step) for i in range(self.count)]
|
486
|
+
|
487
|
+
def __repr__(self) -> str:
|
488
|
+
return f"LogspaceSweep({self.start}, {self.stop}, {self.count})"
|
489
|
+
|
490
|
+
|
491
|
+
class ListSweep(SweepParameter):
|
492
|
+
"""Explicit list parameter sweep: list(item1, item2, ...)"""
|
493
|
+
|
494
|
+
def __init__(self, items: list[Any]):
|
495
|
+
if not items:
|
496
|
+
raise ConfigError("List sweep cannot be empty")
|
497
|
+
self.items = items
|
498
|
+
|
499
|
+
def generate_values(self) -> list[Any]:
|
500
|
+
"""Return the explicit list of values."""
|
501
|
+
return self.items.copy()
|
502
|
+
|
503
|
+
def __repr__(self) -> str:
|
504
|
+
return f"ListSweep({self.items})"
|
505
|
+
|
506
|
+
|
507
|
+
def has_sweep_parameters(config: dict[str, Any]) -> bool:
|
508
|
+
"""
|
509
|
+
Check if configuration contains any sweep parameters.
|
510
|
+
|
511
|
+
Args:
|
512
|
+
config: Configuration dictionary to check
|
513
|
+
|
514
|
+
Returns:
|
515
|
+
True if any values are SweepParameter instances
|
516
|
+
"""
|
517
|
+
|
518
|
+
def check_dict(d: dict[str, Any]) -> bool:
|
519
|
+
for value in d.values():
|
520
|
+
if isinstance(value, SweepParameter):
|
521
|
+
return True
|
522
|
+
elif isinstance(value, dict):
|
523
|
+
if check_dict(value):
|
524
|
+
return True
|
525
|
+
return False
|
526
|
+
|
527
|
+
return check_dict(config)
|
528
|
+
|
529
|
+
|
530
|
+
def expand_parameter_sweeps(config: dict[str, Any]) -> list[dict[str, Any]]:
|
531
|
+
"""
|
532
|
+
Expand parameter sweeps into individual configurations.
|
533
|
+
|
534
|
+
Generates cross-product of all sweep parameters while keeping regular parameters.
|
535
|
+
|
536
|
+
Args:
|
537
|
+
config: Configuration dictionary potentially containing SweepParameter instances
|
538
|
+
|
539
|
+
Returns:
|
540
|
+
List of configuration dictionaries with sweep parameters expanded
|
541
|
+
|
542
|
+
Example:
|
543
|
+
Input: {"lr": RangeSweep(0.01, 0.03, 0.01), "batch_size": 32}
|
544
|
+
Output: [
|
545
|
+
{"lr": 0.01, "batch_size": 32},
|
546
|
+
{"lr": 0.02, "batch_size": 32}
|
547
|
+
]
|
548
|
+
"""
|
549
|
+
if not has_sweep_parameters(config):
|
550
|
+
return [config]
|
551
|
+
|
552
|
+
# Find all sweep parameters and their paths
|
553
|
+
sweep_params = []
|
554
|
+
|
555
|
+
def find_sweeps(d: dict[str, Any], path: str = "") -> None:
|
556
|
+
for key, value in d.items():
|
557
|
+
current_path = f"{path}.{key}" if path else key
|
558
|
+
|
559
|
+
if isinstance(value, SweepParameter):
|
560
|
+
sweep_params.append((current_path, value))
|
561
|
+
elif isinstance(value, dict):
|
562
|
+
find_sweeps(value, current_path)
|
563
|
+
|
564
|
+
find_sweeps(config)
|
565
|
+
|
566
|
+
if not sweep_params:
|
567
|
+
return [config]
|
568
|
+
|
569
|
+
# Generate all combinations using itertools.product
|
570
|
+
import itertools
|
571
|
+
|
572
|
+
sweep_paths, sweep_objects = zip(*sweep_params)
|
573
|
+
sweep_value_lists = [sweep_obj.generate_values() for sweep_obj in sweep_objects]
|
574
|
+
|
575
|
+
# Generate cross-product of all sweep parameter values
|
576
|
+
expanded_configs = []
|
577
|
+
for value_combination in itertools.product(*sweep_value_lists):
|
578
|
+
# Create a deep copy of the original config
|
579
|
+
expanded_config = copy.deepcopy(config)
|
580
|
+
|
581
|
+
# Replace sweep parameters with concrete values
|
582
|
+
for path, value in zip(sweep_paths, value_combination):
|
583
|
+
_set_nested_key(expanded_config, path, value)
|
584
|
+
|
585
|
+
expanded_configs.append(expanded_config)
|
586
|
+
|
587
|
+
return expanded_configs
|
yanex/core/constants.py
ADDED
@@ -0,0 +1,16 @@
|
|
1
|
+
"""
|
2
|
+
Core constants used throughout yanex.
|
3
|
+
"""
|
4
|
+
|
5
|
+
# Valid experiment statuses
|
6
|
+
EXPERIMENT_STATUSES = [
|
7
|
+
"created",
|
8
|
+
"running",
|
9
|
+
"completed",
|
10
|
+
"failed",
|
11
|
+
"cancelled",
|
12
|
+
"staged",
|
13
|
+
]
|
14
|
+
|
15
|
+
# Set version for fast membership testing
|
16
|
+
EXPERIMENT_STATUSES_SET = set(EXPERIMENT_STATUSES)
|