returnn 1.20251027.232712__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. returnn/PKG-INFO +2 -2
  2. returnn/__old_mod_loader__.py +26 -2
  3. returnn/_setup_info_generated.py +2 -2
  4. returnn/datasets/lm.py +130 -42
  5. returnn/datasets/meta.py +93 -43
  6. returnn/datasets/postprocessing.py +597 -108
  7. returnn/datasets/util/vocabulary.py +90 -0
  8. returnn/frontend/__init__.py +1 -0
  9. returnn/frontend/_backend.py +41 -0
  10. returnn/frontend/_native/__init__.py +22 -0
  11. returnn/frontend/_numpy_backend.py +7 -0
  12. returnn/frontend/_utils.py +1 -1
  13. returnn/frontend/array_.py +48 -2
  14. returnn/frontend/assert_.py +35 -0
  15. returnn/frontend/attention.py +54 -20
  16. returnn/frontend/conv.py +273 -54
  17. returnn/frontend/device.py +14 -1
  18. returnn/frontend/encoder/conformer.py +20 -0
  19. returnn/frontend/encoder/transformer.py +2 -0
  20. returnn/frontend/loss.py +222 -3
  21. returnn/frontend/math_.py +54 -14
  22. returnn/native_op.cpp +182 -172
  23. returnn/native_op.py +36 -31
  24. returnn/sprint/cache.py +12 -13
  25. returnn/tensor/_dim_extra.py +7 -7
  26. returnn/tensor/_tensor_extra.py +10 -10
  27. returnn/tensor/utils.py +8 -5
  28. returnn/tf/frontend_layers/_backend.py +7 -3
  29. returnn/tf/layers/basic.py +27 -40
  30. returnn/tf/native_op.py +27 -63
  31. returnn/tf/network.py +1 -1
  32. returnn/tf/util/basic.py +22 -197
  33. returnn/torch/engine.py +157 -6
  34. returnn/torch/frontend/_backend.py +280 -29
  35. returnn/torch/frontend/bridge.py +61 -0
  36. returnn/torch/frontend/compile_helper.py +106 -0
  37. returnn/torch/util/array_.py +30 -0
  38. returnn/torch/util/assert_.py +122 -0
  39. returnn/torch/util/exception_helper.py +7 -1
  40. returnn/torch/util/native_op.py +885 -0
  41. returnn/torch/util/native_op_code_compiler.py +308 -0
  42. returnn/util/basic.py +6 -7
  43. returnn/util/better_exchook.py +4 -0
  44. returnn/util/cuda_env.py +332 -0
  45. returnn/util/debug.py +12 -2
  46. returnn/util/file_cache.py +15 -1
  47. returnn/util/fsa.py +17 -13
  48. returnn/util/native_code_compiler.py +104 -47
  49. returnn/util/task_system.py +1 -1
  50. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +2 -2
  51. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +54 -48
  52. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
  53. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
  54. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,122 @@
