mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -16,199 +16,9 @@
16
16
  """Inner operators for reinforcement learning."""
17
17
 
18
18
  from __future__ import absolute_import
19
- import functools
20
- from mindspore.common.dtype import type_size_in_bytes
21
- import mindspore.context as context
22
19
  from mindspore import _checkparam as validator
23
20
  from mindspore.common import dtype as mstype
24
21
  from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer, Primitive
25
- from mindspore.communication.management import GlobalComm
26
-
27
-
28
- class EnvCreate(PrimitiveWithInfer):
29
- r"""
30
- Create a built-in reinforcement learning environment. Repeated calls to the operator will return the previously
31
- created handle. Make sure to create a new operator instance if you want to create a new environment instance.
32
-
33
- .. warning::
34
- This is an experimental API that is subject to change or deletion.
35
-
36
- Args:
37
- name (str): Name of built-in environment.
38
- kwargs (any): Environment related parameters.
39
-
40
- Inputs:
41
- No inputs.
42
-
43
- Outputs:
44
- handle(Tensor): Handle of created environment instance with dtype int and shape (1,).
45
-
46
- Raises:
47
- TypeError: The environment not supported.
48
- TypeError: The environment parameters not provided.
49
-
50
- Supported Platforms:
51
- ``GPU``
52
- """
53
-
54
- def __init__(self, name, **kwargs):
55
- super(EnvCreate, self).__init__(self.__class__.__name__)
56
- self.add_prim_attr('name', name)
57
- for key in kwargs:
58
- self.add_prim_attr(key, kwargs[key])
59
-
60
- def infer_shape(self, *args):
61
- return (1,)
62
-
63
- def infer_dtype(self, *args):
64
- return mstype.int64
65
-
66
-
67
- class EnvReset(PrimitiveWithInfer):
68
- r"""
69
- Reset reinforcement learning built-in environment.
70
-
71
- .. warning::
72
- This is an experimental API that is subject to change or deletion.
73
-
74
- Args:
75
- handle (int): The handle returned by `EnvCreate` operator.
76
- state_shape (list[tuple[int]]): The dimensionality of the state.
77
- state_dtype (list[:class:`mindspore.dtype`]): The type of the state.
78
- reward_shape (list[tuple[int]]): The dimensionality of the reward.
79
- reward_dtype (list[:class:`mindspore.dtype`]): The type of the reward.echo
80
-
81
- Inputs:
82
- No inputs.
83
-
84
- Outputs:
85
- Tensor, environment observation after reset.
86
-
87
- Raises:
88
- TypeError: Environment instance not exist.
89
-
90
- Supported Platforms:
91
- ``GPU``
92
- """
93
-
94
- @prim_attr_register
95
- def __init__(self, handle, state_shape, state_dtype):
96
- super(EnvReset, self).__init__(self.__class__.__name__)
97
- validator.check_value_type("handle", handle, [int], self.name)
98
- validator.check_value_type("state_shape", state_shape, [list, tuple], self.name)
99
-
100
- def infer_shape(self, *args):
101
- return self.state_shape
102
-
103
- def infer_dtype(self, *args):
104
- return self.state_dtype
105
-
106
-
107
- class EnvStep(PrimitiveWithInfer):
108
- r"""
109
- Run one environment timestep.
110
-
111
- .. warning::
112
- This is an experimental API that is subject to change or deletion.
113
-
114
- Args:
115
- handle (int): The handle returned by `EnvCreate` operator.
116
- state_shape (list[tuple[int]]): The dimensionality of the state.
117
- state_dtype (list[:class:`mindspore.dtype`]): The type of the state.
118
- reward_shape (list[tuple[int]]): The dimensionality of the reward.
119
- reward_dtype (list[:class:`mindspore.dtype`]): The type of the reward.
120
-
121
- Inputs:
122
- - **action** (Tensor) - action
123
-
124
- Outputs:
125
- - **state** (Tensor) - Environment state after previous action.
126
- - **reward** (Tensor), - Reward returned by environment.
127
- - **done** (Tensor), whether the episode has ended.
128
-
129
- Raises:
130
- TypeError: If dtype of `handle` is not int.
131
- TypeError: If dtype of `state_shape` is neither tuple nor list.
132
- TypeError: If dtype of `state_dtype` is not int nor float.
133
- TypeError: If dtype of `state_shape` is neither tuple nor list.
134
- TypeError: If dtype of `reward_dtype` is not int nor float.
135
-
136
- Supported Platforms:
137
- ``GPU``
138
- """
139
-
140
- @prim_attr_register
141
- def __init__(self, handle, state_shape, state_dtype, reward_shape, reward_dtype):
142
- super(EnvStep, self).__init__(self.__class__.__name__)
143
- validator.check_value_type("handle", handle, [int], self.name)
144
- validator.check_value_type("state_shape", state_shape, [list, tuple], self.name)
145
- validator.check_value_type("reward_shape", reward_shape, [list, tuple], self.name)
146
-
147
- def infer_shape(self, action_shape):
148
- return self.state_shape, self.reward_shape, (self.state_shape[0],)
149
-
150
- def infer_dtype(self, action_dtype):
151
- return self.state_dtype, self.reward_dtype, mstype.bool_
152
-
153
-
154
- class DiscountedReturn(PrimitiveWithInfer):
155
- r"""
156
- Calculate discounted return.
157
-
158
- Set discounted return as :math:`G`, discounted factor as :math:`\gamma`, reward as :math:`R`,
159
- timestep as :math:`t`, max timestep as :math:`N`. Then :math:`G_{t} = \Sigma_{t=0}^N{\gamma^tR_{t+1}}`
160
-
161
- For the reward sequence contain multi-episode, :math:`done` is introduced for indicating episode boundary,
162
- :math:`last\_state\_value` represents value after final step of last episode.
163
-
164
- Args:
165
- gamma (float): Discounted factor between [0, 1].
166
-
167
- Inputs:
168
- - **reward** (Tensor) - The reward sequence contains multi-episode.
169
- Tensor of shape :math:`(Timestep, Batch, ...)`
170
- - **done** (Tensor) - The episode done flag. Tensor of shape :math:`(Timestep, Batch)`.
171
- The data type must be bool.
172
- - **last_state_value** (Tensor) - The value after final step of last episode.
173
- Tensor of shape :math:`(Batch, ...)`
174
-
175
- Examples:
176
- >>> net = DiscountedReturn(gamma=0.99)
177
- >>> reward = Tensor([[1, 1, 1, 1]], dtype=mindspore.float32)
178
- >>> done = Tensor([[False, False, True, False]])
179
- >>> last_state_value = Tensor([2.], dtype=mindspore.float32)
180
- >>> ret = net(reward, done, last_state_value)
181
- >>> print(output.shape)
182
- (2, 2)
183
- """
184
-
185
- @prim_attr_register
186
- def __init__(self, gamma):
187
- self.init_prim_io_names(inputs=['reward', 'done', 'last_state_value'], outputs=['output'])
188
- validator.check_float_range(gamma, 0, 1, validator.INC_RIGHT, "gamma", self.name)
189
-
190
- def infer_shape(self, reward_shape, done_shape, last_state_value_shape):
191
- if len(reward_shape) != len(done_shape):
192
- raise ValueError(f'For \'{self.name}\', len(reward) and len(done) must be the same, ',
193
- f'but got {len(reward_shape)} and {len(done_shape)}.')
194
-
195
- if reward_shape[0] != done_shape[0]:
196
- raise ValueError(f'For \'{self.name}\', the first element of the shape of \'reward\' '
197
- f'and \'done\' must be the same, but got reward.shape[0]:'
198
- f' {reward_shape[0]} and done.shape[0]: {done_shape[0]}.')
199
-
200
- if reward_shape[1:] != last_state_value_shape:
201
- raise ValueError(f'For \'{self.name}\', reward.shape[1:] and last_state_value.shape must be the same, '
202
- f'but got reward.shape[1:]: {reward_shape[1:]} '
203
- f'and last_state_value.shape: {last_state_value_shape}.')
204
- return reward_shape
205
-
206
- def infer_dtype(self, reward_dtype, done_dtype, last_state_value_dtype):
207
- valid_dtypes = (mstype.float16, mstype.float32)
208
- args = {"reward": reward_dtype, "last_state_value": last_state_value_dtype}
209
- validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
210
- validator.check_tensor_dtype_valid('done_dtype', done_dtype, [mstype.bool_], self.name)
211
- return reward_dtype
212
22
 
