warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cpp CHANGED
@@ -10,33 +10,99 @@
10
10
  #include "scan.h"
11
11
  #include "array.h"
12
12
 
13
+ #include "exports.h"
14
+
13
15
  #include "stdlib.h"
14
16
  #include "string.h"
15
17
 
18
+ int cuda_init();
16
19
 
17
- namespace wp
18
- {
19
20
 
20
- extern "C"
21
+ uint16_t float_to_half_bits(float x)
21
22
  {
22
- #include "exports.h"
23
- }
24
-
25
- } // namespace wp
23
+ // adapted from Fabien Giesen's post: https://gist.github.com/rygorous/2156668
24
+ union fp32
25
+ {
26
+ uint32_t u;
27
+ float f;
26
28
 
27
- int cuda_init();
29
+ struct
30
+ {
31
+ unsigned int mantissa : 23;
32
+ unsigned int exponent : 8;
33
+ unsigned int sign : 1;
34
+ };
35
+ };
36
+
37
+ fp32 f;
38
+ f.f = x;
39
+
40
+ fp32 f32infty = { 255 << 23 };
41
+ fp32 f16infty = { 31 << 23 };
42
+ fp32 magic = { 15 << 23 };
43
+ uint32_t sign_mask = 0x80000000u;
44
+ uint32_t round_mask = ~0xfffu;
45
+ uint16_t u;
46
+
47
+ uint32_t sign = f.u & sign_mask;
48
+ f.u ^= sign;
49
+
50
+ // NOTE all the integer compares in this function can be safely
51
+ // compiled into signed compares since all operands are below
52
+ // 0x80000000. Important if you want fast straight SSE2 code
53
+ // (since there's no unsigned PCMPGTD).
54
+
55
+ if (f.u >= f32infty.u) // Inf or NaN (all exponent bits set)
56
+ u = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
57
+ else // (De)normalized number or zero
58
+ {
59
+ f.u &= round_mask;
60
+ f.f *= magic.f;
61
+ f.u -= round_mask;
62
+ if (f.u > f16infty.u) f.u = f16infty.u; // Clamp to signed infinity if overflowed
28
63
 
64
+ u = f.u >> 13; // Take the bits!
65
+ }
29
66
 
30
- uint16_t float_to_half_bits(float x)
31
- {
32
- return wp::half(x).u;
67
+ u |= sign >> 16;
68
+ return u;
33
69
  }
34
70
 
35
71
  float half_bits_to_float(uint16_t u)
36
72
  {
37
- wp::half h;
38
- h.u = u;
39
- return half_to_float(h);
73
+ // adapted from Fabien Giesen's post: https://gist.github.com/rygorous/2156668
74
+ union fp32
75
+ {
76
+ uint32_t u;
77
+ float f;
78
+
79
+ struct
80
+ {
81
+ unsigned int mantissa : 23;
82
+ unsigned int exponent : 8;
83
+ unsigned int sign : 1;
84
+ };
85
+ };
86
+
87
+ static const fp32 magic = { 113 << 23 };
88
+ static const uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift
89
+ fp32 o;
90
+
91
+ o.u = (u & 0x7fff) << 13; // exponent/mantissa bits
92
+ uint32_t exp = shifted_exp & o.u; // just the exponent
93
+ o.u += (127 - 15) << 23; // exponent adjust
94
+
95
+ // handle exponent special cases
96
+ if (exp == shifted_exp) // Inf/NaN?
97
+ o.u += (128 - 16) << 23; // extra exp adjust
98
+ else if (exp == 0) // Zero/Denormal?
99
+ {
100
+ o.u += 1 << 23; // extra exp adjust
101
+ o.f -= magic.f; // renormalize
102
+ }
103
+
104
+ o.u |= (u & 0x8000) << 16; // sign bit
105
+ return o.f;
40
106
  }
41
107
 
42
108
  int init()
@@ -102,34 +168,38 @@ void memset_host(void* dest, int value, size_t n)
102
168
  }
103
169
  }
104
170
 
105
- void memtile_host(void* dest, void *src, size_t srcsize, size_t n)
171
+ // fill memory buffer with a value: this is a faster memtile variant
172
+ // for types bigger than one byte, but requires proper alignment of dst
173
+ template <typename T>
174
+ void memtile_value_host(T* dst, T value, size_t n)
106
175
  {
107
- for( size_t i=0; i < n; ++i )
108
- {
109
- memcpy(dest,src,srcsize);
110
- dest = (char*)dest + srcsize;
111
- }
176
+ while (n--)
177
+ *dst++ = value;
112
178
  }
113
179
 
114
- void array_inner_host(uint64_t a, uint64_t b, uint64_t out, int len)
180
+ void memtile_host(void* dst, const void* src, size_t srcsize, size_t n)
115
181
  {
116
- const float* ptr_a = (const float*)(a);
117
- const float* ptr_b = (const float*)(b);
118
- float* ptr_out = (float*)(out);
119
-
120
- *ptr_out = 0.0f;
121
- for (int i=0; i < len; ++i)
122
- *ptr_out += ptr_a[i]*ptr_b[i];
123
- }
182
+ size_t dst_addr = reinterpret_cast<size_t>(dst);
183
+ size_t src_addr = reinterpret_cast<size_t>(src);
124
184
 
125
- void array_sum_host(uint64_t a, uint64_t out, int len)
126
- {
127
- const float* ptr_a = (const float*)(a);
128
- float* ptr_out = (float*)(out);
129
-
130
- *ptr_out = 0.0f;
131
- for (int i=0; i < len; ++i)
132
- *ptr_out += ptr_a[i];
185
+ // try memtile_value first because it should be faster, but we need to ensure proper alignment
186
+ if (srcsize == 8 && (dst_addr & 7) == 0 && (src_addr & 7) == 0)
187
+ memtile_value_host(reinterpret_cast<int64_t*>(dst), *reinterpret_cast<const int64_t*>(src), n);
188
+ else if (srcsize == 4 && (dst_addr & 3) == 0 && (src_addr & 3) == 0)
189
+ memtile_value_host(reinterpret_cast<int32_t*>(dst), *reinterpret_cast<const int32_t*>(src), n);
190
+ else if (srcsize == 2 && (dst_addr & 1) == 0 && (src_addr & 1) == 0)
191
+ memtile_value_host(reinterpret_cast<int16_t*>(dst), *reinterpret_cast<const int16_t*>(src), n);
192
+ else if (srcsize == 1)
193
+ memset(dst, *reinterpret_cast<const int8_t*>(src), n);
194
+ else
195
+ {
196
+ // generic version
197
+ while (n--)
198
+ {
199
+ memcpy(dst, src, srcsize);
200
+ dst = (int8_t*)dst + srcsize;
201
+ }
202
+ }
133
203
  }
134
204
 
135
205
  void array_scan_int_host(uint64_t in, uint64_t out, int len, bool inclusive)
@@ -175,6 +245,312 @@ static void array_copy_nd(void* dst, const void* src,
175
245
  }
176
246
 
177
247
 
