mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +5 -2
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
- mindspore/_extends/parse/parser.py +28 -22
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +23 -2
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
- mindspore/amp.py +0 -18
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/__init__.py +18 -12
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +371 -96
- mindspore/common/_utils.py +7 -43
- mindspore/common/api.py +434 -135
- mindspore/common/dtype.py +98 -57
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/hook_handle.py +82 -3
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +17 -127
- mindspore/common/recompute.py +4 -13
- mindspore/common/tensor.py +50 -217
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +20 -106
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +35 -1
- mindspore/dataset/engine/datasets.py +338 -319
- mindspore/dataset/engine/datasets_user_defined.py +38 -22
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/api/cell.h +28 -4
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +0 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +5 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +6 -1
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +4 -3
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/__init__.py +4 -0
- mindspore/mint/distributed/distributed.py +392 -69
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/_functions.py +1 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +10 -10
- mindspore/mint/nn/layer/normalization.py +11 -16
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +231 -239
- mindspore/nn/layer/activation.py +4 -2
- mindspore/nn/layer/basic.py +56 -14
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +32 -127
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +1 -4
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +2 -4
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +39 -5
- mindspore/nn/wrap/grad_reducer.py +4 -89
- mindspore/numpy/array_creations.py +4 -4
- mindspore/numpy/fft.py +9 -9
- mindspore/numpy/utils_const.py +1 -1
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +1 -5
- mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
- mindspore/ops/auto_generate/gen_extend_func.py +6 -11
- mindspore/ops/auto_generate/gen_ops_def.py +385 -154
- mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +16 -2
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +2 -0
- mindspore/ops/function/array_func.py +24 -18
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +7 -6
- mindspore/ops/function/grad/grad_func.py +4 -12
- mindspore/ops/function/math_func.py +89 -86
- mindspore/ops/function/nn_func.py +92 -313
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +4 -1
- mindspore/ops/functional_overload.py +377 -30
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +12 -50
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +5 -50
- mindspore/ops/operations/comm_ops.py +95 -17
- mindspore/ops/operations/custom_ops.py +237 -22
- mindspore/ops/operations/debug_ops.py +33 -35
- mindspore/ops/operations/manually_defined/ops_def.py +39 -318
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +3 -3
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +4 -27
- mindspore/ops/tensor_method.py +88 -10
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_auto_parallel_context.py +5 -15
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +4 -6
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_utils.py +34 -17
- mindspore/parallel/auto_parallel.py +23 -9
- mindspore/parallel/checkpoint_transform.py +20 -2
- mindspore/parallel/cluster/process_entity/_api.py +28 -33
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/parallel/cluster/run.py +5 -3
- mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/function/reshard_func.py +6 -5
- mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
- mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
- mindspore/parallel/shard.py +7 -21
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +127 -20
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +40 -4
- mindspore/profiler/common/path_manager.py +65 -24
- mindspore/profiler/common/profiler_context.py +27 -14
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +10 -6
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/dynamic_profiler.py +91 -46
- mindspore/profiler/envprofiler.py +30 -5
- mindspore/profiler/experimental_config.py +18 -2
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +34 -7
- mindspore/profiler/profiler.py +193 -145
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +9 -6
- mindspore/runtime/executor.py +35 -0
- mindspore/runtime/memory.py +113 -0
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +8 -21
- mindspore/train/amp.py +6 -7
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +1 -17
- mindspore/train/callback/_flops_collector.py +10 -6
- mindspore/train/callback/_train_fault_tolerance.py +72 -25
- mindspore/train/data_sink.py +5 -9
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +41 -230
- mindspore/train/serialization.py +160 -401
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +152 -16
- mindspore/version.py +1 -1
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/__init__.py +0 -23
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/train/memory_profiling_pb2.py +0 -298
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -14,15 +14,28 @@
|
|
|
14
14
|
"""
|
|
15
15
|
Interpolation Mode, Resampling Filters
|
|
16
16
|
"""
|
|
17
|
+
import gc
|
|
18
|
+
import importlib
|
|
19
|
+
import math
|
|
20
|
+
import numbers
|
|
21
|
+
import os
|
|
22
|
+
import re
|
|
17
23
|
from enum import Enum, IntEnum
|
|
18
24
|
from fractions import Fraction
|
|
19
|
-
import numbers
|
|
20
25
|
|
|
21
26
|
import numpy as np
|
|
22
27
|
from PIL import Image
|
|
23
28
|
|
|
24
29
|
import mindspore
|
|
25
30
|
import mindspore._c_dataengine as cde
|
|
31
|
+
from mindspore import log as logger
|
|
32
|
+
from mindspore.dataset.core.validator_helpers import check_file, check_value, type_check, type_check_list
|
|
33
|
+
from ..core.config import get_video_backend
|
|
34
|
+
|
|
35
|
+
_CALLED_TIMES = 0
|
|
36
|
+
_GC_COLLECTION_INTERVAL = 10
|
|
37
|
+
_INITIALIZED = False
|
|
38
|
+
_INITIALIZED_PID = False
|
|
26
39
|
|
|
27
40
|
# The following constants have been deprecated by Pillow since version 9.1.0
|
|
28
41
|
if int(Image.__version__.split(".")[0]) > 9 or Image.__version__ >= "9.1.0":
|
|
@@ -627,19 +640,371 @@ def read_image(filename, mode=ImageReadMode.UNCHANGED):
|
|
|
627
640
|
return cde.read_image(filename, ImageReadMode.to_c_type(mode)).as_array()
|
|
628
641
|
|
|
629
642
|
|
|
643
|
+
class DecodeParams:
|
|
644
|
+
""" Struct to store decoder parameters. """
|
|
645
|
+
|
|
646
|
+
def __init__(self, container, start_offset, end_offset, stream):
|
|
647
|
+
self.container = container
|
|
648
|
+
self.start_offset = start_offset
|
|
649
|
+
self.end_offset = end_offset
|
|
650
|
+
self.stream = stream
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
class VideoFrameDvpp:
|
|
654
|
+
""" Struct to store parameters of decoder. """
|
|
655
|
+
|
|
656
|
+
dts: int
|
|
657
|
+
pts: int
|
|
658
|
+
positions: int
|
|
659
|
+
frame: np.ndarray
|
|
660
|
+
|
|
661
|
+
def __init__(self, dts=0, pts=0):
|
|
662
|
+
self.pts = pts
|
|
663
|
+
self.dts = dts
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def _get_frame_by_cv(filename, container, stream):
|
|
667
|
+
""" Grab video frames with OpenCV. """
|
|
668
|
+
|
|
669
|
+
try:
|
|
670
|
+
cv2 = importlib.import_module("cv2")
|
|
671
|
+
except ModuleNotFoundError:
|
|
672
|
+
raise ImportError("Importing cv2 failed, try to install it by running `pip install opencv-python`.")
|
|
673
|
+
|
|
674
|
+
cap = cv2.VideoCapture(filename)
|
|
675
|
+
cap.set(cv2.CAP_PROP_FORMAT, -1)
|
|
676
|
+
frames = {}
|
|
677
|
+
pts_list = []
|
|
678
|
+
pts_per_frame = round(1 / cap.get(cv2.CAP_PROP_POS_AVI_RATIO) / cap.get(cv2.CAP_PROP_FPS), 0)
|
|
679
|
+
|
|
680
|
+
for packet in container.demux(stream):
|
|
681
|
+
if packet.pts is not None:
|
|
682
|
+
frame = VideoFrameDvpp(packet.dts, packet.pts)
|
|
683
|
+
_, frame.frame = cap.read()
|
|
684
|
+
pts_list.append(packet.pts)
|
|
685
|
+
frames[frame.pts] = frame
|
|
686
|
+
cap.release()
|
|
687
|
+
pts_list.sort()
|
|
688
|
+
position_list = {value: index for index, value in enumerate(pts_list)}
|
|
689
|
+
for frame in frames.values():
|
|
690
|
+
frame.positions = position_list[frame.pts]
|
|
691
|
+
|
|
692
|
+
return frames, pts_per_frame
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
|
|
696
|
+
""" Align audio frames with specified start and end. """
|
|
697
|
+
|
|
698
|
+
start, end = audio_frames[0].pts, audio_frames[-1].pts
|
|
699
|
+
total_aframes = aframes.shape[1]
|
|
700
|
+
step_per_aframe = (end - start + 1) / total_aframes
|
|
701
|
+
s_idx = 0
|
|
702
|
+
e_idx = total_aframes
|
|
703
|
+
if start < ref_start:
|
|
704
|
+
s_idx = int((ref_start - start) / step_per_aframe)
|
|
705
|
+
if end > ref_end:
|
|
706
|
+
e_idx = int((ref_end - end) / step_per_aframe)
|
|
707
|
+
return aframes[:, s_idx:e_idx]
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
def _decode_video_dvpp(decode_params, frames, pts_per_frame):
|
|
711
|
+
""" Send frames to Ascend and using DVPP to decode. """
|
|
712
|
+
|
|
713
|
+
container = decode_params.container
|
|
714
|
+
start_offset = decode_params.start_offset
|
|
715
|
+
end_offset = decode_params.end_offset
|
|
716
|
+
stream = decode_params.stream
|
|
717
|
+
|
|
718
|
+
codecs_type = stream.name
|
|
719
|
+
hi_pt_h264 = 96
|
|
720
|
+
hi_pt_h265 = 265
|
|
721
|
+
if codecs_type == "h264":
|
|
722
|
+
codec_id = hi_pt_h264
|
|
723
|
+
elif codecs_type == "hevc":
|
|
724
|
+
codec_id = hi_pt_h265
|
|
725
|
+
else:
|
|
726
|
+
raise ValueError(f"The video codecs_type should be either 'h264' or 'hevc', got {codecs_type}.")
|
|
727
|
+
|
|
728
|
+
# if start_offset is between 2 frames, get one more previous frame
|
|
729
|
+
start_offset = int(start_offset / pts_per_frame) * pts_per_frame
|
|
730
|
+
|
|
731
|
+
frame_width = stream.width
|
|
732
|
+
frame_height = stream.height
|
|
733
|
+
|
|
734
|
+
# update end_offset_real
|
|
735
|
+
end_offset_real = end_offset
|
|
736
|
+
# if end_offset equals to a frame's pts, get one more this frame
|
|
737
|
+
if end_offset_real % pts_per_frame == 0:
|
|
738
|
+
end_offset_real += 1
|
|
739
|
+
end_offset_real = min(end_offset_real, len(frames) * pts_per_frame)
|
|
740
|
+
start_frame = math.ceil(start_offset / pts_per_frame)
|
|
741
|
+
total_frame = math.ceil((end_offset_real - start_offset) / pts_per_frame)
|
|
742
|
+
|
|
743
|
+
if end_offset_real < start_offset or total_frame == 0:
|
|
744
|
+
return np.empty(0, dtype=np.uint8)
|
|
745
|
+
ret_tensor = cde.DeviceBuffer([total_frame, 3, frame_height, frame_width])
|
|
746
|
+
|
|
747
|
+
# decode from dvpp
|
|
748
|
+
chn = cde.decode_video_create_chn(codec_id)
|
|
749
|
+
cde.decode_video_start_get_frame(chn, total_frame)
|
|
750
|
+
|
|
751
|
+
for packet in container.demux(stream):
|
|
752
|
+
if packet.pts is not None:
|
|
753
|
+
frame = frames[packet.pts].frame
|
|
754
|
+
input_tensor = cde.DeviceBuffer.from_numpy(frame)
|
|
755
|
+
|
|
756
|
+
if start_offset <= int(packet.pts) <= end_offset:
|
|
757
|
+
display = True
|
|
758
|
+
output_tensor = ret_tensor[frames[packet.pts].positions - start_frame]
|
|
759
|
+
else:
|
|
760
|
+
display = False
|
|
761
|
+
output_tensor = cde.DeviceBuffer([])
|
|
762
|
+
# 12:rgb888packed; 13:bgr888packed; 69:rgb888planer; 70:bgr888planer. Packed is HWC, planer is CHW
|
|
763
|
+
# use CHW to avoid memory copy
|
|
764
|
+
cde.decode_video_send_stream(chn, input_tensor, 69, display, output_tensor)
|
|
765
|
+
|
|
766
|
+
# ret_tensor is ordered by pts
|
|
767
|
+
ret_tensor_dvpp = cde.decode_video_stop_get_frame(chn, total_frame)
|
|
768
|
+
|
|
769
|
+
# if ret_tensor_dvpp empty, means ret_tensor already filled
|
|
770
|
+
if ret_tensor_dvpp.size() != 0:
|
|
771
|
+
ret_tensor = ret_tensor_dvpp
|
|
772
|
+
|
|
773
|
+
ret_numpy = ret_tensor.numpy()
|
|
774
|
+
|
|
775
|
+
cde.decode_video_destroy_chnl(chn)
|
|
776
|
+
return ret_numpy
|
|
777
|
+
|
|
778
|
+
|
|
779
|
+
def _check_buffer(extradata):
|
|
780
|
+
""" Check if the video should be buffered. """
|
|
781
|
+
|
|
782
|
+
should_buffer = True
|
|
783
|
+
if extradata and b"DivX" in extradata:
|
|
784
|
+
# can't use regex directly because of some weird characters sometimes...
|
|
785
|
+
pos = extradata.find(b"DivX")
|
|
786
|
+
d = extradata[pos:]
|
|
787
|
+
o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
|
|
788
|
+
if o is None:
|
|
789
|
+
o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
|
|
790
|
+
if o is not None:
|
|
791
|
+
should_buffer = o.group(3) == b"p"
|
|
792
|
+
return should_buffer
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
def _read_from_stream_dvpp(filename, container, start_offset, end_offset, pts_unit, stream, stream_name):
|
|
796
|
+
""" Read video stream with DVPP. """
|
|
797
|
+
|
|
798
|
+
if not stream.type == "video":
|
|
799
|
+
raise RuntimeError("_read_from_stream_dvpp only handle video type")
|
|
800
|
+
if pts_unit == "sec" and stream.time_base != 0:
|
|
801
|
+
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
|
|
802
|
+
if end_offset != float("inf") and stream.time_base != 0:
|
|
803
|
+
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
|
|
804
|
+
else:
|
|
805
|
+
logger.warning("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
|
|
806
|
+
|
|
807
|
+
max_buffer_size = 5
|
|
808
|
+
# DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
|
|
809
|
+
# so need to buffer some extra frames to sort everything properly
|
|
810
|
+
should_buffer = _check_buffer(stream.codec_context.extradata)
|
|
811
|
+
|
|
812
|
+
seek_offset = start_offset
|
|
813
|
+
# some files don't seek to the right location, so better be safe here
|
|
814
|
+
seek_offset = max(seek_offset - 1, 0)
|
|
815
|
+
|
|
816
|
+
if should_buffer:
|
|
817
|
+
seek_offset = max(seek_offset - max_buffer_size, 0)
|
|
818
|
+
# init frames before seek
|
|
819
|
+
frames, pts_per_frame = _get_frame_by_cv(filename, container, stream)
|
|
820
|
+
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
|
|
821
|
+
decode_params = DecodeParams(container, start_offset, end_offset, stream)
|
|
822
|
+
frames = _decode_video_dvpp(decode_params, frames, pts_per_frame)
|
|
823
|
+
|
|
824
|
+
if frames is None:
|
|
825
|
+
logger.warning(f"_decode_video_dvpp failed: {filename}")
|
|
826
|
+
|
|
827
|
+
return frames
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
def _read_from_stream_ffmpeg(container, start_offset, end_offset, pts_unit, stream, stream_name):
|
|
831
|
+
""" Read video stream with FFMPEG. """
|
|
832
|
+
|
|
833
|
+
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
|
|
834
|
+
_CALLED_TIMES += 1
|
|
835
|
+
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
|
|
836
|
+
gc.collect()
|
|
837
|
+
|
|
838
|
+
if pts_unit == "sec":
|
|
839
|
+
# sec and convert to MS in C++
|
|
840
|
+
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
|
|
841
|
+
if end_offset != float("inf"):
|
|
842
|
+
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
|
|
843
|
+
else:
|
|
844
|
+
logger.warning("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
|
|
845
|
+
|
|
846
|
+
frames = {}
|
|
847
|
+
max_buffer_size = 5
|
|
848
|
+
should_buffer = True
|
|
849
|
+
if stream.type == "video":
|
|
850
|
+
# DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
|
|
851
|
+
# so need to buffer some extra frames to sort everything properly
|
|
852
|
+
should_buffer = _check_buffer(stream.codec_context.extradata)
|
|
853
|
+
|
|
854
|
+
seek_offset = start_offset
|
|
855
|
+
# some files don't seek to the right location, so better be safe here
|
|
856
|
+
seek_offset = max(seek_offset - 1, 0)
|
|
857
|
+
if should_buffer:
|
|
858
|
+
seek_offset = max(seek_offset - max_buffer_size, 0)
|
|
859
|
+
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
|
|
860
|
+
buffer_count = 0
|
|
861
|
+
|
|
862
|
+
for frame in container.decode(**stream_name):
|
|
863
|
+
frames[frame.pts] = frame
|
|
864
|
+
if frame.pts >= end_offset:
|
|
865
|
+
if should_buffer and buffer_count < max_buffer_size:
|
|
866
|
+
buffer_count += 1
|
|
867
|
+
continue
|
|
868
|
+
break
|
|
869
|
+
|
|
870
|
+
# ensure that the results are sorted with the pts
|
|
871
|
+
result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
|
|
872
|
+
if frames and start_offset > 0 and start_offset not in frames:
|
|
873
|
+
# if there is no frame that exactly matches the pts of start_offset
|
|
874
|
+
# add the last frame smaller than start_offset, to guarantee that
|
|
875
|
+
# we will have all the necessary data. This is most useful for audio
|
|
876
|
+
preceding_frames = [i for i in frames if i < start_offset]
|
|
877
|
+
if preceding_frames:
|
|
878
|
+
first_frame_pts = max(preceding_frames)
|
|
879
|
+
result.insert(0, frames[first_frame_pts])
|
|
880
|
+
return result
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
def _dvpp_init():
|
|
884
|
+
""" Init dvpp resources. """
|
|
885
|
+
|
|
886
|
+
global _INITIALIZED, _INITIALIZED_PID
|
|
887
|
+
if _INITIALIZED and _INITIALIZED_PID != os.getpid():
|
|
888
|
+
raise RuntimeError("Cannot re-initialize Ascend in forked process. To use Ascend with multiprocessing, "
|
|
889
|
+
"you must use the 'spawn' start method "
|
|
890
|
+
"via 'mindspore.dataset.config.set_multiprocessing_start_method('spawn')'.")
|
|
891
|
+
|
|
892
|
+
if not _INITIALIZED:
|
|
893
|
+
cde.dvpp_sys_init()
|
|
894
|
+
_INITIALIZED = True
|
|
895
|
+
_INITIALIZED_PID = os.getpid()
|
|
896
|
+
|
|
897
|
+
|
|
898
|
+
def _read_video_dvpp(filename, start_pts=0, end_pts=None, pts_unit="pts", output_format="THWC"):
|
|
899
|
+
""" Read video with DVPP. """
|
|
900
|
+
|
|
901
|
+
_dvpp_init()
|
|
902
|
+
|
|
903
|
+
output_format = output_format.upper()
|
|
904
|
+
if output_format not in ("THWC", "TCHW"):
|
|
905
|
+
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
|
|
906
|
+
|
|
907
|
+
info = {}
|
|
908
|
+
audio_frames = []
|
|
909
|
+
audio_timebase = Fraction(0, 1)
|
|
910
|
+
|
|
911
|
+
with cde.pyav_open(filename) as container:
|
|
912
|
+
if container.streams.audio:
|
|
913
|
+
audio_timebase = container.streams.audio[0].time_base
|
|
914
|
+
|
|
915
|
+
if container.streams.video:
|
|
916
|
+
if container.streams.video[0].name not in ("hevc", "h264"):
|
|
917
|
+
logger.warning(f"This video in {filename} is coding by {container.streams.video[0].name}, "
|
|
918
|
+
"not supported on DVPP backend and will fall back to run on the FFMPEG."
|
|
919
|
+
"This may have performance implications.")
|
|
920
|
+
return _read_video_ffmpeg(filename, start_pts, end_pts, pts_unit)
|
|
921
|
+
|
|
922
|
+
vframes = _read_from_stream_dvpp(
|
|
923
|
+
filename,
|
|
924
|
+
container,
|
|
925
|
+
start_pts,
|
|
926
|
+
end_pts,
|
|
927
|
+
pts_unit,
|
|
928
|
+
container.streams.video[0],
|
|
929
|
+
{"video": 0})
|
|
930
|
+
|
|
931
|
+
video_fps = container.streams.video[0].average_rate
|
|
932
|
+
# guard against potentially corrupted files
|
|
933
|
+
if video_fps is not None:
|
|
934
|
+
info["video_fps"] = float(video_fps)
|
|
935
|
+
else:
|
|
936
|
+
vframes = np.empty(0, dtype=np.uint8)
|
|
937
|
+
|
|
938
|
+
if container.streams.audio:
|
|
939
|
+
audio_frames = _read_from_stream_ffmpeg(
|
|
940
|
+
container,
|
|
941
|
+
start_pts,
|
|
942
|
+
end_pts,
|
|
943
|
+
pts_unit,
|
|
944
|
+
container.streams.audio[0],
|
|
945
|
+
{"audio": 0},
|
|
946
|
+
)
|
|
947
|
+
info["audio_fps"] = container.streams.audio[0].rate
|
|
948
|
+
|
|
949
|
+
aframes_list = []
|
|
950
|
+
for frame in audio_frames:
|
|
951
|
+
aaa = np.vstack(frame.to_ndarray())
|
|
952
|
+
aframes_list.append(aaa)
|
|
953
|
+
|
|
954
|
+
if aframes_list:
|
|
955
|
+
aframes = np.concatenate(aframes_list, 1)
|
|
956
|
+
if pts_unit == "sec" and audio_timebase != 0:
|
|
957
|
+
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
|
|
958
|
+
if end_pts != float("inf") and audio_timebase != 0:
|
|
959
|
+
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
|
|
960
|
+
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
|
|
961
|
+
else:
|
|
962
|
+
aframes = np.empty((1, 0), dtype=np.float32)
|
|
963
|
+
|
|
964
|
+
if output_format == "THWC" and vframes is not None and vframes.size != 0:
|
|
965
|
+
# [T,C,H,W] --> [T,H,W,C]
|
|
966
|
+
vframes = vframes.transpose(0, 2, 3, 1)
|
|
967
|
+
|
|
968
|
+
return vframes, aframes, info
|
|
969
|
+
|
|
970
|
+
|
|
971
|
+
def _read_video_ffmpeg(filename, start_pts=0, end_pts=None, pts_unit="pts"):
|
|
972
|
+
""" Read video with FFMPEG. """
|
|
973
|
+
|
|
974
|
+
video_output, audio_output, raw_metadata = cde.read_video(filename, float(start_pts), float(end_pts), pts_unit)
|
|
975
|
+
|
|
976
|
+
if video_output is not None:
|
|
977
|
+
video_output = video_output.as_array()
|
|
978
|
+
if audio_output is not None:
|
|
979
|
+
audio_output = audio_output.as_array()
|
|
980
|
+
metadata_output = {}
|
|
981
|
+
for key in raw_metadata:
|
|
982
|
+
if key == "video_fps":
|
|
983
|
+
metadata_output[key] = float(raw_metadata[key])
|
|
984
|
+
continue
|
|
985
|
+
if key == "audio_fps":
|
|
986
|
+
metadata_output[key] = int(raw_metadata[key])
|
|
987
|
+
continue
|
|
988
|
+
metadata_output[key] = raw_metadata[key]
|
|
989
|
+
return video_output, audio_output, metadata_output
|
|
990
|
+
|
|
991
|
+
|
|
630
992
|
def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
|
|
631
993
|
"""
|
|
632
994
|
Read the video, audio, metadata from a video file.
|
|
633
995
|
|
|
634
|
-
It supports AVI, H264, H265, MOV, MP4, WMV file formats.
|
|
996
|
+
It supports AVI, H264, H265, MOV, MP4, WMV file formats on CPU, and H264, H265 file formats on Ascend.
|
|
997
|
+
|
|
998
|
+
Note:
|
|
999
|
+
This method is executed on CPU by default, but it is also supported to be executed on Ascend by
|
|
1000
|
+
setting video backend with `mindspore.dataset.config.set_video_backend("Ascend")` .
|
|
635
1001
|
|
|
636
1002
|
Args:
|
|
637
1003
|
filename(str): The path to the video file to be read.
|
|
638
1004
|
start_pts(Union[float, Fraction, int], optional): The start presentation timestamp of the video.
|
|
639
|
-
Default: ``0
|
|
1005
|
+
Default: ``0``, read from the beginning.
|
|
640
1006
|
end_pts(Union[float, Fraction, int], optional): The end presentation timestamp of the video.
|
|
641
|
-
Default: ``None
|
|
642
|
-
The None is represented by 2147483647.
|
|
1007
|
+
Default: ``None``, read until the end.
|
|
643
1008
|
pts_unit(str, optional): The unit of the timestamps. It can be any of ["pts", "sec"]. Default: ``"pts"``.
|
|
644
1009
|
|
|
645
1010
|
Returns:
|
|
@@ -661,7 +1026,7 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
|
|
|
661
1026
|
ValueError: If `pts_unit` is not in ["pts", "sec"].
|
|
662
1027
|
|
|
663
1028
|
Supported Platforms:
|
|
664
|
-
``CPU``
|
|
1029
|
+
``CPU`` ``Ascend``
|
|
665
1030
|
|
|
666
1031
|
Examples:
|
|
667
1032
|
>>> import mindspore.dataset.vision as vision
|
|
@@ -689,22 +1054,268 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
|
|
|
689
1054
|
if pts_unit not in ["pts", "sec"]:
|
|
690
1055
|
raise ValueError("Not supported pts_unit for " + pts_unit)
|
|
691
1056
|
|
|
692
|
-
|
|
1057
|
+
filepath = os.path.realpath(filename)
|
|
693
1058
|
|
|
694
|
-
if
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
1059
|
+
if not os.path.exists(filepath):
|
|
1060
|
+
raise ValueError("Invalid file path, " + filename + " does not exist.")
|
|
1061
|
+
|
|
1062
|
+
if not os.path.isfile(filepath):
|
|
1063
|
+
raise ValueError("Invalid file path, " + filename + " is not a regular file.")
|
|
1064
|
+
|
|
1065
|
+
if get_video_backend() == "Ascend":
|
|
1066
|
+
return _read_video_dvpp(filename, start_pts, end_pts, pts_unit)
|
|
1067
|
+
return _read_video_ffmpeg(filename, start_pts, end_pts, pts_unit)
|
|
1068
|
+
|
|
1069
|
+
|
|
1070
|
+
class VideoDecoder:
|
|
1071
|
+
"""
|
|
1072
|
+
A decoder for single video streams, capable of parsing metadata and extracting frames
|
|
1073
|
+
from H264/H265-encoded content.
|
|
1074
|
+
|
|
1075
|
+
Args:
|
|
1076
|
+
source(str): The path to the video file.
|
|
1077
|
+
|
|
1078
|
+
Raises:
|
|
1079
|
+
TypeError: If `source` is not string.
|
|
1080
|
+
ValueError: If `source` does not exist or permission denied.
|
|
1081
|
+
|
|
1082
|
+
Examples:
|
|
1083
|
+
>>> import mindspore.dataset as ds
|
|
1084
|
+
>>> import mindspore.dataset.vision as vision
|
|
1085
|
+
>>>
|
|
1086
|
+
>>> ds.config.set_video_backend("Ascend")
|
|
1087
|
+
>>> reader = vision.VideoDecoder(source="/path/to/filename")
|
|
1088
|
+
"""
|
|
1089
|
+
def __init__(self, source):
|
|
1090
|
+
check_file(source)
|
|
1091
|
+
self.source = source
|
|
1092
|
+
self._metadata = self.metadata
|
|
1093
|
+
|
|
1094
|
+
def get_frames_at(self, indices):
|
|
1095
|
+
"""
|
|
1096
|
+
Retrieves the frame at the specified index.
|
|
1097
|
+
|
|
1098
|
+
Args:
|
|
1099
|
+
indices (list[int]): List of frame indices to acquire.
|
|
1100
|
+
|
|
1101
|
+
Returns:
|
|
1102
|
+
numpy.ndarray, four dimensions uint8 data for video. The format is [T, H, W, C].
|
|
1103
|
+
`T` is the number of frames, `H` is the height, `W` is the width, `C` is the channel for RGB.
|
|
1104
|
+
|
|
1105
|
+
Raises:
|
|
1106
|
+
TypeError: If `indices` is not of type list.
|
|
1107
|
+
TypeError: If `indices` value is not of type int.
|
|
1108
|
+
ValueError: If `indices` value is not in range [0, total frames).
|
|
1109
|
+
|
|
1110
|
+
Supported Platforms:
|
|
1111
|
+
``Ascend``
|
|
1112
|
+
|
|
1113
|
+
Examples:
|
|
1114
|
+
>>> import mindspore.dataset as ds
|
|
1115
|
+
>>> import mindspore.dataset.vision as vision
|
|
1116
|
+
>>>
|
|
1117
|
+
>>> ds.config.set_video_backend("Ascend")
|
|
1118
|
+
>>> reader = vision.VideoDecoder(source="/path/to/filename")
|
|
1119
|
+
>>> output_frames = reader.get_frames_at([0, 1, 2, 3])
|
|
1120
|
+
"""
|
|
1121
|
+
if get_video_backend() != "Ascend":
|
|
1122
|
+
raise RuntimeError("Method get_frames_at is only supported on Ascend platform.")
|
|
1123
|
+
type_check(indices, (list,), "indices")
|
|
1124
|
+
type_check_list(indices, (int,), "indices")
|
|
1125
|
+
|
|
1126
|
+
_dvpp_init()
|
|
1127
|
+
|
|
1128
|
+
if indices == []:
|
|
1129
|
+
return np.empty(0, dtype=np.uint8)
|
|
1130
|
+
for i, frame_index in enumerate(indices):
|
|
1131
|
+
check_value(frame_index, [0, self._metadata["num_frames"]], "Invalid frame index[{0}]={1}".format(
|
|
1132
|
+
i, indices[i]), right_open_interval=True)
|
|
1133
|
+
filepath = os.path.realpath(self.source)
|
|
1134
|
+
|
|
1135
|
+
with cde.pyav_open(filepath) as container:
|
|
1136
|
+
if container.streams.video:
|
|
1137
|
+
if container.streams.video[0].name in ("hevc", "h264"):
|
|
1138
|
+
vframes = self._read_from_stream_dvpp_frames(
|
|
1139
|
+
filepath,
|
|
1140
|
+
container,
|
|
1141
|
+
0,
|
|
1142
|
+
float("inf"),
|
|
1143
|
+
container.streams.video[0],
|
|
1144
|
+
{"video": 0},
|
|
1145
|
+
indices,
|
|
1146
|
+
)
|
|
1147
|
+
else:
|
|
1148
|
+
raise RuntimeError(f"This video in {filepath} is coding by {container.streams.video[0].name}, "
|
|
1149
|
+
"not supported on DVPP backend.")
|
|
1150
|
+
else:
|
|
1151
|
+
vframes = np.empty(0, dtype=np.uint8)
|
|
1152
|
+
|
|
1153
|
+
if vframes is not None and vframes.size != 0:
|
|
1154
|
+
# [T,C,H,W] --> [T,H,W,C]
|
|
1155
|
+
vframes = vframes.transpose(0, 2, 3, 1)
|
|
1156
|
+
|
|
1157
|
+
return vframes
|
|
1158
|
+
|
|
1159
|
+
@property
|
|
1160
|
+
def metadata(self):
|
|
1161
|
+
"""
|
|
1162
|
+
Getting metadata of the video stream.
|
|
1163
|
+
|
|
1164
|
+
Returns:
|
|
1165
|
+
dict, information about the metadata.
|
|
1166
|
+
|
|
1167
|
+
Examples:
|
|
1168
|
+
>>> import mindspore.dataset as ds
|
|
1169
|
+
>>> import mindspore.dataset.vision as vision
|
|
1170
|
+
>>>
|
|
1171
|
+
>>> ds.config.set_video_backend("Ascend")
|
|
1172
|
+
>>> reader = vision.VideoDecoder(source="/path/to/filename")
|
|
1173
|
+
>>> metadata = reader.metadata
|
|
1174
|
+
"""
|
|
1175
|
+
metadata = {}
|
|
1176
|
+
filepath = os.path.realpath(self.source)
|
|
1177
|
+
with cde.pyav_open(filepath) as container:
|
|
1178
|
+
stream = container.streams.video[0]
|
|
1179
|
+
metadata["width"] = stream.width
|
|
1180
|
+
metadata["height"] = stream.height
|
|
1181
|
+
metadata["duration_seconds"] = round(float(stream.duration * stream.time_base), 6)
|
|
1182
|
+
metadata["num_frames"] = stream.frames
|
|
1183
|
+
metadata["average_fps"] = float(stream.average_rate)
|
|
1184
|
+
return metadata
|
|
1185
|
+
|
|
1186
|
+
def _read_from_stream_dvpp_frames(self, filename, container, start_offset, end_offset,
|
|
1187
|
+
stream, stream_name, indices):
|
|
1188
|
+
""" Read video stream with DVPP. """
|
|
1189
|
+
if not stream.type == "video":
|
|
1190
|
+
raise RuntimeError("_read_from_stream_dvpp_frames only handle video type")
|
|
1191
|
+
max_buffer_size = 5
|
|
1192
|
+
# DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
|
|
1193
|
+
# so need to buffer some extra frames to sort everything properly
|
|
1194
|
+
should_buffer = _check_buffer(stream.codec_context.extradata)
|
|
1195
|
+
|
|
1196
|
+
seek_offset = start_offset
|
|
1197
|
+
# some files don't seek to the right location, so better be safe here
|
|
1198
|
+
seek_offset = max(seek_offset - 1, 0)
|
|
1199
|
+
|
|
1200
|
+
if should_buffer:
|
|
1201
|
+
seek_offset = max(seek_offset - max_buffer_size, 0)
|
|
1202
|
+
# init frames before seek
|
|
1203
|
+
frames, pts_per_frame = _get_frame_by_cv(filename, container, stream)
|
|
1204
|
+
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
|
|
1205
|
+
decode_params = DecodeParams(container, start_offset, end_offset, stream)
|
|
1206
|
+
frames = self._decode_video_dvpp_frames(decode_params, frames, pts_per_frame, indices)
|
|
1207
|
+
|
|
1208
|
+
if frames is None:
|
|
1209
|
+
logger.warning(f"_decode_video_dvpp failed: {filename}")
|
|
1210
|
+
|
|
1211
|
+
return frames
|
|
1212
|
+
|
|
1213
|
+
def _get_key_frames_and_first_pts(self, container, stream, pts_per_frame):
|
|
1214
|
+
""" Get key frames and first pts. """
|
|
1215
|
+
key_list = []
|
|
1216
|
+
first_pts = None
|
|
1217
|
+
for packet in container.demux(stream):
|
|
1218
|
+
if packet.pts is not None:
|
|
1219
|
+
if packet.is_keyframe:
|
|
1220
|
+
key_list.append(int(packet.pts * stream.time_base * stream.average_rate))
|
|
1221
|
+
first_pts = packet.pts % pts_per_frame
|
|
1222
|
+
return key_list, first_pts
|
|
1223
|
+
|
|
1224
|
+
def _compute_ret_tensor(self, container, stream, frames_count, frames_in_group, frames, pts_per_frame,
|
|
1225
|
+
ret_tensor, target_frame_positions, start_frame, chn):
|
|
1226
|
+
""" Compute ret_tensor. """
|
|
1227
|
+
for packet in container.demux(stream):
|
|
1228
|
+
if packet.pts is not None:
|
|
1229
|
+
if frames_count == len(frames_in_group):
|
|
1230
|
+
break
|
|
1231
|
+
frame = frames[packet.pts].frame
|
|
1232
|
+
input_tensor = cde.DeviceBuffer.from_numpy(frame)
|
|
1233
|
+
if packet.pts // pts_per_frame in frames_in_group:
|
|
1234
|
+
display = True
|
|
1235
|
+
frames_count += 1
|
|
1236
|
+
output_tensor = \
|
|
1237
|
+
ret_tensor[target_frame_positions[frames[packet.pts].positions - start_frame][0]]
|
|
1238
|
+
else:
|
|
1239
|
+
display = False
|
|
1240
|
+
output_tensor = cde.DeviceBuffer([])
|
|
1241
|
+
cde.decode_video_send_stream(chn, input_tensor, 69, display, output_tensor)
|
|
1242
|
+
|
|
1243
|
+
def _decode_video_dvpp_frames(self, decode_params, frames, pts_per_frame, indices):
|
|
1244
|
+
""" Send frames to Ascend and using DVPP to decode. """
|
|
1245
|
+
|
|
1246
|
+
container = decode_params.container
|
|
1247
|
+
start_offset = decode_params.start_offset
|
|
1248
|
+
end_offset = decode_params.end_offset
|
|
1249
|
+
stream = decode_params.stream
|
|
1250
|
+
|
|
1251
|
+
codecs_type = stream.name
|
|
1252
|
+
hi_pt_h264 = 96
|
|
1253
|
+
hi_pt_h265 = 265
|
|
1254
|
+
if codecs_type == "h264":
|
|
1255
|
+
codec_id = hi_pt_h264
|
|
1256
|
+
elif codecs_type == "hevc":
|
|
1257
|
+
codec_id = hi_pt_h265
|
|
1258
|
+
else:
|
|
1259
|
+
raise ValueError(f"The video codecs_type should be either 'h264' or 'hevc', got {codecs_type}.")
|
|
1260
|
+
|
|
1261
|
+
# if start_offset is between 2 frames, get one more previous frame
|
|
1262
|
+
start_offset = int(start_offset / pts_per_frame) * pts_per_frame
|
|
1263
|
+
|
|
1264
|
+
frame_width = stream.width
|
|
1265
|
+
frame_height = stream.height
|
|
1266
|
+
|
|
1267
|
+
# update end_offset_real
|
|
1268
|
+
end_offset_real = end_offset
|
|
1269
|
+
end_offset_real = min(end_offset_real, len(frames) * pts_per_frame)
|
|
1270
|
+
start_frame = math.ceil(start_offset / pts_per_frame)
|
|
1271
|
+
total_frame = math.ceil((end_offset_real - start_offset) / pts_per_frame)
|
|
1272
|
+
|
|
1273
|
+
if end_offset_real < start_offset or total_frame == 0:
|
|
1274
|
+
return np.empty(0, dtype=np.uint8)
|
|
1275
|
+
target_frame_list = list(set(indices))
|
|
1276
|
+
target_frame_list.sort()
|
|
1277
|
+
target_frame_positions = {}
|
|
1278
|
+
for index, value in enumerate(target_frame_list):
|
|
1279
|
+
target_frame_positions.setdefault(value, []).append(index)
|
|
1280
|
+
ret_tensor = cde.DeviceBuffer([len(target_frame_list), 3, frame_height, frame_width])
|
|
1281
|
+
|
|
1282
|
+
# decode from dvpp
|
|
1283
|
+
chn = cde.decode_video_create_chn(codec_id)
|
|
1284
|
+
cde.decode_video_start_get_frame(chn, len(target_frame_list))
|
|
1285
|
+
|
|
1286
|
+
groups = {}
|
|
1287
|
+
key_list, first_pts = self._get_key_frames_and_first_pts(container, stream, pts_per_frame)
|
|
1288
|
+
|
|
1289
|
+
for frame in target_frame_list:
|
|
1290
|
+
keyframe = max(k for k in key_list if k <= frame)
|
|
1291
|
+
groups.setdefault(keyframe, []).append(frame)
|
|
1292
|
+
container.seek(0, any_frame=False, backward=True, stream=stream)
|
|
1293
|
+
average_rate = stream.average_rate
|
|
1294
|
+
time_base = stream.time_base
|
|
1295
|
+
|
|
1296
|
+
for keyframe, frames_in_group in groups.items():
|
|
1297
|
+
frames_count = 0
|
|
1298
|
+
timestamp = keyframe / average_rate
|
|
1299
|
+
seek_target = int(timestamp / time_base + first_pts)
|
|
1300
|
+
container.seek(seek_target, any_frame=False, backward=True, stream=stream)
|
|
1301
|
+
self._compute_ret_tensor(container, stream, frames_count, frames_in_group, frames,
|
|
1302
|
+
pts_per_frame, ret_tensor, target_frame_positions, start_frame, chn)
|
|
1303
|
+
|
|
1304
|
+
# ret_tensor is ordered by pts len(target_frame_list)
|
|
1305
|
+
ret_tensor_dvpp = cde.decode_video_stop_get_frame(chn, len(target_frame_list))
|
|
1306
|
+
|
|
1307
|
+
# if ret_tensor_dvpp empty, means ret_tensor already filled
|
|
1308
|
+
if ret_tensor_dvpp.size() != 0:
|
|
1309
|
+
ret_tensor = ret_tensor_dvpp
|
|
1310
|
+
|
|
1311
|
+
ret_numpy = ret_tensor.numpy()
|
|
1312
|
+
|
|
1313
|
+
cde.decode_video_destroy_chnl(chn)
|
|
1314
|
+
|
|
1315
|
+
if indices != target_frame_list:
|
|
1316
|
+
mapping = {val: ret_numpy[index] for index, val in enumerate(target_frame_list)}
|
|
1317
|
+
ret_numpy = np.stack([mapping[val] for val in indices])
|
|
1318
|
+
return ret_numpy
|
|
708
1319
|
|
|
709
1320
|
|
|
710
1321
|
def read_video_timestamps(filename, pts_unit="pts"):
|