mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +26 -32
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +72 -95
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +173 -258
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +240 -145
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +13 -2
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +143 -59
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +11 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +0 -14
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +316 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +21 -28
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +310 -207
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +82 -41
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +13 -18
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +22 -17
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +78 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +4 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +167 -189
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -8
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +470 -251
- mindspore/ops/function/random_func.py +86 -56
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +235 -19
- mindspore/ops/operations/__init__.py +25 -17
- mindspore/ops/operations/_grad_ops.py +52 -7
- mindspore/ops/operations/_inner_ops.py +213 -12
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +64 -280
- mindspore/ops/operations/comm_ops.py +105 -57
- mindspore/ops/operations/custom_ops.py +10 -3
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/math_ops.py +185 -138
- mindspore/ops/operations/nn_ops.py +716 -492
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +14 -12
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +6 -10
- mindspore/parallel/shard.py +4 -4
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
- mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
- mindspore/profiler/parser/ascend_op_generator.py +5 -5
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
- mindspore/profiler/parser/base_timeline_generator.py +9 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +37 -21
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +2 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +139 -71
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +525 -577
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +2 -2
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +14 -7
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +83 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +185 -45
- mindspore/train/serialization.py +390 -150
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +14 -10
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +458 -518
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/nn/layer/dense.py
CHANGED
|
@@ -77,6 +77,7 @@ class BiDense(Cell):
|
|
|
77
77
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter.
|
|
78
78
|
The values of str refer to the function `initializer`. Default: ``None`` .
|
|
79
79
|
has_bias (bool): Specifies whether the layer uses :math:`\text{bias}` vector. Default: ``True`` .
|
|
80
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
80
81
|
|
|
81
82
|
Shape:
|
|
82
83
|
- **input1** - :math:`(*, H_{in1})` where :math:`H_{in1}=\text{in1_channels}` and
|
|
@@ -90,8 +91,8 @@ class BiDense(Cell):
|
|
|
90
91
|
are the same shape as the inputs.
|
|
91
92
|
|
|
92
93
|
Dtype:
|
|
93
|
-
- **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2
|
|
94
|
-
- **
|
|
94
|
+
- **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2** .
|
|
95
|
+
- **input2** (Tensor) - The dtype must be float16 or float32 and be same as **input1** .
|
|
95
96
|
- **output** (Tensor) - With the same dtype as the inputs.
|
|
96
97
|
|
|
97
98
|
Weights:
|
|
@@ -133,7 +134,8 @@ class BiDense(Cell):
|
|
|
133
134
|
out_channels,
|
|
134
135
|
weight_init=None,
|
|
135
136
|
bias_init=None,
|
|
136
|
-
has_bias=True
|
|
137
|
+
has_bias=True,
|
|
138
|
+
dtype=mstype.float32):
|
|
137
139
|
super().__init__()
|
|
138
140
|
self.in_channels = Validator.check_positive_int(in1_channels, "in1_channels", self.cls_name)
|
|
139
141
|
self.in_channels = Validator.check_positive_int(in2_channels, "in2_channels", self.cls_name)
|
|
@@ -156,7 +158,8 @@ class BiDense(Cell):
|
|
|
156
158
|
f"equal to 'in2_channels'. But got 'weight_init': {weight_init}, "
|
|
157
159
|
f"'out_channels': {out_channels}, 'in_channels': {in1_channels}, "
|
|
158
160
|
f"'in2_channels': {in2_channels}")
|
|
159
|
-
self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels)
|
|
161
|
+
self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels), dtype=dtype),
|
|
162
|
+
'weight')
|
|
160
163
|
|
|
161
164
|
if self.has_bias:
|
|
162
165
|
if bias_init is None:
|
|
@@ -166,7 +169,7 @@ class BiDense(Cell):
|
|
|
166
169
|
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' should "
|
|
167
170
|
f"be equal to 1, and the first dim must be equal to 'out_channels'. But got "
|
|
168
171
|
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
|
|
169
|
-
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
|
172
|
+
self.bias = Parameter(initializer(bias_init, [out_channels], dtype=dtype), name="bias")
|
|
170
173
|
self.bias_add = P.BiasAdd()
|
|
171
174
|
self.matmul = P.MatMul()
|
|
172
175
|
|
mindspore/nn/layer/embedding.py
CHANGED
|
@@ -64,11 +64,13 @@ class Embedding(Cell):
|
|
|
64
64
|
embedding_size (int): The size of each embedding vector.
|
|
65
65
|
use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: ``False`` .
|
|
66
66
|
embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
|
|
67
|
-
Refer to class `initializer
|
|
68
|
-
|
|
67
|
+
Refer to class `mindspore.common.initializer
|
|
68
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
69
|
+
for the values of string when a string is specified. Default: ``'normal'`` .
|
|
69
70
|
dtype (:class:`mindspore.dtype`): Data type of `x`. Default: ``mstype.float32`` .
|
|
70
71
|
padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
|
|
71
72
|
will be initialized to zero. Default: ``None`` . The feature is inactivated.
|
|
73
|
+
|
|
72
74
|
Inputs:
|
|
73
75
|
- **x** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{x_length})`. The elements of
|
|
74
76
|
the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
|
|
@@ -145,9 +147,8 @@ class Embedding(Cell):
|
|
|
145
147
|
return output
|
|
146
148
|
|
|
147
149
|
def extend_repr(self):
|
|
148
|
-
|
|
149
|
-
self.
|
|
150
|
-
return s
|
|
150
|
+
return f'vocab_size={self.vocab_size}, embedding_size={self.embedding_size}, use_one_hot={self.use_one_hot}, ' \
|
|
151
|
+
f'embedding_table={self.embedding_table}, dtype={self.dtype}, padding_idx={self.padding_idx}'
|
|
151
152
|
|
|
152
153
|
|
|
153
154
|
@_primexpr
|
|
@@ -190,6 +191,7 @@ class EmbeddingLookup(Cell):
|
|
|
190
191
|
parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
|
|
191
192
|
optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
|
|
192
193
|
memory, so suggests setting a reasonable value to avoid insufficient memory.
|
|
194
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
193
195
|
|
|
194
196
|
Inputs:
|
|
195
197
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
@@ -231,7 +233,7 @@ class EmbeddingLookup(Cell):
|
|
|
231
233
|
|
|
232
234
|
def __init__(self, vocab_size, embedding_size, param_init='normal',
|
|
233
235
|
target='CPU', slice_mode='batch_slice', manual_shapes=None,
|
|
234
|
-
max_norm=None, sparse=True, vocab_cache_size=0):
|
|
236
|
+
max_norm=None, sparse=True, vocab_cache_size=0, dtype=mstype.float32):
|
|
235
237
|
"""Initialize EmbeddingLookup."""
|
|
236
238
|
super(EmbeddingLookup, self).__init__()
|
|
237
239
|
Validator.check_value_type('sparse', sparse, [bool], self.cls_name)
|
|
@@ -255,8 +257,8 @@ class EmbeddingLookup(Cell):
|
|
|
255
257
|
if enable_ps:
|
|
256
258
|
self._process_vocab_cache(slice_mode)
|
|
257
259
|
self.embedding_size = Validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name)
|
|
258
|
-
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]
|
|
259
|
-
|
|
260
|
+
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size],
|
|
261
|
+
dtype=dtype), name='embedding_table')
|
|
260
262
|
parallel_mode = _get_parallel_mode()
|
|
261
263
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
262
264
|
self.gather_revert = P.Gather()
|
|
@@ -267,7 +269,7 @@ class EmbeddingLookup(Cell):
|
|
|
267
269
|
if is_auto_parallel:
|
|
268
270
|
self.unique = P.Unique().shard(((1,),))
|
|
269
271
|
if self.cache_enable and enable_ps:
|
|
270
|
-
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init)
|
|
272
|
+
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init, dtype=dtype)
|
|
271
273
|
if is_auto_parallel:
|
|
272
274
|
self.unique.add_prim_attr('cache_enable', True)
|
|
273
275
|
indices_shape_size = 2
|
|
@@ -310,8 +312,8 @@ class EmbeddingLookup(Cell):
|
|
|
310
312
|
else:
|
|
311
313
|
if is_auto_parallel:
|
|
312
314
|
support_mode = ["field_slice", "table_row_slice", "table_column_slice", "batch_slice"]
|
|
313
|
-
raise ValueError("For '{}', the 'slice_mode' must be in {}, "
|
|
314
|
-
"but got \"{}\"."
|
|
315
|
+
raise ValueError(f"For '{self.cls_name}', the 'slice_mode' must be in {support_mode}, "
|
|
316
|
+
f"but got \"{slice_mode}\".")
|
|
315
317
|
if self.cache_enable and not enable_ps:
|
|
316
318
|
raise ValueError(f"For '{self.cls_name}', haven't supported cache enable for not ps mode.")
|
|
317
319
|
self.embedding_table.unique = self.forward_unique
|
|
@@ -354,7 +356,8 @@ class EmbeddingLookup(Cell):
|
|
|
354
356
|
if _is_role_worker():
|
|
355
357
|
self.vocab_size = self.vocab_cache_size
|
|
356
358
|
|
|
357
|
-
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init
|
|
359
|
+
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init,
|
|
360
|
+
dtype=mstype.float32):
|
|
358
361
|
"""PS embeddingLookup cache enable set."""
|
|
359
362
|
if self.sparse:
|
|
360
363
|
self.forward_unique = True
|
|
@@ -368,10 +371,10 @@ class EmbeddingLookup(Cell):
|
|
|
368
371
|
if _enable_distributed_mindrt():
|
|
369
372
|
self.rank_id = get_rank()
|
|
370
373
|
if self.is_ps_server:
|
|
371
|
-
self._slice_pserver_embeddings("zeros")
|
|
374
|
+
self._slice_pserver_embeddings("zeros", dtype=dtype)
|
|
372
375
|
self._set_cache_enable_and_key_for_pserver(param_key)
|
|
373
376
|
|
|
374
|
-
def _slice_pserver_embeddings(self, param_init):
|
|
377
|
+
def _slice_pserver_embeddings(self, param_init, dtype=mstype.float32):
|
|
375
378
|
'''
|
|
376
379
|
Method to slice embedding tables on Parameter Servers.
|
|
377
380
|
It helps to train with a large scale embedding table and is used only in Parameter Server training mode.
|
|
@@ -399,7 +402,7 @@ class EmbeddingLookup(Cell):
|
|
|
399
402
|
for i in range(server_num):
|
|
400
403
|
self.embedding_table_list.append(Parameter(initializer(param_init,
|
|
401
404
|
[self.embedding_table_vocab_dim_list[i],
|
|
402
|
-
self.embedding_size]),
|
|
405
|
+
self.embedding_size], dtype=dtype),
|
|
403
406
|
name="embedding_table_server_" + str(i)))
|
|
404
407
|
|
|
405
408
|
self.embedding_offset.append(offset)
|
|
@@ -505,12 +508,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
505
508
|
:class:`mindspore.nn.EmbeddingLookup`. Default: ``'batch_slice'``.
|
|
506
509
|
feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently.
|
|
507
510
|
Default: ``None`` .
|
|
508
|
-
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
|
509
|
-
|
|
511
|
+
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32.
|
|
512
|
+
Default: ``None`` .
|
|
510
513
|
sparse (bool): Using sparse mode. When 'target' is set to ``'CPU'`` , 'sparse' has to be true.
|
|
511
514
|
Default: ``True`` .
|
|
512
515
|
operator (str): The pooling method for the features in one field. Support ``'SUM'`` , ``'MEAN'`` and
|
|
513
516
|
``'MAX'`` . Default: ``'SUM'`` .
|
|
517
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
514
518
|
|
|
515
519
|
Inputs:
|
|
516
520
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`.
|
|
@@ -529,12 +533,12 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
529
533
|
TypeError: If `vocab_size` or `embedding_size` or `field_size` is not an int.
|
|
530
534
|
TypeError: If `sparse` is not a bool or `feature_num_list` is not a tuple.
|
|
531
535
|
ValueError: If `vocab_size` or `embedding_size` or `field_size` is less than 1.
|
|
532
|
-
ValueError: If `target` is neither 'CPU' nor 'DEVICE'
|
|
533
|
-
ValueError: If `slice_mode` is not one of 'batch_slice'
|
|
534
|
-
'table_column_slice'.
|
|
535
|
-
ValueError: If `sparse` is False and `target` is 'CPU'.
|
|
536
|
-
ValueError: If `slice_mode` is 'field_slice' and `feature_num_list` is None.
|
|
537
|
-
ValueError: If `operator` is not one of 'SUM'
|
|
536
|
+
ValueError: If `target` is neither ``'CPU'`` nor ``'DEVICE'``.
|
|
537
|
+
ValueError: If `slice_mode` is not one of ``'batch_slice'``, ``'field_slice'``, ``'table_row_slice'``,
|
|
538
|
+
``'table_column_slice'`` .
|
|
539
|
+
ValueError: If `sparse` is False and `target` is ``'CPU'`` .
|
|
540
|
+
ValueError: If `slice_mode` is ``'field_slice'`` and `feature_num_list` is None.
|
|
541
|
+
ValueError: If `operator` is not one of ``'SUM'``, ``'MAX'``, ``'MEAN'`` .
|
|
538
542
|
|
|
539
543
|
Supported Platforms:
|
|
540
544
|
``Ascend`` ``GPU``
|
|
@@ -555,10 +559,11 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
555
559
|
OPERATOR_MAX = 'MAX'
|
|
556
560
|
|
|
557
561
|
def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU',
|
|
558
|
-
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'
|
|
562
|
+
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM',
|
|
563
|
+
dtype=mstype.float32):
|
|
559
564
|
"""Initialize MultiFieldEmbeddingLookup."""
|
|
560
565
|
super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
|
|
561
|
-
slice_mode, feature_num_list, max_norm, sparse)
|
|
566
|
+
slice_mode, feature_num_list, max_norm, sparse, dtype=dtype)
|
|
562
567
|
self.field_size = Validator.check_positive_int(field_size, 'field_size', self.cls_name)
|
|
563
568
|
self.operator = operator
|
|
564
569
|
|
|
@@ -622,8 +627,9 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
622
627
|
self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1)))
|
|
623
628
|
else:
|
|
624
629
|
if is_auto_parallel:
|
|
625
|
-
raise ValueError(
|
|
626
|
-
|
|
630
|
+
raise ValueError(
|
|
631
|
+
f"For '{self.cls_name}', the 'slice_mode' must be in ['table_row_slice', 'batch_slice' "
|
|
632
|
+
f"and 'table_column_slice'], but got {str(slice_mode)}.")
|
|
627
633
|
|
|
628
634
|
# Min value for fp32
|
|
629
635
|
self.negative_inf_value = -3.402823466E+38
|
|
@@ -17,12 +17,13 @@ A FlashAttention Layer.
|
|
|
17
17
|
"""
|
|
18
18
|
import math
|
|
19
19
|
|
|
20
|
-
import mindspore.
|
|
21
|
-
from mindspore import ops
|
|
22
|
-
from mindspore.common import dtype as mstype
|
|
20
|
+
import mindspore.common.dtype as mstype
|
|
23
21
|
from mindspore.common.tensor import Tensor
|
|
22
|
+
from mindspore import ops
|
|
24
23
|
from mindspore.nn.cell import Cell
|
|
25
24
|
from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_impl import get_flash_attention
|
|
25
|
+
from mindspore.ops.operations.nn_ops import FlashAttentionScore
|
|
26
|
+
from mindspore._c_expression import MSContext
|
|
26
27
|
|
|
27
28
|
__all__ = ['FlashAttention']
|
|
28
29
|
|
|
@@ -92,6 +93,7 @@ class FlashAttention(Cell):
|
|
|
92
93
|
|
|
93
94
|
def __init__(self,
|
|
94
95
|
head_dim,
|
|
96
|
+
head_num,
|
|
95
97
|
dropout_rate=0.0,
|
|
96
98
|
prev_block_num=65536,
|
|
97
99
|
next_block_num=65536,
|
|
@@ -104,18 +106,42 @@ class FlashAttention(Cell):
|
|
|
104
106
|
):
|
|
105
107
|
super(FlashAttention, self).__init__()
|
|
106
108
|
|
|
107
|
-
self.flash_attention = get_flash_attention(
|
|
108
|
-
prev_block_num=prev_block_num,
|
|
109
|
-
next_block_num=next_block_num,
|
|
110
|
-
tiling_stgy_name=tiling_stgy_name,
|
|
111
|
-
high_precision=high_precision
|
|
112
|
-
)
|
|
113
|
-
self.flash_attention.add_prim_attr("primitive_target", "Ascend")
|
|
114
109
|
scaling_constant = math.sqrt(head_dim)
|
|
115
|
-
if scaling_constant
|
|
116
|
-
self.scale_factor = Tensor([1. / scaling_constant], dtype=mstype.float16)
|
|
117
|
-
else:
|
|
110
|
+
if scaling_constant == 0:
|
|
118
111
|
raise ValueError("the scaling constant must not be 0.")
|
|
112
|
+
self.scale_factor = Tensor([1. / scaling_constant], dtype=mstype.float16)
|
|
113
|
+
|
|
114
|
+
self.is_910A = MSContext.get_instance().get_ascend_soc_version() == "Ascend910"
|
|
115
|
+
if self.is_910A:
|
|
116
|
+
self.flash_attention = get_flash_attention(
|
|
117
|
+
prev_block_num=prev_block_num,
|
|
118
|
+
next_block_num=next_block_num,
|
|
119
|
+
tiling_stgy_name=tiling_stgy_name,
|
|
120
|
+
high_precision=high_precision
|
|
121
|
+
)
|
|
122
|
+
self.flash_attention.add_prim_attr("primitive_target", "Ascend")
|
|
123
|
+
else:
|
|
124
|
+
if alibi:
|
|
125
|
+
raise ValueError(f"When soc_version is not Ascend910A, alibi must be False")
|
|
126
|
+
self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
|
|
127
|
+
self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
|
|
128
|
+
self.reshape = ops.Reshape()
|
|
129
|
+
self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
|
|
130
|
+
self.zeros = ops.Zeros()
|
|
131
|
+
self.attn_expand_dims = ops.ExpandDims().shard(((dp, 1, 1),))
|
|
132
|
+
fa_strategies = ((dp, 1, mp),
|
|
133
|
+
(dp, 1, mp),
|
|
134
|
+
(dp, 1, mp),
|
|
135
|
+
(dp, 1, 1, 1))
|
|
136
|
+
if dropout_rate > 1e-5:
|
|
137
|
+
fa_strategies += ((dp, mp, 1, 1),)
|
|
138
|
+
self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
|
|
139
|
+
next_tokens=next_block_num,
|
|
140
|
+
keep_prob=1 - dropout_rate,
|
|
141
|
+
scale_value=1.0,
|
|
142
|
+
inner_precise=0 if high_precision else 1).shard(fa_strategies)
|
|
143
|
+
|
|
144
|
+
self.ones = ops.Ones()
|
|
119
145
|
self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
|
|
120
146
|
self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
|
|
121
147
|
self.dropout_rate = dropout_rate
|
|
@@ -136,38 +162,35 @@ class FlashAttention(Cell):
|
|
|
136
162
|
such as MatMul. Default: None.
|
|
137
163
|
:return:
|
|
138
164
|
"""
|
|
139
|
-
if in_strategy is
|
|
140
|
-
shard_stgy = list(in_strategy)
|
|
141
|
-
shard_stgy.insert(3, (1,)) # dim_mask
|
|
142
|
-
shard_stgy = tuple(shard_stgy)
|
|
143
|
-
else:
|
|
165
|
+
if in_strategy is None:
|
|
144
166
|
# default: dp=1, mp=1, construct inputs only contain query, key, value
|
|
145
|
-
|
|
167
|
+
in_strategy = (
|
|
146
168
|
(1, 1, 1, 1),
|
|
147
169
|
(1, 1, 1, 1),
|
|
148
170
|
(1, 1, 1, 1),
|
|
149
|
-
(1,), # dim_mask
|
|
150
171
|
)
|
|
151
|
-
self.flash_attention.shard(
|
|
152
|
-
dp =
|
|
153
|
-
mp =
|
|
172
|
+
self.flash_attention.shard(in_strategy)
|
|
173
|
+
dp = in_strategy[0][0]
|
|
174
|
+
mp = in_strategy[0][1]
|
|
154
175
|
self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
|
|
155
176
|
inputs_tensor_map = [
|
|
156
177
|
[3, 2, 1, 0],
|
|
157
178
|
[3, 2, 1, 0],
|
|
158
179
|
[3, 2, 1, 0],
|
|
159
|
-
[-1]
|
|
160
180
|
]
|
|
161
181
|
if self.have_attention_mask_batch:
|
|
162
182
|
inputs_tensor_map.append([3, 1, 0])
|
|
163
183
|
else:
|
|
164
184
|
inputs_tensor_map.append([-1, 1, 0])
|
|
165
185
|
|
|
186
|
+
input_empty_args_num = 2
|
|
166
187
|
# dropout_mask
|
|
167
188
|
if self.dropout_rate > 1e-5:
|
|
189
|
+
input_empty_args_num -= 1
|
|
168
190
|
inputs_tensor_map.append([3, 2, 1, 0])
|
|
169
191
|
|
|
170
192
|
if self.alibi:
|
|
193
|
+
input_empty_args_num -= 1
|
|
171
194
|
inputs_tensor_map.append([3, 2, 1, 0])
|
|
172
195
|
|
|
173
196
|
self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
|
|
@@ -178,7 +201,7 @@ class FlashAttention(Cell):
|
|
|
178
201
|
[3, 2, 1] # M
|
|
179
202
|
])
|
|
180
203
|
self.flash_attention.add_prim_attr("as_loss_divisor", 0)
|
|
181
|
-
self.flash_attention.add_prim_attr("empty_mirror_ops",
|
|
204
|
+
self.flash_attention.add_prim_attr("empty_mirror_ops", input_empty_args_num)
|
|
182
205
|
|
|
183
206
|
def construct(self, query, key, value, attn_mask=None, alibi_mask=None):
|
|
184
207
|
"""FlashAttention forward
|
|
@@ -200,24 +223,42 @@ class FlashAttention(Cell):
|
|
|
200
223
|
if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
|
|
201
224
|
raise ValueError(
|
|
202
225
|
"query, key, value seq_len must be a multiple of 16, and key seq_len, value seq_len must be the same.")
|
|
203
|
-
|
|
204
|
-
drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
|
|
205
|
-
tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
|
|
206
|
-
ones = self.fill_v2(tensor_shape, self.tensor_one)
|
|
207
|
-
ones = self.depend(ones, query)
|
|
208
|
-
drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
|
|
209
|
-
else:
|
|
210
|
-
drop_mask = None
|
|
226
|
+
|
|
211
227
|
if head_dim > 304:
|
|
212
228
|
raise ValueError(
|
|
213
229
|
"the head_dim must be less than 304, otherwise the ub would be OOM.")
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
230
|
+
|
|
231
|
+
if self.is_910A:
|
|
232
|
+
# 910A -- FlashAttentionPrimtive
|
|
233
|
+
if self.dropout_rate > 1e-5:
|
|
234
|
+
drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
|
|
235
|
+
tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
|
|
236
|
+
ones = self.fill_v2(tensor_shape, self.tensor_one)
|
|
237
|
+
ones = self.depend(ones, query)
|
|
238
|
+
drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
|
|
239
|
+
else:
|
|
240
|
+
drop_mask = None
|
|
241
|
+
output, _, _ = self.flash_attention(query, key, value, attn_mask, drop_mask, alibi_mask)
|
|
221
242
|
else:
|
|
222
|
-
|
|
243
|
+
# FlashAttentionScore
|
|
244
|
+
# Useless input, just for binary calls.
|
|
245
|
+
if self.dropout_rate > 1e-5:
|
|
246
|
+
drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
|
|
247
|
+
(bsz, head_num, seq_len, seq_len // 8))
|
|
248
|
+
else:
|
|
249
|
+
drop_mask_bits = None
|
|
250
|
+
# (B, N, S, D) -> (B, S, H)
|
|
251
|
+
query = self.reshape(self.transpose_4d_pre(query, (0, 2, 1, 3)), (bsz, seq_len, -1))
|
|
252
|
+
key = self.reshape(self.transpose_4d_pre(key, (0, 2, 1, 3)), (bsz, seq_len, -1))
|
|
253
|
+
value = self.reshape(self.transpose_4d_pre(value, (0, 2, 1, 3)), (bsz, seq_len, -1))
|
|
254
|
+
attn_mask = self.attn_expand_dims(attn_mask, 1)
|
|
255
|
+
output, _, _ = self.flash_attention(query,
|
|
256
|
+
key,
|
|
257
|
+
value,
|
|
258
|
+
attn_mask,
|
|
259
|
+
drop_mask_bits,
|
|
260
|
+
None,
|
|
261
|
+
None)
|
|
262
|
+
output = self.transpose_4d_post(self.reshape(output, (bsz, seq_len, head_num, head_dim)), (0, 2, 1, 3))
|
|
263
|
+
|
|
223
264
|
return output
|
mindspore/nn/layer/image.py
CHANGED
|
@@ -83,17 +83,17 @@ class ImageGradients(Cell):
|
|
|
83
83
|
_check_input_4d(F.shape(images), "images", self.cls_name)
|
|
84
84
|
batch_size, depth, height, width = P.Shape()(images)
|
|
85
85
|
if height == 1:
|
|
86
|
-
dy =
|
|
86
|
+
dy = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
|
|
87
87
|
else:
|
|
88
88
|
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
|
|
89
|
-
dy_last =
|
|
89
|
+
dy_last = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
|
|
90
90
|
dy = P.Concat(2)((dy, dy_last))
|
|
91
91
|
|
|
92
92
|
if width == 1:
|
|
93
|
-
dx =
|
|
93
|
+
dx = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
|
|
94
94
|
else:
|
|
95
95
|
dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
|
|
96
|
-
dx_last =
|
|
96
|
+
dx_last = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
|
|
97
97
|
dx = P.Concat(3)((dx, dx_last))
|
|
98
98
|
return dy, dx
|
|
99
99
|
|
|
@@ -571,7 +571,8 @@ class PixelShuffle(Cell):
|
|
|
571
571
|
<https://arxiv.org/abs/1609.05158>`_ .
|
|
572
572
|
|
|
573
573
|
Typically, the input is of shape :math:`(*, C \times r^2, H, W)` , and the output is of shape
|
|
574
|
-
:math:`(*, C, H \times r, W \times r)`,
|
|
574
|
+
:math:`(*, C, H \times r, W \times r)`,
|
|
575
|
+
where :math:`r` is an upscale factor and :math:`*` is zero or more batch dimensions.
|
|
575
576
|
|
|
576
577
|
Note:
|
|
577
578
|
The dimension of input Tensor on Ascend should be less than 7.
|
|
@@ -621,7 +622,8 @@ class PixelUnshuffle(Cell):
|
|
|
621
622
|
<https://arxiv.org/abs/1609.05158>`_ .
|
|
622
623
|
|
|
623
624
|
Typically, the input is of shape :math:`(*, C, H \times r, W \times r)` , and the output is of shape
|
|
624
|
-
:math:`(*, C \times r^2, H, W)` ,
|
|
625
|
+
:math:`(*, C \times r^2, H, W)` ,
|
|
626
|
+
where :math:`r` is a downscale factor and :math:`*` is zero or more batch dimensions.
|
|
625
627
|
|
|
626
628
|
Args:
|
|
627
629
|
downscale_factor (int): factor to unshuffle the input, and is a positive integer.
|
mindspore/nn/layer/math.py
CHANGED
|
@@ -223,7 +223,6 @@ class LGamma(Cell):
|
|
|
223
223
|
self.abs = P.Abs()
|
|
224
224
|
self.shape = P.Shape()
|
|
225
225
|
self.dtype = P.DType()
|
|
226
|
-
self.fill = P.Fill()
|
|
227
226
|
self.floor = P.Floor()
|
|
228
227
|
self.equal = P.Equal()
|
|
229
228
|
self.greater = P.Greater()
|
|
@@ -240,7 +239,7 @@ class LGamma(Cell):
|
|
|
240
239
|
if F.is_sequence_value_unknown(self.shape(x)):
|
|
241
240
|
infinity = self.ones_like(x) * F.cast(self.inf, input_dtype)
|
|
242
241
|
else:
|
|
243
|
-
infinity =
|
|
242
|
+
infinity = F.fill(input_dtype, self.shape(x), self.inf)
|
|
244
243
|
|
|
245
244
|
need_to_reflect = self.less(x, 0.5)
|
|
246
245
|
neg_input = -x
|
|
@@ -335,7 +334,6 @@ class DiGamma(Cell):
|
|
|
335
334
|
self.abs = P.Abs()
|
|
336
335
|
self.shape = P.Shape()
|
|
337
336
|
self.dtype = P.DType()
|
|
338
|
-
self.fill = P.Fill()
|
|
339
337
|
self.floor = P.Floor()
|
|
340
338
|
self.equal = P.Equal()
|
|
341
339
|
self.less = P.Less()
|
|
@@ -371,7 +369,7 @@ class DiGamma(Cell):
|
|
|
371
369
|
reduced_input = x + self.abs(self.floor(x + 0.5))
|
|
372
370
|
reflection = y - self.pi * self.cos(self.pi * reduced_input) / self.sin(self.pi * reduced_input)
|
|
373
371
|
real_result = self.select(need_to_reflect, reflection, y)
|
|
374
|
-
nan =
|
|
372
|
+
nan = F.fill(self.dtype(x), self.shape(x), np.nan)
|
|
375
373
|
|
|
376
374
|
return self.select(self.logicaland(self.less(x, 0), self.equal(x, self.floor(x))),
|
|
377
375
|
nan, real_result)
|
|
@@ -391,7 +389,6 @@ def _igamma_series(ax, x, a, enabled):
|
|
|
391
389
|
|
|
392
390
|
logicaland = P.LogicalAnd()
|
|
393
391
|
greater = P.Greater()
|
|
394
|
-
fill = P.Fill()
|
|
395
392
|
shape = P.Shape()
|
|
396
393
|
dtype = P.DType()
|
|
397
394
|
select = P.Select()
|
|
@@ -424,8 +421,8 @@ def _igamma_series(ax, x, a, enabled):
|
|
|
424
421
|
select(enabled, x, vals[4]), select(enabled, dc_da, vals[5]),
|
|
425
422
|
select(enabled, dans_da, vals[6]))
|
|
426
423
|
|
|
427
|
-
ones = fill(dtype(a), shape(a), 1)
|
|
428
|
-
zeros = fill(dtype(a), shape(a), 0)
|
|
424
|
+
ones = F.fill(dtype(a), shape(a), 1)
|
|
425
|
+
zeros = F.fill(dtype(a), shape(a), 0)
|
|
429
426
|
vals = (enabled, a, ones, ones, x, zeros, zeros)
|
|
430
427
|
|
|
431
428
|
vals = _while_helper_func(cond, body, vals)
|
|
@@ -441,7 +438,6 @@ def _igammac_continued_fraction(ax, x, a, enabled):
|
|
|
441
438
|
greater = P.Greater()
|
|
442
439
|
less = P.Less()
|
|
443
440
|
notequal = P.NotEqual()
|
|
444
|
-
fill = P.Fill()
|
|
445
441
|
shape = P.Shape()
|
|
446
442
|
dtype = P.DType()
|
|
447
443
|
select = P.Select()
|
|
@@ -482,7 +478,7 @@ def _igammac_continued_fraction(ax, x, a, enabled):
|
|
|
482
478
|
qk_is_nonzero = notequal(qk, 0)
|
|
483
479
|
r = pk / qk
|
|
484
480
|
|
|
485
|
-
t = select(qk_is_nonzero, abs_x((ans - r) / r), fill(dtype(t), shape(t), 1))
|
|
481
|
+
t = select(qk_is_nonzero, abs_x((ans - r) / r), F.fill(dtype(t), shape(t), 1))
|
|
486
482
|
ans = select(qk_is_nonzero, r, ans)
|
|
487
483
|
|
|
488
484
|
dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c
|
|
@@ -490,7 +486,7 @@ def _igammac_continued_fraction(ax, x, a, enabled):
|
|
|
490
486
|
dans_da_new = select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da)
|
|
491
487
|
grad_conditional = select(qk_is_nonzero,
|
|
492
488
|
abs_x(dans_da_new - dans_da),
|
|
493
|
-
fill(dtype(dans_da), shape(dans_da), 1))
|
|
489
|
+
F.fill(dtype(dans_da), shape(dans_da), 1))
|
|
494
490
|
|
|
495
491
|
pkm2 = pkm1
|
|
496
492
|
pkm1 = pk
|
|
@@ -525,16 +521,16 @@ def _igammac_continued_fraction(ax, x, a, enabled):
|
|
|
525
521
|
|
|
526
522
|
y = 1 - a
|
|
527
523
|
z = x + y + 1
|
|
528
|
-
c = fill(dtype(x), shape(x), 0)
|
|
529
|
-
pkm2 = fill(dtype(x), shape(x), 1)
|
|
524
|
+
c = F.fill(dtype(x), shape(x), 0)
|
|
525
|
+
pkm2 = F.fill(dtype(x), shape(x), 1)
|
|
530
526
|
qkm2 = x
|
|
531
527
|
pkm1 = x + 1
|
|
532
528
|
qkm1 = z * x
|
|
533
529
|
ans = pkm1 / qkm1
|
|
534
|
-
t = fill(dtype(x), shape(x), 1)
|
|
535
|
-
dpkm2_da = fill(dtype(x), shape(x), 0)
|
|
536
|
-
dqkm2_da = fill(dtype(x), shape(x), 0)
|
|
537
|
-
dpkm1_da = fill(dtype(x), shape(x), 0)
|
|
530
|
+
t = F.fill(dtype(x), shape(x), 1)
|
|
531
|
+
dpkm2_da = F.fill(dtype(x), shape(x), 0)
|
|
532
|
+
dqkm2_da = F.fill(dtype(x), shape(x), 0)
|
|
533
|
+
dpkm1_da = F.fill(dtype(x), shape(x), 0)
|
|
538
534
|
dqkm1_da = -x
|
|
539
535
|
dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1
|
|
540
536
|
vals = (enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da)
|
|
@@ -606,7 +602,6 @@ class IGamma(Cell):
|
|
|
606
602
|
self.exp = P.Exp()
|
|
607
603
|
self.select = P.Select()
|
|
608
604
|
self.zeroslike = P.ZerosLike()
|
|
609
|
-
self.fill = P.Fill()
|
|
610
605
|
self.shape = P.Shape()
|
|
611
606
|
self.dtype = P.DType()
|
|
612
607
|
self.lgamma = LGamma()
|
|
@@ -633,7 +628,7 @@ class IGamma(Cell):
|
|
|
633
628
|
1 - _igammac_continued_fraction(ax, x, a, self.logicaland(enabled, use_igammac)),
|
|
634
629
|
_igamma_series(ax, x, a, self.logicaland(enabled, self.logicalnot(use_igammac))))
|
|
635
630
|
output = self.select(x_is_zero, self.zeroslike(output), output)
|
|
636
|
-
output = self.select(domain_error,
|
|
631
|
+
output = self.select(domain_error, F.fill(self.dtype(a), self.shape(a), np.nan), output)
|
|
637
632
|
return output
|
|
638
633
|
|
|
639
634
|
|