params-proto 3.1.2__py3-none-any.whl → 3.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -57,6 +57,32 @@ def _normalize_class_name(class_name: str) -> str:
57
57
  return class_name.replace("-", "").replace("_", "").lower()
58
58
 
59
59
 
60
+ def _get_required_fields(cls) -> List[str]:
61
+ """Get list of required field names (fields without defaults) in order."""
62
+ import dataclasses
63
+
64
+ required = []
65
+
66
+ # Check if it's a dataclass
67
+ if dataclasses.is_dataclass(cls):
68
+ for field in dataclasses.fields(cls):
69
+ has_default = (
70
+ field.default is not dataclasses.MISSING
71
+ or field.default_factory is not dataclasses.MISSING
72
+ )
73
+ if not has_default:
74
+ required.append(field.name)
75
+ else:
76
+ # For regular classes, check annotations and class-level defaults
77
+ annotations = getattr(cls, "__annotations__", {})
78
+ for name in annotations:
79
+ if not hasattr(cls, name):
80
+ # No class-level default
81
+ required.append(name)
82
+
83
+ return required
84
+
85
+
60
86
  def _match_class_by_name(name: str, classes: list) -> Union[type, None]:
61
87
  """Match a string to one of the Union classes.
62
88
 
@@ -164,12 +190,31 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
164
190
  is_bool = annotation == bool
165
191
  prefix_params[kebab_key] = (singleton, param_name, annotation, is_bool)
166
192
 
193
+ # Build unprefixed union attribute map for classes NOT decorated with @proto.prefix
194
+ # Maps attr-name -> (union_param_name, attr_name_underscore)
195
+ # Classes in _SINGLETONS are @proto.prefix decorated and require prefixed attrs
196
+ unprefixed_attrs = {}
197
+ for kebab_name, (param_name, union_classes) in union_params.items():
198
+ for cls in union_classes:
199
+ # Skip if class is a @proto.prefix singleton (requires prefixed attrs)
200
+ is_prefix_class = cls in _SINGLETONS.values()
201
+ if is_prefix_class:
202
+ continue
203
+ if hasattr(cls, "__annotations__"):
204
+ for attr_name in cls.__annotations__:
205
+ kebab_attr = attr_name.replace("_", "-")
206
+ # Map to the union param (first one wins if multiple unions have same attr)
207
+ if kebab_attr not in unprefixed_attrs:
208
+ unprefixed_attrs[kebab_attr] = (param_name, attr_name)
209
+
167
210
  # Parse arguments
168
211
  result = {}
169
212
  prefix_values = {} # (singleton, param_name) -> value
170
213
  positional_values = []
171
214
  union_selections = {} # param_name -> selected_class
172
215
  union_attrs = {} # (param_name, attr_name) -> value
216
+ union_positional = {} # param_name -> [positional_args] for subcommand fields
217
+ current_union_param = None # Track which union we're collecting positionals for
173
218
 
174
219
  args = sys.argv[1:]
175
220
  i = 0
@@ -336,6 +381,17 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
336
381
  i += 2
337
382
  continue
338
383
 
384
+ # Check unprefixed union attrs when cli_prefix=False
385
+ if key in unprefixed_attrs:
386
+ union_param_name, attr_name = unprefixed_attrs[key]
387
+ # Get the value
388
+ if i + 1 >= len(args):
389
+ raise SystemExit(f"error: argument --{key} requires a value")
390
+ value_str = args[i + 1]
391
+ union_attrs[(union_param_name, attr_name)] = value_str
392
+ i += 2
393
+ continue
394
+
339
395
  # Unknown argument
340
396
  raise SystemExit(f"error: unrecognized argument: {arg}")
341
397
 
@@ -349,12 +405,18 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
349
405
  selected_class = _match_class_by_name(arg, union_classes)
350
406
  if selected_class:
351
407
  union_selections[param_name] = selected_class
408
+ current_union_param = param_name # Track for following positionals
409
+ union_positional[param_name] = []
352
410
  matched_union = True
353
411
  i += 1
354
412
  break
355
413
 
356
414
  if not matched_union:
357
- positional_values.append(arg)
415
+ # If we have a current union, add positional to its list
416
+ if current_union_param is not None:
417
+ union_positional[current_union_param].append(arg)
418
+ else:
419
+ positional_values.append(arg)
358
420
  i += 1
359
421
 
360
422
  # Assign positional arguments to required parameters
@@ -405,6 +467,33 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
405
467
  # No annotations, treat as string
406
468
  attrs[attr_name] = value_str
407
469
 
470
+ # Assign positional args to required fields of the selected class
471
+ if param_name in union_positional and union_positional[param_name]:
472
+ positionals = union_positional[param_name]
473
+ required_fields = _get_required_fields(selected_class)
474
+
475
+ for field_idx, field_name in enumerate(required_fields):
476
+ if field_name in attrs:
477
+ # Already set by named arg, skip
478
+ continue
479
+ if field_idx < len(positionals):
480
+ # Get type annotation for conversion
481
+ if hasattr(selected_class, "__annotations__"):
482
+ field_type = selected_class.__annotations__.get(field_name, str)
483
+ try:
484
+ attrs[field_name] = _convert_type(positionals[field_idx], field_type)
485
+ except (ValueError, TypeError):
486
+ raise SystemExit(
487
+ f"error: invalid value for {field_name}: {positionals[field_idx]}"
488
+ )
489
+ else:
490
+ attrs[field_name] = positionals[field_idx]
491
+
492
+ # Check for extra positional args
493
+ if len(positionals) > len(required_fields):
494
+ extra = positionals[len(required_fields):]
495
+ raise SystemExit(f"error: unrecognized arguments: {' '.join(extra)}")
496
+
408
497
  # If selected_class is a proto.prefix singleton, merge its overrides
409
498
  from params_proto.proto import _SINGLETONS, ptype
410
499
 
@@ -419,6 +508,14 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
419
508
  attrs[key] = value
420
509
  break
421
510
 
511
+ # Check for missing required fields
512
+ required_fields = _get_required_fields(selected_class)
513
+ for field_name in required_fields:
514
+ if field_name not in attrs:
515
+ raise SystemExit(
516
+ f"error: {selected_class.__name__} requires argument: {field_name}"
517
+ )
518
+
422
519
  # Instantiate the class with collected attributes
423
520
  try:
424
521
  instance = selected_class(**attrs)
@@ -143,6 +143,77 @@ class ParameterIterator:
143
143
  """Return number of configs (materializes list)."""
144
144
  return len(self.list)
145
145
 
146
+ def save(self, filename="sweep.jsonl", overwrite=True, verbose=True):
147
+ """
148
+ Save parameter configurations to JSONL file.
149
+
150
+ Args:
151
+ filename: Path to output file (str or PathLike)
152
+ overwrite: If True, overwrite existing file; if False, append
153
+ verbose: If True, print save confirmation
154
+
155
+ Example:
156
+ configs = piter @ {"lr": [0.001, 0.01]} * {"batch_size": [32, 64]}
157
+ configs.save("experiment.jsonl")
158
+ """
159
+ import json
160
+ import os
161
+ from urllib import parse
162
+
163
+ # Convert Path objects to string
164
+ filename_str = os.fspath(filename) if hasattr(os, "fspath") else str(filename)
165
+ configs = self.list
166
+
167
+ with open(filename_str, "w" if overwrite else "a+") as f:
168
+ for item in configs:
169
+ f.write(json.dumps(item) + "\n")
170
+
171
+ if verbose:
172
+ try:
173
+ from termcolor import colored as c
174
+
175
+ print(
176
+ c("saved", "blue"),
177
+ c(len(configs), "green"),
178
+ c("items to", "blue"),
179
+ filename_str,
180
+ ".",
181
+ "file://" + parse.quote(os.path.realpath(filename_str)),
182
+ )
183
+ except ImportError:
184
+ print(f"Saved {len(configs)} items to {filename_str}")
185
+
186
+ @staticmethod
187
+ def load(filename="sweep.jsonl"):
188
+ """
189
+ Load parameter configurations from JSONL file.
190
+
191
+ Args:
192
+ filename: Path to input file (str or PathLike)
193
+
194
+ Returns:
195
+ ParameterIterator with loaded configurations
196
+
197
+ Example:
198
+ configs = ParameterIterator.load("experiment.jsonl")
199
+ for config in configs:
200
+ run_experiment(**config)
201
+ """
202
+ import json
203
+ import os
204
+
205
+ # Convert Path objects to string
206
+ filename_str = os.fspath(filename) if hasattr(os, "fspath") else str(filename)
207
+
208
+ configs = []
209
+ with open(filename_str, "r") as f:
210
+ for line in f:
211
+ line = line.strip()
212
+ if line and not line.startswith("//"):
213
+ configs.append(json.loads(line))
214
+
215
+ return ParameterIterator(iter(configs))
216
+
146
217
 
147
218
  class PiterFactory:
148
219
  """
