warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (269) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.so +0 -0
  57. warp/bin/warp.so +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/fem/field/discrete_field.py +0 -80
  257. warp/fem/space/nodal_function_space.py +0 -233
  258. warp/tests/test_all.py +0 -223
  259. warp/tests/test_array_scan.py +0 -60
  260. warp/tests/test_base.py +0 -208
  261. warp/tests/test_unresolved_func.py +0 -7
  262. warp/tests/test_unresolved_symbol.py +0 -7
  263. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  264. warp_lang-1.0.0b2.dist-info/RECORD +0 -378
  265. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  266. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  267. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  268. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  269. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/native/array.h CHANGED
@@ -19,6 +19,12 @@ namespace wp
19
19
  printf(")\n"); \
20
20
  assert(0); \
21
21
 
22
+ #define FP_VERIFY_FWD(value) \
23
+ if (!isfinite(value)) { \
24
+ printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
25
+ FP_ASSERT_FWD(value) \
26
+ } \
27
+
22
28
  #define FP_VERIFY_FWD_1(value) \
23
29
  if (!isfinite(value)) { \
24
30
  printf("%s:%d - %s(arr, %d) ", __FILE__, __LINE__, __FUNCTION__, i); \
@@ -43,6 +49,13 @@ namespace wp
43
49
  FP_ASSERT_FWD(value) \
44
50
  } \
45
51
 
52
+ #define FP_VERIFY_ADJ(value, adj_value) \
53
+ if (!isfinite(value) || !isfinite(adj_value)) \
54
+ { \
55
+ printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
56
+ FP_ASSERT_ADJ(value, adj_value); \
57
+ } \
58
+
46
59
  #define FP_VERIFY_ADJ_1(value, adj_value) \
47
60
  if (!isfinite(value) || !isfinite(adj_value)) \
48
61
  { \
@@ -74,11 +87,13 @@ namespace wp
74
87
 
75
88
  #else
76
89
 
90
+ #define FP_VERIFY_FWD(value) {}
77
91
  #define FP_VERIFY_FWD_1(value) {}
78
92
  #define FP_VERIFY_FWD_2(value) {}
79
93
  #define FP_VERIFY_FWD_3(value) {}
80
94
  #define FP_VERIFY_FWD_4(value) {}
81
95
 
96
+ #define FP_VERIFY_ADJ(value, adj_value) {}
82
97
  #define FP_VERIFY_ADJ_1(value, adj_value) {}
83
98
  #define FP_VERIFY_ADJ_2(value, adj_value) {}
84
99
  #define FP_VERIFY_ADJ_3(value, adj_value) {}
@@ -113,12 +128,12 @@ struct shape_t
113
128
  }
114
129
  };
115
130
 
116
- CUDA_CALLABLE inline int index(const shape_t& s, int i)
131
+ CUDA_CALLABLE inline int extract(const shape_t& s, int i)
117
132
  {
118
133
  return s.dims[i];
119
134
  }
120
135
 
121
- CUDA_CALLABLE inline void adj_index(const shape_t& s, int i, const shape_t& adj_s, int adj_i, int adj_ret) {}
136
+ CUDA_CALLABLE inline void adj_extract(const shape_t& s, int i, const shape_t& adj_s, int adj_i, int adj_ret) {}
122
137
 
123
138
  inline CUDA_CALLABLE void print(shape_t s)
124
139
  {
@@ -670,43 +685,60 @@ template<template<typename> class A, typename T>
670
685
  inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_max(&index(buf, i, j, k, l), value); }
671
686
 
672
687
  template<template<typename> class A, typename T>
673
- inline CUDA_CALLABLE T load(const A<T>& buf, int i) { return index(buf, i); }
688
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i) { return &index(buf, i); }
674
689
  template<template<typename> class A, typename T>
675
- inline CUDA_CALLABLE T load(const A<T>& buf, int i, int j) { return index(buf, i, j); }
690
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j) { return &index(buf, i, j); }
676
691
  template<template<typename> class A, typename T>
