python-misc-utils 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. py_misc_utils/__init__.py +0 -0
  2. py_misc_utils/abs_timeout.py +12 -0
  3. py_misc_utils/alog.py +311 -0
  4. py_misc_utils/app_main.py +179 -0
  5. py_misc_utils/archive_streamer.py +112 -0
  6. py_misc_utils/assert_checks.py +118 -0
  7. py_misc_utils/ast_utils.py +121 -0
  8. py_misc_utils/async_manager.py +189 -0
  9. py_misc_utils/break_control.py +63 -0
  10. py_misc_utils/buffered_iterator.py +35 -0
  11. py_misc_utils/cached_file.py +507 -0
  12. py_misc_utils/call_limiter.py +26 -0
  13. py_misc_utils/call_result_selector.py +13 -0
  14. py_misc_utils/cleanups.py +85 -0
  15. py_misc_utils/cmd.py +97 -0
  16. py_misc_utils/compression.py +116 -0
  17. py_misc_utils/cond_waiter.py +13 -0
  18. py_misc_utils/context_base.py +18 -0
  19. py_misc_utils/context_managers.py +67 -0
  20. py_misc_utils/core_utils.py +577 -0
  21. py_misc_utils/daemon_process.py +252 -0
  22. py_misc_utils/data_cache.py +46 -0
  23. py_misc_utils/date_utils.py +90 -0
  24. py_misc_utils/debug.py +24 -0
  25. py_misc_utils/dyn_modules.py +50 -0
  26. py_misc_utils/dynamod.py +103 -0
  27. py_misc_utils/env_config.py +35 -0
  28. py_misc_utils/executor.py +239 -0
  29. py_misc_utils/file_overwrite.py +29 -0
  30. py_misc_utils/fin_wrap.py +77 -0
  31. py_misc_utils/fp_utils.py +47 -0
  32. py_misc_utils/fs/__init__.py +0 -0
  33. py_misc_utils/fs/file_fs.py +127 -0
  34. py_misc_utils/fs/ftp_fs.py +242 -0
  35. py_misc_utils/fs/gcs_fs.py +196 -0
  36. py_misc_utils/fs/http_fs.py +241 -0
  37. py_misc_utils/fs/s3_fs.py +417 -0
  38. py_misc_utils/fs_base.py +133 -0
  39. py_misc_utils/fs_utils.py +207 -0
  40. py_misc_utils/gcs_fs.py +169 -0
  41. py_misc_utils/gen_indices.py +54 -0
  42. py_misc_utils/gfs.py +371 -0
  43. py_misc_utils/git_repo.py +77 -0
  44. py_misc_utils/global_namespace.py +110 -0
  45. py_misc_utils/http_async_fetcher.py +139 -0
  46. py_misc_utils/http_server.py +196 -0
  47. py_misc_utils/http_utils.py +143 -0
  48. py_misc_utils/img_utils.py +20 -0
  49. py_misc_utils/infix_op.py +20 -0
  50. py_misc_utils/inspect_utils.py +205 -0
  51. py_misc_utils/iostream.py +21 -0
  52. py_misc_utils/iter_file.py +117 -0
  53. py_misc_utils/key_wrap.py +46 -0
  54. py_misc_utils/lazy_import.py +25 -0
  55. py_misc_utils/lockfile.py +164 -0
  56. py_misc_utils/mem_size.py +64 -0
  57. py_misc_utils/mirror_from.py +72 -0
  58. py_misc_utils/mmap.py +16 -0
  59. py_misc_utils/module_utils.py +196 -0
  60. py_misc_utils/moving_average.py +19 -0
  61. py_misc_utils/msgpack_streamer.py +26 -0
  62. py_misc_utils/multi_wait.py +24 -0
  63. py_misc_utils/multiprocessing.py +102 -0
  64. py_misc_utils/named_array.py +224 -0
  65. py_misc_utils/no_break.py +46 -0
  66. py_misc_utils/no_except.py +32 -0
  67. py_misc_utils/np_ml_framework.py +184 -0
  68. py_misc_utils/np_utils.py +346 -0
  69. py_misc_utils/ntuple_utils.py +38 -0
  70. py_misc_utils/num_utils.py +54 -0
  71. py_misc_utils/obj.py +73 -0
  72. py_misc_utils/object_cache.py +100 -0
  73. py_misc_utils/object_tracker.py +88 -0
  74. py_misc_utils/ordered_set.py +71 -0
  75. py_misc_utils/osfd.py +27 -0
  76. py_misc_utils/packet.py +22 -0
  77. py_misc_utils/parquet_streamer.py +69 -0
  78. py_misc_utils/pd_utils.py +254 -0
  79. py_misc_utils/periodic_task.py +61 -0
  80. py_misc_utils/pickle_wrap.py +121 -0
  81. py_misc_utils/pipeline.py +98 -0
  82. py_misc_utils/remap_pickle.py +50 -0
  83. py_misc_utils/resource_manager.py +155 -0
  84. py_misc_utils/rnd_utils.py +56 -0
  85. py_misc_utils/run_once.py +19 -0
  86. py_misc_utils/scheduler.py +135 -0
  87. py_misc_utils/select_params.py +300 -0
  88. py_misc_utils/signal.py +141 -0
  89. py_misc_utils/skl_utils.py +270 -0
  90. py_misc_utils/split.py +147 -0
  91. py_misc_utils/state.py +53 -0
  92. py_misc_utils/std_module.py +56 -0
  93. py_misc_utils/stream_dataframe.py +176 -0
  94. py_misc_utils/streamed_file.py +144 -0
  95. py_misc_utils/tempdir.py +79 -0
  96. py_misc_utils/template_replace.py +51 -0
  97. py_misc_utils/tensor_stream.py +269 -0
  98. py_misc_utils/thread_context.py +33 -0
  99. py_misc_utils/throttle.py +30 -0
  100. py_misc_utils/time_trigger.py +18 -0
  101. py_misc_utils/timegen.py +11 -0
  102. py_misc_utils/traceback.py +49 -0
  103. py_misc_utils/tracking_executor.py +91 -0
  104. py_misc_utils/transform_array.py +42 -0
  105. py_misc_utils/uncompress.py +35 -0
  106. py_misc_utils/url_fetcher.py +157 -0
  107. py_misc_utils/utils.py +538 -0
  108. py_misc_utils/varint.py +50 -0
  109. py_misc_utils/virt_array.py +52 -0
  110. py_misc_utils/weak_call.py +33 -0
  111. py_misc_utils/work_results.py +100 -0
  112. py_misc_utils/writeback_file.py +43 -0
  113. python_misc_utils-0.2.dist-info/METADATA +36 -0
  114. python_misc_utils-0.2.dist-info/RECORD +117 -0
  115. python_misc_utils-0.2.dist-info/WHEEL +5 -0
  116. python_misc_utils-0.2.dist-info/licenses/LICENSE +13 -0
  117. python_misc_utils-0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,157 @@
