returnn 1.20250830.114445__py3-none-any.whl → 1.20250902.10950__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250830.114445
3
+ Version: 1.20250902.10950
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250830.114445'
2
- long_version = '1.20250830.114445+git.24547d9'
1
+ version = '1.20250902.010950'
2
+ long_version = '1.20250902.010950+git.9d5debf'
@@ -2082,7 +2082,7 @@ class _DimMixin:
2082
2082
  :return: self + other. note that this is not commutative, i.e. different from other + self.
2083
2083
  :rtype: Dim
2084
2084
  """
2085
- if isinstance(other, int) and other == 0:
2085
+ if _is_const_dim_value(other, 0):
2086
2086
  return self
2087
2087
  cache_key = ("add", other)
2088
2088
  cache = self.get_same_base()._make_extra().cache_dim_math
@@ -2090,11 +2090,11 @@ class _DimMixin:
2090
2090
  if cache_entry:
2091
2091
  cache_entry.complete_dyn_size()
2092
2092
  return cache_entry
2093
- term = _OpLinearTerm.from_dim(self)
2094
- term.extend_add_sub_(other, kind="add", right=True)
2095
- dim = term.as_dim()
2096
- cache[cache_key] = dim
2097
- return dim
2093
+ res = _MathFindMatchingAdditive(start=self, right=True, other=other).search_and_maybe_replace()
2094
+ if not res:
2095
+ res = _math_get_dim_via_bin_op([self, other], "add")
2096
+ cache[cache_key] = res
2097
+ return res
2098
2098
 
2099
2099
  def __radd__(self: Dim, other):
2100
2100
  """
@@ -2102,7 +2102,7 @@ class _DimMixin:
2102
2102
  :return: other + self
2103
2103
  :rtype: Dim
2104
2104
  """
2105
- if isinstance(other, int) and other == 0:
2105
+ if _is_const_dim_value(other, 0):
2106
2106
  return self
2107
2107
  cache_key = ("add_left", other)
2108
2108
  cache = self.get_same_base()._make_extra().cache_dim_math
@@ -2110,18 +2110,18 @@ class _DimMixin:
2110
2110
  if cache_entry:
2111
2111
  cache_entry.complete_dyn_size()
2112
2112
  return cache_entry
2113
- term = _OpLinearTerm.from_dim(self)
2114
- term.extend_add_sub_(other, kind="add", right=False)
2115
- dim = term.as_dim()
2116
- cache[cache_key] = dim
2117
- return dim
2113
+ res = _MathFindMatchingAdditive(start=self, right=False, other=other).search_and_maybe_replace()
2114
+ if not res:
2115
+ res = _math_get_dim_via_bin_op([other, self], "add")
2116
+ cache[cache_key] = res
2117
+ return res
2118
2118
 
2119
2119
  def __sub__(self, other):
2120
2120
  """
2121
2121
  :param Dim|int other:
2122
2122
  :rtype: Dim
2123
2123
  """
2124
- if isinstance(other, int) and other == 0:
2124
+ if _is_const_dim_value(other, 0):
2125
2125
  return self
2126
2126
  return self.sub_right(other)
2127
2127
 
@@ -2131,19 +2131,24 @@ class _DimMixin:
2131
2131
  :return: self - other
2132
2132
  :rtype: Dim
2133
2133
  """
2134
- if isinstance(other, int) and other == 0:
2134
+ if _is_const_dim_value(other, 0):
2135
2135
  return self
2136
+ if (
2137
+ self.derived_from_op
2138
+ and self.derived_from_op.kind == "add"
2139
+ and len(self.derived_from_op.inputs) == 2
2140
+ and _dim_or_const_equal(other, self.derived_from_op.inputs[1])
2141
+ ):
2142
+ return self.derived_from_op.inputs[0]
2136
2143
  cache_key = ("sub", other)
2137
2144
  cache = self.get_same_base()._make_extra().cache_dim_math
2138
2145
  cache_entry = cache.get(cache_key, None)
2139
2146
  if cache_entry:
2140
2147
  cache_entry.complete_dyn_size()
2141
2148
  return cache_entry
2142
- term = _OpLinearTerm.from_dim(self)
2143
- term.extend_add_sub_(other, kind="sub", right=True)
2144
- dim = term.as_dim()
2145
- cache[cache_key] = dim
2146
- return dim
2149
+ res = self + (-other)
2150
+ cache[cache_key] = res
2151
+ return res
2147
2152
 
2148
2153
  def sub_left(self: Dim, other):
2149
2154
  """
@@ -2151,44 +2156,55 @@ class _DimMixin:
2151
2156
  :return: (-other) + self
2152
2157
  :rtype: Dim
2153
2158
  """
2154
- if isinstance(other, int) and other == 0:
2159
+ if _is_const_dim_value(other, 0):
2155
2160
  return self
2161
+ if (
2162
+ self.derived_from_op
2163
+ and self.derived_from_op.kind == "add"
2164
+ and len(self.derived_from_op.inputs) == 2
2165
+ and _dim_or_const_equal(other, self.derived_from_op.inputs[0])
2166
+ ):
2167
+ return self.derived_from_op.inputs[1]
2156
2168
  cache_key = ("sub_left", other)
2157
2169
  cache = self.get_same_base()._make_extra().cache_dim_math
2158
2170
  cache_entry = cache.get(cache_key, None)
2159
2171
  if cache_entry:
2160
2172
  cache_entry.complete_dyn_size()
2161
2173
  return cache_entry
2162
- term = _OpLinearTerm.from_dim(self)
2163
- term.extend_add_sub_(other, kind="sub", right=False)
2164
- dim = term.as_dim()
2165
- cache[cache_key] = dim
2166
- return dim
2174
+ res = (-other) + self
2175
+ cache[cache_key] = res
2176
+ return res
2167
2177
 
2168
2178
  def __mul__(self: Dim, other):
2169
2179
  """
2170
2180
  :param Dim|int other:
2171
2181
  :rtype: Dim
2172
2182
  """
2183
+ if isinstance(other, _d.Dim) and other.is_constant_static_dim():
2184
+ other = other.dimension # makes matching easier
2173
2185
  if isinstance(other, int) and other == 1:
2174
2186
  return self
2187
+ if self.is_constant_static_dim() and isinstance(other, _d.Dim):
2188
+ return self.dimension * other # use rmul
2175
2189
  cache_key = ("mul", other)
2176
2190
  cache = self.get_same_base()._make_extra().cache_dim_math
2177
2191
  cache_entry = cache.get(cache_key, None)
2178
2192
  if cache_entry:
2179
2193
  cache_entry.complete_dyn_size()
2180
2194
  return cache_entry
2181
- term = _OpLinearTerm.from_dim(self)
2182
- term.extend_mul_div_(other, kind="mul", right=True)
2183
- dim = term.as_dim()
2184
- cache[cache_key] = dim
2185
- return dim
2195
+ res = _math_find_matching_mult(start=self, right=True, other=other)
2196
+ if not res:
2197
+ res = _math_get_dim_via_bin_op([self, other], "mul")
2198
+ cache[cache_key] = res
2199
+ return res
2186
2200
 