677
- inline CUDA_CALLABLE T load(const A<T>& buf, int i, int j, int k) { return index(buf, i, j, k); }
692
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k) { return &index(buf, i, j, k); }
678
693
  template<template<typename> class A, typename T>
679
- inline CUDA_CALLABLE T load(const A<T>& buf, int i, int j, int k, int l) { return index(buf, i, j, k, l); }
694
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k, int l) { return &index(buf, i, j, k, l); }
680
695
 
681
696
  template<template<typename> class A, typename T>
682
- inline CUDA_CALLABLE void store(const A<T>& buf, int i, T value)
697
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, T value)
683
698
  {
684
699
  FP_VERIFY_FWD_1(value)
685
700
 
686
701
  index(buf, i) = value;
687
702
  }
688
703
  template<template<typename> class A, typename T>
689
- inline CUDA_CALLABLE void store(const A<T>& buf, int i, int j, T value)
704
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, T value)
690
705
  {
691
706
  FP_VERIFY_FWD_2(value)
692
707
 
693
708
  index(buf, i, j) = value;
694
709
  }
695
710
  template<template<typename> class A, typename T>
696
- inline CUDA_CALLABLE void store(const A<T>& buf, int i, int j, int k, T value)
711
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, T value)
697
712
  {
698
713
  FP_VERIFY_FWD_3(value)
699
714
 
700
715
  index(buf, i, j, k) = value;
701
716
  }
702
717
  template<template<typename> class A, typename T>
703
- inline CUDA_CALLABLE void store(const A<T>& buf, int i, int j, int k, int l, T value)
718
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, int l, T value)
704
719
  {
705
720
  FP_VERIFY_FWD_4(value)
706
721
 
707
722
  index(buf, i, j, k, l) = value;
708
723
  }
709
724
 
725
+ template<typename T>
726
+ inline CUDA_CALLABLE void store(T* address, T value)
727
+ {
728
+ FP_VERIFY_FWD(value)
729
+
730
+ *address = value;
731
+ }
732
+
733
+ template<typename T>
734
+ inline CUDA_CALLABLE T load(T* address)
735
+ {
736
+ T value = *address;
737
+ FP_VERIFY_FWD(value)
738
+
739
+ return value;
740
+ }
741
+
710
742
  // select operator to check for array being null
711
743
  template <typename T1, typename T2>
712
744
  CUDA_CALLABLE inline T2 select(const array_t<T1>& arr, const T2& a, const T2& b) { return arr.data?b:a; }
@@ -744,32 +776,32 @@ CUDA_CALLABLE inline void adj_atomic_add(bool* buf, bool value) { }
744
776
 
745
777
  // only generate gradients for T types
746
778
  template<typename T>
747
- inline CUDA_CALLABLE void adj_load(const array_t<T>& buf, int i, const array_t<T>& adj_buf, int& adj_i, const T& adj_output)
779
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, const array_t<T>& adj_buf, int& adj_i, const T& adj_output)
748
780
  {
749
781
  if (buf.grad)
750
782
  adj_atomic_add(&index_grad(buf, i), adj_output);
751
783
  }
752
784
  template<typename T>
753
- inline CUDA_CALLABLE void adj_load(const array_t<T>& buf, int i, int j, const array_t<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output)
785
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, const array_t<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output)
754
786
  {
755
787
  if (buf.grad)
756
788
  adj_atomic_add(&index_grad(buf, i, j), adj_output);
757
789
  }
758
790
  template<typename T>
759
- inline CUDA_CALLABLE void adj_load(const array_t<T>& buf, int i, int j, int k, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output)
791
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, int k, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output)
760
792
  {
761
793
  if (buf.grad)
762
794
  adj_atomic_add(&index_grad(buf, i, j, k), adj_output);
763
795
  }
764
796
  template<typename T>
765
- inline CUDA_CALLABLE void adj_load(const array_t<T>& buf, int i, int j, int k, int l, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output)
797
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, int k, int l, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output)
766
798
  {
767
799
  if (buf.grad)
768
800
  adj_atomic_add(&index_grad(buf, i, j, k, l), adj_output);
769
801
  }