213
23
 
214
24
  class GRUV2(PrimitiveWithInfer):
@@ -486,746 +296,3 @@ class CudnnGRU(Primitive):
486
296
  self.num_directions = 2
487
297
  else:
488
298
  self.num_directions = 1
489
-
490
-
491
- class PriorityReplayBufferCreate(PrimitiveWithInfer):
492
- r"""
493
- PriorityReplayBuffer is experience container used in Deep Q-Networks.
494
- The algorithm is proposed in `Prioritized Experience Replay <https://arxiv.org/abs/1511.05952>`.
495
- Same as the normal replay buffer, it lets the reinforcement learning agents remember and reuse experiences from the
496
- past. Besides, it replays important transitions more frequently and improve sample efficiency.
497
-
498
- Args:
499
- capcity (int64): Capacity of the buffer. It is recommended that set capacity to pow(2, N).
500
- alpha (float): The parameter determines how much prioritization is used between [0, 1].
501
- beta (float): The parameter determines how much compensations for non-uniform probabilities between [0, 1].
502
- shapes (list[tuple[int]]): The dimensionality of the transition.
503
- dtypes (list[:class:`mindspore.dtype`]): The type of the transition.
504
- seed0 (int): Random seed0, must be non-negative. Default: 0.
505
- seed1 (int): Random seed1, must be non-negative. Default: 0.
506
-
507
- Outputs:
508
- handle(Tensor): Handle of created priority replay buffer instance with dtype int64 and shape (1,).
509
-
510
- Raises:
511
- TypeError: The args not provided.
512
-
513
- Supported Platforms:
514
- ``Ascend`` ``GPU`` ``CPU``
515
- """
516
-
517
- @prim_attr_register
518
- def __init__(self, capacity, alpha, shapes, dtypes, seed0, seed1):
519
- """Initialize PriorityReplaBufferCreate."""
520
- validator.check_int(capacity, 1, validator.GE, "capacity", self.name)
521
- validator.check_float_range(alpha, 0.0, 1.0, validator.INC_BOTH)
522
- validator.check_value_type("shape of init data", shapes, [tuple, list], self.name)
523
- validator.check_value_type("dtypes of init data", dtypes, [tuple, list], self.name)
524
- validator.check_non_negative_int(seed0, "seed0", self.name)
525
- validator.check_non_negative_int(seed1, "seed1", self.name)
526
-
527
- schema = []
528
- for shape, dtype in zip(shapes, dtypes):
529
- num_element = functools.reduce(lambda x, y: x * y, shape, 1)
530
- schema.append(num_element * type_size_in_bytes(dtype))
531
- self.add_prim_attr("schema", schema)
532
-
533
- def infer_shape(self):
534
- return (1,)
535
-
536
- def infer_dtype(self):
537
- return mstype.int64
538
-
539
-
540
- class PriorityReplayBufferPush(PrimitiveWithInfer):
541
- r"""
542
- Push a transition to the priority replay buffer.
543
-
544
- Args:
545
- handle(Tensor): Priority replay buffer instance handle with dtype int64 and shape (1,).
546
-
547
- Outputs:
548
- handle(Tensor): Priority replay buffer instance handle with dtype int64 and shape (1,).
549
-
550
- Raises:
551
- TypeError: The priority replay buffer not created before.
552
-
553
- Supported Platforms:
554
- ``Ascend`` ``GPU`` ``CPU``
555
- """
556
-
557
- @prim_attr_register
558
- def __init__(self, handle):
559
- """Initialize PriorityReplaBufferPush."""
560
- validator.check_int(handle, 0, validator.GE, "handle", self.name)
561
-
562
- def infer_shape(self, *inputs):
563
- return (1,)
564
-
565
- def infer_dtype(self, *inputs):
566
- return mstype.int64
567
-
568
-
569
- class PriorityReplayBufferSample(PrimitiveWithInfer):
570
- r"""
571
- Sample a transition to the priority replay buffer.
572
-
573
- .. warning::
574
- This is an experimental API that is subject to change or deletion.
575
-
576
- Args:
577
- handle(Tensor): Priority replay buffer instance handle with dtype int64 and shape (1,).
578
- batch_size (int): The size of the sampled transitions.
579
- shapes (list[tuple[int]]): The dimensionality of the transition.
580
- dtypes (list[:class:`mindspore.dtype`]): The type of the transition.
581
-
582
- Outputs:
583
- tuple(Tensor): Transition with its indices and bias correction weights.
584
-
585
- Raises:
586
- TypeError: The priority replay buffer not created before.
587
-
588
- Supported Platforms:
589
- ``Ascend`` ``GPU`` ``CPU``
590
- """
591
-
592
- @prim_attr_register
593
- def __init__(self, handle, batch_size, shapes, dtypes):
594
- """Initialize PriorityReplaBufferSample."""
595
- validator.check_int(handle, 0, validator.GE, "capacity", self.name)
596
- validator.check_int(batch_size, 1, validator.GE, "batch_size", self.name)
597
- validator.check_value_type("shape of init data", shapes, [tuple, list], self.name)
598
- validator.check_value_type("dtypes of init data", dtypes, [tuple, list], self.name)
599
-
600
- schema = []
601
- for shape, dtype in zip(shapes, dtypes):
602
- num_element = functools.reduce(lambda x, y: x * y, shape, 1)
603
- schema.append(num_element * type_size_in_bytes(dtype))
604
- self.add_prim_attr("schema", schema)
605
-
606
- def infer_shape(self, beta):
607
- output_shape = [(self.batch_size,), (self.batch_size,)]
608
- for shape in self.shapes:
609
- output_shape.append((self.batch_size,) + shape)
610
- # indices, weights, transitions
611
- return tuple(output_shape)
612
-
613
- def infer_dtype(self, beta):
614
- return (mstype.int64, mstype.float32) + self.dtypes
615
-
616
-
617
- class PriorityReplayBufferUpdate(PrimitiveWithInfer):
618
- r"""
619
- Update transition prorities.
620
-
621
- Args:
622
- handle(Tensor): Priority replay buffer instance handle with dtype int64 and shape (1,).
623
-
624
- Inputs:
625
- - **indices** (Tensor) - transition indices.
626
- - **priorities** (Tensor) - Transition priorities.
627
-
628
- Outputs:
629
- Priority replay buffer instance handle with dtype int64 and shape (1,).
630
-
631
- Raises:
632
- TypeError: The priority replay buffer not created before.
633
-
634
- Supported Platforms:
635
- ``Ascend`` ``GPU`` ``CPU``
636
- """
637
-
638
- @prim_attr_register
639
- def __init__(self, handle):
640
- """Initialize PriorityReplaBufferUpdate."""
641
- validator.check_int(handle, 0, validator.GE, "capacity", self.name)
642
-
643
- def infer_shape(self, indices, priorities):
644
- return (1,)
645
-
646
- def infer_dtype(self, indices, priorities):
647
- return mstype.int64
648
-
649
-
650
- class PriorityReplayBufferDestroy(PrimitiveWithInfer):
651
- r"""
652
- Destroy the replay buffer.
653
-
654
- Args:
655
- handle(Tensor): Priority replay buffer instance handle with dtype int64 and shape (1,).
656
-
657
- Outputs:
658
- Priority replay buffer instance handle with dtype int64 and shape (1,).
659
-
660
- Raises:
661
- TypeError: The priority replay buffer not created before.
662
-
663
- Supported Platforms:
664
- ``Ascend`` ``GPU`` ``CPU``
665
- """
666
-
667
- @prim_attr_register
668
- def __init__(self, handle):
669
- """Initialize PriorityReplayBufferDestroy."""
670
- validator.check_int(handle, 0, validator.GE, "handle", self.name)
671
-
672
- def infer_shape(self):
673
- return (1,)
674
-
675
- def infer_dtype(self):
676
- return mstype.int64
677
-
678
-
679
- class ReservoirReplayBufferCreate(Primitive):
680
- r"""
681
- ReservoirReplayBufferCreate is experience container used in reinforcement learning.
682
- The algorithm is proposed in `Random sampling with a reservoir <https://dl.acm.org/doi/pdf/10.1145/3147.3165>`
683
- which used in `Deep Counterfactual Regret Minimization <https://arxiv.org/abs/1811.00164>`.
684
- It lets the reinforcement learning agents remember and reuse experiences from the past. Besides, It keeps an
685
- 'unbiased' sample of previous iterations.
686
-
687
- Args:
688
- capcity (int64): Capacity of the buffer.
689
- shapes (list[tuple[int]]): The dimensionality of the transition.
690
- dtypes (list[:class:`mindspore.dtype`]): The type of the transition.
691
- seed0 (int): Random seed0, must be non-negative. Default: 0.
692
- seed1 (int): Random seed1, must be non-negative. Default: 0.
693
-
694
- Outputs:
695
- handle(Tensor): Handle of created replay buffer instance with dtype int64 and shape (1,).
696
-
697
- Raises:
698
- TypeError: The args not provided.
699
-
700
- Supported Platforms:
701
- ``Ascend`` ``GPU`` ``CPU``
702
- """
703
-
704
- @prim_attr_register
705
- def __init__(self, capacity, shapes, dtypes, seed0, seed1):
706
- """Initialize ReservoirReplayBufferCreate."""
707
- validator.check_int(capacity, 1, validator.GE, "capacity", self.name)
708
- validator.check_value_type("shape of init data", shapes, [tuple, list], self.name)
709
- validator.check_value_type("dtypes of init data", dtypes, [tuple, list], self.name)
710
- validator.check_non_negative_int(seed0, "seed0", self.name)
711
- validator.check_non_negative_int(seed1, "seed1", self.name)
712
-
713
- schema = []
714
- for shape, dtype in zip(shapes, dtypes):
715
- num_element = functools.reduce(lambda x, y: x * y, shape, 1)
716
- schema.append(num_element * type_size_in_bytes(dtype))
717
- self.add_prim_attr("schema", schema)
718
-
719
-
720
- class ReservoirReplayBufferPush(Primitive):
721
- r"""
722
- Push a transition to the replay buffer.
723
-
724
- Args:
725
- handle(Tensor): The replay buffer instance handle with dtype int64 and shape (1,).
726
-
727
- Outputs:
728
- handle(Tensor): The replay buffer instance handle with dtype int64 and shape (1,).
729
-
730
- Raises:
731
- TypeError: The replay buffer not created before.
732
-
733
- Supported Platforms:
734
- ``Ascend`` ``GPU`` ``CPU``
735
- """
736
-
737
- @prim_attr_register
738
- def __init__(self, handle):
739
- """Initialize ReservoirReplayBufferPush."""
740
- validator.check_int(handle, 0, validator.GE, "handle", self.name)
741
-
742
-
743
- class ReservoirReplayBufferSample(Primitive):
744
- r"""
745
- Sample a transition to the replay buffer.
746
-
747
- .. warning::
748
- This is an experimental API that is subject to change or deletion.
749
-
750
- Args:
751
- handle(Tensor): Priority replay buffer instance handle with dtype int64 and shape (1,).
752
- batch_size (int): The size of the sampled transitions.
753
- shapes (list[tuple[int]]): The dimensionality of the transition.
754
- dtypes (list[:class:`mindspore.dtype`]): The type of the transition.
755
-
756
- Outputs:
757
- tuple(Tensor): Transition with its indices and bias correction weights.
758
-
759
- Raises:
760
- TypeError: The replay buffer not created before.
761
-
762
- Supported Platforms:
763
- ``Ascend`` ``GPU`` ``CPU``
764
- """
765
-
766
- @prim_attr_register
767
- def __init__(self, handle, batch_size, shapes, dtypes):
768
- """Initialize PriorityReplaBufferSample."""
769
- validator.check_int(handle, 0, validator.GE, "capacity", self.name)
770
- validator.check_int(batch_size, 1, validator.GE, "batch_size", self.name)
771
- validator.check_value_type("shape of init data", shapes, [tuple, list], self.name)
772
- validator.check_value_type("dtypes of init data", dtypes, [tuple, list], self.name)
773
-
774
- schema = []
775
- for shape, dtype in zip(shapes, dtypes):
776
- num_element = functools.reduce(lambda x, y: x * y, shape, 1)
777
- schema.append(num_element * type_size_in_bytes(dtype))
778
- self.add_prim_attr("schema", schema)
779
-
780
-
781
- class ReservoirReplayBufferDestroy(PrimitiveWithInfer):
782
- r"""
783
- Destroy the replay buffer.
784
-
785
- Args:
786
- handle(Tensor): The Replay buffer instance handle with dtype int64 and shape (1,).
787
-
788
- Outputs:
789
- Replay buffer instance handle with dtype int64 and shape (1,).
790
-
791
- Raises:
792
- TypeError: The replay buffer not created before.
793
-
794
- Supported Platforms:
795
- ``Ascend`` ``GPU`` ``CPU``
796
- """
797
-
798
- @prim_attr_register
799
- def __init__(self, handle):
800
- """Initialize ReservoirReplayBufferDestroy."""
801
- validator.check_int(handle, 0, validator.GE, "handle", self.name)
802
-
803
-
804
- class BatchAssign(PrimitiveWithInfer):
805
- """
806
- Assign the parameters of the source to overwrite the target.
807
-
808
- Args:
809
- lock (bool): Lock when the operator is Write, else shared the mutex. Default: ``True``.
810
-
811
- Inputs:
812
- - **dst_model** (tuple) - A parameters tuple of the dst model.
813
- - **source_model** (tuple) - A parameters tuple of the source model.
814
-
815
- Outputs:
816
- None.
817
-
818
- Raises:
819
- TypeError: If `lock` is not a bool.
820
- ValueError: If elements shape between inputs are not the same.
821
- TypeError: If inputs are not in Tensor type.
822
-
823
- Supported Platforms:
824
- ``GPU`` ``CPU``
825
- """
826
-
827
- @prim_attr_register
828
- def __init__(self, lock=True):
829
- """Initialize BatchAssign."""
830
- self.lock = validator.check_value_type("lock", lock, (bool,), self.name)
831
- self.add_prim_attr("lock", self.lock)
832
- self.add_prim_attr('side_effect_mem', True)
833
- if context.get_context('device_target') == "Ascend":
834
- self.add_prim_attr('device_target', "CPU")
835
-
836
- def infer_shape(self, dst_shape, source_shape):
837
- validator.check_equal_int(len(dst_shape), len(source_shape), "inputs elements", self.name)
838
- for i, shp in enumerate(dst_shape):
839
- if shp != source_shape[i]:
840
- raise ValueError(f'{self.name} element must be same, ',
841
- f'but got {shp} and {dst_shape[i]}.')
842
- return []
843
-
844
- def infer_dtype(self, dst_dtype, source_dtype):
845
- for i, dst_type in enumerate(dst_dtype):
846
- args = {'dst': dst_type, 'source': source_dtype[i]}
847
- validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
848
- return mstype.int64
849
-
850
-
851
- class TensorsQueueCreate(PrimitiveWithInfer):
852
- r"""
853
- TensorsQueueCreate used to create a TensorsQueue and return an unique handle.
854
-
855
- .. warning::
856
- This is an experimental API that is subject to change or deletion.
857
-
858
- Args:
859
- dtype (mindspore.dtype): the data type in the TensorsQueue.
860
- shapes (tuple(tuple(int))): the shape of each tensor in element.
861
- size (int): The size of the TensorsQueue.
862
- name (str): the name of this TensorsQueue. Default: "Q".
863
-
864
- Inputs:
865
- None.
866
-
867
- Outputs:
868
- - **output** (Tensor[mindspore.int64]) - an unique handle binded to the TensorsQueue.
869
-
870
- Supported Platforms:
871
- ``GPU`` ``CPU``
872
-
873
- Examples:
874
- >>> import mindspore
875
- >>> import mindspore.ops.operations._rl_inner_ops as rl_ops
876
- >>> create_op = rl_ops.TensorsQueueCreate(mindspore.float32,((), (1, 16)), 10, "q")
877
- >>> handle = create_op()
878
- >>> print(handle)
879
- 0
880
- """
881
- @prim_attr_register
882
- def __init__(self, dtype, shapes, size=0, name="Q"):
883
- validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
884
- validator.check_int(size, 0, validator.GE, "size", self.name)
885
- elements_num = len(shapes)
886
- validator.check_int(elements_num, 1, validator.GE, "elements_num", self.name)
887
- self.add_prim_attr('shapes', shapes)
888
- self.add_prim_attr('dtype', dtype)
889
- self.add_prim_attr('elements_num', elements_num)
890
- self.add_prim_attr('size', size)
891
- self.add_prim_attr('side_effect_mem', True)
892
- self.add_prim_attr('name', name)
893
-
894
- def infer_shape(self):
895
- return ()
896
-
897
- def infer_dtype(self):
898
- return mstype.int64
899
-
900
-
901
- class TensorsQueuePut(PrimitiveWithInfer):
902
- r"""
903
- TensorsQueuePut used to put tensors into a created TensorsQueue.
904
-
905
- .. warning::
906
- This is an experimental API that is subject to change or deletion.
907
-
908
- Args:
909
- dtype (mindspore.dtype): the data type in the TensorsQueue.
910
- shapes (tuple(tuple(int))): the shape of each tensor in element.
911
-
912
- Inputs:
913
- - **handle** (Tensor[int64]) - The handle pointed to the TensorsQueue.
914
- - **value** (list[Tensor] or tuple(Tensors)) - The element to add into the TensorsQueue.
915
-
916
- Outputs:
917
- None.
918
-
919
- Supported Platforms:
920
- ``GPU`` ``CPU``
921
-
922
- Examples:
923
- >>> import mindspore
924
- >>> import mindspore.ops.operations._rl_inner_ops as rl_ops
925
- >>> create_op = rl_ops.TensorsQueueCreate(mstype.float32, ((), (1, 16)), 10)
926
- >>> handle = create_op()
927
- >>> out_op = rl_ops.TensorsQueuePut(mstype.float32, ((), (1, 16)))
928
- >>> out_op.put(handle, (Tensor(1, mstype.float32), Tensor(2, mstype.float32)))
929
- """
930
- @prim_attr_register
931
- def __init__(self, dtype, shapes):
932
- validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
933
- elements_num = len(shapes)
934
- self.elements_num = validator.check_positive_int(elements_num, "elements_num", self.name)
935
- self.shapes = shapes
936
- self.add_prim_attr('dtype', dtype)
937
- self.add_prim_attr('elements_num', elements_num)
938
- self.add_prim_attr('side_effect_mem', True)
939
-
940
- def infer_shape(self, handle_shape, elements_shape):
941
- validator.check_equal_int(len(elements_shape), self.elements_num, "inputs elements", self.name)
942
- for i, shape in enumerate(elements_shape):
943
- if tuple(shape) != self.shapes[i]:
944
- raise ValueError(f'{self.name} init shape and input shape must be the same, ',
945
- f'but got {self.shapes[i]} and input {shape} in position {i}.')
946
- return ()
947
-
948
- def infer_dtype(self, handle_type, elements_type):
949
- validator.check_type_name("handle", handle_type, (mstype.int64), self.name)
950
- return mstype.int64
951
-
952
-
953
- class TensorsQueueGet(PrimitiveWithInfer):
954
- r"""
955
- TensorsQueueGet used to get tensors in the front of the TensorsQueue.
956
-
957
- .. warning::
958
- This is an experimental API that is subject to change or deletion.
959
-
960
- Args:
961
- shapes (tuple(tuple(int))): the shape of each tensor in element.
962
- dtype (mindspore.dtype): the data type in the TensorsQueue.
963
- pop_after_get (bool): if true, pop the element from TensorsQueue after get.
964
-
965
- Inputs:
966
- - **handle** (Tensor[int64]) - The handle pointed to the TensorsQueue.
967
-
968
- Outputs:
969
- - **value** (list[Tensor] or tuple(Tensors)) - The element in the front of the TensorsQueue.
970
-
971
- Supported Platforms:
972
- ``GPU`` ``CPU``
973
-
974
- Examples:
975
- >>> import mindspore
976
- >>> import mindspore.ops.operations._rl_inner_ops as rl_ops
977
- >>> create_op = rl_ops.TensorsQueueCreate(mstype.float32, ((), (1,2)), 10)
978
- >>> handle = create_op()
979
- >>> get_op = rl_ops.TensorsQueueGet(mstype.float32, ((), (1,2)))
980
- >>> tensors_list = get_op.get(handle)
981
- """
982
- @prim_attr_register
983
- def __init__(self, dtype, shapes, pop_after_get=False):
984
- validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
985
- elements_num = len(shapes)
986
- self.elements_num = validator.check_positive_int(elements_num, "elements_num", self.name)
987
- validator.check_bool(pop_after_get, "pop_after_get", self.name)
988
- self.shapes = shapes
989
- self.dtype = dtype
990
- self.add_prim_attr('dtype', dtype)
991
- self.add_prim_attr("shapes", shapes)
992
- self.add_prim_attr('elements_num', elements_num)
993
- self.add_prim_attr("pop_after_get", pop_after_get)
994
- self.add_prim_attr('side_effect_mem', True)
995
-
996
- def infer_shape(self, handle_shape):
997
- return tuple(self.shapes)
998
-
999
- def infer_dtype(self, handle_type):
1000
- validator.check_type_name("handle", handle_type, (mstype.int64), self.name)
1001
- out_shape = []
1002
- for _ in range(self.elements_num):
1003
- out_shape.append(self.dtype)
1004
- return tuple(out_shape)
1005
-
1006
-
1007
- class TensorsQueueClose(PrimitiveWithInfer):
1008
- r"""
1009
- TensorsQueueClose used to close the created TensorsQueue. The resources in TensorsQueue will be deleted.
1010
-
1011
- .. warning::
1012
- This is an experimental API that is subject to change or deletion.
1013
-
1014
- Inputs:
1015
- - **handle** (mindspore.int64) - The handle pointed to the TensorsQueue.
1016
-
1017
- Outputs:
1018
- None.
1019
-
1020
- Supported Platforms:
1021
- ``GPU`` ``CPU``
1022
-
1023
- Examples:
1024
- >>> import mindspore
1025
- >>> import mindspore.ops.operations._rl_inner_ops as rl_ops
1026
- >>> create_op = rl_ops.TensorsQueueCreate(mindspore.float32, ((), (3, 3)), 10)
1027
- >>> handle = create_op()
1028
- >>> close_op = ops.TensorsQueueClose()
1029
- >>> close_op(handle)
1030
- """
1031
- @prim_attr_register
1032
- def __init__(self):
1033
- self.add_prim_attr('side_effect_mem', True)
1034
-
1035
- def infer_shape(self, handle_shape):
1036
- return ()
1037
-
1038
- def infer_dtype(self, handle_type):
1039
- validator.check_type_name("handle", handle_type, (mstype.int64), self.name)
1040
- return mstype.int64
1041
-
1042
-
1043
- class TensorsQueueSize(PrimitiveWithInfer):
1044
- r"""
1045
- TensorsQueueSize used get the indeed size of TensorsQueue.
1046
-
1047
- .. warning::
1048
- This is an experimental API that is subject to change or deletion.
1049
-
1050
- Inputs:
1051
- - **handle** (mindspore.int64) - The handle pointed to the TensorsQueue.
1052
-
1053
- Outputs:
1054
- - **size** (mindspore.int64) - The used size of the TensorsQueue.
1055
-
1056
- Supported Platforms:
1057
- ``GPU`` ``CPU``
1058
-
1059
- Examples:
1060
- >>> import mindspore
1061
- >>> import mindspore.ops.operations._rl_inner_ops as rl_ops
1062
- >>> create_op = rl_ops.TensorsQueueCreate(mindspore.int32, ((), (3, 2)), 10)
1063
- >>> handle = create_op()
1064
- >>> size_op = ops.TensorsQueueSize()
1065
- >>> print(size_op())
1066
- >>> 0
1067
- """
1068
- @prim_attr_register
1069
- def __init__(self):
1070
- self.add_prim_attr('side_effect_mem', True)
1071
-
1072
- def infer_shape(self, handle_shape):
1073
- return ()
1074
-
1075
- def infer_dtype(self, handle_type):
1076
- validator.check_type_name("handle", handle_type, (mstype.int64), self.name)
1077
- return mstype.int64
1078
-
1079
-
1080
- class TensorsQueueClear(PrimitiveWithInfer):
1081
- r"""
1082
- TensorsQueueClear used to reset the created TensorsQueue. The instance of TensorsQueue is still aviliable.
1083
-
1084
- .. warning::
1085
- This is an experimental API that is subject to change or deletion.
1086
-
1087
- Inputs:
1088
- - **handle** (mindspore.int64) - The handle pointed to the TensorsQueue.
1089
-
1090
- Outputs:
1091
- None.
1092
-
1093
- Supported Platforms:
1094
- ``GPU`` ``CPU``
1095
-
1096
- Examples:
1097
- >>> import mindspore
1098
- >>> import mindspore.ops.operations._rl_inner_ops as rl_ops
1099
- >>> create_op = rl_ops.TensorsQueueCreate(mindspore.float32, ((), (2, 2)), 4)
1100
- >>> handle = create_op()
1101
- >>> clear_op = ops.TensorsQueueClear()
1102
- >>> clear_op(handle)
1103
- """
1104
- @prim_attr_register
1105
- def __init__(self):
1106
- self.add_prim_attr('side_effect_mem', True)
1107
-
1108
- def infer_shape(self, handle_shape):
1109
- return ()
1110
-
1111
- def infer_dtype(self, handle_type):
1112
- validator.check_type_name("handle", handle_type, (mstype.int64), self.name)
1113
- return mstype.int64
1114
-
1115
-
1116
- class MuxSend(PrimitiveWithInfer):
1117
- r"""
1118
- Send tensors to the specified dest_rank.
1119
-
1120
- .. warning::
1121
- This is an experimental API that is subject to change or deletion.
1122
-
1123
- Note:
1124
- Send and Receive must be used in combination.
1125
- Send must be used between servers.
1126
-
1127
- Args:
1128
- dest_rank (int): A required integer identifying the destination rank.
1129
- group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
1130
-
1131
- Inputs:
1132
- - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1133
-
1134
- Examples:
1135
- >>> from mindspore import ops
1136
- >>> import mindspore.nn as nn
1137
- >>> from mindspore.communication import init
1138
- >>> from mindspore import Tensor
1139
- >>> import numpy as np
1140
- >>>
1141
- >>> init()
1142
- >>> class Net(nn.Cell):
1143
- >>> def __init__(self):
1144
- >>> super(Net, self).__init__()
1145
- >>> self.depend = ops.Depend()
1146
- >>> self.send = ops.Send(dest_rank=8, group="hccl_world_group")
1147
- >>>
1148
- >>> def construct(self, x):
1149
- >>> out = self.depend(x, self.send(x))
1150
- >>> return out
1151
- >>>
1152
- >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
1153
- >>> net = Net()
1154
- >>> output = net(input_)
1155
- """
1156
-
1157
- @prim_attr_register
1158
- def __init__(self, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
1159
- self.dest_rank = dest_rank
1160
- self.group = group
1161
- self.add_prim_attr("fusion", 1)
1162
- self.add_prim_attr('side_effect_mem', True)
1163
-
1164
- def infer_shape(self, x_shape):
1165
- self.add_prim_attr("shape", x_shape)
1166
- return []
1167
-
1168
- def infer_dtype(self, x_dtype):
1169
- return x_dtype[0]
1170
-
1171
-
1172
- class MuxReceive(PrimitiveWithInfer):
1173
- r"""
1174
- receive tensors from src_rank.
1175
-
1176
- .. warning::
1177
- This is an experimental API that is subject to change or deletion.
1178
-
1179
- Note:
1180
- Send and Receive must be used in combination.
1181
- Receive must be used between servers.
1182
-
1183
- Args:
1184
- shape (list[int]): A required list identifying the shape of the tensor to be received.
1185
- dtype (Type): A required Type identifying the type of the tensor to be received. The supported types:
1186
- int8, int16, int32, float16, float32.
1187
- group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
1188
-
1189
- Inputs:
1190
- - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1191
-
1192
- Examples:
1193
- >>> from mindspore import ops
1194
- >>> import mindspore.nn as nn
1195
- >>> from mindspore.communication import init
1196
- >>> from mindspore import Tensor
1197
- >>> import numpy as np
1198
- >>>
1199
- >>> init()
1200
- >>> class Net(nn.Cell):
1201
- >>> def __init__(self):
1202
- >>> super(Net, self).__init__()
1203
- >>> self.recv = ops.Receive(shape=[2, 8], dtype=np.float32, group="hccl_world_group")
1204
- >>>
1205
- >>> def construct(self):
1206
- >>> out = self.recv()
1207
- >>> return out
1208
- >>>
1209
- >>> net = Net()
1210
- >>> output = net()
1211
- """
1212
-
1213
- @prim_attr_register
1214
- def __init__(self, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
1215
- self.shape = shape
1216
- self.dtype = dtype
1217
- self.group = group
1218
- valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
1219
- args = {"dtype": dtype}
1220
- self.add_prim_attr('side_effect_mem', True)
1221
- self.add_prim_attr("fusion", 1)
1222
- validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
1223
-
1224
- def infer_shape(self, x_shape=None):
1225
- return tuple(self.get_attr_dict()['shape'])
1226
-
1227
- def infer_dtype(self, x_dtype=None):
1228
- out_type = []
1229
- for _ in self.shape:
1230
- out_type.append(self.dtype)
1231
- return tuple(out_type)