python-misc-utils 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. py_misc_utils/__init__.py +0 -0
  2. py_misc_utils/abs_timeout.py +12 -0
  3. py_misc_utils/alog.py +311 -0
  4. py_misc_utils/app_main.py +179 -0
  5. py_misc_utils/archive_streamer.py +112 -0
  6. py_misc_utils/assert_checks.py +118 -0
  7. py_misc_utils/ast_utils.py +121 -0
  8. py_misc_utils/async_manager.py +189 -0
  9. py_misc_utils/break_control.py +63 -0
  10. py_misc_utils/buffered_iterator.py +35 -0
  11. py_misc_utils/cached_file.py +507 -0
  12. py_misc_utils/call_limiter.py +26 -0
  13. py_misc_utils/call_result_selector.py +13 -0
  14. py_misc_utils/cleanups.py +85 -0
  15. py_misc_utils/cmd.py +97 -0
  16. py_misc_utils/compression.py +116 -0
  17. py_misc_utils/cond_waiter.py +13 -0
  18. py_misc_utils/context_base.py +18 -0
  19. py_misc_utils/context_managers.py +67 -0
  20. py_misc_utils/core_utils.py +577 -0
  21. py_misc_utils/daemon_process.py +252 -0
  22. py_misc_utils/data_cache.py +46 -0
  23. py_misc_utils/date_utils.py +90 -0
  24. py_misc_utils/debug.py +24 -0
  25. py_misc_utils/dyn_modules.py +50 -0
  26. py_misc_utils/dynamod.py +103 -0
  27. py_misc_utils/env_config.py +35 -0
  28. py_misc_utils/executor.py +239 -0
  29. py_misc_utils/file_overwrite.py +29 -0
  30. py_misc_utils/fin_wrap.py +77 -0
  31. py_misc_utils/fp_utils.py +47 -0
  32. py_misc_utils/fs/__init__.py +0 -0
  33. py_misc_utils/fs/file_fs.py +127 -0
  34. py_misc_utils/fs/ftp_fs.py +242 -0
  35. py_misc_utils/fs/gcs_fs.py +196 -0
  36. py_misc_utils/fs/http_fs.py +241 -0
  37. py_misc_utils/fs/s3_fs.py +417 -0
  38. py_misc_utils/fs_base.py +133 -0
  39. py_misc_utils/fs_utils.py +207 -0
  40. py_misc_utils/gcs_fs.py +169 -0
  41. py_misc_utils/gen_indices.py +54 -0
  42. py_misc_utils/gfs.py +371 -0
  43. py_misc_utils/git_repo.py +77 -0
  44. py_misc_utils/global_namespace.py +110 -0
  45. py_misc_utils/http_async_fetcher.py +139 -0
  46. py_misc_utils/http_server.py +196 -0
  47. py_misc_utils/http_utils.py +143 -0
  48. py_misc_utils/img_utils.py +20 -0
  49. py_misc_utils/infix_op.py +20 -0
  50. py_misc_utils/inspect_utils.py +205 -0
  51. py_misc_utils/iostream.py +21 -0
  52. py_misc_utils/iter_file.py +117 -0
  53. py_misc_utils/key_wrap.py +46 -0
  54. py_misc_utils/lazy_import.py +25 -0
  55. py_misc_utils/lockfile.py +164 -0
  56. py_misc_utils/mem_size.py +64 -0
  57. py_misc_utils/mirror_from.py +72 -0
  58. py_misc_utils/mmap.py +16 -0
  59. py_misc_utils/module_utils.py +196 -0
  60. py_misc_utils/moving_average.py +19 -0
  61. py_misc_utils/msgpack_streamer.py +26 -0
  62. py_misc_utils/multi_wait.py +24 -0
  63. py_misc_utils/multiprocessing.py +102 -0
  64. py_misc_utils/named_array.py +224 -0
  65. py_misc_utils/no_break.py +46 -0
  66. py_misc_utils/no_except.py +32 -0
  67. py_misc_utils/np_ml_framework.py +184 -0
  68. py_misc_utils/np_utils.py +346 -0
  69. py_misc_utils/ntuple_utils.py +38 -0
  70. py_misc_utils/num_utils.py +54 -0
  71. py_misc_utils/obj.py +73 -0
  72. py_misc_utils/object_cache.py +100 -0
  73. py_misc_utils/object_tracker.py +88 -0
  74. py_misc_utils/ordered_set.py +71 -0
  75. py_misc_utils/osfd.py +27 -0
  76. py_misc_utils/packet.py +22 -0
  77. py_misc_utils/parquet_streamer.py +69 -0
  78. py_misc_utils/pd_utils.py +254 -0
  79. py_misc_utils/periodic_task.py +61 -0
  80. py_misc_utils/pickle_wrap.py +121 -0
  81. py_misc_utils/pipeline.py +98 -0
  82. py_misc_utils/remap_pickle.py +50 -0
  83. py_misc_utils/resource_manager.py +155 -0
  84. py_misc_utils/rnd_utils.py +56 -0
  85. py_misc_utils/run_once.py +19 -0
  86. py_misc_utils/scheduler.py +135 -0
  87. py_misc_utils/select_params.py +300 -0
  88. py_misc_utils/signal.py +141 -0
  89. py_misc_utils/skl_utils.py +270 -0
  90. py_misc_utils/split.py +147 -0
  91. py_misc_utils/state.py +53 -0
  92. py_misc_utils/std_module.py +56 -0
  93. py_misc_utils/stream_dataframe.py +176 -0
  94. py_misc_utils/streamed_file.py +144 -0
  95. py_misc_utils/tempdir.py +79 -0
  96. py_misc_utils/template_replace.py +51 -0
  97. py_misc_utils/tensor_stream.py +269 -0
  98. py_misc_utils/thread_context.py +33 -0
  99. py_misc_utils/throttle.py +30 -0
  100. py_misc_utils/time_trigger.py +18 -0
  101. py_misc_utils/timegen.py +11 -0
  102. py_misc_utils/traceback.py +49 -0
  103. py_misc_utils/tracking_executor.py +91 -0
  104. py_misc_utils/transform_array.py +42 -0
  105. py_misc_utils/uncompress.py +35 -0
  106. py_misc_utils/url_fetcher.py +157 -0
  107. py_misc_utils/utils.py +538 -0
  108. py_misc_utils/varint.py +50 -0
  109. py_misc_utils/virt_array.py +52 -0
  110. py_misc_utils/weak_call.py +33 -0
  111. py_misc_utils/work_results.py +100 -0
  112. py_misc_utils/writeback_file.py +43 -0
  113. python_misc_utils-0.2.dist-info/METADATA +36 -0
  114. python_misc_utils-0.2.dist-info/RECORD +117 -0
  115. python_misc_utils-0.2.dist-info/WHEEL +5 -0
  116. python_misc_utils-0.2.dist-info/licenses/LICENSE +13 -0
  117. python_misc_utils-0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,19 @@
