mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (275) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/common/__init__.py +6 -4
  13. mindspore/common/_pijit_context.py +190 -0
  14. mindspore/common/_register_for_tensor.py +2 -1
  15. mindspore/common/_tensor_overload.py +139 -0
  16. mindspore/common/api.py +102 -87
  17. mindspore/common/dump.py +5 -6
  18. mindspore/common/generator.py +1 -7
  19. mindspore/common/hook_handle.py +14 -26
  20. mindspore/common/initializer.py +51 -15
  21. mindspore/common/mindir_util.py +2 -2
  22. mindspore/common/parameter.py +62 -15
  23. mindspore/common/recompute.py +39 -9
  24. mindspore/common/sparse_tensor.py +7 -3
  25. mindspore/common/tensor.py +183 -37
  26. mindspore/communication/__init__.py +1 -1
  27. mindspore/communication/_comm_helper.py +38 -3
  28. mindspore/communication/comm_func.py +315 -60
  29. mindspore/communication/management.py +14 -14
  30. mindspore/context.py +132 -22
  31. mindspore/dataset/__init__.py +1 -1
  32. mindspore/dataset/audio/__init__.py +1 -1
  33. mindspore/dataset/core/config.py +7 -0
  34. mindspore/dataset/core/validator_helpers.py +7 -0
  35. mindspore/dataset/engine/cache_client.py +1 -1
  36. mindspore/dataset/engine/datasets.py +72 -44
  37. mindspore/dataset/engine/datasets_audio.py +7 -7
  38. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  39. mindspore/dataset/engine/datasets_text.py +20 -20
  40. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  41. mindspore/dataset/engine/datasets_vision.py +33 -33
  42. mindspore/dataset/engine/iterators.py +29 -0
  43. mindspore/dataset/engine/obs/util.py +7 -0
  44. mindspore/dataset/engine/queue.py +114 -60
  45. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  46. mindspore/dataset/engine/validators.py +34 -14
  47. mindspore/dataset/text/__init__.py +1 -4
  48. mindspore/dataset/transforms/__init__.py +0 -3
  49. mindspore/dataset/utils/line_reader.py +2 -0
  50. mindspore/dataset/vision/__init__.py +1 -4
  51. mindspore/dataset/vision/utils.py +1 -1
  52. mindspore/dataset/vision/validators.py +2 -1
  53. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  54. mindspore/experimental/es/embedding_service.py +883 -0
  55. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  56. mindspore/experimental/llm_boost/__init__.py +21 -0
  57. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  58. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  59. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  60. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  61. mindspore/experimental/llm_boost/register.py +129 -0
  62. mindspore/experimental/llm_boost/utils.py +31 -0
  63. mindspore/experimental/optim/adamw.py +85 -0
  64. mindspore/experimental/optim/optimizer.py +3 -0
  65. mindspore/hal/__init__.py +3 -3
  66. mindspore/hal/contiguous_tensors_handle.py +175 -0
  67. mindspore/hal/stream.py +18 -0
  68. mindspore/include/api/model_group.h +13 -1
  69. mindspore/include/api/types.h +10 -10
  70. mindspore/include/dataset/config.h +2 -2
  71. mindspore/include/dataset/constants.h +2 -2
  72. mindspore/include/dataset/execute.h +2 -2
  73. mindspore/include/dataset/vision.h +4 -0
  74. mindspore/log.py +1 -1
  75. mindspore/mindrecord/filewriter.py +68 -51
  76. mindspore/mindspore_backend.dll +0 -0
  77. mindspore/mindspore_common.dll +0 -0
  78. mindspore/mindspore_core.dll +0 -0
  79. mindspore/mindspore_np_dtype.dll +0 -0
  80. mindspore/mindspore_ops.dll +0 -0
  81. mindspore/mint/__init__.py +983 -46
  82. mindspore/mint/distributed/__init__.py +31 -0
  83. mindspore/mint/distributed/distributed.py +254 -0
  84. mindspore/mint/nn/__init__.py +268 -23
  85. mindspore/mint/nn/functional.py +125 -19
  86. mindspore/mint/nn/layer/__init__.py +39 -0
  87. mindspore/mint/nn/layer/activation.py +133 -0
  88. mindspore/mint/nn/layer/normalization.py +477 -0
  89. mindspore/mint/nn/layer/pooling.py +110 -0
  90. mindspore/mint/optim/adamw.py +26 -13
  91. mindspore/mint/special/__init__.py +63 -0
  92. mindspore/multiprocessing/__init__.py +2 -1
  93. mindspore/nn/__init__.py +0 -1
  94. mindspore/nn/cell.py +276 -96
  95. mindspore/nn/layer/activation.py +211 -44
  96. mindspore/nn/layer/basic.py +137 -10
  97. mindspore/nn/layer/embedding.py +137 -2
  98. mindspore/nn/layer/normalization.py +101 -5
  99. mindspore/nn/layer/padding.py +34 -48
  100. mindspore/nn/layer/pooling.py +161 -7
  101. mindspore/nn/layer/transformer.py +3 -3
  102. mindspore/nn/loss/__init__.py +2 -2
  103. mindspore/nn/loss/loss.py +84 -6
  104. mindspore/nn/optim/__init__.py +2 -1
  105. mindspore/nn/optim/adadelta.py +1 -1
  106. mindspore/nn/optim/adam.py +1 -1
  107. mindspore/nn/optim/lamb.py +1 -1
  108. mindspore/nn/optim/tft_wrapper.py +124 -0
  109. mindspore/nn/wrap/cell_wrapper.py +12 -23
  110. mindspore/nn/wrap/grad_reducer.py +5 -5
  111. mindspore/nn/wrap/loss_scale.py +17 -3
  112. mindspore/numpy/__init__.py +1 -1
  113. mindspore/numpy/array_creations.py +65 -68
  114. mindspore/numpy/array_ops.py +64 -60
  115. mindspore/numpy/fft.py +610 -75
  116. mindspore/numpy/logic_ops.py +11 -10
  117. mindspore/numpy/math_ops.py +85 -84
  118. mindspore/numpy/utils_const.py +4 -4
  119. mindspore/opencv_core452.dll +0 -0
  120. mindspore/opencv_imgcodecs452.dll +0 -0
  121. mindspore/opencv_imgproc452.dll +0 -0
  122. mindspore/ops/__init__.py +6 -4
  123. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  124. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  125. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  126. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  127. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  128. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  129. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  130. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  131. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  132. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  133. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  134. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  135. mindspore/ops/composite/base.py +85 -48
  136. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  137. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  138. mindspore/ops/function/__init__.py +22 -0
  139. mindspore/ops/function/array_func.py +492 -153
  140. mindspore/ops/function/debug_func.py +113 -1
  141. mindspore/ops/function/fft_func.py +15 -2
  142. mindspore/ops/function/grad/grad_func.py +3 -2
  143. mindspore/ops/function/math_func.py +564 -207
  144. mindspore/ops/function/nn_func.py +817 -383
  145. mindspore/ops/function/other_func.py +3 -2
  146. mindspore/ops/function/random_func.py +402 -12
  147. mindspore/ops/function/reshard_func.py +13 -11
  148. mindspore/ops/function/sparse_unary_func.py +1 -1
  149. mindspore/ops/function/vmap_func.py +3 -2
  150. mindspore/ops/functional.py +24 -14
  151. mindspore/ops/op_info_register.py +3 -3
  152. mindspore/ops/operations/__init__.py +7 -2
  153. mindspore/ops/operations/_grad_ops.py +2 -76
  154. mindspore/ops/operations/_infer_ops.py +1 -1
  155. mindspore/ops/operations/_inner_ops.py +71 -94
  156. mindspore/ops/operations/array_ops.py +14 -146
  157. mindspore/ops/operations/comm_ops.py +63 -53
  158. mindspore/ops/operations/custom_ops.py +83 -19
  159. mindspore/ops/operations/debug_ops.py +42 -10
  160. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  161. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  162. mindspore/ops/operations/math_ops.py +12 -223
  163. mindspore/ops/operations/nn_ops.py +20 -114
  164. mindspore/ops/operations/other_ops.py +7 -4
  165. mindspore/ops/operations/random_ops.py +46 -1
  166. mindspore/ops/primitive.py +18 -6
  167. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  168. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  169. mindspore/ops_generate/gen_constants.py +36 -0
  170. mindspore/ops_generate/gen_ops.py +67 -52
  171. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  172. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  173. mindspore/ops_generate/op_proto.py +10 -3
  174. mindspore/ops_generate/pyboost_utils.py +14 -1
  175. mindspore/ops_generate/template.py +43 -21
  176. mindspore/parallel/__init__.py +3 -1
  177. mindspore/parallel/_auto_parallel_context.py +31 -9
  178. mindspore/parallel/_cell_wrapper.py +85 -0
  179. mindspore/parallel/_parallel_serialization.py +47 -19
  180. mindspore/parallel/_tensor.py +127 -13
  181. mindspore/parallel/_utils.py +53 -22
  182. mindspore/parallel/algo_parameter_config.py +5 -5
  183. mindspore/parallel/checkpoint_transform.py +46 -39
  184. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  185. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  186. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  187. mindspore/parallel/parameter_broadcast.py +3 -4
  188. mindspore/parallel/shard.py +162 -31
  189. mindspore/parallel/transform_safetensors.py +1146 -0
  190. mindspore/profiler/__init__.py +2 -1
  191. mindspore/profiler/common/constant.py +29 -0
  192. mindspore/profiler/common/registry.py +47 -0
  193. mindspore/profiler/common/util.py +28 -0
  194. mindspore/profiler/dynamic_profiler.py +694 -0
  195. mindspore/profiler/envprofiling.py +17 -19
  196. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  197. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  198. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  199. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  200. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  201. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  202. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  203. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  205. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  206. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  207. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  208. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  209. mindspore/profiler/parser/framework_parser.py +1 -391
  210. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  211. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  212. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  213. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  214. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  215. mindspore/profiler/parser/profiler_info.py +78 -6
  216. mindspore/profiler/profiler.py +153 -0
  217. mindspore/profiler/profiling.py +285 -413
  218. mindspore/rewrite/__init__.py +1 -2
  219. mindspore/rewrite/common/namespace.py +4 -4
  220. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  221. mindspore/run_check/_check_version.py +39 -104
  222. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  223. mindspore/train/__init__.py +4 -3
  224. mindspore/train/_utils.py +105 -19
  225. mindspore/train/amp.py +171 -53
  226. mindspore/train/callback/__init__.py +2 -2
  227. mindspore/train/callback/_callback.py +4 -4
  228. mindspore/train/callback/_checkpoint.py +97 -31
  229. mindspore/train/callback/_cluster_monitor.py +1 -1
  230. mindspore/train/callback/_flops_collector.py +1 -0
  231. mindspore/train/callback/_loss_monitor.py +3 -3
  232. mindspore/train/callback/_on_request_exit.py +145 -31
  233. mindspore/train/callback/_summary_collector.py +5 -5
  234. mindspore/train/callback/_tft_register.py +375 -0
  235. mindspore/train/dataset_helper.py +15 -3
  236. mindspore/train/metrics/metric.py +3 -3
  237. mindspore/train/metrics/roc.py +4 -4
  238. mindspore/train/mind_ir_pb2.py +44 -39
  239. mindspore/train/model.py +154 -58
  240. mindspore/train/serialization.py +342 -128
  241. mindspore/utils/__init__.py +21 -0
  242. mindspore/utils/utils.py +60 -0
  243. mindspore/version.py +1 -1
  244. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  245. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
  246. mindspore/include/c_api/ms/abstract.h +0 -67
  247. mindspore/include/c_api/ms/attribute.h +0 -197
  248. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  249. mindspore/include/c_api/ms/base/macros.h +0 -32
  250. mindspore/include/c_api/ms/base/status.h +0 -33
  251. mindspore/include/c_api/ms/base/types.h +0 -283
  252. mindspore/include/c_api/ms/context.h +0 -102
  253. mindspore/include/c_api/ms/graph.h +0 -160
  254. mindspore/include/c_api/ms/node.h +0 -606
  255. mindspore/include/c_api/ms/tensor.h +0 -161
  256. mindspore/include/c_api/ms/value.h +0 -84
  257. mindspore/mindspore_shared_lib.dll +0 -0
  258. mindspore/nn/extend/basic.py +0 -140
  259. mindspore/nn/extend/embedding.py +0 -143
  260. mindspore/nn/extend/layer/normalization.py +0 -109
  261. mindspore/nn/extend/pooling.py +0 -117
  262. mindspore/nn/layer/embedding_service.py +0 -531
  263. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  264. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  265. mindspore/ops/extend/__init__.py +0 -53
  266. mindspore/ops/extend/array_func.py +0 -218
  267. mindspore/ops/extend/math_func.py +0 -76
  268. mindspore/ops/extend/nn_func.py +0 -308
  269. mindspore/ops/silent_check.py +0 -162
  270. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  271. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  272. mindspore/train/callback/_mindio_ttp.py +0 -443
  273. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
  274. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  275. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