1
+ import os
2
+ import queue
3
+ import threading
4
+
5
+ from . import alog
6
+ from . import assert_checks as tas
7
+ from . import file_overwrite as fow
8
+ from . import gfs
9
+ from . import tempdir as tmpd
10
+ from . import utils as ut
11
+ from . import work_results as wres
12
+
13
+
14
+ def resolve_url(fss, url, fs_kwargs):
15
+ proto = gfs.get_proto(url)
16
+ fs = fss.get(proto)
17
+ if fs is None:
18
+ fs, fpath = gfs.resolve_fs(url, **fs_kwargs)
19
+ for fsid in fs.IDS:
20
+ fss[fsid] = fs
21
+ else:
22
+ fpath = fs.norm_url(url)
23
+
24
+ return fs, fpath
25
+
26
+
27
+ def fetcher(path, fs_kwargs, uqueue, rqueue):
28
+ fss = dict()
29
+ while True:
30
+ url = uqueue.get()
31
+ if not url:
32
+ break
33
+
34
+ wpath = wres.work_path(path, url)
35
+
36
+ alog.verbose(f'Fetching "{url}"')
37
+ try:
38
+ fs, fpath = resolve_url(fss, url, fs_kwargs)
39
+
40
+ with wres.write_result(wpath) as fd:
41
+ for data in fs.get_file(fpath):
42
+ fd.write(data)
43
+ except Exception as ex:
44
+ wres.write_error(wpath, ex, workid=url)
45
+ finally:
46
+ rqueue.put(url)
47
+
48
+
49
+ class UrlFetcher:
50
+
51
+ def __init__(self, path=None, num_workers=None, fs_kwargs=None):
52
+ fs_kwargs = fs_kwargs or dict()
53
+ fs_kwargs = ut.dict_setmissing(
54
+ fs_kwargs,
55
+ timeout=ut.getenv('FETCHER_TIMEO', dtype=float, defval=10.0),
56
+ )
57
+
58
+ self._ctor_path = path
59
+ self._path = None
60
+ self._num_workers = num_workers or max(os.cpu_count() * 4, 128)
61
+ self._fs_kwargs = fs_kwargs
62
+ self._uqueue = self._rqueue = None
63
+ self._workers = []
64
+ self._pending = set()
65
+
66
+ def start(self):
67
+ if self._ctor_path is None:
68
+ self._path = tmpd.fastfs_dir()
69
+ else:
70
+ self._path = self._ctor_path
71
+
72
+ self._uqueue = queue.Queue()
73
+ self._rqueue = queue.Queue()
74
+ for i in range(self._num_workers):
75
+ worker = threading.Thread(
76
+ target=fetcher,
77
+ args=(self._path, self._fs_kwargs, self._uqueue, self._rqueue),
78
+ daemon=True,
79
+ )
80
+ worker.start()
81
+ self._workers.append(worker)
82
+
83
+ def shutdown(self):
84
+ alog.verbose(f'Sending shutdowns down the queue')
85
+ for _ in range(len(self._workers)):
86
+ self._uqueue.put('')
87
+
88
+ alog.verbose(f'Joining fetcher workers')
89
+ for worker in self._workers:
90
+ worker.join()
91
+
92
+ self._uqueue = self._rqueue = None
93
+ self._workers = []
94
+
95
+ if self._path != self._ctor_path:
96
+ gfs.rmtree(self._path, ignore_errors=True)
97
+
98
+ self._path = None
99
+ self._pending = set()
100
+
101
+ def enqueue(self, *urls):
102
+ wmap = dict()
103
+ for url in urls:
104
+ if url:
105
+ self._uqueue.put(url)
106
+ self._pending.add(url)
107
+ wmap[url] = wres.work_hash(url)
108
+
109
+ return wmap
110
+
111
+ def wait(self, url):
112
+ wpath = wres.work_path(self._path, url)
113
+ if not os.path.isfile(wpath):
114
+ tas.check(url in self._pending, msg=f'URL already retired: {url}')
115
+
116
+ while self._pending:
117
+ rurl = self._rqueue.get()
118
+ self._pending.discard(rurl)
119
+ if rurl == url:
120
+ break
121
+
122
+ try:
123
+ return wres.get_work(wpath)
124
+ finally:
125
+ os.remove(wpath)
126
+
127
+ def iter_results(self, max_results=None, block=True, timeout=None):
128
+ count = 0
129
+ while self._pending:
130
+ try:
131
+ rurl = self._rqueue.get(block=block, timeout=timeout)
132
+
133
+ self._pending.discard(rurl)
134
+ wpath = wres.work_path(self._path, rurl)
135
+
136
+ wdata = wres.load_work(wpath)
137
+
138
+ os.remove(wpath)
139
+
140
+ yield rurl, wdata
141
+ except queue.Empty:
142
+ break
143
+
144
+ count += 1
145
+ if max_results is not None and count >= max_results:
146
+ break
147
+
148
+ def __enter__(self):
149
+ self.start()
150
+
151
+ return self
152
+
153
+ def __exit__(self, *exc):
154
+ self.shutdown()
155
+
156
+ return False
157
+
py_misc_utils/utils.py ADDED
@@ -0,0 +1,538 @@
1
+ import array
2
+ import collections
3
+ import datetime
4
+ import inspect
5
+ import json
6
+ import math
7
+ import os
8
+ import pickle
9
+ import re
10
+ import sys
11
+ import time
12
+ import types
13
+ import yaml
14
+
15
+ import numpy as np
16
+
17
+ from . import alog
18
+ from . import assert_checks as tas
19
+ from . import core_utils as cu
20
+ from . import file_overwrite as fow
21
+ from . import gfs
22
+ from . import inspect_utils as iu
23
+ from . import mmap as mm
24
+ from . import obj
25
+ from . import split as sp
26
+ from . import template_replace as tr
27
+ from . import traceback as tb
28
+
29
+
30
+ class _None:
31
+
32
+ def __repr__(self):
33
+ return 'NONE'
34
+
35
+
36
+ _NONE = _None()
37
+
38
+
39
+ def pickle_proto():
40
+ return getenv('PICKLE_PROTO', dtype=int, defval=pickle.HIGHEST_PROTOCOL)
41
+
42
+
43
+ def fname():
44
+ return tb.get_frame(1).f_code.co_name
45
+
46
+
47
+ def _stri(obj, seen, ffmt, dexp):
48
+ oid = id(obj)
49
+ sres = seen.get(oid, _NONE)
50
+ if sres is None:
51
+ return '...'
52
+ elif sres is not _NONE:
53
+ return sres
54
+
55
+ seen[oid] = None
56
+ if isinstance(obj, str):
57
+ obj_str = obj.replace('"', '\\"')
58
+ result = f'"{obj_str}"'
59
+ elif isinstance(obj, float):
60
+ result = f'{obj:{ffmt}}'
61
+ elif isinstance(obj, bytes):
62
+ result = obj.decode()
63
+ elif cu.is_namedtuple(obj):
64
+ result = str(obj)
65
+ elif cu.is_sequence(obj):
66
+ sl = ', '.join(_stri(x, seen, ffmt, dexp) for x in obj)
67
+
68
+ result = '[' + sl + ']' if isinstance(obj, list) else '(' + sl + ')'
69
+ elif cu.isdict(obj):
70
+ result = '{' + ', '.join(f'{k}={_stri(v, seen, ffmt, dexp)}' for k, v in obj.items()) + '}'
71
+ elif dexp and hasattr(obj, '__dict__'):
72
+ # Drop the braces around the __dict__ output, and use the "Classname(...)" format.
73
+ drepr = _stri(obj.__dict__, seen, ffmt, dexp)
74
+ result = f'{iu.cname(obj)}({drepr[1: -1]})'
75
+ else:
76
+ result = str(obj)
77
+
78
+ seen[oid] = result
79
+
80
+ return result
81
+
82
+
83
+ def stri(l, float_fmt=None, dict_expand=False):
84
+ return _stri(l, dict(), float_fmt or '.3e', dict_expand)
85
+
86
+
87
+ def repr_fmt(obj, *fields, repr_none=False, sep=', '):
88
+ parts = []
89
+ for fname in expand_strings(*fields):
90
+ rnone, xname = repr_none, fname
91
+
92
+ m = re.match(r'([!])?(\w+)\s*(=\s*(\w*))?', fname)
93
+ if m:
94
+ fname = m.group(2)
95
+ rnone = m.group(1) == '!'
96
+ xname = m.group(4)
97
+ if xname is None:
98
+ xname = fname
99
+
100
+ data = getattr(obj, fname, None)
101
+ if data is not None or rnone:
102
+ parts.append(f'{xname}={data}' if xname else str(data))
103
+
104
+ return sep.join(parts)
105
+
106
+
107
+ def mget(d, *args, as_dict=False):
108
+ margs = expand_strings(*args)
109
+ if as_dict:
110
+ return {f: d.get(f) for f in margs}
111
+ else:
112
+ return tuple(d.get(f) for f in margs)
113
+
114
+
115
+ def getvar(obj, name, defval=None):
116
+ return obj.get(name, defval) if cu.isdict(obj) else getattr(obj, name, defval)
117
+
118
+
119
+ def dict_subset(d, *keys):
120
+ mkeys = expand_strings(*keys)
121
+ subd = dict()
122
+ for k in mkeys:
123
+ v = d.get(k, _NONE)
124
+ if v is not _NONE:
125
+ subd[k] = v
126
+
127
+ return subd
128
+
129
+
130
+ def dict_setmissing(d, **kwargs):
131
+ kwargs.update(d)
132
+
133
+ return kwargs
134
+
135
+
136
+ def pop_kwargs(kwargs, names, args_key=None):
137
+ xargs = kwargs.pop(args_key or '_', None)
138
+ if xargs is not None:
139
+ args = [xargs.get(name) for name in expand_strings(names)]
140
+ else:
141
+ args = [kwargs.pop(name, None) for name in expand_strings(names)]
142
+
143
+ return tuple(args)
144
+
145
+
146
+ def resplit(csstr, sep):
147
+ return sp.split(csstr, r'\s*' + sep + r'\s*')
148
+
149
+
150
+ def comma_split(csstr):
151
+ return sp.split(csstr, r'\s*,\s*')
152
+
153
+
154
+ def ws_split(data):
155
+ return sp.split(data, r'\s+')
156
+
157
+
158
+ def expand_strings(*args):
159
+ margs = []
160
+ for arg in args:
161
+ if cu.is_sequence(arg):
162
+ margs.extend(arg)
163
+ else:
164
+ margs.extend(comma_split(arg))
165
+
166
+ return tuple(margs)
167
+
168
+
169
+ def name_values(base_name, values):
170
+ names = []
171
+ if isinstance(values, (list, tuple)):
172
+ if len(values) == 1:
173
+ names.append((base_name, values[0]))
174
+ else:
175
+ for i, v in enumerate(values):
176
+ names.append((f'{base_name}.{i}', v))
177
+ else:
178
+ names.append((base_name, values))
179
+
180
+ return tuple(names)
181
+
182
+
183
+ def write_config(cfg, dest, **kwargs):
184
+ default_flow_style = kwargs.get('default_flow_style', False)
185
+
186
+ with fow.FileOverwrite(dest, mode='wt') as df:
187
+ yaml.dump(cfg, df, default_flow_style=default_flow_style, **kwargs)
188
+
189
+
190
+ def config_to_string(cfg, **kwargs):
191
+ default_flow_style = kwargs.get('default_flow_style', False)
192
+
193
+ return yaml.dump(cfg, default_flow_style=default_flow_style, **kwargs)
194
+
195
+
196
+ def parse_config(cfg):
197
+ if not re.match(r'[\[\{]', cfg):
198
+ # It must be either a dictionary in YAML format, or a valid path.
199
+ with gfs.open(cfg, mode='r') as fd:
200
+ data = fd.read()
201
+ else:
202
+ data = cfg
203
+
204
+ return yaml.safe_load(data)
205
+
206
+
207
+ def load_config(path, extra=None):
208
+ cfgd = parse_config(path)
209
+
210
+ if extra:
211
+ for k, v in extra.items():
212
+ if v is not None:
213
+ cfgd[k] = v
214
+
215
+ return cfgd
216
+
217
+
218
+ def fatal(msg, exc=RuntimeError):
219
+ alog.xraise(exc, msg, stacklevel=2)
220
+
221
+
222
+ def assert_instance(msg, t, ta):
223
+ if not isinstance(t, ta):
224
+ parts = [msg, f': {iu.cname(t)} is not ']
225
+ if isinstance(ta, (list, tuple)):
226
+ parts.append('one of (')
227
+ parts.append(', '.join(iu.cname(x) for x in ta))
228
+ parts.append(')')
229
+ else:
230
+ parts.append(f'a {iu.cname(ta)}')
231
+
232
+ alog.xraise(ValueError, ''.join(parts))
233
+
234
+
235
+ def make_object(**kwargs):
236
+ return obj.Obj(**kwargs)
237
+
238
+
239
+ def make_object_recursive(**kwargs):
240
+ for k, v in kwargs.items():
241
+ if cu.isdict(v):
242
+ kwargs[k] = make_object_recursive(**v)
243
+
244
+ return make_object(**kwargs)
245
+
246
+
247
+ def locals_capture(locs, exclude='self'):
248
+ exclude = set(expand_strings(exclude, 'self'))
249
+
250
+ return make_object(**{k: v for k, v in locs.items() if k not in exclude})
251
+
252
+
253
+ def sreplace(rex, data, mapfn, nmapfn=None, join=True):
254
+ nmapfn = nmapfn if nmapfn is not None else ident
255
+
256
+ lastpos, parts = 0, []
257
+ for m in re.finditer(rex, data):
258
+ start, end = m.span()
259
+ if start > lastpos:
260
+ parts.append(nmapfn(data[lastpos: start]))
261
+
262
+ lastpos = end
263
+ mid = mapfn(m.group(1))
264
+ parts.append(mid)
265
+
266
+ if lastpos < len(data):
267
+ parts.append(nmapfn(data[lastpos:]))
268
+
269
+ return ''.join(parts) if join else parts
270
+
271
+
272
+ def as_sequence(v, t=tuple):
273
+ if isinstance(t, (list, tuple)):
274
+ for st in t:
275
+ if isinstance(v, st):
276
+ return v
277
+
278
+ return t[0]([v]) if not isinstance(v, types.GeneratorType) else t[0](v)
279
+
280
+ if isinstance(v, t):
281
+ return v
282
+
283
+ return t(v) if cu.is_sequence(v) else t([v])
284
+
285
+
286
+ def format(seq, fmt):
287
+ sfmt = f'{{:{fmt}}}'
288
+
289
+ return type(seq)(sfmt.format(x) for x in seq)
290
+
291
+
292
+ def value_or(v, defval):
293
+ return v if v is not None else defval
294
+
295
+
296
+ def dict_rget(sdict, path, defval=None, sep='/'):
297
+ if not isinstance(path, (list, tuple)):
298
+ path = path.strip(sep).split(sep)
299
+
300
+ result = sdict
301
+ for key in path:
302
+ if not cu.isdict(result):
303
+ return defval
304
+ result = result.get(key, defval)
305
+
306
+ return result
307
+
308
+
309
+ def make_index_dict(vals):
310
+ return {v: i for i, v in enumerate(vals)}
311
+
312
+
313
+ def append_index_dict(xlist, xdict, value):
314
+ xlist.append(value)
315
+ xdict[value] = len(xlist) - 1
316
+
317
+
318
+ def compile(code, syms, env=None, vals=None, lookup_fn=None, delim=None):
319
+ # Note that objects compiled with this API cannot be pickled.
320
+ # If that is a requirement, use the dynamod module.
321
+ env = value_or(env, dict())
322
+ if vals is not None or lookup_fn is not None:
323
+ code = tr.template_replace(code, vals=vals, lookup_fn=lookup_fn, delim=delim)
324
+
325
+ exec(code, env)
326
+
327
+ return tuple(env.get(s) for s in expand_strings(syms))
328
+
329
+
330
+ def run(path, fnname, *args, **kwargs):
331
+ compile_args, = pop_kwargs(kwargs, 'compile_args')
332
+
333
+ fn, = compile(mm.file_view(path), fnname, **(compile_args or dict()))
334
+
335
+ return fn(*args, **kwargs)
336
+
337
+
338
+ def unpack_n(l, n, defval=None):
339
+ l = as_sequence(l)
340
+
341
+ return tuple(l[:n] if len(l) >= n else l + [defval] * (n - len(l)))
342
+
343
+
344
+ def getenv(name, dtype=None, defval=None):
345
+ # os.getenv expects the default value to be a string, so cannot be passed in there.
346
+ env = os.getenv(name)
347
+ if env is None:
348
+ env = defval() if callable(defval) else defval
349
+ if env is not None:
350
+ return cu.to_type(env, dtype) if dtype is not None else env
351
+
352
+
353
+ def env(name, defval, vtype=None):
354
+ return getenv(name, dtype=vtype, defval=defval)
355
+
356
+
357
+ def envs(*args, as_dict=False):
358
+ return mget(os.environ, *args, as_dict=as_dict)
359
+
360
+
361
+ def import_env(dest, *args):
362
+ ivars = envs(*args, as_dict=True)
363
+ for k, v in ivars.items():
364
+ dest[k] = cu.infer_value(v)
365
+
366
+ return dest
367
+
368
+
369
+ def map_env(g, prefix=''):
370
+ ovr = dict()
371
+ for k, v in g.items():
372
+ ev = getenv(f'{prefix}{k}', dtype=type(v))
373
+ if ev is not None:
374
+ ovr[k] = ev
375
+
376
+ g.update(ovr)
377
+
378
+ return g
379
+
380
+
381
+ MAJOR = 1
382
+ MINOR = -1
383
+
384
+ def squeeze(shape, keep_dims=0, sdir=MAJOR):
385
+ sshape = list(shape)
386
+ if sdir == MAJOR:
387
+ while len(sshape) > keep_dims and sshape[0] == 1:
388
+ sshape = sshape[1:]
389
+ elif sdir == MINOR:
390
+ while len(sshape) > keep_dims and sshape[-1] == 1:
391
+ sshape = sshape[: -1]
392
+ else:
393
+ alog.xraise(ValueError, f'Unknown squeeze direction: {sdir}')
394
+
395
+ return type(shape)(sshape)
396
+
397
+
398
+ def flat2shape(data, shape):
399
+ tas.check_eq(len(data), np.prod(shape),
400
+ msg=f'Shape {shape} is unsuitable for a {len(data)} long array')
401
+
402
+ # For an Mx...xK input shape, return a M elements (nested) list.
403
+ for n in reversed(shape[1:]):
404
+ data = [data[i: i + n] for i in range(len(data), n)]
405
+
406
+ return data
407
+
408
+
409
+ def shape2flat(data, shape):
410
+ for _ in range(len(shape) - 1):
411
+ tas.check(hasattr(data, '__iter__'), msg=f'Wrong data type: {type(data)}')
412
+ ndata = []
413
+ for av in data:
414
+ tas.check(hasattr(av, '__iter__'), msg=f'Wrong data type: {type(data)}')
415
+ ndata.extend(av)
416
+
417
+ data = ndata
418
+
419
+ tas.check_eq(len(data), np.prod(shape),
420
+ msg=f'Shape {shape} is unsuitable for a {len(data)} long array')
421
+
422
+ return tuple(data)
423
+
424
+
425
+ def binary_reduce(parts, reduce_fn):
426
+ while len(parts) > 1:
427
+ nparts, base = [], 0
428
+ if len(parts) % 2 != 0:
429
+ nparts.append(parts[0])
430
+ base = 1
431
+
432
+ nparts.extend(reduce_fn(parts[i], parts[i + 1]) for i in range(base, len(parts) - 1, 2))
433
+
434
+ parts = nparts
435
+
436
+ return parts[0]
437
+
438
+
439
+ def stringify(s):
440
+ def rwfn(v):
441
+ if not (isinstance(v, (list, tuple)) or cu.isdict(v)):
442
+ return str(v)
443
+
444
+ return cu.data_rewrite(s, rwfn)
445
+
446
+
447
+ def seq_rewrite(seq, sd):
448
+ return type(seq)(sd.get(s, s) for s in seq)
449
+
450
+
451
+ def dfetch(d, *args):
452
+ return tuple(d[n] for n in args)
453
+
454
+
455
+ def numel(t):
456
+ sp = cu.get_property(t, 'shape')
457
+
458
+ return np.prod(sp) if sp is not None else len(t)
459
+
460
+
461
+ def scale_data(data, base_data, scale):
462
+ return ((data - base_data) / base_data) * scale
463
+
464
+
465
+ _ARRAY_SIZES = tuple((array.array(c).itemsize, c) for c in 'B,H,I,L,Q'.split(','))
466
+
467
+ def array_code(size):
468
+ nbytes = math.ceil(math.log2(size)) / 8
469
+ for cb, code in _ARRAY_SIZES:
470
+ if cb >= nbytes:
471
+ return code
472
+
473
+ alog.xraise(ValueError,
474
+ f'Size {size} too big to fit inside any array integer types')
475
+
476
+
477
+ def checked_remove(l, o):
478
+ try:
479
+ l.remove(o)
480
+ except ValueError:
481
+ return False
482
+
483
+ return True
484
+
485
+
486
+ def sleep_until(date, msg=None):
487
+ now = datetime.datetime.now(tz=date.tzinfo)
488
+ if date > now:
489
+ if msg:
490
+ alog.info(msg)
491
+ time.sleep(date.timestamp() - now.timestamp())
492
+
493
+
494
+ def parse_dict(data):
495
+ return yaml.safe_load(data)
496
+
497
+
498
+ def parse_args(in_args):
499
+ seq_args = comma_split(in_args) if isinstance(in_args, str) else in_args
500
+
501
+ args, kwargs = [], dict()
502
+ for arg in seq_args:
503
+ parts = cu.separate(arg, '=')
504
+ if len(parts) == 2:
505
+ kwargs[parts[0]] = yaml.safe_load(parts[1])
506
+ elif len(parts) == 1:
507
+ args.append(yaml.safe_load(parts[0]))
508
+ else:
509
+ alog.xraise(ValueError, f'Syntax error: {arg}')
510
+
511
+ return args, kwargs
512
+
513
+
514
+ def state_update(path, **kwargs):
515
+ if sfile := gfs.maybe_open(path, mode='rb'):
516
+ with sfile as fd:
517
+ state = pickle.load(fd)
518
+ else:
519
+ state = dict()
520
+
521
+ if kwargs:
522
+ state.update(kwargs)
523
+ with fow.FileOverwrite(path, mode='wb') as f:
524
+ pickle.dump(state, f, protocol=pickle_proto())
525
+
526
+ return state
527
+
528
+
529
+ def copy_inplace(dest, src):
530
+ obj_dict = getattr(dest, '__dict__', None)
531
+ if obj_dict is not None:
532
+ obj_dict.update(src.__dict__)
533
+ else:
534
+ for slot in iu.class_slots(dest):
535
+ setattr(dest, slot, getattr(src, slot, None))
536
+
537
+ return dest
538
+