mindspore 2.3.0rc1__cp37-cp37m-manylinux1_x86_64.whl → 2.3.0rc2__cp37-cp37m-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +1 -1
  3. mindspore/_akg/akg/utils/tbe_codegen_utils.py +13 -3
  4. mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
  5. mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
  6. mindspore/_checkparam.py +20 -0
  7. mindspore/_extends/parse/parser.py +1 -1
  8. mindspore/_extends/parse/standard_method.py +6 -5
  9. mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
  10. mindspore/amp.py +5 -5
  11. mindspore/bin/cache_admin +0 -0
  12. mindspore/bin/cache_server +0 -0
  13. mindspore/boost/boost_cell_wrapper.py +1 -1
  14. mindspore/boost/group_loss_scale_manager.py +1 -1
  15. mindspore/common/__init__.py +4 -2
  16. mindspore/common/_register_for_recompute.py +48 -0
  17. mindspore/common/_stub_tensor.py +1 -0
  18. mindspore/common/api.py +56 -4
  19. mindspore/common/dtype.py +5 -3
  20. mindspore/common/dump.py +2 -2
  21. mindspore/common/hook_handle.py +51 -4
  22. mindspore/common/initializer.py +1 -1
  23. mindspore/common/jit_config.py +17 -6
  24. mindspore/common/parameter.py +7 -2
  25. mindspore/common/recompute.py +247 -0
  26. mindspore/common/sparse_tensor.py +2 -2
  27. mindspore/common/symbol.py +1 -1
  28. mindspore/common/tensor.py +74 -36
  29. mindspore/communication/__init__.py +3 -3
  30. mindspore/communication/management.py +30 -30
  31. mindspore/context.py +28 -15
  32. mindspore/dataset/__init__.py +5 -5
  33. mindspore/dataset/audio/__init__.py +2 -2
  34. mindspore/dataset/audio/transforms.py +51 -51
  35. mindspore/dataset/callback/ds_callback.py +2 -2
  36. mindspore/dataset/engine/cache_client.py +1 -1
  37. mindspore/dataset/engine/datasets.py +3 -3
  38. mindspore/dataset/engine/datasets_audio.py +14 -14
  39. mindspore/dataset/engine/datasets_standard_format.py +3 -3
  40. mindspore/dataset/engine/datasets_text.py +38 -38
  41. mindspore/dataset/engine/datasets_user_defined.py +3 -3
  42. mindspore/dataset/engine/datasets_vision.py +68 -68
  43. mindspore/dataset/text/__init__.py +3 -3
  44. mindspore/dataset/text/transforms.py +26 -26
  45. mindspore/dataset/transforms/__init__.py +1 -1
  46. mindspore/dataset/vision/__init__.py +3 -3
  47. mindspore/dataset/vision/transforms.py +92 -92
  48. mindspore/dataset/vision/utils.py +1 -1
  49. mindspore/experimental/optim/adadelta.py +2 -2
  50. mindspore/experimental/optim/adagrad.py +2 -2
  51. mindspore/experimental/optim/adam.py +2 -2
  52. mindspore/experimental/optim/adamax.py +2 -2
  53. mindspore/experimental/optim/adamw.py +2 -2
  54. mindspore/experimental/optim/asgd.py +2 -2
  55. mindspore/experimental/optim/lr_scheduler.py +24 -20
  56. mindspore/experimental/optim/nadam.py +2 -2
  57. mindspore/experimental/optim/optimizer.py +1 -1
  58. mindspore/experimental/optim/radam.py +2 -2
  59. mindspore/experimental/optim/rmsprop.py +2 -2
  60. mindspore/experimental/optim/rprop.py +2 -2
  61. mindspore/experimental/optim/sgd.py +2 -2
  62. mindspore/hal/stream.py +2 -0
  63. mindspore/include/mindapi/base/types.h +5 -0
  64. mindspore/lib/libdnnl.so.2 +0 -0
  65. mindspore/lib/libmindspore.so +0 -0
  66. mindspore/lib/libmindspore_backend.so +0 -0
  67. mindspore/lib/libmindspore_common.so +0 -0
  68. mindspore/lib/libmindspore_core.so +0 -0
  69. mindspore/lib/libmindspore_glog.so.0 +0 -0
  70. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  71. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  72. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  73. mindspore/lib/libmindspore_shared_lib.so +0 -0
  74. mindspore/lib/libopencv_core.so.4.5 +0 -0
  75. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  76. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  77. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  78. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6 -6
  79. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  80. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  81. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  82. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  83. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  84. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  85. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  86. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  87. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  88. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  89. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  90. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  91. mindspore/log.py +2 -2
  92. mindspore/mint/__init__.py +457 -0
  93. mindspore/mint/nn/__init__.py +430 -0
  94. mindspore/mint/nn/functional.py +424 -0
  95. mindspore/mint/optim/__init__.py +24 -0
  96. mindspore/mint/optim/adamw.py +186 -0
  97. mindspore/multiprocessing/__init__.py +4 -0
  98. mindspore/nn/__init__.py +3 -0
  99. mindspore/nn/cell.py +51 -47
  100. mindspore/nn/extend/__init__.py +29 -0
  101. mindspore/nn/extend/basic.py +140 -0
  102. mindspore/nn/extend/embedding.py +143 -0
  103. mindspore/nn/extend/layer/__init__.py +27 -0
  104. mindspore/nn/extend/layer/normalization.py +107 -0
  105. mindspore/nn/extend/pooling.py +117 -0
  106. mindspore/nn/generator.py +297 -0
  107. mindspore/nn/layer/basic.py +109 -1
  108. mindspore/nn/layer/container.py +2 -2
  109. mindspore/nn/layer/conv.py +6 -6
  110. mindspore/nn/layer/embedding.py +1 -1
  111. mindspore/nn/layer/normalization.py +21 -43
  112. mindspore/nn/layer/padding.py +4 -0
  113. mindspore/nn/optim/ada_grad.py +2 -2
  114. mindspore/nn/optim/adadelta.py +1 -1
  115. mindspore/nn/optim/adafactor.py +1 -1
  116. mindspore/nn/optim/adam.py +7 -7
  117. mindspore/nn/optim/adamax.py +2 -2
  118. mindspore/nn/optim/adasum.py +2 -2
  119. mindspore/nn/optim/asgd.py +2 -2
  120. mindspore/nn/optim/ftrl.py +1 -1
  121. mindspore/nn/optim/lamb.py +3 -3
  122. mindspore/nn/optim/lars.py +1 -1
  123. mindspore/nn/optim/lazyadam.py +2 -2
  124. mindspore/nn/optim/momentum.py +2 -2
  125. mindspore/nn/optim/optimizer.py +2 -2
  126. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  127. mindspore/nn/optim/rmsprop.py +2 -2
  128. mindspore/nn/optim/rprop.py +2 -2
  129. mindspore/nn/optim/sgd.py +2 -2
  130. mindspore/nn/optim/thor.py +2 -2
  131. mindspore/nn/wrap/cell_wrapper.py +9 -9
  132. mindspore/nn/wrap/grad_reducer.py +5 -5
  133. mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
  134. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -2
  135. mindspore/ops/_vmap/vmap_math_ops.py +27 -8
  136. mindspore/ops/_vmap/vmap_nn_ops.py +66 -8
  137. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +73 -1
  138. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +12 -3
  139. mindspore/ops/auto_generate/gen_arg_handler.py +24 -0
  140. mindspore/ops/auto_generate/gen_extend_func.py +274 -0
  141. mindspore/ops/auto_generate/gen_ops_def.py +889 -22
  142. mindspore/ops/auto_generate/gen_ops_prim.py +3541 -253
  143. mindspore/ops/auto_generate/pyboost_inner_prim.py +282 -0
  144. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  145. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +9 -0
  146. mindspore/ops/extend/__init__.py +9 -1
  147. mindspore/ops/extend/array_func.py +134 -27
  148. mindspore/ops/extend/math_func.py +3 -3
  149. mindspore/ops/extend/nn_func.py +363 -2
  150. mindspore/ops/function/__init__.py +19 -2
  151. mindspore/ops/function/array_func.py +463 -439
  152. mindspore/ops/function/clip_func.py +7 -18
  153. mindspore/ops/function/grad/grad_func.py +5 -5
  154. mindspore/ops/function/linalg_func.py +4 -4
  155. mindspore/ops/function/math_func.py +260 -243
  156. mindspore/ops/function/nn_func.py +825 -62
  157. mindspore/ops/function/random_func.py +73 -4
  158. mindspore/ops/function/sparse_unary_func.py +1 -1
  159. mindspore/ops/function/vmap_func.py +1 -1
  160. mindspore/ops/functional.py +2 -2
  161. mindspore/ops/op_info_register.py +1 -31
  162. mindspore/ops/operations/__init__.py +2 -3
  163. mindspore/ops/operations/_grad_ops.py +2 -107
  164. mindspore/ops/operations/_inner_ops.py +5 -5
  165. mindspore/ops/operations/_sequence_ops.py +2 -2
  166. mindspore/ops/operations/array_ops.py +11 -233
  167. mindspore/ops/operations/comm_ops.py +32 -32
  168. mindspore/ops/operations/custom_ops.py +7 -89
  169. mindspore/ops/operations/manually_defined/ops_def.py +329 -4
  170. mindspore/ops/operations/math_ops.py +13 -163
  171. mindspore/ops/operations/nn_ops.py +9 -316
  172. mindspore/ops/operations/random_ops.py +1 -1
  173. mindspore/ops/operations/sparse_ops.py +3 -3
  174. mindspore/ops/primitive.py +2 -2
  175. mindspore/ops_generate/arg_dtype_cast.py +12 -3
  176. mindspore/ops_generate/arg_handler.py +24 -0
  177. mindspore/ops_generate/gen_ops_inner_prim.py +2 -0
  178. mindspore/ops_generate/gen_pyboost_func.py +13 -6
  179. mindspore/ops_generate/pyboost_utils.py +2 -17
  180. mindspore/parallel/__init__.py +3 -2
  181. mindspore/parallel/_auto_parallel_context.py +106 -1
  182. mindspore/parallel/_parallel_serialization.py +34 -2
  183. mindspore/parallel/_utils.py +16 -0
  184. mindspore/parallel/algo_parameter_config.py +4 -4
  185. mindspore/parallel/checkpoint_transform.py +249 -77
  186. mindspore/parallel/cluster/process_entity/_api.py +1 -1
  187. mindspore/parallel/parameter_broadcast.py +1 -1
  188. mindspore/parallel/shard.py +1 -1
  189. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +1 -0
  190. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +17 -5
  191. mindspore/profiler/parser/ascend_msprof_exporter.py +3 -3
  192. mindspore/profiler/parser/ascend_msprof_generator.py +10 -3
  193. mindspore/profiler/parser/ascend_op_generator.py +26 -9
  194. mindspore/profiler/parser/ascend_timeline_generator.py +7 -4
  195. mindspore/profiler/parser/profiler_info.py +11 -1
  196. mindspore/profiler/profiling.py +13 -5
  197. mindspore/rewrite/api/node.py +12 -12
  198. mindspore/rewrite/api/symbol_tree.py +11 -11
  199. mindspore/run_check/_check_version.py +1 -1
  200. mindspore/safeguard/rewrite_obfuscation.py +2 -2
  201. mindspore/train/amp.py +4 -4
  202. mindspore/train/anf_ir_pb2.py +8 -2
  203. mindspore/train/callback/_backup_and_restore.py +2 -2
  204. mindspore/train/callback/_callback.py +4 -4
  205. mindspore/train/callback/_checkpoint.py +2 -2
  206. mindspore/train/callback/_early_stop.py +2 -2
  207. mindspore/train/callback/_landscape.py +4 -4
  208. mindspore/train/callback/_loss_monitor.py +2 -2
  209. mindspore/train/callback/_on_request_exit.py +2 -2
  210. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  211. mindspore/train/callback/_summary_collector.py +2 -2
  212. mindspore/train/callback/_time_monitor.py +2 -2
  213. mindspore/train/dataset_helper.py +8 -3
  214. mindspore/train/loss_scale_manager.py +2 -2
  215. mindspore/train/metrics/metric.py +3 -3
  216. mindspore/train/mind_ir_pb2.py +22 -17
  217. mindspore/train/model.py +15 -15
  218. mindspore/train/serialization.py +18 -18
  219. mindspore/train/summary/summary_record.py +7 -7
  220. mindspore/train/train_thor/convert_utils.py +3 -3
  221. mindspore/version.py +1 -1
  222. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +1 -1
  223. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +226 -212
  224. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
  225. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
  226. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py CHANGED