mindspore/common/api.py CHANGED
@@ -38,12 +38,13 @@ from mindspore.common.tensor import Tensor as PythonTensor
38
38
  from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
39
39
  from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
40
40
  from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
41
+ from mindspore._c_expression.amp import get_curr_amp_strategy
41
42
  from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor, COOTensor, \
42
43
  PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
43
- _ms_memory_recycle, _bind_device_ctx, jit_mode_pi_enable, jit_mode_pi_compile
44
+ _ms_memory_recycle, _bind_device_ctx
44
45
  from mindspore.parallel._ps_context import _is_role_sched
45
46
  from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
46
- _is_in_auto_parallel_mode
47
+ _is_in_auto_parallel_mode, _is_parallel_mode
47
48
  from mindspore import _checkparam as Validator
48
49
  from mindspore._checkparam import is_stub_tensor
49
50
  from mindspore.common._utils import is_shape_unknown
@@ -51,6 +52,8 @@ from mindspore.common.mutable import mutable
51
52
  from mindspore.common._register_for_adapter import ms_adapter_registry
52
53
  from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
53
54
  get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
55
+ from mindspore.common._pijit_context import PIJitCaptureContext
56
+ from mindspore.common.parameter import Parameter
54
57
 
55
58
  # Store ms_function class compiled pipeline cache.
