mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (550) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  172. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  173. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  174. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  175. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  176. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  177. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  178. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  179. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  180. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  181. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  182. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  183. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  184. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  185. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  186. mindspore/nn/__init__.py +0 -2
  187. mindspore/nn/cell.py +316 -74
  188. mindspore/nn/dynamic_lr.py +21 -21
  189. mindspore/nn/layer/activation.py +21 -28
  190. mindspore/nn/layer/basic.py +15 -13
  191. mindspore/nn/layer/channel_shuffle.py +1 -1
  192. mindspore/nn/layer/container.py +271 -9
  193. mindspore/nn/layer/conv.py +310 -207
  194. mindspore/nn/layer/dense.py +8 -5
  195. mindspore/nn/layer/embedding.py +33 -27
  196. mindspore/nn/layer/flash_attention.py +82 -41
  197. mindspore/nn/layer/image.py +8 -6
  198. mindspore/nn/layer/math.py +13 -18
  199. mindspore/nn/layer/normalization.py +107 -66
  200. mindspore/nn/layer/padding.py +1 -1
  201. mindspore/nn/layer/pooling.py +131 -109
  202. mindspore/nn/layer/rnn_cells.py +22 -17
  203. mindspore/nn/layer/rnns.py +13 -16
  204. mindspore/nn/layer/thor_layer.py +1 -1
  205. mindspore/nn/layer/transformer.py +221 -154
  206. mindspore/nn/learning_rate_schedule.py +9 -1
  207. mindspore/nn/loss/loss.py +235 -174
  208. mindspore/nn/optim/ada_grad.py +2 -1
  209. mindspore/nn/optim/adadelta.py +1 -0
  210. mindspore/nn/optim/adafactor.py +2 -1
  211. mindspore/nn/optim/adam.py +7 -4
  212. mindspore/nn/optim/adamax.py +3 -2
  213. mindspore/nn/optim/adasum.py +2 -2
  214. mindspore/nn/optim/asgd.py +2 -3
  215. mindspore/nn/optim/ftrl.py +6 -5
  216. mindspore/nn/optim/lamb.py +7 -4
  217. mindspore/nn/optim/lars.py +1 -1
  218. mindspore/nn/optim/lazyadam.py +5 -3
  219. mindspore/nn/optim/momentum.py +2 -1
  220. mindspore/nn/optim/optimizer.py +53 -4
  221. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  222. mindspore/nn/optim/rmsprop.py +4 -3
  223. mindspore/nn/optim/rprop.py +23 -12
  224. mindspore/nn/optim/sgd.py +26 -11
  225. mindspore/nn/optim/thor.py +9 -7
  226. mindspore/nn/probability/bijector/bijector.py +5 -5
  227. mindspore/nn/probability/bijector/power_transform.py +27 -27
  228. mindspore/nn/probability/bijector/softplus.py +3 -3
  229. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  230. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  231. mindspore/nn/probability/distribution/beta.py +3 -3
  232. mindspore/nn/probability/distribution/categorical.py +7 -7
  233. mindspore/nn/probability/distribution/cauchy.py +0 -1
  234. mindspore/nn/probability/distribution/distribution.py +3 -3
  235. mindspore/nn/probability/distribution/gamma.py +3 -3
  236. mindspore/nn/probability/distribution/geometric.py +4 -4
  237. mindspore/nn/probability/distribution/gumbel.py +4 -4
  238. mindspore/nn/probability/distribution/log_normal.py +2 -2
  239. mindspore/nn/probability/distribution/logistic.py +2 -2
  240. mindspore/nn/probability/distribution/poisson.py +4 -4
  241. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  242. mindspore/nn/probability/distribution/uniform.py +6 -6
  243. mindspore/nn/wrap/cell_wrapper.py +78 -34
  244. mindspore/nn/wrap/grad_reducer.py +8 -5
  245. mindspore/nn/wrap/loss_scale.py +105 -42
  246. mindspore/numpy/array_creations.py +1 -2
  247. mindspore/numpy/array_ops.py +3 -2
  248. mindspore/offline_debug/convert_async.py +2 -2
  249. mindspore/ops/_grad_experimental/__init__.py +0 -5
  250. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  251. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  252. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  253. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  254. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  255. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  256. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  257. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  258. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  259. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  260. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  261. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  262. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  263. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  264. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  265. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  266. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  267. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  268. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  269. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  270. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  271. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  272. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  273. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  274. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  275. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  276. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  277. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  278. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  279. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  280. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  281. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  282. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  283. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  284. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  285. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  286. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  287. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  288. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  289. mindspore/ops/_primitive_cache.py +1 -1
  290. mindspore/ops/_tracefunc.py +45 -13
  291. mindspore/ops/_utils/utils.py +4 -1
  292. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  293. mindspore/ops/_vmap/vmap_base.py +3 -3
  294. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  295. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  296. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  297. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  298. mindspore/ops/arg_dtype_cast.py +54 -0
  299. mindspore/ops/composite/base.py +37 -10
  300. mindspore/ops/composite/math_ops.py +5 -4
  301. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  302. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  303. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  304. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  305. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  306. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  308. mindspore/ops/deprecated.py +304 -0
  309. mindspore/ops/function/__init__.py +4 -1
  310. mindspore/ops/function/array_func.py +167 -189
  311. mindspore/ops/function/clip_func.py +81 -13
  312. mindspore/ops/function/debug_func.py +1 -1
  313. mindspore/ops/function/grad/grad_func.py +18 -8
  314. mindspore/ops/function/image_func.py +10 -4
  315. mindspore/ops/function/linalg_func.py +5 -5
  316. mindspore/ops/function/math_func.py +575 -386
  317. mindspore/ops/function/nn_func.py +470 -251
  318. mindspore/ops/function/random_func.py +86 -56
  319. mindspore/ops/function/sparse_func.py +1 -1
  320. mindspore/ops/function/sparse_unary_func.py +14 -12
  321. mindspore/ops/function/vmap_func.py +6 -5
  322. mindspore/ops/functional.py +15 -10
  323. mindspore/ops/op_info_register.py +235 -19
  324. mindspore/ops/operations/__init__.py +25 -17
  325. mindspore/ops/operations/_grad_ops.py +52 -7
  326. mindspore/ops/operations/_inner_ops.py +213 -12
  327. mindspore/ops/operations/_quant_ops.py +4 -8
  328. mindspore/ops/operations/_sequence_ops.py +42 -0
  329. mindspore/ops/operations/array_ops.py +64 -280
  330. mindspore/ops/operations/comm_ops.py +105 -57
  331. mindspore/ops/operations/custom_ops.py +10 -3
  332. mindspore/ops/operations/debug_ops.py +8 -4
  333. mindspore/ops/operations/image_ops.py +18 -12
  334. mindspore/ops/operations/math_ops.py +185 -138
  335. mindspore/ops/operations/nn_ops.py +716 -492
  336. mindspore/ops/operations/other_ops.py +0 -22
  337. mindspore/ops/operations/random_ops.py +53 -111
  338. mindspore/ops/operations/sparse_ops.py +3 -1
  339. mindspore/ops/primitive.py +24 -18
  340. mindspore/parallel/_auto_parallel_context.py +68 -8
  341. mindspore/parallel/_cost_model_context.py +2 -2
  342. mindspore/parallel/_offload_context.py +17 -3
  343. mindspore/parallel/_parallel_serialization.py +2 -2
  344. mindspore/parallel/_ps_context.py +12 -0
  345. mindspore/parallel/_tensor.py +14 -12
  346. mindspore/parallel/_transformer/layers.py +5 -3
  347. mindspore/parallel/_transformer/loss.py +1 -0
  348. mindspore/parallel/_transformer/moe.py +2 -2
  349. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  350. mindspore/parallel/_transformer/transformer.py +23 -3
  351. mindspore/parallel/_utils.py +11 -7
  352. mindspore/parallel/algo_parameter_config.py +85 -5
  353. mindspore/parallel/checkpoint_transform.py +6 -10
  354. mindspore/parallel/shard.py +4 -4
  355. mindspore/profiler/common/struct_type.py +3 -3
  356. mindspore/profiler/common/util.py +3 -2
  357. mindspore/profiler/envprofiling.py +1 -1
  358. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  359. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  360. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  361. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  362. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  363. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  364. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  365. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  366. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  367. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  368. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  369. mindspore/profiler/parser/flops_parser.py +15 -11
  370. mindspore/profiler/parser/framework_parser.py +37 -21
  371. mindspore/profiler/parser/hccl_parser.py +16 -12
  372. mindspore/profiler/parser/integrator.py +22 -11
  373. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  374. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  375. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  376. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  377. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  378. mindspore/profiler/parser/optime_parser.py +1 -1
  379. mindspore/profiler/parser/profiler_info.py +2 -2
  380. mindspore/profiler/parser/step_trace_parser.py +11 -14
  381. mindspore/profiler/profiling.py +139 -71
  382. mindspore/rewrite/api/node.py +102 -19
  383. mindspore/rewrite/api/node_type.py +5 -1
  384. mindspore/rewrite/api/scoped_value.py +9 -17
  385. mindspore/rewrite/api/symbol_tree.py +131 -47
  386. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  387. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  388. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  389. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  390. mindspore/rewrite/common/rewrite_elog.py +5 -1
  391. mindspore/rewrite/namer.py +33 -24
  392. mindspore/rewrite/namespace.py +14 -5
  393. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  394. mindspore/rewrite/node/call_function.py +79 -0
  395. mindspore/rewrite/node/cell_container.py +135 -0
  396. mindspore/rewrite/node/control_flow.py +88 -0
  397. mindspore/rewrite/{node.py → node/node.py} +273 -234
  398. mindspore/rewrite/node/node_manager.py +254 -0
  399. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  400. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  401. mindspore/rewrite/parsers/assign_parser.py +216 -221
  402. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  403. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  404. mindspore/rewrite/parsers/constant_parser.py +9 -6
  405. mindspore/rewrite/parsers/container_parser.py +9 -7
  406. mindspore/rewrite/parsers/for_parser.py +36 -15
  407. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  408. mindspore/rewrite/parsers/if_parser.py +28 -24
  409. mindspore/rewrite/parsers/module_parser.py +196 -25
  410. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  411. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  412. mindspore/rewrite/parsers/return_parser.py +6 -6
  413. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  414. mindspore/rewrite/sparsify/utils.py +1 -1
  415. mindspore/rewrite/symbol_tree.py +525 -577
  416. mindspore/rewrite/symbol_tree_builder.py +9 -193
  417. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  418. mindspore/run_check/_check_version.py +2 -2
  419. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  420. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  421. mindspore/scipy/linalg.py +1 -1
  422. mindspore/scipy/optimize/minimize.py +7 -3
  423. mindspore/train/_utils.py +7 -3
  424. mindspore/train/amp.py +323 -123
  425. mindspore/train/anf_ir_pb2.py +14 -2
  426. mindspore/train/callback/_backup_and_restore.py +2 -12
  427. mindspore/train/callback/_callback.py +29 -4
  428. mindspore/train/callback/_checkpoint.py +23 -8
  429. mindspore/train/callback/_early_stop.py +2 -2
  430. mindspore/train/callback/_landscape.py +4 -4
  431. mindspore/train/callback/_loss_monitor.py +2 -2
  432. mindspore/train/callback/_on_request_exit.py +2 -2
  433. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  434. mindspore/train/callback/_summary_collector.py +14 -7
  435. mindspore/train/callback/_time_monitor.py +58 -5
  436. mindspore/train/data_sink.py +5 -11
  437. mindspore/train/dataset_helper.py +83 -57
  438. mindspore/train/loss_scale_manager.py +2 -2
  439. mindspore/train/metrics/__init__.py +3 -3
  440. mindspore/train/metrics/cosine_similarity.py +1 -1
  441. mindspore/train/metrics/hausdorff_distance.py +3 -2
  442. mindspore/train/metrics/mean_surface_distance.py +3 -2
  443. mindspore/train/metrics/metric.py +39 -19
  444. mindspore/train/metrics/roc.py +2 -2
  445. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  446. mindspore/train/mind_ir_pb2.py +85 -36
  447. mindspore/train/model.py +185 -45
  448. mindspore/train/serialization.py +390 -150
  449. mindspore/train/summary/_writer_pool.py +3 -2
  450. mindspore/train/summary/summary_record.py +14 -10
  451. mindspore/train/train_thor/convert_utils.py +3 -3
  452. mindspore/train/train_thor/dataset_helper.py +1 -1
  453. mindspore/version.py +1 -1
  454. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  455. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +458 -518
  456. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  457. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  458. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  459. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  460. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  461. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  462. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  463. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  464. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  465. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  466. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  467. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  468. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  469. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  470. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  471. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  472. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  473. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  474. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  475. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  476. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  477. mindspore/_extends/graph_kernel/expander.py +0 -80
  478. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  479. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  480. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  481. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  482. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  483. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  484. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  485. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  486. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  487. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  488. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  489. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  490. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  491. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  492. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  493. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  494. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  495. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  496. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  497. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  498. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  499. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  500. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  501. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  502. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  503. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  504. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  505. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  506. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  507. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  509. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  510. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  511. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  512. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  514. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  515. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  516. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  517. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  518. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  519. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  520. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  521. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  522. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  523. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  524. mindspore/dataset/datapreprocess/__init__.py +0 -20
  525. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  526. mindspore/include/api/net.h +0 -142
  527. mindspore/nn/lr_scheduler.py +0 -262
  528. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  529. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  530. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  531. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  532. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  533. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  537. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  541. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  542. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  543. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  546. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  547. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  548. mindspore/rewrite/node_visitor.py +0 -44
  549. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  550. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -77,6 +77,7 @@ class BiDense(Cell):
