python-misc-utils 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. py_misc_utils/__init__.py +0 -0
  2. py_misc_utils/abs_timeout.py +12 -0
  3. py_misc_utils/alog.py +311 -0
  4. py_misc_utils/app_main.py +179 -0
  5. py_misc_utils/archive_streamer.py +112 -0
  6. py_misc_utils/assert_checks.py +118 -0
  7. py_misc_utils/ast_utils.py +121 -0
  8. py_misc_utils/async_manager.py +189 -0
  9. py_misc_utils/break_control.py +63 -0
  10. py_misc_utils/buffered_iterator.py +35 -0
  11. py_misc_utils/cached_file.py +507 -0
  12. py_misc_utils/call_limiter.py +26 -0
  13. py_misc_utils/call_result_selector.py +13 -0
  14. py_misc_utils/cleanups.py +85 -0
  15. py_misc_utils/cmd.py +97 -0
  16. py_misc_utils/compression.py +116 -0
  17. py_misc_utils/cond_waiter.py +13 -0
  18. py_misc_utils/context_base.py +18 -0
  19. py_misc_utils/context_managers.py +67 -0
  20. py_misc_utils/core_utils.py +577 -0
  21. py_misc_utils/daemon_process.py +252 -0
  22. py_misc_utils/data_cache.py +46 -0
  23. py_misc_utils/date_utils.py +90 -0
  24. py_misc_utils/debug.py +24 -0
  25. py_misc_utils/dyn_modules.py +50 -0
  26. py_misc_utils/dynamod.py +103 -0
  27. py_misc_utils/env_config.py +35 -0
  28. py_misc_utils/executor.py +239 -0
  29. py_misc_utils/file_overwrite.py +29 -0
  30. py_misc_utils/fin_wrap.py +77 -0
  31. py_misc_utils/fp_utils.py +47 -0
  32. py_misc_utils/fs/__init__.py +0 -0
  33. py_misc_utils/fs/file_fs.py +127 -0
  34. py_misc_utils/fs/ftp_fs.py +242 -0
  35. py_misc_utils/fs/gcs_fs.py +196 -0
  36. py_misc_utils/fs/http_fs.py +241 -0
  37. py_misc_utils/fs/s3_fs.py +417 -0
  38. py_misc_utils/fs_base.py +133 -0
  39. py_misc_utils/fs_utils.py +207 -0
  40. py_misc_utils/gcs_fs.py +169 -0
  41. py_misc_utils/gen_indices.py +54 -0
  42. py_misc_utils/gfs.py +371 -0
  43. py_misc_utils/git_repo.py +77 -0
  44. py_misc_utils/global_namespace.py +110 -0
  45. py_misc_utils/http_async_fetcher.py +139 -0
  46. py_misc_utils/http_server.py +196 -0
  47. py_misc_utils/http_utils.py +143 -0
  48. py_misc_utils/img_utils.py +20 -0
  49. py_misc_utils/infix_op.py +20 -0
  50. py_misc_utils/inspect_utils.py +205 -0
  51. py_misc_utils/iostream.py +21 -0
  52. py_misc_utils/iter_file.py +117 -0
  53. py_misc_utils/key_wrap.py +46 -0
  54. py_misc_utils/lazy_import.py +25 -0
  55. py_misc_utils/lockfile.py +164 -0
  56. py_misc_utils/mem_size.py +64 -0
  57. py_misc_utils/mirror_from.py +72 -0
  58. py_misc_utils/mmap.py +16 -0
  59. py_misc_utils/module_utils.py +196 -0
  60. py_misc_utils/moving_average.py +19 -0
  61. py_misc_utils/msgpack_streamer.py +26 -0
  62. py_misc_utils/multi_wait.py +24 -0
  63. py_misc_utils/multiprocessing.py +102 -0
  64. py_misc_utils/named_array.py +224 -0
  65. py_misc_utils/no_break.py +46 -0
  66. py_misc_utils/no_except.py +32 -0
  67. py_misc_utils/np_ml_framework.py +184 -0
  68. py_misc_utils/np_utils.py +346 -0
  69. py_misc_utils/ntuple_utils.py +38 -0
  70. py_misc_utils/num_utils.py +54 -0
  71. py_misc_utils/obj.py +73 -0
  72. py_misc_utils/object_cache.py +100 -0
  73. py_misc_utils/object_tracker.py +88 -0
  74. py_misc_utils/ordered_set.py +71 -0
  75. py_misc_utils/osfd.py +27 -0
  76. py_misc_utils/packet.py +22 -0
  77. py_misc_utils/parquet_streamer.py +69 -0
  78. py_misc_utils/pd_utils.py +254 -0
  79. py_misc_utils/periodic_task.py +61 -0
  80. py_misc_utils/pickle_wrap.py +121 -0
  81. py_misc_utils/pipeline.py +98 -0
  82. py_misc_utils/remap_pickle.py +50 -0
  83. py_misc_utils/resource_manager.py +155 -0
  84. py_misc_utils/rnd_utils.py +56 -0
  85. py_misc_utils/run_once.py +19 -0
  86. py_misc_utils/scheduler.py +135 -0
  87. py_misc_utils/select_params.py +300 -0
  88. py_misc_utils/signal.py +141 -0
  89. py_misc_utils/skl_utils.py +270 -0
  90. py_misc_utils/split.py +147 -0
  91. py_misc_utils/state.py +53 -0
  92. py_misc_utils/std_module.py +56 -0
  93. py_misc_utils/stream_dataframe.py +176 -0
  94. py_misc_utils/streamed_file.py +144 -0
  95. py_misc_utils/tempdir.py +79 -0
  96. py_misc_utils/template_replace.py +51 -0
  97. py_misc_utils/tensor_stream.py +269 -0
  98. py_misc_utils/thread_context.py +33 -0
  99. py_misc_utils/throttle.py +30 -0
  100. py_misc_utils/time_trigger.py +18 -0
  101. py_misc_utils/timegen.py +11 -0
  102. py_misc_utils/traceback.py +49 -0
  103. py_misc_utils/tracking_executor.py +91 -0
  104. py_misc_utils/transform_array.py +42 -0
  105. py_misc_utils/uncompress.py +35 -0
  106. py_misc_utils/url_fetcher.py +157 -0
  107. py_misc_utils/utils.py +538 -0
  108. py_misc_utils/varint.py +50 -0
  109. py_misc_utils/virt_array.py +52 -0
  110. py_misc_utils/weak_call.py +33 -0
  111. py_misc_utils/work_results.py +100 -0
  112. py_misc_utils/writeback_file.py +43 -0
  113. python_misc_utils-0.2.dist-info/METADATA +36 -0
  114. python_misc_utils-0.2.dist-info/RECORD +117 -0
  115. python_misc_utils-0.2.dist-info/WHEEL +5 -0
  116. python_misc_utils-0.2.dist-info/licenses/LICENSE +13 -0
  117. python_misc_utils-0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,71 @@