248
+ static void array_copy_to_fabric(wp::fabricarray_t<void>& dst, const void* src_data,
249
+ int src_stride, const int* src_indices, int elem_size)
250
+ {
251
+ const int8_t* src_ptr = static_cast<const int8_t*>(src_data);
252
+
253
+ if (src_indices)
254
+ {
255
+ // copy from indexed array
256
+ for (size_t i = 0; i < dst.nbuckets; i++)
257
+ {
258
+ const wp::fabricbucket_t& bucket = dst.buckets[i];
259
+ int8_t* dst_ptr = static_cast<int8_t*>(bucket.ptr);
260
+ size_t bucket_size = bucket.index_end - bucket.index_start;
261
+ for (size_t j = 0; j < bucket_size; j++)
262
+ {
263
+ int idx = *src_indices;
264
+ memcpy(dst_ptr, src_ptr + idx * elem_size, elem_size);
265
+ dst_ptr += elem_size;
266
+ ++src_indices;
267
+ }
268
+ }
269
+ }
270
+ else
271
+ {
272
+ if (src_stride == elem_size)
273
+ {
274
+ // copy from contiguous array
275
+ for (size_t i = 0; i < dst.nbuckets; i++)
276
+ {
277
+ const wp::fabricbucket_t& bucket = dst.buckets[i];
278
+ size_t num_bytes = (bucket.index_end - bucket.index_start) * elem_size;
279
+ memcpy(bucket.ptr, src_ptr, num_bytes);
280
+ src_ptr += num_bytes;
281
+ }
282
+ }
283
+ else
284
+ {
285
+ // copy from strided array
286
+ for (size_t i = 0; i < dst.nbuckets; i++)
287
+ {
288
+ const wp::fabricbucket_t& bucket = dst.buckets[i];
289
+ int8_t* dst_ptr = static_cast<int8_t*>(bucket.ptr);
290
+ size_t bucket_size = bucket.index_end - bucket.index_start;
291
+ for (size_t j = 0; j < bucket_size; j++)
292
+ {
293
+ memcpy(dst_ptr, src_ptr, elem_size);
294
+ src_ptr += src_stride;
295
+ dst_ptr += elem_size;
296
+ }
297
+ }
298
+ }
299
+ }
300
+ }
301
+
302
+ static void array_copy_from_fabric(const wp::fabricarray_t<void>& src, void* dst_data,
303
+ int dst_stride, const int* dst_indices, int elem_size)
304
+ {
305
+ int8_t* dst_ptr = static_cast<int8_t*>(dst_data);
306
+
307
+ if (dst_indices)
308
+ {
309
+ // copy to indexed array
310
+ for (size_t i = 0; i < src.nbuckets; i++)
311
+ {
312
+ const wp::fabricbucket_t& bucket = src.buckets[i];
313
+ const int8_t* src_ptr = static_cast<const int8_t*>(bucket.ptr);
314
+ size_t bucket_size = bucket.index_end - bucket.index_start;
315
+ for (size_t j = 0; j < bucket_size; j++)
316
+ {
317
+ int idx = *dst_indices;
318
+ memcpy(dst_ptr + idx * elem_size, src_ptr, elem_size);
319
+ src_ptr += elem_size;
320
+ ++dst_indices;
321
+ }
322
+ }
323
+ }
324
+ else
325
+ {
326
+ if (dst_stride == elem_size)
327
+ {
328
+ // copy to contiguous array
329
+ for (size_t i = 0; i < src.nbuckets; i++)
330
+ {
331
+ const wp::fabricbucket_t& bucket = src.buckets[i];
332
+ size_t num_bytes = (bucket.index_end - bucket.index_start) * elem_size;
333
+ memcpy(dst_ptr, bucket.ptr, num_bytes);
334
+ dst_ptr += num_bytes;
335
+ }
336
+ }
337
+ else
338
+ {
339
+ // copy to strided array
340
+ for (size_t i = 0; i < src.nbuckets; i++)
341
+ {
342
+ const wp::fabricbucket_t& bucket = src.buckets[i];
343
+ const int8_t* src_ptr = static_cast<const int8_t*>(bucket.ptr);
344
+ size_t bucket_size = bucket.index_end - bucket.index_start;
345
+ for (size_t j = 0; j < bucket_size; j++)
346
+ {
347
+ memcpy(dst_ptr, src_ptr, elem_size);
348
+ dst_ptr += dst_stride;
349
+ src_ptr += elem_size;
350
+ }
351
+ }
352
+ }
353
+ }
354
+ }
355
+
356
+ static void array_copy_fabric_to_fabric(wp::fabricarray_t<void>& dst, const wp::fabricarray_t<void>& src, int elem_size)
357
+ {
358
+ wp::fabricbucket_t* dst_bucket = dst.buckets;
359
+ const wp::fabricbucket_t* src_bucket = src.buckets;
360
+ int8_t* dst_ptr = static_cast<int8_t*>(dst_bucket->ptr);
361
+ const int8_t* src_ptr = static_cast<const int8_t*>(src_bucket->ptr);
362
+ size_t dst_remaining = dst_bucket->index_end - dst_bucket->index_start;
363
+ size_t src_remaining = src_bucket->index_end - src_bucket->index_start;
364
+ size_t total_copied = 0;
365
+
366
+ while (total_copied < dst.size)
367
+ {
368
+ if (dst_remaining <= src_remaining)
369
+ {
370
+ // copy to destination bucket
371
+ size_t num_elems = dst_remaining;
372
+ size_t num_bytes = num_elems * elem_size;
373
+ memcpy(dst_ptr, src_ptr, num_bytes);
374
+
375
+ // advance to next destination bucket
376
+ ++dst_bucket;
377
+ dst_ptr = static_cast<int8_t*>(dst_bucket->ptr);
378
+ dst_remaining = dst_bucket->index_end - dst_bucket->index_start;
379
+
380
+ // advance source offset
381
+ src_ptr += num_bytes;
382
+ src_remaining -= num_elems;
383
+
384
+ total_copied += num_elems;
385
+ }
386
+ else
387
+ {
388
+ // copy to destination bucket
389
+ size_t num_elems = src_remaining;
390
+ size_t num_bytes = num_elems * elem_size;
391
+ memcpy(dst_ptr, src_ptr, num_bytes);
392
+
393
+ // advance to next source bucket
394
+ ++src_bucket;
395
+ src_ptr = static_cast<const int8_t*>(src_bucket->ptr);
396
+ src_remaining = src_bucket->index_end - src_bucket->index_start;
397
+
398
+ // advance destination offset
399
+ dst_ptr += num_bytes;
400
+ dst_remaining -= num_elems;
401
+
402
+ total_copied += num_elems;
403
+ }
404
+ }
405
+ }
406
+
407
+
408
+ static void array_copy_to_fabric_indexed(wp::indexedfabricarray_t<void>& dst, const void* src_data,
409
+ int src_stride, const int* src_indices, int elem_size)
410
+ {
411
+ const int8_t* src_ptr = static_cast<const int8_t*>(src_data);
412
+
413
+ if (src_indices)
414
+ {
415
+ // copy from indexed array
416
+ for (size_t i = 0; i < dst.size; i++)
417
+ {
418
+ size_t src_idx = src_indices[i];
419
+ size_t dst_idx = dst.indices[i];
420
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
421
+ memcpy(dst_ptr, src_ptr + dst_idx * elem_size, elem_size);
422
+ }
423
+ }
424
+ else
425
+ {
426
+ // copy from contiguous/strided array
427
+ for (size_t i = 0; i < dst.size; i++)
428
+ {
429
+ size_t dst_idx = dst.indices[i];
430
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
431
+ if (dst_ptr)
432
+ {
433
+ memcpy(dst_ptr, src_ptr, elem_size);
434
+ src_ptr += src_stride;
435
+ }
436
+ }
437
+ }
438
+ }
439
+
440
+
441
+ static void array_copy_fabric_indexed_to_fabric(wp::fabricarray_t<void>& dst, const wp::indexedfabricarray_t<void>& src, int elem_size)
442
+ {
443
+ wp::fabricbucket_t* dst_bucket = dst.buckets;
444
+ int8_t* dst_ptr = static_cast<int8_t*>(dst_bucket->ptr);
445
+ int8_t* dst_end = dst_ptr + elem_size * (dst_bucket->index_end - dst_bucket->index_start);
446
+
447
+ for (size_t i = 0; i < src.size; i++)
448
+ {
449
+ size_t src_idx = src.indices[i];
450
+ const void* src_ptr = fabricarray_element_ptr(src.fa, src_idx, elem_size);
451
+
452
+ if (dst_ptr >= dst_end)
453
+ {
454
+ // advance to next destination bucket
455
+ ++dst_bucket;
456
+ dst_ptr = static_cast<int8_t*>(dst_bucket->ptr);
457
+ dst_end = dst_ptr + elem_size * (dst_bucket->index_end - dst_bucket->index_start);
458
+ }
459
+
460
+ memcpy(dst_ptr, src_ptr, elem_size);
461
+
462
+ dst_ptr += elem_size;
463
+ }
464
+ }
465
+
466
+
467
+ static void array_copy_fabric_indexed_to_fabric_indexed(wp::indexedfabricarray_t<void>& dst, const wp::indexedfabricarray_t<void>& src, int elem_size)
468
+ {
469
+ for (size_t i = 0; i < src.size; i++)
470
+ {
471
+ size_t src_idx = src.indices[i];
472
+ size_t dst_idx = dst.indices[i];
473
+
474
+ const void* src_ptr = fabricarray_element_ptr(src.fa, src_idx, elem_size);
475
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
476
+
477
+ memcpy(dst_ptr, src_ptr, elem_size);
478
+ }
479
+ }
480
+
481
+
482
+ static void array_copy_fabric_to_fabric_indexed(wp::indexedfabricarray_t<void>& dst, const wp::fabricarray_t<void>& src, int elem_size)
483
+ {
484
+ wp::fabricbucket_t* src_bucket = src.buckets;
485
+ const int8_t* src_ptr = static_cast<const int8_t*>(src_bucket->ptr);
486
+ const int8_t* src_end = src_ptr + elem_size * (src_bucket->index_end - src_bucket->index_start);
487
+
488
+ for (size_t i = 0; i < dst.size; i++)
489
+ {
490
+ size_t dst_idx = dst.indices[i];
491
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
492
+
493
+ if (src_ptr >= src_end)
494
+ {
495
+ // advance to next source bucket
496
+ ++src_bucket;
497
+ src_ptr = static_cast<int8_t*>(src_bucket->ptr);
498
+ src_end = src_ptr + elem_size * (src_bucket->index_end - src_bucket->index_start);
499
+ }
500
+
501
+ memcpy(dst_ptr, src_ptr, elem_size);
502
+
503
+ src_ptr += elem_size;
504
+ }
505
+ }
506
+
507
+
508
+ static void array_copy_from_fabric_indexed(const wp::indexedfabricarray_t<void>& src, void* dst_data,
509
+ int dst_stride, const int* dst_indices, int elem_size)
510
+ {
511
+ int8_t* dst_ptr = static_cast<int8_t*>(dst_data);
512
+
513
+ if (dst_indices)
514
+ {
515
+ // copy to indexed array
516
+ for (size_t i = 0; i < src.size; i++)
517
+ {
518
+ size_t idx = src.indices[i];
519
+ if (idx < src.fa.size)
520
+ {
521
+ const void* src_ptr = fabricarray_element_ptr(src.fa, idx, elem_size);
522
+ int dst_idx = dst_indices[i];
523
+ memcpy(dst_ptr + dst_idx * elem_size, src_ptr, elem_size);
524
+ }
525
+ else
526
+ {
527
+ fprintf(stderr, "Warp copy error: Source index %llu is out of bounds for fabric array of size %llu",
528
+ (unsigned long long)idx, (unsigned long long)src.fa.size);
529
+ }
530
+ }
531
+ }
532
+ else
533
+ {
534
+ // copy to contiguous/strided array
535
+ for (size_t i = 0; i < src.size; i++)
536
+ {
537
+ size_t idx = src.indices[i];
538
+ if (idx < src.fa.size)
539
+ {
540
+ const void* src_ptr = fabricarray_element_ptr(src.fa, idx, elem_size);
541
+ memcpy(dst_ptr, src_ptr, elem_size);
542
+ dst_ptr += dst_stride;
543
+ }
544
+ else
545
+ {
546
+ fprintf(stderr, "Warp copy error: Source index %llu is out of bounds for fabric array of size %llu",
547
+ (unsigned long long)idx, (unsigned long long)src.fa.size);
548
+ }
549
+ }
550
+ }
551
+ }
552
+
553
+
178
554
  WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type, int elem_size)
