mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +8 -1
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +275 -64
- mindspore/common/_utils.py +0 -44
- mindspore/common/api.py +285 -35
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
- mindspore/common/hook_handle.py +60 -0
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/parameter.py +13 -107
- mindspore/common/recompute.py +4 -11
- mindspore/common/tensor.py +16 -169
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +5 -85
- mindspore/dataset/engine/datasets.py +8 -4
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dnnl.dll +0 -0
- mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +1 -1
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/distributed.py +182 -62
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +4 -4
- mindspore/mint/nn/layer/normalization.py +8 -13
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +16 -66
- mindspore/nn/layer/basic.py +49 -1
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +31 -124
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +0 -1
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +0 -1
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/wrap/grad_reducer.py +4 -74
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/fft.py +9 -9
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +0 -5
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
- mindspore/ops/auto_generate/gen_extend_func.py +2 -7
- mindspore/ops/auto_generate/gen_ops_def.py +98 -141
- mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +15 -1
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +1 -0
- mindspore/ops/function/array_func.py +14 -12
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +3 -4
- mindspore/ops/function/math_func.py +45 -54
- mindspore/ops/function/nn_func.py +75 -294
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +2 -0
- mindspore/ops/functional_overload.py +354 -18
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +1 -38
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +1 -0
- mindspore/ops/operations/comm_ops.py +94 -2
- mindspore/ops/operations/custom_ops.py +228 -19
- mindspore/ops/operations/debug_ops.py +27 -29
- mindspore/ops/operations/manually_defined/ops_def.py +27 -306
- mindspore/ops/operations/nn_ops.py +2 -2
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +1 -17
- mindspore/ops/tensor_method.py +72 -3
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +1 -4
- mindspore/parallel/_utils.py +29 -6
- mindspore/parallel/checkpoint_transform.py +18 -2
- mindspore/parallel/cluster/process_entity/_api.py +24 -32
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +117 -16
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +38 -2
- mindspore/profiler/common/path_manager.py +56 -24
- mindspore/profiler/common/profiler_context.py +2 -12
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/experimental_config.py +2 -1
- mindspore/profiler/platform/npu_profiler.py +33 -6
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +3 -2
- mindspore/runtime/executor.py +11 -3
- mindspore/runtime/memory.py +112 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +5 -18
- mindspore/train/amp.py +6 -4
- mindspore/train/callback/_checkpoint.py +0 -9
- mindspore/train/callback/_train_fault_tolerance.py +69 -18
- mindspore/train/data_sink.py +1 -5
- mindspore/train/model.py +38 -211
- mindspore/train/serialization.py +126 -387
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +144 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/train/memory_profiling_pb2.py +0 -298
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -13,8 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""PowerTransform Bijector"""
|
|
16
|
-
|
|
17
|
-
from mindspore.ops import functional as F
|
|
16
|
+
import mindspore.ops as ops
|
|
18
17
|
from ..distribution._utils.utils import check_greater_equal_zero
|
|
19
18
|
from ..distribution._utils.custom_ops import exp_generic, log_generic
|
|
20
19
|
from .bijector import Bijector
|
|
@@ -76,16 +75,16 @@ class PowerTransform(Bijector):
|
|
|
76
75
|
self._power = self._add_parameter(power, 'power')
|
|
77
76
|
check_greater_equal_zero(self._power, 'Power')
|
|
78
77
|
|
|
79
|
-
self.pow =
|
|
80
|
-
self.dtypeop =
|
|
81
|
-
self.cast =
|
|
82
|
-
self.equal_base =
|
|
78
|
+
self.pow = ops.Pow()
|
|
79
|
+
self.dtypeop = ops.DType()
|
|
80
|
+
self.cast = ops.Cast()
|
|
81
|
+
self.equal_base = ops.Equal()
|
|
83
82
|
self.exp = exp_generic
|
|
84
|
-
self.expm1 =
|
|
83
|
+
self.expm1 = ops.Expm1()
|
|
85
84
|
self.log = log_generic
|
|
86
|
-
self.log1p =
|
|
87
|
-
self.select_base =
|
|
88
|
-
self.shape =
|
|
85
|
+
self.log1p = ops.Log1p()
|
|
86
|
+
self.select_base = ops.Select()
|
|
87
|
+
self.shape = ops.Shape()
|
|
89
88
|
|
|
90
89
|
@property
|
|
91
90
|
def power(self):
|
|
@@ -113,17 +112,17 @@ class PowerTransform(Bijector):
|
|
|
113
112
|
power_local = self.cast_param_by_value(x, self.power)
|
|
114
113
|
|
|
115
114
|
# broad cast the value of x and power
|
|
116
|
-
ones =
|
|
117
|
-
|
|
115
|
+
ones = ops.fill(self.dtypeop(power_local), self.shape(x + power_local),
|
|
116
|
+
1.)
|
|
118
117
|
power_local = power_local * ones
|
|
119
118
|
x = x * ones
|
|
120
119
|
safe_power = self.select_base(
|
|
121
120
|
self.equal_base(power_local,
|
|
122
|
-
|
|
121
|
+
ops.ZerosLike()(power_local)), ones, power_local)
|
|
123
122
|
|
|
124
123
|
forward_v = self.select_base(
|
|
125
124
|
self.equal_base(power_local,
|
|
126
|
-
|
|
125
|
+
ops.ZerosLike()(power_local)), self.exp(x),
|
|
127
126
|
self.exp(self.log1p(x * safe_power) / safe_power))
|
|
128
127
|
return forward_v
|
|
129
128
|
|
|
@@ -135,17 +134,17 @@ class PowerTransform(Bijector):
|
|
|
135
134
|
power_local = self.cast_param_by_value(y, self.power)
|
|
136
135
|
|
|
137
136
|
# broad cast the value of x and power
|
|
138
|
-
ones =
|
|
139
|
-
|
|
137
|
+
ones = ops.fill(self.dtypeop(power_local), self.shape(y + power_local),
|
|
138
|
+
1.)
|
|
140
139
|
power_local = power_local * ones
|
|
141
140
|
y = y * ones
|
|
142
141
|
safe_power = self.select_base(
|
|
143
142
|
self.equal_base(power_local,
|
|
144
|
-
|
|
143
|
+
ops.ZerosLike()(power_local)), ones, power_local)
|
|
145
144
|
|
|
146
145
|
inverse_v = self.select_base(
|
|
147
146
|
self.equal_base(power_local,
|
|
148
|
-
|
|
147
|
+
ops.ZerosLike()(power_local)), self.log(y),
|
|
149
148
|
self.expm1(self.log(y) * safe_power) / safe_power)
|
|
150
149
|
|
|
151
150
|
return inverse_v
|
|
@@ -166,14 +165,14 @@ class PowerTransform(Bijector):
|
|
|
166
165
|
power_local = self.cast_param_by_value(x, self.power)
|
|
167
166
|
|
|
168
167
|
# broad cast the value of x and power
|
|
169
|
-
ones =
|
|
170
|
-
|
|
168
|
+
ones = ops.fill(self.dtypeop(power_local), self.shape(x + power_local),
|
|
169
|
+
1.)
|
|
171
170
|
power_local = power_local * ones
|
|
172
171
|
x = x * ones
|
|
173
172
|
|
|
174
173
|
forward_log_j = self.select_base(
|
|
175
174
|
self.equal_base(power_local,
|
|
176
|
-
|
|
175
|
+
ops.ZerosLike()(power_local)), x,
|
|
177
176
|
(1. / power_local - 1) * self.log1p(x * power_local))
|
|
178
177
|
|
|
179
178
|
return forward_log_j
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Scalar Affine Bijector"""
|
|
16
|
-
|
|
16
|
+
import mindspore.ops as ops
|
|
17
17
|
from ..distribution._utils.custom_ops import log_generic
|
|
18
18
|
from .bijector import Bijector
|
|
19
19
|
|
|
@@ -86,10 +86,10 @@ class ScalarAffine(Bijector):
|
|
|
86
86
|
self._scale = self._add_parameter(scale, 'scale')
|
|
87
87
|
self._shift = self._add_parameter(shift, 'shift')
|
|
88
88
|
|
|
89
|
-
self.abs =
|
|
90
|
-
self.oneslike =
|
|
91
|
-
self.dtypeop =
|
|
92
|
-
self.cast =
|
|
89
|
+
self.abs = ops.Abs()
|
|
90
|
+
self.oneslike = ops.OnesLike()
|
|
91
|
+
self.dtypeop = ops.DType()
|
|
92
|
+
self.cast = ops.Cast()
|
|
93
93
|
self.log = log_generic
|
|
94
94
|
|
|
95
95
|
@property
|
|
@@ -14,8 +14,7 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Softplus Bijector"""
|
|
16
16
|
import numpy as np
|
|
17
|
-
|
|
18
|
-
from mindspore.ops import functional as F
|
|
17
|
+
import mindspore.ops as ops
|
|
19
18
|
from mindspore.nn.layer.activation import LogSigmoid
|
|
20
19
|
from ..distribution._utils.custom_ops import exp_generic, log_generic
|
|
21
20
|
from .bijector import Bijector
|
|
@@ -82,17 +81,17 @@ class Softplus(Bijector):
|
|
|
82
81
|
|
|
83
82
|
self.exp = exp_generic
|
|
84
83
|
self.log = log_generic
|
|
85
|
-
self.expm1 =
|
|
86
|
-
self.abs =
|
|
87
|
-
self.dtypeop =
|
|
88
|
-
self.cast =
|
|
89
|
-
self.greater =
|
|
90
|
-
self.less =
|
|
84
|
+
self.expm1 = ops.Expm1()
|
|
85
|
+
self.abs = ops.Abs()
|
|
86
|
+
self.dtypeop = ops.DType()
|
|
87
|
+
self.cast = ops.Cast()
|
|
88
|
+
self.greater = ops.Greater()
|
|
89
|
+
self.less = ops.Less()
|
|
91
90
|
self.log_sigmoid = LogSigmoid()
|
|
92
|
-
self.logicalor =
|
|
93
|
-
self.select =
|
|
94
|
-
self.shape =
|
|
95
|
-
self.sigmoid =
|
|
91
|
+
self.logicalor = ops.LogicalOr()
|
|
92
|
+
self.select = ops.Select()
|
|
93
|
+
self.shape = ops.Shape()
|
|
94
|
+
self.sigmoid = ops.Sigmoid()
|
|
96
95
|
self.softplus = self._softplus
|
|
97
96
|
self.inverse_softplus = self._inverse_softplus
|
|
98
97
|
|
|
@@ -104,7 +103,7 @@ class Softplus(Bijector):
|
|
|
104
103
|
too_large = self.greater(x, -self.threshold)
|
|
105
104
|
too_small_value = self.exp(x)
|
|
106
105
|
too_large_value = x
|
|
107
|
-
ones =
|
|
106
|
+
ones = ops.fill(self.dtypeop(x), self.shape(x), 1.0)
|
|
108
107
|
too_small_or_too_large = self.logicalor(too_small, too_large)
|
|
109
108
|
x = self.select(too_small_or_too_large, ones, x)
|
|
110
109
|
y = self.log(self.exp(x) + 1.0)
|
|
@@ -120,7 +119,7 @@ class Softplus(Bijector):
|
|
|
120
119
|
too_large = self.greater(x, (-1) * self.threshold)
|
|
121
120
|
too_small_value = self.log(x)
|
|
122
121
|
too_large_value = x
|
|
123
|
-
ones =
|
|
122
|
+
ones = ops.fill(self.dtypeop(x), self.shape(x), 1.0)
|
|
124
123
|
too_small_or_too_large = self.logicalor(too_small, too_large)
|
|
125
124
|
x = self.select(too_small_or_too_large, ones, x)
|
|
126
125
|
y = x + self.log(self.abs(self.expm1((-1)*x)))
|
|
@@ -140,34 +140,6 @@ def _tensors_allreduce_post(degree, mean, allreduce_filter, grad):
|
|
|
140
140
|
return grad
|
|
141
141
|
|
|
142
142
|
|
|
143
|
-
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor", "Bool")
|
|
144
|
-
def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
|
|
145
|
-
"""
|
|
146
|
-
Apply allreduce on gradient.
|
|
147
|
-
|
|
148
|
-
Args:
|
|
149
|
-
degree (int): The mean coefficient.
|
|
150
|
-
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
|
|
151
|
-
allgather (Primitive): The communication operator for sparse gradients.
|
|
152
|
-
allreduce (Primitive): The communication operator for gradients.
|
|
153
|
-
allreduce_filter (bool): When it is true, allreduce would apply.
|
|
154
|
-
grad (Tensor): The gradient tensor before operation.
|
|
155
|
-
ps_parameter (bool): Use parameter server or not.
|
|
156
|
-
|
|
157
|
-
Returns:
|
|
158
|
-
Tensor, the gradient tensor after operation.
|
|
159
|
-
"""
|
|
160
|
-
if ps_parameter:
|
|
161
|
-
return grad
|
|
162
|
-
|
|
163
|
-
if allreduce_filter:
|
|
164
|
-
grad = allreduce(grad)
|
|
165
|
-
if mean:
|
|
166
|
-
grad = ops.tensor_mul(grad, ops.cast(degree, ops.dtype(grad)))
|
|
167
|
-
return grad
|
|
168
|
-
return grad
|
|
169
|
-
|
|
170
|
-
|
|
171
143
|
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor")
|
|
172
144
|
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad):
|
|
173
145
|
"""
|
|
@@ -193,37 +165,6 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce
|
|
|
193
165
|
grad = RowTensorInner(indices, dout, grad.dense_shape)
|
|
194
166
|
return grad
|
|
195
167
|
|
|
196
|
-
|
|
197
|
-
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool")
|
|
198
|
-
def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
|
|
199
|
-
"""
|
|
200
|
-
Apply allgather on gradient instead of allreduce for sparse feature.
|
|
201
|
-
Allgather is a communication operation used for distributed deep learning.
|
|
202
|
-
|
|
203
|
-
Args:
|
|
204
|
-
degree (int): The mean coefficient.
|
|
205
|
-
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
|
|
206
|
-
allgather (Primitive): The communication operator for sparse gradients.
|
|
207
|
-
allreduce (Primitive): The communication operator for gradients.
|
|
208
|
-
allreduce_filter (bool): When it is true, allgather would apply.
|
|
209
|
-
grad (tuple): The indices, gradient tensor and tensor_shape before operation.
|
|
210
|
-
ps_parameter (bool): Use parameter server or not.
|
|
211
|
-
|
|
212
|
-
Returns:
|
|
213
|
-
RowTensor, the gradient after operation.
|
|
214
|
-
"""
|
|
215
|
-
if ps_parameter:
|
|
216
|
-
return grad
|
|
217
|
-
|
|
218
|
-
if allreduce_filter:
|
|
219
|
-
indices = allgather(grad.indices)
|
|
220
|
-
dout = allgather(grad.values)
|
|
221
|
-
if mean:
|
|
222
|
-
dout = ops.tensor_mul(dout, ops.cast(degree, ops.dtype(dout)))
|
|
223
|
-
grad = RowTensorInner(indices, dout, grad.dense_shape)
|
|
224
|
-
return grad
|
|
225
|
-
|
|
226
|
-
|
|
227
168
|
_get_datatype = ops.MultitypeFuncGraph("_get_datatype")
|
|
228
169
|
|
|
229
170
|
|
|
@@ -423,9 +364,6 @@ class DistributedGradReducer(Cell):
|
|
|
423
364
|
self.split_fusion = False
|
|
424
365
|
self.allreduce = AllReduce('sum', group).add_prim_attr('fusion', fusion_type)
|
|
425
366
|
self.allgather = AllGather(group)
|
|
426
|
-
ps_filter = lambda x: x.is_param_ps
|
|
427
|
-
self.ps_parameters = tuple(ps_filter(x) for x in parameters)
|
|
428
|
-
self.enable_parameter_server = any(self.ps_parameters)
|
|
429
367
|
self.mode = context.get_context("mode")
|
|
430
368
|
self.enable_tuple_broaden = True
|
|
431
369
|
|
|
@@ -446,19 +384,11 @@ class DistributedGradReducer(Cell):
|
|
|
446
384
|
grads = self.map_(ops.partial(_cast_datatype, mstype.float32), grads)
|
|
447
385
|
|
|
448
386
|
if self.split_fusion:
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
self.op_list, self.allreduce_filter, grads, self.ps_parameters)
|
|
452
|
-
else:
|
|
453
|
-
new_grad = self.map_(ops.partial(reduce_opt, self.degree, self.mean, self.allgather),
|
|
454
|
-
self.op_list, self.allreduce_filter, grads)
|
|
387
|
+
new_grad = self.map_(ops.partial(reduce_opt, self.degree, self.mean, self.allgather),
|
|
388
|
+
self.op_list, self.allreduce_filter, grads)
|
|
455
389
|
else:
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
self.allreduce), self.allreduce_filter, grads, self.ps_parameters)
|
|
459
|
-
else:
|
|
460
|
-
new_grad = self.map_(ops.partial(reduce_opt, self.degree, self.mean, self.allgather,
|
|
461
|
-
self.allreduce), self.allreduce_filter, grads)
|
|
390
|
+
new_grad = self.map_(ops.partial(reduce_opt, self.degree, self.mean, self.allgather,
|
|
391
|
+
self.allreduce), self.allreduce_filter, grads)
|
|
462
392
|
new_grad = self.map_(ops.partial(_cast_datatype), datatypes, new_grad)
|
|
463
393
|
return new_grad
|
|
464
394
|
|
|
@@ -2622,7 +2622,7 @@ def pad(arr, pad_width, mode="constant", stat_length=None, constant_values=0,
|
|
|
2622
2622
|
unique pad widths for each axis. ``((before, after),)`` yields same
|
|
2623
2623
|
before and after pad for each axis. ``(pad,)`` or int is a shortcut
|
|
2624
2624
|
for ``before = after = pad width`` for all axes.
|
|
2625
|
-
mode (
|
|
2625
|
+
mode (str, optional):
|
|
2626
2626
|
One of the following string values:
|
|
2627
2627
|
|
|
2628
2628
|
- constant (default): Pads with a constant value.
|
|
@@ -2660,7 +2660,7 @@ def pad(arr, pad_width, mode="constant", stat_length=None, constant_values=0,
|
|
|
2660
2660
|
unique end values for each axis. ``((before, after),)`` yields same before
|
|
2661
2661
|
and after end values for each axis. ``(constant,)`` or ``constant``
|
|
2662
2662
|
is a shortcut for ``before = after = constant`` for all axes. Default: ``0`` .
|
|
2663
|
-
reflect_type(
|
|
2663
|
+
reflect_type(str, optional) can choose between \'even\' and \'odd\'. Used in
|
|
2664
2664
|
\'reflect\', and \'symmetric\'. The \'even\' style is the default with an
|
|
2665
2665
|
unaltered reflection around the edge value. For the \'odd\' style, the extended
|
|
2666
2666
|
part of the `arr` is created by subtracting the reflected values from two times
|
mindspore/numpy/fft.py
CHANGED
|
@@ -185,7 +185,7 @@ def rfft(a, n=None, axis=-1, norm=None):
|
|
|
185
185
|
Default: ``None``.
|
|
186
186
|
axis (int, optional): Axis over which to compute the `rfft`.
|
|
187
187
|
Default: ``-1``, which means the last axis of `a` is used.
|
|
188
|
-
norm (
|
|
188
|
+
norm (str, optional): Normalization mode. Default: ``None`` that means ``"backward"``.
|
|
189
189
|
Three modes are defined as,
|
|
190
190
|
|
|
191
191
|
- ``"backward"`` (no normalization).
|
|
@@ -224,7 +224,7 @@ def irfft(a, n=None, axis=-1, norm=None):
|
|
|
224
224
|
Default: ``None``.
|
|
225
225
|
axis (int, optional): Axis over which to compute the `irfft`.
|
|
226
226
|
Default: ``-1``, which means the last axis of `a` is used.
|
|
227
|
-
norm (
|
|
227
|
+
norm (str, optional): Normalization mode. Default: ``None`` that means ``"backward"``.
|
|
228
228
|
Three modes are defined as,
|
|
229
229
|
|
|
230
230
|
- ``"backward"`` (normalize by :math:`1/n`).
|
|
@@ -266,7 +266,7 @@ def fft2(a, s=None, axes=(-2, -1), norm=None):
|
|
|
266
266
|
Default: ``None`` , which does not need to process `a`.
|
|
267
267
|
axes (tuple[int], optional): The dimension along which to take the one dimensional `fft2`.
|
|
268
268
|
Default: ``(-2, -1)`` , which means transform the last two dimension of `a`.
|
|
269
|
-
norm (
|
|
269
|
+
norm (str, optional): Normalization mode. Default: ``None`` that means ``"backward"`` .
|
|
270
270
|
Three modes are defined as, where :math: `n = prod(s)`
|
|
271
271
|
|
|
272
272
|
- ``"backward"`` (no normalization).
|
|
@@ -361,7 +361,7 @@ def fftn(a, s=None, axes=None, norm=None):
|
|
|
361
361
|
axes (tuple[int], optional): The dimension along which to take the one dimensional `fftn`.
|
|
362
362
|
Default: ``None`` , which means transform the all dimension of `a`,
|
|
363
363
|
or the last `len(s)` dimensions if s is given.
|
|
364
|
-
norm (
|
|
364
|
+
norm (str, optional): Normalization mode. Default: ``None`` that means ``"backward"`` .
|
|
365
365
|
Three modes are defined as, where :math: `n = prod(s)`
|
|
366
366
|
|
|
367
367
|
- ``"backward"`` (no normalization).
|
|
@@ -409,7 +409,7 @@ def ifftn(a, s=None, axes=None, norm=None):
|
|
|
409
409
|
axes (tuple[int], optional): The dimension along which to take the one dimensional `ifftn`.
|
|
410
410
|
Default: ``None`` , which means transform the all dimension of `a`,
|
|
411
411
|
or the last `len(s)` dimensions if s is given.
|
|
412
|
-
norm (
|
|
412
|
+
norm (str, optional): Normalization mode. Default: ``None`` that means ``"backward"`` .
|
|
413
413
|
Three modes are defined as, where :math: `n = prod(s)`
|
|
414
414
|
|
|
415
415
|
- ``"backward"`` (normalize by :math:`1/n`).
|
|
@@ -457,7 +457,7 @@ def rfft2(a, s=None, axes=(-2, -1), norm=None):
|
|
|
457
457
|
Default: ``None`` , which does not need to process `a`.
|
|
458
458
|
axes (tuple[int], optional): The dimension along which to take the one dimensional `rfft2`.
|
|
459
459
|
Default: ``(-2, -1)`` , which means transform the last two dimension of `a`.
|
|
460
|
-
norm (
|
|
460
|
+
norm (str, optional): Normalization mode. Default: ``None`` that means ``"backward"`` .
|
|
461
461
|
Three modes are defined as, where :math: `n = prod(s)`
|
|
462
462
|
|
|
463
463
|
- ``"backward"`` (no normalization).
|
|
@@ -502,7 +502,7 @@ def irfft2(a, s=None, axes=(-2, -1), norm=None):
|
|
|
502
502
|
Default: ``None`` , the axes[-1] of the `a` will be zero-padded to :math:`2*(a.shape[axes[-1]]-1)`.
|
|
503
503
|
axes (tuple[int], optional): The dimension along which to take the one dimensional `irfft2`.
|
|
504
504
|
Default: ``(-2, -1)`` , which means transform the last two dimension of `a`.
|
|
505
|
-
norm (
|
|
505
|
+
norm (str, optional): Normalization mode. Default: ``None`` that means ``"backward"`` .
|
|
506
506
|
Three modes are defined as, where :math: `n = prod(s)`
|
|
507
507
|
|
|
508
508
|
- ``"backward"`` (normalize by :math:`1/n`).
|
|
@@ -551,7 +551,7 @@ def rfftn(a, s=None, axes=None, norm=None):
|
|
|
551
551
|
axes (tuple[int], optional): The dimension along which to take the one dimensional `rfftn`.
|
|
552
552
|
Default: ``None`` , which means transform the all dimension of `a`,
|
|
553
553
|
or the last `len(s)` dimensions if s is given.
|
|
554
|
-
norm (
|
|
554
|
+
norm (str, optional): Normalization mode. Default: ``None`` that means ``"backward"`` .
|
|
555
555
|
Three modes are defined as, where :math: `n = prod(s)`
|
|
556
556
|
|
|
557
557
|
- ``"backward"`` (no normalization).
|
|
@@ -599,7 +599,7 @@ def irfftn(a, s=None, axes=None, norm=None):
|
|
|
599
599
|
axes (tuple[int], optional): The dimension along which to take the one dimensional `irfftn`.
|
|
600
600
|
Default: ``None`` , which means transform the all dimension of `a`,
|
|
601
601
|
or the last `len(s)` dimensions if s is given.
|
|
602
|
-
norm (
|
|
602
|
+
norm (str, optional): Normalization mode. Default: ``None`` that means ``"backward"`` .
|
|
603
603
|
Three modes are defined as, where :math: `n = prod(s)`
|
|
604
604
|
|
|
605
605
|
- ``"backward"`` (normalize by :math:`1/n`).
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,13 +12,10 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
|
-
"""
|
|
16
|
-
|
|
17
|
-
"""
|
|
15
|
+
"""onnx module."""
|
|
16
|
+
|
|
18
17
|
from __future__ import absolute_import
|
|
19
18
|
|
|
20
|
-
from
|
|
19
|
+
from .onnx_export import export
|
|
21
20
|
|
|
22
|
-
__all__ = [
|
|
23
|
-
"TensorArray",
|
|
24
|
-
]
|
|
21
|
+
__all__ = ["export"]
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""Model export to ONNX."""
|
|
17
|
+
from __future__ import absolute_import
|
|
18
|
+
from __future__ import division
|
|
19
|
+
|
|
20
|
+
import os
|
|
21
|
+
|
|
22
|
+
import mindspore.nn as nn
|
|
23
|
+
from mindspore import log as logger
|
|
24
|
+
from mindspore._checkparam import check_input_dataset
|
|
25
|
+
from mindspore import _checkparam as Validator
|
|
26
|
+
from mindspore.common.api import _cell_graph_executor as _executor
|
|
27
|
+
from mindspore.train.serialization import _calculation_net_size
|
|
28
|
+
from mindspore.dataset.engine.datasets import Dataset
|
|
29
|
+
|
|
30
|
+
PROTO_LIMIT_SIZE = 1024 * 1024 * 2
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def export(net, *inputs, file_name, input_names=None, output_names=None, export_params=True,
|
|
34
|
+
keep_initializers_as_inputs=False, dynamic_axes=None):
|
|
35
|
+
"""
|
|
36
|
+
Export the MindSpore network into an ONNX model.
|
|
37
|
+
|
|
38
|
+
Note:
|
|
39
|
+
- Support exporting network larger than 2GB. When the network exceeds 2GB,
|
|
40
|
+
parameters are saved in additional binary files stored in the same directory as the ONNX file.
|
|
41
|
+
- When `file_name` does not have a suffix, the system will automatically add the suffix `.onnx` .
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
net (Union[Cell, function]): MindSpore network.
|
|
45
|
+
inputs (Union[Tensor, list, tuple, Number, bool]): It represents the inputs of the `net` , if the network has
|
|
46
|
+
multiple inputs, set them together.
|
|
47
|
+
file_name (str): File name of the model to be exported.
|
|
48
|
+
input_names (list, optional): Names to assign to the input nodes of the graph, in order. Default: ``None`` .
|
|
49
|
+
output_names (list, optional): Names to assign to the output nodes of the graph, in order. Default: ``None`` .
|
|
50
|
+
export_params (bool, optional): If false, parameters (weights) will not be exported,
|
|
51
|
+
parameters will add input nodes as input of the graph. Default: ``True`` .
|
|
52
|
+
keep_initializers_as_inputs (bool, optional): If True, all the initializers (model parameters/weights) will
|
|
53
|
+
add as inputs to the graph. This allows modifying any or all weights when running the exported ONNX model.
|
|
54
|
+
Default: ``False`` .
|
|
55
|
+
dynamic_axes (dict[str, dict[int, str]], optional): To specify axes of input tensors as dynamic (at runtime).
|
|
56
|
+
Default: ``None`` .
|
|
57
|
+
|
|
58
|
+
- Set a dict with scheme: {input_node_name: {axis_index:axis_name}},
|
|
59
|
+
for example, {"input1": {0:"batch_size", 1: "seq_len"}, "input2": {{0:"batch_size"}}.
|
|
60
|
+
- By default, the shapes of all input tensors in the exported model exactly match those specified in
|
|
61
|
+
`inputs`.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If the parameter `net` is not :class:`mindspore.nn.Cell`.
|
|
65
|
+
ValueError: If the parameter `input_names` is not list type.
|
|
66
|
+
ValueError: If the parameter `output_names` is not list type
|
|
67
|
+
ValueError: If the parameter `dynamic_axes` is not dict type.
|
|
68
|
+
|
|
69
|
+
Examples:
|
|
70
|
+
>>> import mindspore as ms
|
|
71
|
+
>>> import numpy as np
|
|
72
|
+
>>> from mindspore import Tensor
|
|
73
|
+
>>>
|
|
74
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
75
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
76
|
+
>>> net = LeNet5()
|
|
77
|
+
>>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
|
78
|
+
>>> ms.onnx.export(net, input_tensor, file_name='lenet.onnx', input_names=['input1'], output_names=['output1'])
|
|
79
|
+
|
|
80
|
+
"""
|
|
81
|
+
Validator.check_file_name_by_regular(file_name)
|
|
82
|
+
logger.info("exporting model file:%s format:%s.", file_name, "ONNX")
|
|
83
|
+
Validator.check_isinstance("net", net, nn.Cell)
|
|
84
|
+
input_names = input_names or []
|
|
85
|
+
Validator.check_isinstance("input_names", input_names, list)
|
|
86
|
+
output_names = output_names or []
|
|
87
|
+
Validator.check_isinstance("output_names", output_names, list)
|
|
88
|
+
dynamic_axes = dynamic_axes or {}
|
|
89
|
+
Validator.check_isinstance("dynamic_axes", dynamic_axes, dict)
|
|
90
|
+
|
|
91
|
+
if check_input_dataset(*inputs, dataset_type=Dataset):
|
|
92
|
+
raise ValueError(f"Can not support dataset as inputs to export ONNX model.")
|
|
93
|
+
|
|
94
|
+
cell_mode = net.training
|
|
95
|
+
net.set_train(mode=False)
|
|
96
|
+
|
|
97
|
+
extra_save_params = False
|
|
98
|
+
total_size = _calculation_net_size(net)
|
|
99
|
+
if total_size > PROTO_LIMIT_SIZE:
|
|
100
|
+
logger.warning('Network size is: {}G, it exceeded the protobuf: {}G limit, now parameters in network are saved '
|
|
101
|
+
'in external data files.'.format(total_size / 1024 / 1024, PROTO_LIMIT_SIZE / 1024 / 1024))
|
|
102
|
+
extra_save_params = True
|
|
103
|
+
|
|
104
|
+
phase_name = 'export.onnx'
|
|
105
|
+
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
|
106
|
+
|
|
107
|
+
abs_file_name = os.path.abspath(file_name)
|
|
108
|
+
if not abs_file_name.endswith('.onnx'):
|
|
109
|
+
abs_file_name += ".onnx"
|
|
110
|
+
|
|
111
|
+
dir_path = os.path.dirname(abs_file_name)
|
|
112
|
+
if dir_path and not os.path.exists(dir_path):
|
|
113
|
+
os.makedirs(dir_path, exist_ok=True)
|
|
114
|
+
|
|
115
|
+
abs_file_dir = os.path.dirname(abs_file_name) if extra_save_params else ""
|
|
116
|
+
|
|
117
|
+
onnx_stream = _executor._get_onnx_func_graph_proto(obj=net, exec_id=graph_id, input_names=input_names,
|
|
118
|
+
output_names=output_names, export_params=export_params,
|
|
119
|
+
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
|
120
|
+
dynamic_axes=dynamic_axes, extra_save_params=extra_save_params,
|
|
121
|
+
save_file_dir=abs_file_dir)
|
|
122
|
+
if onnx_stream is None:
|
|
123
|
+
raise RuntimeError("Export onnx model failed, ensure that the model has been compiled correctly")
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
with open(abs_file_name, 'wb') as f:
|
|
127
|
+
f.write(onnx_stream)
|
|
128
|
+
|
|
129
|
+
if os.path.getsize(abs_file_name) != len(onnx_stream):
|
|
130
|
+
logger.warning("ONNX file size doesn't match expected value, but proceeding continue.")
|
|
131
|
+
|
|
132
|
+
except IOError as e:
|
|
133
|
+
logger.error(f"Failed to write ONNX file: {e}")
|
|
134
|
+
if os.path.exists(abs_file_name):
|
|
135
|
+
os.remove(abs_file_name)
|
|
136
|
+
|
|
137
|
+
net.set_train(mode=cell_mode)
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
mindspore/ops/__init__.py
CHANGED
|
@@ -37,6 +37,7 @@ from mindspore.ops.functional_overload import all_gather_matmul, matmul_reduce_s
|
|
|
37
37
|
from mindspore.ops.composite import *
|
|
38
38
|
from mindspore.ops.operations import *
|
|
39
39
|
from mindspore.ops.function import *
|
|
40
|
+
from mindspore.ops.communication import *
|
|
40
41
|
from mindspore.ops.functional import *
|
|
41
42
|
from mindspore.ops._utils import arg_dtype_cast, arg_handler
|
|
42
43
|
|
|
@@ -55,4 +56,5 @@ __all__.extend(composite.__all__)
|
|
|
55
56
|
__all__.extend(operations.__all__)
|
|
56
57
|
__all__.extend(functional.__all__)
|
|
57
58
|
__all__.extend(function.__all__)
|
|
59
|
+
__all__.extend(communication.__all__)
|
|
58
60
|
__all__.extend(auto_generate.__all__)
|