cui-llama.rn 1.2.6 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. package/README.md +3 -2
  2. package/android/src/main/CMakeLists.txt +26 -6
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +115 -27
  4. package/android/src/main/java/com/rnllama/RNLlama.java +40 -7
  5. package/android/src/main/jni.cpp +228 -40
  6. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +9 -4
  7. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +9 -4
  8. package/cpp/amx/amx.cpp +196 -0
  9. package/cpp/amx/amx.h +20 -0
  10. package/cpp/amx/common.h +101 -0
  11. package/cpp/amx/mmq.cpp +2524 -0
  12. package/cpp/amx/mmq.h +16 -0
  13. package/cpp/common.cpp +118 -251
  14. package/cpp/common.h +53 -30
  15. package/cpp/ggml-aarch64.c +46 -3395
  16. package/cpp/ggml-aarch64.h +0 -20
  17. package/cpp/ggml-alloc.c +6 -8
  18. package/cpp/ggml-backend-impl.h +33 -11
  19. package/cpp/ggml-backend-reg.cpp +423 -0
  20. package/cpp/ggml-backend.cpp +14 -676
  21. package/cpp/ggml-backend.h +46 -9
  22. package/cpp/ggml-common.h +6 -0
  23. package/cpp/ggml-cpu-aarch64.c +3823 -0
  24. package/cpp/ggml-cpu-aarch64.h +32 -0
  25. package/cpp/ggml-cpu-impl.h +14 -242
  26. package/cpp/ggml-cpu-quants.c +10835 -0
  27. package/cpp/ggml-cpu-quants.h +63 -0
  28. package/cpp/ggml-cpu.c +13971 -13720
  29. package/cpp/ggml-cpu.cpp +715 -0
  30. package/cpp/ggml-cpu.h +65 -63
  31. package/cpp/ggml-impl.h +285 -25
  32. package/cpp/ggml-metal.h +8 -8
  33. package/cpp/ggml-metal.m +1221 -728
  34. package/cpp/ggml-quants.c +189 -10681
  35. package/cpp/ggml-quants.h +78 -125
  36. package/cpp/ggml-threading.cpp +12 -0
  37. package/cpp/ggml-threading.h +12 -0
  38. package/cpp/ggml.c +688 -1460
  39. package/cpp/ggml.h +58 -244
  40. package/cpp/json-schema-to-grammar.cpp +1045 -1045
  41. package/cpp/json.hpp +24766 -24766
  42. package/cpp/llama-sampling.cpp +5 -2
  43. package/cpp/llama.cpp +409 -123
  44. package/cpp/llama.h +8 -4
  45. package/cpp/rn-llama.hpp +89 -25
  46. package/cpp/sampling.cpp +42 -3
  47. package/cpp/sampling.h +22 -1
  48. package/cpp/sgemm.cpp +608 -0
  49. package/cpp/speculative.cpp +270 -0
  50. package/cpp/speculative.h +28 -0
  51. package/cpp/unicode.cpp +11 -0
  52. package/ios/RNLlama.mm +43 -20
  53. package/ios/RNLlamaContext.h +9 -3
  54. package/ios/RNLlamaContext.mm +146 -33
  55. package/jest/mock.js +0 -1
  56. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  57. package/lib/commonjs/grammar.js +4 -2
  58. package/lib/commonjs/grammar.js.map +1 -1
  59. package/lib/commonjs/index.js +52 -15
  60. package/lib/commonjs/index.js.map +1 -1
  61. package/lib/module/NativeRNLlama.js.map +1 -1
  62. package/lib/module/grammar.js +2 -1
  63. package/lib/module/grammar.js.map +1 -1
  64. package/lib/module/index.js +51 -15
  65. package/lib/module/index.js.map +1 -1
  66. package/lib/typescript/NativeRNLlama.d.ts +122 -8
  67. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  68. package/lib/typescript/grammar.d.ts +5 -6
  69. package/lib/typescript/grammar.d.ts.map +1 -1
  70. package/lib/typescript/index.d.ts +15 -6
  71. package/lib/typescript/index.d.ts.map +1 -1
  72. package/package.json +2 -1
  73. package/src/NativeRNLlama.ts +135 -13
  74. package/src/grammar.ts +10 -8
  75. package/src/index.ts +104 -28
@@ -1,15 +1,23 @@
1
- import type { NativeContextParams, NativeLlamaContext, NativeCompletionParams, NativeCompletionTokenProb, NativeCompletionResult, NativeTokenizeResult, NativeEmbeddingResult, NativeSessionLoadResult, NativeCPUFeatures } from './NativeRNLlama';
1
+ import type { NativeContextParams, NativeLlamaContext, NativeCompletionParams, NativeCompletionTokenProb, NativeCompletionResult, NativeTokenizeResult, NativeEmbeddingResult, NativeSessionLoadResult, NativeCPUFeatures, NativeEmbeddingParams, NativeCompletionTokenProbItem, NativeCompletionResultTimings } from './NativeRNLlama';
2
+ import type { SchemaGrammarConverterPropOrder, SchemaGrammarConverterBuiltinRule } from './grammar';
2
3
  import { SchemaGrammarConverter, convertJsonSchemaToGrammar } from './grammar';