1
+ from . import assert_checks as tas
2
+ from . import num_utils as nu
3
+
4
+
5
+ class MovingAverage:
6
+
7
+ def __init__(self, factor, init=None):
8
+ tas.check(factor >= 0.0 and factor <= 1.0, msg=f'Invalid factor: {factor:.4e}')
9
+ self._factor = factor
10
+ self.value = init
11
+
12
+ def update(self, value):
13
+ if self.value is None:
14
+ self.value = value
15
+ else:
16
+ self.value = nu.mix(self.value, value, self._factor)
17
+
18
+ return self.value
19
+
@@ -0,0 +1,26 @@
1
+ import msgpack
2
+
3
+ from . import alog
4
+ from . import gfs
5
+ from . import utils as ut
6
+
7
+
8
+ class MsgPackStreamer:
9
+
10
+ def __init__(self, url, rename_columns=None, **kwargs):
11
+ self._url = url
12
+ self._rename_columns = ut.value_or(rename_columns, dict())
13
+ self._kwargs = kwargs
14
+
15
+ def generate(self):
16
+ with gfs.open(self._url, mode='rb', **self._kwargs) as stream:
17
+ unpacker = msgpack.Unpacker(stream)
18
+ for recd in unpacker:
19
+ for key, name in self._rename_columns.items():
20
+ recd[name] = recd.pop(key)
21
+
22
+ yield recd
23
+
24
+ def __iter__(self):
25
+ return self.generate()
26
+
@@ -0,0 +1,24 @@
1
+ import threading
2
+
3
+
4
+ class MultiWait:
5
+
6
+ def __init__(self, count):
7
+ self._count = count
8
+ self._sigcount = 0
9
+ self._lock = threading.Lock()
10
+ self._cond = threading.Condition(lock=self._lock)
11
+
12
+ def signal(self, n=1):
13
+ with self._lock:
14
+ self._sigcount = min(self._sigcount + n, self._count)
15
+ if self._sigcount == self._count:
16
+ self._cond.notify_all()
17
+
18
+ def wait(self, timeout=None):
19
+ with self._lock:
20
+ if self._count > self._sigcount:
21
+ return self._cond.wait(timeout=timeout)
22
+
23
+ return True
24
+
@@ -0,0 +1,102 @@
1
+ import functools
2
+ import gc
3
+ import multiprocessing
4
+ import signal
5
+ import sys
6
+
7
+ from . import alog
8
+ from . import cleanups
9
+ from . import global_namespace as gns
10
+
11
+
12
+ class TerminationError(Exception):
13
+ pass
14
+
15
+ def _sig_handler(sig, frame):
16
+ if sig == signal.SIGINT:
17
+ raise KeyboardInterrupt()
18
+ elif sig == signal.SIGTERM:
19
+ raise TerminationError()
20
+
21
+
22
+ def _process_setup():
23
+ signal.signal(signal.SIGINT, _sig_handler)
24
+ signal.signal(signal.SIGTERM, _sig_handler)
25
+
26
+
27
+ def _process_cleanup():
28
+ cleanups.run()
29
+ gc.collect()
30
+ # Ignore {INT, TERM} signals on exit.
31
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
32
+ signal.signal(signal.SIGTERM, signal.SIG_IGN)
33
+
34
+
35
+ _GNS_KEY = 'gns'
36
+
37
+ def _wrap_procfn_parent(method):
38
+ ctx = dict(method=method)
39
+ ctx.update({_GNS_KEY: gns.parent_switch(method)})
40
+
41
+ return ctx
42
+
43
+
44
+ def _wrap_procfn_child(method, pctx):
45
+ parent_gns = pctx.pop(_GNS_KEY, None)
46
+ if parent_gns is not None:
47
+ gns.child_switch(method, parent_gns)
48
+
49
+ return pctx
50
+
51
+
52
+ _CONTEXT_KEY = '_parent_context'
53
+
54
+ def _capture_parent_context(method, kwargs):
55
+ pctx = _wrap_procfn_parent(method)
56
+ kwargs.update({_CONTEXT_KEY: pctx})
57
+
58
+ return kwargs
59
+
60
+
61
+ def _apply_child_context(kwargs):
62
+ pctx = kwargs.pop(_CONTEXT_KEY)
63
+ pctx = _wrap_procfn_child(pctx['method'], pctx)
64
+
65
+ return kwargs
66
+
67
+
68
+ def procfn_wrap(procfn, *args, **kwargs):
69
+ try:
70
+ _process_setup()
71
+
72
+ return procfn(*args, **kwargs)
73
+ except (KeyboardInterrupt, TerminationError):
74
+ sys.exit(1)
75
+ except Exception as ex:
76
+ alog.exception(ex, exmsg=f'Exception while running process function')
77
+ raise
78
+ finally:
79
+ _process_cleanup()
80
+
81
+
82
+ def _wrapped_procfn(procfn, *args, **kwargs):
83
+ kwargs = _apply_child_context(kwargs)
84
+
85
+ return procfn(*args, **kwargs)
86
+
87
+
88
+ def create_process(procfn, args=(), kwargs=None, context=None, daemon=None):
89
+ if context is None:
90
+ mpctx = multiprocessing
91
+ elif isinstance(context, str):
92
+ mpctx = multiprocessing.get_context(method=context)
93
+ else:
94
+ mpctx = context
95
+
96
+ kwargs = dict() if kwargs is None else kwargs.copy()
97
+
98
+ kwargs = _capture_parent_context(mpctx.get_start_method(), kwargs)
99
+ target = functools.partial(procfn_wrap, _wrapped_procfn, procfn, *args, **kwargs)
100
+
101
+ return mpctx.Process(target=target, daemon=daemon)
102
+
@@ -0,0 +1,224 @@
1
+ import array
2
+ import re
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ from . import assert_checks as tas
8
+ from . import core_utils as cu
9
+ from . import np_utils as npu
10
+
11
+
12
+ _NOT_NUMERIC = 'xS'
13
+
14
+
15
+ class Field:
16
+
17
+ __slots__ = ('name', 'data', 'size', 'fmt', 'str_tbl')
18
+
19
+ def __init__(self, name, data, size, fmt, str_tbl):
20
+ self.name = name
21
+ self.data = data
22
+ self.size = size
23
+ self.fmt = fmt
24
+ self.str_tbl = str_tbl
25
+
26
+ def np_array(self):
27
+ arr = np.array(self.data)
28
+
29
+ return arr.reshape(-1, self.size) if self.size > 1 else arr
30
+
31
+ def append(self, arg):
32
+ if self.fmt == 'S':
33
+ arg = self.stringify(arg)
34
+
35
+ if self.size == 1:
36
+ self.data.append(arg)
37
+ else:
38
+ assert self.size == len(arg), f'{self.name}({self.size}) vs. {len(arg)}'
39
+ self.data.extend(arg)
40
+
41
+ def extend(self, arg):
42
+ assert len(arg) % self.size == 0, f'{self.name}({self.size}) vs. {len(arg)}'
43
+ if self.fmt == 'S':
44
+ arg = self.stringify(arg)
45
+
46
+ if isinstance(arg, np.ndarray) and isinstance(self.data, array.array):
47
+ nptype = np.dtype(self.data.typecode)
48
+ if nptype != arg.dtype:
49
+ arg = arg.astype(nptype)
50
+ self.data.frombytes(arg.tobytes())
51
+ else:
52
+ self.data.extend(arg)
53
+
54
+ def stringify(self, arg):
55
+ if self.size == 1:
56
+ return self.str_tbl.add(arg)
57
+
58
+ return tuple(self.str_tbl.add(x) for x in arg)
59
+
60
+ def __len__(self):
61
+ return len(self.data) // self.size
62
+
63
+ def __getitem__(self, i):
64
+ if self.size == 1:
65
+ return self.data[i]
66
+ elif isinstance(i, int):
67
+ offset = i * self.size
68
+
69
+ return self.data[offset: offset + self.size]
70
+ else:
71
+ start, stop, step = (x * self.size for x in i.indices(len(self)))
72
+ data = []
73
+ for n in range(start, stop, step):
74
+ data.append(self.data[n: n + self.size])
75
+
76
+ return data
77
+
78
+ @staticmethod
79
+ def create(name, fmt, str_tbl):
80
+ m = re.match(r'(\d+)([a-zA-Z])', fmt)
81
+ if m:
82
+ size, efmt = int(m.group(1)), m.group(2)
83
+ else:
84
+ size, efmt = 1, fmt
85
+
86
+ data = array.array(efmt) if not efmt in _NOT_NUMERIC else []
87
+
88
+ return Field(name, data, size, efmt, str_tbl)
89
+
90
+
91
+ class NamedArray:
92
+
93
+ def __init__(self, names, fmt=None):
94
+ fnames = re.split(r'\s*,\s*', names) if isinstance(names, str) else names
95
+ if fmt is None:
96
+ cnames, ffmt = [], []
97
+ for name_format in fnames:
98
+ m = re.match(r'(\w+)\s*=\s*(\w+)$', name_format)
99
+ tas.check(m, msg=f'Invalid name=format specification: {name_format}')
100
+ cnames.append(m.group(1))
101
+ ffmt.append(m.group(2))
102
+
103
+ fnames = cnames
104
+ else:
105
+ ffmt = re.findall(r'\d*[a-zA-Z]', fmt) if isinstance(fmt, str) else fmt
106
+ tas.check_eq(len(fnames), len(ffmt),
107
+ msg=f'Mismatching names and format sizes: {fnames} vs {ffmt}')
108
+
109
+ str_tbl = cu.StringTable()
110
+ fields = dict()
111
+ for name, fmt in zip(fnames, ffmt):
112
+ fields[name] = Field.create(name, fmt, str_tbl)
113
+
114
+ self._fields = fields
115
+ self._fieldseq = tuple(fields.values())
116
+ self._str_tbl = str_tbl
117
+
118
+ def add_column(self, name, fmt, data):
119
+ tas.check(name not in self._fields, msg=f'Column "{name}" already exists')
120
+
121
+ field = Field.create(name, fmt, self._str_tbl)
122
+
123
+ if isinstance(data, np.ndarray):
124
+ data = data.flatten()
125
+
126
+ field.extend(data)
127
+
128
+ self._fields[name] = field
129
+ self._fieldseq += (field,)
130
+
131
+ def append(self, *args):
132
+ for field, arg in zip(self._fieldseq, args):
133
+ field.append(arg)
134
+
135
+ def extend(self, other):
136
+ for field, ofield in zip(self._fieldseq, other._fieldseq):
137
+ field.extend(ofield.data)
138
+
139
+ def append_extend(self, *args):
140
+ for field, arg in zip(self._fieldseq, args):
141
+ field.extend(arg)
142
+
143
+ def kwappend(self, **kwargs):
144
+ for name, field in self._fields.items():
145
+ arg = kwargs[name]
146
+ field.append(arg)
147
+
148
+ def get_tuple_item(self, i):
149
+ item = []
150
+ for field in self._fieldseq:
151
+ item.append(field[i])
152
+
153
+ return tuple(item)
154
+
155
+ def get_arrays(self, names=None):
156
+ if names:
157
+ return tuple(self._fields[name].np_array() for name in names)
158
+
159
+ return tuple(field.np_array() for field in self._fieldseq)
160
+
161
+ def get_array(self, name):
162
+ return self._fields[name].np_array()
163
+
164
+ def data(self):
165
+ return {field.name: field.np_array() for field in self._fieldseq}
166
+
167
+ def to_numpy(self, dtype=None):
168
+ for field in self._fieldseq:
169
+ tas.check(field.fmt not in _NOT_NUMERIC,
170
+ msg=f'Only purely numeric arrays can be converted to ' \
171
+ f'numpy: {field.name}={field.fmt}')
172
+ tas.check_eq(field.size, 1,
173
+ msg=f'Only fields with size==1 can be converted to ' \
174
+ f'numpy: {field.name}={field.size}')
175
+
176
+ if dtype is None:
177
+ dtype = npu.infer_np_dtype(self.dtypes())
178
+
179
+ na = np.empty(self.shape, dtype=dtype)
180
+ for i, field in enumerate(self._fieldseq):
181
+ na[:, i] = field.data
182
+
183
+ return na
184
+
185
+ def __len__(self):
186
+ size = None
187
+ for field in self._fieldseq:
188
+ fsize = len(field)
189
+ if size is None:
190
+ size = fsize
191
+ else:
192
+ tas.check_eq(fsize, size,
193
+ msg=f'Unmatching size for "{field.name}": {fsize} vs. {size}')
194
+
195
+ return size or 0
196
+
197
+ def __getitem__(self, i):
198
+ item = dict()
199
+ for field in self._fieldseq:
200
+ item[field.name] = field[i]
201
+
202
+ return item
203
+
204
+ def __array__(self, dtype=None):
205
+ return self.to_numpy(dtype=dtype)
206
+
207
+ def to_dataframe(self):
208
+ df_data = dict()
209
+ for name, arr in self.data().items():
210
+ df_data[name] = arr if arr.ndim == 1 else arr.tolist()
211
+
212
+ return pd.DataFrame(data=df_data)
213
+
214
+ @property
215
+ def shape(self):
216
+ return (len(self), len(self._fieldseq))
217
+
218
+ def dtypes(self):
219
+ types = []
220
+ for field in self._fieldseq:
221
+ types.append(np.dtype('O') if field.fmt in _NOT_NUMERIC else np.dtype(field.fmt))
222
+
223
+ return tuple(types)
224
+
@@ -0,0 +1,46 @@
1
+ import signal
2
+
3
+ from . import core_utils as cu
4
+ from . import signal as sgn
5
+
6
+
7
+ class NoBreak:
8
+
9
+ SIGMAP = {
10
+ 'INT': signal.SIGINT,
11
+ 'TERM': signal.SIGTERM,
12
+ }
13
+
14
+ def __init__(self, sigs=None, exit_trigger=False):
15
+ if sigs is None:
16
+ self._signals = tuple(self.SIGMAP.values())
17
+ else:
18
+ if isinstance(sigs, str):
19
+ sigs = cu.splitstrip(sigs, ',')
20
+
21
+ self._signals = tuple(self.SIGMAP[sig] for sig in sigs)
22
+
23
+ self._exit_trigger = exit_trigger
24
+
25
+ def __enter__(self):
26
+ self._signal_received = []
27
+ for sig in self._signals:
28
+ sgn.signal(sig, self._handler, prio=sgn.STD_PRIO)
29
+
30
+ return self
31
+
32
+ def _handler(self, sig, frame):
33
+ self._signal_received.append((sig, frame))
34
+
35
+ return sgn.HANDLED
36
+
37
+ def __exit__(self, *exc):
38
+ for sig in self._signals:
39
+ sgn.unsignal(sig, self._handler)
40
+
41
+ if self._exit_trigger:
42
+ for sig, frame in self._signal_received:
43
+ sgn.trigger(sig, frame)
44
+
45
+ return False
46
+
@@ -0,0 +1,32 @@
1
+ # Using logging module directly here to avoid import the alog module, since this
2
+ # is supposed to be a lower level module with minimal/no local dependencies.
3
+ import logging
4
+ import traceback
5
+
6
+
7
+ def no_except(fn, *args, **kwargs):
8
+ try:
9
+ return fn(*args, **kwargs)
10
+ except Exception as ex:
11
+ # The logging.exception() API does emit in DEBUG level, and here we want something
12
+ # a bit more higher level.
13
+ tb = traceback.format_exc()
14
+ logging.warning(f'Exception while running function: {ex}\n{tb}')
15
+
16
+ return ex
17
+
18
+
19
+ def qno_except(fn, *args, **kwargs):
20
+ try:
21
+ return fn(*args, **kwargs)
22
+ except Exception as ex:
23
+ return ex
24
+
25
+
26
+ def xwrap_fn(fn, *args, **kwargs):
27
+
28
+ def fwrap():
29
+ return no_except(fn, *args, **kwargs)
30
+
31
+ return fwrap
32
+
@@ -0,0 +1,184 @@
1
+ import collections
2
+ import os
3
+
4
+
5
+ _MODULES = dict()
6
+
7
+
8
+ def _register_attr(mod, name, attr):
9
+ xa = getattr(mod, name, None)
10
+ if xa is not None:
11
+ raise RuntimeError(f'Attribute "{name}" already exists: {xa}')
12
+
13
+ setattr(mod, name, attr)
14
+
15
+
16
+ def _register(name, mod, checkfn, fromfn, attrs):
17
+ _MODULES[name] = mod
18
+ mod.__npml_name = name
19
+ mod.__npml_check = checkfn
20
+ mod.__npml_from = fromfn
21
+
22
+ for attr_name, attr_value in attrs.items():
23
+ _register_attr(mod, attr_name, attr_value)
24
+
25
+
26
+ def _parse_priorities():
27
+ prefs = os.getenv('NPML_ORDER', 'torch,np,jax,tf').split(',')
28
+
29
+ return {mod: len(prefs) - i for i, mod in enumerate(prefs)}
30
+
31
+
32
+ # Numpy
33
+ try:
34
+ import numpy as np
35
+
36
+ def _np_from(mod, t, tref):
37
+ if mod is not None:
38
+ if mod is torch or mod is tfnp:
39
+ return t.numpy()
40
+
41
+ return np.asarray(t)
42
+
43
+ def _np_check(t):
44
+ return isinstance(t, np.ndarray)
45
+
46
+ _register('np', np, _np_check, _np_from,
47
+ {
48
+ 'item': lambda t: t.item(),
49
+ 'tolist': lambda t: t.tolist(),
50
+ })
51
+ except ImportError:
52
+ np = None
53
+
54
+
55
+ # PyTorch
56
+ try:
57
+ import torch
58
+ from torch.utils import dlpack as torch_dlpack
59
+
60
+ def _torch_from(mod, t, tref):
61
+ if mod is not None:
62
+ if mod is np:
63
+ return torch.from_numpy(t).to(tref.device)
64
+ if mod is jaxnp:
65
+ return torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(t)).to(tref.device)
66
+ if mod is tfnp:
67
+ return torch_dlpack.from_dlpack(tf_dlpack.to_dlpack(t)).to(tref.device)
68
+
69
+ return torch.tensor(t).to(tref.device)
70
+
71
+ def _torch_check(t):
72
+ return isinstance(t, torch.Tensor)
73
+
74
+ _register('torch', torch, _torch_check, _torch_from,
75
+ {
76
+ 'item': lambda t: t.item(),
77
+ 'tolist': lambda t: t.tolist(),
78
+ })
79
+ except ImportError:
80
+ torch = None
81
+
82
+
83
+ # JAX
84
+ try:
85
+ import jax
86
+ from jax import dlpack as jax_dlpack
87
+ import jax.numpy as jaxnp
88
+
89
+ def _jaxdev(t):
90
+ return next(iter(t.devices()))
91
+
92
+ def _jax_from(mod, t, tref):
93
+ if mod is not None:
94
+ if mod is torch:
95
+ return jax.device_put(jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(t)), _jaxdev(tref))
96
+ if mod is tfnp:
97
+ return jax.device_put(jax_dlpack.from_dlpack(tf_dlpack.to_dlpack(t)), _jaxdev(tref))
98
+
99
+ return jax.device_put(jaxnp.asarray(t), _jaxdev(tref))
100
+
101
+ def _jax_check(t):
102
+ return isinstance(t, jax.Array)
103
+
104
+ _register('jax', jaxnp, _jax_check, _jax_from,
105
+ {
106
+ 'item': lambda t: t.item(),
107
+ 'tolist': lambda t: t.tolist(),
108
+ })
109
+ except ImportError:
110
+ jaxnp = None
111
+
112
+
113
+ # Tensorflow
114
+ try:
115
+ import tensorflow as tf
116
+ import tensorflow.experimental.dlpack as tf_dlpack
117
+ import tensorflow.experimental.numpy as tfnp
118
+
119
+ def _tf_from(mod, t, tref):
120
+ if mod is not None:
121
+ if mod is torch:
122
+ with tf.device(tref.device):
123
+ return tf_dlpack.from_dlpack(torch_dlpack.to_dlpack(t))
124
+ if mod is jaxnp:
125
+ with tf.device(tref.device):
126
+ return tf_dlpack.from_dlpack(jax_dlpack.to_dlpack(t))
127
+
128
+ with tf.device(tref.device):
129
+ return tf.convert_to_tensor(t)
130
+
131
+ def _tf_check(t):
132
+ return tf.is_tensor(t)
133
+
134
+ tfnp.experimental_enable_numpy_behavior()
135
+ _register('tf', tfnp, _tf_check, _tf_from,
136
+ {
137
+ 'item': lambda t: t.item(),
138
+ 'tolist': lambda t: t.tolist(),
139
+ })
140
+ except ImportError:
141
+ tfnp = None
142
+
143
+
144
+ _DEFAULT_MODULE = os.getenv('NPML_DEFAULT', 'np')
145
+
146
+ if _DEFAULT_MODULE not in _MODULES:
147
+ raise RuntimeError(f'Unable to find default Numpy ML module: {_DEFAULT_MODULE}')
148
+
149
+
150
+ _MODULES_PRIORITY = _parse_priorities()
151
+ _MODULES_SEQ = tuple(_MODULES[x] for x in sorted(list(_MODULES.keys()),
152
+ key=lambda m: _MODULES_PRIORITY.get(m, -1),
153
+ reverse=True))
154
+
155
+
156
+ def _get_module(t):
157
+ for mod in _MODULES_SEQ:
158
+ if mod.__npml_check(t):
159
+ return mod
160
+
161
+
162
+ def resolve(*args):
163
+ mods = []
164
+ tprio, tmod, tref = -1, None, None
165
+ for i, t in enumerate(args):
166
+ mod = _get_module(t)
167
+ mods.append(mod)
168
+ if mod is not None:
169
+ prio = _MODULES_PRIORITY[mod.__npml_name]
170
+ if prio > tprio:
171
+ tmod = mod
172
+ tprio = prio
173
+ tref = t
174
+
175
+ if tmod is None:
176
+ return _MODULES[_DEFAULT_MODULE], args
177
+
178
+ rargs = list(args)
179
+ for i, mod in enumerate(mods):
180
+ if tmod is not mod:
181
+ rargs[i] = tmod.__npml_from(mod, args[i], tref)
182
+
183
+ return tmod, tuple(rargs)
184
+