react-native-executorch 0.9.0 → 0.9.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. package/android/libs/classes.jar +0 -0
  2. package/common/rnexecutorch/host_objects/JsiConversions.h +43 -0
  3. package/common/rnexecutorch/models/llm/LLM.cpp +55 -42
  4. package/common/rnexecutorch/models/llm/LLM.h +4 -3
  5. package/common/rnexecutorch/models/llm/Types.h +23 -0
  6. package/common/runner/base_llm_runner.cpp +10 -3
  7. package/common/runner/base_llm_runner.h +1 -0
  8. package/common/runner/constants.h +15 -1
  9. package/common/runner/encoders/audio_encoder.cpp +111 -0
  10. package/common/runner/encoders/audio_encoder.h +40 -0
  11. package/common/runner/encoders/vision_encoder.cpp +13 -5
  12. package/common/runner/encoders/vision_encoder.h +15 -2
  13. package/common/runner/irunner.h +5 -0
  14. package/common/runner/multimodal_decoder_runner.h +50 -1
  15. package/common/runner/multimodal_input.h +16 -1
  16. package/common/runner/multimodal_prefiller.cpp +374 -64
  17. package/common/runner/multimodal_prefiller.h +57 -6
  18. package/common/runner/multimodal_runner.cpp +19 -12
  19. package/common/runner/multimodal_runner.h +1 -1
  20. package/common/runner/sampler.cpp +126 -39
  21. package/common/runner/sampler.h +13 -5
  22. package/common/runner/text_decoder_runner.cpp +1 -4
  23. package/common/runner/text_decoder_runner.h +3 -2
  24. package/common/runner/text_prefiller.cpp +8 -8
  25. package/common/runner/text_prefiller.h +8 -1
  26. package/common/runner/text_runner.cpp +35 -9
  27. package/common/runner/text_token_generator.h +2 -3
  28. package/common/runner/util.h +0 -1
  29. package/lib/module/constants/llmDefaults.js +1 -1
  30. package/lib/module/constants/llmDefaults.js.map +1 -1
  31. package/lib/module/constants/modelRegistry.js +62 -3
  32. package/lib/module/constants/modelRegistry.js.map +1 -1
  33. package/lib/module/constants/modelUrls.js +62 -6
  34. package/lib/module/constants/modelUrls.js.map +1 -1
  35. package/lib/module/controllers/LLMController.js +69 -20
  36. package/lib/module/controllers/LLMController.js.map +1 -1
  37. package/lib/module/hooks/natural_language_processing/useLLM.js +1 -5
  38. package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
  39. package/lib/module/modules/computer_vision/PoseEstimationModule.js +13 -1
  40. package/lib/module/modules/computer_vision/PoseEstimationModule.js.map +1 -1
  41. package/lib/module/modules/natural_language_processing/LLMModule.js +12 -7
  42. package/lib/module/modules/natural_language_processing/LLMModule.js.map +1 -1
  43. package/lib/module/types/llm.js +11 -0
  44. package/lib/module/types/llm.js.map +1 -1
  45. package/lib/module/types/poseEstimation.js.map +1 -1
  46. package/lib/typescript/constants/llmDefaults.d.ts +1 -1
  47. package/lib/typescript/constants/llmDefaults.d.ts.map +1 -1
  48. package/lib/typescript/constants/modelRegistry.d.ts +38 -1
  49. package/lib/typescript/constants/modelRegistry.d.ts.map +1 -1
  50. package/lib/typescript/constants/modelUrls.d.ts +52 -12
  51. package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
  52. package/lib/typescript/controllers/LLMController.d.ts +7 -9
  53. package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
  54. package/lib/typescript/modules/computer_vision/PoseEstimationModule.d.ts +6 -0
  55. package/lib/typescript/modules/computer_vision/PoseEstimationModule.d.ts.map +1 -1
  56. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts +6 -3
  57. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts.map +1 -1
  58. package/lib/typescript/types/llm.d.ts +63 -36
  59. package/lib/typescript/types/llm.d.ts.map +1 -1
  60. package/lib/typescript/types/poseEstimation.d.ts +3 -0
  61. package/lib/typescript/types/poseEstimation.d.ts.map +1 -1
  62. package/package.json +1 -1
  63. package/react-native-executorch.podspec +6 -0
  64. package/src/constants/llmDefaults.ts +1 -1
  65. package/src/constants/modelRegistry.ts +62 -2
  66. package/src/constants/modelUrls.ts +69 -6
  67. package/src/controllers/LLMController.ts +89 -40
  68. package/src/hooks/natural_language_processing/useLLM.ts +5 -6
  69. package/src/modules/computer_vision/PoseEstimationModule.ts +12 -0
  70. package/src/modules/natural_language_processing/LLMModule.ts +19 -8
  71. package/src/types/llm.ts +64 -34
  72. package/src/types/poseEstimation.ts +10 -4
  73. package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
  74. package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
  75. package/third-party/include/executorch/ExecuTorch.h +2 -0
  76. package/third-party/include/executorch/ExecuTorchModule.h +46 -0
  77. package/third-party/include/executorch/extension/data_loader/buffer_data_loader.h +4 -3
  78. package/third-party/include/executorch/extension/data_loader/mman.h +46 -0
  79. package/third-party/include/executorch/extension/data_loader/mmap_data_loader.h +4 -0
  80. package/third-party/include/executorch/extension/data_loader/shared_ptr_data_loader.h +7 -3
  81. package/third-party/include/executorch/extension/module/module.h +47 -8
  82. package/third-party/include/executorch/extension/tensor/tensor_ptr.h +17 -5
  83. package/third-party/include/executorch/kernels/optimized/Functions.h +12 -0
  84. package/third-party/include/executorch/kernels/optimized/NativeFunctions.h +4 -0
  85. package/third-party/include/executorch/kernels/portable/Functions.h +18 -0
  86. package/third-party/include/executorch/kernels/portable/NativeFunctions.h +6 -0
  87. package/third-party/include/executorch/runtime/backend/backend_options_map.h +37 -0
  88. package/third-party/include/executorch/runtime/core/array_ref.h +3 -1
  89. package/third-party/include/executorch/runtime/core/error.h +1 -0
  90. package/third-party/include/executorch/runtime/core/evalue.h +256 -9
  91. package/third-party/include/executorch/runtime/core/exec_aten/exec_aten.h +24 -0
  92. package/third-party/include/executorch/runtime/core/hierarchical_allocator.h +9 -6
  93. package/third-party/include/executorch/runtime/core/portable_type/device.h +3 -4
  94. package/third-party/include/executorch/runtime/core/portable_type/tensor_impl.h +31 -1
  95. package/third-party/include/executorch/runtime/executor/method.h +9 -3
  96. package/third-party/include/executorch/runtime/executor/method_meta.h +14 -0
  97. package/third-party/include/executorch/runtime/executor/platform_memory_allocator.h +12 -2
  98. package/third-party/include/executorch/runtime/executor/program.h +3 -1
  99. package/third-party/include/executorch/runtime/executor/tensor_parser.h +5 -1
  100. package/third-party/include/executorch/runtime/kernel/operator_registry.h +9 -0
  101. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  102. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
  103. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/mlx.metallib +0 -0
  104. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  105. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
  106. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/mlx.metallib +0 -0