3
- import type { RNLlamaOAICompatibleMessage } from './chat';
4
+ import type { RNLlamaMessagePart, RNLlamaOAICompatibleMessage } from './chat';
5
+ export type { NativeContextParams, NativeLlamaContext, NativeCompletionParams, NativeCompletionTokenProb, NativeCompletionResult, NativeTokenizeResult, NativeEmbeddingResult, NativeSessionLoadResult, NativeEmbeddingParams, NativeCompletionTokenProbItem, NativeCompletionResultTimings, RNLlamaMessagePart, RNLlamaOAICompatibleMessage, SchemaGrammarConverterPropOrder, SchemaGrammarConverterBuiltinRule, };
4
6
  export { SchemaGrammarConverter, convertJsonSchemaToGrammar };
5
7
  export type TokenData = {
6
8
  token: string;
7
9
  completion_probabilities?: Array<NativeCompletionTokenProb>;
8
10
  };
9
- export type ContextParams = NativeContextParams;
11
+ export type ContextParams = Omit<NativeContextParams, 'cache_type_k' | 'cache_type_v' | 'pooling_type'> & {
12
+ cache_type_k?: 'f16' | 'f32' | 'q8_0' | 'q4_0' | 'q4_1' | 'iq4_nl' | 'q5_0' | 'q5_1';
13
+ cache_type_v?: 'f16' | 'f32' | 'q8_0' | 'q4_0' | 'q4_1' | 'iq4_nl' | 'q5_0' | 'q5_1';
14
+ pooling_type?: 'none' | 'mean' | 'cls' | 'last' | 'rank';
15
+ };
16
+ export type EmbeddingParams = NativeEmbeddingParams;
10
17
  export type CompletionParams = Omit<NativeCompletionParams, 'emit_partial_completion' | 'prompt'> & {
11
18
  prompt?: string;
12
19
  messages?: RNLlamaOAICompatibleMessage[];
20
+ chatTemplate?: string;
13
21
  };