770
802
 
771
803
  template<typename T>
772
- inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int& adj_i, T& adj_value)
804
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int& adj_i, T& adj_value)
773
805
  {
774
806
  if (buf.grad)
775
807
  adj_value += index_grad(buf, i);
@@ -777,7 +809,7 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, T value, const
777
809
  FP_VERIFY_ADJ_1(value, adj_value)
778
810
  }
779
811
  template<typename T>
780
- inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value)
812
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value)
781
813
  {
782
814
  if (buf.grad)
783
815
  adj_value += index_grad(buf, i, j);
@@ -786,7 +818,7 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, T value
786
818
 
787
819
  }
788
820
  template<typename T>
789
- inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value)
821
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value)
790
822
  {
791
823
  if (buf.grad)
792
824
  adj_value += index_grad(buf, i, j, k);
@@ -794,7 +826,7 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, int k,
794
826
  FP_VERIFY_ADJ_3(value, adj_value)
795
827
  }
796
828
  template<typename T>
797
- inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value)
829
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value)
798
830
  {
799
831
  if (buf.grad)
800
832
  adj_value += index_grad(buf, i, j, k, l);
@@ -802,6 +834,19 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, int k,
802
834
  FP_VERIFY_ADJ_4(value, adj_value)
803
835
  }
804
836
 
837
+ template<typename T>
838
+ inline CUDA_CALLABLE void adj_store(const T* address, T value, const T& adj_address, T& adj_value)
839
+ {
840
+ // nop; generic store() operations are not differentiable, only array_store() is
841
+ FP_VERIFY_ADJ(value, adj_value)
842
+ }
843
+
844
+ template<typename T>
845
+ inline CUDA_CALLABLE void adj_load(const T* address, const T& adj_address, T& adj_value)
846
+ {
847
+ // nop; generic load() operations are not differentiable
848
+ }
849
+
805
850
  template<typename T>
806
851
  inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret)
807
852
  {
@@ -871,22 +916,22 @@ inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, int j, in
871
916
 
872
917
  // generic array types that do not support gradient computation (indexedarray, etc.)
873
918
  template<template<typename> class A1, template<typename> class A2, typename T>
874
- inline CUDA_CALLABLE void adj_load(const A1<T>& buf, int i, const A2<T>& adj_buf, int& adj_i, const T& adj_output) {}
919
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, const A2<T>& adj_buf, int& adj_i, const T& adj_output) {}
875
920
  template<template<typename> class A1, template<typename> class A2, typename T>
876
- inline CUDA_CALLABLE void adj_load(const A1<T>& buf, int i, int j, const A2<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output) {}
921
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, const A2<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output) {}
877
922
  template<template<typename> class A1, template<typename> class A2, typename T>
878
- inline CUDA_CALLABLE void adj_load(const A1<T>& buf, int i, int j, int k, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output) {}
923
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, int k, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output) {}
879
924
  template<template<typename> class A1, template<typename> class A2, typename T>
880
- inline CUDA_CALLABLE void adj_load(const A1<T>& buf, int i, int j, int k, int l, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output) {}
925
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, int k, int l, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output) {}
881
926
 
882
927
  template<template<typename> class A1, template<typename> class A2, typename T>
883
- inline CUDA_CALLABLE void adj_store(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value) {}
928
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value) {}
884
929
  template<template<typename> class A1, template<typename> class A2, typename T>
885
- inline CUDA_CALLABLE void adj_store(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value) {}
930
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value) {}
886
931
  template<template<typename> class A1, template<typename> class A2, typename T>
887
- inline CUDA_CALLABLE void adj_store(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value) {}
932
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value) {}
888
933
  template<template<typename> class A1, template<typename> class A2, typename T>
889
- inline CUDA_CALLABLE void adj_store(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value) {}
934
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value) {}
890
935
 
891
936
  template<template<typename> class A1, template<typename> class A2, typename T>
892
937
  inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {}
@@ -906,23 +951,64 @@ inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k,
906
951
  template<template<typename> class A1, template<typename> class A2, typename T>
907
952
  inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {}
