mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +5 -2
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
- mindspore/_extends/parse/parser.py +28 -22
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +23 -2
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
- mindspore/amp.py +0 -18
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/__init__.py +18 -12
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +371 -96
- mindspore/common/_utils.py +7 -43
- mindspore/common/api.py +434 -135
- mindspore/common/dtype.py +98 -57
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/hook_handle.py +82 -3
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +17 -127
- mindspore/common/recompute.py +4 -13
- mindspore/common/tensor.py +50 -217
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +20 -106
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +35 -1
- mindspore/dataset/engine/datasets.py +338 -319
- mindspore/dataset/engine/datasets_user_defined.py +38 -22
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/api/cell.h +28 -4
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +0 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +5 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +6 -1
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +4 -3
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/__init__.py +4 -0
- mindspore/mint/distributed/distributed.py +392 -69
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/_functions.py +1 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +10 -10
- mindspore/mint/nn/layer/normalization.py +11 -16
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +231 -239
- mindspore/nn/layer/activation.py +4 -2
- mindspore/nn/layer/basic.py +56 -14
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +32 -127
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +1 -4
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +2 -4
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +39 -5
- mindspore/nn/wrap/grad_reducer.py +4 -89
- mindspore/numpy/array_creations.py +4 -4
- mindspore/numpy/fft.py +9 -9
- mindspore/numpy/utils_const.py +1 -1
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +1 -5
- mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
- mindspore/ops/auto_generate/gen_extend_func.py +6 -11
- mindspore/ops/auto_generate/gen_ops_def.py +385 -154
- mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +16 -2
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +2 -0
- mindspore/ops/function/array_func.py +24 -18
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +7 -6
- mindspore/ops/function/grad/grad_func.py +4 -12
- mindspore/ops/function/math_func.py +89 -86
- mindspore/ops/function/nn_func.py +92 -313
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +4 -1
- mindspore/ops/functional_overload.py +377 -30
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +12 -50
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +5 -50
- mindspore/ops/operations/comm_ops.py +95 -17
- mindspore/ops/operations/custom_ops.py +237 -22
- mindspore/ops/operations/debug_ops.py +33 -35
- mindspore/ops/operations/manually_defined/ops_def.py +39 -318
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +3 -3
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +4 -27
- mindspore/ops/tensor_method.py +88 -10
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_auto_parallel_context.py +5 -15
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +4 -6
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_utils.py +34 -17
- mindspore/parallel/auto_parallel.py +23 -9
- mindspore/parallel/checkpoint_transform.py +20 -2
- mindspore/parallel/cluster/process_entity/_api.py +28 -33
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/parallel/cluster/run.py +5 -3
- mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/function/reshard_func.py +6 -5
- mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
- mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
- mindspore/parallel/shard.py +7 -21
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +127 -20
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +40 -4
- mindspore/profiler/common/path_manager.py +65 -24
- mindspore/profiler/common/profiler_context.py +27 -14
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +10 -6
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/dynamic_profiler.py +91 -46
- mindspore/profiler/envprofiler.py +30 -5
- mindspore/profiler/experimental_config.py +18 -2
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +34 -7
- mindspore/profiler/profiler.py +193 -145
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +9 -6
- mindspore/runtime/executor.py +35 -0
- mindspore/runtime/memory.py +113 -0
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +8 -21
- mindspore/train/amp.py +6 -7
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +1 -17
- mindspore/train/callback/_flops_collector.py +10 -6
- mindspore/train/callback/_train_fault_tolerance.py +72 -25
- mindspore/train/data_sink.py +5 -9
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +41 -230
- mindspore/train/serialization.py +160 -401
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +152 -16
- mindspore/version.py +1 -1
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/__init__.py +0 -23
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/train/memory_profiling_pb2.py +0 -298
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -16,199 +16,9 @@
|
|
|
16
16
|
"""Inner operators for reinforcement learning."""
|
|
17
17
|
|
|
18
18
|
from __future__ import absolute_import
|
|
19
|
-
import functools
|
|
20
|
-
from mindspore.common.dtype import type_size_in_bytes
|
|
21
|
-
import mindspore.context as context
|
|
22
19
|
from mindspore import _checkparam as validator
|
|
23
20
|
from mindspore.common import dtype as mstype
|
|
24
21
|
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer, Primitive
|
|
25
|
-
from mindspore.communication.management import GlobalComm
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class EnvCreate(PrimitiveWithInfer):
|
|
29
|
-
r"""
|
|
30
|
-
Create a built-in reinforcement learning environment. Repeated calls to the operator will return the previously
|
|
31
|
-
created handle. Make sure to create a new operator instance if you want to create a new environment instance.
|
|
32
|
-
|
|
33
|
-
.. warning::
|
|
34
|
-
This is an experimental API that is subject to change or deletion.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
name (str): Name of built-in environment.
|
|
38
|
-
kwargs (any): Environment related parameters.
|
|
39
|
-
|
|
40
|
-
Inputs:
|
|
41
|
-
No inputs.
|
|
42
|
-
|
|
43
|
-
Outputs:
|
|
44
|
-
handle(Tensor): Handle of created environment instance with dtype int and shape (1,).
|
|
45
|
-
|
|
46
|
-
Raises:
|
|
47
|
-
TypeError: The environment not supported.
|
|
48
|
-
TypeError: The environment parameters not provided.
|
|
49
|
-
|
|
50
|
-
Supported Platforms:
|
|
51
|
-
``GPU``
|
|
52
|
-
"""
|
|
53
|
-
|
|
54
|
-
def __init__(self, name, **kwargs):
|
|
55
|
-
super(EnvCreate, self).__init__(self.__class__.__name__)
|
|
56
|
-
self.add_prim_attr('name', name)
|
|
57
|
-
for key in kwargs:
|
|
58
|
-
self.add_prim_attr(key, kwargs[key])
|
|
59
|
-
|
|
60
|
-
def infer_shape(self, *args):
|
|
61
|
-
return (1,)
|
|
62
|
-
|
|
63
|
-
def infer_dtype(self, *args):
|
|
64
|
-
return mstype.int64
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
class EnvReset(PrimitiveWithInfer):
|
|
68
|
-
r"""
|
|
69
|
-
Reset reinforcement learning built-in environment.
|
|
70
|
-
|
|
71
|
-
.. warning::
|
|
72
|
-
This is an experimental API that is subject to change or deletion.
|
|
73
|
-
|
|
74
|
-
Args:
|
|
75
|
-
handle (int): The handle returned by `EnvCreate` operator.
|
|
76
|
-
state_shape (list[tuple[int]]): The dimensionality of the state.
|
|
77
|
-
state_dtype (list[:class:`mindspore.dtype`]): The type of the state.
|
|
78
|
-
reward_shape (list[tuple[int]]): The dimensionality of the reward.
|
|
79
|
-
reward_dtype (list[:class:`mindspore.dtype`]): The type of the reward.echo
|
|
80
|
-
|
|
81
|
-
Inputs:
|
|
82
|
-
No inputs.
|
|
83
|
-
|
|
84
|
-
Outputs:
|
|
85
|
-
Tensor, environment observation after reset.
|
|
86
|
-
|
|
87
|
-
Raises:
|
|
88
|
-
TypeError: Environment instance not exist.
|
|
89
|
-
|
|
90
|
-
Supported Platforms:
|
|
91
|
-
``GPU``
|
|
92
|
-
"""
|
|
93
|
-
|
|
94
|
-
@prim_attr_register
|
|
95
|
-
def __init__(self, handle, state_shape, state_dtype):
|
|
96
|
-
super(EnvReset, self).__init__(self.__class__.__name__)
|
|
97
|
-
validator.check_value_type("handle", handle, [int], self.name)
|
|
98
|
-
validator.check_value_type("state_shape", state_shape, [list, tuple], self.name)
|
|
99
|
-
|
|
100
|
-
def infer_shape(self, *args):
|
|
101
|
-
return self.state_shape
|
|
102
|
-
|
|
103
|
-
def infer_dtype(self, *args):
|
|
104
|
-
return self.state_dtype
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
class EnvStep(PrimitiveWithInfer):
|
|
108
|
-
r"""
|
|
109
|
-
Run one environment timestep.
|
|
110
|
-
|
|
111
|
-
.. warning::
|
|
112
|
-
This is an experimental API that is subject to change or deletion.
|
|
113
|
-
|
|
114
|
-
Args:
|
|
115
|
-
handle (int): The handle returned by `EnvCreate` operator.
|
|
116
|
-
state_shape (list[tuple[int]]): The dimensionality of the state.
|
|
117
|
-
state_dtype (list[:class:`mindspore.dtype`]): The type of the state.
|
|
118
|
-
reward_shape (list[tuple[int]]): The dimensionality of the reward.
|
|
119
|
-
reward_dtype (list[:class:`mindspore.dtype`]): The type of the reward.
|
|
120
|
-
|
|
121
|
-
Inputs:
|
|
122
|
-
- **action** (Tensor) - action
|
|
123
|
-
|
|
124
|
-
Outputs:
|
|
125
|
-
- **state** (Tensor) - Environment state after previous action.
|
|
126
|
-
- **reward** (Tensor), - Reward returned by environment.
|
|
127
|
-
- **done** (Tensor), whether the episode has ended.
|
|
128
|
-
|
|
129
|
-
Raises:
|
|
130
|
-
TypeError: If dtype of `handle` is not int.
|
|
131
|
-
TypeError: If dtype of `state_shape` is neither tuple nor list.
|
|
132
|
-
TypeError: If dtype of `state_dtype` is not int nor float.
|
|
133
|
-
TypeError: If dtype of `state_shape` is neither tuple nor list.
|
|
134
|
-
TypeError: If dtype of `reward_dtype` is not int nor float.
|
|
135
|
-
|
|
136
|
-
Supported Platforms:
|
|
137
|
-
``GPU``
|
|
138
|
-
"""
|
|
139
|
-
|
|
140
|
-
@prim_attr_register
|
|
141
|
-
def __init__(self, handle, state_shape, state_dtype, reward_shape, reward_dtype):
|
|
142
|
-
super(EnvStep, self).__init__(self.__class__.__name__)
|
|
143
|
-
validator.check_value_type("handle", handle, [int], self.name)
|
|
144
|
-
validator.check_value_type("state_shape", state_shape, [list, tuple], self.name)
|
|
145
|
-
validator.check_value_type("reward_shape", reward_shape, [list, tuple], self.name)
|
|
146
|
-
|
|
147
|
-
def infer_shape(self, action_shape):
|
|
148
|
-
return self.state_shape, self.reward_shape, (self.state_shape[0],)
|
|
149
|
-
|
|
150
|
-
def infer_dtype(self, action_dtype):
|
|
151
|
-
return self.state_dtype, self.reward_dtype, mstype.bool_
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
class DiscountedReturn(PrimitiveWithInfer):
|
|
155
|
-
r"""
|
|
156
|
-
Calculate discounted return.
|
|
157
|
-
|
|
158
|
-
Set discounted return as :math:`G`, discounted factor as :math:`\gamma`, reward as :math:`R`,
|
|
159
|
-
timestep as :math:`t`, max timestep as :math:`N`. Then :math:`G_{t} = \Sigma_{t=0}^N{\gamma^tR_{t+1}}`
|
|
160
|
-
|
|
161
|
-
For the reward sequence contain multi-episode, :math:`done` is introduced for indicating episode boundary,
|
|
162
|
-
:math:`last\_state\_value` represents value after final step of last episode.
|
|
163
|
-
|
|
164
|
-
Args:
|
|
165
|
-
gamma (float): Discounted factor between [0, 1].
|
|
166
|
-
|
|
167
|
-
Inputs:
|
|
168
|
-
- **reward** (Tensor) - The reward sequence contains multi-episode.
|
|
169
|
-
Tensor of shape :math:`(Timestep, Batch, ...)`
|
|
170
|
-
- **done** (Tensor) - The episode done flag. Tensor of shape :math:`(Timestep, Batch)`.
|
|
171
|
-
The data type must be bool.
|
|
172
|
-
- **last_state_value** (Tensor) - The value after final step of last episode.
|
|
173
|
-
Tensor of shape :math:`(Batch, ...)`
|
|
174
|
-
|
|
175
|
-
Examples:
|
|
176
|
-
>>> net = DiscountedReturn(gamma=0.99)
|
|
177
|
-
>>> reward = Tensor([[1, 1, 1, 1]], dtype=mindspore.float32)
|
|
178
|
-
>>> done = Tensor([[False, False, True, False]])
|
|
179
|
-
>>> last_state_value = Tensor([2.], dtype=mindspore.float32)
|
|
180
|
-
>>> ret = net(reward, done, last_state_value)
|
|
181
|
-
>>> print(output.shape)
|
|
182
|
-
(2, 2)
|
|
183
|
-
"""
|
|
184
|
-
|
|
185
|
-
@prim_attr_register
|
|
186
|
-
def __init__(self, gamma):
|
|
187
|
-
self.init_prim_io_names(inputs=['reward', 'done', 'last_state_value'], outputs=['output'])
|
|
188
|
-
validator.check_float_range(gamma, 0, 1, validator.INC_RIGHT, "gamma", self.name)
|
|
189
|
-
|
|
190
|
-
def infer_shape(self, reward_shape, done_shape, last_state_value_shape):
|
|
191
|
-
if len(reward_shape) != len(done_shape):
|
|
192
|
-
raise ValueError(f'For \'{self.name}\', len(reward) and len(done) must be the same, ',
|
|
193
|
-
f'but got {len(reward_shape)} and {len(done_shape)}.')
|
|
194
|
-
|
|
195
|
-
if reward_shape[0] != done_shape[0]:
|
|
196
|
-
raise ValueError(f'For \'{self.name}\', the first element of the shape of \'reward\' '
|
|
197
|
-
f'and \'done\' must be the same, but got reward.shape[0]:'
|
|
198
|
-
f' {reward_shape[0]} and done.shape[0]: {done_shape[0]}.')
|
|
199
|
-
|
|
200
|
-
if reward_shape[1:] != last_state_value_shape:
|
|
201
|
-
raise ValueError(f'For \'{self.name}\', reward.shape[1:] and last_state_value.shape must be the same, '
|
|
202
|
-
f'but got reward.shape[1:]: {reward_shape[1:]} '
|
|
203
|
-
f'and last_state_value.shape: {last_state_value_shape}.')
|
|
204
|
-
return reward_shape
|
|
205
|
-
|
|
206
|
-
def infer_dtype(self, reward_dtype, done_dtype, last_state_value_dtype):
|
|
207
|
-
valid_dtypes = (mstype.float16, mstype.float32)
|
|
208
|
-
args = {"reward": reward_dtype, "last_state_value": last_state_value_dtype}
|
|
209
|
-
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
|
|
210
|
-
validator.check_tensor_dtype_valid('done_dtype', done_dtype, [mstype.bool_], self.name)
|
|
211
|
-
return reward_dtype
|
|
212
22
|
|
|
213
23
|
|
|
214
24
|
class GRUV2(PrimitiveWithInfer):
|
|
@@ -486,746 +296,3 @@ class CudnnGRU(Primitive):
|
|
|
486
296
|
self.num_directions = 2
|
|
487
297
|
else:
|
|
488
298
|
self.num_directions = 1
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
class PriorityReplayBufferCreate(PrimitiveWithInfer):
|
|
492
|
-
r"""
|
|
493
|
-
PriorityReplayBuffer is experience container used in Deep Q-Networks.
|
|
494
|
-
The algorithm is proposed in `Prioritized Experience Replay <https://arxiv.org/abs/1511.05952>`.
|
|
495
|
-
Same as the normal replay buffer, it lets the reinforcement learning agents remember and reuse experiences from the
|
|
496
|
-
past. Besides, it replays important transitions more frequently and improve sample efficiency.
|
|
497
|
-
|
|
498
|
-
Args:
|
|
499
|
-
capcity (int64): Capacity of the buffer. It is recommended that set capacity to pow(2, N).
|
|
500
|
-
alpha (float): The parameter determines how much prioritization is used between [0, 1].
|
|
501
|
-
beta (float): The parameter determines how much compensations for non-uniform probabilities between [0, 1].
|
|
502
|
-
shapes (list[tuple[int]]): The dimensionality of the transition.
|
|
503
|
-
dtypes (list[:class:`mindspore.dtype`]): The type of the transition.
|
|
504
|
-
seed0 (int): Random seed0, must be non-negative. Default: 0.
|
|
505
|
-
seed1 (int): Random seed1, must be non-negative. Default: 0.
|
|
506
|
-
|
|
507
|
-
Outputs:
|
|
508
|
-
handle(Tensor): Handle of created priority replay buffer instance with dtype int64 and shape (1,).
|
|
509
|
-
|
|
510
|
-
Raises:
|
|
511
|
-
TypeError: The args not provided.
|
|
512
|
-
|
|
513
|
-
Supported Platforms:
|
|
514
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
515
|
-
"""
|
|
516
|
-
|
|
517
|
-
@prim_attr_register
|
|
518
|
-
def __init__(self, capacity, alpha, shapes, dtypes, seed0, seed1):
|
|
519
|
-
"""Initialize PriorityReplaBufferCreate."""
|
|
520
|
-
validator.check_int(capacity, 1, validator.GE, "capacity", self.name)
|
|
521
|
-
validator.check_float_range(alpha, 0.0, 1.0, validator.INC_BOTH)
|
|
522
|
-
validator.check_value_type("shape of init data", shapes, [tuple, list], self.name)
|
|
523
|
-
validator.check_value_type("dtypes of init data", dtypes, [tuple, list], self.name)
|
|
524
|
-
validator.check_non_negative_int(seed0, "seed0", self.name)
|
|
525
|
-
validator.check_non_negative_int(seed1, "seed1", self.name)
|
|
526
|
-
|
|
527
|
-
schema = []
|
|
528
|
-
for shape, dtype in zip(shapes, dtypes):
|
|
529
|
-
num_element = functools.reduce(lambda x, y: x * y, shape, 1)
|
|
530
|
-
schema.append(num_element * type_size_in_bytes(dtype))
|
|
531
|
-
self.add_prim_attr("schema", schema)
|
|
532
|
-
|
|
533
|
-
def infer_shape(self):
|
|
534
|
-
return (1,)
|
|
535
|
-
|
|
536
|
-
def infer_dtype(self):
|
|
537
|
-
return mstype.int64
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
class PriorityReplayBufferPush(PrimitiveWithInfer):
|
|
541
|
-
r"""
|
|
542
|
-
Push a transition to the priority replay buffer.
|
|
543
|
-
|
|
544
|
-
Args:
|
|
545
|
-
handle(Tensor): Priority replay buffer instance handle with dtype int64 and shape (1,).
|
|
546
|
-
|
|
547
|
-
Outputs:
|
|
548
|
-
handle(Tensor): Priority replay buffer instance handle with dtype int64 and shape (1,).
|
|
549
|
-
|
|
550
|
-
Raises:
|
|
551
|
-
TypeError: The priority replay buffer not created before.
|
|
552
|
-
|
|
553
|
-
Supported Platforms:
|
|
554
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
555
|
-
"""
|
|
556
|
-
|
|
557
|
-
@prim_attr_register
|
|
558
|
-
def __init__(self, handle):
|
|
559
|
-
"""Initialize PriorityReplaBufferPush."""
|
|
560
|
-
validator.check_int(handle, 0, validator.GE, "handle", self.name)
|
|
561
|
-
|
|
562
|
-
def infer_shape(self, *inputs):
|
|
563
|
-
return (1,)
|
|
564
|
-
|
|
565
|
-
def infer_dtype(self, *inputs):
|
|
566
|
-
return mstype.int64
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
class PriorityReplayBufferSample(PrimitiveWithInfer):
|
|
570
|
-
r"""
|
|
571
|
-
Sample a transition to the priority replay buffer.
|
|
572
|
-
|
|
573
|
-
.. warning::
|
|
574
|
-
This is an experimental API that is subject to change or deletion.
|
|
575
|
-
|
|
576
|
-
Args:
|
|
577
|
-
handle(Tensor): Priority replay buffer instance handle with dtype int64 and shape (1,).
|
|
578
|
-
batch_size (int): The size of the sampled transitions.
|
|
579
|
-
shapes (list[tuple[int]]): The dimensionality of the transition.
|
|
580
|
-
dtypes (list[:class:`mindspore.dtype`]): The type of the transition.
|
|
581
|
-
|
|
582
|
-
Outputs:
|
|
583
|
-
tuple(Tensor): Transition with its indices and bias correction weights.
|
|
584
|
-
|
|
585
|
-
Raises:
|
|
586
|
-
TypeError: The priority replay buffer not created before.
|
|
587
|
-
|
|
588
|
-
Supported Platforms:
|
|
589
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
590
|
-
"""
|
|
591
|
-
|
|
592
|
-
@prim_attr_register
|
|
593
|
-
def __init__(self, handle, batch_size, shapes, dtypes):
|
|
594
|
-
"""Initialize PriorityReplaBufferSample."""
|
|
595
|
-
validator.check_int(handle, 0, validator.GE, "capacity", self.name)
|
|
596
|
-
validator.check_int(batch_size, 1, validator.GE, "batch_size", self.name)
|
|
597
|
-
validator.check_value_type("shape of init data", shapes, [tuple, list], self.name)
|
|
598
|
-
validator.check_value_type("dtypes of init data", dtypes, [tuple, list], self.name)
|
|
599
|
-
|
|
600
|
-
schema = []
|
|
601
|
-
for shape, dtype in zip(shapes, dtypes):
|
|
602
|
-
num_element = functools.reduce(lambda x, y: x * y, shape, 1)
|
|
603
|
-
schema.append(num_element * type_size_in_bytes(dtype))
|
|
604
|
-
self.add_prim_attr("schema", schema)
|
|
605
|
-
|
|
606
|
-
def infer_shape(self, beta):
|
|
607
|
-
output_shape = [(self.batch_size,), (self.batch_size,)]
|
|
608
|
-
for shape in self.shapes:
|
|
609
|
-
output_shape.append((self.batch_size,) + shape)
|
|
610
|
-
# indices, weights, transitions
|
|
611
|
-
return tuple(output_shape)
|
|
612
|
-
|
|
613
|
-
def infer_dtype(self, beta):
|
|
614
|
-
return (mstype.int64, mstype.float32) + self.dtypes
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
class PriorityReplayBufferUpdate(PrimitiveWithInfer):
|
|
618
|
-
r"""
|
|
619
|
-
Update transition prorities.
|
|
620
|
-
|
|
621
|
-
Args:
|
|
622
|
-
handle(Tensor): Priority replay buffer instance handle with dtype int64 and shape (1,).
|
|
623
|
-
|
|
624
|
-
Inputs:
|
|
625
|
-
- **indices** (Tensor) - transition indices.
|
|
626
|
-
- **priorities** (Tensor) - Transition priorities.
|
|
627
|
-
|
|
628
|
-
Outputs:
|
|
629
|
-
Priority replay buffer instance handle with dtype int64 and shape (1,).
|
|
630
|
-
|
|
631
|
-
Raises:
|
|
632
|
-
TypeError: The priority replay buffer not created before.
|
|
633
|
-
|
|
634
|
-
Supported Platforms:
|
|
635
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
636
|
-
"""
|
|
637
|
-
|
|
638
|
-
@prim_attr_register
|
|
639
|
-
def __init__(self, handle):
|
|
640
|
-
"""Initialize PriorityReplaBufferUpdate."""
|
|
641
|
-
validator.check_int(handle, 0, validator.GE, "capacity", self.name)
|
|
642
|
-
|
|
643
|
-
def infer_shape(self, indices, priorities):
|
|
644
|
-
return (1,)
|
|
645
|
-
|
|
646
|
-
def infer_dtype(self, indices, priorities):
|
|
647
|
-
return mstype.int64
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
class PriorityReplayBufferDestroy(PrimitiveWithInfer):
|
|
651
|
-
r"""
|
|
652
|
-
Destroy the replay buffer.
|
|
653
|
-
|
|
654
|
-
Args:
|
|
655
|
-
handle(Tensor): Priority replay buffer instance handle with dtype int64 and shape (1,).
|
|
656
|
-
|
|
657
|
-
Outputs:
|
|
658
|
-
Priority replay buffer instance handle with dtype int64 and shape (1,).
|
|
659
|
-
|
|
660
|
-
Raises:
|
|
661
|
-
TypeError: The priority replay buffer not created before.
|
|
662
|
-
|
|
663
|
-
Supported Platforms:
|
|
664
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
665
|
-
"""
|
|
666
|
-
|
|
667
|
-
@prim_attr_register
|
|
668
|
-
def __init__(self, handle):
|
|
669
|
-
"""Initialize PriorityReplayBufferDestroy."""
|
|
670
|
-
validator.check_int(handle, 0, validator.GE, "handle", self.name)
|
|
671
|
-
|
|
672
|
-
def infer_shape(self):
|
|
673
|
-
return (1,)
|
|
674
|
-
|
|
675
|
-
def infer_dtype(self):
|
|
676
|
-
return mstype.int64
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
class ReservoirReplayBufferCreate(Primitive):
|
|
680
|
-
r"""
|
|
681
|
-
ReservoirReplayBufferCreate is experience container used in reinforcement learning.
|
|
682
|
-
The algorithm is proposed in `Random sampling with a reservoir <https://dl.acm.org/doi/pdf/10.1145/3147.3165>`
|
|
683
|
-
which used in `Deep Counterfactual Regret Minimization <https://arxiv.org/abs/1811.00164>`.
|
|
684
|
-
It lets the reinforcement learning agents remember and reuse experiences from the past. Besides, It keeps an
|
|
685
|
-
'unbiased' sample of previous iterations.
|
|
686
|
-
|
|
687
|
-
Args:
|
|
688
|
-
capcity (int64): Capacity of the buffer.
|
|
689
|
-
shapes (list[tuple[int]]): The dimensionality of the transition.
|
|
690
|
-
dtypes (list[:class:`mindspore.dtype`]): The type of the transition.
|
|
691
|
-
seed0 (int): Random seed0, must be non-negative. Default: 0.
|
|
692
|
-
seed1 (int): Random seed1, must be non-negative. Default: 0.
|
|
693
|
-
|
|
694
|
-
Outputs:
|
|
695
|
-
handle(Tensor): Handle of created replay buffer instance with dtype int64 and shape (1,).
|
|
696
|
-
|
|
697
|
-
Raises:
|
|
698
|
-
TypeError: The args not provided.
|
|
699
|
-
|
|
700
|
-
Supported Platforms:
|
|
701
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
702
|
-
"""
|
|
703
|
-
|
|
704
|
-
@prim_attr_register
|
|
705
|
-
def __init__(self, capacity, shapes, dtypes, seed0, seed1):
|
|
706
|
-
"""Initialize ReservoirReplayBufferCreate."""
|
|
707
|
-
validator.check_int(capacity, 1, validator.GE, "capacity", self.name)
|
|
708
|
-
validator.check_value_type("shape of init data", shapes, [tuple, list], self.name)
|
|
709
|
-
validator.check_value_type("dtypes of init data", dtypes, [tuple, list], self.name)
|
|
710
|
-
validator.check_non_negative_int(seed0, "seed0", self.name)
|
|
711
|
-
validator.check_non_negative_int(seed1, "seed1", self.name)
|
|
712
|
-
|
|
713
|
-
schema = []
|
|
714
|
-
for shape, dtype in zip(shapes, dtypes):
|
|
715
|
-
num_element = functools.reduce(lambda x, y: x * y, shape, 1)
|
|
716
|
-
schema.append(num_element * type_size_in_bytes(dtype))
|
|
717
|
-
self.add_prim_attr("schema", schema)
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
class ReservoirReplayBufferPush(Primitive):
|
|
721
|
-
r"""
|
|
722
|
-
Push a transition to the replay buffer.
|
|
723
|
-
|
|
724
|
-
Args:
|
|
725
|
-
handle(Tensor): The replay buffer instance handle with dtype int64 and shape (1,).
|
|
726
|
-
|
|
727
|
-
Outputs:
|
|
728
|
-
handle(Tensor): The replay buffer instance handle with dtype int64 and shape (1,).
|
|
729
|
-
|
|
730
|
-
Raises:
|
|
731
|
-
TypeError: The replay buffer not created before.
|
|
732
|
-
|
|
733
|
-
Supported Platforms:
|
|
734
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
735
|
-
"""
|
|
736
|
-
|
|
737
|
-
@prim_attr_register
|
|
738
|
-
def __init__(self, handle):
|
|
739
|
-
"""Initialize ReservoirReplayBufferPush."""
|
|
740
|
-
validator.check_int(handle, 0, validator.GE, "handle", self.name)
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
class ReservoirReplayBufferSample(Primitive):
|
|
744
|
-
r"""
|
|
745
|
-
Sample a transition to the replay buffer.
|
|
746
|
-
|
|
747
|
-
.. warning::
|
|
748
|
-
This is an experimental API that is subject to change or deletion.
|
|
749
|
-
|
|
750
|
-
Args:
|
|
751
|
-
handle(Tensor): Priority replay buffer instance handle with dtype int64 and shape (1,).
|
|
752
|
-
batch_size (int): The size of the sampled transitions.
|
|
753
|
-
shapes (list[tuple[int]]): The dimensionality of the transition.
|
|
754
|
-
dtypes (list[:class:`mindspore.dtype`]): The type of the transition.
|
|
755
|
-
|
|
756
|
-
Outputs:
|
|
757
|
-
tuple(Tensor): Transition with its indices and bias correction weights.
|
|
758
|
-
|
|
759
|
-
Raises:
|
|
760
|
-
TypeError: The replay buffer not created before.
|
|
761
|
-
|
|
762
|
-
Supported Platforms:
|
|
763
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
764
|
-
"""
|
|
765
|
-
|
|
766
|
-
@prim_attr_register
|
|
767
|
-
def __init__(self, handle, batch_size, shapes, dtypes):
|
|
768
|
-
"""Initialize PriorityReplaBufferSample."""
|
|
769
|
-
validator.check_int(handle, 0, validator.GE, "capacity", self.name)
|
|
770
|
-
validator.check_int(batch_size, 1, validator.GE, "batch_size", self.name)
|
|
771
|
-
validator.check_value_type("shape of init data", shapes, [tuple, list], self.name)
|
|
772
|
-
validator.check_value_type("dtypes of init data", dtypes, [tuple, list], self.name)
|
|
773
|
-
|
|
774
|
-
schema = []
|
|
775
|
-
for shape, dtype in zip(shapes, dtypes):
|
|
776
|
-
num_element = functools.reduce(lambda x, y: x * y, shape, 1)
|
|
777
|
-
schema.append(num_element * type_size_in_bytes(dtype))
|
|
778
|
-
self.add_prim_attr("schema", schema)
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
class ReservoirReplayBufferDestroy(PrimitiveWithInfer):
|
|
782
|
-
r"""
|
|
783
|
-
Destroy the replay buffer.
|
|
784
|
-
|
|
785
|
-
Args:
|
|
786
|
-
handle(Tensor): The Replay buffer instance handle with dtype int64 and shape (1,).
|
|
787
|
-
|
|
788
|
-
Outputs:
|
|
789
|
-
Replay buffer instance handle with dtype int64 and shape (1,).
|
|
790
|
-
|
|
791
|
-
Raises:
|
|
792
|
-
TypeError: The replay buffer not created before.
|
|
793
|
-
|
|
794
|
-
Supported Platforms:
|
|
795
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
796
|
-
"""
|
|
797
|
-
|
|
798
|
-
@prim_attr_register
|
|
799
|
-
def __init__(self, handle):
|
|
800
|
-
"""Initialize ReservoirReplayBufferDestroy."""
|
|
801
|
-
validator.check_int(handle, 0, validator.GE, "handle", self.name)
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
class BatchAssign(PrimitiveWithInfer):
|
|
805
|
-
"""
|
|
806
|
-
Assign the parameters of the source to overwrite the target.
|
|
807
|
-
|
|
808
|
-
Args:
|
|
809
|
-
lock (bool): Lock when the operator is Write, else shared the mutex. Default: ``True``.
|
|
810
|
-
|
|
811
|
-
Inputs:
|
|
812
|
-
- **dst_model** (tuple) - A parameters tuple of the dst model.
|
|
813
|
-
- **source_model** (tuple) - A parameters tuple of the source model.
|
|
814
|
-
|
|
815
|
-
Outputs:
|
|
816
|
-
None.
|
|
817
|
-
|
|
818
|
-
Raises:
|
|
819
|
-
TypeError: If `lock` is not a bool.
|
|
820
|
-
ValueError: If elements shape between inputs are not the same.
|
|
821
|
-
TypeError: If inputs are not in Tensor type.
|
|
822
|
-
|
|
823
|
-
Supported Platforms:
|
|
824
|
-
``GPU`` ``CPU``
|
|
825
|
-
"""
|
|
826
|
-
|
|
827
|
-
@prim_attr_register
|
|
828
|
-
def __init__(self, lock=True):
|
|
829
|
-
"""Initialize BatchAssign."""
|
|
830
|
-
self.lock = validator.check_value_type("lock", lock, (bool,), self.name)
|
|
831
|
-
self.add_prim_attr("lock", self.lock)
|
|
832
|
-
self.add_prim_attr('side_effect_mem', True)
|
|
833
|
-
if context.get_context('device_target') == "Ascend":
|
|
834
|
-
self.add_prim_attr('device_target', "CPU")
|
|
835
|
-
|
|
836
|
-
def infer_shape(self, dst_shape, source_shape):
|
|
837
|
-
validator.check_equal_int(len(dst_shape), len(source_shape), "inputs elements", self.name)
|
|
838
|
-
for i, shp in enumerate(dst_shape):
|
|
839
|
-
if shp != source_shape[i]:
|
|
840
|
-
raise ValueError(f'{self.name} element must be same, ',
|
|
841
|
-
f'but got {shp} and {dst_shape[i]}.')
|
|
842
|
-
return []
|
|
843
|
-
|
|
844
|
-
def infer_dtype(self, dst_dtype, source_dtype):
|
|
845
|
-
for i, dst_type in enumerate(dst_dtype):
|
|
846
|
-
args = {'dst': dst_type, 'source': source_dtype[i]}
|
|
847
|
-
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
|
|
848
|
-
return mstype.int64
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
class TensorsQueueCreate(PrimitiveWithInfer):
|
|
852
|
-
r"""
|
|
853
|
-
TensorsQueueCreate used to create a TensorsQueue and return an unique handle.
|
|
854
|
-
|
|
855
|
-
.. warning::
|
|
856
|
-
This is an experimental API that is subject to change or deletion.
|
|
857
|
-
|
|
858
|
-
Args:
|
|
859
|
-
dtype (mindspore.dtype): the data type in the TensorsQueue.
|
|
860
|
-
shapes (tuple(tuple(int))): the shape of each tensor in element.
|
|
861
|
-
size (int): The size of the TensorsQueue.
|
|
862
|
-
name (str): the name of this TensorsQueue. Default: "Q".
|
|
863
|
-
|
|
864
|
-
Inputs:
|
|
865
|
-
None.
|
|
866
|
-
|
|
867
|
-
Outputs:
|
|
868
|
-
- **output** (Tensor[mindspore.int64]) - an unique handle binded to the TensorsQueue.
|
|
869
|
-
|
|
870
|
-
Supported Platforms:
|
|
871
|
-
``GPU`` ``CPU``
|
|
872
|
-
|
|
873
|
-
Examples:
|
|
874
|
-
>>> import mindspore
|
|
875
|
-
>>> import mindspore.ops.operations._rl_inner_ops as rl_ops
|
|
876
|
-
>>> create_op = rl_ops.TensorsQueueCreate(mindspore.float32,((), (1, 16)), 10, "q")
|
|
877
|
-
>>> handle = create_op()
|
|
878
|
-
>>> print(handle)
|
|
879
|
-
0
|
|
880
|
-
"""
|
|
881
|
-
@prim_attr_register
|
|
882
|
-
def __init__(self, dtype, shapes, size=0, name="Q"):
|
|
883
|
-
validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
|
|
884
|
-
validator.check_int(size, 0, validator.GE, "size", self.name)
|
|
885
|
-
elements_num = len(shapes)
|
|
886
|
-
validator.check_int(elements_num, 1, validator.GE, "elements_num", self.name)
|
|
887
|
-
self.add_prim_attr('shapes', shapes)
|
|
888
|
-
self.add_prim_attr('dtype', dtype)
|
|
889
|
-
self.add_prim_attr('elements_num', elements_num)
|
|
890
|
-
self.add_prim_attr('size', size)
|
|
891
|
-
self.add_prim_attr('side_effect_mem', True)
|
|
892
|
-
self.add_prim_attr('name', name)
|
|
893
|
-
|
|
894
|
-
def infer_shape(self):
|
|
895
|
-
return ()
|
|
896
|
-
|
|
897
|
-
def infer_dtype(self):
|
|
898
|
-
return mstype.int64
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
class TensorsQueuePut(PrimitiveWithInfer):
|
|
902
|
-
r"""
|
|
903
|
-
TensorsQueuePut used to put tensors into a created TensorsQueue.
|
|
904
|
-
|
|
905
|
-
.. warning::
|
|
906
|
-
This is an experimental API that is subject to change or deletion.
|
|
907
|
-
|
|
908
|
-
Args:
|
|
909
|
-
dtype (mindspore.dtype): the data type in the TensorsQueue.
|
|
910
|
-
shapes (tuple(tuple(int))): the shape of each tensor in element.
|
|
911
|
-
|
|
912
|
-
Inputs:
|
|
913
|
-
- **handle** (Tensor[int64]) - The handle pointed to the TensorsQueue.
|
|
914
|
-
- **value** (list[Tensor] or tuple(Tensors)) - The element to add into the TensorsQueue.
|
|
915
|
-
|
|
916
|
-
Outputs:
|
|
917
|
-
None.
|
|
918
|
-
|
|
919
|
-
Supported Platforms:
|
|
920
|
-
``GPU`` ``CPU``
|
|
921
|
-
|
|
922
|
-
Examples:
|
|
923
|
-
>>> import mindspore
|
|
924
|
-
>>> import mindspore.ops.operations._rl_inner_ops as rl_ops
|
|
925
|
-
>>> create_op = rl_ops.TensorsQueueCreate(mstype.float32, ((), (1, 16)), 10)
|
|
926
|
-
>>> handle = create_op()
|
|
927
|
-
>>> out_op = rl_ops.TensorsQueuePut(mstype.float32, ((), (1, 16)))
|
|
928
|
-
>>> out_op.put(handle, (Tensor(1, mstype.float32), Tensor(2, mstype.float32)))
|
|
929
|
-
"""
|
|
930
|
-
@prim_attr_register
|
|
931
|
-
def __init__(self, dtype, shapes):
|
|
932
|
-
validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
|
|
933
|
-
elements_num = len(shapes)
|
|
934
|
-
self.elements_num = validator.check_positive_int(elements_num, "elements_num", self.name)
|
|
935
|
-
self.shapes = shapes
|
|
936
|
-
self.add_prim_attr('dtype', dtype)
|
|
937
|
-
self.add_prim_attr('elements_num', elements_num)
|
|
938
|
-
self.add_prim_attr('side_effect_mem', True)
|
|
939
|
-
|
|
940
|
-
def infer_shape(self, handle_shape, elements_shape):
|
|
941
|
-
validator.check_equal_int(len(elements_shape), self.elements_num, "inputs elements", self.name)
|
|
942
|
-
for i, shape in enumerate(elements_shape):
|
|
943
|
-
if tuple(shape) != self.shapes[i]:
|
|
944
|
-
raise ValueError(f'{self.name} init shape and input shape must be the same, ',
|
|
945
|
-
f'but got {self.shapes[i]} and input {shape} in position {i}.')
|
|
946
|
-
return ()
|
|
947
|
-
|
|
948
|
-
def infer_dtype(self, handle_type, elements_type):
|
|
949
|
-
validator.check_type_name("handle", handle_type, (mstype.int64), self.name)
|
|
950
|
-
return mstype.int64
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
class TensorsQueueGet(PrimitiveWithInfer):
|
|
954
|
-
r"""
|
|
955
|
-
TensorsQueueGet used to get tensors in the front of the TensorsQueue.
|
|
956
|
-
|
|
957
|
-
.. warning::
|
|
958
|
-
This is an experimental API that is subject to change or deletion.
|
|
959
|
-
|
|
960
|
-
Args:
|
|
961
|
-
shapes (tuple(tuple(int))): the shape of each tensor in element.
|
|
962
|
-
dtype (mindspore.dtype): the data type in the TensorsQueue.
|
|
963
|
-
pop_after_get (bool): if true, pop the element from TensorsQueue after get.
|
|
964
|
-
|
|
965
|
-
Inputs:
|
|
966
|
-
- **handle** (Tensor[int64]) - The handle pointed to the TensorsQueue.
|
|
967
|
-
|
|
968
|
-
Outputs:
|
|
969
|
-
- **value** (list[Tensor] or tuple(Tensors)) - The element in the front of the TensorsQueue.
|
|
970
|
-
|
|
971
|
-
Supported Platforms:
|
|
972
|
-
``GPU`` ``CPU``
|
|
973
|
-
|
|
974
|
-
Examples:
|
|
975
|
-
>>> import mindspore
|
|
976
|
-
>>> import mindspore.ops.operations._rl_inner_ops as rl_ops
|
|
977
|
-
>>> create_op = rl_ops.TensorsQueueCreate(mstype.float32, ((), (1,2)), 10)
|
|
978
|
-
>>> handle = create_op()
|
|
979
|
-
>>> get_op = rl_ops.TensorsQueueGet(mstype.float32, ((), (1,2)))
|
|
980
|
-
>>> tensors_list = get_op.get(handle)
|
|
981
|
-
"""
|
|
982
|
-
@prim_attr_register
|
|
983
|
-
def __init__(self, dtype, shapes, pop_after_get=False):
|
|
984
|
-
validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
|
|
985
|
-
elements_num = len(shapes)
|
|
986
|
-
self.elements_num = validator.check_positive_int(elements_num, "elements_num", self.name)
|
|
987
|
-
validator.check_bool(pop_after_get, "pop_after_get", self.name)
|
|
988
|
-
self.shapes = shapes
|
|
989
|
-
self.dtype = dtype
|
|
990
|
-
self.add_prim_attr('dtype', dtype)
|
|
991
|
-
self.add_prim_attr("shapes", shapes)
|
|
992
|
-
self.add_prim_attr('elements_num', elements_num)
|
|
993
|
-
self.add_prim_attr("pop_after_get", pop_after_get)
|
|
994
|
-
self.add_prim_attr('side_effect_mem', True)
|
|
995
|
-
|
|
996
|
-
def infer_shape(self, handle_shape):
|
|
997
|
-
return tuple(self.shapes)
|
|
998
|
-
|
|
999
|
-
def infer_dtype(self, handle_type):
|
|
1000
|
-
validator.check_type_name("handle", handle_type, (mstype.int64), self.name)
|
|
1001
|
-
out_shape = []
|
|
1002
|
-
for _ in range(self.elements_num):
|
|
1003
|
-
out_shape.append(self.dtype)
|
|
1004
|
-
return tuple(out_shape)
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
class TensorsQueueClose(PrimitiveWithInfer):
|
|
1008
|
-
r"""
|
|
1009
|
-
TensorsQueueClose used to close the created TensorsQueue. The resources in TensorsQueue will be deleted.
|
|
1010
|
-
|
|
1011
|
-
.. warning::
|
|
1012
|
-
This is an experimental API that is subject to change or deletion.
|
|
1013
|
-
|
|
1014
|
-
Inputs:
|
|
1015
|
-
- **handle** (mindspore.int64) - The handle pointed to the TensorsQueue.
|
|
1016
|
-
|
|
1017
|
-
Outputs:
|
|
1018
|
-
None.
|
|
1019
|
-
|
|
1020
|
-
Supported Platforms:
|
|
1021
|
-
``GPU`` ``CPU``
|
|
1022
|
-
|
|
1023
|
-
Examples:
|
|
1024
|
-
>>> import mindspore
|
|
1025
|
-
>>> import mindspore.ops.operations._rl_inner_ops as rl_ops
|
|
1026
|
-
>>> create_op = rl_ops.TensorsQueueCreate(mindspore.float32, ((), (3, 3)), 10)
|
|
1027
|
-
>>> handle = create_op()
|
|
1028
|
-
>>> close_op = ops.TensorsQueueClose()
|
|
1029
|
-
>>> close_op(handle)
|
|
1030
|
-
"""
|
|
1031
|
-
@prim_attr_register
|
|
1032
|
-
def __init__(self):
|
|
1033
|
-
self.add_prim_attr('side_effect_mem', True)
|
|
1034
|
-
|
|
1035
|
-
def infer_shape(self, handle_shape):
|
|
1036
|
-
return ()
|
|
1037
|
-
|
|
1038
|
-
def infer_dtype(self, handle_type):
|
|
1039
|
-
validator.check_type_name("handle", handle_type, (mstype.int64), self.name)
|
|
1040
|
-
return mstype.int64
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
class TensorsQueueSize(PrimitiveWithInfer):
|
|
1044
|
-
r"""
|
|
1045
|
-
TensorsQueueSize used get the indeed size of TensorsQueue.
|
|
1046
|
-
|
|
1047
|
-
.. warning::
|
|
1048
|
-
This is an experimental API that is subject to change or deletion.
|
|
1049
|
-
|
|
1050
|
-
Inputs:
|
|
1051
|
-
- **handle** (mindspore.int64) - The handle pointed to the TensorsQueue.
|
|
1052
|
-
|
|
1053
|
-
Outputs:
|
|
1054
|
-
- **size** (mindspore.int64) - The used size of the TensorsQueue.
|
|
1055
|
-
|
|
1056
|
-
Supported Platforms:
|
|
1057
|
-
``GPU`` ``CPU``
|
|
1058
|
-
|
|
1059
|
-
Examples:
|
|
1060
|
-
>>> import mindspore
|
|
1061
|
-
>>> import mindspore.ops.operations._rl_inner_ops as rl_ops
|
|
1062
|
-
>>> create_op = rl_ops.TensorsQueueCreate(mindspore.int32, ((), (3, 2)), 10)
|
|
1063
|
-
>>> handle = create_op()
|
|
1064
|
-
>>> size_op = ops.TensorsQueueSize()
|
|
1065
|
-
>>> print(size_op())
|
|
1066
|
-
>>> 0
|
|
1067
|
-
"""
|
|
1068
|
-
@prim_attr_register
|
|
1069
|
-
def __init__(self):
|
|
1070
|
-
self.add_prim_attr('side_effect_mem', True)
|
|
1071
|
-
|
|
1072
|
-
def infer_shape(self, handle_shape):
|
|
1073
|
-
return ()
|
|
1074
|
-
|
|
1075
|
-
def infer_dtype(self, handle_type):
|
|
1076
|
-
validator.check_type_name("handle", handle_type, (mstype.int64), self.name)
|
|
1077
|
-
return mstype.int64
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
class TensorsQueueClear(PrimitiveWithInfer):
|
|
1081
|
-
r"""
|
|
1082
|
-
TensorsQueueClear used to reset the created TensorsQueue. The instance of TensorsQueue is still aviliable.
|
|
1083
|
-
|
|
1084
|
-
.. warning::
|
|
1085
|
-
This is an experimental API that is subject to change or deletion.
|
|
1086
|
-
|
|
1087
|
-
Inputs:
|
|
1088
|
-
- **handle** (mindspore.int64) - The handle pointed to the TensorsQueue.
|
|
1089
|
-
|
|
1090
|
-
Outputs:
|
|
1091
|
-
None.
|
|
1092
|
-
|
|
1093
|
-
Supported Platforms:
|
|
1094
|
-
``GPU`` ``CPU``
|
|
1095
|
-
|
|
1096
|
-
Examples:
|
|
1097
|
-
>>> import mindspore
|
|
1098
|
-
>>> import mindspore.ops.operations._rl_inner_ops as rl_ops
|
|
1099
|
-
>>> create_op = rl_ops.TensorsQueueCreate(mindspore.float32, ((), (2, 2)), 4)
|
|
1100
|
-
>>> handle = create_op()
|
|
1101
|
-
>>> clear_op = ops.TensorsQueueClear()
|
|
1102
|
-
>>> clear_op(handle)
|
|
1103
|
-
"""
|
|
1104
|
-
@prim_attr_register
|
|
1105
|
-
def __init__(self):
|
|
1106
|
-
self.add_prim_attr('side_effect_mem', True)
|
|
1107
|
-
|
|
1108
|
-
def infer_shape(self, handle_shape):
|
|
1109
|
-
return ()
|
|
1110
|
-
|
|
1111
|
-
def infer_dtype(self, handle_type):
|
|
1112
|
-
validator.check_type_name("handle", handle_type, (mstype.int64), self.name)
|
|
1113
|
-
return mstype.int64
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
class MuxSend(PrimitiveWithInfer):
|
|
1117
|
-
r"""
|
|
1118
|
-
Send tensors to the specified dest_rank.
|
|
1119
|
-
|
|
1120
|
-
.. warning::
|
|
1121
|
-
This is an experimental API that is subject to change or deletion.
|
|
1122
|
-
|
|
1123
|
-
Note:
|
|
1124
|
-
Send and Receive must be used in combination.
|
|
1125
|
-
Send must be used between servers.
|
|
1126
|
-
|
|
1127
|
-
Args:
|
|
1128
|
-
dest_rank (int): A required integer identifying the destination rank.
|
|
1129
|
-
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
|
|
1130
|
-
|
|
1131
|
-
Inputs:
|
|
1132
|
-
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
1133
|
-
|
|
1134
|
-
Examples:
|
|
1135
|
-
>>> from mindspore import ops
|
|
1136
|
-
>>> import mindspore.nn as nn
|
|
1137
|
-
>>> from mindspore.communication import init
|
|
1138
|
-
>>> from mindspore import Tensor
|
|
1139
|
-
>>> import numpy as np
|
|
1140
|
-
>>>
|
|
1141
|
-
>>> init()
|
|
1142
|
-
>>> class Net(nn.Cell):
|
|
1143
|
-
>>> def __init__(self):
|
|
1144
|
-
>>> super(Net, self).__init__()
|
|
1145
|
-
>>> self.depend = ops.Depend()
|
|
1146
|
-
>>> self.send = ops.Send(dest_rank=8, group="hccl_world_group")
|
|
1147
|
-
>>>
|
|
1148
|
-
>>> def construct(self, x):
|
|
1149
|
-
>>> out = self.depend(x, self.send(x))
|
|
1150
|
-
>>> return out
|
|
1151
|
-
>>>
|
|
1152
|
-
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
1153
|
-
>>> net = Net()
|
|
1154
|
-
>>> output = net(input_)
|
|
1155
|
-
"""
|
|
1156
|
-
|
|
1157
|
-
@prim_attr_register
|
|
1158
|
-
def __init__(self, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
|
|
1159
|
-
self.dest_rank = dest_rank
|
|
1160
|
-
self.group = group
|
|
1161
|
-
self.add_prim_attr("fusion", 1)
|
|
1162
|
-
self.add_prim_attr('side_effect_mem', True)
|
|
1163
|
-
|
|
1164
|
-
def infer_shape(self, x_shape):
|
|
1165
|
-
self.add_prim_attr("shape", x_shape)
|
|
1166
|
-
return []
|
|
1167
|
-
|
|
1168
|
-
def infer_dtype(self, x_dtype):
|
|
1169
|
-
return x_dtype[0]
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
class MuxReceive(PrimitiveWithInfer):
|
|
1173
|
-
r"""
|
|
1174
|
-
receive tensors from src_rank.
|
|
1175
|
-
|
|
1176
|
-
.. warning::
|
|
1177
|
-
This is an experimental API that is subject to change or deletion.
|
|
1178
|
-
|
|
1179
|
-
Note:
|
|
1180
|
-
Send and Receive must be used in combination.
|
|
1181
|
-
Receive must be used between servers.
|
|
1182
|
-
|
|
1183
|
-
Args:
|
|
1184
|
-
shape (list[int]): A required list identifying the shape of the tensor to be received.
|
|
1185
|
-
dtype (Type): A required Type identifying the type of the tensor to be received. The supported types:
|
|
1186
|
-
int8, int16, int32, float16, float32.
|
|
1187
|
-
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
|
|
1188
|
-
|
|
1189
|
-
Inputs:
|
|
1190
|
-
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
1191
|
-
|
|
1192
|
-
Examples:
|
|
1193
|
-
>>> from mindspore import ops
|
|
1194
|
-
>>> import mindspore.nn as nn
|
|
1195
|
-
>>> from mindspore.communication import init
|
|
1196
|
-
>>> from mindspore import Tensor
|
|
1197
|
-
>>> import numpy as np
|
|
1198
|
-
>>>
|
|
1199
|
-
>>> init()
|
|
1200
|
-
>>> class Net(nn.Cell):
|
|
1201
|
-
>>> def __init__(self):
|
|
1202
|
-
>>> super(Net, self).__init__()
|
|
1203
|
-
>>> self.recv = ops.Receive(shape=[2, 8], dtype=np.float32, group="hccl_world_group")
|
|
1204
|
-
>>>
|
|
1205
|
-
>>> def construct(self):
|
|
1206
|
-
>>> out = self.recv()
|
|
1207
|
-
>>> return out
|
|
1208
|
-
>>>
|
|
1209
|
-
>>> net = Net()
|
|
1210
|
-
>>> output = net()
|
|
1211
|
-
"""
|
|
1212
|
-
|
|
1213
|
-
@prim_attr_register
|
|
1214
|
-
def __init__(self, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
|
|
1215
|
-
self.shape = shape
|
|
1216
|
-
self.dtype = dtype
|
|
1217
|
-
self.group = group
|
|
1218
|
-
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
|
|
1219
|
-
args = {"dtype": dtype}
|
|
1220
|
-
self.add_prim_attr('side_effect_mem', True)
|
|
1221
|
-
self.add_prim_attr("fusion", 1)
|
|
1222
|
-
validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
|
|
1223
|
-
|
|
1224
|
-
def infer_shape(self, x_shape=None):
|
|
1225
|
-
return tuple(self.get_attr_dict()['shape'])
|
|
1226
|
-
|
|
1227
|
-
def infer_dtype(self, x_dtype=None):
|
|
1228
|
-
out_type = []
|
|
1229
|
-
for _ in self.shape:
|
|
1230
|
-
out_type.append(self.dtype)
|
|
1231
|
-
return tuple(out_type)
|