python-misc-utils 0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py_misc_utils/__init__.py +0 -0
- py_misc_utils/abs_timeout.py +12 -0
- py_misc_utils/alog.py +311 -0
- py_misc_utils/app_main.py +179 -0
- py_misc_utils/archive_streamer.py +112 -0
- py_misc_utils/assert_checks.py +118 -0
- py_misc_utils/ast_utils.py +121 -0
- py_misc_utils/async_manager.py +189 -0
- py_misc_utils/break_control.py +63 -0
- py_misc_utils/buffered_iterator.py +35 -0
- py_misc_utils/cached_file.py +507 -0
- py_misc_utils/call_limiter.py +26 -0
- py_misc_utils/call_result_selector.py +13 -0
- py_misc_utils/cleanups.py +85 -0
- py_misc_utils/cmd.py +97 -0
- py_misc_utils/compression.py +116 -0
- py_misc_utils/cond_waiter.py +13 -0
- py_misc_utils/context_base.py +18 -0
- py_misc_utils/context_managers.py +67 -0
- py_misc_utils/core_utils.py +577 -0
- py_misc_utils/daemon_process.py +252 -0
- py_misc_utils/data_cache.py +46 -0
- py_misc_utils/date_utils.py +90 -0
- py_misc_utils/debug.py +24 -0
- py_misc_utils/dyn_modules.py +50 -0
- py_misc_utils/dynamod.py +103 -0
- py_misc_utils/env_config.py +35 -0
- py_misc_utils/executor.py +239 -0
- py_misc_utils/file_overwrite.py +29 -0
- py_misc_utils/fin_wrap.py +77 -0
- py_misc_utils/fp_utils.py +47 -0
- py_misc_utils/fs/__init__.py +0 -0
- py_misc_utils/fs/file_fs.py +127 -0
- py_misc_utils/fs/ftp_fs.py +242 -0
- py_misc_utils/fs/gcs_fs.py +196 -0
- py_misc_utils/fs/http_fs.py +241 -0
- py_misc_utils/fs/s3_fs.py +417 -0
- py_misc_utils/fs_base.py +133 -0
- py_misc_utils/fs_utils.py +207 -0
- py_misc_utils/gcs_fs.py +169 -0
- py_misc_utils/gen_indices.py +54 -0
- py_misc_utils/gfs.py +371 -0
- py_misc_utils/git_repo.py +77 -0
- py_misc_utils/global_namespace.py +110 -0
- py_misc_utils/http_async_fetcher.py +139 -0
- py_misc_utils/http_server.py +196 -0
- py_misc_utils/http_utils.py +143 -0
- py_misc_utils/img_utils.py +20 -0
- py_misc_utils/infix_op.py +20 -0
- py_misc_utils/inspect_utils.py +205 -0
- py_misc_utils/iostream.py +21 -0
- py_misc_utils/iter_file.py +117 -0
- py_misc_utils/key_wrap.py +46 -0
- py_misc_utils/lazy_import.py +25 -0
- py_misc_utils/lockfile.py +164 -0
- py_misc_utils/mem_size.py +64 -0
- py_misc_utils/mirror_from.py +72 -0
- py_misc_utils/mmap.py +16 -0
- py_misc_utils/module_utils.py +196 -0
- py_misc_utils/moving_average.py +19 -0
- py_misc_utils/msgpack_streamer.py +26 -0
- py_misc_utils/multi_wait.py +24 -0
- py_misc_utils/multiprocessing.py +102 -0
- py_misc_utils/named_array.py +224 -0
- py_misc_utils/no_break.py +46 -0
- py_misc_utils/no_except.py +32 -0
- py_misc_utils/np_ml_framework.py +184 -0
- py_misc_utils/np_utils.py +346 -0
- py_misc_utils/ntuple_utils.py +38 -0
- py_misc_utils/num_utils.py +54 -0
- py_misc_utils/obj.py +73 -0
- py_misc_utils/object_cache.py +100 -0
- py_misc_utils/object_tracker.py +88 -0
- py_misc_utils/ordered_set.py +71 -0
- py_misc_utils/osfd.py +27 -0
- py_misc_utils/packet.py +22 -0
- py_misc_utils/parquet_streamer.py +69 -0
- py_misc_utils/pd_utils.py +254 -0
- py_misc_utils/periodic_task.py +61 -0
- py_misc_utils/pickle_wrap.py +121 -0
- py_misc_utils/pipeline.py +98 -0
- py_misc_utils/remap_pickle.py +50 -0
- py_misc_utils/resource_manager.py +155 -0
- py_misc_utils/rnd_utils.py +56 -0
- py_misc_utils/run_once.py +19 -0
- py_misc_utils/scheduler.py +135 -0
- py_misc_utils/select_params.py +300 -0
- py_misc_utils/signal.py +141 -0
- py_misc_utils/skl_utils.py +270 -0
- py_misc_utils/split.py +147 -0
- py_misc_utils/state.py +53 -0
- py_misc_utils/std_module.py +56 -0
- py_misc_utils/stream_dataframe.py +176 -0
- py_misc_utils/streamed_file.py +144 -0
- py_misc_utils/tempdir.py +79 -0
- py_misc_utils/template_replace.py +51 -0
- py_misc_utils/tensor_stream.py +269 -0
- py_misc_utils/thread_context.py +33 -0
- py_misc_utils/throttle.py +30 -0
- py_misc_utils/time_trigger.py +18 -0
- py_misc_utils/timegen.py +11 -0
- py_misc_utils/traceback.py +49 -0
- py_misc_utils/tracking_executor.py +91 -0
- py_misc_utils/transform_array.py +42 -0
- py_misc_utils/uncompress.py +35 -0
- py_misc_utils/url_fetcher.py +157 -0
- py_misc_utils/utils.py +538 -0
- py_misc_utils/varint.py +50 -0
- py_misc_utils/virt_array.py +52 -0
- py_misc_utils/weak_call.py +33 -0
- py_misc_utils/work_results.py +100 -0
- py_misc_utils/writeback_file.py +43 -0
- python_misc_utils-0.2.dist-info/METADATA +36 -0
- python_misc_utils-0.2.dist-info/RECORD +117 -0
- python_misc_utils-0.2.dist-info/WHEEL +5 -0
- python_misc_utils-0.2.dist-info/licenses/LICENSE +13 -0
- python_misc_utils-0.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import functools
|
|
3
|
+
import io
|
|
4
|
+
import os
|
|
5
|
+
import pathlib
|
|
6
|
+
import shutil
|
|
7
|
+
import tempfile
|
|
8
|
+
|
|
9
|
+
from . import alog
|
|
10
|
+
from . import assert_checks as tas
|
|
11
|
+
from . import osfd
|
|
12
|
+
from . import rnd_utils as rngu
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def home():
|
|
16
|
+
return pathlib.Path.home()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def maybe_remove(path):
|
|
20
|
+
if os.path.exists(path):
|
|
21
|
+
os.remove(path)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def is_binary(fd):
|
|
25
|
+
if (mode := getattr(fd, 'mode', None)) is not None:
|
|
26
|
+
return 'b' in mode
|
|
27
|
+
|
|
28
|
+
return not isinstance(fd, io.TextIOBase)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def link_or_copy(src_path, dest_path):
|
|
32
|
+
try:
|
|
33
|
+
os.link(src_path, dest_path)
|
|
34
|
+
|
|
35
|
+
return dest_path
|
|
36
|
+
except OSError as ex:
|
|
37
|
+
alog.debug(f'Harklink failed from "{src_path}" to "{dest_path}", trying symlink. ' \
|
|
38
|
+
f'Error was: {ex}')
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
os.symlink(src_path, dest_path)
|
|
42
|
+
|
|
43
|
+
return dest_path
|
|
44
|
+
except OSError:
|
|
45
|
+
alog.debug(f'Symlink failed from "{src_path}" to "{dest_path}", going to copy. ' \
|
|
46
|
+
f'Error was: {ex}')
|
|
47
|
+
|
|
48
|
+
shutil.copyfile(src_path, dest_path)
|
|
49
|
+
shutil.copystat(src_path, dest_path)
|
|
50
|
+
|
|
51
|
+
return dest_path
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def is_newer_file(path, other):
|
|
55
|
+
return os.stat(path).st_mtime > os.stat(other).st_mtime
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def os_opener(*args, **kwargs):
|
|
59
|
+
return functools.partial(os.open, *args, **kwargs)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def safe_rmtree(path, **kwargs):
|
|
63
|
+
tpath = temp_path(nspath=path)
|
|
64
|
+
try:
|
|
65
|
+
os.rename(path, tpath)
|
|
66
|
+
except FileNotFoundError:
|
|
67
|
+
if not kwargs.get('ignore_errors', False):
|
|
68
|
+
raise
|
|
69
|
+
else:
|
|
70
|
+
shutil.rmtree(tpath, **kwargs)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def readall(fd):
|
|
74
|
+
if isinstance(fd, str):
|
|
75
|
+
with osfd.OsFd(fd, os.O_RDONLY) as ifd:
|
|
76
|
+
return readall(ifd)
|
|
77
|
+
else:
|
|
78
|
+
sres = os.stat(fd)
|
|
79
|
+
os.lseek(fd, 0, os.SEEK_SET)
|
|
80
|
+
|
|
81
|
+
return os.read(fd, sres.st_size)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def stat(path):
|
|
85
|
+
try:
|
|
86
|
+
return os.stat(path)
|
|
87
|
+
except:
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def path_split(path):
|
|
92
|
+
path_parts = []
|
|
93
|
+
while True:
|
|
94
|
+
parts = os.path.split(path)
|
|
95
|
+
if parts[0] == path:
|
|
96
|
+
path_parts.append(parts[0])
|
|
97
|
+
break
|
|
98
|
+
elif parts[1] == path:
|
|
99
|
+
path_parts.append(parts[1])
|
|
100
|
+
break
|
|
101
|
+
else:
|
|
102
|
+
path = parts[0]
|
|
103
|
+
path_parts.append(parts[1])
|
|
104
|
+
|
|
105
|
+
path_parts.reverse()
|
|
106
|
+
|
|
107
|
+
return path_parts
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def drop_ext(path, exts):
|
|
111
|
+
xpath, ext = os.path.splitext(path)
|
|
112
|
+
|
|
113
|
+
if isinstance(exts, str):
|
|
114
|
+
return xpath if ext == exts else path
|
|
115
|
+
|
|
116
|
+
return xpath if ext in exts else path
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def normpath(path):
|
|
120
|
+
path = os.path.expanduser(path)
|
|
121
|
+
path = os.path.expandvars(path)
|
|
122
|
+
path = os.path.abspath(path)
|
|
123
|
+
|
|
124
|
+
return os.path.normpath(path)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def localfs_mount(path):
|
|
128
|
+
while True:
|
|
129
|
+
parent_path = os.path.dirname(path)
|
|
130
|
+
if path == parent_path or os.path.ismount(path):
|
|
131
|
+
return path
|
|
132
|
+
path = parent_path
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def enum_chunks(stream, chunk_size=16 * 1024**2):
|
|
136
|
+
while True:
|
|
137
|
+
data = stream.read(chunk_size)
|
|
138
|
+
if data:
|
|
139
|
+
yield data
|
|
140
|
+
if chunk_size > len(data):
|
|
141
|
+
break
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def du(path, follow_symlinks=None, visited=None):
|
|
145
|
+
visited = set() if visited is None else visited
|
|
146
|
+
follow_symlinks = follow_symlinks not in (None, False)
|
|
147
|
+
|
|
148
|
+
size, dirs = 0, []
|
|
149
|
+
if path not in visited:
|
|
150
|
+
visited.add(path)
|
|
151
|
+
with os.scandir(path) as sdit:
|
|
152
|
+
for de in sdit:
|
|
153
|
+
if de.is_file(follow_symlinks=follow_symlinks):
|
|
154
|
+
sres = de.stat()
|
|
155
|
+
size += sres.st_size
|
|
156
|
+
elif de.is_dir(follow_symlinks=follow_symlinks):
|
|
157
|
+
dirs.append(de.name)
|
|
158
|
+
|
|
159
|
+
for dname in dirs:
|
|
160
|
+
size += du(os.path.join(path, dname),
|
|
161
|
+
follow_symlinks=follow_symlinks,
|
|
162
|
+
visited=visited)
|
|
163
|
+
|
|
164
|
+
return size
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def find_path(name, paths, checkfn=os.path.exists):
|
|
168
|
+
for path in paths:
|
|
169
|
+
cpath = os.path.join(path, name)
|
|
170
|
+
if checkfn(cpath):
|
|
171
|
+
return cpath
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
_TMPFN_RNDSIZE = int(os.getenv('TMPFN_RNDSIZE', 10))
|
|
175
|
+
|
|
176
|
+
def temp_path(nspath=None, nsdir=None, rndsize=_TMPFN_RNDSIZE):
|
|
177
|
+
if nspath is not None:
|
|
178
|
+
bpath, ext = os.path.splitext(nspath)
|
|
179
|
+
|
|
180
|
+
return f'{bpath}.{rngu.rand_string(rndsize)}{ext}'
|
|
181
|
+
|
|
182
|
+
nsdir = tempfile.gettempdir() if nsdir is None else nsdir
|
|
183
|
+
|
|
184
|
+
return os.path.join(nsdir, f'{rngu.rand_string(rndsize)}.tmp')
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# This does FileOverwrite() task (although locally limited) but here we do not
|
|
188
|
+
# pull that dependency to allow inlcusion in this module (which allows none).
|
|
189
|
+
@contextlib.contextmanager
|
|
190
|
+
def atomic_write(path, mode='wb', create_parents=False):
|
|
191
|
+
tpath = temp_path(nspath=path)
|
|
192
|
+
|
|
193
|
+
if create_parents:
|
|
194
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
195
|
+
|
|
196
|
+
fd = open(tpath, mode=mode)
|
|
197
|
+
try:
|
|
198
|
+
yield fd
|
|
199
|
+
fd.close()
|
|
200
|
+
fd = None
|
|
201
|
+
finally:
|
|
202
|
+
if fd is not None:
|
|
203
|
+
fd.close()
|
|
204
|
+
os.remove(tpath)
|
|
205
|
+
else:
|
|
206
|
+
os.replace(tpath, path)
|
|
207
|
+
|
py_misc_utils/gcs_fs.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import os
|
|
3
|
+
import stat as st
|
|
4
|
+
|
|
5
|
+
import google.cloud.storage as gcs
|
|
6
|
+
|
|
7
|
+
from . import alog
|
|
8
|
+
from . import fs_base as fsb
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# https://cloud.google.com/appengine/docs/legacy/standard/python/googlecloudstorageclient/read-write-to-cloud-storage
|
|
12
|
+
# https://cloud.google.com/python/docs/reference/storage/1.44.0/client
|
|
13
|
+
# https://cloud.google.com/python/docs/reference/storage/1.44.0/blobs#google.cloud.storage.blob.Blob
|
|
14
|
+
|
|
15
|
+
class GcsFs:
|
|
16
|
+
|
|
17
|
+
def __init__(self, bucket):
|
|
18
|
+
self.bucket = bucket
|
|
19
|
+
self._client = gcs.Client()
|
|
20
|
+
|
|
21
|
+
def _blob_stat(self, blob, base_path=None):
|
|
22
|
+
name = blob.name
|
|
23
|
+
if base_path is not None:
|
|
24
|
+
if not name.startswith(base_path):
|
|
25
|
+
return
|
|
26
|
+
name = name[len(base_path):]
|
|
27
|
+
spos = name.find('/')
|
|
28
|
+
if spos > 0:
|
|
29
|
+
name = name[: spos]
|
|
30
|
+
mode = st.S_IFDIR
|
|
31
|
+
size, etag = 0, None
|
|
32
|
+
else:
|
|
33
|
+
mode = st.S_IFREG
|
|
34
|
+
size, etag = blob.size, blob.etag
|
|
35
|
+
|
|
36
|
+
path = base_path + name
|
|
37
|
+
else:
|
|
38
|
+
spos = name.rfind('/')
|
|
39
|
+
if spos >= 0:
|
|
40
|
+
name = name[spos + 1:]
|
|
41
|
+
mode = st.S_IFREG
|
|
42
|
+
path = blob.name
|
|
43
|
+
size, etag = blob.size, blob.etag
|
|
44
|
+
|
|
45
|
+
return fsb.DirEntry(name=name,
|
|
46
|
+
path=path,
|
|
47
|
+
etag=etag,
|
|
48
|
+
st_mode=mode,
|
|
49
|
+
st_size=size,
|
|
50
|
+
st_ctime=blob.time_created.timestamp(),
|
|
51
|
+
st_mtime=blob.updated.timestamp())
|
|
52
|
+
|
|
53
|
+
def _norm_path(self, path):
|
|
54
|
+
if path:
|
|
55
|
+
if path == '/':
|
|
56
|
+
path = ''
|
|
57
|
+
else:
|
|
58
|
+
path = path + '/' if not path.endswith('/') else path
|
|
59
|
+
|
|
60
|
+
return path
|
|
61
|
+
|
|
62
|
+
def listdir(self, path):
|
|
63
|
+
npath = self._norm_path(path)
|
|
64
|
+
|
|
65
|
+
dentries = dict()
|
|
66
|
+
for blob in self._client.list_blobs(self.bucket, prefix=npath):
|
|
67
|
+
if (dentry := self._blob_stat(blob, base_path=npath)) is not None:
|
|
68
|
+
xdentry = dentries.get(dentry.name)
|
|
69
|
+
if xdentry is not None:
|
|
70
|
+
dentry = dentry._replace(st_ctime=min(dentry.st_ctime, xdentry.st_ctime),
|
|
71
|
+
st_mtime=max(dentry.st_mtime, xdentry.st_mtime))
|
|
72
|
+
|
|
73
|
+
dentries[dentry.name] = dentry
|
|
74
|
+
|
|
75
|
+
sorted_dentries = sorted(dentries.items(), key=lambda x: (x[1].st_mode, x[0]))
|
|
76
|
+
for name, dentry in sorted_dentries:
|
|
77
|
+
yield dentry
|
|
78
|
+
|
|
79
|
+
def open(self, path, mode='rb'):
|
|
80
|
+
bucket = self._client.bucket(self.bucket)
|
|
81
|
+
blob = bucket.blob(path)
|
|
82
|
+
|
|
83
|
+
return blob.open(mode)
|
|
84
|
+
|
|
85
|
+
def upload(self, path, source):
|
|
86
|
+
bucket = self._client.bucket(self.bucket)
|
|
87
|
+
blob = bucket.blob(path)
|
|
88
|
+
|
|
89
|
+
with blob.open('wb') as fd:
|
|
90
|
+
for data in source:
|
|
91
|
+
fd.write(data)
|
|
92
|
+
|
|
93
|
+
def download(self, path, chunk_size=32 * 1024**2):
|
|
94
|
+
bucket = self._client.bucket(self.bucket)
|
|
95
|
+
blob = bucket.blob(path)
|
|
96
|
+
|
|
97
|
+
with blob.open('rb') as fd:
|
|
98
|
+
while True:
|
|
99
|
+
data = fd.read(chunk_size)
|
|
100
|
+
if data:
|
|
101
|
+
yield data
|
|
102
|
+
if chunk_size > len(data):
|
|
103
|
+
break
|
|
104
|
+
|
|
105
|
+
def pread(self, path, offset, size):
|
|
106
|
+
bucket = self._client.bucket(self.bucket)
|
|
107
|
+
blob = bucket.blob(path)
|
|
108
|
+
|
|
109
|
+
return blob.download_as_bytes(start=offset, end=offset + size - 1, raw_download=True)
|
|
110
|
+
|
|
111
|
+
def exists(self, path):
|
|
112
|
+
bucket = self._client.bucket(self.bucket)
|
|
113
|
+
blob = bucket.blob(path)
|
|
114
|
+
|
|
115
|
+
return blob.exists()
|
|
116
|
+
|
|
117
|
+
def stat(self, path):
|
|
118
|
+
bucket = self._client.bucket(self.bucket)
|
|
119
|
+
blob = bucket.get_blob(path)
|
|
120
|
+
if blob is not None:
|
|
121
|
+
return self._blob_stat(blob)
|
|
122
|
+
|
|
123
|
+
ctime = mtime = None
|
|
124
|
+
for de in self.listdir(path):
|
|
125
|
+
if ctime is None or de.st_ctime < ctime:
|
|
126
|
+
ctime = de.st_ctime
|
|
127
|
+
if mtime is None or de.st_mtime > mtime:
|
|
128
|
+
mtime = de.st_mtime
|
|
129
|
+
|
|
130
|
+
if ctime is not None and mtime is not None:
|
|
131
|
+
bpath = path[: -1] if path.endswith('/') else path
|
|
132
|
+
name = os.path.basename(bpath)
|
|
133
|
+
|
|
134
|
+
return fsb.DirEntry(name=name,
|
|
135
|
+
path=bpath,
|
|
136
|
+
st_mode=st.S_IFDIR,
|
|
137
|
+
st_size=0,
|
|
138
|
+
st_ctime=ctime,
|
|
139
|
+
st_mtime=mtime)
|
|
140
|
+
|
|
141
|
+
def remove(self, path):
|
|
142
|
+
bucket = self._client.bucket(self.bucket)
|
|
143
|
+
blob = bucket.blob(path)
|
|
144
|
+
blob.delete()
|
|
145
|
+
|
|
146
|
+
def rename(self, src_path, dest_path):
|
|
147
|
+
bucket = self._client.bucket(self.bucket)
|
|
148
|
+
src_blob = bucket.blob(src_path)
|
|
149
|
+
|
|
150
|
+
bucket.copy_blob(src_blob, bucket, dest_path)
|
|
151
|
+
bucket.delete_blob(src_path)
|
|
152
|
+
|
|
153
|
+
def rmtree(self, path, ignore_errors=None):
|
|
154
|
+
npath = self._norm_path(path)
|
|
155
|
+
|
|
156
|
+
for blob in self._client.list_blobs(self.bucket, prefix=npath):
|
|
157
|
+
try:
|
|
158
|
+
blob.delete()
|
|
159
|
+
except Exception as ex:
|
|
160
|
+
alog.debug(f'Failed to remove "{blob.name}" from "{self.bucket}": {ex}')
|
|
161
|
+
if ignore_errors in (None, False):
|
|
162
|
+
raise
|
|
163
|
+
|
|
164
|
+
def copy(self, src_path, dest_path):
|
|
165
|
+
bucket = self._client.bucket(self.bucket)
|
|
166
|
+
src_blob = bucket.blob(src_path)
|
|
167
|
+
|
|
168
|
+
bucket.copy_blob(src_blob, bucket, dest_path)
|
|
169
|
+
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def generate(shape, idx):
|
|
7
|
+
if isinstance(idx, int):
|
|
8
|
+
assert len(shape) == 1, f'{shape}'
|
|
9
|
+
|
|
10
|
+
gens = [(idx,)]
|
|
11
|
+
elif isinstance(idx, slice):
|
|
12
|
+
assert len(shape) == 1, f'{shape}'
|
|
13
|
+
|
|
14
|
+
gens = [range(*idx.indices(shape[0]))]
|
|
15
|
+
elif isinstance(idx, (list, np.ndarray)):
|
|
16
|
+
assert len(shape) == 1, f'{shape}'
|
|
17
|
+
|
|
18
|
+
gens = [idx]
|
|
19
|
+
elif isinstance(idx, tuple):
|
|
20
|
+
assert len(shape) >= len(idx), f'{shape} vs. {idx}'
|
|
21
|
+
|
|
22
|
+
gens = [None] * len(shape)
|
|
23
|
+
eidx = None
|
|
24
|
+
for i, tidx in enumerate(idx):
|
|
25
|
+
if isinstance(tidx, int):
|
|
26
|
+
gens[i] = (tidx,)
|
|
27
|
+
elif isinstance(tidx, slice):
|
|
28
|
+
gens[i] = range(*tidx.indices(shape[i]))
|
|
29
|
+
elif hasattr(tidx, '__iter__'):
|
|
30
|
+
gens[i] = tidx
|
|
31
|
+
elif isinstance(tidx, Ellipsis.__class__):
|
|
32
|
+
eidx = i
|
|
33
|
+
break
|
|
34
|
+
|
|
35
|
+
if eidx is not None:
|
|
36
|
+
for i in range(1, len(idx) - eidx):
|
|
37
|
+
tidx = idx[-i]
|
|
38
|
+
if isinstance(tidx, int):
|
|
39
|
+
gens[-i] = (tidx,)
|
|
40
|
+
elif isinstance(tidx, slice):
|
|
41
|
+
gens[-i] = range(*tidx.indices(shape[-i]))
|
|
42
|
+
elif hasattr(tidx, '__iter__'):
|
|
43
|
+
gens[-i] = tidx
|
|
44
|
+
elif isinstance(tidx, Ellipsis.__class__):
|
|
45
|
+
raise ValueError(f'Wrong index {idx} for shape {shape}')
|
|
46
|
+
|
|
47
|
+
for i in range(len(shape)):
|
|
48
|
+
if gens[i] is None:
|
|
49
|
+
gens[i] = range(shape[i])
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError(f'Wrong index {idx} for shape {shape}')
|
|
52
|
+
|
|
53
|
+
return itertools.product(*gens)
|
|
54
|
+
|