warp-lang 1.5.1__py3-none-macosx_10_13_universal2.whl → 1.6.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (131) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1077 -481
  8. warp/codegen.py +250 -122
  9. warp/config.py +65 -21
  10. warp/context.py +500 -149
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_marching_cubes.py +1 -1
  16. warp/examples/core/example_mesh.py +1 -1
  17. warp/examples/core/example_torch.py +18 -34
  18. warp/examples/core/example_wave.py +1 -1
  19. warp/examples/fem/example_apic_fluid.py +1 -0
  20. warp/examples/fem/example_mixed_elasticity.py +1 -1
  21. warp/examples/optim/example_bounce.py +1 -1
  22. warp/examples/optim/example_cloth_throw.py +1 -1
  23. warp/examples/optim/example_diffray.py +4 -15
  24. warp/examples/optim/example_drone.py +1 -1
  25. warp/examples/optim/example_softbody_properties.py +392 -0
  26. warp/examples/optim/example_trajectory.py +1 -3
  27. warp/examples/optim/example_walker.py +5 -0
  28. warp/examples/sim/example_cartpole.py +0 -2
  29. warp/examples/sim/example_cloth_self_contact.py +314 -0
  30. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  31. warp/examples/sim/example_jacobian_ik.py +0 -2
  32. warp/examples/sim/example_quadruped.py +5 -2
  33. warp/examples/tile/example_tile_cholesky.py +79 -0
  34. warp/examples/tile/example_tile_convolution.py +2 -2
  35. warp/examples/tile/example_tile_fft.py +2 -2
  36. warp/examples/tile/example_tile_filtering.py +3 -3
  37. warp/examples/tile/example_tile_matmul.py +4 -4
  38. warp/examples/tile/example_tile_mlp.py +12 -12
  39. warp/examples/tile/example_tile_nbody.py +191 -0
  40. warp/examples/tile/example_tile_walker.py +319 -0
  41. warp/math.py +147 -0
  42. warp/native/array.h +12 -0
  43. warp/native/builtin.h +0 -1
  44. warp/native/bvh.cpp +149 -70
  45. warp/native/bvh.cu +287 -68
  46. warp/native/bvh.h +195 -85
  47. warp/native/clang/clang.cpp +6 -2
  48. warp/native/crt.h +1 -0
  49. warp/native/cuda_util.cpp +35 -0
  50. warp/native/cuda_util.h +5 -0
  51. warp/native/exports.h +40 -40
  52. warp/native/intersect.h +17 -0
  53. warp/native/mat.h +57 -3
  54. warp/native/mathdx.cpp +19 -0
  55. warp/native/mesh.cpp +25 -8
  56. warp/native/mesh.cu +153 -101
  57. warp/native/mesh.h +482 -403
  58. warp/native/quat.h +40 -0
  59. warp/native/solid_angle.h +7 -0
  60. warp/native/sort.cpp +85 -0
  61. warp/native/sort.cu +34 -0
  62. warp/native/sort.h +3 -1
  63. warp/native/spatial.h +11 -0
  64. warp/native/tile.h +1189 -664
  65. warp/native/tile_reduce.h +8 -6
  66. warp/native/vec.h +41 -0
  67. warp/native/warp.cpp +8 -1
  68. warp/native/warp.cu +263 -40
  69. warp/native/warp.h +19 -5
  70. warp/optim/linear.py +22 -4
  71. warp/render/render_opengl.py +132 -59
  72. warp/render/render_usd.py +10 -2
  73. warp/sim/__init__.py +6 -1
  74. warp/sim/collide.py +289 -32
  75. warp/sim/import_urdf.py +20 -5
  76. warp/sim/integrator_euler.py +25 -7
  77. warp/sim/integrator_featherstone.py +147 -35
  78. warp/sim/integrator_vbd.py +842 -40
  79. warp/sim/model.py +173 -112
  80. warp/sim/render.py +2 -2
  81. warp/stubs.py +249 -116
  82. warp/tape.py +28 -30
  83. warp/tests/aux_test_module_unload.py +15 -0
  84. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  85. warp/tests/test_array.py +100 -0
  86. warp/tests/test_assert.py +242 -0
  87. warp/tests/test_codegen.py +14 -61
  88. warp/tests/test_collision.py +8 -8
  89. warp/tests/test_examples.py +16 -1
  90. warp/tests/test_grad_debug.py +87 -2
  91. warp/tests/test_hash_grid.py +1 -1
  92. warp/tests/test_ipc.py +116 -0
  93. warp/tests/test_launch.py +77 -26
  94. warp/tests/test_mat.py +213 -168
  95. warp/tests/test_math.py +47 -1
  96. warp/tests/test_matmul.py +11 -7
  97. warp/tests/test_matmul_lite.py +4 -4
  98. warp/tests/test_mesh.py +84 -60
  99. warp/tests/test_mesh_query_aabb.py +165 -0
  100. warp/tests/test_mesh_query_point.py +328 -286
  101. warp/tests/test_mesh_query_ray.py +134 -121
  102. warp/tests/test_mlp.py +2 -2
  103. warp/tests/test_operators.py +43 -0
  104. warp/tests/test_overwrite.py +6 -5
  105. warp/tests/test_quat.py +77 -0
  106. warp/tests/test_reload.py +29 -0
  107. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  108. warp/tests/test_static.py +16 -0
  109. warp/tests/test_tape.py +25 -0
  110. warp/tests/test_tile.py +134 -191
  111. warp/tests/test_tile_load.py +399 -0
  112. warp/tests/test_tile_mathdx.py +61 -8
  113. warp/tests/test_tile_mlp.py +17 -17
  114. warp/tests/test_tile_reduce.py +24 -18
  115. warp/tests/test_tile_shared_memory.py +66 -17
  116. warp/tests/test_tile_view.py +165 -0
  117. warp/tests/test_torch.py +35 -0
  118. warp/tests/test_utils.py +36 -24
  119. warp/tests/test_vec.py +110 -0
  120. warp/tests/unittest_suites.py +29 -4
  121. warp/tests/unittest_utils.py +30 -11
  122. warp/thirdparty/unittest_parallel.py +5 -2
  123. warp/types.py +419 -111
  124. warp/utils.py +9 -5
  125. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/METADATA +86 -45
  126. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/RECORD +129 -118
  127. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/WHEEL +1 -1
  128. warp/examples/benchmarks/benchmark_tile.py +0 -179
  129. warp/native/tile_gemm.h +0 -341
  130. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/LICENSE.md +0 -0
  131. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/top_level.txt +0 -0
