@inferrlm/react-native-mlx 0.4.2-alpha.7 → 0.4.2-alpha.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/ios/Sources/HybridLLM.swift +48 -12
- package/package.json +1 -1
|
@@ -406,6 +406,8 @@ class HybridLLM: HybridLLMSpec {
|
|
|
406
406
|
}
|
|
407
407
|
}
|
|
408
408
|
|
|
409
|
+
private static let fallbackTemplate = "{%- for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{%- endfor %}{%- if add_generation_prompt %}{%- if enable_thinking is defined and enable_thinking is true %}{{ '<|im_start|>assistant\\n<think>\\n' }}{%- else %}{{ '<|im_start|>assistant\\n' }}{%- endif %}{%- endif %}"
|
|
410
|
+
|
|
409
411
|
private func buildChatMessages(
|
|
410
412
|
prompt: String,
|
|
411
413
|
toolResults: [String]?,
|
|
@@ -447,6 +449,50 @@ class HybridLLM: HybridLLMSpec {
|
|
|
447
449
|
return chat
|
|
448
450
|
}
|
|
449
451
|
|
|
452
|
+
private func prepareInput(
|
|
453
|
+
container: ModelContainer,
|
|
454
|
+
chat: [Chat.Message]
|
|
455
|
+
) async throws -> LMInput {
|
|
456
|
+
let tools = !self.toolSchemas.isEmpty ? self.toolSchemas : nil
|
|
457
|
+
let additionalCtx: [String: any Sendable] = ["enable_thinking": true]
|
|
458
|
+
|
|
459
|
+
let messages: [[String: any Sendable]] = chat.map {
|
|
460
|
+
["role": $0.role.rawValue, "content": $0.content]
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
let tokens: [Int] = try await container.perform { context in
|
|
464
|
+
do {
|
|
465
|
+
let result = try context.tokenizer.applyChatTemplate(
|
|
466
|
+
messages: messages,
|
|
467
|
+
addGenerationPrompt: true,
|
|
468
|
+
tools: tools,
|
|
469
|
+
additionalContext: additionalCtx
|
|
470
|
+
)
|
|
471
|
+
self.log("template_applied token_count=\(result.count)")
|
|
472
|
+
let decoded = context.tokenizer.decode(tokens: Array(result.suffix(60)))
|
|
473
|
+
self.log("input_tail_decoded: \(decoded)")
|
|
474
|
+
return result
|
|
475
|
+
} catch {
|
|
476
|
+
self.log("template_error: \(error), retrying with fallback")
|
|
477
|
+
let result = try context.tokenizer.applyChatTemplate(
|
|
478
|
+
messages: messages,
|
|
479
|
+
chatTemplate: .literal(HybridLLM.fallbackTemplate),
|
|
480
|
+
addGenerationPrompt: true,
|
|
481
|
+
truncation: false,
|
|
482
|
+
maxLength: nil,
|
|
483
|
+
tools: nil,
|
|
484
|
+
additionalContext: additionalCtx
|
|
485
|
+
)
|
|
486
|
+
self.log("fallback_template_applied token_count=\(result.count)")
|
|
487
|
+
let decoded = context.tokenizer.decode(tokens: Array(result.suffix(60)))
|
|
488
|
+
self.log("fallback_input_tail_decoded: \(decoded)")
|
|
489
|
+
return result
|
|
490
|
+
}
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
return LMInput(tokens: MLXArray(tokens))
|
|
494
|
+
}
|
|
495
|
+
|
|
450
496
|
private func executeToolCall(
|
|
451
497
|
tool: ToolDefinition,
|
|
452
498
|
argsDict: [String: Any]
|
|
@@ -487,12 +533,7 @@ class HybridLLM: HybridLLMSpec {
|
|
|
487
533
|
log("perform_gen_events depth=\(depth) prompt=\(prompt.count)chars toolResults=\(toolResults?.count ?? 0)")
|
|
488
534
|
|
|
489
535
|
let chat = buildChatMessages(prompt: prompt, toolResults: toolResults, depth: depth)
|
|
490
|
-
let
|
|
491
|
-
chat: chat,
|
|
492
|
-
tools: !self.toolSchemas.isEmpty ? self.toolSchemas : nil
|
|
493
|
-
)
|
|
494
|
-
|
|
495
|
-
let lmInput = try await container.prepare(input: userInput)
|
|
536
|
+
let lmInput = try await prepareInput(container: container, chat: chat)
|
|
496
537
|
log("perform_gen_events input_prepared messages=\(chat.count) maxTokens=\(self.maxTokens) temperature=\(self.temperature)")
|
|
497
538
|
|
|
498
539
|
let stream = try await container.perform { context in
|
|
@@ -681,12 +722,7 @@ class HybridLLM: HybridLLMSpec {
|
|
|
681
722
|
log("perform_gen depth=\(depth) prompt=\(prompt.count)chars toolResults=\(toolResults?.count ?? 0)")
|
|
682
723
|
|
|
683
724
|
let chat = buildChatMessages(prompt: prompt, toolResults: toolResults, depth: depth)
|
|
684
|
-
let
|
|
685
|
-
chat: chat,
|
|
686
|
-
tools: !self.toolSchemas.isEmpty ? self.toolSchemas : nil
|
|
687
|
-
)
|
|
688
|
-
|
|
689
|
-
let lmInput = try await container.prepare(input: userInput)
|
|
725
|
+
let lmInput = try await prepareInput(container: container, chat: chat)
|
|
690
726
|
log("perform_gen input_prepared messages=\(chat.count) maxTokens=\(self.maxTokens) temperature=\(self.temperature)")
|
|
691
727
|
|
|
692
728
|
let stream = try await container.perform { context in
|
package/package.json
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@inferrlm/react-native-mlx",
|
|
3
3
|
"description": "MLX Swift integration for React Native - InferrLM fork with enhanced features",
|
|
4
|
-
"version": "0.4.2-alpha.
|
|
4
|
+
"version": "0.4.2-alpha.8",
|
|
5
5
|
"main": "./lib/module/index.js",
|
|
6
6
|
"module": "./lib/module/index.js",
|
|
7
7
|
"types": "./lib/typescript/src/index.d.ts",
|