langchainrb 0.13.5 → 0.14.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +2 -17
- data/lib/langchain/assistants/assistant.rb +90 -19
- data/lib/langchain/assistants/messages/ollama_message.rb +86 -0
- data/lib/langchain/assistants/thread.rb +8 -1
- data/lib/langchain/llm/ai21.rb +0 -4
- data/lib/langchain/llm/anthropic.rb +15 -6
- data/lib/langchain/llm/azure.rb +3 -3
- data/lib/langchain/llm/base.rb +1 -0
- data/lib/langchain/llm/cohere.rb +0 -2
- data/lib/langchain/llm/google_palm.rb +1 -4
- data/lib/langchain/llm/ollama.rb +1 -1
- data/lib/langchain/llm/response/google_gemini_response.rb +1 -1
- data/lib/langchain/llm/response/ollama_response.rb +19 -1
- data/lib/langchain/vectorsearch/milvus.rb +1 -1
- data/lib/langchain/version.rb +1 -1
- metadata +5 -24
- data/lib/langchain/utils/token_length/ai21_validator.rb +0 -41
- data/lib/langchain/utils/token_length/base_validator.rb +0 -42
- data/lib/langchain/utils/token_length/cohere_validator.rb +0 -49
- data/lib/langchain/utils/token_length/google_palm_validator.rb +0 -57
- data/lib/langchain/utils/token_length/openai_validator.rb +0 -138
- data/lib/langchain/utils/token_length/token_limit_exceeded.rb +0 -17
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 68900cd116cf0fb1b77376a4906e5551f0d578ee2bb47c7ec86d32bf44f84e33
|
|
4
|
+
data.tar.gz: f68782c3cdc856799778618d78b6411a85b0c69adf6a4d33489b8025fdca3dce
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 158410fd769caaf9074eddc1143ddee9256ac5a466a510c32b74d337eba62fab80b676661cbf1673604d236014a5cb4defdd4743e71abb713a659ddea0fe5e8c
|
|
7
|
+
data.tar.gz: 2e956356a443ff37ad711f6c42f8c4940925bcee4be075b403c78c3f702b487c12790dca9ba7d68a01acaf1c245b2910650b3f938e80cedd1fc2d5af14f7ffa8
|
data/CHANGELOG.md
CHANGED
|
@@ -1,5 +1,11 @@
|
|
|
1
1
|
## [Unreleased]
|
|
2
2
|
|
|
3
|
+
## [0.14.0] - 2024-07-12
|
|
4
|
+
- Removed TokenLength validators
|
|
5
|
+
- Assistant works with a Mistral LLM now
|
|
6
|
+
- Assistant keeps track of tokens used
|
|
7
|
+
- Misc fixes and improvements
|
|
8
|
+
|
|
3
9
|
## [0.13.5] - 2024-07-01
|
|
4
10
|
- Add Milvus#remove_texts() method
|
|
5
11
|
- Langchain::Assistant has a `state` now
|
data/README.md
CHANGED
|
@@ -428,25 +428,10 @@ Assistants are Agent-like objects that leverage helpful instructions, LLMs, tool
|
|
|
428
428
|
```ruby
|
|
429
429
|
llm = Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"])
|
|
430
430
|
```
|
|
431
|
-
2. Instantiate
|
|
432
|
-
```ruby
|
|
433
|
-
thread = Langchain::Thread.new
|
|
434
|
-
```
|
|
435
|
-
You can pass old message from previously using the Assistant:
|
|
436
|
-
```ruby
|
|
437
|
-
thread.messages = messages
|
|
438
|
-
```
|
|
439
|
-
Messages contain the conversation history and the whole message history is sent to the LLM every time. A Message belongs to 1 of the 4 roles:
|
|
440
|
-
* `Message(role: "system")` message usually contains the instructions.
|
|
441
|
-
* `Message(role: "user")` messages come from the user.
|
|
442
|
-
* `Message(role: "assistant")` messages are produced by the LLM.
|
|
443
|
-
* `Message(role: "tool")` messages are sent in response to tool calls with tool outputs.
|
|
444
|
-
|
|
445
|
-
3. Instantiate an Assistant
|
|
431
|
+
2. Instantiate an Assistant
|
|
446
432
|
```ruby
|
|
447
433
|
assistant = Langchain::Assistant.new(
|
|
448
434
|
llm: llm,
|
|
449
|
-
thread: thread,
|
|
450
435
|
instructions: "You are a Meteorologist Assistant that is able to pull the weather for any location",
|
|
451
436
|
tools: [
|
|
452
437
|
Langchain::Tool::Weather.new(api_key: ENV["OPEN_WEATHER_API_KEY"])
|
|
@@ -482,7 +467,7 @@ assistant.add_message_and_run content: "What about Sacramento, CA?", auto_tool_e
|
|
|
482
467
|
### Accessing Thread messages
|
|
483
468
|
You can access the messages in a Thread by calling `assistant.thread.messages`.
|
|
484
469
|
```ruby
|
|
485
|
-
assistant.
|
|
470
|
+
assistant.messages
|
|
486
471
|
```
|
|
487
472
|
|
|
488
473
|
The Assistant checks the context window limits before every request to the LLM and remove oldest thread messages one by one if the context window is exceeded.
|
|
@@ -16,13 +16,15 @@ module Langchain
|
|
|
16
16
|
def_delegators :thread, :messages, :messages=
|
|
17
17
|
|
|
18
18
|
attr_reader :llm, :thread, :instructions, :state
|
|
19
|
+
attr_reader :total_prompt_tokens, :total_completion_tokens, :total_tokens
|
|
19
20
|
attr_accessor :tools
|
|
20
21
|
|
|
21
22
|
SUPPORTED_LLMS = [
|
|
22
23
|
Langchain::LLM::Anthropic,
|
|
23
|
-
Langchain::LLM::OpenAI,
|
|
24
24
|
Langchain::LLM::GoogleGemini,
|
|
25
|
-
Langchain::LLM::GoogleVertexAI
|
|
25
|
+
Langchain::LLM::GoogleVertexAI,
|
|
26
|
+
Langchain::LLM::Ollama,
|
|
27
|
+
Langchain::LLM::OpenAI
|
|
26
28
|
]
|
|
27
29
|
|
|
28
30
|
# Create a new assistant
|
|
@@ -40,6 +42,9 @@ module Langchain
|
|
|
40
42
|
unless SUPPORTED_LLMS.include?(llm.class)
|
|
41
43
|
raise ArgumentError, "Invalid LLM; currently only #{SUPPORTED_LLMS.join(", ")} are supported"
|
|
42
44
|
end
|
|
45
|
+
if llm.is_a?(Langchain::LLM::Ollama)
|
|
46
|
+
raise ArgumentError, "Currently only `mistral:7b-instruct-v0.3-fp16` model is supported for Ollama LLM" unless llm.defaults[:completion_model_name] == "mistral:7b-instruct-v0.3-fp16"
|
|
47
|
+
end
|
|
43
48
|
raise ArgumentError, "Tools must be an array of Langchain::Tool::Base instance(s)" unless tools.is_a?(Array) && tools.all? { |tool| tool.is_a?(Langchain::Tool::Base) }
|
|
44
49
|
|
|
45
50
|
@llm = llm
|
|
@@ -48,13 +53,15 @@ module Langchain
|
|
|
48
53
|
@instructions = instructions
|
|
49
54
|
@state = :ready
|
|
50
55
|
|
|
56
|
+
@total_prompt_tokens = 0
|
|
57
|
+
@total_completion_tokens = 0
|
|
58
|
+
@total_tokens = 0
|
|
59
|
+
|
|
51
60
|
raise ArgumentError, "Thread must be an instance of Langchain::Thread" unless @thread.is_a?(Langchain::Thread)
|
|
52
61
|
|
|
53
62
|
# The first message in the thread should be the system instructions
|
|
54
63
|
# TODO: What if the user added old messages and the system instructions are already in there? Should this overwrite the existing instructions?
|
|
55
|
-
|
|
56
|
-
add_message(role: "system", content: instructions) if instructions
|
|
57
|
-
end
|
|
64
|
+
initialize_instructions
|
|
58
65
|
# For Google Gemini, and Anthropic system instructions are added to the `system:` param in the `chat` method
|
|
59
66
|
end
|
|
60
67
|
|
|
@@ -150,7 +157,6 @@ module Langchain
|
|
|
150
157
|
|
|
151
158
|
# Handle the current state and transition to the next state
|
|
152
159
|
#
|
|
153
|
-
# @param state [Symbol] The current state
|
|
154
160
|
# @return [Symbol] The next state
|
|
155
161
|
def handle_state
|
|
156
162
|
case @state
|
|
@@ -189,7 +195,6 @@ module Langchain
|
|
|
189
195
|
|
|
190
196
|
# Handle LLM message scenario
|
|
191
197
|
#
|
|
192
|
-
# @param auto_tool_execution [Boolean] Flag to indicate if tools should be executed automatically
|
|
193
198
|
# @return [Symbol] The next state
|
|
194
199
|
def handle_llm_message
|
|
195
200
|
thread.messages.last.tool_calls.any? ? :requires_action : :completed
|
|
@@ -208,14 +213,29 @@ module Langchain
|
|
|
208
213
|
# @return [Symbol] The next state
|
|
209
214
|
def handle_user_or_tool_message
|
|
210
215
|
response = chat_with_llm
|
|
211
|
-
add_message(role: response.role, content: response.chat_completion, tool_calls: response.tool_calls)
|
|
212
216
|
|
|
217
|
+
# With Ollama, we're calling the `llm.complete()` method
|
|
218
|
+
content = if llm.is_a?(Langchain::LLM::Ollama)
|
|
219
|
+
response.completion
|
|
220
|
+
else
|
|
221
|
+
response.chat_completion
|
|
222
|
+
end
|
|
223
|
+
|
|
224
|
+
add_message(role: response.role, content: content, tool_calls: response.tool_calls)
|
|
225
|
+
record_used_tokens(response.prompt_tokens, response.completion_tokens, response.total_tokens)
|
|
226
|
+
|
|
227
|
+
set_state_for(response: response)
|
|
228
|
+
end
|
|
229
|
+
|
|
230
|
+
def set_state_for(response:)
|
|
213
231
|
if response.tool_calls.any?
|
|
214
232
|
:in_progress
|
|
215
233
|
elsif response.chat_completion
|
|
216
234
|
:completed
|
|
235
|
+
elsif response.completion # Currently only used by Ollama
|
|
236
|
+
:completed
|
|
217
237
|
else
|
|
218
|
-
Langchain.logger.error("LLM response does not contain tool calls or
|
|
238
|
+
Langchain.logger.error("LLM response does not contain tool calls, chat or completion response")
|
|
219
239
|
:failed
|
|
220
240
|
end
|
|
221
241
|
end
|
|
@@ -236,6 +256,8 @@ module Langchain
|
|
|
236
256
|
# @return [String] The tool role
|
|
237
257
|
def determine_tool_role
|
|
238
258
|
case llm
|
|
259
|
+
when Langchain::LLM::Ollama
|
|
260
|
+
Langchain::Messages::OllamaMessage::TOOL_ROLE
|
|
239
261
|
when Langchain::LLM::OpenAI
|
|
240
262
|
Langchain::Messages::OpenAIMessage::TOOL_ROLE
|
|
241
263
|
when Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI
|
|
@@ -245,31 +267,58 @@ module Langchain
|
|
|
245
267
|
end
|
|
246
268
|
end
|
|
247
269
|
|
|
270
|
+
def initialize_instructions
|
|
271
|
+
if llm.is_a?(Langchain::LLM::Ollama)
|
|
272
|
+
content = String.new # rubocop: disable Performance/UnfreezeString
|
|
273
|
+
if tools.any?
|
|
274
|
+
content << %([AVAILABLE_TOOLS] #{tools.map(&:to_openai_tools).flatten}[/AVAILABLE_TOOLS])
|
|
275
|
+
end
|
|
276
|
+
if instructions
|
|
277
|
+
content << "[INST] #{instructions}[/INST]"
|
|
278
|
+
end
|
|
279
|
+
|
|
280
|
+
add_message(role: "system", content: content)
|
|
281
|
+
elsif llm.is_a?(Langchain::LLM::OpenAI)
|
|
282
|
+
add_message(role: "system", content: instructions) if instructions
|
|
283
|
+
end
|
|
284
|
+
end
|
|
285
|
+
|
|
248
286
|
# Call to the LLM#chat() method
|
|
249
287
|
#
|
|
250
288
|
# @return [Langchain::LLM::BaseResponse] The LLM response object
|
|
251
289
|
def chat_with_llm
|
|
252
290
|
Langchain.logger.info("Sending a call to #{llm.class}", for: self.class)
|
|
253
291
|
|
|
254
|
-
params = {
|
|
292
|
+
params = {}
|
|
255
293
|
|
|
256
|
-
if
|
|
257
|
-
if
|
|
294
|
+
if llm.is_a?(Langchain::LLM::OpenAI)
|
|
295
|
+
if tools.any?
|
|
258
296
|
params[:tools] = tools.map(&:to_openai_tools).flatten
|
|
259
297
|
params[:tool_choice] = "auto"
|
|
260
|
-
|
|
298
|
+
end
|
|
299
|
+
elsif llm.is_a?(Langchain::LLM::Anthropic)
|
|
300
|
+
if tools.any?
|
|
261
301
|
params[:tools] = tools.map(&:to_anthropic_tools).flatten
|
|
262
|
-
params[:system] = instructions if instructions
|
|
263
302
|
params[:tool_choice] = {type: "auto"}
|
|
264
|
-
|
|
303
|
+
end
|
|
304
|
+
params[:system] = instructions if instructions
|
|
305
|
+
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
|
306
|
+
if tools.any?
|
|
265
307
|
params[:tools] = tools.map(&:to_google_gemini_tools).flatten
|
|
266
308
|
params[:system] = instructions if instructions
|
|
267
309
|
params[:tool_choice] = "auto"
|
|
268
310
|
end
|
|
269
|
-
# TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it.
|
|
270
311
|
end
|
|
312
|
+
# TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it.
|
|
271
313
|
|
|
272
|
-
llm.
|
|
314
|
+
if llm.is_a?(Langchain::LLM::Ollama)
|
|
315
|
+
params[:raw] = true
|
|
316
|
+
params[:prompt] = thread.prompt_of_concatenated_messages
|
|
317
|
+
llm.complete(**params)
|
|
318
|
+
else
|
|
319
|
+
params[:messages] = thread.array_of_message_hashes
|
|
320
|
+
llm.chat(**params)
|
|
321
|
+
end
|
|
273
322
|
end
|
|
274
323
|
|
|
275
324
|
# Run the tools automatically
|
|
@@ -278,7 +327,9 @@ module Langchain
|
|
|
278
327
|
def run_tools(tool_calls)
|
|
279
328
|
# Iterate over each function invocation and submit tool output
|
|
280
329
|
tool_calls.each do |tool_call|
|
|
281
|
-
tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::
|
|
330
|
+
tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::Ollama)
|
|
331
|
+
extract_ollama_tool_call(tool_call: tool_call)
|
|
332
|
+
elsif llm.is_a?(Langchain::LLM::OpenAI)
|
|
282
333
|
extract_openai_tool_call(tool_call: tool_call)
|
|
283
334
|
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
|
284
335
|
extract_google_gemini_tool_call(tool_call: tool_call)
|
|
@@ -296,6 +347,12 @@ module Langchain
|
|
|
296
347
|
end
|
|
297
348
|
end
|
|
298
349
|
|
|
350
|
+
def extract_ollama_tool_call(tool_call:)
|
|
351
|
+
tool_name, method_name = tool_call.dig("name").split("__")
|
|
352
|
+
tool_arguments = tool_call.dig("arguments").transform_keys(&:to_sym)
|
|
353
|
+
[nil, tool_name, method_name, tool_arguments]
|
|
354
|
+
end
|
|
355
|
+
|
|
299
356
|
# Extract the tool call information from the OpenAI tool call hash
|
|
300
357
|
#
|
|
301
358
|
# @param tool_call [Hash] The tool call hash
|
|
@@ -346,7 +403,9 @@ module Langchain
|
|
|
346
403
|
# @param tool_call_id [String] The ID of the tool call to include in the message
|
|
347
404
|
# @return [Langchain::Message] The Message object
|
|
348
405
|
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
|
|
349
|
-
if llm.is_a?(Langchain::LLM::
|
|
406
|
+
if llm.is_a?(Langchain::LLM::Ollama)
|
|
407
|
+
Langchain::Messages::OllamaMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
|
|
408
|
+
elsif llm.is_a?(Langchain::LLM::OpenAI)
|
|
350
409
|
Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
|
|
351
410
|
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
|
352
411
|
Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
|
|
@@ -355,6 +414,18 @@ module Langchain
|
|
|
355
414
|
end
|
|
356
415
|
end
|
|
357
416
|
|
|
417
|
+
# Increment the tokens count based on the last interaction with the LLM
|
|
418
|
+
#
|
|
419
|
+
# @param prompt_tokens [Integer] The number of used prmopt tokens
|
|
420
|
+
# @param completion_tokens [Integer] The number of used completion tokens
|
|
421
|
+
# @param total_tokens [Integer] The total number of used tokens
|
|
422
|
+
# @return [Integer] The current total tokens count
|
|
423
|
+
def record_used_tokens(prompt_tokens, completion_tokens, total_tokens_from_operation)
|
|
424
|
+
@total_prompt_tokens += prompt_tokens if prompt_tokens
|
|
425
|
+
@total_completion_tokens += completion_tokens if completion_tokens
|
|
426
|
+
@total_tokens += total_tokens_from_operation if total_tokens_from_operation
|
|
427
|
+
end
|
|
428
|
+
|
|
358
429
|
# TODO: Fix the message truncation when context window is exceeded
|
|
359
430
|
end
|
|
360
431
|
end
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Langchain
|
|
4
|
+
module Messages
|
|
5
|
+
class OllamaMessage < Base
|
|
6
|
+
# OpenAI uses the following roles:
|
|
7
|
+
ROLES = [
|
|
8
|
+
"system",
|
|
9
|
+
"assistant",
|
|
10
|
+
"user",
|
|
11
|
+
"tool"
|
|
12
|
+
].freeze
|
|
13
|
+
|
|
14
|
+
TOOL_ROLE = "tool"
|
|
15
|
+
|
|
16
|
+
# Initialize a new OpenAI message
|
|
17
|
+
#
|
|
18
|
+
# @param [String] The role of the message
|
|
19
|
+
# @param [String] The content of the message
|
|
20
|
+
# @param [Array<Hash>] The tool calls made in the message
|
|
21
|
+
# @param [String] The ID of the tool call
|
|
22
|
+
def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
|
|
23
|
+
raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
|
|
24
|
+
raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
|
|
25
|
+
|
|
26
|
+
@role = role
|
|
27
|
+
# Some Tools return content as a JSON hence `.to_s`
|
|
28
|
+
@content = content.to_s
|
|
29
|
+
@tool_calls = tool_calls
|
|
30
|
+
@tool_call_id = tool_call_id
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
def to_s
|
|
34
|
+
send(:"to_#{role}_message_string")
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
def to_system_message_string
|
|
38
|
+
content
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def to_user_message_string
|
|
42
|
+
"[INST] #{content}[/INST]"
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def to_tool_message_string
|
|
46
|
+
"[TOOL_RESULTS] #{content}[/TOOL_RESULTS]"
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
def to_assistant_message_string
|
|
50
|
+
if tool_calls.any?
|
|
51
|
+
%("[TOOL_CALLS] #{tool_calls}")
|
|
52
|
+
else
|
|
53
|
+
content
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
# Check if the message came from an LLM
|
|
58
|
+
#
|
|
59
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
|
60
|
+
def llm?
|
|
61
|
+
assistant?
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
# Check if the message came from an LLM
|
|
65
|
+
#
|
|
66
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
|
67
|
+
def assistant?
|
|
68
|
+
role == "assistant"
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
# Check if the message are system instructions
|
|
72
|
+
#
|
|
73
|
+
# @return [Boolean] true/false whether this message are system instructions
|
|
74
|
+
def system?
|
|
75
|
+
role == "system"
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
# Check if the message is a tool call
|
|
79
|
+
#
|
|
80
|
+
# @return [Boolean] true/false whether this message is a tool call
|
|
81
|
+
def tool?
|
|
82
|
+
role == "tool"
|
|
83
|
+
end
|
|
84
|
+
end
|
|
85
|
+
end
|
|
86
|
+
end
|
|
@@ -17,7 +17,14 @@ module Langchain
|
|
|
17
17
|
#
|
|
18
18
|
# @return [Array<Hash>] The thread as an OpenAI API-compatible array of hashes
|
|
19
19
|
def array_of_message_hashes
|
|
20
|
-
messages
|
|
20
|
+
messages
|
|
21
|
+
.map(&:to_hash)
|
|
22
|
+
.compact
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
# Only used by the Assistant when it calls the LLM#complete() method
|
|
26
|
+
def prompt_of_concatenated_messages
|
|
27
|
+
messages.map(&:to_s).join
|
|
21
28
|
end
|
|
22
29
|
|
|
23
30
|
# Add a message to the thread
|
data/lib/langchain/llm/ai21.rb
CHANGED
|
@@ -16,8 +16,6 @@ module Langchain::LLM
|
|
|
16
16
|
model: "j2-ultra"
|
|
17
17
|
}.freeze
|
|
18
18
|
|
|
19
|
-
LENGTH_VALIDATOR = Langchain::Utils::TokenLength::AI21Validator
|
|
20
|
-
|
|
21
19
|
def initialize(api_key:, default_options: {})
|
|
22
20
|
depends_on "ai21"
|
|
23
21
|
|
|
@@ -35,8 +33,6 @@ module Langchain::LLM
|
|
|
35
33
|
def complete(prompt:, **params)
|
|
36
34
|
parameters = complete_parameters params
|
|
37
35
|
|
|
38
|
-
parameters[:maxTokens] = LENGTH_VALIDATOR.validate_max_tokens!(prompt, parameters[:model], {llm: client})
|
|
39
|
-
|
|
40
36
|
response = client.complete(prompt, parameters)
|
|
41
37
|
Langchain::LLM::AI21Response.new response, model: parameters[:model]
|
|
42
38
|
end
|
|
@@ -5,10 +5,10 @@ module Langchain::LLM
|
|
|
5
5
|
# Wrapper around Anthropic APIs.
|
|
6
6
|
#
|
|
7
7
|
# Gem requirements:
|
|
8
|
-
# gem "anthropic", "~> 0.
|
|
8
|
+
# gem "anthropic", "~> 0.3.0"
|
|
9
9
|
#
|
|
10
10
|
# Usage:
|
|
11
|
-
#
|
|
11
|
+
# anthropic = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
|
|
12
12
|
#
|
|
13
13
|
class Anthropic < Base
|
|
14
14
|
DEFAULTS = {
|
|
@@ -18,9 +18,6 @@ module Langchain::LLM
|
|
|
18
18
|
max_tokens_to_sample: 256
|
|
19
19
|
}.freeze
|
|
20
20
|
|
|
21
|
-
# TODO: Implement token length validator for Anthropic
|
|
22
|
-
# LENGTH_VALIDATOR = Langchain::Utils::TokenLength::AnthropicValidator
|
|
23
|
-
|
|
24
21
|
# Initialize an Anthropic LLM instance
|
|
25
22
|
#
|
|
26
23
|
# @param api_key [String] The API key to use
|
|
@@ -81,7 +78,10 @@ module Langchain::LLM
|
|
|
81
78
|
parameters[:metadata] = metadata if metadata
|
|
82
79
|
parameters[:stream] = stream if stream
|
|
83
80
|
|
|
84
|
-
response =
|
|
81
|
+
response = with_api_error_handling do
|
|
82
|
+
client.complete(parameters: parameters)
|
|
83
|
+
end
|
|
84
|
+
|
|
85
85
|
Langchain::LLM::AnthropicResponse.new(response)
|
|
86
86
|
end
|
|
87
87
|
|
|
@@ -114,6 +114,15 @@ module Langchain::LLM
|
|
|
114
114
|
Langchain::LLM::AnthropicResponse.new(response)
|
|
115
115
|
end
|
|
116
116
|
|
|
117
|
+
def with_api_error_handling
|
|
118
|
+
response = yield
|
|
119
|
+
return if response.empty?
|
|
120
|
+
|
|
121
|
+
raise Langchain::LLM::ApiError.new "Anthropic API error: #{response.dig("error", "message")}" if response&.dig("error")
|
|
122
|
+
|
|
123
|
+
response
|
|
124
|
+
end
|
|
125
|
+
|
|
117
126
|
private
|
|
118
127
|
|
|
119
128
|
def set_extra_headers!
|
data/lib/langchain/llm/azure.rb
CHANGED
|
@@ -42,17 +42,17 @@ module Langchain::LLM
|
|
|
42
42
|
|
|
43
43
|
def embed(...)
|
|
44
44
|
@client = @embed_client
|
|
45
|
-
super
|
|
45
|
+
super
|
|
46
46
|
end
|
|
47
47
|
|
|
48
48
|
def complete(...)
|
|
49
49
|
@client = @chat_client
|
|
50
|
-
super
|
|
50
|
+
super
|
|
51
51
|
end
|
|
52
52
|
|
|
53
53
|
def chat(...)
|
|
54
54
|
@client = @chat_client
|
|
55
|
-
super
|
|
55
|
+
super
|
|
56
56
|
end
|
|
57
57
|
end
|
|
58
58
|
end
|
data/lib/langchain/llm/base.rb
CHANGED
|
@@ -8,6 +8,7 @@ module Langchain::LLM
|
|
|
8
8
|
# Langchain.rb provides a common interface to interact with all supported LLMs:
|
|
9
9
|
#
|
|
10
10
|
# - {Langchain::LLM::AI21}
|
|
11
|
+
# - {Langchain::LLM::Anthropic}
|
|
11
12
|
# - {Langchain::LLM::Azure}
|
|
12
13
|
# - {Langchain::LLM::Cohere}
|
|
13
14
|
# - {Langchain::LLM::GooglePalm}
|
data/lib/langchain/llm/cohere.rb
CHANGED
|
@@ -74,8 +74,6 @@ module Langchain::LLM
|
|
|
74
74
|
|
|
75
75
|
default_params.merge!(params)
|
|
76
76
|
|
|
77
|
-
default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], llm: client)
|
|
78
|
-
|
|
79
77
|
response = client.generate(**default_params)
|
|
80
78
|
Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
|
|
81
79
|
end
|
|
@@ -18,7 +18,7 @@ module Langchain::LLM
|
|
|
18
18
|
chat_completion_model_name: "chat-bison-001",
|
|
19
19
|
embeddings_model_name: "embedding-gecko-001"
|
|
20
20
|
}.freeze
|
|
21
|
-
|
|
21
|
+
|
|
22
22
|
ROLE_MAPPING = {
|
|
23
23
|
"assistant" => "ai"
|
|
24
24
|
}
|
|
@@ -96,9 +96,6 @@ module Langchain::LLM
|
|
|
96
96
|
examples: compose_examples(examples)
|
|
97
97
|
}
|
|
98
98
|
|
|
99
|
-
# chat-bison-001 is the only model that currently supports countMessageTokens functions
|
|
100
|
-
LENGTH_VALIDATOR.validate_max_tokens!(default_params[:messages], "chat-bison-001", llm: self)
|
|
101
|
-
|
|
102
99
|
if options[:stop_sequences]
|
|
103
100
|
default_params[:stop] = options.delete(:stop_sequences)
|
|
104
101
|
end
|
data/lib/langchain/llm/ollama.rb
CHANGED
|
@@ -36,7 +36,7 @@ module Langchain::LLM
|
|
|
36
36
|
end
|
|
37
37
|
|
|
38
38
|
def prompt_tokens
|
|
39
|
-
raw_response.
|
|
39
|
+
raw_response.fetch("prompt_eval_count", 0) if done?
|
|
40
40
|
end
|
|
41
41
|
|
|
42
42
|
def completion_tokens
|
|
@@ -47,6 +47,24 @@ module Langchain::LLM
|
|
|
47
47
|
prompt_tokens + completion_tokens if done?
|
|
48
48
|
end
|
|
49
49
|
|
|
50
|
+
def tool_calls
|
|
51
|
+
if chat_completion && (parsed_tool_calls = JSON.parse(chat_completion))
|
|
52
|
+
[parsed_tool_calls]
|
|
53
|
+
elsif completion&.include?("[TOOL_CALLS]") && (
|
|
54
|
+
parsed_tool_calls = JSON.parse(
|
|
55
|
+
completion
|
|
56
|
+
# Slice out the serialize JSON
|
|
57
|
+
.slice(/\{.*\}/)
|
|
58
|
+
# Replace hash rocket with colon
|
|
59
|
+
.gsub("=>", ":")
|
|
60
|
+
)
|
|
61
|
+
)
|
|
62
|
+
[parsed_tool_calls]
|
|
63
|
+
else
|
|
64
|
+
[]
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
|
|
50
68
|
private
|
|
51
69
|
|
|
52
70
|
def done?
|
|
@@ -140,7 +140,7 @@ module Langchain::Vectorsearch
|
|
|
140
140
|
|
|
141
141
|
client.search(
|
|
142
142
|
collection_name: index_name,
|
|
143
|
-
output_fields: ["id", "content", "vectors"
|
|
143
|
+
output_fields: ["id", "content"], # Add "vectors" if need to have full vectors returned.
|
|
144
144
|
top_k: k.to_s,
|
|
145
145
|
vectors: [embedding],
|
|
146
146
|
dsl_type: 1,
|
data/lib/langchain/version.rb
CHANGED
metadata
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: langchainrb
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.
|
|
4
|
+
version: 0.14.0
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- Andrei Bondarev
|
|
8
8
|
autorequire:
|
|
9
9
|
bindir: exe
|
|
10
10
|
cert_chain: []
|
|
11
|
-
date: 2024-07-
|
|
11
|
+
date: 2024-07-12 00:00:00.000000000 Z
|
|
12
12
|
dependencies:
|
|
13
13
|
- !ruby/object:Gem::Dependency
|
|
14
14
|
name: baran
|
|
@@ -212,14 +212,14 @@ dependencies:
|
|
|
212
212
|
requirements:
|
|
213
213
|
- - "~>"
|
|
214
214
|
- !ruby/object:Gem::Version
|
|
215
|
-
version: '0.
|
|
215
|
+
version: '0.3'
|
|
216
216
|
type: :development
|
|
217
217
|
prerelease: false
|
|
218
218
|
version_requirements: !ruby/object:Gem::Requirement
|
|
219
219
|
requirements:
|
|
220
220
|
- - "~>"
|
|
221
221
|
- !ruby/object:Gem::Version
|
|
222
|
-
version: '0.
|
|
222
|
+
version: '0.3'
|
|
223
223
|
- !ruby/object:Gem::Dependency
|
|
224
224
|
name: aws-sdk-bedrockruntime
|
|
225
225
|
requirement: !ruby/object:Gem::Requirement
|
|
@@ -682,20 +682,6 @@ dependencies:
|
|
|
682
682
|
- - "~>"
|
|
683
683
|
- !ruby/object:Gem::Version
|
|
684
684
|
version: 0.1.0
|
|
685
|
-
- !ruby/object:Gem::Dependency
|
|
686
|
-
name: tiktoken_ruby
|
|
687
|
-
requirement: !ruby/object:Gem::Requirement
|
|
688
|
-
requirements:
|
|
689
|
-
- - "~>"
|
|
690
|
-
- !ruby/object:Gem::Version
|
|
691
|
-
version: 0.0.9
|
|
692
|
-
type: :development
|
|
693
|
-
prerelease: false
|
|
694
|
-
version_requirements: !ruby/object:Gem::Requirement
|
|
695
|
-
requirements:
|
|
696
|
-
- - "~>"
|
|
697
|
-
- !ruby/object:Gem::Version
|
|
698
|
-
version: 0.0.9
|
|
699
685
|
description: Build LLM-backed Ruby applications with Ruby's Langchain.rb
|
|
700
686
|
email:
|
|
701
687
|
- andrei.bondarev13@gmail.com
|
|
@@ -711,6 +697,7 @@ files:
|
|
|
711
697
|
- lib/langchain/assistants/messages/anthropic_message.rb
|
|
712
698
|
- lib/langchain/assistants/messages/base.rb
|
|
713
699
|
- lib/langchain/assistants/messages/google_gemini_message.rb
|
|
700
|
+
- lib/langchain/assistants/messages/ollama_message.rb
|
|
714
701
|
- lib/langchain/assistants/messages/openai_message.rb
|
|
715
702
|
- lib/langchain/assistants/thread.rb
|
|
716
703
|
- lib/langchain/chunk.rb
|
|
@@ -810,12 +797,6 @@ files:
|
|
|
810
797
|
- lib/langchain/tool/wikipedia/wikipedia.rb
|
|
811
798
|
- lib/langchain/utils/cosine_similarity.rb
|
|
812
799
|
- lib/langchain/utils/hash_transformer.rb
|
|
813
|
-
- lib/langchain/utils/token_length/ai21_validator.rb
|
|
814
|
-
- lib/langchain/utils/token_length/base_validator.rb
|
|
815
|
-
- lib/langchain/utils/token_length/cohere_validator.rb
|
|
816
|
-
- lib/langchain/utils/token_length/google_palm_validator.rb
|
|
817
|
-
- lib/langchain/utils/token_length/openai_validator.rb
|
|
818
|
-
- lib/langchain/utils/token_length/token_limit_exceeded.rb
|
|
819
800
|
- lib/langchain/vectorsearch/base.rb
|
|
820
801
|
- lib/langchain/vectorsearch/chroma.rb
|
|
821
802
|
- lib/langchain/vectorsearch/elasticsearch.rb
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
module Langchain
|
|
4
|
-
module Utils
|
|
5
|
-
module TokenLength
|
|
6
|
-
#
|
|
7
|
-
# This class is meant to validate the length of the text passed in to AI21's API.
|
|
8
|
-
# It is used to validate the token length before the API call is made
|
|
9
|
-
#
|
|
10
|
-
|
|
11
|
-
class AI21Validator < BaseValidator
|
|
12
|
-
TOKEN_LIMITS = {
|
|
13
|
-
"j2-ultra" => 8192,
|
|
14
|
-
"j2-mid" => 8192,
|
|
15
|
-
"j2-light" => 8192
|
|
16
|
-
}.freeze
|
|
17
|
-
|
|
18
|
-
#
|
|
19
|
-
# Calculate token length for a given text and model name
|
|
20
|
-
#
|
|
21
|
-
# @param text [String] The text to calculate the token length for
|
|
22
|
-
# @param model_name [String] The model name to validate against
|
|
23
|
-
# @return [Integer] The token length of the text
|
|
24
|
-
#
|
|
25
|
-
def self.token_length(text, model_name, options = {})
|
|
26
|
-
res = options[:llm].tokenize(text)
|
|
27
|
-
res.dig(:tokens).length
|
|
28
|
-
end
|
|
29
|
-
|
|
30
|
-
def self.token_limit(model_name)
|
|
31
|
-
TOKEN_LIMITS[model_name]
|
|
32
|
-
end
|
|
33
|
-
singleton_class.alias_method :completion_token_limit, :token_limit
|
|
34
|
-
|
|
35
|
-
def self.token_length_from_messages(messages, model_name, options)
|
|
36
|
-
messages.sum { |message| token_length(message.to_json, model_name, options) }
|
|
37
|
-
end
|
|
38
|
-
end
|
|
39
|
-
end
|
|
40
|
-
end
|
|
41
|
-
end
|
|
@@ -1,42 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
module Langchain
|
|
4
|
-
module Utils
|
|
5
|
-
module TokenLength
|
|
6
|
-
#
|
|
7
|
-
# Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
|
|
8
|
-
#
|
|
9
|
-
# @param content [String | Array<String>] The text or array of texts to validate
|
|
10
|
-
# @param model_name [String] The model name to validate against
|
|
11
|
-
# @return [Integer] Whether the text is valid or not
|
|
12
|
-
# @raise [TokenLimitExceeded] If the text is too long
|
|
13
|
-
#
|
|
14
|
-
class BaseValidator
|
|
15
|
-
def self.validate_max_tokens!(content, model_name, options = {})
|
|
16
|
-
text_token_length = if content.is_a?(Array)
|
|
17
|
-
token_length_from_messages(content, model_name, options)
|
|
18
|
-
else
|
|
19
|
-
token_length(content, model_name, options)
|
|
20
|
-
end
|
|
21
|
-
|
|
22
|
-
leftover_tokens = token_limit(model_name) - text_token_length
|
|
23
|
-
|
|
24
|
-
# Some models have a separate token limit for completions (e.g. GPT-4 Turbo)
|
|
25
|
-
# We want the lower of the two limits
|
|
26
|
-
max_tokens = [leftover_tokens, completion_token_limit(model_name)].min
|
|
27
|
-
|
|
28
|
-
# Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
|
|
29
|
-
if max_tokens < 0
|
|
30
|
-
raise limit_exceeded_exception(token_limit(model_name), text_token_length)
|
|
31
|
-
end
|
|
32
|
-
|
|
33
|
-
max_tokens
|
|
34
|
-
end
|
|
35
|
-
|
|
36
|
-
def self.limit_exceeded_exception(limit, length)
|
|
37
|
-
TokenLimitExceeded.new("This model's maximum context length is #{limit} tokens, but the given text is #{length} tokens long.", length - limit)
|
|
38
|
-
end
|
|
39
|
-
end
|
|
40
|
-
end
|
|
41
|
-
end
|
|
42
|
-
end
|
|
@@ -1,49 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
module Langchain
|
|
4
|
-
module Utils
|
|
5
|
-
module TokenLength
|
|
6
|
-
#
|
|
7
|
-
# This class is meant to validate the length of the text passed in to Cohere's API.
|
|
8
|
-
# It is used to validate the token length before the API call is made
|
|
9
|
-
#
|
|
10
|
-
|
|
11
|
-
class CohereValidator < BaseValidator
|
|
12
|
-
TOKEN_LIMITS = {
|
|
13
|
-
# Source:
|
|
14
|
-
# https://docs.cohere.com/docs/models
|
|
15
|
-
"command-light" => 4096,
|
|
16
|
-
"command" => 4096,
|
|
17
|
-
"base-light" => 2048,
|
|
18
|
-
"base" => 2048,
|
|
19
|
-
"embed-english-light-v2.0" => 512,
|
|
20
|
-
"embed-english-v2.0" => 512,
|
|
21
|
-
"embed-multilingual-v2.0" => 256,
|
|
22
|
-
"summarize-medium" => 2048,
|
|
23
|
-
"summarize-xlarge" => 2048
|
|
24
|
-
}.freeze
|
|
25
|
-
|
|
26
|
-
#
|
|
27
|
-
# Calculate token length for a given text and model name
|
|
28
|
-
#
|
|
29
|
-
# @param text [String] The text to calculate the token length for
|
|
30
|
-
# @param model_name [String] The model name to validate against
|
|
31
|
-
# @return [Integer] The token length of the text
|
|
32
|
-
#
|
|
33
|
-
def self.token_length(text, model_name, options = {})
|
|
34
|
-
res = options[:llm].tokenize(text: text)
|
|
35
|
-
res["tokens"].length
|
|
36
|
-
end
|
|
37
|
-
|
|
38
|
-
def self.token_limit(model_name)
|
|
39
|
-
TOKEN_LIMITS[model_name]
|
|
40
|
-
end
|
|
41
|
-
singleton_class.alias_method :completion_token_limit, :token_limit
|
|
42
|
-
|
|
43
|
-
def self.token_length_from_messages(messages, model_name, options)
|
|
44
|
-
messages.sum { |message| token_length(message.to_json, model_name, options) }
|
|
45
|
-
end
|
|
46
|
-
end
|
|
47
|
-
end
|
|
48
|
-
end
|
|
49
|
-
end
|
|
@@ -1,57 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
module Langchain
|
|
4
|
-
module Utils
|
|
5
|
-
module TokenLength
|
|
6
|
-
#
|
|
7
|
-
# This class is meant to validate the length of the text passed in to Google Palm's API.
|
|
8
|
-
# It is used to validate the token length before the API call is made
|
|
9
|
-
#
|
|
10
|
-
class GooglePalmValidator < BaseValidator
|
|
11
|
-
TOKEN_LIMITS = {
|
|
12
|
-
# Source:
|
|
13
|
-
# This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage
|
|
14
|
-
|
|
15
|
-
# chat-bison-001 is the only model that currently supports countMessageTokens functions
|
|
16
|
-
"chat-bison-001" => {
|
|
17
|
-
"input_token_limit" => 4000, # 4096 is the limit but the countMessageTokens does not return anything higher than 4000
|
|
18
|
-
"output_token_limit" => 1024
|
|
19
|
-
}
|
|
20
|
-
# "text-bison-001" => {
|
|
21
|
-
# "input_token_limit" => 8196,
|
|
22
|
-
# "output_token_limit" => 1024
|
|
23
|
-
# },
|
|
24
|
-
# "embedding-gecko-001" => {
|
|
25
|
-
# "input_token_limit" => 1024
|
|
26
|
-
# }
|
|
27
|
-
}.freeze
|
|
28
|
-
|
|
29
|
-
#
|
|
30
|
-
# Calculate token length for a given text and model name
|
|
31
|
-
#
|
|
32
|
-
# @param text [String] The text to calculate the token length for
|
|
33
|
-
# @param model_name [String] The model name to validate against
|
|
34
|
-
# @param options [Hash] the options to create a message with
|
|
35
|
-
# @option options [Langchain::LLM:GooglePalm] :llm The Langchain::LLM:GooglePalm instance
|
|
36
|
-
# @return [Integer] The token length of the text
|
|
37
|
-
#
|
|
38
|
-
def self.token_length(text, model_name = "chat-bison-001", options = {})
|
|
39
|
-
response = options[:llm].client.count_message_tokens(model: model_name, prompt: text)
|
|
40
|
-
|
|
41
|
-
raise Langchain::LLM::ApiError.new(response["error"]["message"]) unless response["error"].nil?
|
|
42
|
-
|
|
43
|
-
response.dig("tokenCount")
|
|
44
|
-
end
|
|
45
|
-
|
|
46
|
-
def self.token_length_from_messages(messages, model_name, options = {})
|
|
47
|
-
messages.sum { |message| token_length(message.to_json, model_name, options) }
|
|
48
|
-
end
|
|
49
|
-
|
|
50
|
-
def self.token_limit(model_name)
|
|
51
|
-
TOKEN_LIMITS.dig(model_name, "input_token_limit")
|
|
52
|
-
end
|
|
53
|
-
singleton_class.alias_method :completion_token_limit, :token_limit
|
|
54
|
-
end
|
|
55
|
-
end
|
|
56
|
-
end
|
|
57
|
-
end
|
|
@@ -1,138 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require "tiktoken_ruby"
|
|
4
|
-
|
|
5
|
-
module Langchain
|
|
6
|
-
module Utils
|
|
7
|
-
module TokenLength
|
|
8
|
-
#
|
|
9
|
-
# This class is meant to validate the length of the text passed in to OpenAI's API.
|
|
10
|
-
# It is used to validate the token length before the API call is made
|
|
11
|
-
#
|
|
12
|
-
class OpenAIValidator < BaseValidator
|
|
13
|
-
COMPLETION_TOKEN_LIMITS = {
|
|
14
|
-
# GPT-4 Turbo has a separate token limit for completion
|
|
15
|
-
# Source:
|
|
16
|
-
# https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
|
|
17
|
-
"gpt-4-1106-preview" => 4096,
|
|
18
|
-
"gpt-4-vision-preview" => 4096,
|
|
19
|
-
"gpt-3.5-turbo-1106" => 4096
|
|
20
|
-
}
|
|
21
|
-
|
|
22
|
-
# NOTE: The gpt-4-turbo-preview is an alias that will always point to the latest GPT 4 Turbo preview
|
|
23
|
-
# the future previews may have a different token limit!
|
|
24
|
-
TOKEN_LIMITS = {
|
|
25
|
-
# Source:
|
|
26
|
-
# https://platform.openai.com/docs/api-reference/embeddings
|
|
27
|
-
# https://platform.openai.com/docs/models/gpt-4
|
|
28
|
-
"text-embedding-3-large" => 8191,
|
|
29
|
-
"text-embedding-3-small" => 8191,
|
|
30
|
-
"text-embedding-ada-002" => 8191,
|
|
31
|
-
"gpt-3.5-turbo" => 16385,
|
|
32
|
-
"gpt-3.5-turbo-0301" => 4096,
|
|
33
|
-
"gpt-3.5-turbo-0613" => 4096,
|
|
34
|
-
"gpt-3.5-turbo-1106" => 16385,
|
|
35
|
-
"gpt-3.5-turbo-0125" => 16385,
|
|
36
|
-
"gpt-3.5-turbo-16k" => 16384,
|
|
37
|
-
"gpt-3.5-turbo-16k-0613" => 16384,
|
|
38
|
-
"text-davinci-003" => 4097,
|
|
39
|
-
"text-davinci-002" => 4097,
|
|
40
|
-
"code-davinci-002" => 8001,
|
|
41
|
-
"gpt-4" => 8192,
|
|
42
|
-
"gpt-4-0314" => 8192,
|
|
43
|
-
"gpt-4-0613" => 8192,
|
|
44
|
-
"gpt-4-32k" => 32768,
|
|
45
|
-
"gpt-4-32k-0314" => 32768,
|
|
46
|
-
"gpt-4-32k-0613" => 32768,
|
|
47
|
-
"gpt-4-1106-preview" => 128000,
|
|
48
|
-
"gpt-4-turbo" => 128000,
|
|
49
|
-
"gpt-4-turbo-2024-04-09" => 128000,
|
|
50
|
-
"gpt-4-turbo-preview" => 128000,
|
|
51
|
-
"gpt-4-0125-preview" => 128000,
|
|
52
|
-
"gpt-4-vision-preview" => 128000,
|
|
53
|
-
"gpt-4o" => 128000,
|
|
54
|
-
"gpt-4o-2024-05-13" => 128000,
|
|
55
|
-
"text-curie-001" => 2049,
|
|
56
|
-
"text-babbage-001" => 2049,
|
|
57
|
-
"text-ada-001" => 2049,
|
|
58
|
-
"davinci" => 2049,
|
|
59
|
-
"curie" => 2049,
|
|
60
|
-
"babbage" => 2049,
|
|
61
|
-
"ada" => 2049
|
|
62
|
-
}.freeze
|
|
63
|
-
|
|
64
|
-
#
|
|
65
|
-
# Calculate token length for a given text and model name
|
|
66
|
-
#
|
|
67
|
-
# @param text [String] The text to calculate the token length for
|
|
68
|
-
# @param model_name [String] The model name to validate against
|
|
69
|
-
# @return [Integer] The token length of the text
|
|
70
|
-
#
|
|
71
|
-
def self.token_length(text, model_name, options = {})
|
|
72
|
-
# tiktoken-ruby doesn't support text-embedding-3-large or text-embedding-3-small yet
|
|
73
|
-
if ["text-embedding-3-large", "text-embedding-3-small"].include?(model_name)
|
|
74
|
-
model_name = "text-embedding-ada-002"
|
|
75
|
-
end
|
|
76
|
-
|
|
77
|
-
encoder = Tiktoken.encoding_for_model(model_name)
|
|
78
|
-
encoder.encode(text).length
|
|
79
|
-
end
|
|
80
|
-
|
|
81
|
-
def self.token_limit(model_name)
|
|
82
|
-
TOKEN_LIMITS[model_name]
|
|
83
|
-
end
|
|
84
|
-
|
|
85
|
-
def self.completion_token_limit(model_name)
|
|
86
|
-
COMPLETION_TOKEN_LIMITS[model_name] || token_limit(model_name)
|
|
87
|
-
end
|
|
88
|
-
|
|
89
|
-
# If :max_tokens is passed in, take the lower of it and the calculated max_tokens
|
|
90
|
-
def self.validate_max_tokens!(content, model_name, options = {})
|
|
91
|
-
max_tokens = super(content, model_name, options)
|
|
92
|
-
[options[:max_tokens], max_tokens].reject(&:nil?).min
|
|
93
|
-
end
|
|
94
|
-
|
|
95
|
-
# Copied from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
|
96
|
-
# Return the number of tokens used by a list of messages
|
|
97
|
-
#
|
|
98
|
-
# @param messages [Array<Hash>] The messages to calculate the token length for
|
|
99
|
-
# @param model [String] The model name to validate against
|
|
100
|
-
# @return [Integer] The token length of the messages
|
|
101
|
-
#
|
|
102
|
-
def self.token_length_from_messages(messages, model_name, options = {})
|
|
103
|
-
encoding = Tiktoken.encoding_for_model(model_name)
|
|
104
|
-
|
|
105
|
-
if ["gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-4-0314", "gpt-4-32k-0314", "gpt-4-0613", "gpt-4-32k-0613"].include?(model_name)
|
|
106
|
-
tokens_per_message = 3
|
|
107
|
-
tokens_per_name = 1
|
|
108
|
-
elsif model_name == "gpt-3.5-turbo-0301"
|
|
109
|
-
tokens_per_message = 4 # every message follows {role/name}\n{content}\n
|
|
110
|
-
tokens_per_name = -1 # if there's a name, the role is omitted
|
|
111
|
-
elsif model_name.include?("gpt-3.5-turbo")
|
|
112
|
-
# puts "Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613."
|
|
113
|
-
return token_length_from_messages(messages, "gpt-3.5-turbo-0613", options)
|
|
114
|
-
elsif model_name.include?("gpt-4")
|
|
115
|
-
# puts "Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
|
|
116
|
-
return token_length_from_messages(messages, "gpt-4-0613", options)
|
|
117
|
-
else
|
|
118
|
-
raise NotImplementedError.new(
|
|
119
|
-
"token_length_from_messages() is not implemented for model #{model_name}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."
|
|
120
|
-
)
|
|
121
|
-
end
|
|
122
|
-
|
|
123
|
-
num_tokens = 0
|
|
124
|
-
messages.each do |message|
|
|
125
|
-
num_tokens += tokens_per_message
|
|
126
|
-
message.each do |key, value|
|
|
127
|
-
num_tokens += encoding.encode(value).length
|
|
128
|
-
num_tokens += tokens_per_name if ["name", :name].include?(key)
|
|
129
|
-
end
|
|
130
|
-
end
|
|
131
|
-
|
|
132
|
-
num_tokens += 3 # every reply is primed with assistant
|
|
133
|
-
num_tokens
|
|
134
|
-
end
|
|
135
|
-
end
|
|
136
|
-
end
|
|
137
|
-
end
|
|
138
|
-
end
|
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
module Langchain
|
|
4
|
-
module Utils
|
|
5
|
-
module TokenLength
|
|
6
|
-
class TokenLimitExceeded < StandardError
|
|
7
|
-
attr_reader :token_overflow
|
|
8
|
-
|
|
9
|
-
def initialize(message = "", token_overflow = 0)
|
|
10
|
-
super(message)
|
|
11
|
-
|
|
12
|
-
@token_overflow = token_overflow
|
|
13
|
-
end
|
|
14
|
-
end
|
|
15
|
-
end
|
|
16
|
-
end
|
|
17
|
-
end
|