params-proto 3.1.2__py3-none-any.whl → 3.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- params_proto/cli/cli_parse.py +98 -1
- params_proto/hyper/sweep.py +71 -0
- params_proto/proto.py +28 -16
- {params_proto-3.1.2.dist-info → params_proto-3.2.1.dist-info}/METADATA +1 -1
- {params_proto-3.1.2.dist-info → params_proto-3.2.1.dist-info}/RECORD +7 -7
- {params_proto-3.1.2.dist-info → params_proto-3.2.1.dist-info}/WHEEL +0 -0
- {params_proto-3.1.2.dist-info → params_proto-3.2.1.dist-info}/licenses/LICENSE.md +0 -0
params_proto/cli/cli_parse.py
CHANGED
|
@@ -57,6 +57,32 @@ def _normalize_class_name(class_name: str) -> str:
|
|
|
57
57
|
return class_name.replace("-", "").replace("_", "").lower()
|
|
58
58
|
|
|
59
59
|
|
|
60
|
+
def _get_required_fields(cls) -> List[str]:
|
|
61
|
+
"""Get list of required field names (fields without defaults) in order."""
|
|
62
|
+
import dataclasses
|
|
63
|
+
|
|
64
|
+
required = []
|
|
65
|
+
|
|
66
|
+
# Check if it's a dataclass
|
|
67
|
+
if dataclasses.is_dataclass(cls):
|
|
68
|
+
for field in dataclasses.fields(cls):
|
|
69
|
+
has_default = (
|
|
70
|
+
field.default is not dataclasses.MISSING
|
|
71
|
+
or field.default_factory is not dataclasses.MISSING
|
|
72
|
+
)
|
|
73
|
+
if not has_default:
|
|
74
|
+
required.append(field.name)
|
|
75
|
+
else:
|
|
76
|
+
# For regular classes, check annotations and class-level defaults
|
|
77
|
+
annotations = getattr(cls, "__annotations__", {})
|
|
78
|
+
for name in annotations:
|
|
79
|
+
if not hasattr(cls, name):
|
|
80
|
+
# No class-level default
|
|
81
|
+
required.append(name)
|
|
82
|
+
|
|
83
|
+
return required
|
|
84
|
+
|
|
85
|
+
|
|
60
86
|
def _match_class_by_name(name: str, classes: list) -> Union[type, None]:
|
|
61
87
|
"""Match a string to one of the Union classes.
|
|
62
88
|
|
|
@@ -164,12 +190,31 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
|
|
|
164
190
|
is_bool = annotation == bool
|
|
165
191
|
prefix_params[kebab_key] = (singleton, param_name, annotation, is_bool)
|
|
166
192
|
|
|
193
|
+
# Build unprefixed union attribute map for classes NOT decorated with @proto.prefix
|
|
194
|
+
# Maps attr-name -> (union_param_name, attr_name_underscore)
|
|
195
|
+
# Classes in _SINGLETONS are @proto.prefix decorated and require prefixed attrs
|
|
196
|
+
unprefixed_attrs = {}
|
|
197
|
+
for kebab_name, (param_name, union_classes) in union_params.items():
|
|
198
|
+
for cls in union_classes:
|
|
199
|
+
# Skip if class is a @proto.prefix singleton (requires prefixed attrs)
|
|
200
|
+
is_prefix_class = cls in _SINGLETONS.values()
|
|
201
|
+
if is_prefix_class:
|
|
202
|
+
continue
|
|
203
|
+
if hasattr(cls, "__annotations__"):
|
|
204
|
+
for attr_name in cls.__annotations__:
|
|
205
|
+
kebab_attr = attr_name.replace("_", "-")
|
|
206
|
+
# Map to the union param (first one wins if multiple unions have same attr)
|
|
207
|
+
if kebab_attr not in unprefixed_attrs:
|
|
208
|
+
unprefixed_attrs[kebab_attr] = (param_name, attr_name)
|
|
209
|
+
|
|
167
210
|
# Parse arguments
|
|
168
211
|
result = {}
|
|
169
212
|
prefix_values = {} # (singleton, param_name) -> value
|
|
170
213
|
positional_values = []
|
|
171
214
|
union_selections = {} # param_name -> selected_class
|
|
172
215
|
union_attrs = {} # (param_name, attr_name) -> value
|
|
216
|
+
union_positional = {} # param_name -> [positional_args] for subcommand fields
|
|
217
|
+
current_union_param = None # Track which union we're collecting positionals for
|
|
173
218
|
|
|
174
219
|
args = sys.argv[1:]
|
|
175
220
|
i = 0
|
|
@@ -336,6 +381,17 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
|
|
|
336
381
|
i += 2
|
|
337
382
|
continue
|
|
338
383
|
|
|
384
|
+
# Check unprefixed union attrs when cli_prefix=False
|
|
385
|
+
if key in unprefixed_attrs:
|
|
386
|
+
union_param_name, attr_name = unprefixed_attrs[key]
|
|
387
|
+
# Get the value
|
|
388
|
+
if i + 1 >= len(args):
|
|
389
|
+
raise SystemExit(f"error: argument --{key} requires a value")
|
|
390
|
+
value_str = args[i + 1]
|
|
391
|
+
union_attrs[(union_param_name, attr_name)] = value_str
|
|
392
|
+
i += 2
|
|
393
|
+
continue
|
|
394
|
+
|
|
339
395
|
# Unknown argument
|
|
340
396
|
raise SystemExit(f"error: unrecognized argument: {arg}")
|
|
341
397
|
|
|
@@ -349,12 +405,18 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
|
|
|
349
405
|
selected_class = _match_class_by_name(arg, union_classes)
|
|
350
406
|
if selected_class:
|
|
351
407
|
union_selections[param_name] = selected_class
|
|
408
|
+
current_union_param = param_name # Track for following positionals
|
|
409
|
+
union_positional[param_name] = []
|
|
352
410
|
matched_union = True
|
|
353
411
|
i += 1
|
|
354
412
|
break
|
|
355
413
|
|
|
356
414
|
if not matched_union:
|
|
357
|
-
|
|
415
|
+
# If we have a current union, add positional to its list
|
|
416
|
+
if current_union_param is not None:
|
|
417
|
+
union_positional[current_union_param].append(arg)
|
|
418
|
+
else:
|
|
419
|
+
positional_values.append(arg)
|
|
358
420
|
i += 1
|
|
359
421
|
|
|
360
422
|
# Assign positional arguments to required parameters
|
|
@@ -405,6 +467,33 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
|
|
|
405
467
|
# No annotations, treat as string
|
|
406
468
|
attrs[attr_name] = value_str
|
|
407
469
|
|
|
470
|
+
# Assign positional args to required fields of the selected class
|
|
471
|
+
if param_name in union_positional and union_positional[param_name]:
|
|
472
|
+
positionals = union_positional[param_name]
|
|
473
|
+
required_fields = _get_required_fields(selected_class)
|
|
474
|
+
|
|
475
|
+
for field_idx, field_name in enumerate(required_fields):
|
|
476
|
+
if field_name in attrs:
|
|
477
|
+
# Already set by named arg, skip
|
|
478
|
+
continue
|
|
479
|
+
if field_idx < len(positionals):
|
|
480
|
+
# Get type annotation for conversion
|
|
481
|
+
if hasattr(selected_class, "__annotations__"):
|
|
482
|
+
field_type = selected_class.__annotations__.get(field_name, str)
|
|
483
|
+
try:
|
|
484
|
+
attrs[field_name] = _convert_type(positionals[field_idx], field_type)
|
|
485
|
+
except (ValueError, TypeError):
|
|
486
|
+
raise SystemExit(
|
|
487
|
+
f"error: invalid value for {field_name}: {positionals[field_idx]}"
|
|
488
|
+
)
|
|
489
|
+
else:
|
|
490
|
+
attrs[field_name] = positionals[field_idx]
|
|
491
|
+
|
|
492
|
+
# Check for extra positional args
|
|
493
|
+
if len(positionals) > len(required_fields):
|
|
494
|
+
extra = positionals[len(required_fields):]
|
|
495
|
+
raise SystemExit(f"error: unrecognized arguments: {' '.join(extra)}")
|
|
496
|
+
|
|
408
497
|
# If selected_class is a proto.prefix singleton, merge its overrides
|
|
409
498
|
from params_proto.proto import _SINGLETONS, ptype
|
|
410
499
|
|
|
@@ -419,6 +508,14 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
|
|
|
419
508
|
attrs[key] = value
|
|
420
509
|
break
|
|
421
510
|
|
|
511
|
+
# Check for missing required fields
|
|
512
|
+
required_fields = _get_required_fields(selected_class)
|
|
513
|
+
for field_name in required_fields:
|
|
514
|
+
if field_name not in attrs:
|
|
515
|
+
raise SystemExit(
|
|
516
|
+
f"error: {selected_class.__name__} requires argument: {field_name}"
|
|
517
|
+
)
|
|
518
|
+
|
|
422
519
|
# Instantiate the class with collected attributes
|
|
423
520
|
try:
|
|
424
521
|
instance = selected_class(**attrs)
|
params_proto/hyper/sweep.py
CHANGED
|
@@ -143,6 +143,77 @@ class ParameterIterator:
|
|
|
143
143
|
"""Return number of configs (materializes list)."""
|
|
144
144
|
return len(self.list)
|
|
145
145
|
|
|
146
|
+
def save(self, filename="sweep.jsonl", overwrite=True, verbose=True):
|
|
147
|
+
"""
|
|
148
|
+
Save parameter configurations to JSONL file.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
filename: Path to output file (str or PathLike)
|
|
152
|
+
overwrite: If True, overwrite existing file; if False, append
|
|
153
|
+
verbose: If True, print save confirmation
|
|
154
|
+
|
|
155
|
+
Example:
|
|
156
|
+
configs = piter @ {"lr": [0.001, 0.01]} * {"batch_size": [32, 64]}
|
|
157
|
+
configs.save("experiment.jsonl")
|
|
158
|
+
"""
|
|
159
|
+
import json
|
|
160
|
+
import os
|
|
161
|
+
from urllib import parse
|
|
162
|
+
|
|
163
|
+
# Convert Path objects to string
|
|
164
|
+
filename_str = os.fspath(filename) if hasattr(os, "fspath") else str(filename)
|
|
165
|
+
configs = self.list
|
|
166
|
+
|
|
167
|
+
with open(filename_str, "w" if overwrite else "a+") as f:
|
|
168
|
+
for item in configs:
|
|
169
|
+
f.write(json.dumps(item) + "\n")
|
|
170
|
+
|
|
171
|
+
if verbose:
|
|
172
|
+
try:
|
|
173
|
+
from termcolor import colored as c
|
|
174
|
+
|
|
175
|
+
print(
|
|
176
|
+
c("saved", "blue"),
|
|
177
|
+
c(len(configs), "green"),
|
|
178
|
+
c("items to", "blue"),
|
|
179
|
+
filename_str,
|
|
180
|
+
".",
|
|
181
|
+
"file://" + parse.quote(os.path.realpath(filename_str)),
|
|
182
|
+
)
|
|
183
|
+
except ImportError:
|
|
184
|
+
print(f"Saved {len(configs)} items to {filename_str}")
|
|
185
|
+
|
|
186
|
+
@staticmethod
|
|
187
|
+
def load(filename="sweep.jsonl"):
|
|
188
|
+
"""
|
|
189
|
+
Load parameter configurations from JSONL file.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
filename: Path to input file (str or PathLike)
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
ParameterIterator with loaded configurations
|
|
196
|
+
|
|
197
|
+
Example:
|
|
198
|
+
configs = ParameterIterator.load("experiment.jsonl")
|
|
199
|
+
for config in configs:
|
|
200
|
+
run_experiment(**config)
|
|
201
|
+
"""
|
|
202
|
+
import json
|
|
203
|
+
import os
|
|
204
|
+
|
|
205
|
+
# Convert Path objects to string
|
|
206
|
+
filename_str = os.fspath(filename) if hasattr(os, "fspath") else str(filename)
|
|
207
|
+
|
|
208
|
+
configs = []
|
|
209
|
+
with open(filename_str, "r") as f:
|
|
210
|
+
for line in f:
|
|
211
|
+
line = line.strip()
|
|
212
|
+
if line and not line.startswith("//"):
|
|
213
|
+
configs.append(json.loads(line))
|
|
214
|
+
|
|
215
|
+
return ParameterIterator(iter(configs))
|
|
216
|
+
|
|
146
217
|
|
|
147
218
|
class PiterFactory:
|
|
148
219
|
"""
|
params_proto/proto.py
CHANGED
|
@@ -669,20 +669,32 @@ class ptype(type):
|
|
|
669
669
|
|
|
670
670
|
# Get the original class
|
|
671
671
|
original_cls = type.__getattribute__(cls, "__proto_original_class__")
|
|
672
|
+
annotations = getattr(cls, "__proto_annotations__", {})
|
|
672
673
|
|
|
673
|
-
#
|
|
674
|
-
|
|
674
|
+
# Check if this is a dataclass (has generated __init__ that accepts kwargs)
|
|
675
|
+
is_dataclass = hasattr(original_cls, "__dataclass_fields__")
|
|
675
676
|
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
677
|
+
if is_dataclass:
|
|
678
|
+
# For dataclasses: use the constructor directly
|
|
679
|
+
instance = original_cls(**final_kwargs)
|
|
680
|
+
else:
|
|
681
|
+
# For regular classes: create instance and set attributes manually
|
|
682
|
+
instance = object.__new__(original_cls)
|
|
683
|
+
for name in annotations.keys():
|
|
684
|
+
if name in final_kwargs:
|
|
685
|
+
setattr(instance, name, final_kwargs[name])
|
|
686
|
+
elif hasattr(cls, "__proto_defaults__") and name in cls.__proto_defaults__:
|
|
687
|
+
setattr(instance, name, cls.__proto_defaults__[name])
|
|
688
|
+
else:
|
|
689
|
+
# Required field
|
|
690
|
+
setattr(instance, name, None)
|
|
691
|
+
# Call __post_init__ if defined (dataclasses call it in __init__)
|
|
692
|
+
if hasattr(instance, '__post_init__'):
|
|
693
|
+
instance.__post_init__()
|
|
694
|
+
|
|
695
|
+
# Update the instance's class to the decorated class
|
|
696
|
+
# This allows isinstance(instance, DecoratedClass) to work
|
|
697
|
+
object.__setattr__(instance, "__class__", cls)
|
|
686
698
|
|
|
687
699
|
# Copy methods from original class and wrap to return self
|
|
688
700
|
for name in dir(original_cls):
|
|
@@ -722,10 +734,6 @@ class ptype(type):
|
|
|
722
734
|
|
|
723
735
|
setattr(instance, name, make_wrapper(method))
|
|
724
736
|
|
|
725
|
-
# Call __post_init__ if defined (like dataclasses)
|
|
726
|
-
if hasattr(instance, '__post_init__'):
|
|
727
|
-
instance.__post_init__()
|
|
728
|
-
|
|
729
737
|
return instance
|
|
730
738
|
|
|
731
739
|
|
|
@@ -1013,6 +1021,10 @@ def cli(obj: Any = None, *, prog: str = None):
|
|
|
1013
1021
|
"""
|
|
1014
1022
|
Set up an object as a CLI entry point.
|
|
1015
1023
|
|
|
1024
|
+
By default, subcommand attributes don't require prefix (--epochs works).
|
|
1025
|
+
If the subcommand class is decorated with @proto.prefix, prefix is required
|
|
1026
|
+
(--config.epochs).
|
|
1027
|
+
|
|
1016
1028
|
Args:
|
|
1017
1029
|
obj: The class, function, or Union type to setup as CLI.
|
|
1018
1030
|
If None, returns a decorator.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: params-proto
|
|
3
|
-
Version: 3.1
|
|
3
|
+
Version: 3.2.1
|
|
4
4
|
Summary: Modern Hyper Parameter Management for Machine Learning
|
|
5
5
|
Project-URL: Homepage, https://github.com/geyang/params-proto
|
|
6
6
|
Project-URL: Documentation, https://params-proto.readthedocs.io
|
|
@@ -3,15 +3,15 @@ params_proto/app.py,sha256=UySpd1op3M44Szk6Ekyn0fJcnZsQvMTMPdaEybwWsLE,19
|
|
|
3
3
|
params_proto/documentation.py,sha256=mIqmcwGWo8tM1BuNzLIwVTzdbQ3qyPus7yWTaOce4dM,8091
|
|
4
4
|
params_proto/envvar.py,sha256=A87jxSAQ2tjbKLbrm96lblV90zNdtBGCSV6QRe2DrgA,8398
|
|
5
5
|
params_proto/parse_env_template.py,sha256=mXTvKpNhT2jGr3HpwKw42shd18O0QACmSJn6yWMDdKA,1298
|
|
6
|
-
params_proto/proto.py,sha256=
|
|
6
|
+
params_proto/proto.py,sha256=AGUHrU0OmWxqkYvJJqR6SzQTxZz_Zis0XYYaiFentmo,39077
|
|
7
7
|
params_proto/type_utils.py,sha256=x68rL5m76ZFRKsCRgH_i_4vLpt6ldWEsEAalgacFIH8,7364
|
|
8
8
|
params_proto/cli/__init__.py,sha256=sLpN3GmaBqd_d0J0nvUNOeGlV74_-jQGW0nDUU34tjA,493
|
|
9
9
|
params_proto/cli/ansi_help.py,sha256=-1gzbvOpi9GjPlqgiINOYQAfIstzg0-ukv1se88TYCQ,10967
|
|
10
|
-
params_proto/cli/cli_parse.py,sha256=
|
|
10
|
+
params_proto/cli/cli_parse.py,sha256=kVBSgCCVZk32NOfPfyj-pYWXBLyqitFZzsfme8LspcE,18835
|
|
11
11
|
params_proto/cli/help_gen.py,sha256=Iv9MWC7TJT4_OUWozTfCr8-Nmp_-K8Ohoim_dtsN5AY,12921
|
|
12
12
|
params_proto/hyper/__init__.py,sha256=4zMnKk9H7NPlaTTRzbL2MC7anzwkBbd2_kW51aYhCPs,157
|
|
13
13
|
params_proto/hyper/proxies.py,sha256=OMiaKK-gQx-zT1xeCmZevBSDgWUwwkzz0n54A5_wC60,4492
|
|
14
|
-
params_proto/hyper/sweep.py,sha256=
|
|
14
|
+
params_proto/hyper/sweep.py,sha256=QUjsNdpJUoZU2WMF022LRqTh9rx5dgabpRCfs895q2Q,31444
|
|
15
15
|
params_proto/v1/__init__.py,sha256=NGYZ6Iqicc5M6iyWT6N8FsD0iGLl2by5yZUIsHKhjXw,48
|
|
16
16
|
params_proto/v1/hyper.py,sha256=zFzViWtSkQdqDJXuan33X2OZwKSHHY39Q5HSNPXl0iQ,2883
|
|
17
17
|
params_proto/v1/params_proto.py,sha256=g2TMTG0SXyp01gsvd9EO42m28Hr2aS79xzOnMeH_WVk,8728
|
|
@@ -20,7 +20,7 @@ params_proto/v2/hyper.py,sha256=onBAkT8Ja8IkeHEOq1AwCdTuBzAnthIe766ZE0lAy-M,1146
|
|
|
20
20
|
params_proto/v2/partial.py,sha256=_ovi4NY8goYgHurfYt1OV0E9DSMXGYucjMVIyG1Q_xc,983
|
|
21
21
|
params_proto/v2/proto.py,sha256=KvinzgzwRQr2bHDNtrU7App2kgAyB-SEfBe4SNYceh0,18995
|
|
22
22
|
params_proto/v2/utils.py,sha256=5EWvwboZDTsCYfzSED_J6RVFyNLIlf95nIu4p_ZSVxA,3540
|
|
23
|
-
params_proto-3.1.
|
|
24
|
-
params_proto-3.1.
|
|
25
|
-
params_proto-3.1.
|
|
26
|
-
params_proto-3.1.
|
|
23
|
+
params_proto-3.2.1.dist-info/METADATA,sha256=IhhYdm8bYV3OEnUYmaRNo2Wb_8Sm58dCEBuuBIa0nbg,8991
|
|
24
|
+
params_proto-3.2.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
25
|
+
params_proto-3.2.1.dist-info/licenses/LICENSE.md,sha256=c2qSYi9tUMZtzj9SEsMeKhub5LJUmHwBtDLiIMM5b6U,1526
|
|
26
|
+
params_proto-3.2.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|