warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/native/sparse.cu CHANGED
@@ -27,40 +27,29 @@ CUDA_CALLABLE uint32_t bsr_get_col(const BsrRowCol &row_col) {
27
27
 
28
28
  // Cached temporary storage
29
29
  struct BsrFromTripletsTemp {
30
- // Temp work buffers
31
- int nnz = 0;
32
- int *block_indices = NULL;
33
-
34
- BsrRowCol *combined_row_col = NULL;
35
-
30
+
31
+ int *count_buffer = NULL;
36
32
  cudaEvent_t host_sync_event = NULL;
37
33
 
38
- void ensure_fits(size_t size) {
39
-
40
- if (size > nnz) {
41
- size = std::max(2 * size, (static_cast<size_t>(nnz) * 3) / 2);
42
-
43
- free_device(WP_CURRENT_CONTEXT, block_indices);
44
- free_device(WP_CURRENT_CONTEXT, combined_row_col);
45
-
46
- // Factor 2 for in / out versions , +1 for count
47
- block_indices = static_cast<int *>(
48
- alloc_device(WP_CURRENT_CONTEXT, (2 * size + 1) * sizeof(int)));
49
- combined_row_col = static_cast<BsrRowCol *>(
50
- alloc_device(WP_CURRENT_CONTEXT, 2 * size * sizeof(BsrRowCol)));
34
+ BsrFromTripletsTemp()
35
+ : count_buffer(static_cast<int*>(alloc_pinned(sizeof(int))))
36
+ {
37
+ cudaEventCreateWithFlags(&host_sync_event, cudaEventDisableTiming);
38
+ }
39
+
40
+ ~BsrFromTripletsTemp()
41
+ {
42
+ cudaEventDestroy(host_sync_event);
43
+ free_pinned(count_buffer);
44
+ }
51
45
 
52
- nnz = size;
53
- }
46
+ BsrFromTripletsTemp(const BsrFromTripletsTemp&) = delete;
47
+ BsrFromTripletsTemp& operator=(const BsrFromTripletsTemp&) = delete;
54
48
 
55
- if (host_sync_event == NULL) {
56
- cudaEventCreateWithFlags(&host_sync_event, cudaEventDisableTiming);
57
- }
58
- }
59
49
  };
60
50
 
61
51
  // map temp buffers to CUDA contexts
62
- static std::unordered_map<void *, BsrFromTripletsTemp>
63
- g_bsr_from_triplets_temp_map;
52
+ static std::unordered_map<void *, BsrFromTripletsTemp> g_bsr_from_triplets_temp_map;
64
53
 
65
54
  template <typename T> struct BsrBlockIsNotZero {
66
55
  int block_size;
@@ -145,24 +134,22 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
145
134
  const int block_size = rows_per_block * cols_per_block;
146
135
 
147
136
  void *context = cuda_context_get_current();
137
+ ContextGuard guard(context);
148
138
 
149
139
  // Per-context cached temporary buffers
150
- PinnedTemporaryBuffer &pinned_temp = g_pinned_temp_buffer_map[context];
151
140
  BsrFromTripletsTemp &bsr_temp = g_bsr_from_triplets_temp_map[context];
152
141
 
153
- ContextGuard guard(context);
154
-
155
142
  cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
156
- bsr_temp.ensure_fits(nnz);
157
143
 
158
- cub::DoubleBuffer<int> d_keys(bsr_temp.block_indices,
159
- bsr_temp.block_indices + nnz);
160
- cub::DoubleBuffer<BsrRowCol> d_values(bsr_temp.combined_row_col,
161
- bsr_temp.combined_row_col + nnz);
144
+ ScopedTemporary<int> block_indices(context, 2*nnz);
145
+ ScopedTemporary<BsrRowCol> combined_row_col(context, 2*nnz);
162
146
 
163
- int *d_nz_triplet_count = bsr_temp.block_indices + 2 * nnz;
164
- pinned_temp.ensure_fits(sizeof(int));
165
- int *pinned_count = static_cast<int *>(pinned_temp.buffer);
147
+ cub::DoubleBuffer<int> d_keys(block_indices.buffer(),
148
+ block_indices.buffer() + nnz);
149
+ cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(),
150
+ combined_row_col.buffer() + nnz);
151
+
152
+ int *p_nz_triplet_count = bsr_temp.count_buffer;
166
153
 
167
154
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_block_indices, nnz,
168
155
  (nnz, d_keys.Current()));
@@ -170,33 +157,29 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
170
157
  if (tpl_values) {
171
158
 
172
159
  // Remove zero blocks
173
- size_t buff_size = 0;
174
- BsrBlockIsNotZero<T> isNotZero{block_size, tpl_values};
175
- check_cuda(cub::DeviceSelect::If(nullptr, buff_size, d_keys.Current(),
176
- d_keys.Alternate(), d_nz_triplet_count,
177
- nnz, isNotZero, stream));
178
- void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
179
- check_cuda(cub::DeviceSelect::If(
180
- temp_buffer, buff_size, d_keys.Current(), d_keys.Alternate(),
181
- d_nz_triplet_count, nnz, isNotZero, stream));
182
- free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
160
+ {
161
+ size_t buff_size = 0;
162
+ BsrBlockIsNotZero<T> isNotZero{block_size, tpl_values};
163
+ check_cuda(cub::DeviceSelect::If(nullptr, buff_size, d_keys.Current(),
164
+ d_keys.Alternate(), p_nz_triplet_count,
165
+ nnz, isNotZero, stream));
166
+ ScopedTemporary<> temp(context, buff_size);
167
+ check_cuda(cub::DeviceSelect::If(
168
+ temp.buffer(), buff_size, d_keys.Current(), d_keys.Alternate(),
169
+ p_nz_triplet_count, nnz, isNotZero, stream));
170
+ }
171
+ cudaEventRecord(bsr_temp.host_sync_event, stream);
183
172
 
184
173
  // switch current/alternate in double buffer
185
174
  d_keys.selector ^= 1;
186
175
 
187
- // Copy number of remaining items to host, needed for further launches
188
- memcpy_d2h(WP_CURRENT_CONTEXT, pinned_count, d_nz_triplet_count,
189
- sizeof(int));
190
- cudaEventRecord(bsr_temp.host_sync_event, stream);
191
176
  } else {
192
- *pinned_count = nnz;
193
- memcpy_h2d(WP_CURRENT_CONTEXT, d_nz_triplet_count, pinned_count,
194
- sizeof(int));
177
+ *p_nz_triplet_count = nnz;
195
178
  }
196
179
 
197
180
  // Combine rows and columns so we can sort on them both
198
181
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_row_col, nnz,
199
- (d_nz_triplet_count, d_keys.Current(), tpl_rows, tpl_columns,
182
+ (p_nz_triplet_count, d_keys.Current(), tpl_rows, tpl_columns,
200
183
  d_values.Current()));
201
184
 
202
185
  if (tpl_values) {
@@ -204,30 +187,31 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
204
187
  cudaEventSynchronize(bsr_temp.host_sync_event);
205
188
  }
206
189
 
207
- const int nz_triplet_count = *pinned_count;
190
+ const int nz_triplet_count = *p_nz_triplet_count;
208
191
 
209
192
  // Sort
210
- size_t buff_size = 0;
211
- check_cuda(cub::DeviceRadixSort::SortPairs(
212
- nullptr, buff_size, d_values, d_keys, nz_triplet_count, 0, 64, stream));
213
- void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
214
-
215
- check_cuda(cub::DeviceRadixSort::SortPairs(temp_buffer, buff_size,
216
- d_values, d_keys, nz_triplet_count,
217
- 0, 64, stream));
218
- free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
193
+ {
194
+ size_t buff_size = 0;
195
+ check_cuda(cub::DeviceRadixSort::SortPairs(
196
+ nullptr, buff_size, d_values, d_keys, nz_triplet_count, 0, 64, stream));
197
+ ScopedTemporary<> temp(context, buff_size);
198
+ check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size,
199
+ d_values, d_keys, nz_triplet_count,
200
+ 0, 64, stream));
201
+ }
219
202
 
220
203
  // Runlength encode row-col sequences
221
- check_cuda(cub::DeviceRunLengthEncode::Encode(
222
- nullptr, buff_size, d_values.Current(), d_values.Alternate(),
223
- d_keys.Alternate(), d_nz_triplet_count, nz_triplet_count, stream));
224
- temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
225
- check_cuda(cub::DeviceRunLengthEncode::Encode(
226
- temp_buffer, buff_size, d_values.Current(), d_values.Alternate(),
227
- d_keys.Alternate(), d_nz_triplet_count, nz_triplet_count, stream));
228
- free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
229
-
230
- memcpy_d2h(WP_CURRENT_CONTEXT, pinned_count, d_nz_triplet_count, sizeof(int));
204
+ {
205
+ size_t buff_size = 0;
206
+ check_cuda(cub::DeviceRunLengthEncode::Encode(
207
+ nullptr, buff_size, d_values.Current(), d_values.Alternate(),
208
+ d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
209
+ ScopedTemporary<> temp(context, buff_size);
210
+ check_cuda(cub::DeviceRunLengthEncode::Encode(
211
+ temp.buffer(), buff_size, d_values.Current(), d_values.Alternate(),
212
+ d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
213
+ }
214
+
231
215
  cudaEventRecord(bsr_temp.host_sync_event, stream);
232
216
 
233
217
  // Now we have the following:
@@ -237,14 +221,16 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
237
221
  // d_keys.Alternate(): repeated block-row count
238
222
 
239
223
  // Scan repeated block counts
240
- check_cuda(cub::DeviceScan::InclusiveSum(
241
- nullptr, buff_size, d_keys.Alternate(), d_keys.Alternate(),
242
- nz_triplet_count, stream));
243
- temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
244
- check_cuda(cub::DeviceScan::InclusiveSum(
245
- temp_buffer, buff_size, d_keys.Alternate(), d_keys.Alternate(),
246
- nz_triplet_count, stream));
247
- free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
224
+ {
225
+ size_t buff_size = 0;
226
+ check_cuda(cub::DeviceScan::InclusiveSum(
227
+ nullptr, buff_size, d_keys.Alternate(), d_keys.Alternate(),
228
+ nz_triplet_count, stream));
229
+ ScopedTemporary<> temp(context, buff_size);
230
+ check_cuda(cub::DeviceScan::InclusiveSum(
231
+ temp.buffer(), buff_size, d_keys.Alternate(), d_keys.Alternate(),
232
+ nz_triplet_count, stream));
233
+ }
248
234
 
249
235
  // While we're at it, zero the bsr offsets buffer
250
236
  memset_device(WP_CURRENT_CONTEXT, bsr_offsets, 0,
@@ -252,7 +238,7 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
252
238
 
253
239
  // Wait for number of compressed blocks
254
240
  cudaEventSynchronize(bsr_temp.host_sync_event);
255
- const int compressed_nnz = *pinned_count;
241
+ const int compressed_nnz = *p_nz_triplet_count;
256
242
 
257
243
  // We have all we need to accumulate our repeated blocks
258
244
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, compressed_nnz,
@@ -261,13 +247,15 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
261
247
  bsr_offsets, bsr_columns, bsr_values));
262
248
 
263
249
  // Last, prefix sum the row block counts
264
- check_cuda(cub::DeviceScan::InclusiveSum(nullptr, buff_size, bsr_offsets,
265
- bsr_offsets, row_count + 1, stream));
266
- temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
267
- check_cuda(cub::DeviceScan::InclusiveSum(temp_buffer, buff_size,
268
- bsr_offsets, bsr_offsets,
269
- row_count + 1, stream));
270
- free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
250
+ {
251
+ size_t buff_size = 0;
252
+ check_cuda(cub::DeviceScan::InclusiveSum(nullptr, buff_size, bsr_offsets,
253
+ bsr_offsets, row_count + 1, stream));
254
+ ScopedTemporary<> temp(context, buff_size);
255
+ check_cuda(cub::DeviceScan::InclusiveSum(temp.buffer(), buff_size,
256
+ bsr_offsets, bsr_offsets,
257
+ row_count + 1, stream));
258
+ }
271
259
 
272
260
  return compressed_nnz;
273
261
  }
@@ -350,117 +338,75 @@ bsr_transpose_blocks(const int nnz, const int block_size,
350
338
  }
351
339
 
352
340
  template <typename T>
353
- void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
354
- int col_count, int nnz, const int *bsr_offsets,
355
- const int *bsr_columns, const T *bsr_values,
356
- int *transposed_bsr_offsets,
357
- int *transposed_bsr_columns,
358
- T *transposed_bsr_values) {
359
-
360
- const int block_size = rows_per_block * cols_per_block;
361
-
362
- void *context = cuda_context_get_current();
363
-
364
- // Per-context cached temporary buffer
365
- BsrFromTripletsTemp &bsr_temp = g_bsr_from_triplets_temp_map[context];
366
-
367
- ContextGuard guard(context);
368
-
369
- cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
370
- bsr_temp.ensure_fits(nnz);
371
-
372
- // Zero the transposed offsets
373
- memset_device(WP_CURRENT_CONTEXT, transposed_bsr_offsets, 0,
374
- (col_count + 1) * sizeof(int));
375
-
376
- cub::DoubleBuffer<int> d_keys(bsr_temp.block_indices,
377
- bsr_temp.block_indices + nnz);
378
- cub::DoubleBuffer<BsrRowCol> d_values(bsr_temp.combined_row_col,
379
- bsr_temp.combined_row_col + nnz);
380
-
381
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
382
- (nnz, row_count, bsr_offsets, bsr_columns, d_keys.Current(),
383
- d_values.Current(), transposed_bsr_offsets));
341
+ void
342
+ launch_bsr_transpose_blocks(const int nnz, const int block_size,
343
+ const int rows_per_block, const int cols_per_block,
344
+ const int *block_indices,
345
+ const BsrRowCol *transposed_indices,
346
+ const T *bsr_values,
347
+ int *transposed_bsr_columns, T *transposed_bsr_values) {
384
348
 
385
- // Sort blocks
386
- size_t buff_size = 0;
387
- check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values,
388
- d_keys, nnz, 0, 64, stream));
389
- void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
390
- check_cuda(cub::DeviceRadixSort::SortPairs(
391
- temp_buffer, buff_size, d_values, d_keys, nnz, 0, 64, stream));
392
-
393
- // Prefix sum the trasnposed row block counts
394
- check_cuda(cub::DeviceScan::InclusiveSum(
395
- nullptr, buff_size, transposed_bsr_offsets, transposed_bsr_offsets,
396
- col_count + 1, stream));
397
- temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
398
- check_cuda(cub::DeviceScan::InclusiveSum(
399
- temp_buffer, buff_size, transposed_bsr_offsets,
400
- transposed_bsr_offsets, col_count + 1, stream));
401
-
402
- // Move and transpose invidual blocks
403
- switch (row_count) {
349
+ switch (rows_per_block) {
404
350
  case 1:
405
- switch (col_count) {
351
+ switch (cols_per_block) {
406
352
  case 1:
407
353
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
408
354
  (nnz, block_size, BsrBlockTransposer<1, 1, T>{},
409
- d_keys.Current(), d_values.Current(), bsr_values,
355
+ block_indices, transposed_indices, bsr_values,
410
356
  transposed_bsr_columns, transposed_bsr_values));
411
357
  return;
412
358
  case 2:
413
359
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
414
360
  (nnz, block_size, BsrBlockTransposer<1, 2, T>{},
415
- d_keys.Current(), d_values.Current(), bsr_values,
361
+ block_indices, transposed_indices, bsr_values,
416
362
  transposed_bsr_columns, transposed_bsr_values));
417
363
  return;
418
364
  case 3:
419
365
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
420
366
  (nnz, block_size, BsrBlockTransposer<1, 3, T>{},
421
- d_keys.Current(), d_values.Current(), bsr_values,
367
+ block_indices, transposed_indices, bsr_values,
422
368
  transposed_bsr_columns, transposed_bsr_values));
423
369
  return;
424
370
  }
425
371
  case 2:
426
- switch (col_count) {
372
+ switch (cols_per_block) {
427
373
  case 1:
428
374
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
429
375
  (nnz, block_size, BsrBlockTransposer<2, 1, T>{},
430
- d_keys.Current(), d_values.Current(), bsr_values,
376
+ block_indices, transposed_indices, bsr_values,
431
377
  transposed_bsr_columns, transposed_bsr_values));
432
378
  return;
433
379
  case 2:
434
380
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
435
381
  (nnz, block_size, BsrBlockTransposer<2, 2, T>{},
436
- d_keys.Current(), d_values.Current(), bsr_values,
382
+ block_indices, transposed_indices, bsr_values,
437
383
  transposed_bsr_columns, transposed_bsr_values));
438
384
  return;
439
385
  case 3:
440
386
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
441
387
  (nnz, block_size, BsrBlockTransposer<2, 3, T>{},
442
- d_keys.Current(), d_values.Current(), bsr_values,
388
+ block_indices, transposed_indices, bsr_values,
443
389
  transposed_bsr_columns, transposed_bsr_values));
444
390
  return;
445
391
  }
446
392
  case 3:
447
- switch (col_count) {
393
+ switch (cols_per_block) {
448
394
  case 1:
449
395
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
450
396
  (nnz, block_size, BsrBlockTransposer<3, 1, T>{},
451
- d_keys.Current(), d_values.Current(), bsr_values,
397
+ block_indices, transposed_indices, bsr_values,
452
398
  transposed_bsr_columns, transposed_bsr_values));
453
399
  return;
454
400
  case 2:
455
401
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
456
402
  (nnz, block_size, BsrBlockTransposer<3, 2, T>{},
457
- d_keys.Current(), d_values.Current(), bsr_values,
403
+ block_indices, transposed_indices, bsr_values,
458
404
  transposed_bsr_columns, transposed_bsr_values));