@@ -20,10 +20,9 @@ import inspect
20
20
  import os
21
21
  import time
22
22
  from collections import OrderedDict
23
- from types import FunctionType, MethodType
24
23
  import numpy
25
24
 
26
- from mindspore._checkparam import args_type_check
25
+ from mindspore._checkparam import args_type_check, check_hook_fn
27
26
  from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
28
27
  from mindspore import log as logger
29
28
  from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
@@ -34,7 +33,7 @@ from mindspore._c_expression import init_pipeline, update_func_graph_hyper_param
34
33
  from mindspore import _checkparam as Validator
35
34
  from mindspore.common import dtype as mstype
36
35
  from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
37
- from mindspore.common.api import _generate_branch_control_input
36
+ from mindspore.common.api import _generate_branch_control_input, _convert_python_data, _get_args_for_run_predict
38
37
  from mindspore.common.parameter import Parameter, ParameterTuple
39
38
  from mindspore.common.tensor import Tensor
40
39
  from mindspore.ops.operations import Cast
@@ -43,6 +42,7 @@ from mindspore.ops.operations import _inner_ops as inner
43
42
  from mindspore.parallel.shard import Shard
44
43
  from mindspore._check_jit_forbidden_api import jit_forbidden_register
45
44
  from mindspore.common._decorator import deprecated