56
59
  ms_compile_cache = set()
@@ -513,6 +516,19 @@ def _generate_dyn_compile_args(compile_args, dyn_args):
513
516
  return tuple(new_compile_args)
514
517
 
515
518
 
519
+ def _get_parameter_ids(args, kwargs):
520
+ """Get the ids of parameters."""
521
+ parameter_ids = ""
522
+ for arg in args:
523
+ if isinstance(arg, Parameter):
524
+ parameter_ids += str(id(arg))
525
+ for _, value in kwargs.items():
526
+ # The type of key is usually String type.
527
+ if isinstance(value, Parameter):
528
+ parameter_ids += str(id(value))
529
+ return parameter_ids
530
+
531
+
516
532
  class _MindsporeFunctionExecutor:
517
533
  """
518
534
  Represents a function compiled by graph compiler.
@@ -625,6 +641,10 @@ class _MindsporeFunctionExecutor:
625
641
 
626
642
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
627
643
  key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
644
+
645
+ parameter_ids = _get_parameter_ids(args, kwargs)
646
+ if parameter_ids != "":
647
+ key = str(key) + '.' + parameter_ids
628
648
  phase = generate_name + '.' + str(key)
629
649
 
630
650
  update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
@@ -783,31 +803,28 @@ def _get_jit_hash(hash_input):
783
803
  return _get_obj_id(hash_input)
784
804
 
785
805
 
786
- def _update_graph_executor_config(jit_config):
787
- """Update GraphExecutor jit_config"""
788
- if isinstance(jit_config, JitConfig):
789
- jit_config = jit_config.jit_config_dict
790
- if not isinstance(jit_config, dict):
791
- return
792
- valid_config = dict()
793
- for k, v in jit_config.items():
794
- valid_config[str(k)] = str(v)
795
- GraphExecutor_.get_instance().set_jit_config(JitConfig(**valid_config).jit_config_dict)
796
-
797
-
798
806
  def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=None, compile_once=False):
799
807
  """