908
953
 
954
+ // generic handler for scalar values
909
955
  template<template<typename> class A1, template<typename> class A2, typename T>
910
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {}
956
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {
957
+ if (buf.grad)
958
+ adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
959
+
960
+ FP_VERIFY_ADJ_1(value, adj_value)
961
+ }
911
962
  template<template<typename> class A1, template<typename> class A2, typename T>
912
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {}
963
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {
964
+ if (buf.grad)
965
+ adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
966
+
967
+ FP_VERIFY_ADJ_2(value, adj_value)
968
+ }
913
969
  template<template<typename> class A1, template<typename> class A2, typename T>
914
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {}
970
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {
971
+ if (buf.grad)
972
+ adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
973
+
974
+ FP_VERIFY_ADJ_3(value, adj_value)
975
+ }
915
976
  template<template<typename> class A1, template<typename> class A2, typename T>
916
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {}
977
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {
978
+ if (buf.grad)
979
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
980
+
981
+ FP_VERIFY_ADJ_4(value, adj_value)
982
+ }
917
983
 
918
984
  template<template<typename> class A1, template<typename> class A2, typename T>
919
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {}
985
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {
986
+ if (buf.grad)
987
+ adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
988
+
989
+ FP_VERIFY_ADJ_1(value, adj_value)
990
+ }
920
991
  template<template<typename> class A1, template<typename> class A2, typename T>
921
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {}
992
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {
993
+ if (buf.grad)
994
+ adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
995
+
996
+ FP_VERIFY_ADJ_2(value, adj_value)
997
+ }
922
998
  template<template<typename> class A1, template<typename> class A2, typename T>
923
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {}
999
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {
1000
+ if (buf.grad)
1001
+ adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
1002
+
1003
+ FP_VERIFY_ADJ_3(value, adj_value)
1004
+ }
924
1005
  template<template<typename> class A1, template<typename> class A2, typename T>
925
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {}
1006
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {
1007
+ if (buf.grad)
1008
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
1009
+
1010
+ FP_VERIFY_ADJ_4(value, adj_value)
1011
+ }
926
1012
 
927
1013
  } // namespace wp
928
1014
 
warp/native/builtin.h CHANGED
@@ -251,8 +251,6 @@ CUDA_CALLABLE inline void adj_int8(T, T&, int8) {}
251
251
  template <typename T>
252
252
  CUDA_CALLABLE inline void adj_uint8(T, T&, uint8) {}
253
253
  template <typename T>
254
- CUDA_CALLABLE inline void adj_bool(T, T&, bool) {}
255
- template <typename T>
256
254
  CUDA_CALLABLE inline void adj_int16(T, T&, int16) {}
257
255
  template <typename T>
258
256
  CUDA_CALLABLE inline void adj_uint16(T, T&, uint16) {}
@@ -297,7 +295,7 @@ inline CUDA_CALLABLE T rshift(T a, T b) { return a>>b; } \
297
295
  inline CUDA_CALLABLE T invert(T x) { return ~x; } \
298
296
  inline CUDA_CALLABLE bool isfinite(T x) { return true; } \
299
297
  inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
300
- inline CUDA_CALLABLE void adj_div(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
298
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret) { } \
301
299
  inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
302
300
  inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
303
301
  inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
@@ -435,11 +433,6 @@ inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b,
435
433
  else\
436
434
  adj_x += adj_ret;\
437
435
  }\
438
- inline CUDA_CALLABLE void adj_round(T x, T& adj_x, T adj_ret){ }\
439
- inline CUDA_CALLABLE void adj_rint(T x, T& adj_x, T adj_ret){ }\
440
- inline CUDA_CALLABLE void adj_trunc(T x, T& adj_x, T adj_ret){ }\
441
- inline CUDA_CALLABLE void adj_floor(T x, T& adj_x, T adj_ret){ }\
442
- inline CUDA_CALLABLE void adj_ceil(T x, T& adj_x, T adj_ret){ }\
443
436
  inline CUDA_CALLABLE T div(T a, T b)\