77
77
  bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter.
78
78
  The values of str refer to the function `initializer`. Default: ``None`` .
79
79
  has_bias (bool): Specifies whether the layer uses :math:`\text{bias}` vector. Default: ``True`` .
80
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
80
81
 
81
82
  Shape:
82
83
  - **input1** - :math:`(*, H_{in1})` where :math:`H_{in1}=\text{in1_channels}` and
@@ -90,8 +91,8 @@ class BiDense(Cell):
90
91
  are the same shape as the inputs.
91
92
 
92
93
  Dtype:
93
- - **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2**.
94
- - **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input1**.
94
+ - **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2** .
95
+ - **input2** (Tensor) - The dtype must be float16 or float32 and be same as **input1** .
95
96
  - **output** (Tensor) - With the same dtype as the inputs.
96
97
 
97
98
  Weights:
@@ -133,7 +134,8 @@ class BiDense(Cell):
133
134
  out_channels,
134
135
  weight_init=None,
135
136
  bias_init=None,
136
- has_bias=True):
137
+ has_bias=True,
138
+ dtype=mstype.float32):
137
139
  super().__init__()
138
140
  self.in_channels = Validator.check_positive_int(in1_channels, "in1_channels", self.cls_name)
139
141
  self.in_channels = Validator.check_positive_int(in2_channels, "in2_channels", self.cls_name)