800
808
  Create a callable MindSpore graph from a Python function.
801
809
 
802
810
  This allows the MindSpore runtime to apply optimizations based on graph.
803
811
 
812
+ Note:
813
+ - If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
814
+ will not accept `**kwargs`.
815
+ - It is not supported to run a function with decoration @jit(mode=“PIJit”)
816
+ in static graph mode, in which case the decoration @jit(mode=“PIJit”) is considered invalid.
817
+ - Calls to functions with decorated @jit(mode=“PIJit”) inside functions
818
+ decorated with @jit(mode=“PIJit”) are not supported,
819
+ and the decoration @jit(mode=“PIJit”) is considered invalid.
820
+
804
821
  Args:
805
822
  fn (Function): The Python function that will be run as a graph. Default: ``None`` .
806
823
  mode (str): The type of jit used, the value of mode should be ``PIJit`` or ``PSJit``. Default: ``PSJit`` .
807
824
 
808
- - `PSJit <https://www.mindspore.cn/docs/en/master/note/static_graph_syntax_support.html>`_ :
825
+ - `PSJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html>`_ :
809
826
  Parse python ast to build graph.
810
- - `PIJit <https://www.mindspore.cn/docs/en/master/design/dynamic_graph_and_static_graph.html>`_ :
827
+ - `PIJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/pynative.html#pijit>`_ :
811
828
  Parse python bytecode to build graph at runtime.
812
829
 
813
830
  input_signature (Union[Tuple, List, Dict, Tensor]): The Tensor which describes the input arguments. The
@@ -831,10 +848,6 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
831
848
  it was created again.
832
849
  Default: ``False`` .
833
850
 
834
- Note:
835
- If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
836
- will not accept `**kwargs`.
837
-
838
851
  Returns:
839
852
  Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
840
853
  None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
@@ -938,45 +951,20 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
938
951
  # only the function or cell instance wrapped by shard will fall into this branch
939
952
  if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
940
953
  process_obj = hash_args
954
+ # Handle auto mixed precision strategy.
955
+ if not hasattr(func, "amp_strategy"):
956
+ if isinstance(func, types.MethodType):
957
+ setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
958
+ else:
959
+ setattr(func, "amp_strategy", get_curr_amp_strategy())
941
960
  out = _MindsporeFunctionExecutor(func, hash_obj, dyn_args, process_obj, jit_config)(*args, **kwargs)
942
961
  return out
943
962
 
944
963
  return staging_specialize
945
964
 
946
- def pi_wrap_mindspore(decorated):
947
- func = decorated
948
- if isinstance(func, ms.nn.Cell):
949
- func = func.construct
950
- if isinstance(func, type) and issubclass(func, ms.nn.Cell):
951
- func = func.construct
952
- if isinstance(func, types.MethodType):
953
- func = func.__func__
954
- if not isinstance(func, types.FunctionType):
955
- logger.warning("only support function and mindspore.nn.Cell instance")
956
- return decorated
957
-
958
- # generator, coroutine, awaitable and a function that return them is unsupported
959
- UNSUPPORTED_CODE_TYPE = (inspect.CO_GENERATOR | inspect.CO_COROUTINE |
960
- inspect.CO_ASYNC_GENERATOR | inspect.CO_ITERABLE_COROUTINE)
961
- if func.__code__.co_flags & UNSUPPORTED_CODE_TYPE:
962
- return decorated
963
-
964
- _update_graph_executor_config(jit_config)
965
- config = dict()
966
- if isinstance(jit_config, JitConfig):
967
- config.update(jit_config.jit_config_dict)
968
- elif jit_config is not None:
969
- config.update(jit_config)
970
- jit_mode_pi_enable()
971
-
972
- if jit_mode_pi_compile(func, config, input_signature) is False:
973
- logger.warning('add fn {} to compile failed '.format(func))
974
-
975
- return decorated
976
-
977
965
  wrap_func = wrap_mindspore
978
966
  if mode == "PIJit":
979
- wrap_func = pi_wrap_mindspore
967
+ wrap_func = PIJitCaptureContext(jit_config, input_signature)
980
968
 
981
969
  if fn is not None:
982
970
  return wrap_func(fn)
@@ -1272,7 +1260,7 @@ def jit_class(cls):
1272
1260
  if not inspect.isclass(cls):
1273
1261
  raise TypeError(f'Decorator jit_class can only be used for class type, but got {cls}.')
1274
1262
  # Check if cls is nn.Cell.
