ato 2.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ato/__init__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = '1.12.0'
ato/adict.py ADDED
@@ -0,0 +1,582 @@
1
+ import hashlib
2
+ import importlib.util
3
+ import json
4
+ import sys
5
+ import types
6
+ import warnings
7
+ from collections.abc import MutableMapping as GenericMapping
8
+
9
+ import toml
10
+ import yaml
11
+ import os
12
+ from copy import deepcopy as dcp
13
+ from functools import wraps
14
+ from types import MappingProxyType
15
+ from typing import Mapping, MutableMapping, Callable
16
+
17
+ from ato import xyz
18
+
19
+ ALLOWED_EXTS = ('.yaml', '.yml', '.json', '.toml', '.xyz')
20
+
21
+
22
+ # decorate internal methods in ADict
23
+ def mutate_attribute(fn):
24
+ @wraps(fn)
25
+ def decorator(*args, **kwargs):
26
+ ctx = args[0]
27
+ object.__setattr__(ctx, '_mutate_attribute', True)
28
+ result = fn(*args, **kwargs)
29
+ object.__setattr__(ctx, '_mutate_attribute', False)
30
+ return result
31
+ return decorator
32
+
33
+
34
+ class Dict(GenericMapping):
35
+ def __init__(self, mapping=None, /, **kwargs):
36
+ self._data = dict()
37
+ if mapping is not None:
38
+ self.update(mapping)
39
+ if kwargs:
40
+ self.update(kwargs)
41
+
42
+ def __len__(self):
43
+ return len(self._data)
44
+
45
+ def __getitem__(self, key):
46
+ if key in self._data:
47
+ return self._data[key]
48
+ if hasattr(self.__class__, "__missing__"):
49
+ return self.__class__.__missing__(self, key)
50
+ raise KeyError(key)
51
+
52
+ def __setitem__(self, key, item):
53
+ self._data[key] = item
54
+
55
+ def __delitem__(self, key):
56
+ del self._data[key]
57
+
58
+ def __iter__(self):
59
+ return iter(self._data)
60
+
61
+ # Modify __contains__ to work correctly when __missing__ is present
62
+ def __contains__(self, key):
63
+ return key in self._data
64
+
65
+ # Now, add the methods in dicts but not in MutableMapping
66
+ def __repr__(self):
67
+ return repr(self._data)
68
+
69
+ def __or__(self, other):
70
+ if isinstance(other, Dict):
71
+ return self.__class__(self._data | other._data)
72
+ if isinstance(other, dict):
73
+ return self.__class__(self._data | other)
74
+ return NotImplemented
75
+
76
+ def __ror__(self, other):
77
+ if isinstance(other, Dict):
78
+ return self.__class__(other._data | self._data)
79
+ if isinstance(other, dict):
80
+ return self.__class__(other | self._data)
81
+ return NotImplemented
82
+
83
+ def __ior__(self, other):
84
+ if isinstance(other, Dict):
85
+ self._data |= other._data
86
+ else:
87
+ self._data |= other
88
+ return self
89
+
90
+ def __copy__(self):
91
+ inst = self.__class__.__new__(self.__class__)
92
+ inst.__dict__.update(self.__dict__)
93
+ # Create a copy and avoid triggering descriptors
94
+ inst.__dict__["_data"] = self.__dict__["_data"].copy()
95
+ return inst
96
+
97
+ def copy(self):
98
+ if self.__class__ is Dict:
99
+ return Dict(self._data.copy())
100
+ import copy
101
+ data = self._data
102
+ try:
103
+ self._data = dict()
104
+ c = copy.copy(self)
105
+ finally:
106
+ self._data = data
107
+ c.update(self)
108
+ return c
109
+
110
+ @classmethod
111
+ def fromkeys(cls, iterable, value=None):
112
+ d = cls()
113
+ for key in iterable:
114
+ d[key] = value
115
+ return d
116
+
117
+
118
+ class ADict(Dict):
119
+ @mutate_attribute
120
+ def __init__(self, *args, **kwargs):
121
+ if 'default' in kwargs:
122
+ self._default = kwargs.pop('default')
123
+ self._is_default_defined = True
124
+ else:
125
+ self._default = None
126
+ self._is_default_defined = False
127
+ mappings = dict()
128
+ for mapping in args:
129
+ if not isinstance(mapping, (dict, MutableMapping)):
130
+ try:
131
+ mapping = dict(mapping)
132
+ except (AttributeError, TypeError, ValueError):
133
+ raise TypeError(
134
+ f'Any of positional arguments must be able to converted to key-value type, '
135
+ f'but {mapping} is not.'
136
+ )
137
+ mappings.update(mapping)
138
+ self._frozen = False
139
+ self._accessed_keys = set()
140
+ super().__init__(mappings, **kwargs)
141
+
142
+ @property
143
+ def frozen(self):
144
+ return self._frozen
145
+
146
+ @property
147
+ def accessed_keys(self):
148
+ return self._accessed_keys
149
+
150
+ def get_minimal_config(self):
151
+ new_config = ADict()
152
+ for key in self.accessed_keys:
153
+ value = self.__getitem__(key)
154
+ if isinstance(value, self.__class__):
155
+ new_config[key] = value.get_minimal_config()
156
+ else:
157
+ new_config[key] = value
158
+ return new_config
159
+
160
+ def __getitem__(self, names):
161
+ if isinstance(names, str):
162
+ if names in self._data:
163
+ value = self._data[names]
164
+ elif self._is_default_defined and not self.frozen:
165
+ value = self.get_default()
166
+ self._data[names] = value
167
+ else:
168
+ raise KeyError(f'The key "{names}" does not exist.')
169
+ self._accessed_keys.add(names)
170
+ else:
171
+ value = [self.__getitem__(name) for name in names]
172
+ self._accessed_keys.update(names)
173
+ if self.frozen:
174
+ value = dcp(value)
175
+ return value
176
+
177
+ def __setitem__(self, names, values):
178
+ if not self.frozen:
179
+ if isinstance(names, str):
180
+ if isinstance(values, Mapping):
181
+ values = self.__class__(**values)
182
+ elif isinstance(values, (list, tuple)):
183
+ values = [self.__class__(**value) if isinstance(value, Mapping) else value for value in values]
184
+ super().__setitem__(names, values)
185
+ elif isinstance(values, (list, tuple)):
186
+ for name, value in zip(names, values):
187
+ self.__setitem__(name, value)
188
+ else:
189
+ for name in names:
190
+ self.__setitem__(name, values)
191
+
192
+ def __getattr__(self, name):
193
+ try:
194
+ return object.__getattribute__(self, name)
195
+ except AttributeError:
196
+ try:
197
+ return self.__getitem__(name)
198
+ except KeyError:
199
+ raise AttributeError(name)
200
+
201
+ def __setattr__(self, name, value):
202
+ if self._mutate_attribute:
203
+ object.__setattr__(self, name, value)
204
+ else:
205
+ self.__setitem__(name, value)
206
+
207
+ def __delattr__(self, names):
208
+ if self._mutate_attribute:
209
+ object.__delattr__(self, names)
210
+ elif isinstance(names, str):
211
+ self.__delitem__(names)
212
+ else:
213
+ for name in names:
214
+ self.__delitem__(name)
215
+
216
+ def __deepcopy__(self, memo=None):
217
+ mappings = dcp(self._data)
218
+ kwargs = dict()
219
+ if self._is_default_defined:
220
+ kwargs.update(default=self._default)
221
+ return self.__class__(mappings, **kwargs)
222
+
223
+ def __getstate__(self):
224
+ state = dcp(self.__dict__)
225
+ state.pop('_mutate_attribute')
226
+ return state
227
+
228
+ @mutate_attribute
229
+ def __setstate__(self, state):
230
+ self._data = state.pop('_data')
231
+ for k, v in state.items():
232
+ object.__setattr__(self, k, v)
233
+
234
+ def set_default(self, default=None):
235
+ self._default = default
236
+
237
+ def remove_default(self):
238
+ self._is_default_defined = False
239
+ self._default = None
240
+
241
+ def get_default(self):
242
+ if self._is_default_defined:
243
+ _default = object.__getattribute__(self, '_default')
244
+ if callable(_default):
245
+ return _default()
246
+ else:
247
+ return dcp(_default)
248
+ else:
249
+ raise ValueError('Default value is not defined.')
250
+
251
+ def get(self, name, default=None):
252
+ if name in self:
253
+ return self.__getitem__(name)
254
+ elif self._is_default_defined:
255
+ if self._frozen:
256
+ raise KeyError(f'The key "{name}" does not exist.')
257
+ return self.get_default()
258
+ else:
259
+ return default
260
+
261
+ def __delitem__(self, key):
262
+ if not self.frozen:
263
+ super().__delitem__(key)
264
+
265
+ def pop(self, name, default=None):
266
+ value = self.get(name, default)
267
+ if name in self:
268
+ self.__delitem__(name)
269
+ return value
270
+
271
+ def raw(self, name, default=None):
272
+ value = self.get(name, default)
273
+ item = self.__class__(key=name, value=value)
274
+ return item
275
+
276
+ @mutate_attribute
277
+ def filter(self, fn: Callable):
278
+ data = dict()
279
+ for key, value in self.items():
280
+ if fn(key, value):
281
+ data[key] = value
282
+ self._data = data
283
+
284
+ def get_value_by_name(self, name):
285
+ keys = name.split('.')
286
+ value = self._data
287
+ for key in keys:
288
+ value = value[key]
289
+ return value
290
+
291
+ @mutate_attribute
292
+ def freeze(self):
293
+ self._frozen = True
294
+ return self
295
+
296
+ @mutate_attribute
297
+ def defrost(self):
298
+ self._frozen = False
299
+ return self
300
+
301
+ @mutate_attribute
302
+ def update(self, __m=None, recurrent=False, **kwargs):
303
+ if not self.frozen:
304
+ if __m is not None:
305
+ self.update(**__m, recurrent=recurrent)
306
+ if recurrent:
307
+ children = self.__class__()
308
+ for k, v in kwargs.items():
309
+ if k in self and isinstance(v, Mapping):
310
+ self.__getitem__(k).update(**v, recurrent=True)
311
+ else:
312
+ children[k] = v
313
+ super().update(**children)
314
+ else:
315
+ super().update(**kwargs)
316
+ return self
317
+
318
+ @mutate_attribute
319
+ def update_if_absent(self, __m=None, recurrent=False, **kwargs):
320
+ if not self.frozen:
321
+ if __m is not None:
322
+ self.update_if_absent(**__m, recurrent=recurrent)
323
+ children = self.__class__()
324
+ for k, v in kwargs.items():
325
+ if k in self and isinstance(v, Mapping):
326
+ self.__getitem__(k).update_if_absent(**v, recurrent=True)
327
+ elif k not in self:
328
+ children[k] = v
329
+ super().update(**children)
330
+ return self
331
+
332
+ def get_structural_mapping(self, key, value):
333
+ if key is None:
334
+ key = ""
335
+ items = []
336
+ if isinstance(value, (MutableMapping, dict)):
337
+ for k, v in value.items():
338
+ concat = k if key == '' else f'{key}.{k}'
339
+ items += self.get_structural_mapping(concat, v)
340
+ else:
341
+ return [(key, type(value).__name__)]
342
+ return items
343
+
344
+ def get_structural_repr(self):
345
+ structural_repr = self.__class__()
346
+ for k, v in self.get_structural_mapping("", self):
347
+ structural_repr[k] = v
348
+ return structural_repr
349
+
350
+ def get_structural_hash(self):
351
+ structural_repr = self.get_structural_repr()
352
+ structural_repr = list(structural_repr.items())
353
+ structural_repr.sort(key=lambda x: x[0])
354
+ structural_hash = ''
355
+ for k, v in structural_repr:
356
+ structural_hash += f'[{k}:{v}]'
357
+ return str(dcp(hashlib.sha1(structural_hash.encode('utf-8')).hexdigest()))
358
+
359
+ @mutate_attribute
360
+ def convert_to_immutable(self):
361
+ self._data = MappingProxyType(self._data)
362
+
363
+ @mutate_attribute
364
+ def json(self):
365
+ return json.dumps(self.to_dict())
366
+
367
+ def clone(self):
368
+ return dcp(self)
369
+
370
+ def to_dict(self):
371
+ data = dict()
372
+ for key, value in self.items():
373
+ if isinstance(value, Mapping):
374
+ data[key] = self.__class__(**value).to_dict()
375
+ elif isinstance(value, (list, tuple)):
376
+ data[key] = [self.__class__(**item).to_dict() if isinstance(item, Mapping) else item for item in value]
377
+ else:
378
+ data[key] = value
379
+ return data
380
+
381
+ def to_xyz(self, format_dict=None):
382
+ return xyz.dumps(self.to_dict(), format_dict=format_dict)
383
+
384
+ @classmethod
385
+ def from_file(cls, path):
386
+ if os.path.exists(path):
387
+ ext = os.path.splitext(path)[1].lower()
388
+ if ext in ('.yml', '.yaml'):
389
+ with open(path, 'rb') as f:
390
+ return cls(yaml.load(f, Loader=yaml.FullLoader))
391
+ elif ext == '.toml':
392
+ with open(path, 'r') as f:
393
+ return cls(toml.load(f))
394
+ elif ext == '.json':
395
+ with open(path, 'r') as f:
396
+ obj = json.load(f)
397
+ if isinstance(obj, list):
398
+ return [cls(item) for item in obj]
399
+ else:
400
+ return cls(obj)
401
+ elif ext == '.jsonl':
402
+ with open(path, 'r') as f:
403
+ dict_list = json.load(f)
404
+ return [cls(item) for item in dict_list]
405
+ elif ext == '.xyz':
406
+ obj = xyz.load(path)
407
+ if isinstance(obj, list):
408
+ return [cls(item) for item in obj]
409
+ else:
410
+ return cls(obj)
411
+ elif ext == '.py':
412
+ obj = cls.compile_from_file(path)
413
+ if isinstance(obj, list):
414
+ return [cls(item) for item in obj]
415
+ else:
416
+ return cls(obj)
417
+ else:
418
+ raise ValueError(f'{ext} is not a valid file extension.')
419
+ else:
420
+ raise FileNotFoundError(f'{path} does not exist.')
421
+
422
+ @classmethod
423
+ def compile_from_file(cls, path):
424
+ config_name = os.path.splitext(os.path.basename(path))[0]
425
+ spec = importlib.util.spec_from_file_location(config_name, path)
426
+ config_module = importlib.util.module_from_spec(spec)
427
+ spec.loader.exec_module(config_module)
428
+ sys.modules[config_name] = config_module
429
+ config = {
430
+ name: value
431
+ for name, value in config_module.__dict__.items()
432
+ if not name.startswith('__') and not isinstance(value, (types.ModuleType, types.FunctionType))
433
+ }
434
+ del sys.modules[config_name]
435
+ return cls(**config)
436
+
437
+ def mm_like_update(self, **kwargs):
438
+ for key, value in kwargs.items():
439
+ if isinstance(value, MutableMapping):
440
+ recurrent = '_delete_' not in value
441
+ if not recurrent:
442
+ del value['_delete_']
443
+ if key in self and isinstance(self[key], MutableMapping) and recurrent:
444
+ self[key].mm_like_update(**value)
445
+ else:
446
+ self[key] = value
447
+ else:
448
+ self[key] = value
449
+
450
+ @classmethod
451
+ def from_mm_config(cls, path):
452
+ config = cls.from_file(path)
453
+ mm_like_config = cls()
454
+ if '_base_' in config:
455
+ base_paths = config.pop('_base_')
456
+ if isinstance(base_paths, str):
457
+ base_paths = [base_paths]
458
+ base_paths = [
459
+ os.path.join(os.path.dirname(os.path.realpath(path)), base_path)
460
+ if not os.path.exists(base_path) else base_path
461
+ for base_path in base_paths
462
+ ]
463
+ base_configs = [cls.from_mm_config(path) for path in base_paths]
464
+ else:
465
+ base_configs = []
466
+ for base_config in base_configs:
467
+ mm_like_config.mm_like_update(**base_config)
468
+ mm_like_config.mm_like_update(**config)
469
+ return mm_like_config
470
+
471
+ @mutate_attribute
472
+ def load(self, path, **kwargs):
473
+ if os.path.exists(path):
474
+ ext = os.path.splitext(path)[1].lower()
475
+ if ext in ('.yml', '.yaml'):
476
+ with open(path, 'rb') as f:
477
+ self._data = yaml.load(f, Loader=yaml.FullLoader)
478
+ elif ext == '.json':
479
+ with open(path, 'r') as f:
480
+ self._data = json.load(f, **kwargs)
481
+ elif ext == '.xyz':
482
+ self._data = xyz.load(path)
483
+ elif ext == '.py':
484
+ self._data = self.compile_from_file(path).to_dict()
485
+ else:
486
+ raise ValueError(f'{ext} is not a valid file extension.')
487
+ else:
488
+ raise FileNotFoundError(f'{path} does not exist.')
489
+
490
+ @mutate_attribute
491
+ def load_mm_config(self, path):
492
+ self.update(self.from_mm_config(path).to_dict())
493
+
494
+ @classmethod
495
+ def compose_hierarchy(
496
+ cls,
497
+ root,
498
+ config_filename='config',
499
+ select=None,
500
+ overrides=None,
501
+ on_missing='error',
502
+ required=None
503
+ ):
504
+ select = select or {}
505
+ loaded_paths = []
506
+ config = cls()
507
+
508
+ def _load_first(_base_path, *, label):
509
+ _is_missing = True
510
+ for _ext in ALLOWED_EXTS:
511
+ _config_path = os.path.join(root, _base_path+_ext)
512
+ if os.path.exists(_config_path):
513
+ config.update(cls.from_file(_config_path))
514
+ loaded_paths.append((label, _config_path))
515
+ _is_missing = False
516
+ return _is_missing
517
+
518
+ if not _load_first(config_filename, label='base') and on_missing == 'error':
519
+ raise FileNotFoundError(f'{config_filename}.[yaml|yml|json|toml|xyz] not found in {root}')
520
+ for group, options in select.items():
521
+ options = options if isinstance(options, (list, tuple)) else [options]
522
+ for option in options:
523
+ is_missing = _load_first(os.path.join(group, str(option)), label=f'{group}:{option}')
524
+ if is_missing:
525
+ if on_missing == 'error':
526
+ raise FileNotFoundError(f'{group}/{option} not found')
527
+ elif on_missing == 'warn':
528
+ warnings.warn(f'missing: {group}/{option}')
529
+ for key, value in (overrides or {}).items():
530
+ if isinstance(value, Mapping):
531
+ value = ADict(**value)
532
+ current_config = config
533
+ sub_keys = key.split('.')
534
+ for sub_key in sub_keys[:-1]:
535
+ if sub_key not in current_config or not isinstance(current_config[sub_key], ADict):
536
+ current_config[sub_key] = ADict()
537
+ current_config = current_config[sub_key]
538
+ current_config[sub_keys[-1]] = value
539
+ for is_required in (required or []):
540
+ current_config = config
541
+ is_missing = False
542
+ for key in is_required.split('.'):
543
+ if key not in current_config:
544
+ is_missing = True
545
+ break
546
+ current_config = current_config[key]
547
+ if is_missing:
548
+ raise KeyError(f"Missing required key: {is_required}")
549
+ return config
550
+
551
+ def dump(self, path, **kwargs):
552
+ dir_path = os.path.dirname(os.path.realpath(path))
553
+ os.makedirs(dir_path, exist_ok=True)
554
+ ext = os.path.splitext(path)[1].lower()
555
+ if ext in ('.yml', '.yaml'):
556
+ with open(path, 'w') as f:
557
+ return yaml.dump(self.to_dict(), f, Dumper=yaml.Dumper, **kwargs)
558
+ elif ext == '.toml':
559
+ with open(path, 'w') as f:
560
+ return toml.dump(self.to_dict(), f, **kwargs)
561
+ elif ext == '.json':
562
+ with open(path, 'w') as f:
563
+ return json.dump(self.to_dict(), f, **kwargs)
564
+ elif ext == '.xyz':
565
+ return xyz.dump(self.to_dict(), path, **kwargs)
566
+ else:
567
+ raise ValueError(f'{ext} is not a valid file extension.')
568
+
569
+ def replace_keys(self, src_keys, tgt_keys):
570
+ if len(src_keys) != len(tgt_keys):
571
+ raise IndexError(f'Source and target keys cannot be mapped: {len(src_keys)} != {len(tgt_keys)}')
572
+ for src_key in src_keys:
573
+ if src_key not in self._data:
574
+ raise KeyError(f'The key {src_key} does not exist.')
575
+ self.__setitem__(tgt_keys, [self._data.pop(src_key) for src_key in src_keys])
576
+
577
+ @classmethod
578
+ def auto(cls):
579
+ return cls(default=lambda: cls.auto())
580
+
581
+
582
+
@@ -0,0 +1,8 @@
1
+ class BaseLogger:
2
+ def __init__(self, config):
3
+ self.config = config
4
+
5
+
6
+ class BaseFinder:
7
+ def __init__(self, config):
8
+ self.config = config
File without changes