whisper.rn 0.1.3 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/LICENSE CHANGED
@@ -1,6 +1,6 @@
1
1
  MIT License
2
2
 
3
- Copyright (c) 2023 Jhen <developer@jhen.me>
3
+ Copyright (c) 2023 Jhen-Jie Hong
4
4
  Permission is hereby granted, free of charge, to any person obtaining a copy
5
5
  of this software and associated documentation files (the "Software"), to deal
6
6
  in the Software without restriction, including without limitation the rights
@@ -12,6 +12,7 @@ import com.facebook.react.bridge.ReactContextBaseJavaModule;
12
12
  import com.facebook.react.bridge.ReactMethod;
13
13
  import com.facebook.react.bridge.LifecycleEventListener;
14
14
  import com.facebook.react.bridge.ReadableMap;
15
+ import com.facebook.react.bridge.WritableMap;
15
16
  import com.facebook.react.module.annotations.ReactModule;
16
17
 
17
18
  import java.util.HashMap;
@@ -72,11 +73,11 @@ public class RNWhisperModule extends ReactContextBaseJavaModule implements Lifec
72
73
 
73
74
  @ReactMethod
74
75
  public void transcribe(int id, String filePath, ReadableMap options, Promise promise) {
75
- new AsyncTask<Void, Void, String>() {
76
+ new AsyncTask<Void, Void, WritableMap>() {
76
77
  private Exception exception;
77
78
 
78
79
  @Override
79
- protected String doInBackground(Void... voids) {
80
+ protected WritableMap doInBackground(Void... voids) {
80
81
  try {
81
82
  WhisperContext context = contexts.get(id);
82
83
  if (context == null) {
@@ -90,12 +91,12 @@ public class RNWhisperModule extends ReactContextBaseJavaModule implements Lifec
90
91
  }
91
92
 
92
93
  @Override
93
- protected void onPostExecute(String result) {
94
+ protected void onPostExecute(WritableMap data) {
94
95
  if (exception != null) {
95
96
  promise.reject(exception);
96
97
  return;
97
98
  }
98
- promise.resolve(result);
99
+ promise.resolve(data);
99
100
  }
100
101
  }.execute();
101
102
  }
@@ -1,5 +1,8 @@
1
1
  package com.rnwhisper;
2
2
 
3
+ import com.facebook.react.bridge.Arguments;
4
+ import com.facebook.react.bridge.WritableArray;
5
+ import com.facebook.react.bridge.WritableMap;
3
6
  import com.facebook.react.bridge.ReadableMap;
4
7
 
5
8
  import android.util.Log;
@@ -29,7 +32,7 @@ public class WhisperContext {
29
32
  this.context = context;
30
33
  }
31
34
 
32
- public String transcribe(final String filePath, final ReadableMap options) throws IOException, Exception {
35
+ public WritableMap transcribe(final String filePath, final ReadableMap options) throws IOException, Exception {
33
36
  int code = fullTranscribe(
34
37
  context,
35
38
  decodeWaveFile(new File(filePath)),
@@ -37,14 +40,18 @@ public class WhisperContext {
37
40
  options.hasKey("maxThreads") ? options.getInt("maxThreads") : -1,
38
41
  // jint max_context,
39
42
  options.hasKey("maxContext") ? options.getInt("maxContext") : -1,
43
+
44
+ // jint word_thold,
45
+ options.hasKey("wordThold") ? options.getInt("wordThold") : -1,
40
46
  // jint max_len,
41
47
  options.hasKey("maxLen") ? options.getInt("maxLen") : -1,
48
+ // jboolean token_timestamps,
49
+ options.hasKey("tokenTimestamps") ? options.getBoolean("tokenTimestamps") : false,
50
+
42
51
  // jint offset,
43
52
  options.hasKey("offset") ? options.getInt("offset") : -1,
44
53
  // jint duration,
45
54
  options.hasKey("duration") ? options.getInt("duration") : -1,
46
- // jint word_thold,
47
- options.hasKey("wordThold") ? options.getInt("wordThold") : -1,
48
55
  // jfloat temperature,
49
56
  options.hasKey("temperature") ? (float) options.getDouble("temperature") : -1.0f,
50
57
  // jfloat temperature_inc,
@@ -58,17 +65,31 @@ public class WhisperContext {
58
65
  // jboolean translate,
59
66
  options.hasKey("translate") ? options.getBoolean("translate") : false,
60
67
  // jstring language,
61
- options.hasKey("language") ? options.getString("language") : "auto"
68
+ options.hasKey("language") ? options.getString("language") : "auto",
69
+ // jstring prompt
70
+ options.hasKey("prompt") ? options.getString("prompt") : null
62
71
  );
63
72
  if (code != 0) {
64
73
  throw new Exception("Transcription failed with code " + code);
65
74
  }
66
75
  Integer count = getTextSegmentCount(context);
67
76
  StringBuilder builder = new StringBuilder();
77
+
78
+ WritableMap data = Arguments.createMap();
79
+ WritableArray segments = Arguments.createArray();
68
80
  for (int i = 0; i < count; i++) {
69
- builder.append(getTextSegment(context, i));
81
+ String text = getTextSegment(context, i);
82
+ builder.append(text);
83
+
84
+ WritableMap segment = Arguments.createMap();
85
+ segment.putString("text", text);
86
+ segment.putInt("t0", getTextSegmentT0(context, i));
87
+ segment.putInt("t1", getTextSegmentT1(context, i));
88
+ segments.pushMap(segment);
70
89
  }
71
- return builder.toString();
90
+ data.putString("result", builder.toString());
91
+ data.putArray("segments", segments);
92
+ return data;
72
93
  }
73
94
 
74
95
  public void release() {
@@ -168,19 +189,23 @@ public class WhisperContext {
168
189
  float[] audio_data,
169
190
  int n_threads,
170
191
  int max_context,
192
+ int word_thold,
171
193
  int max_len,
194
+ boolean token_timestamps,
172
195
  int offset,
173
196
  int duration,
174
- int word_thold,
175
197
  float temperature,
176
198
  float temperature_inc,
177
199
  int beam_size,
178
200
  int best_of,
179
201
  boolean speed_up,
180
202
  boolean translate,
181
- String language
203
+ String language,
204
+ String prompt
182
205
  );
183
206
  protected static native int getTextSegmentCount(long context);
184
207
  protected static native String getTextSegment(long context, int index);
208
+ protected static native int getTextSegmentT0(long context, int index);
209
+ protected static native int getTextSegmentT1(long context, int index);
185
210
  protected static native void freeContext(long contextPtr);
186
211
  }
@@ -15,4 +15,5 @@ LOCAL_CFLAGS += -DSTDC_HEADERS -std=c11 -I $(WHISPER_LIB_DIR)
15
15
  LOCAL_CPPFLAGS += -std=c++11
16
16
  LOCAL_SRC_FILES := $(WHISPER_LIB_DIR)/ggml.c \
17
17
  $(WHISPER_LIB_DIR)/whisper.cpp \
18
- $(LOCAL_PATH)/jni.c
18
+ $(WHISPER_LIB_DIR)/rn-whisper.cpp \
19
+ $(LOCAL_PATH)/jni.cpp
@@ -2,10 +2,11 @@
2
2
  #include <android/asset_manager.h>
3
3
  #include <android/asset_manager_jni.h>
4
4
  #include <android/log.h>
5
- #include <stdlib.h>
5
+ #include <cstdlib>
6
6
  #include <sys/sysinfo.h>
7
- #include <string.h>
7
+ #include <string>
8
8
  #include "whisper.h"
9
+ #include "rn-whisper.h"
9
10
  #include "ggml.h"
10
11
 
11
12
  #define UNUSED(x) (void)(x)
@@ -18,31 +19,17 @@ static inline int min(int a, int b) {
18
19
  return (a < b) ? a : b;
19
20
  }
20
21
 
21
- static inline int max(int a, int b) {
22
- return (a > b) ? a : b;
23
- }
24
-
25
- static size_t asset_read(void *ctx, void *output, size_t read_size) {
26
- return AAsset_read((AAsset *) ctx, output, read_size);
27
- }
28
-
29
- static bool asset_is_eof(void *ctx) {
30
- return AAsset_getRemainingLength64((AAsset *) ctx) <= 0;
31
- }
32
-
33
- static void asset_close(void *ctx) {
34
- AAsset_close((AAsset *) ctx);
35
- }
22
+ extern "C" {
36
23
 
37
24
  JNIEXPORT jlong JNICALL
38
25
  Java_com_rnwhisper_WhisperContext_initContext(
39
26
  JNIEnv *env, jobject thiz, jstring model_path_str) {
40
27
  UNUSED(thiz);
41
- struct whisper_context *context = NULL;
42
- const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL);
28
+ struct whisper_context *context = nullptr;
29
+ const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
43
30
  context = whisper_init_from_file(model_path_chars);
44
- (*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars);
45
- return (jlong) context;
31
+ env->ReleaseStringUTFChars(model_path_str, model_path_chars);
32
+ return reinterpret_cast<jlong>(context);
46
33
  }
47
34
 
48
35
  JNIEXPORT jint JNICALL
@@ -53,29 +40,31 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
53
40
  jfloatArray audio_data,
54
41
  jint n_threads,
55
42
  jint max_context,
56
- jint max_len,
43
+ int word_thold,
44
+ int max_len,
45
+ jboolean token_timestamps,
57
46
  jint offset,
58
47
  jint duration,
59
- jint word_thold,
60
48
  jfloat temperature,
61
49
  jfloat temperature_inc,
62
50
  jint beam_size,
63
51
  jint best_of,
64
52
  jboolean speed_up,
65
53
  jboolean translate,
66
- jstring language
54
+ jstring language,
55
+ jstring prompt
67
56
  ) {
68
57
  UNUSED(thiz);
69
- struct whisper_context *context = (struct whisper_context *) context_ptr;
70
- jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL);
71
- const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data);
58
+ struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
59
+ jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);
60
+ const jsize audio_data_length = env->GetArrayLength(audio_data);
72
61
 
73
- int max_threads = max(1, min(8, get_nprocs() - 2));
62
+ int max_threads = min(4, get_nprocs());
74
63
 
75
64
  LOGI("About to create params");
76
65
 
77
66
  struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
78
-
67
+
79
68
  if (beam_size > -1) {
80
69
  params.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
81
70
  params.beam_search.beam_size = beam_size;
@@ -86,22 +75,25 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
86
75
  params.print_timestamps = false;
87
76
  params.print_special = false;
88
77
  params.translate = translate;
89
- params.language = language;
78
+ const char *language_chars = env->GetStringUTFChars(language, nullptr);
79
+ params.language = language_chars;
90
80
  params.n_threads = n_threads > 0 ? n_threads : max_threads;
91
81
  params.speed_up = speed_up;
92
82
  params.offset_ms = 0;
93
83
  params.no_context = true;
94
84
  params.single_segment = false;
95
85
 
86
+ if (max_len > -1) {
87
+ params.max_len = max_len;
88
+ }
89
+ params.token_timestamps = token_timestamps;
90
+
96
91
  if (best_of > -1) {
97
92
  params.greedy.best_of = best_of;
98
93
  }
99
94
  if (max_context > -1) {
100
95
  params.n_max_text_ctx = max_context;
101
96
  }
102
- if (max_len > -1) {
103
- params.max_len = max_len;
104
- }
105
97
  if (offset > -1) {
106
98
  params.offset_ms = offset;
107
99
  }
@@ -117,6 +109,13 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
117
109
  if (temperature_inc > -1) {
118
110
  params.temperature_inc = temperature_inc;
119
111
  }
112
+ if (prompt != nullptr) {
113
+ rn_whisper_convert_prompt(
114
+ context,
115
+ params,
116
+ new std::string(env->GetStringUTFChars(prompt, nullptr))
117
+ );
118
+ }
120
119
 
121
120
  LOGI("About to reset timings");
122
121
  whisper_reset_timings(context);
@@ -126,7 +125,8 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
126
125
  if (code == 0) {
127
126
  // whisper_print_timings(context);
128
127
  }
129
- (*env)->ReleaseFloatArrayElements(env, audio_data, audio_data_arr, JNI_ABORT);
128
+ env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
129
+ env->ReleaseStringUTFChars(language, language_chars);
130
130
  return code;
131
131
  }
132
132
 
@@ -135,7 +135,7 @@ Java_com_rnwhisper_WhisperContext_getTextSegmentCount(
135
135
  JNIEnv *env, jobject thiz, jlong context_ptr) {
136
136
  UNUSED(env);
137
137
  UNUSED(thiz);
138
- struct whisper_context *context = (struct whisper_context *) context_ptr;
138
+ struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
139
139
  return whisper_full_n_segments(context);
140
140
  }
141
141
 
@@ -143,17 +143,37 @@ JNIEXPORT jstring JNICALL
143
143
  Java_com_rnwhisper_WhisperContext_getTextSegment(
144
144
  JNIEnv *env, jobject thiz, jlong context_ptr, jint index) {
145
145
  UNUSED(thiz);
146
- struct whisper_context *context = (struct whisper_context *) context_ptr;
146
+ struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
147
147
  const char *text = whisper_full_get_segment_text(context, index);
148
- jstring string = (*env)->NewStringUTF(env, text);
148
+ jstring string = env->NewStringUTF(text);
149
149
  return string;
150
150
  }
151
151
 
152
+ JNIEXPORT jint JNICALL
153
+ Java_com_rnwhisper_WhisperContext_getTextSegmentT0(
154
+ JNIEnv *env, jobject thiz, jlong context_ptr, jint index) {
155
+ UNUSED(env);
156
+ UNUSED(thiz);
157
+ struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
158
+ return whisper_full_get_segment_t0(context, index);
159
+ }
160
+
161
+ JNIEXPORT jint JNICALL
162
+ Java_com_rnwhisper_WhisperContext_getTextSegmentT1(
163
+ JNIEnv *env, jobject thiz, jlong context_ptr, jint index) {
164
+ UNUSED(env);
165
+ UNUSED(thiz);
166
+ struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
167
+ return whisper_full_get_segment_t1(context, index);
168
+ }
169
+
152
170
  JNIEXPORT void JNICALL
153
171
  Java_com_rnwhisper_WhisperContext_freeContext(
154
172
  JNIEnv *env, jobject thiz, jlong context_ptr) {
155
173
  UNUSED(env);
156
174
  UNUSED(thiz);
157
- struct whisper_context *context = (struct whisper_context *) context_ptr;
175
+ struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
158
176
  whisper_free(context);
159
177
  }
178
+
179
+ } // extern "C"
@@ -0,0 +1,31 @@
1
+ #include <cstdio>
2
+ #include <string>
3
+ #include <vector>
4
+ #include "whisper.h"
5
+
6
+ extern "C" {
7
+
8
+ void rn_whisper_convert_prompt(
9
+ struct whisper_context * ctx,
10
+ struct whisper_full_params params,
11
+ std::string * prompt
12
+ ) {
13
+ std::vector<whisper_token> prompt_tokens;
14
+ if (!prompt->empty()) {
15
+ prompt_tokens.resize(1024);
16
+ prompt_tokens.resize(whisper_tokenize(ctx, prompt->c_str(), prompt_tokens.data(), prompt_tokens.size()));
17
+
18
+ // fprintf(stderr, "\n");
19
+ // fprintf(stderr, "initial prompt: '%s'\n", prompt->c_str());
20
+ // fprintf(stderr, "initial tokens: [ ");
21
+ // for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
22
+ // fprintf(stderr, "%d ", prompt_tokens[i]);
23
+ // }
24
+ // fprintf(stderr, "]\n");
25
+
26
+ params.prompt_tokens = prompt_tokens.data();
27
+ params.prompt_n_tokens = prompt_tokens.size();
28
+ }
29
+ }
30
+
31
+ }
@@ -0,0 +1,16 @@
1
+
2
+ #ifdef __cplusplus
3
+ #include <string>
4
+ #include <whisper.h>
5
+ extern "C" {
6
+ #endif
7
+
8
+ void rn_whisper_convert_prompt(
9
+ struct whisper_context * ctx,
10
+ struct whisper_full_params params,
11
+ std::string * prompt
12
+ );
13
+
14
+ #ifdef __cplusplus
15
+ }
16
+ #endif