warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +15 -7
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +22 -443
- warp/build_dll.py +384 -0
- warp/builtins.py +998 -488
- warp/codegen.py +1307 -739
- warp/config.py +5 -3
- warp/constants.py +6 -0
- warp/context.py +1291 -548
- warp/dlpack.py +31 -31
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +164 -55
- warp/native/builtin.h +150 -174
- warp/native/bvh.cpp +75 -328
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +37 -45
- warp/native/clang/clang.cpp +136 -24
- warp/native/crt.cpp +1 -76
- warp/native/crt.h +111 -104
- warp/native/cuda_crt.h +1049 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -949
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/initializer_array.h +2 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +13 -16
- warp/native/marching.cu +157 -161
- warp/native/mat.h +119 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +108 -83
- warp/native/mesh.cu +243 -6
- warp/native/mesh.h +1547 -458
- warp/native/nanovdb/NanoVDB.h +1 -1
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +157 -0
- warp/native/reduce.cu +348 -0
- warp/native/runlength_encode.cpp +62 -0
- warp/native/runlength_encode.cu +46 -0
- warp/native/scan.cu +11 -13
- warp/native/scan.h +1 -0
- warp/native/solid_angle.h +442 -0
- warp/native/sort.cpp +13 -0
- warp/native/sort.cu +9 -1
- warp/native/sparse.cpp +338 -0
- warp/native/sparse.cu +545 -0
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +30 -0
- warp/native/vec.h +126 -24
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +658 -53
- warp/native/warp.cu +660 -68
- warp/native/warp.h +112 -12
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +392 -152
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +21 -8
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +161 -19
- warp/sim/model.py +795 -291
- warp/sim/optimizer.py +2 -6
- warp/sim/render.py +65 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +1227 -0
- warp/stubs.py +665 -223
- warp/tape.py +66 -15
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +1497 -211
- warp/tests/test_array_reduce.py +150 -0
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +75 -43
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +233 -128
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +136 -108
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -74
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +180 -116
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +577 -24
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +251 -15
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +508 -2778
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +325 -34
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +190 -0
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +460 -0
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +331 -85
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -1987
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +72 -30
- warp/types.py +1744 -713
- warp/utils.py +360 -350
- warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/bin/warp-clang.exp +0 -0
- warp/bin/warp-clang.lib +0 -0
- warp/bin/warp.exp +0 -0
- warp/bin/warp.lib +0 -0
- warp/tests/test_all.py +0 -215
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.9.0.dist-info/METADATA +0 -20
- warp_lang-0.9.0.dist-info/RECORD +0 -177
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/cuda_util.cpp
CHANGED
|
@@ -59,6 +59,7 @@ static PFN_cuDeviceGet_v2000 pfn_cuDeviceGet;
|
|
|
59
59
|
static PFN_cuDeviceGetCount_v2000 pfn_cuDeviceGetCount;
|
|
60
60
|
static PFN_cuDeviceGetName_v2000 pfn_cuDeviceGetName;
|
|
61
61
|
static PFN_cuDeviceGetAttribute_v2000 pfn_cuDeviceGetAttribute;
|
|
62
|
+
static PFN_cuDeviceGetUuid_v11040 pfn_cuDeviceGetUuid;
|
|
62
63
|
static PFN_cuDevicePrimaryCtxRetain_v7000 pfn_cuDevicePrimaryCtxRetain;
|
|
63
64
|
static PFN_cuDevicePrimaryCtxRelease_v11000 pfn_cuDevicePrimaryCtxRelease;
|
|
64
65
|
static PFN_cuDeviceCanAccessPeer_v4000 pfn_cuDeviceCanAccessPeer;
|
|
@@ -89,6 +90,7 @@ static PFN_cuGraphicsResourceGetMappedPointer_v3020 pfn_cuGraphicsResourceGetMap
|
|
|
89
90
|
static PFN_cuGraphicsGLRegisterBuffer_v3000 pfn_cuGraphicsGLRegisterBuffer;
|
|
90
91
|
static PFN_cuGraphicsUnregisterResource_v3000 pfn_cuGraphicsUnregisterResource;
|
|
91
92
|
|
|
93
|
+
static bool cuda_driver_initialized = false;
|
|
92
94
|
|
|
93
95
|
bool ContextGuard::always_restore = false;
|
|
94
96
|
|
|
@@ -165,6 +167,7 @@ bool init_cuda_driver()
|
|
|
165
167
|
get_driver_entry_point("cuDeviceGetCount", &(void*&)pfn_cuDeviceGetCount);
|
|
166
168
|
get_driver_entry_point("cuDeviceGetName", &(void*&)pfn_cuDeviceGetName);
|
|
167
169
|
get_driver_entry_point("cuDeviceGetAttribute", &(void*&)pfn_cuDeviceGetAttribute);
|
|
170
|
+
get_driver_entry_point("cuDeviceGetUuid", &(void*&)pfn_cuDeviceGetUuid);
|
|
168
171
|
get_driver_entry_point("cuDevicePrimaryCtxRetain", &(void*&)pfn_cuDevicePrimaryCtxRetain);
|
|
169
172
|
get_driver_entry_point("cuDevicePrimaryCtxRelease", &(void*&)pfn_cuDevicePrimaryCtxRelease);
|
|
170
173
|
get_driver_entry_point("cuDeviceCanAccessPeer", &(void*&)pfn_cuDeviceCanAccessPeer);
|
|
@@ -196,11 +199,15 @@ bool init_cuda_driver()
|
|
|
196
199
|
get_driver_entry_point("cuGraphicsUnregisterResource", &(void*&)pfn_cuGraphicsUnregisterResource);
|
|
197
200
|
|
|
198
201
|
if (pfn_cuInit)
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
+
cuda_driver_initialized = check_cu(pfn_cuInit(0));
|
|
203
|
+
|
|
204
|
+
return cuda_driver_initialized;
|
|
202
205
|
}
|
|
203
206
|
|
|
207
|
+
bool is_cuda_driver_initialized()
|
|
208
|
+
{
|
|
209
|
+
return cuda_driver_initialized;
|
|
210
|
+
}
|
|
204
211
|
|
|
205
212
|
bool check_cuda_result(cudaError_t code, const char* file, int line)
|
|
206
213
|
{
|
|
@@ -284,6 +291,11 @@ CUresult cuDeviceGetAttribute_f(int* value, CUdevice_attribute attrib, CUdevice
|
|
|
284
291
|
return pfn_cuDeviceGetAttribute ? pfn_cuDeviceGetAttribute(value, attrib, dev) : DRIVER_ENTRY_POINT_ERROR;
|
|
285
292
|
}
|
|
286
293
|
|
|
294
|
+
CUresult cuDeviceGetUuid_f(CUuuid* uuid, CUdevice dev)
|
|
295
|
+
{
|
|
296
|
+
return pfn_cuDeviceGetUuid ? pfn_cuDeviceGetUuid(uuid, dev) : DRIVER_ENTRY_POINT_ERROR;
|
|
297
|
+
}
|
|
298
|
+
|
|
287
299
|
CUresult cuDevicePrimaryCtxRetain_f(CUcontext* ctx, CUdevice dev)
|
|
288
300
|
{
|
|
289
301
|
return pfn_cuDevicePrimaryCtxRetain ? pfn_cuDevicePrimaryCtxRetain(ctx, dev) : DRIVER_ENTRY_POINT_ERROR;
|
warp/native/cuda_util.h
CHANGED
|
@@ -51,6 +51,7 @@ CUresult cuDeviceGet_f(CUdevice *dev, int ordinal);
|
|
|
51
51
|
CUresult cuDeviceGetCount_f(int* count);
|
|
52
52
|
CUresult cuDeviceGetName_f(char* name, int len, CUdevice dev);
|
|
53
53
|
CUresult cuDeviceGetAttribute_f(int* value, CUdevice_attribute attrib, CUdevice dev);
|
|
54
|
+
CUresult cuDeviceGetUuid_f(CUuuid* uuid, CUdevice dev);
|
|
54
55
|
CUresult cuDevicePrimaryCtxRetain_f(CUcontext* ctx, CUdevice dev);
|
|
55
56
|
CUresult cuDevicePrimaryCtxRelease_f(CUdevice dev);
|
|
56
57
|
CUresult cuDeviceCanAccessPeer_f(int* can_access, CUdevice dev, CUdevice peer_dev);
|
|
@@ -83,6 +84,7 @@ CUresult cuGraphicsUnregisterResource_f(CUgraphicsResource resource);
|
|
|
83
84
|
|
|
84
85
|
|
|
85
86
|
bool init_cuda_driver();
|
|
87
|
+
bool is_cuda_driver_initialized();
|
|
86
88
|
|
|
87
89
|
bool check_cuda_result(cudaError_t code, const char* file, int line);
|
|
88
90
|
inline bool check_cuda_result(uint64_t code, const char* file, int line)
|
|
@@ -166,6 +168,6 @@ public:
|
|
|
166
168
|
#endif // WP_ENABLE_CUDA
|
|
167
169
|
|
|
168
170
|
// Pass this value to device functions as the `context` parameter to bypass unnecessary context management.
|
|
169
|
-
// This works in
|
|
171
|
+
// This works in conjunction with ContextGuards, which do nothing if the given context is NULL.
|
|
170
172
|
// Using this variable instead of passing NULL directly aids readability and makes the intent clear.
|
|
171
173
|
constexpr void* WP_CURRENT_CONTEXT = NULL;
|
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
#
|
|
2
|
+
# \file generator.py
|
|
3
|
+
#
|
|
4
|
+
# \brief Generates the CUTLASS Library's instances
|
|
5
|
+
#
|
|
6
|
+
#
|
|
7
|
+
|
|
8
|
+
import enum
|
|
9
|
+
import os.path
|
|
10
|
+
import shutil
|
|
11
|
+
|
|
12
|
+
from library import *
|
|
13
|
+
|
|
14
|
+
###################################################################################################
|
|
15
|
+
|
|
16
|
+
#
|
|
17
|
+
class Conv2dOperation:
|
|
18
|
+
#
|
|
19
|
+
def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \
|
|
20
|
+
stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1, \
|
|
21
|
+
group_mode = GroupMode.NoneGroup):
|
|
22
|
+
|
|
23
|
+
self.operation_kind = OperationKind.Conv2d
|
|
24
|
+
self.arch = arch
|
|
25
|
+
self.tile_description = tile_description
|
|
26
|
+
self.conv_kind = conv_kind
|
|
27
|
+
self.A = A
|
|
28
|
+
self.B = B
|
|
29
|
+
self.C = C
|
|
30
|
+
self.element_epilogue = element_epilogue
|
|
31
|
+
self.epilogue_functor = epilogue_functor
|
|
32
|
+
self.iterator_algorithm = iterator_algorithm
|
|
33
|
+
self.stride_support = stride_support
|
|
34
|
+
self.swizzling_functor = swizzling_functor
|
|
35
|
+
self.group_mode = group_mode
|
|
36
|
+
#
|
|
37
|
+
def is_complex(self):
|
|
38
|
+
complex_operators = [
|
|
39
|
+
MathOperation.multiply_add_complex,
|
|
40
|
+
MathOperation.multiply_add_complex_gaussian
|
|
41
|
+
]
|
|
42
|
+
return self.tile_description.math_instruction.math_operation in complex_operators
|
|
43
|
+
|
|
44
|
+
#
|
|
45
|
+
def accumulator_type(self):
|
|
46
|
+
accum = self.tile_description.math_instruction.element_accumulator
|
|
47
|
+
|
|
48
|
+
if self.is_complex():
|
|
49
|
+
return get_complex_from_real(accum)
|
|
50
|
+
|
|
51
|
+
return accum
|
|
52
|
+
|
|
53
|
+
#
|
|
54
|
+
def core_name(self):
|
|
55
|
+
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
|
56
|
+
|
|
57
|
+
intermediate_type = ''
|
|
58
|
+
|
|
59
|
+
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
|
|
60
|
+
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
|
61
|
+
if self.tile_description.math_instruction.element_a != self.A.element and \
|
|
62
|
+
self.tile_description.math_instruction.element_a != self.accumulator_type():
|
|
63
|
+
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
|
64
|
+
else:
|
|
65
|
+
inst_shape = ''
|
|
66
|
+
|
|
67
|
+
return "%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \
|
|
68
|
+
inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm])
|
|
69
|
+
|
|
70
|
+
#
|
|
71
|
+
def extended_name(self):
|
|
72
|
+
''' Append data types if they differ from compute type. '''
|
|
73
|
+
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
|
74
|
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
|
75
|
+
extended_name = "${element_c}_${core_name}_${element_a}"
|
|
76
|
+
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
|
77
|
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
|
78
|
+
extended_name = "${core_name}_${element_a}"
|
|
79
|
+
else:
|
|
80
|
+
extended_name = "${core_name}"
|
|
81
|
+
|
|
82
|
+
extended_name = SubstituteTemplate(extended_name, {
|
|
83
|
+
'element_a': DataTypeNames[self.A.element],
|
|
84
|
+
'element_c': DataTypeNames[self.C.element],
|
|
85
|
+
'core_name': self.core_name()
|
|
86
|
+
})
|
|
87
|
+
|
|
88
|
+
return extended_name
|
|
89
|
+
|
|
90
|
+
#
|
|
91
|
+
def layout_name(self):
|
|
92
|
+
return "%s" % (ShortLayoutTypeNames[self.A.layout])
|
|
93
|
+
|
|
94
|
+
#
|
|
95
|
+
def configuration_name(self):
|
|
96
|
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
|
97
|
+
|
|
98
|
+
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
|
99
|
+
|
|
100
|
+
threadblock = self.tile_description.procedural_name()
|
|
101
|
+
|
|
102
|
+
# grouped conv
|
|
103
|
+
if self.group_mode != GroupMode.NoneGroup:
|
|
104
|
+
group_conv_name = f"{GroupModeNames[self.group_mode]}_"
|
|
105
|
+
else:
|
|
106
|
+
group_conv_name = ""
|
|
107
|
+
|
|
108
|
+
if self.stride_support == StrideSupport.Unity:
|
|
109
|
+
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}"
|
|
110
|
+
else:
|
|
111
|
+
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}"
|
|
112
|
+
|
|
113
|
+
return SubstituteTemplate(
|
|
114
|
+
configuration_name,
|
|
115
|
+
{
|
|
116
|
+
'opcode_class': opcode_class_name,
|
|
117
|
+
'extended_name': self.extended_name(),
|
|
118
|
+
'threadblock': threadblock,
|
|
119
|
+
'layout': self.layout_name(),
|
|
120
|
+
'alignment': "%d" % self.A.alignment,
|
|
121
|
+
'group_conv_name': group_conv_name
|
|
122
|
+
}
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
#
|
|
126
|
+
def procedural_name(self):
|
|
127
|
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
|
128
|
+
return self.configuration_name()
|
|
129
|
+
|
|
130
|
+
###################################################################################################
|
|
131
|
+
#
|
|
132
|
+
# Emits single instances of a CUTLASS device-wide operator
|
|
133
|
+
#
|
|
134
|
+
###################################################################################################
|
|
135
|
+
|
|
136
|
+
class EmitConv2dInstance:
|
|
137
|
+
def __init__(self):
|
|
138
|
+
self.template = """
|
|
139
|
+
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
|
140
|
+
using ${operation_name}_base =
|
|
141
|
+
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
|
|
142
|
+
${element_a},
|
|
143
|
+
${layout_a},
|
|
144
|
+
${element_b},
|
|
145
|
+
${layout_b},
|
|
146
|
+
${element_c},
|
|
147
|
+
${layout_c},
|
|
148
|
+
${element_accumulator},
|
|
149
|
+
${opcode_class},
|
|
150
|
+
${arch},
|
|
151
|
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
|
152
|
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
|
153
|
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
|
154
|
+
${epilogue_functor}<
|
|
155
|
+
${element_c},
|
|
156
|
+
${epilogue_vector_length},
|
|
157
|
+
${element_accumulator},
|
|
158
|
+
${element_epilogue}
|
|
159
|
+
>,
|
|
160
|
+
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
|
161
|
+
${stages},
|
|
162
|
+
${math_operator},
|
|
163
|
+
${iterator_algorithm},
|
|
164
|
+
${stride_support},
|
|
165
|
+
${align_a},
|
|
166
|
+
${align_b}
|
|
167
|
+
>::Kernel;
|
|
168
|
+
"""
|
|
169
|
+
self.template_group_conv = """
|
|
170
|
+
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
|
171
|
+
using ${operation_name}_base =
|
|
172
|
+
typename cutlass::conv::kernel::DefaultConv2dGroup${conv_kind_name}<
|
|
173
|
+
${element_a},
|
|
174
|
+
${layout_a},
|
|
175
|
+
${element_b},
|
|
176
|
+
${layout_b},
|
|
177
|
+
${element_c},
|
|
178
|
+
${layout_c},
|
|
179
|
+
${element_accumulator},
|
|
180
|
+
${opcode_class},
|
|
181
|
+
${arch},
|
|
182
|
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
|
183
|
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
|
184
|
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
|
185
|
+
${epilogue_functor}<
|
|
186
|
+
${element_c},
|
|
187
|
+
${epilogue_vector_length},
|
|
188
|
+
${element_accumulator},
|
|
189
|
+
${element_epilogue}
|
|
190
|
+
>,
|
|
191
|
+
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
|
192
|
+
${stages},
|
|
193
|
+
${math_operator},
|
|
194
|
+
${group_mode},
|
|
195
|
+
${iterator_algorithm},
|
|
196
|
+
${stride_support},
|
|
197
|
+
${align_a},
|
|
198
|
+
${align_b}
|
|
199
|
+
>::Kernel;
|
|
200
|
+
"""
|
|
201
|
+
self.template_depthwise_direct_conv = """
|
|
202
|
+
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
|
203
|
+
using ${operation_name}_base =
|
|
204
|
+
typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConv${conv_kind_name}<
|
|
205
|
+
${element_a},
|
|
206
|
+
${layout_a},
|
|
207
|
+
${element_b},
|
|
208
|
+
${layout_b},
|
|
209
|
+
${element_c},
|
|
210
|
+
${layout_c},
|
|
211
|
+
${element_accumulator},
|
|
212
|
+
${opcode_class},
|
|
213
|
+
${arch},
|
|
214
|
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
|
215
|
+
cutlass::conv::TensorNHWCShape<${threadblock_output_shape_n}, ${threadblock_output_shape_p}, ${threadblock_output_shape_q}, ${groups_per_cta}>,
|
|
216
|
+
cutlass::MatrixShape<${filter_shape_r}, ${filter_shape_s}>,
|
|
217
|
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
|
218
|
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
|
219
|
+
${epilogue_functor}<
|
|
220
|
+
${element_c},
|
|
221
|
+
${epilogue_vector_length},
|
|
222
|
+
${element_accumulator},
|
|
223
|
+
${element_epilogue},
|
|
224
|
+
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
|
225
|
+
>,
|
|
226
|
+
|
|
227
|
+
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
|
|
228
|
+
1,
|
|
229
|
+
${threadblock_output_shape_n},
|
|
230
|
+
${threadblock_output_shape_p},
|
|
231
|
+
${threadblock_output_shape_q}>,
|
|
232
|
+
${stages},
|
|
233
|
+
${math_operator},
|
|
234
|
+
${iterator_algorithm},
|
|
235
|
+
${stride_support},
|
|
236
|
+
cutlass::MatrixShape<${stride_r}, ${stride_s}>,
|
|
237
|
+
cutlass::MatrixShape<${dilation_r}, ${dilation_s}>
|
|
238
|
+
>::Kernel;
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
def emit(self, operation):
|
|
242
|
+
|
|
243
|
+
warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
|
|
244
|
+
|
|
245
|
+
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
|
246
|
+
|
|
247
|
+
values = {
|
|
248
|
+
'operation_name': operation.procedural_name(),
|
|
249
|
+
'conv_kind': ConvKindTag[operation.conv_kind],
|
|
250
|
+
'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
|
|
251
|
+
'element_a': DataTypeTag[operation.A.element],
|
|
252
|
+
'layout_a': LayoutTag[operation.A.layout],
|
|
253
|
+
'element_b': DataTypeTag[operation.B.element],
|
|
254
|
+
'layout_b': LayoutTag[operation.B.layout],
|
|
255
|
+
'element_c': DataTypeTag[operation.C.element],
|
|
256
|
+
'layout_c': LayoutTag[operation.C.layout],
|
|
257
|
+
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
|
258
|
+
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
|
259
|
+
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
|
260
|
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
|
261
|
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
|
262
|
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
|
263
|
+
'warp_shape_m': str(warp_shape[0]),
|
|
264
|
+
'warp_shape_n': str(warp_shape[1]),
|
|
265
|
+
'warp_shape_k': str(warp_shape[2]),
|
|
266
|
+
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
|
267
|
+
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
|
268
|
+
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
|
269
|
+
'epilogue_vector_length': str(epilogue_vector_length),
|
|
270
|
+
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
|
271
|
+
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
|
272
|
+
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
|
273
|
+
'stages': str(operation.tile_description.stages),
|
|
274
|
+
'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
|
|
275
|
+
'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
|
|
276
|
+
'stride_support': StrideSupportTag[operation.stride_support],
|
|
277
|
+
'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else \
|
|
278
|
+
MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
|
279
|
+
'align_a': str(operation.A.alignment),
|
|
280
|
+
'align_b': str(operation.B.alignment),
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
if operation.group_mode == GroupMode.NoneGroup:
|
|
284
|
+
return SubstituteTemplate(self.template, values)
|
|
285
|
+
|
|
286
|
+
elif operation.group_mode == GroupMode.Depthwise:
|
|
287
|
+
values['group_mode'] = GroupModeTag[operation.group_mode]
|
|
288
|
+
# Setup other template params
|
|
289
|
+
values['threadblock_output_shape_n'] = str(operation.tile_description.threadblock_output_shape[0])
|
|
290
|
+
values['threadblock_output_shape_p'] = str(operation.tile_description.threadblock_output_shape[1])
|
|
291
|
+
values['threadblock_output_shape_q'] = str(operation.tile_description.threadblock_output_shape[2])
|
|
292
|
+
|
|
293
|
+
values['groups_per_cta'] = str(operation.tile_description.threadblock_output_shape[3])
|
|
294
|
+
|
|
295
|
+
values['filter_shape_r'] = str(operation.tile_description.filter_shape[0])
|
|
296
|
+
values['filter_shape_s'] = str(operation.tile_description.filter_shape[1])
|
|
297
|
+
|
|
298
|
+
values['stride_r'] = str(operation.tile_description.stride[0])
|
|
299
|
+
values['stride_s'] = str(operation.tile_description.stride[1])
|
|
300
|
+
|
|
301
|
+
values['dilation_r'] = str(operation.tile_description.dilation[0])
|
|
302
|
+
values['dilation_s'] = str(operation.tile_description.dilation[1])
|
|
303
|
+
|
|
304
|
+
return SubstituteTemplate(self.template_depthwise_direct_conv, values)
|
|
305
|
+
|
|
306
|
+
else:
|
|
307
|
+
values['group_mode'] = GroupModeTag[operation.group_mode]
|
|
308
|
+
return SubstituteTemplate(self.template_group_conv, values)
|
|
309
|
+
|
|
310
|
+
###################################################################################################
|
|
311
|
+
#
|
|
312
|
+
# Generator functions for all layouts
|
|
313
|
+
#
|
|
314
|
+
###################################################################################################
|
|
315
|
+
|
|
316
|
+
#
|
|
317
|
+
def GenerateConv2dTensorOp(manifest, tile_descriptions, min_cc, align = 128):
|
|
318
|
+
|
|
319
|
+
for tile in tile_descriptions:
|
|
320
|
+
for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
|
|
321
|
+
|
|
322
|
+
if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]):
|
|
323
|
+
|
|
324
|
+
#
|
|
325
|
+
output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
|
|
326
|
+
if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
|
|
327
|
+
else [tile.math_instruction.element_accumulator,]
|
|
328
|
+
|
|
329
|
+
for output_type in output_types:
|
|
330
|
+
A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_a]))
|
|
331
|
+
B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_b]))
|
|
332
|
+
C = TensorDescription(output_type, LayoutType.TensorNHWC, max(1, int(align / DataTypeSize[output_type])))
|
|
333
|
+
|
|
334
|
+
manifest.append(Conv2dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator))
|
|
335
|
+
|
|
336
|
+
###################################################################################################
|
|
337
|
+
#
|
|
338
|
+
# Emitters functions for all targets
|
|
339
|
+
#
|
|
340
|
+
###################################################################################################
|
|
341
|
+
|
|
342
|
+
class EmitConv2dConfigurationLibrary:
|
|
343
|
+
def __init__(self, operation_path, configuration_name):
|
|
344
|
+
self.configuration_name = configuration_name
|
|
345
|
+
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
|
|
346
|
+
|
|
347
|
+
self.instance_emitter = EmitConv2dInstance()
|
|
348
|
+
|
|
349
|
+
self.instance_template = """
|
|
350
|
+
${operation_instance}
|
|
351
|
+
|
|
352
|
+
// Derived class
|
|
353
|
+
struct ${operation_name} :
|
|
354
|
+
public ${operation_name}_base { };
|
|
355
|
+
|
|
356
|
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
357
|
+
|
|
358
|
+
"""
|
|
359
|
+
self.header_template = """
|
|
360
|
+
/*
|
|
361
|
+
Generated by conv2d_operation.py - Do not edit.
|
|
362
|
+
*/
|
|
363
|
+
|
|
364
|
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
365
|
+
|
|
366
|
+
#include "cutlass/cutlass.h"
|
|
367
|
+
#include "cutlass/library/library.h"
|
|
368
|
+
#include "cutlass/library/manifest.h"
|
|
369
|
+
|
|
370
|
+
#include "library_internal.h"
|
|
371
|
+
#include "conv2d_operation.h"
|
|
372
|
+
|
|
373
|
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
374
|
+
"""
|
|
375
|
+
|
|
376
|
+
self.configuration_header = """
|
|
377
|
+
|
|
378
|
+
namespace cutlass {
|
|
379
|
+
namespace library {
|
|
380
|
+
|
|
381
|
+
// Initialize all instances
|
|
382
|
+
void initialize_${configuration_name}(Manifest &manifest) {
|
|
383
|
+
|
|
384
|
+
"""
|
|
385
|
+
|
|
386
|
+
self.configuration_instance = """
|
|
387
|
+
using Operation_${operation_name} = cutlass::conv::device::ImplicitGemmConvolution<
|
|
388
|
+
${operation_name}>;
|
|
389
|
+
|
|
390
|
+
manifest.append(new cutlass::library::Conv2dOperation<
|
|
391
|
+
Operation_${operation_name}>(
|
|
392
|
+
"${operation_name}"));
|
|
393
|
+
|
|
394
|
+
"""
|
|
395
|
+
|
|
396
|
+
self.configuration_direct_conv_instance = """
|
|
397
|
+
using Operation_${operation_name} = cutlass::conv::device::DirectConvolution<
|
|
398
|
+
${operation_name}>;
|
|
399
|
+
|
|
400
|
+
manifest.append(new cutlass::library::DirectConv2dOperation<
|
|
401
|
+
Operation_${operation_name}>(
|
|
402
|
+
"${operation_name}"));
|
|
403
|
+
|
|
404
|
+
"""
|
|
405
|
+
|
|
406
|
+
self.configuration_epilogue = """
|
|
407
|
+
}
|
|
408
|
+
"""
|
|
409
|
+
self.epilogue_template = """
|
|
410
|
+
|
|
411
|
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
412
|
+
|
|
413
|
+
} // namespace library
|
|
414
|
+
} // namespace cutlass
|
|
415
|
+
|
|
416
|
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
417
|
+
|
|
418
|
+
"""
|
|
419
|
+
|
|
420
|
+
#
|
|
421
|
+
def __enter__(self):
|
|
422
|
+
self.configuration_file = open(self.configuration_path, "w")
|
|
423
|
+
self.configuration_file.write(SubstituteTemplate(self.header_template, {
|
|
424
|
+
'configuration_name': self.configuration_name
|
|
425
|
+
}))
|
|
426
|
+
self.operations = []
|
|
427
|
+
return self
|
|
428
|
+
|
|
429
|
+
#
|
|
430
|
+
def emit(self, operation):
|
|
431
|
+
self.operations.append(operation)
|
|
432
|
+
self.configuration_file.write(SubstituteTemplate(self.instance_template, {
|
|
433
|
+
'configuration_name': self.configuration_name,
|
|
434
|
+
'operation_name': operation.procedural_name(),
|
|
435
|
+
'operation_instance': self.instance_emitter.emit(operation)
|
|
436
|
+
}))
|
|
437
|
+
|
|
438
|
+
#
|
|
439
|
+
def __exit__(self, exception_type, exception_value, traceback):
|
|
440
|
+
|
|
441
|
+
self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
|
|
442
|
+
'configuration_name': self.configuration_name
|
|
443
|
+
}))
|
|
444
|
+
|
|
445
|
+
for operation in self.operations:
|
|
446
|
+
if operation.group_mode == GroupMode.Depthwise:
|
|
447
|
+
self.configuration_file.write(SubstituteTemplate(self.configuration_direct_conv_instance, {
|
|
448
|
+
'configuration_name': self.configuration_name,
|
|
449
|
+
'operation_name': operation.procedural_name()
|
|
450
|
+
}))
|
|
451
|
+
else:
|
|
452
|
+
self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
|
|
453
|
+
'configuration_name': self.configuration_name,
|
|
454
|
+
'operation_name': operation.procedural_name()
|
|
455
|
+
}))
|
|
456
|
+
|
|
457
|
+
self.configuration_file.write(self.configuration_epilogue)
|
|
458
|
+
self.configuration_file.write(self.epilogue_template)
|
|
459
|
+
self.configuration_file.close()
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
###################################################################################################
|
|
463
|
+
###################################################################################################
|