warp/native/tile_reduce.h CHANGED
@@ -86,7 +86,7 @@ auto tile_reduce_impl(Op f, Tile& t)
86
86
  using T = typename Tile::Type;
87
87
 
88
88
  auto input = t.copy_to_register();
89
- auto output = tile_register_t<T, 1, 1>();
89
+ auto output = tile_register_t<T, tile_layout_register_t<tile_shape_t<1>>>();
90
90
 
91
91
  const int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1)/WP_TILE_WARP_SIZE;
92
92
  const int warp_index = threadIdx.x/WP_TILE_WARP_SIZE;
@@ -94,19 +94,21 @@ auto tile_reduce_impl(Op f, Tile& t)
94
94
 
95
95
  T thread_sum = input.data[0];
96
96
 
97
+ using Layout = typename decltype(input)::Layout;
98
+
97
99
  // thread reduction
98
100
  WP_PRAGMA_UNROLL
99
- for (int i=1; i < input.NumRegs; ++i)
101
+ for (int i=1; i < Layout::NumRegs; ++i)
100
102
  {
101
- int linear = t.index(i);
102
- if (!Tile::Aligned && linear >= Tile::Size)
103
+ int linear = Layout::linear_from_register(i);
104
+ if (!Layout::valid(linear))
103
105
  break;
104
106
 
105
107
  thread_sum = f(thread_sum, input.data[i]);
106
108
  }
107
109
 
108
110
  // ensure that only threads with at least one valid item participate in the reduction
109
- unsigned int mask = __ballot_sync(__activemask(), t.index(0) < Tile::Size);
111
+ unsigned int mask = __ballot_sync(__activemask(), Layout::valid(Layout::linear_from_register(0)));
110
112
 
111
113
  // warp reduction
112
114
  T warp_sum = warp_reduce(thread_sum, f, mask);
@@ -177,7 +179,7 @@ void adj_tile_sum(Tile& t, Tile& adj_t, AdjTile& adj_ret)
177
179
  WP_TILE_SYNC();
178
180
 
179
181
  // broadcast scalar across input dimensions (note zero strides)
180
- auto adj_ret_reg = tile_shared_t<T, Tile::M, Tile::N, 0, 0>(&scratch, NULL).copy_to_register();
182
+ auto adj_ret_reg = tile_shared_t<T, tile_layout_strided_t<typename Tile::Layout::Shape, tile_stride_t<0, 0>>>(&scratch, NULL).copy_to_register();
181
183
  adj_t.grad_add(adj_ret_reg);
182
184
  }
183
185
 
warp/native/vec.h CHANGED
@@ -495,6 +495,37 @@ inline CUDA_CALLABLE void adj_indexref(vec_t<Length, Type>* v, int idx,
495
495
  // nop
496
496
  }
497
497
 
498
+
499
+ template<unsigned Length, typename Type>
500
+ inline CUDA_CALLABLE void augassign_add(vec_t<Length, Type>& v, int idx, Type value)
501
+ {
502
+ v[idx] += value;
503
+ }
504
+
505
+
506
+ template<unsigned Length, typename Type>
507
+ inline CUDA_CALLABLE void adj_augassign_add(vec_t<Length, Type>& v, int idx, Type value,
508
+ vec_t<Length, Type>& adj_v, int adj_idx, Type& adj_value)
509
+ {
510
+ adj_value += adj_v[idx];
511
+ }
512
+
513
+
514
+ template<unsigned Length, typename Type>
515
+ inline CUDA_CALLABLE void augassign_sub(vec_t<Length, Type>& v, int idx, Type value)
516
+ {
517
+ v[idx] -= value;
518
+ }
519
+
520
+
521
+ template<unsigned Length, typename Type>
522
+ inline CUDA_CALLABLE void adj_augassign_sub(vec_t<Length, Type>& v, int idx, Type value,
523
+ vec_t<Length, Type>& adj_v, int adj_idx, Type& adj_value)
524
+ {
525
+ adj_value -= adj_v[idx];
526
+ }
527
+
528
+
498
529
  template<unsigned Length, typename Type>
