pyglove 0.4.5.dev202410100808__py3-none-any.whl → 0.4.5.dev202410160809__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyglove/core/__init__.py CHANGED
@@ -277,6 +277,7 @@ ObjectFactory = patching.ObjectFactory
277
277
 
278
278
  from pyglove.core import object_utils
279
279
  KeyPath = object_utils.KeyPath
280
+ KeyPathSet = object_utils.KeyPathSet
280
281
  MISSING_VALUE = object_utils.MISSING_VALUE
281
282
 
282
283
  Formattable = object_utils.Formattable
@@ -89,6 +89,7 @@ from pyglove.core.object_utils.formatting import kvlist_str
89
89
  from pyglove.core.object_utils.formatting import quote_if_str
90
90
  from pyglove.core.object_utils.formatting import maybe_markdown_quote
91
91
  from pyglove.core.object_utils.formatting import comma_delimited_str
92
+ from pyglove.core.object_utils.formatting import camel_to_snake
92
93
  from pyglove.core.object_utils.formatting import auto_plural
93
94
  from pyglove.core.object_utils.formatting import BracketType
94
95
  from pyglove.core.object_utils.formatting import bracket_chars
@@ -100,6 +101,7 @@ from pyglove.core.object_utils.formatting import repr_format
100
101
 
101
102
  # Value location.
102
103
  from pyglove.core.object_utils.value_location import KeyPath
104
+ from pyglove.core.object_utils.value_location import KeyPathSet
103
105
  from pyglove.core.object_utils.value_location import StrKey
104
106
  from pyglove.core.object_utils.value_location import message_on_path
105
107
 
@@ -438,6 +438,22 @@ def maybe_markdown_quote(s: str, markdown: bool = True) -> str:
438
438
  return f'```\n{s}\n```'
439
439
 
440
440
 
441
+ def camel_to_snake(text: str, separator: str = '_') -> str:
442
+ """Returns the snake case version of a camel case string."""
443
+ chunks = []
444
+ chunk_start = 0
445
+ last_upper = 0
446
+ length = len(text)
447
+ for i, c in enumerate(text):
448
+ if c.isupper():
449
+ if last_upper < i - 1 or (i < length - 1 and text[i + 1].islower()):
450
+ chunks.append(text[chunk_start:i])
451
+ chunk_start = i
452
+ last_upper = i
453
+ chunks.append(text[chunk_start:])
454
+ return (separator.join(c for c in chunks if c)).lower()
455
+
456
+
441
457
  def printv(v: Any, **kwargs):
442
458
  """Prints formatted value."""
443
459
  fs = kwargs.pop('file', sys.stdout)
@@ -74,6 +74,15 @@ class StringHelperTest(unittest.TestCase):
74
74
  self.assertNotEqual(raw, formatting.RawText('abcd'))
75
75
  self.assertNotEqual(raw, 'abcd')
76
76
 
77
+ def test_camel_to_snake(self):
78
+ self.assertEqual(formatting.camel_to_snake('foo'), 'foo')
79
+ self.assertEqual(formatting.camel_to_snake('Foo'), 'foo')
80
+ self.assertEqual(formatting.camel_to_snake('FooBar'), 'foo_bar')
81
+ self.assertEqual(formatting.camel_to_snake('AI'), 'ai')
82
+ self.assertEqual(formatting.camel_to_snake('AIMessage'), 'ai_message')
83
+ self.assertEqual(formatting.camel_to_snake('ABCMeta'), 'abc_meta')
84
+ self.assertEqual(formatting.camel_to_snake('ABC123Meta'), 'abc123_meta')
85
+
77
86
  def test_special_format_support(self):
78
87
 
79
88
  class NewLine:
@@ -14,9 +14,9 @@
14
14
  """Handling locations in a hierarchical object."""
15
15
 
16
16
  import abc
17
- import copy
17
+ import copy as copy_lib
18
18
  import operator
19
- from typing import Any, Callable, List, Optional, Union
19
+ from typing import Any, Callable, Iterable, Iterator, List, Optional, Union
20
20
  from pyglove.core.object_utils import formatting
21
21
 
22
22
 
@@ -187,7 +187,7 @@ class KeyPath(formatting.Formattable):
187
187
  @property
188
188
  def keys(self) -> List[Any]:
189
189
  """A list of keys in this path."""
190
- return copy.copy(self._keys)
190
+ return copy_lib.copy(self._keys)
191
191
 
192
192
  @property
193
193
  def key(self) -> Any:
@@ -287,6 +287,10 @@ class KeyPath(formatting.Formattable):
287
287
  return self
288
288
  if isinstance(other, str):
289
289
  other = KeyPath.parse(other)