14
22
  export type BenchResult = {
15
23
  modelDesc: string;
@@ -38,18 +46,19 @@ export declare class LlamaContext {
38
46
  saveSession(filepath: string, options?: {
39
47
  tokenSize: number;
40
48
  }): Promise<number>;
41
- getFormattedChat(messages: RNLlamaOAICompatibleMessage[]): Promise<string>;
49
+ getFormattedChat(messages: RNLlamaOAICompatibleMessage[], template?: string): Promise<string>;
42
50
  completion(params: CompletionParams, callback?: (data: TokenData) => void): Promise<NativeCompletionResult>;
43
51
  stopCompletion(): Promise<void>;
44
52
  tokenizeAsync(text: string): Promise<NativeTokenizeResult>;
45
53
  tokenizeSync(text: string): NativeTokenizeResult;
46
54
  detokenize(tokens: number[]): Promise<string>;
47
- embedding(text: string): Promise<NativeEmbeddingResult>;
55
+ embedding(text: string, params?: EmbeddingParams): Promise<NativeEmbeddingResult>;
48
56
  bench(pp: number, tg: number, pl: number, nr: number): Promise<BenchResult>;
49
57
  release(): Promise<void>;
50
58
  }
51
59
  export declare function getCpuFeatures(): Promise<NativeCPUFeatures>;
52
60
  export declare function setContextLimit(limit: number): Promise<void>;
53
- export declare function initLlama({ model, is_model_asset: isModelAsset, ...rest }: ContextParams, progressCallback?: (progress: number) => void): Promise<LlamaContext>;
61
+ export declare function loadLlamaModelInfo(model: string): Promise<Object>;
62
+ export declare function initLlama({ model, is_model_asset: isModelAsset, pooling_type: poolingType, lora, ...rest }: ContextParams, onProgress?: (progress: number) => void): Promise<LlamaContext>;
54
63
  export declare function releaseAllLlama(): Promise<void>;
55
64
  //# sourceMappingURL=index.d.ts.map
@@ -1 +1 @@
1
- {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":"AAGA,OAAO,KAAK,EACV,mBAAmB,EACnB,kBAAkB,EAClB,sBAAsB,EACtB,yBAAyB,EACzB,sBAAsB,EACtB,oBAAoB,EACpB,qBAAqB,EACrB,uBAAuB,EACvB,iBAAiB,EAClB,MAAM,iBAAiB,CAAA;AACxB,OAAO,EAAE,sBAAsB,EAAE,0BAA0B,EAAE,MAAM,WAAW,CAAA;AAC9E,OAAO,KAAK,EAAE,2BAA2B,EAAE,MAAM,QAAQ,CAAA;AAGzD,OAAO,EAAE,sBAAsB,EAAE,0BAA0B,EAAE,CAAA;AAe7D,MAAM,MAAM,SAAS,GAAG;IACtB,KAAK,EAAE,MAAM,CAAA;IACb,wBAAwB,CAAC,EAAE,KAAK,CAAC,yBAAyB,CAAC,CAAA;CAC5D,CAAA;AAOD,MAAM,MAAM,aAAa,GAAG,mBAAmB,CAAA;AAE/C,MAAM,MAAM,gBAAgB,GAAG,IAAI,CACjC,sBAAsB,EACtB,yBAAyB,GAAG,QAAQ,CACrC,GAAG;IACF,MAAM,CAAC,EAAE,MAAM,CAAA;IACf,QAAQ,CAAC,EAAE,2BAA2B,EAAE,CAAA;CACzC,CAAA;AAED,MAAM,MAAM,WAAW,GAAG;IACxB,SAAS,EAAE,MAAM,CAAA;IACjB,SAAS,EAAE,MAAM,CAAA;IACjB,YAAY,EAAE,MAAM,CAAA;IACpB,KAAK,EAAE,MAAM,CAAA;IACb,KAAK,EAAE,MAAM,CAAA;IACb,KAAK,EAAE,MAAM,CAAA;IACb,KAAK,EAAE,MAAM,CAAA;CACd,CAAA;AAED,qBAAa,YAAY;IACvB,EAAE,EAAE,MAAM,CAAA;IAEV,GAAG,EAAE,OAAO,CAAQ;IAEpB,WAAW,EAAE,MAAM,CAAK;IAExB,KAAK,EAAE;QACL,uBAAuB,CAAC,EAAE,OAAO,CAAA;KAClC,CAAK;gBAEM,EAAE,SAAS,EAAE,GAAG,EAAE,WAAW,EAAE,KAAK,EAAE,EAAE,kBAAkB;IAOtE;;OAEG;IACG,WAAW,CAAC,QAAQ,EAAE,MAAM,GAAG,OAAO,CAAC,uBAAuB,CAAC;IAMrE;;OAEG;IACG,WAAW,CACf,QAAQ,EAAE,MAAM,EAChB,OAAO,CAAC,EAAE;QAAE,SAAS,EAAE,MAAM,CAAA;KAAE,GAC9B,OAAO,CAAC,MAAM,CAAC;IAIZ,gBAAgB,CACpB,QAAQ,EAAE,2BAA2B,EAAE,GACtC,OAAO,CAAC,MAAM,CAAC;IASZ,UAAU,CACd,MAAM,EAAE,gBAAgB,EACxB,QAAQ,CAAC,EAAE,CAAC,IAAI,EAAE,SAAS,KAAK,IAAI,GACnC,OAAO,CAAC,sBAAsB,CAAC;IAkClC,cAAc,IAAI,OAAO,CAAC,IAAI,CAAC;IAI/B,aAAa,CAAC,IAAI,EAAE,MAAM,GAAG,OAAO,CAAC,oBAAoB,CAAC;IAI1D,YAAY,CAAC,IAAI,EAAE,MAAM,GAAG,oBAAoB;IAIhD,UAAU,CAAC,MAAM,EAAE,MAAM,EAAE,GAAG,OAAO,CAAC,MAAM,CAAC;IAI7C,SAAS,CAAC,IAAI,EAAE,MAAM,GAAG,OAAO,CAAC,qBAAqB,CAAC;IAIjD,KAAK,CACT,EAAE,EAAE,MAAM,EACV,EAAE,EAAE,MAAM,EACV,EAAE,EAAE,MAAM,EACV,EAAE,EAAE,MAAM,GACT,OAAO,CAAC,WAAW,CAAC;IAejB,OAAO,IAAI,OAAO,CAAC,IAAI,CAAC;CAG/B;AAED,wBAAsB,cAAc,IAAK,OAAO,CAAC,iBAAiB,CAAC,CAElE;AAED,wBAAsB,eAAe,CAAC,KAAK,EAAE,MAAM,GAAG,OAAO,CAAC,IAAI,CAAC,CAElE;AAED,wBAAsB,SAAS,CAAC,EAC5B,KAAK,EACL,cAAc,EAAE,YAAY,EAC5B,GAAG,IAAI,EACR,EAAE,aAAa,EAChB,gBAAgB,CAAC,EAAE,CAAC,QAAQ,EAAE,MAAM,KAAK,IAAI,GAC5C,OAAO,CAAC,YAAY,CAAC,CAwBvB;AAED,wBAAsB,eAAe,IAAI,OAAO,CAAC,IAAI,CAAC,CAErD"}
1
+ {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":"AAGA,OAAO,KAAK,EACV,mBAAmB,EACnB,kBAAkB,EAClB,sBAAsB,EACtB,yBAAyB,EACzB,sBAAsB,EACtB,oBAAoB,EACpB,qBAAqB,EACrB,uBAAuB,EACvB,iBAAiB,EACjB,qBAAqB,EACrB,6BAA6B,EAC7B,6BAA6B,EAC9B,MAAM,iBAAiB,CAAA;AACxB,OAAO,KAAK,EAAE,+BAA+B,EAAE,iCAAiC,EAAE,MAAM,WAAW,CAAA;AACnG,OAAO,EAAE,sBAAsB,EAAE,0BAA0B,EAAE,MAAM,WAAW,CAAA;AAC9E,OAAO,KAAK,EAAE,kBAAkB,EAAE,2BAA2B,EAAE,MAAM,QAAQ,CAAA;AAG7E,YAAY,EACV,mBAAmB,EACnB,kBAAkB,EAClB,sBAAsB,EACtB,yBAAyB,EACzB,sBAAsB,EACtB,oBAAoB,EACpB,qBAAqB,EACrB,uBAAuB,EACvB,qBAAqB,EACrB,6BAA6B,EAC7B,6BAA6B,EAC7B,kBAAkB,EAClB,2BAA2B,EAC3B,+BAA+B,EAC/B,iCAAiC,GAClC,CAAA;AAED,OAAO,EAAE,sBAAsB,EAAE,0BAA0B,EAAE,CAAA;AAc7D,MAAM,MAAM,SAAS,GAAG;IACtB,KAAK,EAAE,MAAM,CAAA;IACb,wBAAwB,CAAC,EAAE,KAAK,CAAC,yBAAyB,CAAC,CAAA;CAC5D,CAAA;AAOD,MAAM,MAAM,aAAa,GAAG,IAAI,CAC9B,mBAAmB,EACnB,cAAc,GAAG,cAAc,GAAI,cAAc,CAClD,GAAG;IACF,YAAY,CAAC,EAAE,KAAK,GAAG,KAAK,GAAG,MAAM,GAAG,MAAM,GAAG,MAAM,GAAG,QAAQ,GAAG,MAAM,GAAG,MAAM,CAAA;IACpF,YAAY,CAAC,EAAE,KAAK,GAAG,KAAK,GAAG,MAAM,GAAG,MAAM,GAAG,MAAM,GAAG,QAAQ,GAAG,MAAM,GAAG,MAAM,CAAA;IACpF,YAAY,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,KAAK,GAAG,MAAM,GAAG,MAAM,CAAA;CACzD,CAAA;AAED,MAAM,MAAM,eAAe,GAAG,qBAAqB,CAAA;AAEnD,MAAM,MAAM,gBAAgB,GAAG,IAAI,CACjC,sBAAsB,EACtB,yBAAyB,GAAG,QAAQ,CACrC,GAAG;IACF,MAAM,CAAC,EAAE,MAAM,CAAA;IACf,QAAQ,CAAC,EAAE,2BAA2B,EAAE,CAAA;IACxC,YAAY,CAAC,EAAE,MAAM,CAAA;CACtB,CAAA;AAED,MAAM,MAAM,WAAW,GAAG;IACxB,SAAS,EAAE,MAAM,CAAA;IACjB,SAAS,EAAE,MAAM,CAAA;IACjB,YAAY,EAAE,MAAM,CAAA;IACpB,KAAK,EAAE,MAAM,CAAA;IACb,KAAK,EAAE,MAAM,CAAA;IACb,KAAK,EAAE,MAAM,CAAA;IACb,KAAK,EAAE,MAAM,CAAA;CACd,CAAA;AAED,qBAAa,YAAY;IACvB,EAAE,EAAE,MAAM,CAAA;IAEV,GAAG,EAAE,OAAO,CAAQ;IAEpB,WAAW,EAAE,MAAM,CAAK;IAExB,KAAK,EAAE;QACL,uBAAuB,CAAC,EAAE,OAAO,CAAA;KAClC,CAAK;gBAEM,EAAE,SAAS,EAAE,GAAG,EAAE,WAAW,EAAE,KAAK,EAAE,EAAE,kBAAkB;IAOtE;;OAEG;IACG,WAAW,CAAC,QAAQ,EAAE,MAAM,GAAG,OAAO,CAAC,uBAAuB,CAAC;IAMrE;;OAEG;IACG,WAAW,CACf,QAAQ,EAAE,MAAM,EAChB,OAAO,CAAC,EAAE;QAAE,SAAS,EAAE,MAAM,CAAA;KAAE,GAC9B,OAAO,CAAC,MAAM,CAAC;IAIZ,gBAAgB,CACpB,QAAQ,EAAE,2BAA2B,EAAE,EACvC,QAAQ,CAAC,EAAE,MAAM,GAChB,OAAO,CAAC,MAAM,CAAC;IAOZ,UAAU,CACd,MAAM,EAAE,gBAAgB,EACxB,QAAQ,CAAC,EAAE,CAAC,IAAI,EAAE,SAAS,KAAK,IAAI,GACnC,OAAO,CAAC,sBAAsB,CAAC;IAkClC,cAAc,IAAI,OAAO,CAAC,IAAI,CAAC;IAI/B,aAAa,CAAC,IAAI,EAAE,MAAM,GAAG,OAAO,CAAC,oBAAoB,CAAC;IAI1D,YAAY,CAAC,IAAI,EAAE,MAAM,GAAG,oBAAoB;IAIhD,UAAU,CAAC,MAAM,EAAE,MAAM,EAAE,GAAG,OAAO,CAAC,MAAM,CAAC;IAI7C,SAAS,CACP,IAAI,EAAE,MAAM,EACZ,MAAM,CAAC,EAAE,eAAe,GACvB,OAAO,CAAC,qBAAqB,CAAC;IAI3B,KAAK,CACT,EAAE,EAAE,MAAM,EACV,EAAE,EAAE,MAAM,EACV,EAAE,EAAE,MAAM,EACV,EAAE,EAAE,MAAM,GACT,OAAO,CAAC,WAAW,CAAC;IAejB,OAAO,IAAI,OAAO,CAAC,IAAI,CAAC;CAG/B;AAED,wBAAsB,cAAc,IAAK,OAAO,CAAC,iBAAiB,CAAC,CAElE;AAED,wBAAsB,eAAe,CAAC,KAAK,EAAE,MAAM,GAAG,OAAO,CAAC,IAAI,CAAC,CAElE;AAYD,wBAAsB,kBAAkB,CAAC,KAAK,EAAE,MAAM,GAAG,OAAO,CAAC,MAAM,CAAC,CAIvE;AAWD,wBAAsB,SAAS,CAC7B,EACE,KAAK,EACL,cAAc,EAAE,YAAY,EAC5B,YAAY,EAAE,WAAW,EACzB,IAAI,EACJ,GAAG,IAAI,EACR,EAAE,aAAa,EAChB,UAAU,CAAC,EAAE,CAAC,QAAQ,EAAE,MAAM,KAAK,IAAI,GACtC,OAAO,CAAC,YAAY,CAAC,CAuCvB;AAED,wBAAsB,eAAe,IAAI,OAAO,CAAC,IAAI,CAAC,CAErD"}
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "cui-llama.rn",
3
- "version": "1.2.6",
3
+ "version": "1.3.3",
4
4
  "description": "Fork of llama.rn for ChatterUI",
5
5
  "main": "lib/commonjs/index",
6
6
  "module": "lib/module/index",
@@ -14,6 +14,7 @@
14
14
  "ios",
15
15
  "android",
16
16
  "cpp/*.*",
17
+ "cpp/amx/*.*",
17
18
  "*.podspec",
18
19
  "!lib/typescript/example",
19
20
  "!ios/build",
@@ -1,11 +1,14 @@
1
1
  import type { TurboModule } from 'react-native'
2
2
  import { TurboModuleRegistry } from 'react-native'
3
3
 
4
+ export type NativeEmbeddingParams = {
5
+ embd_normalize?: number
6
+ }
7
+
4
8
  export type NativeContextParams = {
5
9
  model: string
6
10
  is_model_asset?: boolean
7
-
8
- embedding?: boolean
11
+ use_progress_callback?: boolean
9
12
 
10
13
  n_ctx?: number
11
14
  n_batch?: number
@@ -13,6 +16,20 @@ export type NativeContextParams = {
13
16
  n_threads?: number
14
17
  n_gpu_layers?: number
15
18
 
19
+ /**
20
+ * Enable flash attention, only recommended in GPU device (Experimental in llama.cpp)
21
+ */
22
+ flash_attn?: boolean
23
+
24
+ /**
25
+ * KV cache data type for the K (Experimental in llama.cpp)
26
+ */
27
+ cache_type_k?: string
28
+ /**
29
+ * KV cache data type for the V (Experimental in llama.cpp)
30
+ */
31
+ cache_type_v?: string
32
+
16
33
  use_mlock?: boolean
17
34
  use_mmap?: boolean
18
35
  vocab_only?: boolean
@@ -22,35 +39,134 @@ export type NativeContextParams = {
22
39
 
23
40
  rope_freq_base?: number
24
41
  rope_freq_scale?: number
42
+
43
+ pooling_type?: number
44
+
45
+ // Embedding params
46
+ embedding?: boolean
47
+ embd_normalize?: number
25
48
  }
26
49
 
27
50
  export type NativeCompletionParams = {
28
51
  prompt: string
29
- grammar?: string
30
- stop?: Array<string> // -> antiprompt
31
-
32
52
  n_threads?: number
53
+ /**
54
+ * Set grammar for grammar-based sampling. Default: no grammar
55
+ */
56
+ grammar?: string
57
+ /**
58
+ * Specify a JSON array of stopping strings.
59
+ * These words will not be included in the completion, so make sure to add them to the prompt for the next iteration. Default: `[]`
60
+ */
61
+ stop?: Array<string>
62
+ /**
63
+ * Set the maximum number of tokens to predict when generating text.
64
+ * **Note:** May exceed the set limit slightly if the last token is a partial multibyte character.
65
+ * When 0,no tokens will be generated but the prompt is evaluated into the cache. Default: `-1`, where `-1` is infinity.
66
+ */
33
67
  n_predict?: number
68
+ /**
69
+ * If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings.
70
+ * Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings.
71
+ * Default: `0`
72
+ */
34
73
  n_probs?: number
74
+ /**
75
+ * Limit the next token selection to the K most probable tokens. Default: `40`
76
+ */
35
77
  top_k?: number
78
+ /**
79
+ * Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P. Default: `0.95`
80
+ */
36
81
  top_p?: number
82
+ /**
83
+ * The minimum probability for a token to be considered, relative to the probability of the most likely token. Default: `0.05`
84
+ */
37
85
  min_p?: number
38
- xtc_t?: number
39
- xtc_p?: number
86
+ /**
87
+ * Set the chance for token removal via XTC sampler. Default: `0.0`, which is disabled.
88
+ */
89
+ xtc_probability?: number
90
+ /**
91
+ * Set a minimum probability threshold for tokens to be removed via XTC sampler. Default: `0.1` (> `0.5` disables XTC)
92
+ */
93
+ xtc_threshold?: number
94
+ /**
95
+ * Enable locally typical sampling with parameter p. Default: `1.0`, which is disabled.
96
+ */
40
97
  typical_p?: number
41
- temperature?: number // -> temp
98
+ /**
99
+ * Adjust the randomness of the generated text. Default: `0.8`
100
+ */
101
+ temperature?: number
102
+ /**
103
+ * Last n tokens to consider for penalizing repetition. Default: `64`, where `0` is disabled and `-1` is ctx-size.
104
+ */
42
105
  penalty_last_n?: number
106
+ /**
107
+ * Control the repetition of token sequences in the generated text. Default: `1.0`
108
+ */
43
109
  penalty_repeat?: number
110
+ /**
111
+ * Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
112
+ */
44
113
  penalty_freq?: number
114
+ /**
115
+ * Repeat alpha presence penalty. Default: `0.0`, which is disabled.
116
+ */
45
117
  penalty_present?: number
118
+ /**
119
+ * Penalize newline tokens when applying the repeat penalty. Default: `false`
120
+ */
121
+ penalize_nl?: boolean
122
+ /**
123
+ * Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.
124
+ */
46
125
  mirostat?: number
126
+ /**
127
+ * Set the Mirostat target entropy, parameter tau. Default: `5.0`
128
+ */
47
129
  mirostat_tau?: number
130
+ /**
131
+ * Set the Mirostat learning rate, parameter eta. Default: `0.1`
132
+ */
48
133
  mirostat_eta?: number
49
- penalize_nl?: boolean
50
- seed?: number
51
-
134
+ /**
135
+ * Set the DRY (Don't Repeat Yourself) repetition penalty multiplier. Default: `0.0`, which is disabled.
136
+ */
137
+ dry_multiplier?: number
138
+ /**
139
+ * Set the DRY repetition penalty base value. Default: `1.75`
140
+ */
141
+ dry_base?: number
142
+ /**
143
+ * Tokens that extend repetition beyond this receive exponentially increasing penalty: multiplier * base ^ (length of repeating sequence before token - allowed length). Default: `2`
144
+ */
145
+ dry_allowed_length?: number
146
+ /**
147
+ * How many tokens to scan for repetitions. Default: `-1`, where `0` is disabled and `-1` is context size.
148
+ */
149
+ dry_penalty_last_n?: number
150
+ /**
151
+ * Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n', ':', '"', '*']`
152
+ */
153
+ dry_sequence_breakers?: Array<string>
154
+ /**
155
+ * Ignore end of stream token and continue generating. Default: `false`
156
+ */
52
157
  ignore_eos?: boolean
158
+ /**
159
+ * Modify the likelihood of a token appearing in the generated text completion.
160
+ * For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood.
161
+ * Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings,
162
+ * e.g.`[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does.
163
+ * Default: `[]`
164
+ */
53
165
  logit_bias?: Array<Array<number>>
166
+ /**
167
+ * Set the random number generator (RNG) seed. Default: `-1`, which is a random seed.
168
+ */
169
+ seed?: number
54
170
 
55
171
  emit_partial_completion: boolean
56
172
  }
@@ -125,7 +241,9 @@ export type NativeCPUFeatures = {
125
241
 
126
242
  export interface Spec extends TurboModule {
127
243
  setContextLimit(limit: number): Promise<void>
128
- initContext(params: NativeContextParams): Promise<NativeLlamaContext>
244
+
245
+ modelInfo(path: string, skip?: string[]): Promise<Object>
246
+ initContext(contextId: number, params: NativeContextParams): Promise<NativeLlamaContext>
129
247
 
130
248
  loadSession(
131
249
  contextId: number,
@@ -150,7 +268,11 @@ export interface Spec extends TurboModule {
150
268
  chatTemplate?: string,
151
269
  ): Promise<string>
152
270
  detokenize(contextId: number, tokens: number[]): Promise<string>
153
- embedding(contextId: number, text: string): Promise<NativeEmbeddingResult>
271
+ embedding(
272
+ contextId: number,
273
+ text: string,
274
+ params: NativeEmbeddingParams,
275
+ ): Promise<NativeEmbeddingResult>
154
276
  bench(
155
277
  contextId: number,
156
278
  pp: number,
package/src/grammar.ts CHANGED
@@ -74,7 +74,7 @@ function buildRepetition(
74
74
  return result
75
75
  }
76
76
 
77
- class BuiltinRule {
77
+ export class SchemaGrammarConverterBuiltinRule {
78
78
  content: string
79
79
 
80
80
  deps: string[]
@@ -85,9 +85,11 @@ class BuiltinRule {
85
85
  }
86
86
  }
87
87
 
88
+ const BuiltinRule = SchemaGrammarConverterBuiltinRule
89
+
88
90
  const UP_TO_15_DIGITS = buildRepetition('[0-9]', 0, 15)
89
91
 
90
- const PRIMITIVE_RULES: { [key: string]: BuiltinRule } = {
92
+ const PRIMITIVE_RULES: { [key: string]: SchemaGrammarConverterBuiltinRule } = {
91
93
  boolean: new BuiltinRule('("true" | "false") space', []),
92
94
  'decimal-part': new BuiltinRule(`[0-9] ${UP_TO_15_DIGITS}`, []),
93
95
  'integral-part': new BuiltinRule(`[0-9] | [1-9] ${UP_TO_15_DIGITS}`, []),
@@ -126,7 +128,7 @@ const PRIMITIVE_RULES: { [key: string]: BuiltinRule } = {
126
128
  }
127
129
 
128
130
  // TODO: support "uri", "email" string formats
129
- const STRING_FORMAT_RULES: { [key: string]: BuiltinRule } = {
131
+ const STRING_FORMAT_RULES: { [key: string]: SchemaGrammarConverterBuiltinRule } = {
130
132
  date: new BuiltinRule(
131
133
  '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )',
132
134
  [],
@@ -173,7 +175,7 @@ const formatLiteral = (literal: string): string => {
173
175
  const generateConstantRule = (value: any): string =>
174
176
  formatLiteral(JSON.stringify(value))
175
177
 
176
- interface PropOrder {
178
+ export interface SchemaGrammarConverterPropOrder {
177
179
  [key: string]: number
178
180
  }
179
181
 
@@ -196,7 +198,7 @@ function* groupBy(iterable: Iterable<any>, keyFn: (x: any) => any) {
196
198
  }
197
199
 
198
200
  export class SchemaGrammarConverter {
199
- private _propOrder: PropOrder
201
+ private _propOrder: SchemaGrammarConverterPropOrder
200
202
 
201
203
  private _allowFetch: boolean
202
204
 
@@ -209,7 +211,7 @@ export class SchemaGrammarConverter {
209
211
  private _refsBeingResolved: Set<string>
210
212
 
211
213
  constructor(options: {
212
- prop_order?: PropOrder
214
+ prop_order?: SchemaGrammarConverterPropOrder
213
215
  allow_fetch?: boolean
214
216
  dotall?: boolean
215
217
  }) {
@@ -690,7 +692,7 @@ export class SchemaGrammarConverter {
690
692
  }
691
693
  }
692
694
 
693
- _addPrimitive(name: string, rule: BuiltinRule | undefined) {
695
+ _addPrimitive(name: string, rule: SchemaGrammarConverterBuiltinRule | undefined) {
694
696
  if (!rule) {
695
697
  throw new Error(`Rule ${name} not known`)
696
698
  }
@@ -828,7 +830,7 @@ export const convertJsonSchemaToGrammar = ({
828
830
  allowFetch,
829
831
  }: {
830
832
  schema: any
831
- propOrder?: PropOrder
833
+ propOrder?: SchemaGrammarConverterPropOrder
832
834
  dotall?: boolean
833
835
  allowFetch?: boolean
834
836
  }): string | Promise<string> => {
package/src/index.ts CHANGED
@@ -11,17 +11,38 @@ import type {
11
11
  NativeEmbeddingResult,
12
12
  NativeSessionLoadResult,
13
13
  NativeCPUFeatures,
14
+ NativeEmbeddingParams,
15
+ NativeCompletionTokenProbItem,
16
+ NativeCompletionResultTimings,
14
17
  } from './NativeRNLlama'
18
+ import type { SchemaGrammarConverterPropOrder, SchemaGrammarConverterBuiltinRule } from './grammar'
15
19
  import { SchemaGrammarConverter, convertJsonSchemaToGrammar } from './grammar'
16
- import type { RNLlamaOAICompatibleMessage } from './chat'
20
+ import type { RNLlamaMessagePart, RNLlamaOAICompatibleMessage } from './chat'
17
21
  import { formatChat } from './chat'
18
22
 
23
+ export type {
24
+ NativeContextParams,
25
+ NativeLlamaContext,
26
+ NativeCompletionParams,
27
+ NativeCompletionTokenProb,
28
+ NativeCompletionResult,
29
+ NativeTokenizeResult,
30
+ NativeEmbeddingResult,
31
+ NativeSessionLoadResult,
32
+ NativeEmbeddingParams,
33
+ NativeCompletionTokenProbItem,
34
+ NativeCompletionResultTimings,
35
+ RNLlamaMessagePart,
36
+ RNLlamaOAICompatibleMessage,
37
+ SchemaGrammarConverterPropOrder,
38
+ SchemaGrammarConverterBuiltinRule,
39
+ }
40
+
19
41
  export { SchemaGrammarConverter, convertJsonSchemaToGrammar }
20
42
 
43
+ const EVENT_ON_INIT_CONTEXT_PROGRESS = '@RNLlama_onInitContextProgress'
21
44
  const EVENT_ON_TOKEN = '@RNLlama_onToken'
22
45
 
23
- const EVENT_ON_MODEL_PROGRESS = '@RNLlama_onModelProgress'
24
-
25
46
  let EventEmitter: NativeEventEmitter | DeviceEventEmitterStatic
26
47
  if (Platform.OS === 'ios') {
27
48
  // @ts-ignore
@@ -41,7 +62,16 @@ type TokenNativeEvent = {
41
62
  tokenResult: TokenData
42
63
  }
43
64
 
44
- export type ContextParams = NativeContextParams
65
+ export type ContextParams = Omit<
66
+ NativeContextParams,
67
+ 'cache_type_k' | 'cache_type_v' | 'pooling_type'
68
+ > & {
69
+ cache_type_k?: 'f16' | 'f32' | 'q8_0' | 'q4_0' | 'q4_1' | 'iq4_nl' | 'q5_0' | 'q5_1'
70
+ cache_type_v?: 'f16' | 'f32' | 'q8_0' | 'q4_0' | 'q4_1' | 'iq4_nl' | 'q5_0' | 'q5_1'
71
+ pooling_type?: 'none' | 'mean' | 'cls' | 'last' | 'rank'
72
+ }
73
+
74
+ export type EmbeddingParams = NativeEmbeddingParams
45
75
 
46
76
  export type CompletionParams = Omit<
47
77
  NativeCompletionParams,
@@ -49,6 +79,7 @@ export type CompletionParams = Omit<
49
79
  > & {
50
80
  prompt?: string
51
81
  messages?: RNLlamaOAICompatibleMessage[]
82
+ chatTemplate?: string
52
83
  }
53
84
 
54
85
  export type BenchResult = {
@@ -100,23 +131,22 @@ export class LlamaContext {
100
131
 
101
132
  async getFormattedChat(
102
133
  messages: RNLlamaOAICompatibleMessage[],
134
+ template?: string,
103
135
  ): Promise<string> {
104
136
  const chat = formatChat(messages)
105
- return RNLlama.getFormattedChat(
106
- this.id,
107
- chat,
108
- this.model?.isChatTemplateSupported ? undefined : 'chatml',
109
- )
137
+ let tmpl = this.model?.isChatTemplateSupported ? undefined : 'chatml'
138
+ if (template) tmpl = template // Force replace if provided
139
+ return RNLlama.getFormattedChat(this.id, chat, tmpl)
110
140
  }
111
141
 
112
142
  async completion(
113
143
  params: CompletionParams,
114
144
  callback?: (data: TokenData) => void,
115
145
  ): Promise<NativeCompletionResult> {
116
-
117
146
  let finalPrompt = params.prompt
118
- if (params.messages) { // messages always win
119
- finalPrompt = await this.getFormattedChat(params.messages)
147
+ if (params.messages) {
148
+ // messages always win
149
+ finalPrompt = await this.getFormattedChat(params.messages, params.chatTemplate)
120
150
  }
121
151
 
122
152
  let tokenListener: any =
@@ -162,8 +192,11 @@ export class LlamaContext {
162
192
  return RNLlama.detokenize(this.id, tokens)
163
193
  }
164
194
 
165
- embedding(text: string): Promise<NativeEmbeddingResult> {
166
- return RNLlama.embedding(this.id, text)
195
+ embedding(
196
+ text: string,
197
+ params?: EmbeddingParams,
198
+ ): Promise<NativeEmbeddingResult> {
199
+ return RNLlama.embedding(this.id, text, params || {})
167
200
  }
168
201
 
169
202
  async bench(
@@ -199,35 +232,78 @@ export async function setContextLimit(limit: number): Promise<void> {
199
232
  return RNLlama.setContextLimit(limit)
200
233
  }
201
234
 
202
- export async function initLlama({
235
+ let contextIdCounter = 0
236
+ const contextIdRandom = () =>
237
+ process.env.NODE_ENV === 'test' ? 0 : Math.floor(Math.random() * 100000)
238
+
239
+ const modelInfoSkip = [
240
+ // Large fields
241
+ 'tokenizer.ggml.tokens',
242
+ 'tokenizer.ggml.token_type',
243
+ 'tokenizer.ggml.merges',
244
+ ]
245
+ export async function loadLlamaModelInfo(model: string): Promise<Object> {
246
+ let path = model
247
+ if (path.startsWith('file://')) path = path.slice(7)
248
+ return RNLlama.modelInfo(path, modelInfoSkip)
249
+ }
250
+
251
+ const poolTypeMap = {
252
+ // -1 is unspecified as undefined
253
+ none: 0,
254
+ mean: 1,
255
+ cls: 2,
256
+ last: 3,
257
+ rank: 4,
258
+ }
259
+
260
+ export async function initLlama(
261
+ {
203
262
  model,
204
263
  is_model_asset: isModelAsset,
264
+ pooling_type: poolingType,
265
+ lora,
205
266
  ...rest
206
- }: ContextParams,
207
- progressCallback?: (progress: number) => void
267
+ }: ContextParams,
268
+ onProgress?: (progress: number) => void,
208
269
  ): Promise<LlamaContext> {
209
270
  let path = model
210
271
  if (path.startsWith('file://')) path = path.slice(7)
211
-
212
- const modelProgressListener = EventEmitter.addListener(EVENT_ON_MODEL_PROGRESS, (event) => {
213
- if(event.progress && progressCallback)
214
- progressCallback(event.progress)
215
- if(event.progress === 100) {
216
- modelProgressListener.remove()
217
- }
218
- })
219
272
 
273
+ let loraPath = lora
274
+ if (loraPath?.startsWith('file://')) loraPath = loraPath.slice(7)
275
+
276
+ const contextId = contextIdCounter + contextIdRandom()
277
+ contextIdCounter += 1
278
+
279
+ let removeProgressListener: any = null
280
+ if (onProgress) {
281
+ removeProgressListener = EventEmitter.addListener(
282
+ EVENT_ON_INIT_CONTEXT_PROGRESS,
283
+ (evt: { contextId: number; progress: number }) => {
284
+ if (evt.contextId !== contextId) return
285
+ onProgress(evt.progress)
286
+ },
287
+ )
288
+ }
289
+
290
+ const poolType = poolTypeMap[poolingType as keyof typeof poolTypeMap]
220
291
  const {
221
- contextId,
222
292
  gpu,
223
293
  reasonNoGPU,
224
294
  model: modelDetails,
225
- } = await RNLlama.initContext({
295
+ } = await RNLlama.initContext(contextId, {
226
296
  model: path,
227
297
  is_model_asset: !!isModelAsset,
298
+ use_progress_callback: !!onProgress,
299
+ pooling_type: poolType,
300
+ lora: loraPath,
228
301
  ...rest,
302
+ }).catch((err: any) => {
303
+ removeProgressListener?.remove()
304
+ throw err
229
305
  })
230
-
306
+ removeProgressListener?.remove()
231
307
  return new LlamaContext({ contextId, gpu, reasonNoGPU, model: modelDetails })
232
308
  }
233
309