warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/crt.h CHANGED
@@ -259,6 +259,8 @@ float expf(float);
259
259
  double exp(double);
260
260
  float sqrtf(float);
261
261
  double sqrt(double);
262
+ float cbrtf(float);
263
+ double cbrt(double);
262
264
  float powf(float, float);
263
265
  double pow(double, double);
264
266
  float floorf(float);
warp/native/cuda_crt.h CHANGED
@@ -1033,6 +1033,11 @@ __device_forceinline__ unsigned int atomicAdd(unsigned int *const address, const
1033
1033
  return __uAtomicAdd(address, val);
1034
1034
  }
1035
1035
 
1036
+ __device_forceinline__ unsigned int atomicAdd(unsigned long long *const address, const unsigned long long val)
1037
+ {
1038
+ return __ullAtomicAdd(address, val);
1039
+ }
1040
+
1036
1041
  __device_forceinline__ int atomicMin(int *const address, const int val)
1037
1042
  {
1038
1043
  return __iAtomicMin(address, val);
warp/native/cuda_util.cpp CHANGED
@@ -59,6 +59,7 @@ static PFN_cuDeviceGet_v2000 pfn_cuDeviceGet;
59
59
  static PFN_cuDeviceGetCount_v2000 pfn_cuDeviceGetCount;
60
60
  static PFN_cuDeviceGetName_v2000 pfn_cuDeviceGetName;
61
61
  static PFN_cuDeviceGetAttribute_v2000 pfn_cuDeviceGetAttribute;
62
+ static PFN_cuDeviceGetUuid_v11040 pfn_cuDeviceGetUuid;
62
63
  static PFN_cuDevicePrimaryCtxRetain_v7000 pfn_cuDevicePrimaryCtxRetain;
63
64
  static PFN_cuDevicePrimaryCtxRelease_v11000 pfn_cuDevicePrimaryCtxRelease;
64
65
  static PFN_cuDeviceCanAccessPeer_v4000 pfn_cuDeviceCanAccessPeer;
@@ -89,6 +90,7 @@ static PFN_cuGraphicsResourceGetMappedPointer_v3020 pfn_cuGraphicsResourceGetMap
89
90
  static PFN_cuGraphicsGLRegisterBuffer_v3000 pfn_cuGraphicsGLRegisterBuffer;
90
91
  static PFN_cuGraphicsUnregisterResource_v3000 pfn_cuGraphicsUnregisterResource;
91
92
 
93
+ static bool cuda_driver_initialized = false;
92
94
 
93
95
  bool ContextGuard::always_restore = false;
94
96
 
@@ -165,6 +167,7 @@ bool init_cuda_driver()
165
167
  get_driver_entry_point("cuDeviceGetCount", &(void*&)pfn_cuDeviceGetCount);
166
168
  get_driver_entry_point("cuDeviceGetName", &(void*&)pfn_cuDeviceGetName);
167
169
  get_driver_entry_point("cuDeviceGetAttribute", &(void*&)pfn_cuDeviceGetAttribute);
170
+ get_driver_entry_point("cuDeviceGetUuid", &(void*&)pfn_cuDeviceGetUuid);
168
171
  get_driver_entry_point("cuDevicePrimaryCtxRetain", &(void*&)pfn_cuDevicePrimaryCtxRetain);
169
172
  get_driver_entry_point("cuDevicePrimaryCtxRelease", &(void*&)pfn_cuDevicePrimaryCtxRelease);
170
173
  get_driver_entry_point("cuDeviceCanAccessPeer", &(void*&)pfn_cuDeviceCanAccessPeer);
@@ -196,11 +199,15 @@ bool init_cuda_driver()
196
199
  get_driver_entry_point("cuGraphicsUnregisterResource", &(void*&)pfn_cuGraphicsUnregisterResource);
197
200
 
198
201
  if (pfn_cuInit)
199
- return check_cu(pfn_cuInit(0));
200
- else
201
- return false;
202
+ cuda_driver_initialized = check_cu(pfn_cuInit(0));
203
+
204
+ return cuda_driver_initialized;
202
205
  }
203
206
 
207
+ bool is_cuda_driver_initialized()
208
+ {
209
+ return cuda_driver_initialized;
210
+ }
204
211
 
205
212
  bool check_cuda_result(cudaError_t code, const char* file, int line)