499
530
  inline CUDA_CALLABLE vec_t<Length, Type> assign(vec_t<Length, Type>& v, int idx, Type value)
500
531
  {
@@ -1311,5 +1342,15 @@ inline CUDA_CALLABLE void adj_vec4(float s, float& adj_s, const vec4& adj_ret)
1311
1342
  adj_vec_t(s, adj_s, adj_ret);
1312
1343
  }
1313
1344
 
1345
+ template<unsigned Length, typename Type>
1346
+ CUDA_CALLABLE inline int len(const vec_t<Length, Type>& x)
1347
+ {
1348
+ return Length;
1349
+ }
1350
+
1351
+ template<unsigned Length, typename Type>
1352
+ CUDA_CALLABLE inline void adj_len(const vec_t<Length, Type>& x, vec_t<Length, Type>& adj_x, const int& adj_ret)
1353
+ {
1354
+ }
1314
1355
 
1315
1356
  } // namespace wp
warp/native/warp.cpp CHANGED
@@ -992,6 +992,7 @@ WP_API int cuda_device_get_pci_bus_id(int ordinal) { return -1; }
992
992
  WP_API int cuda_device_get_pci_device_id(int ordinal) { return -1; }
993
993
  WP_API int cuda_device_is_uva(int ordinal) { return 0; }
994
994
  WP_API int cuda_device_is_mempool_supported(int ordinal) { return 0; }
995
+ WP_API int cuda_device_is_ipc_supported(int ordinal) { return 0; }
995
996
  WP_API int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold) { return 0; }
996
997
  WP_API uint64_t cuda_device_get_mempool_release_threshold(int ordinal) { return 0; }
997
998
  WP_API void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem) {}
@@ -1015,6 +1016,12 @@ WP_API int cuda_set_peer_access_enabled(void* target_context, void* peer_context
1015
1016
  WP_API int cuda_is_mempool_access_enabled(int target_ordinal, int peer_ordinal) { return 0; }
1016
1017
  WP_API int cuda_set_mempool_access_enabled(int target_ordinal, int peer_ordinal, int enable) { return 0; }
1017
1018
 
1019
+ WP_API void cuda_ipc_get_mem_handle(void* ptr, char* out_buffer) {}
1020
+ WP_API void* cuda_ipc_open_mem_handle(void* context, char* handle) { return NULL; }
1021
+ WP_API void cuda_ipc_close_mem_handle(void* ptr) {}
1022
+ WP_API void cuda_ipc_get_event_handle(void* context, void* event, char* out_buffer) {}
1023
+ WP_API void* cuda_ipc_open_event_handle(void* context, char* handle) { return NULL; }
1024
+
1018
1025
  WP_API void* cuda_stream_create(void* context, int priority) { return NULL; }
1019
1026
  WP_API void cuda_stream_destroy(void* context, void* stream) {}
1020
1027
  WP_API void cuda_stream_register(void* context, void* stream) {}
@@ -1038,7 +1045,7 @@ WP_API bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret
1038
1045
  WP_API bool cuda_graph_launch(void* graph, void* stream) { return false; }
1039
1046
  WP_API bool cuda_graph_destroy(void* context, void* graph) { return false; }
1040
1047
 
1041
- WP_API size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes) { return 0; }
1048
+ WP_API size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types) { return 0; }
1042
1049
 
1043
1050
  WP_API void* cuda_load_module(void* context, const char* ptx) { return NULL; }
1044
1051
  WP_API void cuda_unload_module(void* context, void* module) {}
warp/native/warp.cu CHANGED
@@ -36,6 +36,7 @@
36
36
  #define check_nvjitlink(handle, code) (check_nvjitlink_result(handle, code, __FILE__, __LINE__))
37
37
  #define check_cufftdx(code) (check_cufftdx_result(code, __FILE__, __LINE__))
38
38
  #define check_cublasdx(code) (check_cublasdx_result(code, __FILE__, __LINE__))
39
+ #define check_cusolver(code) (check_cusolver_result(code, __FILE__, __LINE__))
39
40
  #define CHECK_ANY(code) \
40
41
  { \
41
42
  do { \
@@ -63,6 +64,15 @@
63
64
  } \
64
65
  } while(0); \
65
66
  }
67
+ #define CHECK_CUSOLVER(code) \
68
+ { \
69
+ do { \
70
+ bool out = (check_cusolver(code)); \
71
+ if(!out) { \
72
+ return out; \
73
+ } \
74
+ } while(0); \
75
+ }
66
76
 
67
77
  bool check_nvrtc_result(nvrtcResult result, const char* file, int line)