1
+ """
2
+ Async device assertion utility.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import threading
8
+ from textwrap import dedent
9
+ from queue import Queue
10
+ import torch
11
+
12
+
13
+ def assert_(cond: torch.Tensor, message: str):
14
+ """
15
+ Does a device-side assertion.
16
+ For CPU, this will directly check the condition and raise an error if false.
17
+ For CUDA devices, this runs asynchronously on a separate thread (to avoid pin_memory in the current thread),
18
+ and non-blocking (does not trigger a CUDA sync).
19
+ """
20
+ if cond.device.type == "cpu":
21
+ if not cond.item():
22
+ raise AssertionError(message)
23
+ return
24
+ elif cond.device.type == "cuda":
25
+ # This triggers the Lazy initialization on first call
26
+ _CudaAsyncWorker().push(cond, message)
27
+ else:
28
+ raise NotImplementedError(f"assert_ not implemented for device type: {cond.device.type}")
29
+
30
+
31
+ def _get_ext():
32
+ global _ext
33
+ if _ext:
34
+ return _ext
35
+
36
+ from .native_op_code_compiler import OpCodeCompiler
37
+
38
+ compiler = OpCodeCompiler(
39
+ "async_assert_ext", use_cuda_if_available=True, code=_cpp_source + _cuda_source, is_python_module=True
40
+ )
41
+ _ext = compiler.load_module()
42
+ return _ext
43
+
44
+
45
+ _ext = None
46
+
47
+ _cpp_source = dedent("""\
48
+ #include <torch/extension.h>
49
+
50
+ void async_assert_cuda(const at::Tensor& cond, const at::Tensor& msg_tensor);
51
+
52
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
53
+ m.def("async_assert_cuda", torch::wrap_pybind_function(async_assert_cuda), "Asynchronous CUDA assert");
54
+ }
55
+ """)
56
+
57
+ _cuda_source = dedent("""\
58
+ #include <torch/types.h>
59
+ #include <cuda.h>
60
+ #include <cuda_runtime.h>
61
+ #include <torch/extension.h>
62
+ #include <ATen/cuda/CUDAContext.h>
63
+ #include <c10/cuda/CUDACachingAllocator.h>
64
+ #include <assert.h>
65
+
66
+ __global__ void assert_kernel(const bool* cond, const char* msg) {
67
+ if (blockIdx.x == 0 && threadIdx.x == 0) {
68
+ if (!(*cond)) {
69
+ printf("\\n[GPU ASSERT FAILED]: %s\\n", msg);
70
+ assert(false);
71
+ }
72
+ }
73
+ }
74
+
75
+ void async_assert_cuda(const at::Tensor& cond, const at::Tensor& msg_tensor) {
76
+ auto stream = at::cuda::getCurrentCUDAStream();
77
+
78
+ // Safety: Protect memory from GC while the kernel is in flight
79
+ c10::cuda::CUDACachingAllocator::recordStream(cond.storage().data_ptr(), stream);
80
+ c10::cuda::CUDACachingAllocator::recordStream(msg_tensor.storage().data_ptr(), stream);
81
+
82
+ assert_kernel<<<1, 1, 0, stream>>>(
83
+ cond.data_ptr<bool>(),
84
+ (const char*)msg_tensor.data_ptr<uint8_t>()
85
+ );
86
+ }
87
+ """)
88
+
89
+
90
+ class _CudaAsyncWorker:
91
+ _instance = None
92
+ _lock = threading.Lock()
93
+
94
+ def __new__(cls):
95
+ with cls._lock:
96
+ if cls._instance is None:
97
+ cls._instance = super(_CudaAsyncWorker, cls).__new__(cls)
98
+ cls._instance._init_worker()
99
+ return cls._instance
100
+
101
+ def _init_worker(self):
102
+ self.queue = Queue()
103
+ self.thread = threading.Thread(target=self._loop, daemon=True)
104
+ self.thread.start()
105
+
106
+ def _loop(self):
107
+ while True:
108
+ cond, message_str, stream = self.queue.get()
109
+
110
+ # Use the actual Stream object context
111
+ with torch.cuda.stream(stream):
112
+ # Convert string to pinned tensor (Avoiding read-only NP view)
113
+ msg_bytes = list(message_str.encode("utf-8"))
114
+ msg_cpu = torch.tensor(msg_bytes, dtype=torch.uint8, pin_memory=True)
115
+ msg_gpu = msg_cpu.to("cuda", non_blocking=True)
116
+
117
+ # Call JIT-compiled function
118
+ _get_ext().async_assert_cuda(cond, msg_gpu)
119
+
120
+ def push(self, cond: torch.Tensor, message: str):
121
+ """push to queue"""
122
+ self.queue.put((cond, message, torch.cuda.current_stream()))
@@ -71,7 +71,13 @@ def help_on_torch_exception(
71
71
  if not count_frames:
72
72
  exc_ext.append("(No module call frames.)")
73
73
 
74
- if len(exc.args) == 1 and isinstance(exc.args[0], str) and not always_direct_print:
74
+ if (
75
+ # KeyError formatting would be wrong, showing `KeyError: "enc_spatial_dim\n\nStep idx: 0\..."`
76
+ not isinstance(exc, KeyError)
77
+ and len(exc.args) == 1
78
+ and isinstance(exc.args[0], str)
79
+ and not always_direct_print
80
+ ):
75
81
  exc.args = ("\n".join([exc.args[0], ""] + exc_ext),)
76
82
  else:
77
83
  for msg in exc_ext: