mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -1,359 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """Operators for TensorArray."""
17
-
18
- import mindspore as ms
19
- from mindspore import _checkparam as validator
20
- from ...common import dtype as mstype
21
- from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
22
-
23
-
24
- class TensorArray(PrimitiveWithInfer):
25
- r"""
26
- TensorArrayCreate used to create a TensorArray and return an unique handle.
27
-
28
- .. warning::
29
- This is an experimental API that is subject to change or deletion.
30
-
31
- Args:
32
- dtype (mindspore.dtype): the data type in the TensorArray.
33
- element_shape (tuple[int]): the shape of each tensor in a TensorArray.
34
- dynamic_size (bool): If true the TensorArray can increase the size. Default: ``True``.
35
- size (int): The size of the TensorArray if dynamic_size = False.
36
- name (str): the name of this TensorArray. Default: "TA".
37
-
38
- Inputs:
39
- None.
40
-
41
- Outputs:
42
- - **output** (Tensor[mindspore.int64]) - an unique handle binded to the TensorArray.
43
-
44
- Supported Platforms:
45
- ``GPU`` ``CPU``
46
-
47
- Examples:
48
- >>> import mindspore
49
- >>> from mindspore import ops
50
- >>> create_op = ops.TensorArray(mindspore.int32, ())
51
- >>> handle = create_op()
52
- >>> print(handle)
53
- 0
54
- """
55
- @prim_attr_register
56
- def __init__(self, dtype, element_shape, dynamic_size=True, size=0, name="TA"):
57
- validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
58
- validator.check_int(size, 0, validator.GE, "size", self.name)
59
- self.add_prim_attr('dtype', dtype)
60
- self.add_prim_attr('element_shape', element_shape)
61
- self.add_prim_attr('dynamic_size', dynamic_size)
62
- self.add_prim_attr('size', size)
63
- self.add_prim_attr('side_effect_mem', True)
64
- self.add_prim_attr('name', name)
65
-
66
- def infer_shape(self):
67
- return ()
68
-
69
- def infer_dtype(self):
70
- return mstype.int64
71
-
72
-
73
- class TensorArrayWrite(PrimitiveWithInfer):
74
- r"""
75
- TensorArrayWrite used to write tensor into a created TensorArray.
76
-
77
- .. warning::
78
- This is an experimental API that is subject to change or deletion.
79
-
80
- Inputs:
81
- - **index** (Tensor[int64]) - The position to write.
82
- - **value** (Tensor) - The value to add into the TensorArray.
83
- - **handle** (Tensor[int64]) - The handle pointed to the TensorArray.
84
-
85
- Outputs:
86
- None.
87
-
88
- Supported Platforms:
89
- ``GPU`` ``CPU``
90
-
91
- Examples:
92
- >>> import mindspore
93
- >>> from mindspore import ops
94
- >>> create_op = ops.TensorArray(mindspore.int32, ())
95
- >>> handle = create_op()
96
- >>> write_op = ops.TensorArrayWrite()
97
- >>> write_op.write(handle, 0, 1)
98
- """
99
- @prim_attr_register
100
- def __init__(self):
101
- self.add_prim_attr('side_effect_mem', True)
102
-
103
- def infer_shape(self, handle_shape, index_shape, value_shape):
104
- return ()
105
-
106
- def infer_dtype(self, handle_type, index_type, value_type):
107
- validator.check_type_name("handle", handle_type, (ms.int64), self.name)
108
- validator.check_type_name("index", index_type, (int, ms.int64), self.name)
109
- validator.check_type_name("value", value_type, mstype.number_type + (mstype.bool_,), self.name)
110
- return mstype.int64
111
-
112
-
113
- class TensorArrayRead(PrimitiveWithInfer):
114
- r"""
115
- TensorArrayRead used to read tensor from a created TensorArray by the given index.
116
-
117
- .. warning::
118
- This is an experimental API that is subject to change or deletion.
119
-
120
- Args:
121
- dtype (mindspore.dtype): the data type in the TensorArray.
122
- element_shape (tuple[int]): the shape of each tensor in a TensorArray.
123
-
124
- Inputs:
125
- - **index** (Tensor[int64]) - The position to read.
126
- - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
127
-
128
- Outputs:
129
- - **output** (Tensor) - the value in position index.
130
-
131
- Supported Platforms:
132
- ``GPU`` ``CPU``
133
-
134
- Examples:
135
- >>> import mindspore
136
- >>> from mindspore import ops
137
- >>> create_op = ops.TensorArray(mindspore.int32, ())
138
- >>> handle = create_op()
139
- >>> write_op = ops.TensorArrayWrite()
140
- >>> write_op.write(handle, 0, 1)
141
- >>> read_op = ops.TensorArrayRead(mindspore.int32, ())
142
- >>> ans = read_op(handle, 0)
143
- >>> print(ans)
144
- 1
145
- """
146
- @prim_attr_register
147
- def __init__(self, dtype, element_shape):
148
- validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
149
- self.add_prim_attr('dtype', dtype)
150
- self.add_prim_attr('element_shape', element_shape)
151
- self.add_prim_attr('side_effect_mem', True)
152
- self.dtype = dtype
153
- self.shape = element_shape
154
-
155
- def infer_shape(self, handle_shape, index_shape):
156
- return self.shape
157
-
158
- def infer_dtype(self, handle_type, index_type):
159
- validator.check_type_name("handle", handle_type, (ms.int64), self.name)
160
- validator.check_type_name("index", index_type, (int, ms.int64), self.name)
161
- return self.dtype
162
-
163
-
164
- class TensorArrayClose(PrimitiveWithInfer):
165
- r"""
166
- TensorArrayClose used to close the created TensorArray. The resources in TensorArray will be deleted.
167
-
168
- .. warning::
169
- This is an experimental API that is subject to change or deletion.
170
-
171
- Inputs:
172
- - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
173
-
174
- Outputs:
175
- None.
176
-
177
- Supported Platforms:
178
- ``GPU`` ``CPU``
179
-
180
- Examples:
181
- >>> import mindspore
182
- >>> from mindspore import ops
183
- >>> create_op = ops.TensorArray(mindspore.int32, ())
184
- >>> handle = create_op()
185
- >>> close_op = ops.TensorArrayClose()
186
- >>> close_op(handle)
187
- """
188
- @prim_attr_register
189
- def __init__(self):
190
- self.add_prim_attr('side_effect_mem', True)
191
-
192
- def infer_shape(self, handle_shape):
193
- return ()
194
-
195
- def infer_dtype(self, handle_type):
196
- validator.check_type_name("handle", handle_type, (ms.int64), self.name)
197
- return mstype.int64
198
-
199
-
200
- class TensorArrayClear(PrimitiveWithInfer):
201
- r"""
202
- TensorArrayClear used to reset the created TensorArray. The instance of TensorArray is still aviliable.
203
-
204
- .. warning::
205
- This is an experimental API that is subject to change or deletion.
206
-
207
- Inputs:
208
- - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
209
-
210
- Outputs:
211
- None.
212
-
213
- Supported Platforms:
214
- ``GPU`` ``CPU``
215
-
216
- Examples:
217
- >>> import mindspore
218
- >>> from mindspore import ops
219
- >>> create_op = ops.TensorArray(mindspore.int32, ())
220
- >>> handle = create_op()
221
- >>> clear_op = ops.TensorArrayClear()
222
- >>> clear_op(handle)
223
- """
224
- @prim_attr_register
225
- def __init__(self):
226
- self.add_prim_attr('side_effect_mem', True)
227
-
228
- def infer_shape(self, handle_shape):
229
- return ()
230
-
231
- def infer_dtype(self, handle_type):
232
- validator.check_type_name("handle", handle_type, (ms.int64), self.name)
233
- return mstype.int64
234
-
235
-
236
- class TensorArrayStack(Primitive):
237
- r"""
238
- TensorArrayStack used to stack the tensors in a created TensorArray into one tensor.
239
-
240
- .. warning::
241
- This is an experimental API that is subject to change or deletion.
242
-
243
- Args:
244
- dtype (mindspore.dtype): the data type in the TensorArray.
245
- element_shape (tuple[int]): the shape of each tensor in a TensorArray.
246
-
247
- Inputs:
248
- - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
249
-
250
- Outputs:
251
- - **output** (Tensor) - the stacked value from the TensorArray.
252
-
253
- Supported Platforms:
254
- ``GPU`` ``CPU``
255
-
256
- Examples:
257
- >>> import mindspore
258
- >>> from mindspore import ops
259
- >>> create_op = ops.TensorArray(mindspore.int32, ())
260
- >>> handle = create_op()
261
- >>> write_op = ops.TensorArrayWrite()
262
- >>> write_op.write(handle, 0, 1)
263
- >>> write_op.write(handle, 1, 2)
264
- >>> stack_op = ops.TensorArrayStack(mindspore.int32, ())
265
- >>> ans = stack_op(handle)
266
- >>> print(ans)
267
- [1 2]
268
- """
269
- @prim_attr_register
270
- def __init__(self, dtype, element_shape, dynamic_size, size):
271
- """Initialize TensorArrayStack"""
272
- self.init_prim_io_names(inputs=[''], outputs=['output'])
273
- self.add_prim_attr('dtype', dtype)
274
- self.add_prim_attr('element_shape', element_shape)
275
- self.add_prim_attr('is_dynamic_shape', dynamic_size)
276
- self.add_prim_attr('size', size)
277
- self.add_prim_attr('side_effect_mem', True)
278
-
279
-
280
- class TensorArraySize(PrimitiveWithInfer):
281
- r"""
282
- TensorArraySize used to get the logical size of the created TensorArray.
283
-
284
- .. warning::
285
- This is an experimental API that is subject to change or deletion.
286
-
287
- Inputs:
288
- - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
289
-
290
- Outputs:
291
- - **output** (Tensor[mindspore.int64]) - the logical size of the TensorArray.
292
-
293
- Supported Platforms:
294
- ``GPU`` ``CPU``
295
-
296
- Examples:
297
- >>> import mindspore
298
- >>> from mindspore import ops
299
- >>> create_op = ops.TensorArray(mindspore.int32, ())
300
- >>> handle = create_op()
301
- >>> size_op = ops.TensorArraySize()
302
- >>> size = size_op(handle)
303
- """
304
- @prim_attr_register
305
- def __init__(self):
306
- self.add_prim_attr('side_effect_mem', True)
307
-
308
- def infer_shape(self, handle_shape):
309
- return ()
310
-
311
- def infer_dtype(self, handle_type):
312
- validator.check_type_name("handle", handle_type, (ms.int64), self.name)
313
- return mstype.int64
314
-
315
-
316
- class TensorArrayGather(PrimitiveWithInfer):
317
- r"""
318
- TensorArrayGather used to gather specified elements from the created TensorArray.
319
-
320
- .. warning::
321
- This is an experimental API that is subject to change or deletion.
322
-
323
- Args:
324
- dtype (mindspore.dtype): the data type in the TensorArray.
325
- element_shape (tuple[int]): the shape of each tensor in a TensorArray.
326
-
327
- Inputs:
328
- - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
329
- - **indices** (mindspore.int32) - The locations of the gathered elements.
330
-
331
- Outputs:
332
- - **output** (Tensor) - The gathered value from the TensorArray.
333
-
334
- Examples:
335
- >>> import mindspore
336
- >>> from mindspore import ops
337
- >>> from mindspore import numpy as mnp
338
- >>> create_op = ops.TensorArray(mindspore.float32, dynamic_size=False, element_shape=(8,))
339
- >>> handle = create_op()
340
- >>> indices = mnp.range(0, 25, 1, mindspore.int32)
341
- >>> gather_op = ops.TensorArrayGather(dtype=mindspore.float32, element_shape=(8,))
342
- >>> gather_result = gather_op(handle, indices)
343
- """
344
- @prim_attr_register
345
- def __init__(self, dtype, element_shape):
346
- self.init_prim_io_names(inputs=['handle', 'indices'], outputs=['value'])
347
- self.add_prim_attr("side_effect_mem", True)
348
- self.dtype = dtype
349
- self.element_shape = element_shape
350
-
351
- def infer_shape(self, handle, indices):
352
- if len(indices) != 1:
353
- return ValueError("indices dimension should be equal to 1")
354
- return [indices[0]] + list(self.element_shape)
355
-
356
- def infer_dtype(self, handle, indices):
357
- validator.check_type_name("handle", handle, (ms.int64), self.name)
358
- validator.check_type_name("indices", indices, (ms.int32), self.name)
359
- return self.dtype
@@ -1,288 +0,0 @@
1
- # Copyright 2021-2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """Operators for reinforce learning."""
17
-
18
- from functools import reduce
19
- import mindspore.context as context
20
- from mindspore import _checkparam as validator
21
- from ...common import dtype as mstype
22
- from ..primitive import prim_attr_register, PrimitiveWithInfer
23
-
24
-
25
- class BufferSample(PrimitiveWithInfer):
26
- r"""
27
- In reinforcement learning, the data is sampled from the replaybuffer randomly.
28
-
29
- Returns the tuple tensor with the given shape, decided by the given batchsize.
30
-
31
- .. warning::
32
- This is an experimental API that is subject to change or deletion.
33
-
34
- Args:
35
- capacity (int64): Capacity of the buffer, must be non-negative.
36
- batch_size (int64): The size of the sampled data, lessequal to `capacity`.
37
- buffer_shape (tuple(shape)): The shape of an buffer.
38
- buffer_dtype (tuple(type)): The type of an buffer.
39
- seed (int64): Random seed for sample. Default: ``0`` . If use the default seed, it will generate a ramdom
40
- one in kernel. Set a number other than `0` to keep a specific seed. Default: ``0`` .
41
- unique (bool): Whether the sampled data is strictly unique. Setting it to False has a better performance.
42
- Default: ``False`` .
43
-
44
- Inputs:
45
- - **data** (tuple(Parameter(Tensor))) - The tuple(Tensor) represents replaybuffer,
46
- each tensor is described by the `buffer_shape` and `buffer_type`.
47
- - **count** (Parameter) - The count means the real available size of the buffer,
48
- data type: int32.
49
- - **head** (Parameter) - The position of the first data in buffer, data type: int32.
50
-
51
- Outputs:
52
- tuple(Tensor). The shape is `batch_size` * `buffer_shape`. The dtype is `buffer_dtype`.
53
-
54
- Raises:
55
- TypeError: If `buffer_shape` is not a tuple.
56
- ValueError: If batch_size is larger than capacity.
57
- ValueError: If `capacity` is not a positive integer.
58
-
59
- Supported Platforms:
60
- ``GPU`` ``CPU``
61
-
62
- Examples:
63
- >>> capacity = 100
64
- >>> batch_size = 5
65
- >>> count = Parameter(Tensor(5, ms.int32), name="count")
66
- >>> head = Parameter(Tensor(0, ms.int32), name="head")
67
- >>> shapes = [(4,), (2,), (1,), (4,)]
68
- >>> types = [ms.float32, ms.int32, ms.int32, ms.float32]
69
- >>> buffer = [Parameter(Tensor(np.arange(100 * 4).reshape(100, 4).astype(np.float32)), name="states"),
70
- ... Parameter(Tensor(np.arange(100 * 2).reshape(100, 2).astype(np.int32)), name="action"),
71
- ... Parameter(Tensor(np.ones((100, 1)).astype(np.int32)), name="reward"),
72
- ... Parameter(Tensor(np.arange(100 * 4).reshape(100, 4).astype(np.float32)), name="state_")]
73
- >>> buffer_sample = ops.BufferSample(capacity, batch_size, shapes, types)
74
- >>> output = buffer_sample(buffer, count, head)
75
- >>> print(output)
76
- (Tensor(shape=[5, 4], dtype=Float32, value=
77
- [[ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00, 3.00000000e+00],
78
- [ 8.00000000e+00, 9.00000000e+00, 1.00000000e+01, 1.10000000e+01],
79
- [ 1.60000000e+01, 1.70000000e+01, 1.80000000e+01, 1.90000000e+01],
80
- [ 1.20000000e+01, 1.30000000e+01, 1.40000000e+01, 1.50000000e+01],
81
- [ 3.20000000e+01, 3.30000000e+01, 3.40000000e+01, 3.50000000e+01]]),
82
- Tensor(shape=[5, 2], dtype=Int32, value=
83
- [[ 0, 1],
84
- [ 4, 5],
85
- [ 8, 9],
86
- [ 6, 7],
87
- [16, 17]]),
88
- Tensor(shape=[5, 1], dtype=Int32, value=
89
- [[1],
90
- [1],
91
- [1],
92
- [1],
93
- [1]]),
94
- Tensor(shape=[5, 4], dtype=Float32, value=
95
- [[ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00, 3.00000000e+00],
96
- [ 8.00000000e+00, 9.00000000e+00, 1.00000000e+01, 1.10000000e+01],
97
- [ 1.60000000e+01, 1.70000000e+01, 1.80000000e+01, 1.90000000e+01],
98
- [ 1.20000000e+01, 1.30000000e+01, 1.40000000e+01, 1.50000000e+01],
99
- [ 3.20000000e+01, 3.30000000e+01, 3.40000000e+01, 3.50000000e+01]]))
100
- """
101
-
102
- @prim_attr_register
103
- def __init__(self, capacity, batch_size, buffer_shape, buffer_dtype, seed=0, unique=False):
104
- """Initialize BufferSample."""
105
- self.init_prim_io_names(inputs=["buffer"], outputs=["sample"])
106
- validator.check_value_type("shape of init data", buffer_shape, [tuple, list], self.name)
107
- validator.check_int(capacity, 1, validator.GE, "capacity", self.name)
108
- self._batch_size = batch_size
109
- self._buffer_shape = buffer_shape
110
- self._buffer_dtype = buffer_dtype
111
- self._n = len(buffer_shape)
112
- validator.check_int(self._batch_size, capacity, validator.LE, "batchsize", self.name)
113
- self.add_prim_attr('capacity', capacity)
114
- self.add_prim_attr('seed', seed)
115
- self.add_prim_attr('unique', unique)
116
- buffer_elements = []
117
- for shape in buffer_shape:
118
- buffer_elements.append(reduce(lambda x, y: x * y, shape))
119
- self.add_prim_attr('buffer_elements', buffer_elements)
120
- self.add_prim_attr('buffer_dtype', buffer_dtype)
121
- self.add_prim_attr('side_effect_mem', True)
122
- if context.get_context('device_target') == "Ascend":
123
- self.add_prim_attr('device_target', "CPU")
124
-
125
- def infer_shape(self, data_shape, count_shape, head_shape):
126
- validator.check_value_type("shape of data", data_shape, [tuple, list], self.name)
127
- out_shapes = []
128
- for i in range(self._n):
129
- out_shapes.append((self._batch_size,) + self._buffer_shape[i])
130
- return tuple(out_shapes)
131
-
132
- def infer_dtype(self, data_type, count_type, head_type):
133
- validator.check_type_name("count type", count_type, (mstype.int32), self.name)
134
- validator.check_type_name("head type", head_type, (mstype.int32), self.name)
135
- return tuple(self._buffer_dtype)
136
-
137
-
138
- class BufferAppend(PrimitiveWithInfer):
139
- r"""
140
- In reinforcement learning, the experience data is collected in each step. We use `BufferAppend` to
141
- push data to the bottom of buffer under the First-In-First-Out rule.
142
-
143
- .. warning::
144
- This is an experimental API that is subject to change or deletion.
145
-
146
- Args:
147
- capacity (int64): Capacity of the buffer, must be non-negative.
148
- buffer_shape (tuple(shape)): The shape of an buffer.
149
- buffer_dtype (tuple(type)): The type of an buffer.
150
-
151
- Inputs:
152
- - **data** (tuple(Parameter(Tensor))) - The tuple(Tensor) represents replaybuffer,
153
- each tensor is described by the `buffer_shape` and `buffer_type`.
154
- - **exp** (tuple(Parameter(Tensor))) - The tuple(Tensor) represents one list of experience data,
155
- each tensor is described by the `buffer_shape` and `buffer_type`.
156
- - **count** (Parameter) - The count means the real available size of the buffer,
157
- data type: int32.
158
- - **head** (Parameter) - The position of the first data in buffer, data type: int32.
159
-
160
- Outputs:
161
- None.
162
-
163
- Raises:
164
- ValueError: If `count` and `head` is not an integer.
165
- ValueError: If `capacity` is not a positive integer.
166
- ValueError: If length of `data` is not equal to length of `exp`.
167
- ValueError: If dim of data is equal to dim of exp, but `data[1:]` is not equal to the shape in `exp`.
168
- ValueError: If the shape of `data[1:]` is not equal to the shape in `exp`.
169
- TypeError: If the type in `exp` is not the same with `data`.
170
-
171
- Supported Platforms:
172
- ``GPU`` ``CPU``
173
-
174
- Examples:
175
- >>> capacity = 100
176
- >>> count = Parameter(Tensor(5, ms.int32), name="count")
177
- >>> head = Parameter(Tensor(0, ms.int32), name="head")
178
- >>> shapes = [(4,), (2,), (1,), (4,)]
179
- >>> types = [ms.float32, ms.int32, ms.int32, ms.float32]
180
- >>> buffer = [Parameter(Tensor(np.arange(100 * 4).reshape(100, 4).astype(np.float32)), name="states"),
181
- ... Parameter(Tensor(np.arange(100 * 2).reshape(100, 2).astype(np.int32)), name="action"),
182
- ... Parameter(Tensor(np.ones((100, 1)).astype(np.int32)), name="reward"),
183
- ... Parameter(Tensor(np.arange(100 * 4).reshape(100, 4).astype(np.float32)), name="state_")]
184
- >>> exp = [Tensor(np.array([2, 2, 2, 2]), ms.float32), Tensor(np.array([0, 0]), ms.int32),
185
- ... Tensor(np.array([0]), ms.int32), Tensor(np.array([3, 3, 3, 3]), ms.float32)]
186
- >>> batch_exp = [Tensor(np.array([[2, 2, 2, 2], [2, 2, 2, 2]]), ms.float32),
187
- ... Tensor(np.array([[0, 0], [0, 0]]), ms.int32),
188
- ... Tensor(np.array([[0], [0]]), ms.int32),
189
- ... Tensor(np.array([[3, 3, 3, 3], [3, 3, 3, 3]]), ms.float32)]
190
- >>> buffer_append = ops.BufferAppend(capacity, shapes, types)
191
- >>> buffer_append(buffer, exp, count, head)
192
- >>> buffer_append(buffer, batch_exp, count, head)
193
- """
194
- @prim_attr_register
195
- def __init__(self, capacity, buffer_shape, buffer_dtype):
196
- """Initialize BufferAppend."""
197
- validator.check_int(capacity, 1, validator.GE, "capacity", self.name)
198
- self.add_prim_attr('capacity', capacity)
199
- buffer_elements = []
200
- for shape in buffer_shape:
201
- buffer_elements.append(reduce(lambda x, y: x * y, shape))
202
- self.add_prim_attr('buffer_elements', buffer_elements)
203
- self.add_prim_attr('buffer_dtype', buffer_dtype)
204
- self.add_prim_attr('side_effect_mem', True)
205
- if context.get_context('device_target') == "Ascend":
206
- self.add_prim_attr('device_target', "CPU")
207
-
208
-
209
- class BufferGetItem(PrimitiveWithInfer):
210
- r"""
211
- Get the data from buffer in the position of input index.
212
-
213
- .. warning::
214
- This is an experimental API that is subject to change or deletion.
215
-
216
- Args:
217
- capacity (int64): Capacity of the buffer, must be non-negative.
218
- buffer_shape (tuple(shape)): The shape of an buffer.
219
- buffer_dtype (tuple(type)): The type of an buffer.
220
-
221
- Inputs:
222
- - **data** (tuple(Parameter(Tensor))) - The tuple(Tensor) represents replaybuffer,
223
- each tensor is described by the `buffer_shape` and `buffer_type`.
224
- - **count** (Parameter) - The count means the real available size of the buffer,
225
- data type: int32.
226
- - **head** (Parameter) - The position of the first data in buffer, data type: int32.
227
- - **index** (int64) - The position of the data in buffer.
228
-
229
- Outputs:
230
- tuple(Tensor). The shape is `buffer_shape`. The dtype is `buffer_dtype`.
231
-
232
- Raises:
233
- ValueError: If `count` and `head` is not an integer.
234
- ValueError: If `capacity` is not a positive integer.
235
- TypeError: If `buffer_shape` is not a tuple.
236
-
237
- Supported Platforms:
238
- ``GPU`` ``CPU``
239
-
240
- Examples:
241
- >>> capacity = 100
242
- >>> index = 3
243
- >>> count = Parameter(Tensor(5, ms.int32), name="count")
244
- >>> head = Parameter(Tensor(0, ms.int32), name="head")
245
- >>> shapes = [(4,), (2,), (1,), (4,)]
246
- >>> types = [ms.float32, ms.int32, ms.int32, ms.float32]
247
- >>> buffer = [Parameter(Tensor(np.arange(100 * 4).reshape(100, 4).astype(np.float32)), name="states"),
248
- ... Parameter(Tensor(np.arange(100 * 2).reshape(100, 2).astype(np.int32)), name="action"),
249
- ... Parameter(Tensor(np.ones((100, 1)).astype(np.int32)), name="reward"),
250
- ... Parameter(Tensor(np.arange(100 * 4).reshape(100, 4).astype(np.float32)), name="state_")]
251
- >>> buffer_get = ops.BufferGetItem(capacity, shapes, types)
252
- >>> output = buffer_get(buffer, count, head, index)
253
- >>> print(output)
254
- (Tensor(shape=[4], dtype=Float32, value=
255
- [ 1.20000000e+01, 1.30000000e+01, 1.40000000e+01, 1.50000000e+01]),
256
- Tensor(shape=[2], dtype=Int32, value= [6, 7]),
257
- Tensor(shape=[1], dtype=Int32, value= [1]),
258
- Tensor(shape=[4], dtype=Float32, value=
259
- [ 1.20000000e+01, 1.30000000e+01, 1.40000000e+01, 1.50000000e+01]))
260
-
261
- """
262
- @prim_attr_register
263
- def __init__(self, capacity, buffer_shape, buffer_dtype):
264
- """Initialize BufferGetItem."""
265
- self.init_prim_io_names(inputs=["buffer"], outputs=["item"])
266
- validator.check_int(capacity, 1, validator.GE, "capacity", self.name)
267
- self._buffer_shape = buffer_shape
268
- self._buffer_dtype = buffer_dtype
269
- self._n = len(buffer_shape)
270
- buffer_elements = []
271
- for shape in buffer_shape:
272
- buffer_elements.append(reduce(lambda x, y: x * y, shape))
273
- self.add_prim_attr('buffer_elements', buffer_elements)
274
- self.add_prim_attr('buffer_dtype', buffer_dtype)
275
- self.add_prim_attr('capacity', capacity)
276
- self.add_prim_attr('side_effect_mem', True)
277
- if context.get_context('device_target') == "Ascend":
278
- self.add_prim_attr('device_target', "CPU")
279
-
280
- def infer_shape(self, data_shape, count_shape, head_shape, index_shape):
281
- validator.check_value_type("shape of data", data_shape, [tuple, list], self.name)
282
- return tuple(self._buffer_shape)
283
-
284
- def infer_dtype(self, data_type, count_type, head_type, index_type):
285
- validator.check_type_name("count type", count_type, (mstype.int32), self.name)
286
- validator.check_type_name("head type", head_type, (mstype.int32), self.name)
287
- validator.check_type_name("index type", index_type, (mstype.int64, mstype.int32), self.name)
288
- return tuple(self._buffer_dtype)