warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/marching.cu CHANGED
@@ -1,8 +1,6 @@
1
1
  #include "warp.h"
2
2
  #include "cuda_util.h"
3
-
4
- #include "thrust/device_ptr.h"
5
- #include "thrust/sort.h"
3
+ #include "scan.h"
6
4
 
7
5
  namespace wp {
8
6
 
@@ -162,13 +160,17 @@ namespace wp {
162
160
  };
163
161
 
164
162
 
165
- // ---------------------------------------------------------------------------------------
166
- struct MarchingCubes
167
- {
168
- MarchingCubes()
163
+ // ---------------------------------------------------------------------------------------
164
+ struct MarchingCubes
165
+ {
166
+ MarchingCubes()
169
167
  {
170
- memset(this, 0, sizeof(MarchingCubes));
171
- }
168
+ memset(this, 0, sizeof(MarchingCubes));
169
+ first_cell_vert = nullptr;
170
+ first_cell_tri = nullptr;
171
+ cell_verts = nullptr;
172
+ context = nullptr;
173
+ }
172
174
 
173
175
  __device__ __host__ int cell_index(int xi, int yi, int zi) const
174
176
  {
@@ -181,169 +183,169 @@ namespace wp {
181
183
  xi = cell_index / ny;
182
184
  }
183
185
 
184
- // grid
186
+ // grid
185
187
  int nx;
186
188
  int ny;
187
189
  int nz;
188
190
 
189
- int* first_cell_vert;
190
- int* first_cell_tri;
191
- int* cell_verts;
191
+ int* first_cell_vert;
192
+ int* first_cell_tri;
193
+ int* cell_verts;
192
194
 
193
195
  int num_cells;
194
196
  int max_cells;
195
197
 
196
198
  void* context;
197
- };
198
-
199
-
200
- // -----------------------------------------------------------------------------------
201
- __global__ void count_cell_verts(MarchingCubes mc, const float* density, float threshold)
202
- {
203
- int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
204
- if (cell_index >= mc.num_cells)
205
- return;
206
-
207
- int xi, yi, zi;
208
- mc.cell_coord(cell_index, xi, yi, zi);
209
-
210
- mc.first_cell_vert[cell_index] = 0;
211
- if (xi >= mc.nx - 1 || yi >= mc.ny - 1 || zi >= mc.nz - 1)
212
- return;
213
-
214
- float d0 = density[cell_index];
215
- float dx = density[mc.cell_index(xi + 1, yi, zi)];
216
- float dy = density[mc.cell_index(xi, yi + 1, zi)];
217
- float dz = density[mc.cell_index(xi, yi, zi + 1)];
218
-
219
- int num = 0;
220
- if ((d0 <= threshold && dx >= threshold) || (dx <= threshold && d0 >= threshold))
221
- num++;
222
- if ((d0 <= threshold && dy >= threshold) || (dy <= threshold && d0 >= threshold))
223
- num++;
224
- if ((d0 <= threshold && dz >= threshold) || (dz <= threshold && d0 >= threshold))
225
- num++;
226
-
227
- mc.first_cell_vert[cell_index] = num;
228
- }
229
-
230
- // -----------------------------------------------------------------------------------
231
- __global__ void create_cell_verts(MarchingCubes mc, vec3* __restrict__ vertices, vec3* normals, const float* __restrict__ density, float threshold)
232
- {
233
- int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
234
- if (cell_index >= mc.num_cells)
235
- return;
236
-
237
- int xi, yi, zi;
238
- mc.cell_coord(cell_index, xi, yi, zi);
239
- if (xi >= mc.nx - 1 || yi >= mc.ny - 1 || zi >= mc.nz - 1)
240
- return;
241
-
242
- vec3 p = vec3(xi + 0.5f, yi + 0.5f, zi + 0.5f);
243
-
244
- float d0 = density[cell_index];
245
- float ds[3];
246
- ds[0] = density[mc.cell_index(xi + 1, yi, zi)];
247
- ds[1] = density[mc.cell_index(xi, yi + 1, zi)];
248
- ds[2] = density[mc.cell_index(xi, yi, zi + 1)];
249
-
250
- // vec3 n0 = densityNormal[cell_index];
251
- // vec3 ns[3];
252
- // ns[0] = densityNormal[mc.cell_index(xi + 1, yi, zi)];
253
- // ns[1] = densityNormal[mc.cell_index(xi, yi + 1, zi)];
254
- // ns[2] = densityNormal[mc.cell_index(xi, yi, zi + 1)];
255
-
256
- int first = mc.first_cell_vert[cell_index];
257
-
258
- for (int dim = 0; dim < 3; dim++)
199
+ };
200
+
201
+
202
+ // -----------------------------------------------------------------------------------
203
+ __global__ void count_cell_verts(MarchingCubes mc, const float* density, float threshold)
204
+ {
205
+ int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
206
+ if (cell_index >= mc.num_cells)
207
+ return;
208
+
209
+ int xi, yi, zi;
210
+ mc.cell_coord(cell_index, xi, yi, zi);
211
+
212
+ mc.first_cell_vert[cell_index] = 0;
213
+ if (xi >= mc.nx - 1 || yi >= mc.ny - 1 || zi >= mc.nz - 1)
214
+ return;
215
+
216
+ float d0 = density[cell_index];
217
+ float dx = density[mc.cell_index(xi + 1, yi, zi)];
218
+ float dy = density[mc.cell_index(xi, yi + 1, zi)];
219
+ float dz = density[mc.cell_index(xi, yi, zi + 1)];
220
+
221
+ int num = 0;
222
+ if ((d0 <= threshold && dx >= threshold) || (dx <= threshold && d0 >= threshold))
223
+ num++;
224
+ if ((d0 <= threshold && dy >= threshold) || (dy <= threshold && d0 >= threshold))
225
+ num++;
226
+ if ((d0 <= threshold && dz >= threshold) || (dz <= threshold && d0 >= threshold))
227
+ num++;
228
+
229
+ mc.first_cell_vert[cell_index] = num;
230
+ }
231
+
232
+ // -----------------------------------------------------------------------------------
233
+ __global__ void create_cell_verts(MarchingCubes mc, vec3* __restrict__ vertices, vec3* normals, const float* __restrict__ density, float threshold)
234
+ {
235
+ int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
236
+ if (cell_index >= mc.num_cells)
237
+ return;
238
+
239
+ int xi, yi, zi;
240
+ mc.cell_coord(cell_index, xi, yi, zi);
241
+ if (xi >= mc.nx - 1 || yi >= mc.ny - 1 || zi >= mc.nz - 1)
242
+ return;
243
+
244
+ vec3 p = vec3(xi + 0.5f, yi + 0.5f, zi + 0.5f);
245
+
246
+ float d0 = density[cell_index];
247
+ float ds[3];
248
+ ds[0] = density[mc.cell_index(xi + 1, yi, zi)];
249
+ ds[1] = density[mc.cell_index(xi, yi + 1, zi)];
250
+ ds[2] = density[mc.cell_index(xi, yi, zi + 1)];
251
+
252
+ // vec3 n0 = densityNormal[cell_index];
253
+ // vec3 ns[3];
254
+ // ns[0] = densityNormal[mc.cell_index(xi + 1, yi, zi)];
255
+ // ns[1] = densityNormal[mc.cell_index(xi, yi + 1, zi)];
256
+ // ns[2] = densityNormal[mc.cell_index(xi, yi, zi + 1)];
257
+
258
+ int first = mc.first_cell_vert[cell_index];
259
+
260
+ for (int dim = 0; dim < 3; dim++)
259
261
  {
260
- float d = ds[dim];
261
- mc.cell_verts[3 * cell_index + dim] = 0;
262
+ float d = ds[dim];
263
+ mc.cell_verts[3 * cell_index + dim] = 0;
262
264
 
263
- if ((d0 <= threshold && d >= threshold) || (d <= threshold && d0 >= threshold))
265
+ if ((d0 <= threshold && d >= threshold) || (d <= threshold && d0 >= threshold))
264
266
  {
265
- float t = (d != d0) ? clamp((threshold - d0) / (d - d0), 0.0f, 1.0f) : 0.5f;
266
- int id = first++;
267
-
268
- vec3 off;
269
- off[dim] = t;
270
- vertices[id] = p + off;
271
-
272
- // vec3 n = normalize(n0 + t * (ns[dim] - n0));
273
- // normals[id] = -n;
274
-
275
- mc.cell_verts[3 * cell_index + dim] = id;
276
- }
277
- }
278
- }
279
-
280
- // -----------------------------------------------------------------------------------
281
- __global__ void count_cell_tris(MarchingCubes mc, const float* __restrict__ density, float threshold)
282
- {
283
- int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
284
- if (cell_index >= mc.num_cells)
285
- return;
286
-
287
- int xi, yi, zi;
288
- mc.cell_coord(cell_index, xi, yi, zi);
289
-
290
- mc.first_cell_tri[cell_index] = 0;
291
- if (xi >= mc.nx - 2 || yi >= mc.ny - 2 || zi >= mc.nz - 2)
292
- return;
293
-
294
- int code = 0;
295
- for (int i = 0; i < 8; i++) {
296
- int cxi = xi + marchingCubeCorners[i][0];
297
- int cyi = yi + marchingCubeCorners[i][1];
298
- int czi = zi + marchingCubeCorners[i][2];
299
-
300
- if (density[mc.cell_index(cxi, cyi, czi)] >= threshold)
301
- code |= (1 << i);
302
- }
303
-
304
- mc.first_cell_tri[cell_index] = firstMarchingCubesId[code + 1] - firstMarchingCubesId[code];
305
- }
306
-
307
- // -----------------------------------------------------------------------------------
308
- __global__ void create_cell_tris(MarchingCubes mc, const float* __restrict__ density, int* __restrict__ triangles, float threshold)
309
- {
310
- int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
311
- if (cell_index >= mc.num_cells)
312
- return;
313
-
314
- int xi, yi, zi;
315
- mc.cell_coord(cell_index, xi, yi, zi);
316
- if (xi >= mc.nx - 2 || yi >= mc.ny - 2 || zi >= mc.nz - 2)
317
- return;
318
-
319
- int code = 0;
320
- for (int i = 0; i < 8; i++)
267
+ float t = (d != d0) ? clamp((threshold - d0) / (d - d0), 0.0f, 1.0f) : 0.5f;
268
+ int id = first++;
269
+
270
+ vec3 off;
271
+ off[dim] = t;
272
+ vertices[id] = p + off;
273
+
274
+ // vec3 n = normalize(n0 + t * (ns[dim] - n0));
275
+ // normals[id] = -n;
276
+
277
+ mc.cell_verts[3 * cell_index + dim] = id;
278
+ }
279
+ }
280
+ }
281
+
282
+ // -----------------------------------------------------------------------------------
283
+ __global__ void count_cell_tris(MarchingCubes mc, const float* __restrict__ density, float threshold)
284
+ {
285
+ int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
286
+ if (cell_index >= mc.num_cells)
287
+ return;
288
+
289
+ int xi, yi, zi;
290
+ mc.cell_coord(cell_index, xi, yi, zi);
291
+
292
+ mc.first_cell_tri[cell_index] = 0;
293
+ if (xi >= mc.nx - 2 || yi >= mc.ny - 2 || zi >= mc.nz - 2)
294
+ return;
295
+
296
+ int code = 0;
297
+ for (int i = 0; i < 8; i++) {
298
+ int cxi = xi + marchingCubeCorners[i][0];
299
+ int cyi = yi + marchingCubeCorners[i][1];
300
+ int czi = zi + marchingCubeCorners[i][2];
301
+
302
+ if (density[mc.cell_index(cxi, cyi, czi)] >= threshold)
303
+ code |= (1 << i);
304
+ }
305
+
306
+ mc.first_cell_tri[cell_index] = firstMarchingCubesId[code + 1] - firstMarchingCubesId[code];
307
+ }
308
+
309
+ // -----------------------------------------------------------------------------------
310
+ __global__ void create_cell_tris(MarchingCubes mc, const float* __restrict__ density, int* __restrict__ triangles, float threshold)
311
+ {
312
+ int cell_index = blockIdx.x * blockDim.x + threadIdx.x;
313
+ if (cell_index >= mc.num_cells)
314
+ return;
315
+
316
+ int xi, yi, zi;
317
+ mc.cell_coord(cell_index, xi, yi, zi);
318
+ if (xi >= mc.nx - 2 || yi >= mc.ny - 2 || zi >= mc.nz - 2)
319
+ return;
320
+
321
+ int code = 0;
322
+ for (int i = 0; i < 8; i++)
321
323
  {
322
- int cxi = xi + marchingCubeCorners[i][0];
323
- int cyi = yi + marchingCubeCorners[i][1];
324
- int czi = zi + marchingCubeCorners[i][2];
324
+ int cxi = xi + marchingCubeCorners[i][0];
325
+ int cyi = yi + marchingCubeCorners[i][1];
326
+ int czi = zi + marchingCubeCorners[i][2];
325
327
 
326
- if (density[mc.cell_index(cxi, cyi, czi)] >= threshold)
327
- code |= (1 << i);
328
- }
328
+ if (density[mc.cell_index(cxi, cyi, czi)] >= threshold)
329
+ code |= (1 << i);
330
+ }
329
331
 
