mindspore 2.1.0__cp39-cp39-win_amd64.whl → 2.2.10__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (505) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +4 -1
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -1
  9. mindspore/_checkparam.py +23 -29
  10. mindspore/_extends/graph_kernel/__init__.py +0 -1
  11. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  12. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  13. mindspore/_extends/graph_kernel/splitter.py +4 -11
  14. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  15. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  16. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  17. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  20. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  21. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  22. mindspore/_extends/parse/__init__.py +12 -15
  23. mindspore/_extends/parse/namespace.py +7 -33
  24. mindspore/_extends/parse/parser.py +61 -71
  25. mindspore/_extends/parse/resources.py +1 -1
  26. mindspore/_extends/parse/standard_method.py +74 -104
  27. mindspore/_extends/parse/trope.py +1 -1
  28. mindspore/_extends/remote/kernel_build_server.py +25 -7
  29. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  30. mindspore/_install_custom.py +43 -0
  31. mindspore/amp.py +47 -11
  32. mindspore/atlprov.dll +0 -0
  33. mindspore/boost/boost.py +1 -8
  34. mindspore/boost/boost_cell_wrapper.py +3 -2
  35. mindspore/boost/grad_accumulation.py +1 -1
  36. mindspore/boost/group_loss_scale_manager.py +8 -7
  37. mindspore/c1.dll +0 -0
  38. mindspore/c1xx.dll +0 -0
  39. mindspore/c2.dll +0 -0
  40. mindspore/common/__init__.py +5 -3
  41. mindspore/common/_jit_fallback_utils.py +6 -0
  42. mindspore/common/_register_for_adapter.py +2 -0
  43. mindspore/common/_register_for_tensor.py +2 -2
  44. mindspore/common/_stub_tensor.py +13 -0
  45. mindspore/common/_utils.py +13 -0
  46. mindspore/common/api.py +174 -259
  47. mindspore/common/auto_dynamic_shape.py +494 -0
  48. mindspore/common/dtype.py +18 -11
  49. mindspore/common/dump.py +6 -4
  50. mindspore/common/initializer.py +14 -14
  51. mindspore/common/jit_config.py +33 -15
  52. mindspore/common/lazy_inline.py +126 -7
  53. mindspore/common/mindir_util.py +101 -0
  54. mindspore/common/parameter.py +51 -41
  55. mindspore/common/seed.py +4 -4
  56. mindspore/common/sparse_tensor.py +13 -14
  57. mindspore/common/tensor.py +243 -165
  58. mindspore/communication/__init__.py +7 -4
  59. mindspore/communication/_comm_helper.py +83 -4
  60. mindspore/communication/management.py +152 -84
  61. mindspore/config/op_info.config +14 -3
  62. mindspore/context.py +152 -61
  63. mindspore/dataset/__init__.py +5 -5
  64. mindspore/dataset/audio/__init__.py +2 -2
  65. mindspore/dataset/audio/transforms.py +52 -52
  66. mindspore/dataset/callback/ds_callback.py +16 -2
  67. mindspore/dataset/core/config.py +68 -51
  68. mindspore/dataset/engine/cache_client.py +28 -5
  69. mindspore/dataset/engine/datasets.py +250 -112
  70. mindspore/dataset/engine/datasets_audio.py +43 -211
  71. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  72. mindspore/dataset/engine/datasets_text.py +43 -67
  73. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  74. mindspore/dataset/engine/datasets_vision.py +219 -1029
  75. mindspore/dataset/engine/iterators.py +11 -4
  76. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  77. mindspore/dataset/engine/obs/util.py +3 -0
  78. mindspore/dataset/engine/samplers.py +1 -1
  79. mindspore/dataset/engine/validators.py +19 -5
  80. mindspore/dataset/text/__init__.py +3 -3
  81. mindspore/dataset/text/transforms.py +101 -127
  82. mindspore/dataset/text/utils.py +205 -138
  83. mindspore/dataset/transforms/__init__.py +1 -1
  84. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  85. mindspore/dataset/transforms/transforms.py +95 -40
  86. mindspore/dataset/utils/browse_dataset.py +8 -2
  87. mindspore/dataset/utils/line_reader.py +17 -19
  88. mindspore/dataset/vision/__init__.py +3 -3
  89. mindspore/dataset/vision/c_transforms.py +6 -3
  90. mindspore/dataset/vision/transforms.py +409 -287
  91. mindspore/dataset/vision/utils.py +13 -14
  92. mindspore/dataset/vision/validators.py +11 -1
  93. mindspore/dnnl.dll +0 -0
  94. mindspore/dpcmi.dll +0 -0
  95. mindspore/experimental/map_parameter.py +14 -0
  96. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  97. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  98. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  99. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  100. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  101. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  102. mindspore/gen_ops.py +273 -0
  103. mindspore/include/OWNERS +0 -1
  104. mindspore/include/api/data_type.h +2 -1
  105. mindspore/include/api/graph.h +0 -15
  106. mindspore/include/api/kernel.h +2 -0
  107. mindspore/include/api/kernel_api.h +37 -12
  108. mindspore/include/api/model.h +17 -14
  109. mindspore/include/api/status.h +8 -3
  110. mindspore/include/api/types.h +37 -4
  111. mindspore/include/c_api/ms/abstract.h +67 -0
  112. mindspore/include/c_api/ms/attribute.h +197 -0
  113. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  114. mindspore/include/c_api/ms/base/macros.h +32 -0
  115. mindspore/include/c_api/ms/base/status.h +33 -0
  116. mindspore/include/c_api/ms/base/types.h +282 -0
  117. mindspore/include/c_api/ms/context.h +102 -0
  118. mindspore/include/c_api/ms/graph.h +160 -0
  119. mindspore/include/c_api/ms/node.h +606 -0
  120. mindspore/include/c_api/ms/tensor.h +161 -0
  121. mindspore/include/c_api/ms/value.h +84 -0
  122. mindspore/include/dataset/constants.h +6 -5
  123. mindspore/include/dataset/execute.h +23 -13
  124. mindspore/include/dataset/text.h +26 -26
  125. mindspore/include/dataset/transforms.h +13 -13
  126. mindspore/include/dataset/vision.h +60 -60
  127. mindspore/include/dataset/vision_ascend.h +5 -6
  128. mindspore/include/dataset/vision_lite.h +17 -17
  129. mindspore/jpeg62.dll +0 -0
  130. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  131. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  132. mindspore/mindspore_backend.dll +0 -0
  133. mindspore/mindspore_common.dll +0 -0
  134. mindspore/mindspore_core.dll +0 -0
  135. mindspore/mindspore_glog.dll +0 -0
  136. mindspore/mindspore_shared_lib.dll +0 -0
  137. mindspore/msobj140.dll +0 -0
  138. mindspore/mspdb140.dll +0 -0
  139. mindspore/mspdbcore.dll +0 -0
  140. mindspore/mspdbst.dll +0 -0
  141. mindspore/mspft140.dll +0 -0
  142. mindspore/msvcdis140.dll +0 -0
  143. mindspore/msvcp140_1.dll +0 -0
  144. mindspore/msvcp140_2.dll +0 -0
  145. mindspore/msvcp140_atomic_wait.dll +0 -0
  146. mindspore/msvcp140_codecvt_ids.dll +0 -0
  147. mindspore/nn/__init__.py +0 -2
  148. mindspore/nn/cell.py +313 -74
  149. mindspore/nn/dynamic_lr.py +21 -21
  150. mindspore/nn/layer/activation.py +22 -30
  151. mindspore/nn/layer/basic.py +15 -13
  152. mindspore/nn/layer/channel_shuffle.py +1 -1
  153. mindspore/nn/layer/container.py +271 -9
  154. mindspore/nn/layer/conv.py +323 -204
  155. mindspore/nn/layer/dense.py +8 -5
  156. mindspore/nn/layer/embedding.py +33 -27
  157. mindspore/nn/layer/flash_attention.py +141 -88
  158. mindspore/nn/layer/image.py +8 -6
  159. mindspore/nn/layer/math.py +16 -25
  160. mindspore/nn/layer/normalization.py +107 -66
  161. mindspore/nn/layer/padding.py +1 -1
  162. mindspore/nn/layer/pooling.py +131 -109
  163. mindspore/nn/layer/rnn_cells.py +27 -22
  164. mindspore/nn/layer/rnns.py +13 -16
  165. mindspore/nn/layer/thor_layer.py +1 -1
  166. mindspore/nn/layer/transformer.py +221 -154
  167. mindspore/nn/learning_rate_schedule.py +9 -1
  168. mindspore/nn/loss/loss.py +235 -174
  169. mindspore/nn/optim/ada_grad.py +2 -1
  170. mindspore/nn/optim/adadelta.py +1 -0
  171. mindspore/nn/optim/adafactor.py +2 -1
  172. mindspore/nn/optim/adam.py +7 -4
  173. mindspore/nn/optim/adamax.py +3 -2
  174. mindspore/nn/optim/adasum.py +2 -2
  175. mindspore/nn/optim/asgd.py +2 -3
  176. mindspore/nn/optim/ftrl.py +6 -5
  177. mindspore/nn/optim/lamb.py +7 -4
  178. mindspore/nn/optim/lars.py +1 -1
  179. mindspore/nn/optim/lazyadam.py +5 -3
  180. mindspore/nn/optim/momentum.py +2 -1
  181. mindspore/nn/optim/optimizer.py +53 -4
  182. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  183. mindspore/nn/optim/rmsprop.py +4 -3
  184. mindspore/nn/optim/rprop.py +23 -12
  185. mindspore/nn/optim/sgd.py +26 -11
  186. mindspore/nn/optim/thor.py +9 -7
  187. mindspore/nn/probability/bijector/bijector.py +5 -5
  188. mindspore/nn/probability/bijector/power_transform.py +27 -27
  189. mindspore/nn/probability/bijector/softplus.py +3 -3
  190. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  191. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  192. mindspore/nn/probability/distribution/beta.py +3 -3
  193. mindspore/nn/probability/distribution/categorical.py +7 -7
  194. mindspore/nn/probability/distribution/cauchy.py +0 -1
  195. mindspore/nn/probability/distribution/distribution.py +3 -3
  196. mindspore/nn/probability/distribution/gamma.py +3 -3
  197. mindspore/nn/probability/distribution/geometric.py +4 -4
  198. mindspore/nn/probability/distribution/gumbel.py +4 -4
  199. mindspore/nn/probability/distribution/log_normal.py +2 -2
  200. mindspore/nn/probability/distribution/logistic.py +2 -2
  201. mindspore/nn/probability/distribution/poisson.py +4 -4
  202. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  203. mindspore/nn/probability/distribution/uniform.py +6 -6
  204. mindspore/nn/wrap/cell_wrapper.py +84 -34
  205. mindspore/nn/wrap/grad_reducer.py +8 -5
  206. mindspore/nn/wrap/loss_scale.py +105 -42
  207. mindspore/numpy/array_creations.py +1 -2
  208. mindspore/numpy/array_ops.py +3 -2
  209. mindspore/numpy/utils_const.py +5 -5
  210. mindspore/opencv_core452.dll +0 -0
  211. mindspore/opencv_imgcodecs452.dll +0 -0
  212. mindspore/opencv_imgproc452.dll +0 -0
  213. mindspore/ops/_grad_experimental/__init__.py +0 -5
  214. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  215. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  216. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  217. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  218. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  219. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  220. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  221. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  222. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  223. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  224. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  225. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  226. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  227. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  228. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  229. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  230. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  231. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  232. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  233. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  234. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  235. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  236. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  237. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  238. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  239. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  240. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  241. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  242. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  243. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  244. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  245. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  246. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  247. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  248. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  249. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  250. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  251. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  252. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  253. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  254. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  255. mindspore/ops/_primitive_cache.py +1 -1
  256. mindspore/ops/_tracefunc.py +45 -13
  257. mindspore/ops/_utils/utils.py +6 -1
  258. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  259. mindspore/ops/_vmap/vmap_base.py +3 -3
  260. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  261. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  262. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  263. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  264. mindspore/ops/arg_dtype_cast.py +54 -0
  265. mindspore/ops/composite/base.py +37 -10
  266. mindspore/ops/composite/math_ops.py +5 -4
  267. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  268. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  269. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  270. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  271. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  272. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  273. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  274. mindspore/ops/deprecated.py +304 -0
  275. mindspore/ops/function/__init__.py +4 -1
  276. mindspore/ops/function/array_func.py +174 -193
  277. mindspore/ops/function/clip_func.py +81 -13
  278. mindspore/ops/function/debug_func.py +1 -1
  279. mindspore/ops/function/grad/grad_func.py +18 -9
  280. mindspore/ops/function/image_func.py +10 -4
  281. mindspore/ops/function/linalg_func.py +5 -5
  282. mindspore/ops/function/math_func.py +575 -386
  283. mindspore/ops/function/nn_func.py +568 -260
  284. mindspore/ops/function/random_func.py +88 -57
  285. mindspore/ops/function/sparse_func.py +1 -1
  286. mindspore/ops/function/sparse_unary_func.py +14 -12
  287. mindspore/ops/function/vmap_func.py +6 -5
  288. mindspore/ops/functional.py +15 -10
  289. mindspore/ops/op_info_register.py +244 -25
  290. mindspore/ops/operations/__init__.py +28 -19
  291. mindspore/ops/operations/_grad_ops.py +72 -7
  292. mindspore/ops/operations/_inner_ops.py +350 -17
  293. mindspore/ops/operations/_quant_ops.py +4 -8
  294. mindspore/ops/operations/_sequence_ops.py +42 -0
  295. mindspore/ops/operations/array_ops.py +68 -282
  296. mindspore/ops/operations/comm_ops.py +107 -59
  297. mindspore/ops/operations/custom_ops.py +94 -70
  298. mindspore/ops/operations/debug_ops.py +8 -4
  299. mindspore/ops/operations/image_ops.py +18 -12
  300. mindspore/ops/operations/inner_ops.py +26 -3
  301. mindspore/ops/operations/math_ops.py +189 -141
  302. mindspore/ops/operations/nn_ops.py +794 -489
  303. mindspore/ops/operations/other_ops.py +0 -22
  304. mindspore/ops/operations/random_ops.py +53 -111
  305. mindspore/ops/operations/sparse_ops.py +3 -1
  306. mindspore/ops/primitive.py +24 -18
  307. mindspore/parallel/_auto_parallel_context.py +68 -8
  308. mindspore/parallel/_cost_model_context.py +2 -2
  309. mindspore/parallel/_offload_context.py +17 -3
  310. mindspore/parallel/_parallel_serialization.py +12 -5
  311. mindspore/parallel/_ps_context.py +12 -0
  312. mindspore/parallel/_tensor.py +18 -13
  313. mindspore/parallel/_transformer/layers.py +5 -3
  314. mindspore/parallel/_transformer/loss.py +1 -0
  315. mindspore/parallel/_transformer/moe.py +2 -2
  316. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  317. mindspore/parallel/_transformer/transformer.py +23 -3
  318. mindspore/parallel/_utils.py +11 -7
  319. mindspore/parallel/algo_parameter_config.py +85 -5
  320. mindspore/parallel/checkpoint_transform.py +19 -12
  321. mindspore/parallel/shard.py +21 -14
  322. mindspore/pgodb140.dll +0 -0
  323. mindspore/pgort140.dll +0 -0
  324. mindspore/profiler/common/struct_type.py +3 -3
  325. mindspore/profiler/common/util.py +4 -2
  326. mindspore/profiler/envprofiling.py +1 -1
  327. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  328. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  329. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  330. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  331. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  332. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  333. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  334. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  335. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  336. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  337. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  338. mindspore/profiler/parser/flops_parser.py +15 -11
  339. mindspore/profiler/parser/framework_parser.py +38 -22
  340. mindspore/profiler/parser/hccl_parser.py +16 -12
  341. mindspore/profiler/parser/integrator.py +22 -11
  342. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  343. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  344. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  345. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  346. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  347. mindspore/profiler/parser/optime_parser.py +1 -1
  348. mindspore/profiler/parser/profiler_info.py +21 -2
  349. mindspore/profiler/parser/step_trace_parser.py +11 -14
  350. mindspore/profiler/profiling.py +179 -89
  351. mindspore/rewrite/api/node.py +102 -19
  352. mindspore/rewrite/api/node_type.py +5 -1
  353. mindspore/rewrite/api/pattern_engine.py +1 -1
  354. mindspore/rewrite/api/scoped_value.py +9 -17
  355. mindspore/rewrite/api/symbol_tree.py +131 -47
  356. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  357. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  358. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  359. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  360. mindspore/rewrite/common/rewrite_elog.py +5 -1
  361. mindspore/rewrite/namer.py +33 -24
  362. mindspore/rewrite/namespace.py +14 -5
  363. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  364. mindspore/rewrite/node/call_function.py +79 -0
  365. mindspore/rewrite/node/cell_container.py +135 -0
  366. mindspore/rewrite/node/control_flow.py +88 -0
  367. mindspore/rewrite/{node.py → node/node.py} +273 -234
  368. mindspore/rewrite/node/node_manager.py +254 -0
  369. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  370. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  371. mindspore/rewrite/parsers/assign_parser.py +216 -221
  372. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  373. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  374. mindspore/rewrite/parsers/constant_parser.py +9 -6
  375. mindspore/rewrite/parsers/container_parser.py +9 -7
  376. mindspore/rewrite/parsers/for_parser.py +36 -15
  377. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  378. mindspore/rewrite/parsers/if_parser.py +28 -24
  379. mindspore/rewrite/parsers/module_parser.py +196 -25
  380. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  381. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  382. mindspore/rewrite/parsers/return_parser.py +6 -6
  383. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  384. mindspore/rewrite/sparsify/utils.py +1 -1
  385. mindspore/rewrite/symbol_tree.py +523 -578
  386. mindspore/rewrite/symbol_tree_builder.py +9 -193
  387. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  388. mindspore/run_check/_check_version.py +6 -4
  389. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  390. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  391. mindspore/tbbmalloc.dll +0 -0
  392. mindspore/tinyxml2.dll +0 -0
  393. mindspore/train/_utils.py +7 -3
  394. mindspore/train/amp.py +323 -123
  395. mindspore/train/anf_ir_pb2.py +14 -2
  396. mindspore/train/callback/_backup_and_restore.py +2 -12
  397. mindspore/train/callback/_callback.py +29 -4
  398. mindspore/train/callback/_checkpoint.py +23 -8
  399. mindspore/train/callback/_early_stop.py +2 -2
  400. mindspore/train/callback/_landscape.py +4 -4
  401. mindspore/train/callback/_loss_monitor.py +2 -2
  402. mindspore/train/callback/_on_request_exit.py +2 -2
  403. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  404. mindspore/train/callback/_summary_collector.py +15 -8
  405. mindspore/train/callback/_time_monitor.py +58 -5
  406. mindspore/train/data_sink.py +5 -11
  407. mindspore/train/dataset_helper.py +84 -57
  408. mindspore/train/loss_scale_manager.py +2 -2
  409. mindspore/train/metrics/__init__.py +3 -3
  410. mindspore/train/metrics/cosine_similarity.py +1 -1
  411. mindspore/train/metrics/hausdorff_distance.py +3 -2
  412. mindspore/train/metrics/mean_surface_distance.py +3 -2
  413. mindspore/train/metrics/metric.py +39 -19
  414. mindspore/train/metrics/roc.py +2 -2
  415. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  416. mindspore/train/mind_ir_pb2.py +85 -36
  417. mindspore/train/model.py +187 -47
  418. mindspore/train/serialization.py +487 -161
  419. mindspore/train/summary/_summary_adapter.py +1 -1
  420. mindspore/train/summary/_writer_pool.py +3 -2
  421. mindspore/train/summary/summary_record.py +37 -17
  422. mindspore/train/train_thor/convert_utils.py +3 -3
  423. mindspore/train/train_thor/dataset_helper.py +1 -1
  424. mindspore/turbojpeg.dll +0 -0
  425. mindspore/vcmeta.dll +0 -0
  426. mindspore/vcruntime140.dll +0 -0
  427. mindspore/vcruntime140_1.dll +0 -0
  428. mindspore/version.py +1 -1
  429. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +5 -3
  430. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +433 -479
  431. mindspore/_extends/graph_kernel/expander.py +0 -80
  432. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  433. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  434. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  435. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  436. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  437. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  438. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  439. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  440. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  441. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  442. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  443. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  444. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  445. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  446. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  447. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  448. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  449. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  450. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  451. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  452. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  453. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  454. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  455. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  456. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  457. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  458. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  459. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  460. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  461. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  462. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  463. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  464. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  465. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  466. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  467. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  468. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  469. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  470. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  471. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  472. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  473. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  474. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  475. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  476. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  477. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  478. mindspore/dataset/datapreprocess/__init__.py +0 -20
  479. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  480. mindspore/include/api/net.h +0 -142
  481. mindspore/nn/lr_scheduler.py +0 -262
  482. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  483. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  484. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  485. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  486. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  487. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  488. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  489. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  490. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  491. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  492. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  493. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  494. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  495. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  496. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  497. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  498. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  499. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  500. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  501. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  502. mindspore/rewrite/node_visitor.py +0 -44
  503. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  504. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -0
  505. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -1,506 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """GraphKernel Op Infer"""