459
405
  return;
460
406
  case 3:
461
407
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
462
408
  (nnz, block_size, BsrBlockTransposer<3, 3, T>{},
463
- d_keys.Current(), d_values.Current(), bsr_values,
409
+ block_indices, transposed_indices, bsr_values,
464
410
  transposed_bsr_columns, transposed_bsr_values));
465
411
  return;
466
412
  }
@@ -470,10 +416,72 @@ void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
470
416
  WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
471
417
  (nnz, block_size,
472
418
  BsrBlockTransposer<-1, -1, T>{rows_per_block, cols_per_block},
473
- d_keys.Current(), d_values.Current(), bsr_values, transposed_bsr_columns,
419
+ block_indices, transposed_indices, bsr_values, transposed_bsr_columns,
474
420
  transposed_bsr_values));
475
421
  }
476
422
 
423
+ template <typename T>
424
+ void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
425
+ int col_count, int nnz, const int *bsr_offsets,
426
+ const int *bsr_columns, const T *bsr_values,
427
+ int *transposed_bsr_offsets,
428
+ int *transposed_bsr_columns,
429
+ T *transposed_bsr_values) {
430
+
431
+ const int block_size = rows_per_block * cols_per_block;
432
+
433
+ void *context = cuda_context_get_current();
434
+ ContextGuard guard(context);
435
+
436
+ cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
437
+
438
+ // Zero the transposed offsets
439
+ memset_device(WP_CURRENT_CONTEXT, transposed_bsr_offsets, 0,
440
+ (col_count + 1) * sizeof(int));
441
+
442
+ ScopedTemporary<int> block_indices(context, 2*nnz);
443
+ ScopedTemporary<BsrRowCol> combined_row_col(context, 2*nnz);
444
+
445
+ cub::DoubleBuffer<int> d_keys(block_indices.buffer(),
446
+ block_indices.buffer() + nnz);
447
+ cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(),
448
+ combined_row_col.buffer() + nnz);
449
+
450
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
451
+ (nnz, row_count, bsr_offsets, bsr_columns, d_keys.Current(),
452
+ d_values.Current(), transposed_bsr_offsets));
453
+
454
+ // Sort blocks
455
+ {
456
+ size_t buff_size = 0;
457
+ check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values,
458
+ d_keys, nnz, 0, 64, stream));
459
+ void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
460
+ ScopedTemporary<> temp(context, buff_size);
461
+ check_cuda(cub::DeviceRadixSort::SortPairs(
462
+ temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
463
+ }
464
+
465
+ // Prefix sum the transposed row block counts
466
+ {
467
+ size_t buff_size = 0;
468
+ check_cuda(cub::DeviceScan::InclusiveSum(
469
+ nullptr, buff_size, transposed_bsr_offsets, transposed_bsr_offsets,
470
+ col_count + 1, stream));
471
+ ScopedTemporary<> temp(context, buff_size);
472
+ check_cuda(cub::DeviceScan::InclusiveSum(
473
+ temp.buffer(), buff_size, transposed_bsr_offsets,
474
+ transposed_bsr_offsets, col_count + 1, stream));
475
+ }
476
+
477
+ // Move and transpose individual blocks
478
+ launch_bsr_transpose_blocks(
479
+ nnz, block_size,
480
+ rows_per_block, cols_per_block,
481
+ d_keys.Current(), d_values.Current(), bsr_values, transposed_bsr_columns,
482
+ transposed_bsr_values);
483
+ }
484
+
477
485
  } // namespace
