params-proto 3.2.3__py3-none-any.whl → 3.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -83,6 +83,101 @@ def _get_required_fields(cls) -> List[str]:
83
83
  return required
84
84
 
85
85
 
86
+ def _is_nested_dataclass(annotation) -> bool:
87
+ """Check if annotation is a nested dataclass type."""
88
+ import dataclasses
89
+
90
+ # Skip primitive types and common non-dataclass types
91
+ if annotation in {int, str, float, bool, list, dict, tuple, set, Path, type(None)}:
92
+ return False
93
+
94
+ # Check if it's a type and a dataclass
95
+ if isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
96
+ return True
97
+
98
+ return False
99
+
100
+
101
+ def _get_nested_attrs(cls, prefix: str = "") -> Dict[str, tuple]:
102
+ """Recursively get all nested dataclass attributes.
103
+
104
+ Returns dict mapping dotted-kebab-name -> (dotted_underscore_path, leaf_type)
105
+
106
+ Example for TrainConfig with nested ModelConfig:
107
+ {
108
+ "model.hidden-size": ("model.hidden_size", int),
109
+ "model.num-layers": ("model.num_layers", int),
110
+ }
111
+ """
112
+ import dataclasses
113
+
114
+ result = {}
115
+ annotations = getattr(cls, "__annotations__", {})
116
+
117
+ for attr_name, attr_type in annotations.items():
118
+ kebab_attr = attr_name.replace("_", "-")
119
+ full_kebab = f"{prefix}{kebab_attr}" if prefix else kebab_attr
120
+ full_underscore = f"{prefix.replace('-', '_')}{attr_name}" if prefix else attr_name
121
+
122
+ if _is_nested_dataclass(attr_type):
123
+ # Recursively get nested attributes
124
+ nested = _get_nested_attrs(attr_type, prefix=f"{full_kebab}.")
125
+ result.update(nested)
126
+ else:
127
+ # Leaf attribute
128
+ result[full_kebab] = (full_underscore, attr_type)
129
+
130
+ return result
131
+
132
+
133
+ def _set_nested_value(d: dict, path: str, value: Any) -> None:
134
+ """Set a value in a nested dict using dot notation path.
135
+
136
+ Example: _set_nested_value({}, "model.hidden_size", 512)
137
+ Results in {"model": {"hidden_size": 512}}
138
+ """
139
+ parts = path.split(".")
140
+ current = d
141
+ for part in parts[:-1]:
142
+ if part not in current:
143
+ current[part] = {}
144
+ current = current[part]
145
+ current[parts[-1]] = value
146
+
147
+
148
+ def _build_nested_instance(cls, flat_attrs: dict, nested_attrs: dict):
149
+ """Build a dataclass instance with nested dataclass fields.
150
+
151
+ Args:
152
+ cls: The dataclass class to instantiate
153
+ flat_attrs: Dict of top-level attribute values
154
+ nested_attrs: Dict of nested dicts for nested dataclass fields
155
+
156
+ Returns:
157
+ Instance of cls with nested dataclasses properly constructed
158
+ """
159
+ import dataclasses
160
+
161
+ final_attrs = dict(flat_attrs)
162
+ annotations = getattr(cls, "__annotations__", {})
163
+
164
+ for attr_name, attr_type in annotations.items():
165
+ if attr_name in nested_attrs and _is_nested_dataclass(attr_type):
166
+ # Recursively build nested dataclass
167
+ nested_data = nested_attrs[attr_name]
168
+ # Separate flat and nested for the nested class
169
+ nested_flat = {}
170
+ nested_nested = {}
171
+ for k, v in nested_data.items():
172
+ if isinstance(v, dict):
173
+ nested_nested[k] = v
174
+ else:
175
+ nested_flat[k] = v
176
+ final_attrs[attr_name] = _build_nested_instance(attr_type, nested_flat, nested_nested)
177
+
178
+ return cls(**final_attrs)
179
+
180
+
86
181
  def _match_class_by_name(name: str, classes: list) -> Union[type, None]:
87
182
  """Match a string to one of the Union classes.
88
183
 
@@ -193,6 +288,7 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
193
288
  # Build unprefixed union attribute map for classes NOT decorated with @proto.prefix
194
289
  # Maps attr-name -> (union_param_name, attr_name_underscore)
195
290
  # Classes in _SINGLETONS are @proto.prefix decorated and require prefixed attrs
291
+ # Also includes nested dataclass attributes (e.g., "model.hidden-size")
196
292
  unprefixed_attrs = {}
197
293
  for kebab_name, (param_name, union_classes) in union_params.items():
198
294
  for cls in union_classes:
@@ -201,12 +297,20 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
201
297
  if is_prefix_class:
202
298
  continue
203
299
  if hasattr(cls, "__annotations__"):
204
- for attr_name in cls.__annotations__:
300
+ for attr_name, attr_type in cls.__annotations__.items():
205
301
  kebab_attr = attr_name.replace("_", "-")
206
302
  # Map to the union param (first one wins if multiple unions have same attr)
207
303
  if kebab_attr not in unprefixed_attrs:
208
304
  unprefixed_attrs[kebab_attr] = (param_name, attr_name)
209
305
 
306
+ # Check for nested dataclass and add its attributes
307
+ if _is_nested_dataclass(attr_type):
308
+ nested_attrs = _get_nested_attrs(attr_type, prefix=f"{kebab_attr}.")
309
+ for nested_kebab, (nested_path, nested_type) in nested_attrs.items():
310
+ full_path = f"{attr_name}.{nested_path.split('.', 1)[1] if '.' in nested_path else nested_path}"
311
+ if nested_kebab not in unprefixed_attrs:
312
+ unprefixed_attrs[nested_kebab] = (param_name, full_path)
313
+
210
314
  # Parse arguments
211
315
  result = {}
212
316
  prefix_values = {} # (singleton, param_name) -> value
@@ -451,21 +555,65 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
451
555
  # Instantiate Union classes with collected attributes
452
556
  for param_name, selected_class in union_selections.items():
453
557
  # Collect attributes for this Union parameter
454
- attrs = {}
455
- for (union_param, attr_name), value_str in union_attrs.items():
558
+ # Separate flat (top-level) and nested attributes
559
+ flat_attrs = {}
560
+ nested_attrs = {} # For nested dataclass fields
561
+
562
+ for (union_param, attr_path), value_str in union_attrs.items():
456
563
  if union_param == param_name:
457
- # Get the type annotation for this attribute
458
- if hasattr(selected_class, "__annotations__"):
459
- attr_type = selected_class.__annotations__.get(attr_name, str)
460
- try:
461
- attrs[attr_name] = _convert_type(value_str, attr_type)
462
- except (ValueError, TypeError):
463
- raise SystemExit(
464
- f"error: invalid value for --{param_name.replace('_', '-')}.{attr_name.replace('_', '-')}: {value_str}"
465
- )
564
+ # Check if this is a nested path (contains dots)
565
+ if "." in attr_path:
566
+ # Nested attribute like "model.hidden_size"
567
+ parts = attr_path.split(".")
568
+ top_level = parts[0]
569
+ rest_path = ".".join(parts[1:])
570
+
571
+ # Get the type of the nested field
572
+ if hasattr(selected_class, "__annotations__"):
573
+ top_type = selected_class.__annotations__.get(top_level)
574
+ if top_type and _is_nested_dataclass(top_type):
575
+ # Find the leaf type by traversing the path
576
+ current_type = top_type
577
+ for part in parts[1:]:
578
+ if hasattr(current_type, "__annotations__"):
579
+ current_type = current_type.__annotations__.get(part, str)
580
+ else:
581
+ current_type = str
582
+ break
583
+
584
+ try:
585
+ value = _convert_type(value_str, current_type)
586
+ except (ValueError, TypeError):
587
+ raise SystemExit(
588
+ f"error: invalid value for --{attr_path.replace('_', '-')}: {value_str}"
589
+ )
590
+
591
+ # Store in nested structure
592
+ if top_level not in nested_attrs:
593
+ nested_attrs[top_level] = {}
594
+ _set_nested_value(nested_attrs[top_level], rest_path, value)
595
+ continue
596
+
597
+ # Fallback: treat as string
598
+ if top_level not in nested_attrs:
599
+ nested_attrs[top_level] = {}
600
+ _set_nested_value(nested_attrs[top_level], rest_path, value_str)
466
601
  else:
467
- # No annotations, treat as string
468
- attrs[attr_name] = value_str
602
+ # Top-level attribute
603
+ if hasattr(selected_class, "__annotations__"):
604
+ attr_type = selected_class.__annotations__.get(attr_path, str)
605
+ try:
606
+ flat_attrs[attr_path] = _convert_type(value_str, attr_type)
607
+ except (ValueError, TypeError):
608
+ raise SystemExit(
609
+ f"error: invalid value for --{param_name.replace('_', '-')}.{attr_path.replace('_', '-')}: {value_str}"
610
+ )
611
+ else:
612
+ # No annotations, treat as string
613
+ flat_attrs[attr_path] = value_str
614
+
615
+ # Merge flat_attrs into attrs for compatibility with existing code
616
+ attrs = flat_attrs
469
617
 
470
618
  # Assign positional args to required fields of the selected class
471
619
  if param_name in union_positional and union_positional[param_name]:
@@ -508,17 +656,21 @@ def parse_cli_args(wrapper) -> Dict[str, Any]:
508
656
  attrs[key] = value
509
657
  break
510
658
 
511
- # Check for missing required fields
659
+ # Check for missing required fields (only check top-level, nested have defaults)
512
660
  required_fields = _get_required_fields(selected_class)
513
661
  for field_name in required_fields:
514
- if field_name not in attrs:
662
+ if field_name not in attrs and field_name not in nested_attrs:
515
663
  raise SystemExit(
516
664
  f"error: {selected_class.__name__} requires argument: {field_name}"
517
665
  )
518
666
 
519
667
  # Instantiate the class with collected attributes
520
668
  try:
521
- instance = selected_class(**attrs)
669
+ if nested_attrs:
670
+ # Build with nested dataclass support
671
+ instance = _build_nested_instance(selected_class, attrs, nested_attrs)
672
+ else:
673
+ instance = selected_class(**attrs)
522
674
  result[param_name] = instance
523
675
  except TypeError as e:
524
676
  raise SystemExit(f"error: failed to instantiate {selected_class.__name__}: {e}")
params_proto/proto.py CHANGED
@@ -694,59 +694,9 @@ class ptype(type):
694
694
 
695
695
  # Update the instance's class to the decorated class
696
696
  # This allows isinstance(instance, DecoratedClass) to work
697
+ # Since cls is a subclass of original_cls, methods are inherited naturally
697
698
  object.__setattr__(instance, "__class__", cls)
698
699
 
699
- # Copy methods from original class and wrap to return self
700
- for name in dir(original_cls):
701
- # Skip proto fields (fields are handled above)
702
- if name in annotations:
703
- continue
704
-
705
- # For dunder methods, only copy user-defined ones (not inherited from object/type)
706
- if name.startswith("__"):
707
- # Check if this dunder method is user-defined (not from object or type)
708
- is_user_defined = False
709
- for klass in original_cls.__mro__:
710
- if klass is object or klass is type:
711
- break
712
- if name in klass.__dict__:
713
- is_user_defined = True
714
- break
715
- if not is_user_defined:
716
- continue
717
-
718
- # Check raw descriptor in MRO to detect staticmethod/classmethod (handles inheritance)
719
- raw_attr = None
720
- for klass in original_cls.__mro__:
721
- if name in klass.__dict__:
722
- raw_attr = klass.__dict__[name]
723
- break
724
-
725
- attr = getattr(original_cls, name)
726
-
727
- # Only process actual methods (staticmethod, classmethod, or function)
728
- if isinstance(raw_attr, staticmethod):
729
- # For staticmethod, use directly (no binding needed)
730
- method = attr
731
- elif isinstance(raw_attr, classmethod) or inspect.isfunction(raw_attr) or inspect.ismethod(attr):
732
- # For instance methods and classmethods, bind to instance
733
- # Note: classmethods bound to instance is intentional for @proto
734
- # semantics where instances have all attributes accessible
735
- method = attr.__get__(instance, original_cls)
736
- else:
737
- # Not a method (e.g., _EnvVar, property, or other callable), skip
738
- continue
739
-
740
- # Wrap it to return self if it returns None
741
- def make_wrapper(m):
742
- def wrapper(*args, **kwargs):
743
- result = m(*args, **kwargs)
744
- return instance if result is None else result
745
-
746
- return wrapper
747
-
748
- setattr(instance, name, make_wrapper(method))
749
-
750
700
  return instance
751
701
 
752
702
 
@@ -869,30 +819,32 @@ def proto(
869
819
  else:
870
820
  metaclass = ptype
871
821
 
872
- # Recreate the class with ptype as its metaclass
873
- # Collect class namespace (attributes and methods)
874
- namespace = {}
875
- for key in dir(obj):
876
- if not key.startswith("__") or key in ("__annotations__", "__module__", "__qualname__", "__doc__"):
877
- try:
878
- # Use __dict__ to preserve classmethod/staticmethod descriptors
879
- # getattr() would return bound methods instead of descriptors
880
- if key in obj.__dict__:
881
- namespace[key] = obj.__dict__[key]
882
- else:
883
- namespace[key] = getattr(obj, key)
884
- except AttributeError:
885
- pass
822
+ # Create new class with metaclass as subclass of original
823
+ # Since new class inherits from obj, methods are inherited naturally.
824
+ # We only need to provide:
825
+ # - Module/qualname metadata
826
+ # - Annotations (so annotated fields are visible on the class)
827
+ # - Resolved default values (with EnvVars resolved)
828
+ namespace = {
829
+ "__module__": obj.__module__,
830
+ "__qualname__": obj.__qualname__,
831
+ "__doc__": obj.__doc__,
832
+ "__annotations__": annotations,
833
+ }
886
834
 
887
- # Replace _EnvVar objects with resolved values from defaults
888
- # This ensures the descriptor doesn't interfere with class attribute access
889
- for key, value in defaults.items():
890
- namespace[key] = value
835
+ # Add resolved default values (EnvVars are already resolved in defaults dict)
836
+ # Also set None for annotated fields without defaults so they're accessible
837
+ for key in annotations.keys():
838
+ if key in defaults:
839
+ namespace[key] = defaults[key]
840
+ else:
841
+ namespace[key] = None
891
842
 
892
- # Create new class with metaclass
893
- # IMPORTANT: Use (obj,) as bases to make new class a SUBCLASS of original.
894
- # This ensures super() works correctly - the original class is in the MRO,
895
- # so Python's super() validation passes when checking isinstance(self, original_class).
843
+ # Create new class as SUBCLASS of original.
844
+ # This ensures:
845
+ # 1. super() works correctly (original class is in MRO)
846
+ # 2. Methods, staticmethods, classmethods are inherited naturally
847
+ # 3. isinstance(instance, DecoratedClass) works
896
848
  new_cls = metaclass(
897
849
  obj.__name__,
898
850
  (obj,),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: params-proto
3
- Version: 3.2.3
3
+ Version: 3.3.0
4
4
  Summary: Modern Hyper Parameter Management for Machine Learning
5
5
  Project-URL: Homepage, https://github.com/geyang/params-proto
6
6
  Project-URL: Documentation, https://params-proto.readthedocs.io
@@ -3,11 +3,11 @@ params_proto/app.py,sha256=UySpd1op3M44Szk6Ekyn0fJcnZsQvMTMPdaEybwWsLE,19
3
3
  params_proto/documentation.py,sha256=mIqmcwGWo8tM1BuNzLIwVTzdbQ3qyPus7yWTaOce4dM,8091
4
4
  params_proto/envvar.py,sha256=A87jxSAQ2tjbKLbrm96lblV90zNdtBGCSV6QRe2DrgA,8398
5
5
  params_proto/parse_env_template.py,sha256=mXTvKpNhT2jGr3HpwKw42shd18O0QACmSJn6yWMDdKA,1298
6
- params_proto/proto.py,sha256=CYLLEKT8mDcxi25q75wIjHuoS3U_7cMvWuqFRmeTYAQ,39996
6
+ params_proto/proto.py,sha256=9l_WFj3yfKbW74Jo-247rnJzNsinQHZ2wqW9K158Zeo,38081
7
7
  params_proto/type_utils.py,sha256=x68rL5m76ZFRKsCRgH_i_4vLpt6ldWEsEAalgacFIH8,7364
8
8
  params_proto/cli/__init__.py,sha256=sLpN3GmaBqd_d0J0nvUNOeGlV74_-jQGW0nDUU34tjA,493
9
9
  params_proto/cli/ansi_help.py,sha256=-1gzbvOpi9GjPlqgiINOYQAfIstzg0-ukv1se88TYCQ,10967
10
- params_proto/cli/cli_parse.py,sha256=kVBSgCCVZk32NOfPfyj-pYWXBLyqitFZzsfme8LspcE,18835
10
+ params_proto/cli/cli_parse.py,sha256=jdywU2wDFsHCShMozvOLe2EUbl7V1Q7wWBMmZvewYSA,24401
11
11
  params_proto/cli/help_gen.py,sha256=Iv9MWC7TJT4_OUWozTfCr8-Nmp_-K8Ohoim_dtsN5AY,12921
12
12
  params_proto/hyper/__init__.py,sha256=4zMnKk9H7NPlaTTRzbL2MC7anzwkBbd2_kW51aYhCPs,157
13
13
  params_proto/hyper/proxies.py,sha256=OMiaKK-gQx-zT1xeCmZevBSDgWUwwkzz0n54A5_wC60,4492
@@ -20,7 +20,7 @@ params_proto/v2/hyper.py,sha256=onBAkT8Ja8IkeHEOq1AwCdTuBzAnthIe766ZE0lAy-M,1146
20
20
  params_proto/v2/partial.py,sha256=_ovi4NY8goYgHurfYt1OV0E9DSMXGYucjMVIyG1Q_xc,983
21
21
  params_proto/v2/proto.py,sha256=-BTTwFOhJsaPnujwFIDh14QMB8r_ZdridK9I2Jkqd_U,19228
22
22
  params_proto/v2/utils.py,sha256=5EWvwboZDTsCYfzSED_J6RVFyNLIlf95nIu4p_ZSVxA,3540
23
- params_proto-3.2.3.dist-info/METADATA,sha256=bKvvnv82Bs4UkhnDvsNyWXoGPjhX7fDf0U9_IiCoXJA,8959
24
- params_proto-3.2.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
25
- params_proto-3.2.3.dist-info/licenses/LICENSE.md,sha256=c2qSYi9tUMZtzj9SEsMeKhub5LJUmHwBtDLiIMM5b6U,1526
26
- params_proto-3.2.3.dist-info/RECORD,,
23
+ params_proto-3.3.0.dist-info/METADATA,sha256=evqSOimTwLNTpF2OkokqEPbJf-GRQ6o4gVx9denyqGk,8959
24
+ params_proto-3.3.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
25
+ params_proto-3.3.0.dist-info/licenses/LICENSE.md,sha256=c2qSYi9tUMZtzj9SEsMeKhub5LJUmHwBtDLiIMM5b6U,1526
26
+ params_proto-3.3.0.dist-info/RECORD,,