16
-
17
- import copy
18
- import sys
19
- from functools import reduce as prod_reduce
20
- from .model import GraphKernelUnsupportedException as GKException
21
- from .model import PrimLib, DataFormat as DF
22
-
23
-
24
- def infer(op_name, inputs, attrs):
25
- """infer shape dtype and format"""
26
-
27
- def _create_opinfer():
28
- self_module = sys.modules.get(__name__, None)
29
- if self_module is None:
30
- raise GKException("OpInfo does not support op {}".format(op_name))
31
-
32
- if hasattr(self_module, op_name):
33
- op_cls = getattr(self_module, op_name)
34
- return op_cls(op_name, inputs, attrs)
35
- # common infer
36
- class_name_map = {
37
- PrimLib.ELEMWISE: "_Elemwise",
38
- PrimLib.REDUCE: "_Reduce",
39
- }
40
- cls_name = class_name_map.get(PrimLib.primtives.get(op_name, PrimLib.default_primtive).iter_type, None)
41
- if not cls_name:
42
- raise GKException("OpInfo does not support op {}".format(op_name))
43
- op_cls = getattr(self_module, cls_name)
44
- return op_cls(op_name, inputs, attrs)
45
-
46
- return _create_opinfer().infer()
47
-
48
-
49
- class OpInfer:
50
- """
51
- OpInfer is the base class for inferring operator info in GraphKernel model builder.
52
-
53
- There are three methods should be overridden to define the infer logic of the operator:
54
- _infer_shape(), _infer_type() and _infer_format().
55
- """
56
-
57
- def __init__(self, name, inputs, attrs):
58
- self.name = name
59
- self.inputs = inputs
60
- self.attrs = attrs
61
-
62
- def infer(self):
63
- """Infer shape, type and format by op inputs"""
64
- self._check()
65
- return self._infer_shape(), self._infer_type(), self._infer_format()
66
-
67
- def _infer_shape(self):
68
- return self.inputs[0].shape
69
-
70
- def _infer_type(self):
71
- return self.inputs[0].dtype
72
-
73
- def _infer_format(self):
74
- return self.inputs[0].data_format
75
-
76
- def _check(self):
77
- self._check_shape()
78
- self._check_type()
79
- self._check_format()
80
-
81
- def _check_shape(self):
82
- pass
83
-
84
- def _check_type(self):
85
- """check all dtypes are same"""
86
- dtype = self.inputs[0].dtype
87
- for i, t in enumerate(self.inputs[1:]):
88
- if t.dtype != dtype:
89
- raise GKException(
90
- "Incompatible data type between input {}({}) and {}({})".format(0, dtype, i + 1, t.dtype))
91
-
92
- def _check_format(self):
93
- """check formats are compatible. only DefaultFormat is compatible with others"""
94
- result = self.inputs[0].data_format
95
- i = 0
96
- for j, t in enumerate(self.inputs[1:]):
97
- if t.data_format != result:
98
- if DF.DEFAULT not in (result, t.data_format):
99
- raise GKException("Incompatible format between input {}({}) and {}({})".format(
100
- i, result, j + 1, t.data_format))
101
- if result == DF.DEFAULT:
102
- result = t.data_format
103
- i = j + 1
104
-
105
-
106
- class _Elemwise(OpInfer):
107
- """Common infer for elementwise operators"""
108
-
109
- @staticmethod
110
- def broadcast_shape(shapes):
111
- """deduce broadcast shape using same rules as numpy"""
112
- dim_size = max(len(shape) for shape in shapes)
113
- align_shapes = [[1] * (dim_size - len(shape)) + shape for shape in shapes]
114
- out_shape = [1] * dim_size
115
- for i in range(dim_size):
116
- for align_shape in align_shapes:
117
- if align_shape[i] == 1:
118
- continue
119
- if out_shape[i] == 1:
120
- out_shape[i] = align_shape[i]
121
- elif out_shape[i] != align_shape[i]:
122
- raise GKException("Input shapes {} can not broadcast.".format(shapes))
123
- return out_shape
124
-
125
- @staticmethod
126
- def defaultformat_to_nz(default_shape):
127
- """default format shape to fractal_Nz format shape"""
128
- # As shape (1,) can broadcast to any shape, it can be regarded as a special FractalNZ shape
129
- if len(default_shape) == 1 and default_shape[0] == 1:
130
- return default_shape
131
- more_two_d_shape, two_d_shape = default_shape[:-2], default_shape[-2:]
132
- # (32) or (1, 32) -> (2, 1, 1, 16)
133
- if len(two_d_shape) == 1 or (len(two_d_shape) == 2 and two_d_shape[0] == 1):
134
- shape = [two_d_shape[-1] // 16, 1, 1, 16]
135
- if two_d_shape[-1] % 16 != 0:
136
- raise GKException("Can not convert default format shape{} to fractal_Nz format shape, because default "
137
- "format shape[-1] should be multiplies of 16, but got {}"
138
- .format(default_shape, two_d_shape[-1]))
139
- return more_two_d_shape + shape
140
- # (32, 1) -> (1, 2, 16, 1)
141
- if len(two_d_shape) == 2 and two_d_shape[1] == 1:
142
- shape = [1, two_d_shape[0] // 16, 16, 1]
143
- if two_d_shape[0] % 16 != 0:
144
- raise GKException("Can not convert default format shape{} to fractal_Nz format shape, because default "
145
- "format shape[-2] should be multiples of 16, but got {}"
146
- .format(default_shape, two_d_shape[0]))
147
- return more_two_d_shape + shape
148
- # (32, 48) -> (3, 2, 16, 16)
149
- shape = [two_d_shape[1] // 16, two_d_shape[0] // 16, 16, 16]
150
- if two_d_shape[0] % 16 != 0 or two_d_shape[1] % 16 != 0:
151
- raise GKException("Can not convert default format shape{} to fractal_Nz format shape, because default "
152
- "format shape[-2] and shape[-1] should be multiples of 16, but got {} and {}"
153
- .format(default_shape, two_d_shape[0], two_d_shape[1]))
154
- return more_two_d_shape + shape
155
-
156
- def _infer_shape(self):
157
- """returns the output shape with broadcast"""
158
-
159
- # in case all inputs are default format/NHWC/NCHW
160
- is_default = [op_input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW) for op_input in self.inputs]
161
- if all(is_default):
162
- return self.broadcast_shape([op_input.shape for op_input in self.inputs])
163
-
164
- # in case formats are fractal_nz, default_fromat/NHWC/HCHW(optional)
165
- is_default_frac_nz = (op_input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW, DF.FRAC_NZ)
166
- for op_input in self.inputs)
167
- if all(is_default_frac_nz):
168
- nz_shapes = [self.defaultformat_to_nz(op_input.shape) if op_input.data_format != DF.FRAC_NZ
169
- else op_input.shape for op_input in self.inputs]
170
- return self.broadcast_shape(nz_shapes)
171
-
172
- inputs_format = [op_input.data_format for op_input in self.inputs]
173
- raise GKException("Only support DefaultFormat, NHWC, NCHW and FRACTAL_NZ in inputs format, but got {}"
174
- .format(inputs_format))
175
-
176
- def _infer_format(self):
177
- for tensor in self.inputs:
178
- if tensor.data_format != DF.DEFAULT:
179
- return tensor.data_format
180
- return DF.DEFAULT
181
-
182
-
183
- class _Reduce(OpInfer):
184
- """Common infer for reduction operators"""
185
-
186
- def _check(self):
187
- super(_Reduce, self)._check()
188
- # check reduce axis in the range [-len, len)
189
- shape_len = len(self.inputs[0].shape)
190
- axis = self.attrs['reduce_axis']
191
- if isinstance(axis, int):
192
- axis = [axis]
193
- if not all((-shape_len <= i < shape_len) for i in axis):
194
- raise GKException(
195
- "Reduce axis should be in range [{},{}) but got {}".format(-shape_len, shape_len, axis))
196
-
197
- def _infer_shape(self):
198
- shape = copy.deepcopy(self.inputs[0].shape)
199
- axis = self.attrs['reduce_axis']
200
-
201
- if isinstance(axis, int):
202
- axis = [axis]
203
- if any(i < 0 for i in axis):
204
- # change the axis to non-negative number.
205
- axis = list(map(lambda i: i + len(shape) if i < 0 else i, axis))
206
- self.attrs['reduce_axis'] = sorted(axis)
207
-
208
- if self.attrs['keep_dims']:
209
- for i in axis:
210
- shape[i] = 1
211
- return shape
212
-
213
- real_shape = []
214
- for i, s in enumerate(shape):
215
- if i not in axis:
216
- real_shape.append(s)
217
- return real_shape
218
-
219
- def _infer_format(self):
220
- return DF.DEFAULT
221
-
222
-
223
- class _Reshape(OpInfer):
224
- """Common infer for reshape operators, should not be instantiated"""
225
-
226
- def _infer_shape(self):
227
- raise GKException("_infer_shape should be implemented by subclass")
228
-
229
- def _infer_format(self):
230
- return DF.DEFAULT if "format" not in self.attrs else self.attrs["format"]
231
-
232
-
233
- class Reshape(_Reshape):
234
- """Reshape op infer"""
235
-
236
- def _check_shape(self):
237
- input_shape = self.inputs[0].shape
238
- output_shape = self.attrs["shape"]
239
- size_before_reshape = prod_reduce(lambda x, y: x * y, input_shape)
240
- size_after_reshape = prod_reduce(lambda x, y: x * y, output_shape)
241
- if size_before_reshape != size_after_reshape:
242
- raise GKException("For 'Reshape', can not reshape {} to {}".format(input_shape, output_shape))
243
-
244
- def _infer_shape(self):
245
- return self.attrs["shape"]
246
-
247
-
248
- class Cast(_Elemwise):
249
- """Cast op infer"""
250
-
251
- def _infer_type(self):
252
- return self.attrs["dst_type"]
253
-
254
-
255
- class InplaceAssign(_Elemwise):
256
- """InplaceAssign op infer"""
257
-
258
- def _infer_shape(self):
259
- return self.inputs[2].shape
260
-
261
- def _infer_type(self):
262
- return self.inputs[2].dtype
263
-
264
- def _infer_format(self):
265
- return self.inputs[2].data_format
266
-
267
-
268
- class BroadcastTo(OpInfer):
269
- """BroadcastTo op infer"""
270
-
271
- def _infer_shape(self):
272
- return self.attrs["shape"]
273
-
274
- def _infer_format(self):
275
- return self.inputs[0].data_format
276
-
277
-
278
- class _CompareOp(_Elemwise):
279
- """Compare operators"""
280
-
281
- def _infer_type(self):
282
- return "bool"
283
-
284
-
285
- class CImag(OpInfer):
286
- """CImag op infer"""
287
-
288
- def _check_type(self):
289
- if self.inputs[0].dtype != "complex64" and self.inputs[0].dtype != "complex128":
290
- raise GKException("For 'CImag', input[0] should be of type complex64 or"
291
- "type complex128, but got {}".format(self.inputs[0].dtype))
292
-
293
- def _infer_type(self):
294
- if self.inputs[0].dtype == "complex64":
295
- return "float32"
296
- return "float64"
297
-
298
-
299
- class CReal(OpInfer):
300
- """CReal op infer"""
301
-
302
- def _check_type(self):
303
- if self.inputs[0].dtype != "complex64" and self.inputs[0].dtype != "complex128":
304
- raise GKException("For 'CReal', input[0] should be of type complex64 or"
305
- "type complex128, but got {}".format(self.inputs[0].dtype))
306
-
307
- def _infer_type(self):
308
- if self.inputs[0].dtype == "complex64":
309
- return "float32"
310
- return "float64"
311
-
312
-
313
- class Complex(OpInfer):
314
- """Complex op infer"""
315
-
316
- def _check_type(self):
317
- if self.inputs[0].dtype != "float32" and self.inputs[0].dtype != "float64":
318
- raise GKException("For 'Complex', input[0] should be of type float32 or type float64,"
319
- "but got {}".format(self.inputs[0].dtype))
320
- if self.inputs[0].dtype != self.inputs[1].dtype:
321
- raise GKException("For 'Complex', inputs data type mismatch ({} vs {})"
322
- .format(self.inputs[0].dtype, self.inputs[1].dtype))
323
-
324
- def _infer_type(self):
325
- if self.inputs[0].dtype == "float32":
326
- return "complex64"
327
- return "complex128"
328
-
329
-
330
- class Less(_CompareOp):
331
- """Less op infer"""
332
-
333
-
334
- class LessEqual(_CompareOp):
335
- """LessEqual op infer"""
336
-
337
-
338
- class Equal(_CompareOp):
339
- """Equal op infer"""
340
-
341
-
342
- class Greater(_CompareOp):
343
- """Greater op infer"""
344
-
345
-
346
- class GreaterEqual(_CompareOp):
347
- """GreaterEqual op infer"""
348
-
349
-
350
- class Select(_Elemwise):
351
- """Select op infer"""
352
-
353
- def _check_type(self):
354
- if self.inputs[0].dtype != "bool":
355
- raise GKException("For 'Select', input[0] should be of type bool, but got {}".format(self.inputs[0].dtype))
356
- if self.inputs[1].dtype != self.inputs[2].dtype:
357
- raise GKException("For 'Select', input[1] and input[2] data type mismatch ({} vs {})"
358
- .format(self.inputs[1].dtype, self.inputs[2].dtype))
359
-
360
- def _infer_type(self):
361
- return self.inputs[1].dtype
362
-
363
-
364
- def check_format_any(formats, checked_format):
365
- """Check whether input format in formats list"""
366
- if not isinstance(formats, (list, tuple)):
367
- raise GKException("formats {} should be of type list or tuple, but got {}.".format(formats, type(formats)))
368
- if checked_format not in formats:
369
- raise GKException("Check {} failed: can not find it in {}".format(checked_format, formats))
370
-
371
-
372
- def check_nd(data, nd):
373
- """Check whether data are nd format"""
374
- if not isinstance(data, (list, tuple)) or len(data) != nd:
375
- raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data))
376
-
377
-
378
- def conv_had_pad(pad_list, pad_mode):
379
- """Check whether conv need to add pad"""
380
- if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4:
381
- raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list))
382
- if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]:
383
- return True
384
- if pad_mode not in ["VALID", "valid"]:
385
- for _, pad in enumerate(pad_list):
386
- if pad != 0:
387
- return True
388
- return False
389
-
390
-
391
- class Conv2D(OpInfer):
392
- """Conv2D infer"""
393
-
394
- def _infer_type(self):
395
- if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
396
- return self.attrs["dst_type"]
397
- return self.inputs[0].dtype
398
-
399
- def _infer_shape(self):
400
- shape_0 = list(self.inputs[0].shape)
401
- shape_1 = list(self.inputs[1].shape)
402
- check_nd(shape_0, 4)
403
- check_nd(shape_1, 4)
404
-
405
- formats = [self.inputs[0].data_format, self.inputs[1].data_format, self.attrs["format"]]
406
- check_format_any(formats, DF.NHWC)
407
-
408
- n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0]
409
- pad_list = self.attrs["pad_list"]
410
- pad_mode = self.attrs["pad_mode"]
411
- kernel_size = self.attrs["kernel_size"]
412
- stride = self.attrs["stride"]
413
- dilation = self.attrs["dilation"]
414
- check_nd(pad_list, 4)
415
- check_nd(kernel_size, 2)
416
- check_nd(stride, 4)
417
- check_nd(dilation, 4)
418
-
419
- has_pad = conv_had_pad(pad_list, pad_mode)
420
- if not has_pad:
421
- pad_list = [0, 0, 0, 0]
422
-
423
- k_h = (kernel_size[0] - 1) * dilation[-2] + 1
424
- k_w = (kernel_size[1] - 1) * dilation[-1] + 1
425
- out_h = (h + pad_list[0] + pad_list[1] - k_h) // stride[-2] + 1
426
- out_w = (w + pad_list[2] + pad_list[3] - k_w) // stride[-1] + 1
427
- return [n, out_h, out_w, out_channel]
428
-
429
-
430
- class MatMul(OpInfer):
431
- """MatMul infer"""
432
-
433
- def _infer_type(self):
434
- if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
435
- return self.attrs["dst_type"]
436
- return self.inputs[0].dtype
437
-
438
- def _infer_shape(self):
439
- shape_0 = list(self.inputs[0].shape)
440
- shape_1 = list(self.inputs[1].shape)
441
- if len(shape_0) != 2 or len(shape_1) != 2:
442
- raise GKException("For 'MatMul', inputs shape must be 2D, but got {}, {}"
443
- .format(shape_0, shape_1))
444
- transpose_a = self.attrs["transpose_a"]
445
- transpose_b = self.attrs["transpose_b"]
446
- m, k1 = (shape_0[-1], shape_0[-2]) if transpose_a else (shape_0[-2], shape_0[-1])
447
- k2, n = (shape_1[-1], shape_1[-2]) if transpose_b else (shape_1[-2], shape_1[-1])
448
- if k1 != k2:
449
- raise GKException("For 'MatMul', inputs have different k value: {} vs {}".format(k1, k2))
450
- output_shape = [m, n]
451
- return output_shape
452
-
453
-
454
- class PadAkg(OpInfer):
455
- """PadAkg infer"""
456
-
457
- def _infer_shape(self):
458
- shape = list(self.inputs[0].shape)
459
- n = len(shape)
460
- pad_before = list(self.attrs["head"])
461
- pad_after = list(self.attrs["tail"])
462
- if len(pad_before) != n or len(pad_after) != n:
463
- raise GKException("For 'PadAkg', input dimension and pad mismatch: {}d vs {}d vs {}d"
464
- .format(n, len(pad_before), len(pad_after)))
465
- out_shape = [shape[i] + pad_before[i] + pad_after[i] for i in range(n)]
466
- return out_shape
467
-
468
-
469
- class UnPadAkg(OpInfer):
470
- """UnPadAkg infer"""
471
-
472
- def _infer_shape(self):
473
- shape = list(self.inputs[0].shape)
474
- n = len(shape)
475
- unpad_after = list(self.attrs["tail"])
476
- if len(unpad_after) != n:
477
- raise GKException("For 'UnPadAkg', input dimension and pad mismatch: {}d vs {}d"
478
- .format(n, len(unpad_after)))
479
- out_shape = [shape[i] - unpad_after[i] for i in range(n)]
480
- return out_shape
481
-
482
-
483
- class Gather(OpInfer):
484
- """Gather infer"""
485
-
486
- def _infer_shape(self):
487
- input_shape = self.inputs[0].shape
488
- indices_shape = self.inputs[1].shape
489
- axis = self.attrs['axis']
490
- output_shape = input_shape
491
- indices_shape_one_dim = 1
492
- for dim in indices_shape:
493
- indices_shape_one_dim *= dim
494
- output_shape[axis] = indices_shape_one_dim
495
- return output_shape
496
-
497
- def _infer_type(self):
498
- return self.inputs[0].dtype
499
-
500
- def _infer_format(self):
501
- return self.inputs[0].data_format
502
-
503
- def _check_type(self):
504
- if self.inputs[1].dtype != "int32":
505
- raise GKException("For 'Gather', inputs[1] should be of type int32, but got {}"
506
- .format(self.inputs[1].dtype))
@@ -1,20 +0,0 @@
1
- # Copyright 2019 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Preprocess of dataset.
17
- """
18
- from __future__ import absolute_import
19
-
20
- from mindspore.dataset.datapreprocess.preprocess_imagenet_validate_dataset import *
@@ -1,54 +0,0 @@
1
- # Copyright 2019 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Process imagenet validate dataset.
16
- """
17
- from __future__ import absolute_import
18
-
19
- import os
20
- import stat
21
- from mindspore import log as logger
22
-
23
-
24
- def preprocess_imagenet_validation_dataset(train_dataset_path, validation_dataset_path, image_label_mapping_file):
25
- """
26
- Call this function before read imagenet validation dataset.
27
-
28
- Args:
29
- train_dataset_path (str): train dataset path
30
- validation_dataset_path (str): validation dataset path
31
- image_label_mapping_file (str): imagenet_validate_dataset_2012_image_dir_map.txt file path
32
- """
33
- train_dataset_path = os.path.realpath(train_dataset_path)
34
- sub_dir = [dir_.name for dir_ in os.scandir(train_dataset_path) if dir_.is_dir()]
35
- for sub_dir_name in sub_dir:
36
- validate_sub_dir = os.path.join(validation_dataset_path, sub_dir_name)
37
- validate_sub_dir = os.path.realpath(validate_sub_dir)
38
- if not os.path.exists(validate_sub_dir):
39
- os.makedirs(validate_sub_dir, mode=stat.S_IRWXU)
40
- real_file_path = os.path.realpath(image_label_mapping_file)
41
- mappings = [mapping.strip() for mapping in open(real_file_path).readlines()]
42
- for mapping in mappings:
43
- image_dir = mapping.split(':')
44
- old_image_path = os.path.join(validation_dataset_path, image_dir[0])
45
- old_image_path = os.path.realpath(old_image_path)
46
- if not os.path.exists(old_image_path):
47
- logger.warning('Image is not existed %s', old_image_path)
48
- new_image_sub_dir = os.path.join(validation_dataset_path, image_dir[1])
49
- new_image_sub_dir = os.path.realpath(new_image_sub_dir)
50
- new_image_path = os.path.join(new_image_sub_dir, image_dir[0])
51
- new_image_path = os.path.realpath(new_image_path)
52
- if not os.path.exists(new_image_sub_dir):
53
- logger.warning('Image sub dir is not existed %s', new_image_sub_dir)
54
- os.rename(old_image_path, new_image_path)