330
- int firstIn = firstMarchingCubesId[code];
331
- int num = firstMarchingCubesId[code + 1] - firstIn;
332
- int firstOut = mc.first_cell_tri[cell_index];
332
+ int firstIn = firstMarchingCubesId[code];
333
+ int num = firstMarchingCubesId[code + 1] - firstIn;
334
+ int firstOut = mc.first_cell_tri[cell_index];
333
335
 
334
- for (int i = 0; i < num; i++)
336
+ for (int i = 0; i < num; i++)
335
337
  {
336
- int eid = marchingCubesIds[firstIn + i];
338
+ int eid = marchingCubesIds[firstIn + i];
337
339
 
338
- int exi = xi + marchingCubesEdgeLocations[eid][0];
339
- int eyi = yi + marchingCubesEdgeLocations[eid][1];
340
- int ezi = zi + marchingCubesEdgeLocations[eid][2];
341
- int edgeNr = marchingCubesEdgeLocations[eid][3];
340
+ int exi = xi + marchingCubesEdgeLocations[eid][0];
341
+ int eyi = yi + marchingCubesEdgeLocations[eid][1];
342
+ int ezi = zi + marchingCubesEdgeLocations[eid][2];
343
+ int edgeNr = marchingCubesEdgeLocations[eid][3];
342
344
 
343
- int id = mc.cell_verts[3 * mc.cell_index(exi, eyi, ezi) + edgeNr];
344
- triangles[firstOut + i] = id;
345
- }
346
- }
345
+ int id = mc.cell_verts[3 * mc.cell_index(exi, eyi, ezi) + edgeNr];
346
+ triangles[firstOut + i] = id;
347
+ }
348
+ }
347
349
 
