warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/marching.cu CHANGED
@@ -1,8 +1,6 @@
1
1
  #include "warp.h"
2
2
  #include "cuda_util.h"
3
-
4
- #include "thrust/device_ptr.h"
5
- #include "thrust/sort.h"
3
+ #include "scan.h"
6
4
 
7
5
  namespace wp {
8
6
 
@@ -162,13 +160,17 @@ namespace wp {
162
160
  };
163
161
 
164
162
 
165
- // ---------------------------------------------------------------------------------------
166
- struct MarchingCubes
167
- {
168
- MarchingCubes()
163
+ // ---------------------------------------------------------------------------------------
164
+ struct MarchingCubes
165
+ {
166
+ MarchingCubes()
169
167
  {
170
- memset(this, 0, sizeof(MarchingCubes));
171
- }
168
+ memset(this, 0, sizeof(MarchingCubes));
169
+ first_cell_vert = nullptr;
170
+ first_cell_tri = nullptr;
171
+ cell_verts = nullptr;
172
+ context = nullptr;
173
+ }
172
174
 
173
175
  __device__ __host__ int cell_index(int xi, int yi, int zi) const
174
176
  {
@@ -181,169 +183,169 @@ namespace wp {
181
183
  xi = cell_index / ny;
182
184
  }
183
185
 
184
- // grid
186
+ // grid
185
187
  int nx;
186
188
  int ny;
187
189
  int nz;
188
190
 
189
- int* first_cell_vert;
190
- int* first_cell_tri;
191
- int* cell_verts;
191
+ int* first_cell_vert;
192
+ int* first_cell_tri;
193
+ int* cell_verts;
192
194
 
193
195
  int num_cells;
194
196
  int max_cells;
195
197
 
196
198
  void* context;
197
- };
198
-
199
-
200
- // -----------------------------------------------------------------------------------
201
- __global__ void count_cell_verts(MarchingCubes mc, const float* density, float threshold)
202
- {
203
- int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
204
- if (cell_index >= mc.num_cells)
205
- return;
206
-
207
- int xi, yi, zi;
208
- mc.cell_coord(cell_index, xi, yi, zi);
209
-
210
- mc.first_cell_vert[cell_index] = 0;
211
- if (xi >= mc.nx - 1 || yi >= mc.ny - 1 || zi >= mc.nz - 1)
212
- return;
213
-
214
- float d0 = density[cell_index];
215
- float dx = density[mc.cell_index(xi + 1, yi, zi)];
216
- float dy = density[mc.cell_index(xi, yi + 1, zi)];
217
- float dz = density[mc.cell_index(xi, yi, zi + 1)];
218
-
219
- int num = 0;
220
- if ((d0 <= threshold && dx >= threshold) || (dx <= threshold && d0 >= threshold))
221
- num++;
222
- if ((d0 <= threshold && dy >= threshold) || (dy <= threshold && d0 >= threshold))
223
- num++;
224
- if ((d0 <= threshold && dz >= threshold) || (dz <= threshold && d0 >= threshold))
225
- num++;
226
-
227
- mc.first_cell_vert[cell_index] = num;
228
- }
229
-
230
- // -----------------------------------------------------------------------------------
231
- __global__ void create_cell_verts(MarchingCubes mc, vec3* __restrict__ vertices, vec3* normals, const float* __restrict__ density, float threshold)
232
- {
233
- int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
234
- if (cell_index >= mc.num_cells)
235
- return;
236
-
237
- int xi, yi, zi;
238
- mc.cell_coord(cell_index, xi, yi, zi);
239
- if (xi >= mc.nx - 1 || yi >= mc.ny - 1 || zi >= mc.nz - 1)
240
- return;
241
-
242
- vec3 p = vec3(xi + 0.5f, yi + 0.5f, zi + 0.5f);
243
-
244
- float d0 = density[cell_index];
245
- float ds[3];
246
- ds[0] = density[mc.cell_index(xi + 1, yi, zi)];
247
- ds[1] = density[mc.cell_index(xi, yi + 1, zi)];
248
- ds[2] = density[mc.cell_index(xi, yi, zi + 1)];
249
-
250
- // vec3 n0 = densityNormal[cell_index];
251
- // vec3 ns[3];
252
- // ns[0] = densityNormal[mc.cell_index(xi + 1, yi, zi)];
253
- // ns[1] = densityNormal[mc.cell_index(xi, yi + 1, zi)];
254
- // ns[2] = densityNormal[mc.cell_index(xi, yi, zi + 1)];
255
-
256
- int first = mc.first_cell_vert[cell_index];
257
-
258
- for (int dim = 0; dim < 3; dim++)
199
+ };
200
+
201
+
202
+ // -----------------------------------------------------------------------------------
203
+ __global__ void count_cell_verts(MarchingCubes mc, const float* density, float threshold)
204
+ {
205
+ int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
206
+ if (cell_index >= mc.num_cells)
207
+ return;
208
+
209
+ int xi, yi, zi;
210
+ mc.cell_coord(cell_index, xi, yi, zi);
211
+
212
+ mc.first_cell_vert[cell_index] = 0;
213
+ if (xi >= mc.nx - 1 || yi >= mc.ny - 1 || zi >= mc.nz - 1)
214
+ return;
215
+
216
+ float d0 = density[cell_index];
217
+ float dx = density[mc.cell_index(xi + 1, yi, zi)];
218
+ float dy = density[mc.cell_index(xi, yi + 1, zi)];
219
+ float dz = density[mc.cell_index(xi, yi, zi + 1)];
220
+
221
+ int num = 0;
222
+ if ((d0 <= threshold && dx >= threshold) || (dx <= threshold && d0 >= threshold))
223
+ num++;
224
+ if ((d0 <= threshold && dy >= threshold) || (dy <= threshold && d0 >= threshold))
225
+ num++;
226
+ if ((d0 <= threshold && dz >= threshold) || (dz <= threshold && d0 >= threshold))
227
+ num++;
228
+
229
+ mc.first_cell_vert[cell_index] = num;
230
+ }
231
+
232
+ // -----------------------------------------------------------------------------------
233
+ __global__ void create_cell_verts(MarchingCubes mc, vec3* __restrict__ vertices, vec3* normals, const float* __restrict__ density, float threshold)
234
+ {
235
+ int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
236
+ if (cell_index >= mc.num_cells)
237
+ return;
238
+
239
+ int xi, yi, zi;
240
+ mc.cell_coord(cell_index, xi, yi, zi);
241
+ if (xi >= mc.nx - 1 || yi >= mc.ny - 1 || zi >= mc.nz - 1)
242
+ return;
243
+
244
+ vec3 p = vec3(xi + 0.5f, yi + 0.5f, zi + 0.5f);
245
+
246
+ float d0 = density[cell_index];
247
+ float ds[3];
248
+ ds[0] = density[mc.cell_index(xi + 1, yi, zi)];
249
+ ds[1] = density[mc.cell_index(xi, yi + 1, zi)];
250
+ ds[2] = density[mc.cell_index(xi, yi, zi + 1)];
251
+
252
+ // vec3 n0 = densityNormal[cell_index];
253
+ // vec3 ns[3];
254
+ // ns[0] = densityNormal[mc.cell_index(xi + 1, yi, zi)];
255
+ // ns[1] = densityNormal[mc.cell_index(xi, yi + 1, zi)];
256
+ // ns[2] = densityNormal[mc.cell_index(xi, yi, zi + 1)];
257
+
258
+ int first = mc.first_cell_vert[cell_index];
259
+
260
+ for (int dim = 0; dim < 3; dim++)
259
261
  {
260
- float d = ds[dim];
261
- mc.cell_verts[3 * cell_index + dim] = 0;
262
+ float d = ds[dim];
263
+ mc.cell_verts[3 * cell_index + dim] = 0;
262
264
 
263
- if ((d0 <= threshold && d >= threshold) || (d <= threshold && d0 >= threshold))
265
+ if ((d0 <= threshold && d >= threshold) || (d <= threshold && d0 >= threshold))
264
266
  {
265
- float t = (d != d0) ? clamp((threshold - d0) / (d - d0), 0.0f, 1.0f) : 0.5f;
266
- int id = first++;
267
-
268
- vec3 off;
269
- off[dim] = t;
270
- vertices[id] = p + off;
271
-
272
- // vec3 n = normalize(n0 + t * (ns[dim] - n0));
273
- // normals[id] = -n;
274
-
275
- mc.cell_verts[3 * cell_index + dim] = id;
276
- }
277
- }
278
- }
279
-
280
- // -----------------------------------------------------------------------------------
281
- __global__ void count_cell_tris(MarchingCubes mc, const float* __restrict__ density, float threshold)
282
- {
283
- int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
284
- if (cell_index >= mc.num_cells)
285
- return;
286
-
287
- int xi, yi, zi;
288
- mc.cell_coord(cell_index, xi, yi, zi);
289
-
290
- mc.first_cell_tri[cell_index] = 0;
291
- if (xi >= mc.nx - 2 || yi >= mc.ny - 2 || zi >= mc.nz - 2)
292
- return;
293
-
294
- int code = 0;
295
- for (int i = 0; i < 8; i++) {
296
- int cxi = xi + marchingCubeCorners[i][0];
297
- int cyi = yi + marchingCubeCorners[i][1];
298
- int czi = zi + marchingCubeCorners[i][2];
299
-
300
- if (density[mc.cell_index(cxi, cyi, czi)] >= threshold)
301
- code |= (1 << i);
302
- }
303
-
304
- mc.first_cell_tri[cell_index] = firstMarchingCubesId[code + 1] - firstMarchingCubesId[code];
305
- }
306
-
307
- // -----------------------------------------------------------------------------------
308
- __global__ void create_cell_tris(MarchingCubes mc, const float* __restrict__ density, int* __restrict__ triangles, float threshold)
309
- {
310
- int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
311
- if (cell_index >= mc.num_cells)
312
- return;
313
-
314
- int xi, yi, zi;
315
- mc.cell_coord(cell_index, xi, yi, zi);
316
- if (xi >= mc.nx - 2 || yi >= mc.ny - 2 || zi >= mc.nz - 2)
317
- return;
318
-
319
- int code = 0;
320
- for (int i = 0; i < 8; i++)
267
+ float t = (d != d0) ? clamp((threshold - d0) / (d - d0), 0.0f, 1.0f) : 0.5f;
268
+ int id = first++;
269
+
270
+ vec3 off;
271
+ off[dim] = t;
272
+ vertices[id] = p + off;
273
+
274
+ // vec3 n = normalize(n0 + t * (ns[dim] - n0));
275
+ // normals[id] = -n;
276
+
277
+ mc.cell_verts[3 * cell_index + dim] = id;
278
+ }
279
+ }
280
+ }
281
+
282
+ // -----------------------------------------------------------------------------------
283
+ __global__ void count_cell_tris(MarchingCubes mc, const float* __restrict__ density, float threshold)
284
+ {
285
+ int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
286
+ if (cell_index >= mc.num_cells)
287
+ return;
288
+
289
+ int xi, yi, zi;
290
+ mc.cell_coord(cell_index, xi, yi, zi);
291
+
292
+ mc.first_cell_tri[cell_index] = 0;
293
+ if (xi >= mc.nx - 2 || yi >= mc.ny - 2 || zi >= mc.nz - 2)
294
+ return;
295
+
296
+ int code = 0;
297
+ for (int i = 0; i < 8; i++) {
298
+ int cxi = xi + marchingCubeCorners[i][0];
299
+ int cyi = yi + marchingCubeCorners[i][1];
300
+ int czi = zi + marchingCubeCorners[i][2];
301
+
302
+ if (density[mc.cell_index(cxi, cyi, czi)] >= threshold)
303
+ code |= (1 << i);
304
+ }
305
+
306
+ mc.first_cell_tri[cell_index] = firstMarchingCubesId[code + 1] - firstMarchingCubesId[code];
307
+ }
308
+
309
+ // -----------------------------------------------------------------------------------
310
+ __global__ void create_cell_tris(MarchingCubes mc, const float* __restrict__ density, int* __restrict__ triangles, float threshold)
311
+ {
312
+ int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
313
+ if (cell_index >= mc.num_cells)
314
+ return;
315
+
316
+ int xi, yi, zi;
317
+ mc.cell_coord(cell_index, xi, yi, zi);
318
+ if (xi >= mc.nx - 2 || yi >= mc.ny - 2 || zi >= mc.nz - 2)
319
+ return;
320
+
321
+ int code = 0;
322
+ for (int i = 0; i < 8; i++)
321
323
  {
322
- int cxi = xi + marchingCubeCorners[i][0];
323
- int cyi = yi + marchingCubeCorners[i][1];
324
- int czi = zi + marchingCubeCorners[i][2];
324
+ int cxi = xi + marchingCubeCorners[i][0];
325
+ int cyi = yi + marchingCubeCorners[i][1];
326
+ int czi = zi + marchingCubeCorners[i][2];
325
327
 
326
- if (density[mc.cell_index(cxi, cyi, czi)] >= threshold)
327
- code |= (1 << i);
328
- }
328
+ if (density[mc.cell_index(cxi, cyi, czi)] >= threshold)
329
+ code |= (1 << i);
330
+ }
329
331
 
