python-misc-utils 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. py_misc_utils/__init__.py +0 -0
  2. py_misc_utils/abs_timeout.py +12 -0
  3. py_misc_utils/alog.py +311 -0
  4. py_misc_utils/app_main.py +179 -0
  5. py_misc_utils/archive_streamer.py +112 -0
  6. py_misc_utils/assert_checks.py +118 -0
  7. py_misc_utils/ast_utils.py +121 -0
  8. py_misc_utils/async_manager.py +189 -0
  9. py_misc_utils/break_control.py +63 -0
  10. py_misc_utils/buffered_iterator.py +35 -0
  11. py_misc_utils/cached_file.py +507 -0
  12. py_misc_utils/call_limiter.py +26 -0
  13. py_misc_utils/call_result_selector.py +13 -0
  14. py_misc_utils/cleanups.py +85 -0
  15. py_misc_utils/cmd.py +97 -0
  16. py_misc_utils/compression.py +116 -0
  17. py_misc_utils/cond_waiter.py +13 -0
  18. py_misc_utils/context_base.py +18 -0
  19. py_misc_utils/context_managers.py +67 -0
  20. py_misc_utils/core_utils.py +577 -0
  21. py_misc_utils/daemon_process.py +252 -0
  22. py_misc_utils/data_cache.py +46 -0
  23. py_misc_utils/date_utils.py +90 -0
  24. py_misc_utils/debug.py +24 -0
  25. py_misc_utils/dyn_modules.py +50 -0
  26. py_misc_utils/dynamod.py +103 -0
  27. py_misc_utils/env_config.py +35 -0
  28. py_misc_utils/executor.py +239 -0
  29. py_misc_utils/file_overwrite.py +29 -0
  30. py_misc_utils/fin_wrap.py +77 -0
  31. py_misc_utils/fp_utils.py +47 -0
  32. py_misc_utils/fs/__init__.py +0 -0
  33. py_misc_utils/fs/file_fs.py +127 -0
  34. py_misc_utils/fs/ftp_fs.py +242 -0
  35. py_misc_utils/fs/gcs_fs.py +196 -0
  36. py_misc_utils/fs/http_fs.py +241 -0
  37. py_misc_utils/fs/s3_fs.py +417 -0
  38. py_misc_utils/fs_base.py +133 -0
  39. py_misc_utils/fs_utils.py +207 -0
  40. py_misc_utils/gcs_fs.py +169 -0
  41. py_misc_utils/gen_indices.py +54 -0
  42. py_misc_utils/gfs.py +371 -0
  43. py_misc_utils/git_repo.py +77 -0
  44. py_misc_utils/global_namespace.py +110 -0
  45. py_misc_utils/http_async_fetcher.py +139 -0
  46. py_misc_utils/http_server.py +196 -0
  47. py_misc_utils/http_utils.py +143 -0
  48. py_misc_utils/img_utils.py +20 -0
  49. py_misc_utils/infix_op.py +20 -0
  50. py_misc_utils/inspect_utils.py +205 -0
  51. py_misc_utils/iostream.py +21 -0
  52. py_misc_utils/iter_file.py +117 -0
  53. py_misc_utils/key_wrap.py +46 -0
  54. py_misc_utils/lazy_import.py +25 -0
  55. py_misc_utils/lockfile.py +164 -0
  56. py_misc_utils/mem_size.py +64 -0
  57. py_misc_utils/mirror_from.py +72 -0
  58. py_misc_utils/mmap.py +16 -0
  59. py_misc_utils/module_utils.py +196 -0
  60. py_misc_utils/moving_average.py +19 -0
  61. py_misc_utils/msgpack_streamer.py +26 -0
  62. py_misc_utils/multi_wait.py +24 -0
  63. py_misc_utils/multiprocessing.py +102 -0
  64. py_misc_utils/named_array.py +224 -0
  65. py_misc_utils/no_break.py +46 -0
  66. py_misc_utils/no_except.py +32 -0
  67. py_misc_utils/np_ml_framework.py +184 -0
  68. py_misc_utils/np_utils.py +346 -0
  69. py_misc_utils/ntuple_utils.py +38 -0
  70. py_misc_utils/num_utils.py +54 -0
  71. py_misc_utils/obj.py +73 -0
  72. py_misc_utils/object_cache.py +100 -0
  73. py_misc_utils/object_tracker.py +88 -0
  74. py_misc_utils/ordered_set.py +71 -0
  75. py_misc_utils/osfd.py +27 -0
  76. py_misc_utils/packet.py +22 -0
  77. py_misc_utils/parquet_streamer.py +69 -0
  78. py_misc_utils/pd_utils.py +254 -0
  79. py_misc_utils/periodic_task.py +61 -0
  80. py_misc_utils/pickle_wrap.py +121 -0
  81. py_misc_utils/pipeline.py +98 -0
  82. py_misc_utils/remap_pickle.py +50 -0
  83. py_misc_utils/resource_manager.py +155 -0
  84. py_misc_utils/rnd_utils.py +56 -0
  85. py_misc_utils/run_once.py +19 -0
  86. py_misc_utils/scheduler.py +135 -0
  87. py_misc_utils/select_params.py +300 -0
  88. py_misc_utils/signal.py +141 -0
  89. py_misc_utils/skl_utils.py +270 -0
  90. py_misc_utils/split.py +147 -0
  91. py_misc_utils/state.py +53 -0
  92. py_misc_utils/std_module.py +56 -0
  93. py_misc_utils/stream_dataframe.py +176 -0
  94. py_misc_utils/streamed_file.py +144 -0
  95. py_misc_utils/tempdir.py +79 -0
  96. py_misc_utils/template_replace.py +51 -0
  97. py_misc_utils/tensor_stream.py +269 -0
  98. py_misc_utils/thread_context.py +33 -0
  99. py_misc_utils/throttle.py +30 -0
  100. py_misc_utils/time_trigger.py +18 -0
  101. py_misc_utils/timegen.py +11 -0
  102. py_misc_utils/traceback.py +49 -0
  103. py_misc_utils/tracking_executor.py +91 -0
  104. py_misc_utils/transform_array.py +42 -0
  105. py_misc_utils/uncompress.py +35 -0
  106. py_misc_utils/url_fetcher.py +157 -0
  107. py_misc_utils/utils.py +538 -0
  108. py_misc_utils/varint.py +50 -0
  109. py_misc_utils/virt_array.py +52 -0
  110. py_misc_utils/weak_call.py +33 -0
  111. py_misc_utils/work_results.py +100 -0
  112. py_misc_utils/writeback_file.py +43 -0
  113. python_misc_utils-0.2.dist-info/METADATA +36 -0
  114. python_misc_utils-0.2.dist-info/RECORD +117 -0
  115. python_misc_utils-0.2.dist-info/WHEEL +5 -0
  116. python_misc_utils-0.2.dist-info/licenses/LICENSE +13 -0
  117. python_misc_utils-0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,300 @@