1275
- if issubclass(cls, nn.Cell):
1263
+ if issubclass(cls, nn.cell.Cell):
1276
1264
  raise TypeError(f"Decorator jit_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
1277
1265
  setattr(cls, '__ms_class__', True)
1278
1266
  return cls
@@ -1463,23 +1451,22 @@ class _PyNativeExecutor:
1463
1451
  """
1464
1452
  self._executor.end_graph(obj, output, *args, *(kwargs.values()))
1465
1453
 
1466
- def check_run(self, grad, obj, weights, grad_hash_id, *args, **kwargs):
1454
+ def check_run(self, grad, obj, weights, grad_hash_id, *args):
1467
1455
  """
1468
1456
  Whether the forward graph need to construct.
1469
1457
 
1470
1458
  Args:
1471
1459
  grad (GradOperation): The gradoperation object.
1472
1460
  obj (Function/Cell): The function or cell instance.
1473
- grad_hash_id (tuple): The id of objects which contribute to cache of compiled graph in pynative mode.
1461
+ grad_hash_id (tuple): The id of objects, which contributes to cache of compiled graph in pynative mode.
1474
1462
  args (tuple): Function or cell input arguments.
1475
- kwargs (dict): keyword arguments.
1476
1463
 
1477
1464
  Return:
1478
- bool, specifies whether the forward graph need to construct.
1465
+ bool, specifies whether the forward graph needs to construct.
1479
1466
  """
1480
- return self._executor.check_run(grad, obj, weights, grad_hash_id, *args, *(kwargs.values()))
1467
+ return self._executor.check_run(grad, obj, weights, grad_hash_id, *args)
1481
1468
 
1482
- def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
1469
+ def grad(self, obj, grad, weights, grad_position, *args):
1483
1470
  """
1484
1471
  Get grad graph.
1485
1472
 
@@ -1490,12 +1477,11 @@ class _PyNativeExecutor:
1490
1477
  grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
1491
1478
  If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
1492
1479
  args (tuple): Function or cell input arguments.
1493
- kwargs (dict): keyword arguments.
1494
1480
 
1495
1481
  Return:
1496
1482
  None.
1497
1483
  """
1498
- return self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
1484
+ return self._executor.grad(grad, obj, weights, grad_position, *args)
1499
1485
 
1500
1486
  def clear_res(self):
1501
1487
  """
@@ -1528,9 +1514,23 @@ class _PyNativeExecutor:
1528
1514
  """
1529
1515
  return self._executor.grad_jit(output, *args)
1530
1516
 
1517
+ def call_custom_bprop(self, obj, output, *args, **kwargs):
1518
+ """
1519
+ Call custom bprop to build variable for cell bprop.
1520
+ Args:
1521
+ obj (Cell): The function or cell instance.
1522
+ output (Tensor/tuple/list): Function or cell output object.
1523
+ args (tuple): Function or cell input arguments.
1524
+ kwargs (dict): keyword arguments.
1525
+
1526
+ Return:
1527
+ None.
1528
+ """
1529
+ return self._executor.call_custom_bprop(obj, output, *args, *(kwargs.values()))
1530
+
1531
1531
  def grad_flag(self):
1532
1532
  """
1533
- The flag of building grad graph.
1533
+ The flag of whether the net building grad graph.
1534
1534
 
1535
1535
  Return:
1536
1536
  bool, whether building grad graph.
@@ -1563,7 +1563,7 @@ class _PyNativeExecutor:
1563
1563
 
1564
1564
  def enable_grad(self):
1565
1565
  """
1566
- The global flag whether needing to calculate gradient.
1566
+ The global flag that whether need to calculate gradient use in no_grad.
1567
1567
 
1568
1568
  Return:
1569
1569
  bool, whether needing to calculate gradient.
@@ -1582,6 +1582,18 @@ class _PyNativeExecutor:
1582
1582
  """
1583
1583
  self._executor.set_enable_grad(flag)
1584
1584
 
1585
+ def requires_grad(self):
1586
+ """
1587
+ When both enable_grad is true and grad_flag is true, that the flag requires_grad will be true.
1588
+
1589
+ Args:
1590
+ flag (bool): Specifying whether calculating gradient.
1591
+
1592
+ Return:
1593
+ None.
1594
+ """
1595
+ return self._executor.requires_grad()
1596
+
1585
1597
  def set_jit_compile_status(self, status, phase):
1586
1598
  """
1587
1599
  Set jit is compiling
@@ -1605,6 +1617,18 @@ class _PyNativeExecutor:
1605
1617
  """
1606
1618
  self._executor.set_is_run_recompute(status)
1607
1619
 
1620
+ def set_cell_use_dynamic_shape_process(self, flag):
1621
+ """
1622
+ Set the dynamic shape flag of eval process.
1623
+
1624
+ Args:
1625
+ flag (bool): Specifying whether using a dynamic process.
1626
+
1627
+ Return:
1628
+ None.
1629
+ """
1630
+ self._executor.set_cell_use_dynamic_shape_process(flag)
1631
+
1608
1632
  def set_dynamic_input(self, obj, *args):
1609
1633
  """
1610
1634
  Set dynamic shape tensor of input arguments.
@@ -1630,27 +1654,19 @@ class _PyNativeExecutor:
1630
1654
  """
1631
1655
  return self._executor.get_dynamic_input(*actual_args)
1632
1656
 
1633
- def is_first_cell(self):
1634
- """
1635
- The flag of first cell instance.
1636
-
1637
- Return:
1638
- bool, specifies whether is the first cell.
1657
+ def set_mixed_precision_type(self, mixed_precision_type, is_push=True):
1639
1658
  """
1640
-
1641
- return self._executor.is_first_cell()
1642
-
1643
- def set_hook_changed(self, cell):
1644
- """
1645
- The flag of registering or removing a hook function on Cell instance.
1659
+ The value of mixed precision type.
1646
1660
 
1647
1661
  Args:
1648
- cell (Cell): The cell instance.
1662
+ type(MixedPrecisionType): Mix precision type.
1663
+ is_push(bool): If called by __enter__, is push will be True
1649
1664
 
1650
1665
  Return:
1651
1666
  None.
1652
1667
  """
1653
- self._executor.set_hook_changed(cell)
1668
+
1669
+ return self._executor.set_mixed_precision_type(mixed_precision_type, is_push)
1654
1670
 
1655
1671
  def constant_folding(self, *args):
1656
1672
  """
@@ -1687,6 +1703,7 @@ class _CellGraphExecutor:
1687
1703
  self._graph_executor = GraphExecutor_.get_instance()
1688
1704
  self._graph_executor.set_py_exe_path(sys.executable)
1689
1705
  self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
1706
+ self._pid = os.getpid()
1690
1707
 
1691
1708
  def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
1692
1709
  input_indexs, phase='dataset', need_run=True):
@@ -1789,6 +1806,10 @@ class _CellGraphExecutor:
1789
1806
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
1790
1807
  key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
1791
1808
  obj.arguments_key = str(key)
1809
+ # When exist parameter in the top graph inputs, need check if the parameter object has changed.
1810
+ parameter_ids = _get_parameter_ids(args, kwargs)
1811
+ if parameter_ids != "":
1812
+ obj.arguments_key = obj.arguments_key + '.' + parameter_ids
1792
1813
  raw_phase = phase
1793
1814
  phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
1794
1815
  obj.phase_cache[raw_phase] = phase
@@ -1825,7 +1846,7 @@ class _CellGraphExecutor:
1825
1846
  if graph is None:
1826
1847
  raise RuntimeError("Compile graph failed for phase {}.".format(phase))
1827
1848
 
1828
- auto_parallel_mode = _is_in_auto_parallel_mode()
1849
+ auto_parallel_mode = _is_in_auto_parallel_mode() or _is_parallel_mode()
1829
1850
  if not auto_parallel_mode:
1830
1851
  replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
1831
1852
  self._update_param_node_default_input(phase, replace)
@@ -1913,15 +1934,9 @@ class _CellGraphExecutor:
1913
1934
 
1914
1935
  def del_net_res(self, obj, net_id):
1915
1936
  """Clear the memory resource of a network."""
1916
- self._graph_executor.del_net_res(obj, net_id)
1917
-
1918
- def inc_graph_cell_count(self):
1919
- """Increase the count of GraphCell instance."""
1920
- self._graph_executor.inc_graph_cell_count()
1921
-
1922
- def dec_graph_cell_count(self):
1923
- """Decrease the count of GraphCell instance."""
1924
- self._graph_executor.dec_graph_cell_count()
1937
+ # no need to del net res by gc in independent dataset process which is a subprocess forked by main process
1938
+ if self._pid == os.getpid():
1939
+ self._graph_executor.del_net_res(obj, net_id)
1925
1940
 
1926
1941
  def _get_branch_control_input(self):
1927
1942
  if ('obf_ratio' not in self.obfuscate_config.keys()) or (
mindspore/common/dump.py CHANGED
@@ -27,18 +27,17 @@ def set_dump(target, enabled=True):
27
27
  `target` should be an instance of :class:`mindspore.nn.Cell` or :class:`mindspore.ops.Primitive` .
28
28
  Please note that this API takes effect only when Synchronous Dump is enabled and the `dump_mode`
29
29
  field in dump config file is ``"2"`` . See the `dump document
30
- <https://www.mindspore.cn/tutorials/experts/en/master/debug/dump.html>`_ for details.
30
+ <https://www.mindspore.cn/docs/en/master/model_train/debug/dump.html>`_ for details.
31
31
  The default enabled status for
32
32
  a :class:`mindspore.nn.Cell` or :class:`mindspore.ops.Primitive` is False.
33
33
 
34
34
  Note:
35
- 1. This API is only effective for GRAPH_MODE whose graph compilation level is O0/O1 with Ascend backend.
36
- 2. This API only supports being called before training starts.
35
+ 1. This API only supports being called before training starts.
37
36
  If you call this API during training, it may not be effective.
38
- 3. After using `set_dump(Cell, True)` , operators in forward and backward
37
+ 2. After using `set_dump(Cell, True)` , operators in forward and backward
39
38
  computation (computation generated by the grad operations) of the
40
39
  cell will be dumped.
41
- 4. For :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer, the forward
40
+ 3. For :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer, the forward
42
41
  computation and backward computation use the same set of
43
42
  operators. So you can only see dump data from backward computation.
44
43
  Please note that :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer will also use
@@ -58,7 +57,7 @@ def set_dump(target, enabled=True):
58
57
  .. note::
59
58
  Please set environment variable `MINDSPORE_DUMP_CONFIG` to the dump config file and set `dump_mode` field
60
59
  in dump config file to 2 before running this example.
61
- See `dump document <https://www.mindspore.cn/tutorials/experts/en/master/debug/dump.html>`_ for details.
60
+ See `dump document <https://www.mindspore.cn/docs/en/master/model_train/debug/dump.html>`_ for details.
62
61
 
63
62
  >>> import numpy as np
64
63
  >>> import mindspore as ms
@@ -56,12 +56,6 @@ class Generator:
56
56
  A generator that manages the state of random numbers and provides seed and offset for random functions.
57
57
  When the seed and offset are fixed, the random function generates the same random sequence.
58
58
 
59
- Inputs:
60
- - **step** (int) - Set the step size for offset update.
61
-
62
- Outputs:
63
- Tuple consisting of the seed and offset of generator.
64
-
65
59
  Supported Platforms:
66
60
  ``Ascend`` ``GPU`` ``CPU``
67
61
 
@@ -199,7 +193,7 @@ def manual_seed(seed): # pylint: disable=redefined-outer-name
199
193
  >>> print(initial_seed())
200
194
  13
201
195
  """
202
- default_generator.manual_seed(seed)
196
+ return default_generator.manual_seed(seed)
203
197
 
204
198
 
205
199
  def initial_seed():
@@ -77,27 +77,19 @@ class HookHandle:
77
77
  It is only supported in pynative mode and works when registering or removing hook function for Cell object.
78
78
 
79
79
  Args:
80
- hook_cell (Cell): The Cell object with hook function registered on. Default value: None.
81
- hook_key (int): The key of cell hook function in dict. It is generated during cell hook function registration.
82
- Default value: -1.
83
- hook_type (str): The type of cell hook function: '_forward_pre_hook', '_forward_hook' or '_cell_backward_hook'.
84
- Default value: "".
80
+ hook_dict (Dict): The hook object with hook function registered on. Default value: None.
85
81
 
86
82
  Supported Platforms:
87
83
  ``Ascend`` ``GPU`` ``CPU``
88
84
  """
89
- def __init__(self, hook_cell=None, hook_key=-1, hook_type=""):
90
- if hook_cell is not None:
91
- self._hook_cell = weakref.ref(hook_cell)
92
- else:
93
- self._hook_cell = hook_cell
94
- self._hook_key = hook_key
95
- self._hook_type = hook_type
96
-
97
- def __del__(self):
98
- self._hook_cell = None
99
- self._hook_key = None
100
- self._hook_type = None
85
+ unique_id = 0
86
+
87
+ def __init__(self, hook_dict=None):
88
+ self.hook_dict_ref = None
89
+ if hook_dict is not None:
90
+ self.hook_dict_ref = weakref.ref(hook_dict)
91
+ self.handle_id = HookHandle.unique_id
92
+ HookHandle.unique_id += 1
101
93
 
102
94
  def remove(self):
103
95
  """
@@ -121,7 +113,7 @@ class HookHandle:
121
113
  >>> from mindspore import Tensor
122
114
  >>> from mindspore.ops import GradOperation
123
115
  >>> ms.set_context(mode=ms.PYNATIVE_MODE)
124
- >>> def forward_pre_hook_fn(cell_id, inputs):
116
+ >>> def forward_pre_hook_fn(cell, inputs):
125
117
  ... print("forward inputs: ", inputs)
126
118
  ...
127
119
  >>> class Net(nn.Cell):
@@ -145,11 +137,7 @@ class HookHandle:
145
137
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
146
138
  value= [ 2.00000000e+00]))
