returnn 1.20251027.224345__py3-none-any.whl → 1.20260109.93428__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

Files changed (37) hide show
  1. returnn/PKG-INFO +2 -2
  2. returnn/_setup_info_generated.py +2 -2
  3. returnn/config.py +1 -1
  4. returnn/datasets/lm.py +20 -0
  5. returnn/datasets/meta.py +93 -43
  6. returnn/datasets/postprocessing.py +597 -108
  7. returnn/datasets/util/vocabulary.py +90 -0
  8. returnn/frontend/array_.py +46 -0
  9. returnn/frontend/attention.py +54 -20
  10. returnn/frontend/conv.py +273 -54
  11. returnn/frontend/device.py +14 -1
  12. returnn/frontend/encoder/conformer.py +20 -0
  13. returnn/frontend/encoder/transformer.py +2 -0
  14. returnn/frontend/loss.py +40 -1
  15. returnn/frontend/math_.py +54 -14
  16. returnn/native_op.cpp +80 -0
  17. returnn/sprint/cache.py +12 -13
  18. returnn/tensor/utils.py +7 -4
  19. returnn/tf/frontend_layers/_backend.py +4 -3
  20. returnn/tf/layers/basic.py +15 -39
  21. returnn/tf/native_op.py +11 -58
  22. returnn/tf/network.py +1 -1
  23. returnn/tf/util/basic.py +19 -0
  24. returnn/torch/engine.py +37 -3
  25. returnn/torch/frontend/_backend.py +135 -13
  26. returnn/torch/frontend/bridge.py +61 -0
  27. returnn/torch/util/exception_helper.py +7 -1
  28. returnn/util/basic.py +3 -6
  29. returnn/util/better_exchook.py +4 -0
  30. returnn/util/debug.py +11 -2
  31. returnn/util/file_cache.py +15 -1
  32. returnn/util/task_system.py +1 -1
  33. {returnn-1.20251027.224345.dist-info → returnn-1.20260109.93428.dist-info}/METADATA +2 -2
  34. {returnn-1.20251027.224345.dist-info → returnn-1.20260109.93428.dist-info}/RECORD +37 -37
  35. {returnn-1.20251027.224345.dist-info → returnn-1.20260109.93428.dist-info}/LICENSE +0 -0
  36. {returnn-1.20251027.224345.dist-info → returnn-1.20260109.93428.dist-info}/WHEEL +0 -0
  37. {returnn-1.20251027.224345.dist-info → returnn-1.20260109.93428.dist-info}/top_level.txt +0 -0