2187
2201
  def __rmul__(self: Dim, other):
2188
2202
  """
2189
2203
  :param Dim|int other:
2190
2204
  :rtype: Dim
2191
2205
  """
2206
+ if isinstance(other, _d.Dim) and other.is_constant_static_dim():
2207
+ other = other.dimension # makes matching easier
2192
2208
  if isinstance(other, int) and other == 1:
2193
2209
  return self
2194
2210
  cache_key = ("mul_left", other)
@@ -2197,38 +2213,45 @@ class _DimMixin:
2197
2213
  if cache_entry:
2198
2214
  cache_entry.complete_dyn_size()
2199
2215
  return cache_entry
2200
- term = _OpLinearTerm.from_dim(self)
2201
- term.extend_mul_div_(other, kind="mul", right=False)
2202
- dim = term.as_dim()
2203
- cache[cache_key] = dim
2204
- return dim
2216
+ res = _math_find_matching_mult(start=self, right=False, other=other)
2217
+ if not res:
2218
+ res = _math_get_dim_via_bin_op([other, self], "mul")
2219
+ cache[cache_key] = res
2220
+ return res
2205
2221
 
2206
2222
  def __floordiv__(self: Dim, other):
2207
2223
  """
2208
2224
  :param Dim|int other:
2209
2225
  :rtype: Dim
2210
2226
  """
2227
+ if isinstance(other, _d.Dim) and other.is_constant_static_dim():
2228
+ other = other.dimension # makes matching easier
2211
2229
  if isinstance(other, int) and other == 1:
2212
2230
  return self
2231
+ if (
2232
+ self.derived_from_op
2233
+ and self.derived_from_op.kind == "mul"
2234
+ and len(self.derived_from_op.inputs) == 2
2235
+ and _dim_or_const_equal(other, self.derived_from_op.inputs[1])
2236
+ ):
2237
+ return self.derived_from_op.inputs[0]
2213
2238
  cache_key = ("floordiv", other)
2214
2239
  cache = self.get_same_base()._make_extra().cache_dim_math
2215
2240
  cache_entry = cache.get(cache_key, None)
2216
2241
  if cache_entry:
2217
2242
  cache_entry.complete_dyn_size()
2218
2243
  return cache_entry
2219
- term = _OpLinearTerm.from_dim(self)
2220
- term.extend_mul_div_(other, kind="floordiv", right=True)
2221
- dim = term.as_dim()
2222
- cache[cache_key] = dim
2223
- return dim
2244
+ res = _math_find_matching_div(start=self, right=True, other=other, kind="floordiv")
2245
+ if not res:
2246
+ res = _math_get_dim_via_bin_op([self, other], "floordiv")
2247
+ cache[cache_key] = res
2248
+ return res
2224
2249
 
2225
2250
  def __truediv__(self, other):
2226
2251
  """
2227
2252
  :param Dim|int other:
2228
2253
  :rtype: Dim
2229
2254
  """
2230
- if isinstance(other, int) and other == 1:
2231
- return self
2232
2255
  return self.div_right(other)
2233
2256
 
2234
2257
  def div_left(self: Dim, other):
@@ -2236,76 +2259,112 @@ class _DimMixin:
2236
2259
  :param Dim|int other:
2237
2260
  :rtype: Dim
2238
2261
  """
2262
+ if isinstance(other, _d.Dim) and other.is_constant_static_dim():
2263
+ other = other.dimension # makes matching easier
2239
2264
  if isinstance(other, int) and other == 1:
2240
2265
  return self
2266
+ if (
2267
+ self.derived_from_op
2268
+ and self.derived_from_op.kind == "mul"
2269
+ and len(self.derived_from_op.inputs) == 2
2270
+ and _dim_or_const_equal(other, self.derived_from_op.inputs[0])
2271
+ ):
2272
+ return self.derived_from_op.inputs[1]
2241
2273
  cache_key = ("truediv_left", other)
2242
2274
  cache = self.get_same_base()._make_extra().cache_dim_math
2243
2275
  cache_entry = cache.get(cache_key, None)
2244
2276
  if cache_entry:
2245
2277
  cache_entry.complete_dyn_size()
2246
2278
  return cache_entry
2247
- term = _OpLinearTerm.from_dim(self)
2248
- term.extend_mul_div_(other, kind="truediv", right=False)
2249
- dim = term.as_dim()
2250
- cache[cache_key] = dim
2251
- return dim
2279
+ res = _math_find_matching_div(start=self, right=True, other=other, kind="truediv_left")
2280
+ if not res:
2281
+ res = _math_get_dim_via_bin_op([self, other], "truediv_left")
2282
+ cache[cache_key] = res
2283
+ return res
2252
2284
 
2253
2285
  def div_right(self: Dim, other):
2254
2286
  """
2255
2287
  :param Dim|int other:
2256
2288
  :rtype: Dim
2257
2289
  """
2290
+ if isinstance(other, _d.Dim) and other.is_constant_static_dim():
2291
+ other = other.dimension # makes matching easier
2258
2292
  if isinstance(other, int) and other == 1:
2259
2293
  return self
2294
+ if (
2295
+ self.derived_from_op
2296
+ and self.derived_from_op.kind == "mul"
2297
+ and len(self.derived_from_op.inputs) == 2
2298
+ and _dim_or_const_equal(other, self.derived_from_op.inputs[1])
2299
+ ):
2300
+ return self.derived_from_op.inputs[0]
2260
2301
  cache_key = ("truediv", other)
2261
2302
  cache = self.get_same_base()._make_extra().cache_dim_math
2262
2303
  cache_entry = cache.get(cache_key, None)
2263
2304
  if cache_entry:
2264
2305
  cache_entry.complete_dyn_size()
2265
2306
  return cache_entry
2266
- term = _OpLinearTerm.from_dim(self)
2267
- term.extend_mul_div_(other, kind="truediv", right=True)
2268
- dim = term.as_dim()
2269
- cache[cache_key] = dim
2270
- return dim
2307
+ res = _math_find_matching_div(start=self, right=True, other=other, kind="truediv")
2308
+ if not res:
2309
+ res = _math_get_dim_via_bin_op([self, other], "truediv")
2310
+ cache[cache_key] = res
2311
+ return res
2271
2312
 
2272
2313
  def ceildiv_left(self: Dim, other):
2273
2314
  """
2274
2315
  :param Dim|int other:
2275
2316
  :rtype: Dim