330
- int firstIn = firstMarchingCubesId[code];
331
- int num = firstMarchingCubesId[code + 1] - firstIn;
332
- int firstOut = mc.first_cell_tri[cell_index];
332
+ int firstIn = firstMarchingCubesId[code];
333
+ int num = firstMarchingCubesId[code + 1] - firstIn;
334
+ int firstOut = mc.first_cell_tri[cell_index];
333
335
 
334
- for (int i = 0; i < num; i++)
336
+ for (int i = 0; i < num; i++)
335
337
  {
336
- int eid = marchingCubesIds[firstIn + i];
338
+ int eid = marchingCubesIds[firstIn + i];
337
339
 
338
- int exi = xi + marchingCubesEdgeLocations[eid][0];
339
- int eyi = yi + marchingCubesEdgeLocations[eid][1];
340
- int ezi = zi + marchingCubesEdgeLocations[eid][2];
341
- int edgeNr = marchingCubesEdgeLocations[eid][3];
340
+ int exi = xi + marchingCubesEdgeLocations[eid][0];
341
+ int eyi = yi + marchingCubesEdgeLocations[eid][1];
342
+ int ezi = zi + marchingCubesEdgeLocations[eid][2];
343
+ int edgeNr = marchingCubesEdgeLocations[eid][3];
342
344
 
343
- int id = mc.cell_verts[3 * mc.cell_index(exi, eyi, ezi) + edgeNr];
344
- triangles[firstOut + i] = id;
345
- }
346
- }
345
+ int id = mc.cell_verts[3 * mc.cell_index(exi, eyi, ezi) + edgeNr];
346
+ triangles[firstOut + i] = id;
347
+ }
348
+ }
347
349
 
