xmos-ai-tools 1.3.2.dev80__py3-none-macosx_10_15_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xmos_ai_tools/__init__.py +7 -0
- xmos_ai_tools/io_server/__init__.py +151 -0
- xmos_ai_tools/runtime/__init__.py +0 -0
- xmos_ai_tools/runtime/buildfiles/aitoolslib.cmake +13 -0
- xmos_ai_tools/runtime/buildfiles/aitoolslib.make +8 -0
- xmos_ai_tools/runtime/include/flash_server.h +74 -0
- xmos_ai_tools/runtime/include/flatbuffers/allocator.h +68 -0
- xmos_ai_tools/runtime/include/flatbuffers/array.h +243 -0
- xmos_ai_tools/runtime/include/flatbuffers/base.h +474 -0
- xmos_ai_tools/runtime/include/flatbuffers/bfbs_generator.h +43 -0
- xmos_ai_tools/runtime/include/flatbuffers/buffer.h +142 -0
- xmos_ai_tools/runtime/include/flatbuffers/buffer_ref.h +53 -0
- xmos_ai_tools/runtime/include/flatbuffers/code_generators.h +235 -0
- xmos_ai_tools/runtime/include/flatbuffers/default_allocator.h +64 -0
- xmos_ai_tools/runtime/include/flatbuffers/detached_buffer.h +114 -0
- xmos_ai_tools/runtime/include/flatbuffers/flatbuffer_builder.h +1197 -0
- xmos_ai_tools/runtime/include/flatbuffers/flatbuffers.h +270 -0
- xmos_ai_tools/runtime/include/flatbuffers/flatc.h +111 -0
- xmos_ai_tools/runtime/include/flatbuffers/flexbuffers.h +1897 -0
- xmos_ai_tools/runtime/include/flatbuffers/grpc.h +300 -0
- xmos_ai_tools/runtime/include/flatbuffers/hash.h +127 -0
- xmos_ai_tools/runtime/include/flatbuffers/idl.h +1232 -0
- xmos_ai_tools/runtime/include/flatbuffers/minireflect.h +419 -0
- xmos_ai_tools/runtime/include/flatbuffers/pch/flatc_pch.h +39 -0
- xmos_ai_tools/runtime/include/flatbuffers/pch/pch.h +38 -0
- xmos_ai_tools/runtime/include/flatbuffers/reflection.h +502 -0
- xmos_ai_tools/runtime/include/flatbuffers/reflection_generated.h +1449 -0
- xmos_ai_tools/runtime/include/flatbuffers/registry.h +128 -0
- xmos_ai_tools/runtime/include/flatbuffers/stl_emulation.h +509 -0
- xmos_ai_tools/runtime/include/flatbuffers/string.h +64 -0
- xmos_ai_tools/runtime/include/flatbuffers/struct.h +53 -0
- xmos_ai_tools/runtime/include/flatbuffers/table.h +168 -0
- xmos_ai_tools/runtime/include/flatbuffers/util.h +690 -0
- xmos_ai_tools/runtime/include/flatbuffers/vector.h +370 -0
- xmos_ai_tools/runtime/include/flatbuffers/vector_downward.h +271 -0
- xmos_ai_tools/runtime/include/flatbuffers/verifier.h +283 -0
- xmos_ai_tools/runtime/include/ioserver.h +44 -0
- xmos_ai_tools/runtime/include/lib_nn/api/TransposeConv.h +24 -0
- xmos_ai_tools/runtime/include/lib_nn/api/add_int16.h +27 -0
- xmos_ai_tools/runtime/include/lib_nn/api/add_int16_transform.h +42 -0
- xmos_ai_tools/runtime/include/lib_nn/api/dequantize_int16.h +22 -0
- xmos_ai_tools/runtime/include/lib_nn/api/dequantize_int16_transform.h +34 -0
- xmos_ai_tools/runtime/include/lib_nn/api/expand_8_to_16.h +8 -0
- xmos_ai_tools/runtime/include/lib_nn/api/multiply_int16.h +42 -0
- xmos_ai_tools/runtime/include/lib_nn/api/multiply_int16_transform.h +71 -0
- xmos_ai_tools/runtime/include/lib_nn/api/nn_api.h +15 -0
- xmos_ai_tools/runtime/include/lib_nn/api/nn_bin_types.h +14 -0
- xmos_ai_tools/runtime/include/lib_nn/api/nn_config.h +287 -0
- xmos_ai_tools/runtime/include/lib_nn/api/nn_conv2d_structs.h +72 -0
- xmos_ai_tools/runtime/include/lib_nn/api/nn_image.h +26 -0
- xmos_ai_tools/runtime/include/lib_nn/api/nn_layers.h +303 -0
- xmos_ai_tools/runtime/include/lib_nn/api/nn_op_helper.h +132 -0
- xmos_ai_tools/runtime/include/lib_nn/api/nn_op_utils.h +150 -0
- xmos_ai_tools/runtime/include/lib_nn/api/nn_operator.h +18 -0
- xmos_ai_tools/runtime/include/lib_nn/api/nn_pooling.h +551 -0
- xmos_ai_tools/runtime/include/lib_nn/api/nn_types.h +83 -0
- xmos_ai_tools/runtime/include/lib_nn/api/nn_window_params.h +55 -0
- xmos_ai_tools/runtime/include/lib_nn/api/output_transform_fn_int16.h +54 -0
- xmos_ai_tools/runtime/include/lib_nn/api/output_transform_fn_int16_kernel_transform.h +37 -0
- xmos_ai_tools/runtime/include/lib_nn/api/output_transform_fn_int16_mappings.h +13 -0
- xmos_ai_tools/runtime/include/lib_nn/api/quadratic_approximation.h +82 -0
- xmos_ai_tools/runtime/include/lib_nn/api/quadratic_interpolation.h +23 -0
- xmos_ai_tools/runtime/include/lib_nn/api/quantize_int16.h +22 -0
- xmos_ai_tools/runtime/include/lib_nn/api/quantize_int16_transform.h +33 -0
- xmos_ai_tools/runtime/include/lib_nn/api/version.h +13 -0
- xmos_ai_tools/runtime/include/lib_nn/api/vpu_memmove_word_aligned.h +15 -0
- xmos_ai_tools/runtime/include/lib_nn/api/vpu_memset_256.h +55 -0
- xmos_ai_tools/runtime/include/lib_nn/api/vpu_sim.h +118 -0
- xmos_ai_tools/runtime/include/lib_nn/api/xs3_vpu.h +216 -0
- xmos_ai_tools/runtime/include/lib_nn/api/xs3a_registers.h +2869 -0
- xmos_ai_tools/runtime/include/lib_nn/src/asm/asm_constants.h +41 -0
- xmos_ai_tools/runtime/include/lib_nn/src/asm/window_op_plan.h +25 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/api/fast_flash.h +47 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/api/inference_engine.h +218 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/api/memory_parallel_transport.h +52 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/api/version.h +13 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/api/xcore_config.h +17 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/api/xcore_device_memory.h +62 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/api/xcore_shared_config.h +31 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/src/tflite-xcore-kernels/conv2d_float.h +155 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/src/tflite-xcore-kernels/xcore_common.h +19 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/src/tflite-xcore-kernels/xcore_custom_options.h +28 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/src/tflite-xcore-kernels/xcore_error_reporter.h +32 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/src/tflite-xcore-kernels/xcore_interpreter.h +49 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/src/tflite-xcore-kernels/xcore_ops.h +71 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/src/tflite-xcore-kernels/xcore_profiler.h +49 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/src/tflite-xcore-kernels/xcore_utils.h +160 -0
- xmos_ai_tools/runtime/include/lib_tflite_micro/src/thread_call.h +119 -0
- xmos_ai_tools/runtime/include/lib_xud/lib_xud/api/legacy/usb_defs.h +4 -0
- xmos_ai_tools/runtime/include/lib_xud/lib_xud/api/legacy/usb_device.h +4 -0
- xmos_ai_tools/runtime/include/lib_xud/lib_xud/api/legacy/usb_std_descriptors.h +4 -0
- xmos_ai_tools/runtime/include/lib_xud/lib_xud/api/legacy/usb_std_requests.h +4 -0
- xmos_ai_tools/runtime/include/lib_xud/lib_xud/api/xud.h +518 -0
- xmos_ai_tools/runtime/include/lib_xud/lib_xud/api/xud_conf_default.h +11 -0
- xmos_ai_tools/runtime/include/lib_xud/lib_xud/api/xud_device.h +87 -0
- xmos_ai_tools/runtime/include/lib_xud/lib_xud/api/xud_std_descriptors.h +191 -0
- xmos_ai_tools/runtime/include/lib_xud/lib_xud/api/xud_std_requests.h +120 -0
- xmos_ai_tools/runtime/include/lib_xud/lib_xud/src/user/XUD_USB_Defines.h +70 -0
- xmos_ai_tools/runtime/include/lib_xud/lib_xud/src/user/class/hid.h +23 -0
- xmos_ai_tools/runtime/include/lib_xud/lib_xud/src/user/class/usbaudio10.h +30 -0
- xmos_ai_tools/runtime/include/lib_xud/lib_xud/src/user/class/usbaudio20.h +357 -0
- xmos_ai_tools/runtime/include/lib_xud/lib_xud/src/user/class/usbaudiocommon.h +168 -0
- xmos_ai_tools/runtime/include/signal/micro/kernels/delay_flexbuffers_generated_data.h +25 -0
- xmos_ai_tools/runtime/include/signal/micro/kernels/energy_flexbuffers_generated_data.h +28 -0
- xmos_ai_tools/runtime/include/signal/micro/kernels/fft_flexbuffers_generated_data.h +37 -0
- xmos_ai_tools/runtime/include/signal/micro/kernels/filter_bank_flexbuffers_generated_data.h +25 -0
- xmos_ai_tools/runtime/include/signal/micro/kernels/filter_bank_log_flexbuffers_generated_data.h +27 -0
- xmos_ai_tools/runtime/include/signal/micro/kernels/filter_bank_spectral_subtraction_flexbuffers_generated_data.h +26 -0
- xmos_ai_tools/runtime/include/signal/micro/kernels/framer_flexbuffers_generated_data.h +25 -0
- xmos_ai_tools/runtime/include/signal/micro/kernels/irfft.h +31 -0
- xmos_ai_tools/runtime/include/signal/micro/kernels/overlap_add_flexbuffers_generated_data.h +25 -0
- xmos_ai_tools/runtime/include/signal/micro/kernels/pcan_flexbuffers_generated_data.h +7 -0
- xmos_ai_tools/runtime/include/signal/micro/kernels/rfft.h +31 -0
- xmos_ai_tools/runtime/include/signal/micro/kernels/stacker_flexbuffers_generated_data.h +25 -0
- xmos_ai_tools/runtime/include/signal/micro/kernels/window_flexbuffers_generated_data.h +25 -0
- xmos_ai_tools/runtime/include/signal/src/circular_buffer.h +118 -0
- xmos_ai_tools/runtime/include/signal/src/complex.h +29 -0
- xmos_ai_tools/runtime/include/signal/src/energy.h +38 -0
- xmos_ai_tools/runtime/include/signal/src/fft_auto_scale.h +35 -0
- xmos_ai_tools/runtime/include/signal/src/filter_bank.h +69 -0
- xmos_ai_tools/runtime/include/signal/src/filter_bank_log.h +38 -0
- xmos_ai_tools/runtime/include/signal/src/filter_bank_spectral_subtraction.h +73 -0
- xmos_ai_tools/runtime/include/signal/src/filter_bank_square_root.h +34 -0
- xmos_ai_tools/runtime/include/signal/src/irfft.h +84 -0
- xmos_ai_tools/runtime/include/signal/src/kiss_fft_wrappers/kiss_fft_common.h +49 -0
- xmos_ai_tools/runtime/include/signal/src/kiss_fft_wrappers/kiss_fft_float.h +31 -0
- xmos_ai_tools/runtime/include/signal/src/kiss_fft_wrappers/kiss_fft_int16.h +30 -0
- xmos_ai_tools/runtime/include/signal/src/kiss_fft_wrappers/kiss_fft_int32.h +31 -0
- xmos_ai_tools/runtime/include/signal/src/log.h +30 -0
- xmos_ai_tools/runtime/include/signal/src/max_abs.h +31 -0
- xmos_ai_tools/runtime/include/signal/src/msb.h +32 -0
- xmos_ai_tools/runtime/include/signal/src/overlap_add.h +46 -0
- xmos_ai_tools/runtime/include/signal/src/pcan_argc_fixed.h +41 -0
- xmos_ai_tools/runtime/include/signal/src/rfft.h +85 -0
- xmos_ai_tools/runtime/include/signal/src/square_root.h +32 -0
- xmos_ai_tools/runtime/include/signal/src/window.h +31 -0
- xmos_ai_tools/runtime/include/signal/testdata/fft_test_data.h +48 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/array.h +156 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/builtin_op_data.h +22 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/builtin_ops.h +241 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/c/builtin_op_data.h +20 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/c/c_api_types.h +26 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/c/common.h +30 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/context_util.h +54 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/core/api/error_reporter.h +72 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/core/api/flatbuffer_conversions.h +440 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/core/api/tensor_utils.h +28 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/core/c/builtin_op_data.h +626 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/core/c/c_api_types.h +178 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/core/c/common.h +1496 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/core/macros.h +78 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/bits.h +102 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/fft.h +50 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/fft_io.h +34 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/fft_util.h +34 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/filterbank.h +63 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/filterbank_io.h +35 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/filterbank_util.h +50 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/frontend.h +64 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/frontend_io.h +31 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/frontend_util.h +52 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/kiss_fft_common.h +48 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/kiss_fft_int16.h +33 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/log_lut.h +40 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/log_scale.h +39 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/log_scale_io.h +33 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/log_scale_util.h +45 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/noise_reduction.h +46 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_io.h +36 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_util.h +50 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.h +47 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_util.h +57 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/window.h +49 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/window_io.h +34 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/experimental/microfrontend/lib/window_util.h +45 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/common.h +1358 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/compatibility.h +122 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/cppmath.h +40 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/max.h +35 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/min.h +35 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/optimized/neon_check.h +20 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/portable_tensor.h +141 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/portable_tensor_utils.h +623 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/quantization_util.h +292 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/add.h +561 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/add_n.h +86 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/arg_min_max.h +88 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/batch_matmul.h +275 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h +101 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/binary_function.h +91 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/broadcast_args.h +56 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/broadcast_to.h +97 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/ceil.h +37 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/comparisons.h +271 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/concatenation.h +141 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/conv.h +289 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/cumsum.h +175 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/depth_to_space.h +79 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h +100 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h +319 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/dequantize.h +78 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/div.h +247 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/elu.h +37 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/exp.h +38 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/fill.h +38 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/floor.h +39 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/floor_div.h +35 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/floor_mod.h +44 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/fully_connected.h +323 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/hard_swish.h +168 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/integer_ops/add.h +250 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h +241 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h +291 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h +126 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h +67 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h +121 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h +18 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h +194 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h +264 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h +117 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h +224 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/l2normalization.h +90 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/leaky_relu.h +69 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/log_softmax.h +256 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/logistic.h +132 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/lstm_cell.h +422 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/maximum_minimum.h +64 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/mul.h +267 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/neg.h +37 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/pad.h +169 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/pooling.h +303 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h +333 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h +244 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/prelu.h +111 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h +140 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/quantize.h +89 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/reduce.h +491 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/requantize.h +70 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/resize_bilinear.h +233 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h +102 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/round.h +51 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/select.h +151 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/slice.h +80 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/softmax.h +233 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h +109 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/space_to_depth.h +80 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/strided_slice.h +147 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/sub.h +465 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/tanh.h +129 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/transpose.h +203 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/reference/transpose_conv.h +225 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/runtime_shape.h +168 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/strided_slice_logic.h +278 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/tensor_ctypes.h +42 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/internal/types.h +1096 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/kernel_util.h +341 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/op_macros.h +49 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/kernels/padding.h +115 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/arena_allocator/ibuffer_allocator.h +100 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.h +104 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.h +58 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/arena_allocator/recording_single_arena_buffer_allocator.h +63 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h +144 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/benchmarks/micro_benchmark.h +95 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/compatibility.h +32 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/cortex_m_generic/debug_log_callback.h +49 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/debug_log.h +38 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/examples/micro_speech/micro_model_settings.h +37 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/examples/network_tester/expected_output_data.h +47 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/examples/network_tester/input_data.h +108 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/examples/network_tester/network_model.h +166 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/examples/person_detection/detection_responder.h +32 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/examples/person_detection/image_provider.h +38 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/examples/person_detection/main_functions.h +37 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/examples/person_detection/model_settings.h +35 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/fake_micro_context.h +70 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/flatbuffer_utils.h +65 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/activation_utils.h +57 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/activations.h +64 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/add.h +78 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/arc_mli/mli_function_specializations.h +141 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/arc_mli/mli_interface.h +75 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/arc_mli/mli_slicers.h +56 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/arc_mli/mli_tf_utils.h +310 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.h +145 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h +78 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/ceva/ceva_common.h +24 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/ceva/ceva_tflm_lib.h +613 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/ceva/mcps_macros.h +115 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/ceva/types.h +1286 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/circular_buffer.h +45 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/circular_buffer_flexbuffers_generated_data.h +22 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/conv.h +117 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/conv_test.h +94 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/depthwise_conv.h +80 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/dequantize.h +38 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/detection_postprocess_flexbuffers_generated_data.h +25 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/ethosu.h +28 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/fully_connected.h +112 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/hard_swish.h +30 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/kernel_runner.h +86 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/kernel_util.h +150 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/leaky_relu.h +43 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/logical.h +35 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/logistic.h +42 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/lstm_eval.h +541 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/lstm_eval_test.h +817 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/lstm_shared.h +150 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/micro_ops.h +158 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/micro_tensor_utils.h +56 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/mul.h +74 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/pad.h +27 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/pooling.h +142 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/prelu.h +39 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/quantize.h +37 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/reduce.h +65 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/reshape.h +26 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/softmax.h +67 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/strided_slice.h +40 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/sub.h +60 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/svdf.h +100 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/testdata/conv_test_data.h +37 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/testdata/lstm_test_data.h +579 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/unidirectional_sequence_lstm.h +47 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/xtensa/hifimini/fixedpoint_utils.h +139 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h +216 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/xtensa/lstm_shared.h +78 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/xtensa/xtensa.h +38 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/xtensa/xtensa_add.h +48 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/xtensa/xtensa_conv.h +89 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/xtensa/xtensa_depthwise_conv.h +74 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/xtensa/xtensa_fully_connected.h +78 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/xtensa/xtensa_pad.h +49 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/xtensa/xtensa_pooling.h +76 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/xtensa/xtensa_reduce.h +47 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/xtensa/xtensa_reshape.h +44 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/xtensa/xtensa_softmax.h +58 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/kernels/xtensa/xtensa_svdf.h +39 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/memory_helpers.h +64 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/memory_planner/greedy_memory_planner.h +170 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/memory_planner/linear_memory_planner.h +53 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/memory_planner/memory_plan_struct.h +73 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/memory_planner/micro_memory_planner.h +95 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/memory_planner/non_persistent_buffer_planner_shim.h +133 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_allocation_info.h +138 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_allocator.h +351 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_arena_constants.h +28 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_common.h +38 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_context.h +176 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_graph.h +79 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_interpreter.h +189 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_interpreter_context.h +125 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_interpreter_graph.h +110 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_log.h +42 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_mutable_op_resolver.h +708 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_op_resolver.h +62 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_profiler.h +140 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_profiler_interface.h +38 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_resource_variable.h +89 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_time.h +36 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/micro_utils.h +162 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/mock_micro_graph.h +60 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/python/interpreter/src/python_ops_resolver.h +21 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/python/tflite_size/src/flatbuffer_size.h +30 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/python/tflite_size/src/flatbuffer_size_wrapper.h +33 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/recording_micro_allocator.h +125 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/recording_micro_interpreter.h +69 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/system_setup.h +27 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/test_helper_custom_ops.h +49 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/test_helpers.h +334 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/testing/micro_test.h +267 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/testing/test_conv_model.h +23 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/tflite_bridge/flatbuffer_conversions_bridge.h +45 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h +36 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/tools/benchmarking/log_utils.h +273 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/tools/benchmarking/metrics.h +41 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/micro/tools/benchmarking/op_resolver.h +127 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/portable_type_to_tflitetype.h +75 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/schema/schema_generated.h +24644 -0
- xmos_ai_tools/runtime/include/tensorflow/lite/schema/schema_utils.h +33 -0
- xmos_ai_tools/runtime/include/tile_ram_server.h +38 -0
- xmos_ai_tools/runtime/lib/libhost_xtflitemicro.a +0 -0
- xmos_ai_tools/runtime/lib/libxtflitemicro.a +0 -0
- xmos_ai_tools/xformer/__init__.py +60 -0
- xmos_ai_tools/xformer/flash.py +190 -0
- xmos_ai_tools/xinterpreters/__init__.py +1 -0
- xmos_ai_tools/xinterpreters/exceptions.py +38 -0
- xmos_ai_tools/xinterpreters/host_interpreter.py +652 -0
- xmos_ai_tools/xinterpreters/libs/macos/xtflm_python.1.0.1.dylib +0 -0
- xmos_ai_tools/xinterpreters/libs/macos/xtflm_python.dylib +0 -0
- xmos_ai_tools-1.3.2.dev80.data/data/bin/xcore-opt +0 -0
- xmos_ai_tools-1.3.2.dev80.dist-info/METADATA +33 -0
- xmos_ai_tools-1.3.2.dev80.dist-info/RECORD +395 -0
- xmos_ai_tools-1.3.2.dev80.dist-info/WHEEL +5 -0
- xmos_ai_tools-1.3.2.dev80.dist-info/top_level.txt +1 -0
@@ -0,0 +1,541 @@
|
|
1
|
+
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
// Functions to perform integer evaulation for standard LSTM (e.g., defined in
|
17
|
+
// the keras lstm layer, no peephole etc.). Currently used by the 16 bits
|
18
|
+
// activation case only
|
19
|
+
|
20
|
+
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_GENERAL_H_
|
21
|
+
#define TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_GENERAL_H_
|
22
|
+
#include <algorithm>
|
23
|
+
#include <cstdint>
|
24
|
+
|
25
|
+
#include "tensorflow/lite/c/builtin_op_data.h"
|
26
|
+
#include "tensorflow/lite/c/common.h"
|
27
|
+
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
28
|
+
#include "tensorflow/lite/micro/kernels/lstm_shared.h"
|
29
|
+
#include "tensorflow/lite/micro/micro_log.h"
|
30
|
+
|
31
|
+
namespace tflite_micro {
|
32
|
+
|
33
|
+
// Interface to access all the TempTfLiteTensors of the LSTM kernel during the
|
34
|
+
// preparation phase. Can only be constructed through the constructor to avoid
|
35
|
+
// memory leakage. All TempTfLiteTensors will be deallocated through the
|
36
|
+
// destructor.
|
37
|
+
class LstmTensors {
|
38
|
+
public:
|
39
|
+
LstmTensors(const LstmTensors& other) = delete;
|
40
|
+
LstmTensors& operator=(const LstmTensors& other) = delete;
|
41
|
+
|
42
|
+
LstmTensors(TfLiteContext* context, TfLiteNode* node);
|
43
|
+
~LstmTensors();
|
44
|
+
|
45
|
+
// Verify the LSTM internal tensor properties (e.g., type checks)
|
46
|
+
// Input/output/states/fc weights tensors are required for kernel evaulation.
|
47
|
+
// The state tensors should be variables. Variants of the standard LSTM
|
48
|
+
// are not supported here, therefore their corresponding tensors should be
|
49
|
+
// invalid
|
50
|
+
TfLiteStatus ValidateTensorStatus(TfLiteContext* context) const;
|
51
|
+
|
52
|
+
// Internal tensors. see lstm_shared.h for tensor names
|
53
|
+
const TfLiteTensor* GetInternalTensor(const int tensor_index) const {
|
54
|
+
return internal_tensors_[tensor_index];
|
55
|
+
}
|
56
|
+
|
57
|
+
const TfLiteTensor* HiddenStateTensor() const {
|
58
|
+
return internal_tensors_[kLstmOutputStateTensor];
|
59
|
+
}
|
60
|
+
const TfLiteTensor* CellStateTensor() const {
|
61
|
+
return internal_tensors_[kLstmCellStateTensor];
|
62
|
+
}
|
63
|
+
const TfLiteTensor* OutputTensor() const { return output_tensor_; }
|
64
|
+
|
65
|
+
private:
|
66
|
+
// see lstm_shared.h for tensor names
|
67
|
+
MicroContext* micro_context_;
|
68
|
+
TfLiteTensor* internal_tensors_[24];
|
69
|
+
TfLiteTensor* output_tensor_;
|
70
|
+
};
|
71
|
+
|
72
|
+
// Deduce the size information (Batch (B), Time Steps (T), Input dimension (I),
|
73
|
+
// State dimension (S)) that defines the LSTM using the input and hidden state
|
74
|
+
// tensor
|
75
|
+
LstmSizeInfo CreateLstmSizeInfo(
|
76
|
+
const bool time_major, const TfLiteIntArray* input_tensor_shape,
|
77
|
+
const TfLiteIntArray* hidden_state_tensor_shape);
|
78
|
+
|
79
|
+
TfLiteStatus ValidateWeightTensorSize(TfLiteContext* context,
|
80
|
+
const TfLiteTensor* tensor, int dim1_size,
|
81
|
+
int dim2_size);
|
82
|
+
|
83
|
+
TfLiteStatus ValidateBiasTensorSize(TfLiteContext* context,
|
84
|
+
const TfLiteTensor* tensor, int size);
|
85
|
+
|
86
|
+
// Go through every tensors and make sure their shape match the kernel
|
87
|
+
// configuration
|
88
|
+
TfLiteStatus ValidateTensorSize(TfLiteContext* context,
|
89
|
+
const LstmTensors& tensors,
|
90
|
+
const LstmSizeInfo& size_info);
|
91
|
+
|
92
|
+
// Wrapper function to create gate parameters for the four internal LSTM gates
|
93
|
+
TfLiteStatus CreateGateParams(
|
94
|
+
TfLiteContext* context,
|
95
|
+
/*Input tensors*/
|
96
|
+
const TfLiteTensor* input, const TfLiteTensor* input_weight,
|
97
|
+
const TfLiteTensor* input_bias,
|
98
|
+
/*Hidden state tensors*/
|
99
|
+
const TfLiteTensor* hidden_state, const TfLiteTensor* hidden_state_weight,
|
100
|
+
const TfLiteTensor* hidden_state_bias,
|
101
|
+
/*Scale of the fc output (input to non-linear activation)*/
|
102
|
+
const float nonlinear_activation_input_scale, const TfLiteType cell_type,
|
103
|
+
const tflite_micro::GateParameters& gate_params);
|
104
|
+
|
105
|
+
// Create parameters for element wise multiplication that happens in a) cell
|
106
|
+
// state update ; b) hidden state update
|
107
|
+
// Note that all the output of gates are symmetrically quantized so only scales
|
108
|
+
// are required for input. However, during the hidden state update phase, the
|
109
|
+
// output is the updated hidden state, which is asymmetrically quantized. Thus
|
110
|
+
// output may require zero point
|
111
|
+
tflite_micro::ArithmeticParams CreateInterGateMulParams(const float input1_scale,
|
112
|
+
const float input2_scale,
|
113
|
+
const float output_scale,
|
114
|
+
const TfLiteType output_type,
|
115
|
+
const int output_zp = 0);
|
116
|
+
|
117
|
+
// Create the additional information about the cell state, which include:
|
118
|
+
// cell_state_scale_power: used in integer nonlinear function (e.g., tanh)
|
119
|
+
// quantized_cell_clip: quantized cell clip range
|
120
|
+
CellStateInfo CreateLstmCellStateInfo(const float cell_state_scale,
|
121
|
+
const float cell_clip);
|
122
|
+
|
123
|
+
CellStateInfo CreateLstmCellStateInfoFloat(const float cell_clip);
|
124
|
+
tflite_micro::FullyConnectedParams CreateFCParamsFloat();
|
125
|
+
|
126
|
+
tflite_micro::GateParameters CreateGateParamsFloat();
|
127
|
+
|
128
|
+
tflite_micro::ArithmeticParams CreateInterGateMulParamsFloat();
|
129
|
+
|
130
|
+
TfLiteStatus PrepareGateParametersFloat(TfLiteContext* context,
|
131
|
+
const LstmTensors& lstm_tensors,
|
132
|
+
OpDataLSTM* op_data_lstm);
|
133
|
+
|
134
|
+
TfLiteStatus PrepareGateParametersInteger(TfLiteContext* context,
|
135
|
+
const LstmTensors& lstm_tensors,
|
136
|
+
OpDataLSTM* op_data_lstm);
|
137
|
+
|
138
|
+
LSTMKernelContents CreateLSTMKernelContent(TfLiteContext* context,
|
139
|
+
TfLiteNode* node);
|
140
|
+
|
141
|
+
template <typename CellType>
|
142
|
+
LSTMBuffers<CellType> CreateLSTMBuffers(TfLiteContext* context,
|
143
|
+
const int* buffer_indices) {
|
144
|
+
LSTMBuffers<CellType> buffers;
|
145
|
+
buffers.buffer0 = reinterpret_cast<CellType*>(
|
146
|
+
context->GetScratchBuffer(context, buffer_indices[0]));
|
147
|
+
buffers.buffer1 = reinterpret_cast<CellType*>(
|
148
|
+
context->GetScratchBuffer(context, buffer_indices[1]));
|
149
|
+
buffers.buffer2 = reinterpret_cast<CellType*>(
|
150
|
+
context->GetScratchBuffer(context, buffer_indices[2]));
|
151
|
+
buffers.buffer3 = reinterpret_cast<CellType*>(
|
152
|
+
context->GetScratchBuffer(context, buffer_indices[3]));
|
153
|
+
return buffers;
|
154
|
+
}
|
155
|
+
|
156
|
+
// Since LSTM includes multiple intermediate stages, introducing the internal
|
157
|
+
// namespace to expose them for testing
|
158
|
+
namespace lstm_internal {
|
159
|
+
|
160
|
+
void Sigmoid(const RuntimeShape& data_shape, int16_t* data);
|
161
|
+
|
162
|
+
void Sigmoid(const RuntimeShape& data_shape, float* data);
|
163
|
+
|
164
|
+
void Tanh(int32_t cell_state_scale_power, const RuntimeShape& input_data_shape,
|
165
|
+
int16_t* input_data, const RuntimeShape& output_data_shape,
|
166
|
+
int16_t* output_data);
|
167
|
+
|
168
|
+
void Tanh(int32_t cell_state_scale_power, const RuntimeShape& input_data_shape,
|
169
|
+
float* input_data, const RuntimeShape& output_data_shape,
|
170
|
+
float* output_data);
|
171
|
+
|
172
|
+
void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
|
173
|
+
const int16_t* input1_data, const int16_t* input2_data,
|
174
|
+
int8_t* output_data);
|
175
|
+
|
176
|
+
void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
|
177
|
+
const int16_t* input1_data, const int16_t* input2_data,
|
178
|
+
int16_t* output_data);
|
179
|
+
|
180
|
+
void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
|
181
|
+
const float* input1_data, const float* input2_data,
|
182
|
+
float* output_data);
|
183
|
+
|
184
|
+
void FullyConnected(const FullyConnectedParams& params,
|
185
|
+
const RuntimeShape& input_shape, const int8_t* input_data,
|
186
|
+
const RuntimeShape& filter_shape, const int8_t* filter_data,
|
187
|
+
const RuntimeShape& bias_shape, const int32_t* bias_data,
|
188
|
+
const RuntimeShape& output_shape, int16_t* output_data);
|
189
|
+
|
190
|
+
void FullyConnected(const FullyConnectedParams& params,
|
191
|
+
const RuntimeShape& input_shape, const int16_t* input_data,
|
192
|
+
const RuntimeShape& filter_shape, const int8_t* filter_data,
|
193
|
+
const RuntimeShape& bias_shape, const int64_t* bias_data,
|
194
|
+
const RuntimeShape& output_shape, int16_t* output_data);
|
195
|
+
|
196
|
+
void FullyConnected(const FullyConnectedParams& params,
|
197
|
+
const RuntimeShape& input_shape, const float* input_data,
|
198
|
+
const RuntimeShape& filter_shape, const float* filter_data,
|
199
|
+
const RuntimeShape& bias_shape, const float* bias_data,
|
200
|
+
const RuntimeShape& output_shape, float* output_data);
|
201
|
+
|
202
|
+
void AddElementWise(const int16_t* input_1, const int16_t* input_2, int n_batch,
|
203
|
+
int n_input, int16_t* output);
|
204
|
+
|
205
|
+
void AddElementWise(const float* input_1, const float* input_2, int n_batch,
|
206
|
+
int n_input, float* output);
|
207
|
+
|
208
|
+
void Clipping(const int v_size, const CellStateInfo& cell_state_info,
|
209
|
+
int16_t* vector);
|
210
|
+
|
211
|
+
void Clipping(const int v_size, const CellStateInfo& cell_state_info,
|
212
|
+
float* vector);
|
213
|
+
|
214
|
+
// Manages the slice position (offset), slice length (sliced tensor shape),
|
215
|
+
// and update rules for input/output/hidden state/cell state tensors at each
|
216
|
+
// time step.
|
217
|
+
class LstmStepManager {
|
218
|
+
public:
|
219
|
+
LstmStepManager() = delete;
|
220
|
+
// Does not take any ownership, and all pointers must refer to valid objects
|
221
|
+
// that outlive the one constructed.
|
222
|
+
explicit LstmStepManager(const LstmSizeInfo* size_info)
|
223
|
+
: size_info_(*size_info) {}
|
224
|
+
|
225
|
+
void UpdateTime();
|
226
|
+
void UpdateBatch();
|
227
|
+
|
228
|
+
void ResetTime() { current_time_ = 0; }
|
229
|
+
RuntimeShape InputShape() const;
|
230
|
+
RuntimeShape StateShape() const;
|
231
|
+
|
232
|
+
int InputOffset() const { return input_offset_; }
|
233
|
+
int OutputOffset() const { return output_offset_; }
|
234
|
+
int HiddenStateOffset() const { return hidden_state_offset_; }
|
235
|
+
int CellStateOffset() const { return cell_state_offset_; }
|
236
|
+
|
237
|
+
private:
|
238
|
+
int current_time_ = 0;
|
239
|
+
int current_batch_ = 0;
|
240
|
+
int input_offset_ = 0;
|
241
|
+
int output_offset_ = 0;
|
242
|
+
int hidden_state_offset_ = 0;
|
243
|
+
int cell_state_offset_ = 0;
|
244
|
+
// Sizeinfo is from LstmOpData, which reside in the memory arena
|
245
|
+
// (guarante to outlast LSTMStepManager, which reside in stack)
|
246
|
+
const LstmSizeInfo& size_info_;
|
247
|
+
};
|
248
|
+
|
249
|
+
// Calculates a single LSTM gate.
|
250
|
+
// Implements the following formula:
|
251
|
+
// gate = activate(FC(input) + FC(recurrent))
|
252
|
+
// Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
|
253
|
+
template <typename ActivationType, typename WeightType, typename CellType,
|
254
|
+
typename BiasType>
|
255
|
+
void CalculateLstmGate(
|
256
|
+
const LstmStepManager& step_info, const GateParameters& gate_params,
|
257
|
+
// Input FC
|
258
|
+
const TfLiteEvalTensor* input, const TfLiteEvalTensor* input_weight,
|
259
|
+
const TfLiteEvalTensor* input_bias,
|
260
|
+
// Recurrent FC
|
261
|
+
const TfLiteEvalTensor* recurrent, const TfLiteEvalTensor* recurrent_weight,
|
262
|
+
const TfLiteEvalTensor* recurrent_bias,
|
263
|
+
// Output
|
264
|
+
CellType* gate_output,
|
265
|
+
// Scratch arrays
|
266
|
+
CellType* fc_output_buffer, const TfLiteFusedActivation activation) {
|
267
|
+
const auto gate_output_shape = step_info.StateShape();
|
268
|
+
// Check offset validity to avoid memory overflow
|
269
|
+
TFLITE_DCHECK_LE(step_info.InputOffset() + step_info.InputShape().FlatSize(),
|
270
|
+
tflite_micro::micro::GetTensorShape(input).FlatSize());
|
271
|
+
TFLITE_DCHECK_LE(
|
272
|
+
step_info.HiddenStateOffset() + step_info.StateShape().FlatSize(),
|
273
|
+
tflite_micro::micro::GetTensorShape(recurrent).FlatSize());
|
274
|
+
|
275
|
+
// Input FC
|
276
|
+
FullyConnected(gate_params.input_fc_params, step_info.InputShape(),
|
277
|
+
tflite_micro::micro::GetTensorData<ActivationType>(input) +
|
278
|
+
step_info.InputOffset(),
|
279
|
+
micro::GetTensorShape(input_weight),
|
280
|
+
tflite_micro::micro::GetTensorData<WeightType>(input_weight),
|
281
|
+
tflite_micro::micro::GetTensorShape(input_bias),
|
282
|
+
tflite_micro::micro::GetOptionalTensorData<BiasType>(input_bias),
|
283
|
+
gate_output_shape, gate_output);
|
284
|
+
|
285
|
+
// Recurrent FC
|
286
|
+
FullyConnected(gate_params.recurrent_fc_params, step_info.StateShape(),
|
287
|
+
tflite_micro::micro::GetTensorData<ActivationType>(recurrent) +
|
288
|
+
step_info.HiddenStateOffset(),
|
289
|
+
tflite_micro::micro::GetTensorShape(recurrent_weight),
|
290
|
+
tflite_micro::micro::GetTensorData<WeightType>(recurrent_weight),
|
291
|
+
tflite_micro::micro::GetTensorShape(recurrent_bias),
|
292
|
+
tflite_micro::micro::GetOptionalTensorData<BiasType>(recurrent_bias),
|
293
|
+
gate_output_shape, fc_output_buffer);
|
294
|
+
|
295
|
+
AddElementWise(gate_output, fc_output_buffer,
|
296
|
+
/*n_batch=*/gate_output_shape.DimsData()[0],
|
297
|
+
/*n_state=*/gate_output_shape.DimsData()[1], gate_output);
|
298
|
+
// Apply activation
|
299
|
+
switch (activation) {
|
300
|
+
case kTfLiteActSigmoid:
|
301
|
+
Sigmoid(gate_output_shape, gate_output);
|
302
|
+
break;
|
303
|
+
case kTfLiteActTanh: {
|
304
|
+
// Set the scale power to -12 to avoid shift
|
305
|
+
Tanh(/*cell_state_scale_power=*/-12, gate_output_shape, gate_output,
|
306
|
+
gate_output_shape, gate_output);
|
307
|
+
} break;
|
308
|
+
default:
|
309
|
+
// Only Sigmoid or Tanh is used.
|
310
|
+
TFLITE_ASSERT_FALSE;
|
311
|
+
}
|
312
|
+
}
|
313
|
+
|
314
|
+
// Update the cell state using the output from the forget gate, input gate, and
|
315
|
+
// cell gate Formula: updated_cell_state = forget_gate_output*cell_state +
|
316
|
+
// input_gate_output * cell_gate_output, where * denotes element wise
|
317
|
+
// multiplication
|
318
|
+
template <typename CellType>
|
319
|
+
void UpdateLstmCell(const LstmStepManager& step_info,
|
320
|
+
TfLiteEvalTensor* cell_state,
|
321
|
+
// Gate outputs
|
322
|
+
CellType* forget_gate_output,
|
323
|
+
const CellType* input_gate_output,
|
324
|
+
const CellType* cell_gate_output,
|
325
|
+
// Mul parameters
|
326
|
+
const ArithmeticParams& forget_cell_mul_params,
|
327
|
+
const ArithmeticParams& input_mul_params,
|
328
|
+
const CellStateInfo& cell_state_info, CellType* buffer) {
|
329
|
+
// Check offset validity to avoid memory overflow
|
330
|
+
TFLITE_DCHECK_LE(
|
331
|
+
step_info.CellStateOffset() + step_info.StateShape().FlatSize(),
|
332
|
+
tflite_micro::micro::GetTensorShape(cell_state).FlatSize());
|
333
|
+
|
334
|
+
auto cell_state_shape = step_info.StateShape();
|
335
|
+
// Forget Gate x Cell State
|
336
|
+
Mul(cell_state_shape, forget_cell_mul_params, forget_gate_output,
|
337
|
+
tflite_micro::micro::GetTensorData<CellType>(cell_state) +
|
338
|
+
step_info.CellStateOffset(),
|
339
|
+
tflite_micro::micro::GetTensorData<CellType>(cell_state) +
|
340
|
+
step_info.CellStateOffset());
|
341
|
+
// Input Gate x Cell Gate
|
342
|
+
Mul(cell_state_shape, input_mul_params, input_gate_output, cell_gate_output,
|
343
|
+
buffer);
|
344
|
+
|
345
|
+
// Update the cell state
|
346
|
+
AddElementWise(tflite_micro::micro::GetTensorData<CellType>(cell_state) +
|
347
|
+
step_info.CellStateOffset(),
|
348
|
+
buffer,
|
349
|
+
/*n_batch=*/cell_state_shape.DimsData()[0],
|
350
|
+
/*n_state=*/cell_state_shape.DimsData()[1],
|
351
|
+
tflite_micro::micro::GetTensorData<CellType>(cell_state) +
|
352
|
+
step_info.CellStateOffset());
|
353
|
+
|
354
|
+
if (cell_state_info.cell_clip > 0) {
|
355
|
+
Clipping(cell_state_shape.FlatSize(), cell_state_info,
|
356
|
+
tflite_micro::micro::GetTensorData<CellType>(cell_state) +
|
357
|
+
step_info.CellStateOffset());
|
358
|
+
}
|
359
|
+
}
|
360
|
+
|
361
|
+
// Update the hidden state of the LSTM kernel using the following formula:
|
362
|
+
// updated_hidden_state = Tanh(updated_cell_state) * output_gate_output, * means
|
363
|
+
// element wise multiplication
|
364
|
+
template <typename CellType, typename ActivationType>
|
365
|
+
void UpdateLstmHidden(const LstmStepManager& step_info,
|
366
|
+
TfLiteEvalTensor* cell_state,
|
367
|
+
TfLiteEvalTensor* hidden_state,
|
368
|
+
const CellType* output_gate_output,
|
369
|
+
const ArithmeticParams& mul_params,
|
370
|
+
int32_t cell_state_scale_power, CellType* buffer) {
|
371
|
+
// Check offset validity to avoid memory overflow
|
372
|
+
TFLITE_DCHECK_LE(
|
373
|
+
step_info.CellStateOffset() + step_info.StateShape().FlatSize(),
|
374
|
+
tflite_micro::micro::GetTensorShape(cell_state).FlatSize());
|
375
|
+
TFLITE_DCHECK_LE(
|
376
|
+
step_info.HiddenStateOffset() + step_info.StateShape().FlatSize(),
|
377
|
+
tflite_micro::micro::GetTensorShape(hidden_state).FlatSize());
|
378
|
+
|
379
|
+
auto cell_state_shape = step_info.StateShape();
|
380
|
+
CellType* cell_state_data =
|
381
|
+
tflite_micro::micro::GetTensorData<CellType>(cell_state) +
|
382
|
+
step_info.CellStateOffset();
|
383
|
+
// Tanh(cell_state)
|
384
|
+
Tanh(cell_state_scale_power, cell_state_shape, cell_state_data,
|
385
|
+
cell_state_shape, buffer);
|
386
|
+
// Update the hidden state
|
387
|
+
Mul(cell_state_shape, mul_params, buffer, output_gate_output,
|
388
|
+
tflite_micro::micro::GetTensorData<ActivationType>(hidden_state) +
|
389
|
+
step_info.HiddenStateOffset());
|
390
|
+
}
|
391
|
+
|
392
|
+
template <typename ActivationType, typename WeightType, typename CellType,
|
393
|
+
typename BiasType>
|
394
|
+
void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
|
395
|
+
LSTMKernelContents& kernel_content,
|
396
|
+
const LSTMBuffers<CellType>& buffers) {
|
397
|
+
/*Step1: Calculate gate outputs to prepare cell state update*/
|
398
|
+
CellType* gate_internal_buffer = buffers.buffer3;
|
399
|
+
CellType* forget_gate_output = buffers.buffer0;
|
400
|
+
CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
|
401
|
+
step_info, op_data.forget_gate_parameters,
|
402
|
+
// Input FC
|
403
|
+
kernel_content.GetInternalTensor(tflite_micro::kLstmInputTensor),
|
404
|
+
kernel_content.GetInternalTensor(tflite_micro::kLstmInputToForgetWeightsTensor),
|
405
|
+
kernel_content.GetInternalTensor(tflite_micro::kLstmForgetGateBiasTensor),
|
406
|
+
// Recurrent FC
|
407
|
+
kernel_content.HiddenStateTensor(),
|
408
|
+
kernel_content.GetInternalTensor(
|
409
|
+
tflite_micro::kLstmRecurrentToForgetWeightsTensor),
|
410
|
+
/*recurrent_bias*/ nullptr,
|
411
|
+
// Output
|
412
|
+
forget_gate_output,
|
413
|
+
// Scratch arrays
|
414
|
+
gate_internal_buffer, kTfLiteActSigmoid);
|
415
|
+
|
416
|
+
// Input Gate calculation;
|
417
|
+
CellType* input_gate_output = buffers.buffer1;
|
418
|
+
CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
|
419
|
+
step_info, op_data.input_gate_parameters,
|
420
|
+
// Input FC
|
421
|
+
kernel_content.GetInternalTensor(tflite_micro::kLstmInputTensor),
|
422
|
+
kernel_content.GetInternalTensor(tflite_micro::kLstmInputToInputWeightsTensor),
|
423
|
+
kernel_content.GetInternalTensor(tflite_micro::kLstmInputGateBiasTensor),
|
424
|
+
// Recurrent FC
|
425
|
+
kernel_content.HiddenStateTensor(),
|
426
|
+
kernel_content.GetInternalTensor(
|
427
|
+
tflite_micro::kLstmRecurrentToInputWeightsTensor),
|
428
|
+
/*recurrent_bias*/ nullptr,
|
429
|
+
// Output
|
430
|
+
input_gate_output,
|
431
|
+
// Scratch arrays
|
432
|
+
gate_internal_buffer, kTfLiteActSigmoid);
|
433
|
+
|
434
|
+
// Cell Gate calculation
|
435
|
+
CellType* cell_gate_output = buffers.buffer2;
|
436
|
+
CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
|
437
|
+
step_info, op_data.cell_gate_parameters,
|
438
|
+
// Input FC
|
439
|
+
kernel_content.GetInternalTensor(tflite_micro::kLstmInputTensor),
|
440
|
+
kernel_content.GetInternalTensor(tflite_micro::kLstmInputToCellWeightsTensor),
|
441
|
+
kernel_content.GetInternalTensor(tflite_micro::kLstmCellGateBiasTensor),
|
442
|
+
// Recurrent FC
|
443
|
+
kernel_content.HiddenStateTensor(),
|
444
|
+
kernel_content.GetInternalTensor(
|
445
|
+
tflite_micro::kLstmRecurrentToCellWeightsTensor),
|
446
|
+
/*recurrent_bias*/ nullptr,
|
447
|
+
// Output
|
448
|
+
cell_gate_output,
|
449
|
+
// Scratch arrays
|
450
|
+
gate_internal_buffer, op_data.cell_gate_nonlinear_type);
|
451
|
+
|
452
|
+
/*Step2: update the cell state */
|
453
|
+
const InterGateParameters& inter_gate_params = op_data.inter_gate_parameters;
|
454
|
+
CellType* updated_input_buffer = buffers.buffer1; // reuse buffer
|
455
|
+
|
456
|
+
UpdateLstmCell<CellType>(step_info, kernel_content.CellStateTensor(),
|
457
|
+
forget_gate_output, input_gate_output,
|
458
|
+
cell_gate_output,
|
459
|
+
inter_gate_params.forget_cell_mul_params,
|
460
|
+
inter_gate_params.input_mul_params,
|
461
|
+
op_data.cell_state_info, updated_input_buffer);
|
462
|
+
|
463
|
+
/*Step3: update the hidden state */
|
464
|
+
CellType* output_gate_output = buffers.buffer1; // reuse buffer
|
465
|
+
CalculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
|
466
|
+
step_info, op_data.output_gate_parameters,
|
467
|
+
// Input FC
|
468
|
+
kernel_content.GetInternalTensor(tflite_micro::kLstmInputTensor),
|
469
|
+
kernel_content.GetInternalTensor(tflite_micro::kLstmInputToOutputWeightsTensor),
|
470
|
+
kernel_content.GetInternalTensor(tflite_micro::kLstmOutputGateBiasTensor),
|
471
|
+
// Recurrent FC
|
472
|
+
kernel_content.HiddenStateTensor(),
|
473
|
+
kernel_content.GetInternalTensor(
|
474
|
+
tflite_micro::kLstmRecurrentToOutputWeightsTensor),
|
475
|
+
/*recurrent_bias*/ nullptr,
|
476
|
+
// Output
|
477
|
+
output_gate_output,
|
478
|
+
// Scratch arrays
|
479
|
+
gate_internal_buffer, kTfLiteActSigmoid);
|
480
|
+
|
481
|
+
CellType* tanh_activated_cell_buffer = buffers.buffer0; // reuse buffer
|
482
|
+
tflite_micro::lstm_internal::UpdateLstmHidden<CellType, ActivationType>(
|
483
|
+
step_info, kernel_content.CellStateTensor(),
|
484
|
+
kernel_content.HiddenStateTensor(), output_gate_output,
|
485
|
+
inter_gate_params.output_mul_params,
|
486
|
+
op_data.cell_state_info.cell_state_scale_power,
|
487
|
+
tanh_activated_cell_buffer);
|
488
|
+
|
489
|
+
/*Step4: copy the update the hidden state to output*/
|
490
|
+
// Check offset validity to avoid memory overflow
|
491
|
+
TFLITE_DCHECK_LE(
|
492
|
+
step_info.OutputOffset() + step_info.StateShape().FlatSize(),
|
493
|
+
tflite_micro::micro::GetTensorShape(kernel_content.output_tensor).FlatSize());
|
494
|
+
// record the output (from the updated hidden state)
|
495
|
+
ActivationType* output_ptr = tflite_micro::micro::GetTensorData<ActivationType>(
|
496
|
+
kernel_content.output_tensor);
|
497
|
+
const auto* hidden_state = kernel_content.HiddenStateTensor();
|
498
|
+
std::memcpy(output_ptr + step_info.OutputOffset(),
|
499
|
+
tflite_micro::micro::GetTensorData<ActivationType>(hidden_state) +
|
500
|
+
step_info.HiddenStateOffset(),
|
501
|
+
step_info.StateShape().FlatSize() * sizeof(ActivationType));
|
502
|
+
}
|
503
|
+
|
504
|
+
} // namespace lstm_internal
|
505
|
+
|
506
|
+
// Evaulate the LSTM kernel with (potential) multi-steps and multi-batch input
|
507
|
+
// Since
|
508
|
+
template <typename ActivationType, typename WeightType, typename CellType,
|
509
|
+
typename BiasType>
|
510
|
+
TfLiteStatus EvalLstm(const OpDataLSTM& op_data,
|
511
|
+
LSTMKernelContents& kernel_content,
|
512
|
+
const LSTMBuffers<CellType>& buffers) {
|
513
|
+
lstm_internal::LstmStepManager step_info(&op_data.size_info);
|
514
|
+
const auto& size_info = op_data.size_info;
|
515
|
+
// time is the first dimention, enable batch computation
|
516
|
+
if (size_info.time_major) {
|
517
|
+
for (int t = 0; t < size_info.time_steps; t++) {
|
518
|
+
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
|
519
|
+
step_info, op_data, kernel_content, buffers);
|
520
|
+
// prepare for the next time step
|
521
|
+
step_info.UpdateTime();
|
522
|
+
}
|
523
|
+
} else {
|
524
|
+
// batch first, unable to size the input data. single batch inference
|
525
|
+
for (int b = 0; b < size_info.batch_size; b++) {
|
526
|
+
for (int t = 0; t < size_info.time_steps; t++) {
|
527
|
+
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
|
528
|
+
step_info, op_data, kernel_content, buffers);
|
529
|
+
// prepare for the next time step
|
530
|
+
step_info.UpdateTime();
|
531
|
+
}
|
532
|
+
// prepare for the next batch
|
533
|
+
step_info.UpdateBatch();
|
534
|
+
step_info.ResetTime();
|
535
|
+
}
|
536
|
+
}
|
537
|
+
return kTfLiteOk;
|
538
|
+
}
|
539
|
+
} // namespace tflite_micro
|
540
|
+
|
541
|
+
#endif // TENSORFLOW_LITE_MICRO_KERNELS_LSTM_EVAL_16ACT_H_
|