2276
2317
  """
2318
+ if isinstance(other, _d.Dim) and other.is_constant_static_dim():
2319
+ other = other.dimension # makes matching easier
2277
2320
  if isinstance(other, int) and other == 1:
2278
2321
  return self
2322
+ if (
2323
+ self.derived_from_op
2324
+ and self.derived_from_op.kind == "mul"
2325
+ and len(self.derived_from_op.inputs) == 2
2326
+ and _dim_or_const_equal(other, self.derived_from_op.inputs[0])
2327
+ ):
2328
+ return self.derived_from_op.inputs[1]
2279
2329
  cache_key = ("ceildiv_left", other)
2280
2330
  cache = self.get_same_base()._make_extra().cache_dim_math
2281
2331
  cache_entry = cache.get(cache_key, None)
2282
2332
  if cache_entry:
2283
2333
  cache_entry.complete_dyn_size()
2284
2334
  return cache_entry
2285
- term = _OpLinearTerm.from_dim(self)
2286
- term.extend_mul_div_(other, kind="ceildiv", right=False)
2287
- dim = term.as_dim()
2288
- cache[cache_key] = dim
2289
- return dim
2335
+ res = _math_find_matching_div(start=self, right=True, other=other, kind="ceildiv_left")
2336
+ if not res:
2337
+ res = _math_get_dim_via_bin_op([self, other], "ceildiv_left")
2338
+ cache[cache_key] = res
2339
+ return res
2290
2340
 
2291
2341
  def ceildiv_right(self: Dim, other):
2292
2342
  """
2293
2343
  :param Dim|int other:
2294
2344
  :rtype: Dim
2295
2345
  """
2346
+ if isinstance(other, _d.Dim) and other.is_constant_static_dim():
2347
+ other = other.dimension # makes matching easier
2296
2348
  if isinstance(other, int) and other == 1:
2297
2349
  return self
2350
+ if (
2351
+ self.derived_from_op
2352
+ and self.derived_from_op.kind == "mul"
2353
+ and len(self.derived_from_op.inputs) == 2
2354
+ and _dim_or_const_equal(other, self.derived_from_op.inputs[1])
2355
+ ):
2356
+ return self.derived_from_op.inputs[0]
2298
2357
  cache_key = ("ceildiv", other)
2299
2358
  cache = self.get_same_base()._make_extra().cache_dim_math
2300
2359
  cache_entry = cache.get(cache_key, None)
2301
2360
  if cache_entry:
2302
2361
  cache_entry.complete_dyn_size()
2303
2362
  return cache_entry
2304
- term = _OpLinearTerm.from_dim(self)
2305
- term.extend_mul_div_(other, kind="ceildiv", right=True)
2306
- dim = term.as_dim()
2307
- cache[cache_key] = dim
2308
- return dim
2363
+ res = _math_find_matching_div(start=self, right=True, other=other, kind="ceildiv")
2364
+ if not res:
2365
+ res = _math_get_dim_via_bin_op([self, other], "ceildiv")
2366
+ cache[cache_key] = res
2367
+ return res
2309
2368
 
2310
2369
  def __neg__(self):
2311
2370
  """
@@ -2348,6 +2407,7 @@ def _make_constant_static_dim(value, kind=None):
2348
2407
  :param Entity|None kind:
2349
2408
  :rtype: Dim