147
139
  """
148
- if self._hook_cell is not None:
149
- hook_cell = self._hook_cell()
150
- if self._hook_type == "_forward_pre_hook" and self._hook_key in hook_cell._forward_pre_hook:
151
- del hook_cell._forward_pre_hook[self._hook_key]
152
- elif self._hook_type == "_forward_hook" and self._hook_key in hook_cell._forward_hook:
153
- del hook_cell._forward_hook[self._hook_key]
154
- elif self._hook_type == "_cell_backward_hook":
155
- hook_cell._cell_backward_hook.remove_backward_hook(self._hook_key)
140
+ if self.hook_dict_ref is not None:
141
+ hook_dict = self.hook_dict_ref()
142
+ if hook_dict is not None and self.handle_id in hook_dict:
143
+ del hook_dict[self.handle_id]
@@ -103,6 +103,12 @@ def _numpy_seed():
103
103
  return np.random.randint(low=1, high=(1 << 63), dtype=np.int64)
104
104
 
105
105
 
106
+ def _init_random_normal_inplace(mean, sigma, arr):
107
+ if sigma < 0:
108
+ raise ValueError("sigma < 0")
109
+ _random_normal(_numpy_seed(), arr, mean, sigma)
110
+
111
+
106
112
  def _init_random_normal(mean, sigma, shape):
107
113
  if sigma < 0:
108
114
  raise ValueError("sigma < 0")
@@ -111,12 +117,22 @@ def _init_random_normal(mean, sigma, shape):
111
117
  return data
112
118
 
113
119
 
120
+ def _init_random_uniform_inplace(a, b, arr):
121
+ _random_uniform(_numpy_seed(), arr, a, b)
122
+
123
+
114
124
  def _init_random_uniform(a, b, shape):
115
125
  data = np.ndarray(shape=shape, dtype=np.float32)
116
126
  _random_uniform(_numpy_seed(), data, a, b)
117
127
  return data
118
128
 
119
129
 
130
+ def _init_truncated_normal_inplace(a, b, mean, sigma, arr):
131
+ if sigma < 0:
132
+ raise ValueError("sigma < 0")
133
+ _truncated_normal(_numpy_seed(), arr, a, b, mean, sigma)
134
+
135
+
120
136
  def _init_truncated_normal(a, b, mean, sigma, shape):
121
137
  if sigma < 0:
122
138
  raise ValueError("sigma < 0")
@@ -298,9 +314,11 @@ class XavierNormal(Initializer):
298
314
  fan_in, fan_out = _calculate_fan_in_and_fan_out(arr.shape)
299
315
 
300
316
  std = self.gain * math.sqrt(2.0 / float(fan_in + fan_out))
301
- data = _init_random_normal(0, std, arr.shape)
302
-
303
- _assignment(arr, data)
317
+ if isinstance(arr, np.ndarray) and arr.dtype == np.float32:
318
+ _init_random_normal_inplace(0, std, arr)
319
+ else:
320
+ data = _init_random_normal(0, std, arr.shape)
321
+ _assignment(arr, data)
304
322
 
305
323
 
306
324
  @_register('xavier_uniform')
@@ -337,8 +355,11 @@ class XavierUniform(Initializer):
337
355
  def _initialize(self, arr):
338
356
  n_in, n_out = _calculate_fan_in_and_fan_out(arr.shape)
339
357
  boundary = self.gain * math.sqrt(6.0 / (n_in + n_out))
340
- data = _init_random_uniform(-boundary, boundary, arr.shape)
341
- _assignment(arr, data)
358
+ if isinstance(arr, np.ndarray) and arr.dtype == np.float32:
359
+ _init_random_uniform_inplace(-boundary, boundary, arr)
360
+ else:
361
+ data = _init_random_uniform(-boundary, boundary, arr.shape)
362
+ _assignment(arr, data)
342
363
 
343
364
 
344
365
  @_register('he_uniform')
@@ -386,8 +407,11 @@ class HeUniform(Initializer):
386
407
  gain = _calculate_gain(self.nonlinearity, self.negative_slope)
387
408
  std = gain / math.sqrt(fan)
388
409
  boundary = math.sqrt(3.0) * std
389
- data = _init_random_uniform(-boundary, boundary, arr.shape)
390
- _assignment(arr, data)
410
+ if isinstance(arr, np.ndarray) and arr.dtype == np.float32:
411
+ _init_random_uniform_inplace(-boundary, boundary, arr)
412
+ else:
413
+ data = _init_random_uniform(-boundary, boundary, arr.shape)
414
+ _assignment(arr, data)
391
415
 
392
416
 
393
417
  @_register('he_normal')
@@ -432,8 +456,11 @@ class HeNormal(Initializer):
432
456
  fan = _calculate_correct_fan(arr.shape, self.mode)
433
457
  gain = _calculate_gain(self.nonlinearity, self.negative_slope)
434
458
  std = gain / math.sqrt(fan)
435
- data = _init_random_normal(0, std, arr.shape)
436
- _assignment(arr, data)
459
+ if isinstance(arr, np.ndarray) and arr.dtype == np.float32:
460
+ _init_random_normal_inplace(0, std, arr)
461
+ else:
462
+ data = _init_random_normal(0, std, arr.shape)
463
+ _assignment(arr, data)
437
464
 
438
465
 
439
466
  class Constant(Initializer):
@@ -718,8 +745,11 @@ class Uniform(Initializer):
718
745
  self.scale = scale
719
746
 
720
747
  def _initialize(self, arr):
721
- tmp = _init_random_uniform(-self.scale, self.scale, arr.shape)
722
- _assignment(arr, tmp)
748
+ if isinstance(arr, np.ndarray) and arr.dtype == np.float32:
749
+ _init_random_uniform_inplace(-self.scale, self.scale, arr)
750
+ else:
751
+ tmp = _init_random_uniform(-self.scale, self.scale, arr.shape)
752
+ _assignment(arr, tmp)
723
753
 
724
754
 
725
755
  @_register()
@@ -749,8 +779,11 @@ class Normal(Initializer):
749
779
  self.mean = mean
750
780
 
751
781
  def _initialize(self, arr):
752
- data = _init_random_normal(self.mean, self.sigma, arr.shape)
753
- _assignment(arr, data)
782
+ if isinstance(arr, np.ndarray) and arr.dtype == np.float32:
783
+ _init_random_normal_inplace(self.mean, self.sigma, arr)
784
+ else:
785
+ data = _init_random_normal(self.mean, self.sigma, arr.shape)
786
+ _assignment(arr, data)
754
787
 
755
788
 
756
789
  @_register()
@@ -780,8 +813,11 @@ class TruncatedNormal(Initializer):
780
813
  self.b = b
781
814
 
782
815
  def _initialize(self, arr):
783
- tmp = _init_truncated_normal(self.a, self.b, self.mean, self.sigma, arr.shape)
784
- _assignment(arr, tmp)
816
+ if isinstance(arr, np.ndarray) and arr.dtype == np.float32:
817
+ _init_truncated_normal_inplace(self.a, self.b, self.mean, self.sigma, arr)
818
+ else:
819
+ tmp = _init_truncated_normal(self.a, self.b, self.mean, self.sigma, arr.shape)
820
+ _assignment(arr, tmp)
785
821
 
786
822
 
787
823
  def initializer(init, shape=None, dtype=mstype.float32):
@@ -90,9 +90,9 @@ def save_mindir(model, file_name):
90
90
  if not file_name.endswith('.mindir'):
91
91
  file_name += ".mindir"
92
92
 
93
- current_path = os.path.abspath(file_name)
93
+ current_path = os.path.realpath(file_name)
94
94
  dirname = os.path.dirname(current_path)
95
- os.makedirs(dirname, exist_ok=True)
95
+ os.makedirs(dirname, mode=0o700, exist_ok=True)
96
96
  if os.path.exists(file_name):
97
97
  os.chmod(file_name, stat.S_IWUSR)
98
98