@@ -156,7 +158,8 @@ class BiDense(Cell):
156
158
  f"equal to 'in2_channels'. But got 'weight_init': {weight_init}, "
157
159
  f"'out_channels': {out_channels}, 'in_channels': {in1_channels}, "
158
160
  f"'in2_channels': {in2_channels}")
159
- self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels)), 'weight')
161
+ self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels), dtype=dtype),
162
+ 'weight')
160
163
 
161
164
  if self.has_bias:
162
165
  if bias_init is None:
@@ -166,7 +169,7 @@ class BiDense(Cell):
166
169
  raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' should "
167
170
  f"be equal to 1, and the first dim must be equal to 'out_channels'. But got "
168
171
  f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
169
- self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
172
+ self.bias = Parameter(initializer(bias_init, [out_channels], dtype=dtype), name="bias")
170
173
  self.bias_add = P.BiasAdd()
171
174
  self.matmul = P.MatMul()
172
175
 
@@ -64,11 +64,13 @@ class Embedding(Cell):
64
64
  embedding_size (int): The size of each embedding vector.
65
65
  use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: ``False`` .
66
66
  embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
67
- Refer to class `initializer` for the values of string when a string
68
- is specified. Default: ``'normal'`` .
67
+ Refer to class `mindspore.common.initializer
68
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
69
+ for the values of string when a string is specified. Default: ``'normal'`` .
69
70
  dtype (:class:`mindspore.dtype`): Data type of `x`. Default: ``mstype.float32`` .
70
71
  padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
71
72
  will be initialized to zero. Default: ``None`` . The feature is inactivated.
73
+
72
74
  Inputs:
73
75
  - **x** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{x_length})`. The elements of