@@ -38,7 +38,7 @@ import { RnExecutorchErrorCode } from '../errors/ErrorCodes';
38
38
  * compile-time error.
39
39
  * @category Utils
40
40
  */
41
- export type Backend = 'xnnpack' | 'coreml' | 'vulkan' | 'qnn';
41
+ export type Backend = 'xnnpack' | 'coreml' | 'vulkan' | 'qnn' | 'mlx';
42
42
 
43
43
  /**
44
44
  * Options for a `models` accessor call.
@@ -78,7 +78,7 @@ type ConfigOf<V> = Extract<
78
78
  >;
79
79
  type BackendsOf<V> = Extract<keyof V, Backend>;
80
80
 
81
- const BACKEND_ORDER: Backend[] = ['xnnpack', 'coreml', 'vulkan', 'qnn'];
81
+ const BACKEND_ORDER: Backend[] = ['xnnpack', 'coreml', 'mlx', 'vulkan', 'qnn'];
82
82
 
83
83
  function firstBackend(variants: AnyVariantMap): Backend {
84
84
  for (const b of BACKEND_ORDER) {
@@ -181,6 +181,33 @@ function tts<C extends TextToSpeechModelConfig>(c: C): () => C {
181
181
  // Per-backend variant maps for models that ship more than one backend.
182
182
  // ─────────────────────────────────────────────────────────────────────────────
183
183
 
184
+ const GEMMA4_E2B_VARIANTS = {
185
+ mlx: {
186
+ base: {
187
+ modelName: 'gemma4-e2b' as const,
188
+ modelSource: M.GEMMA4_E2B_MLX_MODEL,
189
+ tokenizerSource: M.GEMMA4_E2B_TOKENIZER,
190
+ tokenizerConfigSource: M.GEMMA4_E2B_TOKENIZER_CONFIG,
191
+ },
192
+ },
193
+ xnnpack: {
194
+ base: {
195
+ modelName: 'gemma4-e2b' as const,
196
+ modelSource: M.GEMMA4_E2B_XNNPACK_MODEL,
197
+ tokenizerSource: M.GEMMA4_E2B_TOKENIZER,
198
+ tokenizerConfigSource: M.GEMMA4_E2B_TOKENIZER_CONFIG,
199
+ },
200
+ },
201
+ vulkan: {
202
+ base: {
203
+ modelName: 'gemma4-e2b' as const,
204
+ modelSource: M.GEMMA4_E2B_VULKAN_MODEL,
205
+ tokenizerSource: M.GEMMA4_E2B_TOKENIZER,
206
+ tokenizerConfigSource: M.GEMMA4_E2B_TOKENIZER_CONFIG,
207
+ },
208
+ },
209
+ };
210
+
184
211
  const EFFICIENTNET_V2_S_VARIANTS = {
185
212
  xnnpack: {
186
213
  base: {
@@ -249,6 +276,31 @@ const RF_DETR_NANO_SEG_VARIANTS = {
249
276
  },
250
277
  };
251
278
 
279
+ // RF-DETR Keypoint (pose estimation) — BETA preview. Configs mirror the
280
+ // All three backends ship fp32
281
+ // (non-quantized); this entry may be re-exported under a different constant
282
+ // once more RF-DETR keypoint weights are released.
283
+ const RF_DETR_KEYPOINT_PREVIEW_VARIANTS = {
284
+ xnnpack: {
285
+ base: {
286
+ modelName: 'rfdetr-keypoint-preview' as const,
287
+ modelSource: M.RF_DETR_KEYPOINT_PREVIEW_XNNPACK_FP32_MODEL,
288
+ },
289
+ },
290
+ coreml: {
291
+ base: {
292
+ modelName: 'rfdetr-keypoint-preview' as const,
293
+ modelSource: M.RF_DETR_KEYPOINT_PREVIEW_COREML_FP32_MODEL,
294
+ },
295
+ },
296
+ mlx: {
297
+ base: {
298
+ modelName: 'rfdetr-keypoint-preview' as const,
299
+ modelSource: M.RF_DETR_KEYPOINT_PREVIEW_MLX_FP32_MODEL,
300
+ },
301
+ },
302
+ };
303
+
252
304
  const FASTSAM_S_VARIANTS = {
253
305
  xnnpack: {
254
306
  base: {
@@ -496,10 +548,15 @@ export const models = {
496
548
  M.LFM2_5_1_2B_INSTRUCT_QUANTIZED
497
549
  ),
498
550
  bielik_v3_0_1_5b: pair(M.BIELIK_V3_0_1_5B, M.BIELIK_V3_0_1_5B_QUANTIZED),
551
+ gemma4_e2b: variant(GEMMA4_E2B_VARIANTS, {
552
+ ios: 'mlx',
553
+ android: 'vulkan',
554
+ }),
499
555
  // Multimodal LLMs — same hook/module as plain LLMs, listed here so users
500
556
  // pick a model by capability ("LLM") rather than by modality.
501
557
  lfm2_5_vl_1_6b: base(M.LFM2_5_VL_1_6B_QUANTIZED),
502
558
  lfm2_5_vl_450m: base(M.LFM2_5_VL_450M_QUANTIZED),
559
+ gemma4_e2b_multimodal: base(M.GEMMA4_E2B_MM),
503
560
  },
504
561
  classification: {
505
562
  efficientnet_v2_s: variant(EFFICIENTNET_V2_S_VARIANTS),
@@ -521,6 +578,9 @@ export const models = {
521
578
  },
522
579
  pose_estimation: {
523
580
  yolo26n: base(M.YOLO26N_POSE),
581
+ // BETA preview — may be re-exported under a different constant once a
582
+ // stable RF-DETR keypoint model ships.
583
+ rfdetr_keypoint_preview: variant(RF_DETR_KEYPOINT_PREVIEW_VARIANTS),
524
584
  },
525
585
  semantic_segmentation: {
526
586
  deeplab_v3_resnet50: pair(
@@ -125,6 +125,47 @@ export const QWEN3_0_6B_QUANTIZED = {
125
125
  generationConfig: QWEN3_GENERATION_CONFIG,
126
126
  } as const;
127
127
 
128
+ // GEMMA 4 — separate HF repo; tokenizer files live at the e2b root and are
129
+ // shared by all backend variants.
130
+ const GEMMA4_E2B_PREFIX = `${URL_PREFIX}-gemma-4/${VERSION_TAG}/e2b`;
131
+ export const GEMMA4_E2B_MLX_MODEL = `${GEMMA4_E2B_PREFIX}/mlx/gemma4_e2b_mlx_int4.pte`;
132
+ export const GEMMA4_E2B_XNNPACK_MODEL = `${GEMMA4_E2B_PREFIX}/xnnpack/gemma_4_e2b_xnnpack_8da4w.pte`;
133
+ export const GEMMA4_E2B_VULKAN_MODEL = `${GEMMA4_E2B_PREFIX}/vulkan/gemma_4_e2b_vulkan_8da4w.pte`;
134
+ export const GEMMA4_E2B_TOKENIZER = `${GEMMA4_E2B_PREFIX}/tokenizer.json`;
135
+ export const GEMMA4_E2B_TOKENIZER_CONFIG = `${GEMMA4_E2B_PREFIX}/tokenizer_config.json`;
136
+
137
+ const GEMMA4_E2B_MODEL =
138
+ Platform.OS === `android` ? GEMMA4_E2B_VULKAN_MODEL : GEMMA4_E2B_MLX_MODEL;
139
+
140
+ const GEMMA4_E2B_MLX_MM = `${URL_PREFIX}-gemma-4-multimodal/${VERSION_TAG}/e2b/mlx/gemma4_e2b_mlx_int4.pte`;
141
+ const GEMMA4_E2B_VULKAN_MM = `${URL_PREFIX}-gemma-4-multimodal/${VERSION_TAG}/e2b/vulkan/gemma_4_e2b_vulkan_8da4w.pte`;
142
+
143
+ /**
144
+ * @category Models - LLM
145
+ */
146
+ export const GEMMA4_E2B = {
147
+ modelName: 'gemma4-e2b',
148
+ modelSource: GEMMA4_E2B_MODEL,
149
+ tokenizerSource: GEMMA4_E2B_TOKENIZER,
150
+ tokenizerConfigSource: GEMMA4_E2B_TOKENIZER_CONFIG,
151
+ } as const;
152
+
153
+ /**
154
+ * @category Models - LLM Multimodal
155
+ */
156
+ export const GEMMA4_E2B_MM = {
157
+ modelName: 'gemma4-e2b-multimodal',
158
+ modelSource:
159
+ Platform.OS === `android` ? GEMMA4_E2B_VULKAN_MM : GEMMA4_E2B_MLX_MM,
160
+ tokenizerSource: GEMMA4_E2B_TOKENIZER,
161
+ tokenizerConfigSource: GEMMA4_E2B_TOKENIZER_CONFIG,
162
+ capabilities: ['vision', 'audio'],
163
+ audioConfig: {
164
+ samplesPerBlock: 7680,
165
+ tokensPerBlock: 12,
166
+ },
167
+ } as const;
168
+
128
169
  /**
129
170
  * @category Models - LLM
130
171
  */