2350
2409
  """
2410
+ assert isinstance(value, int)
2351
2411
  return _d.Dim(
2352
2412
  dimension=value,
2353
2413
  kind=kind or DimTypes.Unspecified,
@@ -2357,30 +2417,188 @@ def _make_constant_static_dim(value, kind=None):
2357
2417
  )
2358
2418
 
2359
2419
 
2360
- def _math_get_dim_via_bin_op(a: Dim, b: Dim, op_kind: str) -> Dim:
2361
- assert op_kind in {"add", "mul"}
2362
- op_kind_ = op_kind
2363
- a_, b_ = a, b
2364
- if a.is_constant_static_dim() and not b.is_constant_static_dim():
2365
- op_kind_ = op_kind + "_left"
2366
- a_, b_ = b, a
2367
- # noinspection PyProtectedMember
2368
- cache, cache_key, res = a_._cache_dim_math_get(op_kind_, b_)
2369
- if res:
2370
- return res
2371
- if a.dimension is not None and b.dimension is not None:
2372
- dim_value = getattr(operator, op_kind)(a.dimension, b.dimension)
2420
+ def _dim_or_const_equal(a: Union[Dim, int], b: Union[Dim, int]) -> bool:
2421
+ if isinstance(a, int):
2422
+ if isinstance(b, int):
2423
+ return a == b
2424
+ elif isinstance(b, _d.Dim):
2425
+ return a == b.dimension and b.is_constant_static_dim()
2426
+ else:
2427
+ raise TypeError(f"unexpected b type {type(b)}")
2428
+ elif isinstance(a, _d.Dim):
2429
+ if a.is_constant_static_dim():
2430
+ if isinstance(b, int):
2431
+ return a.dimension == b
2432
+ elif isinstance(b, _d.Dim):
2433
+ return a.dimension == b.dimension and b.is_constant_static_dim()
2434
+ else:
2435
+ raise TypeError(f"unexpected b type {type(b)}")
2436
+ else:
2437
+ if isinstance(b, int):
2438
+ return False
2439
+ elif isinstance(b, _d.Dim):
2440
+ return a == b
2441
+ else:
2442
+ raise TypeError(f"unexpected b type {type(b)}")
2443
+ else:
2444
+ raise TypeError(f"unexpected a type {type(a)}")
2445
+
2446
+
2447
+ _BinOps = {
2448
+ "add": operator.add,
2449
+ "mul": operator.mul,
2450
+ "sub": operator.sub,
2451
+ "floordiv": operator.floordiv,
2452
+ "truediv": operator.floordiv,
2453
+ "truediv_left": operator.floordiv,
2454
+ "ceildiv": lambda a, b: -(-a // b),
2455
+ "ceildiv_left": lambda a, b: -(-a // b),
2456
+ }
2457
+
2458
+ _BinOpStrs = {
2459
+ "add": "+",
2460
+ "mul": "*",
2461
+ "sub": "-",
2462
+ "floordiv": "//",
2463
+ "truediv": "/",
2464
+ "truediv_left": " /l ",
2465
+ "ceildiv": "/",
2466
+ "ceildiv_left": " /l ",
2467
+ }
2468
+
2469
+
2470
+ def _math_get_dim_via_bin_op(dims: Sequence[Union[Dim, int]], op_kind: str) -> Dim:
2471
+ dims = [d if isinstance(d, _d.Dim) else _make_constant_static_dim(d) for d in dims]
2472
+ if all(d.dimension is not None for d in dims):
2473
+ op = _BinOps[op_kind]
2474
+ dim_value = dims[0].dimension
2475
+ for d in dims[1:]:
2476
+ dim_value = op(dim_value, d.dimension)
2373
2477
  else:
2374
2478
  dim_value = None
2375
- res = _d.Dim(
2376
- kind=_get_merged_dim_kind((a, b)),
2377
- description=_get_description(a) + {"add": "+", "mul": "*"}[op_kind] + _get_description(b),
2479
+ if all(d.is_constant_static_dim() for d in dims):
2480
+ return _make_constant_static_dim(dim_value, kind=_get_merged_dim_kind(dims))
2481
+ desc = _BinOpStrs[op_kind].join(_get_description(d) for d in dims)
2482
+ if op_kind.startswith("ceildiv"):
2483
+ desc = f"⌈{desc}⌉"
2484
+ return _d.Dim(
2485
+ kind=_get_merged_dim_kind(dims),
2486
+ description=desc,
2378
2487
  dimension=dim_value,
2379
- derived_from_op=Op(kind=op_kind, inputs=[a, b]),
2380
- derived_from_tag=_representative_tag((a, b)),
2488
+ derived_from_op=Op(kind=op_kind, inputs=list(dims)),
2489
+ derived_from_tag=_representative_tag(dims),
2381
2490
  )
2382
- cache[cache_key] = res
2383
- return res
2491
+
2492
+
2493
+ def _is_const_dim_value(d: Union[Dim, int], value: int) -> bool:
2494
+ if isinstance(d, int):
2495
+ return d == value
2496
+ elif isinstance(d, _d.Dim):
2497
+ return d.is_constant_static_dim() and d.dimension == value
2498
+ else:
2499
+ raise TypeError(f"unexpected type {type(d)}")
2500
+
2501
+
2502
+ class _MathFindMatchingAdditive:
2503
+ def __init__(self, start: Dim, *, max_depth: int = 2, right: bool, other: Union[int, Dim]):
2504
+ self.start = start
2505
+ self.max_depth = max_depth
2506
+ self.right = right
2507
+ self.other = other
2508
+
2509
+ def _check_and_maybe_replace(self, candidate: Dim) -> Optional[Dim]:
2510
+ """
2511
+ Check and return potential replacement for candidate, when adding `other` to it.
2512
+ """
2513
+ other = self.other
2514
+ if isinstance(other, int) or other.is_constant_static_dim():
2515
+ if candidate.is_constant_static_dim():
2516
+ return _math_get_dim_via_bin_op([candidate, other] if self.right else [other, candidate], "add")
2517
+ return None
2518
+ if candidate == other:
2519
+ return candidate.__rmul__(2)
2520
+ c_op = candidate.derived_from_op
2521
+ if c_op and c_op.kind == "mul" and len(c_op.inputs) == 2 and c_op.inputs[1] == other:
2522
+ factor = (c_op.inputs[0] + 1) if self.right else (1 + c_op.inputs[0])
2523
+ if factor.is_constant_static_dim():
2524
+ if factor.dimension == 0:
2525
+ return factor
2526
+ factor = factor.dimension
2527
+ return factor * c_op.inputs[1]
2528
+ o_op = other.derived_from_op
2529
+ if not o_op or o_op.kind != "mul" or len(o_op.inputs) != 2:
2530
+ return None
2531
+ o_base, other = o_op.inputs # continue checking this
2532
+ if candidate == other:
2533
+ factor = (1 + o_base) if self.right else (o_base + 1)
2534
+ if factor.is_constant_static_dim():
2535
+ if factor.dimension == 0:
2536
+ return factor
2537
+ factor = factor.dimension
2538
+ return factor * candidate
2539
+ if c_op and c_op.kind == "mul" and len(c_op.inputs) == 2 and c_op.inputs[1] == other:
2540
+ factor = (c_op.inputs[0] + o_base) if self.right else (o_base + c_op.inputs[0])
2541
+ if factor.is_constant_static_dim():
2542
+ if factor.dimension == 0:
2543
+ return factor
2544
+ factor = factor.dimension
2545
+ return factor * c_op.inputs[1]
2546
+ return None
2547
+
2548
+ def search_and_maybe_replace(self) -> Optional[Dim]:
2549
+ """search"""
2550
+ cur = self.start
2551
+ depth = 0
2552
+ history = []
2553
+ while True:
2554
+ res_cur = self._check_and_maybe_replace(cur)
2555
+ if res_cur:
2556
+ if depth > 0 and res_cur.is_constant_static_dim() and res_cur.dimension == 0:
2557
+ res_cur = history.pop(-1)
2558
+ res = res_cur
2559
+ for h in reversed(history):
2560
+ res = _math_get_dim_via_bin_op([h, res] if self.right else [res, h], "add")
2561
+ return res
2562
+ depth += 1
2563
+ if depth > self.max_depth:
2564
+ return None
2565
+ op = cur.derived_from_op
2566
+ if not op or op.kind != "add" or len(op.inputs) != 2:
2567
+ return None
2568
+ cur = op.inputs[1 if self.right else 0]
2569
+ hist = op.inputs[0 if self.right else 1]
2570
+ history.append(hist)
2571
+
2572
+
2573
+ def _math_find_matching_mult(start: Dim, other: Union[int, Dim], *, right: bool) -> Optional[Dim]:
2574
+ if (isinstance(other, int) or other.is_constant_static_dim()) and start.is_constant_static_dim():
2575
+ return _math_get_dim_via_bin_op([start, other] if right else [other, start], "mul")
2576
+ c_op = start.derived_from_op
2577
+ if c_op and c_op.kind == "mul" and len(c_op.inputs) == 2:
2578
+ if right:
2579
+ return c_op.inputs[0] * (c_op.inputs[1] * other)
2580
+ else:
2581
+ return (other * c_op.inputs[0]) * c_op.inputs[1]
2582
+ return None
2583
+
2584
+
2585
+ _DivKindToMeth: Dict[str, Callable[[Dim, Dim], Dim]] = {
2586
+ "truediv": _DimMixin.div_right,
2587
+ "truediv_left": _DimMixin.div_left,
2588
+ "ceildiv": _DimMixin.ceildiv_right,
2589
+ "ceildiv_left": _DimMixin.ceildiv_left,
2590
+ "floordiv": _DimMixin.__floordiv__,
2591
+ }
2592
+
2593
+
2594
+ def _math_find_matching_div(start: Dim, other: Union[int, Dim], *, right: bool, kind: str) -> Optional[Dim]:
2595
+ if (isinstance(other, int) or other.is_constant_static_dim()) and start.is_constant_static_dim():
2596
+ return _math_get_dim_via_bin_op([start, other] if right else [other, start], kind)
2597
+ c_op = start.derived_from_op
2598
+ if c_op and c_op.kind == kind and len(c_op.inputs) == 2:
2599
+ meth = _DivKindToMeth[kind]
2600
+ return meth(c_op.inputs[0], c_op.inputs[1] * other if right else other * c_op.inputs[1])
2601
+ return None
2384
2602
 
2385
2603
 
2386
2604
  class Op:
@@ -2437,457 +2655,7 @@ def _get_description(dim, brackets=True):
2437
2655
  return "unnamed_%s_dim%s" % (dim.kind, dim.dimension if dim.dimension is not None else "?")
2438
2656
 
2439
2657
 
2440
- class _OpMultTerm:
2441
- """
2442
- represents sth like a * b * c
2443
- """
2444
-
2445
- @classmethod
2446
- def from_dim(cls, dim: Dim) -> _OpMultTerm:
2447
- """
2448
- :param dim:
2449
- :return: op mult term
2450
- """
2451
- dim = dim.get_same_base()
2452
- if dim.dimension == 1 and dim.is_constant_static_dim():
2453
- return cls.one()
2454
- if dim.derived_from_op and dim.derived_from_op.kind == "mul":
2455
- return cls(list(dim.derived_from_op.inputs))
2456
- return cls([dim])
2457
-
2458
- @classmethod
2459
- def from_dim_factors(cls, dims: List[Dim]) -> _OpMultTerm:
2460
- """from dim factors"""
2461
- res = cls.one()
2462
- for d in dims:
2463
- res.extend_mul_div_(d, kind="mul", right=True)
2464
- return res
2465
-
2466
- @classmethod
2467
- def one(cls) -> _OpMultTerm:
2468
- """1"""
2469
- return cls([])
2470
-
2471
- def __init__(self, terms: List[Dim]):
2472
- self.terms = terms
2473
-
2474
- def __hash__(self):
2475
- return hash(tuple(self.terms))
2476
-
2477
- def __eq__(self, other):
2478
- """
2479
- :param _OpMultTerm other:
2480
- """
2481
- if isinstance(other, _OpMultTerm):
2482
- return self.terms == other.terms
2483
- return False
2484
-
2485
- def __ne__(self, other):
2486
- return not self.__eq__(other)
2487
-
2488
- def __repr__(self):
2489
- return "Dim._OpMultTerm(%r)" % (self.terms,)
2490
-
2491
- @property
2492
- def dimension(self) -> Optional[int]:
2493
- """static dim or None"""
2494
- dim = 1
2495
- for part in self.terms:
2496
- if part.dimension is None:
2497
- return None
2498
- dim *= part.dimension
2499
- return dim
2500
-
2501
- def base_term(self) -> Dim:
2502
- """base term (Dim)"""
2503
- assert self.terms
2504
- return self.terms[-1]
2505
-
2506
- def is_one(self) -> bool:
2507
- """is 1"""
2508
- return not self.terms
2509
-
2510
- def is_constant_static_dim(self) -> bool:
2511
- """is constant static dim"""
2512
- if not self.terms:
2513
- return True
2514
- return all(term.is_constant_static_dim() for term in self.terms)
2515
-
2516
- def copy(self) -> _OpMultTerm:
2517
- """copy"""
2518
- return _OpMultTerm(list(self.terms))
2519
-
2520
- def negative(self) -> _OpMultTerm:
2521
- """negative"""
2522
- if self.terms and self.terms[0].is_constant_static_dim() and self.terms[0].dimension == -1:
2523
- return _OpMultTerm(self.terms[1:])
2524
- res = self.copy()
2525
- res.extend_mul_div_(_make_constant_static_dim(-1), kind="mul", right=False)
2526
- return res
2527
-
2528
- def divisible(self, other, right):
2529
- """
2530
- :param Dim other:
2531
- :param bool right:
2532
- :return: whether we can divide other, without remainder
2533
- :rtype: bool
2534
- """
2535
- if not self.terms:
2536
- return False
2537
- if other.derived_from_op and other.derived_from_op.kind == "mul":
2538
- tmp = self.copy()
2539
- for term in other.derived_from_op.inputs if right else reversed(other.derived_from_op.inputs):
2540
- if not tmp.divisible(term, right=right):
2541
- return False
2542
- tmp.extend_mul_div_(term, kind="truediv", right=right)
2543
- return True
2544
- most_recent_term = self.terms[-1 if right else 0]
2545
- if other == most_recent_term:
2546
- return True
2547
- if most_recent_term.dimension is not None and other.dimension is not None:
2548
- if most_recent_term.dimension % other.dimension == 0:
2549
- return True
2550
- return False
2551
-
2552
- def can_simplify(self, other, kind, right):
2553
- """
2554
- :param Dim other:
2555
- :param str kind:
2556
- :param bool right:
2557
- :return: whether we can simplify when applying this operation
2558
- :rtype: bool
2559
- """
2560
- if other.derived_from_op and other.derived_from_op.kind == "mul":
2561
- tmp = self.copy()
2562
- for term in other.derived_from_op.inputs if right else reversed(other.derived_from_op.inputs):
2563
- if not tmp.can_simplify(term, kind=kind, right=right):
2564
- return False
2565
- tmp.extend_mul_div_(term, kind=kind, right=right)
2566
- return True
2567
- idx = self._simplify_term_idx(other, kind=kind, right=right)
2568
- return idx is not None
2569
-
2570
- def _simplify_term_idx(self, other, kind, right):
2571
- """
2572
- :param Dim other:
2573
- :param str kind:
2574
- :param bool right:
2575
- :return: index of term to simplify
2576
- :rtype: int|None
2577
- """
2578
- if not self.terms:
2579
- return None
2580
- if kind == "mul":
2581
- # We want (b * a) // b != a.
2582
- # However, we want h * (2 * a // h) == 2 * a.
2583
- # So, for `mul`, and only for `mul`, check all terms, whether we can simplify some division-term.
2584
- for i, term in reversed(list(enumerate(self.terms))) if right else enumerate(self.terms):
2585
- assert isinstance(term, _d.Dim)
2586
- if term.derived_from_op:
2587
- if term.derived_from_op.kind == "truediv_" + ("right" if right else "left"):
2588
- if term.derived_from_op.inputs[-1] == other:
2589
- return i
2590
- if other.derived_from_op:
2591
- if other.derived_from_op.kind == "truediv_" + ("right" if not right else "left"):
2592
- if other.derived_from_op.inputs[-1] == term:
2593
- return i
2594
- if term.is_constant_static_dim() and other.is_constant_static_dim():
2595
- return i
2596
- # For the last/first term, extra checks.
2597
- i = len(self.terms) - 1 if right else 0
2598
- term = self.terms[i]
2599
- if kind.endswith("div") and other == term:
2600
- return i
2601
- op_kind = kind + "_" + ("right" if right else "left")
2602
- if term.derived_from_op and term.derived_from_op.kind == op_kind:
2603
- return i
2604
- return None
2605
-
2606
- def extend_mul_div_(self, other, kind, right):
2607
- """
2608
- :param Dim other:
2609
- :param str kind:
2610
- :param bool right:
2611
- """
2612
- assert kind in {"mul", "floordiv", "truediv", "ceildiv"}
2613
- if other.is_constant_static_dim() and other.dimension == 1:
2614
- return
2615
- if not self.terms:
2616
- if kind == "mul":
2617
- self.terms.append(other)
2618
- elif kind.endswith("div"):
2619
- self.terms = [_OpMultTerm.new_div_dim(self.as_dim(), other, kind=kind, right=right)]
2620
- return
2621
- if other.derived_from_op and other.derived_from_op.kind == "mul":
2622
- for term in other.derived_from_op.inputs if right else reversed(other.derived_from_op.inputs):
2623
- self.extend_mul_div_(term, kind=kind, right=right)
2624
- return
2625
- idx = self._simplify_term_idx(other, kind=kind, right=right)
2626
- if idx is not None:
2627
- term = self.terms[idx]
2628
- assert isinstance(term, _d.Dim)
2629
- if kind.endswith("div") and other == term:
2630
- self.terms.pop(idx)
2631
- return
2632
- if kind == "mul" and term.derived_from_op:
2633
- if term.derived_from_op.kind == "truediv_" + ("right" if right else "left"):
2634
- if term.derived_from_op.inputs[-1] == other:
2635
- self.terms[idx] = term.derived_from_op.inputs[0]
2636
- return
2637
- if kind == "mul" and other.derived_from_op:
2638
- if other.derived_from_op.kind == "truediv_" + ("right" if not right else "left"):
2639
- if other.derived_from_op.inputs[-1] == term:
2640
- self.terms[idx] = other.derived_from_op.inputs[0]
2641
- return
2642
- if term.is_constant_static_dim() and other.is_constant_static_dim():
2643
- if kind == "mul":
2644
- if term.dimension * other.dimension == 1:
2645
- self.terms.pop(idx)
2646
- return
2647
- self.terms[idx] = _make_constant_static_dim(term.dimension * other.dimension, kind=term.kind)
2648
- return
2649
- if kind.endswith("div") and term.dimension % other.dimension == 0:
2650
- self.terms[idx] = _make_constant_static_dim(term.dimension // other.dimension, kind=term.kind)
2651
- return
2652
- # Fallback with generic handling.
2653
- op_kind = kind + "_" + ("right" if right else "left")
2654
- if kind.endswith("div") and term.derived_from_op and term.derived_from_op.kind == op_kind:
2655
- numerator = term.derived_from_op.inputs[0]
2656
- denominator = term.derived_from_op.inputs[1]
2657
- self.terms[idx] = _OpMultTerm.new_div_dim(numerator, denominator * other, kind=kind, right=right)
2658
- return
2659
- if kind.endswith("div"):
2660
- self.terms = [_OpMultTerm.new_div_dim(self.as_dim(), other, kind=kind, right=right)]
2661
- return
2662
- if kind == "mul":
2663
- if right:
2664
- self.terms.append(other)
2665
- else:
2666
- self.terms.insert(0, other)
2667
- return
2668
- assert False
2669
-
2670
- @classmethod
2671
- def new_div_dim(cls, numerator, denominator, kind, right):
2672
- """
2673
- :param Dim numerator:
2674
- :param Dim denominator:
2675
- :param str kind: "floordiv" or "ceildiv" or "truediv"
2676
- :param bool right:
2677
- :rtype: Dim
2678
- """
2679
- dim_value = None
2680
- a = numerator.dimension
2681
- b = denominator.dimension
2682
- if a is not None and b is not None:
2683
- if kind == "floordiv":
2684
- dim_value = a // b
2685
- elif kind == "ceildiv":
2686
- dim_value = -(-a // b)
2687
- if a % b == 0 and right:
2688
- kind = "floordiv" # for nicer description, and does not matter
2689
- elif kind == "truediv":
2690
- if a % b != 0:
2691
- raise ValueError(
2692
- "%s truediv %s only allowed if the result is an integer" % (numerator, denominator)
2693
- )
2694
- dim_value = a // b
2695
- if right:
2696
- kind = "floordiv" # for nicer description, and does not matter
2697
- else:
2698
- raise ValueError("invalid kind %r" % (kind,))
2699
- # noinspection PyProtectedMember
2700
- cache, cache_key, res = numerator._cache_dim_math_get(kind + ("" if right else "_left"), denominator)
2701
- if res:
2702
- return res
2703
- if kind == "floordiv" and right:
2704
- description = "%s//%s" % (_get_description(numerator), _get_description(denominator))
2705
- elif kind == "ceildiv" and right:
2706
- description = "⌈%s/%s⌉" % (_get_description(numerator), _get_description(denominator))
2707
- else:
2708
- description = "%s_%s(%s, %s)" % (
2709
- kind,
2710
- "right" if right else "left",
2711
- _get_description(numerator, brackets=False),
2712
- _get_description(denominator, brackets=False),
2713
- )
2714
- op_kind = kind
2715
- if a is not None and b is not None and a % b == 0:
2716
- op_kind = "truediv" # makes some other checks simpler
2717
- op_kind += "_" + ("right" if right else "left")
2718
- res = _d.Dim(
2719
- description=description,
2720
- kind=numerator.kind,
2721
- dimension=dim_value,
2722
- derived_from_op=Op(kind=op_kind, inputs=[numerator, denominator]),
2723
- derived_from_tag=numerator,
2724
- )
2725
- cache[cache_key] = res
2726
- return res
2727
-
2728
- def as_dim(self):
2729
- """
2730
- :rtype: Dim
2731
- """
2732
- if self.is_one():
2733
- return _make_constant_static_dim(1)
2734
- if len(self.terms) == 1:
2735
- return self.terms[0]
2736
- res = self.terms[0]
2737
- for operand in self.terms[1:]:
2738
- res = _math_get_dim_via_bin_op(res, operand, "mul")
2739
- return res
2740
-
2741
-
2742
- class _OpLinearTerm:
2743
- """
2744
- Linear combination of :class:`_OpMultTerm`.
2745
- Represents sth like a * b + c of :class:`Dim`.
2746
- """
2747
-
2748
- @classmethod
2749
- def from_dim(cls, dim: Dim) -> _OpLinearTerm:
2750
- """from dim"""
2751
- res = cls.zero()
2752
- res.extend_add_sub_(dim, kind="add", right=True)
2753
- return res
2754
-
2755
- @classmethod
2756
- def zero(cls) -> _OpLinearTerm:
2757
- """0"""
2758
- return _OpLinearTerm([])
2759
-
2760
- def __init__(self, terms: List[_OpMultTerm]):
2761
- self.terms = terms
2762
-
2763
- def __hash__(self):
2764
- return hash(tuple(self.terms))
2765
-
2766
- def __eq__(self, other):
2767
- if isinstance(other, _OpLinearTerm):
2768
- return self.terms == other.terms
2769
- return False
2770
-
2771
- def __ne__(self, other):
2772
- return not self.__eq__(other)
2773
-
2774
- def as_dim(self) -> Dim:
2775
- """as dim"""
2776
- if self.is_zero():
2777
- return _make_constant_static_dim(0)
2778
- if len(self.terms) == 1:
2779
- return self.terms[0].as_dim()
2780
- res = self.terms[0].as_dim()
2781
- for operand in self.terms[1:]:
2782
- res = _math_get_dim_via_bin_op(res, operand.as_dim(), "add")
2783
- return res
2784
-
2785
- def __repr__(self):
2786
- return "Dim._OpLinearTerm(%r)" % (self.terms,)
2787
-
2788
- def is_zero(self):
2789
- """
2790
- :rtype: bool
2791
- """
2792
- return not self.terms
2793
-
2794
- def extend_add_sub_(self, other, kind, right):
2795
- """
2796
- :param Dim|int other:
2797
- :param str kind: "add" or "sub"
2798
- :param bool right: or left. right means self + other, left means other + self
2799
- """
2800
- assert kind in {"add", "sub"}
2801
- other = self._make_dim(other, kind=kind)
2802
- if other.is_constant_static_dim() and other.dimension == 0:
2803
- return
2804
- if other.derived_from_op and other.derived_from_op.kind == "add":
2805
- for other_ in other.derived_from_op.inputs if right else reversed(other.derived_from_op.inputs):
2806
- self.extend_add_sub_(other_, kind=kind, right=right)
2807
- return
2808
- term = _OpMultTerm.from_dim(other)
2809
- neg_term = term.negative()
2810
- if kind == "sub":
2811
- term, neg_term = neg_term, term
2812
- most_recent_term = self.terms[-1 if right else 0] if self.terms else None
2813
- if most_recent_term:
2814
- if most_recent_term == neg_term:
2815
- self.terms.pop(-1 if right else 0)
2816
- return
2817
- if most_recent_term.is_constant_static_dim() and term.is_constant_static_dim():
2818
- self.terms[-1 if right else 0] = _OpMultTerm.from_dim(
2819
- _make_constant_static_dim(most_recent_term.dimension + term.dimension, kind=other.kind)
2820
- )
2821
- return
2822
- if most_recent_term.terms and term.terms and most_recent_term.terms[-1] == term.terms[-1]:
2823
- # Merge terms
2824
- a = _OpMultTerm.from_dim_factors(most_recent_term.terms[:-1]).as_dim()
2825
- b = _OpMultTerm.from_dim_factors(term.terms[:-1]).as_dim()
2826
- if a.is_constant_static_dim() and not b.is_constant_static_dim():
2827
- a = a.dimension
2828
- elif b.is_constant_static_dim() and not a.is_constant_static_dim():
2829
- b = b.dimension
2830
- res = _OpMultTerm.from_dim((a + b) if right else (b + a))
2831
- res.extend_mul_div_(term.terms[-1], kind="mul", right=True)
2832
- self.terms[-1 if right else 0] = res
2833
- return
2834
- if right:
2835
- self.terms.append(term)
2836
- else:
2837
- self.terms.insert(0, term)
2838
-
2839
- def extend_mul_div_(self, other, kind, right):
2840
- """
2841
- :param Dim|int other:
2842
- :param str kind: "mul" or "ceildiv"
2843
- :param bool right: or left. right means self * other, left means other * self
2844
- """
2845
- assert kind in {"mul", "floordiv", "truediv", "ceildiv"}
2846
- other = self._make_dim(other, kind=kind)
2847
- if kind == "mul" and right:
2848
- if not all(term.can_simplify(other, kind=kind, right=right) for term in self.terms):
2849
- # Do it the other way around
2850
- self.terms, other = _OpLinearTerm.from_dim(other).terms, self.as_dim()
2851
- right = False
2852
- if other.is_constant_static_dim() and other.dimension == 1:
2853
- return
2854
- if kind.endswith("div") and len(self.terms) >= 2:
2855
- if any(not term.divisible(other, right=right) for term in self.terms):
2856
- self.terms = [
2857
- _OpMultTerm.from_dim(_OpMultTerm.new_div_dim(self.as_dim(), other, kind=kind, right=right))
2858
- ]
2859
- return
2860
- for term in self.terms:
2861
- term.extend_mul_div_(other, kind=kind, right=right)
2862
-
2863
- def _make_dim(self, other, kind):
2864
- """
2865
- :param Dim|int other:
2866
- :param str kind:
2867
- :rtype: Dim
2868
- """
2869
- if isinstance(other, int):
2870
- base_tag = self.representative_tag()
2871
- return _make_constant_static_dim(other, kind=base_tag.kind if base_tag else None)
2872
- elif isinstance(other, _d.Dim):
2873
- return other.get_same_base()
2874
- else:
2875
- raise TypeError("%s %s %s invalid for type %s" % (self, kind, other, type(other)))
2876
-
2877
- def representative_tag(self):
2878
- """
2879
- :rtype: Dim|None
2880
- """
2881
- terms = [_representative_tag(term.terms) for term in self.terms]
2882
- return _representative_tag([term for term in terms if term])
2883
-
2884
-
2885
- def _get_merged_dim_kind(dim_tags):
2886
- """
2887
- :param list[Dim]|tuple[Dim] dim_tags:
2888
- :return: dim kind
2889
- :rtype: Entity
2890
- """
2658
+ def _get_merged_dim_kind(dim_tags: Sequence[Dim]) -> Entity:
2891
2659
  if any(tag.is_batch_dim() for tag in dim_tags):
2892
2660
  return DimTypes.Batch
2893
2661
  elif any(tag.is_feature_dim() for tag in dim_tags):
@@ -8,6 +8,10 @@ import os
8
8
  import sys
9
9
  import gc
10
10
  import subprocess
11
+ import signal
12
+ import time
13
+ import contextlib
14
+ import multiprocessing
11
15
  import torch
12
16
  from returnn.util.better_exchook import better_exchook
13
17
  from returnn.util.basic import human_bytes_size
@@ -26,36 +30,39 @@ def print_available_devices(*, file: Optional[TextIO] = None):
26
30
  print("CUDA_VISIBLE_DEVICES is set to %r." % os.environ["CUDA_VISIBLE_DEVICES"], file=file)
27
31
  cuda_visible_devs = dict(enumerate([int(d) for d in os.environ["CUDA_VISIBLE_DEVICES"].split(",") if d]))
28
32
  else:
29
- if torch.cuda.is_available():
30
- print("CUDA_VISIBLE_DEVICES is not set.", file=file)
31
-
32
- if torch.cuda.is_available():
33
- print("Available CUDA devices:")
34
- count = torch.cuda.device_count()
35
- if cuda_visible_devs is not None and len(cuda_visible_devs) != count:
36
- print(
37
- f"(Mismatch between CUDA device count {count}"
38
- f" and CUDA_VISIBLE_DEVICES {cuda_visible_devs} count {len(cuda_visible_devs)}?)",
39
- file=file,
40
- )
41
- for i in range(count):
42
- print(f" {i + 1}/{count}: cuda:{i}", file=file)
43
- props = torch.cuda.get_device_properties(i)
44
- print(f" name: {props.name}", file=file)
45
- print(f" total_memory: {human_bytes_size(props.total_memory)}", file=file)
46
- print(f" capability: {props.major}.{props.minor}", file=file)
47
- if cuda_visible_devs is not None:
48
- if len(cuda_visible_devs) == count:
49
- dev_idx_s = cuda_visible_devs[i]
50
- else:
51
- dev_idx_s = "?"
33
+ with timeout("torch.cuda.is_available()"):
34
+ if torch.cuda.is_available():
35
+ print("CUDA_VISIBLE_DEVICES is not set.", file=file)
36
+
37
+ with timeout("torch.cuda.is_available()"):
38
+ if not torch.cuda.is_available():
39
+ print("(CUDA not available)", file=file)
40
+ return
41
+
42
+ print("Available CUDA devices:", file=file)
43
+ count = torch.cuda.device_count()
44
+ if cuda_visible_devs is not None and len(cuda_visible_devs) != count:
45
+ print(
46
+ f"(Mismatch between CUDA device count {count}"
47
+ f" and CUDA_VISIBLE_DEVICES {cuda_visible_devs} count {len(cuda_visible_devs)}?)",
48
+ file=file,
49
+ )
50
+ for i in range(count):
51
+ print(f" {i + 1}/{count}: cuda:{i}", file=file)
52
+ props = torch.cuda.get_device_properties(i)
53
+ print(f" name: {props.name}", file=file)
54
+ print(f" total_memory: {human_bytes_size(props.total_memory)}", file=file)
55
+ print(f" capability: {props.major}.{props.minor}", file=file)
56
+ if cuda_visible_devs is not None:
57
+ if len(cuda_visible_devs) == count:
58
+ dev_idx_s = cuda_visible_devs[i]
52
59
  else:
53
- dev_idx_s = i
54
- print(f" device_index: {dev_idx_s}", file=file)
55
- if not count:
56
- print(" (None)")
57
- else:
58
- print("(CUDA not available)")
60
+ dev_idx_s = "?"
61
+ else:
62
+ dev_idx_s = i
63
+ print(f" device_index: {dev_idx_s}", file=file)
64
+ if not count:
65
+ print(" (None)", file=file)
59
66
 
60
67
 
61
68
  def print_using_cuda_device_report(dev: Union[str, torch.device], *, file: Optional[TextIO] = None):
@@ -108,7 +115,7 @@ def diagnose_no_gpu() -> List[str]:
108
115
  except Exception as exc:
109
116
  print("nvidia-smi failed:", exc)
110
117
  better_exchook(*sys.exc_info(), debugshell=False)
111
- res.append(f"nvidia-smi failed")
118
+ res.append("nvidia-smi failed")
112
119
 
113
120
  return res
114
121
 
@@ -152,4 +159,31 @@ def garbage_collect():
152
159
  f"alloc {human_bytes_size(torch.cuda.memory_allocated())}",
153
160
  f"reserved {human_bytes_size(torch.cuda.memory_reserved())}",
154
161
  ]
155
- print(f"CUDA memory usage after triggered GC:", " ".join(stats))
162
+ print("CUDA memory usage after triggered GC:", " ".join(stats))
163
+
164
+
165
+ @contextlib.contextmanager
166
+ def timeout(info: str, *, seconds: int = 30):
167
+ """
168
+ Note: don't use signal handlers (e.g. signal.alarm) because unfortunately
169
+ potential hanging funcs will block the main thread and thus block the signal handler from executing.
170
+ Thus, we use a subprocess.
171
+
172
+ :param seconds:
173
+ :param info:
174
+ """
175
+ proc = multiprocessing.Process(
176
+ target=_timeout_handler, kwargs={"seconds": seconds, "proc_id": os.getpid(), "info": info}
177
+ )
178
+ proc.start()
179
+ try:
180
+ yield
181
+ finally:
182
+ proc.terminate()
183
+ proc.join()
184
+
185
+
186
+ def _timeout_handler(*, seconds: Union[float, int], proc_id: int, info: str):
187
+ time.sleep(seconds)
188
+ print(f"ERROR: {info}: Timeout handler after {seconds} seconds, killing proc {proc_id}.", file=sys.stderr)
189
+ os.kill(proc_id, signal.SIGABRT)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250830.114445
3
+ Version: 1.20250902.10950
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=PZOQvfJJKkUgAsvh_nvFeAdQOFCWpBuj2T4euwap8VA,5215
1
+ returnn/PKG-INFO,sha256=GVal7eVN_obo9mfdhPK2WvH2MzSm51cFZJChHEsF2XU,5214
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=ukzF4nRM4yknegwGHq4ktCyLc7tP9fasq5bpzDq0Tvg,77
6
+ returnn/_setup_info_generated.py,sha256=jTlsQFAqLqFgm0UJ0uWltcnLf69QwqOK0yV4Slt-2Is,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -154,7 +154,7 @@ returnn/sprint/extern_interface.py,sha256=l-v1X-Yg0UpTFe7Y3c4FwWOqpSNuv9Oy5EzqlK
154
154
  returnn/sprint/interface.py,sha256=1j5SB0V8hSW8A5song9ciZtcBnZoKKfNipk9ezOIMuA,36491
155
155
  returnn/tensor/README.md,sha256=X6BqcRLrPLPnwF9yR69uqIFrMnNluj9pBkOPHwNgzuo,501
156
156
  returnn/tensor/__init__.py,sha256=on6j5PEOQpck50UcsR4nJzJSDmoVy34z1Oq4efv6Ax0,154
157
- returnn/tensor/_dim_extra.py,sha256=VN7Smn1Q0Y0DO7GSPM-aJUhp_jy5pzSMJbPkCk6JnqY,123448
157
+ returnn/tensor/_dim_extra.py,sha256=rwtDR5WRS8wqgKj4WkPaWtaKa8UJYTrS76ZhX0W5bP4,115580
158
158
  returnn/tensor/_tensor_extra.py,sha256=gbSl6HMtn8WFYloanew_RaNNwx3eCpnKv3UfCkntJiQ,164923
159
159
  returnn/tensor/_tensor_mixin_base.py,sha256=H5z86I0NejxrSgMH1c5oXQzBqS6L9HpvP4y7oegBaSc,643
160
160
  returnn/tensor/_tensor_op_overloads.py,sha256=HklwuTBjy7mH_665VKaCUdu-oC3aa7Uz1ZQiCz4jeZc,5448
@@ -227,7 +227,7 @@ returnn/torch/util/README.md,sha256=AW-6ueWhgcwDcm57md6sm227QXNkvLnlRLwaH7NlS-w,
227
227
  returnn/torch/util/__init__.py,sha256=AOXYUjzPm0XrzFJCPAXo9Jj_FvqD1XH3FfKtho80Vl8,26
228
228
  returnn/torch/util/array_.py,sha256=ell3VZvn01SLtF9Pw2fvPzFNO-XDQ7tSB9VCrVSKmSA,2556
229
229
  returnn/torch/util/debug_inf_nan.py,sha256=fmzSSTJJyLf7i5yDWRHLeDI0gxvadeqLE8RxMuSHx_4,6398
230
- returnn/torch/util/diagnose_gpu.py,sha256=PYMmSk7iQ-jC3RXKKNXlYx1Q744C0LXqz0SB6ympwQg,5844
230
+ returnn/torch/util/diagnose_gpu.py,sha256=_yswLmwR8Q2rCsv2jI5FUQNBT__453jBmiWYwazdu20,6808
231
231
  returnn/torch/util/exception_helper.py,sha256=_SqxTD5F-GDY2eR4uRALyUTJwt0ytcbJGB_w38RJMBA,4320
232
232
  returnn/torch/util/gradient_checkpoint.py,sha256=iLy-FB65DC8O6LxzmMvFjnSdpIVpko87ppIvRKAbtpQ,27995
233
233
  returnn/torch/util/module.py,sha256=MXHIrF9Isu575DDJIa81212ULKwdqu1oOLxDVZecVSk,1693
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250830.114445.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250830.114445.dist-info/METADATA,sha256=PZOQvfJJKkUgAsvh_nvFeAdQOFCWpBuj2T4euwap8VA,5215
258
- returnn-1.20250830.114445.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20250830.114445.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250830.114445.dist-info/RECORD,,
256
+ returnn-1.20250902.10950.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250902.10950.dist-info/METADATA,sha256=GVal7eVN_obo9mfdhPK2WvH2MzSm51cFZJChHEsF2XU,5214
258
+ returnn-1.20250902.10950.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20250902.10950.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250902.10950.dist-info/RECORD,,