444
437
  {\
445
438
  DO_IF_FPCHECK(\
@@ -450,10 +443,10 @@ inline CUDA_CALLABLE T div(T a, T b)\
450
443
  })\
451
444
  return a/b;\
452
445
  }\
453
- inline CUDA_CALLABLE void adj_div(T a, T b, T& adj_a, T& adj_b, T adj_ret)\
446
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
454
447
  {\
455
448
  adj_a += adj_ret/b;\
456
- adj_b -= adj_ret*(a/b)/b;\
449
+ adj_b -= adj_ret*(ret)/b;\
457
450
  DO_IF_FPCHECK(\
458
451
  if (!isfinite(adj_a) || !isfinite(adj_b))\
459
452
  {\
@@ -792,6 +785,10 @@ inline CUDA_CALLABLE half sqrt(half x)
792
785
  return ::sqrtf(float(x));
793
786
  }
794
787
 
788
+ inline CUDA_CALLABLE float cbrt(float x) { return ::cbrtf(x); }
789
+ inline CUDA_CALLABLE double cbrt(double x) { return ::cbrt(x); }
790
+ inline CUDA_CALLABLE half cbrt(half x) { return ::cbrtf(float(x)); }
791
+
795
792
  inline CUDA_CALLABLE float tan(float x) { return ::tanf(x); }
796
793
  inline CUDA_CALLABLE float sinh(float x) { return ::sinhf(x);}
797
794
  inline CUDA_CALLABLE float cosh(float x) { return ::coshf(x);}
@@ -818,6 +815,21 @@ inline CUDA_CALLABLE float rint(float x) { return ::rintf(x); }
818
815
  inline CUDA_CALLABLE float trunc(float x) { return ::truncf(x); }
819
816
  inline CUDA_CALLABLE float floor(float x) { return ::floorf(x); }
820
817
  inline CUDA_CALLABLE float ceil(float x) { return ::ceilf(x); }
818
+ inline CUDA_CALLABLE float frac(float x) { return x - trunc(x); }
819
+
820
+ inline CUDA_CALLABLE double round(double x) { return ::round(x); }
821
+ inline CUDA_CALLABLE double rint(double x) { return ::rint(x); }
822
+ inline CUDA_CALLABLE double trunc(double x) { return ::trunc(x); }
823
+ inline CUDA_CALLABLE double floor(double x) { return ::floor(x); }
824
+ inline CUDA_CALLABLE double ceil(double x) { return ::ceil(x); }
825
+ inline CUDA_CALLABLE double frac(double x) { return x - trunc(x); }
826
+
827
+ inline CUDA_CALLABLE half round(half x) { return ::roundf(float(x)); }
828
+ inline CUDA_CALLABLE half rint(half x) { return ::rintf(float(x)); }
829
+ inline CUDA_CALLABLE half trunc(half x) { return ::truncf(float(x)); }
830
+ inline CUDA_CALLABLE half floor(half x) { return ::floorf(float(x)); }
831
+ inline CUDA_CALLABLE half ceil(half x) { return ::ceilf(float(x)); }
832
+ inline CUDA_CALLABLE half frac(half x) { return float(x) - trunc(float(x)); }
821
833
 
822
834
  #define DECLARE_ADJOINTS(T)\
823
835
  inline CUDA_CALLABLE void adj_log(T a, T& adj_a, T adj_ret)\
@@ -847,11 +859,11 @@ inline CUDA_CALLABLE void adj_log10(T a, T& adj_a, T adj_ret)\
847
859
  assert(0);\
848
860
  })\
849
861
  }\