@@ -690,6 +731,28 @@ export const YOLO26N_POSE = {
690
731
  modelSource: YOLO26N_POSE_MODEL,
691
732
  } as const;
692
733
 
734
+ // RF-DETR Keypoint (pose estimation) — BETA preview.
735
+ // NOTE: served from the `preview/` path under PREVIOUS_VERSION_TAG (shipping as
736
+ // part of a patch release). This export is a preview and may be re-exported
737
+ // under a different constant once a stable version ships.
738
+ export const RF_DETR_KEYPOINT_PREVIEW_XNNPACK_FP32_MODEL = `${URL_PREFIX}-rfdetr-keypoint/${VERSION_TAG}/preview/xnnpack/rfdetr_keypoint_preview_xnnpack_fp32.pte`;
739
+ export const RF_DETR_KEYPOINT_PREVIEW_COREML_FP32_MODEL = `${URL_PREFIX}-rfdetr-keypoint/${VERSION_TAG}/preview/coreml/rfdetr_keypoint_preview_coreml_fp32.pte`;
740
+ export const RF_DETR_KEYPOINT_PREVIEW_MLX_FP32_MODEL = `${URL_PREFIX}-rfdetr-keypoint/${VERSION_TAG}/preview/mlx/rfdetr_keypoint_preview_mlx_fp32.pte`;
741
+ const RF_DETR_KEYPOINT_PREVIEW_MODEL =
742
+ Platform.OS === 'ios'
743
+ ? RF_DETR_KEYPOINT_PREVIEW_COREML_FP32_MODEL
744
+ : RF_DETR_KEYPOINT_PREVIEW_XNNPACK_FP32_MODEL;
745
+
746
+ /**
747
+ * @category Models - Pose Estimation
748
+ * @beta Preview export — may be re-exported under a different constant once a
749
+ * stable RF-DETR keypoint model ships.
750
+ */
751
+ export const RF_DETR_KEYPOINT_PREVIEW = {
752
+ modelName: 'rfdetr-keypoint-preview',
753
+ modelSource: RF_DETR_KEYPOINT_PREVIEW_MODEL,
754
+ } as const;
755
+
693
756
  // Style transfer