179
555
  {
180
556
  if (!src || !dst)
@@ -193,6 +569,12 @@ WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type,
193
569
  const int*const* src_indices = NULL;
194
570
  const int*const* dst_indices = NULL;
195
571
 
572
+ const wp::fabricarray_t<void>* src_fabricarray = NULL;
573
+ wp::fabricarray_t<void>* dst_fabricarray = NULL;
574
+
575
+ const wp::indexedfabricarray_t<void>* src_indexedfabricarray = NULL;
576
+ wp::indexedfabricarray_t<void>* dst_indexedfabricarray = NULL;
577
+
196
578
  const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
197
579
 
198
580
  if (src_type == wp::ARRAY_TYPE_REGULAR)
@@ -214,9 +596,19 @@ WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type,
214
596
  src_strides = src_arr.arr.strides;
215
597
  src_indices = src_arr.indices;
216
598
  }
599
+ else if (src_type == wp::ARRAY_TYPE_FABRIC)
600
+ {
601
+ src_fabricarray = static_cast<const wp::fabricarray_t<void>*>(src);
602
+ src_ndim = 1;
603
+ }
604
+ else if (src_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
605
+ {
606
+ src_indexedfabricarray = static_cast<const wp::indexedfabricarray_t<void>*>(src);
607
+ src_ndim = 1;
608
+ }
217
609
  else
218
610
  {
219
- fprintf(stderr, "Warp error: Invalid array type (%d)\n", src_type);
611
+ fprintf(stderr, "Warp copy error: Invalid source array type (%d)\n", src_type);
220
612
  return 0;
221
613
  }
222
614
 
@@ -239,26 +631,134 @@ WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type,
239
631
  dst_strides = dst_arr.arr.strides;
240
632
  dst_indices = dst_arr.indices;
241
633
  }