params_proto/proto.py CHANGED
@@ -669,20 +669,32 @@ class ptype(type):
669
669
 
670
670
  # Get the original class
671
671
  original_cls = type.__getattribute__(cls, "__proto_original_class__")
672
+ annotations = getattr(cls, "__proto_annotations__", {})
672
673
 
673
- # Create instance
674
- instance = object.__new__(original_cls)
674
+ # Check if this is a dataclass (has generated __init__ that accepts kwargs)
675
+ is_dataclass = hasattr(original_cls, "__dataclass_fields__")
675
676
 
676
- # Set attributes
677
- annotations = getattr(cls, "__proto_annotations__", {})
678
- for name in annotations.keys():
679
- if name in final_kwargs:
680
- setattr(instance, name, final_kwargs[name])
681
- elif hasattr(cls, "__proto_defaults__") and name in cls.__proto_defaults__:
682
- setattr(instance, name, cls.__proto_defaults__[name])
683
- else:
684
- # Required field
685
- setattr(instance, name, None)
677
+ if is_dataclass:
678
+ # For dataclasses: use the constructor directly
679
+ instance = original_cls(**final_kwargs)
680
+ else:
681
+ # For regular classes: create instance and set attributes manually
682
+ instance = object.__new__(original_cls)
683
+ for name in annotations.keys():
684
+ if name in final_kwargs:
685
+ setattr(instance, name, final_kwargs[name])
686
+ elif hasattr(cls, "__proto_defaults__") and name in cls.__proto_defaults__:
687
+ setattr(instance, name, cls.__proto_defaults__[name])
688
+ else:
689
+ # Required field
690
+ setattr(instance, name, None)
691
+ # Call __post_init__ if defined (dataclasses call it in __init__)
692
+ if hasattr(instance, '__post_init__'):
693
+ instance.__post_init__()
694
+
695
+ # Update the instance's class to the decorated class
696
+ # This allows isinstance(instance, DecoratedClass) to work
697
+ object.__setattr__(instance, "__class__", cls)
686
698
 