348
350
  // -------------------------
349
351
  void marching_cubes_resize(MarchingCubes& mc, int nx, int ny, int nz)
@@ -444,10 +446,7 @@ WP_API int marching_cubes_surface_device(
444
446
  int num_last;
445
447
  memcpy_d2h(WP_CURRENT_CONTEXT, &num_last, &mc.first_cell_vert[mc.num_cells - 1], sizeof(int));
446
448
 
447
- thrust::exclusive_scan(
448
- thrust::device_ptr<int>(mc.first_cell_vert),
449
- thrust::device_ptr<int>(mc.first_cell_vert + mc.num_cells),
450
- thrust::device_ptr<int>(mc.first_cell_vert));
449
+ scan_device(mc.first_cell_vert, mc.first_cell_vert, mc.num_cells, false);
451
450
 
452
451
  int num_verts;
453
452
  memcpy_d2h(WP_CURRENT_CONTEXT, &num_verts, &mc.first_cell_vert[mc.num_cells - 1], sizeof(int));
@@ -472,10 +471,7 @@ WP_API int marching_cubes_surface_device(
472
471
 
473
472
  memcpy_d2h(WP_CURRENT_CONTEXT, &num_last, &mc.first_cell_tri[mc.num_cells - 1], sizeof(int));
474
473
 
475
- thrust::exclusive_scan(
476
- thrust::device_ptr<int>(mc.first_cell_tri),
477
- thrust::device_ptr<int>(mc.first_cell_tri + mc.num_cells),
478
- thrust::device_ptr<int>(mc.first_cell_tri));
474
+ scan_device(mc.first_cell_tri, mc.first_cell_tri, mc.num_cells, false);
479
475
 
480
476
 
481
477
  int num_indices;
warp/native/mat.h CHANGED
@@ -21,7 +21,9 @@ struct quat_t;
21
21
  template<unsigned Rows, unsigned Cols, typename Type>
22
22
  struct mat_t
23
23
  {
24
- inline mat_t() = default;
24
+ inline CUDA_CALLABLE mat_t()
25
+ : data()
26
+ {}
25
27
 
26
28
  inline CUDA_CALLABLE mat_t(Type s)
27
29
  {
@@ -30,6 +32,14 @@ struct mat_t
30
32
  data[i][j] = s;
31
33
  }
32
34
 
35
+ template <typename OtherType>
36
+ inline explicit CUDA_CALLABLE mat_t(const mat_t<Rows, Cols, OtherType>& other)
37
+ {
38
+ for (unsigned i=0; i < Rows; ++i)
39
+ for (unsigned j=0; j < Cols; ++j)
40
+ data[i][j] = other.data[i][j];
41
+ }
42
+
33
43
  inline CUDA_CALLABLE mat_t(vec_t<2,Type> c0, vec_t<2,Type> c1)
34
44
  {
35
45
  data[0][0] = c0[0];
@@ -185,7 +195,7 @@ struct mat_t
185
195
  }
186
196
 
187
197
  // row major storage assumed to be compatible with PyTorch
188
- Type data[Rows][Cols] = {};
198
+ Type data[Rows][Cols];
189
199
  };
190
200
 
191
201
 
@@ -290,7 +300,19 @@ inline CUDA_CALLABLE mat_t<Rows,Cols,Type> atomic_max(mat_t<Rows,Cols,Type> * ad
290
300
  }
291
301
 
292
302
  template<unsigned Rows, unsigned Cols, typename Type>
293
- inline CUDA_CALLABLE vec_t<Cols,Type> index(const mat_t<Rows,Cols,Type>& m, int row)
303
+ inline CUDA_CALLABLE void adj_atomic_minmax(
304
+ mat_t<Rows,Cols,Type> *addr,
305
+ mat_t<Rows,Cols,Type> *adj_addr,
306
+ const mat_t<Rows,Cols,Type> &value,
307
+ mat_t<Rows,Cols,Type> &adj_value)
308
+ {
309
+ for (unsigned i=0; i < Rows; ++i)
310
+ for (unsigned j=0; j < Cols; ++j)
311
+ adj_atomic_minmax(&addr->data[i][j], &adj_addr->data[i][j], value.data[i][j], adj_value.data[i][j]);
312
+ }
313
+
314
+ template<unsigned Rows, unsigned Cols, typename Type>
315
+ inline CUDA_CALLABLE vec_t<Cols,Type> extract(const mat_t<Rows,Cols,Type>& m, int row)
294
316
  {
295
317
  vec_t<Cols,Type> ret;
296
318
  for(unsigned i=0; i < Cols; ++i)
@@ -301,7 +323,7 @@ inline CUDA_CALLABLE vec_t<Cols,Type> index(const mat_t<Rows,Cols,Type>& m, int
301
323
  }
302
324
 
303
325
  template<unsigned Rows, unsigned Cols, typename Type>
304
- inline CUDA_CALLABLE Type index(const mat_t<Rows,Cols,Type>& m, int row, int col)
326
+ inline CUDA_CALLABLE Type extract(const mat_t<Rows,Cols,Type>& m, int row, int col)
305
327
  {
306
328
  #ifndef NDEBUG
307
329
  if (row < 0 || row >= Rows)
@@ -319,7 +341,7 @@ inline CUDA_CALLABLE Type index(const mat_t<Rows,Cols,Type>& m, int row, int col
319
341
  }
320
342
 
321
343
  template<unsigned Rows, unsigned Cols, typename Type>
322
- inline CUDA_CALLABLE void indexset(mat_t<Rows,Cols,Type>& m, int row, vec_t<Cols, Type> value)
344
+ inline CUDA_CALLABLE vec_t<Cols, Type>* index(mat_t<Rows,Cols,Type>& m, int row)
323
345
  {
324
346
  #ifndef NDEBUG
325
347
  if (row < 0 || row >= Rows)
@@ -329,12 +351,11 @@ inline CUDA_CALLABLE void indexset(mat_t<Rows,Cols,Type>& m, int row, vec_t<Cols
329
351
  }
330
352
  #endif
331
353
 
332
- for(unsigned i=0; i < Cols; ++i)
333
- m.data[row][i] = value[i];
354
+ return reinterpret_cast<vec_t<Cols, Type>*>(&m.data[row]);
334
355
  }
335
356
 
336
357
  template<unsigned Rows, unsigned Cols, typename Type>
337
- inline CUDA_CALLABLE void indexset(mat_t<Rows,Cols,Type>& m, int row, int col, Type value)
358
+ inline CUDA_CALLABLE Type* index(mat_t<Rows,Cols,Type>& m, int row, int col)
338
359
  {
339
360
  #ifndef NDEBUG
340
361
  if (row < 0 || row >= Rows)
@@ -348,18 +369,19 @@ inline CUDA_CALLABLE void indexset(mat_t<Rows,Cols,Type>& m, int row, int col, T
348
369
  assert(0);
349
370
  }
350
371
  #endif
351
- m.data[row][col] = value;
372
+
373
+ return &m.data[row][col];
352
374
  }
353
375
 
354
376
  template<unsigned Rows, unsigned Cols, typename Type>
355
- inline CUDA_CALLABLE void adj_indexset(const mat_t<Rows,Cols,Type>& m, int row, const vec_t<Cols, Type>& value,
377
+ inline CUDA_CALLABLE void adj_index(const mat_t<Rows,Cols,Type>& m, int row,
356
378
  const mat_t<Rows,Cols,Type>& adj_m, int adj_row, const vec_t<Cols, Type>& adj_value)
357
379
  {
358
380
  // nop
359
381
  }
360
382
 
361
383
  template<unsigned Rows, unsigned Cols, typename Type>
362
- inline CUDA_CALLABLE void adj_indexset(const mat_t<Rows,Cols,Type>& m, int row, int col, Type value,
384
+ inline CUDA_CALLABLE void adj_index(const mat_t<Rows,Cols,Type>& m, int row, int col,
363
385
  const mat_t<Rows,Cols,Type>& adj_m, int adj_row, int adj_col, Type adj_value)
364
386
  {
365
387
  // nop
@@ -417,7 +439,22 @@ inline CUDA_CALLABLE mat_t<Rows,Cols,Type> div(const mat_t<Rows,Cols,Type>& a, T
417
439
  }
418
440
  }
419
441
 
420
- return t;
442
+ return t;
443
+ }
444
+
445
+ template<unsigned Rows, unsigned Cols, typename Type>
446
+ inline CUDA_CALLABLE mat_t<Rows,Cols,Type> div(Type b, const mat_t<Rows,Cols,Type>& a)
447
+ {
448
+ mat_t<Rows,Cols,Type> t;
449
+ for (unsigned i=0; i < Rows; ++i)
450
+ {
451
+ for (unsigned j=0; j < Cols; ++j)
452
+ {
453
+ t.data[i][j] = b / a.data[i][j];
454
+ }
455
+ }
456
+
457
+ return t;
421
458
  }
422
459
 
423
460
  template<unsigned Rows, unsigned Cols, typename Type>
@@ -432,7 +469,7 @@ inline CUDA_CALLABLE mat_t<Rows,Cols,Type> mul(const mat_t<Rows,Cols,Type>& a, T
432
469
  }
433
470
  }
434
471
 
435
- return t;
472
+ return t;
436
473
  }
437
474
 
438
475
  template<unsigned Rows, unsigned Cols, typename Type>
@@ -465,6 +502,17 @@ inline CUDA_CALLABLE vec_t<Rows,Type> mul(const mat_t<Rows,Cols,Type>& a, const
465
502
  return r;
466
503
  }
467
504
 
505
+ template<unsigned Rows, unsigned Cols, typename Type>
506
+ inline CUDA_CALLABLE vec_t<Cols,Type> mul(const vec_t<Rows,Type>& b, const mat_t<Rows,Cols,Type>& a)
507
+ {
508
+ vec_t<Cols,Type> r = a.get_row(0)*b[0];
509
+ for( unsigned i=1; i < Rows; ++i )
510
+ {
511
+ r += a.get_row(i)*b[i];
512
+ }
513
+ return r;
514
+ }
515
+
468
516
  template<unsigned Rows, unsigned Cols, unsigned ColsOut, typename Type>
469
517
  inline CUDA_CALLABLE mat_t<Rows,ColsOut,Type> mul(const mat_t<Rows,Cols,Type>& a, const mat_t<Cols,ColsOut,Type>& b)
470
518
  {
@@ -608,6 +656,17 @@ inline CUDA_CALLABLE Type trace(const mat_t<Rows,Rows,Type>& m)
608
656
  return ret;
609
657
  }
610
658
 
659
+ template<unsigned Rows, typename Type>
660
+ inline CUDA_CALLABLE vec_t<Rows, Type> get_diag(const mat_t<Rows,Rows,Type>& m)
661
+ {
662
+ vec_t<Rows, Type> ret;
663
+ for( unsigned i=0; i < Rows; ++i )
664
+ {
665
+ ret[i] = m.data[i][i];
666
+ }
667
+ return ret;
668
+ }
669
+
611
670
  // Only implementing inverses for 2x2, 3x3 and 4x4 matrices for now...
612
671
  template<typename Type>
613
672
  inline CUDA_CALLABLE mat_t<2,2,Type> inverse(const mat_t<2,2,Type>& m)
@@ -842,14 +901,14 @@ inline CUDA_CALLABLE vec_t<3,Type> transform_vector(const mat_t<4,4,Type>& m, co
842
901
  }
843
902
 
844
903
  template<unsigned Rows, unsigned Cols, typename Type>
845
- inline CUDA_CALLABLE void adj_index(const mat_t<Rows,Cols,Type>& m, int row, mat_t<Rows,Cols,Type>& adj_m, int& adj_row, const vec_t<Cols,Type>& adj_ret)
904
+ inline CUDA_CALLABLE void adj_extract(const mat_t<Rows,Cols,Type>& m, int row, mat_t<Rows,Cols,Type>& adj_m, int& adj_row, const vec_t<Cols,Type>& adj_ret)
846
905
  {
847
906
  for( unsigned col=0; col < Cols; ++col )
848
907
  adj_m.data[row][col] += adj_ret[col];
849
908
  }
850
909
 
851
910
  template<unsigned Rows, unsigned Cols, typename Type>
852
- inline void CUDA_CALLABLE adj_index(const mat_t<Rows,Cols,Type>& m, int row, int col, mat_t<Rows,Cols,Type>& adj_m, int& adj_row, int& adj_col, Type adj_ret)
911
+ inline void CUDA_CALLABLE adj_extract(const mat_t<Rows,Cols,Type>& m, int row, int col, mat_t<Rows,Cols,Type>& adj_m, int& adj_row, int& adj_col, Type adj_ret)
853
912
  {
854
913
  #ifndef NDEBUG
855
914
  if (row < 0 || row > Rows)
@@ -913,6 +972,20 @@ inline CUDA_CALLABLE void adj_div(const mat_t<Rows,Cols,Type>& a, Type s, mat_t<
913
972
  }
914
973
  }
915
974
 
975
+ template<unsigned Rows, unsigned Cols, typename Type>
976
+ inline CUDA_CALLABLE void adj_div(Type s, const mat_t<Rows,Cols,Type>& a, Type& adj_s, mat_t<Rows,Cols,Type>& adj_a, const mat_t<Rows,Cols,Type>& adj_ret)
977
+ {
978
+ adj_s -= tensordot(a , adj_ret)/ (s * s); // - a / s^2
979
+
980
+ for (unsigned i=0; i < Rows; ++i)
981
+ {
982
+ for (unsigned j=0; j < Cols; ++j)
983
+ {
984
+ adj_a.data[i][j] += s / adj_ret.data[i][j];
985
+ }
986
+ }
987
+ }
988
+
916
989
  template<unsigned Rows, unsigned Cols, typename Type>
917
990
  inline CUDA_CALLABLE void adj_mul(const mat_t<Rows,Cols,Type>& a, Type b, mat_t<Rows,Cols,Type>& adj_a, Type& adj_b, const mat_t<Rows,Cols,Type>& adj_ret)
918
991
  {
@@ -946,6 +1019,13 @@ inline CUDA_CALLABLE void adj_mul(const mat_t<Rows,Cols,Type>& a, const vec_t<Co
946
1019
  adj_b += mul(transpose(a), adj_ret);
947
1020
  }
948
1021
 
1022
+ template<unsigned Rows, unsigned Cols, typename Type>
1023
+ inline CUDA_CALLABLE void adj_mul(const vec_t<Rows,Type>& b, const mat_t<Rows,Cols,Type>& a, vec_t<Rows,Type>& adj_b, mat_t<Rows,Cols,Type>& adj_a, const vec_t<Cols,Type>& adj_ret)
1024
+ {
1025
+ adj_a += outer(b, adj_ret);
1026
+ adj_b += mul(adj_ret, transpose(a));
1027
+ }
1028
+
949
1029
  template<unsigned Rows, unsigned Cols, unsigned ColsOut, typename Type>
950
1030
  inline CUDA_CALLABLE void adj_mul(const mat_t<Rows,Cols,Type>& a, const mat_t<Cols,ColsOut,Type>& b, mat_t<Rows,Cols,Type>& adj_a, mat_t<Cols,ColsOut,Type>& adj_b, const mat_t<Rows,ColsOut,Type>& adj_ret)
951
1031
  {
@@ -973,6 +1053,13 @@ inline CUDA_CALLABLE void adj_diag(const vec_t<Rows,Type>& d, vec_t<Rows,Type>&
973
1053
  adj_d[i] += adj_ret.data[i][i];
974
1054
  }
975
1055
 
1056
+ template<unsigned Rows, typename Type>
1057
+ inline CUDA_CALLABLE void adj_get_diag(const mat_t<Rows,Rows,Type>& m, mat_t<Rows,Rows,Type>& adj_m, const vec_t<Rows,Type>& adj_ret)
1058
+ {
1059
+ for (unsigned i=0; i < Rows; ++i)
1060
+ adj_m.data[i][i] += adj_ret[i];
1061
+ }
1062
+
976
1063
  template<typename Type>
977
1064
  inline CUDA_CALLABLE void adj_determinant(const mat_t<2,2,Type>& m, mat_t<2,2,Type>& adj_m, Type adj_ret)
978
1065
  {
@@ -1079,10 +1166,10 @@ inline CUDA_CALLABLE void adj_determinant(const mat_t<4,4,Type>& m, mat_t<4,4,Ty
1079
1166
  }
1080
1167
 
1081
1168
  template<unsigned Rows, typename Type>
1082
- inline CUDA_CALLABLE void adj_inverse(const mat_t<Rows,Rows,Type>& m, mat_t<Rows,Rows,Type>& adj_m, const mat_t<Rows,Rows,Type>& adj_ret)
1169
+ inline CUDA_CALLABLE void adj_inverse(const mat_t<Rows,Rows,Type>& m, mat_t<Rows,Rows,Type>& ret, mat_t<Rows,Rows,Type>& adj_m, const mat_t<Rows,Rows,Type>& adj_ret)
1083
1170
  {
1084
1171
  // todo: how to cache this from the forward pass?
1085
- mat_t<Rows,Rows,Type> invt = transpose(inverse(m));
1172
+ mat_t<Rows,Rows,Type> invt = transpose(ret);
1086
1173
 
1087
1174
  // see https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf 2.2.3
1088
1175
  adj_m -= mul(mul(invt, adj_ret), invt);
@@ -1124,10 +1211,10 @@ inline CUDA_CALLABLE void adj_cw_mul(const mat_t<Rows,Cols,Type>& a, const mat_t
1124
1211
  }
1125
1212
 
1126
1213
  template<unsigned Rows, unsigned Cols, typename Type>
1127
- inline CUDA_CALLABLE void adj_cw_div(const mat_t<Rows,Cols,Type>& a, const mat_t<Rows,Cols,Type>& b, mat_t<Rows,Cols,Type>& adj_a, mat_t<Rows,Cols,Type>& adj_b, const mat_t<Rows,Cols,Type>& adj_ret)
1214
+ inline CUDA_CALLABLE void adj_cw_div(const mat_t<Rows,Cols,Type>& a, const mat_t<Rows,Cols,Type>& b, mat_t<Rows,Cols,Type>& ret, mat_t<Rows,Cols,Type>& adj_a, mat_t<Rows,Cols,Type>& adj_b, const mat_t<Rows,Cols,Type>& adj_ret)
1128
1215
  {
1129
1216
  adj_a += cw_div(adj_ret, b);
1130
- adj_b -= cw_mul(adj_ret, cw_div(cw_div(a, b), b));
1217
+ adj_b -= cw_mul(adj_ret, cw_div(ret, b));
1131
1218
  }
1132
1219
 
1133
1220
  // adjoint for the constant constructor:
@@ -1143,6 +1230,19 @@ inline CUDA_CALLABLE void adj_mat_t(Type s, Type& adj_s, const mat_t<Rows, Cols,
1143
1230
  }
1144
1231
  }
1145
1232
 
1233
+ // adjoint for the casting constructor:
1234
+ template<unsigned Rows, unsigned Cols, typename Type, typename OtherType>
1235
+ inline CUDA_CALLABLE void adj_mat_t(const mat_t<Rows, Cols, OtherType>& other, mat_t<Rows, Cols, OtherType>& adj_other, const mat_t<Rows, Cols, Type>& adj_ret)
1236
+ {
1237
+ for (unsigned i=0; i < Rows; ++i)
1238
+ {
1239
+ for (unsigned j=0; j < Cols; ++j)
1240
+ {
1241
+ adj_other.data[i][j] += adj_ret.data[i][j];
1242
+ }
1243
+ }
1244
+ }
1245
+
1146
1246
  // adjoint for the initializer_array scalar constructor:
1147
1247
  template<unsigned Rows, unsigned Cols, typename Type>
1148
1248
  inline CUDA_CALLABLE void adj_mat_t(const initializer_array<Rows * Cols, Type> &cmps, const initializer_array<Rows * Cols, Type*> &adj_cmps, const mat_t<Rows, Cols, Type>& adj_ret)
warp/native/matnn.h CHANGED
@@ -248,7 +248,7 @@ CUDA_CALLABLE inline void adj_mlp(const array_t<float>& weights, const array_t<f
248
248
  tmp += weights.data[i*n + j]*x.data[index + b*j];
249
249
  }
250
250
 
251
- // adjoint w.r.t to acivation
251
+ // adjoint w.r.t to activation
252
252
  float adj_f = 0.0f;
253
253
 
254
254
  if (adj_out.data)
@@ -313,7 +313,7 @@ CUDA_CALLABLE inline void adj_mlp(const array_t<float>& weights, const array_t<f
313
313
  // tmp += weights[i*n + j]*x[index + b*j];
314
314
  // }
315
315
 
316
- // // adjoint w.r.t to acivation
316
+ // // adjoint w.r.t to activation
317
317
  // float adj_f = 0.0f;
318
318
  // adj_activation(tmp, adj_f, adj_out[index + b*i]);
319
319