634
+ else if (dst_type == wp::ARRAY_TYPE_FABRIC)
635
+ {
636
+ dst_fabricarray = static_cast<wp::fabricarray_t<void>*>(dst);
637
+ dst_ndim = 1;
638
+ }
639
+ else if (dst_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
640
+ {
641
+ dst_indexedfabricarray = static_cast<wp::indexedfabricarray_t<void>*>(dst);
642
+ dst_ndim = 1;
643
+ }
242
644
  else
243
645
  {
244
- fprintf(stderr, "Warp error: Invalid array type (%d)\n", dst_type);
646
+ fprintf(stderr, "Warp copy error: Invalid destination array type (%d)\n", dst_type);
245
647
  return 0;
246
648
  }
247
649
 
248
650
  if (src_ndim != dst_ndim)
249
651
  {
250
- fprintf(stderr, "Warp error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
652
+ fprintf(stderr, "Warp copy error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
251
653
  return 0;
252
654
  }
253
655
 
254
- bool has_grad = (src_grad && dst_grad);
255
- size_t n = 1;
656
+ // handle fabric arrays
657
+ if (dst_fabricarray)
658
+ {
659
+ size_t n = dst_fabricarray->size;
660
+ if (src_fabricarray)
661
+ {
662
+ // copy from fabric to fabric
663
+ if (src_fabricarray->size != n)
664
+ {
665
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
666
+ return 0;
667
+ }
668
+ array_copy_fabric_to_fabric(*dst_fabricarray, *src_fabricarray, elem_size);
669
+ return n;
670
+ }
671
+ else if (src_indexedfabricarray)
672
+ {
673
+ // copy from fabric indexed to fabric
674
+ if (src_indexedfabricarray->size != n)
675
+ {
676
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
677
+ return 0;
678
+ }
679
+ array_copy_fabric_indexed_to_fabric(*dst_fabricarray, *src_indexedfabricarray, elem_size);
680
+ return n;
681
+ }
682
+ else
683
+ {
684
+ // copy to fabric
685
+ if (size_t(src_shape[0]) != n)
686
+ {
687
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
688
+ return 0;
689
+ }
690
+ array_copy_to_fabric(*dst_fabricarray, src_data, src_strides[0], src_indices[0], elem_size);
691
+ return n;
692
+ }
693
+ }
694
+ else if (dst_indexedfabricarray)
695
+ {
696
+ size_t n = dst_indexedfabricarray->size;
697
+ if (src_fabricarray)
698
+ {
699
+ // copy from fabric to fabric indexed
700
+ if (src_fabricarray->size != n)
701
+ {
702
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
703
+ return 0;
704
+ }
705
+ array_copy_fabric_to_fabric_indexed(*dst_indexedfabricarray, *src_fabricarray, elem_size);
706
+ return n;
707
+ }
708
+ else if (src_indexedfabricarray)
709
+ {
710
+ // copy from fabric indexed to fabric indexed
711
+ if (src_indexedfabricarray->size != n)
712
+ {
713
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
714
+ return 0;
715
+ }
716
+ array_copy_fabric_indexed_to_fabric_indexed(*dst_indexedfabricarray, *src_indexedfabricarray, elem_size);
717
+ return n;
718
+ }
719
+ else
720
+ {
721
+ // copy to fabric indexed
722
+ if (size_t(src_shape[0]) != n)
723
+ {
724
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
725
+ return 0;
726
+ }
727
+ array_copy_to_fabric_indexed(*dst_indexedfabricarray, src_data, src_strides[0], src_indices[0], elem_size);
728
+ return n;
729
+ }
730
+ }
731
+ else if (src_fabricarray)
732
+ {
733
+ // copy from fabric
734
+ size_t n = src_fabricarray->size;
735
+ if (size_t(dst_shape[0]) != n)
736
+ {
737
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
738
+ return 0;
739
+ }
740
+ array_copy_from_fabric(*src_fabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size);
741
+ return n;
742
+ }
743
+ else if (src_indexedfabricarray)
744
+ {
745
+ // copy from fabric indexed
746
+ size_t n = src_indexedfabricarray->size;
747
+ if (size_t(dst_shape[0]) != n)
748
+ {
749
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
750
+ return 0;
751
+ }
752
+ array_copy_from_fabric_indexed(*src_indexedfabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size);
753
+ return n;
754
+ }
256
755
 
756
+ size_t n = 1;
257
757
  for (int i = 0; i < src_ndim; i++)
258
758
  {
259
759
  if (src_shape[i] != dst_shape[i])
260
760
  {
261
- fprintf(stderr, "Warp error: Incompatible array shapes\n");
761
+ fprintf(stderr, "Warp copy error: Incompatible array shapes\n");
262
762
  return 0;
263
763
  }
264
764
  n *= src_shape[i];
@@ -269,15 +769,111 @@ WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type,
269
769
  dst_indices, src_indices,
270
770
  src_shape, src_ndim, elem_size);
271
771
 
272
- if (has_grad)
772
+ return n;
773
+ }
774
+
775
+
776
+ static void array_fill_strided(void* data, const int* shape, const int* strides, int ndim, const void* value, int value_size)
777
+ {
778
+ if (ndim == 1)
779
+ {
780
+ char* p = (char*)data;
781
+ for (int i = 0; i < shape[0]; i++)
782
+ {
783
+ memcpy(p, value, value_size);
784
+ p += strides[0];
785
+ }
786
+ }
787
+ else
788
+ {
789
+ for (int i = 0; i < shape[0]; i++)
790
+ {
791
+ char* p = (char*)data + i * strides[0];
792
+ // recurse on next inner dimension
793
+ array_fill_strided(p, shape + 1, strides + 1, ndim - 1, value, value_size);
794
+ }
795
+ }
796
+ }
797
+
798
+
799
+ static void array_fill_indexed(void* data, const int* shape, const int* strides, const int*const* indices, int ndim, const void* value, int value_size)
800
+ {
801
+ if (ndim == 1)
273
802
  {
274
- array_copy_nd(dst_grad, src_grad,
275
- dst_strides, src_strides,
276
- dst_indices, src_indices,
277
- src_shape, src_ndim, elem_size);
803
+ for (int i = 0; i < shape[0]; i++)
804
+ {
805
+ int idx = indices[0] ? indices[0][i] : i;
806
+ char* p = (char*)data + idx * strides[0];
807
+ memcpy(p, value, value_size);
808
+ }
278
809
  }
810
+ else
811
+ {
812
+ for (int i = 0; i < shape[0]; i++)
813
+ {
814
+ int idx = indices[0] ? indices[0][i] : i;
815
+ char* p = (char*)data + idx * strides[0];
816
+ // recurse on next inner dimension
817
+ array_fill_indexed(p, shape + 1, strides + 1, indices + 1, ndim - 1, value, value_size);
818
+ }
819
+ }
820
+ }
279
821
 
280
- return n;
822
+
823
+ static void array_fill_fabric(wp::fabricarray_t<void>& fa, const void* value_ptr, int value_size)
824
+ {
825
+ for (size_t i = 0; i < fa.nbuckets; i++)
826
+ {
827
+ const wp::fabricbucket_t& bucket = fa.buckets[i];
828
+ size_t bucket_size = bucket.index_end - bucket.index_start;
829
+ memtile_host(bucket.ptr, value_ptr, value_size, bucket_size);
830
+ }
831
+ }
832
+
833
+
834
+ static void array_fill_fabric_indexed(wp::indexedfabricarray_t<void>& ifa, const void* value_ptr, int value_size)
835
+ {
836
+ for (size_t i = 0; i < ifa.size; i++)
837
+ {
838
+ size_t idx = size_t(ifa.indices[i]);
839
+ if (idx < ifa.fa.size)
840
+ {
841
+ void* p = fabricarray_element_ptr(ifa.fa, idx, value_size);
842
+ memcpy(p, value_ptr, value_size);
843
+ }
844
+ }
845
+ }
846
+
847
+
848
+ WP_API void array_fill_host(void* arr_ptr, int arr_type, const void* value_ptr, int value_size)
849
+ {
850
+ if (!arr_ptr || !value_ptr)
851
+ return;
852
+
853
+ if (arr_type == wp::ARRAY_TYPE_REGULAR)
854
+ {
855
+ wp::array_t<void>& arr = *static_cast<wp::array_t<void>*>(arr_ptr);
856
+ array_fill_strided(arr.data, arr.shape.dims, arr.strides, arr.ndim, value_ptr, value_size);
857
+ }
858
+ else if (arr_type == wp::ARRAY_TYPE_INDEXED)
859
+ {
860
+ wp::indexedarray_t<void>& ia = *static_cast<wp::indexedarray_t<void>*>(arr_ptr);
861
+ array_fill_indexed(ia.arr.data, ia.shape.dims, ia.arr.strides, ia.indices, ia.arr.ndim, value_ptr, value_size);
862
+ }
863
+ else if (arr_type == wp::ARRAY_TYPE_FABRIC)
864
+ {
865
+ wp::fabricarray_t<void>& fa = *static_cast<wp::fabricarray_t<void>*>(arr_ptr);
866
+ array_fill_fabric(fa, value_ptr, value_size);
867
+ }
868
+ else if (arr_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
869
+ {
870
+ wp::indexedfabricarray_t<void>& ifa = *static_cast<wp::indexedfabricarray_t<void>*>(arr_ptr);
871
+ array_fill_fabric_indexed(ifa, value_ptr, value_size);
872
+ }
873
+ else
874
+ {
875
+ fprintf(stderr, "Warp fill error: Invalid array type id %d\n", arr_type);
876
+ }
281
877
  }
282
878
 
283
879
 
@@ -334,7 +930,7 @@ void memset_device(void* context, void* dest, int value, size_t n)
334
930
  {
335
931
  }
336
932
 
337
- void memtile_device(void* context, void* dest, void *src, size_t srcsize, size_t n)
933
+ void memtile_device(void* context, void* dest, const void* src, size_t srcsize, size_t n)
338
934
  {
339
935
  }
340
936
 
@@ -343,8 +939,13 @@ size_t array_copy_device(void* context, void* dst, void* src, int dst_type, int
343
939
  return 0;
344
940
  }
345
941
 
942
+ void array_fill_device(void* context, void* arr, int arr_type, const void* value, int value_size)
943
+ {
944
+ }
945
+
346
946
  WP_API int cuda_driver_version() { return 0; }
347
947
  WP_API int cuda_toolkit_version() { return 0; }
948
+ WP_API bool cuda_driver_is_initialized() { return false; }
348
949
 
349
950
  WP_API int nvrtc_supported_arch_count() { return 0; }
350
951
  WP_API void nvrtc_supported_archs(int* archs) {}
@@ -354,7 +955,12 @@ WP_API void* cuda_device_primary_context_retain(int ordinal) { return NULL; }
354
955
  WP_API void cuda_device_primary_context_release(int ordinal) {}
355
956
  WP_API const char* cuda_device_get_name(int ordinal) { return NULL; }
356
957
  WP_API int cuda_device_get_arch(int ordinal) { return 0; }
958
+ WP_API void cuda_device_get_uuid(int ordinal, char uuid[16]) {}
959
+ WP_API int cuda_device_get_pci_domain_id(int ordinal) { return -1; }
960
+ WP_API int cuda_device_get_pci_bus_id(int ordinal) { return -1; }
961
+ WP_API int cuda_device_get_pci_device_id(int ordinal) { return -1; }
357
962
  WP_API int cuda_device_is_uva(int ordinal) { return 0; }
963
+ WP_API int cuda_device_is_memory_pool_supported() { return 0; }
358
964
 
359
965
  WP_API void* cuda_context_get_current() { return NULL; }
360
966
  WP_API void cuda_context_set_current(void* ctx) {}
@@ -366,6 +972,7 @@ WP_API void cuda_context_synchronize(void* context) {}
366
972
  WP_API uint64_t cuda_context_check(void* context) { return 0; }
367
973
  WP_API int cuda_context_get_device_ordinal(void* context) { return -1; }
368
974
  WP_API int cuda_context_is_primary(void* context) { return 0; }
975
+ WP_API int cuda_context_is_memory_pool_supported(void* context) { return 0; }
369
976
  WP_API void* cuda_context_get_stream(void* context) { return NULL; }
370
977
  WP_API void cuda_context_set_stream(void* context, void* stream) {}
371
978
  WP_API int cuda_context_can_access_peer(void* context, void* peer_context) { return 0; }
@@ -392,13 +999,11 @@ WP_API size_t cuda_compile_program(const char* cuda_src, int arch, const char* i
392
999
  WP_API void* cuda_load_module(void* context, const char* ptx) { return NULL; }
393
1000
  WP_API void cuda_unload_module(void* context, void* module) {}
394
1001
  WP_API void* cuda_get_kernel(void* context, void* module, const char* name) { return NULL; }
395
- WP_API size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, void** args) { return 0;}
1002
+ WP_API size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args) { return 0;}
396
1003
 
397
1004
  WP_API void cuda_set_context_restore_policy(bool always_restore) {}
398
1005
  WP_API int cuda_get_context_restore_policy() { return false; }
399
1006
 
400
- WP_API void array_inner_device(uint64_t a, uint64_t b, uint64_t out, int len) {}
401
- WP_API void array_sum_device(uint64_t a, uint64_t out, int len) {}
402
1007
  WP_API void array_scan_int_device(uint64_t in, uint64_t out, int len, bool inclusive) {}
403
1008
  WP_API void array_scan_float_device(uint64_t in, uint64_t out, int len, bool inclusive) {}
404
1009