290
+ elif isinstance(other, KeyPathSet):
291
+ other = other.copy()
292
+ other.rebase(self)
293
+ return other
290
294
  elif not isinstance(other, KeyPath):
291
295
  other = KeyPath(other)
292
296
  assert isinstance(other, KeyPath)
@@ -570,6 +574,245 @@ class KeyPath(formatting.Formattable):
570
574
  return comparison(self.key, other.key)
571
575
 
572
576
 
577
+ class KeyPathSet(formatting.Formattable):
578
+ """A KeyPath set based on trie-like data structure."""
579
+
580
+ def __init__(
581
+ self,
582
+ paths: Optional[Iterable[KeyPath]] = None,
583
+ *,
584
+ include_intermediate: bool = False
585
+ ):
586
+ self._trie = {}
587
+ if paths:
588
+ for path in paths:
589
+ self.add(path, include_intermediate=include_intermediate)
590
+
591
+ def add(
592
+ self,
593
+ path: Union[str, int, KeyPath],
594
+ include_intermediate: bool = False,
595
+ ) -> bool:
596
+ """Adds a path to the set."""
597
+ path = KeyPath.from_value(path)
598
+ root = self._trie
599
+ updated = False
600
+ for key in path.keys:
601
+ if key not in root:
602
+ root[key] = {}
603
+ if include_intermediate:
604
+ root['$'] = True
605
+ updated = True
606
+ root = root[key]
607
+
608
+ assert isinstance(root, dict), root
609
+ if '$' not in root:
610
+ root['$'] = True
611
+ updated = True
612
+ return updated
613
+
614
+ def remove(self, path: Union[str, int, KeyPath]) -> bool:
615
+ """Removes a path from the set."""
616
+ path = KeyPath.from_value(path)
617
+ stack = [self._trie]
618
+ for key in path.keys:
619
+ if key not in stack[-1]:
620
+ return False
621
+ value = stack[-1][key]
622
+ assert isinstance(value, dict), value
623
+ stack.append(value)
624
+
625
+ if '$' in stack[-1]:
626
+ stack[-1].pop('$')
627
+ stack.pop(-1)
628
+ assert len(stack) == len(path.keys), (path.keys, stack)
629
+ for key, parent_node in zip(reversed(path.keys), reversed(stack)):
630
+ if not parent_node[key]:
631
+ del parent_node[key]
632
+ return True
633
+ return False
634
+
635
+ def __contains__(self, path: Union[str, int, KeyPath]) -> bool:
636
+ """Returns True if the path is in the set."""
637
+ path = KeyPath.from_value(path)
638
+ root = self._trie
639
+ for key in path.keys:
640
+ if key not in root:
641
+ return False
642
+ root = root[key]
643
+ return '$' in root
644
+
645
+ def __bool__(self) -> bool:
646
+ """Returns True if the set is not empty."""
647
+ return bool(self._trie)
648
+
649
+ def __iter__(self) -> Iterator[KeyPath]:
650
+ """Iterates all paths in the set."""
651
+ def _traverse(node, keys):
652
+ for k, v in node.items():
653
+ if k == '$':
654
+ yield KeyPath(keys)
655
+ else:
656
+ keys.append(k)
657
+ for path in _traverse(v, keys):
658
+ yield path
659
+ keys.pop(-1)
660
+ return _traverse(self._trie, [])
661
+
662
+ def __eq__(self, other: Any):
663
+ return isinstance(other, KeyPathSet) and self._trie == other._trie
664
+
665
+ def __ne__(self, other: Any) -> bool:
666
+ return not self.__eq__(other)
667
+
668
+ def has_prefix(self, root_path: Union[int, str, KeyPath]) -> bool:
669
+ """Returns True if the set has a path with the given prefix."""
670
+ root_path = KeyPath.from_value(root_path)
671
+ root = self._trie
672
+ for key in root_path.keys:
673
+ if key not in root:
674
+ return False
675
+ root = root[key]
676
+ return True
677
+
678
+ def __add__(self, other: 'KeyPathSet') -> 'KeyPathSet':
679
+ return self.union(other, copy=True)
680
+
681
+ def rebase(
682
+ self,
683
+ root_path: Union[int, str, KeyPath],
684
+ ) -> None:
685
+ """Returns a KeyPathSet with the given prefix path added."""
686
+ root_path = KeyPath.from_value(root_path)
687
+ root = self._trie
688
+ for key in reversed(root_path.keys):
689
+ root = {key: root}
690
+ self._trie = root
691
+
692
+ def clear(self) -> None:
693
+ """Clears the set."""
694
+ self._trie.clear()
695
+
696
+ def copy(self) -> 'KeyPathSet':
697
+ """Returns a deep copy of the set."""
698
+ return copy_lib.deepcopy(self)
699
+
700
+ def difference_update(self, other: 'KeyPathSet') -> None:
701
+ """Removes the paths in the other set from the current set."""
702
+ def _remove_same(target_dict, src_dict):
703
+ keys_to_remove = []
704
+ for key, value in target_dict.items():
705
+ if key in src_dict:
706
+ if key == '$' or _remove_same(value, src_dict[key]):
707
+ keys_to_remove.append(key)
708
+ for key in keys_to_remove:
709
+ del target_dict[key]
710
+ if not target_dict:
711
+ return True
712
+ return False
713
+ _remove_same(self._trie, other._trie) # pylint: disable=protected-access
714
+
715
+ def difference(
716
+ self, other: 'KeyPathSet',
717
+ ) -> 'KeyPathSet':
718
+ """Returns the subset KeyPathSet based on a prefix path."""
719
+ x = self.copy()
720
+ x.difference_update(other)
721
+ return x
722
+
723
+ def intersection_update(self, other: 'KeyPathSet') -> None:
724
+ """Removes the paths in the other set from the current set."""
725
+ def _remove_diff(target_dict, src_dict):
726
+ keys_to_remove = []
727
+ for key, value in target_dict.items():
728
+ if key not in src_dict:
729
+ keys_to_remove.append(key)
730
+ elif key != '$':
731
+ _remove_diff(value, src_dict[key])
732
+ if not value:
733
+ keys_to_remove.append(key)
734
+ for key in keys_to_remove:
735
+ del target_dict[key]
736
+ _remove_diff(self._trie, other._trie) # pylint: disable=protected-access
737
+
738
+ def intersection(self, other: 'KeyPathSet') -> 'KeyPathSet':
739
+ """Returns the intersection KeyPathSet."""
740
+ copy = self.copy()
741
+ copy.intersection_update(other)
742
+ return copy
743
+
744
+ def update(self, other: 'KeyPathSet') -> None:
745
+ """Updates the current set with the other set."""
746
+ def _merge(target_dict, src_dict):
747
+ for key, value in src_dict.items():
748
+ if key != '$' and key in target_dict:
749
+ _merge(target_dict[key], value)
750
+ else:
751
+ target_dict[key] = copy_lib.deepcopy(value)
752
+ _merge(self._trie, other._trie) # pylint: disable=protected-access
753
+
754
+ def union(
755
+ self, other: 'KeyPathSet', copy: bool = False) -> 'KeyPathSet':
756
+ """Returns the union KeyPathSet."""
757
+ x = self.copy()
758
+ x.update(other)
759
+ return x
760
+
761
+ def subtree(
762
+ self,
763
+ root_path: Union[int, str, KeyPath],
764
+ ) -> Optional['KeyPathSet']:
765
+ """Returns the relative paths of the sub-tree rooted at the given path.
766
+
767
+ Args:
768
+ root_path: A KeyPath for the root of the sub-tree.
769
+
770
+ Returns:
771
+ A KeyPathSet that contains all the child paths of the given root path.
772
+ Please note that the returned value share the same trie as the current
773
+ value. So addition/removal of paths in the returned value will also
774
+ affect the current value. If there is no child path under the given root
775
+ path, None will be returned.
776
+ """
777
+ root_path = KeyPath.from_value(root_path)
778
+ if not root_path:
779
+ return self
780
+ root = self._trie
781
+ for key in root_path.keys:
782
+ if key not in root:
783
+ return None
784
+ root = root[key]
785
+ ret = KeyPathSet()
786
+ ret._trie = root # pylint: disable=protected-access
787
+ return ret
788
+
789
+ def format(self, *args, **kwargs) -> str:
790
+ """Formats the set."""
791
+ return formatting.kvlist_str(
792
+ [
793
+ ('', list(self), [])
794
+ ],
795
+ label=self.__class__.__name__,
796
+ **kwargs
797
+ )
798
+
799
+ @classmethod
800
+ def from_value(
801
+ cls,
802
+ value: Union[Iterable[Union[int, str, KeyPath]], 'KeyPathSet'],
803
+ include_intermediate: bool = False,
804
+ ):
805
+ """Returns a KeyPathSet from a compatible value."""
806
+ if isinstance(value, KeyPathSet):
807
+ return value
808
+ if isinstance(value, (list, set, tuple)):
809
+ return cls(value, include_intermediate=include_intermediate)
810
+ raise ValueError(
811
+ f'Cannot convert {value!r} to KeyPathSet. '
812
+ f'Expected a list, set, tuple, or KeyPathSet.'
813
+ )
814
+
815
+
573
816
  class StrKey(metaclass=abc.ABCMeta):
574
817
  """Interface for classes whose instances can be treated as str in ``KeyPath``.
575
818