74
76
  the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
@@ -145,9 +147,8 @@ class Embedding(Cell):
145
147
  return output
146
148
 
147
149
  def extend_repr(self):
148
- s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
149
- self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
150
- return s
150
+ return f'vocab_size={self.vocab_size}, embedding_size={self.embedding_size}, use_one_hot={self.use_one_hot}, ' \
151
+ f'embedding_table={self.embedding_table}, dtype={self.dtype}, padding_idx={self.padding_idx}'
151
152
 
152
153
 
153
154
  @_primexpr
@@ -190,6 +191,7 @@ class EmbeddingLookup(Cell):
190
191
  parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
191
192
  optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
192
193
  memory, so suggests setting a reasonable value to avoid insufficient memory.
194
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
193
195
 
194
196
  Inputs:
195
197
  - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
@@ -231,7 +233,7 @@ class EmbeddingLookup(Cell):
231
233
 
232
234
  def __init__(self, vocab_size, embedding_size, param_init='normal',
233
235
  target='CPU', slice_mode='batch_slice', manual_shapes=None,
234
- max_norm=None, sparse=True, vocab_cache_size=0):
236
+ max_norm=None, sparse=True, vocab_cache_size=0, dtype=mstype.float32):
235
237
  """Initialize EmbeddingLookup."""
236
238
  super(EmbeddingLookup, self).__init__()
237
239
  Validator.check_value_type('sparse', sparse, [bool], self.cls_name)
@@ -255,8 +257,8 @@ class EmbeddingLookup(Cell):
255
257
  if enable_ps:
256
258
  self._process_vocab_cache(slice_mode)
257
259
  self.embedding_size = Validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name)
258
- self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
259
- name='embedding_table')
260
+ self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size],
261
+ dtype=dtype), name='embedding_table')
260
262
  parallel_mode = _get_parallel_mode()
261
263
  is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
262
264
  self.gather_revert = P.Gather()
@@ -267,7 +269,7 @@ class EmbeddingLookup(Cell):
267
269
  if is_auto_parallel:
268
270
  self.unique = P.Unique().shard(((1,),))
269
271
  if self.cache_enable and enable_ps:
270
- self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init)
272
+ self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init, dtype=dtype)
271
273
  if is_auto_parallel:
272
274
  self.unique.add_prim_attr('cache_enable', True)
273
275
  indices_shape_size = 2
@@ -310,8 +312,8 @@ class EmbeddingLookup(Cell):
310
312
  else:
311
313
  if is_auto_parallel:
312
314
  support_mode = ["field_slice", "table_row_slice", "table_column_slice", "batch_slice"]
313
- raise ValueError("For '{}', the 'slice_mode' must be in {}, "
314
- "but got \"{}\".".format(self.cls_name, support_mode, slice_mode))
315
+ raise ValueError(f"For '{self.cls_name}', the 'slice_mode' must be in {support_mode}, "
316
+ f"but got \"{slice_mode}\".")
315
317
  if self.cache_enable and not enable_ps:
316
318
  raise ValueError(f"For '{self.cls_name}', haven't supported cache enable for not ps mode.")
317
319
  self.embedding_table.unique = self.forward_unique
@@ -354,7 +356,8 @@ class EmbeddingLookup(Cell):
354
356
  if _is_role_worker():
355
357
  self.vocab_size = self.vocab_cache_size
356
358
 
357
- def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init):
359
+ def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init,
360
+ dtype=mstype.float32):
358
361
  """PS embeddingLookup cache enable set."""
359
362
  if self.sparse:
360
363
  self.forward_unique = True
@@ -368,10 +371,10 @@ class EmbeddingLookup(Cell):
368
371
  if _enable_distributed_mindrt():
369
372
  self.rank_id = get_rank()
370
373
  if self.is_ps_server:
371
- self._slice_pserver_embeddings("zeros")
374
+ self._slice_pserver_embeddings("zeros", dtype=dtype)
372
375
  self._set_cache_enable_and_key_for_pserver(param_key)
373
376
 