45
+ from mindspore.common._register_for_recompute import recompute_registry
46
46
 
47
47
 
48
48
  class Cell(Cell_):
@@ -125,11 +125,13 @@ class Cell(Cell_):
125
125
  self._create_time = int(time.time() * 1e9)
126
126
  self.arguments_key = ""
127
127
  self.compile_cache = set()
128
+ self.phase_cache = dict()
128
129
  cells_compile_cache[id(self)] = self.compile_cache
129
130
  self.parameter_broadcast_done = False
130
131
  self._id = 1
131
132
  self.exist_names = set("")
132
133
  self.exist_objs = set()
134
+ self.recompute_cell = None
133
135
  init_pipeline()
134
136
 
135
137
  # call gc to release GE session resources used by non-used cell objects
@@ -217,7 +219,7 @@ class Cell(Cell_):
217
219
 
218
220
  Tutorial Examples:
219
221
  - `Cell and Parameter - Custom Cell Reverse
220
- <https://mindspore.cn/tutorials/en/r2.3.q1/advanced/modules/layer.html#custom-cell-reverse>`_
222
+ <https://mindspore.cn/tutorials/en/master/advanced/modules/layer.html#custom-cell-reverse>`_
221
223
  """
222
224
  return self._bprop_debug
223
225
 
@@ -415,7 +417,7 @@ class Cell(Cell_):
415
417
  elif isinstance(item, float):
416
418
  res.append(self.cast(item, dst_type))
417
419
  elif hasattr(item, "dtype") and item.dtype in \
418
- {mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
420
+ {mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
419
421
  res.append(self.cast(item, dst_type))
420
422
  else:
421
423
  res.append(item)
@@ -474,7 +476,10 @@ class Cell(Cell_):
474
476
  elif hasattr(self, "_shard_fn"):
475
477
  output = self._shard_fn(*cast_inputs, **kwargs)
476
478
  else:
477
- output = self.construct(*cast_inputs, **kwargs)
479
+ if self.recompute_cell is not None:
480
+ output = self.recompute_cell(*cast_inputs, **kwargs)
481
+ else:
482
+ output = self.construct(*cast_inputs, **kwargs)
478
483
  if self._enable_forward_hook:
479
484
  output = self._run_forward_hook(cast_inputs, output)
480
485
  return output
@@ -659,6 +664,16 @@ class Cell(Cell_):
659
664
  self.check_names_and_refresh_name()
660
665
  self._is_check_and_refresh = True
661
666
 
667
+ def _predict(self, *args, **kwargs):
668
+ if not hasattr(self, "phase"):
669
+ return False, None
670
+ if (self.phase == "prefill" or self.phase == 'increment') and self.phase in self.phase_cache:
671
+ new_args = _get_args_for_run_predict(self, args, kwargs, self._compile_args)
672
+ res = _cell_graph_executor._graph_executor(tuple(new_args), self.phase_cache[self.phase])
673
+ res = _convert_python_data(res)
674
+ return True, res
675
+ return False, None
676
+
662
677
  def __call__(self, *args, **kwargs):
663
678
  # Run in Graph mode.
664
679
  if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
@@ -667,7 +682,12 @@ class Cell(Cell_):
667
682
  bound_arguments.apply_defaults()
668
683
  args = bound_arguments.args
669
684
  kwargs = bound_arguments.kwargs
685
+
686
+ predict_compiled, res = self._predict(*args, **kwargs)
687
+ if predict_compiled:
688
+ return res
670
689
  self._check_construct_args(*args)
690
+
671
691
  if self._hook_fn_registered():
672
692
  logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
673
693
  f"function, please use context.set_context to set pynative mode.")
@@ -964,7 +984,6 @@ class Cell(Cell_):
964
984
  return self._dynamic_shape_inputs
965
985
  return args
966
986
 
967
-
968
987
  def compile(self, *args, **kwargs):
969
988
  """
