cdxcore 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cdxcore might be problematic. Click here for more details.
- cdxcore/__init__.py +15 -0
- cdxcore/config.py +1633 -0
- cdxcore/crman.py +105 -0
- cdxcore/deferred.py +220 -0
- cdxcore/dynaplot.py +1155 -0
- cdxcore/filelock.py +430 -0
- cdxcore/jcpool.py +411 -0
- cdxcore/logger.py +319 -0
- cdxcore/np.py +1098 -0
- cdxcore/npio.py +270 -0
- cdxcore/prettydict.py +388 -0
- cdxcore/prettyobject.py +64 -0
- cdxcore/sharedarray.py +285 -0
- cdxcore/subdir.py +2963 -0
- cdxcore/uniquehash.py +970 -0
- cdxcore/util.py +1041 -0
- cdxcore/verbose.py +403 -0
- cdxcore/version.py +402 -0
- cdxcore-0.1.5.dist-info/METADATA +1418 -0
- cdxcore-0.1.5.dist-info/RECORD +30 -0
- cdxcore-0.1.5.dist-info/WHEEL +5 -0
- cdxcore-0.1.5.dist-info/licenses/LICENSE +21 -0
- cdxcore-0.1.5.dist-info/top_level.txt +4 -0
- conda/conda_exists.py +10 -0
- conda/conda_modify_yaml.py +42 -0
- tests/_cdxbasics.py +1086 -0
- tests/test_uniquehash.py +469 -0
- tests/test_util.py +329 -0
- up/git_message.py +7 -0
- up/pip_modify_setup.py +55 -0
cdxcore/npio.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Numpy fast IO
|
|
3
|
+
Hans Buehler 2023
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .util import fmt_digits, fmt as txtfmt
|
|
7
|
+
import numpy as np
|
|
8
|
+
import numba as numba#NOQA
|
|
9
|
+
import warnings as warnings
|
|
10
|
+
|
|
11
|
+
def error( text, *args, exception = RuntimeError, **kwargs ):
|
|
12
|
+
raise exception( txtfmt(text, *args, **kwargs) )
|
|
13
|
+
def verify( cond, text, *args, exception = RuntimeError, **kwargs ):
|
|
14
|
+
if not cond:
|
|
15
|
+
error( text, *args, **kwargs, exception=exception )
|
|
16
|
+
def warn( text, *args, warning=warnings.RuntimeWarning, stack_level=1, **kwargs ):
|
|
17
|
+
warnings.warn( txtfmt(text, *args, **kwargs), warning, stack_level=stack_level )
|
|
18
|
+
def warn_if( cond, text, *args, warning=warnings.RuntimeWarning, stack_level=1, **kwargs ):
|
|
19
|
+
if cond:
|
|
20
|
+
warn( text, *args, warning=warning, stack_level=stack_level, **kwargs )
|
|
21
|
+
|
|
22
|
+
dtype_map = {
|
|
23
|
+
"bool" : 0,
|
|
24
|
+
"int8" : 1,
|
|
25
|
+
"int16" : 2,
|
|
26
|
+
"int32" : 3,
|
|
27
|
+
"int64" : 4,
|
|
28
|
+
"uint16" : 5,
|
|
29
|
+
"uint32" : 6,
|
|
30
|
+
"uint64" : 7,
|
|
31
|
+
"float16" : 8,
|
|
32
|
+
"float32" : 9,
|
|
33
|
+
"float64" : 10,
|
|
34
|
+
"complex64" : 11,
|
|
35
|
+
"complex128" : 12,
|
|
36
|
+
"datetime64" : 13,
|
|
37
|
+
"timedelta64": 14
|
|
38
|
+
}
|
|
39
|
+
dtype_rev = { v:k for k,v in dtype_map.items() }
|
|
40
|
+
|
|
41
|
+
def _write_int(f,x,lbytes):
|
|
42
|
+
x = int(x).to_bytes(lbytes,"big")
|
|
43
|
+
w = f.write( x )
|
|
44
|
+
if w != len(x):
|
|
45
|
+
raise IOError(f"could only write {w} bytes, not {len(x)}.")
|
|
46
|
+
|
|
47
|
+
def _tofile(f, array : np.ndarray, dtype_map : dict ):
|
|
48
|
+
# split into chunks
|
|
49
|
+
array = np.asarray( array )
|
|
50
|
+
dtypec = np.int8(dtype_map[ str(array.dtype) ] )
|
|
51
|
+
length = np.int64( np.product( array.shape, dtype=np.uint64 ) )
|
|
52
|
+
shape32 = tuple( [np.int32(i) for i in array.shape])
|
|
53
|
+
array = np.reshape( array, (length,) ) # this operation should not reallocate any memory
|
|
54
|
+
dsize = int(array.itemsize)
|
|
55
|
+
max_size = int(1024*1024*1024//dsize)
|
|
56
|
+
num = int(length-1)//max_size+1
|
|
57
|
+
saved = 0
|
|
58
|
+
|
|
59
|
+
# write shape
|
|
60
|
+
_write_int( f, len(shape32), 2 ) # max 32k dimension
|
|
61
|
+
for i in shape32:
|
|
62
|
+
_write_int(f, i, 4) # max 32 bit resolution
|
|
63
|
+
# write dtype
|
|
64
|
+
_write_int(f, dtypec, 1)
|
|
65
|
+
# write object
|
|
66
|
+
for j in range(num):
|
|
67
|
+
s = j*max_size
|
|
68
|
+
e = min(s+max_size, length)
|
|
69
|
+
bts = array.data[s:e]
|
|
70
|
+
nw = f.write( bts )
|
|
71
|
+
if nw != (e-s)*dsize:
|
|
72
|
+
raise IOError(asked= (e-s)*dsize, recevied=nw)
|
|
73
|
+
saved += nw
|
|
74
|
+
if saved != length*dsize:
|
|
75
|
+
return IOError(asked=length*dsize, recevied=saved)
|
|
76
|
+
|
|
77
|
+
def tofile( file,
|
|
78
|
+
array : np.ndarray, *,
|
|
79
|
+
buffering : int = -1
|
|
80
|
+
):
|
|
81
|
+
"""
|
|
82
|
+
Write 'array' into file using a binary format.
|
|
83
|
+
This function will work for unbuffered files exceeding 2GB which is the usual unbuffered write() limitation on Linux.
|
|
84
|
+
This function will only work with the types contained in 'dtype_map'
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
file : file name passed to open() or an open file handle
|
|
89
|
+
array : numpy or sharedarray
|
|
90
|
+
buffering : see open(). Use 0 to turn off buffering.
|
|
91
|
+
"""
|
|
92
|
+
if isinstance(file, str):
|
|
93
|
+
with open( file, "wb", buffering=buffering ) as f:
|
|
94
|
+
return tofile(f, array, buffering=buffering)
|
|
95
|
+
f = file
|
|
96
|
+
del file
|
|
97
|
+
|
|
98
|
+
if not array.data.contiguous:
|
|
99
|
+
warn("Array is not 'contiguous'. Is that an issue??")
|
|
100
|
+
array = np.ascontiguousarray( array, dtype=array.dtype ) if not array.data.contiguous else array
|
|
101
|
+
|
|
102
|
+
try:
|
|
103
|
+
_tofile(f, array=array, dtype_map=dtype_map )
|
|
104
|
+
except IOError as e:
|
|
105
|
+
raise IOError(f"Could not write all {fmt_digits(array.nbytes)} bytes to {f.name}: {str(e)}", e)
|
|
106
|
+
|
|
107
|
+
def _read_int(f, lbytes) -> int:
|
|
108
|
+
x = f.read(lbytes)
|
|
109
|
+
if len(x) != lbytes:
|
|
110
|
+
raise IOError(f"could only read {len(x)} bytes not {lbytes}.")
|
|
111
|
+
x = int.from_bytes(x,"big")
|
|
112
|
+
return int(x)
|
|
113
|
+
|
|
114
|
+
def _readfromfile( f, array ):
|
|
115
|
+
# split into chunks
|
|
116
|
+
shape = array.shape
|
|
117
|
+
length = int( np.product( array.shape, dtype=np.uint64 ) )
|
|
118
|
+
array = np.reshape( array, (length,) )
|
|
119
|
+
dsize = int(array.itemsize)
|
|
120
|
+
max_size = int(1024*1024*1024//dsize)
|
|
121
|
+
num = int(length-1)//max_size+1
|
|
122
|
+
read = 0
|
|
123
|
+
# read
|
|
124
|
+
for j in range(num):
|
|
125
|
+
s = j*max_size
|
|
126
|
+
e = min(s+max_size, length)
|
|
127
|
+
nr = f.readinto( array.data[s:e] )
|
|
128
|
+
if nr != (e-s)*dsize:
|
|
129
|
+
raise IOError(f"could only read {fmt_digits(nr)} of {fmt_digits((e-s)*dsize)} bytes.")
|
|
130
|
+
read += nr
|
|
131
|
+
if read != length*dsize:
|
|
132
|
+
raise IOError(f"could only read {fmt_digits(read)} of {fmt_digits(length*dsize)} bytes.")
|
|
133
|
+
return np.reshape( array, shape ) # no copy
|
|
134
|
+
|
|
135
|
+
def _readheader(f):
|
|
136
|
+
"""
|
|
137
|
+
Read shape, dtype
|
|
138
|
+
"""
|
|
139
|
+
shape_len = _read_int(f,2)
|
|
140
|
+
shape = tuple( [ int(_read_int(f,4)) for _ in range(shape_len) ] )
|
|
141
|
+
dtype = dtype_rev[_read_int(f,1)]
|
|
142
|
+
return shape, dtype
|
|
143
|
+
|
|
144
|
+
def readfromfile( file,
|
|
145
|
+
target : np.ndarray, *,
|
|
146
|
+
read_only : bool = False,
|
|
147
|
+
buffering : int = -1,
|
|
148
|
+
validate_dtype : type = None,
|
|
149
|
+
validate_shape : tuple = None
|
|
150
|
+
) -> np.ndarray:
|
|
151
|
+
"""
|
|
152
|
+
Read array from disk into an existing array or into a new array.
|
|
153
|
+
See readinto and fromfile for a simpler interface.
|
|
154
|
+
|
|
155
|
+
Parameters
|
|
156
|
+
----------
|
|
157
|
+
file : file name passed to open(), or a file handle from open()
|
|
158
|
+
target : either an array, or a function which returns an array for a given shape and dtype
|
|
159
|
+
def create( shape ):
|
|
160
|
+
return np.empty( shape, dtype )
|
|
161
|
+
read_only : whether to clear the 'writable' flag of the array
|
|
162
|
+
buffering : see open(); -1 is the default, 0 for no buffering.
|
|
163
|
+
validate_dtype: if specified, check that the array has the specified dtype
|
|
164
|
+
validate_shape: if specified, check that the array has the specified shape
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
The array
|
|
169
|
+
"""
|
|
170
|
+
if isinstance(file, str):
|
|
171
|
+
with open( file, "rb", buffering=buffering ) as f:
|
|
172
|
+
return readfromfile( f, target,
|
|
173
|
+
read_only=read_only,
|
|
174
|
+
buffering=buffering,
|
|
175
|
+
validate_dtype=validate_dtype,
|
|
176
|
+
validate_shape=validate_shape )
|
|
177
|
+
f = file
|
|
178
|
+
del file
|
|
179
|
+
|
|
180
|
+
# read shape
|
|
181
|
+
shape, dtype = _readheader(f)
|
|
182
|
+
|
|
183
|
+
if not validate_dtype is None and validate_dtype != dtype:
|
|
184
|
+
raise IOError(f"Failed to read {f.name}: found type {dtype} expected {validate_dtype}.")
|
|
185
|
+
if not validate_shape is None and validate_shape != shape:
|
|
186
|
+
raise IOError(f"Failed to read {f.name}: found type {shape} expected {validate_shape}.")
|
|
187
|
+
|
|
188
|
+
# handle array
|
|
189
|
+
if isinstance(target, np.ndarray):
|
|
190
|
+
if target.shape != shape or target.dtype.base != dtype:
|
|
191
|
+
raise IOError(f"File {f.name} read error: expected shape {target.shape}/{str(target.dtype)} but found {shape}/{str(dtype)}.")
|
|
192
|
+
array = target
|
|
193
|
+
|
|
194
|
+
else:
|
|
195
|
+
array = target( shape=shape, dtype=dtype )
|
|
196
|
+
assert not array is None, ("'target' function returned None")
|
|
197
|
+
assert array.shape == shape and array.dtype == dtype, ("'target' function returned wrong array; shape:", array.shape, shape, "; dtype:", array.dtype, dtype)
|
|
198
|
+
del target
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
_readfromfile(f, array)
|
|
202
|
+
except IOError as e:
|
|
203
|
+
raise IOError(f"Cannot read from {f.name}: {str(e)}", e)
|
|
204
|
+
if read_only:
|
|
205
|
+
array.flags.writeable = False
|
|
206
|
+
|
|
207
|
+
assert array.flags.writeable == (not read_only), ("Internal flag error", array.flags.writeable, read_only, not read_only )
|
|
208
|
+
return array
|
|
209
|
+
|
|
210
|
+
def read_shape_dtype( file, buffering : int = -1 ) -> tuple:
|
|
211
|
+
"""
|
|
212
|
+
Read shape and dtype from a numpy binary file.
|
|
213
|
+
|
|
214
|
+
Parameters
|
|
215
|
+
----------
|
|
216
|
+
file : file name passed to open(), or a file handle from open()
|
|
217
|
+
|
|
218
|
+
Returns
|
|
219
|
+
-------
|
|
220
|
+
shape, dtype
|
|
221
|
+
"""
|
|
222
|
+
if isinstance(file, str):
|
|
223
|
+
with open( file, "rb", buffering=buffering ) as f:
|
|
224
|
+
return read_shape_dtype( f, buffering=buffering )
|
|
225
|
+
return _readheader(file)
|
|
226
|
+
|
|
227
|
+
def readinto( file, array : np.ndarray, *, read_only : bool = False, buffering : int = -1 ):
|
|
228
|
+
"""
|
|
229
|
+
Read an array from disk into an existing array.
|
|
230
|
+
The receiving array must have the same shape and dtype as the array on disk.
|
|
231
|
+
|
|
232
|
+
Parameters
|
|
233
|
+
----------
|
|
234
|
+
file : file name passed to open(), or an open file
|
|
235
|
+
array : target array to write into. This array must have the same shape and dtype as the source data.
|
|
236
|
+
read_only : whether to clear the 'writable' flag of the array after the file was read
|
|
237
|
+
buffering : see open(); -1 is the default, 0 for no buffering.
|
|
238
|
+
|
|
239
|
+
Returns
|
|
240
|
+
-------
|
|
241
|
+
The array.
|
|
242
|
+
"""
|
|
243
|
+
return readfromfile( file, target = array, read_only=read_only, buffering=buffering )
|
|
244
|
+
|
|
245
|
+
def fromfile( file, *, validate_dtype = None, validate_shape = None, read_only : bool = False, buffering : int = -1 ) -> np.ndarray:
|
|
246
|
+
"""
|
|
247
|
+
Read array from disk into a new numpy array.
|
|
248
|
+
Use sharedarray.shared_fromfile() to create a shared array
|
|
249
|
+
|
|
250
|
+
Parameters
|
|
251
|
+
----------
|
|
252
|
+
file : file name passed to open(), or an open file
|
|
253
|
+
read_only: if True, clears the 'writable' flag for the returned array
|
|
254
|
+
validate_dtype: if specified, check that the array has the specified dtype
|
|
255
|
+
validate_shape: if specified, check that the array has the specified shape
|
|
256
|
+
buffering : see open(); -1 is the default, 0 for no buffering.
|
|
257
|
+
|
|
258
|
+
Returns
|
|
259
|
+
-------
|
|
260
|
+
Newly created numpy array
|
|
261
|
+
"""
|
|
262
|
+
return readfromfile( file,
|
|
263
|
+
target=lambda shape, dtype : np.empty( shape=shape, dtype=dtype ),
|
|
264
|
+
read_only = read_only,
|
|
265
|
+
validate_dtype=validate_dtype,
|
|
266
|
+
validate_shape=validate_shape,
|
|
267
|
+
buffering=buffering )
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
|
cdxcore/prettydict.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
"""
|
|
2
|
+
prettyict
|
|
3
|
+
Dictionaries with member synthax access
|
|
4
|
+
Hans Buehler 2022
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from collections import OrderedDict
|
|
8
|
+
from sortedcontainers import SortedDict
|
|
9
|
+
import dataclasses as dataclasses
|
|
10
|
+
from dataclasses import Field
|
|
11
|
+
import types as types
|
|
12
|
+
from collections.abc import Mapping
|
|
13
|
+
|
|
14
|
+
class PrettyDict(dict):
|
|
15
|
+
"""
|
|
16
|
+
Dictionary which allows accessing its members with member notation, e.g.
|
|
17
|
+
pdct = PrettyDict()
|
|
18
|
+
pdct.x = 1
|
|
19
|
+
x = pdct.x
|
|
20
|
+
|
|
21
|
+
Functions will be made members, i.e the following works as expected
|
|
22
|
+
def mult_x(self, a):
|
|
23
|
+
return self.x * a
|
|
24
|
+
pdct.mult_x = mult_x
|
|
25
|
+
pdct.mult_x(2) --> 2
|
|
26
|
+
|
|
27
|
+
To assign a static member use []:
|
|
28
|
+
def mult(a,b):
|
|
29
|
+
return a*b
|
|
30
|
+
pdct['mult'] = mult
|
|
31
|
+
pdct.mult(1,3) --> 3
|
|
32
|
+
|
|
33
|
+
IMPORTANT
|
|
34
|
+
Attributes starting with '__' are handled as standard attributes.
|
|
35
|
+
In other words,
|
|
36
|
+
pdct = PrettyDict()
|
|
37
|
+
pdct.__x = 1
|
|
38
|
+
_ = pdct['__x'] <- throws an exception
|
|
39
|
+
This allows re-use of general operator handling.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __getattr__(self, key : str):
|
|
43
|
+
""" Equyivalent to self[key] """
|
|
44
|
+
if key[:2] == "__": raise AttributeError(key) # you cannot treat private members as dictionary members
|
|
45
|
+
return self[key]
|
|
46
|
+
def __delattr__(self, key : str):
|
|
47
|
+
""" Equyivalent to del self[key] """
|
|
48
|
+
if key[:2] == "__": raise AttributeError(key) # you cannot treat private members as dictionary members
|
|
49
|
+
del self[key]
|
|
50
|
+
def __setattr__(self, key : str, value):
|
|
51
|
+
""" Equivalent to self[key] = value """
|
|
52
|
+
if key[:2] == "__":
|
|
53
|
+
return dict.__setattr__(self, key, value)
|
|
54
|
+
if isinstance(value,types.FunctionType):
|
|
55
|
+
# bind function to this object
|
|
56
|
+
value = types.MethodType(value,self)
|
|
57
|
+
elif isinstance(value,types.MethodType):
|
|
58
|
+
# re-point the method to the current instance
|
|
59
|
+
value = types.MethodType(value.__func__,self)
|
|
60
|
+
self[key] = value
|
|
61
|
+
def __call__(self, key : str, *default):
|
|
62
|
+
""" Equivalent of self.get(key,default) """
|
|
63
|
+
if len(default) > 1:
|
|
64
|
+
raise NotImplementedError("Cannot pass more than one default parameter.")
|
|
65
|
+
return self.get(key,default[0]) if len(default) == 1 else self.get(key)
|
|
66
|
+
|
|
67
|
+
def as_field(self) -> Field:
|
|
68
|
+
"""
|
|
69
|
+
Returns a PrettyDictField wrapper around self for use in dataclasses
|
|
70
|
+
See PrettyDictField documentation for an example
|
|
71
|
+
"""
|
|
72
|
+
return PrettyDictField(self)
|
|
73
|
+
|
|
74
|
+
class PrettyOrderedDict(OrderedDict):
|
|
75
|
+
"""
|
|
76
|
+
Ordered dictionary which allows accessing its members with member notation, e.g.
|
|
77
|
+
pdct = PrettyDict()
|
|
78
|
+
pdct.x = 1
|
|
79
|
+
x = pdct.x
|
|
80
|
+
|
|
81
|
+
Functions will be made members, i.e the following works as expected
|
|
82
|
+
def mult_x(self, a):
|
|
83
|
+
return self.x * a
|
|
84
|
+
pdct.mult_x = mult_x
|
|
85
|
+
pdct.mult_x(2) --> 2
|
|
86
|
+
|
|
87
|
+
To assign a static member use []:
|
|
88
|
+
def mult(a,b):
|
|
89
|
+
return a*b
|
|
90
|
+
pdct['mult'] = mult
|
|
91
|
+
pdct.mult(1,3) --> 3
|
|
92
|
+
|
|
93
|
+
IMPORTANT
|
|
94
|
+
Attributes starting with '__' are handled as standard attributes.
|
|
95
|
+
In other words,
|
|
96
|
+
pdct = PrettyOrderedDict()
|
|
97
|
+
pdct.__x = 1
|
|
98
|
+
_ = pdct['__x'] <- throws an exception
|
|
99
|
+
This allows re-use of general operator handling.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __getattr__(self, key : str):
|
|
103
|
+
""" Equyivalent to self[key] """
|
|
104
|
+
if key[:2] == "__": raise AttributeError(key) # you cannot treat private members as dictionary members
|
|
105
|
+
return self[key]
|
|
106
|
+
def __delattr__(self, key : str):
|
|
107
|
+
""" Equyivalent to del self[key] """
|
|
108
|
+
if key[:2] == "__": raise AttributeError(key) # you cannot treat private members as dictionary members
|
|
109
|
+
del self[key]
|
|
110
|
+
def __setattr__(self, key : str, value):
|
|
111
|
+
""" Equivalent to self[key] = value """
|
|
112
|
+
if key[:2] == "__":
|
|
113
|
+
return OrderedDict.__setattr__(self, key, value)
|
|
114
|
+
if isinstance(value,types.FunctionType):
|
|
115
|
+
# bind function to this object
|
|
116
|
+
value = types.MethodType(value,self)
|
|
117
|
+
elif isinstance(value,types.MethodType):
|
|
118
|
+
# re-point the method to the current instance
|
|
119
|
+
value = types.MethodType(value.__func__,self)
|
|
120
|
+
self[key] = value
|
|
121
|
+
def __call__(self, key : str, *default):
|
|
122
|
+
""" Equivalent of self.get(key,default) """
|
|
123
|
+
if len(default) > 1:
|
|
124
|
+
raise NotImplementedError("Cannot pass more than one default parameter.")
|
|
125
|
+
return self.get(key,default[0]) if len(default) == 1 else self.get(key)
|
|
126
|
+
|
|
127
|
+
# pickling
|
|
128
|
+
def __getstate__(self):
|
|
129
|
+
""" Return state to pickle """
|
|
130
|
+
return self.__dict__
|
|
131
|
+
def __setstate__(self, state):
|
|
132
|
+
""" Restore pickle """
|
|
133
|
+
self.__dict__.update(state)
|
|
134
|
+
|
|
135
|
+
def as_field(self) -> Field:
|
|
136
|
+
"""
|
|
137
|
+
Returns a PrettyDictField wrapper around 'self' for use in dataclasses
|
|
138
|
+
See PrettyDictField documentation for an example
|
|
139
|
+
"""
|
|
140
|
+
return PrettyDictField(self)
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def at_pos(self):
|
|
144
|
+
"""
|
|
145
|
+
Element access
|
|
146
|
+
|
|
147
|
+
at_pos[position] returns an element or elements at an ordinal position.
|
|
148
|
+
It returns a single element if 'position' refers to only one field.
|
|
149
|
+
If 'position' is a slice then the respecitve list of fields is returned
|
|
150
|
+
|
|
151
|
+
at_pos[position] = item assigns an item or an ordinal position
|
|
152
|
+
If 'position' refers to a single element, 'item' must be that item
|
|
153
|
+
If 'position' is a slice then 'item' must resolve to a list of the required size.
|
|
154
|
+
|
|
155
|
+
Key access
|
|
156
|
+
|
|
157
|
+
at_pos.keys[position] returns the key or keys at 'position'
|
|
158
|
+
|
|
159
|
+
at_pos.items[position] returns the tuple (key, element) or a list thereof for `position`
|
|
160
|
+
"""
|
|
161
|
+
class Access:
|
|
162
|
+
"""
|
|
163
|
+
Wrapper object to allow index access for at_pos
|
|
164
|
+
"""
|
|
165
|
+
def __init__(self):
|
|
166
|
+
self.__keys = None
|
|
167
|
+
|
|
168
|
+
def __getitem__(_, position):
|
|
169
|
+
key = _.keys[position]
|
|
170
|
+
return self[key] if not isinstance(key,list) else [ self[k] for k in key ]
|
|
171
|
+
def __setitem__(_, position, item ):
|
|
172
|
+
key = _.keys[position]
|
|
173
|
+
if not isinstance(key,list):
|
|
174
|
+
self[key] = item
|
|
175
|
+
else:
|
|
176
|
+
for k, i in zip(key, item):
|
|
177
|
+
self[k] = i
|
|
178
|
+
@property
|
|
179
|
+
def keys(_) -> list:
|
|
180
|
+
""" Returns the list of keys of the original dictionary """
|
|
181
|
+
if _.__keys is None:
|
|
182
|
+
_.__keys = list(self.keys())
|
|
183
|
+
return _.__keys
|
|
184
|
+
@property
|
|
185
|
+
def items(_) -> list:
|
|
186
|
+
""" Returns the list of keys of the original dictionary """
|
|
187
|
+
class ItemAccess(object):
|
|
188
|
+
def __getitem__(_x, position):
|
|
189
|
+
key = _.keys[position]
|
|
190
|
+
return (key, self[key]) if not isinstance(key,list) else [ (k,self[k]) for k in key ]
|
|
191
|
+
return ItemAccess()
|
|
192
|
+
|
|
193
|
+
return Access()
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def key_at_pos(self) -> list:
|
|
197
|
+
"""
|
|
198
|
+
key_at_pos returns the list of keys, hence key_at_pos[i] returns the ith key
|
|
199
|
+
"""
|
|
200
|
+
return list(self)
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def item_at_pos(self) -> list:
|
|
204
|
+
"""
|
|
205
|
+
key_at_pos returns the list of keys, hence key_at_pos[i] returns the ith key
|
|
206
|
+
"""
|
|
207
|
+
return list(self)
|
|
208
|
+
|
|
209
|
+
pdct = PrettyOrderedDict
|
|
210
|
+
|
|
211
|
+
class BETA_PrettySortedDict(SortedDict):
|
|
212
|
+
"""
|
|
213
|
+
*NOT WORKING WELL*
|
|
214
|
+
Sorted dictionary which allows accessing its members with member notation, e.g.
|
|
215
|
+
pdct = PrettyDict()
|
|
216
|
+
pdct.x = 1
|
|
217
|
+
x = pdct.x
|
|
218
|
+
|
|
219
|
+
Functions will be made members, i.e the following works as expected
|
|
220
|
+
def mult_x(self, a):
|
|
221
|
+
return self.x * a
|
|
222
|
+
pdct.mult_x = mult_x
|
|
223
|
+
pdct.mult_x(2) --> 2
|
|
224
|
+
|
|
225
|
+
To assign a static member use []:
|
|
226
|
+
def mult(a,b):
|
|
227
|
+
return a*b
|
|
228
|
+
pdct['mult'] = mult
|
|
229
|
+
pdct.mult(1,3) --> 3
|
|
230
|
+
|
|
231
|
+
IMPORTANT
|
|
232
|
+
Attributes starting with '_' (one underscore) are handled as standard attributes.
|
|
233
|
+
In other words,
|
|
234
|
+
pdct = PrettyOrderedDict()
|
|
235
|
+
pdct._x = 1
|
|
236
|
+
_ = pdct['_x'] <- throws an exception
|
|
237
|
+
This allows re-use of general operator handling.
|
|
238
|
+
The reason the sorted class disallow '_' (as opposed to the other two classes who merely disallow '__')
|
|
239
|
+
is that SortedDict() uses protected members.
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
def __getattr__(self, key : str):
|
|
243
|
+
""" Equyivalent to self[key] """
|
|
244
|
+
if key[:1] == "_": raise AttributeError(key) # you cannot treat protected or private members as dictionary members
|
|
245
|
+
return self[key]
|
|
246
|
+
def __delattr__(self, key : str):
|
|
247
|
+
""" Equyivalent to del self[key] """
|
|
248
|
+
if key[:2] == "__": raise AttributeError(key) # you cannot treat private members as dictionary members
|
|
249
|
+
del self[key]
|
|
250
|
+
def __setattr__(self, key : str, value):
|
|
251
|
+
""" Equivalent to self[key] = value """
|
|
252
|
+
if key[:1] == "_":
|
|
253
|
+
return SortedDict.__setattr__(self, key, value)
|
|
254
|
+
if isinstance(value,types.FunctionType):
|
|
255
|
+
# bind function to this object
|
|
256
|
+
value = types.MethodType(value,self)
|
|
257
|
+
elif isinstance(value,types.MethodType):
|
|
258
|
+
# re-point the method to the current instance
|
|
259
|
+
value = types.MethodType(value.__func__,self)
|
|
260
|
+
self[key] = value
|
|
261
|
+
def __call__(self, key : str, *default):
|
|
262
|
+
""" Equivalent of self.get(key,default) """
|
|
263
|
+
if len(default) > 1:
|
|
264
|
+
raise NotImplementedError("Cannot pass more than one default parameter.")
|
|
265
|
+
return self.get(key,default[0]) if len(default) == 1 else self.get(key)
|
|
266
|
+
|
|
267
|
+
class PrettyDictField(object):
|
|
268
|
+
"""
|
|
269
|
+
Simplististc 'read only' wrapper for PrettyOrderedDict objects.
|
|
270
|
+
Useful for Flax
|
|
271
|
+
|
|
272
|
+
import dataclasses as dataclasses
|
|
273
|
+
import jax.numpy as jnp
|
|
274
|
+
import jax as jax
|
|
275
|
+
from options.cdxbasics.config import Config, ConfigField
|
|
276
|
+
import types as types
|
|
277
|
+
|
|
278
|
+
class A( nn.Module ):
|
|
279
|
+
pdct : PrettyOrderedDictField = PrettyOrderedDictField.Field()
|
|
280
|
+
|
|
281
|
+
def setup(self):
|
|
282
|
+
self.dense = nn.Dense(1)
|
|
283
|
+
|
|
284
|
+
def __call__(self, x):
|
|
285
|
+
a = self.pdct.a
|
|
286
|
+
return self.dense(x)*a
|
|
287
|
+
|
|
288
|
+
r = PrettyOrderedDict(a=1.)
|
|
289
|
+
a = A( r.as_field() )
|
|
290
|
+
|
|
291
|
+
key1, key2 = jax.random.split(jax.random.key(0))
|
|
292
|
+
x = jnp.zeros((10,10))
|
|
293
|
+
param = a.init( key1, x )
|
|
294
|
+
y = a.apply( param, x )
|
|
295
|
+
|
|
296
|
+
The class will traverse pretty dictionaries of pretty dictionaries correctly.
|
|
297
|
+
However, it has some limitations as it does not handle custom lists of pretty dicts.
|
|
298
|
+
"""
|
|
299
|
+
def __init__(self, pdct : PrettyOrderedDict = None, **kwargs):
|
|
300
|
+
""" Initialize with an input dictionary and potential overwrites """
|
|
301
|
+
if not pdct is None:
|
|
302
|
+
if type(pdct).__name__ == type(self).__name__ and len(kwargs) == 0:
|
|
303
|
+
# copy
|
|
304
|
+
self.__pdct = PrettyOrderedDict( pdct.__pdct )
|
|
305
|
+
return
|
|
306
|
+
if not isinstance(pdct, Mapping): raise ValueError("'pdct' must be a Mapping")
|
|
307
|
+
self.__pdct = PrettyOrderedDict(pdct)
|
|
308
|
+
self.__pdct.update(kwargs)
|
|
309
|
+
else:
|
|
310
|
+
self.__pdct = PrettyOrderedDict(**kwargs)
|
|
311
|
+
def rec(x):
|
|
312
|
+
for k, v in x.items():
|
|
313
|
+
if isinstance(v, (PrettyDict, PrettyOrderedDict)):
|
|
314
|
+
x[k] = PrettyDictField(v)
|
|
315
|
+
elif isinstance(v, Mapping):
|
|
316
|
+
rec(v)
|
|
317
|
+
rec(self.__pdct)
|
|
318
|
+
|
|
319
|
+
def as_dict(self) -> PrettyOrderedDict:
|
|
320
|
+
""" Return copy of underlying dictionary """
|
|
321
|
+
return PrettyOrderedDict( self.__pdct )
|
|
322
|
+
|
|
323
|
+
def as_field(self) -> Field:
|
|
324
|
+
"""
|
|
325
|
+
Returns a PrettyDictField wrapper around self for use in dataclasse
|
|
326
|
+
This function makes a (shallow enough) copy of the current field.
|
|
327
|
+
It is present so iterative applications of as_field() are convenient.
|
|
328
|
+
"""
|
|
329
|
+
return PrettyDictField(self)
|
|
330
|
+
|
|
331
|
+
@staticmethod
|
|
332
|
+
def default():
|
|
333
|
+
return PrettyDictField()
|
|
334
|
+
|
|
335
|
+
@staticmethod
|
|
336
|
+
def Field( default : PrettyOrderedDict = None, **kwargs):
|
|
337
|
+
"""
|
|
338
|
+
Returns a dataclasses.field for PrettyDictField
|
|
339
|
+
"""
|
|
340
|
+
if default is None and len(kwargs) == 0:
|
|
341
|
+
return dataclasses.field( default_factory=PrettyDictField )
|
|
342
|
+
|
|
343
|
+
if not default is None:
|
|
344
|
+
default.upate(kwargs)
|
|
345
|
+
else:
|
|
346
|
+
default = kwargs
|
|
347
|
+
def factory():
|
|
348
|
+
return PrettyDictField(default)
|
|
349
|
+
return dataclasses.field( default_factory=factory )
|
|
350
|
+
|
|
351
|
+
# mimic the underlying dictionary
|
|
352
|
+
# -------------------------------
|
|
353
|
+
|
|
354
|
+
def __getattr__(self, key):
|
|
355
|
+
if key[:2] == "__":
|
|
356
|
+
return object.__getattr__(self,key)
|
|
357
|
+
return self.__pdct.__getattr__(key)
|
|
358
|
+
def __getitem__(self, key):
|
|
359
|
+
return self.__pdct[key]
|
|
360
|
+
def __call__(self, *kargs, **kwargs):
|
|
361
|
+
return self.__pdct(*kargs, **kwargs)
|
|
362
|
+
def __eq__(self, other):
|
|
363
|
+
if type(other).__name__ == "PrettyOrderedDict":
|
|
364
|
+
return self.__pdct == other
|
|
365
|
+
else:
|
|
366
|
+
return self.__pdct == other.pdct
|
|
367
|
+
def keys(self):
|
|
368
|
+
return self.__pdct.keys()
|
|
369
|
+
def items(self):
|
|
370
|
+
return self.__pdct.items()
|
|
371
|
+
def values(self):
|
|
372
|
+
return self.__pdct.values()
|
|
373
|
+
def __hash__(self):
|
|
374
|
+
h = 0
|
|
375
|
+
for k, v in self.items():
|
|
376
|
+
h ^= hash(k) ^ hash(v)
|
|
377
|
+
return h
|
|
378
|
+
def __iter__(self):
|
|
379
|
+
return self.__pdct.__iter__()
|
|
380
|
+
def __contains__(self, key):
|
|
381
|
+
return self.__pdct.__contains__(key)
|
|
382
|
+
def __len__(self):
|
|
383
|
+
return self.__pdct.__len__()
|
|
384
|
+
def __str__(self):
|
|
385
|
+
return self.__pdct.__str__()
|
|
386
|
+
def __repr__(self):
|
|
387
|
+
return self.__pdct.__repr__()
|
|
388
|
+
|