694
757
  /**
695
758
  * Builds the four `(backend, precision)` URLs for a single style-transfer style.
@@ -816,27 +879,27 @@ export const STYLE_TRANSFER_UDNIE_QUANTIZED = {
816
879
  // S2T
817
880
  export const WHISPER_TINY_EN_TOKENIZER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/tokenizer.json`;
818
881
  export const WHISPER_TINY_EN_MODEL_XNNPACK = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_xnnpack_fp32.pte`;
819
- export const WHISPER_TINY_EN_MODEL_COREML = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/coreml/whisper_tiny_en_coreml_fp32.pte`;
882
+ export const WHISPER_TINY_EN_MODEL_COREML = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/coreml/whisper_tiny_en_coreml_fp16.pte`;
820
883
 
821
884
  export const WHISPER_BASE_EN_TOKENIZER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/tokenizer.json`;
822
885
  export const WHISPER_BASE_EN_MODEL_XNNPACK = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/xnnpack/whisper_base_en_xnnpack_fp32.pte`;
823
- export const WHISPER_BASE_EN_MODEL_COREML = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/coreml/whisper_base_en_coreml_fp32.pte`;
886
+ export const WHISPER_BASE_EN_MODEL_COREML = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/coreml/whisper_base_en_coreml_fp16.pte`;
824
887
 
825
888
  export const WHISPER_SMALL_EN_TOKENIZER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/tokenizer.json`;
826
889
  export const WHISPER_SMALL_EN_MODEL_XNNPACK = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/xnnpack/whisper_small_en_xnnpack_fp32.pte`;
827
- export const WHISPER_SMALL_EN_MODEL_COREML = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/coreml/whisper_small_en_coreml_fp32.pte`;
890
+ export const WHISPER_SMALL_EN_MODEL_COREML = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/coreml/whisper_small_en_coreml_fp16.pte`;
828
891
 
829
892
  export const WHISPER_TINY_TOKENIZER = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/tokenizer.json`;
830
893
  export const WHISPER_TINY_MODEL_XNNPACK = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/xnnpack/whisper_tiny_xnnpack_fp32.pte`;
831
- export const WHISPER_TINY_MODEL_COREML = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/coreml/whisper_tiny_coreml_fp32.pte`;
894
+ export const WHISPER_TINY_MODEL_COREML = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/coreml/whisper_tiny_coreml_fp16.pte`;
832
895
 
833
896
  export const WHISPER_BASE_TOKENIZER = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/tokenizer.json`;
834
897
  export const WHISPER_BASE_MODEL_XNNPACK = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/xnnpack/whisper_base_xnnpack_fp32.pte`;
835
- export const WHISPER_BASE_MODEL_COREML = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/coreml/whisper_base_coreml_fp32.pte`;
898
+ export const WHISPER_BASE_MODEL_COREML = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/coreml/whisper_base_coreml_fp16.pte`;
836
899
 
837
900
  export const WHISPER_SMALL_TOKENIZER = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/tokenizer.json`;
838
901
  export const WHISPER_SMALL_MODEL_XNNPACK = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/xnnpack/whisper_small_xnnpack_fp32.pte`;
839
- export const WHISPER_SMALL_MODEL_COREML = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/coreml/whisper_small_coreml_fp32.pte`;
902
+ export const WHISPER_SMALL_MODEL_COREML = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/coreml/whisper_small_coreml_fp16.pte`;
840
903
 
841
904
  /**
842
905
  * @category Models - Speech To Text
@@ -1,11 +1,11 @@
1
- import { ResourceSource } from '../types/common';
2
1
  import { ResourceFetcher } from '../utils/ResourceFetcher';
3
2
  import { Template } from '@huggingface/jinja';
4
3
  import { DEFAULT_CHAT_CONFIG } from '../constants/llmDefaults';
5
4
  import {
5
+ AudioConfig,
6
6
  ChatConfig,
7
7
  GenerationConfig,
8
- LLMCapability,
8
+ LLMModel,
9
9
  LLMTool,
10
10
  Message,
11
11
  SPECIAL_TOKENS,
@@ -30,6 +30,7 @@ export class LLMController {
30
30
  private messageHistoryCallback: (messageHistory: Message[]) => void;
31
31
  private isReadyCallback: (isReady: boolean) => void;
32
32
  private isGeneratingCallback: (isGenerating: boolean) => void;
33
+ private audioConfig: AudioConfig | undefined;
33
34
 
34
35
  constructor({
35
36
  tokenCallback,
@@ -72,18 +73,10 @@ export class LLMController {
72
73
  }
73
74
 
74
75
  public async load({
75
- modelSource,
76
- tokenizerSource,
77
- tokenizerConfigSource,
78
- capabilities,
79
- defaultGenerationConfig,
76
+ model,
80
77
  onDownloadProgressCallback,
81
78
  }: {
82
- modelSource: ResourceSource;
83
- tokenizerSource: ResourceSource;
84
- tokenizerConfigSource: ResourceSource;
85
- capabilities?: readonly LLMCapability[];
86
- defaultGenerationConfig?: GenerationConfig;
79
+ model: LLMModel;
87
80
  onDownloadProgressCallback?: (downloadProgress: number) => void;
88
81
  }) {
89
82
  // reset inner state when loading new model
@@ -94,13 +87,13 @@ export class LLMController {
94
87
  try {
95
88
  const tokenizersPromise = ResourceFetcher.fetch(
96
89
  undefined,
97
- tokenizerSource,
98
- tokenizerConfigSource
90
+ model.tokenizerSource,
91
+ model.tokenizerConfigSource
99
92
  );
100
93
 
101
94
  const modelPromise = ResourceFetcher.fetch(
102
95
  onDownloadProgressCallback,
103
- modelSource
96
+ model.modelSource
104
97
  );
105
98
 
106
99
  const [tokenizersResults, modelResult] = await Promise.all([
@@ -124,16 +117,18 @@ export class LLMController {
124
117
  this.nativeModule.unload();
125
118
  }
126
119
 
120
+ this.audioConfig = model.audioConfig;
121
+
127
122
  this.nativeModule = await global.loadLLM(
128
123
  modelPath,
129
124
  tokenizerPath,
130
- capabilities ?? []
125
+ model.capabilities ?? []
131
126
  );
132
- if (defaultGenerationConfig) {
127
+ if (model.generationConfig) {
133
128
  // Apply model-specific recommended sampling defaults before flipping
134
129
  // isReady so callers that react to it see the right config on first
135
130
  // send. User-provided `configure()` calls still override these.
136
- this.applyGenerationConfig(defaultGenerationConfig);
131
+ this.applyGenerationConfig(model.generationConfig);
137
132
  }
138
133
  this.isReadyCallback(true);
139
134
  this.onToken = (data: string) => {
@@ -236,6 +231,17 @@ export class LLMController {
236
231
  return token;
237
232
  }
238
233
 
234
+ private getAudioToken(): string {
235
+ const token = this.tokenizerConfig.audio_token;
236
+ if (!token) {
237
+ throw new RnExecutorchError(
238
+ RnExecutorchErrorCode.InvalidConfig,
239
+ "Tokenizer config is missing 'audio_token'. Audio-capable models require tokenizerConfigSource with an 'audio_token' field."
240
+ );
241
+ }
242
+ return token;
243
+ }
244
+
239
245
  private filterSpecialTokens(text: string): string {
240
246
  let filtered = text;
241
247
  if (
@@ -244,6 +250,12 @@ export class LLMController {
244
250
  ) {
245
251
  filtered = filtered.replaceAll(this.tokenizerConfig.eos_token, '');
246
252
  }
253
+ if (
254
+ SPECIAL_TOKENS.EOT_TOKEN in this.tokenizerConfig &&
255
+ this.tokenizerConfig.eot_token
256
+ ) {
257
+ filtered = filtered.replaceAll(this.tokenizerConfig.eot_token, '');
258
+ }
247
259
  if (
248
260
  SPECIAL_TOKENS.PAD_TOKEN in this.tokenizerConfig &&
249
261
  this.tokenizerConfig.pad_token
@@ -269,25 +281,37 @@ export class LLMController {
269
281
  this.isGeneratingCallback(false);
270
282
  }
271
283
 
272
- public async forward(input: string, imagePaths?: string[]): Promise<string> {
284
+ public async forward(
285
+ input: string,
286
+ imagePaths?: string[],
287
+ audioWaveforms?: Float32Array[]
288
+ ): Promise<string> {
273
289
  if (!this._isReady) {
274
290
  throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded);
275
291
  }
276
292
  if (this._isGenerating) {
277
293
  throw new RnExecutorchError(RnExecutorchErrorCode.ModelGenerating);
278
294
  }
295
+ const hasImages = !!imagePaths && imagePaths.length > 0;
296
+ const hasAudio = !!audioWaveforms && audioWaveforms.length > 0;
279
297
  try {
280
298
  this.isGeneratingCallback(true);
281
299
  this.nativeModule.reset();
282
- const response =
283
- imagePaths && imagePaths.length > 0
284
- ? await this.nativeModule.generateMultimodal(
285
- input,
286
- imagePaths.map(normalizeImagePath),
287
- this.getImageToken(),
288
- this.onToken
289
- )
290
- : await this.nativeModule.generate(input, this.onToken);
300
+ let response: string;
301
+ if (hasImages || hasAudio) {
302
+ response = await this.nativeModule.generateMultimodal(
303
+ input,
304
+ this.onToken,
305
+ {
306
+ imagePaths: hasImages ? imagePaths!.map(normalizeImagePath) : null,
307
+ imageToken: hasImages ? this.getImageToken() : null,
308
+ audioWaveforms: hasAudio ? audioWaveforms! : null,
309
+ audioToken: hasAudio ? this.getAudioToken() : null,
310
+ }
311
+ );
312
+ } else {
313
+ response = await this.nativeModule.generate(input, this.onToken);
314
+ }
291
315
  return this.filterSpecialTokens(response);
292
316
  } catch (e) {
293
317
  throw parseUnknownError(e);
@@ -355,7 +379,9 @@ export class LLMController {
355
379
  const imagePaths = messages
356
380
  .filter((m) => m.mediaPath)
357
381
  .map((m) => m.mediaPath!);
358
-
382
+ const audioWaveforms = messages
383
+ .filter((m) => m.audioWaveform)
384
+ .map((m) => m.audioWaveform!);
359
385
  const renderedChat: string = this.applyChatTemplate(
360
386
  messages,
361
387
  this.tokenizerConfig,
@@ -365,19 +391,22 @@ export class LLMController {
365
391
 
366
392
  return await this.forward(
367
393
  renderedChat,
368
- imagePaths.length > 0 ? imagePaths : undefined
394
+ imagePaths.length > 0 ? imagePaths : undefined,
395
+ audioWaveforms.length > 0 ? audioWaveforms : undefined
369
396
  );
370
397
  }
371
398
 
372
399
  public async sendMessage(
373
400
  message: string,
374
- media?: { imagePath?: string }
401
+ media?: { imagePath?: string; audioBuffer?: Float32Array }
375
402
  ): Promise<string> {
376
403
  const mediaPath = media?.imagePath;
404
+ const audioBuffer = media?.audioBuffer;
377
405
  const newMessage: Message = {
378
406
  content: message,
379
407
  role: 'user',
380
408
  ...(mediaPath ? { mediaPath } : {}),
409
+ ...(audioBuffer ? { audioWaveform: audioBuffer } : {}),
381
410
  };
382
411
  const updatedHistory = [...this._messageHistory, newMessage];
383
412
  this.messageHistoryCallback(updatedHistory);
@@ -392,7 +421,22 @@ export class LLMController {
392
421
  );
393
422
  const textTokens = this.nativeModule.countTextTokens(rendered);
394
423
  const imageCount = messages.filter((m) => m.mediaPath).length;
395
- return textTokens + imageCount * (visualTokenCount - 1);
424
+ // Audio soft-token expansion: audio_encoder pads samples to
425
+ // multiples of this.audioConfig.samplesPerBlock (7680 @ 16 kHz) and emits
426
+ // this.audioConfig.tokensPerBlock (~12) soft tokens per padded block. The
427
+ // rendered template only contributes 1 token for the audio placeholder,
428
+ // so add (expansion - 1) per audio message to match prefill consumption.
429
+ const audioTokenExpansion = messages.reduce((acc, m) => {
430
+ if (!m.audioWaveform) return acc;
431
+ const kBlocks = Math.max(
432
+ 1,
433
+ Math.ceil(m.audioWaveform.length / this.audioConfig!.samplesPerBlock)
434
+ );
435
+ return acc + (this.audioConfig!.tokensPerBlock * kBlocks - 1);
436
+ }, 0);
437
+ return (
438
+ textTokens + imageCount * (visualTokenCount - 1) + audioTokenExpansion
439
+ );
396
440
  };
397
441
  const maxContextLength = this.nativeModule.getMaxContextLength();
398
442
  const messageHistoryWithPrompt =
@@ -497,12 +541,17 @@ function normalizeImagePath(path: string): string {
497
541
  * @returns Messages with image-bearing turns rewritten to structured content.
498
542
  */
