torch-einops-utils 0.0.19__tar.gz → 0.0.21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-einops-utils
3
- Version: 0.0.19
3
+ Version: 0.0.21
4
4
  Summary: Personal utility functions
5
5
  Project-URL: Homepage, https://pypi.org/project/torch-einops-utils/
6
6
  Project-URL: Repository, https://github.com/lucidrains/torch-einops-utils
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "torch-einops-utils"
3
- version = "0.0.19"
3
+ version = "0.0.21"
4
4
  description = "Personal utility functions"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -210,6 +210,7 @@ def test_safe_functions():
210
210
  assert safe_stack([None]) is None
211
211
  assert (safe_stack([t1]) == t1).all()
212
212
  assert (safe_stack([t1, None]) == t1).all()
213
+ assert safe_stack([t1]).shape == (1, 2, 3)
213
214
  assert safe_stack([t1, t2]).shape == (2, 2, 3)
214
215
 
215
216
  assert safe_cat([]) is None
@@ -23,6 +23,9 @@ def identity(t, *args, **kwargs):
23
23
  def first(arr):
24
24
  return arr[0]
25
25
 
26
+ def compact(arr):
27
+ return [*filter(exists, arr)]
28
+
26
29
  def maybe(fn):
27
30
 
28
31
  if not exists(fn):
@@ -40,7 +43,7 @@ def maybe(fn):
40
43
  def safe(fn):
41
44
  @wraps(fn)
42
45
  def inner(tensors, *args, **kwargs):
43
- tensors = [*filter(exists, tensors)]
46
+ tensors = compact(tensors)
44
47
 
45
48
  if len(tensors) == 0:
46
49
  return None
@@ -160,8 +163,12 @@ def align_dims_left(
160
163
 
161
164
  # cat and stack
162
165
 
163
- @safe
164
166
  def safe_stack(tensors, dim = 0):
167
+ tensors = compact(tensors)
168
+
169
+ if len(tensors) == 0:
170
+ return None
171
+
165
172
  return stack(tensors, dim = dim)
166
173
 
167
174
  @safe
@@ -237,7 +244,8 @@ def pad_sequence(
237
244
  value = 0.,
238
245
  left = False,
239
246
  dim_stack = 0,
240
- return_lens = False
247
+ return_lens = False,
248
+ pad_lens = False # returns padding length instead of sequence lengths
241
249
  ):
242
250
  if len(tensors) == 0:
243
251
  return None
@@ -255,6 +263,9 @@ def pad_sequence(
255
263
  if not return_lens:
256
264
  return stacked
257
265
 
266
+ if pad_lens:
267
+ lens = max_len - lens
268
+
258
269
  return stacked, lens
259
270
 
260
271
  # tree flatten with inverse