params-proto 3.2.0__py3-none-any.whl → 3.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- params_proto/cli/cli_parse.py +70 -1
- {params_proto-3.2.0.dist-info → params_proto-3.2.1.dist-info}/METADATA +1 -1
- {params_proto-3.2.0.dist-info → params_proto-3.2.1.dist-info}/RECORD +5 -5
- {params_proto-3.2.0.dist-info → params_proto-3.2.1.dist-info}/WHEEL +0 -0
- {params_proto-3.2.0.dist-info → params_proto-3.2.1.dist-info}/licenses/LICENSE.md +0 -0
params_proto/cli/cli_parse.py
CHANGED
|
@@ -57,6 +57,32 @@ def _normalize_class_name(class_name: str) -> str:
|
|
|
57
57
|
return class_name.replace("-", "").replace("_", "").lower()
|
|
58
58
|
|
|
59
59
|
|
|
60
|
+
def _get_required_fields(cls) -> List[str]:
|
|
61
|
+
"""Get list of required field names (fields without defaults) in order."""
|
|
62
|
+
import dataclasses
|
|
63
|
+
|
|
64
|
+
required = []
|
|
65
|
+
|
|
66
|
+
# Check if it's a dataclass
|
|
67
|
+
if dataclasses.is_dataclass(cls):
|
|
68
|
+
for field in dataclasses.fields(cls):
|
|
69
|
+
has_default = (
|
|
70
|
+
field.default is not dataclasses.MISSING
|
|
71
|
+
or field.default_factory is not dataclasses.MISSING
|
|
72
|
+
)
|
|
73
|
+
if not has_default:
|
|
74
|
+
required.append(field.name)
|
|
75
|
+
else:
|
|
76
|
+
# For regular classes, check annotations and class-level defaults
|
|
77
|
+
annotations = getattr(cls, "__annotations__", {})
|
|
78
|
+
for name in annotations:
|
|
79
|
+
if not hasattr(cls, name):
|
|
80
|
+
# No class-level default
|
|
81
|
+
required.append(name)
|
|
82
|
+
|
|
83
|
+
return required
|
|
84
|
+
|
|
85
|
+
|
|
60
86
|
def _match_class_by_name(name: str, classes: list) -> Union[type, None]:
|
|
61
87
|
"""Match a string to one of the Union classes.
|
|
62
88
|
|
|
@@ -187,6 +213,8 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
|
|
|
187
213
|
positional_values = []
|
|
188
214
|
union_selections = {} # param_name -> selected_class
|
|
189
215
|
union_attrs = {} # (param_name, attr_name) -> value
|
|
216
|
+
union_positional = {} # param_name -> [positional_args] for subcommand fields
|
|
217
|
+
current_union_param = None # Track which union we're collecting positionals for
|
|
190
218
|
|
|
191
219
|
args = sys.argv[1:]
|
|
192
220
|
i = 0
|
|
@@ -377,12 +405,18 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
|
|
|
377
405
|
selected_class = _match_class_by_name(arg, union_classes)
|
|
378
406
|
if selected_class:
|
|
379
407
|
union_selections[param_name] = selected_class
|
|
408
|
+
current_union_param = param_name # Track for following positionals
|
|
409
|
+
union_positional[param_name] = []
|
|
380
410
|
matched_union = True
|
|
381
411
|
i += 1
|
|
382
412
|
break
|
|
383
413
|
|
|
384
414
|
if not matched_union:
|
|
385
|
-
|
|
415
|
+
# If we have a current union, add positional to its list
|
|
416
|
+
if current_union_param is not None:
|
|
417
|
+
union_positional[current_union_param].append(arg)
|
|
418
|
+
else:
|
|
419
|
+
positional_values.append(arg)
|
|
386
420
|
i += 1
|
|
387
421
|
|
|
388
422
|
# Assign positional arguments to required parameters
|
|
@@ -433,6 +467,33 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
|
|
|
433
467
|
# No annotations, treat as string
|
|
434
468
|
attrs[attr_name] = value_str
|
|
435
469
|
|
|
470
|
+
# Assign positional args to required fields of the selected class
|
|
471
|
+
if param_name in union_positional and union_positional[param_name]:
|
|
472
|
+
positionals = union_positional[param_name]
|
|
473
|
+
required_fields = _get_required_fields(selected_class)
|
|
474
|
+
|
|
475
|
+
for field_idx, field_name in enumerate(required_fields):
|
|
476
|
+
if field_name in attrs:
|
|
477
|
+
# Already set by named arg, skip
|
|
478
|
+
continue
|
|
479
|
+
if field_idx < len(positionals):
|
|
480
|
+
# Get type annotation for conversion
|
|
481
|
+
if hasattr(selected_class, "__annotations__"):
|
|
482
|
+
field_type = selected_class.__annotations__.get(field_name, str)
|
|
483
|
+
try:
|
|
484
|
+
attrs[field_name] = _convert_type(positionals[field_idx], field_type)
|
|
485
|
+
except (ValueError, TypeError):
|
|
486
|
+
raise SystemExit(
|
|
487
|
+
f"error: invalid value for {field_name}: {positionals[field_idx]}"
|
|
488
|
+
)
|
|
489
|
+
else:
|
|
490
|
+
attrs[field_name] = positionals[field_idx]
|
|
491
|
+
|
|
492
|
+
# Check for extra positional args
|
|
493
|
+
if len(positionals) > len(required_fields):
|
|
494
|
+
extra = positionals[len(required_fields):]
|
|
495
|
+
raise SystemExit(f"error: unrecognized arguments: {' '.join(extra)}")
|
|
496
|
+
|
|
436
497
|
# If selected_class is a proto.prefix singleton, merge its overrides
|
|
437
498
|
from params_proto.proto import _SINGLETONS, ptype
|
|
438
499
|
|
|
@@ -447,6 +508,14 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
|
|
|
447
508
|
attrs[key] = value
|
|
448
509
|
break
|
|
449
510
|
|
|
511
|
+
# Check for missing required fields
|
|
512
|
+
required_fields = _get_required_fields(selected_class)
|
|
513
|
+
for field_name in required_fields:
|
|
514
|
+
if field_name not in attrs:
|
|
515
|
+
raise SystemExit(
|
|
516
|
+
f"error: {selected_class.__name__} requires argument: {field_name}"
|
|
517
|
+
)
|
|
518
|
+
|
|
450
519
|
# Instantiate the class with collected attributes
|
|
451
520
|
try:
|
|
452
521
|
instance = selected_class(**attrs)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: params-proto
|
|
3
|
-
Version: 3.2.
|
|
3
|
+
Version: 3.2.1
|
|
4
4
|
Summary: Modern Hyper Parameter Management for Machine Learning
|
|
5
5
|
Project-URL: Homepage, https://github.com/geyang/params-proto
|
|
6
6
|
Project-URL: Documentation, https://params-proto.readthedocs.io
|
|
@@ -7,7 +7,7 @@ params_proto/proto.py,sha256=AGUHrU0OmWxqkYvJJqR6SzQTxZz_Zis0XYYaiFentmo,39077
|
|
|
7
7
|
params_proto/type_utils.py,sha256=x68rL5m76ZFRKsCRgH_i_4vLpt6ldWEsEAalgacFIH8,7364
|
|
8
8
|
params_proto/cli/__init__.py,sha256=sLpN3GmaBqd_d0J0nvUNOeGlV74_-jQGW0nDUU34tjA,493
|
|
9
9
|
params_proto/cli/ansi_help.py,sha256=-1gzbvOpi9GjPlqgiINOYQAfIstzg0-ukv1se88TYCQ,10967
|
|
10
|
-
params_proto/cli/cli_parse.py,sha256=
|
|
10
|
+
params_proto/cli/cli_parse.py,sha256=kVBSgCCVZk32NOfPfyj-pYWXBLyqitFZzsfme8LspcE,18835
|
|
11
11
|
params_proto/cli/help_gen.py,sha256=Iv9MWC7TJT4_OUWozTfCr8-Nmp_-K8Ohoim_dtsN5AY,12921
|
|
12
12
|
params_proto/hyper/__init__.py,sha256=4zMnKk9H7NPlaTTRzbL2MC7anzwkBbd2_kW51aYhCPs,157
|
|
13
13
|
params_proto/hyper/proxies.py,sha256=OMiaKK-gQx-zT1xeCmZevBSDgWUwwkzz0n54A5_wC60,4492
|
|
@@ -20,7 +20,7 @@ params_proto/v2/hyper.py,sha256=onBAkT8Ja8IkeHEOq1AwCdTuBzAnthIe766ZE0lAy-M,1146
|
|
|
20
20
|
params_proto/v2/partial.py,sha256=_ovi4NY8goYgHurfYt1OV0E9DSMXGYucjMVIyG1Q_xc,983
|
|
21
21
|
params_proto/v2/proto.py,sha256=KvinzgzwRQr2bHDNtrU7App2kgAyB-SEfBe4SNYceh0,18995
|
|
22
22
|
params_proto/v2/utils.py,sha256=5EWvwboZDTsCYfzSED_J6RVFyNLIlf95nIu4p_ZSVxA,3540
|
|
23
|
-
params_proto-3.2.
|
|
24
|
-
params_proto-3.2.
|
|
25
|
-
params_proto-3.2.
|
|
26
|
-
params_proto-3.2.
|
|
23
|
+
params_proto-3.2.1.dist-info/METADATA,sha256=IhhYdm8bYV3OEnUYmaRNo2Wb_8Sm58dCEBuuBIa0nbg,8991
|
|
24
|
+
params_proto-3.2.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
25
|
+
params_proto-3.2.1.dist-info/licenses/LICENSE.md,sha256=c2qSYi9tUMZtzj9SEsMeKhub5LJUmHwBtDLiIMM5b6U,1526
|
|
26
|
+
params_proto-3.2.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|