langchainrb 0.13.5 → 0.14.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +2 -17
- data/lib/langchain/assistants/assistant.rb +90 -19
- data/lib/langchain/assistants/messages/ollama_message.rb +86 -0
- data/lib/langchain/assistants/thread.rb +8 -1
- data/lib/langchain/llm/ai21.rb +0 -4
- data/lib/langchain/llm/anthropic.rb +15 -6
- data/lib/langchain/llm/azure.rb +3 -3
- data/lib/langchain/llm/base.rb +1 -0
- data/lib/langchain/llm/cohere.rb +0 -2
- data/lib/langchain/llm/google_palm.rb +1 -4
- data/lib/langchain/llm/ollama.rb +1 -1
- data/lib/langchain/llm/response/google_gemini_response.rb +1 -1
- data/lib/langchain/llm/response/ollama_response.rb +19 -1
- data/lib/langchain/vectorsearch/milvus.rb +1 -1
- data/lib/langchain/version.rb +1 -1
- metadata +5 -24
- data/lib/langchain/utils/token_length/ai21_validator.rb +0 -41
- data/lib/langchain/utils/token_length/base_validator.rb +0 -42
- data/lib/langchain/utils/token_length/cohere_validator.rb +0 -49
- data/lib/langchain/utils/token_length/google_palm_validator.rb +0 -57
- data/lib/langchain/utils/token_length/openai_validator.rb +0 -138
- data/lib/langchain/utils/token_length/token_limit_exceeded.rb +0 -17
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 68900cd116cf0fb1b77376a4906e5551f0d578ee2bb47c7ec86d32bf44f84e33
|
4
|
+
data.tar.gz: f68782c3cdc856799778618d78b6411a85b0c69adf6a4d33489b8025fdca3dce
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 158410fd769caaf9074eddc1143ddee9256ac5a466a510c32b74d337eba62fab80b676661cbf1673604d236014a5cb4defdd4743e71abb713a659ddea0fe5e8c
|
7
|
+
data.tar.gz: 2e956356a443ff37ad711f6c42f8c4940925bcee4be075b403c78c3f702b487c12790dca9ba7d68a01acaf1c245b2910650b3f938e80cedd1fc2d5af14f7ffa8
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,11 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [0.14.0] - 2024-07-12
|
4
|
+
- Removed TokenLength validators
|
5
|
+
- Assistant works with a Mistral LLM now
|
6
|
+
- Assistant keeps track of tokens used
|
7
|
+
- Misc fixes and improvements
|
8
|
+
|
3
9
|
## [0.13.5] - 2024-07-01
|
4
10
|
- Add Milvus#remove_texts() method
|
5
11
|
- Langchain::Assistant has a `state` now
|
data/README.md
CHANGED
@@ -428,25 +428,10 @@ Assistants are Agent-like objects that leverage helpful instructions, LLMs, tool
|
|
428
428
|
```ruby
|
429
429
|
llm = Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"])
|
430
430
|
```
|
431
|
-
2. Instantiate
|
432
|
-
```ruby
|
433
|
-
thread = Langchain::Thread.new
|
434
|
-
```
|
435
|
-
You can pass old message from previously using the Assistant:
|
436
|
-
```ruby
|
437
|
-
thread.messages = messages
|
438
|
-
```
|
439
|
-
Messages contain the conversation history and the whole message history is sent to the LLM every time. A Message belongs to 1 of the 4 roles:
|
440
|
-
* `Message(role: "system")` message usually contains the instructions.
|
441
|
-
* `Message(role: "user")` messages come from the user.
|
442
|
-
* `Message(role: "assistant")` messages are produced by the LLM.
|
443
|
-
* `Message(role: "tool")` messages are sent in response to tool calls with tool outputs.
|
444
|
-
|
445
|
-
3. Instantiate an Assistant
|
431
|
+
2. Instantiate an Assistant
|
446
432
|
```ruby
|
447
433
|
assistant = Langchain::Assistant.new(
|
448
434
|
llm: llm,
|
449
|
-
thread: thread,
|
450
435
|
instructions: "You are a Meteorologist Assistant that is able to pull the weather for any location",
|
451
436
|
tools: [
|
452
437
|
Langchain::Tool::Weather.new(api_key: ENV["OPEN_WEATHER_API_KEY"])
|
@@ -482,7 +467,7 @@ assistant.add_message_and_run content: "What about Sacramento, CA?", auto_tool_e
|
|
482
467
|
### Accessing Thread messages
|
483
468
|
You can access the messages in a Thread by calling `assistant.thread.messages`.
|
484
469
|
```ruby
|
485
|
-
assistant.
|
470
|
+
assistant.messages
|
486
471
|
```
|
487
472
|
|
488
473
|
The Assistant checks the context window limits before every request to the LLM and remove oldest thread messages one by one if the context window is exceeded.
|
@@ -16,13 +16,15 @@ module Langchain
|
|
16
16
|
def_delegators :thread, :messages, :messages=
|
17
17
|
|
18
18
|
attr_reader :llm, :thread, :instructions, :state
|
19
|
+
attr_reader :total_prompt_tokens, :total_completion_tokens, :total_tokens
|
19
20
|
attr_accessor :tools
|
20
21
|
|
21
22
|
SUPPORTED_LLMS = [
|
22
23
|
Langchain::LLM::Anthropic,
|
23
|
-
Langchain::LLM::OpenAI,
|
24
24
|
Langchain::LLM::GoogleGemini,
|
25
|
-
Langchain::LLM::GoogleVertexAI
|
25
|
+
Langchain::LLM::GoogleVertexAI,
|
26
|
+
Langchain::LLM::Ollama,
|
27
|
+
Langchain::LLM::OpenAI
|
26
28
|
]
|
27
29
|
|
28
30
|
# Create a new assistant
|
@@ -40,6 +42,9 @@ module Langchain
|
|
40
42
|
unless SUPPORTED_LLMS.include?(llm.class)
|
41
43
|
raise ArgumentError, "Invalid LLM; currently only #{SUPPORTED_LLMS.join(", ")} are supported"
|
42
44
|
end
|
45
|
+
if llm.is_a?(Langchain::LLM::Ollama)
|
46
|
+
raise ArgumentError, "Currently only `mistral:7b-instruct-v0.3-fp16` model is supported for Ollama LLM" unless llm.defaults[:completion_model_name] == "mistral:7b-instruct-v0.3-fp16"
|
47
|
+
end
|
43
48
|
raise ArgumentError, "Tools must be an array of Langchain::Tool::Base instance(s)" unless tools.is_a?(Array) && tools.all? { |tool| tool.is_a?(Langchain::Tool::Base) }
|
44
49
|
|
45
50
|
@llm = llm
|
@@ -48,13 +53,15 @@ module Langchain
|
|
48
53
|
@instructions = instructions
|
49
54
|
@state = :ready
|
50
55
|
|
56
|
+
@total_prompt_tokens = 0
|
57
|
+
@total_completion_tokens = 0
|
58
|
+
@total_tokens = 0
|
59
|
+
|
51
60
|
raise ArgumentError, "Thread must be an instance of Langchain::Thread" unless @thread.is_a?(Langchain::Thread)
|
52
61
|
|
53
62
|
# The first message in the thread should be the system instructions
|
54
63
|
# TODO: What if the user added old messages and the system instructions are already in there? Should this overwrite the existing instructions?
|
55
|
-
|
56
|
-
add_message(role: "system", content: instructions) if instructions
|
57
|
-
end
|
64
|
+
initialize_instructions
|
58
65
|
# For Google Gemini, and Anthropic system instructions are added to the `system:` param in the `chat` method
|
59
66
|
end
|
60
67
|
|
@@ -150,7 +157,6 @@ module Langchain
|
|
150
157
|
|
151
158
|
# Handle the current state and transition to the next state
|
152
159
|
#
|
153
|
-
# @param state [Symbol] The current state
|
154
160
|
# @return [Symbol] The next state
|
155
161
|
def handle_state
|
156
162
|
case @state
|
@@ -189,7 +195,6 @@ module Langchain
|
|
189
195
|
|
190
196
|
# Handle LLM message scenario
|
191
197
|
#
|
192
|
-
# @param auto_tool_execution [Boolean] Flag to indicate if tools should be executed automatically
|
193
198
|
# @return [Symbol] The next state
|
194
199
|
def handle_llm_message
|
195
200
|
thread.messages.last.tool_calls.any? ? :requires_action : :completed
|
@@ -208,14 +213,29 @@ module Langchain
|
|
208
213
|
# @return [Symbol] The next state
|
209
214
|
def handle_user_or_tool_message
|
210
215
|
response = chat_with_llm
|
211
|
-
add_message(role: response.role, content: response.chat_completion, tool_calls: response.tool_calls)
|
212
216
|
|
217
|
+
# With Ollama, we're calling the `llm.complete()` method
|
218
|
+
content = if llm.is_a?(Langchain::LLM::Ollama)
|
219
|
+
response.completion
|
220
|
+
else
|
221
|
+
response.chat_completion
|
222
|
+
end
|
223
|
+
|
224
|
+
add_message(role: response.role, content: content, tool_calls: response.tool_calls)
|
225
|
+
record_used_tokens(response.prompt_tokens, response.completion_tokens, response.total_tokens)
|
226
|
+
|
227
|
+
set_state_for(response: response)
|
228
|
+
end
|
229
|
+
|
230
|
+
def set_state_for(response:)
|
213
231
|
if response.tool_calls.any?
|
214
232
|
:in_progress
|
215
233
|
elsif response.chat_completion
|
216
234
|
:completed
|
235
|
+
elsif response.completion # Currently only used by Ollama
|
236
|
+
:completed
|
217
237
|
else
|
218
|
-
Langchain.logger.error("LLM response does not contain tool calls or
|
238
|
+
Langchain.logger.error("LLM response does not contain tool calls, chat or completion response")
|
219
239
|
:failed
|
220
240
|
end
|
221
241
|
end
|
@@ -236,6 +256,8 @@ module Langchain
|
|
236
256
|
# @return [String] The tool role
|
237
257
|
def determine_tool_role
|
238
258
|
case llm
|
259
|
+
when Langchain::LLM::Ollama
|
260
|
+
Langchain::Messages::OllamaMessage::TOOL_ROLE
|
239
261
|
when Langchain::LLM::OpenAI
|
240
262
|
Langchain::Messages::OpenAIMessage::TOOL_ROLE
|
241
263
|
when Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI
|
@@ -245,31 +267,58 @@ module Langchain
|
|
245
267
|
end
|
246
268
|
end
|
247
269
|
|
270
|
+
def initialize_instructions
|
271
|
+
if llm.is_a?(Langchain::LLM::Ollama)
|
272
|
+
content = String.new # rubocop: disable Performance/UnfreezeString
|
273
|
+
if tools.any?
|
274
|
+
content << %([AVAILABLE_TOOLS] #{tools.map(&:to_openai_tools).flatten}[/AVAILABLE_TOOLS])
|
275
|
+
end
|
276
|
+
if instructions
|
277
|
+
content << "[INST] #{instructions}[/INST]"
|
278
|
+
end
|
279
|
+
|
280
|
+
add_message(role: "system", content: content)
|
281
|
+
elsif llm.is_a?(Langchain::LLM::OpenAI)
|
282
|
+
add_message(role: "system", content: instructions) if instructions
|
283
|
+
end
|
284
|
+
end
|
285
|
+
|
248
286
|
# Call to the LLM#chat() method
|
249
287
|
#
|
250
288
|
# @return [Langchain::LLM::BaseResponse] The LLM response object
|
251
289
|
def chat_with_llm
|
252
290
|
Langchain.logger.info("Sending a call to #{llm.class}", for: self.class)
|
253
291
|
|
254
|
-
params = {
|
292
|
+
params = {}
|
255
293
|
|
256
|
-
if
|
257
|
-
if
|
294
|
+
if llm.is_a?(Langchain::LLM::OpenAI)
|
295
|
+
if tools.any?
|
258
296
|
params[:tools] = tools.map(&:to_openai_tools).flatten
|
259
297
|
params[:tool_choice] = "auto"
|
260
|
-
|
298
|
+
end
|
299
|
+
elsif llm.is_a?(Langchain::LLM::Anthropic)
|
300
|
+
if tools.any?
|
261
301
|
params[:tools] = tools.map(&:to_anthropic_tools).flatten
|
262
|
-
params[:system] = instructions if instructions
|
263
302
|
params[:tool_choice] = {type: "auto"}
|
264
|
-
|
303
|
+
end
|
304
|
+
params[:system] = instructions if instructions
|
305
|
+
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
306
|
+
if tools.any?
|
265
307
|
params[:tools] = tools.map(&:to_google_gemini_tools).flatten
|
266
308
|
params[:system] = instructions if instructions
|
267
309
|
params[:tool_choice] = "auto"
|
268
310
|
end
|
269
|
-
# TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it.
|
270
311
|
end
|
312
|
+
# TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it.
|
271
313
|
|
272
|
-
llm.
|
314
|
+
if llm.is_a?(Langchain::LLM::Ollama)
|
315
|
+
params[:raw] = true
|
316
|
+
params[:prompt] = thread.prompt_of_concatenated_messages
|
317
|
+
llm.complete(**params)
|
318
|
+
else
|
319
|
+
params[:messages] = thread.array_of_message_hashes
|
320
|
+
llm.chat(**params)
|
321
|
+
end
|
273
322
|
end
|
274
323
|
|
275
324
|
# Run the tools automatically
|
@@ -278,7 +327,9 @@ module Langchain
|
|
278
327
|
def run_tools(tool_calls)
|
279
328
|
# Iterate over each function invocation and submit tool output
|
280
329
|
tool_calls.each do |tool_call|
|
281
|
-
tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::
|
330
|
+
tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::Ollama)
|
331
|
+
extract_ollama_tool_call(tool_call: tool_call)
|
332
|
+
elsif llm.is_a?(Langchain::LLM::OpenAI)
|
282
333
|
extract_openai_tool_call(tool_call: tool_call)
|
283
334
|
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
284
335
|
extract_google_gemini_tool_call(tool_call: tool_call)
|
@@ -296,6 +347,12 @@ module Langchain
|
|
296
347
|
end
|
297
348
|
end
|
298
349
|
|
350
|
+
def extract_ollama_tool_call(tool_call:)
|
351
|
+
tool_name, method_name = tool_call.dig("name").split("__")
|
352
|
+
tool_arguments = tool_call.dig("arguments").transform_keys(&:to_sym)
|
353
|
+
[nil, tool_name, method_name, tool_arguments]
|
354
|
+
end
|
355
|
+
|
299
356
|
# Extract the tool call information from the OpenAI tool call hash
|
300
357
|
#
|
301
358
|
# @param tool_call [Hash] The tool call hash
|
@@ -346,7 +403,9 @@ module Langchain
|
|
346
403
|
# @param tool_call_id [String] The ID of the tool call to include in the message
|
347
404
|
# @return [Langchain::Message] The Message object
|
348
405
|
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
|
349
|
-
if llm.is_a?(Langchain::LLM::
|
406
|
+
if llm.is_a?(Langchain::LLM::Ollama)
|
407
|
+
Langchain::Messages::OllamaMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
|
408
|
+
elsif llm.is_a?(Langchain::LLM::OpenAI)
|
350
409
|
Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
|
351
410
|
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
352
411
|
Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
|
@@ -355,6 +414,18 @@ module Langchain
|
|
355
414
|
end
|
356
415
|
end
|
357
416
|
|
417
|
+
# Increment the tokens count based on the last interaction with the LLM
|
418
|
+
#
|
419
|
+
# @param prompt_tokens [Integer] The number of used prmopt tokens
|
420
|
+
# @param completion_tokens [Integer] The number of used completion tokens
|
421
|
+
# @param total_tokens [Integer] The total number of used tokens
|
422
|
+
# @return [Integer] The current total tokens count
|
423
|
+
def record_used_tokens(prompt_tokens, completion_tokens, total_tokens_from_operation)
|
424
|
+
@total_prompt_tokens += prompt_tokens if prompt_tokens
|
425
|
+
@total_completion_tokens += completion_tokens if completion_tokens
|
426
|
+
@total_tokens += total_tokens_from_operation if total_tokens_from_operation
|
427
|
+
end
|
428
|
+
|
358
429
|
# TODO: Fix the message truncation when context window is exceeded
|
359
430
|
end
|
360
431
|
end
|
@@ -0,0 +1,86 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
module Messages
|
5
|
+
class OllamaMessage < Base
|
6
|
+
# OpenAI uses the following roles:
|
7
|
+
ROLES = [
|
8
|
+
"system",
|
9
|
+
"assistant",
|
10
|
+
"user",
|
11
|
+
"tool"
|
12
|
+
].freeze
|
13
|
+
|
14
|
+
TOOL_ROLE = "tool"
|
15
|
+
|
16
|
+
# Initialize a new OpenAI message
|
17
|
+
#
|
18
|
+
# @param [String] The role of the message
|
19
|
+
# @param [String] The content of the message
|
20
|
+
# @param [Array<Hash>] The tool calls made in the message
|
21
|
+
# @param [String] The ID of the tool call
|
22
|
+
def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
|
23
|
+
raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
|
24
|
+
raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
|
25
|
+
|
26
|
+
@role = role
|
27
|
+
# Some Tools return content as a JSON hence `.to_s`
|
28
|
+
@content = content.to_s
|
29
|
+
@tool_calls = tool_calls
|
30
|
+
@tool_call_id = tool_call_id
|
31
|
+
end
|
32
|
+
|
33
|
+
def to_s
|
34
|
+
send(:"to_#{role}_message_string")
|
35
|
+
end
|
36
|
+
|
37
|
+
def to_system_message_string
|
38
|
+
content
|
39
|
+
end
|
40
|
+
|
41
|
+
def to_user_message_string
|
42
|
+
"[INST] #{content}[/INST]"
|
43
|
+
end
|
44
|
+
|
45
|
+
def to_tool_message_string
|
46
|
+
"[TOOL_RESULTS] #{content}[/TOOL_RESULTS]"
|
47
|
+
end
|
48
|
+
|
49
|
+
def to_assistant_message_string
|
50
|
+
if tool_calls.any?
|
51
|
+
%("[TOOL_CALLS] #{tool_calls}")
|
52
|
+
else
|
53
|
+
content
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|
57
|
+
# Check if the message came from an LLM
|
58
|
+
#
|
59
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
60
|
+
def llm?
|
61
|
+
assistant?
|
62
|
+
end
|
63
|
+
|
64
|
+
# Check if the message came from an LLM
|
65
|
+
#
|
66
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
67
|
+
def assistant?
|
68
|
+
role == "assistant"
|
69
|
+
end
|
70
|
+
|
71
|
+
# Check if the message are system instructions
|
72
|
+
#
|
73
|
+
# @return [Boolean] true/false whether this message are system instructions
|
74
|
+
def system?
|
75
|
+
role == "system"
|
76
|
+
end
|
77
|
+
|
78
|
+
# Check if the message is a tool call
|
79
|
+
#
|
80
|
+
# @return [Boolean] true/false whether this message is a tool call
|
81
|
+
def tool?
|
82
|
+
role == "tool"
|
83
|
+
end
|
84
|
+
end
|
85
|
+
end
|
86
|
+
end
|
@@ -17,7 +17,14 @@ module Langchain
|
|
17
17
|
#
|
18
18
|
# @return [Array<Hash>] The thread as an OpenAI API-compatible array of hashes
|
19
19
|
def array_of_message_hashes
|
20
|
-
messages
|
20
|
+
messages
|
21
|
+
.map(&:to_hash)
|
22
|
+
.compact
|
23
|
+
end
|
24
|
+
|
25
|
+
# Only used by the Assistant when it calls the LLM#complete() method
|
26
|
+
def prompt_of_concatenated_messages
|
27
|
+
messages.map(&:to_s).join
|
21
28
|
end
|
22
29
|
|
23
30
|
# Add a message to the thread
|
data/lib/langchain/llm/ai21.rb
CHANGED
@@ -16,8 +16,6 @@ module Langchain::LLM
|
|
16
16
|
model: "j2-ultra"
|
17
17
|
}.freeze
|
18
18
|
|
19
|
-
LENGTH_VALIDATOR = Langchain::Utils::TokenLength::AI21Validator
|
20
|
-
|
21
19
|
def initialize(api_key:, default_options: {})
|
22
20
|
depends_on "ai21"
|
23
21
|
|
@@ -35,8 +33,6 @@ module Langchain::LLM
|
|
35
33
|
def complete(prompt:, **params)
|
36
34
|
parameters = complete_parameters params
|
37
35
|
|
38
|
-
parameters[:maxTokens] = LENGTH_VALIDATOR.validate_max_tokens!(prompt, parameters[:model], {llm: client})
|
39
|
-
|
40
36
|
response = client.complete(prompt, parameters)
|
41
37
|
Langchain::LLM::AI21Response.new response, model: parameters[:model]
|
42
38
|
end
|
@@ -5,10 +5,10 @@ module Langchain::LLM
|
|
5
5
|
# Wrapper around Anthropic APIs.
|
6
6
|
#
|
7
7
|
# Gem requirements:
|
8
|
-
# gem "anthropic", "~> 0.
|
8
|
+
# gem "anthropic", "~> 0.3.0"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
#
|
11
|
+
# anthropic = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
|
12
12
|
#
|
13
13
|
class Anthropic < Base
|
14
14
|
DEFAULTS = {
|
@@ -18,9 +18,6 @@ module Langchain::LLM
|
|
18
18
|
max_tokens_to_sample: 256
|
19
19
|
}.freeze
|
20
20
|
|
21
|
-
# TODO: Implement token length validator for Anthropic
|
22
|
-
# LENGTH_VALIDATOR = Langchain::Utils::TokenLength::AnthropicValidator
|
23
|
-
|
24
21
|
# Initialize an Anthropic LLM instance
|
25
22
|
#
|
26
23
|
# @param api_key [String] The API key to use
|
@@ -81,7 +78,10 @@ module Langchain::LLM
|
|
81
78
|
parameters[:metadata] = metadata if metadata
|
82
79
|
parameters[:stream] = stream if stream
|
83
80
|
|
84
|
-
response =
|
81
|
+
response = with_api_error_handling do
|
82
|
+
client.complete(parameters: parameters)
|
83
|
+
end
|
84
|
+
|
85
85
|
Langchain::LLM::AnthropicResponse.new(response)
|
86
86
|
end
|
87
87
|
|
@@ -114,6 +114,15 @@ module Langchain::LLM
|
|
114
114
|
Langchain::LLM::AnthropicResponse.new(response)
|
115
115
|
end
|
116
116
|
|
117
|
+
def with_api_error_handling
|
118
|
+
response = yield
|
119
|
+
return if response.empty?
|
120
|
+
|
121
|
+
raise Langchain::LLM::ApiError.new "Anthropic API error: #{response.dig("error", "message")}" if response&.dig("error")
|
122
|
+
|
123
|
+
response
|
124
|
+
end
|
125
|
+
|
117
126
|
private
|
118
127
|
|
119
128
|
def set_extra_headers!
|
data/lib/langchain/llm/azure.rb
CHANGED
@@ -42,17 +42,17 @@ module Langchain::LLM
|
|
42
42
|
|
43
43
|
def embed(...)
|
44
44
|
@client = @embed_client
|
45
|
-
super
|
45
|
+
super
|
46
46
|
end
|
47
47
|
|
48
48
|
def complete(...)
|
49
49
|
@client = @chat_client
|
50
|
-
super
|
50
|
+
super
|
51
51
|
end
|
52
52
|
|
53
53
|
def chat(...)
|
54
54
|
@client = @chat_client
|
55
|
-
super
|
55
|
+
super
|
56
56
|
end
|
57
57
|
end
|
58
58
|
end
|
data/lib/langchain/llm/base.rb
CHANGED
@@ -8,6 +8,7 @@ module Langchain::LLM
|
|
8
8
|
# Langchain.rb provides a common interface to interact with all supported LLMs:
|
9
9
|
#
|
10
10
|
# - {Langchain::LLM::AI21}
|
11
|
+
# - {Langchain::LLM::Anthropic}
|
11
12
|
# - {Langchain::LLM::Azure}
|
12
13
|
# - {Langchain::LLM::Cohere}
|
13
14
|
# - {Langchain::LLM::GooglePalm}
|
data/lib/langchain/llm/cohere.rb
CHANGED
@@ -74,8 +74,6 @@ module Langchain::LLM
|
|
74
74
|
|
75
75
|
default_params.merge!(params)
|
76
76
|
|
77
|
-
default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], llm: client)
|
78
|
-
|
79
77
|
response = client.generate(**default_params)
|
80
78
|
Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
|
81
79
|
end
|
@@ -18,7 +18,7 @@ module Langchain::LLM
|
|
18
18
|
chat_completion_model_name: "chat-bison-001",
|
19
19
|
embeddings_model_name: "embedding-gecko-001"
|
20
20
|
}.freeze
|
21
|
-
|
21
|
+
|
22
22
|
ROLE_MAPPING = {
|
23
23
|
"assistant" => "ai"
|
24
24
|
}
|
@@ -96,9 +96,6 @@ module Langchain::LLM
|
|
96
96
|
examples: compose_examples(examples)
|
97
97
|
}
|
98
98
|
|
99
|
-
# chat-bison-001 is the only model that currently supports countMessageTokens functions
|
100
|
-
LENGTH_VALIDATOR.validate_max_tokens!(default_params[:messages], "chat-bison-001", llm: self)
|
101
|
-
|
102
99
|
if options[:stop_sequences]
|
103
100
|
default_params[:stop] = options.delete(:stop_sequences)
|
104
101
|
end
|
data/lib/langchain/llm/ollama.rb
CHANGED
@@ -36,7 +36,7 @@ module Langchain::LLM
|
|
36
36
|
end
|
37
37
|
|
38
38
|
def prompt_tokens
|
39
|
-
raw_response.
|
39
|
+
raw_response.fetch("prompt_eval_count", 0) if done?
|
40
40
|
end
|
41
41
|
|
42
42
|
def completion_tokens
|
@@ -47,6 +47,24 @@ module Langchain::LLM
|
|
47
47
|
prompt_tokens + completion_tokens if done?
|
48
48
|
end
|
49
49
|
|
50
|
+
def tool_calls
|
51
|
+
if chat_completion && (parsed_tool_calls = JSON.parse(chat_completion))
|
52
|
+
[parsed_tool_calls]
|
53
|
+
elsif completion&.include?("[TOOL_CALLS]") && (
|
54
|
+
parsed_tool_calls = JSON.parse(
|
55
|
+
completion
|
56
|
+
# Slice out the serialize JSON
|
57
|
+
.slice(/\{.*\}/)
|
58
|
+
# Replace hash rocket with colon
|
59
|
+
.gsub("=>", ":")
|
60
|
+
)
|
61
|
+
)
|
62
|
+
[parsed_tool_calls]
|
63
|
+
else
|
64
|
+
[]
|
65
|
+
end
|
66
|
+
end
|
67
|
+
|
50
68
|
private
|
51
69
|
|
52
70
|
def done?
|
@@ -140,7 +140,7 @@ module Langchain::Vectorsearch
|
|
140
140
|
|
141
141
|
client.search(
|
142
142
|
collection_name: index_name,
|
143
|
-
output_fields: ["id", "content", "vectors"
|
143
|
+
output_fields: ["id", "content"], # Add "vectors" if need to have full vectors returned.
|
144
144
|
top_k: k.to_s,
|
145
145
|
vectors: [embedding],
|
146
146
|
dsl_type: 1,
|
data/lib/langchain/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: langchainrb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.14.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrei Bondarev
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-07-
|
11
|
+
date: 2024-07-12 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: baran
|
@@ -212,14 +212,14 @@ dependencies:
|
|
212
212
|
requirements:
|
213
213
|
- - "~>"
|
214
214
|
- !ruby/object:Gem::Version
|
215
|
-
version: '0.
|
215
|
+
version: '0.3'
|
216
216
|
type: :development
|
217
217
|
prerelease: false
|
218
218
|
version_requirements: !ruby/object:Gem::Requirement
|
219
219
|
requirements:
|
220
220
|
- - "~>"
|
221
221
|
- !ruby/object:Gem::Version
|
222
|
-
version: '0.
|
222
|
+
version: '0.3'
|
223
223
|
- !ruby/object:Gem::Dependency
|
224
224
|
name: aws-sdk-bedrockruntime
|
225
225
|
requirement: !ruby/object:Gem::Requirement
|
@@ -682,20 +682,6 @@ dependencies:
|
|
682
682
|
- - "~>"
|
683
683
|
- !ruby/object:Gem::Version
|
684
684
|
version: 0.1.0
|
685
|
-
- !ruby/object:Gem::Dependency
|
686
|
-
name: tiktoken_ruby
|
687
|
-
requirement: !ruby/object:Gem::Requirement
|
688
|
-
requirements:
|
689
|
-
- - "~>"
|
690
|
-
- !ruby/object:Gem::Version
|
691
|
-
version: 0.0.9
|
692
|
-
type: :development
|
693
|
-
prerelease: false
|
694
|
-
version_requirements: !ruby/object:Gem::Requirement
|
695
|
-
requirements:
|
696
|
-
- - "~>"
|
697
|
-
- !ruby/object:Gem::Version
|
698
|
-
version: 0.0.9
|
699
685
|
description: Build LLM-backed Ruby applications with Ruby's Langchain.rb
|
700
686
|
email:
|
701
687
|
- andrei.bondarev13@gmail.com
|
@@ -711,6 +697,7 @@ files:
|
|
711
697
|
- lib/langchain/assistants/messages/anthropic_message.rb
|
712
698
|
- lib/langchain/assistants/messages/base.rb
|
713
699
|
- lib/langchain/assistants/messages/google_gemini_message.rb
|
700
|
+
- lib/langchain/assistants/messages/ollama_message.rb
|
714
701
|
- lib/langchain/assistants/messages/openai_message.rb
|
715
702
|
- lib/langchain/assistants/thread.rb
|
716
703
|
- lib/langchain/chunk.rb
|
@@ -810,12 +797,6 @@ files:
|
|
810
797
|
- lib/langchain/tool/wikipedia/wikipedia.rb
|
811
798
|
- lib/langchain/utils/cosine_similarity.rb
|
812
799
|
- lib/langchain/utils/hash_transformer.rb
|
813
|
-
- lib/langchain/utils/token_length/ai21_validator.rb
|
814
|
-
- lib/langchain/utils/token_length/base_validator.rb
|
815
|
-
- lib/langchain/utils/token_length/cohere_validator.rb
|
816
|
-
- lib/langchain/utils/token_length/google_palm_validator.rb
|
817
|
-
- lib/langchain/utils/token_length/openai_validator.rb
|
818
|
-
- lib/langchain/utils/token_length/token_limit_exceeded.rb
|
819
800
|
- lib/langchain/vectorsearch/base.rb
|
820
801
|
- lib/langchain/vectorsearch/chroma.rb
|
821
802
|
- lib/langchain/vectorsearch/elasticsearch.rb
|
@@ -1,41 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Langchain
|
4
|
-
module Utils
|
5
|
-
module TokenLength
|
6
|
-
#
|
7
|
-
# This class is meant to validate the length of the text passed in to AI21's API.
|
8
|
-
# It is used to validate the token length before the API call is made
|
9
|
-
#
|
10
|
-
|
11
|
-
class AI21Validator < BaseValidator
|
12
|
-
TOKEN_LIMITS = {
|
13
|
-
"j2-ultra" => 8192,
|
14
|
-
"j2-mid" => 8192,
|
15
|
-
"j2-light" => 8192
|
16
|
-
}.freeze
|
17
|
-
|
18
|
-
#
|
19
|
-
# Calculate token length for a given text and model name
|
20
|
-
#
|
21
|
-
# @param text [String] The text to calculate the token length for
|
22
|
-
# @param model_name [String] The model name to validate against
|
23
|
-
# @return [Integer] The token length of the text
|
24
|
-
#
|
25
|
-
def self.token_length(text, model_name, options = {})
|
26
|
-
res = options[:llm].tokenize(text)
|
27
|
-
res.dig(:tokens).length
|
28
|
-
end
|
29
|
-
|
30
|
-
def self.token_limit(model_name)
|
31
|
-
TOKEN_LIMITS[model_name]
|
32
|
-
end
|
33
|
-
singleton_class.alias_method :completion_token_limit, :token_limit
|
34
|
-
|
35
|
-
def self.token_length_from_messages(messages, model_name, options)
|
36
|
-
messages.sum { |message| token_length(message.to_json, model_name, options) }
|
37
|
-
end
|
38
|
-
end
|
39
|
-
end
|
40
|
-
end
|
41
|
-
end
|
@@ -1,42 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Langchain
|
4
|
-
module Utils
|
5
|
-
module TokenLength
|
6
|
-
#
|
7
|
-
# Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
|
8
|
-
#
|
9
|
-
# @param content [String | Array<String>] The text or array of texts to validate
|
10
|
-
# @param model_name [String] The model name to validate against
|
11
|
-
# @return [Integer] Whether the text is valid or not
|
12
|
-
# @raise [TokenLimitExceeded] If the text is too long
|
13
|
-
#
|
14
|
-
class BaseValidator
|
15
|
-
def self.validate_max_tokens!(content, model_name, options = {})
|
16
|
-
text_token_length = if content.is_a?(Array)
|
17
|
-
token_length_from_messages(content, model_name, options)
|
18
|
-
else
|
19
|
-
token_length(content, model_name, options)
|
20
|
-
end
|
21
|
-
|
22
|
-
leftover_tokens = token_limit(model_name) - text_token_length
|
23
|
-
|
24
|
-
# Some models have a separate token limit for completions (e.g. GPT-4 Turbo)
|
25
|
-
# We want the lower of the two limits
|
26
|
-
max_tokens = [leftover_tokens, completion_token_limit(model_name)].min
|
27
|
-
|
28
|
-
# Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
|
29
|
-
if max_tokens < 0
|
30
|
-
raise limit_exceeded_exception(token_limit(model_name), text_token_length)
|
31
|
-
end
|
32
|
-
|
33
|
-
max_tokens
|
34
|
-
end
|
35
|
-
|
36
|
-
def self.limit_exceeded_exception(limit, length)
|
37
|
-
TokenLimitExceeded.new("This model's maximum context length is #{limit} tokens, but the given text is #{length} tokens long.", length - limit)
|
38
|
-
end
|
39
|
-
end
|
40
|
-
end
|
41
|
-
end
|
42
|
-
end
|
@@ -1,49 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Langchain
|
4
|
-
module Utils
|
5
|
-
module TokenLength
|
6
|
-
#
|
7
|
-
# This class is meant to validate the length of the text passed in to Cohere's API.
|
8
|
-
# It is used to validate the token length before the API call is made
|
9
|
-
#
|
10
|
-
|
11
|
-
class CohereValidator < BaseValidator
|
12
|
-
TOKEN_LIMITS = {
|
13
|
-
# Source:
|
14
|
-
# https://docs.cohere.com/docs/models
|
15
|
-
"command-light" => 4096,
|
16
|
-
"command" => 4096,
|
17
|
-
"base-light" => 2048,
|
18
|
-
"base" => 2048,
|
19
|
-
"embed-english-light-v2.0" => 512,
|
20
|
-
"embed-english-v2.0" => 512,
|
21
|
-
"embed-multilingual-v2.0" => 256,
|
22
|
-
"summarize-medium" => 2048,
|
23
|
-
"summarize-xlarge" => 2048
|
24
|
-
}.freeze
|
25
|
-
|
26
|
-
#
|
27
|
-
# Calculate token length for a given text and model name
|
28
|
-
#
|
29
|
-
# @param text [String] The text to calculate the token length for
|
30
|
-
# @param model_name [String] The model name to validate against
|
31
|
-
# @return [Integer] The token length of the text
|
32
|
-
#
|
33
|
-
def self.token_length(text, model_name, options = {})
|
34
|
-
res = options[:llm].tokenize(text: text)
|
35
|
-
res["tokens"].length
|
36
|
-
end
|
37
|
-
|
38
|
-
def self.token_limit(model_name)
|
39
|
-
TOKEN_LIMITS[model_name]
|
40
|
-
end
|
41
|
-
singleton_class.alias_method :completion_token_limit, :token_limit
|
42
|
-
|
43
|
-
def self.token_length_from_messages(messages, model_name, options)
|
44
|
-
messages.sum { |message| token_length(message.to_json, model_name, options) }
|
45
|
-
end
|
46
|
-
end
|
47
|
-
end
|
48
|
-
end
|
49
|
-
end
|
@@ -1,57 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Langchain
|
4
|
-
module Utils
|
5
|
-
module TokenLength
|
6
|
-
#
|
7
|
-
# This class is meant to validate the length of the text passed in to Google Palm's API.
|
8
|
-
# It is used to validate the token length before the API call is made
|
9
|
-
#
|
10
|
-
class GooglePalmValidator < BaseValidator
|
11
|
-
TOKEN_LIMITS = {
|
12
|
-
# Source:
|
13
|
-
# This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage
|
14
|
-
|
15
|
-
# chat-bison-001 is the only model that currently supports countMessageTokens functions
|
16
|
-
"chat-bison-001" => {
|
17
|
-
"input_token_limit" => 4000, # 4096 is the limit but the countMessageTokens does not return anything higher than 4000
|
18
|
-
"output_token_limit" => 1024
|
19
|
-
}
|
20
|
-
# "text-bison-001" => {
|
21
|
-
# "input_token_limit" => 8196,
|
22
|
-
# "output_token_limit" => 1024
|
23
|
-
# },
|
24
|
-
# "embedding-gecko-001" => {
|
25
|
-
# "input_token_limit" => 1024
|
26
|
-
# }
|
27
|
-
}.freeze
|
28
|
-
|
29
|
-
#
|
30
|
-
# Calculate token length for a given text and model name
|
31
|
-
#
|
32
|
-
# @param text [String] The text to calculate the token length for
|
33
|
-
# @param model_name [String] The model name to validate against
|
34
|
-
# @param options [Hash] the options to create a message with
|
35
|
-
# @option options [Langchain::LLM:GooglePalm] :llm The Langchain::LLM:GooglePalm instance
|
36
|
-
# @return [Integer] The token length of the text
|
37
|
-
#
|
38
|
-
def self.token_length(text, model_name = "chat-bison-001", options = {})
|
39
|
-
response = options[:llm].client.count_message_tokens(model: model_name, prompt: text)
|
40
|
-
|
41
|
-
raise Langchain::LLM::ApiError.new(response["error"]["message"]) unless response["error"].nil?
|
42
|
-
|
43
|
-
response.dig("tokenCount")
|
44
|
-
end
|
45
|
-
|
46
|
-
def self.token_length_from_messages(messages, model_name, options = {})
|
47
|
-
messages.sum { |message| token_length(message.to_json, model_name, options) }
|
48
|
-
end
|
49
|
-
|
50
|
-
def self.token_limit(model_name)
|
51
|
-
TOKEN_LIMITS.dig(model_name, "input_token_limit")
|
52
|
-
end
|
53
|
-
singleton_class.alias_method :completion_token_limit, :token_limit
|
54
|
-
end
|
55
|
-
end
|
56
|
-
end
|
57
|
-
end
|
@@ -1,138 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require "tiktoken_ruby"
|
4
|
-
|
5
|
-
module Langchain
|
6
|
-
module Utils
|
7
|
-
module TokenLength
|
8
|
-
#
|
9
|
-
# This class is meant to validate the length of the text passed in to OpenAI's API.
|
10
|
-
# It is used to validate the token length before the API call is made
|
11
|
-
#
|
12
|
-
class OpenAIValidator < BaseValidator
|
13
|
-
COMPLETION_TOKEN_LIMITS = {
|
14
|
-
# GPT-4 Turbo has a separate token limit for completion
|
15
|
-
# Source:
|
16
|
-
# https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
|
17
|
-
"gpt-4-1106-preview" => 4096,
|
18
|
-
"gpt-4-vision-preview" => 4096,
|
19
|
-
"gpt-3.5-turbo-1106" => 4096
|
20
|
-
}
|
21
|
-
|
22
|
-
# NOTE: The gpt-4-turbo-preview is an alias that will always point to the latest GPT 4 Turbo preview
|
23
|
-
# the future previews may have a different token limit!
|
24
|
-
TOKEN_LIMITS = {
|
25
|
-
# Source:
|
26
|
-
# https://platform.openai.com/docs/api-reference/embeddings
|
27
|
-
# https://platform.openai.com/docs/models/gpt-4
|
28
|
-
"text-embedding-3-large" => 8191,
|
29
|
-
"text-embedding-3-small" => 8191,
|
30
|
-
"text-embedding-ada-002" => 8191,
|
31
|
-
"gpt-3.5-turbo" => 16385,
|
32
|
-
"gpt-3.5-turbo-0301" => 4096,
|
33
|
-
"gpt-3.5-turbo-0613" => 4096,
|
34
|
-
"gpt-3.5-turbo-1106" => 16385,
|
35
|
-
"gpt-3.5-turbo-0125" => 16385,
|
36
|
-
"gpt-3.5-turbo-16k" => 16384,
|
37
|
-
"gpt-3.5-turbo-16k-0613" => 16384,
|
38
|
-
"text-davinci-003" => 4097,
|
39
|
-
"text-davinci-002" => 4097,
|
40
|
-
"code-davinci-002" => 8001,
|
41
|
-
"gpt-4" => 8192,
|
42
|
-
"gpt-4-0314" => 8192,
|
43
|
-
"gpt-4-0613" => 8192,
|
44
|
-
"gpt-4-32k" => 32768,
|
45
|
-
"gpt-4-32k-0314" => 32768,
|
46
|
-
"gpt-4-32k-0613" => 32768,
|
47
|
-
"gpt-4-1106-preview" => 128000,
|
48
|
-
"gpt-4-turbo" => 128000,
|
49
|
-
"gpt-4-turbo-2024-04-09" => 128000,
|
50
|
-
"gpt-4-turbo-preview" => 128000,
|
51
|
-
"gpt-4-0125-preview" => 128000,
|
52
|
-
"gpt-4-vision-preview" => 128000,
|
53
|
-
"gpt-4o" => 128000,
|
54
|
-
"gpt-4o-2024-05-13" => 128000,
|
55
|
-
"text-curie-001" => 2049,
|
56
|
-
"text-babbage-001" => 2049,
|
57
|
-
"text-ada-001" => 2049,
|
58
|
-
"davinci" => 2049,
|
59
|
-
"curie" => 2049,
|
60
|
-
"babbage" => 2049,
|
61
|
-
"ada" => 2049
|
62
|
-
}.freeze
|
63
|
-
|
64
|
-
#
|
65
|
-
# Calculate token length for a given text and model name
|
66
|
-
#
|
67
|
-
# @param text [String] The text to calculate the token length for
|
68
|
-
# @param model_name [String] The model name to validate against
|
69
|
-
# @return [Integer] The token length of the text
|
70
|
-
#
|
71
|
-
def self.token_length(text, model_name, options = {})
|
72
|
-
# tiktoken-ruby doesn't support text-embedding-3-large or text-embedding-3-small yet
|
73
|
-
if ["text-embedding-3-large", "text-embedding-3-small"].include?(model_name)
|
74
|
-
model_name = "text-embedding-ada-002"
|
75
|
-
end
|
76
|
-
|
77
|
-
encoder = Tiktoken.encoding_for_model(model_name)
|
78
|
-
encoder.encode(text).length
|
79
|
-
end
|
80
|
-
|
81
|
-
def self.token_limit(model_name)
|
82
|
-
TOKEN_LIMITS[model_name]
|
83
|
-
end
|
84
|
-
|
85
|
-
def self.completion_token_limit(model_name)
|
86
|
-
COMPLETION_TOKEN_LIMITS[model_name] || token_limit(model_name)
|
87
|
-
end
|
88
|
-
|
89
|
-
# If :max_tokens is passed in, take the lower of it and the calculated max_tokens
|
90
|
-
def self.validate_max_tokens!(content, model_name, options = {})
|
91
|
-
max_tokens = super(content, model_name, options)
|
92
|
-
[options[:max_tokens], max_tokens].reject(&:nil?).min
|
93
|
-
end
|
94
|
-
|
95
|
-
# Copied from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
96
|
-
# Return the number of tokens used by a list of messages
|
97
|
-
#
|
98
|
-
# @param messages [Array<Hash>] The messages to calculate the token length for
|
99
|
-
# @param model [String] The model name to validate against
|
100
|
-
# @return [Integer] The token length of the messages
|
101
|
-
#
|
102
|
-
def self.token_length_from_messages(messages, model_name, options = {})
|
103
|
-
encoding = Tiktoken.encoding_for_model(model_name)
|
104
|
-
|
105
|
-
if ["gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-4-0314", "gpt-4-32k-0314", "gpt-4-0613", "gpt-4-32k-0613"].include?(model_name)
|
106
|
-
tokens_per_message = 3
|
107
|
-
tokens_per_name = 1
|
108
|
-
elsif model_name == "gpt-3.5-turbo-0301"
|
109
|
-
tokens_per_message = 4 # every message follows {role/name}\n{content}\n
|
110
|
-
tokens_per_name = -1 # if there's a name, the role is omitted
|
111
|
-
elsif model_name.include?("gpt-3.5-turbo")
|
112
|
-
# puts "Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613."
|
113
|
-
return token_length_from_messages(messages, "gpt-3.5-turbo-0613", options)
|
114
|
-
elsif model_name.include?("gpt-4")
|
115
|
-
# puts "Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
|
116
|
-
return token_length_from_messages(messages, "gpt-4-0613", options)
|
117
|
-
else
|
118
|
-
raise NotImplementedError.new(
|
119
|
-
"token_length_from_messages() is not implemented for model #{model_name}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."
|
120
|
-
)
|
121
|
-
end
|
122
|
-
|
123
|
-
num_tokens = 0
|
124
|
-
messages.each do |message|
|
125
|
-
num_tokens += tokens_per_message
|
126
|
-
message.each do |key, value|
|
127
|
-
num_tokens += encoding.encode(value).length
|
128
|
-
num_tokens += tokens_per_name if ["name", :name].include?(key)
|
129
|
-
end
|
130
|
-
end
|
131
|
-
|
132
|
-
num_tokens += 3 # every reply is primed with assistant
|
133
|
-
num_tokens
|
134
|
-
end
|
135
|
-
end
|
136
|
-
end
|
137
|
-
end
|
138
|
-
end
|
@@ -1,17 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Langchain
|
4
|
-
module Utils
|
5
|
-
module TokenLength
|
6
|
-
class TokenLimitExceeded < StandardError
|
7
|
-
attr_reader :token_overflow
|
8
|
-
|
9
|
-
def initialize(message = "", token_overflow = 0)
|
10
|
-
super(message)
|
11
|
-
|
12
|
-
@token_overflow = token_overflow
|
13
|
-
end
|
14
|
-
end
|
15
|
-
end
|
16
|
-
end
|
17
|
-
end
|