yanex 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yanex/core/config.py ADDED
@@ -0,0 +1,587 @@
1
+ """
2
+ Configuration management for experiments.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import copy
8
+ import re
9
+ from pathlib import Path
10
+ from typing import Any, Optional, Union
11
+
12
+ import yaml
13
+
14
+ from ..utils.exceptions import ConfigError
15
+ from ..utils.validation import validate_config_data
16
+
17
+
18
+ def load_yaml_config(config_path: Path) -> dict[str, Any]:
19
+ """
20
+ Load configuration from YAML file.
21
+
22
+ Args:
23
+ config_path: Path to YAML configuration file
24
+
25
+ Returns:
26
+ Configuration dictionary
27
+
28
+ Raises:
29
+ ConfigError: If config file cannot be loaded or parsed
30
+ """
31
+ if not config_path.exists():
32
+ raise ConfigError(f"Configuration file not found: {config_path}")
33
+
34
+ if not config_path.is_file():
35
+ raise ConfigError(f"Configuration path is not a file: {config_path}")
36
+
37
+ try:
38
+ with config_path.open("r", encoding="utf-8") as f:
39
+ config_data = yaml.safe_load(f) or {}
40
+ except yaml.YAMLError as e:
41
+ raise ConfigError(f"Failed to parse YAML config: {e}") from e
42
+ except Exception as e:
43
+ raise ConfigError(f"Failed to read config file: {e}") from e
44
+
45
+ if not isinstance(config_data, dict):
46
+ raise ConfigError(
47
+ f"Configuration must be a dictionary, got {type(config_data)}"
48
+ )
49
+
50
+ return validate_config_data(config_data)
51
+
52
+
53
+ def save_yaml_config(config_data: dict[str, Any], config_path: Path) -> None:
54
+ """
55
+ Save configuration to YAML file.
56
+
57
+ Args:
58
+ config_data: Configuration dictionary to save
59
+ config_path: Path where to save the configuration
60
+
61
+ Raises:
62
+ ConfigError: If config cannot be saved
63
+ """
64
+ validate_config_data(config_data)
65
+
66
+ try:
67
+ # Ensure parent directory exists
68
+ config_path.parent.mkdir(parents=True, exist_ok=True)
69
+
70
+ with config_path.open("w", encoding="utf-8") as f:
71
+ yaml.safe_dump(
72
+ config_data, f, default_flow_style=False, sort_keys=True, indent=2
73
+ )
74
+ except Exception as e:
75
+ raise ConfigError(f"Failed to save config file: {e}") from e
76
+
77
+
78
+ def parse_param_overrides(param_strings: list[str]) -> dict[str, Any]:
79
+ """
80
+ Parse parameter override strings from CLI.
81
+
82
+ Args:
83
+ param_strings: List of "key=value" strings
84
+
85
+ Returns:
86
+ Dictionary of parsed parameters
87
+
88
+ Raises:
89
+ ConfigError: If parameter format is invalid
90
+ """
91
+ overrides = {}
92
+
93
+ for param_string in param_strings:
94
+ if "=" not in param_string:
95
+ raise ConfigError(
96
+ f"Invalid parameter format: {param_string}. Expected 'key=value'"
97
+ )
98
+
99
+ key, value_str = param_string.split("=", 1)
100
+ key = key.strip()
101
+ value_str = value_str.strip()
102
+
103
+ if not key:
104
+ raise ConfigError(f"Empty parameter key in: {param_string}")
105
+
106
+ # Try to parse value as different types
107
+ parsed_value = _parse_parameter_value(value_str)
108
+
109
+ # Support nested keys like "model.learning_rate=0.01"
110
+ _set_nested_key(overrides, key, parsed_value)
111
+
112
+ return overrides
113
+
114
+
115
+ def _parse_parameter_value(value_str: str) -> Any:
116
+ """
117
+ Parse parameter value string to appropriate Python type.
118
+
119
+ Supports sweep syntax: range(), linspace(), logspace(), list()
120
+
121
+ Args:
122
+ value_str: String value to parse
123
+
124
+ Returns:
125
+ Parsed value with appropriate type (including SweepParameter instances)
126
+ """
127
+ value_str = value_str.strip()
128
+
129
+ # Handle empty string
130
+ if not value_str:
131
+ return ""
132
+
133
+ # Handle null/none
134
+ if value_str.lower() in ("null", "none", "~"):
135
+ return None
136
+
137
+ # Check for sweep syntax first
138
+ sweep_result = _parse_sweep_syntax(value_str)
139
+ if sweep_result is not None:
140
+ return sweep_result
141
+
142
+ # Try to parse as number first (before booleans)
143
+ try:
144
+ # Try integer first
145
+ if "." not in value_str and "e" not in value_str.lower():
146
+ return int(value_str)
147
+ else:
148
+ return float(value_str)
149
+ except ValueError:
150
+ pass
151
+
152
+ # Handle boolean values (after numbers, so "1" and "0" are treated as numbers)
153
+ if value_str.lower() in ("true", "yes", "on"):
154
+ return True
155
+ if value_str.lower() in ("false", "no", "off"):
156
+ return False
157
+
158
+ # Try to parse as JSON-like structures
159
+ if value_str.startswith("[") and value_str.endswith("]"):
160
+ try:
161
+ # Simple list parsing (comma-separated)
162
+ content = value_str[1:-1].strip()
163
+ if not content:
164
+ return []
165
+ items = [
166
+ _parse_parameter_value(item.strip()) for item in content.split(",")
167
+ ]
168
+ return items
169
+ except Exception:
170
+ pass
171
+
172
+ # Return as string
173
+ return value_str
174
+
175
+
176
+ def _parse_sweep_syntax(value_str: str) -> Optional[SweepParameter]:
177
+ """
178
+ Parse sweep syntax into SweepParameter objects.
179
+
180
+ Supported syntax:
181
+ - range(start, stop, step)
182
+ - linspace(start, stop, count)
183
+ - logspace(start, stop, count)
184
+ - list(item1, item2, ...)
185
+
186
+ Args:
187
+ value_str: String value to parse
188
+
189
+ Returns:
190
+ SweepParameter instance or None if not sweep syntax
191
+ """
192
+ # Regular expressions for sweep function parsing
193
+ range_pattern = r"range\(\s*([^,]+)\s*,\s*([^,]+)\s*,\s*([^,]+)\s*\)"
194
+ linspace_pattern = r"linspace\(\s*([^,]+)\s*,\s*([^,]+)\s*,\s*([^,]+)\s*\)"
195
+ logspace_pattern = r"logspace\(\s*([^,]+)\s*,\s*([^,]+)\s*,\s*([^,]+)\s*\)"
196
+ list_pattern = r"list\(\s*([^)]*)\s*\)"
197
+
198
+ # Try range() syntax
199
+ match = re.match(range_pattern, value_str)
200
+ if match:
201
+ try:
202
+ start = _parse_numeric_value(match.group(1))
203
+ stop = _parse_numeric_value(match.group(2))
204
+ step = _parse_numeric_value(match.group(3))
205
+ return RangeSweep(start, stop, step)
206
+ except Exception as e:
207
+ raise ConfigError(f"Invalid range() syntax: {value_str}. Error: {e}")
208
+
209
+ # Try linspace() syntax
210
+ match = re.match(linspace_pattern, value_str)
211
+ if match:
212
+ try:
213
+ start = _parse_numeric_value(match.group(1))
214
+ stop = _parse_numeric_value(match.group(2))
215
+ count = int(_parse_numeric_value(match.group(3)))
216
+ return LinspaceSweep(start, stop, count)
217
+ except Exception as e:
218
+ raise ConfigError(f"Invalid linspace() syntax: {value_str}. Error: {e}")
219
+
220
+ # Try logspace() syntax
221
+ match = re.match(logspace_pattern, value_str)
222
+ if match:
223
+ try:
224
+ start = _parse_numeric_value(match.group(1))
225
+ stop = _parse_numeric_value(match.group(2))
226
+ count = int(_parse_numeric_value(match.group(3)))
227
+ return LogspaceSweep(start, stop, count)
228
+ except Exception as e:
229
+ raise ConfigError(f"Invalid logspace() syntax: {value_str}. Error: {e}")
230
+
231
+ # Try list() syntax
232
+ match = re.match(list_pattern, value_str)
233
+ if match:
234
+ try:
235
+ content = match.group(1).strip()
236
+ if not content:
237
+ raise ConfigError("List sweep cannot be empty")
238
+
239
+ # Parse comma-separated items
240
+ items = []
241
+ for item_str in content.split(","):
242
+ item_str = item_str.strip()
243
+ if not item_str:
244
+ continue
245
+ # Parse each item as a regular parameter value (recursive, but won't match sweep syntax)
246
+ parsed_item = _parse_non_sweep_value(item_str)
247
+ items.append(parsed_item)
248
+
249
+ if not items:
250
+ raise ConfigError("List sweep cannot be empty")
251
+
252
+ return ListSweep(items)
253
+ except Exception as e:
254
+ raise ConfigError(f"Invalid list() syntax: {value_str}. Error: {e}")
255
+
256
+ return None
257
+
258
+
259
+ def _parse_numeric_value(value_str: str) -> Union[int, float]:
260
+ """Parse a string as a numeric value (int or float)."""
261
+ value_str = value_str.strip()
262
+
263
+ try:
264
+ # Try integer first
265
+ if "." not in value_str and "e" not in value_str.lower():
266
+ return int(value_str)
267
+ else:
268
+ return float(value_str)
269
+ except ValueError:
270
+ raise ConfigError(f"Expected numeric value, got: {value_str}")
271
+
272
+
273
+ def _parse_non_sweep_value(value_str: str) -> Any:
274
+ """Parse parameter value without sweep syntax detection."""
275
+ value_str = value_str.strip()
276
+
277
+ # Handle quoted strings
278
+ if (value_str.startswith('"') and value_str.endswith('"')) or (
279
+ value_str.startswith("'") and value_str.endswith("'")
280
+ ):
281
+ return value_str[1:-1]
282
+
283
+ # Handle null/none
284
+ if value_str.lower() in ("null", "none", "~"):
285
+ return None
286
+
287
+ # Try to parse as number
288
+ try:
289
+ if "." not in value_str and "e" not in value_str.lower():
290
+ return int(value_str)
291
+ else:
292
+ return float(value_str)
293
+ except ValueError:
294
+ pass
295
+
296
+ # Handle boolean values
297
+ if value_str.lower() in ("true", "yes", "on"):
298
+ return True
299
+ if value_str.lower() in ("false", "no", "off"):
300
+ return False
301
+
302
+ # Return as string
303
+ return value_str
304
+
305
+
306
+ def _set_nested_key(config_dict: dict[str, Any], key: str, value: Any) -> None:
307
+ """
308
+ Set nested key in configuration dictionary.
309
+
310
+ Args:
311
+ config_dict: Configuration dictionary to modify
312
+ key: Potentially nested key (e.g., "model.learning_rate")
313
+ value: Value to set
314
+ """
315
+ keys = key.split(".")
316
+ current = config_dict
317
+
318
+ # Navigate to the nested location
319
+ for key_part in keys[:-1]:
320
+ if key_part not in current:
321
+ current[key_part] = {}
322
+ elif not isinstance(current[key_part], dict):
323
+ # If intermediate key exists but is not a dict, override it
324
+ current[key_part] = {}
325
+ current = current[key_part]
326
+
327
+ # Set the final value
328
+ current[keys[-1]] = value
329
+
330
+
331
+ def merge_configs(
332
+ base_config: dict[str, Any], override_config: dict[str, Any]
333
+ ) -> dict[str, Any]:
334
+ """
335
+ Merge two configuration dictionaries.
336
+
337
+ Args:
338
+ base_config: Base configuration dictionary
339
+ override_config: Override configuration dictionary
340
+
341
+ Returns:
342
+ Merged configuration dictionary
343
+
344
+ Note:
345
+ Override config takes precedence. Nested dictionaries are merged recursively.
346
+ """
347
+ result = copy.deepcopy(base_config)
348
+
349
+ def merge_recursive(base: dict[str, Any], override: dict[str, Any]) -> None:
350
+ for key, value in override.items():
351
+ if key in base and isinstance(base[key], dict) and isinstance(value, dict):
352
+ merge_recursive(base[key], value)
353
+ else:
354
+ base[key] = copy.deepcopy(value)
355
+
356
+ merge_recursive(result, override_config)
357
+ return result
358
+
359
+
360
+ def resolve_config(
361
+ config_path: Optional[Path] = None,
362
+ param_overrides: Optional[list[str]] = None,
363
+ default_config_name: str = "config.yaml",
364
+ ) -> dict[str, Any]:
365
+ """
366
+ Resolve final configuration from file and parameter overrides.
367
+
368
+ Args:
369
+ config_path: Path to configuration file
370
+ param_overrides: List of parameter override strings
371
+ default_config_name: Default config filename to look for
372
+
373
+ Returns:
374
+ Resolved configuration dictionary
375
+
376
+ Raises:
377
+ ConfigError: If configuration cannot be resolved
378
+ """
379
+ # Start with empty config
380
+ config = {}
381
+
382
+ # Load from file if specified or if default exists
383
+ if config_path is None:
384
+ default_path = Path.cwd() / default_config_name
385
+ if default_path.exists():
386
+ config_path = default_path
387
+
388
+ if config_path is not None:
389
+ config = load_yaml_config(config_path)
390
+
391
+ # Apply parameter overrides
392
+ if param_overrides:
393
+ override_config = parse_param_overrides(param_overrides)
394
+ config = merge_configs(config, override_config)
395
+
396
+ return config
397
+
398
+
399
+ # Parameter Sweep Classes and Functions
400
+
401
+
402
+ class SweepParameter:
403
+ """Base class for parameter sweep definitions."""
404
+
405
+ def generate_values(self) -> list[Any]:
406
+ """Generate list of values for this sweep parameter."""
407
+ raise NotImplementedError
408
+
409
+
410
+ class RangeSweep(SweepParameter):
411
+ """Range-based parameter sweep: range(start, stop, step)"""
412
+
413
+ def __init__(
414
+ self, start: Union[int, float], stop: Union[int, float], step: Union[int, float]
415
+ ):
416
+ self.start = start
417
+ self.stop = stop
418
+ self.step = step
419
+
420
+ if step == 0:
421
+ raise ConfigError("Range step cannot be zero")
422
+ if (stop - start) * step < 0:
423
+ raise ConfigError("Range step direction doesn't match start/stop values")
424
+
425
+ def generate_values(self) -> list[Union[int, float]]:
426
+ """Generate range values."""
427
+ values = []
428
+ current = self.start
429
+
430
+ if self.step > 0:
431
+ while current < self.stop:
432
+ values.append(current)
433
+ current += self.step
434
+ else:
435
+ while current > self.stop:
436
+ values.append(current)
437
+ current += self.step
438
+
439
+ return values
440
+
441
+ def __repr__(self) -> str:
442
+ return f"RangeSweep({self.start}, {self.stop}, {self.step})"
443
+
444
+
445
+ class LinspaceSweep(SweepParameter):
446
+ """Linear space parameter sweep: linspace(start, stop, count)"""
447
+
448
+ def __init__(self, start: Union[int, float], stop: Union[int, float], count: int):
449
+ self.start = start
450
+ self.stop = stop
451
+ self.count = count
452
+
453
+ if count <= 0:
454
+ raise ConfigError("Linspace count must be positive")
455
+
456
+ def generate_values(self) -> list[float]:
457
+ """Generate linearly spaced values."""
458
+ if self.count == 1:
459
+ return [float(self.start)]
460
+
461
+ step = (self.stop - self.start) / (self.count - 1)
462
+ return [self.start + i * step for i in range(self.count)]
463
+
464
+ def __repr__(self) -> str:
465
+ return f"LinspaceSweep({self.start}, {self.stop}, {self.count})"
466
+
467
+
468
+ class LogspaceSweep(SweepParameter):
469
+ """Logarithmic space parameter sweep: logspace(start, stop, count)"""
470
+
471
+ def __init__(self, start: Union[int, float], stop: Union[int, float], count: int):
472
+ self.start = start
473
+ self.stop = stop
474
+ self.count = count
475
+
476
+ if count <= 0:
477
+ raise ConfigError("Logspace count must be positive")
478
+
479
+ def generate_values(self) -> list[float]:
480
+ """Generate logarithmically spaced values."""
481
+ if self.count == 1:
482
+ return [10.0**self.start]
483
+
484
+ step = (self.stop - self.start) / (self.count - 1)
485
+ return [10.0 ** (self.start + i * step) for i in range(self.count)]
486
+
487
+ def __repr__(self) -> str:
488
+ return f"LogspaceSweep({self.start}, {self.stop}, {self.count})"
489
+
490
+
491
+ class ListSweep(SweepParameter):
492
+ """Explicit list parameter sweep: list(item1, item2, ...)"""
493
+
494
+ def __init__(self, items: list[Any]):
495
+ if not items:
496
+ raise ConfigError("List sweep cannot be empty")
497
+ self.items = items
498
+
499
+ def generate_values(self) -> list[Any]:
500
+ """Return the explicit list of values."""
501
+ return self.items.copy()
502
+
503
+ def __repr__(self) -> str:
504
+ return f"ListSweep({self.items})"
505
+
506
+
507
+ def has_sweep_parameters(config: dict[str, Any]) -> bool:
508
+ """
509
+ Check if configuration contains any sweep parameters.
510
+
511
+ Args:
512
+ config: Configuration dictionary to check
513
+
514
+ Returns:
515
+ True if any values are SweepParameter instances
516
+ """
517
+
518
+ def check_dict(d: dict[str, Any]) -> bool:
519
+ for value in d.values():
520
+ if isinstance(value, SweepParameter):
521
+ return True
522
+ elif isinstance(value, dict):
523
+ if check_dict(value):
524
+ return True
525
+ return False
526
+
527
+ return check_dict(config)
528
+
529
+
530
+ def expand_parameter_sweeps(config: dict[str, Any]) -> list[dict[str, Any]]:
531
+ """
532
+ Expand parameter sweeps into individual configurations.
533
+
534
+ Generates cross-product of all sweep parameters while keeping regular parameters.
535
+
536
+ Args:
537
+ config: Configuration dictionary potentially containing SweepParameter instances
538
+
539
+ Returns:
540
+ List of configuration dictionaries with sweep parameters expanded
541
+
542
+ Example:
543
+ Input: {"lr": RangeSweep(0.01, 0.03, 0.01), "batch_size": 32}
544
+ Output: [
545
+ {"lr": 0.01, "batch_size": 32},
546
+ {"lr": 0.02, "batch_size": 32}
547
+ ]
548
+ """
549
+ if not has_sweep_parameters(config):
550
+ return [config]
551
+
552
+ # Find all sweep parameters and their paths
553
+ sweep_params = []
554
+
555
+ def find_sweeps(d: dict[str, Any], path: str = "") -> None:
556
+ for key, value in d.items():
557
+ current_path = f"{path}.{key}" if path else key
558
+
559
+ if isinstance(value, SweepParameter):
560
+ sweep_params.append((current_path, value))
561
+ elif isinstance(value, dict):
562
+ find_sweeps(value, current_path)
563
+
564
+ find_sweeps(config)
565
+
566
+ if not sweep_params:
567
+ return [config]
568
+
569
+ # Generate all combinations using itertools.product
570
+ import itertools
571
+
572
+ sweep_paths, sweep_objects = zip(*sweep_params)
573
+ sweep_value_lists = [sweep_obj.generate_values() for sweep_obj in sweep_objects]
574
+
575
+ # Generate cross-product of all sweep parameter values
576
+ expanded_configs = []
577
+ for value_combination in itertools.product(*sweep_value_lists):
578
+ # Create a deep copy of the original config
579
+ expanded_config = copy.deepcopy(config)
580
+
581
+ # Replace sweep parameters with concrete values
582
+ for path, value in zip(sweep_paths, value_combination):
583
+ _set_nested_key(expanded_config, path, value)
584
+
585
+ expanded_configs.append(expanded_config)
586
+
587
+ return expanded_configs
@@ -0,0 +1,16 @@
1
+ """
2
+ Core constants used throughout yanex.
3
+ """
4
+
5
+ # Valid experiment statuses
6
+ EXPERIMENT_STATUSES = [
7
+ "created",
8
+ "running",
9
+ "completed",
10
+ "failed",
11
+ "cancelled",
12
+ "staged",
13
+ ]
14
+
15
+ # Set version for fast membership testing
16
+ EXPERIMENT_STATUSES_SET = set(EXPERIMENT_STATUSES)