374
- def _slice_pserver_embeddings(self, param_init):
377
+ def _slice_pserver_embeddings(self, param_init, dtype=mstype.float32):
375
378
  '''
376
379
  Method to slice embedding tables on Parameter Servers.
377
380
  It helps to train with a large scale embedding table and is used only in Parameter Server training mode.
@@ -399,7 +402,7 @@ class EmbeddingLookup(Cell):
399
402
  for i in range(server_num):
400
403
  self.embedding_table_list.append(Parameter(initializer(param_init,
401
404
  [self.embedding_table_vocab_dim_list[i],
402
- self.embedding_size]),
405
+ self.embedding_size], dtype=dtype),
403
406
  name="embedding_table_server_" + str(i)))
404
407
 
405
408
  self.embedding_offset.append(offset)
@@ -505,12 +508,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
505
508
  :class:`mindspore.nn.EmbeddingLookup`. Default: ``'batch_slice'``.
506
509
  feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently.
507
510
  Default: ``None`` .
508
- max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
509
- or None. Default: ``None`` .
511
+ max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32.
512
+ Default: ``None`` .
510
513
  sparse (bool): Using sparse mode. When 'target' is set to ``'CPU'`` , 'sparse' has to be true.
511
514
  Default: ``True`` .
512
515
  operator (str): The pooling method for the features in one field. Support ``'SUM'`` , ``'MEAN'`` and
513
516
  ``'MAX'`` . Default: ``'SUM'`` .
517
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
514
518
 
515
519
  Inputs:
516
520
  - **input_indices** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`.
@@ -529,12 +533,12 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
529
533
  TypeError: If `vocab_size` or `embedding_size` or `field_size` is not an int.
530
534
  TypeError: If `sparse` is not a bool or `feature_num_list` is not a tuple.
531
535
  ValueError: If `vocab_size` or `embedding_size` or `field_size` is less than 1.
532
- ValueError: If `target` is neither 'CPU' nor 'DEVICE'.
533
- ValueError: If `slice_mode` is not one of 'batch_slice', 'field_slice', 'table_row_slice',
534
- 'table_column_slice'.
535
- ValueError: If `sparse` is False and `target` is 'CPU'.
536
- ValueError: If `slice_mode` is 'field_slice' and `feature_num_list` is None.
537
- ValueError: If `operator` is not one of 'SUM', 'MAX', 'MEAN'.
536
+ ValueError: If `target` is neither ``'CPU'`` nor ``'DEVICE'``.
537
+ ValueError: If `slice_mode` is not one of ``'batch_slice'``, ``'field_slice'``, ``'table_row_slice'``,
538
+ ``'table_column_slice'`` .
539
+ ValueError: If `sparse` is False and `target` is ``'CPU'`` .
540
+ ValueError: If `slice_mode` is ``'field_slice'`` and `feature_num_list` is None.
541
+ ValueError: If `operator` is not one of ``'SUM'``, ``'MAX'``, ``'MEAN'`` .
538
542
 
539
543
  Supported Platforms:
540
544
  ``Ascend`` ``GPU``
@@ -555,10 +559,11 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
555
559
  OPERATOR_MAX = 'MAX'
556
560
 
557
561
  def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU',
558
- slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'):
562
+ slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM',
563
+ dtype=mstype.float32):
559
564
  """Initialize MultiFieldEmbeddingLookup."""
560
565
  super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
561
- slice_mode, feature_num_list, max_norm, sparse)
566
+ slice_mode, feature_num_list, max_norm, sparse, dtype=dtype)
562
567
  self.field_size = Validator.check_positive_int(field_size, 'field_size', self.cls_name)
563
568
  self.operator = operator
564
569
 
@@ -622,8 +627,9 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
622
627
  self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1)))
623
628
  else:
624
629
  if is_auto_parallel:
625
- raise ValueError("For '{}', the 'slice_mode' must be in ['table_row_slice', 'batch_slice' and \
626
- 'table_column_slice'], but got {}".format(self.cls_name, str(slice_mode)))
630
+ raise ValueError(
631
+ f"For '{self.cls_name}', the 'slice_mode' must be in ['table_row_slice', 'batch_slice' "
632
+ f"and 'table_column_slice'], but got {str(slice_mode)}.")
627
633
 
628
634
  # Min value for fp32
629
635
  self.negative_inf_value = -3.402823466E+38
@@ -17,12 +17,13 @@ A FlashAttention Layer.
17
17
  """
18
18
  import math
19
19
 
20
- import mindspore.numpy as mnp
21
- from mindspore import ops
22
- from mindspore.common import dtype as mstype
20
+ import mindspore.common.dtype as mstype
23
21
  from mindspore.common.tensor import Tensor
22
+ from mindspore import ops
24
23
  from mindspore.nn.cell import Cell
25
24
  from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_impl import get_flash_attention
25
+ from mindspore.ops.operations.nn_ops import FlashAttentionScore
26
+ from mindspore._c_expression import MSContext
26
27
 
27
28
  __all__ = ['FlashAttention']
28
29
 
@@ -92,6 +93,7 @@ class FlashAttention(Cell):
92
93
 
93
94
  def __init__(self,
94
95
  head_dim,
96
+ head_num,
95
97
  dropout_rate=0.0,
96
98
  prev_block_num=65536,
97
99
  next_block_num=65536,
@@ -104,18 +106,42 @@ class FlashAttention(Cell):
104
106
  ):
105
107
  super(FlashAttention, self).__init__()
106
108
 
107
- self.flash_attention = get_flash_attention(
108
- prev_block_num=prev_block_num,
109
- next_block_num=next_block_num,
110
- tiling_stgy_name=tiling_stgy_name,
111
- high_precision=high_precision
112
- )
113
- self.flash_attention.add_prim_attr("primitive_target", "Ascend")
114
109
  scaling_constant = math.sqrt(head_dim)
115
- if scaling_constant != 0:
116
- self.scale_factor = Tensor([1. / scaling_constant], dtype=mstype.float16)
117
- else:
110
+ if scaling_constant == 0:
118
111
  raise ValueError("the scaling constant must not be 0.")
112
+ self.scale_factor = Tensor([1. / scaling_constant], dtype=mstype.float16)
113
+
114
+ self.is_910A = MSContext.get_instance().get_ascend_soc_version() == "Ascend910"
115
+ if self.is_910A:
116
+ self.flash_attention = get_flash_attention(
117
+ prev_block_num=prev_block_num,
118
+ next_block_num=next_block_num,
119
+ tiling_stgy_name=tiling_stgy_name,
120
+ high_precision=high_precision
121
+ )
122
+ self.flash_attention.add_prim_attr("primitive_target", "Ascend")
123
+ else:
124
+ if alibi:
125
+ raise ValueError(f"When soc_version is not Ascend910A, alibi must be False")
126
+ self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
127
+ self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
128
+ self.reshape = ops.Reshape()
129
+ self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
130
+ self.zeros = ops.Zeros()
131
+ self.attn_expand_dims = ops.ExpandDims().shard(((dp, 1, 1),))
132
+ fa_strategies = ((dp, 1, mp),
133
+ (dp, 1, mp),
134
+ (dp, 1, mp),
135
+ (dp, 1, 1, 1))
136
+ if dropout_rate > 1e-5:
137
+ fa_strategies += ((dp, mp, 1, 1),)
138
+ self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
139
+ next_tokens=next_block_num,
140
+ keep_prob=1 - dropout_rate,
141
+ scale_value=1.0,
142
+ inner_precise=0 if high_precision else 1).shard(fa_strategies)
143
+
144
+ self.ones = ops.Ones()
119
145
  self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
120
146
  self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
121
147
  self.dropout_rate = dropout_rate
@@ -136,38 +162,35 @@ class FlashAttention(Cell):
136
162
  such as MatMul. Default: None.
137
163
  :return:
138
164
  """
139
- if in_strategy is not None:
140
- shard_stgy = list(in_strategy)
141
- shard_stgy.insert(3, (1,)) # dim_mask
142
- shard_stgy = tuple(shard_stgy)
143
- else:
165
+ if in_strategy is None:
144
166
  # default: dp=1, mp=1, construct inputs only contain query, key, value
145
- shard_stgy = (
167
+ in_strategy = (
146
168
  (1, 1, 1, 1),
147
169
  (1, 1, 1, 1),
148
170
  (1, 1, 1, 1),
149
- (1,), # dim_mask
150
171
  )
151
- self.flash_attention.shard(shard_stgy)
152
- dp = shard_stgy[0][0]
153
- mp = shard_stgy[0][1]
172
+ self.flash_attention.shard(in_strategy)
173
+ dp = in_strategy[0][0]
174
+ mp = in_strategy[0][1]
154
175
  self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
155
176
  inputs_tensor_map = [
156
177
  [3, 2, 1, 0],
157
178
  [3, 2, 1, 0],
158
179
  [3, 2, 1, 0],
159
- [-1]
160
180
  ]
161
181
  if self.have_attention_mask_batch:
162
182
  inputs_tensor_map.append([3, 1, 0])
163
183
  else:
164
184
  inputs_tensor_map.append([-1, 1, 0])
165
185
 
186
+ input_empty_args_num = 2
166
187
  # dropout_mask
167
188
  if self.dropout_rate > 1e-5:
189
+ input_empty_args_num -= 1
168
190
  inputs_tensor_map.append([3, 2, 1, 0])
169
191
 
170
192
  if self.alibi:
193
+ input_empty_args_num -= 1
171
194
  inputs_tensor_map.append([3, 2, 1, 0])
172
195
 
173
196
  self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
@@ -178,7 +201,7 @@ class FlashAttention(Cell):
178
201
  [3, 2, 1] # M
179
202
  ])
180
203
  self.flash_attention.add_prim_attr("as_loss_divisor", 0)
181
- self.flash_attention.add_prim_attr("empty_mirror_ops", 1)
204
+ self.flash_attention.add_prim_attr("empty_mirror_ops", input_empty_args_num)
182
205
 
183
206
  def construct(self, query, key, value, attn_mask=None, alibi_mask=None):
184
207
  """FlashAttention forward
@@ -200,24 +223,42 @@ class FlashAttention(Cell):
200
223
  if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
201
224
  raise ValueError(
202
225
  "query, key, value seq_len must be a multiple of 16, and key seq_len, value seq_len must be the same.")
203
- if self.dropout_rate > 1e-5:
204
- drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
205
- tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
206
- ones = self.fill_v2(tensor_shape, self.tensor_one)
207
- ones = self.depend(ones, query)
208
- drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
209
- else:
210
- drop_mask = None
226
+
211
227
  if head_dim > 304:
212
228
  raise ValueError(
213
229
  "the head_dim must be less than 304, otherwise the ub would be OOM.")
214
- if head_dim % 16 != 0:
215
- padding_size = 16 - head_dim % 16
216
- query = mnp.pad(query, ((0, 0), (0, 0), (0, 0), (0, padding_size)), constant_values=0)
217
- key = mnp.pad(key, ((0, 0), (0, 0), (0, 0), (0, padding_size)), constant_values=0)
218
- value = mnp.pad(value, ((0, 0), (0, 0), (0, 0), (0, padding_size)), constant_values=0)
219
- output, _, _ = self.flash_attention(query, key, value, self.dim_mask, attn_mask, drop_mask, alibi_mask)
220
- output = ops.slice(output, [0, 0, 0, 0], [bsz, head_num, seq_len, head_dim])
230
+
231
+ if self.is_910A:
232
+ # 910A -- FlashAttentionPrimtive
233
+ if self.dropout_rate > 1e-5:
234
+ drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
235
+ tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
236
+ ones = self.fill_v2(tensor_shape, self.tensor_one)
237
+ ones = self.depend(ones, query)
238
+ drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
239
+ else:
240
+ drop_mask = None
241
+ output, _, _ = self.flash_attention(query, key, value, attn_mask, drop_mask, alibi_mask)
221
242
  else:
222
- output, _, _ = self.flash_attention(query, key, value, self.dim_mask, attn_mask, drop_mask, alibi_mask)
243
+ # FlashAttentionScore
244
+ # Useless input, just for binary calls.
245
+ if self.dropout_rate > 1e-5:
246
+ drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
247
+ (bsz, head_num, seq_len, seq_len // 8))
248
+ else:
249
+ drop_mask_bits = None
250
+ # (B, N, S, D) -> (B, S, H)
251
+ query = self.reshape(self.transpose_4d_pre(query, (0, 2, 1, 3)), (bsz, seq_len, -1))
252
+ key = self.reshape(self.transpose_4d_pre(key, (0, 2, 1, 3)), (bsz, seq_len, -1))
253
+ value = self.reshape(self.transpose_4d_pre(value, (0, 2, 1, 3)), (bsz, seq_len, -1))
254
+ attn_mask = self.attn_expand_dims(attn_mask, 1)
255
+ output, _, _ = self.flash_attention(query,
256
+ key,
257
+ value,
258
+ attn_mask,
259
+ drop_mask_bits,
260
+ None,
261
+ None)
262
+ output = self.transpose_4d_post(self.reshape(output, (bsz, seq_len, head_num, head_dim)), (0, 2, 1, 3))
263
+
223
264
  return output
@@ -83,17 +83,17 @@ class ImageGradients(Cell):
83
83
  _check_input_4d(F.shape(images), "images", self.cls_name)
84
84
  batch_size, depth, height, width = P.Shape()(images)
85
85
  if height == 1:
86
- dy = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
86
+ dy = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
87
87
  else:
88
88
  dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
89
- dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
89
+ dy_last = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
90
90
  dy = P.Concat(2)((dy, dy_last))
91
91
 
92
92
  if width == 1:
93
- dx = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
93
+ dx = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
94
94
  else:
95
95
  dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
96
- dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
96
+ dx_last = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
97
97
  dx = P.Concat(3)((dx, dx_last))
98
98
  return dy, dx
99
99
 
@@ -571,7 +571,8 @@ class PixelShuffle(Cell):
571
571
  <https://arxiv.org/abs/1609.05158>`_ .
572
572
 
573
573
  Typically, the input is of shape :math:`(*, C \times r^2, H, W)` , and the output is of shape
574
- :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor and * is zero or more batch dimensions.
574
+ :math:`(*, C, H \times r, W \times r)`,
575
+ where :math:`r` is an upscale factor and :math:`*` is zero or more batch dimensions.
575
576
 
576
577
  Note:
577
578
  The dimension of input Tensor on Ascend should be less than 7.
@@ -621,7 +622,8 @@ class PixelUnshuffle(Cell):
621
622
  <https://arxiv.org/abs/1609.05158>`_ .
622
623
 
623
624
  Typically, the input is of shape :math:`(*, C, H \times r, W \times r)` , and the output is of shape
624
- :math:`(*, C \times r^2, H, W)` , where r is a downscale factor and * is zero or more batch dimensions.
625
+ :math:`(*, C \times r^2, H, W)` ,
626
+ where :math:`r` is a downscale factor and :math:`*` is zero or more batch dimensions.
625
627
 
626
628
  Args:
627
629
  downscale_factor (int): factor to unshuffle the input, and is a positive integer.
@@ -223,7 +223,6 @@ class LGamma(Cell):
223
223
  self.abs = P.Abs()
224
224
  self.shape = P.Shape()
225
225
  self.dtype = P.DType()
226
- self.fill = P.Fill()
227
226
  self.floor = P.Floor()
228
227
  self.equal = P.Equal()
229
228
  self.greater = P.Greater()
@@ -240,7 +239,7 @@ class LGamma(Cell):
240
239
  if F.is_sequence_value_unknown(self.shape(x)):
241
240
  infinity = self.ones_like(x) * F.cast(self.inf, input_dtype)
242
241
  else:
243
- infinity = self.fill(input_dtype, self.shape(x), self.inf)
242
+ infinity = F.fill(input_dtype, self.shape(x), self.inf)
244
243
 
245
244
  need_to_reflect = self.less(x, 0.5)
246
245
  neg_input = -x
@@ -335,7 +334,6 @@ class DiGamma(Cell):
335
334
  self.abs = P.Abs()
336
335
  self.shape = P.Shape()
337
336
  self.dtype = P.DType()
338
- self.fill = P.Fill()
339
337
  self.floor = P.Floor()
340
338
  self.equal = P.Equal()
341
339
  self.less = P.Less()
@@ -371,7 +369,7 @@ class DiGamma(Cell):
371
369
  reduced_input = x + self.abs(self.floor(x + 0.5))
372
370
  reflection = y - self.pi * self.cos(self.pi * reduced_input) / self.sin(self.pi * reduced_input)
373
371
  real_result = self.select(need_to_reflect, reflection, y)
374
- nan = self.fill(self.dtype(x), self.shape(x), np.nan)
372
+ nan = F.fill(self.dtype(x), self.shape(x), np.nan)
375
373
 
376
374
  return self.select(self.logicaland(self.less(x, 0), self.equal(x, self.floor(x))),
377
375
  nan, real_result)
@@ -391,7 +389,6 @@ def _igamma_series(ax, x, a, enabled):
391
389
 
392
390
  logicaland = P.LogicalAnd()
393
391
  greater = P.Greater()
394
- fill = P.Fill()
395
392
  shape = P.Shape()
396
393
  dtype = P.DType()
397
394
  select = P.Select()
@@ -424,8 +421,8 @@ def _igamma_series(ax, x, a, enabled):
424
421
  select(enabled, x, vals[4]), select(enabled, dc_da, vals[5]),
425
422
  select(enabled, dans_da, vals[6]))
426
423
 
427
- ones = fill(dtype(a), shape(a), 1)
428
- zeros = fill(dtype(a), shape(a), 0)
424
+ ones = F.fill(dtype(a), shape(a), 1)
425
+ zeros = F.fill(dtype(a), shape(a), 0)
429
426
  vals = (enabled, a, ones, ones, x, zeros, zeros)
430
427
 
431
428
  vals = _while_helper_func(cond, body, vals)
@@ -441,7 +438,6 @@ def _igammac_continued_fraction(ax, x, a, enabled):
441
438
  greater = P.Greater()
442
439
  less = P.Less()
443
440
  notequal = P.NotEqual()
444
- fill = P.Fill()
445
441
  shape = P.Shape()
446
442
  dtype = P.DType()
447
443
  select = P.Select()
@@ -482,7 +478,7 @@ def _igammac_continued_fraction(ax, x, a, enabled):
482
478
  qk_is_nonzero = notequal(qk, 0)
483
479
  r = pk / qk
484
480
 
485
- t = select(qk_is_nonzero, abs_x((ans - r) / r), fill(dtype(t), shape(t), 1))
481
+ t = select(qk_is_nonzero, abs_x((ans - r) / r), F.fill(dtype(t), shape(t), 1))
486
482
  ans = select(qk_is_nonzero, r, ans)
487
483
 
488
484
  dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c
@@ -490,7 +486,7 @@ def _igammac_continued_fraction(ax, x, a, enabled):
490
486
  dans_da_new = select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da)
491
487
  grad_conditional = select(qk_is_nonzero,
492
488
  abs_x(dans_da_new - dans_da),
493
- fill(dtype(dans_da), shape(dans_da), 1))
489
+ F.fill(dtype(dans_da), shape(dans_da), 1))
494
490
 
495
491
  pkm2 = pkm1
496
492
  pkm1 = pk
@@ -525,16 +521,16 @@ def _igammac_continued_fraction(ax, x, a, enabled):
525
521
 
526
522
  y = 1 - a
527
523
  z = x + y + 1
528
- c = fill(dtype(x), shape(x), 0)
529
- pkm2 = fill(dtype(x), shape(x), 1)
524
+ c = F.fill(dtype(x), shape(x), 0)
525
+ pkm2 = F.fill(dtype(x), shape(x), 1)
530
526
  qkm2 = x
531
527
  pkm1 = x + 1
532
528
  qkm1 = z * x
533
529
  ans = pkm1 / qkm1
534
- t = fill(dtype(x), shape(x), 1)
535
- dpkm2_da = fill(dtype(x), shape(x), 0)
536
- dqkm2_da = fill(dtype(x), shape(x), 0)
537
- dpkm1_da = fill(dtype(x), shape(x), 0)
530
+ t = F.fill(dtype(x), shape(x), 1)
531
+ dpkm2_da = F.fill(dtype(x), shape(x), 0)
532
+ dqkm2_da = F.fill(dtype(x), shape(x), 0)
533
+ dpkm1_da = F.fill(dtype(x), shape(x), 0)
538
534
  dqkm1_da = -x
539
535
  dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1
540
536
  vals = (enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da)
@@ -606,7 +602,6 @@ class IGamma(Cell):
606
602
  self.exp = P.Exp()
607
603
  self.select = P.Select()
608
604
  self.zeroslike = P.ZerosLike()
609
- self.fill = P.Fill()
610
605
  self.shape = P.Shape()
611
606
  self.dtype = P.DType()
612
607
  self.lgamma = LGamma()
@@ -633,7 +628,7 @@ class IGamma(Cell):
633
628
  1 - _igammac_continued_fraction(ax, x, a, self.logicaland(enabled, use_igammac)),
634
629
  _igamma_series(ax, x, a, self.logicaland(enabled, self.logicalnot(use_igammac))))
635
630
  output = self.select(x_is_zero, self.zeroslike(output), output)
636
- output = self.select(domain_error, self.fill(self.dtype(a), self.shape(a), np.nan), output)
631
+ output = self.select(domain_error, F.fill(self.dtype(a), self.shape(a), np.nan), output)
637
632
  return output
638
633
 
639
634