1
+ import collections
2
+ import functools
3
+ import multiprocessing as mp
4
+ import multiprocessing.pool as mp_pool
5
+ import os
6
+ import pickle
7
+ import re
8
+ import subprocess
9
+
10
+ import numpy as np
11
+
12
+ from . import alog
13
+ from . import gfs
14
+ from . import np_utils as npu
15
+ from . import utils as ut
16
+
17
+
18
+ Point = collections.namedtuple('Point', 'pid, idx')
19
+
20
+
21
+ _SCORES_X_RUN = ut.getenv('SCORES_X_RUN', dtype=int, defval=10)
22
+ _KGEN_EXTRA = ut.getenv('KGEN_EXTRA', dtype=float, defval=2)
23
+
24
+
25
+ def _norm_params(params):
26
+ nparams = dict()
27
+ for k, v in params.items():
28
+ if not isinstance(v, np.ndarray):
29
+ v = np.array(v)
30
+ nparams[k] = np.sort(v)
31
+
32
+ return nparams
33
+
34
+
35
+ def _get_space(params):
36
+ skeys = sorted(params.keys())
37
+
38
+ return skeys, [len(params[k]) for k in skeys]
39
+
40
+
41
+ def _mkdelta(idx, space, delta_std):
42
+ # Sample around the index.
43
+ rng = np.random.default_rng()
44
+
45
+ aspace = np.array(space)
46
+ delta = np.array(idx) + rng.standard_normal(len(idx)) * aspace * delta_std
47
+ delta = np.rint(delta).astype(np.int32)
48
+
49
+ return np.clip(delta, np.zeros_like(delta), aspace - 1)
50
+
51
+
52
+ def _mp_score_fn(score_fn, params):
53
+ return score_fn(**params)
54
+
55
+
56
+ def _is_worth_gain(pscore, score, min_gain_pct):
57
+ pabs = np.abs(pscore)
58
+ if np.isclose(pabs, 0):
59
+ return score > pscore
60
+
61
+ delta = (score - pscore) / pabs
62
+
63
+ return delta >= min_gain_pct
64
+
65
+
66
+ class Selector:
67
+
68
+ def __init__(self, params, seeds=None):
69
+ self.nparams = _norm_params(params)
70
+ self.processed = set()
71
+ self.scores_db = collections.defaultdict(list)
72
+ self.best_score, self.best_idx, self.best_param = None, None, None
73
+ self.pid_scores = dict()
74
+ self.blanks = 0
75
+ self.skeys, self.space = _get_space(self.nparams)
76
+ self.pts, self.cpid = [], 0
77
+ self.current_scores, self.processed_scores = [], 0
78
+ if seeds:
79
+ self._load_seeds(seeds)
80
+
81
+ def _load_seeds(self, seeds):
82
+ for sd in seeds:
83
+ idx = np.zeros(len(self.space), dtype=np.int32)
84
+ for i, k in enumerate(self.skeys):
85
+ v = sd[k]
86
+ idx[i] = np.argmin(np.abs(self.nparams[k] - v))
87
+
88
+ if not self._is_processed(idx):
89
+ self.pts.append(Point(self.cpid, idx))
90
+ self.cpid += 1
91
+ else:
92
+ alog.info(f'Seed already processed: {sd}')
93
+
94
+ def _register_scores(self, xparams, scores):
95
+ for params, score in zip(xparams, scores):
96
+ alog.debug0(f'SCORE: {score} -- {params}')
97
+
98
+ for k, v in params.items():
99
+ self.scores_db[k].append(v)
100
+ self.scores_db['SCORE'].append(score)
101
+
102
+ def _make_param(self, idx):
103
+ param = dict()
104
+ for i, k in enumerate(self.skeys):
105
+ # We keep parameters as numpy arrays, but when we pluck values we want to
106
+ # return them in Python scalar form.
107
+ pvalue = self.nparams[k][idx[i]]
108
+ param[k] = pvalue.item()
109
+
110
+ return param
111
+
112
+ def _score_slice(self, pts, score_fn, n_jobs=None, mp_ctx=None):
113
+ xparams = [self._make_param(pt.idx) for pt in pts]
114
+
115
+ n_jobs = os.cpu_count() if n_jobs is None else n_jobs
116
+ if n_jobs == 1:
117
+ scores = [score_fn(**p) for p in xparams]
118
+ else:
119
+ context = mp.get_context(mp_ctx if mp_ctx is not None else mp.get_start_method())
120
+ fn = functools.partial(_mp_score_fn, score_fn)
121
+ with mp_pool.Pool(processes=n_jobs if n_jobs > 0 else None, context=context) as pool:
122
+ scores = list(pool.map(fn, xparams))
123
+
124
+ self._register_scores(xparams, scores)
125
+
126
+ return scores
127
+
128
+ def _fetch_scores(self, score_fn, n_jobs=None, mp_ctx=None, scores_x_run=None,
129
+ status_path=None):
130
+ scores_x_run = scores_x_run or _SCORES_X_RUN
131
+
132
+ for i in range(self.processed_scores, len(self.pts), scores_x_run):
133
+ current_points = self.pts[i: i + scores_x_run]
134
+
135
+ scores = self._score_slice(current_points, score_fn, n_jobs=n_jobs, mp_ctx=mp_ctx)
136
+
137
+ self.current_scores.extend(scores)
138
+ self.processed_scores += len(current_points)
139
+
140
+ alog.info(f'Processed {self.processed_scores}/{len(self.pts)}: ' \
141
+ f'{ut.format(sorted(self.current_scores, reverse=True), ".6e")}')
142
+
143
+ if status_path is not None:
144
+ self.save_status(status_path)
145
+
146
+ def _select_top_n(self, top_n, min_pid_gain_pct):
147
+ # The np.argsort() has no "reverse" option, so it's either np.flip() or negate
148
+ # the scores.
149
+ sidx = np.flip(np.argsort(self.current_scores))
150
+
151
+ pseen, fsidx = set(), []
152
+ for i in sidx:
153
+ pt = self.pts[i]
154
+ if pt.pid not in pseen:
155
+ pseen.add(pt.pid)
156
+ pscore = self.pid_scores.get(pt.pid)
157
+ if pscore is None or _is_worth_gain(pscore, self.current_scores[i], min_pid_gain_pct):
158
+ self.pid_scores[pt.pid] = self.current_scores[i]
159
+ fsidx.append(i)
160
+ # The sidx array contains indices mapping to a descending sort of the
161
+ # scores, so once we have top_n of them, we know we have selected the
162
+ # higher ones available.
163
+ if len(fsidx) >= top_n:
164
+ break
165
+
166
+ return fsidx
167
+
168
+ def _is_processed(self, idx):
169
+ return idx.tobytes() in self.processed
170
+
171
+ def _generating(self, dest, count):
172
+ max_attempts = round(count * _KGEN_EXTRA)
173
+ n = 0
174
+ while count > len(dest) and n < max_attempts:
175
+ yield n
176
+ n += 1
177
+
178
+ def _randgen(self, count):
179
+ rng = np.random.default_rng()
180
+ high = np.array(self.space, dtype=np.int32)
181
+ low = np.zeros_like(high)
182
+
183
+ rpoints = []
184
+ for _ in self._generating(rpoints, count):
185
+ idx = rng.integers(low, high)
186
+ if not self._is_processed(idx):
187
+ rpoints.append(Point(self.cpid, idx))
188
+ self.cpid += 1
189
+
190
+ return rpoints
191
+
192
+ def _select_deltas(self, pt, delta_spacek, delta_std):
193
+ if delta_spacek is None:
194
+ num_deltas = len(self.space)
195
+ elif delta_spacek > 0:
196
+ num_deltas = int(np.ceil(len(self.space) * delta_spacek))
197
+ else:
198
+ num_deltas = int(np.rint(-delta_spacek))
199
+
200
+ deltas = []
201
+ for _ in self._generating(deltas, num_deltas):
202
+ idx = _mkdelta(pt.idx, self.space, delta_std)
203
+ if not self._is_processed(idx):
204
+ deltas.append(Point(pt.pid, idx))
205
+
206
+ return deltas
207
+
208
+ def __call__(self, score_fn,
209
+ status_path=None,
210
+ delta_spacek=None,
211
+ delta_std=0.1,
212
+ top_n=10,
213
+ explore_pct=0.05,
214
+ rnd_pct=0.2,
215
+ min_pid_gain_pct=0.01,
216
+ max_blanks_pct=0.1,
217
+ scores_x_run=None,
218
+ n_jobs=None,
219
+ mp_ctx=None):
220
+ alog.debug0(f'{len(self.space)} parameters, {np.prod(self.space)} configurations')
221
+
222
+ if not self.pts:
223
+ self.pts.extend(self._randgen(top_n))
224
+
225
+ max_explore = int(np.prod(self.space) * explore_pct)
226
+ max_blanks = int(max_explore * max_blanks_pct)
227
+ while self.pts and len(self.processed) < max_explore and self.blanks < max_blanks:
228
+ alog.debug0(f'{len(self.pts)} points, {len(self.processed)} processed ' \
229
+ f'(max {max_explore}), {self.blanks}/{max_blanks} blanks')
230
+
231
+ self._fetch_scores(score_fn,
232
+ n_jobs=n_jobs,
233
+ mp_ctx=mp_ctx,
234
+ scores_x_run=scores_x_run,
235
+ status_path=status_path)
236
+
237
+ fsidx = self._select_top_n(top_n, min_pid_gain_pct)
238
+
239
+ score = self.current_scores[fsidx[0]]
240
+ if self.best_score is None or score > self.best_score:
241
+ self.best_score = score
242
+ self.best_idx = self.pts[fsidx[0]].idx
243
+ self.best_param = self._make_param(self.best_idx)
244
+ self.blanks = 0
245
+
246
+ alog.debug0(f'BestScore = {self.best_score:.5e}\tParam = {self.best_param}')
247
+ else:
248
+ self.blanks += len(self.pts)
249
+ alog.info(f'Score not improved (current run top {score}, best {self.best_score})')
250
+
251
+ for pt in self.pts:
252
+ self.processed.add(pt.idx.tobytes())
253
+
254
+ # Sample around best points ...
255
+ next_pts = []
256
+ for i in fsidx:
257
+ next_pts.extend(self._select_deltas(self.pts[i], delta_spacek, delta_std))
258
+
259
+ # And randomly add ones in search of better scores.
260
+ rnd_count = max(top_n, int(rnd_pct * len(next_pts)))
261
+ next_pts.extend(self._randgen(rnd_count))
262
+
263
+ self.pts = next_pts
264
+ self.current_scores, self.processed_scores = [], 0
265
+
266
+ def save_status(self, path):
267
+ with gfs.open(path, mode='wb') as sfd:
268
+ pickle.dump(self, sfd, protocol=ut.pickle_proto())
269
+
270
+ @staticmethod
271
+ def load_status(path):
272
+ with gfs.open(path, mode='rb') as sfd:
273
+ return pickle.load(sfd)
274
+
275
+
276
+ _SCORE_TAG = 'SPSCORE'
277
+ _SCORE_FMT = os.getenv('SPSCORE_FMT', 'f')
278
+
279
+ def format_score(s):
280
+ return f'[{_SCORE_TAG}={s:{_SCORE_FMT}}]'
281
+
282
+
283
+ def match_score(data):
284
+ matches = re.findall(f'\[{_SCORE_TAG}=' + r'([^]]+)\]', data)
285
+
286
+ return [float(m) for m in matches]
287
+
288
+
289
+ def run_score_process(cmdline):
290
+ try:
291
+ output = subprocess.check_output(cmdline, stderr=subprocess.STDOUT)
292
+ except subprocess.CalledProcessError as ex:
293
+ alog.exception(ex, exmsg=f'Error while running scoring process: {ex.output.decode()}')
294
+ raise
295
+
296
+ if isinstance(output, bytes):
297
+ output = output.decode()
298
+
299
+ return match_score(output), output
300
+
@@ -0,0 +1,141 @@
1
+ import collections
2
+ import signal as sgn
3
+ import threading
4
+
5
+ from . import core_utils as cu
6
+ from . import global_namespace as gns
7
+ from . import traceback as tb
8
+
9
+
10
+ _Handler = collections.namedtuple('Handler', 'handler, prio')
11
+
12
+ MAX_PRIO = 0
13
+ MIN_PRIO = 99
14
+ STD_PRIO = MIN_PRIO
15
+ CALL_NEXT = 0
16
+ HANDLED = 1
17
+
18
+
19
+ class _SignalRegistry:
20
+
21
+ def __init__(self):
22
+ self.lock = threading.Lock()
23
+ self.handlers = dict()
24
+ self.prev_handlers = dict()
25
+
26
+ def signal(self, sig, handler, prio=None):
27
+ prio = STD_PRIO if prio is None else prio
28
+
29
+ with self.lock:
30
+ handlers = self.handlers.get(sig, ())
31
+ handlers += (_Handler(handler, prio),)
32
+ self.handlers[sig] = tuple(sorted(handlers, key=lambda h: h.prio))
33
+
34
+ if sig not in self.prev_handlers:
35
+ self.prev_handlers[sig] = sgn.signal(sig, _handler)
36
+
37
+ def unsignal(self, sig, uhandler):
38
+ handlers, dropped = [], 0
39
+ with self.lock:
40
+ for handler in self.handlers.get(sig, ()):
41
+ if handler.handler != uhandler:
42
+ handlers.append(handler)
43
+ else:
44
+ dropped += 1
45
+
46
+ if dropped:
47
+ self.handlers[sig] = tuple(handlers)
48
+ if not handlers:
49
+ sgn.signal(sig, self.prev_handlers.pop(sig))
50
+
51
+ return dropped
52
+
53
+ def sig_handler(self, sig, frame):
54
+ with self.lock:
55
+ handlers = self.handlers.get(sig, ())
56
+ prev_handler = self.prev_handlers.get(sig)
57
+
58
+ for handler in handlers:
59
+ hres = handler.handler(sig, frame)
60
+ if hres == HANDLED:
61
+ return
62
+
63
+ if callable(prev_handler):
64
+ prev_handler(sig, frame)
65
+ else:
66
+ handler = sgn.getsignal(sig)
67
+ if callable(handler):
68
+ handler(sig, frame)
69
+
70
+
71
+ def _parent_fn(sreg):
72
+ return sreg.prev_handlers
73
+
74
+
75
+ def _child_fn(prev_handlers):
76
+ for sig, prev_handler in prev_handlers.items():
77
+ sgn.signal(sig, prev_handler)
78
+
79
+ return _SignalRegistry()
80
+
81
+
82
+ def _create_fn():
83
+ return _SignalRegistry()
84
+
85
+
86
+ _SIGREG = gns.Var(f'{__name__}.SIGREG',
87
+ parent_fn=_parent_fn,
88
+ child_fn=_child_fn,
89
+ defval=_create_fn)
90
+
91
+ def _sig_registry():
92
+ return gns.get(_SIGREG)
93
+
94
+
95
+ def _handler(sig, frame):
96
+ sreg = _sig_registry()
97
+
98
+ sreg.sig_handler(sig, frame)
99
+
100
+
101
+ def trigger(sig, frame=None):
102
+ _handler(sig, frame or tb.get_frame(n=1))
103
+
104
+
105
+ def signal(sig, handler, prio=None):
106
+ sreg = _sig_registry()
107
+
108
+ sreg.signal(sig, handler, prio=prio)
109
+
110
+
111
+ def unsignal(sig, uhandler):
112
+ sreg = _sig_registry()
113
+
114
+ return sreg.unsignal(sig, uhandler)
115
+
116
+
117
+ class Signals:
118
+
119
+ def __init__(self, sig, handler, prio=None):
120
+ if isinstance(sig, str):
121
+ sig = [getattr(sgn, f'SIG{s.upper()}') for s in cu.splitstrip(sig, ',')]
122
+ elif not isinstance(sig, (list, tuple)):
123
+ sig = [sig]
124
+ if not isinstance(handler, (list, tuple)):
125
+ handler = [handler] * len(sig)
126
+
127
+ self._sigs = tuple(zip(sig, handler))
128
+ self._prio = prio
129
+
130
+ def __enter__(self):
131
+ for sig, handler in self._sigs:
132
+ signal(sig, handler, prio=self._prio)
133
+
134
+ return self
135
+
136
+ def __exit__(self, *exc):
137
+ for sig, handler in self._sigs:
138
+ unsignal(sig, handler)
139
+
140
+ return True
141
+
@@ -0,0 +1,270 @@
1
+ import os
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import sklearn.decomposition
6
+ import sklearn.neighbors
7
+ import sklearn.preprocessing
8
+ import sklearn.random_projection
9
+ import sklearn.utils.extmath
10
+
11
+ from . import alog
12
+ from . import assert_checks as tas
13
+ from . import utils as pyu
14
+
15
+
16
+ class Quantizer:
17
+
18
+ def __init__(self, nbins=3, std_k=2.0, eps=1e-6):
19
+ self._nbins = nbins
20
+ self._std_k = std_k
21
+ self._eps = eps
22
+ self._mean = None
23
+ self._std = None
24
+
25
+ def fit(self, X, *args):
26
+ self._mean = np.mean(X, axis=0, keepdims=True)
27
+ std = np.std(X, axis=0, keepdims=True)
28
+ self._std = np.where(std > self._eps, std, self._eps)
29
+
30
+ return self
31
+
32
+ def transform(self, X):
33
+ q = np.round(self._nbins * (X - self._mean) / (self._std_k * self._std))
34
+
35
+ return np.clip(q, -self._nbins, self._nbins)
36
+
37
+ def fit_transform(self, X, *args):
38
+ return self.fit(X).transform(X)
39
+
40
+
41
+ _TRANSFORMERS = {
42
+ 'QTZ': Quantizer,
43
+ 'STD_SCALE': sklearn.preprocessing.StandardScaler,
44
+ 'MIN_MAX': sklearn.preprocessing.MinMaxScaler,
45
+ 'MAX_ABS': sklearn.preprocessing.MaxAbsScaler,
46
+ 'ROBUST': sklearn.preprocessing.RobustScaler,
47
+ 'QUANT': sklearn.preprocessing.QuantileTransformer,
48
+ 'POWER': sklearn.preprocessing.PowerTransformer,
49
+ 'NORM': sklearn.preprocessing.Normalizer,
50
+ 'PCA': sklearn.decomposition.PCA,
51
+ 'FICA': sklearn.decomposition.FastICA,
52
+ 'INCPCA': sklearn.decomposition.IncrementalPCA,
53
+ 'KBINDIS': sklearn.preprocessing.KBinsDiscretizer,
54
+ 'POLYF': sklearn.preprocessing.PolynomialFeatures,
55
+ 'PWR': sklearn.preprocessing.PowerTransformer,
56
+ 'SPLN': sklearn.preprocessing.SplineTransformer,
57
+ 'GRPRJ': sklearn.random_projection.GaussianRandomProjection,
58
+ 'SRPRJ': sklearn.random_projection.SparseRandomProjection,
59
+ }
60
+
61
+ def parse_transform(trs_spec):
62
+ trs, *spec = trs_spec.split(':', 1)
63
+ spec_cfg = pyu.parse_config(spec[0]) if spec else dict()
64
+
65
+ alog.debug0(f'Parsed Transformer: {trs}\t{spec_cfg}')
66
+
67
+ trs_fn = _TRANSFORMERS.get(trs)
68
+ tas.check_is_not_none(trs_fn, msg=f'Unknown transformation requested: {trs_spec}')
69
+
70
+ return trs_fn(**spec_cfg)
71
+
72
+
73
+ class TransformPipeline:
74
+
75
+ def __init__(self, transformers):
76
+ self._transformers = []
77
+ for trs in transformers or []:
78
+ if isinstance(trs, str):
79
+ trs = parse_transform(trs)
80
+ self._transformers.append(trs)
81
+
82
+ @property
83
+ def transformers(self):
84
+ return tuple(self._transformers)
85
+
86
+ @property
87
+ def supports_partial_fit(self):
88
+ return all([hasattr(trs, 'partial_fit') for trs in self._transformers])
89
+
90
+ @property
91
+ def supports_fit(self):
92
+ return len(self._transformers) <= 1
93
+
94
+ def add(self, transformer):
95
+ self._transformers.append(transformer)
96
+
97
+ def transform(self, x):
98
+ tx = x
99
+ for trs in self._transformers:
100
+ tx = trs.transform(tx)
101
+
102
+ return tx
103
+
104
+ def fit_transform(self, x):
105
+ tx = x
106
+ for trs in self._transformers:
107
+ tx = trs.fit_transform(tx)
108
+
109
+ return tx
110
+
111
+ def fit(self, x):
112
+ tas.check(self.supports_fit,
113
+ msg=f'Only pipelines with a single transformer can call fit()')
114
+ for trs in self._transformers:
115
+ trs.fit(x)
116
+
117
+ return self
118
+
119
+ def partial_fit_transform(self, x):
120
+ tx = x
121
+ for trs in self._transformers:
122
+ trs.partial_fit(tx)
123
+ tx = trs.transform(tx)
124
+
125
+ return tx
126
+
127
+
128
+ def fit(m, x, y, **kwargs):
129
+ shape = y.shape
130
+ # Some SciKit Learn models (ie. KNeighborsClassifier/KNeighborsRegressor) although
131
+ # supporting multi-output targets, insist on special casing the 1-output case
132
+ # requiring the 1D (N,) vector instead of a (N, 1) tensor.
133
+ if len(shape) == 2 and shape[-1] == 1:
134
+ y = np.squeeze(pyu.to_numpy(y), axis=-1)
135
+
136
+ return m.fit(x, y, **kwargs)
137
+
138
+
139
+ def predict(m, x, **kwargs):
140
+ y = m.predict(x, **kwargs)
141
+
142
+ # Some SciKit Learn models (ie. KNeighborsClassifier/KNeighborsRegressor) although
143
+ # supporting multi-output targets, insist on special casing the 1-output case
144
+ # emitting 1D (N,) vectors instead of a (N, 1) tensors.
145
+ return y if y.ndim > 1 else y.reshape(-1, 1)
146
+
147
+
148
+ def predict_proba(m, x, classes=None):
149
+
150
+ def extract_probs(p, cls):
151
+ preds = [p[:, c].reshape(-1, 1) for c in cls]
152
+ return np.concatenate(preds, axis=-1)
153
+
154
+ probs = m.predict_proba(x)
155
+ if classes is None:
156
+ return probs
157
+
158
+ if isinstance(probs, (list, tuple)):
159
+ return [extract_probs(p, classes) for p in probs]
160
+
161
+ return extract_probs(probs, classes)
162
+
163
+
164
+ def _get_weights(dist, weights):
165
+ if weights in (None, "uniform"):
166
+ return np.ones_like(dist)
167
+ elif weights == "distance":
168
+ # if user attempts to classify a point that was zero distance from one
169
+ # or more training points, those training points are weighted as 1.0
170
+ # and the other points as 0.0
171
+ if dist.dtype is np.dtype:
172
+ for point_dist_i, point_dist in enumerate(dist):
173
+ # check if point_dist is iterable
174
+ # (ex: RadiusNeighborClassifier.predict may set an element of
175
+ # dist to 1e-6 to represent an 'outlier')
176
+ if hasattr(point_dist, "__contains__") and 0.0 in point_dist:
177
+ dist[point_dist_i] = point_dist == 0.0
178
+ else:
179
+ dist[point_dist_i] = 1.0 / point_dist
180
+ else:
181
+ with np.errstate(divide="ignore"):
182
+ dist = 1.0 / dist
183
+ inf_mask = np.isinf(dist)
184
+ inf_row = np.any(inf_mask, axis=1)
185
+ dist[inf_row] = inf_mask[inf_row]
186
+ return dist
187
+ elif callable(weights):
188
+ return weights(dist)
189
+ else:
190
+ alog.xraise(ValueError, f'Unrecognized "weights" value: {weights}')
191
+
192
+
193
+ class WeighedKNNClassifier:
194
+
195
+ def __init__(self, **kwargs):
196
+ self._sample_weight = None
197
+ self._neigh = sklearn.neighbors.KNeighborsClassifier(**kwargs)
198
+
199
+ def fit(self, X, y, sample_weight=None):
200
+ self._sample_weight = sample_weight
201
+ self._neigh.fit(X, y)
202
+
203
+ return self
204
+
205
+ def _prepare_predict(self, X):
206
+ neigh_dist, neigh_ind = self._neigh.kneighbors(X)
207
+
208
+ classes = self._neigh.classes_
209
+ y = self._neigh._y
210
+ if not self._neigh.outputs_2d_:
211
+ y = y.reshape(-1, 1)
212
+ classes = [classes]
213
+ cweights = _get_weights(neigh_dist, self._neigh.weights)
214
+ sample_weight = self._sample_weight
215
+ if sample_weight.ndim < y.ndim:
216
+ sample_weight = sample_weight.reshape(-1, 1)
217
+
218
+ return pyu.make_object(neigh_dist=neigh_dist,
219
+ neigh_ind=neigh_ind,
220
+ y=y,
221
+ classes=classes,
222
+ cweights=cweights,
223
+ sample_weight=sample_weight)
224
+
225
+ def predict(self, X):
226
+ if self._sample_weight is None:
227
+ return self._neigh.predict(X)
228
+
229
+ pp = self._prepare_predict(X)
230
+
231
+ y_pred = np.empty((len(X), len(pp.classes)), dtype=pp.classes[0].dtype)
232
+ for k, k_classes in enumerate(pp.classes):
233
+ weights = pp.sample_weight[pp.neigh_ind, k] * pp.cweights
234
+ mode, _ = sklearn.utils.extmath.weighted_mode(pp.y[pp.neigh_ind, k], weights,
235
+ axis=1)
236
+
237
+ mode = np.asarray(mode.ravel(), dtype=np.intp)
238
+ y_pred[:, k] = k_classes.take(mode)
239
+
240
+ if not self._neigh.outputs_2d_:
241
+ y_pred = y_pred.ravel()
242
+
243
+ return y_pred
244
+
245
+ def predict_proba(self, X):
246
+ if self._sample_weight is None:
247
+ return self._neigh.predict_proba(X)
248
+
249
+ pp = self._prepare_predict(X)
250
+
251
+ all_rows = np.arange(len(X))
252
+ probabilities = []
253
+ for k, k_classes in enumerate(pp.classes):
254
+ pred_labels = pp.y[:, k][pp.neigh_ind]
255
+ proba_k = np.zeros((len(X), k_classes.size))
256
+
257
+ ksweight = pp.sample_weight[:, k]
258
+ for i, idx in enumerate(pred_labels.T):
259
+ proba_k[all_rows, idx] += ksweight * pp.cweights[:, i]
260
+
261
+ normalizer = proba_k.sum(axis=1)[:, np.newaxis]
262
+ normalizer[normalizer == 0.0] = 1.0
263
+
264
+ probabilities.append(proba_k / normalizer)
265
+
266
+ if not self._neigh.outputs_2d_:
267
+ probabilities = probabilities[0]
268
+
269
+ return probabilities
270
+