mindspore 2.4.1__cp311-cp311-macosx_11_0_arm64.whl → 2.5.0__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +8 -3
- mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
- mindspore/_c_expression.cpython-311-darwin.so +0 -0
- mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
- mindspore/_checkparam.py +0 -5
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/compile_config.py +64 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
- mindspore/_extends/parse/parser.py +23 -5
- mindspore/_extends/parse/standard_method.py +123 -27
- mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
- mindspore/amp.py +7 -1
- mindspore/boost/boost_cell_wrapper.py +136 -41
- mindspore/common/__init__.py +3 -1
- mindspore/common/_register_for_tensor.py +0 -1
- mindspore/common/_stub_tensor.py +25 -4
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +6132 -0
- mindspore/common/api.py +99 -25
- mindspore/common/dtype.py +34 -34
- mindspore/common/dump.py +2 -1
- mindspore/common/file_system.py +8 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +3 -1
- mindspore/common/initializer.py +3 -4
- mindspore/common/lazy_inline.py +8 -2
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/parameter.py +30 -27
- mindspore/common/tensor.py +713 -1337
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +10 -0
- mindspore/communication/comm_func.py +215 -173
- mindspore/communication/management.py +23 -20
- mindspore/context.py +292 -193
- mindspore/dataset/__init__.py +23 -19
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +84 -3
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +5 -4
- mindspore/dataset/engine/datasets.py +192 -149
- mindspore/dataset/engine/datasets_audio.py +14 -0
- mindspore/dataset/engine/datasets_standard_format.py +28 -11
- mindspore/dataset/engine/datasets_text.py +38 -1
- mindspore/dataset/engine/datasets_user_defined.py +125 -65
- mindspore/dataset/engine/datasets_vision.py +81 -8
- mindspore/dataset/engine/iterators.py +281 -63
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +26 -2
- mindspore/dataset/engine/serializer_deserializer.py +1 -1
- mindspore/dataset/engine/validators.py +43 -11
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +29 -12
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +94 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +127 -0
- mindspore/device_context/cpu/__init__.py +25 -0
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +134 -0
- mindspore/experimental/llm_boost/__init__.py +3 -2
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +239 -64
- mindspore/experimental/llm_boost/atb/llama_boost.py +52 -30
- mindspore/experimental/llm_boost/atb/qwen_boost.py +47 -24
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/optim/adadelta.py +26 -22
- mindspore/experimental/optim/adam.py +3 -0
- mindspore/experimental/optim/lr_scheduler.py +33 -24
- mindspore/experimental/optim/radam.py +33 -30
- mindspore/hal/device.py +28 -0
- mindspore/hal/event.py +17 -0
- mindspore/hal/memory.py +94 -3
- mindspore/hal/stream.py +91 -6
- mindspore/include/api/context.h +1 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/lib/libavcodec.59.dylib +0 -0
- mindspore/lib/libavdevice.59.dylib +0 -0
- mindspore/lib/libavfilter.8.dylib +0 -0
- mindspore/lib/libavformat.59.dylib +0 -0
- mindspore/lib/libavutil.57.dylib +0 -0
- mindspore/lib/libdnnl.2.dylib +0 -0
- mindspore/lib/libmindspore_backend.dylib +0 -0
- mindspore/lib/libmindspore_common.dylib +0 -0
- mindspore/lib/libmindspore_core.dylib +0 -0
- mindspore/lib/libmindspore_glog.0.dylib +0 -0
- mindspore/lib/libmindspore_gpr.15.dylib +0 -0
- mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
- mindspore/lib/libmindspore_grpc.15.dylib +0 -0
- mindspore/lib/libmindspore_ops.dylib +0 -0
- mindspore/lib/libnnacl.dylib +0 -0
- mindspore/lib/libopencv_core.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
- mindspore/lib/libswresample.4.dylib +0 -0
- mindspore/lib/libswscale.6.dylib +0 -0
- mindspore/lib/libtinyxml2.8.dylib +0 -0
- mindspore/log.py +12 -0
- mindspore/mindrecord/__init__.py +1 -1
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mint/__init__.py +824 -218
- mindspore/mint/distributed/__init__.py +66 -4
- mindspore/mint/distributed/distributed.py +2594 -44
- mindspore/mint/linalg/__init__.py +6 -0
- mindspore/mint/nn/__init__.py +473 -14
- mindspore/mint/nn/functional.py +486 -11
- mindspore/mint/nn/layer/__init__.py +17 -4
- mindspore/mint/nn/layer/_functions.py +330 -0
- mindspore/mint/nn/layer/activation.py +169 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +727 -0
- mindspore/mint/nn/layer/normalization.py +215 -19
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +170 -0
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +2 -0
- mindspore/nn/cell.py +142 -21
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +6 -6
- mindspore/nn/layer/basic.py +35 -25
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/conv.py +3 -0
- mindspore/nn/layer/embedding.py +3 -3
- mindspore/nn/layer/normalization.py +8 -7
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +55 -23
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +48 -26
- mindspore/nn/learning_rate_schedule.py +5 -3
- mindspore/nn/loss/loss.py +31 -36
- mindspore/nn/optim/ada_grad.py +1 -0
- mindspore/nn/optim/adadelta.py +2 -2
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/utils/__init__.py +22 -0
- mindspore/nn/utils/init.py +73 -0
- mindspore/nn/wrap/cell_wrapper.py +4 -6
- mindspore/nn/wrap/loss_scale.py +3 -4
- mindspore/numpy/array_creations.py +60 -62
- mindspore/numpy/array_ops.py +148 -143
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +16 -16
- mindspore/numpy/utils_const.py +4 -4
- mindspore/ops/__init__.py +2 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +107 -8
- mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_vmap/vmap_array_ops.py +20 -19
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
- mindspore/ops/_vmap/vmap_math_ops.py +11 -9
- mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
- mindspore/ops/auto_generate/gen_extend_func.py +554 -60
- mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
- mindspore/ops/auto_generate/gen_ops_prim.py +8027 -3411
- mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
- mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
- mindspore/ops/function/__init__.py +12 -0
- mindspore/ops/function/array_func.py +561 -159
- mindspore/ops/function/clip_func.py +64 -0
- mindspore/ops/function/debug_func.py +28 -20
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +5 -4
- mindspore/ops/function/math_func.py +1664 -294
- mindspore/ops/function/nn_func.py +988 -317
- mindspore/ops/function/parameter_func.py +3 -56
- mindspore/ops/function/random_func.py +243 -33
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/functional.py +18 -5
- mindspore/ops/functional_overload.py +897 -0
- mindspore/ops/operations/__init__.py +3 -2
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -34
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +38 -8
- mindspore/ops/operations/array_ops.py +45 -303
- mindspore/ops/operations/comm_ops.py +23 -17
- mindspore/ops/operations/custom_ops.py +7 -49
- mindspore/ops/operations/debug_ops.py +42 -47
- mindspore/ops/operations/inner_ops.py +6 -4
- mindspore/ops/operations/linalg_ops.py +3 -2
- mindspore/ops/operations/manually_defined/ops_def.py +185 -104
- mindspore/ops/operations/math_ops.py +11 -216
- mindspore/ops/operations/nn_ops.py +153 -310
- mindspore/ops/primitive.py +23 -21
- mindspore/ops/tensor_method.py +1669 -0
- mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
- mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
- mindspore/ops_generate/arg_handler.py +0 -61
- mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
- mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/base_generator.py +11 -0
- mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
- mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
- mindspore/ops_generate/functional_overload_py_generator.py +110 -0
- mindspore/ops_generate/functions_cc_generator.py +233 -0
- mindspore/ops_generate/gen_aclnn_implement.py +110 -114
- mindspore/ops_generate/gen_constants.py +157 -3
- mindspore/ops_generate/gen_ops.py +245 -990
- mindspore/ops_generate/gen_pyboost_func.py +97 -998
- mindspore/ops_generate/gen_utils.py +119 -33
- mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
- mindspore/ops_generate/op_api_proto.py +206 -0
- mindspore/ops_generate/op_def_py_generator.py +131 -0
- mindspore/ops_generate/op_prim_py_generator.py +480 -0
- mindspore/ops_generate/op_proto.py +373 -108
- mindspore/ops_generate/op_template_parser.py +436 -0
- mindspore/ops_generate/ops_def_cc_generator.py +288 -0
- mindspore/ops_generate/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/ops_name_h_generator.py +68 -0
- mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
- mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
- mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
- mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
- mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
- mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
- mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
- mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
- mindspore/ops_generate/pyboost_utils.py +92 -33
- mindspore/ops_generate/template.py +294 -44
- mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
- mindspore/parallel/__init__.py +3 -3
- mindspore/parallel/_auto_parallel_context.py +44 -34
- mindspore/parallel/_cell_wrapper.py +22 -3
- mindspore/parallel/_parallel_serialization.py +13 -2
- mindspore/parallel/_utils.py +4 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +44 -0
- mindspore/parallel/cluster/process_entity/_api.py +131 -37
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +20 -3
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +3 -0
- mindspore/parallel/transform_safetensors.py +119 -253
- mindspore/profiler/__init__.py +17 -4
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +174 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +202 -0
- mindspore/profiler/common/path_manager.py +371 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +476 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +210 -0
- mindspore/profiler/common/profiler_path_manager.py +120 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +270 -37
- mindspore/profiler/envprofiler.py +138 -0
- mindspore/profiler/mstx.py +199 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +309 -0
- mindspore/profiler/profiler.py +580 -93
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +114 -0
- mindspore/profiler/schedule.py +208 -0
- mindspore/rewrite/api/symbol_tree.py +1 -2
- mindspore/run_check/_check_version.py +18 -13
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +148 -0
- mindspore/runtime/memory.py +392 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/train/__init__.py +2 -2
- mindspore/train/_utils.py +53 -18
- mindspore/train/amp.py +8 -4
- mindspore/train/callback/_checkpoint.py +32 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +105 -69
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_summary_collector.py +44 -6
- mindspore/train/callback/_tft_register.py +37 -15
- mindspore/train/dataset_helper.py +11 -11
- mindspore/train/metrics/precision.py +4 -5
- mindspore/train/mind_ir_pb2.py +167 -46
- mindspore/train/model.py +13 -14
- mindspore/train/serialization.py +461 -72
- mindspore/train/summary/summary_record.py +1 -2
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/utils/__init__.py +4 -2
- mindspore/utils/dryrun.py +138 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/METADATA +3 -4
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/RECORD +370 -244
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/lib/libmindspore_np_dtype.dylib +0 -0
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
mindspore/.commit_id
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__commit_id__ = ''[sha1]:
|
|
1
|
+
__commit_id__ = ''[sha1]:241c405d,[branch]:(HEAD,origin/master,origin/HEAD,master)''
|
mindspore/__init__.py
CHANGED
|
@@ -17,13 +17,17 @@ from __future__ import absolute_import
|
|
|
17
17
|
|
|
18
18
|
from mindspore.run_check import run_check
|
|
19
19
|
from mindspore import common, dataset, mindrecord, train, log, amp
|
|
20
|
-
from mindspore import profiler, communication, numpy, parallel, hal
|
|
20
|
+
from mindspore import profiler, communication, numpy, parallel, hal, runtime, device_context
|
|
21
21
|
from mindspore.common import *
|
|
22
|
+
from mindspore.common import _tensor_docs
|
|
23
|
+
del _tensor_docs
|
|
22
24
|
from mindspore.mindrecord import *
|
|
23
25
|
from mindspore.ops import _op_impl, grad, value_and_grad, vjp, jvp, jacfwd, jacrev, vmap, get_grad, constexpr, reshard
|
|
24
26
|
from mindspore.train import *
|
|
25
27
|
from mindspore.log import *
|
|
26
28
|
from mindspore.utils import *
|
|
29
|
+
from mindspore.device_manager import *
|
|
30
|
+
from mindspore.runtime import *
|
|
27
31
|
from mindspore.context import GRAPH_MODE, PYNATIVE_MODE, set_context, get_context, set_auto_parallel_context, \
|
|
28
32
|
get_auto_parallel_context, reset_auto_parallel_context, ParallelMode, set_ps_context, \
|
|
29
33
|
get_ps_context, reset_ps_context, set_offload_context, get_offload_context, STRICT, COMPATIBLE, LAX
|
|
@@ -31,8 +35,7 @@ from mindspore.version import __version__
|
|
|
31
35
|
from mindspore.profiler import Profiler, EnvProfiler
|
|
32
36
|
from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters, \
|
|
33
37
|
rank_list_for_transform, transform_checkpoint_by_rank, transform_checkpoints, merge_pipeline_strategys, shard, \
|
|
34
|
-
Layout, sync_pipeline_shared_parameters, parameter_broadcast, load_segmented_checkpoints,
|
|
35
|
-
safetensors_to_ckpt, ckpt_to_safetensors, unified_safetensors
|
|
38
|
+
Layout, sync_pipeline_shared_parameters, parameter_broadcast, load_segmented_checkpoints, unified_safetensors
|
|
36
39
|
from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType
|
|
37
40
|
from mindspore.safeguard import obfuscate_ckpt, load_obf_params_into_net
|
|
38
41
|
from mindspore._check_jit_forbidden_api import get_obj_module_and_name_info, is_jit_forbidden_module, \
|
|
@@ -50,4 +53,6 @@ __all__.extend(context.__all__)
|
|
|
50
53
|
__all__.extend(parallel.__all__)
|
|
51
54
|
__all__.extend(rewrite.__all__)
|
|
52
55
|
__all__.extend(safeguard.__all__)
|
|
56
|
+
__all__.extend(device_manager.__all__)
|
|
57
|
+
__all__.extend(runtime.__all__)
|
|
53
58
|
__all__.append("Profiler")
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
mindspore/_checkparam.py
CHANGED
|
@@ -1373,11 +1373,6 @@ def args_type_check(*type_args, **type_kwargs):
|
|
|
1373
1373
|
|
|
1374
1374
|
def check_hook_fn(hook_type, hook_fn):
|
|
1375
1375
|
"""Check hook fn"""
|
|
1376
|
-
if context.get_context("mode") != context.PYNATIVE_MODE:
|
|
1377
|
-
logger.warning(f"'{hook_type}' function is only supported in pynative mode, you can use "
|
|
1378
|
-
f"context.set_context to set pynative mode.")
|
|
1379
|
-
return False
|
|
1380
|
-
|
|
1381
1376
|
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
|
1382
1377
|
raise TypeError(f"When using 'hook_type(hook_fn)', the type of 'hook_fn' must be python "
|
|
1383
1378
|
f"function, but got {type(hook_fn)}.")
|
|
@@ -23,7 +23,7 @@ from itertools import product
|
|
|
23
23
|
|
|
24
24
|
SUPPORTED_INPUT_NUM = [1, 2, 3, 4, 5, 6, 7]
|
|
25
25
|
SUPPORTED_OUTPUT_NUM = [1, 2, 3, 4, 5]
|
|
26
|
-
SUPPORTED_DEVICE_ARCH = ["ascend910", "ascend910b"]
|
|
26
|
+
SUPPORTED_DEVICE_ARCH = ["ascend910", "ascend910b", "ascend910_93"]
|
|
27
27
|
VALUE_ALL = "all"
|
|
28
28
|
VALUE = "value"
|
|
29
29
|
NAME = "name"
|
|
@@ -267,6 +267,64 @@ Value Range:
|
|
|
267
267
|
"""
|
|
268
268
|
STRICT_CHECK_PARENT_CONTEXT = ''
|
|
269
269
|
|
|
270
|
+
"""
|
|
271
|
+
Name: CELL_PARAMETER_HOOK
|
|
272
|
+
Function: Whether to enable cell parameter hook.
|
|
273
|
+
Cell parameter hook is an experimental api that may be deleted later.
|
|
274
|
+
Value Range:
|
|
275
|
+
1: Enable
|
|
276
|
+
Default: Disable
|
|
277
|
+
"""
|
|
278
|
+
CELL_PARAMETERS_HOOK = ''
|
|
279
|
+
|
|
280
|
+
"""
|
|
281
|
+
Name: CHECK_BPROP
|
|
282
|
+
Function: Whether to check back propagation nodes. The checking ensures that the shape and dtype of
|
|
283
|
+
back propagation node outputs is the same as input parameters.
|
|
284
|
+
Value Range:
|
|
285
|
+
1: Enable
|
|
286
|
+
Default: Disable.
|
|
287
|
+
"""
|
|
288
|
+
CHECK_BPROP = ''
|
|
289
|
+
|
|
290
|
+
"""
|
|
291
|
+
Name: GRAD_FOR_SCALAR
|
|
292
|
+
Function: Whether to get gradient for scalar. When enable, the function's scalar input can be derived.
|
|
293
|
+
Because the back-end does not support scaling operations currently, this interface only
|
|
294
|
+
supports simple operations that can be deduced by the front-end.
|
|
295
|
+
Value Range:
|
|
296
|
+
1: Enable
|
|
297
|
+
Default: Disable.
|
|
298
|
+
"""
|
|
299
|
+
GRAD_FOR_SCALAR = ''
|
|
300
|
+
|
|
301
|
+
"""
|
|
302
|
+
Name: DEBUG_LEVEL
|
|
303
|
+
Function: Whether to record more debug information in compiling process. Used for debugging when errors occur.
|
|
304
|
+
Value Range:
|
|
305
|
+
1: Enable
|
|
306
|
+
Default: Disable.
|
|
307
|
+
"""
|
|
308
|
+
DEBUG_LEVEL = ''
|
|
309
|
+
|
|
310
|
+
"""
|
|
311
|
+
Name: PYNATIVE_JIT_GRAD_MODE
|
|
312
|
+
Function: Which method used for grad jit in pynative mode
|
|
313
|
+
Value Range:
|
|
314
|
+
1: Replace ValueNode
|
|
315
|
+
Default: Parametrization
|
|
316
|
+
"""
|
|
317
|
+
PYNATIVE_JIT_GRAD_MODE = ''
|
|
318
|
+
|
|
319
|
+
"""
|
|
320
|
+
Name: PUT_ALL_CNODE_INTO_ORDER_LIST
|
|
321
|
+
Function: Whether to put all CNode into order list in back prop.
|
|
322
|
+
Value Range:
|
|
323
|
+
0: Disable
|
|
324
|
+
Default: Enable.
|
|
325
|
+
"""
|
|
326
|
+
PUT_ALL_CNODE_INTO_ORDER_LIST = ''
|
|
327
|
+
|
|
270
328
|
__all__ = [
|
|
271
329
|
"COMPILE_PROFILE",
|
|
272
330
|
"COMPILE_PROFILE_FINISH_ACTION",
|
|
@@ -296,4 +354,10 @@ __all__ = [
|
|
|
296
354
|
"ENABLE_RECOMPUTE_BEFORE_INLINE",
|
|
297
355
|
"STRICT_CHECK_PARENT_CONTEXT",
|
|
298
356
|
"AUTO_PASSES_OPTIMIZE_PATH",
|
|
357
|
+
"CELL_PARAMETERS_HOOK",
|
|
358
|
+
"CHECK_BPROP",
|
|
359
|
+
"GRAD_FOR_SCALAR",
|
|
360
|
+
"DEBUG_LEVEL",
|
|
361
|
+
"PYNATIVE_JIT_GRAD_MODE",
|
|
362
|
+
"PUT_ALL_CNODE_INTO_ORDER_LIST",
|
|
299
363
|
]
|
|
File without changes
|
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
2
|
+
#
|
|
3
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
# ============================================================================
|
|
17
|
+
"""Deprecated Tensor method"""
|
|
18
|
+
|
|
19
|
+
deprecated_tensor_method_map = {
|
|
20
|
+
# 1 to
|
|
21
|
+
|
|
22
|
+
# 2 masked_fill
|
|
23
|
+
|
|
24
|
+
# 3 abs
|
|
25
|
+
|
|
26
|
+
# 4 __abs__
|
|
27
|
+
|
|
28
|
+
# 5 add
|
|
29
|
+
"add": "deprecated_tensor_add",
|
|
30
|
+
"add_": "deprecated_tensor_add_",
|
|
31
|
+
# 6 all
|
|
32
|
+
"all": "tensor_all",
|
|
33
|
+
# 7 allclose
|
|
34
|
+
"allclose": "tensor_allclose",
|
|
35
|
+
# 8 any
|
|
36
|
+
"any": "tensor_any",
|
|
37
|
+
# 9 arctan2
|
|
38
|
+
"arctan2": "tensor_arctan2",
|
|
39
|
+
# 10 argmax
|
|
40
|
+
"argmax": "deprecated_tensor_argmax",
|
|
41
|
+
# 11 argmin
|
|
42
|
+
"argmin": "deprecated_tensor_argmin",
|
|
43
|
+
# 12 argsort
|
|
44
|
+
"argsort": "deprecated_tensor_argsort",
|
|
45
|
+
# 13 atan2
|
|
46
|
+
"atan2": "tensor_atan2",
|
|
47
|
+
# 14 bfloat16
|
|
48
|
+
|
|
49
|
+
# 15 bmm
|
|
50
|
+
|
|
51
|
+
# 16 bool
|
|
52
|
+
|
|
53
|
+
# 17 broadcast_to
|
|
54
|
+
|
|
55
|
+
# 18 byte
|
|
56
|
+
|
|
57
|
+
# 19 ceil
|
|
58
|
+
|
|
59
|
+
# 20 chunk
|
|
60
|
+
"chunk": "deprecated_tensor_chunk",
|
|
61
|
+
# 21 clamp
|
|
62
|
+
|
|
63
|
+
# 22 clip
|
|
64
|
+
|
|
65
|
+
# 23 cos
|
|
66
|
+
|
|
67
|
+
# 24 cumprod
|
|
68
|
+
|
|
69
|
+
# 25 cumsum
|
|
70
|
+
"cumsum": "deprecated_tensor_cumsum",
|
|
71
|
+
# 26 dim
|
|
72
|
+
|
|
73
|
+
# 27 div
|
|
74
|
+
"div": "tensor_div",
|
|
75
|
+
# 28 divide
|
|
76
|
+
|
|
77
|
+
# 29 eq
|
|
78
|
+
|
|
79
|
+
# 30 erf
|
|
80
|
+
|
|
81
|
+
# 31 exp
|
|
82
|
+
|
|
83
|
+
# 32 expand
|
|
84
|
+
|
|
85
|
+
# 33 expand_as
|
|
86
|
+
"expand_as": "deprecated_tensor_expand_as",
|
|
87
|
+
# 34 flatten
|
|
88
|
+
"flatten": "deprecated_tensor_flatten",
|
|
89
|
+
# 35 flip
|
|
90
|
+
|
|
91
|
+
# 36 float
|
|
92
|
+
|
|
93
|
+
# 37 floor
|
|
94
|
+
|
|
95
|
+
# 38 gather
|
|
96
|
+
"gather": "deprecated_tensor_gather",
|
|
97
|
+
# 39 greater
|
|
98
|
+
|
|
99
|
+
# 40 greater_equal
|
|
100
|
+
|
|
101
|
+
# 41 gt
|
|
102
|
+
|
|
103
|
+
# 42 half
|
|
104
|
+
|
|
105
|
+
# 43 index_put
|
|
106
|
+
|
|
107
|
+
# 44 index_select
|
|
108
|
+
"index_select": "deprecated_tensor_index_select",
|
|
109
|
+
# 45 int
|
|
110
|
+
|
|
111
|
+
# 46 inverse
|
|
112
|
+
"inverse": "deprecated_tensor_inverse",
|
|
113
|
+
# 47 is_contiguous
|
|
114
|
+
|
|
115
|
+
# 48 isclose
|
|
116
|
+
"isclose": "deprecated_tensor_isclose",
|
|
117
|
+
# 49 isfinite
|
|
118
|
+
|
|
119
|
+
# 50 isnan
|
|
120
|
+
|
|
121
|
+
# 51 item
|
|
122
|
+
|
|
123
|
+
# 52 le
|
|
124
|
+
|
|
125
|
+
# 53 less
|
|
126
|
+
|
|
127
|
+
# 54 less_equal
|
|
128
|
+
|
|
129
|
+
# 55 log
|
|
130
|
+
|
|
131
|
+
# 56 log2
|
|
132
|
+
"log2": "tensor_log2",
|
|
133
|
+
# 57 logical_and
|
|
134
|
+
"logical_and": "tensor_logical_and",
|
|
135
|
+
# 58 logical_not
|
|
136
|
+
|
|
137
|
+
# 59 logical_or
|
|
138
|
+
"logical_or": "tensor_logical_or",
|
|
139
|
+
|
|
140
|
+
# 60 long
|
|
141
|
+
|
|
142
|
+
# 61 lt
|
|
143
|
+
|
|
144
|
+
# 62 masked_fill
|
|
145
|
+
|
|
146
|
+
# 63 masked_select
|
|
147
|
+
|
|
148
|
+
# 64 matmul
|
|
149
|
+
"matmul": "deprecated_tensor_matmul",
|
|
150
|
+
# 65 max
|
|
151
|
+
"max": "deprecated_tensor_max",
|
|
152
|
+
# 66 maximum
|
|
153
|
+
|
|
154
|
+
# 67 mean
|
|
155
|
+
"mean": "deprecated_tensor_mean",
|
|
156
|
+
# 68 min
|
|
157
|
+
"min": "deprecated_tensor_min",
|
|
158
|
+
# 69 minimum
|
|
159
|
+
|
|
160
|
+
# 70 mul
|
|
161
|
+
|
|
162
|
+
# 71 nan_to_num
|
|
163
|
+
|
|
164
|
+
# 72 narrow
|
|
165
|
+
"narrow": "deprecated_tensor_narrow",
|
|
166
|
+
# 73 ne
|
|
167
|
+
|
|
168
|
+
# 74 neg
|
|
169
|
+
|
|
170
|
+
# 75 negative
|
|
171
|
+
|
|
172
|
+
# 76 nonzero
|
|
173
|
+
|
|
174
|
+
# 77 norm
|
|
175
|
+
|
|
176
|
+
# 78 numel
|
|
177
|
+
|
|
178
|
+
# 79 numpy
|
|
179
|
+
|
|
180
|
+
# 80 outer
|
|
181
|
+
"outer": "deprecated_tensor_outer",
|
|
182
|
+
# 81 permute
|
|
183
|
+
|
|
184
|
+
# 82 pow
|
|
185
|
+
"pow": "deprecated_tensor_pow",
|
|
186
|
+
# 83 prod
|
|
187
|
+
"prod": "deprecated_tensor_prod",
|
|
188
|
+
# 84 reciprocal
|
|
189
|
+
|
|
190
|
+
# 85 remainder
|
|
191
|
+
"remainder": "deprecated_tensor_remainder",
|
|
192
|
+
|
|
193
|
+
# 86 repeat
|
|
194
|
+
|
|
195
|
+
# 87 repeat_interleave
|
|
196
|
+
"repeat_interleave": "deprecated_tensor_repeat_interleave",
|
|
197
|
+
# 88 reshape
|
|
198
|
+
|
|
199
|
+
# 89 round
|
|
200
|
+
|
|
201
|
+
# 90 rsqrt
|
|
202
|
+
|
|
203
|
+
# 91 scatter
|
|
204
|
+
"scatter": "deprecated_tensor_scatter",
|
|
205
|
+
|
|
206
|
+
# 92 scatter_add
|
|
207
|
+
"scatter_add": "deprecated_tensor_scatter_add",
|
|
208
|
+
# 93 select
|
|
209
|
+
"select": "deprecated_tensor_select",
|
|
210
|
+
# 94 sigmoid
|
|
211
|
+
|
|
212
|
+
# 95 sin
|
|
213
|
+
|
|
214
|
+
# 96 size
|
|
215
|
+
|
|
216
|
+
# 97 sort
|
|
217
|
+
"sort": "deprecated_tensor_sort",
|
|
218
|
+
# 98 split
|
|
219
|
+
"split": "deprecated_tensor_split",
|
|
220
|
+
# 99 sqrt
|
|
221
|
+
|
|
222
|
+
# 100 square
|
|
223
|
+
|
|
224
|
+
# 101 squeeze
|
|
225
|
+
|
|
226
|
+
# 102 std
|
|
227
|
+
"std": "deprecated_tensor_std",
|
|
228
|
+
# 103 sub
|
|
229
|
+
"sub": "deprecated_tensor_sub",
|
|
230
|
+
# 104 sum
|
|
231
|
+
"sum": "deprecated_tensor_sum",
|
|
232
|
+
# 105 swapaxes
|
|
233
|
+
|
|
234
|
+
# 106 t
|
|
235
|
+
"t": "deprecated_tensor_t",
|
|
236
|
+
# 107 tanh
|
|
237
|
+
|
|
238
|
+
# 108 tile
|
|
239
|
+
"tile": "deprecated_tensor_tile",
|
|
240
|
+
# 109 tolist
|
|
241
|
+
|
|
242
|
+
# 110 topk
|
|
243
|
+
"topk": "deprecated_tensor_topk",
|
|
244
|
+
# 111 transpose
|
|
245
|
+
"transpose": "deprecated_tensor_transpose",
|
|
246
|
+
# 112 tril
|
|
247
|
+
"tril": "deprecated_tensor_tril",
|
|
248
|
+
# 113 trunc
|
|
249
|
+
|
|
250
|
+
# 114 type
|
|
251
|
+
|
|
252
|
+
# 115 type_as
|
|
253
|
+
"type_as": "deprecated_tensor_type_as",
|
|
254
|
+
# 116 unbind
|
|
255
|
+
"unbind": "deprecated_tensor_unbind",
|
|
256
|
+
# 117 unfold
|
|
257
|
+
|
|
258
|
+
# 118 unique
|
|
259
|
+
"unique": "deprecated_tensor_unique",
|
|
260
|
+
# 119 unsqeeze
|
|
261
|
+
|
|
262
|
+
# 120 view
|
|
263
|
+
|
|
264
|
+
# 121 contiguous
|
|
265
|
+
|
|
266
|
+
# 122 where
|
|
267
|
+
"where": "deprecated_tensor_where",
|
|
268
|
+
# 123 div_
|
|
269
|
+
|
|
270
|
+
# 124 fill_
|
|
271
|
+
|
|
272
|
+
# 125 floor_
|
|
273
|
+
|
|
274
|
+
# 126 masked_fill_
|
|
275
|
+
|
|
276
|
+
# 127 mul_
|
|
277
|
+
|
|
278
|
+
# 128 normal_
|
|
279
|
+
|
|
280
|
+
# 129 requires_grad_
|
|
281
|
+
|
|
282
|
+
# 130 sub_
|
|
283
|
+
"sub_": "deprecated_tensor_sub_",
|
|
284
|
+
# 131 uniform_
|
|
285
|
+
|
|
286
|
+
# 132 absolute
|
|
287
|
+
|
|
288
|
+
# 133 bincount
|
|
289
|
+
"bincount": "tensor_bincount",
|
|
290
|
+
# 134 diff
|
|
291
|
+
|
|
292
|
+
# 135 double
|
|
293
|
+
|
|
294
|
+
# 136 lcm
|
|
295
|
+
|
|
296
|
+
# 137 mm
|
|
297
|
+
"mm": "deprecated_tensor_mm",
|
|
298
|
+
# 138 ravel
|
|
299
|
+
|
|
300
|
+
# 139 nelement
|
|
301
|
+
|
|
302
|
+
# 140 stride
|
|
303
|
+
|
|
304
|
+
# 141 indices
|
|
305
|
+
|
|
306
|
+
# 142 view_as
|
|
307
|
+
"view_as": "deprecated_tensor_view_as",
|
|
308
|
+
# 143 values
|
|
309
|
+
|
|
310
|
+
# 144 index_copy
|
|
311
|
+
|
|
312
|
+
# 145 element_size
|
|
313
|
+
|
|
314
|
+
# 146 gcd
|
|
315
|
+
|
|
316
|
+
# 147 isinf
|
|
317
|
+
|
|
318
|
+
# 148 not_equal
|
|
319
|
+
|
|
320
|
+
# 149 triu
|
|
321
|
+
|
|
322
|
+
# 150 __eq__
|
|
323
|
+
|
|
324
|
+
# 151
|
|
325
|
+
|
|
326
|
+
# 152
|
|
327
|
+
'median': 'deprecated_tensor_median',
|
|
328
|
+
|
|
329
|
+
# 153 acos, arccos; acosh, arccosh, asin, arcsin; asinh, arcsinh; atan, arctan; dot
|
|
330
|
+
"acos": "deprecated_tensor_acos",
|
|
331
|
+
"arccos": "deprecated_tensor_arccos",
|
|
332
|
+
"acosh": "deprecated_tensor_acosh",
|
|
333
|
+
"arccosh": "deprecated_tensor_arccosh",
|
|
334
|
+
"asin": "deprecated_tensor_asin",
|
|
335
|
+
"arcsin": "deprecated_tensor_arcsin",
|
|
336
|
+
"asinh": "deprecated_tensor_asinh",
|
|
337
|
+
"arcsinh": "deprecated_tensor_arcsinh",
|
|
338
|
+
"atan": "deprecated_tensor_atan",
|
|
339
|
+
"arctan": "deprecated_tensor_arctan",
|
|
340
|
+
"dot": "deprecated_tensor_dot",
|
|
341
|
+
|
|
342
|
+
# 153
|
|
343
|
+
|
|
344
|
+
# 154
|
|
345
|
+
|
|
346
|
+
# 155
|
|
347
|
+
|
|
348
|
+
# 156
|
|
349
|
+
|
|
350
|
+
# 157
|
|
351
|
+
|
|
352
|
+
# 158
|
|
353
|
+
"unsqueeze": "deprecated_tensor_unsqueeze",
|
|
354
|
+
# 159 histc
|
|
355
|
+
"histc": "tensor_histc",
|
|
356
|
+
|
|
357
|
+
# 160 frac
|
|
358
|
+
"frac": "tensor_frac",
|
|
359
|
+
|
|
360
|
+
# 161
|
|
361
|
+
"fmod": "deprecated_tensor_fmod",
|
|
362
|
+
# 162 log10
|
|
363
|
+
"log10": "tensor_log10",
|
|
364
|
+
|
|
365
|
+
# 501
|
|
366
|
+
"addbmm": "deprecated_tensor_addbmm",
|
|
367
|
+
# 931
|
|
368
|
+
"nansum": "deprecated_tensor_nansum",
|
|
369
|
+
# 502
|
|
370
|
+
"addmm": "deprecated_tensor_addmm",
|
|
371
|
+
# 790 addmv
|
|
372
|
+
"addmv": "deprecated_tensor_addmv",
|
|
373
|
+
# 1028
|
|
374
|
+
"var": "deprecated_tensor_var",
|
|
375
|
+
}
|
|
@@ -107,7 +107,22 @@ MODULE_FROM_USER_WORKSPACE = 2
|
|
|
107
107
|
# Process expr statement white list
|
|
108
108
|
# Add as needed, eg: "clear", "extend", "insert", "remove", "reverse"
|
|
109
109
|
parse_expr_statement_white_list = (
|
|
110
|
-
"append", "insert", "clear", "reverse", "extend", "update",
|
|
110
|
+
"append", "insert", "clear", "reverse", "extend", "update", "register_hook",
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Methods that need to reorder after it's caller is used before
|
|
114
|
+
# e.g. We need to reorder `x.register_hook` after x is used in `out = x + 1` when `register_hook` is called.
|
|
115
|
+
# def construct(x):
|
|
116
|
+
# out = x + 1
|
|
117
|
+
# x.register_hook(hook_fn)
|
|
118
|
+
# return out
|
|
119
|
+
# equals to:
|
|
120
|
+
# def construct(x):
|
|
121
|
+
# x = x.register_hook(hook_fn) # register_hook will return itself when it is called in the graph (in `GRAPH_MODE`).
|
|
122
|
+
# out = x + 1
|
|
123
|
+
# return out
|
|
124
|
+
_need_reorder_methods = (
|
|
125
|
+
"register_hook",
|
|
111
126
|
)
|
|
112
127
|
|
|
113
128
|
_builtin_function_or_method_type = type(abs)
|
|
@@ -666,7 +681,7 @@ def expand_expr_statement(node):
|
|
|
666
681
|
Process the expr statement and expand it.
|
|
667
682
|
|
|
668
683
|
Returns:
|
|
669
|
-
|
|
684
|
+
(False,)/(True, expr.value, target, bool)/(True, expr.value).
|
|
670
685
|
"""
|
|
671
686
|
if isinstance(node, ast.Expr):
|
|
672
687
|
expr_value = node.value
|
|
@@ -679,7 +694,7 @@ def expand_expr_statement(node):
|
|
|
679
694
|
target = func.value
|
|
680
695
|
if method in parse_expr_statement_white_list:
|
|
681
696
|
logger.debug("Expand expr, target:%s, method:%s", target, method)
|
|
682
|
-
return True, expr_value, target
|
|
697
|
+
return True, expr_value, target, method in _need_reorder_methods
|
|
683
698
|
if not isinstance(expr_value, ast.Str):
|
|
684
699
|
return True, expr_value
|
|
685
700
|
return (False,)
|
|
@@ -1075,7 +1090,9 @@ def is_ms_tensor_method(obj):
|
|
|
1075
1090
|
if not hasattr(obj, '__name__') or not hasattr(Tensor, obj.__name__):
|
|
1076
1091
|
return False
|
|
1077
1092
|
fn = inspect.unwrap(obj.__func__ if isinstance(obj, types.MethodType) else obj)
|
|
1078
|
-
|
|
1093
|
+
tensor_method = getattr(Tensor, obj.__name__)
|
|
1094
|
+
tensor_method = tensor_method.__func__ if hasattr(tensor_method, "__func__") else tensor_method
|
|
1095
|
+
return fn == tensor_method
|
|
1079
1096
|
|
|
1080
1097
|
|
|
1081
1098
|
def can_constant_fold(obj):
|
|
@@ -1224,7 +1241,8 @@ class Parser:
|
|
|
1224
1241
|
self.col_offset = \
|
|
1225
1242
|
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
|
|
1226
1243
|
logger.debug("Get source: %s", src)
|
|
1227
|
-
self.
|
|
1244
|
+
if not hasattr(self.fn, attr):
|
|
1245
|
+
self.check_lambda(src)
|
|
1228
1246
|
try:
|
|
1229
1247
|
ast_tokens = asttokens.ASTTokens(src, parse=True)
|
|
1230
1248
|
except IndentationError as idt_err:
|