850
- inline CUDA_CALLABLE void adj_exp(T a, T& adj_a, T adj_ret) { adj_a += exp(a)*adj_ret; }\
851
- inline CUDA_CALLABLE void adj_pow(T a, T b, T& adj_a, T& adj_b, T adj_ret)\
862
+ inline CUDA_CALLABLE void adj_exp(T a, T ret, T& adj_a, T adj_ret) { adj_a += ret*adj_ret; }\
863
+ inline CUDA_CALLABLE void adj_pow(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
852
864
  { \
853
865
  adj_a += b*pow(a, b-T(1))*adj_ret;\
854
- adj_b += log(a)*pow(a, b)*adj_ret;\
866
+ adj_b += log(a)*ret*adj_ret;\
855
867
  DO_IF_FPCHECK(if (!isfinite(adj_a) || !isfinite(adj_b))\
856
868
  {\
857
869
  printf("%s:%d - adj_pow(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
@@ -950,20 +962,28 @@ inline CUDA_CALLABLE void adj_cosh(T x, T& adj_x, T adj_ret)\
950
962
  {\
951
963
  adj_x += sinh(x)*adj_ret;\
952
964
  }\
953
- inline CUDA_CALLABLE void adj_tanh(T x, T& adj_x, T adj_ret)\
965
+ inline CUDA_CALLABLE void adj_tanh(T x, T ret, T& adj_x, T adj_ret)\
954
966
  {\
955
- T tanh_x = tanh(x);\
956
- adj_x += (T(1) - tanh_x*tanh_x)*adj_ret;\
967
+ adj_x += (T(1) - ret*ret)*adj_ret;\
957
968
  }\
958
- inline CUDA_CALLABLE void adj_sqrt(T x, T& adj_x, T adj_ret)\
969
+ inline CUDA_CALLABLE void adj_sqrt(T x, T ret, T& adj_x, T adj_ret)\
959
970
  {\
960
- adj_x += T(0.5)*(T(1)/sqrt(x))*adj_ret;\
971
+ adj_x += T(0.5)*(T(1)/ret)*adj_ret;\
961
972
  DO_IF_FPCHECK(if (!isfinite(adj_x))\
962
973
  {\
963
974
  printf("%s:%d - adj_sqrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
964
975
  assert(0);\
965
976
  })\
966
977
  }\
978
+ inline CUDA_CALLABLE void adj_cbrt(T x, T ret, T& adj_x, T adj_ret)\
979
+ {\
980
+ adj_x += (T(1)/T(3))*(T(1)/(ret*ret))*adj_ret;\
981
+ DO_IF_FPCHECK(if (!isfinite(adj_x))\
982
+ {\
983
+ printf("%s:%d - adj_cbrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
984
+ assert(0);\
985
+ })\
986
+ }\
967
987
  inline CUDA_CALLABLE void adj_degrees(T x, T& adj_x, T adj_ret)\
968
988
  {\
969
989
  adj_x += RAD_TO_DEG * adj_ret;\
@@ -971,7 +991,13 @@ inline CUDA_CALLABLE void adj_degrees(T x, T& adj_x, T adj_ret)\
971
991
  inline CUDA_CALLABLE void adj_radians(T x, T& adj_x, T adj_ret)\
972
992
  {\
973
993
  adj_x += DEG_TO_RAD * adj_ret;\
974
- }
994
+ }\
995
+ inline CUDA_CALLABLE void adj_round(T x, T& adj_x, T adj_ret){ }\
996
+ inline CUDA_CALLABLE void adj_rint(T x, T& adj_x, T adj_ret){ }\
997
+ inline CUDA_CALLABLE void adj_trunc(T x, T& adj_x, T adj_ret){ }\
998
+ inline CUDA_CALLABLE void adj_floor(T x, T& adj_x, T adj_ret){ }\
999
+ inline CUDA_CALLABLE void adj_ceil(T x, T& adj_x, T adj_ret){ }\
1000
+ inline CUDA_CALLABLE void adj_frac(T x, T& adj_x, T adj_ret){ }
975
1001
 
976
1002
  DECLARE_ADJOINTS(float16)
977
1003
  DECLARE_ADJOINTS(float32)
@@ -995,17 +1021,31 @@ CUDA_CALLABLE inline void adj_select(const C& cond, const T& a, const T& b, C& a
995
1021
  }
996
1022
 
997
1023
  template <typename T>
998
- CUDA_CALLABLE inline void copy(T& dest, const T& src)
1024
+ CUDA_CALLABLE inline T copy(const T& src)
1025
+ {
1026
+ return src;
1027
+ }
1028
+
1029
+ template <typename T>
1030
+ CUDA_CALLABLE inline void adj_copy(const T& src, T& adj_src, T& adj_dest)
1031
+ {
1032
+ adj_src = adj_dest;
1033
+ adj_dest = T{};
1034
+ }
1035
+
1036
+ template <typename T>
1037
+ CUDA_CALLABLE inline void assign(T& dest, const T& src)
999
1038
  {
1000
1039
  dest = src;
1001
1040
  }
1002
1041
 
1003
1042
  template <typename T>
1004
- CUDA_CALLABLE inline void adj_copy(T& dest, const T& src, T& adj_dest, T& adj_src)
1043
+ CUDA_CALLABLE inline void adj_assign(T& dest, const T& src, T& adj_dest, T& adj_src)
1005
1044
  {
1006
- // nop, this is non-differentiable operation since it violates SSA
1045
+ // this is generally a non-differentiable operation since it violates SSA,
1046
+ // except in read-modify-write statements which are reversible through backpropagation
1007
1047
  adj_src = adj_dest;
1008
- adj_dest = T(0);
1048
+ adj_dest = T{};
1009
1049
  }
1010
1050
 
1011
1051
 
@@ -1050,34 +1090,8 @@ struct launch_bounds_t
1050
1090
  size_t size; // total number of threads
1051
1091
  };
1052
1092
 
1053
- #ifdef __CUDACC__
1054
-
1055
- // store launch bounds in shared memory so
1056
- // we can access them from any user func
1057
- // this is to avoid having to explicitly
1058
- // set another piece of __constant__ memory
1059
- // from the host
1060
- __shared__ launch_bounds_t s_launchBounds;
1061
-
1062
- __device__ inline void set_launch_bounds(const launch_bounds_t& b)
1063
- {
1064
- if (threadIdx.x == 0)
1065
- s_launchBounds = b;
1066
-
1067
- __syncthreads();
1068
- }
1069
-
1070
- #else
1071
-
1072
- // for single-threaded CPU we store launch
1073
- // bounds in static memory to share globally
1074
- static launch_bounds_t s_launchBounds;
1093
+ #ifndef __CUDACC__
1075
1094
  static size_t s_threadIdx;
1076
-
1077
- inline void set_launch_bounds(const launch_bounds_t& b)
1078
- {
1079
- s_launchBounds = b;
1080
- }
1081
1095
  #endif
1082
1096
 
1083
1097
  inline CUDA_CALLABLE size_t grid_index()
@@ -1091,10 +1105,8 @@ inline CUDA_CALLABLE size_t grid_index()
1091
1105
  #endif
1092
1106
  }
1093
1107
 
1094
- inline CUDA_CALLABLE int tid()
1108
+ inline CUDA_CALLABLE int tid(size_t index)
1095
1109
  {
1096
- const size_t index = grid_index();
1097
-
1098
1110
  // For the 1-D tid() we need to warn the user if we're about to provide a truncated index
1099
1111
  // Only do this in _DEBUG when called from device to avoid excessive register allocation
1100
1112
  #if defined(_DEBUG) || !defined(__CUDA_ARCH__)
@@ -1105,23 +1117,19 @@ inline CUDA_CALLABLE int tid()
1105
1117
  return static_cast<int>(index);
1106
1118
  }
1107
1119
 
1108
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j)
1120
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, size_t index, const launch_bounds_t& launch_bounds)
1109
1121
  {
1110
- const size_t index = grid_index();
1111
-
1112
- const size_t n = s_launchBounds.shape[1];
1122
+ const size_t n = launch_bounds.shape[1];
1113
1123
 
1114
1124
  // convert to work item
1115
1125
  i = index/n;
1116
1126
  j = index%n;
1117
1127
  }
1118
1128
 
1119
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k)
1129
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, size_t index, const launch_bounds_t& launch_bounds)
1120
1130
  {
1121
- const size_t index = grid_index();
1122
-
1123
- const size_t n = s_launchBounds.shape[1];
1124
- const size_t o = s_launchBounds.shape[2];
1131
+ const size_t n = launch_bounds.shape[1];
1132
+ const size_t o = launch_bounds.shape[2];
1125
1133
 
1126
1134
  // convert to work item
1127
1135
  i = index/(n*o);
@@ -1129,13 +1137,11 @@ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k)
1129
1137
  k = index%o;
1130
1138
  }