206
213
  {
@@ -284,6 +291,11 @@ CUresult cuDeviceGetAttribute_f(int* value, CUdevice_attribute attrib, CUdevice
284
291
  return pfn_cuDeviceGetAttribute ? pfn_cuDeviceGetAttribute(value, attrib, dev) : DRIVER_ENTRY_POINT_ERROR;
285
292
  }
286
293
 
294
+ CUresult cuDeviceGetUuid_f(CUuuid* uuid, CUdevice dev)
295
+ {
296
+ return pfn_cuDeviceGetUuid ? pfn_cuDeviceGetUuid(uuid, dev) : DRIVER_ENTRY_POINT_ERROR;
297
+ }
298
+
287
299
  CUresult cuDevicePrimaryCtxRetain_f(CUcontext* ctx, CUdevice dev)
288
300
  {
289
301
  return pfn_cuDevicePrimaryCtxRetain ? pfn_cuDevicePrimaryCtxRetain(ctx, dev) : DRIVER_ENTRY_POINT_ERROR;
warp/native/cuda_util.h CHANGED
@@ -51,6 +51,7 @@ CUresult cuDeviceGet_f(CUdevice *dev, int ordinal);
51
51
  CUresult cuDeviceGetCount_f(int* count);
52
52
  CUresult cuDeviceGetName_f(char* name, int len, CUdevice dev);
53
53
  CUresult cuDeviceGetAttribute_f(int* value, CUdevice_attribute attrib, CUdevice dev);
54
+ CUresult cuDeviceGetUuid_f(CUuuid* uuid, CUdevice dev);
54
55
  CUresult cuDevicePrimaryCtxRetain_f(CUcontext* ctx, CUdevice dev);
55
56
  CUresult cuDevicePrimaryCtxRelease_f(CUdevice dev);
56
57
  CUresult cuDeviceCanAccessPeer_f(int* can_access, CUdevice dev, CUdevice peer_dev);
@@ -83,6 +84,7 @@ CUresult cuGraphicsUnregisterResource_f(CUgraphicsResource resource);
83
84
 
84
85
 
85
86
  bool init_cuda_driver();
87
+ bool is_cuda_driver_initialized();
86
88
 
87
89
  bool check_cuda_result(cudaError_t code, const char* file, int line);
88
90
  inline bool check_cuda_result(uint64_t code, const char* file, int line)
@@ -166,6 +168,6 @@ public:
166
168
  #endif // WP_ENABLE_CUDA
167
169
 
168
170
  // Pass this value to device functions as the `context` parameter to bypass unnecessary context management.
169
- // This works in conjuntion with ContextGuards, which do nothing if the given context is NULL.
171
+ // This works in conjunction with ContextGuards, which do nothing if the given context is NULL.
170
172
  // Using this variable instead of passing NULL directly aids readability and makes the intent clear.
171
173
  constexpr void* WP_CURRENT_CONTEXT = NULL;
@@ -0,0 +1,463 @@
1
+ #
2
+ # \file generator.py
3
+ #
4
+ # \brief Generates the CUTLASS Library's instances
5
+ #
6
+ #
7
+
8
+ import enum
9
+ import os.path
10
+ import shutil
11
+
12
+ from library import *
13
+
14
+ ###################################################################################################
15
+
16
+ #
17
+ class Conv2dOperation:
18
+ #
19
+ def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \
20
+ stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1, \
21
+ group_mode = GroupMode.NoneGroup):
22
+
23
+ self.operation_kind = OperationKind.Conv2d
24
+ self.arch = arch
25
+ self.tile_description = tile_description
26
+ self.conv_kind = conv_kind
27
+ self.A = A
28
+ self.B = B
29
+ self.C = C
30
+ self.element_epilogue = element_epilogue
31
+ self.epilogue_functor = epilogue_functor
32
+ self.iterator_algorithm = iterator_algorithm
33
+ self.stride_support = stride_support
34
+ self.swizzling_functor = swizzling_functor
35
+ self.group_mode = group_mode
36
+ #
37
+ def is_complex(self):
38
+ complex_operators = [
39
+ MathOperation.multiply_add_complex,
40
+ MathOperation.multiply_add_complex_gaussian
41
+ ]
42
+ return self.tile_description.math_instruction.math_operation in complex_operators
43
+
44
+ #
45
+ def accumulator_type(self):
46
+ accum = self.tile_description.math_instruction.element_accumulator
47
+
48
+ if self.is_complex():
49
+ return get_complex_from_real(accum)
50
+
51
+ return accum
52
+
53
+ #
54
+ def core_name(self):
55
+ ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
56
+
57
+ intermediate_type = ''
58
+
59
+ if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
60
+ inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
61
+ if self.tile_description.math_instruction.element_a != self.A.element and \
62
+ self.tile_description.math_instruction.element_a != self.accumulator_type():
63
+ intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
64
+ else:
65
+ inst_shape = ''
66
+
67
+ return "%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \
68
+ inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm])
69
+
70
+ #
71
+ def extended_name(self):
72
+ ''' Append data types if they differ from compute type. '''
73
+ if self.C.element != self.tile_description.math_instruction.element_accumulator and \
74
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
75
+ extended_name = "${element_c}_${core_name}_${element_a}"
76
+ elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
77
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
78
+ extended_name = "${core_name}_${element_a}"
79
+ else:
80
+ extended_name = "${core_name}"
81
+
82
+ extended_name = SubstituteTemplate(extended_name, {
83
+ 'element_a': DataTypeNames[self.A.element],
84
+ 'element_c': DataTypeNames[self.C.element],
85
+ 'core_name': self.core_name()
86
+ })
87
+
88
+ return extended_name
89
+
90
+ #
91
+ def layout_name(self):
92
+ return "%s" % (ShortLayoutTypeNames[self.A.layout])
93
+
94
+ #
95
+ def configuration_name(self):
96
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
97
+
98
+ opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
99
+
100
+ threadblock = self.tile_description.procedural_name()
101
+
102
+ # grouped conv
103
+ if self.group_mode != GroupMode.NoneGroup:
104
+ group_conv_name = f"{GroupModeNames[self.group_mode]}_"
105
+ else:
106
+ group_conv_name = ""
107
+
108
+ if self.stride_support == StrideSupport.Unity:
109
+ configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}"
110
+ else:
111
+ configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}"
112
+
113
+ return SubstituteTemplate(
114
+ configuration_name,
115
+ {
116
+ 'opcode_class': opcode_class_name,
117
+ 'extended_name': self.extended_name(),
118
+ 'threadblock': threadblock,
119
+ 'layout': self.layout_name(),
120
+ 'alignment': "%d" % self.A.alignment,
121
+ 'group_conv_name': group_conv_name
122
+ }
123
+ )
124
+
125
+ #
126
+ def procedural_name(self):
127
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
128
+ return self.configuration_name()
129
+
130
+ ###################################################################################################
131
+ #
132
+ # Emits single instances of a CUTLASS device-wide operator
133
+ #
134
+ ###################################################################################################
135
+
136
+ class EmitConv2dInstance:
137
+ def __init__(self):
138
+ self.template = """
139
+ // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
140
+ using ${operation_name}_base =
141
+ typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
142
+ ${element_a},
143
+ ${layout_a},
144
+ ${element_b},
145
+ ${layout_b},
146
+ ${element_c},
147
+ ${layout_c},
148
+ ${element_accumulator},
149
+ ${opcode_class},
150
+ ${arch},
151
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
152
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
153
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
154
+ ${epilogue_functor}<
155
+ ${element_c},
156
+ ${epilogue_vector_length},
157
+ ${element_accumulator},
158
+ ${element_epilogue}
159
+ >,
160
+ ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
161
+ ${stages},
162
+ ${math_operator},
163
+ ${iterator_algorithm},
164
+ ${stride_support},
165
+ ${align_a},
166
+ ${align_b}
167
+ >::Kernel;
168
+ """
169
+ self.template_group_conv = """
170
+ // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
171
+ using ${operation_name}_base =
172
+ typename cutlass::conv::kernel::DefaultConv2dGroup${conv_kind_name}<
173
+ ${element_a},
174
+ ${layout_a},
175
+ ${element_b},
176
+ ${layout_b},
177
+ ${element_c},
178
+ ${layout_c},
179
+ ${element_accumulator},
180
+ ${opcode_class},
181
+ ${arch},
182
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
183
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
184
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
185
+ ${epilogue_functor}<
186
+ ${element_c},
187
+ ${epilogue_vector_length},
188
+ ${element_accumulator},
189
+ ${element_epilogue}
190
+ >,
191
+ ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
192
+ ${stages},
193
+ ${math_operator},
194
+ ${group_mode},
195
+ ${iterator_algorithm},
196
+ ${stride_support},
197
+ ${align_a},
198
+ ${align_b}
199
+ >::Kernel;
200
+ """
201
+ self.template_depthwise_direct_conv = """
202
+ // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
203
+ using ${operation_name}_base =
204
+ typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConv${conv_kind_name}<
205
+ ${element_a},
206
+ ${layout_a},
207
+ ${element_b},
208
+ ${layout_b},
209
+ ${element_c},
210
+ ${layout_c},
211
+ ${element_accumulator},
212
+ ${opcode_class},
213
+ ${arch},
214
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
215
+ cutlass::conv::TensorNHWCShape<${threadblock_output_shape_n}, ${threadblock_output_shape_p}, ${threadblock_output_shape_q}, ${groups_per_cta}>,
216
+ cutlass::MatrixShape<${filter_shape_r}, ${filter_shape_s}>,
217
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
218
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
219
+ ${epilogue_functor}<
220
+ ${element_c},
221
+ ${epilogue_vector_length},
222
+ ${element_accumulator},
223
+ ${element_epilogue},
224
+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
225
+ >,
226
+
227
+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
228
+ 1,
229
+ ${threadblock_output_shape_n},
230
+ ${threadblock_output_shape_p},
231
+ ${threadblock_output_shape_q}>,
232
+ ${stages},
233
+ ${math_operator},
234
+ ${iterator_algorithm},
235
+ ${stride_support},
236
+ cutlass::MatrixShape<${stride_r}, ${stride_s}>,
237
+ cutlass::MatrixShape<${dilation_r}, ${dilation_s}>
238
+ >::Kernel;
239
+ """
240
+
241
+ def emit(self, operation):
242
+
243
+ warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
244
+
245
+ epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
246
+
247
+ values = {
248
+ 'operation_name': operation.procedural_name(),
249
+ 'conv_kind': ConvKindTag[operation.conv_kind],
250
+ 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
251
+ 'element_a': DataTypeTag[operation.A.element],
252
+ 'layout_a': LayoutTag[operation.A.layout],
253
+ 'element_b': DataTypeTag[operation.B.element],
254
+ 'layout_b': LayoutTag[operation.B.layout],
255
+ 'element_c': DataTypeTag[operation.C.element],
256
+ 'layout_c': LayoutTag[operation.C.layout],
257
+ 'element_accumulator': DataTypeTag[operation.accumulator_type()],
258
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
259
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
260
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
261
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
262
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
263
+ 'warp_shape_m': str(warp_shape[0]),
264
+ 'warp_shape_n': str(warp_shape[1]),
265
+ 'warp_shape_k': str(warp_shape[2]),
266
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
267
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
268
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
269
+ 'epilogue_vector_length': str(epilogue_vector_length),
270
+ 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
271
+ 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
272
+ 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
273
+ 'stages': str(operation.tile_description.stages),
274
+ 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
275
+ 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
276
+ 'stride_support': StrideSupportTag[operation.stride_support],
277
+ 'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else \
278
+ MathOperationTag[operation.tile_description.math_instruction.math_operation],
279
+ 'align_a': str(operation.A.alignment),
280
+ 'align_b': str(operation.B.alignment),
281
+ }
282
+
283
+ if operation.group_mode == GroupMode.NoneGroup:
284
+ return SubstituteTemplate(self.template, values)
285
+
286
+ elif operation.group_mode == GroupMode.Depthwise:
287
+ values['group_mode'] = GroupModeTag[operation.group_mode]
288
+ # Setup other template params
289
+ values['threadblock_output_shape_n'] = str(operation.tile_description.threadblock_output_shape[0])
290
+ values['threadblock_output_shape_p'] = str(operation.tile_description.threadblock_output_shape[1])
291
+ values['threadblock_output_shape_q'] = str(operation.tile_description.threadblock_output_shape[2])
292
+
293
+ values['groups_per_cta'] = str(operation.tile_description.threadblock_output_shape[3])
294
+
295
+ values['filter_shape_r'] = str(operation.tile_description.filter_shape[0])
296
+ values['filter_shape_s'] = str(operation.tile_description.filter_shape[1])
297
+
298
+ values['stride_r'] = str(operation.tile_description.stride[0])
299
+ values['stride_s'] = str(operation.tile_description.stride[1])
300
+
301
+ values['dilation_r'] = str(operation.tile_description.dilation[0])
302
+ values['dilation_s'] = str(operation.tile_description.dilation[1])
303
+
304
+ return SubstituteTemplate(self.template_depthwise_direct_conv, values)
305
+
306
+ else:
307
+ values['group_mode'] = GroupModeTag[operation.group_mode]
308
+ return SubstituteTemplate(self.template_group_conv, values)
309
+
310
+ ###################################################################################################
311
+ #
312
+ # Generator functions for all layouts
313
+ #
314
+ ###################################################################################################
315
+
316
+ #
317
+ def GenerateConv2dTensorOp(manifest, tile_descriptions, min_cc, align = 128):
318
+
319
+ for tile in tile_descriptions:
320
+ for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
321
+
322
+ if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]):
323
+
324
+ #
325
+ output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
326
+ if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
327
+ else [tile.math_instruction.element_accumulator,]
328
+
329
+ for output_type in output_types:
330
+ A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_a]))
331
+ B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_b]))
332
+ C = TensorDescription(output_type, LayoutType.TensorNHWC, max(1, int(align / DataTypeSize[output_type])))
333
+
334
+ manifest.append(Conv2dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator))
335
+
336
+ ###################################################################################################
337
+ #
338
+ # Emitters functions for all targets
339
+ #
340
+ ###################################################################################################
341
+
342
+ class EmitConv2dConfigurationLibrary:
343
+ def __init__(self, operation_path, configuration_name):
344
+ self.configuration_name = configuration_name
345
+ self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
346
+
347
+ self.instance_emitter = EmitConv2dInstance()
348
+
349
+ self.instance_template = """
350
+ ${operation_instance}
351
+
352
+ // Derived class
353
+ struct ${operation_name} :
354
+ public ${operation_name}_base { };
355
+
356
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
357
+
358
+ """
359
+ self.header_template = """
360
+ /*
361
+ Generated by conv2d_operation.py - Do not edit.
362
+ */
363
+
364
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
365
+
366
+ #include "cutlass/cutlass.h"
367
+ #include "cutlass/library/library.h"
368
+ #include "cutlass/library/manifest.h"
369
+
370
+ #include "library_internal.h"
371
+ #include "conv2d_operation.h"
372
+
373
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
374
+ """
375
+
376
+ self.configuration_header = """
377
+
378
+ namespace cutlass {
379
+ namespace library {
380
+
381
+ // Initialize all instances
382
+ void initialize_${configuration_name}(Manifest &manifest) {
383
+
384
+ """
385
+
386
+ self.configuration_instance = """
387
+ using Operation_${operation_name} = cutlass::conv::device::ImplicitGemmConvolution<
388
+ ${operation_name}>;
389
+
390
+ manifest.append(new cutlass::library::Conv2dOperation<
391
+ Operation_${operation_name}>(
392
+ "${operation_name}"));
393
+
394
+ """
395
+
396
+ self.configuration_direct_conv_instance = """
397
+ using Operation_${operation_name} = cutlass::conv::device::DirectConvolution<
398
+ ${operation_name}>;
399
+
400
+ manifest.append(new cutlass::library::DirectConv2dOperation<
401
+ Operation_${operation_name}>(
402
+ "${operation_name}"));
403
+
404
+ """
405
+
406
+ self.configuration_epilogue = """
407
+ }
408
+ """
409
+ self.epilogue_template = """
410
+
411
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
412
+
413
+ } // namespace library
414
+ } // namespace cutlass
415
+
416
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
417
+
418
+ """
419
+
420
+ #
421
+ def __enter__(self):
422
+ self.configuration_file = open(self.configuration_path, "w")
423
+ self.configuration_file.write(SubstituteTemplate(self.header_template, {
424
+ 'configuration_name': self.configuration_name
425
+ }))
426
+ self.operations = []
427
+ return self
428
+
429
+ #
430
+ def emit(self, operation):
431
+ self.operations.append(operation)
432
+ self.configuration_file.write(SubstituteTemplate(self.instance_template, {
433
+ 'configuration_name': self.configuration_name,
434
+ 'operation_name': operation.procedural_name(),
435
+ 'operation_instance': self.instance_emitter.emit(operation)
436
+ }))
437
+
438
+ #
439
+ def __exit__(self, exception_type, exception_value, traceback):
440
+
441
+ self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
442
+ 'configuration_name': self.configuration_name
443
+ }))
444
+
445
+ for operation in self.operations:
446
+ if operation.group_mode == GroupMode.Depthwise:
447
+ self.configuration_file.write(SubstituteTemplate(self.configuration_direct_conv_instance, {
448
+ 'configuration_name': self.configuration_name,
449
+ 'operation_name': operation.procedural_name()
450
+ }))
451
+ else:
452
+ self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
453
+ 'configuration_name': self.configuration_name,
454
+ 'operation_name': operation.procedural_name()
455
+ }))
456
+
457
+ self.configuration_file.write(self.configuration_epilogue)
458
+ self.configuration_file.write(self.epilogue_template)
459
+ self.configuration_file.close()
460
+
461
+
462
+ ###################################################################################################
463
+ ###################################################################################################