348
350
  // -------------------------
349
351
  void marching_cubes_resize(MarchingCubes& mc, int nx, int ny, int nz)
@@ -444,10 +446,7 @@ WP_API int marching_cubes_surface_device(
444
446
  int num_last;
445
447
  memcpy_d2h(WP_CURRENT_CONTEXT, &num_last, &mc.first_cell_vert[mc.num_cells - 1], sizeof(int));
446
448
 
447
- thrust::exclusive_scan(
448
- thrust::device_ptr<int>(mc.first_cell_vert),
449
- thrust::device_ptr<int>(mc.first_cell_vert + mc.num_cells),
450
- thrust::device_ptr<int>(mc.first_cell_vert));
449
+ scan_device(mc.first_cell_vert, mc.first_cell_vert, mc.num_cells, false);
451
450
 
452
451
  int num_verts;
453
452
  memcpy_d2h(WP_CURRENT_CONTEXT, &num_verts, &mc.first_cell_vert[mc.num_cells - 1], sizeof(int));
@@ -472,10 +471,7 @@ WP_API int marching_cubes_surface_device(
472
471
 
473
472
  memcpy_d2h(WP_CURRENT_CONTEXT, &num_last, &mc.first_cell_tri[mc.num_cells - 1], sizeof(int));
474
473
 
475
- thrust::exclusive_scan(
476
- thrust::device_ptr<int>(mc.first_cell_tri),
477
- thrust::device_ptr<int>(mc.first_cell_tri + mc.num_cells),
478
- thrust::device_ptr<int>(mc.first_cell_tri));
474
+ scan_device(mc.first_cell_tri, mc.first_cell_tri, mc.num_cells, false);
479
475
 
480
476
 
481
477
  int num_indices;
warp/native/mat.h CHANGED
@@ -21,7 +21,9 @@ struct quat_t;
21
21
  template<unsigned Rows, unsigned Cols, typename Type>
22
22
  struct mat_t
23
23
  {
24
- inline mat_t() = default;
24
+ inline CUDA_CALLABLE mat_t()
25
+ : data()
26
+ {}
25
27
 
26
28
  inline CUDA_CALLABLE mat_t(Type s)
27
29
  {
@@ -193,7 +195,7 @@ struct mat_t
193
195
  }
194
196
 
195
197
  // row major storage assumed to be compatible with PyTorch
196
- Type data[Rows][Cols] = {};
198
+ Type data[Rows][Cols];
197
199
  };
198
200
 
199
201
 
@@ -298,7 +300,19 @@ inline CUDA_CALLABLE mat_t<Rows,Cols,Type> atomic_max(mat_t<Rows,Cols,Type> * ad
298
300
  }
299
301
 
300
302
  template<unsigned Rows, unsigned Cols, typename Type>
301
- inline CUDA_CALLABLE vec_t<Cols,Type> index(const mat_t<Rows,Cols,Type>& m, int row)
303
+ inline CUDA_CALLABLE void adj_atomic_minmax(
304
+ mat_t<Rows,Cols,Type> *addr,
305
+ mat_t<Rows,Cols,Type> *adj_addr,
306
+ const mat_t<Rows,Cols,Type> &value,
307
+ mat_t<Rows,Cols,Type> &adj_value)
308
+ {
309
+ for (unsigned i=0; i < Rows; ++i)
310
+ for (unsigned j=0; j < Cols; ++j)
311
+ adj_atomic_minmax(&addr->data[i][j], &adj_addr->data[i][j], value.data[i][j], adj_value.data[i][j]);
312
+ }
313
+
314
+ template<unsigned Rows, unsigned Cols, typename Type>
315
+ inline CUDA_CALLABLE vec_t<Cols,Type> extract(const mat_t<Rows,Cols,Type>& m, int row)
302
316
  {
303
317
  vec_t<Cols,Type> ret;
304
318
  for(unsigned i=0; i < Cols; ++i)
@@ -309,7 +323,7 @@ inline CUDA_CALLABLE vec_t<Cols,Type> index(const mat_t<Rows,Cols,Type>& m, int
309
323
  }
310
324
 
311
325
  template<unsigned Rows, unsigned Cols, typename Type>
312
- inline CUDA_CALLABLE Type index(const mat_t<Rows,Cols,Type>& m, int row, int col)
326
+ inline CUDA_CALLABLE Type extract(const mat_t<Rows,Cols,Type>& m, int row, int col)
313
327
  {
314
328
  #ifndef NDEBUG
315
329
  if (row < 0 || row >= Rows)
@@ -327,7 +341,7 @@ inline CUDA_CALLABLE Type index(const mat_t<Rows,Cols,Type>& m, int row, int col
327
341
  }
328
342
 
329
343
  template<unsigned Rows, unsigned Cols, typename Type>
330
- inline CUDA_CALLABLE void indexset(mat_t<Rows,Cols,Type>& m, int row, vec_t<Cols, Type> value)
344
+ inline CUDA_CALLABLE vec_t<Cols, Type>* index(mat_t<Rows,Cols,Type>& m, int row)
331
345
  {
332
346
  #ifndef NDEBUG
333
347
  if (row < 0 || row >= Rows)
@@ -337,12 +351,11 @@ inline CUDA_CALLABLE void indexset(mat_t<Rows,Cols,Type>& m, int row, vec_t<Cols
337
351
  }
338
352
  #endif
339
353
 
340
- for(unsigned i=0; i < Cols; ++i)
341
- m.data[row][i] = value[i];
354
+ return reinterpret_cast<vec_t<Cols, Type>*>(&m.data[row]);
342
355
  }
343
356
 
344
357
  template<unsigned Rows, unsigned Cols, typename Type>
345
- inline CUDA_CALLABLE void indexset(mat_t<Rows,Cols,Type>& m, int row, int col, Type value)
358
+ inline CUDA_CALLABLE Type* index(mat_t<Rows,Cols,Type>& m, int row, int col)
346
359
  {
347
360
  #ifndef NDEBUG
348
361
  if (row < 0 || row >= Rows)
@@ -356,18 +369,19 @@ inline CUDA_CALLABLE void indexset(mat_t<Rows,Cols,Type>& m, int row, int col, T
356
369
  assert(0);
357
370
  }
358
371
  #endif
359
- m.data[row][col] = value;
372
+
373
+ return &m.data[row][col];
360
374
  }
361
375
 
362
376
  template<unsigned Rows, unsigned Cols, typename Type>
363
- inline CUDA_CALLABLE void adj_indexset(const mat_t<Rows,Cols,Type>& m, int row, const vec_t<Cols, Type>& value,
377
+ inline CUDA_CALLABLE void adj_index(const mat_t<Rows,Cols,Type>& m, int row,
364
378
  const mat_t<Rows,Cols,Type>& adj_m, int adj_row, const vec_t<Cols, Type>& adj_value)
365
379
  {
366
380
  // nop
367
381
  }
368
382
 
369
383
  template<unsigned Rows, unsigned Cols, typename Type>
370
- inline CUDA_CALLABLE void adj_indexset(const mat_t<Rows,Cols,Type>& m, int row, int col, Type value,
384
+ inline CUDA_CALLABLE void adj_index(const mat_t<Rows,Cols,Type>& m, int row, int col,
371
385
  const mat_t<Rows,Cols,Type>& adj_m, int adj_row, int adj_col, Type adj_value)
372
386
  {
373
387
  // nop
@@ -425,7 +439,22 @@ inline CUDA_CALLABLE mat_t<Rows,Cols,Type> div(const mat_t<Rows,Cols,Type>& a, T
425
439
  }
426
440
  }
427
441
 
428
- return t;
442
+ return t;
443
+ }
444
+
445
+ template<unsigned Rows, unsigned Cols, typename Type>
446
+ inline CUDA_CALLABLE mat_t<Rows,Cols,Type> div(Type b, const mat_t<Rows,Cols,Type>& a)
447
+ {
448
+ mat_t<Rows,Cols,Type> t;
449
+ for (unsigned i=0; i < Rows; ++i)
450
+ {
451
+ for (unsigned j=0; j < Cols; ++j)
452
+ {
453
+ t.data[i][j] = b / a.data[i][j];
454
+ }
455
+ }
456
+
457
+ return t;
429
458
  }
430
459
 
431
460
  template<unsigned Rows, unsigned Cols, typename Type>
@@ -440,7 +469,7 @@ inline CUDA_CALLABLE mat_t<Rows,Cols,Type> mul(const mat_t<Rows,Cols,Type>& a, T
440
469
  }
441
470
  }
442
471
 
443
- return t;
472
+ return t;
444
473
  }
445
474
 
446
475
  template<unsigned Rows, unsigned Cols, typename Type>
@@ -473,6 +502,17 @@ inline CUDA_CALLABLE vec_t<Rows,Type> mul(const mat_t<Rows,Cols,Type>& a, const
473
502
  return r;
474
503
  }
475
504
 
505
+ template<unsigned Rows, unsigned Cols, typename Type>
506
+ inline CUDA_CALLABLE vec_t<Cols,Type> mul(const vec_t<Rows,Type>& b, const mat_t<Rows,Cols,Type>& a)
507
+ {
508
+ vec_t<Cols,Type> r = a.get_row(0)*b[0];
509
+ for( unsigned i=1; i < Rows; ++i )
510
+ {
511
+ r += a.get_row(i)*b[i];
512
+ }
513
+ return r;
514
+ }
515
+
476
516
  template<unsigned Rows, unsigned Cols, unsigned ColsOut, typename Type>
477
517
  inline CUDA_CALLABLE mat_t<Rows,ColsOut,Type> mul(const mat_t<Rows,Cols,Type>& a, const mat_t<Cols,ColsOut,Type>& b)
478
518
  {
@@ -861,14 +901,14 @@ inline CUDA_CALLABLE vec_t<3,Type> transform_vector(const mat_t<4,4,Type>& m, co
861
901
  }
862
902
 
863
903
  template<unsigned Rows, unsigned Cols, typename Type>
864
- inline CUDA_CALLABLE void adj_index(const mat_t<Rows,Cols,Type>& m, int row, mat_t<Rows,Cols,Type>& adj_m, int& adj_row, const vec_t<Cols,Type>& adj_ret)
904
+ inline CUDA_CALLABLE void adj_extract(const mat_t<Rows,Cols,Type>& m, int row, mat_t<Rows,Cols,Type>& adj_m, int& adj_row, const vec_t<Cols,Type>& adj_ret)
865
905
  {
866
906
  for( unsigned col=0; col < Cols; ++col )
867
907
  adj_m.data[row][col] += adj_ret[col];
868
908
  }
869
909
 
870
910
  template<unsigned Rows, unsigned Cols, typename Type>
871
- inline void CUDA_CALLABLE adj_index(const mat_t<Rows,Cols,Type>& m, int row, int col, mat_t<Rows,Cols,Type>& adj_m, int& adj_row, int& adj_col, Type adj_ret)
911
+ inline void CUDA_CALLABLE adj_extract(const mat_t<Rows,Cols,Type>& m, int row, int col, mat_t<Rows,Cols,Type>& adj_m, int& adj_row, int& adj_col, Type adj_ret)
872
912
  {
873
913
  #ifndef NDEBUG
874
914
  if (row < 0 || row > Rows)
@@ -932,6 +972,20 @@ inline CUDA_CALLABLE void adj_div(const mat_t<Rows,Cols,Type>& a, Type s, mat_t<
932
972
  }
933
973
  }
934
974
 
975
+ template<unsigned Rows, unsigned Cols, typename Type>
976
+ inline CUDA_CALLABLE void adj_div(Type s, const mat_t<Rows,Cols,Type>& a, Type& adj_s, mat_t<Rows,Cols,Type>& adj_a, const mat_t<Rows,Cols,Type>& adj_ret)
977
+ {
978
+ adj_s -= tensordot(a , adj_ret)/ (s * s); // - a / s^2
979
+
980
+ for (unsigned i=0; i < Rows; ++i)
981
+ {
982
+ for (unsigned j=0; j < Cols; ++j)
983
+ {
984
+ adj_a.data[i][j] += s / adj_ret.data[i][j];
985
+ }
986
+ }
987
+ }
988
+
935
989
  template<unsigned Rows, unsigned Cols, typename Type>
936
990
  inline CUDA_CALLABLE void adj_mul(const mat_t<Rows,Cols,Type>& a, Type b, mat_t<Rows,Cols,Type>& adj_a, Type& adj_b, const mat_t<Rows,Cols,Type>& adj_ret)
937
991
  {
@@ -965,6 +1019,13 @@ inline CUDA_CALLABLE void adj_mul(const mat_t<Rows,Cols,Type>& a, const vec_t<Co
965
1019
  adj_b += mul(transpose(a), adj_ret);
966
1020
  }
967
1021
 
1022
+ template<unsigned Rows, unsigned Cols, typename Type>
1023
+ inline CUDA_CALLABLE void adj_mul(const vec_t<Rows,Type>& b, const mat_t<Rows,Cols,Type>& a, vec_t<Rows,Type>& adj_b, mat_t<Rows,Cols,Type>& adj_a, const vec_t<Cols,Type>& adj_ret)
1024
+ {
1025
+ adj_a += outer(b, adj_ret);
1026
+ adj_b += mul(adj_ret, transpose(a));
1027
+ }
1028
+
968
1029
  template<unsigned Rows, unsigned Cols, unsigned ColsOut, typename Type>
969
1030
  inline CUDA_CALLABLE void adj_mul(const mat_t<Rows,Cols,Type>& a, const mat_t<Cols,ColsOut,Type>& b, mat_t<Rows,Cols,Type>& adj_a, mat_t<Cols,ColsOut,Type>& adj_b, const mat_t<Rows,ColsOut,Type>& adj_ret)
970
1031
  {
@@ -1105,10 +1166,10 @@ inline CUDA_CALLABLE void adj_determinant(const mat_t<4,4,Type>& m, mat_t<4,4,Ty
1105
1166
  }
1106
1167
 
1107
1168
  template<unsigned Rows, typename Type>
1108
- inline CUDA_CALLABLE void adj_inverse(const mat_t<Rows,Rows,Type>& m, mat_t<Rows,Rows,Type>& adj_m, const mat_t<Rows,Rows,Type>& adj_ret)
1169
+ inline CUDA_CALLABLE void adj_inverse(const mat_t<Rows,Rows,Type>& m, mat_t<Rows,Rows,Type>& ret, mat_t<Rows,Rows,Type>& adj_m, const mat_t<Rows,Rows,Type>& adj_ret)
1109
1170
  {
1110
1171
  // todo: how to cache this from the forward pass?
1111
- mat_t<Rows,Rows,Type> invt = transpose(inverse(m));
1172
+ mat_t<Rows,Rows,Type> invt = transpose(ret);
1112
1173
 
1113
1174
  // see https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf 2.2.3
1114
1175
  adj_m -= mul(mul(invt, adj_ret), invt);
@@ -1150,10 +1211,10 @@ inline CUDA_CALLABLE void adj_cw_mul(const mat_t<Rows,Cols,Type>& a, const mat_t
1150
1211
  }
1151
1212
 
1152
1213
  template<unsigned Rows, unsigned Cols, typename Type>
1153
- inline CUDA_CALLABLE void adj_cw_div(const mat_t<Rows,Cols,Type>& a, const mat_t<Rows,Cols,Type>& b, mat_t<Rows,Cols,Type>& adj_a, mat_t<Rows,Cols,Type>& adj_b, const mat_t<Rows,Cols,Type>& adj_ret)
1214
+ inline CUDA_CALLABLE void adj_cw_div(const mat_t<Rows,Cols,Type>& a, const mat_t<Rows,Cols,Type>& b, mat_t<Rows,Cols,Type>& ret, mat_t<Rows,Cols,Type>& adj_a, mat_t<Rows,Cols,Type>& adj_b, const mat_t<Rows,Cols,Type>& adj_ret)
1154
1215
  {
1155
1216
  adj_a += cw_div(adj_ret, b);
1156
- adj_b -= cw_mul(adj_ret, cw_div(cw_div(a, b), b));
1217
+ adj_b -= cw_mul(adj_ret, cw_div(ret, b));
1157
1218
  }
1158
1219
 
1159
1220
  // adjoint for the constant constructor:
warp/native/matnn.h CHANGED
@@ -248,7 +248,7 @@ CUDA_CALLABLE inline void adj_mlp(const array_t<float>& weights, const array_t<f
248
248
  tmp += weights.data[i*n + j]*x.data[index + b*j];
249
249
  }
250
250
 
251
- // adjoint w.r.t to acivation
251
+ // adjoint w.r.t to activation
252
252
  float adj_f = 0.0f;
253
253
 
254
254
  if (adj_out.data)
@@ -313,7 +313,7 @@ CUDA_CALLABLE inline void adj_mlp(const array_t<float>& weights, const array_t<f
313
313
  // tmp += weights[i*n + j]*x[index + b*j];
314
314
  // }
315
315
 
316
- // // adjoint w.r.t to acivation
316
+ // // adjoint w.r.t to activation
317
317
  // float adj_f = 0.0f;
318
318
  // adj_activation(tmp, adj_f, adj_out[index + b*i]);
319
319