warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,619 @@
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+
33
+ from typing import Generic, TypeVar
34
+ from treelib import Tree
35
+ import numpy as np
36
+
37
+ from pycutlass import *
38
+ import pycutlass
39
+
40
+ import ast
41
+ import textwrap
42
+ import inspect
43
+
44
+ ################################################################################
45
+ # Type annotation for input arguments
46
+ ################################################################################
47
+
48
+ Ttype = TypeVar("Ttype")
49
+ Dtype = TypeVar("Dtype")
50
+
51
+ class NDArray(np.ndarray, Generic[Ttype, Dtype]):
52
+ pass
53
+
54
+ ################################################################################
55
+ # Operations
56
+ ################################################################################
57
+
58
+ operators = {
59
+ ast.Add: "Add",
60
+ ast.Div: "Div",
61
+ ast.Eq: "Equal",
62
+ ast.Mult: "Mult"
63
+ }
64
+
65
+ ################################################################################
66
+ # AST Node abstractions
67
+ ################################################################################
68
+ class UnaryNode:
69
+ cnt = 0
70
+ # Concept: this is created by the BinOp Node in python ast
71
+ def __init__(self,
72
+ element_accumulator, element_compute, elements_per_access,
73
+ node, args) -> None:
74
+ if isinstance(node, BinOpNode):
75
+ self.op = node.op
76
+ elif isinstance(node, ast.Call):
77
+ if isinstance(node.func, ast.Name):
78
+ self.op = node.func.id
79
+ elif isinstance(node.func, ast.Attribute):
80
+ self.op = node.func.value.id
81
+ else:
82
+ raise TypeError
83
+ else:
84
+ raise TypeError
85
+ self.tag = "Unary" + self.op + str(UnaryNode.cnt)
86
+ self.id = self.op + str(UnaryNode.cnt)
87
+ self.args = args
88
+ UnaryNode.cnt += 1
89
+
90
+ self.type = "tensor"
91
+
92
+ self.epilogue_op = getattr(pycutlass, self.op)(element_compute)
93
+
94
+ # data types
95
+ self.element_accumulator = element_accumulator
96
+ self.element_compute = element_compute
97
+ self.elements_per_access = elements_per_access
98
+
99
+ def get_epilogue_node(self, visitors):
100
+ self.epilogue_node = UnaryOp(
101
+ self.element_accumulator, self.element_compute,
102
+ self.elements_per_access, *visitors, self.epilogue_op)
103
+
104
+ def get_argument(self, visitor_args, kwargs):
105
+ epilogue_ops = []
106
+ for arg in self.args:
107
+ try:
108
+ epilogue_ops.append(kwargs[arg])
109
+ except:
110
+ epilogue_ops.append(arg) # direct arguments like constant
111
+ self.argument = self.epilogue_node.argument_type(self.epilogue_op.argument_type(*epilogue_ops), *visitor_args)
112
+
113
+
114
+ class BinOpNode:
115
+ cnt = 0
116
+ # Concept: this is created by the BinOp Node in python ast
117
+ def __init__(self,
118
+ element_accumulator, element_compute, elements_per_access,
119
+ node) -> None:
120
+ self.op = operators[type(node.op)]
121
+ self.tag = "Binary" + self.op + str(BinOpNode.cnt)
122
+ self.id = self.op + str(BinOpNode.cnt)
123
+ self.args = None
124
+ BinOpNode.cnt += 1
125
+
126
+ self.type = "tensor"
127
+
128
+ self.epilogue_op = getattr(pycutlass, "Vector"+self.op)(element_compute)
129
+
130
+ # data types
131
+ self.element_accumulator = element_accumulator
132
+ self.element_compute = element_compute
133
+ self.elements_per_access = elements_per_access
134
+
135
+ def get_epilogue_node(self, visitors):
136
+ self.epilogue_node = BinaryOp(
137
+ self.element_accumulator, self.element_compute,
138
+ self.elements_per_access, *visitors, self.epilogue_op)
139
+
140
+ def get_argument(self, visitor_args, kwargs):
141
+ self.argument = self.epilogue_node.argument_type(self.epilogue_op.argument_type(self.args), *visitor_args)
142
+
143
+
144
+ class NameNode:
145
+ # Concept: this is created by the Name Node in python ast
146
+ def __init__(self, node) -> None:
147
+ try:
148
+ self.id = node.id
149
+ except:
150
+ self.id = node.targets[0].id
151
+ self.tag = self.id
152
+
153
+ class ScalarInputNode(NameNode):
154
+ # Concept: scalar
155
+ def __init__(self, node) -> None:
156
+ super().__init__(node)
157
+ self.tag = "Scalar:" + self.tag
158
+ self.type = "scalar"
159
+
160
+ class AccumulatorNode(NameNode):
161
+ # Concept: VisitorOpAccumulator
162
+ def __init__(self,
163
+ element_accumulator, elements_per_access, node) -> None:
164
+ super().__init__(node)
165
+ self.tag = "Accum:" + self.tag
166
+ self.type = "tensor"
167
+
168
+ self.element_accumulator = element_accumulator
169
+ self.elements_per_access = elements_per_access
170
+
171
+ def get_epilogue_node(self, visitors):
172
+ self.epilogue_node = AccumulatorOp(
173
+ self.element_accumulator, self.elements_per_access)
174
+
175
+ def get_argument(self, visitor_args, kwargs):
176
+ self.argument = self.epilogue_node.argument_type()
177
+
178
+ class TensorInputNode(NameNode):
179
+ # Concept: VisitorOpTensorInput
180
+ def __init__(self, element_accumulator, node) -> None:
181
+ super().__init__(node)
182
+ self.tag = "TensorInput:" + self.tag
183
+ self.type = "tensor"
184
+ self.element_accumulator = element_accumulator
185
+
186
+ def get_epilogue_node(self, *args):
187
+ self.epilogue_node = TensorInputOp(self.element_accumulator)
188
+
189
+ def get_argument(self, visitor_args, kwargs):
190
+ self.argument = self.epilogue_node.argument_type(
191
+ kwargs[self.id + "_ptr"], kwargs["problem_size"][1],
192
+ kwargs["problem_size"][0] * kwargs["problem_size"][1])
193
+
194
+ class RowBroadcastNode(NameNode):
195
+ # Concept: VisitorOpRowBroadcast
196
+ def __init__(self, element_accumulator, element_fragment, node) -> None:
197
+ super().__init__(node)
198
+ #
199
+ self.tag = "RowBroadcast:" + self.tag
200
+ self.type = "tensor"
201
+ self.element_accumulator = element_accumulator
202
+ self.element_fragment = element_fragment
203
+
204
+ def get_epilogue_node(self, *args):
205
+ self.epilogue_node = RowBroadcastOp(
206
+ self.element_accumulator, self.element_fragment)
207
+
208
+ def get_argument(self, visitor_args, kwargs):
209
+ self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][1])
210
+
211
+ class ColumnBroadcastNode(NameNode):
212
+ # Concept: VisitorOpColumnBroadcast
213
+ def __init__(self, element_accumulator, element_fragment, node) -> None:
214
+ super().__init__(node)
215
+ self.tag = "ColumnBroadcast:" + self.tag
216
+ self.type = "tensor"
217
+ self.element_accumulator = element_accumulator
218
+ self.element_fragment = element_fragment
219
+
220
+ def get_epilogue_node(self, *args):
221
+ self.epilogue_node = ColumnBroadcastOp(
222
+ self.element_accumulator, self.element_fragment)
223
+
224
+ def get_argument(self, visitor_args, kwargs):
225
+ self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][0])
226
+
227
+ class TensorOutputNode(NameNode):
228
+ # Concept: VisitorOpTensorOutput
229
+ def __init__(self, element_accumulator, node) -> None:
230
+ super().__init__(node)
231
+ self.tag = "TensorOutput:" + self.tag
232
+ self.type = "tensor"
233
+ self.element_accumulator = element_accumulator
234
+
235
+ def get_epilogue_node(self, visitors):
236
+ self.epilogue_node = TensorOutputOp(self.element_accumulator, *visitors)
237
+
238
+ def get_argument(self, visitor_args, kwargs):
239
+ self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][1], *visitor_args, kwargs["problem_size"][0] * kwargs["problem_size"][1])
240
+
241
+ class RowReductionNode:
242
+ # Concept: RowReductionOp
243
+ def __init__(self, element_accumulator, element_reduction,
244
+ element_reduction_accumulator, id, factor) -> None:
245
+ #
246
+ self.id = id
247
+ self.tag = "RowReduction:" + self.id
248
+ self.type = "tensor"
249
+ self.element_accumulator = element_accumulator
250
+ self.element_reduction = element_reduction
251
+ self.element_reduction_accumulator = element_reduction_accumulator
252
+ self.factor = factor
253
+
254
+ def get_epilogue_node(self, visitors):
255
+ self.epilogue_node = RowReductionOp(
256
+ self.element_accumulator, self.element_reduction,
257
+ self.element_reduction_accumulator, *visitors)
258
+
259
+ def get_batch_stride(self, problem_size):
260
+ return problem_size[0] * ((problem_size[1] + self.factor - 1) // self.factor)
261
+
262
+ def get_argument(self, visitor_args, kwargs):
263
+ self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], *visitor_args, self.get_batch_stride(kwargs["problem_size"]))
264
+
265
+ class ColumnReductionNode:
266
+ # Concept: ColumnReductionOp
267
+ def __init__(self, element_accumulator, element_reduction,
268
+ element_reduction_accumulator, id, factor) -> None:
269
+ #
270
+ self.id = id
271
+ self.tag = "ColumnReduction:" + self.id
272
+ self.type = "tensor"
273
+ self.element_accumulator = element_accumulator
274
+ self.element_reduction = element_reduction
275
+ self.element_reduction_accumulator = element_reduction_accumulator
276
+ self.factor = factor
277
+
278
+ def get_epilogue_node(self, visitors):
279
+ self.epilogue_node = ColumnReductionOp(
280
+ self.element_accumulator, self.element_reduction,
281
+ self.element_reduction_accumulator, *visitors)
282
+
283
+ def get_batch_stride(self, problem_size):
284
+ return problem_size[1] * ((problem_size[0] + self.factor - 1) // self.factor)
285
+
286
+ def get_argument(self, visitor_args, kwargs):
287
+ self.argument = self.epilogue_node.argument_type(kwargs[self.id + '_ptr'], *visitor_args, self.get_batch_stride(kwargs["problem_size"]))
288
+
289
+ ################################################################################
290
+ # Epilogue parser function
291
+ ################################################################################
292
+ class EpilogueAST(ast.NodeVisitor):
293
+ def __init__(self, epilogue,
294
+ tile_description,
295
+ element_accumulator, elements_per_access,
296
+ element_compute, element_output) -> None:
297
+ #
298
+
299
+ self.tile_description = tile_description
300
+ self.element_accumulator = element_accumulator
301
+ self.elements_per_access = elements_per_access
302
+ self.element_compute = element_compute
303
+ self.element_output = element_output
304
+ self.epilogue = epilogue
305
+
306
+ self.source = textwrap.dedent(inspect.getsource(epilogue.__call__))
307
+ self.ast_tree = ast.parse(self.source)
308
+ self.epilogue_tree = Tree()
309
+
310
+
311
+ # print(ast.dump(self.ast_tree, indent=4)) # For Debug purpose
312
+
313
+ # input arguments
314
+ self.input_args = {}
315
+ # return nodes
316
+ self.returns = []
317
+ # reduction source nodes
318
+ self.reduction_source = {}
319
+
320
+ # stack used to keep the parent node id
321
+ self.stack = []
322
+
323
+ # visit the AST
324
+ self.visit(self.ast_tree)
325
+
326
+ # visit the name node
327
+ def visit_Name(self, node):
328
+ # append the return ids into self.returns
329
+ if self.stack[-1] == "return":
330
+ self.returns.append(node.id)
331
+ else:
332
+ # accum is produced from accumulator node
333
+ if node.id == "accum":
334
+ name_node = AccumulatorNode(
335
+ self.element_accumulator, self.elements_per_access, node)
336
+ else:
337
+ # for input nodes
338
+ if node.id in self.input_args.keys():
339
+ type = self.input_args[node.id][0]
340
+ if type == "tensor":
341
+ name_node = TensorInputNode(self.element_accumulator, node)
342
+ elif type == "row":
343
+ name_node = RowBroadcastNode(self.element_accumulator, self.element_compute, node)
344
+ elif type == "column":
345
+ name_node = ColumnBroadcastNode(self.element_accumulator, self.element_compute, node)
346
+ elif type == "scalar":
347
+ name_node = ScalarInputNode(node)
348
+ else:
349
+ raise ValueError(type)
350
+ # for output nodes
351
+ else:
352
+ name_node = TensorOutputNode(self.element_accumulator, node)
353
+ self.epilogue_tree.create_node(name_node.tag, name_node.id, data=name_node, parent=self.stack[-1])
354
+
355
+ def visit_Assign(self, node):
356
+ pre_assign_node = self.epilogue_tree.get_node(node.targets[0].id)
357
+ if pre_assign_node is None:
358
+ # The assign is to a root node
359
+ # skip the reduction nodes
360
+ if isinstance(node.value, ast.Call):
361
+ if isinstance(node.value.func, ast.Name):
362
+ func_type = node.value.func.id
363
+ elif isinstance(node.value.func, ast.Attribute):
364
+ func_type = node.value.func.value.id
365
+ else:
366
+ raise TypeError
367
+ if func_type == 'reduction_op':
368
+ self.reduction_source[node.value.args[0].id] = [node.value.args[1].value, node.value.args[2].value, node.targets[0].id]
369
+ return
370
+ name_node = TensorOutputNode(self.element_accumulator, node)
371
+ self.epilogue_tree.create_node(name_node.tag, name_node.id, data=name_node)
372
+ self.stack.append(name_node.id)
373
+ else:
374
+ if node.targets[0].id in self.returns or node.targets[0].id in self.reduction_source.keys():
375
+ self.stack.append(node.targets[0].id)
376
+ else:
377
+ self.stack.append(pre_assign_node.predecessor(self.epilogue_tree.identifier))
378
+ self.epilogue_tree.remove_node(node.targets[0].id)
379
+
380
+ # get child tag
381
+ self.visit(node.value)
382
+ self.stack.pop()
383
+
384
+ def visit_Call(self, node):
385
+ if isinstance(node.func, ast.Name):
386
+ func_type = node.func.id
387
+ elif isinstance(node.func, ast.Attribute):
388
+ func_type = node.func.value.id
389
+ else:
390
+ raise TypeError
391
+ if func_type == "reduction_op":
392
+ self.visit(node.args[0])
393
+ else:
394
+ arg_list = []
395
+ for idx, arg in enumerate(node.args):
396
+ if idx == 0: continue
397
+ if isinstance(arg, ast.Constant):
398
+ arg_list.append(arg.value)
399
+ elif isinstance(arg, ast.Name):
400
+ arg_list.append(arg.id)
401
+ else:
402
+ raise TypeError
403
+
404
+ unary_node = UnaryNode(self.element_accumulator, self.element_compute, self.elements_per_access, node, arg_list)
405
+ self.epilogue_tree.create_node(unary_node.tag, unary_node.id, parent=self.stack[-1], data=unary_node)
406
+ self.stack.append(unary_node.id)
407
+ self.visit(node.args[0])
408
+ self.stack.pop()
409
+
410
+ def visit_BinOp(self, node):
411
+ binop = BinOpNode(self.element_accumulator, self.element_compute,
412
+ self.elements_per_access, node)
413
+ self.epilogue_tree.create_node(binop.tag, binop.id, data=binop, parent=self.stack[-1])
414
+ self.stack.append(binop.id)
415
+ self.visit(node.left)
416
+ self.visit(node.right)
417
+ self.stack.pop()
418
+
419
+ def visit_Return(self, node):
420
+ self.stack.append("return")
421
+ self.visit(node.value)
422
+ self.stack.pop()
423
+
424
+ # # A function definition
425
+ def visit_FunctionDef(self, node: ast.FunctionDef):
426
+ # visit args
427
+ for arg in node.args.args:
428
+ if arg.arg == "self": continue
429
+ if isinstance(arg.annotation, ast.Constant):
430
+ self.input_args[arg.arg] = [arg.annotation.value, ]
431
+ # visit the assign in the reverse order
432
+ for idx in range(len(node.body)):
433
+ self.visit(node.body[-1-idx])
434
+
435
+ #
436
+ # Tree optimization pass
437
+ #
438
+
439
+ # pass 1: lower Binary to Unary
440
+ def pass_binary_2_unary(self, tree, nid):
441
+ node = tree.get_node(nid)
442
+ if isinstance(node.data, BinOpNode):
443
+ lhs_node = tree.get_node(node.successors(tree.identifier)[0])
444
+ left_type = lhs_node.data.type
445
+ rhs_node = tree.get_node(node.successors(tree.identifier)[1])
446
+ right_type = rhs_node.data.type
447
+
448
+ if left_type == "scalar" and right_type == "tensor":
449
+ node.data = UnaryNode(
450
+ self.element_accumulator, self.element_compute,
451
+ self.elements_per_access,
452
+ node.data, [lhs_node.data.id,])
453
+ node.tag = node.data.tag
454
+ tree.remove_node(lhs_node.data.id)
455
+ self.pass_binary_2_unary(tree, rhs_node.data.id)
456
+
457
+ elif left_type == "tensor" and right_type == "scalar":
458
+ node.data = UnaryNode(
459
+ self.element_accumulator, self.element_compute,
460
+ self.elements_per_access,
461
+ node.data, [rhs_node.id,])
462
+ node.tag = node.data.tag
463
+ tree.remove_node(rhs_node.data.id)
464
+ self.pass_binary_2_unary(tree, lhs_node.data.id)
465
+
466
+ else:
467
+ self.pass_binary_2_unary(tree, lhs_node.data.id)
468
+ self.pass_binary_2_unary(tree, rhs_node.data.id)
469
+ else:
470
+ for child in node.successors(tree.identifier):
471
+ self.pass_binary_2_unary(tree, child)
472
+
473
+ # pass 2: inject reduction nodes
474
+ def pass_inject_reduction(self, tree, nid):
475
+ node = tree.get_node(nid)
476
+ if isinstance(node.data, TensorOutputNode):
477
+ if node.data.id in self.reduction_source.keys():
478
+ direction = self.reduction_source[node.data.id][0]
479
+ target = self.reduction_source[node.data.id][-1]
480
+ if direction == 'row':
481
+ reduction_node = RowReductionNode(
482
+ self.element_accumulator, self.element_output,
483
+ self.element_accumulator, target, self.tile_description.threadblock_shape[1])
484
+ elif direction == "column":
485
+ reduction_node = ColumnReductionNode(
486
+ self.element_accumulator, self.element_output,
487
+ self.element_accumulator, target, self.tile_description.threadblock_shape[0])
488
+ else:
489
+ raise ValueError(direction)
490
+ child_nid = node.successors(tree.identifier)[0]
491
+ # if this output node is injected only for reduction
492
+ if node.data.id not in self.returns:
493
+ # get reduction config from disc
494
+ node.data = reduction_node
495
+ node.tag = reduction_node.tag
496
+ self.pass_inject_reduction(tree, child_nid)
497
+ # if this output node is also a tensor output, inject reduction as its children
498
+ else:
499
+ # get child node
500
+ tree.create_node(reduction_node.tag, reduction_node.id, data=reduction_node, parent=node.data.id)
501
+ tree.move_node(child_nid, reduction_node.id)
502
+ child = tree.get_node(child_nid)
503
+ for grand_child in child.successors(tree.identifier):
504
+ self.pass_inject_reduction(tree, grand_child)
505
+ else:
506
+ for child in node.successors(tree.identifier):
507
+ self.pass_inject_reduction(tree, child)
508
+ else:
509
+ for child in node.successors(tree.identifier):
510
+ self.pass_inject_reduction(tree, child)
511
+
512
+ def pass_inject_epilogue_op(self, tree, nid):
513
+ node = tree.get_node(nid)
514
+ visitors = []
515
+ for child in node.successors(tree.identifier):
516
+ visitors.append(self.pass_inject_epilogue_op(tree, child))
517
+
518
+ node.data.get_epilogue_node(visitors)
519
+ return node.data.epilogue_node
520
+
521
+ def get_arguments(self, tree, nid, kwargs):
522
+ node = tree.get_node(nid)
523
+ visitor_args = []
524
+ for child in node.successors(tree.identifier):
525
+ visitor_args.append(self.get_arguments(tree, child, kwargs))
526
+
527
+ node.data.get_argument(visitor_args, kwargs)
528
+ return node.data.argument
529
+
530
+ class EpilogueVisitTree:
531
+ KernelTemplate = """
532
+ ${visitor}
533
+
534
+ using ${operation_name}_EpilogueVisitor = cutlass::epilogue::threadblock::EpilogueVisitorGeneric<${visitor_name}>;
535
+ """
536
+ def __init__(self, elementwise_functor, tile_description,
537
+ element_accumulator, elements_per_access,
538
+ element_compute, element_output) -> None:
539
+ #
540
+ # data types
541
+ self.tile_description = tile_description
542
+ self.element_accumulator = element_accumulator
543
+ self.elements_per_access = elements_per_access
544
+ self.element_compute = element_compute
545
+ self.element_output = element_output
546
+ # TODO: deprecate this
547
+ self.elementwise_functor = elementwise_functor
548
+ pass
549
+
550
+ def initialize(self):
551
+ function = EpilogueAST(self, self.tile_description,
552
+ self.element_accumulator, self.elements_per_access,
553
+ self.element_compute, self.element_output)
554
+ #
555
+ tree = function.epilogue_tree
556
+ self.tree = tree
557
+ # self.tree.show() # for debug
558
+ function.pass_binary_2_unary(self.tree, self.tree.root)
559
+ # self.tree.show() # for debug
560
+ function.pass_inject_reduction(self.tree, self.tree.root)
561
+ # self.tree.show() # for debug
562
+ function.pass_inject_epilogue_op(self.tree,self.tree.root)
563
+
564
+ visitor = self.tree.get_node(self.tree.root).data.epilogue_node
565
+ self.visitor = visitor
566
+
567
+ class _Argument(ctypes.Structure):
568
+ _fields_ = [
569
+ ("visitor_arg", visitor.argument_type)
570
+ ]
571
+ def __init__(self, **kwargs) -> None:
572
+ # process input args
573
+ _kwargs = {}
574
+ for input_key in function.input_args.keys():
575
+ if input_key == "accum":
576
+ continue
577
+ if function.input_args[input_key][0] == "scalar":
578
+ # _kwargs[input_key] = kwargs[input_key]
579
+ continue
580
+ # tensor input
581
+ else:
582
+ setattr(self, "buffer_tensor_" + input_key, NumpyFrontend.argument(kwargs[input_key], False))
583
+ setattr(self, input_key + "_ptr", int(getattr(self, "buffer_tensor_" + input_key).ptr))
584
+ _kwargs[input_key+"_ptr"] = getattr(self, input_key + "_ptr")
585
+ # process the return args
586
+ for ret in function.returns:
587
+ setattr(self, "buffer_tensor_" + ret, NumpyFrontend.argument(kwargs[ret], True))
588
+ setattr(self, ret + "_ptr", int(getattr(self, "buffer_tensor_" + ret).ptr))
589
+ _kwargs[ret+"_ptr"] = getattr(self, ret + "_ptr")
590
+ setattr(self, "host_tensor_" + ret, kwargs[ret])
591
+
592
+ _kwargs.update(kwargs)
593
+ function.get_arguments(tree, tree.root, _kwargs)
594
+ self.visitor_arg = tree.get_node(tree.root).data.argument
595
+
596
+ def sync(self, stream_sync=True):
597
+ if stream_sync:
598
+ err, = cudart.cudaDeviceSynchronize()
599
+ if err != cuda.CUresult.CUDA_SUCCESS:
600
+ raise RuntimeError("CUDA Error %s" % str(err))
601
+
602
+ for ret in function.returns:
603
+ err, = cuda.cuMemcpyDtoH(
604
+ getattr(self, "host_tensor_" + ret), cuda.CUdeviceptr(getattr(self, ret + "_ptr")),
605
+ getattr(self, "host_tensor_" + ret).size * getattr(self, "host_tensor_" + ret).itemsize
606
+ )
607
+ if err != cuda.CUresult.CUDA_SUCCESS:
608
+ raise RuntimeError("CUDA Error %s" % str(err))
609
+ pass
610
+
611
+ self.epilogue_type = _Argument
612
+
613
+ def emit(self, operation):
614
+ values = {
615
+ 'visitor': self.visitor.emit(operation),
616
+ 'operation_name': operation.procedural_name(),
617
+ 'visitor_name': self.visitor.instance_name
618
+ }
619
+ return SubstituteTemplate(self.KernelTemplate, values)