langchainrb 0.5.3 → 0.5.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.env.example +1 -0
- data/CHANGELOG.md +10 -0
- data/Gemfile.lock +10 -1
- data/README.md +5 -2
- data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent.rb +1 -1
- data/lib/langchain/conversation.rb +66 -0
- data/lib/langchain/llm/google_palm.rb +47 -10
- data/lib/langchain/llm/openai.rb +45 -10
- data/lib/langchain/tool/base.rb +9 -0
- data/lib/langchain/tool/weather.rb +67 -0
- data/lib/langchain/utils/token_length/google_palm_validator.rb +69 -0
- data/lib/langchain/utils/token_length/openai_validator.rb +75 -0
- data/lib/langchain/vectorsearch/hnswlib.rb +122 -0
- data/lib/langchain/version.rb +1 -1
- data/lib/langchain.rb +9 -1
- metadata +36 -3
- data/lib/langchain/utils/token_length_validator.rb +0 -89
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 9781999daf45e5fedb0c7a905268866fbefd4581fd35a1a512ebb5844598f2c7
|
|
4
|
+
data.tar.gz: 93e2161a331151218cb94706827ab1ca2d94cb363613a5117ef7ce0c36cb9469
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 58255ecc90b645cf6b276bee83a385d567122fb8b90ad7a075f4b2ec90ba2f6871156c6e5762b0bd67b259371d1c35b3973fd6b4e631be858ea8a368aac163b7
|
|
7
|
+
data.tar.gz: ffcaba1dc980b1f3175269a223ebe522bae1a3e17931b1f911adf387f8fc4a32692d45712ab0032e5a9c273ba304d97ca0ce2086978381d17e716efa1941072a
|
data/.env.example
CHANGED
data/CHANGELOG.md
CHANGED
|
@@ -1,5 +1,15 @@
|
|
|
1
1
|
## [Unreleased]
|
|
2
2
|
|
|
3
|
+
## [0.5.5] - 2023-06-12
|
|
4
|
+
- [BREAKING] Rename `Langchain::Chat` to `Langchain::Conversation`
|
|
5
|
+
- 🛠️ Tools
|
|
6
|
+
- Introducing `Langchain::Tool::Weather`, a tool that calls Open Weather API to retrieve the current weather
|
|
7
|
+
|
|
8
|
+
## [0.5.4] - 2023-06-10
|
|
9
|
+
- 🔍 Vectorsearch
|
|
10
|
+
- Introducing support for HNSWlib
|
|
11
|
+
- Improved and new `Langchain::Chat` interface that persists chat history in memory
|
|
12
|
+
|
|
3
13
|
## [0.5.3] - 2023-06-09
|
|
4
14
|
- 🗣️ LLMs
|
|
5
15
|
- Chat message history support for Langchain::LLM::GooglePalm and Langchain::LLM::OpenAI
|
data/Gemfile.lock
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
PATH
|
|
2
2
|
remote: .
|
|
3
3
|
specs:
|
|
4
|
-
langchainrb (0.5.
|
|
4
|
+
langchainrb (0.5.5)
|
|
5
5
|
colorize (~> 0.8.1)
|
|
6
6
|
tiktoken_ruby (~> 0.0.5)
|
|
7
7
|
|
|
@@ -135,6 +135,8 @@ GEM
|
|
|
135
135
|
activesupport (>= 3.0)
|
|
136
136
|
graphql
|
|
137
137
|
hashery (2.1.2)
|
|
138
|
+
hashie (5.0.0)
|
|
139
|
+
hnswlib (0.8.1)
|
|
138
140
|
httparty (0.21.0)
|
|
139
141
|
mini_mime (>= 1.0.0)
|
|
140
142
|
multi_xml (>= 0.5.2)
|
|
@@ -166,6 +168,11 @@ GEM
|
|
|
166
168
|
racc (~> 1.4)
|
|
167
169
|
nokogiri (1.14.3-x86_64-linux)
|
|
168
170
|
racc (~> 1.4)
|
|
171
|
+
open-weather-ruby-client (0.3.0)
|
|
172
|
+
activesupport
|
|
173
|
+
faraday (>= 1.0.0)
|
|
174
|
+
faraday_middleware
|
|
175
|
+
hashie
|
|
169
176
|
parallel (1.23.0)
|
|
170
177
|
parser (3.2.2.1)
|
|
171
178
|
ast (~> 2.4.1)
|
|
@@ -312,10 +319,12 @@ DEPENDENCIES
|
|
|
312
319
|
eqn (~> 1.6.5)
|
|
313
320
|
google_palm_api (~> 0.1.1)
|
|
314
321
|
google_search_results (~> 2.0.0)
|
|
322
|
+
hnswlib (~> 0.8.1)
|
|
315
323
|
hugging-face (~> 0.3.4)
|
|
316
324
|
langchainrb!
|
|
317
325
|
milvus (~> 0.9.0)
|
|
318
326
|
nokogiri (~> 1.13)
|
|
327
|
+
open-weather-ruby-client (~> 0.3.0)
|
|
319
328
|
pdf-reader (~> 1.4)
|
|
320
329
|
pg (~> 1.5)
|
|
321
330
|
pgvector (~> 0.2)
|
data/README.md
CHANGED
|
@@ -34,6 +34,7 @@ require "langchain"
|
|
|
34
34
|
| Database | Querying | Storage | Schema Management | Backups | Rails Integration |
|
|
35
35
|
| -------- |:------------------:| -------:| -----------------:| -------:| -----------------:|
|
|
36
36
|
| [Chroma](https://trychroma.com/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
|
37
|
+
| [Hnswlib](https://github.com/nmslib/hnswlib/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
|
37
38
|
| [Milvus](https://milvus.io/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
|
38
39
|
| [Pinecone](https://www.pinecone.io/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
|
39
40
|
| [Pgvector](https://github.com/pgvector/pgvector) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
|
@@ -56,11 +57,12 @@ client = Langchain::Vectorsearch::Weaviate.new(
|
|
|
56
57
|
)
|
|
57
58
|
|
|
58
59
|
# You can instantiate any other supported vector search database:
|
|
60
|
+
client = Langchain::Vectorsearch::Chroma.new(...) # `gem "chroma-db", "~> 0.3.0"`
|
|
61
|
+
client = Langchain::Vectorsearch::Hnswlib.new(...) # `gem "hnswlib", "~> 0.8.1"`
|
|
59
62
|
client = Langchain::Vectorsearch::Milvus.new(...) # `gem "milvus", "~> 0.9.0"`
|
|
60
|
-
client = Langchain::Vectorsearch::Qdrant.new(...) # `gem"qdrant-ruby", "~> 0.9.0"`
|
|
61
63
|
client = Langchain::Vectorsearch::Pinecone.new(...) # `gem "pinecone", "~> 0.1.6"`
|
|
62
|
-
client = Langchain::Vectorsearch::Chroma.new(...) # `gem "chroma-db", "~> 0.3.0"`
|
|
63
64
|
client = Langchain::Vectorsearch::Pgvector.new(...) # `gem "pgvector", "~> 0.2"`
|
|
65
|
+
client = Langchain::Vectorsearch::Qdrant.new(...) # `gem"qdrant-ruby", "~> 0.9.0"`
|
|
64
66
|
```
|
|
65
67
|
|
|
66
68
|
```ruby
|
|
@@ -307,6 +309,7 @@ agent.run(question: "How many users have a name with length greater than 5 in th
|
|
|
307
309
|
| "database" | Useful for querying a SQL database | | `gem "sequel", "~> 5.68.0"` |
|
|
308
310
|
| "ruby_code_interpreter" | Interprets Ruby expressions | | `gem "safe_ruby", "~> 1.0.4"` |
|
|
309
311
|
| "search" | A wrapper around Google Search | `ENV["SERPAPI_API_KEY"]` (https://serpapi.com/manage-api-key) | `gem "google_search_results", "~> 2.0.0"` |
|
|
312
|
+
| "weather" | Calls Open Weather API to retrieve the current weather | `ENV["OPEN_WEATHER_API_KEY]` (https://home.openweathermap.org/api_keys) | `gem "open-weather-ruby-client", "~> 0.3.0"` |
|
|
310
313
|
| "wikipedia" | Calls Wikipedia API to retrieve the summary | | `gem "wikipedia-client", "~> 1.17.0"` |
|
|
311
314
|
|
|
312
315
|
#### Loaders 🚚
|
|
@@ -101,7 +101,7 @@ module Langchain::Agent
|
|
|
101
101
|
tool_names: "[#{tool_list.join(", ")}]",
|
|
102
102
|
tools: tools.map do |tool|
|
|
103
103
|
tool_name = tool.tool_name
|
|
104
|
-
tool_description = tool.
|
|
104
|
+
tool_description = tool.tool_description
|
|
105
105
|
"#{tool_name}: #{tool_description}"
|
|
106
106
|
end.join("\n")
|
|
107
107
|
)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Langchain
|
|
4
|
+
#
|
|
5
|
+
# A high-level API for running a conversation with an LLM.
|
|
6
|
+
# Currently supports: OpenAI and Google PaLM LLMs.
|
|
7
|
+
#
|
|
8
|
+
# Usage:
|
|
9
|
+
# llm = Langchain::LLM::OpenAI.new(api_key: "YOUR_API_KEY")
|
|
10
|
+
# chat = Langchain::Conversation.new(llm: llm)
|
|
11
|
+
# chat.set_context("You are a chatbot from the future")
|
|
12
|
+
# chat.message("Tell me about future technologies")
|
|
13
|
+
#
|
|
14
|
+
class Conversation
|
|
15
|
+
attr_reader :context, :examples, :messages
|
|
16
|
+
|
|
17
|
+
# Intialize Conversation with a LLM
|
|
18
|
+
#
|
|
19
|
+
# @param llm [Object] The LLM to use for the conversation
|
|
20
|
+
# @param options [Hash] Options to pass to the LLM, like temperature, top_k, etc.
|
|
21
|
+
# @return [Langchain::Conversation] The Langchain::Conversation instance
|
|
22
|
+
def initialize(llm:, **options)
|
|
23
|
+
@llm = llm
|
|
24
|
+
@options = options
|
|
25
|
+
@context = nil
|
|
26
|
+
@examples = []
|
|
27
|
+
@messages = []
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
# Set the context of the conversation. Usually used to set the model's persona.
|
|
31
|
+
# @param message [String] The context of the conversation
|
|
32
|
+
def set_context(message)
|
|
33
|
+
@context = message
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
# Add examples to the conversation. Used to give the model a sense of the conversation.
|
|
37
|
+
# @param examples [Array<Hash>] The examples to add to the conversation
|
|
38
|
+
def add_examples(examples)
|
|
39
|
+
@examples.concat examples
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
# Message the model with a prompt and return the response.
|
|
43
|
+
# @param message [String] The prompt to message the model with
|
|
44
|
+
# @return [String] The response from the model
|
|
45
|
+
def message(message)
|
|
46
|
+
append_user_message(message)
|
|
47
|
+
response = llm_response(message)
|
|
48
|
+
append_ai_message(response)
|
|
49
|
+
response
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
private
|
|
53
|
+
|
|
54
|
+
def llm_response(prompt)
|
|
55
|
+
@llm.chat(messages: @messages, context: @context, examples: @examples, **@options)
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
def append_ai_message(message)
|
|
59
|
+
@messages << {role: "ai", content: message}
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
def append_user_message(message)
|
|
63
|
+
@messages << {role: "user", content: message}
|
|
64
|
+
end
|
|
65
|
+
end
|
|
66
|
+
end
|
|
@@ -80,28 +80,31 @@ module Langchain::LLM
|
|
|
80
80
|
# @param params extra parameters passed to GooglePalmAPI::Client#generate_chat_message
|
|
81
81
|
# @return [String] The chat completion
|
|
82
82
|
#
|
|
83
|
-
def chat(prompt: "", messages: [], **
|
|
83
|
+
def chat(prompt: "", messages: [], context: "", examples: [], **options)
|
|
84
84
|
raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
|
|
85
85
|
|
|
86
|
-
messages << {author: "0", content: prompt} if !prompt.empty?
|
|
87
|
-
|
|
88
|
-
# TODO: Figure out how to introduce persisted conversations
|
|
89
86
|
default_params = {
|
|
90
87
|
temperature: DEFAULTS[:temperature],
|
|
91
|
-
|
|
88
|
+
context: context,
|
|
89
|
+
messages: compose_chat_messages(prompt: prompt, messages: messages),
|
|
90
|
+
examples: compose_examples(examples)
|
|
92
91
|
}
|
|
93
92
|
|
|
94
|
-
|
|
95
|
-
|
|
93
|
+
Langchain::Utils::TokenLength::GooglePalmValidator.validate_max_tokens!(self, default_params[:messages], "chat-bison-001")
|
|
94
|
+
|
|
95
|
+
if options[:stop_sequences]
|
|
96
|
+
default_params[:stop] = options.delete(:stop_sequences)
|
|
96
97
|
end
|
|
97
98
|
|
|
98
|
-
if
|
|
99
|
-
default_params[:max_output_tokens] =
|
|
99
|
+
if options[:max_tokens]
|
|
100
|
+
default_params[:max_output_tokens] = options.delete(:max_tokens)
|
|
100
101
|
end
|
|
101
102
|
|
|
102
|
-
default_params.merge!(
|
|
103
|
+
default_params.merge!(options)
|
|
103
104
|
|
|
104
105
|
response = client.generate_chat_message(**default_params)
|
|
106
|
+
raise "GooglePalm API returned an error: #{response}" if response.dig("error")
|
|
107
|
+
|
|
105
108
|
response.dig("candidates", 0, "content")
|
|
106
109
|
end
|
|
107
110
|
|
|
@@ -124,5 +127,39 @@ module Langchain::LLM
|
|
|
124
127
|
max_tokens: 2048
|
|
125
128
|
)
|
|
126
129
|
end
|
|
130
|
+
|
|
131
|
+
private
|
|
132
|
+
|
|
133
|
+
def compose_chat_messages(prompt:, messages:)
|
|
134
|
+
history = []
|
|
135
|
+
history.concat transform_messages(messages) unless messages.empty?
|
|
136
|
+
|
|
137
|
+
unless prompt.empty?
|
|
138
|
+
if history.last && history.last[:role] == "user"
|
|
139
|
+
history.last[:content] += "\n#{prompt}"
|
|
140
|
+
else
|
|
141
|
+
history.append({author: "user", content: prompt})
|
|
142
|
+
end
|
|
143
|
+
end
|
|
144
|
+
history
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
def compose_examples(examples)
|
|
148
|
+
examples.each_slice(2).map do |example|
|
|
149
|
+
{
|
|
150
|
+
input: {content: example.first[:content]},
|
|
151
|
+
output: {content: example.last[:content]}
|
|
152
|
+
}
|
|
153
|
+
end
|
|
154
|
+
end
|
|
155
|
+
|
|
156
|
+
def transform_messages(messages)
|
|
157
|
+
messages.map do |message|
|
|
158
|
+
{
|
|
159
|
+
author: message[:role],
|
|
160
|
+
content: message[:content]
|
|
161
|
+
}
|
|
162
|
+
end
|
|
163
|
+
end
|
|
127
164
|
end
|
|
128
165
|
end
|
data/lib/langchain/llm/openai.rb
CHANGED
|
@@ -35,7 +35,7 @@ module Langchain::LLM
|
|
|
35
35
|
def embed(text:, **params)
|
|
36
36
|
parameters = {model: DEFAULTS[:embeddings_model_name], input: text}
|
|
37
37
|
|
|
38
|
-
Langchain::Utils::
|
|
38
|
+
Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(text, parameters[:model])
|
|
39
39
|
|
|
40
40
|
response = client.embeddings(parameters: parameters.merge(params))
|
|
41
41
|
response.dig("data").first.dig("embedding")
|
|
@@ -52,7 +52,7 @@ module Langchain::LLM
|
|
|
52
52
|
parameters = compose_parameters DEFAULTS[:completion_model_name], params
|
|
53
53
|
|
|
54
54
|
parameters[:prompt] = prompt
|
|
55
|
-
parameters[:max_tokens] = Langchain::Utils::
|
|
55
|
+
parameters[:max_tokens] = Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(prompt, parameters[:model])
|
|
56
56
|
|
|
57
57
|
response = client.completions(parameters: parameters)
|
|
58
58
|
response.dig("choices", 0, "text")
|
|
@@ -63,19 +63,22 @@ module Langchain::LLM
|
|
|
63
63
|
#
|
|
64
64
|
# @param prompt [String] The prompt to generate a chat completion for
|
|
65
65
|
# @param messages [Array] The messages that have been sent in the conversation
|
|
66
|
-
# @param
|
|
66
|
+
# @param context [String] The context of the conversation
|
|
67
|
+
# @param examples [Array] Examples of messages provide model with
|
|
68
|
+
# @param options extra parameters passed to OpenAI::Client#chat
|
|
67
69
|
# @return [String] The chat completion
|
|
68
70
|
#
|
|
69
|
-
def chat(prompt: "", messages: [], **
|
|
71
|
+
def chat(prompt: "", messages: [], context: "", examples: [], **options)
|
|
70
72
|
raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
|
|
71
73
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
parameters =
|
|
75
|
-
parameters[:messages] = messages
|
|
76
|
-
parameters[:max_tokens] = validate_max_tokens(messages, parameters[:model])
|
|
74
|
+
parameters = compose_parameters DEFAULTS[:chat_completion_model_name], options
|
|
75
|
+
parameters[:messages] = compose_chat_messages(prompt: prompt, messages: messages, context: context, examples: examples)
|
|
76
|
+
parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
|
|
77
77
|
|
|
78
78
|
response = client.chat(parameters: parameters)
|
|
79
|
+
|
|
80
|
+
raise "Chat completion failed: #{response}" if response.dig("error")
|
|
81
|
+
|
|
79
82
|
response.dig("choices", 0, "message", "content")
|
|
80
83
|
end
|
|
81
84
|
|
|
@@ -104,8 +107,40 @@ module Langchain::LLM
|
|
|
104
107
|
default_params.merge(params)
|
|
105
108
|
end
|
|
106
109
|
|
|
110
|
+
def compose_chat_messages(prompt:, messages:, context:, examples:)
|
|
111
|
+
history = []
|
|
112
|
+
|
|
113
|
+
history.concat transform_messages(examples) unless examples.empty?
|
|
114
|
+
|
|
115
|
+
history.concat transform_messages(messages) unless messages.empty?
|
|
116
|
+
|
|
117
|
+
unless context.nil? || context.empty?
|
|
118
|
+
history.reject! { |message| message[:role] == "system" }
|
|
119
|
+
history.prepend({role: "system", content: context})
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
unless prompt.empty?
|
|
123
|
+
if history.last && history.last[:role] == "user"
|
|
124
|
+
history.last[:content] += "\n#{prompt}"
|
|
125
|
+
else
|
|
126
|
+
history.append({role: "user", content: prompt})
|
|
127
|
+
end
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
history
|
|
131
|
+
end
|
|
132
|
+
|
|
133
|
+
def transform_messages(messages)
|
|
134
|
+
messages.map do |message|
|
|
135
|
+
{
|
|
136
|
+
content: message[:content],
|
|
137
|
+
role: (message[:role] == "ai") ? "assistant" : message[:role]
|
|
138
|
+
}
|
|
139
|
+
end
|
|
140
|
+
end
|
|
141
|
+
|
|
107
142
|
def validate_max_tokens(messages, model)
|
|
108
|
-
Langchain::Utils::
|
|
143
|
+
Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(messages, model)
|
|
109
144
|
end
|
|
110
145
|
end
|
|
111
146
|
end
|
data/lib/langchain/tool/base.rb
CHANGED
|
@@ -57,6 +57,15 @@ module Langchain::Tool
|
|
|
57
57
|
self.class.const_get(:NAME)
|
|
58
58
|
end
|
|
59
59
|
|
|
60
|
+
#
|
|
61
|
+
# Returns the DESCRIPTION constant of the tool
|
|
62
|
+
#
|
|
63
|
+
# @return [String] tool description
|
|
64
|
+
#
|
|
65
|
+
def tool_description
|
|
66
|
+
self.class.const_get(:DESCRIPTION)
|
|
67
|
+
end
|
|
68
|
+
|
|
60
69
|
#
|
|
61
70
|
# Sets the DESCRIPTION constant of the tool
|
|
62
71
|
#
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Langchain::Tool
|
|
4
|
+
class Weather < Base
|
|
5
|
+
#
|
|
6
|
+
# A weather tool that gets current weather data
|
|
7
|
+
#
|
|
8
|
+
# Current weather data is free for 1000 calls per day (https://home.openweathermap.org/api_keys)
|
|
9
|
+
# Forecast and historical data require registration with credit card, so not supported yet.
|
|
10
|
+
#
|
|
11
|
+
# Gem requirements:
|
|
12
|
+
# gem "open-weather-ruby-client", "~> 0.3.0"
|
|
13
|
+
# api_key: https://home.openweathermap.org/api_keys
|
|
14
|
+
#
|
|
15
|
+
# Usage:
|
|
16
|
+
# weather = Langchain::Tool::Weather.new(api_key: "YOUR_API_KEY")
|
|
17
|
+
# weather.execute(input: "Boston, MA; imperial")
|
|
18
|
+
#
|
|
19
|
+
|
|
20
|
+
NAME = "weather"
|
|
21
|
+
|
|
22
|
+
description <<~DESC
|
|
23
|
+
Useful for getting current weather data
|
|
24
|
+
|
|
25
|
+
The input to this tool should be a city name followed by the units (imperial, metric, or standard)
|
|
26
|
+
Usage:
|
|
27
|
+
Action Input: St Louis, Missouri; metric
|
|
28
|
+
Action Input: Boston, Massachusetts; imperial
|
|
29
|
+
Action Input: Dubai, AE; imperial
|
|
30
|
+
Action Input: Kiev, Ukraine; metric
|
|
31
|
+
DESC
|
|
32
|
+
|
|
33
|
+
attr_reader :client, :units
|
|
34
|
+
|
|
35
|
+
#
|
|
36
|
+
# Initializes the Weather tool
|
|
37
|
+
#
|
|
38
|
+
# @param api_key [String] Open Weather API key
|
|
39
|
+
# @return [Langchain::Tool::Weather] Weather tool
|
|
40
|
+
#
|
|
41
|
+
def initialize(api_key:, units: "metric")
|
|
42
|
+
depends_on "open-weather-ruby-client"
|
|
43
|
+
require "open-weather-ruby-client"
|
|
44
|
+
|
|
45
|
+
OpenWeather::Client.configure do |config|
|
|
46
|
+
config.api_key = api_key
|
|
47
|
+
config.user_agent = "Langchainrb Ruby Client"
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
@client = OpenWeather::Client.new
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
# Returns current weather for a city
|
|
54
|
+
# @param input [String] comma separated city and unit (optional: imperial, metric, or standard)
|
|
55
|
+
# @return [String] Answer
|
|
56
|
+
def execute(input:)
|
|
57
|
+
Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing for \"#{input}\"")
|
|
58
|
+
|
|
59
|
+
input_array = input.split(";")
|
|
60
|
+
city, units = *input_array.map(&:strip)
|
|
61
|
+
|
|
62
|
+
data = client.current_weather(city: city, units: units)
|
|
63
|
+
weather = data.main.map { |key, value| "#{key} #{value}" }.join(", ")
|
|
64
|
+
"The current weather in #{data.name} is #{weather}"
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
end
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Langchain
|
|
4
|
+
module Utils
|
|
5
|
+
module TokenLength
|
|
6
|
+
#
|
|
7
|
+
# This class is meant to validate the length of the text passed in to Google Palm's API.
|
|
8
|
+
# It is used to validate the token length before the API call is made
|
|
9
|
+
#
|
|
10
|
+
class GooglePalmValidator
|
|
11
|
+
TOKEN_LIMITS = {
|
|
12
|
+
# Source:
|
|
13
|
+
# This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage
|
|
14
|
+
|
|
15
|
+
# chat-bison-001 is the only model that currently supports countMessageTokens functions
|
|
16
|
+
"chat-bison-001" => {
|
|
17
|
+
"input_token_limit" => 4000, # 4096 is the limit but the countMessageTokens does not return anything higher than 4000
|
|
18
|
+
"output_token_limit" => 1024
|
|
19
|
+
}
|
|
20
|
+
# "text-bison-001" => {
|
|
21
|
+
# "input_token_limit" => 8196,
|
|
22
|
+
# "output_token_limit" => 1024
|
|
23
|
+
# },
|
|
24
|
+
# "embedding-gecko-001" => {
|
|
25
|
+
# "input_token_limit" => 1024
|
|
26
|
+
# }
|
|
27
|
+
}.freeze
|
|
28
|
+
|
|
29
|
+
#
|
|
30
|
+
# Validate the context length of the text
|
|
31
|
+
#
|
|
32
|
+
# @param content [String | Array<String>] The text or array of texts to validate
|
|
33
|
+
# @param model_name [String] The model name to validate against
|
|
34
|
+
# @return [Integer] Whether the text is valid or not
|
|
35
|
+
# @raise [TokenLimitExceeded] If the text is too long
|
|
36
|
+
#
|
|
37
|
+
def self.validate_max_tokens!(google_palm_llm, content, model_name)
|
|
38
|
+
text_token_length = if content.is_a?(Array)
|
|
39
|
+
content.sum { |item| token_length(google_palm_llm, item.to_json, model_name) }
|
|
40
|
+
else
|
|
41
|
+
token_length(google_palm_llm, content, model_name)
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
leftover_tokens = TOKEN_LIMITS.dig(model_name, "input_token_limit") - text_token_length
|
|
45
|
+
|
|
46
|
+
# Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
|
|
47
|
+
if leftover_tokens <= 0
|
|
48
|
+
raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS.dig(model_name, "input_token_limit")} tokens, but the given text is #{text_token_length} tokens long."
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
leftover_tokens
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
#
|
|
55
|
+
# Calculate token length for a given text and model name
|
|
56
|
+
#
|
|
57
|
+
# @param llm [Langchain::LLM:GooglePalm] The Langchain::LLM:GooglePalm instance
|
|
58
|
+
# @param text [String] The text to calculate the token length for
|
|
59
|
+
# @param model_name [String] The model name to validate against
|
|
60
|
+
# @return [Integer] The token length of the text
|
|
61
|
+
#
|
|
62
|
+
def self.token_length(llm, text, model_name = "chat-bison-001")
|
|
63
|
+
response = llm.client.count_message_tokens(model: model_name, prompt: text)
|
|
64
|
+
response.dig("tokenCount")
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
end
|
|
68
|
+
end
|
|
69
|
+
end
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "tiktoken_ruby"
|
|
4
|
+
|
|
5
|
+
module Langchain
|
|
6
|
+
module Utils
|
|
7
|
+
module TokenLength
|
|
8
|
+
#
|
|
9
|
+
# This class is meant to validate the length of the text passed in to OpenAI's API.
|
|
10
|
+
# It is used to validate the token length before the API call is made
|
|
11
|
+
#
|
|
12
|
+
class OpenAIValidator
|
|
13
|
+
TOKEN_LIMITS = {
|
|
14
|
+
# Source:
|
|
15
|
+
# https://platform.openai.com/docs/api-reference/embeddings
|
|
16
|
+
# https://platform.openai.com/docs/models/gpt-4
|
|
17
|
+
"text-embedding-ada-002" => 8191,
|
|
18
|
+
"gpt-3.5-turbo" => 4096,
|
|
19
|
+
"gpt-3.5-turbo-0301" => 4096,
|
|
20
|
+
"text-davinci-003" => 4097,
|
|
21
|
+
"text-davinci-002" => 4097,
|
|
22
|
+
"code-davinci-002" => 8001,
|
|
23
|
+
"gpt-4" => 8192,
|
|
24
|
+
"gpt-4-0314" => 8192,
|
|
25
|
+
"gpt-4-32k" => 32768,
|
|
26
|
+
"gpt-4-32k-0314" => 32768,
|
|
27
|
+
"text-curie-001" => 2049,
|
|
28
|
+
"text-babbage-001" => 2049,
|
|
29
|
+
"text-ada-001" => 2049,
|
|
30
|
+
"davinci" => 2049,
|
|
31
|
+
"curie" => 2049,
|
|
32
|
+
"babbage" => 2049,
|
|
33
|
+
"ada" => 2049
|
|
34
|
+
}.freeze
|
|
35
|
+
|
|
36
|
+
#
|
|
37
|
+
# Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
|
|
38
|
+
#
|
|
39
|
+
# @param content [String | Array<String>] The text or array of texts to validate
|
|
40
|
+
# @param model_name [String] The model name to validate against
|
|
41
|
+
# @return [Integer] Whether the text is valid or not
|
|
42
|
+
# @raise [TokenLimitExceeded] If the text is too long
|
|
43
|
+
#
|
|
44
|
+
def self.validate_max_tokens!(content, model_name)
|
|
45
|
+
text_token_length = if content.is_a?(Array)
|
|
46
|
+
content.sum { |item| token_length(item.to_json, model_name) }
|
|
47
|
+
else
|
|
48
|
+
token_length(content, model_name)
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
max_tokens = TOKEN_LIMITS[model_name] - text_token_length
|
|
52
|
+
|
|
53
|
+
# Raise an error even if whole prompt is equal to the model's token limit (max_tokens == 0) since not response will be returned
|
|
54
|
+
if max_tokens <= 0
|
|
55
|
+
raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS[model_name]} tokens, but the given text is #{text_token_length} tokens long."
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
max_tokens
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
#
|
|
62
|
+
# Calculate token length for a given text and model name
|
|
63
|
+
#
|
|
64
|
+
# @param text [String] The text to calculate the token length for
|
|
65
|
+
# @param model_name [String] The model name to validate against
|
|
66
|
+
# @return [Integer] The token length of the text
|
|
67
|
+
#
|
|
68
|
+
def self.token_length(text, model_name)
|
|
69
|
+
encoder = Tiktoken.encoding_for_model(model_name)
|
|
70
|
+
encoder.encode(text).length
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
end
|
|
74
|
+
end
|
|
75
|
+
end
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Langchain::Vectorsearch
|
|
4
|
+
class Hnswlib < Base
|
|
5
|
+
#
|
|
6
|
+
# Wrapper around HNSW (Hierarchical Navigable Small World) library.
|
|
7
|
+
# HNSWLib is an in-memory vectorstore that can be saved to a file on disk.
|
|
8
|
+
#
|
|
9
|
+
# Gem requirements:
|
|
10
|
+
# gem "hnswlib", "~> 0.8.1"
|
|
11
|
+
#
|
|
12
|
+
# Usage:
|
|
13
|
+
# hnsw = Langchain::Vectorsearch::Hnswlib.new(llm:, url:, index_name:)
|
|
14
|
+
#
|
|
15
|
+
|
|
16
|
+
attr_reader :client, :path_to_index
|
|
17
|
+
|
|
18
|
+
#
|
|
19
|
+
# Initialize the HNSW vector search
|
|
20
|
+
#
|
|
21
|
+
# @param llm [Object] The LLM client to use
|
|
22
|
+
# @param path_to_index [String] The local path to the index file, e.g.: "/storage/index.ann"
|
|
23
|
+
# @return [Langchain::Vectorsearch::Hnswlib] Class instance
|
|
24
|
+
#
|
|
25
|
+
def initialize(llm:, path_to_index:)
|
|
26
|
+
depends_on "hnswlib"
|
|
27
|
+
require "hnswlib"
|
|
28
|
+
|
|
29
|
+
super(llm: llm)
|
|
30
|
+
|
|
31
|
+
@client = ::Hnswlib::HierarchicalNSW.new(space: DEFAULT_METRIC, dim: llm.default_dimension)
|
|
32
|
+
@path_to_index = path_to_index
|
|
33
|
+
|
|
34
|
+
initialize_index
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
#
|
|
38
|
+
# Add a list of texts and corresponding IDs to the index
|
|
39
|
+
#
|
|
40
|
+
# @param texts [Array] The list of texts to add
|
|
41
|
+
# @param ids [Array] The list of corresponding IDs (integers) to the texts
|
|
42
|
+
# @return [Boolean] The response from the HNSW library
|
|
43
|
+
#
|
|
44
|
+
def add_texts(texts:, ids:)
|
|
45
|
+
resize_index(texts.size)
|
|
46
|
+
|
|
47
|
+
Array(texts).each_with_index do |text, i|
|
|
48
|
+
embedding = llm.embed(text: text)
|
|
49
|
+
|
|
50
|
+
client.add_point(embedding, ids[i])
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
client.save_index(path_to_index)
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
#
|
|
57
|
+
# Search for similar texts
|
|
58
|
+
#
|
|
59
|
+
# @param query [String] The text to search for
|
|
60
|
+
# @param k [Integer] The number of results to return
|
|
61
|
+
# @return [Array] Results in the format `[[id1, distance3], [id2, distance2]]`
|
|
62
|
+
#
|
|
63
|
+
def similarity_search(
|
|
64
|
+
query:,
|
|
65
|
+
k: 4
|
|
66
|
+
)
|
|
67
|
+
embedding = llm.embed(text: query)
|
|
68
|
+
|
|
69
|
+
similarity_search_by_vector(
|
|
70
|
+
embedding: embedding,
|
|
71
|
+
k: k
|
|
72
|
+
)
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
#
|
|
76
|
+
# Search for the K nearest neighbors of a given vector
|
|
77
|
+
#
|
|
78
|
+
# @param embedding [Array] The embedding to search for
|
|
79
|
+
# @param k [Integer] The number of results to return
|
|
80
|
+
# @return [Array] Results in the format `[[id1, distance3], [id2, distance2]]`
|
|
81
|
+
#
|
|
82
|
+
def similarity_search_by_vector(
|
|
83
|
+
embedding:,
|
|
84
|
+
k: 4
|
|
85
|
+
)
|
|
86
|
+
client.search_knn(embedding, k)
|
|
87
|
+
end
|
|
88
|
+
|
|
89
|
+
private
|
|
90
|
+
|
|
91
|
+
#
|
|
92
|
+
# Optionally resizes the index if there's no space for new data
|
|
93
|
+
#
|
|
94
|
+
# @param num_of_elements_to_add [Integer] The number of elements to add to the index
|
|
95
|
+
#
|
|
96
|
+
def resize_index(num_of_elements_to_add)
|
|
97
|
+
current_count = client.current_count
|
|
98
|
+
|
|
99
|
+
if (current_count + num_of_elements_to_add) > client.max_elements
|
|
100
|
+
new_size = current_count + num_of_elements_to_add
|
|
101
|
+
|
|
102
|
+
client.resize_index(new_size)
|
|
103
|
+
end
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
#
|
|
107
|
+
# Loads or initializes the new index
|
|
108
|
+
#
|
|
109
|
+
def initialize_index
|
|
110
|
+
if File.exist?(path_to_index)
|
|
111
|
+
client.load_index(path_to_index)
|
|
112
|
+
|
|
113
|
+
Langchain.logger.info("[#{self.class.name}]".blue + ": Successfully loaded the index at \"#{path_to_index}\"")
|
|
114
|
+
else
|
|
115
|
+
# Default max_elements: 100, but we constantly resize the index as new data is written to it
|
|
116
|
+
client.init_index(max_elements: 100)
|
|
117
|
+
|
|
118
|
+
Langchain.logger.info("[#{self.class.name}]".blue + ": Creating a new index at \"#{path_to_index}\"")
|
|
119
|
+
end
|
|
120
|
+
end
|
|
121
|
+
end
|
|
122
|
+
end
|
data/lib/langchain/version.rb
CHANGED
data/lib/langchain.rb
CHANGED
|
@@ -62,6 +62,7 @@ module Langchain
|
|
|
62
62
|
|
|
63
63
|
autoload :Loader, "langchain/loader"
|
|
64
64
|
autoload :Data, "langchain/data"
|
|
65
|
+
autoload :Conversation, "langchain/conversation"
|
|
65
66
|
autoload :DependencyHelper, "langchain/dependency_helper"
|
|
66
67
|
|
|
67
68
|
module Agent
|
|
@@ -75,6 +76,7 @@ module Langchain
|
|
|
75
76
|
autoload :Calculator, "langchain/tool/calculator"
|
|
76
77
|
autoload :RubyCodeInterpreter, "langchain/tool/ruby_code_interpreter"
|
|
77
78
|
autoload :SerpApi, "langchain/tool/serp_api"
|
|
79
|
+
autoload :Weather, "langchain/tool/weather"
|
|
78
80
|
autoload :Wikipedia, "langchain/tool/wikipedia"
|
|
79
81
|
autoload :Database, "langchain/tool/database"
|
|
80
82
|
end
|
|
@@ -92,12 +94,18 @@ module Langchain
|
|
|
92
94
|
end
|
|
93
95
|
|
|
94
96
|
module Utils
|
|
95
|
-
|
|
97
|
+
module TokenLength
|
|
98
|
+
class TokenLimitExceeded < StandardError; end
|
|
99
|
+
|
|
100
|
+
autoload :OpenAIValidator, "langchain/utils/token_length/openai_validator"
|
|
101
|
+
autoload :GooglePalmValidator, "langchain/utils/token_length/google_palm_validator"
|
|
102
|
+
end
|
|
96
103
|
end
|
|
97
104
|
|
|
98
105
|
module Vectorsearch
|
|
99
106
|
autoload :Base, "langchain/vectorsearch/base"
|
|
100
107
|
autoload :Chroma, "langchain/vectorsearch/chroma"
|
|
108
|
+
autoload :Hnswlib, "langchain/vectorsearch/hnswlib"
|
|
101
109
|
autoload :Milvus, "langchain/vectorsearch/milvus"
|
|
102
110
|
autoload :Pinecone, "langchain/vectorsearch/pinecone"
|
|
103
111
|
autoload :Pgvector, "langchain/vectorsearch/pgvector"
|
metadata
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: langchainrb
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.5.
|
|
4
|
+
version: 0.5.5
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- Andrei Bondarev
|
|
8
8
|
autorequire:
|
|
9
9
|
bindir: exe
|
|
10
10
|
cert_chain: []
|
|
11
|
-
date: 2023-06-
|
|
11
|
+
date: 2023-06-12 00:00:00.000000000 Z
|
|
12
12
|
dependencies:
|
|
13
13
|
- !ruby/object:Gem::Dependency
|
|
14
14
|
name: tiktoken_ruby
|
|
@@ -192,6 +192,20 @@ dependencies:
|
|
|
192
192
|
- - "~>"
|
|
193
193
|
- !ruby/object:Gem::Version
|
|
194
194
|
version: 2.0.0
|
|
195
|
+
- !ruby/object:Gem::Dependency
|
|
196
|
+
name: hnswlib
|
|
197
|
+
requirement: !ruby/object:Gem::Requirement
|
|
198
|
+
requirements:
|
|
199
|
+
- - "~>"
|
|
200
|
+
- !ruby/object:Gem::Version
|
|
201
|
+
version: 0.8.1
|
|
202
|
+
type: :development
|
|
203
|
+
prerelease: false
|
|
204
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
205
|
+
requirements:
|
|
206
|
+
- - "~>"
|
|
207
|
+
- !ruby/object:Gem::Version
|
|
208
|
+
version: 0.8.1
|
|
195
209
|
- !ruby/object:Gem::Dependency
|
|
196
210
|
name: hugging-face
|
|
197
211
|
requirement: !ruby/object:Gem::Requirement
|
|
@@ -234,6 +248,20 @@ dependencies:
|
|
|
234
248
|
- - "~>"
|
|
235
249
|
- !ruby/object:Gem::Version
|
|
236
250
|
version: '1.13'
|
|
251
|
+
- !ruby/object:Gem::Dependency
|
|
252
|
+
name: open-weather-ruby-client
|
|
253
|
+
requirement: !ruby/object:Gem::Requirement
|
|
254
|
+
requirements:
|
|
255
|
+
- - "~>"
|
|
256
|
+
- !ruby/object:Gem::Version
|
|
257
|
+
version: 0.3.0
|
|
258
|
+
type: :development
|
|
259
|
+
prerelease: false
|
|
260
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
261
|
+
requirements:
|
|
262
|
+
- - "~>"
|
|
263
|
+
- !ruby/object:Gem::Version
|
|
264
|
+
version: 0.3.0
|
|
237
265
|
- !ruby/object:Gem::Dependency
|
|
238
266
|
name: pg
|
|
239
267
|
requirement: !ruby/object:Gem::Requirement
|
|
@@ -432,6 +460,7 @@ files:
|
|
|
432
460
|
- lib/langchain/agent/sql_query_agent/sql_query_agent.rb
|
|
433
461
|
- lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.json
|
|
434
462
|
- lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.json
|
|
463
|
+
- lib/langchain/conversation.rb
|
|
435
464
|
- lib/langchain/data.rb
|
|
436
465
|
- lib/langchain/dependency_helper.rb
|
|
437
466
|
- lib/langchain/llm/ai21.rb
|
|
@@ -461,10 +490,13 @@ files:
|
|
|
461
490
|
- lib/langchain/tool/database.rb
|
|
462
491
|
- lib/langchain/tool/ruby_code_interpreter.rb
|
|
463
492
|
- lib/langchain/tool/serp_api.rb
|
|
493
|
+
- lib/langchain/tool/weather.rb
|
|
464
494
|
- lib/langchain/tool/wikipedia.rb
|
|
465
|
-
- lib/langchain/utils/
|
|
495
|
+
- lib/langchain/utils/token_length/google_palm_validator.rb
|
|
496
|
+
- lib/langchain/utils/token_length/openai_validator.rb
|
|
466
497
|
- lib/langchain/vectorsearch/base.rb
|
|
467
498
|
- lib/langchain/vectorsearch/chroma.rb
|
|
499
|
+
- lib/langchain/vectorsearch/hnswlib.rb
|
|
468
500
|
- lib/langchain/vectorsearch/milvus.rb
|
|
469
501
|
- lib/langchain/vectorsearch/pgvector.rb
|
|
470
502
|
- lib/langchain/vectorsearch/pinecone.rb
|
|
@@ -480,6 +512,7 @@ metadata:
|
|
|
480
512
|
homepage_uri: https://rubygems.org/gems/langchainrb
|
|
481
513
|
source_code_uri: https://github.com/andreibondarev/langchainrb
|
|
482
514
|
changelog_uri: https://github.com/andreibondarev/langchainrb/CHANGELOG.md
|
|
515
|
+
documentation_uri: https://rubydoc.info/gems/langchainrb
|
|
483
516
|
post_install_message:
|
|
484
517
|
rdoc_options: []
|
|
485
518
|
require_paths:
|
|
@@ -1,89 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require "tiktoken_ruby"
|
|
4
|
-
|
|
5
|
-
module Langchain
|
|
6
|
-
module Utils
|
|
7
|
-
class TokenLimitExceeded < StandardError; end
|
|
8
|
-
|
|
9
|
-
class TokenLengthValidator
|
|
10
|
-
#
|
|
11
|
-
# This class is meant to validate the length of the text passed in to OpenAI's API.
|
|
12
|
-
# It is used to validate the token length before the API call is made
|
|
13
|
-
#
|
|
14
|
-
TOKEN_LIMITS = {
|
|
15
|
-
# Source:
|
|
16
|
-
# https://platform.openai.com/docs/api-reference/embeddings
|
|
17
|
-
# https://platform.openai.com/docs/models/gpt-4
|
|
18
|
-
"text-embedding-ada-002" => 8191,
|
|
19
|
-
"gpt-3.5-turbo" => 4096,
|
|
20
|
-
"gpt-3.5-turbo-0301" => 4096,
|
|
21
|
-
"text-davinci-003" => 4097,
|
|
22
|
-
"text-davinci-002" => 4097,
|
|
23
|
-
"code-davinci-002" => 8001,
|
|
24
|
-
"gpt-4" => 8192,
|
|
25
|
-
"gpt-4-0314" => 8192,
|
|
26
|
-
"gpt-4-32k" => 32768,
|
|
27
|
-
"gpt-4-32k-0314" => 32768,
|
|
28
|
-
"text-curie-001" => 2049,
|
|
29
|
-
"text-babbage-001" => 2049,
|
|
30
|
-
"text-ada-001" => 2049,
|
|
31
|
-
"davinci" => 2049,
|
|
32
|
-
"curie" => 2049,
|
|
33
|
-
"babbage" => 2049,
|
|
34
|
-
"ada" => 2049
|
|
35
|
-
}.freeze
|
|
36
|
-
|
|
37
|
-
# GOOGLE_PALM_TOKEN_LIMITS = {
|
|
38
|
-
# "chat-bison-001" => {
|
|
39
|
-
# "inputTokenLimit"=>4096,
|
|
40
|
-
# "outputTokenLimit"=>1024
|
|
41
|
-
# },
|
|
42
|
-
# "text-bison-001" => {
|
|
43
|
-
# "inputTokenLimit"=>8196,
|
|
44
|
-
# "outputTokenLimit"=>1024
|
|
45
|
-
# },
|
|
46
|
-
# "embedding-gecko-001" => {
|
|
47
|
-
# "inputTokenLimit"=>1024
|
|
48
|
-
# }
|
|
49
|
-
# }.freeze
|
|
50
|
-
|
|
51
|
-
#
|
|
52
|
-
# Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
|
|
53
|
-
#
|
|
54
|
-
# @param content [String | Array<String>] The text or array of texts to validate
|
|
55
|
-
# @param model_name [String] The model name to validate against
|
|
56
|
-
# @return [Integer] Whether the text is valid or not
|
|
57
|
-
# @raise [TokenLimitExceeded] If the text is too long
|
|
58
|
-
#
|
|
59
|
-
def self.validate_max_tokens!(content, model_name)
|
|
60
|
-
text_token_length = if content.is_a?(Array)
|
|
61
|
-
content.sum { |item| token_length(item.to_json, model_name) }
|
|
62
|
-
else
|
|
63
|
-
token_length(content, model_name)
|
|
64
|
-
end
|
|
65
|
-
|
|
66
|
-
max_tokens = TOKEN_LIMITS[model_name] - text_token_length
|
|
67
|
-
|
|
68
|
-
# Raise an error even if whole prompt is equal to the model's token limit (max_tokens == 0) since not response will be returned
|
|
69
|
-
if max_tokens <= 0
|
|
70
|
-
raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS[model_name]} tokens, but the given text is #{text_token_length} tokens long."
|
|
71
|
-
end
|
|
72
|
-
|
|
73
|
-
max_tokens
|
|
74
|
-
end
|
|
75
|
-
|
|
76
|
-
#
|
|
77
|
-
# Calculate token length for a given text and model name
|
|
78
|
-
#
|
|
79
|
-
# @param text [String] The text to validate
|
|
80
|
-
# @param model_name [String] The model name to validate against
|
|
81
|
-
# @return [Integer] The token length of the text
|
|
82
|
-
#
|
|
83
|
-
def self.token_length(text, model_name)
|
|
84
|
-
encoder = Tiktoken.encoding_for_model(model_name)
|
|
85
|
-
encoder.encode(text).length
|
|
86
|
-
end
|
|
87
|
-
end
|
|
88
|
-
end
|
|
89
|
-
end
|