1131
1139
 
1132
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l)
1140
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l, size_t index, const launch_bounds_t& launch_bounds)
1133
1141
  {
1134
- const size_t index = grid_index();
1135
-
1136
- const size_t n = s_launchBounds.shape[1];
1137
- const size_t o = s_launchBounds.shape[2];
1138
- const size_t p = s_launchBounds.shape[3];
1142
+ const size_t n = launch_bounds.shape[1];
1143
+ const size_t o = launch_bounds.shape[2];
1144
+ const size_t p = launch_bounds.shape[3];
1139
1145
 
1140
1146
  // convert to work item
1141
1147
  i = index/(n*o*p);
@@ -1265,9 +1271,36 @@ inline CUDA_CALLABLE int atomic_min(int* address, int val)
1265
1271
  #endif
1266
1272
  }
1267
1273
 
1274
+ // default behavior for adjoint of atomic min/max operation that accumulates gradients for all elements matching the min/max value
1275
+ template <typename T>
1276
+ CUDA_CALLABLE inline void adj_atomic_minmax(T *addr, T *adj_addr, const T &value, T &adj_value)
1277
+ {
1278
+ if (value == *addr)
1279
+ adj_value += *adj_addr;
1280
+ }
1281
+
1282
+ // for integral types we do not accumulate gradients
1283
+ CUDA_CALLABLE inline void adj_atomic_minmax(int8* buf, int8* adj_buf, const int8 &value, int8 &adj_value) { }
1284
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint8* buf, uint8* adj_buf, const uint8 &value, uint8 &adj_value) { }
1285
+ CUDA_CALLABLE inline void adj_atomic_minmax(int16* buf, int16* adj_buf, const int16 &value, int16 &adj_value) { }
1286
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint16* buf, uint16* adj_buf, const uint16 &value, uint16 &adj_value) { }
1287
+ CUDA_CALLABLE inline void adj_atomic_minmax(int32* buf, int32* adj_buf, const int32 &value, int32 &adj_value) { }
1288
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint32* buf, uint32* adj_buf, const uint32 &value, uint32 &adj_value) { }
1289
+ CUDA_CALLABLE inline void adj_atomic_minmax(int64* buf, int64* adj_buf, const int64 &value, int64 &adj_value) { }
1290
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint64* buf, uint64* adj_buf, const uint64 &value, uint64 &adj_value) { }
1291
+ CUDA_CALLABLE inline void adj_atomic_minmax(bool* buf, bool* adj_buf, const bool &value, bool &adj_value) { }
1292
+
1268
1293
 
1269
1294
  } // namespace wp
1270
1295
 
1296
+
1297
+ // bool and printf are defined outside of the wp namespace in crt.h, hence
1298
+ // their adjoint counterparts are also defined in the global namespace.
1299
+ template <typename T>
1300
+ CUDA_CALLABLE inline void adj_bool(T, T&, bool) {}
1301
+ inline CUDA_CALLABLE void adj_printf(const char* fmt, ...) {}
1302
+
1303
+
1271
1304
  #include "vec.h"
1272
1305
  #include "mat.h"
1273
1306
  #include "quat.h"
@@ -1432,10 +1465,6 @@ inline CUDA_CALLABLE void adj_print(transform_t<Type> t, transform_t<Type>& adj_
1432
1465
  inline CUDA_CALLABLE void adj_print(str t, str& adj_t) {}
1433
1466
 
1434
1467
 
1435
- // printf defined globally in crt.h
1436
- inline CUDA_CALLABLE void adj_printf(const char* fmt, ...) {}
1437
-
1438
-
1439
1468
  template <typename T>
1440
1469
  inline CUDA_CALLABLE void expect_eq(const T& actual, const T& expected)
1441
1470
  {