torch-einops-utils 0.0.19__tar.gz → 0.0.21__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torch_einops_utils-0.0.19 → torch_einops_utils-0.0.21}/PKG-INFO +1 -1
- {torch_einops_utils-0.0.19 → torch_einops_utils-0.0.21}/pyproject.toml +1 -1
- {torch_einops_utils-0.0.19 → torch_einops_utils-0.0.21}/tests/test_utils.py +1 -0
- {torch_einops_utils-0.0.19 → torch_einops_utils-0.0.21}/torch_einops_utils/torch_einops_utils.py +14 -3
- {torch_einops_utils-0.0.19 → torch_einops_utils-0.0.21}/.github/workflows/python-publish.yml +0 -0
- {torch_einops_utils-0.0.19 → torch_einops_utils-0.0.21}/.github/workflows/test.yml +0 -0
- {torch_einops_utils-0.0.19 → torch_einops_utils-0.0.21}/.gitignore +0 -0
- {torch_einops_utils-0.0.19 → torch_einops_utils-0.0.21}/LICENSE +0 -0
- {torch_einops_utils-0.0.19 → torch_einops_utils-0.0.21}/README.md +0 -0
- {torch_einops_utils-0.0.19 → torch_einops_utils-0.0.21}/tests/test_save_load.py +0 -0
- {torch_einops_utils-0.0.19 → torch_einops_utils-0.0.21}/torch_einops_utils/__init__.py +0 -0
- {torch_einops_utils-0.0.19 → torch_einops_utils-0.0.21}/torch_einops_utils/save_load.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torch-einops-utils
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.21
|
|
4
4
|
Summary: Personal utility functions
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/torch-einops-utils/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/torch-einops-utils
|
|
@@ -210,6 +210,7 @@ def test_safe_functions():
|
|
|
210
210
|
assert safe_stack([None]) is None
|
|
211
211
|
assert (safe_stack([t1]) == t1).all()
|
|
212
212
|
assert (safe_stack([t1, None]) == t1).all()
|
|
213
|
+
assert safe_stack([t1]).shape == (1, 2, 3)
|
|
213
214
|
assert safe_stack([t1, t2]).shape == (2, 2, 3)
|
|
214
215
|
|
|
215
216
|
assert safe_cat([]) is None
|
{torch_einops_utils-0.0.19 → torch_einops_utils-0.0.21}/torch_einops_utils/torch_einops_utils.py
RENAMED
|
@@ -23,6 +23,9 @@ def identity(t, *args, **kwargs):
|
|
|
23
23
|
def first(arr):
|
|
24
24
|
return arr[0]
|
|
25
25
|
|
|
26
|
+
def compact(arr):
|
|
27
|
+
return [*filter(exists, arr)]
|
|
28
|
+
|
|
26
29
|
def maybe(fn):
|
|
27
30
|
|
|
28
31
|
if not exists(fn):
|
|
@@ -40,7 +43,7 @@ def maybe(fn):
|
|
|
40
43
|
def safe(fn):
|
|
41
44
|
@wraps(fn)
|
|
42
45
|
def inner(tensors, *args, **kwargs):
|
|
43
|
-
tensors =
|
|
46
|
+
tensors = compact(tensors)
|
|
44
47
|
|
|
45
48
|
if len(tensors) == 0:
|
|
46
49
|
return None
|
|
@@ -160,8 +163,12 @@ def align_dims_left(
|
|
|
160
163
|
|
|
161
164
|
# cat and stack
|
|
162
165
|
|
|
163
|
-
@safe
|
|
164
166
|
def safe_stack(tensors, dim = 0):
|
|
167
|
+
tensors = compact(tensors)
|
|
168
|
+
|
|
169
|
+
if len(tensors) == 0:
|
|
170
|
+
return None
|
|
171
|
+
|
|
165
172
|
return stack(tensors, dim = dim)
|
|
166
173
|
|
|
167
174
|
@safe
|
|
@@ -237,7 +244,8 @@ def pad_sequence(
|
|
|
237
244
|
value = 0.,
|
|
238
245
|
left = False,
|
|
239
246
|
dim_stack = 0,
|
|
240
|
-
return_lens = False
|
|
247
|
+
return_lens = False,
|
|
248
|
+
pad_lens = False # returns padding length instead of sequence lengths
|
|
241
249
|
):
|
|
242
250
|
if len(tensors) == 0:
|
|
243
251
|
return None
|
|
@@ -255,6 +263,9 @@ def pad_sequence(
|
|
|
255
263
|
if not return_lens:
|
|
256
264
|
return stacked
|
|
257
265
|
|
|
266
|
+
if pad_lens:
|
|
267
|
+
lens = max_len - lens
|
|
268
|
+
|
|
258
269
|
return stacked, lens
|
|
259
270
|
|
|
260
271
|
# tree flatten with inverse
|
{torch_einops_utils-0.0.19 → torch_einops_utils-0.0.21}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|