returnn 1.20250830.114445__py3-none-any.whl → 1.20250902.10950__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/tensor/_dim_extra.py +301 -533
- returnn/torch/util/diagnose_gpu.py +65 -31
- {returnn-1.20250830.114445.dist-info → returnn-1.20250902.10950.dist-info}/METADATA +1 -1
- {returnn-1.20250830.114445.dist-info → returnn-1.20250902.10950.dist-info}/RECORD +9 -9
- {returnn-1.20250830.114445.dist-info → returnn-1.20250902.10950.dist-info}/LICENSE +0 -0
- {returnn-1.20250830.114445.dist-info → returnn-1.20250902.10950.dist-info}/WHEEL +0 -0
- {returnn-1.20250830.114445.dist-info → returnn-1.20250902.10950.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250902.010950'
|
|
2
|
+
long_version = '1.20250902.010950+git.9d5debf'
|
returnn/tensor/_dim_extra.py
CHANGED
|
@@ -2082,7 +2082,7 @@ class _DimMixin:
|
|
|
2082
2082
|
:return: self + other. note that this is not commutative, i.e. different from other + self.
|
|
2083
2083
|
:rtype: Dim
|
|
2084
2084
|
"""
|
|
2085
|
-
if
|
|
2085
|
+
if _is_const_dim_value(other, 0):
|
|
2086
2086
|
return self
|
|
2087
2087
|
cache_key = ("add", other)
|
|
2088
2088
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
@@ -2090,11 +2090,11 @@ class _DimMixin:
|
|
|
2090
2090
|
if cache_entry:
|
|
2091
2091
|
cache_entry.complete_dyn_size()
|
|
2092
2092
|
return cache_entry
|
|
2093
|
-
|
|
2094
|
-
|
|
2095
|
-
|
|
2096
|
-
cache[cache_key] =
|
|
2097
|
-
return
|
|
2093
|
+
res = _MathFindMatchingAdditive(start=self, right=True, other=other).search_and_maybe_replace()
|
|
2094
|
+
if not res:
|
|
2095
|
+
res = _math_get_dim_via_bin_op([self, other], "add")
|
|
2096
|
+
cache[cache_key] = res
|
|
2097
|
+
return res
|
|
2098
2098
|
|
|
2099
2099
|
def __radd__(self: Dim, other):
|
|
2100
2100
|
"""
|
|
@@ -2102,7 +2102,7 @@ class _DimMixin:
|
|
|
2102
2102
|
:return: other + self
|
|
2103
2103
|
:rtype: Dim
|
|
2104
2104
|
"""
|
|
2105
|
-
if
|
|
2105
|
+
if _is_const_dim_value(other, 0):
|
|
2106
2106
|
return self
|
|
2107
2107
|
cache_key = ("add_left", other)
|
|
2108
2108
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
@@ -2110,18 +2110,18 @@ class _DimMixin:
|
|
|
2110
2110
|
if cache_entry:
|
|
2111
2111
|
cache_entry.complete_dyn_size()
|
|
2112
2112
|
return cache_entry
|
|
2113
|
-
|
|
2114
|
-
|
|
2115
|
-
|
|
2116
|
-
cache[cache_key] =
|
|
2117
|
-
return
|
|
2113
|
+
res = _MathFindMatchingAdditive(start=self, right=False, other=other).search_and_maybe_replace()
|
|
2114
|
+
if not res:
|
|
2115
|
+
res = _math_get_dim_via_bin_op([other, self], "add")
|
|
2116
|
+
cache[cache_key] = res
|
|
2117
|
+
return res
|
|
2118
2118
|
|
|
2119
2119
|
def __sub__(self, other):
|
|
2120
2120
|
"""
|
|
2121
2121
|
:param Dim|int other:
|
|
2122
2122
|
:rtype: Dim
|
|
2123
2123
|
"""
|
|
2124
|
-
if
|
|
2124
|
+
if _is_const_dim_value(other, 0):
|
|
2125
2125
|
return self
|
|
2126
2126
|
return self.sub_right(other)
|
|
2127
2127
|
|
|
@@ -2131,19 +2131,24 @@ class _DimMixin:
|
|
|
2131
2131
|
:return: self - other
|
|
2132
2132
|
:rtype: Dim
|
|
2133
2133
|
"""
|
|
2134
|
-
if
|
|
2134
|
+
if _is_const_dim_value(other, 0):
|
|
2135
2135
|
return self
|
|
2136
|
+
if (
|
|
2137
|
+
self.derived_from_op
|
|
2138
|
+
and self.derived_from_op.kind == "add"
|
|
2139
|
+
and len(self.derived_from_op.inputs) == 2
|
|
2140
|
+
and _dim_or_const_equal(other, self.derived_from_op.inputs[1])
|
|
2141
|
+
):
|
|
2142
|
+
return self.derived_from_op.inputs[0]
|
|
2136
2143
|
cache_key = ("sub", other)
|
|
2137
2144
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2138
2145
|
cache_entry = cache.get(cache_key, None)
|
|
2139
2146
|
if cache_entry:
|
|
2140
2147
|
cache_entry.complete_dyn_size()
|
|
2141
2148
|
return cache_entry
|
|
2142
|
-
|
|
2143
|
-
|
|
2144
|
-
|
|
2145
|
-
cache[cache_key] = dim
|
|
2146
|
-
return dim
|
|
2149
|
+
res = self + (-other)
|
|
2150
|
+
cache[cache_key] = res
|
|
2151
|
+
return res
|
|
2147
2152
|
|
|
2148
2153
|
def sub_left(self: Dim, other):
|
|
2149
2154
|
"""
|
|
@@ -2151,44 +2156,55 @@ class _DimMixin:
|
|
|
2151
2156
|
:return: (-other) + self
|
|
2152
2157
|
:rtype: Dim
|
|
2153
2158
|
"""
|
|
2154
|
-
if
|
|
2159
|
+
if _is_const_dim_value(other, 0):
|
|
2155
2160
|
return self
|
|
2161
|
+
if (
|
|
2162
|
+
self.derived_from_op
|
|
2163
|
+
and self.derived_from_op.kind == "add"
|
|
2164
|
+
and len(self.derived_from_op.inputs) == 2
|
|
2165
|
+
and _dim_or_const_equal(other, self.derived_from_op.inputs[0])
|
|
2166
|
+
):
|
|
2167
|
+
return self.derived_from_op.inputs[1]
|
|
2156
2168
|
cache_key = ("sub_left", other)
|
|
2157
2169
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2158
2170
|
cache_entry = cache.get(cache_key, None)
|
|
2159
2171
|
if cache_entry:
|
|
2160
2172
|
cache_entry.complete_dyn_size()
|
|
2161
2173
|
return cache_entry
|
|
2162
|
-
|
|
2163
|
-
|
|
2164
|
-
|
|
2165
|
-
cache[cache_key] = dim
|
|
2166
|
-
return dim
|
|
2174
|
+
res = (-other) + self
|
|
2175
|
+
cache[cache_key] = res
|
|
2176
|
+
return res
|
|
2167
2177
|
|
|
2168
2178
|
def __mul__(self: Dim, other):
|
|
2169
2179
|
"""
|
|
2170
2180
|
:param Dim|int other:
|
|
2171
2181
|
:rtype: Dim
|
|
2172
2182
|
"""
|
|
2183
|
+
if isinstance(other, _d.Dim) and other.is_constant_static_dim():
|
|
2184
|
+
other = other.dimension # makes matching easier
|
|
2173
2185
|
if isinstance(other, int) and other == 1:
|
|
2174
2186
|
return self
|
|
2187
|
+
if self.is_constant_static_dim() and isinstance(other, _d.Dim):
|
|
2188
|
+
return self.dimension * other # use rmul
|
|
2175
2189
|
cache_key = ("mul", other)
|
|
2176
2190
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2177
2191
|
cache_entry = cache.get(cache_key, None)
|
|
2178
2192
|
if cache_entry:
|
|
2179
2193
|
cache_entry.complete_dyn_size()
|
|
2180
2194
|
return cache_entry
|
|
2181
|
-
|
|
2182
|
-
|
|
2183
|
-
|
|
2184
|
-
cache[cache_key] =
|
|
2185
|
-
return
|
|
2195
|
+
res = _math_find_matching_mult(start=self, right=True, other=other)
|
|
2196
|
+
if not res:
|
|
2197
|
+
res = _math_get_dim_via_bin_op([self, other], "mul")
|
|
2198
|
+
cache[cache_key] = res
|
|
2199
|
+
return res
|
|
2186
2200
|
|
|
2187
2201
|
def __rmul__(self: Dim, other):
|
|
2188
2202
|
"""
|
|
2189
2203
|
:param Dim|int other:
|
|
2190
2204
|
:rtype: Dim
|
|
2191
2205
|
"""
|
|
2206
|
+
if isinstance(other, _d.Dim) and other.is_constant_static_dim():
|
|
2207
|
+
other = other.dimension # makes matching easier
|
|
2192
2208
|
if isinstance(other, int) and other == 1:
|
|
2193
2209
|
return self
|
|
2194
2210
|
cache_key = ("mul_left", other)
|
|
@@ -2197,38 +2213,45 @@ class _DimMixin:
|
|
|
2197
2213
|
if cache_entry:
|
|
2198
2214
|
cache_entry.complete_dyn_size()
|
|
2199
2215
|
return cache_entry
|
|
2200
|
-
|
|
2201
|
-
|
|
2202
|
-
|
|
2203
|
-
cache[cache_key] =
|
|
2204
|
-
return
|
|
2216
|
+
res = _math_find_matching_mult(start=self, right=False, other=other)
|
|
2217
|
+
if not res:
|
|
2218
|
+
res = _math_get_dim_via_bin_op([other, self], "mul")
|
|
2219
|
+
cache[cache_key] = res
|
|
2220
|
+
return res
|
|
2205
2221
|
|
|
2206
2222
|
def __floordiv__(self: Dim, other):
|
|
2207
2223
|
"""
|
|
2208
2224
|
:param Dim|int other:
|
|
2209
2225
|
:rtype: Dim
|
|
2210
2226
|
"""
|
|
2227
|
+
if isinstance(other, _d.Dim) and other.is_constant_static_dim():
|
|
2228
|
+
other = other.dimension # makes matching easier
|
|
2211
2229
|
if isinstance(other, int) and other == 1:
|
|
2212
2230
|
return self
|
|
2231
|
+
if (
|
|
2232
|
+
self.derived_from_op
|
|
2233
|
+
and self.derived_from_op.kind == "mul"
|
|
2234
|
+
and len(self.derived_from_op.inputs) == 2
|
|
2235
|
+
and _dim_or_const_equal(other, self.derived_from_op.inputs[1])
|
|
2236
|
+
):
|
|
2237
|
+
return self.derived_from_op.inputs[0]
|
|
2213
2238
|
cache_key = ("floordiv", other)
|
|
2214
2239
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2215
2240
|
cache_entry = cache.get(cache_key, None)
|
|
2216
2241
|
if cache_entry:
|
|
2217
2242
|
cache_entry.complete_dyn_size()
|
|
2218
2243
|
return cache_entry
|
|
2219
|
-
|
|
2220
|
-
|
|
2221
|
-
|
|
2222
|
-
cache[cache_key] =
|
|
2223
|
-
return
|
|
2244
|
+
res = _math_find_matching_div(start=self, right=True, other=other, kind="floordiv")
|
|
2245
|
+
if not res:
|
|
2246
|
+
res = _math_get_dim_via_bin_op([self, other], "floordiv")
|
|
2247
|
+
cache[cache_key] = res
|
|
2248
|
+
return res
|
|
2224
2249
|
|
|
2225
2250
|
def __truediv__(self, other):
|
|
2226
2251
|
"""
|
|
2227
2252
|
:param Dim|int other:
|
|
2228
2253
|
:rtype: Dim
|
|
2229
2254
|
"""
|
|
2230
|
-
if isinstance(other, int) and other == 1:
|
|
2231
|
-
return self
|
|
2232
2255
|
return self.div_right(other)
|
|
2233
2256
|
|
|
2234
2257
|
def div_left(self: Dim, other):
|
|
@@ -2236,76 +2259,112 @@ class _DimMixin:
|
|
|
2236
2259
|
:param Dim|int other:
|
|
2237
2260
|
:rtype: Dim
|
|
2238
2261
|
"""
|
|
2262
|
+
if isinstance(other, _d.Dim) and other.is_constant_static_dim():
|
|
2263
|
+
other = other.dimension # makes matching easier
|
|
2239
2264
|
if isinstance(other, int) and other == 1:
|
|
2240
2265
|
return self
|
|
2266
|
+
if (
|
|
2267
|
+
self.derived_from_op
|
|
2268
|
+
and self.derived_from_op.kind == "mul"
|
|
2269
|
+
and len(self.derived_from_op.inputs) == 2
|
|
2270
|
+
and _dim_or_const_equal(other, self.derived_from_op.inputs[0])
|
|
2271
|
+
):
|
|
2272
|
+
return self.derived_from_op.inputs[1]
|
|
2241
2273
|
cache_key = ("truediv_left", other)
|
|
2242
2274
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2243
2275
|
cache_entry = cache.get(cache_key, None)
|
|
2244
2276
|
if cache_entry:
|
|
2245
2277
|
cache_entry.complete_dyn_size()
|
|
2246
2278
|
return cache_entry
|
|
2247
|
-
|
|
2248
|
-
|
|
2249
|
-
|
|
2250
|
-
cache[cache_key] =
|
|
2251
|
-
return
|
|
2279
|
+
res = _math_find_matching_div(start=self, right=True, other=other, kind="truediv_left")
|
|
2280
|
+
if not res:
|
|
2281
|
+
res = _math_get_dim_via_bin_op([self, other], "truediv_left")
|
|
2282
|
+
cache[cache_key] = res
|
|
2283
|
+
return res
|
|
2252
2284
|
|
|
2253
2285
|
def div_right(self: Dim, other):
|
|
2254
2286
|
"""
|
|
2255
2287
|
:param Dim|int other:
|
|
2256
2288
|
:rtype: Dim
|
|
2257
2289
|
"""
|
|
2290
|
+
if isinstance(other, _d.Dim) and other.is_constant_static_dim():
|
|
2291
|
+
other = other.dimension # makes matching easier
|
|
2258
2292
|
if isinstance(other, int) and other == 1:
|
|
2259
2293
|
return self
|
|
2294
|
+
if (
|
|
2295
|
+
self.derived_from_op
|
|
2296
|
+
and self.derived_from_op.kind == "mul"
|
|
2297
|
+
and len(self.derived_from_op.inputs) == 2
|
|
2298
|
+
and _dim_or_const_equal(other, self.derived_from_op.inputs[1])
|
|
2299
|
+
):
|
|
2300
|
+
return self.derived_from_op.inputs[0]
|
|
2260
2301
|
cache_key = ("truediv", other)
|
|
2261
2302
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2262
2303
|
cache_entry = cache.get(cache_key, None)
|
|
2263
2304
|
if cache_entry:
|
|
2264
2305
|
cache_entry.complete_dyn_size()
|
|
2265
2306
|
return cache_entry
|
|
2266
|
-
|
|
2267
|
-
|
|
2268
|
-
|
|
2269
|
-
cache[cache_key] =
|
|
2270
|
-
return
|
|
2307
|
+
res = _math_find_matching_div(start=self, right=True, other=other, kind="truediv")
|
|
2308
|
+
if not res:
|
|
2309
|
+
res = _math_get_dim_via_bin_op([self, other], "truediv")
|
|
2310
|
+
cache[cache_key] = res
|
|
2311
|
+
return res
|
|
2271
2312
|
|
|
2272
2313
|
def ceildiv_left(self: Dim, other):
|
|
2273
2314
|
"""
|
|
2274
2315
|
:param Dim|int other:
|
|
2275
2316
|
:rtype: Dim
|
|
2276
2317
|
"""
|
|
2318
|
+
if isinstance(other, _d.Dim) and other.is_constant_static_dim():
|
|
2319
|
+
other = other.dimension # makes matching easier
|
|
2277
2320
|
if isinstance(other, int) and other == 1:
|
|
2278
2321
|
return self
|
|
2322
|
+
if (
|
|
2323
|
+
self.derived_from_op
|
|
2324
|
+
and self.derived_from_op.kind == "mul"
|
|
2325
|
+
and len(self.derived_from_op.inputs) == 2
|
|
2326
|
+
and _dim_or_const_equal(other, self.derived_from_op.inputs[0])
|
|
2327
|
+
):
|
|
2328
|
+
return self.derived_from_op.inputs[1]
|
|
2279
2329
|
cache_key = ("ceildiv_left", other)
|
|
2280
2330
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2281
2331
|
cache_entry = cache.get(cache_key, None)
|
|
2282
2332
|
if cache_entry:
|
|
2283
2333
|
cache_entry.complete_dyn_size()
|
|
2284
2334
|
return cache_entry
|
|
2285
|
-
|
|
2286
|
-
|
|
2287
|
-
|
|
2288
|
-
cache[cache_key] =
|
|
2289
|
-
return
|
|
2335
|
+
res = _math_find_matching_div(start=self, right=True, other=other, kind="ceildiv_left")
|
|
2336
|
+
if not res:
|
|
2337
|
+
res = _math_get_dim_via_bin_op([self, other], "ceildiv_left")
|
|
2338
|
+
cache[cache_key] = res
|
|
2339
|
+
return res
|
|
2290
2340
|
|
|
2291
2341
|
def ceildiv_right(self: Dim, other):
|
|
2292
2342
|
"""
|
|
2293
2343
|
:param Dim|int other:
|
|
2294
2344
|
:rtype: Dim
|
|
2295
2345
|
"""
|
|
2346
|
+
if isinstance(other, _d.Dim) and other.is_constant_static_dim():
|
|
2347
|
+
other = other.dimension # makes matching easier
|
|
2296
2348
|
if isinstance(other, int) and other == 1:
|
|
2297
2349
|
return self
|
|
2350
|
+
if (
|
|
2351
|
+
self.derived_from_op
|
|
2352
|
+
and self.derived_from_op.kind == "mul"
|
|
2353
|
+
and len(self.derived_from_op.inputs) == 2
|
|
2354
|
+
and _dim_or_const_equal(other, self.derived_from_op.inputs[1])
|
|
2355
|
+
):
|
|
2356
|
+
return self.derived_from_op.inputs[0]
|
|
2298
2357
|
cache_key = ("ceildiv", other)
|
|
2299
2358
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2300
2359
|
cache_entry = cache.get(cache_key, None)
|
|
2301
2360
|
if cache_entry:
|
|
2302
2361
|
cache_entry.complete_dyn_size()
|
|
2303
2362
|
return cache_entry
|
|
2304
|
-
|
|
2305
|
-
|
|
2306
|
-
|
|
2307
|
-
cache[cache_key] =
|
|
2308
|
-
return
|
|
2363
|
+
res = _math_find_matching_div(start=self, right=True, other=other, kind="ceildiv")
|
|
2364
|
+
if not res:
|
|
2365
|
+
res = _math_get_dim_via_bin_op([self, other], "ceildiv")
|
|
2366
|
+
cache[cache_key] = res
|
|
2367
|
+
return res
|
|
2309
2368
|
|
|
2310
2369
|
def __neg__(self):
|
|
2311
2370
|
"""
|
|
@@ -2348,6 +2407,7 @@ def _make_constant_static_dim(value, kind=None):
|
|
|
2348
2407
|
:param Entity|None kind:
|
|
2349
2408
|
:rtype: Dim
|
|
2350
2409
|
"""
|
|
2410
|
+
assert isinstance(value, int)
|
|
2351
2411
|
return _d.Dim(
|
|
2352
2412
|
dimension=value,
|
|
2353
2413
|
kind=kind or DimTypes.Unspecified,
|
|
@@ -2357,30 +2417,188 @@ def _make_constant_static_dim(value, kind=None):
|
|
|
2357
2417
|
)
|
|
2358
2418
|
|
|
2359
2419
|
|
|
2360
|
-
def
|
|
2361
|
-
|
|
2362
|
-
|
|
2363
|
-
|
|
2364
|
-
|
|
2365
|
-
|
|
2366
|
-
|
|
2367
|
-
|
|
2368
|
-
|
|
2369
|
-
|
|
2370
|
-
|
|
2371
|
-
|
|
2372
|
-
|
|
2420
|
+
def _dim_or_const_equal(a: Union[Dim, int], b: Union[Dim, int]) -> bool:
|
|
2421
|
+
if isinstance(a, int):
|
|
2422
|
+
if isinstance(b, int):
|
|
2423
|
+
return a == b
|
|
2424
|
+
elif isinstance(b, _d.Dim):
|
|
2425
|
+
return a == b.dimension and b.is_constant_static_dim()
|
|
2426
|
+
else:
|
|
2427
|
+
raise TypeError(f"unexpected b type {type(b)}")
|
|
2428
|
+
elif isinstance(a, _d.Dim):
|
|
2429
|
+
if a.is_constant_static_dim():
|
|
2430
|
+
if isinstance(b, int):
|
|
2431
|
+
return a.dimension == b
|
|
2432
|
+
elif isinstance(b, _d.Dim):
|
|
2433
|
+
return a.dimension == b.dimension and b.is_constant_static_dim()
|
|
2434
|
+
else:
|
|
2435
|
+
raise TypeError(f"unexpected b type {type(b)}")
|
|
2436
|
+
else:
|
|
2437
|
+
if isinstance(b, int):
|
|
2438
|
+
return False
|
|
2439
|
+
elif isinstance(b, _d.Dim):
|
|
2440
|
+
return a == b
|
|
2441
|
+
else:
|
|
2442
|
+
raise TypeError(f"unexpected b type {type(b)}")
|
|
2443
|
+
else:
|
|
2444
|
+
raise TypeError(f"unexpected a type {type(a)}")
|
|
2445
|
+
|
|
2446
|
+
|
|
2447
|
+
_BinOps = {
|
|
2448
|
+
"add": operator.add,
|
|
2449
|
+
"mul": operator.mul,
|
|
2450
|
+
"sub": operator.sub,
|
|
2451
|
+
"floordiv": operator.floordiv,
|
|
2452
|
+
"truediv": operator.floordiv,
|
|
2453
|
+
"truediv_left": operator.floordiv,
|
|
2454
|
+
"ceildiv": lambda a, b: -(-a // b),
|
|
2455
|
+
"ceildiv_left": lambda a, b: -(-a // b),
|
|
2456
|
+
}
|
|
2457
|
+
|
|
2458
|
+
_BinOpStrs = {
|
|
2459
|
+
"add": "+",
|
|
2460
|
+
"mul": "*",
|
|
2461
|
+
"sub": "-",
|
|
2462
|
+
"floordiv": "//",
|
|
2463
|
+
"truediv": "/",
|
|
2464
|
+
"truediv_left": " /l ",
|
|
2465
|
+
"ceildiv": "/",
|
|
2466
|
+
"ceildiv_left": " /l ",
|
|
2467
|
+
}
|
|
2468
|
+
|
|
2469
|
+
|
|
2470
|
+
def _math_get_dim_via_bin_op(dims: Sequence[Union[Dim, int]], op_kind: str) -> Dim:
|
|
2471
|
+
dims = [d if isinstance(d, _d.Dim) else _make_constant_static_dim(d) for d in dims]
|
|
2472
|
+
if all(d.dimension is not None for d in dims):
|
|
2473
|
+
op = _BinOps[op_kind]
|
|
2474
|
+
dim_value = dims[0].dimension
|
|
2475
|
+
for d in dims[1:]:
|
|
2476
|
+
dim_value = op(dim_value, d.dimension)
|
|
2373
2477
|
else:
|
|
2374
2478
|
dim_value = None
|
|
2375
|
-
|
|
2376
|
-
kind=_get_merged_dim_kind(
|
|
2377
|
-
|
|
2479
|
+
if all(d.is_constant_static_dim() for d in dims):
|
|
2480
|
+
return _make_constant_static_dim(dim_value, kind=_get_merged_dim_kind(dims))
|
|
2481
|
+
desc = _BinOpStrs[op_kind].join(_get_description(d) for d in dims)
|
|
2482
|
+
if op_kind.startswith("ceildiv"):
|
|
2483
|
+
desc = f"⌈{desc}⌉"
|
|
2484
|
+
return _d.Dim(
|
|
2485
|
+
kind=_get_merged_dim_kind(dims),
|
|
2486
|
+
description=desc,
|
|
2378
2487
|
dimension=dim_value,
|
|
2379
|
-
derived_from_op=Op(kind=op_kind, inputs=
|
|
2380
|
-
derived_from_tag=_representative_tag(
|
|
2488
|
+
derived_from_op=Op(kind=op_kind, inputs=list(dims)),
|
|
2489
|
+
derived_from_tag=_representative_tag(dims),
|
|
2381
2490
|
)
|
|
2382
|
-
|
|
2383
|
-
|
|
2491
|
+
|
|
2492
|
+
|
|
2493
|
+
def _is_const_dim_value(d: Union[Dim, int], value: int) -> bool:
|
|
2494
|
+
if isinstance(d, int):
|
|
2495
|
+
return d == value
|
|
2496
|
+
elif isinstance(d, _d.Dim):
|
|
2497
|
+
return d.is_constant_static_dim() and d.dimension == value
|
|
2498
|
+
else:
|
|
2499
|
+
raise TypeError(f"unexpected type {type(d)}")
|
|
2500
|
+
|
|
2501
|
+
|
|
2502
|
+
class _MathFindMatchingAdditive:
|
|
2503
|
+
def __init__(self, start: Dim, *, max_depth: int = 2, right: bool, other: Union[int, Dim]):
|
|
2504
|
+
self.start = start
|
|
2505
|
+
self.max_depth = max_depth
|
|
2506
|
+
self.right = right
|
|
2507
|
+
self.other = other
|
|
2508
|
+
|
|
2509
|
+
def _check_and_maybe_replace(self, candidate: Dim) -> Optional[Dim]:
|
|
2510
|
+
"""
|
|
2511
|
+
Check and return potential replacement for candidate, when adding `other` to it.
|
|
2512
|
+
"""
|
|
2513
|
+
other = self.other
|
|
2514
|
+
if isinstance(other, int) or other.is_constant_static_dim():
|
|
2515
|
+
if candidate.is_constant_static_dim():
|
|
2516
|
+
return _math_get_dim_via_bin_op([candidate, other] if self.right else [other, candidate], "add")
|
|
2517
|
+
return None
|
|
2518
|
+
if candidate == other:
|
|
2519
|
+
return candidate.__rmul__(2)
|
|
2520
|
+
c_op = candidate.derived_from_op
|
|
2521
|
+
if c_op and c_op.kind == "mul" and len(c_op.inputs) == 2 and c_op.inputs[1] == other:
|
|
2522
|
+
factor = (c_op.inputs[0] + 1) if self.right else (1 + c_op.inputs[0])
|
|
2523
|
+
if factor.is_constant_static_dim():
|
|
2524
|
+
if factor.dimension == 0:
|
|
2525
|
+
return factor
|
|
2526
|
+
factor = factor.dimension
|
|
2527
|
+
return factor * c_op.inputs[1]
|
|
2528
|
+
o_op = other.derived_from_op
|
|
2529
|
+
if not o_op or o_op.kind != "mul" or len(o_op.inputs) != 2:
|
|
2530
|
+
return None
|
|
2531
|
+
o_base, other = o_op.inputs # continue checking this
|
|
2532
|
+
if candidate == other:
|
|
2533
|
+
factor = (1 + o_base) if self.right else (o_base + 1)
|
|
2534
|
+
if factor.is_constant_static_dim():
|
|
2535
|
+
if factor.dimension == 0:
|
|
2536
|
+
return factor
|
|
2537
|
+
factor = factor.dimension
|
|
2538
|
+
return factor * candidate
|
|
2539
|
+
if c_op and c_op.kind == "mul" and len(c_op.inputs) == 2 and c_op.inputs[1] == other:
|
|
2540
|
+
factor = (c_op.inputs[0] + o_base) if self.right else (o_base + c_op.inputs[0])
|
|
2541
|
+
if factor.is_constant_static_dim():
|
|
2542
|
+
if factor.dimension == 0:
|
|
2543
|
+
return factor
|
|
2544
|
+
factor = factor.dimension
|
|
2545
|
+
return factor * c_op.inputs[1]
|
|
2546
|
+
return None
|
|
2547
|
+
|
|
2548
|
+
def search_and_maybe_replace(self) -> Optional[Dim]:
|
|
2549
|
+
"""search"""
|
|
2550
|
+
cur = self.start
|
|
2551
|
+
depth = 0
|
|
2552
|
+
history = []
|
|
2553
|
+
while True:
|
|
2554
|
+
res_cur = self._check_and_maybe_replace(cur)
|
|
2555
|
+
if res_cur:
|
|
2556
|
+
if depth > 0 and res_cur.is_constant_static_dim() and res_cur.dimension == 0:
|
|
2557
|
+
res_cur = history.pop(-1)
|
|
2558
|
+
res = res_cur
|
|
2559
|
+
for h in reversed(history):
|
|
2560
|
+
res = _math_get_dim_via_bin_op([h, res] if self.right else [res, h], "add")
|
|
2561
|
+
return res
|
|
2562
|
+
depth += 1
|
|
2563
|
+
if depth > self.max_depth:
|
|
2564
|
+
return None
|
|
2565
|
+
op = cur.derived_from_op
|
|
2566
|
+
if not op or op.kind != "add" or len(op.inputs) != 2:
|
|
2567
|
+
return None
|
|
2568
|
+
cur = op.inputs[1 if self.right else 0]
|
|
2569
|
+
hist = op.inputs[0 if self.right else 1]
|
|
2570
|
+
history.append(hist)
|
|
2571
|
+
|
|
2572
|
+
|
|
2573
|
+
def _math_find_matching_mult(start: Dim, other: Union[int, Dim], *, right: bool) -> Optional[Dim]:
|
|
2574
|
+
if (isinstance(other, int) or other.is_constant_static_dim()) and start.is_constant_static_dim():
|
|
2575
|
+
return _math_get_dim_via_bin_op([start, other] if right else [other, start], "mul")
|
|
2576
|
+
c_op = start.derived_from_op
|
|
2577
|
+
if c_op and c_op.kind == "mul" and len(c_op.inputs) == 2:
|
|
2578
|
+
if right:
|
|
2579
|
+
return c_op.inputs[0] * (c_op.inputs[1] * other)
|
|
2580
|
+
else:
|
|
2581
|
+
return (other * c_op.inputs[0]) * c_op.inputs[1]
|
|
2582
|
+
return None
|
|
2583
|
+
|
|
2584
|
+
|
|
2585
|
+
_DivKindToMeth: Dict[str, Callable[[Dim, Dim], Dim]] = {
|
|
2586
|
+
"truediv": _DimMixin.div_right,
|
|
2587
|
+
"truediv_left": _DimMixin.div_left,
|
|
2588
|
+
"ceildiv": _DimMixin.ceildiv_right,
|
|
2589
|
+
"ceildiv_left": _DimMixin.ceildiv_left,
|
|
2590
|
+
"floordiv": _DimMixin.__floordiv__,
|
|
2591
|
+
}
|
|
2592
|
+
|
|
2593
|
+
|
|
2594
|
+
def _math_find_matching_div(start: Dim, other: Union[int, Dim], *, right: bool, kind: str) -> Optional[Dim]:
|
|
2595
|
+
if (isinstance(other, int) or other.is_constant_static_dim()) and start.is_constant_static_dim():
|
|
2596
|
+
return _math_get_dim_via_bin_op([start, other] if right else [other, start], kind)
|
|
2597
|
+
c_op = start.derived_from_op
|
|
2598
|
+
if c_op and c_op.kind == kind and len(c_op.inputs) == 2:
|
|
2599
|
+
meth = _DivKindToMeth[kind]
|
|
2600
|
+
return meth(c_op.inputs[0], c_op.inputs[1] * other if right else other * c_op.inputs[1])
|
|
2601
|
+
return None
|
|
2384
2602
|
|
|
2385
2603
|
|
|
2386
2604
|
class Op:
|
|
@@ -2437,457 +2655,7 @@ def _get_description(dim, brackets=True):
|
|
|
2437
2655
|
return "unnamed_%s_dim%s" % (dim.kind, dim.dimension if dim.dimension is not None else "?")
|
|
2438
2656
|
|
|
2439
2657
|
|
|
2440
|
-
|
|
2441
|
-
"""
|
|
2442
|
-
represents sth like a * b * c
|
|
2443
|
-
"""
|
|
2444
|
-
|
|
2445
|
-
@classmethod
|
|
2446
|
-
def from_dim(cls, dim: Dim) -> _OpMultTerm:
|
|
2447
|
-
"""
|
|
2448
|
-
:param dim:
|
|
2449
|
-
:return: op mult term
|
|
2450
|
-
"""
|
|
2451
|
-
dim = dim.get_same_base()
|
|
2452
|
-
if dim.dimension == 1 and dim.is_constant_static_dim():
|
|
2453
|
-
return cls.one()
|
|
2454
|
-
if dim.derived_from_op and dim.derived_from_op.kind == "mul":
|
|
2455
|
-
return cls(list(dim.derived_from_op.inputs))
|
|
2456
|
-
return cls([dim])
|
|
2457
|
-
|
|
2458
|
-
@classmethod
|
|
2459
|
-
def from_dim_factors(cls, dims: List[Dim]) -> _OpMultTerm:
|
|
2460
|
-
"""from dim factors"""
|
|
2461
|
-
res = cls.one()
|
|
2462
|
-
for d in dims:
|
|
2463
|
-
res.extend_mul_div_(d, kind="mul", right=True)
|
|
2464
|
-
return res
|
|
2465
|
-
|
|
2466
|
-
@classmethod
|
|
2467
|
-
def one(cls) -> _OpMultTerm:
|
|
2468
|
-
"""1"""
|
|
2469
|
-
return cls([])
|
|
2470
|
-
|
|
2471
|
-
def __init__(self, terms: List[Dim]):
|
|
2472
|
-
self.terms = terms
|
|
2473
|
-
|
|
2474
|
-
def __hash__(self):
|
|
2475
|
-
return hash(tuple(self.terms))
|
|
2476
|
-
|
|
2477
|
-
def __eq__(self, other):
|
|
2478
|
-
"""
|
|
2479
|
-
:param _OpMultTerm other:
|
|
2480
|
-
"""
|
|
2481
|
-
if isinstance(other, _OpMultTerm):
|
|
2482
|
-
return self.terms == other.terms
|
|
2483
|
-
return False
|
|
2484
|
-
|
|
2485
|
-
def __ne__(self, other):
|
|
2486
|
-
return not self.__eq__(other)
|
|
2487
|
-
|
|
2488
|
-
def __repr__(self):
|
|
2489
|
-
return "Dim._OpMultTerm(%r)" % (self.terms,)
|
|
2490
|
-
|
|
2491
|
-
@property
|
|
2492
|
-
def dimension(self) -> Optional[int]:
|
|
2493
|
-
"""static dim or None"""
|
|
2494
|
-
dim = 1
|
|
2495
|
-
for part in self.terms:
|
|
2496
|
-
if part.dimension is None:
|
|
2497
|
-
return None
|
|
2498
|
-
dim *= part.dimension
|
|
2499
|
-
return dim
|
|
2500
|
-
|
|
2501
|
-
def base_term(self) -> Dim:
|
|
2502
|
-
"""base term (Dim)"""
|
|
2503
|
-
assert self.terms
|
|
2504
|
-
return self.terms[-1]
|
|
2505
|
-
|
|
2506
|
-
def is_one(self) -> bool:
|
|
2507
|
-
"""is 1"""
|
|
2508
|
-
return not self.terms
|
|
2509
|
-
|
|
2510
|
-
def is_constant_static_dim(self) -> bool:
|
|
2511
|
-
"""is constant static dim"""
|
|
2512
|
-
if not self.terms:
|
|
2513
|
-
return True
|
|
2514
|
-
return all(term.is_constant_static_dim() for term in self.terms)
|
|
2515
|
-
|
|
2516
|
-
def copy(self) -> _OpMultTerm:
|
|
2517
|
-
"""copy"""
|
|
2518
|
-
return _OpMultTerm(list(self.terms))
|
|
2519
|
-
|
|
2520
|
-
def negative(self) -> _OpMultTerm:
|
|
2521
|
-
"""negative"""
|
|
2522
|
-
if self.terms and self.terms[0].is_constant_static_dim() and self.terms[0].dimension == -1:
|
|
2523
|
-
return _OpMultTerm(self.terms[1:])
|
|
2524
|
-
res = self.copy()
|
|
2525
|
-
res.extend_mul_div_(_make_constant_static_dim(-1), kind="mul", right=False)
|
|
2526
|
-
return res
|
|
2527
|
-
|
|
2528
|
-
def divisible(self, other, right):
|
|
2529
|
-
"""
|
|
2530
|
-
:param Dim other:
|
|
2531
|
-
:param bool right:
|
|
2532
|
-
:return: whether we can divide other, without remainder
|
|
2533
|
-
:rtype: bool
|
|
2534
|
-
"""
|
|
2535
|
-
if not self.terms:
|
|
2536
|
-
return False
|
|
2537
|
-
if other.derived_from_op and other.derived_from_op.kind == "mul":
|
|
2538
|
-
tmp = self.copy()
|
|
2539
|
-
for term in other.derived_from_op.inputs if right else reversed(other.derived_from_op.inputs):
|
|
2540
|
-
if not tmp.divisible(term, right=right):
|
|
2541
|
-
return False
|
|
2542
|
-
tmp.extend_mul_div_(term, kind="truediv", right=right)
|
|
2543
|
-
return True
|
|
2544
|
-
most_recent_term = self.terms[-1 if right else 0]
|
|
2545
|
-
if other == most_recent_term:
|
|
2546
|
-
return True
|
|
2547
|
-
if most_recent_term.dimension is not None and other.dimension is not None:
|
|
2548
|
-
if most_recent_term.dimension % other.dimension == 0:
|
|
2549
|
-
return True
|
|
2550
|
-
return False
|
|
2551
|
-
|
|
2552
|
-
def can_simplify(self, other, kind, right):
|
|
2553
|
-
"""
|
|
2554
|
-
:param Dim other:
|
|
2555
|
-
:param str kind:
|
|
2556
|
-
:param bool right:
|
|
2557
|
-
:return: whether we can simplify when applying this operation
|
|
2558
|
-
:rtype: bool
|
|
2559
|
-
"""
|
|
2560
|
-
if other.derived_from_op and other.derived_from_op.kind == "mul":
|
|
2561
|
-
tmp = self.copy()
|
|
2562
|
-
for term in other.derived_from_op.inputs if right else reversed(other.derived_from_op.inputs):
|
|
2563
|
-
if not tmp.can_simplify(term, kind=kind, right=right):
|
|
2564
|
-
return False
|
|
2565
|
-
tmp.extend_mul_div_(term, kind=kind, right=right)
|
|
2566
|
-
return True
|
|
2567
|
-
idx = self._simplify_term_idx(other, kind=kind, right=right)
|
|
2568
|
-
return idx is not None
|
|
2569
|
-
|
|
2570
|
-
def _simplify_term_idx(self, other, kind, right):
|
|
2571
|
-
"""
|
|
2572
|
-
:param Dim other:
|
|
2573
|
-
:param str kind:
|
|
2574
|
-
:param bool right:
|
|
2575
|
-
:return: index of term to simplify
|
|
2576
|
-
:rtype: int|None
|
|
2577
|
-
"""
|
|
2578
|
-
if not self.terms:
|
|
2579
|
-
return None
|
|
2580
|
-
if kind == "mul":
|
|
2581
|
-
# We want (b * a) // b != a.
|
|
2582
|
-
# However, we want h * (2 * a // h) == 2 * a.
|
|
2583
|
-
# So, for `mul`, and only for `mul`, check all terms, whether we can simplify some division-term.
|
|
2584
|
-
for i, term in reversed(list(enumerate(self.terms))) if right else enumerate(self.terms):
|
|
2585
|
-
assert isinstance(term, _d.Dim)
|
|
2586
|
-
if term.derived_from_op:
|
|
2587
|
-
if term.derived_from_op.kind == "truediv_" + ("right" if right else "left"):
|
|
2588
|
-
if term.derived_from_op.inputs[-1] == other:
|
|
2589
|
-
return i
|
|
2590
|
-
if other.derived_from_op:
|
|
2591
|
-
if other.derived_from_op.kind == "truediv_" + ("right" if not right else "left"):
|
|
2592
|
-
if other.derived_from_op.inputs[-1] == term:
|
|
2593
|
-
return i
|
|
2594
|
-
if term.is_constant_static_dim() and other.is_constant_static_dim():
|
|
2595
|
-
return i
|
|
2596
|
-
# For the last/first term, extra checks.
|
|
2597
|
-
i = len(self.terms) - 1 if right else 0
|
|
2598
|
-
term = self.terms[i]
|
|
2599
|
-
if kind.endswith("div") and other == term:
|
|
2600
|
-
return i
|
|
2601
|
-
op_kind = kind + "_" + ("right" if right else "left")
|
|
2602
|
-
if term.derived_from_op and term.derived_from_op.kind == op_kind:
|
|
2603
|
-
return i
|
|
2604
|
-
return None
|
|
2605
|
-
|
|
2606
|
-
def extend_mul_div_(self, other, kind, right):
|
|
2607
|
-
"""
|
|
2608
|
-
:param Dim other:
|
|
2609
|
-
:param str kind:
|
|
2610
|
-
:param bool right:
|
|
2611
|
-
"""
|
|
2612
|
-
assert kind in {"mul", "floordiv", "truediv", "ceildiv"}
|
|
2613
|
-
if other.is_constant_static_dim() and other.dimension == 1:
|
|
2614
|
-
return
|
|
2615
|
-
if not self.terms:
|
|
2616
|
-
if kind == "mul":
|
|
2617
|
-
self.terms.append(other)
|
|
2618
|
-
elif kind.endswith("div"):
|
|
2619
|
-
self.terms = [_OpMultTerm.new_div_dim(self.as_dim(), other, kind=kind, right=right)]
|
|
2620
|
-
return
|
|
2621
|
-
if other.derived_from_op and other.derived_from_op.kind == "mul":
|
|
2622
|
-
for term in other.derived_from_op.inputs if right else reversed(other.derived_from_op.inputs):
|
|
2623
|
-
self.extend_mul_div_(term, kind=kind, right=right)
|
|
2624
|
-
return
|
|
2625
|
-
idx = self._simplify_term_idx(other, kind=kind, right=right)
|
|
2626
|
-
if idx is not None:
|
|
2627
|
-
term = self.terms[idx]
|
|
2628
|
-
assert isinstance(term, _d.Dim)
|
|
2629
|
-
if kind.endswith("div") and other == term:
|
|
2630
|
-
self.terms.pop(idx)
|
|
2631
|
-
return
|
|
2632
|
-
if kind == "mul" and term.derived_from_op:
|
|
2633
|
-
if term.derived_from_op.kind == "truediv_" + ("right" if right else "left"):
|
|
2634
|
-
if term.derived_from_op.inputs[-1] == other:
|
|
2635
|
-
self.terms[idx] = term.derived_from_op.inputs[0]
|
|
2636
|
-
return
|
|
2637
|
-
if kind == "mul" and other.derived_from_op:
|
|
2638
|
-
if other.derived_from_op.kind == "truediv_" + ("right" if not right else "left"):
|
|
2639
|
-
if other.derived_from_op.inputs[-1] == term:
|
|
2640
|
-
self.terms[idx] = other.derived_from_op.inputs[0]
|
|
2641
|
-
return
|
|
2642
|
-
if term.is_constant_static_dim() and other.is_constant_static_dim():
|
|
2643
|
-
if kind == "mul":
|
|
2644
|
-
if term.dimension * other.dimension == 1:
|
|
2645
|
-
self.terms.pop(idx)
|
|
2646
|
-
return
|
|
2647
|
-
self.terms[idx] = _make_constant_static_dim(term.dimension * other.dimension, kind=term.kind)
|
|
2648
|
-
return
|
|
2649
|
-
if kind.endswith("div") and term.dimension % other.dimension == 0:
|
|
2650
|
-
self.terms[idx] = _make_constant_static_dim(term.dimension // other.dimension, kind=term.kind)
|
|
2651
|
-
return
|
|
2652
|
-
# Fallback with generic handling.
|
|
2653
|
-
op_kind = kind + "_" + ("right" if right else "left")
|
|
2654
|
-
if kind.endswith("div") and term.derived_from_op and term.derived_from_op.kind == op_kind:
|
|
2655
|
-
numerator = term.derived_from_op.inputs[0]
|
|
2656
|
-
denominator = term.derived_from_op.inputs[1]
|
|
2657
|
-
self.terms[idx] = _OpMultTerm.new_div_dim(numerator, denominator * other, kind=kind, right=right)
|
|
2658
|
-
return
|
|
2659
|
-
if kind.endswith("div"):
|
|
2660
|
-
self.terms = [_OpMultTerm.new_div_dim(self.as_dim(), other, kind=kind, right=right)]
|
|
2661
|
-
return
|
|
2662
|
-
if kind == "mul":
|
|
2663
|
-
if right:
|
|
2664
|
-
self.terms.append(other)
|
|
2665
|
-
else:
|
|
2666
|
-
self.terms.insert(0, other)
|
|
2667
|
-
return
|
|
2668
|
-
assert False
|
|
2669
|
-
|
|
2670
|
-
@classmethod
|
|
2671
|
-
def new_div_dim(cls, numerator, denominator, kind, right):
|
|
2672
|
-
"""
|
|
2673
|
-
:param Dim numerator:
|
|
2674
|
-
:param Dim denominator:
|
|
2675
|
-
:param str kind: "floordiv" or "ceildiv" or "truediv"
|
|
2676
|
-
:param bool right:
|
|
2677
|
-
:rtype: Dim
|
|
2678
|
-
"""
|
|
2679
|
-
dim_value = None
|
|
2680
|
-
a = numerator.dimension
|
|
2681
|
-
b = denominator.dimension
|
|
2682
|
-
if a is not None and b is not None:
|
|
2683
|
-
if kind == "floordiv":
|
|
2684
|
-
dim_value = a // b
|
|
2685
|
-
elif kind == "ceildiv":
|
|
2686
|
-
dim_value = -(-a // b)
|
|
2687
|
-
if a % b == 0 and right:
|
|
2688
|
-
kind = "floordiv" # for nicer description, and does not matter
|
|
2689
|
-
elif kind == "truediv":
|
|
2690
|
-
if a % b != 0:
|
|
2691
|
-
raise ValueError(
|
|
2692
|
-
"%s truediv %s only allowed if the result is an integer" % (numerator, denominator)
|
|
2693
|
-
)
|
|
2694
|
-
dim_value = a // b
|
|
2695
|
-
if right:
|
|
2696
|
-
kind = "floordiv" # for nicer description, and does not matter
|
|
2697
|
-
else:
|
|
2698
|
-
raise ValueError("invalid kind %r" % (kind,))
|
|
2699
|
-
# noinspection PyProtectedMember
|
|
2700
|
-
cache, cache_key, res = numerator._cache_dim_math_get(kind + ("" if right else "_left"), denominator)
|
|
2701
|
-
if res:
|
|
2702
|
-
return res
|
|
2703
|
-
if kind == "floordiv" and right:
|
|
2704
|
-
description = "%s//%s" % (_get_description(numerator), _get_description(denominator))
|
|
2705
|
-
elif kind == "ceildiv" and right:
|
|
2706
|
-
description = "⌈%s/%s⌉" % (_get_description(numerator), _get_description(denominator))
|
|
2707
|
-
else:
|
|
2708
|
-
description = "%s_%s(%s, %s)" % (
|
|
2709
|
-
kind,
|
|
2710
|
-
"right" if right else "left",
|
|
2711
|
-
_get_description(numerator, brackets=False),
|
|
2712
|
-
_get_description(denominator, brackets=False),
|
|
2713
|
-
)
|
|
2714
|
-
op_kind = kind
|
|
2715
|
-
if a is not None and b is not None and a % b == 0:
|
|
2716
|
-
op_kind = "truediv" # makes some other checks simpler
|
|
2717
|
-
op_kind += "_" + ("right" if right else "left")
|
|
2718
|
-
res = _d.Dim(
|
|
2719
|
-
description=description,
|
|
2720
|
-
kind=numerator.kind,
|
|
2721
|
-
dimension=dim_value,
|
|
2722
|
-
derived_from_op=Op(kind=op_kind, inputs=[numerator, denominator]),
|
|
2723
|
-
derived_from_tag=numerator,
|
|
2724
|
-
)
|
|
2725
|
-
cache[cache_key] = res
|
|
2726
|
-
return res
|
|
2727
|
-
|
|
2728
|
-
def as_dim(self):
|
|
2729
|
-
"""
|
|
2730
|
-
:rtype: Dim
|
|
2731
|
-
"""
|
|
2732
|
-
if self.is_one():
|
|
2733
|
-
return _make_constant_static_dim(1)
|
|
2734
|
-
if len(self.terms) == 1:
|
|
2735
|
-
return self.terms[0]
|
|
2736
|
-
res = self.terms[0]
|
|
2737
|
-
for operand in self.terms[1:]:
|
|
2738
|
-
res = _math_get_dim_via_bin_op(res, operand, "mul")
|
|
2739
|
-
return res
|
|
2740
|
-
|
|
2741
|
-
|
|
2742
|
-
class _OpLinearTerm:
|
|
2743
|
-
"""
|
|
2744
|
-
Linear combination of :class:`_OpMultTerm`.
|
|
2745
|
-
Represents sth like a * b + c of :class:`Dim`.
|
|
2746
|
-
"""
|
|
2747
|
-
|
|
2748
|
-
@classmethod
|
|
2749
|
-
def from_dim(cls, dim: Dim) -> _OpLinearTerm:
|
|
2750
|
-
"""from dim"""
|
|
2751
|
-
res = cls.zero()
|
|
2752
|
-
res.extend_add_sub_(dim, kind="add", right=True)
|
|
2753
|
-
return res
|
|
2754
|
-
|
|
2755
|
-
@classmethod
|
|
2756
|
-
def zero(cls) -> _OpLinearTerm:
|
|
2757
|
-
"""0"""
|
|
2758
|
-
return _OpLinearTerm([])
|
|
2759
|
-
|
|
2760
|
-
def __init__(self, terms: List[_OpMultTerm]):
|
|
2761
|
-
self.terms = terms
|
|
2762
|
-
|
|
2763
|
-
def __hash__(self):
|
|
2764
|
-
return hash(tuple(self.terms))
|
|
2765
|
-
|
|
2766
|
-
def __eq__(self, other):
|
|
2767
|
-
if isinstance(other, _OpLinearTerm):
|
|
2768
|
-
return self.terms == other.terms
|
|
2769
|
-
return False
|
|
2770
|
-
|
|
2771
|
-
def __ne__(self, other):
|
|
2772
|
-
return not self.__eq__(other)
|
|
2773
|
-
|
|
2774
|
-
def as_dim(self) -> Dim:
|
|
2775
|
-
"""as dim"""
|
|
2776
|
-
if self.is_zero():
|
|
2777
|
-
return _make_constant_static_dim(0)
|
|
2778
|
-
if len(self.terms) == 1:
|
|
2779
|
-
return self.terms[0].as_dim()
|
|
2780
|
-
res = self.terms[0].as_dim()
|
|
2781
|
-
for operand in self.terms[1:]:
|
|
2782
|
-
res = _math_get_dim_via_bin_op(res, operand.as_dim(), "add")
|
|
2783
|
-
return res
|
|
2784
|
-
|
|
2785
|
-
def __repr__(self):
|
|
2786
|
-
return "Dim._OpLinearTerm(%r)" % (self.terms,)
|
|
2787
|
-
|
|
2788
|
-
def is_zero(self):
|
|
2789
|
-
"""
|
|
2790
|
-
:rtype: bool
|
|
2791
|
-
"""
|
|
2792
|
-
return not self.terms
|
|
2793
|
-
|
|
2794
|
-
def extend_add_sub_(self, other, kind, right):
|
|
2795
|
-
"""
|
|
2796
|
-
:param Dim|int other:
|
|
2797
|
-
:param str kind: "add" or "sub"
|
|
2798
|
-
:param bool right: or left. right means self + other, left means other + self
|
|
2799
|
-
"""
|
|
2800
|
-
assert kind in {"add", "sub"}
|
|
2801
|
-
other = self._make_dim(other, kind=kind)
|
|
2802
|
-
if other.is_constant_static_dim() and other.dimension == 0:
|
|
2803
|
-
return
|
|
2804
|
-
if other.derived_from_op and other.derived_from_op.kind == "add":
|
|
2805
|
-
for other_ in other.derived_from_op.inputs if right else reversed(other.derived_from_op.inputs):
|
|
2806
|
-
self.extend_add_sub_(other_, kind=kind, right=right)
|
|
2807
|
-
return
|
|
2808
|
-
term = _OpMultTerm.from_dim(other)
|
|
2809
|
-
neg_term = term.negative()
|
|
2810
|
-
if kind == "sub":
|
|
2811
|
-
term, neg_term = neg_term, term
|
|
2812
|
-
most_recent_term = self.terms[-1 if right else 0] if self.terms else None
|
|
2813
|
-
if most_recent_term:
|
|
2814
|
-
if most_recent_term == neg_term:
|
|
2815
|
-
self.terms.pop(-1 if right else 0)
|
|
2816
|
-
return
|
|
2817
|
-
if most_recent_term.is_constant_static_dim() and term.is_constant_static_dim():
|
|
2818
|
-
self.terms[-1 if right else 0] = _OpMultTerm.from_dim(
|
|
2819
|
-
_make_constant_static_dim(most_recent_term.dimension + term.dimension, kind=other.kind)
|
|
2820
|
-
)
|
|
2821
|
-
return
|
|
2822
|
-
if most_recent_term.terms and term.terms and most_recent_term.terms[-1] == term.terms[-1]:
|
|
2823
|
-
# Merge terms
|
|
2824
|
-
a = _OpMultTerm.from_dim_factors(most_recent_term.terms[:-1]).as_dim()
|
|
2825
|
-
b = _OpMultTerm.from_dim_factors(term.terms[:-1]).as_dim()
|
|
2826
|
-
if a.is_constant_static_dim() and not b.is_constant_static_dim():
|
|
2827
|
-
a = a.dimension
|
|
2828
|
-
elif b.is_constant_static_dim() and not a.is_constant_static_dim():
|
|
2829
|
-
b = b.dimension
|
|
2830
|
-
res = _OpMultTerm.from_dim((a + b) if right else (b + a))
|
|
2831
|
-
res.extend_mul_div_(term.terms[-1], kind="mul", right=True)
|
|
2832
|
-
self.terms[-1 if right else 0] = res
|
|
2833
|
-
return
|
|
2834
|
-
if right:
|
|
2835
|
-
self.terms.append(term)
|
|
2836
|
-
else:
|
|
2837
|
-
self.terms.insert(0, term)
|
|
2838
|
-
|
|
2839
|
-
def extend_mul_div_(self, other, kind, right):
|
|
2840
|
-
"""
|
|
2841
|
-
:param Dim|int other:
|
|
2842
|
-
:param str kind: "mul" or "ceildiv"
|
|
2843
|
-
:param bool right: or left. right means self * other, left means other * self
|
|
2844
|
-
"""
|
|
2845
|
-
assert kind in {"mul", "floordiv", "truediv", "ceildiv"}
|
|
2846
|
-
other = self._make_dim(other, kind=kind)
|
|
2847
|
-
if kind == "mul" and right:
|
|
2848
|
-
if not all(term.can_simplify(other, kind=kind, right=right) for term in self.terms):
|
|
2849
|
-
# Do it the other way around
|
|
2850
|
-
self.terms, other = _OpLinearTerm.from_dim(other).terms, self.as_dim()
|
|
2851
|
-
right = False
|
|
2852
|
-
if other.is_constant_static_dim() and other.dimension == 1:
|
|
2853
|
-
return
|
|
2854
|
-
if kind.endswith("div") and len(self.terms) >= 2:
|
|
2855
|
-
if any(not term.divisible(other, right=right) for term in self.terms):
|
|
2856
|
-
self.terms = [
|
|
2857
|
-
_OpMultTerm.from_dim(_OpMultTerm.new_div_dim(self.as_dim(), other, kind=kind, right=right))
|
|
2858
|
-
]
|
|
2859
|
-
return
|
|
2860
|
-
for term in self.terms:
|
|
2861
|
-
term.extend_mul_div_(other, kind=kind, right=right)
|
|
2862
|
-
|
|
2863
|
-
def _make_dim(self, other, kind):
|
|
2864
|
-
"""
|
|
2865
|
-
:param Dim|int other:
|
|
2866
|
-
:param str kind:
|
|
2867
|
-
:rtype: Dim
|
|
2868
|
-
"""
|
|
2869
|
-
if isinstance(other, int):
|
|
2870
|
-
base_tag = self.representative_tag()
|
|
2871
|
-
return _make_constant_static_dim(other, kind=base_tag.kind if base_tag else None)
|
|
2872
|
-
elif isinstance(other, _d.Dim):
|
|
2873
|
-
return other.get_same_base()
|
|
2874
|
-
else:
|
|
2875
|
-
raise TypeError("%s %s %s invalid for type %s" % (self, kind, other, type(other)))
|
|
2876
|
-
|
|
2877
|
-
def representative_tag(self):
|
|
2878
|
-
"""
|
|
2879
|
-
:rtype: Dim|None
|
|
2880
|
-
"""
|
|
2881
|
-
terms = [_representative_tag(term.terms) for term in self.terms]
|
|
2882
|
-
return _representative_tag([term for term in terms if term])
|
|
2883
|
-
|
|
2884
|
-
|
|
2885
|
-
def _get_merged_dim_kind(dim_tags):
|
|
2886
|
-
"""
|
|
2887
|
-
:param list[Dim]|tuple[Dim] dim_tags:
|
|
2888
|
-
:return: dim kind
|
|
2889
|
-
:rtype: Entity
|
|
2890
|
-
"""
|
|
2658
|
+
def _get_merged_dim_kind(dim_tags: Sequence[Dim]) -> Entity:
|
|
2891
2659
|
if any(tag.is_batch_dim() for tag in dim_tags):
|
|
2892
2660
|
return DimTypes.Batch
|
|
2893
2661
|
elif any(tag.is_feature_dim() for tag in dim_tags):
|
|
@@ -8,6 +8,10 @@ import os
|
|
|
8
8
|
import sys
|
|
9
9
|
import gc
|
|
10
10
|
import subprocess
|
|
11
|
+
import signal
|
|
12
|
+
import time
|
|
13
|
+
import contextlib
|
|
14
|
+
import multiprocessing
|
|
11
15
|
import torch
|
|
12
16
|
from returnn.util.better_exchook import better_exchook
|
|
13
17
|
from returnn.util.basic import human_bytes_size
|
|
@@ -26,36 +30,39 @@ def print_available_devices(*, file: Optional[TextIO] = None):
|
|
|
26
30
|
print("CUDA_VISIBLE_DEVICES is set to %r." % os.environ["CUDA_VISIBLE_DEVICES"], file=file)
|
|
27
31
|
cuda_visible_devs = dict(enumerate([int(d) for d in os.environ["CUDA_VISIBLE_DEVICES"].split(",") if d]))
|
|
28
32
|
else:
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
33
|
+
with timeout("torch.cuda.is_available()"):
|
|
34
|
+
if torch.cuda.is_available():
|
|
35
|
+
print("CUDA_VISIBLE_DEVICES is not set.", file=file)
|
|
36
|
+
|
|
37
|
+
with timeout("torch.cuda.is_available()"):
|
|
38
|
+
if not torch.cuda.is_available():
|
|
39
|
+
print("(CUDA not available)", file=file)
|
|
40
|
+
return
|
|
41
|
+
|
|
42
|
+
print("Available CUDA devices:", file=file)
|
|
43
|
+
count = torch.cuda.device_count()
|
|
44
|
+
if cuda_visible_devs is not None and len(cuda_visible_devs) != count:
|
|
45
|
+
print(
|
|
46
|
+
f"(Mismatch between CUDA device count {count}"
|
|
47
|
+
f" and CUDA_VISIBLE_DEVICES {cuda_visible_devs} count {len(cuda_visible_devs)}?)",
|
|
48
|
+
file=file,
|
|
49
|
+
)
|
|
50
|
+
for i in range(count):
|
|
51
|
+
print(f" {i + 1}/{count}: cuda:{i}", file=file)
|
|
52
|
+
props = torch.cuda.get_device_properties(i)
|
|
53
|
+
print(f" name: {props.name}", file=file)
|
|
54
|
+
print(f" total_memory: {human_bytes_size(props.total_memory)}", file=file)
|
|
55
|
+
print(f" capability: {props.major}.{props.minor}", file=file)
|
|
56
|
+
if cuda_visible_devs is not None:
|
|
57
|
+
if len(cuda_visible_devs) == count:
|
|
58
|
+
dev_idx_s = cuda_visible_devs[i]
|
|
52
59
|
else:
|
|
53
|
-
dev_idx_s =
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
print("(
|
|
60
|
+
dev_idx_s = "?"
|
|
61
|
+
else:
|
|
62
|
+
dev_idx_s = i
|
|
63
|
+
print(f" device_index: {dev_idx_s}", file=file)
|
|
64
|
+
if not count:
|
|
65
|
+
print(" (None)", file=file)
|
|
59
66
|
|
|
60
67
|
|
|
61
68
|
def print_using_cuda_device_report(dev: Union[str, torch.device], *, file: Optional[TextIO] = None):
|
|
@@ -108,7 +115,7 @@ def diagnose_no_gpu() -> List[str]:
|
|
|
108
115
|
except Exception as exc:
|
|
109
116
|
print("nvidia-smi failed:", exc)
|
|
110
117
|
better_exchook(*sys.exc_info(), debugshell=False)
|
|
111
|
-
res.append(
|
|
118
|
+
res.append("nvidia-smi failed")
|
|
112
119
|
|
|
113
120
|
return res
|
|
114
121
|
|
|
@@ -152,4 +159,31 @@ def garbage_collect():
|
|
|
152
159
|
f"alloc {human_bytes_size(torch.cuda.memory_allocated())}",
|
|
153
160
|
f"reserved {human_bytes_size(torch.cuda.memory_reserved())}",
|
|
154
161
|
]
|
|
155
|
-
print(
|
|
162
|
+
print("CUDA memory usage after triggered GC:", " ".join(stats))
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@contextlib.contextmanager
|
|
166
|
+
def timeout(info: str, *, seconds: int = 30):
|
|
167
|
+
"""
|
|
168
|
+
Note: don't use signal handlers (e.g. signal.alarm) because unfortunately
|
|
169
|
+
potential hanging funcs will block the main thread and thus block the signal handler from executing.
|
|
170
|
+
Thus, we use a subprocess.
|
|
171
|
+
|
|
172
|
+
:param seconds:
|
|
173
|
+
:param info:
|
|
174
|
+
"""
|
|
175
|
+
proc = multiprocessing.Process(
|
|
176
|
+
target=_timeout_handler, kwargs={"seconds": seconds, "proc_id": os.getpid(), "info": info}
|
|
177
|
+
)
|
|
178
|
+
proc.start()
|
|
179
|
+
try:
|
|
180
|
+
yield
|
|
181
|
+
finally:
|
|
182
|
+
proc.terminate()
|
|
183
|
+
proc.join()
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _timeout_handler(*, seconds: Union[float, int], proc_id: int, info: str):
|
|
187
|
+
time.sleep(seconds)
|
|
188
|
+
print(f"ERROR: {info}: Timeout handler after {seconds} seconds, killing proc {proc_id}.", file=sys.stderr)
|
|
189
|
+
os.kill(proc_id, signal.SIGABRT)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=GVal7eVN_obo9mfdhPK2WvH2MzSm51cFZJChHEsF2XU,5214
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=jTlsQFAqLqFgm0UJ0uWltcnLf69QwqOK0yV4Slt-2Is,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -154,7 +154,7 @@ returnn/sprint/extern_interface.py,sha256=l-v1X-Yg0UpTFe7Y3c4FwWOqpSNuv9Oy5EzqlK
|
|
|
154
154
|
returnn/sprint/interface.py,sha256=1j5SB0V8hSW8A5song9ciZtcBnZoKKfNipk9ezOIMuA,36491
|
|
155
155
|
returnn/tensor/README.md,sha256=X6BqcRLrPLPnwF9yR69uqIFrMnNluj9pBkOPHwNgzuo,501
|
|
156
156
|
returnn/tensor/__init__.py,sha256=on6j5PEOQpck50UcsR4nJzJSDmoVy34z1Oq4efv6Ax0,154
|
|
157
|
-
returnn/tensor/_dim_extra.py,sha256=
|
|
157
|
+
returnn/tensor/_dim_extra.py,sha256=rwtDR5WRS8wqgKj4WkPaWtaKa8UJYTrS76ZhX0W5bP4,115580
|
|
158
158
|
returnn/tensor/_tensor_extra.py,sha256=gbSl6HMtn8WFYloanew_RaNNwx3eCpnKv3UfCkntJiQ,164923
|
|
159
159
|
returnn/tensor/_tensor_mixin_base.py,sha256=H5z86I0NejxrSgMH1c5oXQzBqS6L9HpvP4y7oegBaSc,643
|
|
160
160
|
returnn/tensor/_tensor_op_overloads.py,sha256=HklwuTBjy7mH_665VKaCUdu-oC3aa7Uz1ZQiCz4jeZc,5448
|
|
@@ -227,7 +227,7 @@ returnn/torch/util/README.md,sha256=AW-6ueWhgcwDcm57md6sm227QXNkvLnlRLwaH7NlS-w,
|
|
|
227
227
|
returnn/torch/util/__init__.py,sha256=AOXYUjzPm0XrzFJCPAXo9Jj_FvqD1XH3FfKtho80Vl8,26
|
|
228
228
|
returnn/torch/util/array_.py,sha256=ell3VZvn01SLtF9Pw2fvPzFNO-XDQ7tSB9VCrVSKmSA,2556
|
|
229
229
|
returnn/torch/util/debug_inf_nan.py,sha256=fmzSSTJJyLf7i5yDWRHLeDI0gxvadeqLE8RxMuSHx_4,6398
|
|
230
|
-
returnn/torch/util/diagnose_gpu.py,sha256=
|
|
230
|
+
returnn/torch/util/diagnose_gpu.py,sha256=_yswLmwR8Q2rCsv2jI5FUQNBT__453jBmiWYwazdu20,6808
|
|
231
231
|
returnn/torch/util/exception_helper.py,sha256=_SqxTD5F-GDY2eR4uRALyUTJwt0ytcbJGB_w38RJMBA,4320
|
|
232
232
|
returnn/torch/util/gradient_checkpoint.py,sha256=iLy-FB65DC8O6LxzmMvFjnSdpIVpko87ppIvRKAbtpQ,27995
|
|
233
233
|
returnn/torch/util/module.py,sha256=MXHIrF9Isu575DDJIa81212ULKwdqu1oOLxDVZecVSk,1693
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250902.10950.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250902.10950.dist-info/METADATA,sha256=GVal7eVN_obo9mfdhPK2WvH2MzSm51cFZJChHEsF2XU,5214
|
|
258
|
+
returnn-1.20250902.10950.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
259
|
+
returnn-1.20250902.10950.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250902.10950.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|