68
78
  {
@@ -136,6 +146,7 @@ struct DeviceInfo
136
146
  int arch = 0;
137
147
  int is_uva = 0;
138
148
  int is_mempool_supported = 0;
149
+ int is_ipc_supported = -1;
139
150
  int max_smem_bytes = 0;
140
151
  CUcontext primary_context = NULL;
141
152
  };
@@ -187,6 +198,13 @@ struct FreeInfo
187
198
  bool is_async = false;
188
199
  };
189
200
 
201
+ // Information used when deferring module unloading.
202
+ struct ModuleInfo
203
+ {
204
+ void* context = NULL;
205
+ void* module = NULL;
206
+ };
207
+
190
208
  static std::unordered_map<CUfunction, std::string> g_kernel_names;
191
209
 
192
210
  // cached info for all devices, indexed by ordinal
@@ -214,6 +232,9 @@ static std::unordered_map<void*, GraphAllocInfo> g_graph_allocs;
214
232
  // Call free_deferred_allocs() to release.
215
233
  static std::vector<FreeInfo> g_deferred_free_list;
216
234
 
235
+ // Modules that cannot be unloaded immediately get queued here.
236
+ // Call unload_deferred_modules() to release.
237
+ static std::vector<ModuleInfo> g_deferred_module_list;
217
238
 
218
239
  void cuda_set_context_restore_policy(bool always_restore)
219
240
  {
@@ -250,6 +271,21 @@ int cuda_init()
250
271
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
251
272
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
252
273
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_mempool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
274
+ #ifdef CUDA_VERSION
275
+ #if CUDA_VERSION >= 12000
276
+ int device_attribute_integrated = 0;
277
+ check_cu(cuDeviceGetAttribute_f(&device_attribute_integrated, CU_DEVICE_ATTRIBUTE_INTEGRATED, device));
278
+ if (device_attribute_integrated == 0)
279
+ {
280
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_ipc_supported, CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED, device));
281
+ }
282
+ else
283
+ {
284
+ // integrated devices do not support CUDA IPC
285
+ g_devices[i].is_ipc_supported = 0;
286
+ }
287
+ #endif
288
+ #endif
253
289
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].max_smem_bytes, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device));
254
290
  int major = 0;
255
291
  int minor = 0;
@@ -410,6 +446,31 @@ static int free_deferred_allocs(void* context = NULL)
410
446
  return num_freed_allocs;
411
447
  }
412
448
 
449
+ static int unload_deferred_modules(void* context = NULL)
450
+ {
451
+ if (g_deferred_module_list.empty() || !g_captures.empty())
452
+ return 0;
453
+
454
+ int num_unloaded_modules = 0;
455
+ for (auto it = g_deferred_module_list.begin(); it != g_deferred_module_list.end(); /*noop*/)
456
+ {
457
+ // free the module if it matches the given context or if the context is unspecified
458
+ const ModuleInfo& module_info = *it;
459
+ if (module_info.context == context || !context)
460
+ {
461
+ cuda_unload_module(module_info.context, module_info.module);
462
+ ++num_unloaded_modules;
463
+ it = g_deferred_module_list.erase(it);
464
+ }
465
+ else
466
+ {
467
+ ++it;
468
+ }
469
+ }
470
+
471
+ return num_unloaded_modules;
472
+ }
473
+
413
474
  static void CUDART_CB on_graph_destroy(void* user_data)
414
475
  {
415
476
  if (!user_data)
@@ -1756,6 +1817,13 @@ int cuda_device_is_mempool_supported(int ordinal)
1756
1817
  return 0;
1757
1818
  }
1758
1819
 
1820
+ int cuda_device_is_ipc_supported(int ordinal)
1821
+ {
1822
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1823
+ return g_devices[ordinal].is_ipc_supported;
1824
+ return 0;
1825
+ }
1826
+
1759
1827
  int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold)
1760
1828
  {
1761
1829
  if (ordinal < 0 || ordinal > int(g_devices.size()))
@@ -1920,6 +1988,8 @@ void cuda_context_synchronize(void* context)
1920
1988
  check_cu(cuCtxSynchronize_f());
1921
1989
  }
1922
1990
 
1991
+ unload_deferred_modules(context);
1992
+
1923
1993
  // check_cuda(cudaDeviceGraphMemTrim(cuda_context_get_device_ordinal(context)));
1924
1994
  }
1925
1995
 
@@ -2221,6 +2291,52 @@ int cuda_set_mempool_access_enabled(int target_ordinal, int peer_ordinal, int en
2221
2291
  return 1; // success
2222
2292
  }
2223
2293
 
