mindspore 2.3.0__cp39-none-any.whl → 2.3.0rc2__cp39-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (423) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +0 -1512
  3. mindspore/__init__.py +1 -2
  4. mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
  5. mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
  6. mindspore/_c_mindrecord.cpython-39-aarch64-linux-gnu.so +0 -0
  7. mindspore/_checkparam.py +25 -5
  8. mindspore/_extends/graph_kernel/model/graph_parallel.py +1 -1
  9. mindspore/_extends/parse/__init__.py +2 -2
  10. mindspore/_extends/parse/compile_config.py +0 -29
  11. mindspore/_extends/parse/namespace.py +2 -2
  12. mindspore/_extends/parse/parser.py +5 -21
  13. mindspore/_extends/parse/resources.py +7 -5
  14. mindspore/_extends/parse/standard_method.py +59 -40
  15. mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
  16. mindspore/amp.py +5 -26
  17. mindspore/bin/cache_admin +0 -0
  18. mindspore/bin/cache_server +0 -0
  19. mindspore/boost/adasum.py +1 -1
  20. mindspore/boost/base.py +1 -1
  21. mindspore/boost/boost_cell_wrapper.py +1 -1
  22. mindspore/boost/grad_freeze.py +2 -2
  23. mindspore/boost/less_batch_normalization.py +6 -9
  24. mindspore/common/__init__.py +1 -8
  25. mindspore/common/_register_for_tensor.py +9 -8
  26. mindspore/common/api.py +65 -275
  27. mindspore/common/dtype.py +4 -8
  28. mindspore/common/dump.py +5 -2
  29. mindspore/common/jit_config.py +1 -1
  30. mindspore/common/lazy_inline.py +2 -14
  31. mindspore/common/parameter.py +15 -14
  32. mindspore/common/recompute.py +5 -20
  33. mindspore/common/sparse_tensor.py +6 -21
  34. mindspore/common/tensor.py +52 -100
  35. mindspore/communication/__init__.py +11 -6
  36. mindspore/communication/management.py +94 -92
  37. mindspore/context.py +18 -180
  38. mindspore/dataset/engine/datasets.py +46 -69
  39. mindspore/dataset/engine/datasets_user_defined.py +53 -72
  40. mindspore/dataset/engine/datasets_vision.py +2 -2
  41. mindspore/dataset/engine/queue.py +38 -56
  42. mindspore/dataset/engine/validators.py +5 -11
  43. mindspore/dataset/vision/__init__.py +5 -5
  44. mindspore/dataset/vision/c_transforms.py +5 -5
  45. mindspore/dataset/vision/py_transforms_util.py +1 -1
  46. mindspore/dataset/vision/transforms.py +46 -591
  47. mindspore/dataset/vision/utils.py +1 -121
  48. mindspore/dataset/vision/validators.py +3 -9
  49. mindspore/hal/__init__.py +1 -7
  50. mindspore/hal/device.py +1 -1
  51. mindspore/include/api/model.h +0 -3
  52. mindspore/include/dataset/vision.h +2 -54
  53. mindspore/include/mindapi/base/types.h +0 -1
  54. mindspore/lib/libdnnl.so.2 +0 -0
  55. mindspore/lib/libmindspore.so +0 -0
  56. mindspore/lib/libmindspore_backend.so +0 -0
  57. mindspore/lib/libmindspore_common.so +0 -0
  58. mindspore/lib/libmindspore_core.so +0 -0
  59. mindspore/lib/libmindspore_glog.so.0 +0 -0
  60. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  61. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  62. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  63. mindspore/lib/libmindspore_shared_lib.so +0 -0
  64. mindspore/lib/libmpi_adapter.so +0 -0
  65. mindspore/lib/libmpi_collective.so +0 -0
  66. mindspore/lib/libnnacl.so +0 -0
  67. mindspore/lib/libopencv_core.so.4.5 +0 -0
  68. mindspore/lib/libps_cache.so +0 -0
  69. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +0 -35
  70. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
  71. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
  72. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  73. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +0 -72
  74. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  75. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/{aclnn_all_finite.h → aclnn_add_custom.h} +11 -9
  76. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_decoder_kv_cache.h +1 -1
  77. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_prompt_kv_cache.h +1 -1
  78. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/lib/libcust_opapi.so +0 -0
  79. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +12 -184
  80. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +15 -7
  81. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +15 -7
  82. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.cpp +81 -0
  83. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.py +134 -0
  84. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.py +31 -77
  85. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.py +31 -77
  86. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/lib/linux/aarch64/libcust_opmaster_rt2.0.so +0 -0
  87. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/liboptiling.so +0 -0
  88. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/inc/op_proto.h +5 -4
  89. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/lib/linux/aarch64/libcust_opsproto_rt2.0.so +0 -0
  90. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  91. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  92. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  93. mindspore/lib/plugin/ascend/liblowlatency_collective.so +0 -0
  94. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  95. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/DeviceBin +0 -0
  96. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/PkgInspect +0 -0
  97. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/op_man +0 -0
  98. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/device/ascend910b/bin/ascend910b.bin +286 -275
  99. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_cann_host.so +0 -0
  100. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_host.so +0 -0
  101. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
  102. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
  103. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/add/add_impl.h +0 -1
  104. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/apply_rotary_pos_emb_impl.h +0 -1
  105. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/asdop/asd_op_impl.h +0 -3
  106. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/backend_param.h +0 -5
  107. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/cast/cast_tiling.h +45 -1
  108. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/compare/compare_impl.h +0 -1
  109. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/flash_attention_score/flash_attention_score_impl.h +4 -8
  110. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/flash_attention_score/flash_attention_score_tiling.h +4 -11
  111. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/flash_attention_score/kernel/flash_attention_score_mix_hwsync.h +0 -18
  112. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/internal_kernel.h +0 -6
  113. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/internal_rtbackend.h +75 -1
  114. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/kernel/matmul.h +5 -5
  115. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/matmul_impl.h +3 -18
  116. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/pp_matmul_common_tiling.h +5 -5
  117. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/pp_matmul_info.h +2 -2
  118. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/tiling_data.h +3 -36
  119. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/kernel/matmul_stridedslice_fusion.h +2 -2
  120. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/matmul_stridedslice_fusion_impl.h +4 -22
  121. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/op_param.h +2 -16
  122. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/kernel/paged_attention_mix_hwsync.h +3 -1
  123. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_impl.h +4 -5
  124. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_tiling.h +4 -9
  125. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/attention_param.h +2 -5
  126. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_ext_param.h +0 -1
  127. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_qkv_param.h +4 -10
  128. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/sub_param.h +12 -0
  129. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/rms_norm_impl.h +0 -1
  130. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_impl.h +0 -1
  131. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/tune_repo/matmul_table.h +1 -1
  132. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/backend.h +2 -10
  133. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/elewise_utils.h +1 -5
  134. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log.h +0 -1
  135. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_tiling.h +0 -17
  136. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/math.h +7 -2
  137. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libAdd_impl.so +0 -0
  138. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libSub_impl.so +0 -0
  139. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_layernorm_impl.so +0 -0
  140. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_impl.so +0 -0
  141. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_impl.so +0 -0
  142. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libcast_impl.so +0 -0
  143. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libgelu_impl.so +0 -0
  144. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_impl.so +0 -0
  145. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_stridedslice_fusion_impl.so +0 -0
  146. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
  147. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libnot_equal_impl.so +0 -0
  148. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_impl.so +0 -0
  149. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/librms_norm_impl.so +0 -0
  150. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
  151. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
  152. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_full_mix.o +0 -0
  153. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
  154. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
  155. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
  156. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_full_mix.o +0 -0
  157. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
  158. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bnsd_full_mix.o +0 -0
  159. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bsh_full_mix.o +0 -0
  160. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bnsd_full_mix.o +0 -0
  161. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bsh_full_mix.o +0 -0
  162. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
  163. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblccl_wrapper.so +0 -0
  164. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  165. mindspore/mindrecord/filewriter.py +2 -2
  166. mindspore/mint/__init__.py +40 -720
  167. mindspore/mint/nn/__init__.py +7 -89
  168. mindspore/mint/nn/functional.py +16 -165
  169. mindspore/mint/optim/adamw.py +16 -15
  170. mindspore/nn/__init__.py +2 -0
  171. mindspore/nn/cell.py +98 -97
  172. mindspore/nn/extend/basic.py +2 -2
  173. mindspore/nn/extend/embedding.py +1 -1
  174. mindspore/nn/extend/layer/normalization.py +5 -7
  175. mindspore/nn/generator.py +297 -0
  176. mindspore/nn/layer/activation.py +3 -4
  177. mindspore/nn/layer/basic.py +16 -79
  178. mindspore/nn/layer/conv.py +8 -17
  179. mindspore/nn/layer/embedding.py +4 -1
  180. mindspore/nn/layer/math.py +1 -1
  181. mindspore/nn/layer/normalization.py +1 -1
  182. mindspore/nn/layer/pooling.py +0 -5
  183. mindspore/nn/layer/rnn_cells.py +2 -2
  184. mindspore/nn/loss/loss.py +19 -19
  185. mindspore/nn/optim/adasum.py +1 -1
  186. mindspore/nn/optim/sgd.py +2 -3
  187. mindspore/nn/probability/distribution/exponential.py +1 -1
  188. mindspore/nn/probability/distribution/geometric.py +1 -1
  189. mindspore/nn/probability/distribution/logistic.py +1 -1
  190. mindspore/nn/wrap/cell_wrapper.py +1 -25
  191. mindspore/nn/wrap/loss_scale.py +1 -24
  192. mindspore/numpy/array_ops.py +1 -5
  193. mindspore/numpy/dtypes.py +3 -3
  194. mindspore/numpy/math_ops.py +8 -8
  195. mindspore/ops/__init__.py +1 -1
  196. mindspore/ops/_grad_experimental/grad_comm_ops.py +16 -75
  197. mindspore/ops/_vmap/vmap_array_ops.py +0 -27
  198. mindspore/ops/_vmap/vmap_math_ops.py +1 -29
  199. mindspore/ops/_vmap/vmap_nn_ops.py +18 -19
  200. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +8 -34
  201. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +9 -2
  202. mindspore/ops/auto_generate/gen_arg_handler.py +0 -26
  203. mindspore/ops/auto_generate/gen_extend_func.py +27 -603
  204. mindspore/ops/auto_generate/gen_ops_def.py +203 -993
  205. mindspore/ops/auto_generate/gen_ops_prim.py +402 -1946
  206. mindspore/ops/auto_generate/pyboost_inner_prim.py +20 -90
  207. mindspore/ops/composite/base.py +6 -3
  208. mindspore/ops/composite/math_ops.py +1 -1
  209. mindspore/ops/composite/multitype_ops/_compile_utils.py +17 -24
  210. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  211. mindspore/ops/extend/__init__.py +3 -2
  212. mindspore/ops/extend/array_func.py +51 -10
  213. mindspore/ops/extend/nn_func.py +78 -2
  214. mindspore/ops/function/__init__.py +13 -8
  215. mindspore/ops/function/array_func.py +179 -455
  216. mindspore/ops/function/clip_func.py +1 -1
  217. mindspore/ops/function/grad/grad_func.py +3 -3
  218. mindspore/ops/function/math_func.py +103 -117
  219. mindspore/ops/function/nn_func.py +163 -275
  220. mindspore/ops/function/other_func.py +2 -2
  221. mindspore/ops/function/random_func.py +69 -202
  222. mindspore/ops/function/sparse_func.py +4 -4
  223. mindspore/ops/functional.py +327 -332
  224. mindspore/ops/operations/__init__.py +3 -13
  225. mindspore/ops/operations/_grad_ops.py +27 -3
  226. mindspore/ops/operations/_inner_ops.py +356 -53
  227. mindspore/ops/operations/_rl_inner_ops.py +2 -2
  228. mindspore/ops/operations/_tensor_array.py +8 -8
  229. mindspore/ops/operations/array_ops.py +65 -82
  230. mindspore/ops/operations/comm_ops.py +93 -784
  231. mindspore/ops/operations/custom_ops.py +28 -51
  232. mindspore/ops/operations/debug_ops.py +4 -4
  233. mindspore/ops/operations/inner_ops.py +2 -2
  234. mindspore/ops/operations/manually_defined/ops_def.py +4 -304
  235. mindspore/ops/operations/math_ops.py +50 -3
  236. mindspore/ops/operations/nn_ops.py +247 -14
  237. mindspore/ops/operations/other_ops.py +3 -3
  238. mindspore/ops/operations/random_ops.py +1 -1
  239. mindspore/ops/operations/sparse_ops.py +1 -1
  240. mindspore/ops/primitive.py +8 -9
  241. mindspore/ops/silent_check.py +5 -5
  242. mindspore/ops_generate/arg_dtype_cast.py +9 -2
  243. mindspore/ops_generate/arg_handler.py +0 -26
  244. mindspore/ops_generate/gen_aclnn_implement.py +4 -1
  245. mindspore/ops_generate/gen_ops.py +4 -26
  246. mindspore/ops_generate/gen_pyboost_func.py +12 -41
  247. mindspore/ops_generate/gen_utils.py +0 -21
  248. mindspore/ops_generate/pyboost_utils.py +2 -7
  249. mindspore/ops_generate/template.py +0 -1
  250. mindspore/parallel/_auto_parallel_context.py +1 -21
  251. mindspore/parallel/_tensor.py +5 -0
  252. mindspore/parallel/_transformer/transformer.py +1 -1
  253. mindspore/parallel/_utils.py +1 -15
  254. mindspore/parallel/algo_parameter_config.py +3 -1
  255. mindspore/parallel/checkpoint_transform.py +9 -12
  256. mindspore/parallel/cluster/process_entity/_api.py +29 -28
  257. mindspore/parallel/cluster/process_entity/_utils.py +3 -13
  258. mindspore/parallel/cluster/run.py +16 -13
  259. mindspore/parallel/parameter_broadcast.py +2 -2
  260. mindspore/parallel/shard.py +17 -31
  261. mindspore/profiler/__init__.py +2 -3
  262. mindspore/profiler/common/util.py +2 -107
  263. mindspore/profiler/envprofiling.py +1 -1
  264. mindspore/profiler/parser/ascend_analysis/constant.py +21 -8
  265. mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -82
  266. mindspore/profiler/parser/ascend_analysis/function_event.py +28 -43
  267. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +27 -49
  268. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +10 -15
  269. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +20 -25
  270. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +5 -5
  271. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +1 -10
  272. mindspore/profiler/parser/ascend_hccl_generator.py +1 -4
  273. mindspore/profiler/parser/ascend_msprof_exporter.py +22 -43
  274. mindspore/profiler/parser/ascend_timeline_generator.py +5 -7
  275. mindspore/profiler/parser/minddata_parser.py +3 -72
  276. mindspore/profiler/profiling.py +59 -176
  277. mindspore/rewrite/api/node.py +1 -1
  278. mindspore/rewrite/common/namespace.py +5 -5
  279. mindspore/rewrite/parsers/assign_parser.py +0 -2
  280. mindspore/rewrite/parsers/class_def_parser.py +4 -8
  281. mindspore/run_check/_check_version.py +1 -1
  282. mindspore/scipy/fft.py +3 -1
  283. mindspore/scipy/linalg.py +3 -2
  284. mindspore/scipy/ops.py +3 -5
  285. mindspore/scipy/optimize/__init__.py +2 -2
  286. mindspore/train/__init__.py +4 -4
  287. mindspore/train/anf_ir_pb2.py +2 -8
  288. mindspore/train/callback/__init__.py +2 -5
  289. mindspore/train/callback/_backup_and_restore.py +2 -2
  290. mindspore/train/callback/_checkpoint.py +16 -104
  291. mindspore/train/callback/_landscape.py +1 -1
  292. mindspore/train/callback/_time_monitor.py +1 -1
  293. mindspore/train/data_sink.py +4 -5
  294. mindspore/train/dataset_helper.py +20 -45
  295. mindspore/train/model.py +38 -266
  296. mindspore/train/serialization.py +105 -256
  297. mindspore/train/summary/_summary_adapter.py +1 -1
  298. mindspore/version.py +1 -1
  299. {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +2 -2
  300. {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +303 -420
  301. mindspore/_extends/pijit/__init__.py +0 -23
  302. mindspore/_extends/pijit/pijit_func_white_list.py +0 -343
  303. mindspore/common/file_system.py +0 -48
  304. mindspore/common/generator.py +0 -260
  305. mindspore/common/no_inline.py +0 -54
  306. mindspore/common/np_dtype.py +0 -25
  307. mindspore/communication/comm_func.py +0 -1140
  308. mindspore/hal/memory.py +0 -326
  309. mindspore/lib/libavcodec.so.59 +0 -0
  310. mindspore/lib/libavdevice.so.59 +0 -0
  311. mindspore/lib/libavfilter.so.8 +0 -0
  312. mindspore/lib/libavformat.so.59 +0 -0
  313. mindspore/lib/libavutil.so.57 +0 -0
  314. mindspore/lib/libmindspore_np_dtype.so +0 -0
  315. mindspore/lib/libswresample.so.4 +0 -0
  316. mindspore/lib/libswscale.so.6 +0 -0
  317. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/all_finite.cpp +0 -326
  318. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/all_finite.py +0 -180
  319. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_576ceaeef5870c451cab59af55ea46ad.json +0 -58
  320. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_576ceaeef5870c451cab59af55ea46ad.o +0 -0
  321. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_86a73ff6e28d734c96bb8d3054f7dd18.json +0 -58
  322. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_86a73ff6e28d734c96bb8d3054f7dd18.o +0 -0
  323. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_f55e0ebaad1f2f572e43677336992fa0.json +0 -58
  324. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_f55e0ebaad1f2f572e43677336992fa0.o +0 -0
  325. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/all_finite.json +0 -109
  326. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/binary_info_config.json +0 -38
  327. mindspore/lib/plugin/ascend/custom_compiler/OWNERS +0 -12
  328. mindspore/lib/plugin/ascend/custom_compiler/setup.py +0 -255
  329. mindspore/lib/plugin/ascend/custom_compiler/start.sh +0 -26
  330. mindspore/lib/plugin/ascend/custom_compiler/template.json +0 -40
  331. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/acme.h +0 -24
  332. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/acme_op.h +0 -69
  333. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/base_type.h +0 -133
  334. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/op_creator.h +0 -32
  335. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/op_param.h +0 -35
  336. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/tiling_info.h +0 -60
  337. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/core/kernel_register.h +0 -37
  338. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/core/platform/platform_configs.h +0 -89
  339. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/core/platform/rt_funcs.h +0 -135
  340. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/add_op.h +0 -34
  341. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_backoff_base.h +0 -62
  342. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_elewise_op.h +0 -33
  343. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_ops.h +0 -88
  344. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_pa_op.h +0 -45
  345. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/cast_op.h +0 -52
  346. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/matmul_op.h +0 -95
  347. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/utils/asd_utils.h +0 -84
  348. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/utils/comm_utils.h +0 -61
  349. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_fp32.h +0 -224
  350. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/and_impl.h +0 -29
  351. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/div_impl.h +0 -29
  352. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/elewise_binary_impl.h +0 -48
  353. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/elewise_binary_tiling.h +0 -25
  354. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/and_kernel.h +0 -46
  355. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/div_kernel.h +0 -46
  356. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/elewise_binary_base.h +0 -260
  357. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/elewise_binary_kernel.h +0 -35
  358. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/max_kernel.h +0 -66
  359. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/min_kernel.h +0 -66
  360. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/mul_kernel.h +0 -66
  361. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/or_kernel.h +0 -46
  362. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/max_impl.h +0 -29
  363. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/min_impl.h +0 -29
  364. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/mul_impl.h +0 -29
  365. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/or_impl.h +0 -29
  366. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/abs_impl.h +0 -29
  367. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/elewise_unary_impl.h +0 -47
  368. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/elewise_unary_tiling.h +0 -24
  369. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/exp_impl.h +0 -29
  370. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/abs_kernel.h +0 -45
  371. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/elewise_unary_base.h +0 -148
  372. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/elewise_unary_kernel.h +0 -31
  373. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/exp_kernel.h +0 -45
  374. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/ln_kernel.h +0 -45
  375. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/not_kernel.h +0 -45
  376. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/reciprocal_kernel.h +0 -45
  377. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/relu_kernel.h +0 -55
  378. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/rsqrt_kernel.h +0 -45
  379. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/sqrt_kernel.h +0 -45
  380. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/ln_impl.h +0 -29
  381. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/not_impl.h +0 -29
  382. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/reciprocal_impl.h +0 -29
  383. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/relu_impl.h +0 -29
  384. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/rsqrt_impl.h +0 -29
  385. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/sqrt_impl.h +0 -29
  386. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/grouped_matmul_impl.h +0 -45
  387. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/grouped_matmul_tiling.h +0 -187
  388. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/kernel/grouped_matmul.h +0 -245
  389. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/kernel/grouped_matmul_interface.h +0 -24
  390. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/kernel/grouped_matmul_utils.h +0 -111
  391. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/tiling_data.h +0 -54
  392. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/compare_param.h +0 -31
  393. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/elewise_param.h +0 -41
  394. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/grouped_matmul_param.h +0 -40
  395. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/profiling_util.h +0 -364
  396. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_utils.h +0 -69
  397. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/register/kernel_creator.h +0 -39
  398. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/register/kernel_registry.h +0 -114
  399. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/utils.h +0 -98
  400. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix.json +0 -19
  401. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix.o +0 -0
  402. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix_mix_aic_0.o +0 -0
  403. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix_mix_aiv_0.o +0 -0
  404. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix.json +0 -19
  405. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix.o +0 -0
  406. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix_mix_aic_0.o +0 -0
  407. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix_mix_aiv_0.o +0 -0
  408. mindspore/mint/linalg/__init__.py +0 -22
  409. mindspore/nn/layer/embedding_service.py +0 -531
  410. mindspore/nn/layer/embedding_service_layer.py +0 -393
  411. mindspore/ops/function/reshard_func.py +0 -102
  412. mindspore/ops/operations/_infer_ops.py +0 -19
  413. mindspore/ops/operations/reshard_ops.py +0 -53
  414. mindspore/profiler/common/process_pool.py +0 -41
  415. mindspore/profiler/common/singleton.py +0 -28
  416. mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
  417. mindspore/profiler/parser/ascend_memory_generator.py +0 -185
  418. mindspore/train/callback/_cluster_monitor.py +0 -201
  419. mindspore/train/callback/_flops_collector.py +0 -238
  420. mindspore/train/callback/_mindio_ttp.py +0 -443
  421. {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
  422. {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
  423. {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
@@ -19,14 +19,10 @@ from mindspore._c_expression import ArgMaxWithValuePrim_
19
19
  from mindspore._c_expression import ArgMinWithValuePrim_
20
20
  from mindspore._c_expression import BatchMatMulPrim_
21
21
  from mindspore._c_expression import BatchNormGradExtPrim_
22
- from mindspore._c_expression import BinaryCrossEntropyGradPrim_
23
- from mindspore._c_expression import BinaryCrossEntropyPrim_
24
- from mindspore._c_expression import BCEWithLogitsLossPrim_
25
22
  from mindspore._c_expression import BroadcastToPrim_
26
23
  from mindspore._c_expression import ConcatPrim_
27
24
  from mindspore._c_expression import ConvolutionGradPrim_
28
25
  from mindspore._c_expression import ConvolutionPrim_
29
- from mindspore._c_expression import EluExtPrim_
30
26
  from mindspore._c_expression import FFNExtPrim_
31
27
  from mindspore._c_expression import FlashAttentionScoreGradPrim_
32
28
  from mindspore._c_expression import FlashAttentionScorePrim_
@@ -34,25 +30,22 @@ from mindspore._c_expression import GridSampler2DGradPrim_
34
30
  from mindspore._c_expression import GridSampler2DPrim_
35
31
  from mindspore._c_expression import GridSampler3DGradPrim_
36
32
  from mindspore._c_expression import GridSampler3DPrim_
37
- from mindspore._c_expression import IsClosePrim_
38
33
  from mindspore._c_expression import MatMulPrim_
39
34
  from mindspore._c_expression import MaxPoolGradWithIndicesPrim_
40
35
  from mindspore._c_expression import MaxPoolGradWithMaskPrim_
41
36
  from mindspore._c_expression import MaxPoolWithIndicesPrim_
42
37
  from mindspore._c_expression import MaxPoolWithMaskPrim_
43
38
  from mindspore._c_expression import OneHotExtPrim_
39
+ from mindspore._c_expression import QuantBatchMatmulPrim_
44
40
  from mindspore._c_expression import ReduceAllPrim_
45
41
  from mindspore._c_expression import ReduceAnyPrim_
46
42
  from mindspore._c_expression import ReverseV2Prim_
47
- from mindspore._c_expression import RmsNormPrim_
48
- from mindspore._c_expression import SearchSortedPrim_
49
43
  from mindspore._c_expression import SoftmaxPrim_
50
44
  from mindspore._c_expression import StackExtPrim_
45
+ from mindspore._c_expression import TrilPrim_
51
46
  from mindspore._c_expression import TriuPrim_
52
47
  from mindspore._c_expression import UpsampleTrilinear3DGradPrim_
53
48
  from mindspore._c_expression import UpsampleTrilinear3DPrim_
54
- from mindspore._c_expression import GroupedMatmulPrim_
55
- from mindspore._c_expression import QuantBatchMatmulPrim_
56
49
  from mindspore._c_expression import WeightQuantBatchMatmulPrim_
57
50
 
58
51
 
@@ -92,33 +85,6 @@ class _PyboostBatchNormGradExtPrim(BatchNormGradExtPrim_):
92
85
  batch_norm_grad_ext_impl = _PyboostBatchNormGradExtPrim()
93
86
 
94
87
 
95
- class _PyboostBinaryCrossEntropyGradPrim(BinaryCrossEntropyGradPrim_):
96
- def __call__(self, input, target, grad_output, weight, reduction):
97
- converted_reduction = str_to_enum(reduction)
98
- return _convert_stub(super().__call__(input, target, grad_output, weight, reduction))
99
-
100
-
101
- binary_cross_entropy_grad_impl = _PyboostBinaryCrossEntropyGradPrim()
102
-
103
-
104
- class _PyboostBinaryCrossEntropyPrim(BinaryCrossEntropyPrim_):
105
- def __call__(self, input, target, weight, reduction):
106
- converted_reduction = str_to_enum(reduction)
107
- return _convert_stub(super().__call__(input, target, weight, reduction))
108
-
109
-
110
- binary_cross_entropy_impl = _PyboostBinaryCrossEntropyPrim()
111
-
112
-
113
- class _PyboostBCEWithLogitsLossPrim(BCEWithLogitsLossPrim_):
114
- def __call__(self, input, target, weight, posWeight, reduction):
115
- converted_reduction = str_to_enum(reduction)
116
- return _convert_stub(super().__call__(input, target, weight, posWeight, reduction))
117
-
118
-
119
- binary_cross_entropy_with_logits_impl = _PyboostBCEWithLogitsLossPrim()
120
-
121
-
122
88
  class _PyboostBroadcastToPrim(BroadcastToPrim_):
123
89
  def __call__(self, input, shape):
124
90
 
@@ -161,15 +127,6 @@ class _PyboostConvolutionPrim(ConvolutionPrim_):
161
127
  convolution_impl = _PyboostConvolutionPrim()
162
128
 
163
129
 
164
- class _PyboostEluExtPrim(EluExtPrim_):
165
- def __call__(self, input, alpha):
166
-
167
- return _convert_stub(super().__call__(input, alpha))
168
-
169
-
170
- elu_ext_impl = _PyboostEluExtPrim()
171
-
172
-
173
130
  class _PyboostFFNExtPrim(FFNExtPrim_):
174
131
  def __call__(self, x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, activation, inner_precise):
175
132
  converted_activation = str_to_enum(activation)
@@ -237,15 +194,6 @@ class _PyboostGridSampler3DPrim(GridSampler3DPrim_):
237
194
  grid_sampler_3d_impl = _PyboostGridSampler3DPrim()
238
195
 
239
196
 
240
- class _PyboostIsClosePrim(IsClosePrim_):
241
- def __call__(self, input, other, rtol, atol, equal_nan):
242
-
243
- return _convert_stub(super().__call__(input, other, rtol, atol, equal_nan))
244
-
245
-
246
- isclose_impl = _PyboostIsClosePrim()
247
-
248
-
249
197
  class _PyboostMatMulPrim(MatMulPrim_):
250
198
  def __call__(self, input, mat2, transpose_a, transpose_b):
251
199
 
@@ -312,6 +260,15 @@ class _PyboostOneHotExtPrim(OneHotExtPrim_):
312
260
  one_hot_ext_impl = _PyboostOneHotExtPrim()
313
261
 
314
262
 
263
+ class _PyboostQuantBatchMatmulPrim(QuantBatchMatmulPrim_):
264
+ def __call__(self, x1, x2, scale, offset, bias, transpose_x1, transpose_x2, dtype):
265
+
266
+ return _convert_stub(super().__call__(x1, x2, scale, offset, bias, transpose_x1, transpose_x2, dtype))
267
+
268
+
269
+ quant_batch_matmul_impl = _PyboostQuantBatchMatmulPrim()
270
+
271
+
315
272
  class _PyboostReduceAllPrim(ReduceAllPrim_):
316
273
  def __call__(self, input, axis, keep_dims):
317
274
 
@@ -339,24 +296,6 @@ class _PyboostReverseV2Prim(ReverseV2Prim_):
339
296
  reverse_v2_impl = _PyboostReverseV2Prim()
340
297
 
341
298
 
342
- class _PyboostRmsNormPrim(RmsNormPrim_):
343
- def __call__(self, x, gamma, epsilon):
344
-
345
- return _convert_stub(super().__call__(x, gamma, epsilon))
346
-
347
-
348
- rms_norm_impl = _PyboostRmsNormPrim()
349
-
350
-
351
- class _PyboostSearchSortedPrim(SearchSortedPrim_):
352
- def __call__(self, sorted_sequence, values, sorter, dtype, right):
353
-
354
- return _convert_stub(super().__call__(sorted_sequence, values, sorter, dtype, right))
355
-
356
-
357
- searchsorted_impl = _PyboostSearchSortedPrim()
358
-
359
-
360
299
  class _PyboostSoftmaxPrim(SoftmaxPrim_):
361
300
  def __call__(self, input, axis):
362
301
 
@@ -375,6 +314,15 @@ class _PyboostStackExtPrim(StackExtPrim_):
375
314
  stack_ext_impl = _PyboostStackExtPrim()
376
315
 
377
316
 
317
+ class _PyboostTrilPrim(TrilPrim_):
318
+ def __call__(self, input, diagonal):
319
+
320
+ return _convert_stub(super().__call__(input, diagonal))
321
+
322
+
323
+ tril_impl = _PyboostTrilPrim()
324
+
325
+
378
326
  class _PyboostTriuPrim(TriuPrim_):
379
327
  def __call__(self, input, diagonal):
380
328
 
@@ -402,24 +350,6 @@ class _PyboostUpsampleTrilinear3DPrim(UpsampleTrilinear3DPrim_):
402
350
  upsample_trilinear3d_impl = _PyboostUpsampleTrilinear3DPrim()
403
351
 
404
352
 
405
- class _PyboostGroupedMatmulPrim(GroupedMatmulPrim_):
406
- def __call__(self, x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type):
407
-
408
- return _convert_stub(super().__call__(x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type))
409
-
410
-
411
- grouped_matmul_impl = _PyboostGroupedMatmulPrim()
412
-
413
-
414
- class _PyboostQuantBatchMatmulPrim(QuantBatchMatmulPrim_):
415
- def __call__(self, x1, x2, scale, offset, bias, transpose_x1, transpose_x2, dtype):
416
-
417
- return _convert_stub(super().__call__(x1, x2, scale, offset, bias, transpose_x1, transpose_x2, dtype))
418
-
419
-
420
- quant_batch_matmul_impl = _PyboostQuantBatchMatmulPrim()
421
-
422
-
423
353
  class _PyboostWeightQuantBatchMatmulPrim(WeightQuantBatchMatmulPrim_):
424
354
  def __call__(self, x, weight, antiquant_scale, antiquant_offset, quant_scale, quant_offset, bias, transpose_x, transpose_weight, antiquant_group_size):
425
355
 
@@ -414,8 +414,9 @@ class GradOperation(GradOperation_):
414
414
  else:
415
415
  # Check if fn have run already
416
416
  if not _pynative_executor.check_run(grad, fn, weights, None, *args, **new_kwargs):
417
- _pynative_executor.set_grad_flag(True)
417
+ fn.set_grad()
418
418
  fn(*args, **new_kwargs)
419
+ fn.set_grad(False)
419
420
 
420
421
 
421
422
  class _TaylorOperation(TaylorOperation_):
@@ -653,8 +654,9 @@ class _Grad(GradOperation_):
653
654
  else:
654
655
  # Check if fn has run already.
655
656
  if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *args, **new_kwargs):
656
- _pynative_executor.set_grad_flag(True)
657
+ fn.set_grad()
657
658
  outputs = fn(*args, **new_kwargs)
659
+ fn.set_grad(False)
658
660
  return outputs
659
661
  if (self.get_value or self.has_aux) and not outputs:
660
662
  outputs = fn(*args, **new_kwargs)
@@ -713,8 +715,9 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
713
715
  >>> # `add` is a metagraph object which will add two objects according to
714
716
  >>> # input type using ".register" decorator.
715
717
  >>> from mindspore import Tensor
716
- >>> from mindspore import dtype as mstype
717
718
  >>> from mindspore import ops
719
+ >>> from mindspore import dtype as mstype
720
+ >>> import mindspore.ops as ops
718
721
  >>>
719
722
  >>> tensor_add = ops.Add()
720
723
  >>> add = ops.MultitypeFuncGraph('add')
@@ -112,7 +112,7 @@ def mm(input, mat2):
112
112
 
113
113
  Examples:
114
114
  >>> import mindspore as ms
115
- >>> from mindspore import ops
115
+ >>> import mindspore.ops as ops
116
116
  >>> import numpy as np
117
117
  >>> x1 = ms.Tensor(np.random.rand(2, 3), ms.float32)
118
118
  >>> x2 = ms.Tensor(np.random.rand(3, 4), ms.float32)
@@ -218,8 +218,8 @@ def _tensor_setitem(self, index, value):
218
218
  return output
219
219
 
220
220
 
221
- setattr(tensor_operator_registry, "__getitem__", _tensor_getitem)
222
- setattr(tensor_operator_registry, "__setitem__", _tensor_setitem)
221
+ tensor_operator_registry.register("__getitem__", _tensor_getitem)
222
+ tensor_operator_registry.register("__setitem__", _tensor_setitem)
223
223
 
224
224
 
225
225
  def _tensor_add(self, other):
@@ -288,15 +288,15 @@ def _tensor_floordiv(self, other):
288
288
  return F.floordiv(self, other)
289
289
 
290
290
 
291
- setattr(tensor_operator_registry, '__add__', _tensor_add)
292
- setattr(tensor_operator_registry, '__sub__', _tensor_sub)
293
- setattr(tensor_operator_registry, '__mul__', _tensor_mul)
294
- setattr(tensor_operator_registry, '__matmul__', _tensor_matmul)
295
- setattr(tensor_operator_registry, '__truediv__', _tensor_div)
296
- setattr(tensor_operator_registry, '__mod__', _tensor_mod)
297
- setattr(tensor_operator_registry, '__pow__', _tensor_pow)
298
- setattr(tensor_operator_registry, '__rpow__', _tensor_rpow)
299
- setattr(tensor_operator_registry, '__floordiv__', _tensor_floordiv)
291
+ tensor_operator_registry.register('__add__', _tensor_add)
292
+ tensor_operator_registry.register('__sub__', _tensor_sub)
293
+ tensor_operator_registry.register('__mul__', _tensor_mul)
294
+ tensor_operator_registry.register('__matmul__', _tensor_matmul)
295
+ tensor_operator_registry.register('__truediv__', _tensor_div)
296
+ tensor_operator_registry.register('__mod__', _tensor_mod)
297
+ tensor_operator_registry.register('__pow__', _tensor_pow)
298
+ tensor_operator_registry.register('__rpow__', _tensor_rpow)
299
+ tensor_operator_registry.register('__floordiv__', _tensor_floordiv)
300
300
 
301
301
 
302
302
  def _scalar_to_tensor(input_x):
@@ -356,8 +356,8 @@ def tensor_itemset(data, *args):
356
356
  return tensor_itemset_with_number(data, args[0])
357
357
 
358
358
 
359
- setattr(tensor_operator_registry, "item", tensor_item)
360
- setattr(tensor_operator_registry, "itemset", tensor_itemset)
359
+ tensor_operator_registry.register("item", tensor_item)
360
+ tensor_operator_registry.register("itemset", tensor_itemset)
361
361
 
362
362
 
363
363
  def tensor_itemset_with_number(data, number_value):
@@ -1204,11 +1204,7 @@ def tensor_setitem_by_ellipsis_with_tensor(data, value):
1204
1204
  value = value.astype(data_dtype)
1205
1205
 
1206
1206
  value_shape = F.shape(value)
1207
-
1208
- if len(value_shape) > len(data_shape):
1209
- source_shape = data_shape
1210
- else:
1211
- source_shape = value_shape
1207
+ source_shape = const_utils.get_source_shape(data_shape, value_shape)
1212
1208
  value = F.reshape(value, source_shape)
1213
1209
  data = F.broadcast_to(value, data_shape)
1214
1210
  return data
@@ -1233,10 +1229,7 @@ def tensor_setitem_by_bool(data, index, value):
1233
1229
 
1234
1230
  if index:
1235
1231
  value_shape = F.shape(value)
1236
- if len(value_shape) > len(data_shape):
1237
- source_shape = data_shape
1238
- else:
1239
- source_shape = value_shape
1232
+ source_shape = const_utils.get_source_shape(data_shape, value_shape)
1240
1233
  value = F.reshape(value, source_shape)
1241
1234
  data = F.broadcast_to(value, data_shape)
1242
1235
  return data
@@ -1403,7 +1396,7 @@ def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
1403
1396
  return reduce_fn(a, axes).astype(dtype)
1404
1397
 
1405
1398
 
1406
- setattr(tensor_operator_registry, "reduce", reduce_)
1399
+ tensor_operator_registry.register("reduce", reduce_)
1407
1400
 
1408
1401
 
1409
1402
  def check_indices(dims, indices, mode, allow_negative_index=True):
@@ -1430,7 +1423,7 @@ def check_indices(dims, indices, mode, allow_negative_index=True):
1430
1423
  return clipped
1431
1424
 
1432
1425
 
1433
- setattr(tensor_operator_registry, 'check_indices', check_indices)
1426
+ tensor_operator_registry.register('check_indices', check_indices)
1434
1427
 
1435
1428
 
1436
1429
  def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position):
@@ -216,7 +216,7 @@ def make_tensor(a, dtype=mstype.int64, data_shape=None, dim_size=None):
216
216
 
217
217
  return Tensor(a, dtype)
218
218
 
219
- setattr(tensor_operator_registry, 'make_tensor', make_tensor)
219
+ tensor_operator_registry.register('make_tensor', make_tensor)
220
220
 
221
221
 
222
222
  def judge_data_dim(data_dim, min_data_dim=0, max_data_dim=8):
@@ -33,7 +33,7 @@ from . import (
33
33
  nn_func,
34
34
  )
35
35
 
36
- from .array_func import gather, max, min, one_hot
36
+ from .array_func import gather, max, min, one_hot, narrow
37
37
  from .math_func import (
38
38
  baddbmm,
39
39
  bmm,
@@ -44,7 +44,8 @@ from .math_func import (
44
44
  from .nn_func import (
45
45
  conv2d,
46
46
  max_pool2d,
47
- leaky_relu_ext
47
+ leaky_relu_ext,
48
+ batch_norm
48
49
  )
49
50
 
50
51
  __all__ = []
@@ -21,12 +21,52 @@ Array Operators
21
21
  from mindspore.common import Tensor
22
22
  from mindspore.ops.operations.array_ops import ArgMaxWithValue, ArgMinWithValue
23
23
  from mindspore.ops._primitive_cache import _get_cache_prim
24
- from mindspore.ops.auto_generate.gen_ops_prim import gather_d_op
24
+ from mindspore.ops.auto_generate.gen_ops_prim import gather_d_op, slice_ext_op, OneHotExt
25
25
  from mindspore.ops.auto_generate.gen_ops_def import max_, min_
26
- from mindspore.ops.auto_generate.pyboost_inner_prim import _PyboostOneHotExtPrim
27
- one_hot_ext_impl = _PyboostOneHotExtPrim()
26
+ from mindspore import _checkparam as validator
27
+
28
28
 
29
29
  # define Primitive global variables
30
+ def narrow(input, dim, start, length):
31
+ """
32
+ Returns a narrowed tensor from input tensor, and
33
+ the dimension axis is input from start to start + length.
34
+
35
+ Args:
36
+ input (Tensor): the tensor to narrow.
37
+ dim (int): dimension along which to narrow.
38
+ start (int): the starting dimension.
39
+ length (int): the distance to the ending dimension.
40
+
41
+ Returns:
42
+ Tensor.
43
+
44
+ - output (Tensors) - The narrowed tensor.
45
+
46
+ Raises:
47
+ TypeError: If the input is not a tensor or tuple or list of tensors.
48
+
49
+ Supported Platforms:
50
+ ``Ascend`` ``GPU`` ``CPU``
51
+
52
+ Examples:
53
+ >>> import mindspore
54
+ >>> from mindspore import ops
55
+ >>> from mindspore import Tensor
56
+ >>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mindspore.int32)
57
+ >>> output = ops.narrow(x, 0, 0, 2)
58
+ >>> print(output)
59
+ [[ 1 2 3]
60
+ [ 4 5 6]]
61
+ >>> output = ops.narrow(x, 1, 1, 2)
62
+ >>> print(output)
63
+ [[ 2 3]
64
+ [ 5 6]
65
+ [ 8 9]]
66
+ """
67
+ validator.check_value_type("input", input, Tensor, "narrow")
68
+ return slice_ext_op(input, dim, start, start+length, 1)
69
+
30
70
 
31
71
  def gather(input, dim, index):
32
72
  r"""
@@ -48,7 +88,7 @@ def gather(input, dim, index):
48
88
  index (Tensor): The index tensor, with int32 or int64 data type. An valid `index` should be:
49
89
 
50
90
  - `index.rank == input.rank`;
51
- - for `axis != dim`, `index.shape[axis] <= input.shape[axis]`;
91
+ - `index.shape[axis] <= input.shape[axis]` where axis goes through all dimensions of `input` except `dim`;
52
92
  - the value of `index` is in range `[-input.shape[dim], input.shape[dim])`.
53
93
 
54
94
  Returns:
@@ -72,7 +112,7 @@ def gather(input, dim, index):
72
112
  >>> output = ops.extend.gather(input, 1, index)
73
113
  >>> print(output)
74
114
  [[-0.1 -0.1]
75
- [0.5 0.5]]
115
+ [ 0.5 0.5]]
76
116
  """
77
117
  return gather_d_op(input, dim, index)
78
118
 
@@ -92,7 +132,7 @@ def max(input, dim=None, keepdim=False):
92
132
  and same dtype as `input`.
93
133
 
94
134
  tuple (Tensor) if `dim` is not the default value ``None`` , tuple of 2 tensors, containing the maximum
95
- value of the input tensor along the given dimension `dim` and the corresponding index.
135
+ value of the input tensor along the given dimension `dim` and the corresponding index:
96
136
 
97
137
  - **values (Tensor)** - The maximum value of input tensor along the given dimension `dim`, with same dtype as
98
138
  `input`. If `keepdim` is ``True`` , the shape of output tensors is :math:`(input_1, input_2, ...,
@@ -141,7 +181,7 @@ def min(input, dim=None, keepdim=False):
141
181
  and same dtype as `input`.
142
182
 
143
183
  tuple (Tensor) if `dim` is not the default value ``None`` , tuple of 2 tensors, containing the minimum value
144
- of the input tensor along the given dimension `dim` and the corresponding index.
184
+ of the input tensor along the given dimension `dim` and the corresponding index:
145
185
 
146
186
  - **values (Tensor)** - The minimum value of input tensor along the given dimension `dim`, with same dtype as
147
187
  `input`. If `keepdim` is ``True`` , the shape of output tensors is :math:`(input_1, input_2, ...,
@@ -184,7 +224,7 @@ def one_hot(tensor, num_classes):
184
224
  Args:
185
225
  tensor (Tensor): A tensor of indices. Tensor of shape :math:`(X_0, \ldots, X_n)`.
186
226
  Data type must be int32 or int64.
187
- num_classes (int): A scalar defining the depth of the one-hot dimension.
227
+ num_classes (Union[int, Tensor]): A scalar defining the depth of the one-hot dimension.
188
228
 
189
229
  Returns:
190
230
  Tensor, one-hot tensor.
@@ -200,7 +240,7 @@ def one_hot(tensor, num_classes):
200
240
  Examples:
201
241
  >>> import mindspore
202
242
  >>> import numpy as np
203
- >>> from mindspore import ops
243
+ >>> import mindspore.ops as ops
204
244
  >>> from mindspore import Tensor
205
245
  >>> tensor = Tensor(np.array([0, 1, 2]), mindspore.int32)
206
246
  >>> num_classes = 3
@@ -212,7 +252,8 @@ def one_hot(tensor, num_classes):
212
252
  """
213
253
  on_value = Tensor(1, dtype=tensor.dtype)
214
254
  off_value = Tensor(0, dtype=tensor.dtype)
215
- return one_hot_ext_impl(tensor, num_classes, on_value, off_value, -1)
255
+ onehot = _get_cache_prim(OneHotExt)(-1)
256
+ return onehot(tensor, num_classes, on_value, off_value)
216
257
 
217
258
 
218
259
  __all__ = ['gather', 'max', 'min', 'one_hot']
@@ -21,6 +21,8 @@ NN Operators with better performance
21
21
  from mindspore.ops._primitive_cache import _get_cache_prim
22
22
  from mindspore.ops.auto_generate.gen_ops_prim import Convolution, ConstantPadND, MaxPoolWithIndices, MaxPoolWithMask
23
23
  from mindspore.ops.auto_generate import leaky_relu_ext
24
+ from mindspore.ops.auto_generate import BatchNormExt
25
+ from mindspore import ops
24
26
  from mindspore import _checkparam as validator
25
27
 
26
28
 
@@ -232,7 +234,7 @@ def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, *, ceil_m
232
234
  \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)
233
235
 
234
236
  .. warning::
235
- Only support on Atlas A2 training series.
237
+ Only support on Atlas training series.
236
238
 
237
239
  Args:
238
240
  input (Tensor): Tensor of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})` with data type of float32
@@ -305,4 +307,78 @@ def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, *, ceil_m
305
307
  return out
306
308
 
307
309
 
308
- __all__ = ['conv2d', 'max_pool2d', 'leaky_relu_ext']
310
+ def batch_norm(input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-5):
311
+ r"""
312
+ Batch Normalization for input data and updated parameters.
313
+
314
+ Batch Normalization is widely used in convolutional neural networks. This operation
315
+ applies Batch Normalization over inputs to avoid internal covariate shift as described
316
+ in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal
317
+ Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
318
+ features using a mini-batch of data and the learned parameters can be described
319
+ in the following formula,
320
+
321
+ .. math::
322
+
323
+ y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
324
+
325
+ where :math:`\gamma` is `weight`, :math:`\beta` is `bias`, :math:`\epsilon` is `eps`, :math:`mean` is the
326
+ mean of :math:`x`, :math:`variance` is the variance of :math:`x`.
327
+
328
+ .. warning::
329
+ - For Atlas 200/300/500 inference product,
330
+ the result accuracy fails to reach 1‰ due to the square root instruction.
331
+
332
+ Note:
333
+ - If `training` is `False`, `weight`, `bias`, `running_mean` and `running_var` are Tensors.
334
+ - If `training` is `True`, `weight`, `bias`, `running_mean` and `running_var` are Parameters.
335
+
336
+ Args:
337
+ input (Tensor): Tensor of shape :math:`(N, C, *)`, with bfloat16, float16 or float32 data type.
338
+ running_mean (Union[Tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`.
339
+ running_var (Union[Tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`.
340
+ weight (Union[Tensor, Parameter]): The shape :math:`(C,)`, with bfloat, float16 or float32 data type.
341
+ bias (Union[Tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`.
342
+ training (bool, optional): If `training` is `True`, `mean` and `variance` are computed during training.
343
+ If `training` is `False`, they're loaded from checkpoint during inference. Default: ``False`` .
344
+ momentum (float, optional): The hyper parameter to compute moving average for `running_mean` and `running_var`
345
+ (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
346
+ Default: ``0.1`` .
347
+ eps (float, optional): A small value added for numerical stability. Default: ``1e-5``.
348
+
349
+ Returns:
350
+ output_x (Tensor) - The same type and shape as the `input_x`. The shape is :math:`(N, C, *)`.
351
+
352
+ Raises:
353
+ TypeError: If `training` is not a bool.
354
+ TypeError: If dtype of `eps` or `momentum` is not float.
355
+ TypeError: If `input_x`, `weight`, `bias`, `running_mean` or `running_var` is not a Tensor.
356
+ TypeError: If dtype of `input_x`, `weight` is not bfloat16, float16 or float32.
357
+
358
+ Supported Platforms:
359
+ ``Ascend``
360
+
361
+ Examples:
362
+ >>> import mindspore
363
+ >>> from mindspore import Tensor, ops
364
+ >>> input_x = Tensor([[1.0, 2.0], [3.0, 4.0]], mindspore.float32)
365
+ >>> running_mean = Tensor([0.5, 1.5], mindspore.float32)
366
+ >>> running_var = Tensor([0.1, 0.2], mindspore.float32)
367
+ >>> weight = Tensor([2.0, 2.0], mindspore.float32)
368
+ >>> bias = Tensor([-1.0, -1.0], mindspore.float32)
369
+ >>> output = ops.batch_norm(input_x, running_mean, running_var, weight, bias)
370
+ >>> print(output)
371
+ [[ 2.1621194 1.2360122]
372
+ [14.810596 10.180061 ]]
373
+ """
374
+ if weight is None:
375
+ weight = ops.ones([input.shape[1]], dtype=input.dtype)
376
+ if bias is None:
377
+ bias = ops.zeros([input.shape[1]], dtype=input.dtype)
378
+ batch_norm_op = BatchNormExt(training=training, momentum=momentum, eps=eps)
379
+ output = batch_norm_op(input, weight, bias, running_mean, running_var)
380
+ return output[0]
381
+
382
+
383
+
384
+ __all__ = ['conv2d', 'max_pool2d', 'leaky_relu_ext', 'batch_norm']
@@ -26,8 +26,7 @@ from . import (
26
26
  nn_func,
27
27
  linalg_func,
28
28
  clip_func,
29
- fft_func,
30
- reshard_func
29
+ fft_func
31
30
  )
32
31
  from .array_func import (
33
32
  unique,
@@ -100,6 +99,7 @@ from .array_func import (
100
99
  tensor_scatter_elements,
101
100
  scatter,
102
101
  scatter_add,
102
+ scatter_add_ext,
103
103
  scatter_mul,
104
104
  scatter_max,
105
105
  scatter_min,
@@ -138,7 +138,6 @@ from .array_func import (
138
138
  index_fill,
139
139
  index_select,
140
140
  max,
141
- argmax,
142
141
  min,
143
142
  population_count,
144
143
  topk,
@@ -166,6 +165,9 @@ from .array_func import (
166
165
  top_k,
167
166
  deepcopy,
168
167
  arange_ext,
168
+ zeros_like_ext,
169
+ ones_like_ext,
170
+ full_ext,
169
171
  )
170
172
  from .parameter_func import (
171
173
  assign,
@@ -178,6 +180,7 @@ from .math_func import (
178
180
  addn,
179
181
  absolute,
180
182
  abs,
183
+ argmax,
181
184
  argmin,
182
185
  angle,
183
186
  bincount,
@@ -197,6 +200,7 @@ from .math_func import (
197
200
  le,
198
201
  lerp,
199
202
  norm,
203
+ norm_ext,
200
204
  vector_norm,
201
205
  matrix_norm,
202
206
  round,
@@ -264,6 +268,7 @@ from .math_func import (
264
268
  matrix_determinant,
265
269
  det,
266
270
  linspace,
271
+ linspace_ext,
267
272
  lu_solve,
268
273
  matrix_solve,
269
274
  maximum,
@@ -369,6 +374,7 @@ from .math_func import (
369
374
  amin,
370
375
  amax,
371
376
  mean,
377
+ mean_ext,
372
378
  prod,
373
379
  all,
374
380
  any,
@@ -456,7 +462,6 @@ from .nn_func import (
456
462
  max_pool2d,
457
463
  max_pool3d,
458
464
  batch_norm,
459
- rms_norm,
460
465
  bidense,
461
466
  celu,
462
467
  bias_add,
@@ -505,6 +510,7 @@ from .nn_func import (
505
510
  softplus,
506
511
  pdist,
507
512
  pad,
513
+ pad_ext,
508
514
  prelu,
509
515
  mirror_pad,
510
516
  nll_loss,
@@ -543,6 +549,7 @@ from .nn_func import (
543
549
  channel_shuffle,
544
550
  hardsigmoid,
545
551
  group_norm,
552
+ dropout_ext,
546
553
  )
547
554
  from .linalg_func import (
548
555
  cond,
@@ -593,6 +600,7 @@ from .random_func import (
593
600
  standard_laplace,
594
601
  random_categorical,
595
602
  uniform,
603
+ uniform_ext,
596
604
  standard_normal,
597
605
  random_gamma,
598
606
  uniform_candidate_sampler,
@@ -600,6 +608,7 @@ from .random_func import (
600
608
  log_uniform_candidate_sampler,
601
609
  shuffle,
602
610
  choice_with_mask,
611
+ normal_ext,
603
612
  normal,
604
613
  laplace,
605
614
  gamma,
@@ -733,9 +742,6 @@ from .other_func import (
733
742
  depend,
734
743
  partial,
735
744
  )
736
- from .reshard_func import (
737
- reshard,
738
- )
739
745
 
740
746
  from ..operations.manually_defined import (rank, scalar_cast)
741
747
 
@@ -756,5 +762,4 @@ __all__.extend(sparse_unary_func.__all__)
756
762
  __all__.extend(clip_func.__all__)
757
763
  __all__.extend(fft_func.__all__)
758
764
  __all__.extend(other_func.__all__)
759
- __all__.extend(reshard_func.__all__)
760
765
  __all__.sort()