cdxcore 0.1.6__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cdxcore might be problematic. Click here for more details.

@@ -0,0 +1,750 @@
1
+ """
2
+ Overview
3
+ --------
4
+
5
+ A simple extension to standard dictionaries which allows accessing elements of the dictionary with "."
6
+ notation. The purpose is a functional-programming style pattern for generating complex objects::
7
+
8
+ from cdxbasics.prettydict import PrettyObject
9
+ pdct = PrettyObject(z=1)
10
+
11
+ pdct.num_samples = 1000
12
+ pdct.num_batches = 100
13
+ pdct.method = "signature"
14
+
15
+ This, of course, works just with using any derived calss of ``object``.
16
+ The class :class:`cdxcore.pretty.PrettyObject` adds:
17
+
18
+ * Implements all relevant dictionary protocols, so objects of type :class:`cdxcore.pretty.PrettyObject` can
19
+ (nearly always) be passed where dictionaries are expected:
20
+
21
+ * A :class:`cdxcore.pretty.PrettyObject` object supports standard dictionary semantics in addition to member attribute
22
+ access.
23
+ That means you can use ``pdct['num_samples']`` as well as ``pdc.num_samples``.
24
+ You can mix standard dictionary notation with member attribute notation::
25
+
26
+ print(pdct["num_samples"]) # -> prints "1000"
27
+ pdct["test"] = 1 # sets pdct.test to 1
28
+
29
+ * Iterations work just like for dictionaries; for example::
30
+
31
+ for k,v in pdct.items():
32
+ print( k, v)
33
+
34
+ * Applying ``str`` and ``repr`` to objects of type :class:`cdxcore.pretty.PrettyObject` will return dictionary-type
35
+ results, so for example ``print(pdct)`` of the above will return ``{'z': 1, 'num_samples': 1000, 'num_batches': 100, 'method': 'signature'}``.
36
+
37
+ * The :attr:`cdxcore.pretty.PrettyObject.at_pos` attribute allows accessing element of the ordered dictionary
38
+ by positon:
39
+
40
+ * ``cdxcore.pretty.PrettyObject.at_pos[i]`` returns the `i` th element.
41
+
42
+ * ``cdxcore.pretty.PrettyObject.at_pos.keys[i]`` returns the `i` th key.
43
+
44
+ * ``cdxcore.pretty.PrettyObject.at_pos.items[i]`` returns the `i` th item.
45
+
46
+ For example::
47
+
48
+ print(pdct.at_pos[3]) # -> prints "signature"
49
+ print(pdct.at_pos.keys[3]) # -> prints "method"
50
+
51
+ * You can assign member functions. The following works as expected::
52
+
53
+ pdct.f = lambda self, y: return self.y*x
54
+
55
+ (to assign a static function which does not refer to ``self``, use ``pdct['g'] = lambda z : return z``).
56
+
57
+ **Dataclasses**
58
+
59
+ :mod:`dataclasses` rely on default values of any member being "frozen" objects, which most user-defined objects and
60
+ :class:`cdxcore.pretty.PrettyObject` objects are not.
61
+ This limitationb applies as well to `flax <https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html>`__ modules.
62
+ To use non-frozen default values, :class:`cdxcore.pretty.PrettyObject` wraps the required data factory into its
63
+ :meth:`cdxcore.pretty.PrettyObject.as_field` function::
64
+
65
+ from cdxbasics.prettydict import PrettyObject
66
+ from dataclasses import dataclass
67
+
68
+ @dataclass
69
+ class Data:
70
+ data : PrettyObject = PrettyObject(x=2).as_field()
71
+
72
+ def f(self):
73
+ return self.data.x
74
+
75
+ d = Data() # default constructor used.
76
+ f.f()
77
+
78
+
79
+ Import
80
+ ------
81
+ .. code-block:: python
82
+
83
+ from cdxcore.pretty import PrettyObject as pdct
84
+ """
85
+
86
+ from collections import OrderedDict
87
+ import dataclasses as dataclasses
88
+ from dataclasses import Field
89
+ import types as types
90
+ from collections.abc import Mapping, MutableMapping, Sequence
91
+
92
+ class __No_Default_dummy():
93
+ pass
94
+ no_default = __No_Default_dummy()
95
+
96
+ class PrettyObject(MutableMapping):
97
+ """
98
+ Ordered dictionary which allows accessing its members with member notation.
99
+
100
+ Example::
101
+
102
+ from cdxcore.pretty import PrettyObject
103
+ pdct = PrettyObject()
104
+ pdct.x = 1
105
+ pdct['y'] = 2
106
+ print( pdct['x'], pdct.y ) # -> prints 1 2
107
+
108
+ The object mimics a dictionary::
109
+
110
+ print(pdct) # -> '{'x': 1, 'y': 2}'
111
+
112
+ u = dict( pdct )
113
+ print(u) # -> {'x': 1, 'y': 2}
114
+
115
+ u = { k: 2*v for k,v in pdct.items() }
116
+ print(u) # -> {'x': 2, 'y': 4}
117
+
118
+ l = list( pdct )
119
+ print(l) # -> ['x', 'y']
120
+
121
+ *Important:*
122
+ attributes starting with '__' cannot be accessed with item ``[]`` notation.
123
+ In other words::
124
+
125
+ pdct = PrettyObject()
126
+ pdct.__x = 1 # fine
127
+ _ = pdct['__x'] # <- throws an exception
128
+
129
+ **Access by Index Position**"
130
+
131
+ :class:`cdxcore.pretty.PrettyObject` retains order of construction. To access its members
132
+ by index position, use the :attr:`cdxcore.pretty.PrettyObject.at_pos` attribute::
133
+
134
+ print(pdct.at_pos[1]) # -> prints "2"
135
+ print(pdct.at_pos.keys[1]) # -> prints "y"
136
+ print(list(pdct.at_pos.items[2])) # -> prints "[('x', 1), ('y', 2)]"
137
+
138
+ **Assigning Member Functions**
139
+
140
+ ``PrettyObject`` objects also allow assigning bona fide member functions by a simple semantic of the form::
141
+
142
+ pdct = PrettyObject(b=2)
143
+ pdct.mult_b = lambda self, x: self.b*x
144
+ pdct.mult_b(3) # -> 6
145
+
146
+ Calling ``pdct.mult_b(3)`` with above ``pdct`` will return `6` as expected.
147
+ To assign static member functions, use the ``[]`` operator.
148
+ The reason for this is as follows: consider::
149
+
150
+ def mult( a, b ):
151
+ return a*b
152
+ pdct = PrettyObject()
153
+ pdct.mult = mult
154
+ pdct.mult(3,4) --> produces am error as three arguments must be passed: self, 3, and 4
155
+
156
+ In this case, use::
157
+
158
+ pdct = PrettyObject()
159
+ pdct['mult'] = mult
160
+ pdct.mult(3,4) --> 12
161
+
162
+ You can also pass member functions to the constructor::
163
+
164
+ p = PrettyObject( f=lambda self, x: self.y*x, y=2)
165
+ p.f(3) # -> 6
166
+
167
+ **Operators**
168
+
169
+ Objects of type :class:`cdxcore.pretty.PrettyObject` support the following operators:
170
+
171
+ * Comparison operator ``==`` and ``!=`` test for equality of keys and values. Unlike for dictionaries
172
+ comparisons are performed in *in order*. That means ``PrettyObject(x=1,y=2)`` and ``PrettyObject(y=2,x=1)``
173
+ are *not* equal.
174
+
175
+ * Super/subset operators ``>=`` and ``<=`` test for a super/sup set relationship, respectively.
176
+
177
+ * The ``a | b`` returns the union of two :class:`cdxcore.pretty.PrettyObject`. Elements of the ``b`` overwrite any elements of ``a``, if they
178
+ are present in both. The order of the new dictionary is determined by the order of appearance of keys in first ``a`` and then ``b``, that
179
+ means in all but trivial cases ``a|b != b|a``.
180
+
181
+ The ``|=`` operator is a short-cut for :meth:`cdxcore.pretty.PrettyObject.update`.
182
+ """
183
+ def __init__(self, copy : Mapping = None, **kwargs):
184
+ """
185
+ Construct the object with same sematics as dictionary construction.
186
+
187
+ Since Python 3.6 `dictionaries preserve the order <https://docs.python.org/3/whatsnew/3.6.html#whatsnew36-compactdict>`__
188
+ in which they were constructed; so does therefore PrettyObject.
189
+
190
+ However, Python semantics remain otherwise order-invariant, i.e. ``{'x':1, 'y':2}`` tests equal to ``{'y':2',x':1}``.
191
+
192
+ Parameters
193
+ ----------
194
+ copy : Mapping or `None`
195
+ If present, shallow copy elements of this mapping.
196
+ **kwargs
197
+ Add key/value pairs directly provided to the constructor.
198
+ """
199
+ if not copy is None:
200
+ self.update(copy)
201
+ for k, v in kwargs.items():
202
+ setattr(self, k, v)
203
+
204
+ def __getitem__(self, key):
205
+ try:
206
+ return getattr( self, key )
207
+ except AttributeError as e:
208
+ raise KeyError(key,*e.args)
209
+
210
+ def __setitem__(self,key,value):
211
+ """
212
+ Route ``self[key] = value`` to the base class ``__setattr__`` method.
213
+ This way you can assign static functions using ``[]`` which assinging
214
+ functions using ``.`` will assign member functions.
215
+ """
216
+ try:
217
+ super().__setattr__(key, value)
218
+ return self[key]
219
+ except AttributeError as e:
220
+ raise KeyError(key,*e.args)
221
+
222
+ def __delitem__(self,key):
223
+ try:
224
+ delattr(self, key)
225
+ except AttributeError as e:
226
+ raise KeyError(key,*e.args)
227
+ def __iter__(self):
228
+ return self.__dict__.__iter__()
229
+ def __reversed__(self):
230
+ return self.__dict__.__reversed__()
231
+ def __sizeof__(self):
232
+ return self.__dict__.__sizeof__()
233
+ def __contains__(self, key):
234
+ return self.__dict__.__contains__(key)
235
+ def __len__(self):
236
+ return self.__dict__.__len__()
237
+
238
+ # allow assigning functions with ``self``
239
+ def __setattr__(self, key, value):
240
+ """
241
+ ``__setattr__`` converts function assignments to member functions
242
+ """
243
+ if key[:2] == "__":
244
+ super().__setattr__(key, value)
245
+ if isinstance(value,types.FunctionType):
246
+ # bind function to this object
247
+ value = types.MethodType(value,self)
248
+ elif isinstance(value,types.MethodType):
249
+ # re-point the method to the current instance
250
+ value = types.MethodType(value.__func__,self)
251
+ super().__setattr__(key, value)
252
+
253
+ # dictionary
254
+ def copy(self, **kwargs):
255
+ """ Copy `self`. """
256
+ return PrettyObject(self,**kwargs)
257
+ def get(self, key, default = no_default ):
258
+ """ Equivalent to :meth:`dict.get`. """
259
+ try:
260
+ return getattr(self, key) if default == no_default else getattr(self, key, default)
261
+ except AttributeError as e:
262
+ raise KeyError(key,*e.args)
263
+
264
+ def pop(self, key, default = no_default ):
265
+ """ Equivalent to :meth:`dict.pop`. """
266
+ try:
267
+ v = getattr(self, key) if default == no_default else getattr(self, key, default)
268
+ delattr(self,key)
269
+ return v
270
+ except AttributeError as e:
271
+ raise KeyError(key,*e.args)
272
+ def setdefault( self, key, default=None ):
273
+ """ Equivalent to :meth:`dict.setdefault`. """
274
+ #return self.__dict__.setdefault(key,default)
275
+ if not hasattr(self, key):
276
+ self.__setattr__(key, default)
277
+ return getattr(self,key)
278
+
279
+ def update(self, other : Mapping = None, **kwargs):
280
+ """ Equivalent to :meth:`dict.update`. """
281
+ if not other is None:
282
+ for k, v in other.items():
283
+ setattr(self, k, v)
284
+ for k, v in kwargs.items():
285
+ setattr(self, k, v)
286
+ return self
287
+
288
+ # behave like a dictionary
289
+ def keys(self):
290
+ """ Equivalent to :meth:`dict.keys` """
291
+ return self.__dict__.keys()
292
+ def items(self):
293
+ """ Equivalent to :meth:`dict.items` """
294
+ return self.__dict__.items()
295
+ def values(self):
296
+ """ Equivalent to :meth:`dict.values` """
297
+ return self.__dict__.values()
298
+
299
+ # update
300
+ def __ior__(self, other):
301
+ return self.update(other)
302
+ def __or__(self, other):
303
+ copy = self.copy()
304
+ copy.update(other)
305
+ return copy
306
+ def __ror__(self, other):
307
+ copy = self.copy()
308
+ copy.update(other)
309
+ return copy
310
+
311
+ # dictionary comparison
312
+ def __eq__(self, other):
313
+ """
314
+ Comparison operator. Unlike dictionary comparison, this comparision operator
315
+ preservers order.
316
+ """
317
+ if len(self) != len(other):
318
+ return False
319
+ for k1, k2 in zip( self, other ):
320
+ if not k1==k2:
321
+ return False
322
+ for v1, v2 in zip( self.values(), other.values() ):
323
+ if not v1==v2:
324
+ return False
325
+ return True
326
+ def __le__(self, other):
327
+ """
328
+ Subset operator i.e. if ``self`` is contained in ``other``, including values.
329
+ """
330
+ for k, v in self.items():
331
+ if not k in other:
332
+ return False
333
+ if not v == other[k]:
334
+ return False
335
+ return True
336
+ def __ge__(self, other):
337
+ """
338
+ Superset operator i.e. if ``self`` is a superset of ``other``, including values.
339
+ """
340
+ return other <= self
341
+
342
+ def __neq__(self, other):
343
+ """
344
+ Comparison operator. Unlike dictionary comparison, this comparision operator
345
+ preservers order.
346
+ """
347
+ return not self == other
348
+
349
+ # print representation
350
+ def __repr__(self):
351
+ return f"PrettyObject({self.__dict__.__repr__()})"
352
+ def __str__(self):
353
+ return self.__dict__.__str__()
354
+
355
+ # data classes
356
+ def as_field(self) -> Field:
357
+ """
358
+ This function provides support for :class:`dataclasses.dataclass` fields
359
+ with ``PrettyObject`` default values.
360
+
361
+ When adding
362
+ a field with a non-frozen default value to a ``@dataclass`` class,
363
+ a ``default_factory`` has to be provided.
364
+ The function ``as_field`` returns the corresponding :class:`dataclasses.Field`
365
+ element by returning simply::
366
+
367
+ def factory():
368
+ return self
369
+ return dataclasses.field( default_factory=factory )
370
+
371
+ Usage is as follows::
372
+
373
+ from dataclasses import dataclass
374
+ @dataclass
375
+ class A:
376
+ data : PrettyDict = PrettyDict(x=2).as_field()
377
+
378
+ a = A()
379
+ print(a.data.x) # -> "2"
380
+ a = A(data=PrettyDict(x=3))
381
+ print(a.data.x) # -> "3"
382
+ """
383
+ def factory():
384
+ return self
385
+ return dataclasses.field( default_factory=factory )
386
+
387
+ @property
388
+ def at_pos(self):
389
+ """
390
+ Elementary access to the data contained in `self` by ordinal position. The ordinal
391
+ position of an element is determined by the order of addition to the dictionary.
392
+
393
+ * ``at_pos[position]`` returns an element or elements at an ordinal position:
394
+
395
+ * It returns a single element if 'position' refers to only one field.
396
+ * If 'position' is a slice then the respecitve list of fields is returned
397
+
398
+ * ``at_pos.keys[position]`` returns the key or keys at 'position'
399
+
400
+ * ``at_pos.items[position]`` returns the tuple ``(key, element)`` or a list thereof for `position`
401
+
402
+ You can also write data using the `attribute` notation:
403
+
404
+ * ``at_pos[position] = item`` assigns an item or an ordinal position:
405
+
406
+ * If 'position' refers to a single element, 'item' must be that item
407
+ * If 'position' is a slice then 'item' must resolve to a list of the required size
408
+ """
409
+
410
+ class Access(Sequence):
411
+ """
412
+ Wrapper object to allow index access for at_pos
413
+ """
414
+ def __init__(self):
415
+ self.__keys = None
416
+
417
+ def __getitem__(_, position):
418
+ key = _.keys[position]
419
+ return self[key] if not isinstance(key,list) else [ self[k] for k in key ]
420
+ def __setitem__(_, position, item ):
421
+ key = _.keys[position]
422
+ if not isinstance(key,list):
423
+ self[key] = item
424
+ else:
425
+ for k, i in zip(key, item):
426
+ self[k] = i
427
+ def __len__(_):
428
+ return len(self)
429
+ def __iter__(_):
430
+ for key in self:
431
+ yield self[key]
432
+
433
+ @property
434
+ def keys(_) -> list:
435
+ """ Returns the list of keys in construction order """
436
+ return list(self.keys())
437
+ @property
438
+ def values(_) -> list:
439
+ """ Returns the list of values in construction order """
440
+ return list(self.values())
441
+ @property
442
+ def items(_) -> Sequence:
443
+ """ Returns the sequence of key, value pairs of the original dictionary """
444
+ class ItemAccess(Sequence):
445
+ def __init__(_x):
446
+ _x.keys = list(self.keys())
447
+ def __getitem__(_x, position):
448
+ key = _x.keys[position]
449
+ return (key, self[key]) if not isinstance(key,(list,types.GeneratorType)) else [ (k,self[k]) for k in key ]
450
+ def __len__(_x):
451
+ return len(_x.keys)
452
+ def __iter__(_x):
453
+ for key in _x.keys:
454
+ yield key, self[key]
455
+ return ItemAccess()
456
+
457
+ return Access()
458
+
459
+ if False:
460
+ class PrettyDict(OrderedDict):
461
+ """
462
+ **Deprecated. Recommendation is to use :class:`cdxcore.pretty.PrettyObject`**
463
+
464
+ Ordered dictionary which allows accessing its members with member notation, e.g.::
465
+
466
+ from cdxcore.pretty import PrettyDict
467
+ pdct = PrettyDict()
468
+ pdct.x = 1
469
+ x = pdct.x
470
+
471
+ *IMPORTANT*
472
+ Attributes starting with '__' are assumed to be existing object attributes
473
+ and cannot be overwritten.
474
+ In other words::
475
+
476
+ pdct = PrettyDict()
477
+ pdct.__x = 1
478
+ _ = pdct['__x'] <- throws an exception
479
+
480
+ (This conventions allows re-use of general operator handling: otherwise
481
+ access to, say, ``__add__`` would trigger a ``KeyError``.)
482
+
483
+ **Dataclasses**
484
+
485
+ :mod:`dataclasses` have difficulties with using directly derived dictionaries.
486
+ This applies as well to ``flax`` modules.
487
+ For fields in dataclasses use :class:`cdxcore.pretty.PrettyField`::
488
+
489
+ from cdxbasics.prettydict import PrettyField
490
+ from dataclasses import dataclass
491
+
492
+ @dataclass
493
+ class Data:
494
+ ...
495
+ data : PrettyField = PrettyField.Field()
496
+
497
+ def f(self):
498
+ return self.data.x
499
+
500
+ p = PrettyDict(x=1)
501
+ d = Data( p.as_field() )
502
+ f.f()
503
+
504
+ **Assigning member functions**
505
+
506
+ `PrettyDict` objects also allow assigning bona fide member functions by a simple semantic of the form::
507
+
508
+ def mult_b( self, x ):
509
+ return self.b * x
510
+ pdct = PrettyDict()
511
+ pdct = mult_a
512
+ pdct.mult_a(3)
513
+
514
+ Calling ``pdct.mult_a(3)`` with above `pdct` will return `6` as expected. This only works when using the member synthax for assigning values
515
+ to a pretty dictionary; use the standard ``[]`` operator to assign static functions to ``self``.
516
+
517
+ The reason for this is as follows: consider::
518
+
519
+ def mult( a, b ):
520
+ return a*b
521
+ pdct = PrettyDict()
522
+ pdct.mult = mult
523
+ pdct.mult(3,4) --> produces am error as three arguments as are passed if we count 'self'
524
+
525
+ In this case, use::
526
+
527
+ pdct = PrettyDict()
528
+ pdct['mult'] = mult
529
+ pdct.mult(3,4) --> 12
530
+
531
+ **Functions passed to the Constructor**
532
+
533
+ The constructor works like an item assignment, i.e.::
534
+
535
+ def mult( a, b ):
536
+ return a*b
537
+ pdct = PrettyDict(mult=mult)
538
+ pdct.mult(3,4) --> 12
539
+
540
+ """
541
+
542
+ def __getattr__(self, key : str):
543
+ """ Equyivalent to self[key] """
544
+ if key[:2] == "__": raise AttributeError(key) # you cannot treat private members as dictionary members
545
+ return self[key]
546
+ def __delattr__(self, key : str):
547
+ """ Equyivalent to del self[key] """
548
+ if key[:2] == "__": raise AttributeError(key) # you cannot treat private members as dictionary members
549
+ del self[key]
550
+ def __setattr__(self, key : str, value):
551
+ """ Equivalent to self[key] = value """
552
+ if key[:2] == "__":
553
+ OrderedDict.__setattr__(self, key, value)
554
+ return
555
+ if isinstance(value,types.FunctionType):
556
+ # bind function to this object
557
+ value = types.MethodType(value,self)
558
+ elif isinstance(value,types.MethodType):
559
+ # re-point the method to the current instance
560
+ value = types.MethodType(value.__func__,self)
561
+ self[key] = value
562
+
563
+ def __str__(self):
564
+ """ Return standard dictionary string """
565
+ return dict(self).__str__()
566
+
567
+ def __call__(self, key : str, default = no_default ):
568
+ """
569
+ Short-cut for :func:`dict.get`.
570
+ """
571
+ return self.get(key) if default != no_default else self.get(key,default)
572
+
573
+ def copy(self) -> object:
574
+ """
575
+ Return copy of ``self``
576
+ """
577
+ return PrettyDict(self)
578
+
579
+ def as_field(self) -> Field:
580
+ """
581
+ Returns a :class:`cdxcore.pretty.PrettyField` wrapper around ``self`` for use in :mod:`dataclasses`.
582
+ See :class:`cdxcore.pretty.PrettyField` documentation for an example
583
+ """
584
+ def factory():
585
+ return self
586
+ return dataclasses.field( default_factory=factory )
587
+
588
+ @property
589
+ def at_pos(self):
590
+ """
591
+ Elementary access to the data contained in `self`:
592
+
593
+ * ``at_pos[position]`` returns an element or elements at an ordinal position:
594
+
595
+ * It returns a single element if 'position' refers to only one field.
596
+ * If 'position' is a slice then the respecitve list of fields is returned
597
+
598
+ * ``at_pos.keys[position]`` returns the key or keys at 'position'
599
+
600
+ * ``at_pos.items[position]`` returns the tuple ``(key, element)`` or a list thereof for `position`
601
+
602
+ You can also write data using the `attribute` notation:
603
+
604
+ * ``at_pos[position] = item`` assigns an item or an ordinal position:
605
+
606
+ * If 'position' refers to a single element, 'item' must be that item
607
+ * If 'position' is a slice then 'item' must resolve to a list of the required size
608
+ """
609
+
610
+ class Access:
611
+ """
612
+ Wrapper object to allow index access for at_pos
613
+ """
614
+ def __init__(self):
615
+ self.__keys = None
616
+
617
+ def __getitem__(_, position):
618
+ key = _.keys[position]
619
+ return self[key] if not isinstance(key,list) else [ self[k] for k in key ]
620
+ def __setitem__(_, position, item ):
621
+ key = _.keys[position]
622
+ if not isinstance(key,list):
623
+ self[key] = item
624
+ else:
625
+ for k, i in zip(key, item):
626
+ self[k] = i
627
+ @property
628
+ def keys(_) -> list:
629
+ """ Returns the list of keys of the original dictionary """
630
+ if _.__keys is None:
631
+ _.__keys = list(self.keys())
632
+ return _.__keys
633
+ @property
634
+ def items(_) -> list:
635
+ """ Returns the list of keys of the original dictionary """
636
+ class ItemAccess(object):
637
+ def __getitem__(_x, position):
638
+ key = _.keys[position]
639
+ return (key, self[key]) if not isinstance(key,list) else [ (k,self[k]) for k in key ]
640
+ return ItemAccess()
641
+
642
+ return Access()
643
+
644
+ # pickling
645
+ def __getstate__(self):
646
+ """ Return state to pickle """
647
+ return self.__dict__
648
+ def __setstate__(self, state):
649
+ """ Restore pickle """
650
+ self.__dict__.update(state)
651
+
652
+ class PrettyField(object):
653
+ """
654
+ Wraps :class:`dataclasses.field` for :class:`cdxcore.pretty.PrettyDict` objects.
655
+
656
+ Useful for Flax::
657
+
658
+ import dataclasses as dataclasses
659
+ import jax.numpy as jnp
660
+ import jax as jax
661
+ from options.cdxbasics.config import Config, ConfigField
662
+ import types as types
663
+
664
+ class A( nn.Module ):
665
+ pdct : PrettyField = PrettyField.Field()
666
+
667
+ def setup(self):
668
+ self.dense = nn.Dense(1)
669
+
670
+ def __call__(self, x):
671
+ a = self.pdct.a # <-- basic access to 'a'
672
+ return self.dense(x)*a
673
+
674
+ r = PrettyDict(a=1.)
675
+ a = A( r.as_field() )
676
+
677
+ key1, key2 = jax.random.split(jax.random.key(0))
678
+ x = jnp.zeros((10,10))
679
+ param = a.init( key1, x )
680
+ y = a.apply( param, x )
681
+ """
682
+ def __init__(self, pdct : PrettyDict = None, **kwargs):
683
+ """
684
+ Initialize with an input dictionary and potential overwrites
685
+ """
686
+ if not pdct is None:
687
+ if type(pdct).__name__ == type(self).__name__ and len(kwargs) == 0:
688
+ # copy
689
+ self.__pdct = PrettyDict( pdct.__pdct )
690
+ return
691
+ if not isinstance(pdct, Mapping): raise ValueError("'pdct' must be a Mapping")
692
+ self.__pdct = PrettyDict(pdct)
693
+ self.__pdct.update(kwargs)
694
+ else:
695
+ self.__pdct = PrettyDict(**kwargs)
696
+ def rec(x):
697
+ for k, v in x.items():
698
+ if isinstance(v, (PrettyDict, PrettyDict)):
699
+ x[k] = PrettyField(v)
700
+ elif isinstance(v, Mapping):
701
+ rec(v)
702
+ rec(self.__pdct)
703
+
704
+ def as_dict(self) -> PrettyDict:
705
+ """ Return copy of underlying dictionary """
706
+ return PrettyDict( self.__pdct )
707
+
708
+ # data classes
709
+
710
+ # mimic the underlying dictionary
711
+ # -------------------------------
712
+
713
+ def __getattr__(self, key):
714
+ if key[:2] == "__":
715
+ return object.__getattr__(self,key)
716
+ return self.__pdct.__getattr__(key)
717
+ def __getitem__(self, key):
718
+ return self.__pdct[key]
719
+ def __call__(self, *kargs, **kwargs):
720
+ return self.__pdct(*kargs, **kwargs)
721
+ def __eq__(self, other):
722
+ if type(other).__name__ == "PrettyDict":
723
+ return self.__pdct == other
724
+ else:
725
+ return self.__pdct == other.pdct
726
+ def keys(self):
727
+ """ :meta private: """
728
+ return self.__pdct.keys()
729
+ def items(self):
730
+ """ :meta private: """
731
+ return self.__pdct.items()
732
+ def values(self):
733
+ """ :meta private: """
734
+ return self.__pdct.values()
735
+ def __hash__(self):
736
+ h = 0
737
+ for k, v in self.items():
738
+ h ^= hash(k) ^ hash(v)
739
+ return h
740
+ def __iter__(self):
741
+ return self.__pdct.__iter__()
742
+ def __contains__(self, key):
743
+ return self.__pdct.__contains__(key)
744
+ def __len__(self):
745
+ return self.__pdct.__len__()
746
+ def __str__(self):
747
+ return self.__pdct.__str__()
748
+ def __repr__(self):
749
+ return self.__pdct.__repr__()
750
+