2294
+ void cuda_ipc_get_mem_handle(void* ptr, char* out_buffer) {
2295
+ CUipcMemHandle memHandle;
2296
+ check_cu(cuIpcGetMemHandle_f(&memHandle, (CUdeviceptr)ptr));
2297
+ memcpy(out_buffer, memHandle.reserved, CU_IPC_HANDLE_SIZE);
2298
+ }
2299
+
2300
+ void* cuda_ipc_open_mem_handle(void* context, char* handle) {
2301
+ ContextGuard guard(context);
2302
+
2303
+ CUipcMemHandle memHandle;
2304
+ memcpy(memHandle.reserved, handle, CU_IPC_HANDLE_SIZE);
2305
+
2306
+ CUdeviceptr device_ptr;
2307
+
2308
+ // Strangely, the CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS flag is required
2309
+ if check_cu(cuIpcOpenMemHandle_f(&device_ptr, memHandle, CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS))
2310
+ return (void*) device_ptr;
2311
+ else
2312
+ return NULL;
2313
+ }
2314
+
2315
+ void cuda_ipc_close_mem_handle(void* ptr) {
2316
+ check_cu(cuIpcCloseMemHandle_f((CUdeviceptr) ptr));
2317
+ }
2318
+
2319
+ void cuda_ipc_get_event_handle(void* context, void* event, char* out_buffer) {
2320
+ ContextGuard guard(context);
2321
+
2322
+ CUipcEventHandle eventHandle;
2323
+ check_cu(cuIpcGetEventHandle_f(&eventHandle, static_cast<CUevent>(event)));
2324
+ memcpy(out_buffer, eventHandle.reserved, CU_IPC_HANDLE_SIZE);
2325
+ }
2326
+
2327
+ void* cuda_ipc_open_event_handle(void* context, char* handle) {
2328
+ ContextGuard guard(context);
2329
+
2330
+ CUipcEventHandle eventHandle;
2331
+ memcpy(eventHandle.reserved, handle, CU_IPC_HANDLE_SIZE);
2332
+
2333
+ CUevent event;
2334
+
2335
+ if (check_cu(cuIpcOpenEventHandle_f(&event, eventHandle)))
2336
+ return event;
2337
+ else
2338
+ return NULL;
2339
+ }
2224
2340
 
2225
2341
  void* cuda_stream_create(void* context, int priority)
2226
2342
  {
@@ -2542,7 +2658,10 @@ bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2542
2658
 
2543
2659
  // process deferred free list if no more captures are ongoing
2544
2660
  if (g_captures.empty())
2661
+ {
2545
2662
  free_deferred_allocs();
2663
+ unload_deferred_modules();
2664
+ }
2546
2665
 
2547
2666
  if (graph_ret)
2548
2667
  *graph_ret = graph_exec;
@@ -2614,7 +2733,7 @@ bool write_file(const char* data, size_t size, std::string filename, const char*
2614
2733
  }
2615
2734
  #endif
2616
2735
 
2617
- size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes)
2736
+ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types)
2618
2737
  {
2619
2738
  // use file extension to determine whether to output PTX or CUBIN
2620
2739
  const char* output_ext = strrchr(output_path, '.');
@@ -2675,8 +2794,13 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2675
2794
  //opts.push_back("--device-debug");
2676
2795
  }
2677
2796
  else
2797
+ {
2678
2798
  opts.push_back("--define-macro=NDEBUG");
2679
2799
 
2800
+ if (lineinfo)
2801
+ opts.push_back("--generate-line-info");
2802
+ }
2803
+
2680
2804
  if (verify_fp)
2681
2805
  opts.push_back("--define-macro=WP_VERIFY_FP");
2682
2806
  else
@@ -2685,6 +2809,11 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2685
2809
  if (fast_math)
2686
2810
  opts.push_back("--use_fast_math");
2687
2811
 
2812
+ if (fuse_fp)
2813
+ opts.push_back("--fmad=true");
2814
+ else
2815
+ opts.push_back("--fmad=false");
2816
+
2688
2817
  char include_cutlass[max_path];
2689
2818
  sprintf(include_cutlass, "--include-path=%s/cutlass/include", include_dir);
2690
2819
  opts.push_back(include_cutlass);
@@ -2712,7 +2841,7 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2712
2841
  res = nvrtcCreateProgram(
2713
2842
  &prog, // prog
2714
2843
  cuda_src, // buffer
2715
- NULL, // name
2844
+ program_name, // name
2716
2845
  0, // numHeaders
2717
2846
  NULL, // headers
2718
2847
  NULL); // includeNames
@@ -2793,6 +2922,10 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2793
2922
  if (num_ltoirs > 0)