478
486
 
479
487
  int bsr_matrix_from_triplets_float_device(
warp/native/spatial.h CHANGED
@@ -265,13 +265,13 @@ inline CUDA_CALLABLE Type tensordot(const transform_t<Type>& a, const transform_
265
265
  }
266
266
 
267
267
  template<typename Type>
268
- inline CUDA_CALLABLE Type index(const transform_t<Type>& t, int i)
268
+ inline CUDA_CALLABLE Type extract(const transform_t<Type>& t, int i)
269
269
  {
270
270
  return t[i];
271
271
  }
272
272
 
273
273
  template<typename Type>
274
- inline void CUDA_CALLABLE adj_index(const transform_t<Type>& t, int i, transform_t<Type>& adj_t, int& adj_i, Type adj_ret)
274
+ inline void CUDA_CALLABLE adj_extract(const transform_t<Type>& t, int i, transform_t<Type>& adj_t, int& adj_i, Type adj_ret)
275
275
  {
276
276
  adj_t[i] += adj_ret;
277
277
  }
warp/native/temp_buffer.h CHANGED
@@ -1,26 +1,30 @@
1
1
 
2
2
  #pragma once
3
3
 
4
- #include "warp.h"
5
4
  #include "cuda_util.h"
5
+ #include "warp.h"
6
6
 
7
7
  #include <unordered_map>
8
8
 
9
- struct PinnedTemporaryBuffer
9
+ template <typename T = char> struct ScopedTemporary
10
10
  {
11
- void *buffer = NULL;
12
- size_t buffer_size = 0;
13
11
 
14
- void ensure_fits(size_t size)
12
+ ScopedTemporary(void *context, size_t size)
13
+ : m_context(context), m_buffer(static_cast<T*>(alloc_temp_device(m_context, size * sizeof(T))))
15
14
  {
16
- if (size > buffer_size)
17
- {
18
- free_pinned(buffer);
19
- buffer = alloc_pinned(size);
20
- buffer_size = size;
21
- }
22
15
  }
23
- };
24
16
 
25
- // map temp buffers to CUDA contexts
26
- static std::unordered_map<void *, PinnedTemporaryBuffer> g_pinned_temp_buffer_map;
17
+ ~ScopedTemporary()
18
+ {
19
+ free_temp_device(m_context, m_buffer);
20
+ }
21
+
22
+ T *buffer() const
23
+ {
24
+ return m_buffer;
25
+ }
26
+
27
+ private:
28
+ void *m_context;
29
+ T *m_buffer;
30
+ };