mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -1,152 +0,0 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """
16
- TensorsQueue, each element in the queue is a list of tensors.
17
- """
18
- from __future__ import absolute_import
19
-
20
- from mindspore.nn.cell import Cell
21
- from mindspore.ops.operations import _rl_inner_ops as rl_ops
22
- from mindspore import _checkparam as Validator
23
- from mindspore.common import dtype as mstype
24
-
25
-
26
- class TensorsQueue(Cell):
27
- r'''
28
- TensorsQueue: a queue which stores tensors lists.
29
-
30
- .. warning::
31
- This is an experiential prototype that is subject to change and/or deletion.
32
-
33
- Args:
34
- dtype (mindspore.dtype): the data type in the TensorsQueue. Each tensor should have the same dtype.
35
- shapes (tuple[int64]): the shape of each element in TensorsQueue.
36
- size (int): the size of the TensorsQueue.
37
- name (str): the name of this TensorsQueue. Default: "TQ".
38
-
39
- Raises:
40
- TypeError: If `dtype` is not mindspore number type.
41
- ValueError: If `size` is less than 0.
42
- ValueError: If `shapes` size is less than 1.
43
-
44
- Supported Platforms:
45
- ``GPU`` ``CPU``
46
-
47
- Examples:
48
- >>> import mindspore as ms
49
- >>> from mindspore import Tensor
50
- >>> import mindspore.nn as nn
51
- >>> data1 = Tensor([[0, 1], [1, 2]], dtype=ms.float32)
52
- >>> data2 = Tensor([1], dtype=ms.float32)
53
- >>> tq = nn.TensorsQueue(dtype=ms.float32, shapes=((2, 2), (1,)), size=5)
54
- >>> tq.put((data1, data2))
55
- >>> ans = tq.pop()
56
- '''
57
-
58
- def __init__(self, dtype, shapes, size=0, name="TQ"):
59
- """Initialize TensorsQueue"""
60
- super(TensorsQueue, self).__init__()
61
- Validator.check_subclass("dtype", dtype, mstype.number_type + (mstype.bool_,), self.cls_name)
62
- Validator.check_int(size, 0, Validator.GE, "size", self.cls_name)
63
- elements_num = len(shapes)
64
- Validator.check_int(elements_num, 1, Validator.GE, "len(shapes)", self.cls_name)
65
- self.handle_ = rl_ops.TensorsQueueCreate(dtype, shapes, size, name)()
66
- self.tensors_q_put = rl_ops.TensorsQueuePut(dtype, shapes)
67
- self.tensors_q_get = rl_ops.TensorsQueueGet(dtype, shapes)
68
- self.tensors_q_pop = rl_ops.TensorsQueueGet(dtype, shapes, pop_after_get=True)
69
- self.tensors_q_clear = rl_ops.TensorsQueueClear()
70
- self.tensors_q_close = rl_ops.TensorsQueueClose()
71
- self.tensors_q_size = rl_ops.TensorsQueueSize()
72
- self.__is_tensors_queue__ = True
73
-
74
- def put(self, element):
75
- """
76
- Put element(tuple(Tensors)) to TensorsQueue in the end of queue.
77
-
78
- Args:
79
- element (tuple(Tensor) or list[tensor]): The input element.
80
-
81
- Returns:
82
- Bool, true.
83
- """
84
- self.tensors_q_put(self.handle_, element)
85
- return True
86
-
87
- def get(self):
88
- """
89
- Get one element int the front of the TensorsQueue.
90
-
91
- Returns:
92
- tuple(Tensors), the element in TensorsQueue.
93
- """
94
- element = self.tensors_q_get(self.handle_)
95
- return element
96
-
97
- def pop(self):
98
- """
99
- Get one element int the front of the TensorsQueue, and remove it.
100
-
101
- Returns:
102
- tuple(Tensors), the element in TensorsQueue.
103
- """
104
- element = self.tensors_q_pop(self.handle_)
105
- return element
106
-
107
- def __graph_pop__(self):
108
- """
109
- Get one element int the front of the TensorsQueue, and remove it.
110
- This is only used in graph mode.
111
-
112
- Returns:
113
- tuple(Tensors), the element in TensorsQueue.
114
- """
115
- element = self.tensors_q_pop(self.handle_)
116
- return self.handle_, element
117
-
118
- def size(self):
119
- """
120
- Get the used/available size of the TensorsQueue, and remove it.
121
-
122
- Returns:
123
- Tensor(mindspore.int64), the used size of TensorsQueue.
124
- """
125
- size = self.tensors_q_size(self.handle_)
126
- return size
127
-
128
- def close(self):
129
- """
130
- Close the created TensorsQueue.
131
-
132
- .. warning::
133
- Once close the TensorsQueue, every functions belong to this TensorsQueue will be disaviliable.
134
- Every resources created in TensorsQueue will be removed. If this TensorsQueue will be used in next step
135
- or somewhere, eg: next loop, please use `clear` instead.
136
-
137
- Returns:
138
- Bool, true.
139
- """
140
- self.tensors_q_close(self.handle_)
141
- return True
142
-
143
- def clear(self):
144
- """
145
- Clear the created TensorsQueue. Only reset the TensorsQueue, clear the data and reset the size
146
- in TensorsQueue and keep the instance of this TensorsQueue.
147
-
148
- Returns:
149
- Bool, true.
150
- """
151
- self.tensors_q_clear(self.handle_)
152
- return True
@@ -1,145 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """
16
- TensorArray
17
- """
18
- from __future__ import absolute_import
19
-
20
- from mindspore.nn.cell import Cell
21
- from mindspore.ops.operations import _tensor_array as ta
22
- from mindspore import _checkparam as Validator
23
- from mindspore.common import dtype as mstype
24
-
25
-
26
- class TensorArray(Cell):
27
- r"""TensorArray: a dynamic array to store tensors.
28
-
29
- .. warning::
30
- This is an experiential prototype that is subject to change and/or deletion.
31
-
32
- Args:
33
- dtype (mindspore.dtype): the data type in the TensorArray.
34
- element_shape (tuple[int]): the shape of each tensor in a TensorArray.
35
- dynamic_size (bool): if ``true`` , the size of TensorArray can be increased. Default: ``True`` .
36
- size (int): if dynamic_size=False, `size` means the max_size of the TensorArray.
37
- name (str): the name of this TensorArray. Default: ``"TA"`` .
38
-
39
- Supported Platforms:
40
- ``GPU`` ``CPU``
41
-
42
- Examples:
43
- >>> import mindspore
44
- >>> import mindspore.nn as nn
45
- >>> ta = nn.TensorArray(mindspore.int64, ())
46
- >>> ta.write(0, 1)
47
- >>> ta.write(1, 2)
48
- >>> ans = ta.read(1)
49
- >>> print(ans)
50
- 2
51
- >>> s = ta.stack()
52
- >>> print(s)
53
- [1 2]
54
- >>> ta.clear()
55
- >>> ta.write(0, 3)
56
- >>> ans = ta.read(0)
57
- >>> print(ans)
58
- 3
59
- >>> ta.close()
60
- """
61
- def __init__(self, dtype, element_shape, dynamic_size=True, size=0, name="TA"):
62
- """Initialize TensorArray"""
63
- super(TensorArray, self).__init__()
64
- Validator.check_subclass("dtype", dtype, mstype.number_type + (mstype.bool_,), self.cls_name)
65
- Validator.check_int(size, 0, Validator.GE, "size", self.cls_name)
66
- self.handle_ = ta.TensorArray(dtype, element_shape, dynamic_size, size, name)()
67
- self.tensor_array_write = ta.TensorArrayWrite()
68
- self.tensor_array_read = ta.TensorArrayRead(dtype, element_shape)
69
- self.tensor_array_close = ta.TensorArrayClose()
70
- self.tensor_array_clear = ta.TensorArrayClear()
71
- self.tensor_array_stack = ta.TensorArrayStack(dtype, element_shape, dynamic_size, size)
72
- self.tensor_array_size = ta.TensorArraySize()
73
-
74
- def write(self, index, value):
75
- """
76
- Write value(Tensor) to TensorArray in position index.
77
-
78
- Args:
79
- index ([int, mindspore.int64]): The position to write.
80
- value (Tensor): The value to add into the TensorArray.
81
-
82
- Returns:
83
- Bool, true.
84
- """
85
- self.tensor_array_write(self.handle_, index, value)
86
- return True
87
-
88
- def read(self, index):
89
- """
90
- Read tensor form the TensorArray by the given position index.
91
-
92
- Args:
93
- index ([int, mindspore.int64]): The given index to get the tensor.
94
-
95
- Returns:
96
- Tensor, the value in position index.
97
- """
98
- value = self.tensor_array_read(self.handle_, index)
99
- return value
100
-
101
- def close(self):
102
- """
103
- Close the created TensorArray.
104
-
105
- .. warning::
106
- Once close the TensorArray, every functions belong to this TensorArray will be disaviliable.
107
- Every resources created in TensorArray will be removed. If this TensorArray will be used in next step
108
- or somewhere, eg: next loop, please use `clear` instead.
109
-
110
- Returns:
111
- Bool, true.
112
- """
113
- self.tensor_array_close(self.handle_)
114
- return True
115
-
116
- def clear(self):
117
- """
118
- Clear the created TensorArray. Only reset the TensorArray, clear the data and reset the size
119
- in TensorArray and keep the instance of this TensorArray.
120
-
121
- Returns:
122
- Bool, true.
123
- """
124
- self.tensor_array_clear(self.handle_)
125
- return True
126
-
127
- def stack(self):
128
- """
129
- Stack the values in TensorArray into a stacked Tensor.
130
-
131
- Returns:
132
- Tensor, all the values will be stacked into one tensor.
133
- """
134
- ans = self.tensor_array_stack(self.handle_)
135
- return ans
136
-
137
- def size(self):
138
- """
139
- The logical size of TensorArray.
140
-
141
- Returns:
142
- Tensor, the size of TensorArray.
143
- """
144
- size = self.tensor_array_size(self.handle_)
145
- return size
Binary file
Binary file
@@ -1,113 +0,0 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """PriorityReplayBuffer op"""
17
- from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
18
-
19
-
20
- prb_create_op_info = AiCPURegOp("PriorityReplayBufferCreate") \
21
- .fusion_type("OPAQUE") \
22
- .output(0, "handle", "required") \
23
- .attr("capacity", "int") \
24
- .attr("alpha", "float") \
25
- .attr("beta", "float") \
26
- .attr("schema", "listInt") \
27
- .attr("seed0", "int") \
28
- .attr("seed1", "int") \
29
- .dtype_format(DataType.I64_Default) \
30
- .get_op_info()
31
-
32
-
33
- prb_push_op_info = AiCPURegOp("PriorityReplayBufferPush") \
34
- .input(0, "transition", "dynamic") \
35
- .output(0, "handle", "required") \
36
- .attr("handle", "int") \
37
- .dtype_format(DataType.BOOL_Default, DataType.I64_Default) \
38
- .dtype_format(DataType.I8_Default, DataType.I64_Default) \
39
- .dtype_format(DataType.I16_Default, DataType.I64_Default) \
40
- .dtype_format(DataType.I32_Default, DataType.I64_Default) \
41
- .dtype_format(DataType.I64_Default, DataType.I64_Default) \
42
- .dtype_format(DataType.F16_Default, DataType.I64_Default) \
43
- .dtype_format(DataType.U8_Default, DataType.I64_Default) \
44
- .dtype_format(DataType.U16_Default, DataType.I64_Default) \
45
- .dtype_format(DataType.U32_Default, DataType.I64_Default) \
46
- .dtype_format(DataType.U64_Default, DataType.I64_Default) \
47
- .dtype_format(DataType.F32_Default, DataType.I64_Default) \
48
- .get_op_info()
49
-
50
-
51
- prb_sample_op_info = AiCPURegOp("PriorityReplayBufferSample") \
52
- .output(0, "transitions", "dynamic") \
53
- .attr("handle", "int") \
54
- .attr("batch_size", "int") \
55
- .attr("schema", "listInt") \
56
- .dtype_format(DataType.BOOL_Default) \
57
- .dtype_format(DataType.I8_Default) \
58
- .dtype_format(DataType.I16_Default) \
59
- .dtype_format(DataType.I32_Default) \
60
- .dtype_format(DataType.I64_Default) \
61
- .dtype_format(DataType.F16_Default) \
62
- .dtype_format(DataType.U8_Default) \
63
- .dtype_format(DataType.U16_Default) \
64
- .dtype_format(DataType.U32_Default) \
65
- .dtype_format(DataType.U64_Default) \
66
- .dtype_format(DataType.F32_Default) \
67
- .get_op_info()
68
-
69
-
70
- prb_update_op_info = AiCPURegOp("PriorityReplayBufferUpdate") \
71
- .input(0, "indices", "require") \
72
- .input(1, "priorities", "require") \
73
- .output(0, "handle", "require") \
74
- .attr("handle", "int") \
75
- .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I64_Default) \
76
- .get_op_info()
77
-
78
-
79
- prb_destroy_op_info = AiCPURegOp("PriorityReplayBufferDestroy") \
80
- .output(0, "handle", "required") \
81
- .attr("handle", "int") \
82
- .dtype_format(DataType.I64_Default) \
83
- .get_op_info()
84
-
85
-
86
- @op_info_register(prb_create_op_info)
87
- def _prb_create_op_cpu():
88
- """PriorityReplayBufferSample AICPU register"""
89
- return
90
-
91
-
92
- @op_info_register(prb_push_op_info)
93
- def _prb_push_op_cpu():
94
- """PriorityReplayBufferPush AICPU register"""
95
- return
96
-
97
-
98
- @op_info_register(prb_sample_op_info)
99
- def _prb_sample_op_cpu():
100
- """PriorityReplayBufferSample AICPU register"""
101
- return
102
-
103
-
104
- @op_info_register(prb_update_op_info)
105
- def _prb_update_op_cpu():
106
- """PriorityReplayBufferUpdate AICPU register"""
107
- return
108
-
109
-
110
- @op_info_register(prb_destroy_op_info)
111
- def _prb_destroy_op_cpu():
112
- """PriorityReplayBufferDestroy AICPU register"""
113
- return
@@ -1,96 +0,0 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """ReservoirReplayBuffer op"""
17
- from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
18
-
19
-
20
- rrb_create_op_info = AiCPURegOp("ReservoirReplayBufferCreate") \
21
- .fusion_type("OPAQUE") \
22
- .output(0, "handle", "required") \
23
- .attr("capacity", "int") \
24
- .attr("schema", "listInt") \
25
- .attr("seed0", "int") \
26
- .attr("seed1", "int") \
27
- .dtype_format(DataType.I64_Default) \
28
- .get_op_info()
29
-
30
-
31
- rrb_push_op_info = AiCPURegOp("ReservoirReplayBufferPush") \
32
- .input(0, "transition", "dynamic") \
33
- .output(0, "handle", "required") \
34
- .attr("handle", "int") \
35
- .dtype_format(DataType.BOOL_Default, DataType.I64_Default) \
36
- .dtype_format(DataType.I8_Default, DataType.I64_Default) \
37
- .dtype_format(DataType.I16_Default, DataType.I64_Default) \
38
- .dtype_format(DataType.I32_Default, DataType.I64_Default) \
39
- .dtype_format(DataType.I64_Default, DataType.I64_Default) \
40
- .dtype_format(DataType.F16_Default, DataType.I64_Default) \
41
- .dtype_format(DataType.U8_Default, DataType.I64_Default) \
42
- .dtype_format(DataType.U16_Default, DataType.I64_Default) \
43
- .dtype_format(DataType.U32_Default, DataType.I64_Default) \
44
- .dtype_format(DataType.U64_Default, DataType.I64_Default) \
45
- .dtype_format(DataType.F32_Default, DataType.I64_Default) \
46
- .get_op_info()
47
-
48
-
49
- rrb_sample_op_info = AiCPURegOp("ReservoirReplayBufferSample") \
50
- .output(0, "transitions", "dynamic") \
51
- .attr("handle", "int") \
52
- .attr("batch_size", "int") \
53
- .attr("schema", "listInt") \
54
- .dtype_format(DataType.BOOL_Default) \
55
- .dtype_format(DataType.I8_Default) \
56
- .dtype_format(DataType.I16_Default) \
57
- .dtype_format(DataType.I32_Default) \
58
- .dtype_format(DataType.I64_Default) \
59
- .dtype_format(DataType.F16_Default) \
60
- .dtype_format(DataType.U8_Default) \
61
- .dtype_format(DataType.U16_Default) \
62
- .dtype_format(DataType.U32_Default) \
63
- .dtype_format(DataType.U64_Default) \
64
- .dtype_format(DataType.F32_Default) \
65
- .get_op_info()
66
-
67
-
68
- rrb_destroy_op_info = AiCPURegOp("ReservoirReplayBufferDestroy") \
69
- .output(0, "handle", "required") \
70
- .attr("handle", "int") \
71
- .dtype_format(DataType.I64_Default) \
72
- .get_op_info()
73
-
74
-
75
- @op_info_register(rrb_create_op_info)
76
- def _rrb_create_op_cpu():
77
- """ReservoirReplayBufferCreate AICPU register"""
78
- return
79
-
80
-
81
- @op_info_register(rrb_push_op_info)
82
- def _rrb_push_op_cpu():
83
- """ReservoirReplayBufferPush AICPU register"""
84
- return
85
-
86
-
87
- @op_info_register(rrb_sample_op_info)
88
- def _rrb_sample_op_cpu():
89
- """ReservoirReplayBufferSample AICPU register"""
90
- return
91
-
92
-
93
- @op_info_register(rrb_destroy_op_info)
94
- def _rrb_destroy_op_cpu():
95
- """ReservoirReplayBufferDestroy AICPU register"""
96
- return
@@ -1,42 +0,0 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """SparseCross op"""
17
- from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
18
-
19
- sparse_cross_op_info = AiCPURegOp("SparseCross") \
20
- .fusion_type("OPAQUE") \
21
- .attr("N", "int") \
22
- .attr("hashed_output", "bool") \
23
- .attr("hash_key", "int") \
24
- .attr("out_type", "Type") \
25
- .attr("internal_type", "Type") \
26
- .attr("num_buckets", "int") \
27
- .input(0, "indices", "dynamic") \
28
- .input(1, "values", "dynamic") \
29
- .input(2, "shapes", "dynamic") \
30
- .input(3, "dense_inputs", "dynamic") \
31
- .output(0, "output_indices", "required") \
32
- .output(1, "output_values", "required") \
33
- .output(2, "output_shape", "required") \
34
- .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \
35
- DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
36
- .get_op_info()
37
-
38
-
39
- @op_info_register(sparse_cross_op_info)
40
- def _sparse_cross_aicpu():
41
- """SparseCross AiCPU register"""
42
- return
@@ -1,28 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """BufferAppend op"""
16
- from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
17
-
18
- buffer_append_op_info = CpuRegOp("BufferAppend") \
19
- .input(0, "x", "dynamic") \
20
- .output(0, "output", "dynamic") \
21
- .dtype_format(DataType.I32_Default, DataType.I32_Default) \
22
- .get_op_info()
23
-
24
-
25
- @op_info_register(buffer_append_op_info)
26
- def _buffer_append_cpu():
27
- """BufferAppend cpu register"""
28
- return
@@ -1,28 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """BufferGetItem op"""
16
- from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
17
-
18
- buffer_get_op_info = CpuRegOp("BufferGetItem") \
19
- .input(0, "x", "dynamic") \
20
- .output(0, "output", "dynamic") \
21
- .dtype_format(DataType.I32_Default, DataType.I32_Default) \
22
- .get_op_info()
23
-
24
-
25
- @op_info_register(buffer_get_op_info)
26
- def _buffer_get_cpu():
27
- """BufferGetItem cpu register"""
28
- return
@@ -1,28 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """BufferSample op"""
16
- from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
17
-
18
- buffer_sample_op_info = CpuRegOp("BufferSample") \
19
- .input(0, "x", "dynamic") \
20
- .output(0, "output", "dynamic") \
21
- .dtype_format(DataType.I32_Default, DataType.I32_Default) \
22
- .get_op_info()
23
-
24
-
25
- @op_info_register(buffer_sample_op_info)
26
- def _buffer_sample_cpu():
27
- """BufferSample cpu register"""
28
- return
@@ -1,42 +0,0 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """PriorityReplayBuffer op"""
17
- from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
18
-
19
- prb_sample_op_info = CpuRegOp("PriorityReplayBufferSample") \
20
- .input(0, "x", "dynamic") \
21
- .output(0, "y", "dynamic") \
22
- .dtype_format(DataType.I32_Default, DataType.I32_Default) \
23
- .get_op_info()
24
-
25
-
26
- @op_info_register(prb_sample_op_info)
27
- def _prb_sample_op_cpu():
28
- """PriorityReplayBufferSample cpu register"""
29
- return
30
-
31
-
32
- prb_push_op_info = CpuRegOp("PriorityReplayBufferPush") \
33
- .input(0, "x", "dynamic") \
34
- .output(0, "y", "dynamic") \
35
- .dtype_format(DataType.I32_Default, DataType.I32_Default) \
36
- .get_op_info()
37
-
38
-
39
- @op_info_register(prb_push_op_info)
40
- def _prb_push_op_cpu():
41
- """PriorityReplayBufferPush cpu register"""
42
- return