1
+ import copy
2
+
3
+
4
+ class OrderedSet:
5
+
6
+ def __init__(self, init=None):
7
+ self._data = dict()
8
+ self._seqno = 0
9
+
10
+ for value in init or ():
11
+ self.add(value)
12
+
13
+ def add(self, value):
14
+ n = self._data.get(value)
15
+ if n is None:
16
+ self._data[value] = n = self._seqno
17
+ self._seqno += 1
18
+
19
+ return n
20
+
21
+ def remove(self, value):
22
+ self._data.pop(value)
23
+
24
+ def discard(self, value):
25
+ self._data.pop(value, None)
26
+
27
+ def pop(self):
28
+ value, n = self._data.popitem()
29
+
30
+ return value
31
+
32
+ def clear(self):
33
+ self._data = dict()
34
+ self._seqno = 0
35
+
36
+ def __len__(self):
37
+ return len(self._data)
38
+
39
+ def values(self):
40
+ return (y[0] for y in sorted(self._data.items(), key=lambda x: x[1]))
41
+
42
+ def __iter__(self):
43
+ return iter(self.values())
44
+
45
+ def __contains__(self, value):
46
+ return value in self._data
47
+
48
+ def union(self, *others):
49
+ nos = copy.copy(self)
50
+ for other in others:
51
+ for value in other:
52
+ nos.add(value)
53
+
54
+ return nos
55
+
56
+ def intersection(self, *others):
57
+ nos = OrderedSet()
58
+ for value in self.values():
59
+ if all(value in other for other in others):
60
+ nos.add(value)
61
+
62
+ return nos
63
+
64
+ def difference(self, *others):
65
+ nos = OrderedSet()
66
+ for value in self.values():
67
+ if not any(value in other for other in others):
68
+ nos.add(value)
69
+
70
+ return nos
71
+
py_misc_utils/osfd.py ADDED
@@ -0,0 +1,27 @@
1
+ import os
2
+
3
+
4
+ class OsFd:
5
+
6
+ def __init__(self, path, flags, remove_on_error=None, **kwargs):
7
+ self._path = path
8
+ self._flags = flags
9
+ self._kwargs = kwargs
10
+ self._remove_on_error = (remove_on_error if remove_on_error is not None else
11
+ not os.path.exists(path))
12
+
13
+ def __enter__(self):
14
+ self._fd = os.open(self._path, self._flags, **self._kwargs)
15
+
16
+ return self._fd
17
+
18
+ def __exit__(self, *exc):
19
+ os.close(self._fd)
20
+ if any(ex is not None for ex in exc) and self._remove_on_error:
21
+ try:
22
+ os.remove(self._path)
23
+ except OSError:
24
+ pass
25
+
26
+ return False
27
+
@@ -0,0 +1,22 @@
1
+ import os
2
+ import struct
3
+
4
+ from . import iostream as ios
5
+
6
+
7
+ _SIZE_PACKER = struct.Struct('<Q')
8
+
9
+ def write_packet(fd, data):
10
+ packet = _SIZE_PACKER.pack(len(data)) + data
11
+ iofd = ios.IOStream(fd)
12
+ iofd.write(packet)
13
+
14
+
15
+ def read_packet(fd):
16
+ iofd = ios.IOStream(fd)
17
+ data = iofd.read(_SIZE_PACKER.size)
18
+ size = _SIZE_PACKER.unpack(data)[0]
19
+ packet = iofd.read(size)
20
+
21
+ return packet
22
+
@@ -0,0 +1,69 @@
1
+ import contextlib
2
+ import functools
3
+
4
+ import pyarrow.parquet as pq
5
+
6
+ from . import alog
7
+ from . import fin_wrap as fw
8
+ from . import gfs
9
+ from . import url_fetcher as urlf
10
+ from . import utils as ut
11
+
12
+
13
+ class ParquetStreamer:
14
+
15
+ def __init__(self, url,
16
+ batch_size=128,
17
+ load_columns=None,
18
+ rename_columns=None,
19
+ num_workers=None,
20
+ **kwargs):
21
+ self._url = url
22
+ self._batch_size = batch_size
23
+ self._load_columns = ut.value_or(load_columns, dict())
24
+ self._rename_columns = ut.value_or(rename_columns, dict())
25
+ self._num_workers = num_workers
26
+ self._kwargs = kwargs
27
+
28
+ def _fetcher(self):
29
+ if self._load_columns:
30
+ return urlf.UrlFetcher(num_workers=self._num_workers,
31
+ fs_kwargs=self._kwargs)
32
+ else:
33
+ return contextlib.nullcontext()
34
+
35
+ def _prefetch(self, fetcher, recs):
36
+ if fetcher is not None:
37
+ for recd in recs:
38
+ for key in self._load_columns.keys():
39
+ fetcher.enqueue(recd[key])
40
+
41
+ def _transform(self, fetcher, recd):
42
+ if fetcher is not None:
43
+ for key, name in self._load_columns.items():
44
+ recd[name] = fetcher.wait(recd[key])
45
+
46
+ for key, name in self._rename_columns.items():
47
+ recd[name] = recd.pop(key)
48
+
49
+ return recd
50
+
51
+ def generate(self):
52
+ with (self._fetcher() as fetcher,
53
+ gfs.open(self._url, mode='rb', **self._kwargs) as stream):
54
+ pqfd = pq.ParquetFile(stream)
55
+ for batch in pqfd.iter_batches(batch_size=self._batch_size):
56
+ recs = batch.to_pylist()
57
+
58
+ self._prefetch(fetcher, recs)
59
+ for recd in recs:
60
+ try:
61
+ yield self._transform(fetcher, recd)
62
+ except GeneratorExit:
63
+ raise
64
+ except Exception as ex:
65
+ alog.verbose(f'Unable to create parquet entry ({recd}): {ex}')
66
+
67
+ def __iter__(self):
68
+ return self.generate()
69
+
@@ -0,0 +1,254 @@
1
+ import array
2
+ import collections
3
+ import datetime
4
+ import os
5
+ import re
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from . import alog
11
+ from . import assert_checks as tas
12
+ from . import core_utils as cu
13
+ from . import gfs
14
+ from . import np_utils as npu
15
+ from . import utils as ut
16
+
17
+
18
+ def get_df_columns(df, discards=None):
19
+ dset = discards or {}
20
+
21
+ return [c for c in df.columns if c not in dset]
22
+
23
+
24
+ def get_typed_columns(df, type_fn, discards=None):
25
+ cols = []
26
+ for c in get_df_columns(df, discards=discards):
27
+ if type_fn(df[c].dtype):
28
+ cols.append(c)
29
+
30
+ return cols
31
+
32
+
33
+ def re_select_columns(df, re_cols):
34
+ cols = []
35
+ for c in df.columns:
36
+ for rc in re_cols:
37
+ if re.match(rc, c):
38
+ cols.append(c)
39
+ break
40
+
41
+ return cols
42
+
43
+
44
+ def read_csv(path, rows_sample=100, dtype=None, args=None):
45
+ with gfs.open(path, mode='r') as fd:
46
+ args = dict() if args is None else args
47
+ if args.get('index_col') is None:
48
+ args = args.copy()
49
+ fields = ut.comma_split(fd.readline())
50
+ fd.seek(0)
51
+ # If 'index_col' is not specified, we use column 0 if its name is empty, otherwise
52
+ # we disable it setting it to False.
53
+ args['index_col'] = False if fields[0] else 0
54
+
55
+ if dtype is None:
56
+ return pd.read_csv(fd, **args)
57
+ if cu.isdict(dtype):
58
+ dtype = {c: np.dtype(t) for c, t in dtype.items()}
59
+ else:
60
+ df_test = pd.read_csv(fd, nrows=rows_sample, **args)
61
+ fd.seek(0)
62
+ dtype = {c: dtype for c in get_typed_columns(df_test, npu.is_numeric)}
63
+
64
+ return pd.read_csv(fd, dtype=dtype, **args)
65
+
66
+
67
+ def save_dataframe(df, path, **kwargs):
68
+ _, ext = os.path.splitext(os.path.basename(path))
69
+ if ext == '.pkl':
70
+ args = ut.dict_subset(kwargs, ('compression', 'protocol', 'storage_options'))
71
+ if 'protocol' not in args:
72
+ args['protocol'] = ut.pickle_proto()
73
+ with gfs.open(path, mode='wb') as fd:
74
+ df.to_pickle(fd, **args)
75
+ elif ext == '.csv':
76
+ args = ut.dict_subset(kwargs, ('float_format', 'columns', 'header', 'index',
77
+ 'index_label', 'mode', 'encoding', 'quoting',
78
+ 'quotechar', 'line_terminator', 'chunksize',
79
+ 'date_format', 'doublequote', 'escapechar',
80
+ 'decimal', 'compression', 'error', 'storage_options'))
81
+
82
+ # For CSV file, unless otherwise specified, and the index has no name, drop
83
+ # the index column as it adds no value to the output (it's simply a sequential).
84
+ if not df.index.name:
85
+ args = ut.dict_setmissing(args, index=None)
86
+
87
+ with gfs.open(path, mode='w') as fd:
88
+ df.to_csv(fd, **args)
89
+ else:
90
+ alog.xraise(RuntimeError, f'Unknown extension: {ext}')
91
+
92
+
93
+ def load_dataframe(path, **kwargs):
94
+ _, ext = os.path.splitext(os.path.basename(path))
95
+ if ext == '.pkl':
96
+ with gfs.open(path, mode='rb') as fd:
97
+ return pd.read_pickle(fd)
98
+ elif ext == '.csv':
99
+ rows_sample = kwargs.pop('rows_sample', 100)
100
+ dtype = kwargs.pop('dtype', None)
101
+ args = ut.dict_subset(kwargs, ('sep', 'delimiter', 'header', 'names',
102
+ 'index_col', 'usecols', 'squeeze', 'prefix',
103
+ 'mangle_dupe_cols', 'dtype', 'engine',
104
+ 'converters', 'true_values', 'false_values',
105
+ 'skipinitialspace', 'skiprows', 'skipfooter',
106
+ 'nrows', 'na_values', 'keep_default_na', 'na_filter',
107
+ 'verbose', 'skip_blank_lines', 'parse_dates',
108
+ 'infer_datetime_format', 'keep_date_col', 'date_parser',
109
+ 'dayfirst', 'cache_dates', 'iterator', 'chunksize',
110
+ 'compression', 'thousands', 'decimal', 'lineterminator',
111
+ 'quotechar', 'quoting', 'doublequote', 'escapechar',
112
+ 'comment', 'encoding', 'dialect', 'error_bad_lines',
113
+ 'warn_bad_lines', 'delim_whitespace', 'low_memory',
114
+ 'memory_map', 'float_precision', 'storage_options'))
115
+
116
+ return read_csv(path, rows_sample=rows_sample, dtype=dtype,
117
+ args=args)
118
+ else:
119
+ alog.xraise(RuntimeError, f'Unknown extension: {ext}')
120
+
121
+
122
+ def to_npdict(df, reset_index=False, dtype=None, no_convert=()):
123
+ if reset_index and df.index.name:
124
+ df = df.reset_index()
125
+
126
+ cdata = dict()
127
+ for c in df.columns:
128
+ data = df[c].to_numpy()
129
+ if dtype is not None and c not in no_convert:
130
+ data = npu.astype(data, c, dtype)
131
+ cdata[c] = data
132
+
133
+ return cdata
134
+
135
+
136
+ def load_dataframe_as_npdict(path, reset_index=False, dtype=None, no_convert=()):
137
+ df = load_dataframe(path)
138
+
139
+ return to_npdict(df, reset_index=reset_index, dtype=dtype, no_convert=no_convert)
140
+
141
+
142
+ def column_or_index(df, name, numpy=True):
143
+ data = df.get(name)
144
+ if data is None and df.index.name == name:
145
+ data = df.index
146
+ if data is not None:
147
+ return data.to_numpy() if numpy else data
148
+
149
+
150
+ def columns_transform(df, cols, tfn):
151
+ for c in cols:
152
+ cv = df.get(c)
153
+ if cv is not None:
154
+ df[c] = tfn(c, cv, index=False)
155
+ elif df.index.name == c:
156
+ df.index = pd.Index(data=tfn(c, df.index, index=True), name=df.index.name)
157
+ else:
158
+ alog.xraise(RuntimeError, f'Unable to find column or index named "{c}"')
159
+
160
+ return df
161
+
162
+
163
+ def get_columns_index(df):
164
+ cols = df.columns.tolist()
165
+
166
+ return ut.make_index_dict(cols), cols
167
+
168
+
169
+ def concat_dataframes(files, **kwargs):
170
+ dfs = []
171
+ for path in files:
172
+ df = load_dataframe(path, **kwargs)
173
+ dfs.append(df)
174
+
175
+ return pd.concat(dfs, **kwargs) if dfs else None
176
+
177
+
178
+ def get_dataframe_groups(df, cols, cols_transforms=None):
179
+ # Pandas groupby() is extremely slow when there are many groups as it builds
180
+ # a DataFrame for each group. This simply collects the rows associated with each
181
+ # tuple values representing the grouped columns.
182
+ # Row numbers must be strictly ascending within each group, do NOT change that!
183
+ groups = collections.defaultdict(lambda: array.array('L'))
184
+ if cols_transforms:
185
+ tcols = [(df[c], cols_transforms.get(c, pycu.ident)) for c in cols]
186
+ for i in range(len(df)):
187
+ k = tuple([f(d[i]) for d, f in tcols])
188
+ groups[k].append(i)
189
+ else:
190
+ cdata = [df[c] for c in cols]
191
+ for i in range(len(df)):
192
+ k = tuple([d[i] for d in cdata])
193
+ groups[k].append(i)
194
+
195
+ return groups
196
+
197
+
198
+ def limit_per_group(df, cols, limit):
199
+ mask = np.full(len(df), False)
200
+
201
+ groups = get_dataframe_groups(df, cols)
202
+ for k, g in groups.items():
203
+ if limit > 0:
204
+ rows = g[: limit]
205
+ else:
206
+ rows = g[limit:]
207
+ mask[rows] = True
208
+
209
+ return df[mask]
210
+
211
+
212
+ def dataframe_column_rewrite(df, name, fn):
213
+ data = df.get(name)
214
+ if data is not None:
215
+ df[name] = fn(data.to_numpy())
216
+ elif df.index.name == name:
217
+ nvalues = fn(df.index.to_numpy())
218
+ df.index = type(df.index)(data=nvalues, name=df.index.name)
219
+ else:
220
+ alog.xraise(RuntimeError, f'No column or index named "{name}"')
221
+
222
+
223
+ def correlate(df, col, top_n=None):
224
+ ccorr = df.corrwith(df[col])
225
+ scorr = ccorr.sort_values(key=lambda x: abs(x), ascending=False)
226
+ top_n = len(scorr) if top_n is None else min(top_n, len(scorr))
227
+
228
+ return tuple([(scorr.index[i], scorr[i]) for i in range(top_n)])
229
+
230
+
231
+ def type_convert_dataframe(df, types):
232
+ # Pandas DataFrame astype() doe not have an 'inplace' argument.
233
+ for c, t in types.items():
234
+ df[c] = df[c].astype(t)
235
+
236
+ return df
237
+
238
+
239
+ def dataframe_rows_select(df, indices):
240
+ return df.loc[df.index[indices]]
241
+
242
+
243
+ def sorted_index(df, col):
244
+ if df.index.name == col:
245
+ return df.sort_values(col)
246
+
247
+ sdf = df.sort_values(col, ignore_index=True)
248
+
249
+ return sdf.set_index(col)
250
+
251
+
252
+ def datetime_to_epoch(data):
253
+ return (pd.to_datetime(data) - datetime.datetime(1970, 1, 1)).dt.total_seconds()
254
+
@@ -0,0 +1,61 @@
1
+ import threading
2
+ import weakref
3
+
4
+ from . import alog
5
+ from . import scheduler as sch
6
+ from . import weak_call as wcall
7
+
8
+
9
+ class PeriodicTask:
10
+
11
+ def __init__(self, name, periodic_fn, period, scheduler=None, stop_on_error=None):
12
+ self._name = name
13
+ self._periodic_fn = wcall.WeakCall(periodic_fn)
14
+ self._period = period
15
+ self._scheduler = scheduler or sch.common_scheduler()
16
+ self._stop_on_error = stop_on_error in (None, True)
17
+ self._lock = threading.Lock()
18
+ self._event = None
19
+ self._completed_event = None
20
+
21
+ def start(self):
22
+ with self._lock:
23
+ if self._event is None:
24
+ self._schedule()
25
+
26
+ return self
27
+
28
+ def stop(self):
29
+ completed_event = None
30
+ with self._lock:
31
+ if self._event is not None:
32
+ # If "events" is empty, we were not able to cancel the task, so it will be
33
+ # in flight, and we need to wait for it to complete before exiting.
34
+ events = self._scheduler.cancel(self._event)
35
+ completed_event = self._completed_event if not events else None
36
+ self._event = None
37
+
38
+ if completed_event is not None:
39
+ completed_event.wait()
40
+
41
+ def _schedule(self):
42
+ self._event = self._scheduler.enter(self._period, self._runner)
43
+ self._completed_event = threading.Event()
44
+
45
+ def _runner(self):
46
+ re_issue = True
47
+ try:
48
+ if self._periodic_fn() is wcall.GONE:
49
+ re_issue = False
50
+ except Exception as ex:
51
+ alog.exception(ex, exmsg=f'Exception while running periodic task "{self._name}"')
52
+ re_issue = not self._stop_on_error
53
+ finally:
54
+ with self._lock:
55
+ self._completed_event.set()
56
+ if self._event is not None:
57
+ if re_issue:
58
+ self._schedule()
59
+ else:
60
+ self._event = None
61
+
@@ -0,0 +1,121 @@
1
+ import importlib
2
+ import os
3
+ import pickle
4
+ import sys
5
+
6
+ from . import alog
7
+ from . import core_utils as cu
8
+ from . import inspect_utils as iu
9
+ from . import std_module as stdm
10
+
11
+
12
+ KNOWN_MODULES = {
13
+ 'builtins',
14
+ 'numpy',
15
+ 'pandas',
16
+ 'torch',
17
+ cu.root_module(__name__),
18
+ }
19
+
20
+ def add_known_module(modname):
21
+ KNOWN_MODULES.add(cu.root_module(modname))
22
+
23
+
24
+ def _needs_wrap(obj):
25
+ objmod = iu.moduleof(obj)
26
+ if objmod is not None:
27
+ modname = cu.root_module(objmod)
28
+ if modname in KNOWN_MODULES or stdm.is_std_module(modname):
29
+ return False
30
+
31
+ return True
32
+
33
+
34
+ def _wrap(obj, pickle_module):
35
+ wrapped = 0
36
+ if isinstance(obj, (list, tuple)):
37
+ wobj = []
38
+ for v in obj:
39
+ wv = _wrap(v, pickle_module)
40
+ if wv is not v:
41
+ wrapped += 1
42
+
43
+ wobj.append(wv)
44
+
45
+ return type(obj)(wobj) if wrapped else obj
46
+ elif cu.isdict(obj):
47
+ wobj = type(obj)()
48
+ for k, v in obj.items():
49
+ wk = _wrap(k, pickle_module)
50
+ if wk is not k:
51
+ wrapped += 1
52
+ wv = _wrap(v, pickle_module)
53
+ if wv is not v:
54
+ wrapped += 1
55
+
56
+ wobj[wk] = wv
57
+
58
+ return wobj if wrapped else obj
59
+ elif isinstance(obj, PickleWrap):
60
+ return obj
61
+ elif _needs_wrap(obj):
62
+ return PickleWrap(obj, pickle_module=pickle_module)
63
+ else:
64
+ return obj
65
+
66
+
67
+ def _unwrap(obj, pickle_module):
68
+ unwrapped = 0
69
+ if isinstance(obj, (list, tuple)):
70
+ uwobj = []
71
+ for v in obj:
72
+ wv = _unwrap(v, pickle_module)
73
+ if wv is not v:
74
+ unwrapped += 1
75
+
76
+ uwobj.append(wv)
77
+
78
+ return type(obj)(uwobj) if unwrapped else obj
79
+ elif cu.isdict(obj):
80
+ uwobj = type(obj)()
81
+ for k, v in obj.items():
82
+ wk = _unwrap(k, pickle_module)
83
+ if wk is not k:
84
+ unwrapped += 1
85
+ wv = _unwrap(v, pickle_module)
86
+ if wv is not v:
87
+ unwrapped += 1
88
+
89
+ uwobj[wk] = wv
90
+
91
+ return uwobj if unwrapped else obj
92
+ elif isinstance(obj, PickleWrap):
93
+ try:
94
+ return obj.load(pickle_module=pickle_module)
95
+ except Exception as ex:
96
+ alog.debug(f'Unable to reload pickle-wrapped data ({obj.wrapped_class()}): {ex}')
97
+ return obj
98
+ else:
99
+ return obj
100
+
101
+
102
+ def wrap(obj, pickle_module=pickle):
103
+ return _wrap(obj, pickle_module)
104
+
105
+
106
+ def unwrap(obj, pickle_module=pickle):
107
+ return _unwrap(obj, pickle_module)
108
+
109
+
110
+ class PickleWrap:
111
+
112
+ def __init__(self, obj, pickle_module=pickle):
113
+ self._class = iu.qual_name(obj)
114
+ self._data = pickle_module.dumps(obj)
115
+
116
+ def wrapped_class(self):
117
+ return self._class
118
+
119
+ def load(self, pickle_module=pickle):
120
+ return pickle_module.loads(self._data)
121
+