970
989
  Compile Cell as a computation graph, the input must be consistent with the input defined in construct.
@@ -1335,7 +1354,7 @@ class Cell(Cell_):
1335
1354
 
1336
1355
  Tutorial Examples:
1337
1356
  - `Model Training - Optimizer
1338
- <https://mindspore.cn/tutorials/en/r2.3.q1/beginner/train.html#optimizer>`_
1357
+ <https://mindspore.cn/tutorials/en/master/beginner/train.html#optimizer>`_
1339
1358
  """
1340
1359
  return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
1341
1360
 
@@ -1446,7 +1465,7 @@ class Cell(Cell_):
1446
1465
 
1447
1466
  Tutorial Examples:
1448
1467
  - `Building a Network - Model Parameters
1449
- <https://mindspore.cn/tutorials/en/r2.3.q1/beginner/model.html#model-parameters>`_
1468
+ <https://mindspore.cn/tutorials/en/master/beginner/model.html#model-parameters>`_
1450
1469
  """
1451
1470
  cells = []
1452
1471
  if expand:
@@ -1785,7 +1804,7 @@ class Cell(Cell_):
1785
1804
  accelerate the algorithm in the algorithm library.
1786
1805
 
1787
1806
  If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
1788
- `algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.3.q1/mindspore/python/mindspore/boost>`_.
1807
+ `algorithm library <https://gitee.com/mindspore/mindspore/tree/master/mindspore/python/mindspore/boost>`_.
1789
1808
 
1790
1809
  Note:
1791
1810
  Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
@@ -1842,7 +1861,7 @@ class Cell(Cell_):
1842
1861
 
1843
1862
  Tutorial Examples:
1844
1863
  - `Model Training - Implementing Training and Evaluation
1845
- <https://mindspore.cn/tutorials/en/r2.3.q1/beginner/train.html#training-and-evaluation>`_
1864
+ <https://mindspore.cn/tutorials/en/master/beginner/train.html#training-and-evaluation>`_
1846
1865
  """
1847
1866
  if mode:
1848
1867
  self._phase = 'train'
@@ -1936,8 +1955,8 @@ class Cell(Cell_):
1936
1955
  hook_fn (function): Python function. Forward pre hook function.
1937
1956
 
1938
1957
  Returns:
1939
- Handle, it is an instance of `mindspore.common.hook_handle.HookHandle` and corresponding to the `hook_fn` .
1940
- The handle can be used to remove the added `hook_fn` by calling `handle.remove()` .
1958
+ A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
1959
+ `handle.remove()` .
1941
1960
 
1942
1961
  Raises:
1943
1962
  TypeError: If the `hook_fn` is not a function of python.
@@ -1972,17 +1991,8 @@ class Cell(Cell_):
1972
1991
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
1973
1992
  value= [ 2.00000000e+00]))
1974
1993
  """
1975
- if context.get_context("mode") != context.PYNATIVE_MODE:
1976
- logger.warning(f"'register_forward_pre_hook' function is only supported in pynative mode, you can use "
1977
- f"context.set_context to set pynative mode.")
1994
+ if not check_hook_fn("register_forward_pre_hook", hook_fn):
1978
1995
  return HookHandle()
1979
-
1980
- if not isinstance(hook_fn, (FunctionType, MethodType)):
1981
- raise TypeError(f"When using 'register_forward_pre_hook(hook_fn)', the type of 'hook_fn' must be python "
1982
- f"function, but got {type(hook_fn)}.")
1983
- if hook_fn.__code__.co_name == "staging_specialize":
1984
- raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
1985
-
1986
1996
  self._enable_forward_pre_hook = True
1987
1997
  _pynative_executor.set_hook_changed(self)
1988
1998
  if not hasattr(self, '_forward_pre_hook_key'):
@@ -2036,8 +2046,8 @@ class Cell(Cell_):
2036
2046
  hook_fn (function): Python function. Forward hook function.
2037
2047
 
2038
2048
  Returns:
2039
- Handle, it is an instance of `mindspore.common.hook_handle.HookHandle` and corresponding to the `hook_fn` .
2040
- The handle can be used to remove the added `hook_fn` by calling `handle.remove()` .
2049
+ A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
2050
+ `handle.remove()` .
2041
2051
 
2042
2052
  Raises:
2043
2053
  TypeError: If the `hook_fn` is not a function of python.
@@ -2074,17 +2084,8 @@ class Cell(Cell_):
2074
2084
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
2075
2085
  value= [ 2.00000000e+00]))
2076
2086
  """
2077
- if context.get_context("mode") != context.PYNATIVE_MODE:
2078
- logger.warning(f"'register_forward_hook' function is only supported in pynative mode, you can use "
2079
- f"context.set_context to set pynative mode.")
2087
+ if not check_hook_fn("register_forward_hook", hook_fn):
2080
2088
  return HookHandle()
2081
-
2082
- if not isinstance(hook_fn, (FunctionType, MethodType)):
2083
- raise TypeError(f"When using 'register_forward_hook(hook_fn)', the type of 'hook_fn' must be python "
2084
- f"function, but got {type(hook_fn)}.")
2085
- if hook_fn.__code__.co_name == "staging_specialize":
2086
- raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
2087
-
2088
2089
  self._enable_forward_hook = True
2089
2090
  _pynative_executor.set_hook_changed(self)
2090
2091
  if not hasattr(self, '_forward_hook_key'):
@@ -2136,8 +2137,8 @@ class Cell(Cell_):
2136
2137
  hook_fn (function): Python function. Backward hook function.
2137
2138
 
2138
2139
  Returns:
2139
- Handle, it is an instance of `mindspore.common.hook_handle.HookHandle` and corresponding to the `hook_fn` .
2140
- The handle can be used to remove the added `hook_fn` by calling `handle.remove()` .
2140
+ A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
2141
+ `handle.remove()` .
2141
2142
 
2142
2143
  Raises:
2143
2144
  TypeError: If the `hook_fn` is not a function of python.
@@ -2172,14 +2173,8 @@ class Cell(Cell_):
2172
2173
  >>> print(output)
2173
2174
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
2174
2175
  """
2175
- if context.get_context("mode") != context.PYNATIVE_MODE:
2176
- logger.warning(f"'register_backward_hook' function is only supported in pynative mode, you can use "
2177
- f"context.set_context to set pynative mode.")
2176
+ if not check_hook_fn("register_backward_hook", hook_fn):
2178
2177
  return HookHandle()
2179
-
2180
- if not isinstance(hook_fn, (FunctionType, MethodType)):
2181
- raise TypeError(f"When using 'register_backward_hook(hook_fn)', the type of 'hook_fn' must be python "
2182
- f"function, but got {type(hook_fn)}.")
2183
2178
  if self._cell_backward_hook is None:
2184
2179
  self._enable_backward_hook = True
2185
2180
  self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")")
@@ -2209,10 +2204,16 @@ class Cell(Cell_):
2209
2204
  else:
2210
2205
  inputs = self._cell_backward_hook(*inputs)
2211
2206
  inputs = (inputs,)
2212
- if isinstance(inputs, tuple):
2213
- outputs = self.construct(*inputs, **kwargs)
2207
+ if self.recompute_cell is not None:
2208
+ if isinstance(inputs, tuple):
2209
+ outputs = self.recompute_cell(*inputs, **kwargs)
2210
+ else:
2211
+ outputs = self.recompute_cell(inputs, **kwargs)
2214
2212
  else:
2215
- outputs = self.construct(inputs, **kwargs)
2213
+ if isinstance(inputs, tuple):
2214
+ outputs = self.construct(*inputs, **kwargs)
2215
+ else:
2216
+ outputs = self.construct(inputs, **kwargs)
2216
2217
  outputs = self._cell_backward_hook(outputs)
2217
2218
  return outputs
2218
2219
 
@@ -2342,6 +2343,9 @@ class Cell(Cell_):
2342
2343
  introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
2343
2344
  Default: ``False`` .
2344
2345
  """
2346
+ if context.get_context("mode") == context.PYNATIVE_MODE:
2347
+ self.recompute_cell = recompute_registry.get()(self.construct)
2348
+ return
2345
2349
  self._recompute()
2346
2350
  if 'mp_comm_recompute' in kwargs.keys():
2347
2351
  self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
@@ -0,0 +1,29 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ nn Extend.
17
+ """
18
+ from __future__ import absolute_import
19
+
20
+ from mindspore.nn.extend.embedding import Embedding
21
+ from mindspore.nn.extend.basic import Linear
22
+ from mindspore.nn.extend.pooling import MaxPool2d
23
+ from mindspore.nn.extend import layer
24
+ from mindspore.nn.extend.layer import *
25
+
26
+ __all__ = ['Embedding', 'Linear', 'MaxPool2d']
27
+ __all__.extend(layer.__all__)
28
+
29
+ __all__.sort()
@@ -0,0 +1,140 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """basic"""
17
+ from __future__ import absolute_import
18
+
19
+ import math
20
+
21
+ import mindspore.common.dtype as mstype
22
+ from mindspore import _checkparam as Validator
23
+ from mindspore._extends import cell_attr_register
24
+ from mindspore.common.initializer import initializer, HeUniform, Uniform
25
+ from mindspore.common.parameter import Parameter
26
+ from mindspore.common.tensor import Tensor
27
+ from mindspore.nn.cell import Cell
28
+ from mindspore.ops import operations as P
29
+
30
+ __all__ = ['Linear']
31
+
32
+
33
+ class Linear(Cell):
34
+ r"""
35
+ The linear connected layer.
36
+
37
+ Applies linear connected layer for the input. This layer implements the operation as:
38
+
39
+ .. math::
40
+ \text{outputs} = X * kernel + bias
41
+
42
+ where :math:`X` is the input tensors, :math:`\text{kernel}` is a weight matrix with the same
43
+ data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector
44
+ with the same data type as the :math:`X` created by the layer (only if has_bias is True).
45
+
46
+ Args:
47
+ in_features (int): The number of features in the input space.
48
+ out_features (int): The number of features in the output space.
49
+ bias (bool): Specifies whether the layer uses a bias vector :math:`\text{bias}`. Default: ``True``.
50
+ weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
51
+ is same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
52
+ weight will be initialized using HeUniform.
53
+ bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
54
+ same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
55
+ bias will be initialized using Uniform.
56
+ dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``None`` .
57
+
58
+ Inputs:
59
+ - **x** (Tensor) - Tensor of shape :math:`(*, in\_features)`. The `in_features` in `Args` should be equal
60
+ to :math:`in\_features` in `Inputs`.
61
+
62
+ Outputs:
63
+ Tensor of shape :math:`(*, out\_features)`.
64
+
65
+ Raises:
66
+ TypeError: If `in_features` or `out_features` is not an int.
67
+ TypeError: If `bias` is not a bool.
68
+ ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init`
69
+ is not equal to `out_features` or shape[1] of `weight_init` is not equal to `in_features`.
70
+ ValueError: If length of shape of `bias_init` is not equal to 1
71
+ or shape[0] of `bias_init` is not equal to `out_features`.
72
+
73
+ Supported Platforms:
74
+ ``Ascend`` ``GPU`` ``CPU``
75
+
76
+ Examples:
77
+ >>> import mindspore
78
+ >>> from mindspore import Tensor
79
+ >>> from mindspore.nn.extend import Linear
80
+ >>> import numpy as np
81
+ >>> x = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
82
+ >>> net = Linear(3, 4)
83
+ >>> output = net(x)
84
+ >>> print(output.shape)
85
+ (2, 4)
86
+ """
87
+
88
+ @cell_attr_register(attrs=['has_bias'])
89
+ def __init__(self,
90
+ in_features,
91
+ out_features,
92
+ bias=True,
93
+ weight_init=None,
94
+ bias_init=None,
95
+ dtype=None):
96
+ """Initialize Linear."""
97
+ super(Linear, self).__init__()
98
+ self.in_features = Validator.check_positive_int(
99
+ in_features, "in_features", self.cls_name)
100
+ self.out_features = Validator.check_positive_int(
101
+ out_features, "out_features", self.cls_name)
102
+ self.has_bias = Validator.check_bool(
103
+ bias, "has_bias", self.cls_name)
104
+ self.dense = P.Dense()
105
+ if dtype is None:
106
+ dtype = mstype.float32
107
+ if isinstance(weight_init, Tensor):
108
+ if weight_init.ndim != 2 or weight_init.shape[0] != out_features or \
109
+ weight_init.shape[1] != in_features:
110
+ raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' must "
111
+ f"be equal to 2, and the first dim must be equal to 'out_features', and the "
112
+ f"second dim must be equal to 'in_features'. But got 'weight_init': {weight_init}, "
113
+ f"'out_features': {out_features}, 'in_features': {in_features}.")
114
+ if weight_init is None:
115
+ weight_init = HeUniform(math.sqrt(5))
116
+ self.weight = Parameter(initializer(
117
+ weight_init, [out_features, in_features], dtype=dtype), name="weight")
118
+
119
+ self.bias = None
120
+ if self.has_bias:
121
+ if isinstance(bias_init, Tensor):
122
+ if bias_init.ndim != 1 or bias_init.shape[0] != out_features:
123
+ raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must "
124
+ f"be equal to 1, and the first dim must be equal to 'out_features'. But got "
125
+ f"'bias_init': {bias_init}, 'out_features': {out_features}.")
126
+ if bias_init is None:
127
+ bound = 1 / math.sqrt(in_features)
128
+ bias_init = Uniform(scale=bound)
129
+ self.bias = Parameter(initializer(
130
+ bias_init, [out_features], dtype=dtype), name="bias")
131
+
132
+ def construct(self, x):
133
+ x = self.dense(x, self.weight, self.bias)
134
+ return x
135
+
136
+ def extend_repr(self):
137
+ s = f'input_features={self.in_features}, output_features={self.out_features}'
138
+ if self.has_bias:
139
+ s += f', has_bias={self.has_bias}'
140
+ return s
@@ -0,0 +1,143 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """embedding"""
16
+ from __future__ import absolute_import
17
+
18
+ import mindspore.common.dtype as mstype
19
+ from mindspore.common.initializer import Normal
20
+ from mindspore import _checkparam as Validator
21
+ from mindspore.nn.cell import Cell
22
+ from mindspore import ops
23
+ from mindspore.common.parameter import Parameter
24
+ from mindspore.common.tensor import Tensor
25
+
26
+ __all__ = ['Embedding']
27
+
28
+
29
+ class Embedding(Cell):
30
+ r"""
31
+ Embedding layer.
32
+ Retrieve the word embeddings in weight stored in the layer using indices specified in `input`.
33
+
34
+ .. warning::
35
+ On Ascend, the behavior is unpredictable when the value of `input` is invalid.
36
+
37
+ Args:
38
+ num_embeddings (int): Size of the dictionary of embeddings.
39
+ embedding_dim (int): The size of each embedding vector.
40
+ padding_idx (int, optional): If the value is not None, the corresponding row of embedding vector
41
+ will not be updated in training. The value of embedding vector at `padding_idx` will default
42
+ to zeros when the Embedding layer is newly constructed. The value should be in range
43
+ `[-num_embeddings, num_embeddings)` if it's not ``None``. Default ``None``.
44
+ max_norm (float, optional): If the value is not None, firstly get the p-norm result of the embedding
45
+ vector specified by `input` where p is specified by `norm_type`; if the result is larger then `max_norm`,
46
+ update the embedding vector` with :math:`\frac{max\_norm}{result+1e^{-7}}`. Default ``None``.
47
+ norm_type (float, optional): Indicated the value of p in p-norm. Default ``2.0``.
48
+ scale_grad_by_freq (bool, optional): If ``True`` the gradients will be scaled by the inverse of frequency
49
+ of the index in `input`. Default ``False``.
50
+ _weight (Tensor, optional): Used to initialize the weight of Embedding. If ``None``, the weight will be
51
+ initialized from normal distribution :math:`{N}(\text{sigma=1.0}, \text{mean=0.0})`. Default ``None``.
52
+ dtype (mindspore.dtype, optional) : Dtype of Parameters. It is meaningless when `_weight` is not None.
53
+ Default: ``mindspore.float32``.
54
+
55
+ Inputs:
56
+ - **input** (Tensor) - The indices used to lookup in the embedding vector. The data type must be
57
+ mindspore.int32 or mindspore.int64, and the value should be in range `[0, num_embeddings)`.
58
+
59
+ Outputs:
60
+ Tensor, has the same data type as weight, the shape is :math:`(*input.shape, embedding_dim)`.
61
+
62
+ Raises:
63
+ TypeError: If `num_embeddings` is not an int.
64
+ TypeError: If `embedding_dim` is not an int.
65
+ ValueError: If `padding_idx` is out of valid range.
66
+ TypeError: If `max_norm` is not a float.
67
+ TypeError: If `norm_type` is not a float.
68
+ TypeError: If `scale_grad_by_freq` is not a bool.
69
+ TypeError: If `dtype` is not one of mindspore.dtype.
70
+
71
+ Supported Platforms:
72
+ ``Ascend``
73
+
74
+ Examples:
75
+ >>> import mindspore
76
+ >>> import numpy as np
77
+ >>> from mindspore import Tensor, nn
78
+ >>> input = Tensor([[1, 0, 1, 1], [0, 0, 1, 0]])
79
+ >>> embedding = nn.extend.Embedding(num_embeddings=10, embedding_dim=3)
80
+ >>> output = embedding(input)
81
+ >>> print(output)
82
+ [[[-0.0024154 -0.01203444 0.00811537]
83
+ [ 0.00233847 -0.00596091 0.00536799]
84
+ [-0.0024154 -0.01203444 0.00811537]
85
+ [-0.0024154 -0.01203444 0.00811537]]
86
+ [[ 0.00233847 -0.00596091 0.00536799]
87
+ [ 0.00233847 -0.00596091 0.00536799]
88
+ [-0.0024154 -0.01203444 0.00811537]
89
+ [ 0.00233847 -0.00596091 0.00536799]]]
90
+ """
91
+
92
+ def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0,
93
+ scale_grad_by_freq=False, _weight=None, dtype=mstype.float32):
94
+ """Initialize Embedding."""
95
+ super().__init__()
96
+ self.num_embeddings = Validator.check_value_type(
97
+ 'num_embeddings', num_embeddings, [int], self.cls_name)
98
+ self.embedding_dim = Validator.check_value_type(
99
+ 'embedding_dim', embedding_dim, [int], self.cls_name)
100
+ Validator.check_subclass(
101
+ "dtype", dtype, mstype.number_type, self.cls_name)
102
+ self.dtype = dtype
103
+ self.padding_idx = padding_idx
104
+ if _weight is None:
105
+ init_tensor = Tensor(shape=[num_embeddings, embedding_dim], dtype=dtype, init=Normal(1, 0))
106
+ init_tensor = self._zero_weight_by_index(init_tensor)
107
+ self.weight = Parameter(init_tensor, name='weight')
108
+ else:
109
+ self.weight = Parameter(_weight)
110
+
111
+ self.max_norm = max_norm
112
+ if max_norm is not None:
113
+ self.max_norm = Validator.check_value_type('max_norm', max_norm, [float], self.cls_name)
114
+
115
+ self.norm_type = norm_type
116
+ if norm_type is not None:
117
+ self.norm_type = Validator.check_value_type('norm_type', norm_type,
118
+ [float], self.cls_name)
119
+
120
+ self.scale_grad_by_freq = scale_grad_by_freq
121
+ if scale_grad_by_freq is not None:
122
+ self.scale_grad_by_freq = Validator.check_value_type('scale_grad_by_freq',
123
+ scale_grad_by_freq,
124
+ [bool], self.cls_name)
125
+
126
+ def _zero_weight_by_index(self, init_tensor):
127
+ if self.padding_idx is not None:
128
+ self.padding_idx = Validator.check_int_range(self.padding_idx, -self.num_embeddings, self.num_embeddings,
129
+ Validator.INC_LEFT, "padding_idx", self.cls_name)
130
+ if isinstance(init_tensor, Tensor) and init_tensor.init is not None:
131
+ init_tensor = init_tensor.init_data()
132
+ init_tensor[self.padding_idx] = 0
133
+
134
+ return init_tensor
135
+
136
+ def construct(self, input):
137
+ return ops.embedding(input, self.weight, self.padding_idx, self.max_norm,
138
+ self.norm_type, self.scale_grad_by_freq)
139
+
140
+ def extend_repr(self):
141
+ return f'num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, ' \
142
+ f'padding_idx={self.padding_idx}, max_norm={self.max_norm}, norm_type={self.norm_type}, ' \
143
+ f'scale_grad_by_freq={self.scale_grad_by_freq}, dtype={self.dtype}'
@@ -0,0 +1,27 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ Layer.
17
+
18
+ The high-level components(Cells) used to construct the neural network.
19
+ """
20
+ from __future__ import absolute_import
21
+
22
+ from mindspore.nn.extend.layer import normalization
23
+ from mindspore.nn.extend.layer.normalization import *
24
+
25
+ __all__ = []
26
+
27
+ __all__.extend(normalization.__all__)