2794
2923
  {
2795
2924
  #if WP_ENABLE_MATHDX
2925
+ if(ltoir_input_types == nullptr || ltoirs == nullptr || ltoir_sizes == nullptr) {
2926
+ fprintf(stderr, "Warp error: num_ltoirs > 0 but ltoir_input_types, ltoirs or ltoir_sizes are NULL\n");
2927
+ return size_t(-1);
2928
+ }
2796
2929
  nvJitLinkHandle handle;
2797
2930
  std::vector<const char *> lopts = {"-dlto", arch_opt_lto};
2798
2931
  if (use_ptx) {
@@ -2820,11 +2953,26 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2820
2953
  }
2821
2954
  for(size_t ltoidx = 0; ltoidx < num_ltoirs; ltoidx++)
2822
2955
  {
2956
+ nvJitLinkInputType input_type = static_cast<nvJitLinkInputType>(ltoir_input_types[ltoidx]);
2957
+ const char* ext = ".unknown";
2958
+ switch(input_type) {
2959
+ case NVJITLINK_INPUT_CUBIN:
2960
+ ext = ".cubin";
2961
+ break;
2962
+ case NVJITLINK_INPUT_LTOIR:
2963
+ ext = ".ltoir";
2964
+ break;
2965
+ case NVJITLINK_INPUT_FATBIN:
2966
+ ext = ".fatbin";
2967
+ break;
2968
+ default:
2969
+ break;
2970
+ }
2823
2971
  if(std::getenv("WARP_DUMP_LTOIR"))
2824
2972
  {
2825
- write_file(ltoirs[ltoidx], ltoir_sizes[ltoidx], std::string("lto_online_") + std::to_string(ltoidx) + ".ltoir", "wb");
2973
+ write_file(ltoirs[ltoidx], ltoir_sizes[ltoidx], std::string("lto_online_") + std::to_string(ltoidx) + ext, "wb");
2826
2974
  }
2827
- if(!check_nvjitlink(handle, nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, ltoirs[ltoidx], ltoir_sizes[ltoidx], "lto_online"))) // External LTOIR
2975
+ if(!check_nvjitlink(handle, nvJitLinkAddData(handle, input_type, ltoirs[ltoidx], ltoir_sizes[ltoidx], "lto_online"))) // External LTOIR
2828
2976
  {
2829
2977
  res = nvrtcResult(-1);
2830
2978
  }
@@ -2871,9 +3019,9 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2871
3019
  }
2872
3020
 
2873
3021
  #if WP_ENABLE_MATHDX
2874
- bool check_cufftdx_result(commonDxStatusType result, const char* file, int line)
3022
+ bool check_cufftdx_result(commondxStatusType result, const char* file, int line)
2875
3023
  {
2876
- if (result != commonDxStatusType::COMMONDX_SUCCESS) {
3024
+ if (result != commondxStatusType::COMMONDX_SUCCESS) {
2877
3025
  fprintf(stderr, "libmathdx cuFFTDx error: %d on %s:%d\n", (int)result, file, line);
2878
3026
  return false;
2879
3027
  } else {
@@ -2881,9 +3029,9 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2881
3029
  }
2882
3030
  }
2883
3031
 
2884
- bool check_cublasdx_result(commonDxStatusType result, const char* file, int line)
3032
+ bool check_cublasdx_result(commondxStatusType result, const char* file, int line)
2885
3033
  {
2886
- if (result != commonDxStatusType::COMMONDX_SUCCESS) {
3034
+ if (result != commondxStatusType::COMMONDX_SUCCESS) {
2887
3035
  fprintf(stderr, "libmathdx cuBLASDx error: %d on %s:%d\n", (int)result, file, line);
2888
3036
  return false;
2889
3037
  } else {
@@ -2891,6 +3039,16 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2891
3039
  }
2892
3040
  }
2893
3041
 
3042
+ bool check_cusolver_result(commondxStatusType result, const char* file, int line)
3043
+ {
3044
+ if (result != commondxStatusType::COMMONDX_SUCCESS) {
3045
+ fprintf(stderr, "libmathdx cuSOLVER error: %d on %s:%d\n", (int)result, file, line);
3046
+ return false;
3047
+ } else {
3048
+ return true;
3049
+ }
3050
+ }
3051
+
2894
3052
  bool cuda_compile_fft(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int size, int elements_per_thread, int direction, int precision, int* shared_memory_size)
2895
3053
  {
2896
3054
 
@@ -2904,35 +3062,35 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2904
3062
 
2905
3063
  bool res = true;
2906
3064
  cufftdxHandle h;
2907
- CHECK_CUFFTDX(cufftDxCreate(&h));
3065
+ CHECK_CUFFTDX(cufftdxCreate(&h));
2908
3066
 
2909
3067
  // CUFFTDX_API_BLOCK_LMEM means each thread starts with a subset of the data
2910
- CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_API, cufftDxApi::CUFFTDX_API_BLOCK_LMEM));
2911
- CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_EXECUTION, commonDxExecution::COMMONDX_EXECUTION_BLOCK));
2912
- CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_SIZE, (long long)size));
2913
- CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_DIRECTION, (cufftDxDirection)direction));
2914
- CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_PRECISION, (commonDxPrecision)precision));
2915
- CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_SM, (long long)(arch * 10)));
2916
- CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_ELEMENTS_PER_THREAD, (long long)(elements_per_thread)));
2917
- CHECK_CUFFTDX(cufftDxSetOperatorInt64(h, cufftDxOperatorType::CUFFTDX_OPERATOR_FFTS_PER_BLOCK, 1));
3068
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_API, cufftdxApi::CUFFTDX_API_BLOCK_LMEM));
3069
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
3070
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_SIZE, (long long)size));
3071
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_DIRECTION, (cufftdxDirection)direction));
3072
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_PRECISION, (commondxPrecision)precision));
3073
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_SM, (long long)(arch * 10)));
3074
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_ELEMENTS_PER_THREAD, (long long)(elements_per_thread)));
3075
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_FFTS_PER_BLOCK, 1));
2918
3076
 
