mindspore 2.4.0__cp311-cp311-manylinux1_x86_64.whl → 2.4.1__cp311-cp311-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (97) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/_c_dataengine.cpython-311-x86_64-linux-gnu.so +0 -0
  3. mindspore/_c_expression.cpython-311-x86_64-linux-gnu.so +0 -0
  4. mindspore/common/initializer.py +51 -15
  5. mindspore/common/parameter.py +18 -4
  6. mindspore/common/tensor.py +15 -49
  7. mindspore/communication/comm_func.py +7 -7
  8. mindspore/context.py +9 -0
  9. mindspore/include/mindapi/base/format.h +13 -0
  10. mindspore/lib/libdnnl.so.2 +0 -0
  11. mindspore/lib/libmindspore_backend.so +0 -0
  12. mindspore/lib/libmindspore_common.so +0 -0
  13. mindspore/lib/libmindspore_core.so +0 -0
  14. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  15. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  16. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  17. mindspore/lib/libmindspore_ops.so +0 -0
  18. mindspore/lib/libopencv_core.so.4.5 +0 -0
  19. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  20. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  21. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/all_finite.json +10 -10
  22. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/binary_info_config.json +8 -8
  23. mindspore/lib/plugin/ascend/custom_compiler/setup.py +1 -1
  24. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  25. mindspore/lib/plugin/ascend/libmindspore_internal_kernels.so +0 -0
  26. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/base/types.h +5 -5
  27. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
  28. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
  29. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/liblcal.so +0 -0
  30. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/liblcal_static.a +0 -0
  31. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/acme_op.h +1 -0
  32. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/paged_attention_op.h +6 -1
  33. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/rms_norm_op.h +4 -3
  34. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_310p_impl.so +0 -0
  35. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
  36. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
  37. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
  38. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_bf16_bsh_full_mix.o +0 -0
  39. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
  40. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
  41. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
  42. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_fp16_bsh_full_mix.o +0 -0
  43. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
  44. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/paged_attention/paged_attention_bf16_bnsd_mix.o +0 -0
  45. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/paged_attention/paged_attention_bf16_bsh_mix.o +0 -0
  46. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/paged_attention/paged_attention_fp16_bnsd_mix.o +0 -0
  47. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/paged_attention/paged_attention_fp16_bsh_mix.o +0 -0
  48. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
  49. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  50. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  51. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  52. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  53. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  54. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  55. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  56. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  57. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  58. mindspore/mint/__init__.py +490 -2
  59. mindspore/mint/nn/__init__.py +2 -2
  60. mindspore/mint/optim/adamw.py +6 -14
  61. mindspore/nn/cell.py +1 -3
  62. mindspore/nn/layer/basic.py +24 -7
  63. mindspore/nn/layer/embedding.py +31 -14
  64. mindspore/nn/optim/tft_wrapper.py +12 -15
  65. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  66. mindspore/ops/_grad_experimental/grad_comm_ops.py +20 -1
  67. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +6 -0
  68. mindspore/ops/auto_generate/gen_extend_func.py +33 -0
  69. mindspore/ops/auto_generate/gen_ops_def.py +52 -3
  70. mindspore/ops/auto_generate/gen_ops_prim.py +155 -6
  71. mindspore/ops/function/array_func.py +2 -0
  72. mindspore/ops/function/math_func.py +7 -1
  73. mindspore/ops/function/random_func.py +221 -7
  74. mindspore/ops/operations/__init__.py +1 -1
  75. mindspore/ops/operations/array_ops.py +3 -1
  76. mindspore/ops/operations/comm_ops.py +21 -0
  77. mindspore/ops/operations/manually_defined/ops_def.py +8 -10
  78. mindspore/parallel/_auto_parallel_context.py +3 -1
  79. mindspore/parallel/_cell_wrapper.py +2 -0
  80. mindspore/parallel/_tensor.py +46 -2
  81. mindspore/parallel/_utils.py +40 -21
  82. mindspore/parallel/transform_safetensors.py +196 -43
  83. mindspore/profiler/profiling.py +5 -1
  84. mindspore/run_check/_check_version.py +4 -2
  85. mindspore/train/_utils.py +92 -32
  86. mindspore/train/callback/_checkpoint.py +12 -9
  87. mindspore/train/callback/_on_request_exit.py +12 -1
  88. mindspore/train/callback/_tft_register.py +27 -4
  89. mindspore/train/dataset_helper.py +10 -2
  90. mindspore/train/model.py +20 -0
  91. mindspore/train/serialization.py +8 -18
  92. mindspore/version.py +1 -1
  93. {mindspore-2.4.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +8 -6
  94. {mindspore-2.4.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +97 -97
  95. {mindspore-2.4.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
  96. {mindspore-2.4.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  97. {mindspore-2.4.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
@@ -35,7 +35,7 @@ namespace acme {
35
35
  class RmsNormOp : public MultiImplsOp {
36
36
  public:
37
37
  RmsNormOp(const InputsImmutableInfoList &inputs_ii, const OutputsImmutableInfoList &outputs_ii,
38
- const NormParam &param, const std::string &op_name);
38
+ const NormParam &param, const std::string &op_name);
39
39
  ~RmsNormOp() = default;
40
40
 
41
41
  AcmeStatus InitImpl() override;
@@ -43,8 +43,8 @@ class RmsNormOp : public MultiImplsOp {
43
43
  const std::string &TargetKernelName() const override { return target_kernel_name; }
44
44
  ShapeInfoList InferShape(const ShapeInfoList &inputs_shape) const override;
45
45
 
46
- protected:
47
- bool UseAsdopImpl() override { return false; }
46
+ protected:
47
+ bool UseAsdopImpl() override;
48
48
  AcmeStatus TilingImplAcme(RawHostAddr host_ptr, HostRunInfoPtr *run_info_ptr) override;
49
49
  AcmeStatus LaunchImplAcme(const InputsAddrList &input_ptrs, const OutputsAddrList &output_ptrs,
50
50
  const WsAddrList &ws_ptrs, void *stream) override;
@@ -53,6 +53,7 @@ protected:
53
53
 
54
54
  private:
55
55
  NormParam param_;
56
+ bool is_ascend_310p_{false};
56
57
  const std::string target_kernel_name{"NormOperation"};
57
58
  };
58
59
 
Binary file
Binary file
Binary file
Binary file
Binary file
@@ -15,6 +15,7 @@
15
15
  """mint module."""
16
16
  from __future__ import absolute_import
17
17
  import mindspore.ops as ops
18
+ from mindspore.ops.primitive import constexpr
18
19
  from mindspore.common._register_for_tensor import tensor_operator_registry_for_mint
19
20
  from mindspore.common.tensor import Tensor
20
21
  from mindspore.ops.function.array_func import gather_ext as gather, max_ext as max, min_ext as min
@@ -320,8 +321,13 @@ from mindspore.ops.auto_generate import erfc
320
321
  from mindspore.ops.auto_generate import expm1
321
322
  # 208
322
323
  from mindspore.ops.function.array_func import eye
324
+ from mindspore.ops.function.random_func import randperm_ext as randperm
323
325
  from mindspore.ops.function.random_func import rand_ext as rand
324
326
  from mindspore.ops.function.random_func import rand_like_ext as rand_like
327
+ from mindspore.ops.function.random_func import randn_ext as randn
328
+ from mindspore.ops.function.random_func import randn_like_ext as randn_like
329
+ from mindspore.ops.function.random_func import randint_ext as randint
330
+ from mindspore.ops.function.random_func import randint_like_ext as randint_like
325
331
  # 210
326
332
  from mindspore.ops.auto_generate import floor
327
333
  # 231
@@ -364,6 +370,9 @@ from mindspore.ops.auto_generate import logaddexp_ext as logaddexp
364
370
  # 610
365
371
  from mindspore.ops.function.math_func import nan_to_num
366
372
 
373
+ # 695
374
+ from mindspore.ops.auto_generate import count_nonzero
375
+
367
376
 
368
377
  def add(input, other, *, alpha=1):
369
378
  r"""
@@ -661,6 +670,474 @@ def cummax(input, dim):
661
670
  return ops.auto_generate.cummax(input, dim)
662
671
 
663
672
 
673
+ def _einsum_convert_sublist_to_label(num, ell_num=False):
674
+ """Convert sublist to label."""
675
+ if num == Ellipsis or ell_num and num == 52:
676
+ return '...'
677
+ if 0 <= num < 26:
678
+ return chr(num + ord('A'))
679
+ if 26 <= num < 52:
680
+ return chr(num + ord('a') - 26)
681
+ raise ValueError(f'For einsum, the number in sublist must be in range [0, 52), but got {num}')
682
+
683
+
684
+ def _einsum_convert_label_to_index(label):
685
+ """Convert label to index."""
686
+ label_num = ord(label)
687
+ if ord('A') <= label_num <= ord('Z'):
688
+ return label_num - ord('A')
689
+ if ord('a') <= label_num <= ord('z'):
690
+ return label_num - ord('a') + 26
691
+ if label_num == ord('.'):
692
+ return 52
693
+ raise ValueError(f'For einsum, the label in equation must be in [a-zA-Z] or ., but got {label}')
694
+
695
+
696
+ def _einsum_convert_sublist(equation, *operands):
697
+ """Convert the sublist to an equation operand if the received input is a sublist format."""
698
+ if isinstance(equation, Tensor):
699
+ equation_tmp = ''
700
+ for i, lst in enumerate(operands):
701
+ if i % 2 == 0:
702
+ for _, num in enumerate(lst):
703
+ equation_tmp += _einsum_convert_sublist_to_label(num)
704
+ if i in (len(operands) - 1, len(operands) - 2):
705
+ continue
706
+ equation_tmp += ','
707
+ if len(operands) % 2 == 0:
708
+ equation_tmp += '->'
709
+ for _, num in enumerate(operands[-1]):
710
+ equation_tmp += _einsum_convert_sublist_to_label(num)
711
+ operands_tmp = list([equation]) + list(operands[1:-1:2])
712
+ else:
713
+ operands_tmp = list([equation]) + list(operands[1::2])
714
+ equation = equation_tmp
715
+ operands = tuple(operands_tmp)
716
+ if len(operands) == 0: # pylint: disable=len-as-condition
717
+ raise ValueError("For einsum, the 'operands' must have at least one operand.")
718
+ return equation, operands
719
+
720
+
721
+ def _einsum_check_inputargs(equation, operands):
722
+ """Check equation and operands."""
723
+ if not isinstance(equation, str):
724
+ raise TypeError(f"For einsum, 'equation' must be a str, but got {type(equation)}.")
725
+ for operand in operands:
726
+ if not isinstance(operand, Tensor):
727
+ raise TypeError(f"For einsum, members of 'operands' must be Tensor, but got {type(operand)}.")
728
+
729
+
730
+ @constexpr
731
+ def _einsum_parse_equation(equation):
732
+ """Parse equation."""
733
+ l_equation = ''
734
+ r_equation = ''
735
+ equation = equation.replace(' ', '')
736
+
737
+ if '->' in equation:
738
+ l_equation, r_equation = equation.split('->', 1)
739
+ if l_equation == '':
740
+ raise ValueError('For einsum, equation must contain characters to the left fo the arrow.')
741
+ else:
742
+ l_equation = equation
743
+
744
+ if ',' in l_equation:
745
+ l_equationlst = l_equation.split(",")
746
+ else:
747
+ l_equationlst = [l_equation]
748
+
749
+ l_equationlst = []
750
+
751
+ for subequation in l_equation.split(','):
752
+ if '.' in subequation and ('...' not in subequation or subequation.count('.') != 3):
753
+ raise ValueError(f"For einsum, an ellipsis in the equation must include three continuous \'.\', "
754
+ f"and can only be found once.")
755
+ subequation_lst = [_einsum_convert_label_to_index(label) for label in subequation.replace('...', '.')]
756
+ l_equationlst.append(subequation_lst)
757
+
758
+ if "." in r_equation and ('...' not in r_equation or r_equation.count('.') != 3):
759
+ raise ValueError(f"For einsum, an ellipsis in the equation must include three continuous \'.\', "
760
+ f"and can only be found once.")
761
+ r_equationlst = [_einsum_convert_label_to_index(label) for label in r_equation.replace('...', '.')]
762
+
763
+ return l_equationlst, r_equationlst, ('->' in equation)
764
+
765
+
766
+ def _einsum_parse_labels(l_equationlst, operands):
767
+ """Parse left script of equation."""
768
+ align_rank = 0
769
+ max_labels = 53
770
+ labels_count = [0] * max_labels
771
+ labels2dimlst = [None] * max_labels
772
+
773
+ if len(operands) != len(l_equationlst):
774
+ raise ValueError(f"For einsum, 'operands' is not equal to specified in the 'equation', "
775
+ f"but got {len(operands)} and {len(l_equationlst)}.")
776
+
777
+ for idx, sub_equ in enumerate(l_equationlst):
778
+ start_dim = 0
779
+ label_num = 0
780
+ operand_shape = list(operands[idx].shape)
781
+ for label in sub_equ:
782
+ label_num += 1
783
+ end_dim = start_dim + 1
784
+
785
+ # Label is ellipsis
786
+ if label == 52:
787
+ end_dim = len(operand_shape) - len(sub_equ) + label_num
788
+ if labels2dimlst[label] is None:
789
+ labels2dimlst[label] = operand_shape[start_dim:end_dim]
790
+ align_rank += (end_dim - start_dim)
791
+ else:
792
+ if labels2dimlst[label] != operand_shape[start_dim:end_dim]:
793
+ raise ValueError(f"For einsum, one label in 'equation' can only represent the same dimension "
794
+ f"in 'operands', but '{_einsum_convert_sublist_to_label(label, True)}' "
795
+ f"represented different dimensions.")
796
+ labels_count[label] += 1
797
+ start_dim = end_dim
798
+ if label_num != len(sub_equ) or start_dim != len(operand_shape):
799
+ raise ValueError(f"For einsum, the numbers of labels specified in the 'equation' does not match "
800
+ f"'operands[{idx}]'.")
801
+ return labels2dimlst, labels_count, align_rank
802
+
803
+
804
+ def _einsum_infer_output(r_equationlst, arrow_exist, labels2dimlst, labels_count):
805
+ """Parse right script of equation and infer output shape."""
806
+ idx = 0
807
+ idle_idx = -1
808
+ output_shape = []
809
+ labels_perm_idx = [idle_idx] * 53
810
+
811
+ if arrow_exist:
812
+ for label in r_equationlst:
813
+ if labels_count[label] != 0:
814
+ output_shape += labels2dimlst[label]
815
+ if labels_perm_idx[label] != idle_idx:
816
+ raise ValueError(f"For einsum, '{_einsum_convert_sublist_to_label(label, True)}' or {label} in "
817
+ f"sublist format has appears more than once in output subscript.")
818
+ labels_perm_idx[label] = idx
819
+ idx += len(labels2dimlst[label])
820
+ else:
821
+ raise ValueError(f"For einsum, the label to the right of arrow in the 'equation' must appear on "
822
+ f"left, but '{_einsum_convert_sublist_to_label(label, True)}' does not.")
823
+ else:
824
+ if labels_count[52] != 0:
825
+ output_shape += labels2dimlst[52]
826
+ labels_perm_idx[52] = idx
827
+ idx += len(labels2dimlst[52])
828
+ for label, count in enumerate(labels_count):
829
+ if count == 1:
830
+ output_shape += labels2dimlst[label]
831
+ labels_perm_idx[label] = idx
832
+ idx += len(labels2dimlst[label])
833
+
834
+ for label, count in enumerate(labels_count):
835
+ if count != 0 and labels_perm_idx[label] == idle_idx:
836
+ labels_perm_idx[label] = idx
837
+ idx += 1
838
+
839
+ return output_shape, labels_perm_idx
840
+
841
+
842
+ def _einsum_adjust_operands(operands, l_equationlst, labels2dimlst, labels_perm_idx, align_rank):
843
+ """Align operands to output as possible."""
844
+ # Unsqueeze miss dimensions to make all operands has same rank, compute diagonal if operand has same label.
845
+ # Then use _labels_perm_idx to transpose all operands to align dimensions with output.
846
+ adjust_operands = []
847
+ for idx, operand in enumerate(operands):
848
+ idle_dim = -1
849
+ align_axis = [idle_dim] * align_rank
850
+ label_dims = [idle_dim] * 53
851
+ dim = 0
852
+
853
+ for label in l_equationlst[idx]:
854
+ if label_dims[label] != idle_dim:
855
+ operand = ops.diagonal(operand, 0, label_dims[label], dim)
856
+ diag_perm = []
857
+ diag_dim = 0
858
+ for i in range(len(operand.shape)):
859
+ if i == label_dims[label]:
860
+ diag_perm.append(len(operand.shape) - 1)
861
+ else:
862
+ diag_perm.append(diag_dim)
863
+ diag_dim += 1
864
+ operand = permute(operand, tuple(diag_perm))
865
+ else:
866
+ label_dims[label] = dim
867
+ if label == 52:
868
+ for ell_idx in range(len(labels2dimlst[label])):
869
+ align_axis[labels_perm_idx[label] + ell_idx] = dim
870
+ dim += 1
871
+ else:
872
+ align_axis[labels_perm_idx[label]] = dim
873
+ dim += 1
874
+ if len(operand.shape) < align_rank:
875
+ for i, axis in enumerate(align_axis):
876
+ if axis == idle_dim:
877
+ align_axis[i] = dim
878
+ dim += 1
879
+ missing_dims = [1] * (align_rank - len(operand.shape))
880
+ operand_shape = list(operand.shape) + missing_dims
881
+ operand = reshape(operand, operand_shape)
882
+ operand = permute(operand, tuple(align_axis))
883
+ adjust_operands.append(operand)
884
+ return adjust_operands
885
+
886
+
887
+ def _einsum_find_dimlastop(align_rank, operands, adjust_operands):
888
+ """Find dim last operand."""
889
+ dim_last_op = [0 for _ in range(align_rank)]
890
+ has_zero_dim = False
891
+ for dim in range(align_rank):
892
+ broadcast_dim = adjust_operands[0].shape[dim]
893
+ for idx in range(1, len(adjust_operands)):
894
+ other_dim = adjust_operands[idx].shape[dim]
895
+ if broadcast_dim != other_dim and broadcast_dim != 1 and other_dim != 1:
896
+ err_msg = "For einsum, operands do not broadcast after align to output [shapes :origin -> adjust]:"
897
+ for i in range(len(operands)):
898
+ err_msg += f" {operands[i].shape} -> {adjust_operands[i].shape}"
899
+ raise ValueError(err_msg)
900
+ if other_dim != 1:
901
+ dim_last_op[dim] = idx
902
+ broadcast_dim = other_dim
903
+ has_zero_dim = has_zero_dim or broadcast_dim == 0
904
+ return dim_last_op, has_zero_dim
905
+
906
+
907
+ def _einsum_multiplication(sum_dims, l_tensor, r_tensor):
908
+ """Compute bmm for einsum."""
909
+ batch_dims = []
910
+ lonly_dims = []
911
+ ronly_dims = []
912
+ batch_size = 1
913
+ lonly_size = 1
914
+ ronly_size = 1
915
+ sum_size = 1
916
+
917
+ l_shape = l_tensor.shape
918
+ r_shape = r_tensor.shape
919
+
920
+ # Compute sum if dim is in sum_dims and get shapes for bmm
921
+ for i in range(len(l_shape)):
922
+ sum_l = l_shape[i] > 1
923
+ sum_r = r_shape[i] > 1
924
+ if i in sum_dims:
925
+ if sum_l and sum_r:
926
+ sum_size *= l_shape[i]
927
+ elif sum_l:
928
+ l_tensor = sum(l_tensor, i, True)
929
+ elif sum_r:
930
+ r_tensor = sum(r_tensor, i, True)
931
+ elif sum_l and sum_r:
932
+ batch_dims.append(i)
933
+ batch_size *= l_shape[i]
934
+ elif sum_l:
935
+ lonly_dims.append(i)
936
+ lonly_size *= l_shape[i]
937
+ else:
938
+ ronly_dims.append(i)
939
+ ronly_size *= r_shape[i]
940
+
941
+ # Compute the einsum bmm operators pipeline.
942
+ # The whole operators pipline is transpose(in) -> reshape(in) -> bmm(in) -> reshape(out) -> transpose(out).
943
+ l_reshape_shape = (batch_size, lonly_size, sum_size)
944
+ r_reshape_shape = (batch_size, sum_size, ronly_size)
945
+
946
+ out_reshape_shape = [l_shape[dim] for dim in batch_dims]
947
+ out_reshape_shape += [l_shape[dim] for dim in lonly_dims]
948
+ out_reshape_shape += [1 for _ in sum_dims]
949
+ out_reshape_shape += [r_shape[dim] for dim in ronly_dims]
950
+
951
+ l_perm_axis = batch_dims + lonly_dims + sum_dims + ronly_dims
952
+ r_perm_axis = batch_dims + sum_dims + ronly_dims + lonly_dims
953
+ out_perm_axis = [-1] * len(out_reshape_shape)
954
+
955
+ out_dim = 0
956
+ for idx in range(len(l_perm_axis)):
957
+ out_perm_axis[l_perm_axis[idx]] = out_dim
958
+ out_dim += 1
959
+
960
+ l_tensor = permute(l_tensor, tuple(l_perm_axis))
961
+ l_tensor = reshape(l_tensor, l_reshape_shape)
962
+
963
+ r_tensor = permute(r_tensor, tuple(r_perm_axis))
964
+ r_tensor = reshape(r_tensor, r_reshape_shape)
965
+
966
+ output = bmm(l_tensor, r_tensor)
967
+ output = reshape(output, out_reshape_shape)
968
+ output = permute(output, tuple(out_perm_axis))
969
+
970
+ output_origin_shape = output.shape
971
+ output_squeeze_shape = []
972
+ for dim in range(len(output_origin_shape)):
973
+ if dim not in sum_dims:
974
+ output_squeeze_shape.append(output_origin_shape[dim])
975
+
976
+ return reshape(output, output_squeeze_shape)
977
+
978
+
979
+ def _einsum_squeeze(operand, dim):
980
+ '''Will be replaced by mint.squeeze in the future'''
981
+ operand_shape = operand.shape
982
+ squeeze_shape = []
983
+ for idx in range(len(operand_shape)):
984
+ if idx != dim:
985
+ squeeze_shape.append(operand_shape[idx])
986
+ return reshape(operand, squeeze_shape)
987
+
988
+
989
+ def _einsum(equation, operands):
990
+ '''Einsum main process'''
991
+ _l_equationlst, _r_equationlst, _arrow_exist = _einsum_parse_equation(equation)
992
+ _labels2dimlst, _labels_count, _align_rank = _einsum_parse_labels(_l_equationlst, operands)
993
+ _output_shape, _labels_perm_idx = _einsum_infer_output(_r_equationlst, _arrow_exist, _labels2dimlst, _labels_count)
994
+ _output_rank = len(_output_shape)
995
+
996
+ _adjust_operands = _einsum_adjust_operands(operands, _l_equationlst, _labels2dimlst, _labels_perm_idx, _align_rank)
997
+ _dim_last_op, _has_zero_dim = _einsum_find_dimlastop(_align_rank, operands, _adjust_operands)
998
+ _result = _adjust_operands[0]
999
+
1000
+ # Fast path if operands has zero dim.
1001
+ if _has_zero_dim:
1002
+ return zeros(_output_shape, dtype=_result.dtype)
1003
+
1004
+ # Sum or squeeze dimensions that is 1 for all rest operands.
1005
+ _reduce_dim = _output_rank
1006
+ for dim in range(_output_rank, _align_rank):
1007
+ if _dim_last_op[dim] == 0:
1008
+ if _result.shape[_reduce_dim] == 1:
1009
+ _result = _einsum_squeeze(_result, _reduce_dim)
1010
+ else:
1011
+ _result = sum(_result, _reduce_dim)
1012
+ else:
1013
+ _reduce_dim += 1
1014
+
1015
+ # Compute multiplication if operands are more than two.
1016
+ for i in range(1, len(_adjust_operands)):
1017
+ operand = _adjust_operands[i]
1018
+ dim = _output_rank
1019
+ sum_dims = []
1020
+ for j in range(_output_rank, _align_rank):
1021
+ if _dim_last_op[j] < i:
1022
+ operand = _einsum_squeeze(operand, dim)
1023
+ elif _dim_last_op[j] == i:
1024
+ if _result.shape[dim] == 1:
1025
+ operand = sum(operand, dim)
1026
+ _result = _einsum_squeeze(_result, dim)
1027
+ else:
1028
+ sum_dims.append(dim)
1029
+ dim += 1
1030
+ else:
1031
+ dim += 1
1032
+
1033
+ if sum_dims == []:
1034
+ _result = mul(_result, operand)
1035
+ elif len(sum_dims) == len(_result.shape):
1036
+ _result = ops.auto_generate.dot(flatten(_result), flatten(operand))
1037
+ else:
1038
+ _result = _einsum_multiplication(sum_dims, _result, operand)
1039
+
1040
+ return _result
1041
+
1042
+
1043
+ def einsum(equation, *operands):
1044
+ r"""
1045
+ According to the Einstein summation Convention (Einsum),
1046
+ the product of the input tensor elements is summed along the specified dimension.
1047
+ You can use this operator to perform diagonal, reducesum, transpose, matmul, mul, inner product operations, etc.
1048
+
1049
+ Note:
1050
+ The sublist format is also supported. For example, mint.einsum(op1, sublist1, op2, sublist2, ..., sublist_out).
1051
+ In this format, equation can be derived by the sublists which are made up of Python's Ellipsis and list of
1052
+ integers in [0, 52). Each operand is followed by a sublist and an output sublist is at the end.
1053
+
1054
+ .. warning::
1055
+ This is an experimental API that is subject to change or deletion.
1056
+
1057
+ Args:
1058
+ equation (str): Notation based on the Einstein summation convention, represent the operation you want to do.
1059
+ the value can contain only letters, commas, ellipsis and arrow.
1060
+ The letters represent input tensor dimension, commas represent separate tensors, ellipsis indicates
1061
+ the tensor dimension that you do not care about, the left of the arrow indicates the input tensors,
1062
+ and the right of it indicates the desired output dimension.
1063
+ operands (Tensor): Input tensor used for calculation. The dtype of the tensor must be the same.
1064
+
1065
+ Returns:
1066
+ Tensor, the shape of it can be obtained from the `equation` , and the dtype is the same as input tensors.
1067
+
1068
+ Raises:
1069
+ TypeError: If `equation` is invalid, or the `equation` does not match the input tensor.
1070
+ ValueError: If the number in sublist is not in [0, 52) in sublist format.
1071
+
1072
+ Supported Platforms:
1073
+ ``Ascend``
1074
+
1075
+ Examples:
1076
+ >>> import mindspore
1077
+ >>> import numpy as np
1078
+ >>> from mindspore import Tensor, mint
1079
+ >>> x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
1080
+ >>> equation = "i->"
1081
+ >>> output = mint.einsum(equation, x)
1082
+ >>> print(output)
1083
+ [7.]
1084
+ >>> x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
1085
+ >>> y = Tensor(np.array([2.0, 4.0, 3.0]), mindspore.float32)
1086
+ >>> equation = "i,i->i"
1087
+ >>> output = mint.einsum(equation, x, y)
1088
+ >>> print(output)
1089
+ [ 2. 8. 12.]
1090
+ >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
1091
+ >>> y = Tensor(np.array([[2.0, 3.0], [1.0, 2.0], [4.0, 5.0]]), mindspore.float32)
1092
+ >>> equation = "ij,jk->ik"
1093
+ >>> output = mint.einsum(equation, x, y)
1094
+ >>> print(output)
1095
+ [[16. 22.]
1096
+ [37. 52.]]
1097
+ >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
1098
+ >>> equation = "ij->ji"
1099
+ >>> output = mint.einsum(equation, x)
1100
+ >>> print(output)
1101
+ [[1. 4.]
1102
+ [2. 5.]
1103
+ [3. 6.]]
1104
+ >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
1105
+ >>> equation = "ij->j"
1106
+ >>> output = mint.einsum(equation, x)
1107
+ >>> print(output)
1108
+ [5. 7. 9.]
1109
+ >>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
1110
+ >>> equation = "...->"
1111
+ >>> output = mint.einsum(equation, x)
1112
+ >>> print(output)
1113
+ [21.]
1114
+ >>> x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
1115
+ >>> y = Tensor(np.array([2.0, 4.0, 1.0]), mindspore.float32)
1116
+ >>> equation = "j,i->ji"
1117
+ >>> output = mint.einsum(equation, x, y)
1118
+ >>> print(output)
1119
+ [[ 2. 4. 1.]
1120
+ [ 4. 8. 2.]
1121
+ [ 6. 12. 3.]]
1122
+ >>> x = mindspore.Tensor([1, 2, 3, 4], mindspore.float32)
1123
+ >>> y = mindspore.Tensor([1, 2], mindspore.float32)
1124
+ >>> output = mint.einsum(x, [..., 1], y, [..., 2], [..., 1, 2])
1125
+ >>> print(output)
1126
+ [[1. 2.]
1127
+ [2. 4.]
1128
+ [3. 6.]
1129
+ [4. 8.]]
1130
+ """
1131
+ _equation, _operands = _einsum_convert_sublist(equation, *operands)
1132
+ _einsum_check_inputargs(_equation, _operands)
1133
+
1134
+ for operand in _operands:
1135
+ if ops.is_sequence_shape_unknown(operand.shape) or ops.is_sequence_value_unknown(operand.shape):
1136
+ raise ValueError(f"For einsum, the element of 'operands' can't be dynamic shape or dynamic rank.")
1137
+
1138
+ return _einsum(_equation, _operands)
1139
+
1140
+
664
1141
  def item(input):
665
1142
  r"""
666
1143
  Returns the value of this tensor as a standard Python number.
@@ -694,7 +1171,8 @@ def item(input):
694
1171
  if not isinstance(input, Tensor):
695
1172
  raise TypeError(f"the input must be a Tensor, but got {type(input)}")
696
1173
  if input.size != 1:
697
- raise RuntimeError("a Tensor with {} elements cannot be converted to Scalar".format(input.size))
1174
+ raise RuntimeError(
1175
+ "a Tensor with {} elements cannot be converted to Scalar".format(input.size))
698
1176
  return input.asnumpy().item()
699
1177
 
700
1178
 
@@ -1283,6 +1761,7 @@ __all__ = [
1283
1761
  # 31
1284
1762
  'cummax',
1285
1763
  'cummin',
1764
+ 'einsum',
1286
1765
  'sub',
1287
1766
  # 33
1288
1767
  'split',
@@ -1518,8 +1997,13 @@ __all__ = [
1518
1997
 
1519
1998
  # 256
1520
1999
  'median',
2000
+ 'randperm',
1521
2001
  'rand',
1522
2002
  'rand_like',
2003
+ 'randn',
2004
+ 'randn_like',
2005
+ 'randint',
2006
+ 'randint_like',
1523
2007
  # 210
1524
2008
  'floor',
1525
2009
  # 231
@@ -1554,6 +2038,9 @@ __all__ = [
1554
2038
 
1555
2039
  # 610
1556
2040
  'nan_to_num',
2041
+
2042
+ # 695
2043
+ 'count_nonzero',
1557
2044
  ]
1558
2045
 
1559
2046
  setattr(tensor_operator_registry_for_mint, 'add', add)
@@ -1568,7 +2055,8 @@ setattr(tensor_operator_registry_for_mint, 'item', item)
1568
2055
  setattr(tensor_operator_registry_for_mint, 'max', max)
1569
2056
  setattr(tensor_operator_registry_for_mint, 'mean', mean)
1570
2057
  setattr(tensor_operator_registry_for_mint, 'min', min)
1571
- setattr(tensor_operator_registry_for_mint, 'repeat_interleave', repeat_interleave)
2058
+ setattr(tensor_operator_registry_for_mint,
2059
+ 'repeat_interleave', repeat_interleave)
1572
2060
  setattr(tensor_operator_registry_for_mint, 'ne', ne)
1573
2061
  setattr(tensor_operator_registry_for_mint, 'round', round)
1574
2062
  setattr(tensor_operator_registry_for_mint, 'sin', sin)
@@ -28,7 +28,7 @@ from mindspore.nn import EmbeddingExt as Embedding, MaxPool2dExt as MaxPool2d, L
28
28
  # 2
29
29
 
30
30
  # 3
31
-
31
+ from mindspore.nn.layer.basic import Identity
32
32
  # 4
33
33
 
34
34
  # 5
@@ -529,7 +529,7 @@ __all__ = [
529
529
  # 2
530
530
 
531
531
  # 3
532
-
532
+ 'Identity',
533
533
  # 4
534
534
 
535
535
  # 5