687
699
  # Copy methods from original class and wrap to return self
688
700
  for name in dir(original_cls):
@@ -722,10 +734,6 @@ class ptype(type):
722
734
 
723
735
  setattr(instance, name, make_wrapper(method))
724
736
 
725
- # Call __post_init__ if defined (like dataclasses)
726
- if hasattr(instance, '__post_init__'):
727
- instance.__post_init__()
728
-
729
737
  return instance
730
738
 
731
739
 
@@ -1013,6 +1021,10 @@ def cli(obj: Any = None, *, prog: str = None):
1013
1021
  """
1014
1022
  Set up an object as a CLI entry point.
1015
1023
 
1024
+ By default, subcommand attributes don't require prefix (--epochs works).
1025
+ If the subcommand class is decorated with @proto.prefix, prefix is required
1026
+ (--config.epochs).
1027
+
1016
1028
  Args:
1017
1029
  obj: The class, function, or Union type to setup as CLI.
1018
1030
  If None, returns a decorator.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: params-proto
3
- Version: 3.1.2
3
+ Version: 3.2.1
4
4
  Summary: Modern Hyper Parameter Management for Machine Learning
5
5
  Project-URL: Homepage, https://github.com/geyang/params-proto
6
6
  Project-URL: Documentation, https://params-proto.readthedocs.io
@@ -3,15 +3,15 @@ params_proto/app.py,sha256=UySpd1op3M44Szk6Ekyn0fJcnZsQvMTMPdaEybwWsLE,19
3
3
  params_proto/documentation.py,sha256=mIqmcwGWo8tM1BuNzLIwVTzdbQ3qyPus7yWTaOce4dM,8091
4
4
  params_proto/envvar.py,sha256=A87jxSAQ2tjbKLbrm96lblV90zNdtBGCSV6QRe2DrgA,8398
5
5
  params_proto/parse_env_template.py,sha256=mXTvKpNhT2jGr3HpwKw42shd18O0QACmSJn6yWMDdKA,1298
6
- params_proto/proto.py,sha256=81q5EyPYtH7uHCT0L2ydiin5Na2kfJ2VO3LmDfu8AXM,38386
6
+ params_proto/proto.py,sha256=AGUHrU0OmWxqkYvJJqR6SzQTxZz_Zis0XYYaiFentmo,39077
7
7
  params_proto/type_utils.py,sha256=x68rL5m76ZFRKsCRgH_i_4vLpt6ldWEsEAalgacFIH8,7364
8
8
  params_proto/cli/__init__.py,sha256=sLpN3GmaBqd_d0J0nvUNOeGlV74_-jQGW0nDUU34tjA,493
9
9
  params_proto/cli/ansi_help.py,sha256=-1gzbvOpi9GjPlqgiINOYQAfIstzg0-ukv1se88TYCQ,10967
10
- params_proto/cli/cli_parse.py,sha256=aRXkPdfyRFJeXb6W6NwQ1bkkkhjUQlvin1k4GaUtbFg,14848
10
+ params_proto/cli/cli_parse.py,sha256=kVBSgCCVZk32NOfPfyj-pYWXBLyqitFZzsfme8LspcE,18835
11
11
  params_proto/cli/help_gen.py,sha256=Iv9MWC7TJT4_OUWozTfCr8-Nmp_-K8Ohoim_dtsN5AY,12921
12
12
  params_proto/hyper/__init__.py,sha256=4zMnKk9H7NPlaTTRzbL2MC7anzwkBbd2_kW51aYhCPs,157
13
13
  params_proto/hyper/proxies.py,sha256=OMiaKK-gQx-zT1xeCmZevBSDgWUwwkzz0n54A5_wC60,4492
14
- params_proto/hyper/sweep.py,sha256=F6uac1o7_zzahgvnwcuqwn-7C48Xl-RlHKsrk6wulsU,29475
14
+ params_proto/hyper/sweep.py,sha256=QUjsNdpJUoZU2WMF022LRqTh9rx5dgabpRCfs895q2Q,31444
15
15
  params_proto/v1/__init__.py,sha256=NGYZ6Iqicc5M6iyWT6N8FsD0iGLl2by5yZUIsHKhjXw,48
16
16
  params_proto/v1/hyper.py,sha256=zFzViWtSkQdqDJXuan33X2OZwKSHHY39Q5HSNPXl0iQ,2883
17
17
  params_proto/v1/params_proto.py,sha256=g2TMTG0SXyp01gsvd9EO42m28Hr2aS79xzOnMeH_WVk,8728
@@ -20,7 +20,7 @@ params_proto/v2/hyper.py,sha256=onBAkT8Ja8IkeHEOq1AwCdTuBzAnthIe766ZE0lAy-M,1146
20
20
  params_proto/v2/partial.py,sha256=_ovi4NY8goYgHurfYt1OV0E9DSMXGYucjMVIyG1Q_xc,983
21
21
  params_proto/v2/proto.py,sha256=KvinzgzwRQr2bHDNtrU7App2kgAyB-SEfBe4SNYceh0,18995
22
22
  params_proto/v2/utils.py,sha256=5EWvwboZDTsCYfzSED_J6RVFyNLIlf95nIu4p_ZSVxA,3540
23
- params_proto-3.1.2.dist-info/METADATA,sha256=vbXuWaBI3YY_BHQfuhWaebllCwVD5K0Q_myC_XFNku4,8991
24
- params_proto-3.1.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
25
- params_proto-3.1.2.dist-info/licenses/LICENSE.md,sha256=c2qSYi9tUMZtzj9SEsMeKhub5LJUmHwBtDLiIMM5b6U,1526
26
- params_proto-3.1.2.dist-info/RECORD,,
23
+ params_proto-3.2.1.dist-info/METADATA,sha256=IhhYdm8bYV3OEnUYmaRNo2Wb_8Sm58dCEBuuBIa0nbg,8991
24
+ params_proto-3.2.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
25
+ params_proto-3.2.1.dist-info/licenses/LICENSE.md,sha256=c2qSYi9tUMZtzj9SEsMeKhub5LJUmHwBtDLiIMM5b6U,1526
26
+ params_proto-3.2.1.dist-info/RECORD,,