499
543
  function messagesForChatTemplate(messages: Message[]): any[] {
500
- return messages.map((m) =>
501
- m.mediaPath && typeof m.content === 'string'
502
- ? {
503
- ...m,
504
- content: [{ type: 'image' }, { type: 'text', text: m.content }],
505
- }
506
- : m
507
- );
544
+ return messages.map((m) => {
545
+ if (typeof m.content !== 'string') return m;
546
+ const hasImage = !!m.mediaPath;
547
+ const hasAudio = !!m.audioWaveform;
548
+ if (!hasImage && !hasAudio) return m;
549
+ const parts: any[] = [];
550
+ if (hasImage) parts.push({ type: 'image' });
551
+ if (hasAudio) parts.push({ type: 'audio' });
552
+ parts.push({ type: 'text', text: m.content });
553
+ // Drop the Float32Array on the clone only — passing it into the Jinja
554
+ // template engine slows render past 3s. Don't mutate m;
555
+ return { ...m, content: parts, audioWaveform: undefined };
556
+ });
508
557
  }
@@ -58,11 +58,7 @@ export function useLLM({
58
58
  (async () => {
59
59
  try {
60
60
  await controllerInstance.load({
61
- modelSource: model.modelSource,
62
- tokenizerSource: model.tokenizerSource,
63
- tokenizerConfigSource: model.tokenizerConfigSource!,
64
- capabilities: model.capabilities,
65
- defaultGenerationConfig: model.generationConfig,
61
+ model: model,
66
62
  onDownloadProgressCallback: setDownloadProgress,
67
63
  });
68
64
  } catch (e) {
@@ -106,7 +102,10 @@ export function useLLM({
106
102
  );
107
103
 
108
104
  const sendMessage = useCallback(
109
- (message: string, media?: { imagePath?: string }) => {
105
+ (
106
+ message: string,
107
+ media?: { imagePath?: string; audioBuffer?: Float32Array }
108
+ ) => {
110
109
  setResponse('');
111
110
  return controllerInstance.sendMessage(message, media);
112
111
  },
@@ -29,8 +29,20 @@ const YOLO_POSE_CONFIG = {
29
29
  defaultKeypointThreshold: 0.5,
30
30
  } satisfies PoseEstimationConfig<typeof CocoKeypoint>;
31
31
 
32
+ // RF-DETR keypoint preview (BETA). Unlike yolo26n-pose's multi-method
33
+ // `forward_<size>` export, this ships a single `forward` method — omitting
34
+ // availableInputSizes/defaultInputSize makes forward() dispatch to plain
35
+ // `forward`. May be renamed once a stable model ships.
36
+ const RFDETR_KEYPOINT_CONFIG = {
37
+ keypointMap: CocoKeypoint,
38
+ preprocessorConfig: undefined,
39
+ defaultDetectionThreshold: 0.5,
40
+ defaultKeypointThreshold: 0.5,
41
+ } satisfies PoseEstimationConfig<typeof CocoKeypoint>;
42
+
32
43
  const ModelConfigs = {
33
44
  'yolo26n-pose': YOLO_POSE_CONFIG,
45
+ 'rfdetr-keypoint-preview': RFDETR_KEYPOINT_CONFIG,
34
46
  } as const satisfies Record<
35
47
  PoseEstimationModelName,
36
48
  PoseEstimationConfig<LabelEnum>
@@ -3,6 +3,7 @@ import { Logger } from '../../common/Logger';
3
3
  import { parseUnknownError } from '../../errors/errorUtils';
4
4
  import { ResourceSource } from '../../types/common';
5
5
  import {
6
+ AudioConfig,
6
7
  LLMCapability,
7
8
  LLMConfig,
8
9
  LLMModelName,
@@ -51,6 +52,7 @@ export class LLMModule {
51
52
  tokenizerSource: ResourceSource;
52
53
  tokenizerConfigSource: ResourceSource;
53
54
  capabilities?: readonly LLMCapability[];
55
+ audioConfig?: AudioConfig;
54
56
  },
55
57
  onDownloadProgress: (progress: number) => void = () => {},
56
58
  tokenCallback?: (token: string) => void,
@@ -59,10 +61,14 @@ export class LLMModule {
59
61
  const instance = new LLMModule({ tokenCallback, messageHistoryCallback });
60
62
  try {
61
63
  await instance.controller.load({
62
- modelSource: namedSources.modelSource,
63
- tokenizerSource: namedSources.tokenizerSource,
64
- tokenizerConfigSource: namedSources.tokenizerConfigSource,
65
- capabilities: namedSources.capabilities,
64
+ model: {
65
+ modelName: namedSources.modelName,
66
+ modelSource: namedSources.modelSource,
67
+ tokenizerSource: namedSources.tokenizerSource,
68
+ tokenizerConfigSource: namedSources.tokenizerConfigSource,
69
+ capabilities: namedSources.capabilities,
70
+ audioConfig: namedSources.audioConfig,
71
+ },
66
72
  onDownloadProgressCallback: onDownloadProgress,
67
73
  });
68
74
  return instance;
@@ -140,10 +146,15 @@ export class LLMModule {
140
146
  * If you want a simple chat with model the consider using `sendMessage`
141
147
  * @param input - Raw input string containing the prompt and conversation history.
142
148
  * @param imagePaths - Optional array of local image paths for multimodal inference. Each entry may be either `file:///absolute/path` or `/absolute/path` — the controller normalizes the path before passing it to native code.
149
+ * @param audioWaveforms - Optional array of 16kHz waveforms of audio recordings for multimodal inference.
143
150
  * @returns The generated response as a string.
144
151
  */
145
- async forward(input: string, imagePaths?: string[]): Promise<string> {
146
- return await this.controller.forward(input, imagePaths);
152
+ async forward(
153
+ input: string,
154
+ imagePaths?: string[],
155
+ audioWaveforms?: Float32Array[]
156
+ ): Promise<string> {
157
+ return await this.controller.forward(input, imagePaths, audioWaveforms);
147
158
  }
148
159
 
149
160
  /**
@@ -162,12 +173,12 @@ export class LLMModule {
162
173
  * After model responds it will call `messageHistoryCallback()` containing both user message and model response.
163
174
  * It also returns them.
164
175
  * @param message - The message string to send.
165
- * @param media - Optional media object containing a local image path for multimodal models.
176
+ * @param media - Optional media object containing a local image path or 16kHz waveform of an audio recording for multimodal models.
166
177
  * @returns - Updated message history including the new user message and model response.
167
178
  */
168
179
  async sendMessage(
169
180
  message: string,
170
- media?: { imagePath?: string }
181
+ media?: { imagePath?: string; audioBuffer?: Float32Array }
171
182
  ): Promise<Message[]> {
172
183
  await this.controller.sendMessage(message, media);
173
184
  return this.controller.messageHistory;