2919
- CHECK_CUFFTDX(cufftDxSetOptionStr(h, commonDxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
3077
+ CHECK_CUFFTDX(cufftdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
2920
3078
 
2921
3079
  size_t lto_size = 0;
2922
- CHECK_CUFFTDX(cufftDxGetLTOIRSize(h, &lto_size));
3080
+ CHECK_CUFFTDX(cufftdxGetLTOIRSize(h, &lto_size));
2923
3081
 
2924
3082
  std::vector<char> lto(lto_size);
2925
- CHECK_CUFFTDX(cufftDxGetLTOIR(h, lto.size(), lto.data()));
3083
+ CHECK_CUFFTDX(cufftdxGetLTOIR(h, lto.size(), lto.data()));
2926
3084
 
2927
3085
  long long int smem = 0;
2928
- CHECK_CUFFTDX(cufftDxGetTraitInt64(h, cufftDxTraitType::CUFFTDX_TRAIT_SHARED_MEMORY_SIZE, &smem));
3086
+ CHECK_CUFFTDX(cufftdxGetTraitInt64(h, cufftdxTraitType::CUFFTDX_TRAIT_SHARED_MEMORY_SIZE, &smem));
2929
3087
  *shared_memory_size = (int)smem;
2930
3088
 
2931
3089
  if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
2932
3090
  res = false;
2933
3091
  }
2934
3092
 
2935
- CHECK_CUFFTDX(cufftDxDestroy(h));
3093
+ CHECK_CUFFTDX(cufftdxDestroy(h));
2936
3094
 
2937
3095
  return res;
2938
3096
  }
@@ -2949,38 +3107,92 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
2949
3107
 
2950
3108
  bool res = true;
2951
3109
  cublasdxHandle h;
2952
- CHECK_CUBLASDX(cublasDxCreate(&h));
3110
+ CHECK_CUBLASDX(cublasdxCreate(&h));
2953
3111
 
2954
- CHECK_CUBLASDX(cublasDxSetOperatorInt64(h, cublasDxOperatorType::CUBLASDX_OPERATOR_FUNCTION, cublasDxFunction::CUBLASDX_FUNCTION_MM));
2955
- CHECK_CUBLASDX(cublasDxSetOperatorInt64(h, cublasDxOperatorType::CUBLASDX_OPERATOR_EXECUTION, commonDxExecution::COMMONDX_EXECUTION_BLOCK));
2956
- CHECK_CUBLASDX(cublasDxSetOperatorInt64(h, cublasDxOperatorType::CUBLASDX_OPERATOR_API, cublasDxApi::CUBLASDX_API_BLOCK_SMEM));
3112
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_FUNCTION, cublasdxFunction::CUBLASDX_FUNCTION_MM));
3113
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
3114
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_API, cublasdxApi::CUBLASDX_API_BLOCK_SMEM));
2957
3115
  std::array<long long int, 3> precisions = {precision_A, precision_B, precision_C};
2958
- CHECK_CUBLASDX(cublasDxSetOperatorInt64Array(h, cublasDxOperatorType::CUBLASDX_OPERATOR_PRECISION, 3, precisions.data()));
2959
- CHECK_CUBLASDX(cublasDxSetOperatorInt64(h, cublasDxOperatorType::CUBLASDX_OPERATOR_SM, (long long)(arch * 10)));
2960
- CHECK_CUBLASDX(cublasDxSetOperatorInt64(h, cublasDxOperatorType::CUBLASDX_OPERATOR_TYPE, (cublasDxType)type));
3116
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_PRECISION, 3, precisions.data()));
3117
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SM, (long long)(arch * 10)));
3118
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_TYPE, (cublasdxType)type));
2961
3119
  std::array<long long int, 3> block_dim = {num_threads, 1, 1};
2962
- CHECK_CUBLASDX(cublasDxSetOperatorInt64Array(h, cublasDxOperatorType::CUBLASDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
3120
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
2963
3121
  std::array<long long int, 3> size = {M, N, K};
2964
- CHECK_CUBLASDX(cublasDxSetOperatorInt64Array(h, cublasDxOperatorType::CUBLASDX_OPERATOR_SIZE, size.size(), size.data()));
3122
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SIZE, size.size(), size.data()));
2965
3123
  std::array<long long int, 3> arrangement = {arrangement_A, arrangement_B, arrangement_C};