returnn/native_op.cpp CHANGED
@@ -206,6 +206,14 @@ Ndarray* Ndarray_Copy(const Ndarray* self) {
206
206
 
207
207
  #include "tensorflow/core/public/version.h"
208
208
 
209
+ #ifndef TF_MAJOR_VERSION
210
+ #error "TF_MAJOR_VERSION is not defined!"
211
+ #endif
212
+
213
+ #ifndef TF_MINOR_VERSION
214
+ #error "TF_MINOR_VERSION is not defined!"
215
+ #endif
216
+
209
217
  #if (TF_MAJOR_VERSION == 1 && TF_MINOR_VERSION >= 6) || (TF_MAJOR_VERSION > 1)
210
218
  #define TF_issue_6602_workaround 0
211
219
  #define TWOD_LSTM_SUPPORT 1
@@ -402,6 +410,9 @@ static void tf_cuda_sgemm_batched(
402
410
 
403
411
 
404
412
  #else // CUDA
413
+
414
+ #ifdef HAVE_CUSTOM_BLAS
415
+
405
416
  /*
406
417
  // matrices are in column-major form
407
418
  int sgemm_(char *transa, char *transb,
@@ -419,6 +430,75 @@ static void tf_cuda_sgemm_batched(
419
430
  sgemm_(&transa, &transb, \
420
431
  &m_, &n_, &k_, alpha, A, &lda_, B, &ldb_, beta, C, &ldc_); \
421
432
  }
433
+
434
+ #else // HAVE_CUSTOM_BLAS
435
+
436
+ template<typename T>
437
+ static void tf_cpu_sgemm(
438
+ OpKernelContext* context,
439
+ char transa_, char transb_,
440
+ int m, int n, int k,
441
+ const T* alpha_ptr, const T* a_ptr, int lda,
442
+ const T* b_ptr, int ldb, const T* beta_ptr,
443
+ T* c_ptr, int ldc)
444
+ {
445
+ if (m <= 0 || n <= 0 || k <= 0) return;
446
+
447
+ auto d = context->eigen_cpu_device();
448
+ const T alpha = *alpha_ptr;
449
+ const T beta = *beta_ptr;
450
+
451
+ bool transa = (transa_ == 'T' || transa_ == 't' || transa_ == 'C' || transa_ == 'c');
452
+ bool transb = (transb_ == 'T' || transb_ == 't' || transb_ == 'C' || transb_ == 'c');
453
+
454
+ // 1. Map as COLUMN-MAJOR
455
+ // Physical rows (height) for the Map is always the leading dimension (lda, ldb, ldc)
456
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::ColMajor>, Eigen::Unaligned> ConstMap;
457
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::ColMajor>, Eigen::Unaligned> MutableMap;
458
+
459
+ // Logical height/width of slices before any transposition
460
+ int a_slice_rows = transa ? k : m;
461
+ int a_slice_cols = transa ? m : k;
462
+ int b_slice_rows = transb ? n : k;
463
+ int b_slice_cols = transb ? k : n;
464
+
465
+ // Map and Slice
466
+ auto a = ConstMap(a_ptr, lda, a_slice_cols).slice(
467
+ Eigen::array<Eigen::Index, 2>({0, 0}),
468
+ Eigen::array<Eigen::Index, 2>({(Eigen::Index)a_slice_rows, (Eigen::Index)a_slice_cols}));
469
+
470
+ auto b = ConstMap(b_ptr, ldb, b_slice_cols).slice(
471
+ Eigen::array<Eigen::Index, 2>({0, 0}),
472
+ Eigen::array<Eigen::Index, 2>({(Eigen::Index)b_slice_rows, (Eigen::Index)b_slice_cols}));
473
+
474
+ auto c = MutableMap(c_ptr, ldc, n).slice(
475
+ Eigen::array<Eigen::Index, 2>({0, 0}),
476
+ Eigen::array<Eigen::Index, 2>({(Eigen::Index)m, (Eigen::Index)n}));
477
+
478
+ // 2. Define Contraction Pairs based on Transposition
479
+ // Column-Major Matrix Mult: (M x K) * (K x N)
480
+ // Standard: Contract Axis 1 of A with Axis 0 of B
481
+ // If A is Transposed: A is (K x M), contract Axis 0 of A
482
+ // If B is Transposed: B is (N x K), contract Axis 1 of B
483
+ Eigen::array<Eigen::IndexPair<int>, 1> pairs;
484
+ pairs[0] = Eigen::IndexPair<int>(transa ? 0 : 1, transb ? 1 : 0);
485
+
486
+ // 3. Execution
487
+ if (alpha == T(1) && beta == T(0)) {
488
+ c.device(d) = a.contract(b, pairs);
489
+ } else if (alpha == T(1) && beta == T(1)) {
490
+ c.device(d) += a.contract(b, pairs);
491
+ } else {
492
+ c.device(d) = a.contract(b, pairs) * alpha + c * beta;
493
+ }
494
+ }
495
+
496
+ #define Ndarray_sgemm(\
497
+ transpose_A, transpose_B, \
498
+ m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) \
499
+ tf_cpu_sgemm<float>(context, transpose_A, transpose_B, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
500
+
501
+ #endif // HAVE_CUSTOM_BLAS
422
502
  #endif // CUDA
423
503
 
424
504
  // See Context struct below.
returnn/sprint/cache.py CHANGED
@@ -7,10 +7,9 @@ This module is about reading (maybe later also writing) the Sprint archive forma
7
7
  """
8
8
 
9
9
  from __future__ import annotations
10
- from typing import List, Optional, Tuple
10
+ from typing import Optional, List, Tuple, Dict
11
11
  import sys
12
12
  import os
13
- import typing
14
13
  import array
15
14
  from struct import pack, unpack
16
15
  import numpy
@@ -212,7 +211,7 @@ class FileArchive:
212
211
  def __init__(self, filename, must_exists=True, encoding="ascii"):
213
212
  self.encoding = encoding
214
213
 
215
- self.ft = {} # type: typing.Dict[str,FileInfo]
214
+ self.ft: Dict[str, FileInfo] = {}
216
215
  if os.path.exists(filename):
217
216
  self.allophones = []
218
217
  self.f = open(filename, "rb")
@@ -334,8 +333,8 @@ class FileArchive:
334
333
  # print(typ)
335
334
  assert type_ == "vector-f32"
336
335
  count = self.read_U32()
337
- data = [None] * count # type: typing.List[typing.Optional[numpy.ndarray]]
338
- time_ = [None] * count # type: typing.List[typing.Optional[numpy.ndarray]]
336
+ data: List[Optional[numpy.ndarray]] = [None] * count
337
+ time_: List[Optional[numpy.ndarray]] = [None] * count
339
338
  for i in range(count):
340
339
  size = self.read_U32()
341
340
  data[i] = self.read_v("f", size) # size x f32
@@ -450,7 +449,7 @@ class FileArchive:
450
449
  a = array.array("b")
451
450
  a.fromfile(self.f, comp)
452
451
  # unpack
453
- b = zlib.decompress(a.tostring(), 15 + 32)
452
+ b = zlib.decompress(a.tobytes(), 15 + 32)
454
453
  # substitute self.f by an anonymous memmap file object
455
454
  # restore original file handle after we're done
456
455
  backup_f = self.f
@@ -575,17 +574,17 @@ class FileArchiveBundle:
575
574
  :param str encoding: encoding used in the files
576
575
  """
577
576
  # filename -> FileArchive
578
- self.archives = {} # type: typing.Dict[str,FileArchive]
577
+ self.archives: Dict[str, FileArchive] = {}
579
578
  # archive content file -> FileArchive
580
- self.files = {} # type: typing.Dict[str,FileArchive]
579
+ self.files: Dict[str, FileArchive] = {}
581
580
  self._short_seg_names = {}
582
581
  if filename is not None:
583
582
  self.add_bundle(filename=filename, encoding=encoding)
584
583
 
585
- def add_bundle(self, filename, encoding="ascii"):
584
+ def add_bundle(self, filename: str, encoding: str = "ascii"):
586
585
  """
587
- :param str filename: bundle
588
- :param str encoding:
586
+ :param filename: bundle
587
+ :param encoding:
589
588
  """
590
589
  file_dir = os.path.dirname(filename) or "."
591
590
  for line in open(filename).read().splitlines():
@@ -837,7 +836,7 @@ class MixtureSet:
837
836
  """
838
837
  a = array.array("b")
839
838
  a.fromfile(self.f, length)
840
- return a.tostring().decode(encoding)
839
+ return a.tobytes().decode(encoding)
841
840
 
842
841
  def read_f32(self):
843
842
  """
@@ -1003,7 +1002,7 @@ class WordBoundaries:
1003
1002
  """
1004
1003
  a = array.array("b")
1005
1004
  a.fromfile(self.f, length)
1006
- return a.tostring().decode(encoding)
1005
+ return a.tobytes().decode(encoding)
1007
1006
 
1008
1007
  def __init__(self, filename):
1009
1008
  """
returnn/tensor/utils.py CHANGED
@@ -36,11 +36,14 @@ def tensor_fill_random_numpy_(
36
36
  *,
37
37
  min_val: int = 0,
38
38
  max_val: Optional[int] = None,
39
- rnd: numpy.random.RandomState,
39
+ rnd: Optional[numpy.random.RandomState] = None,
40
40
  dyn_dim_max_sizes: Optional[Dict[Dim, int]] = None,
41
41
  dyn_dim_min_sizes: Optional[Dict[Dim, int]] = None,
42
42
  ) -> bool:
43
43
  """fill. return whether sth was filled"""
44
+ if rnd is None:
45
+ # noinspection PyUnresolvedReferences,PyProtectedMember
46
+ rnd = numpy.random.mtrand._rand
44
47
  if dyn_dim_max_sizes is None:
45
48
  dyn_dim_max_sizes = {}
46
49
  if dyn_dim_min_sizes is None:
@@ -59,7 +62,7 @@ def tensor_fill_random_numpy_(
59
62
  continue
60
63
  if tensor_fill_random_numpy_(
61
64
  dim.dyn_size_ext,
62
- min_val=dyn_dim_min_sizes.get(dim, 2),
65
+ min_val=dyn_dim_min_sizes.get(dim, min(2, dyn_dim_max_sizes.get(dim, 2))),
63
66
  max_val=dyn_dim_max_sizes.get(dim, None),
64
67
  rnd=rnd,
65
68
  dyn_dim_max_sizes=dyn_dim_max_sizes,
@@ -98,8 +101,8 @@ def tensor_fill_random_numpy_(
98
101
  if max_val is None:
99
102
  max_val = rnd.randint(5, 20)
100
103
  if x.sparse_dim and x.sparse_dim.dimension is not None:
101
- max_val = x.sparse_dim.dimension
102
- x.raw_tensor = rnd.randint(min_val, max_val, size=shape, dtype=x.dtype)
104
+ max_val = x.sparse_dim.dimension - 1
105
+ x.raw_tensor = rnd.randint(min_val, max_val + 1, size=shape, dtype=x.dtype)
103
106
  elif x.dtype == "bool":
104
107
  x.raw_tensor = rnd.randint(0, 2, size=shape, dtype=x.dtype)
105
108
  elif x.dtype.startswith("float"):
@@ -944,7 +944,6 @@ class ReturnnLayersBackend(Backend[Layer]):
944
944
  """
945
945
  assert mask.dtype == "bool"
946
946
  assert set(mask.dims) == set(dims)
947
- assert set(mask.dims).issubset(set(tensor.dims))
948
947
  if not out_dim:
949
948
  out_dim = Dim(None, name="mask")
950
949
  return (
@@ -1067,14 +1066,16 @@ class ReturnnLayersBackend(Backend[Layer]):
1067
1066
  s = filter_size[i].dimension if not strides else strides[i]
1068
1067
  if filter_size[i].dimension == s == 1 or (s == 1 and padding.lower() == "same"):
1069
1068
  out_spatial_dims[i] = in_spatial_dims[i]
1070
- layer_dict = {
1069
+ assert all(size.is_static() for size in filter_size)
1070
+ layer_dict: Dict[str, Any] = {
1071
1071
  "class": "transposed_conv",
1072
1072
  "from": source,
1073
1073
  "in_dim": in_dim,
1074
1074
  "in_spatial_dims": in_spatial_dims,
1075
1075
  "out_dim": out_dim,
1076
1076
  "out_spatial_dims": out_spatial_dims,
1077
- "filter_size": filter_size,
1077
+ "filter_size": [size.dimension for size in filter_size],
1078
+ "filter_perm": list(filter_size) + [out_dim, in_dim],
1078
1079
  "padding": padding,
1079
1080
  }
1080
1081
  if remove_padding:
@@ -2741,7 +2741,7 @@ class BooleanMaskLayer(LayerBase):
2741
2741
  tensor = self.sources[0].output
2742
2742
  remaining_dims = [d for d in tensor.dims if d not in dims]
2743
2743
  tensor_templ = tensor.copy_template_new_dim_tags(tuple(dims) + tuple(remaining_dims))
2744
- tensor = tensor.copy_compatible_to(tensor_templ, add_dims=False)
2744
+ tensor = tensor.copy_compatible_to(tensor_templ, unbroadcast=True)
2745
2745
  mask_templ = mask.output.copy_template_new_dim_tags(new_dim_tags=tuple(dims))
2746
2746
  mask_ = mask.output.copy_compatible_to(mask_templ, add_dims=False)
2747
2747
  self.output.raw_tensor = tf.boolean_mask(tensor.raw_tensor, mask=mask_.raw_tensor)
@@ -7371,7 +7371,7 @@ class TransposedConvLayer(_ConcatInputLayer):
7371
7371
  """
7372
7372
  from returnn.tf.util.basic import get_initializer, get_activation_function, get_shape
7373
7373
 
7374
- super(TransposedConvLayer, self).__init__(**kwargs)
7374
+ super(TransposedConvLayer, self).__init__(in_dim=in_dim, **kwargs)
7375
7375
  out_dim # noqa # via get_out_data_from_opts
7376
7376
  assert not self.input_data.sparse
7377
7377
  assert self.input_data.have_batch_axis()
@@ -7516,7 +7516,10 @@ class TransposedConvLayer(_ConcatInputLayer):
7516
7516
  ):
7517
7517
  """
7518
7518
  Determines output length of a transposed convolution given input length.
7519
- Copied from conv_utils.deconv_output_length, adapted with simplification.
7519
+
7520
+ Copied from TF/Keras conv_utils.deconv_output_length
7521
+ (https://github.com/tensorflow/tensorflow/blob/5912f51d580551e5cee2cfde4cb882594b4d3e60/tensorflow/python/keras/utils/conv_utils.py#L140),
7522
+ adapted with simplification.
7520
7523
 
7521
7524
  Also see :func:`ConvLayer.calc_out_dim`.
7522
7525
 
@@ -7533,44 +7536,17 @@ class TransposedConvLayer(_ConcatInputLayer):
7533
7536
  """
7534
7537
  if out_dim and out_dim.is_dim_known():
7535
7538
  return out_dim.get_dim_value()
7536
- assert padding in {"same", "valid", "full"}
7537
-
7538
- # Get the dilated kernel size
7539
- filter_size = filter_size + (filter_size - 1) * (dilation - 1)
7540
7539
 
7541
- if stride != 1:
7542
- input_length = input_length * stride
7540
+ import returnn.frontend as rf
7543
7541
 
7544
- # Infer length if output padding is None, else compute the exact length
7545
- if output_padding is None:
7546
- if padding == "valid":
7547
- if isinstance(input_length, Dim):
7548
- length = input_length + max(filter_size - stride, 0)
7549
- else:
7550
- length = tf_util.simplify_add(input_length, max(filter_size - stride, 0))
7551
- elif padding == "full":
7552
- if isinstance(input_length, Dim):
7553
- length = input_length - (stride + filter_size - 2)
7554
- else:
7555
- length = tf_util.simplify_add(input_length, -(stride + filter_size - 2))
7556
- elif padding == "same":
7557
- length = input_length
7558
- else:
7559
- raise Exception("invalid padding %r" % (padding,))
7560
- else: # output_padding
7561
- if padding == "same":
7562
- pad = filter_size // 2
7563
- elif padding == "valid":
7564
- pad = 0
7565
- elif padding == "full":
7566
- pad = filter_size - 1
7567
- else:
7568
- raise Exception("invalid padding %r" % (padding,))
7569
- if isinstance(input_length, Dim):
7570
- length = input_length + (-stride + filter_size - 2 * pad + output_padding)
7571
- else:
7572
- length = tf_util.simplify_add(input_length, -stride + filter_size - 2 * pad + output_padding)
7573
- return length
7542
+ return rf.calc_transposed_conv_out_length(
7543
+ input_length,
7544
+ filter_size=filter_size,
7545
+ padding=padding,
7546
+ output_padding=output_padding,
7547
+ stride=stride,
7548
+ dilation_rate=dilation,
7549
+ )
7574
7550
 
7575
7551
  @classmethod
7576
7552
  def get_out_data_from_opts(
returnn/tf/native_op.py CHANGED
@@ -528,77 +528,30 @@ class OpMaker:
528
528
  def _make_mod(self):
529
529
  if self.cache_key in self.mod_cache:
530
530
  return self.mod_cache[self.cache_key]
531
- from returnn.util.basic import find_lib
532
-
533
- # Note about BLAS linkage:
534
- # TensorFlow (or its Eigen lib) likely has linked against some BLAS lib itself.
535
- # For our CPU code, we directly call some BLAS functions such as `sgemm_`.
536
- # On platforms where there is a flat namespace (e.g. Mac),
537
- # it probably is not needed to explicitly link it again for this module.
538
- # In other cases, it's probably needed, but it's not so clear which lib has the
539
- # right symbols (e.g. the `sgemm_` symbol).
531
+
532
+ # Note about BLAS / matmul:
533
+ # Earlier, we assumed that TensorFlow/Eigen used BLAS internally,
534
+ # and our code directly called BLAS sgemm_, so we needed to link directly to BLAS.
535
+ # Now, by default, we use the underlying Eigen library,
536
+ # which is the same code path that TF also uses for CPU matmul.
537
+ # Only if an explicit BLAS library is specified, we use that instead.
540
538
  ld_flags = []
541
- have_blas_lib = False
539
+ c_macro_defines = {}
542
540
 
543
541
  if self.blas_lib is not None and os.path.exists(self.blas_lib):
544
542
  path = os.path.dirname(self.blas_lib)
545
543
  if path == "":
546
544
  path = "."
547
545
  ld_flags += ["-L%s" % path, "-l:%s" % os.path.basename(self.blas_lib)]
548
- have_blas_lib = True
549
- if not have_blas_lib and self.search_for_runtime_blas:
550
- from returnn.util.basic import find_sgemm_libs_from_runtime
551
-
552
- libs = find_sgemm_libs_from_runtime()
553
- if libs:
554
- numpy_libs = [fn for fn in libs if "/numpy/.libs/" in fn]
555
- if numpy_libs:
556
- # Prefer Numpy; move to front.
557
- libs = numpy_libs + [fn for fn in libs if fn not in numpy_libs]
558
- if self.blas_lib is not None:
559
- libs = [lib for lib in libs if self.blas_lib in lib]
560
- for fn in libs:
561
- ld_flags += ["-L%s" % os.path.dirname(fn), "-l:%s" % os.path.basename(fn)]
562
- have_blas_lib = True
563
- if not have_blas_lib and self.search_for_numpy_blas:
564
- # Find related Numpy libs.
565
- # Numpy usually comes with OpenBlas, and Numpy is probably loaded anyway.
566
- # Even do this before the other libs below, as it is likely
567
- # that this OpenBlas lib is correctly initialized already.
568
- import numpy
569
-
570
- numpy_dir = os.path.dirname(numpy.__file__)
571
- if os.path.exists("%s/.libs" % numpy_dir):
572
- ld_flags += ["-L%s/.libs" % numpy_dir]
573
- from glob import glob
574
-
575
- for f in glob("%s/.libs/*.so" % numpy_dir):
576
- f = os.path.basename(f)
577
- if self.blas_lib is not None and self.blas_lib not in f:
578
- continue
579
- if f.startswith("lib"):
580
- f = f[3:]
581
- if f.endswith(".so"):
582
- f = f[:-3]
583
- ld_flags += ["-l%s" % f]
584
- have_blas_lib = True
585
- if not have_blas_lib and self.search_for_system_blas:
586
- # Try to just link against blas/f77blas
587
- # (both can potentially have the symbol) if it finds the lib.
588
- if find_lib("blas"):
589
- ld_flags += ["-lblas"]
590
- have_blas_lib = True
591
- if find_lib("f77blas"):
592
- ld_flags += ["-lf77blas"]
593
- have_blas_lib = True
594
- if not have_blas_lib:
595
- print("WARNING: OpMaker: no BLAS lib found")
546
+ c_macro_defines["HAVE_CUSTOM_BLAS"] = "1"
547
+
596
548
  comp = tf_util.OpCodeCompiler(
597
549
  base_name=self.name,
598
550
  code_version=self.description.code_version,
599
551
  code=self._make_code(),
600
552
  include_deps=[self.support_native_op_cpp_filename],
601
553
  ld_flags=ld_flags,
554
+ c_macro_defines=c_macro_defines,
602
555
  use_cuda_if_available=self.with_cuda,
603
556
  log_stream=self.log_stream,
604
557
  **dict(self.compiler_opts),
returnn/tf/network.py CHANGED
@@ -4428,7 +4428,7 @@ def help_on_tf_exception(
4428
4428
  data = extern_data.data[data_key]
4429
4429
  info += ", %s" % data
4430
4430
  print(" %r: %s" % (key, info), file=file)
4431
- if data and data.sparse:
4431
+ if data is not None and data.sparse:
4432
4432
  if v_minmax[0] < 0 or v_minmax[1] >= data.dim:
4433
4433
  print(" WARNING, invalid label for data", data, file=file)
4434
4434
  elif feed_dict is None:
returnn/tf/util/basic.py CHANGED
@@ -2784,6 +2784,10 @@ class CudaEnv:
2784
2784
  self.cuda_path = None
2785
2785
  if self.verbose_find_cuda:
2786
2786
  print("CUDA disabled via env DISABLE_CUDA.")
2787
+ elif os.environ.get("CUDA_VISIBLE_DEVICES", None) in ["", "-1"]:
2788
+ self.cuda_path = None
2789
+ if self.verbose_find_cuda:
2790
+ print(f"CUDA disabled via env CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']!r}.")
2787
2791
  else:
2788
2792
  self.cuda_path = self._find_cuda_path()
2789
2793
  if self.verbose_find_cuda:
@@ -3020,6 +3024,21 @@ class OpCodeCompiler(NativeCodeCompiler):
3020
3024
  ld_flags += tf.sysconfig.get_link_flags()
3021
3025
  elif have_min_tf_version((1, 4)):
3022
3026
  ld_flags += ["-L%s" % tf.sysconfig.get_lib(), "-ltensorflow_framework"]
3027
+ if have_min_tf_version((2, 20)):
3028
+ # TF 2.20 removed TF_MAJOR_VERSION and co from version.h,
3029
+ # and one is supposed to define these macros externally.
3030
+ # Also, release_version.h was added to define TF_VERSION_STRING based on this (if needed).
3031
+ # https://github.com/tensorflow/tensorflow/commit/c8f0e0620e5678d0f165a07e64114024a966ab7f
3032
+ major, minor, patch = tf.__version__.split(".", 2)
3033
+ patch, suffix = patch.split("-", 1) if "-" in patch else (patch, "")
3034
+ c_macro_defines.update(
3035
+ {
3036
+ "TF_MAJOR_VERSION": major,
3037
+ "TF_MINOR_VERSION": minor,
3038
+ "TF_PATCH_VERSION": patch,
3039
+ "TF_VERSION_SUFFIX": suffix,
3040
+ }
3041
+ )
3023
3042
  use_cxx11_abi = getattr(getattr(tf, "sysconfig", tf), "CXX11_ABI_FLAG", getattr(tf, "CXX11_ABI_FLAG", False))
3024
3043
  super(OpCodeCompiler, self).__init__(
3025
3044
  include_paths=include_paths,
returnn/torch/engine.py CHANGED
@@ -532,7 +532,7 @@ class Engine(EngineBase):
532
532
  for key, val in eval_info.items():
533
533
  self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step)
534
534
  self._tensorboard_writer.add_scalar(
535
- f"train/learning_rate",
535
+ "train/learning_rate",
536
536
  self._updater.get_effective_learning_rate(),
537
537
  global_step=self.global_train_step,
538
538
  )
@@ -930,7 +930,7 @@ class Engine(EngineBase):
930
930
  if not os.path.exists(filename) and os.path.exists(model_epoch_filename):
931
931
  filename = model_epoch_filename
932
932
  print("Load model %s" % (filename,), file=log.v4)
933
- checkpoint_state = torch.load(filename, map_location=self._device)
933
+ checkpoint_state = _torch_load(filename, device=self._device)
934
934
  if epoch is None:
935
935
  epoch = checkpoint_state.get("epoch", self._start_epoch or 1)
936
936
  step = checkpoint_state.get("step", 1)
@@ -1030,7 +1030,7 @@ class Engine(EngineBase):
1030
1030
  print("(No relevant parameters matching.)", file=log.v3)
1031
1031
  continue
1032
1032
  print(f"Pre-load weights for key '{preload_key}' from {opts['filename']}", file=log.v3)
1033
- preload_model_state = torch.load(opts["filename"], map_location=self._device)
1033
+ preload_model_state = _torch_load(opts["filename"], device=self._device)
1034
1034
  if opts.get("checkpoint_key", "model") is not None:
1035
1035
  # This can be used if an external checkpoint saves a checkpoint a different structure that just the
1036
1036
  # model state dict. E.g., if a checkpoint is created using
@@ -1063,6 +1063,28 @@ class Engine(EngineBase):
1063
1063
  preload_model_state_keys = set(preload_model_state.keys())
1064
1064
  loaded_state_keys.update(preload_model_state.keys())
1065
1065
  missing_keys.difference_update(preload_model_state.keys())
1066
+
1067
+ custom_missing_load_func = opts.get("custom_missing_load_func")
1068
+ if custom_missing_load_func:
1069
+ custom_missing_vars_map = {}
1070
+ for var_name in missing_keys_preload:
1071
+ var_shape = self._pt_model.state_dict()[var_name].shape
1072
+ var_val = custom_missing_load_func(
1073
+ name=var_name,
1074
+ shape=var_shape,
1075
+ preload_model_state=preload_model_state,
1076
+ **util.get_fwd_compat_kwargs(),
1077
+ )
1078
+ if var_val is not None:
1079
+ assert var_val.shape == var_shape
1080
+ custom_missing_vars_map[var_name] = var_val
1081
+ preload_model_state.update(custom_missing_vars_map)
1082
+ missing_keys_preload, unexpected_keys_preload = self._pt_model.load_state_dict(
1083
+ preload_model_state, strict=False
1084
+ )
1085
+ loaded_state_keys.update(preload_model_state.keys())
1086
+ missing_keys.difference_update(preload_model_state.keys())
1087
+
1066
1088
  del preload_model_state
1067
1089
  gc.collect()
1068
1090
 
@@ -1700,3 +1722,15 @@ def _get_total_grad_norm(model: torch.nn.Module, p: float) -> float:
1700
1722
  p=p,
1701
1723
  ).item()
1702
1724
  )
1725
+
1726
+
1727
+ def _torch_load(filename: Union[str, os.PathLike], *, device: str) -> Dict[str, Any]:
1728
+ # Might resolve PtCheckpoint or Sisyphus Path objects or so.
1729
+ filename = os.fspath(filename)
1730
+
1731
+ if filename.endswith(".safetensors"):
1732
+ from safetensors.torch import load_file as safetensors_load
1733
+
1734
+ return safetensors_load(filename, device=device)
1735
+
1736
+ return torch.load(filename, map_location=device)