2966
- CHECK_CUBLASDX(cublasDxSetOperatorInt64Array(h, cublasDxOperatorType::CUBLASDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
3124
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
2967
3125
 
2968
- CHECK_CUBLASDX(cublasDxSetOptionStr(h, commonDxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
3126
+ CHECK_CUBLASDX(cublasdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
2969
3127
 
2970
3128
  size_t lto_size = 0;
2971
- CHECK_CUBLASDX(cublasDxGetLTOIRSize(h, &lto_size));
3129
+ CHECK_CUBLASDX(cublasdxGetLTOIRSize(h, &lto_size));
2972
3130
 
2973
3131
  std::vector<char> lto(lto_size);
2974
- CHECK_CUBLASDX(cublasDxGetLTOIR(h, lto.size(), lto.data()));
3132
+ CHECK_CUBLASDX(cublasdxGetLTOIR(h, lto.size(), lto.data()));
2975
3133
 
2976
3134
  if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
2977
3135
  res = false;
2978
3136
  }
2979
3137
 
2980
- CHECK_CUBLASDX(cublasDxDestroy(h));
3138
+ CHECK_CUBLASDX(cublasdxDestroy(h));
2981
3139
 
2982
3140
  return res;
2983
3141
  }
3142
+
3143
+ bool cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int function, int precision, int fill_mode, int num_threads)
3144
+ {
3145
+
3146
+ CHECK_ANY(ltoir_output_path != nullptr);
3147
+ CHECK_ANY(symbol_name != nullptr);
3148
+ CHECK_ANY(mathdx_include_dir == nullptr);
3149
+ CHECK_ANY(num_include_dirs == 0);
3150
+ CHECK_ANY(include_dirs == nullptr);
3151
+
3152
+ bool res = true;
3153
+
3154
+ cusolverHandle h { 0 };
3155
+ CHECK_CUSOLVER(cusolverCreate(&h));
3156
+ long long int size[2] = {M, N};
3157
+ long long int block_dim[3] = {num_threads, 1, 1};
3158
+ CHECK_CUSOLVER(cusolverSetOperatorInt64Array(h, cusolverOperatorType::CUSOLVER_OPERATOR_SIZE, 2, size));
3159
+ CHECK_CUSOLVER(cusolverSetOperatorInt64Array(h, cusolverOperatorType::CUSOLVER_OPERATOR_BLOCK_DIM, 3, block_dim));
3160
+ CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_TYPE, cusolverType::CUSOLVER_TYPE_REAL));
3161
+ CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_API, cusolverApi::CUSOLVER_API_BLOCK_SMEM));
3162
+ CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_FUNCTION, (cusolverFunction)function));
3163
+ CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
3164
+ CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_PRECISION, (commondxPrecision)precision));
3165
+ CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_FILL_MODE, (cusolverFillMode)fill_mode));
3166
+ CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_SM, (long long)(arch * 10)));
3167
+
3168
+ CHECK_CUSOLVER(cusolverSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
3169
+
3170
+ size_t lto_size = 0;
3171
+ CHECK_CUSOLVER(cusolverGetLTOIRSize(h, &lto_size));
3172
+
3173
+ std::vector<char> lto(lto_size);
3174
+ CHECK_CUSOLVER(cusolverGetLTOIR(h, lto.size(), lto.data()));
3175
+
3176
+ // This fatbin is universal, ie it is the same for any instantations of a cusolver device function
3177
+ size_t fatbin_size = 0;
3178
+ CHECK_CUSOLVER(cusolverGetUniversalFATBINSize(h, &fatbin_size));
3179
+
3180
+ std::vector<char> fatbin(fatbin_size);
3181
+ CHECK_CUSOLVER(cusolverGetUniversalFATBIN(h, fatbin.size(), fatbin.data()));
3182
+
3183
+ if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
3184
+ res = false;
3185
+ }
3186
+
3187
+ if(!write_file(fatbin.data(), fatbin.size(), fatbin_output_path, "wb")) {
3188
+ res = false;
3189
+ }
3190
+
3191
+ CHECK_CUSOLVER(cusolverDestroy(h));
3192
+
3193
+ return res;
3194
+ }
3195
+
2984
3196
  #endif
2985
3197
 
2986
3198
  void* cuda_load_module(void* context, const char* path)
@@ -3104,9 +3316,20 @@ void* cuda_load_module(void* context, const char* path)
3104
3316
 
3105
3317
  void cuda_unload_module(void* context, void* module)
3106
3318
  {
3107
- ContextGuard guard(context);
3108
-
3109
- check_cu(cuModuleUnload_f((CUmodule)module));
3319
+ // ensure there are no graph captures in progress
3320
+ if (g_captures.empty())
3321
+ {
3322
+ ContextGuard guard(context);
3323
+ check_cu(cuModuleUnload_f((CUmodule)module));
3324
+ }
3325
+ else
3326
+ {
3327
+ // defer until graph capture completes
3328
+ ModuleInfo module_info;
3329
+ module_info.context = context ? context : get_current_context();
3330
+ module_info.module = module;
3331
+ g_deferred_module_list.push_back(module_info);
3332
+ }
3110
3333
  }
3111
3334
 
3112
3335
 
@@ -3154,7 +3377,7 @@ size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_block
3154
3377
  if (block_dim <= 0)
3155
3378
  {
3156
3379
  #if defined(_DEBUG)
3157
- fprintf(stderr, "Warp warning: Launch got block_dim %d. Setting to 256.\n", dim, block_dim);
3380
+ fprintf(stderr, "Warp warning: Launch got block_dim %d. Setting to 256.\n", block_dim);
3158
3381
  #endif
3159
3382
  block_dim = 256;
3160
3383
  }