langchainrb 0.5.3 → 0.5.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.env.example +1 -0
- data/CHANGELOG.md +10 -0
- data/Gemfile.lock +10 -1
- data/README.md +5 -2
- data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent.rb +1 -1
- data/lib/langchain/conversation.rb +66 -0
- data/lib/langchain/llm/google_palm.rb +47 -10
- data/lib/langchain/llm/openai.rb +45 -10
- data/lib/langchain/tool/base.rb +9 -0
- data/lib/langchain/tool/weather.rb +67 -0
- data/lib/langchain/utils/token_length/google_palm_validator.rb +69 -0
- data/lib/langchain/utils/token_length/openai_validator.rb +75 -0
- data/lib/langchain/vectorsearch/hnswlib.rb +122 -0
- data/lib/langchain/version.rb +1 -1
- data/lib/langchain.rb +9 -1
- metadata +36 -3
- data/lib/langchain/utils/token_length_validator.rb +0 -89
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 9781999daf45e5fedb0c7a905268866fbefd4581fd35a1a512ebb5844598f2c7
|
4
|
+
data.tar.gz: 93e2161a331151218cb94706827ab1ca2d94cb363613a5117ef7ce0c36cb9469
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 58255ecc90b645cf6b276bee83a385d567122fb8b90ad7a075f4b2ec90ba2f6871156c6e5762b0bd67b259371d1c35b3973fd6b4e631be858ea8a368aac163b7
|
7
|
+
data.tar.gz: ffcaba1dc980b1f3175269a223ebe522bae1a3e17931b1f911adf387f8fc4a32692d45712ab0032e5a9c273ba304d97ca0ce2086978381d17e716efa1941072a
|
data/.env.example
CHANGED
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,15 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [0.5.5] - 2023-06-12
|
4
|
+
- [BREAKING] Rename `Langchain::Chat` to `Langchain::Conversation`
|
5
|
+
- 🛠️ Tools
|
6
|
+
- Introducing `Langchain::Tool::Weather`, a tool that calls Open Weather API to retrieve the current weather
|
7
|
+
|
8
|
+
## [0.5.4] - 2023-06-10
|
9
|
+
- 🔍 Vectorsearch
|
10
|
+
- Introducing support for HNSWlib
|
11
|
+
- Improved and new `Langchain::Chat` interface that persists chat history in memory
|
12
|
+
|
3
13
|
## [0.5.3] - 2023-06-09
|
4
14
|
- 🗣️ LLMs
|
5
15
|
- Chat message history support for Langchain::LLM::GooglePalm and Langchain::LLM::OpenAI
|
data/Gemfile.lock
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
PATH
|
2
2
|
remote: .
|
3
3
|
specs:
|
4
|
-
langchainrb (0.5.
|
4
|
+
langchainrb (0.5.5)
|
5
5
|
colorize (~> 0.8.1)
|
6
6
|
tiktoken_ruby (~> 0.0.5)
|
7
7
|
|
@@ -135,6 +135,8 @@ GEM
|
|
135
135
|
activesupport (>= 3.0)
|
136
136
|
graphql
|
137
137
|
hashery (2.1.2)
|
138
|
+
hashie (5.0.0)
|
139
|
+
hnswlib (0.8.1)
|
138
140
|
httparty (0.21.0)
|
139
141
|
mini_mime (>= 1.0.0)
|
140
142
|
multi_xml (>= 0.5.2)
|
@@ -166,6 +168,11 @@ GEM
|
|
166
168
|
racc (~> 1.4)
|
167
169
|
nokogiri (1.14.3-x86_64-linux)
|
168
170
|
racc (~> 1.4)
|
171
|
+
open-weather-ruby-client (0.3.0)
|
172
|
+
activesupport
|
173
|
+
faraday (>= 1.0.0)
|
174
|
+
faraday_middleware
|
175
|
+
hashie
|
169
176
|
parallel (1.23.0)
|
170
177
|
parser (3.2.2.1)
|
171
178
|
ast (~> 2.4.1)
|
@@ -312,10 +319,12 @@ DEPENDENCIES
|
|
312
319
|
eqn (~> 1.6.5)
|
313
320
|
google_palm_api (~> 0.1.1)
|
314
321
|
google_search_results (~> 2.0.0)
|
322
|
+
hnswlib (~> 0.8.1)
|
315
323
|
hugging-face (~> 0.3.4)
|
316
324
|
langchainrb!
|
317
325
|
milvus (~> 0.9.0)
|
318
326
|
nokogiri (~> 1.13)
|
327
|
+
open-weather-ruby-client (~> 0.3.0)
|
319
328
|
pdf-reader (~> 1.4)
|
320
329
|
pg (~> 1.5)
|
321
330
|
pgvector (~> 0.2)
|
data/README.md
CHANGED
@@ -34,6 +34,7 @@ require "langchain"
|
|
34
34
|
| Database | Querying | Storage | Schema Management | Backups | Rails Integration |
|
35
35
|
| -------- |:------------------:| -------:| -----------------:| -------:| -----------------:|
|
36
36
|
| [Chroma](https://trychroma.com/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
37
|
+
| [Hnswlib](https://github.com/nmslib/hnswlib/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
37
38
|
| [Milvus](https://milvus.io/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
38
39
|
| [Pinecone](https://www.pinecone.io/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
39
40
|
| [Pgvector](https://github.com/pgvector/pgvector) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
@@ -56,11 +57,12 @@ client = Langchain::Vectorsearch::Weaviate.new(
|
|
56
57
|
)
|
57
58
|
|
58
59
|
# You can instantiate any other supported vector search database:
|
60
|
+
client = Langchain::Vectorsearch::Chroma.new(...) # `gem "chroma-db", "~> 0.3.0"`
|
61
|
+
client = Langchain::Vectorsearch::Hnswlib.new(...) # `gem "hnswlib", "~> 0.8.1"`
|
59
62
|
client = Langchain::Vectorsearch::Milvus.new(...) # `gem "milvus", "~> 0.9.0"`
|
60
|
-
client = Langchain::Vectorsearch::Qdrant.new(...) # `gem"qdrant-ruby", "~> 0.9.0"`
|
61
63
|
client = Langchain::Vectorsearch::Pinecone.new(...) # `gem "pinecone", "~> 0.1.6"`
|
62
|
-
client = Langchain::Vectorsearch::Chroma.new(...) # `gem "chroma-db", "~> 0.3.0"`
|
63
64
|
client = Langchain::Vectorsearch::Pgvector.new(...) # `gem "pgvector", "~> 0.2"`
|
65
|
+
client = Langchain::Vectorsearch::Qdrant.new(...) # `gem"qdrant-ruby", "~> 0.9.0"`
|
64
66
|
```
|
65
67
|
|
66
68
|
```ruby
|
@@ -307,6 +309,7 @@ agent.run(question: "How many users have a name with length greater than 5 in th
|
|
307
309
|
| "database" | Useful for querying a SQL database | | `gem "sequel", "~> 5.68.0"` |
|
308
310
|
| "ruby_code_interpreter" | Interprets Ruby expressions | | `gem "safe_ruby", "~> 1.0.4"` |
|
309
311
|
| "search" | A wrapper around Google Search | `ENV["SERPAPI_API_KEY"]` (https://serpapi.com/manage-api-key) | `gem "google_search_results", "~> 2.0.0"` |
|
312
|
+
| "weather" | Calls Open Weather API to retrieve the current weather | `ENV["OPEN_WEATHER_API_KEY]` (https://home.openweathermap.org/api_keys) | `gem "open-weather-ruby-client", "~> 0.3.0"` |
|
310
313
|
| "wikipedia" | Calls Wikipedia API to retrieve the summary | | `gem "wikipedia-client", "~> 1.17.0"` |
|
311
314
|
|
312
315
|
#### Loaders 🚚
|
@@ -101,7 +101,7 @@ module Langchain::Agent
|
|
101
101
|
tool_names: "[#{tool_list.join(", ")}]",
|
102
102
|
tools: tools.map do |tool|
|
103
103
|
tool_name = tool.tool_name
|
104
|
-
tool_description = tool.
|
104
|
+
tool_description = tool.tool_description
|
105
105
|
"#{tool_name}: #{tool_description}"
|
106
106
|
end.join("\n")
|
107
107
|
)
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
#
|
5
|
+
# A high-level API for running a conversation with an LLM.
|
6
|
+
# Currently supports: OpenAI and Google PaLM LLMs.
|
7
|
+
#
|
8
|
+
# Usage:
|
9
|
+
# llm = Langchain::LLM::OpenAI.new(api_key: "YOUR_API_KEY")
|
10
|
+
# chat = Langchain::Conversation.new(llm: llm)
|
11
|
+
# chat.set_context("You are a chatbot from the future")
|
12
|
+
# chat.message("Tell me about future technologies")
|
13
|
+
#
|
14
|
+
class Conversation
|
15
|
+
attr_reader :context, :examples, :messages
|
16
|
+
|
17
|
+
# Intialize Conversation with a LLM
|
18
|
+
#
|
19
|
+
# @param llm [Object] The LLM to use for the conversation
|
20
|
+
# @param options [Hash] Options to pass to the LLM, like temperature, top_k, etc.
|
21
|
+
# @return [Langchain::Conversation] The Langchain::Conversation instance
|
22
|
+
def initialize(llm:, **options)
|
23
|
+
@llm = llm
|
24
|
+
@options = options
|
25
|
+
@context = nil
|
26
|
+
@examples = []
|
27
|
+
@messages = []
|
28
|
+
end
|
29
|
+
|
30
|
+
# Set the context of the conversation. Usually used to set the model's persona.
|
31
|
+
# @param message [String] The context of the conversation
|
32
|
+
def set_context(message)
|
33
|
+
@context = message
|
34
|
+
end
|
35
|
+
|
36
|
+
# Add examples to the conversation. Used to give the model a sense of the conversation.
|
37
|
+
# @param examples [Array<Hash>] The examples to add to the conversation
|
38
|
+
def add_examples(examples)
|
39
|
+
@examples.concat examples
|
40
|
+
end
|
41
|
+
|
42
|
+
# Message the model with a prompt and return the response.
|
43
|
+
# @param message [String] The prompt to message the model with
|
44
|
+
# @return [String] The response from the model
|
45
|
+
def message(message)
|
46
|
+
append_user_message(message)
|
47
|
+
response = llm_response(message)
|
48
|
+
append_ai_message(response)
|
49
|
+
response
|
50
|
+
end
|
51
|
+
|
52
|
+
private
|
53
|
+
|
54
|
+
def llm_response(prompt)
|
55
|
+
@llm.chat(messages: @messages, context: @context, examples: @examples, **@options)
|
56
|
+
end
|
57
|
+
|
58
|
+
def append_ai_message(message)
|
59
|
+
@messages << {role: "ai", content: message}
|
60
|
+
end
|
61
|
+
|
62
|
+
def append_user_message(message)
|
63
|
+
@messages << {role: "user", content: message}
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
@@ -80,28 +80,31 @@ module Langchain::LLM
|
|
80
80
|
# @param params extra parameters passed to GooglePalmAPI::Client#generate_chat_message
|
81
81
|
# @return [String] The chat completion
|
82
82
|
#
|
83
|
-
def chat(prompt: "", messages: [], **
|
83
|
+
def chat(prompt: "", messages: [], context: "", examples: [], **options)
|
84
84
|
raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
|
85
85
|
|
86
|
-
messages << {author: "0", content: prompt} if !prompt.empty?
|
87
|
-
|
88
|
-
# TODO: Figure out how to introduce persisted conversations
|
89
86
|
default_params = {
|
90
87
|
temperature: DEFAULTS[:temperature],
|
91
|
-
|
88
|
+
context: context,
|
89
|
+
messages: compose_chat_messages(prompt: prompt, messages: messages),
|
90
|
+
examples: compose_examples(examples)
|
92
91
|
}
|
93
92
|
|
94
|
-
|
95
|
-
|
93
|
+
Langchain::Utils::TokenLength::GooglePalmValidator.validate_max_tokens!(self, default_params[:messages], "chat-bison-001")
|
94
|
+
|
95
|
+
if options[:stop_sequences]
|
96
|
+
default_params[:stop] = options.delete(:stop_sequences)
|
96
97
|
end
|
97
98
|
|
98
|
-
if
|
99
|
-
default_params[:max_output_tokens] =
|
99
|
+
if options[:max_tokens]
|
100
|
+
default_params[:max_output_tokens] = options.delete(:max_tokens)
|
100
101
|
end
|
101
102
|
|
102
|
-
default_params.merge!(
|
103
|
+
default_params.merge!(options)
|
103
104
|
|
104
105
|
response = client.generate_chat_message(**default_params)
|
106
|
+
raise "GooglePalm API returned an error: #{response}" if response.dig("error")
|
107
|
+
|
105
108
|
response.dig("candidates", 0, "content")
|
106
109
|
end
|
107
110
|
|
@@ -124,5 +127,39 @@ module Langchain::LLM
|
|
124
127
|
max_tokens: 2048
|
125
128
|
)
|
126
129
|
end
|
130
|
+
|
131
|
+
private
|
132
|
+
|
133
|
+
def compose_chat_messages(prompt:, messages:)
|
134
|
+
history = []
|
135
|
+
history.concat transform_messages(messages) unless messages.empty?
|
136
|
+
|
137
|
+
unless prompt.empty?
|
138
|
+
if history.last && history.last[:role] == "user"
|
139
|
+
history.last[:content] += "\n#{prompt}"
|
140
|
+
else
|
141
|
+
history.append({author: "user", content: prompt})
|
142
|
+
end
|
143
|
+
end
|
144
|
+
history
|
145
|
+
end
|
146
|
+
|
147
|
+
def compose_examples(examples)
|
148
|
+
examples.each_slice(2).map do |example|
|
149
|
+
{
|
150
|
+
input: {content: example.first[:content]},
|
151
|
+
output: {content: example.last[:content]}
|
152
|
+
}
|
153
|
+
end
|
154
|
+
end
|
155
|
+
|
156
|
+
def transform_messages(messages)
|
157
|
+
messages.map do |message|
|
158
|
+
{
|
159
|
+
author: message[:role],
|
160
|
+
content: message[:content]
|
161
|
+
}
|
162
|
+
end
|
163
|
+
end
|
127
164
|
end
|
128
165
|
end
|
data/lib/langchain/llm/openai.rb
CHANGED
@@ -35,7 +35,7 @@ module Langchain::LLM
|
|
35
35
|
def embed(text:, **params)
|
36
36
|
parameters = {model: DEFAULTS[:embeddings_model_name], input: text}
|
37
37
|
|
38
|
-
Langchain::Utils::
|
38
|
+
Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(text, parameters[:model])
|
39
39
|
|
40
40
|
response = client.embeddings(parameters: parameters.merge(params))
|
41
41
|
response.dig("data").first.dig("embedding")
|
@@ -52,7 +52,7 @@ module Langchain::LLM
|
|
52
52
|
parameters = compose_parameters DEFAULTS[:completion_model_name], params
|
53
53
|
|
54
54
|
parameters[:prompt] = prompt
|
55
|
-
parameters[:max_tokens] = Langchain::Utils::
|
55
|
+
parameters[:max_tokens] = Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(prompt, parameters[:model])
|
56
56
|
|
57
57
|
response = client.completions(parameters: parameters)
|
58
58
|
response.dig("choices", 0, "text")
|
@@ -63,19 +63,22 @@ module Langchain::LLM
|
|
63
63
|
#
|
64
64
|
# @param prompt [String] The prompt to generate a chat completion for
|
65
65
|
# @param messages [Array] The messages that have been sent in the conversation
|
66
|
-
# @param
|
66
|
+
# @param context [String] The context of the conversation
|
67
|
+
# @param examples [Array] Examples of messages provide model with
|
68
|
+
# @param options extra parameters passed to OpenAI::Client#chat
|
67
69
|
# @return [String] The chat completion
|
68
70
|
#
|
69
|
-
def chat(prompt: "", messages: [], **
|
71
|
+
def chat(prompt: "", messages: [], context: "", examples: [], **options)
|
70
72
|
raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
|
71
73
|
|
72
|
-
|
73
|
-
|
74
|
-
parameters =
|
75
|
-
parameters[:messages] = messages
|
76
|
-
parameters[:max_tokens] = validate_max_tokens(messages, parameters[:model])
|
74
|
+
parameters = compose_parameters DEFAULTS[:chat_completion_model_name], options
|
75
|
+
parameters[:messages] = compose_chat_messages(prompt: prompt, messages: messages, context: context, examples: examples)
|
76
|
+
parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
|
77
77
|
|
78
78
|
response = client.chat(parameters: parameters)
|
79
|
+
|
80
|
+
raise "Chat completion failed: #{response}" if response.dig("error")
|
81
|
+
|
79
82
|
response.dig("choices", 0, "message", "content")
|
80
83
|
end
|
81
84
|
|
@@ -104,8 +107,40 @@ module Langchain::LLM
|
|
104
107
|
default_params.merge(params)
|
105
108
|
end
|
106
109
|
|
110
|
+
def compose_chat_messages(prompt:, messages:, context:, examples:)
|
111
|
+
history = []
|
112
|
+
|
113
|
+
history.concat transform_messages(examples) unless examples.empty?
|
114
|
+
|
115
|
+
history.concat transform_messages(messages) unless messages.empty?
|
116
|
+
|
117
|
+
unless context.nil? || context.empty?
|
118
|
+
history.reject! { |message| message[:role] == "system" }
|
119
|
+
history.prepend({role: "system", content: context})
|
120
|
+
end
|
121
|
+
|
122
|
+
unless prompt.empty?
|
123
|
+
if history.last && history.last[:role] == "user"
|
124
|
+
history.last[:content] += "\n#{prompt}"
|
125
|
+
else
|
126
|
+
history.append({role: "user", content: prompt})
|
127
|
+
end
|
128
|
+
end
|
129
|
+
|
130
|
+
history
|
131
|
+
end
|
132
|
+
|
133
|
+
def transform_messages(messages)
|
134
|
+
messages.map do |message|
|
135
|
+
{
|
136
|
+
content: message[:content],
|
137
|
+
role: (message[:role] == "ai") ? "assistant" : message[:role]
|
138
|
+
}
|
139
|
+
end
|
140
|
+
end
|
141
|
+
|
107
142
|
def validate_max_tokens(messages, model)
|
108
|
-
Langchain::Utils::
|
143
|
+
Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(messages, model)
|
109
144
|
end
|
110
145
|
end
|
111
146
|
end
|
data/lib/langchain/tool/base.rb
CHANGED
@@ -57,6 +57,15 @@ module Langchain::Tool
|
|
57
57
|
self.class.const_get(:NAME)
|
58
58
|
end
|
59
59
|
|
60
|
+
#
|
61
|
+
# Returns the DESCRIPTION constant of the tool
|
62
|
+
#
|
63
|
+
# @return [String] tool description
|
64
|
+
#
|
65
|
+
def tool_description
|
66
|
+
self.class.const_get(:DESCRIPTION)
|
67
|
+
end
|
68
|
+
|
60
69
|
#
|
61
70
|
# Sets the DESCRIPTION constant of the tool
|
62
71
|
#
|
@@ -0,0 +1,67 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain::Tool
|
4
|
+
class Weather < Base
|
5
|
+
#
|
6
|
+
# A weather tool that gets current weather data
|
7
|
+
#
|
8
|
+
# Current weather data is free for 1000 calls per day (https://home.openweathermap.org/api_keys)
|
9
|
+
# Forecast and historical data require registration with credit card, so not supported yet.
|
10
|
+
#
|
11
|
+
# Gem requirements:
|
12
|
+
# gem "open-weather-ruby-client", "~> 0.3.0"
|
13
|
+
# api_key: https://home.openweathermap.org/api_keys
|
14
|
+
#
|
15
|
+
# Usage:
|
16
|
+
# weather = Langchain::Tool::Weather.new(api_key: "YOUR_API_KEY")
|
17
|
+
# weather.execute(input: "Boston, MA; imperial")
|
18
|
+
#
|
19
|
+
|
20
|
+
NAME = "weather"
|
21
|
+
|
22
|
+
description <<~DESC
|
23
|
+
Useful for getting current weather data
|
24
|
+
|
25
|
+
The input to this tool should be a city name followed by the units (imperial, metric, or standard)
|
26
|
+
Usage:
|
27
|
+
Action Input: St Louis, Missouri; metric
|
28
|
+
Action Input: Boston, Massachusetts; imperial
|
29
|
+
Action Input: Dubai, AE; imperial
|
30
|
+
Action Input: Kiev, Ukraine; metric
|
31
|
+
DESC
|
32
|
+
|
33
|
+
attr_reader :client, :units
|
34
|
+
|
35
|
+
#
|
36
|
+
# Initializes the Weather tool
|
37
|
+
#
|
38
|
+
# @param api_key [String] Open Weather API key
|
39
|
+
# @return [Langchain::Tool::Weather] Weather tool
|
40
|
+
#
|
41
|
+
def initialize(api_key:, units: "metric")
|
42
|
+
depends_on "open-weather-ruby-client"
|
43
|
+
require "open-weather-ruby-client"
|
44
|
+
|
45
|
+
OpenWeather::Client.configure do |config|
|
46
|
+
config.api_key = api_key
|
47
|
+
config.user_agent = "Langchainrb Ruby Client"
|
48
|
+
end
|
49
|
+
|
50
|
+
@client = OpenWeather::Client.new
|
51
|
+
end
|
52
|
+
|
53
|
+
# Returns current weather for a city
|
54
|
+
# @param input [String] comma separated city and unit (optional: imperial, metric, or standard)
|
55
|
+
# @return [String] Answer
|
56
|
+
def execute(input:)
|
57
|
+
Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing for \"#{input}\"")
|
58
|
+
|
59
|
+
input_array = input.split(";")
|
60
|
+
city, units = *input_array.map(&:strip)
|
61
|
+
|
62
|
+
data = client.current_weather(city: city, units: units)
|
63
|
+
weather = data.main.map { |key, value| "#{key} #{value}" }.join(", ")
|
64
|
+
"The current weather in #{data.name} is #{weather}"
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
module Utils
|
5
|
+
module TokenLength
|
6
|
+
#
|
7
|
+
# This class is meant to validate the length of the text passed in to Google Palm's API.
|
8
|
+
# It is used to validate the token length before the API call is made
|
9
|
+
#
|
10
|
+
class GooglePalmValidator
|
11
|
+
TOKEN_LIMITS = {
|
12
|
+
# Source:
|
13
|
+
# This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage
|
14
|
+
|
15
|
+
# chat-bison-001 is the only model that currently supports countMessageTokens functions
|
16
|
+
"chat-bison-001" => {
|
17
|
+
"input_token_limit" => 4000, # 4096 is the limit but the countMessageTokens does not return anything higher than 4000
|
18
|
+
"output_token_limit" => 1024
|
19
|
+
}
|
20
|
+
# "text-bison-001" => {
|
21
|
+
# "input_token_limit" => 8196,
|
22
|
+
# "output_token_limit" => 1024
|
23
|
+
# },
|
24
|
+
# "embedding-gecko-001" => {
|
25
|
+
# "input_token_limit" => 1024
|
26
|
+
# }
|
27
|
+
}.freeze
|
28
|
+
|
29
|
+
#
|
30
|
+
# Validate the context length of the text
|
31
|
+
#
|
32
|
+
# @param content [String | Array<String>] The text or array of texts to validate
|
33
|
+
# @param model_name [String] The model name to validate against
|
34
|
+
# @return [Integer] Whether the text is valid or not
|
35
|
+
# @raise [TokenLimitExceeded] If the text is too long
|
36
|
+
#
|
37
|
+
def self.validate_max_tokens!(google_palm_llm, content, model_name)
|
38
|
+
text_token_length = if content.is_a?(Array)
|
39
|
+
content.sum { |item| token_length(google_palm_llm, item.to_json, model_name) }
|
40
|
+
else
|
41
|
+
token_length(google_palm_llm, content, model_name)
|
42
|
+
end
|
43
|
+
|
44
|
+
leftover_tokens = TOKEN_LIMITS.dig(model_name, "input_token_limit") - text_token_length
|
45
|
+
|
46
|
+
# Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
|
47
|
+
if leftover_tokens <= 0
|
48
|
+
raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS.dig(model_name, "input_token_limit")} tokens, but the given text is #{text_token_length} tokens long."
|
49
|
+
end
|
50
|
+
|
51
|
+
leftover_tokens
|
52
|
+
end
|
53
|
+
|
54
|
+
#
|
55
|
+
# Calculate token length for a given text and model name
|
56
|
+
#
|
57
|
+
# @param llm [Langchain::LLM:GooglePalm] The Langchain::LLM:GooglePalm instance
|
58
|
+
# @param text [String] The text to calculate the token length for
|
59
|
+
# @param model_name [String] The model name to validate against
|
60
|
+
# @return [Integer] The token length of the text
|
61
|
+
#
|
62
|
+
def self.token_length(llm, text, model_name = "chat-bison-001")
|
63
|
+
response = llm.client.count_message_tokens(model: model_name, prompt: text)
|
64
|
+
response.dig("tokenCount")
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
69
|
+
end
|
@@ -0,0 +1,75 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "tiktoken_ruby"
|
4
|
+
|
5
|
+
module Langchain
|
6
|
+
module Utils
|
7
|
+
module TokenLength
|
8
|
+
#
|
9
|
+
# This class is meant to validate the length of the text passed in to OpenAI's API.
|
10
|
+
# It is used to validate the token length before the API call is made
|
11
|
+
#
|
12
|
+
class OpenAIValidator
|
13
|
+
TOKEN_LIMITS = {
|
14
|
+
# Source:
|
15
|
+
# https://platform.openai.com/docs/api-reference/embeddings
|
16
|
+
# https://platform.openai.com/docs/models/gpt-4
|
17
|
+
"text-embedding-ada-002" => 8191,
|
18
|
+
"gpt-3.5-turbo" => 4096,
|
19
|
+
"gpt-3.5-turbo-0301" => 4096,
|
20
|
+
"text-davinci-003" => 4097,
|
21
|
+
"text-davinci-002" => 4097,
|
22
|
+
"code-davinci-002" => 8001,
|
23
|
+
"gpt-4" => 8192,
|
24
|
+
"gpt-4-0314" => 8192,
|
25
|
+
"gpt-4-32k" => 32768,
|
26
|
+
"gpt-4-32k-0314" => 32768,
|
27
|
+
"text-curie-001" => 2049,
|
28
|
+
"text-babbage-001" => 2049,
|
29
|
+
"text-ada-001" => 2049,
|
30
|
+
"davinci" => 2049,
|
31
|
+
"curie" => 2049,
|
32
|
+
"babbage" => 2049,
|
33
|
+
"ada" => 2049
|
34
|
+
}.freeze
|
35
|
+
|
36
|
+
#
|
37
|
+
# Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
|
38
|
+
#
|
39
|
+
# @param content [String | Array<String>] The text or array of texts to validate
|
40
|
+
# @param model_name [String] The model name to validate against
|
41
|
+
# @return [Integer] Whether the text is valid or not
|
42
|
+
# @raise [TokenLimitExceeded] If the text is too long
|
43
|
+
#
|
44
|
+
def self.validate_max_tokens!(content, model_name)
|
45
|
+
text_token_length = if content.is_a?(Array)
|
46
|
+
content.sum { |item| token_length(item.to_json, model_name) }
|
47
|
+
else
|
48
|
+
token_length(content, model_name)
|
49
|
+
end
|
50
|
+
|
51
|
+
max_tokens = TOKEN_LIMITS[model_name] - text_token_length
|
52
|
+
|
53
|
+
# Raise an error even if whole prompt is equal to the model's token limit (max_tokens == 0) since not response will be returned
|
54
|
+
if max_tokens <= 0
|
55
|
+
raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS[model_name]} tokens, but the given text is #{text_token_length} tokens long."
|
56
|
+
end
|
57
|
+
|
58
|
+
max_tokens
|
59
|
+
end
|
60
|
+
|
61
|
+
#
|
62
|
+
# Calculate token length for a given text and model name
|
63
|
+
#
|
64
|
+
# @param text [String] The text to calculate the token length for
|
65
|
+
# @param model_name [String] The model name to validate against
|
66
|
+
# @return [Integer] The token length of the text
|
67
|
+
#
|
68
|
+
def self.token_length(text, model_name)
|
69
|
+
encoder = Tiktoken.encoding_for_model(model_name)
|
70
|
+
encoder.encode(text).length
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
@@ -0,0 +1,122 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain::Vectorsearch
|
4
|
+
class Hnswlib < Base
|
5
|
+
#
|
6
|
+
# Wrapper around HNSW (Hierarchical Navigable Small World) library.
|
7
|
+
# HNSWLib is an in-memory vectorstore that can be saved to a file on disk.
|
8
|
+
#
|
9
|
+
# Gem requirements:
|
10
|
+
# gem "hnswlib", "~> 0.8.1"
|
11
|
+
#
|
12
|
+
# Usage:
|
13
|
+
# hnsw = Langchain::Vectorsearch::Hnswlib.new(llm:, url:, index_name:)
|
14
|
+
#
|
15
|
+
|
16
|
+
attr_reader :client, :path_to_index
|
17
|
+
|
18
|
+
#
|
19
|
+
# Initialize the HNSW vector search
|
20
|
+
#
|
21
|
+
# @param llm [Object] The LLM client to use
|
22
|
+
# @param path_to_index [String] The local path to the index file, e.g.: "/storage/index.ann"
|
23
|
+
# @return [Langchain::Vectorsearch::Hnswlib] Class instance
|
24
|
+
#
|
25
|
+
def initialize(llm:, path_to_index:)
|
26
|
+
depends_on "hnswlib"
|
27
|
+
require "hnswlib"
|
28
|
+
|
29
|
+
super(llm: llm)
|
30
|
+
|
31
|
+
@client = ::Hnswlib::HierarchicalNSW.new(space: DEFAULT_METRIC, dim: llm.default_dimension)
|
32
|
+
@path_to_index = path_to_index
|
33
|
+
|
34
|
+
initialize_index
|
35
|
+
end
|
36
|
+
|
37
|
+
#
|
38
|
+
# Add a list of texts and corresponding IDs to the index
|
39
|
+
#
|
40
|
+
# @param texts [Array] The list of texts to add
|
41
|
+
# @param ids [Array] The list of corresponding IDs (integers) to the texts
|
42
|
+
# @return [Boolean] The response from the HNSW library
|
43
|
+
#
|
44
|
+
def add_texts(texts:, ids:)
|
45
|
+
resize_index(texts.size)
|
46
|
+
|
47
|
+
Array(texts).each_with_index do |text, i|
|
48
|
+
embedding = llm.embed(text: text)
|
49
|
+
|
50
|
+
client.add_point(embedding, ids[i])
|
51
|
+
end
|
52
|
+
|
53
|
+
client.save_index(path_to_index)
|
54
|
+
end
|
55
|
+
|
56
|
+
#
|
57
|
+
# Search for similar texts
|
58
|
+
#
|
59
|
+
# @param query [String] The text to search for
|
60
|
+
# @param k [Integer] The number of results to return
|
61
|
+
# @return [Array] Results in the format `[[id1, distance3], [id2, distance2]]`
|
62
|
+
#
|
63
|
+
def similarity_search(
|
64
|
+
query:,
|
65
|
+
k: 4
|
66
|
+
)
|
67
|
+
embedding = llm.embed(text: query)
|
68
|
+
|
69
|
+
similarity_search_by_vector(
|
70
|
+
embedding: embedding,
|
71
|
+
k: k
|
72
|
+
)
|
73
|
+
end
|
74
|
+
|
75
|
+
#
|
76
|
+
# Search for the K nearest neighbors of a given vector
|
77
|
+
#
|
78
|
+
# @param embedding [Array] The embedding to search for
|
79
|
+
# @param k [Integer] The number of results to return
|
80
|
+
# @return [Array] Results in the format `[[id1, distance3], [id2, distance2]]`
|
81
|
+
#
|
82
|
+
def similarity_search_by_vector(
|
83
|
+
embedding:,
|
84
|
+
k: 4
|
85
|
+
)
|
86
|
+
client.search_knn(embedding, k)
|
87
|
+
end
|
88
|
+
|
89
|
+
private
|
90
|
+
|
91
|
+
#
|
92
|
+
# Optionally resizes the index if there's no space for new data
|
93
|
+
#
|
94
|
+
# @param num_of_elements_to_add [Integer] The number of elements to add to the index
|
95
|
+
#
|
96
|
+
def resize_index(num_of_elements_to_add)
|
97
|
+
current_count = client.current_count
|
98
|
+
|
99
|
+
if (current_count + num_of_elements_to_add) > client.max_elements
|
100
|
+
new_size = current_count + num_of_elements_to_add
|
101
|
+
|
102
|
+
client.resize_index(new_size)
|
103
|
+
end
|
104
|
+
end
|
105
|
+
|
106
|
+
#
|
107
|
+
# Loads or initializes the new index
|
108
|
+
#
|
109
|
+
def initialize_index
|
110
|
+
if File.exist?(path_to_index)
|
111
|
+
client.load_index(path_to_index)
|
112
|
+
|
113
|
+
Langchain.logger.info("[#{self.class.name}]".blue + ": Successfully loaded the index at \"#{path_to_index}\"")
|
114
|
+
else
|
115
|
+
# Default max_elements: 100, but we constantly resize the index as new data is written to it
|
116
|
+
client.init_index(max_elements: 100)
|
117
|
+
|
118
|
+
Langchain.logger.info("[#{self.class.name}]".blue + ": Creating a new index at \"#{path_to_index}\"")
|
119
|
+
end
|
120
|
+
end
|
121
|
+
end
|
122
|
+
end
|
data/lib/langchain/version.rb
CHANGED
data/lib/langchain.rb
CHANGED
@@ -62,6 +62,7 @@ module Langchain
|
|
62
62
|
|
63
63
|
autoload :Loader, "langchain/loader"
|
64
64
|
autoload :Data, "langchain/data"
|
65
|
+
autoload :Conversation, "langchain/conversation"
|
65
66
|
autoload :DependencyHelper, "langchain/dependency_helper"
|
66
67
|
|
67
68
|
module Agent
|
@@ -75,6 +76,7 @@ module Langchain
|
|
75
76
|
autoload :Calculator, "langchain/tool/calculator"
|
76
77
|
autoload :RubyCodeInterpreter, "langchain/tool/ruby_code_interpreter"
|
77
78
|
autoload :SerpApi, "langchain/tool/serp_api"
|
79
|
+
autoload :Weather, "langchain/tool/weather"
|
78
80
|
autoload :Wikipedia, "langchain/tool/wikipedia"
|
79
81
|
autoload :Database, "langchain/tool/database"
|
80
82
|
end
|
@@ -92,12 +94,18 @@ module Langchain
|
|
92
94
|
end
|
93
95
|
|
94
96
|
module Utils
|
95
|
-
|
97
|
+
module TokenLength
|
98
|
+
class TokenLimitExceeded < StandardError; end
|
99
|
+
|
100
|
+
autoload :OpenAIValidator, "langchain/utils/token_length/openai_validator"
|
101
|
+
autoload :GooglePalmValidator, "langchain/utils/token_length/google_palm_validator"
|
102
|
+
end
|
96
103
|
end
|
97
104
|
|
98
105
|
module Vectorsearch
|
99
106
|
autoload :Base, "langchain/vectorsearch/base"
|
100
107
|
autoload :Chroma, "langchain/vectorsearch/chroma"
|
108
|
+
autoload :Hnswlib, "langchain/vectorsearch/hnswlib"
|
101
109
|
autoload :Milvus, "langchain/vectorsearch/milvus"
|
102
110
|
autoload :Pinecone, "langchain/vectorsearch/pinecone"
|
103
111
|
autoload :Pgvector, "langchain/vectorsearch/pgvector"
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: langchainrb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.5.
|
4
|
+
version: 0.5.5
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrei Bondarev
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-06-
|
11
|
+
date: 2023-06-12 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: tiktoken_ruby
|
@@ -192,6 +192,20 @@ dependencies:
|
|
192
192
|
- - "~>"
|
193
193
|
- !ruby/object:Gem::Version
|
194
194
|
version: 2.0.0
|
195
|
+
- !ruby/object:Gem::Dependency
|
196
|
+
name: hnswlib
|
197
|
+
requirement: !ruby/object:Gem::Requirement
|
198
|
+
requirements:
|
199
|
+
- - "~>"
|
200
|
+
- !ruby/object:Gem::Version
|
201
|
+
version: 0.8.1
|
202
|
+
type: :development
|
203
|
+
prerelease: false
|
204
|
+
version_requirements: !ruby/object:Gem::Requirement
|
205
|
+
requirements:
|
206
|
+
- - "~>"
|
207
|
+
- !ruby/object:Gem::Version
|
208
|
+
version: 0.8.1
|
195
209
|
- !ruby/object:Gem::Dependency
|
196
210
|
name: hugging-face
|
197
211
|
requirement: !ruby/object:Gem::Requirement
|
@@ -234,6 +248,20 @@ dependencies:
|
|
234
248
|
- - "~>"
|
235
249
|
- !ruby/object:Gem::Version
|
236
250
|
version: '1.13'
|
251
|
+
- !ruby/object:Gem::Dependency
|
252
|
+
name: open-weather-ruby-client
|
253
|
+
requirement: !ruby/object:Gem::Requirement
|
254
|
+
requirements:
|
255
|
+
- - "~>"
|
256
|
+
- !ruby/object:Gem::Version
|
257
|
+
version: 0.3.0
|
258
|
+
type: :development
|
259
|
+
prerelease: false
|
260
|
+
version_requirements: !ruby/object:Gem::Requirement
|
261
|
+
requirements:
|
262
|
+
- - "~>"
|
263
|
+
- !ruby/object:Gem::Version
|
264
|
+
version: 0.3.0
|
237
265
|
- !ruby/object:Gem::Dependency
|
238
266
|
name: pg
|
239
267
|
requirement: !ruby/object:Gem::Requirement
|
@@ -432,6 +460,7 @@ files:
|
|
432
460
|
- lib/langchain/agent/sql_query_agent/sql_query_agent.rb
|
433
461
|
- lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.json
|
434
462
|
- lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.json
|
463
|
+
- lib/langchain/conversation.rb
|
435
464
|
- lib/langchain/data.rb
|
436
465
|
- lib/langchain/dependency_helper.rb
|
437
466
|
- lib/langchain/llm/ai21.rb
|
@@ -461,10 +490,13 @@ files:
|
|
461
490
|
- lib/langchain/tool/database.rb
|
462
491
|
- lib/langchain/tool/ruby_code_interpreter.rb
|
463
492
|
- lib/langchain/tool/serp_api.rb
|
493
|
+
- lib/langchain/tool/weather.rb
|
464
494
|
- lib/langchain/tool/wikipedia.rb
|
465
|
-
- lib/langchain/utils/
|
495
|
+
- lib/langchain/utils/token_length/google_palm_validator.rb
|
496
|
+
- lib/langchain/utils/token_length/openai_validator.rb
|
466
497
|
- lib/langchain/vectorsearch/base.rb
|
467
498
|
- lib/langchain/vectorsearch/chroma.rb
|
499
|
+
- lib/langchain/vectorsearch/hnswlib.rb
|
468
500
|
- lib/langchain/vectorsearch/milvus.rb
|
469
501
|
- lib/langchain/vectorsearch/pgvector.rb
|
470
502
|
- lib/langchain/vectorsearch/pinecone.rb
|
@@ -480,6 +512,7 @@ metadata:
|
|
480
512
|
homepage_uri: https://rubygems.org/gems/langchainrb
|
481
513
|
source_code_uri: https://github.com/andreibondarev/langchainrb
|
482
514
|
changelog_uri: https://github.com/andreibondarev/langchainrb/CHANGELOG.md
|
515
|
+
documentation_uri: https://rubydoc.info/gems/langchainrb
|
483
516
|
post_install_message:
|
484
517
|
rdoc_options: []
|
485
518
|
require_paths:
|
@@ -1,89 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require "tiktoken_ruby"
|
4
|
-
|
5
|
-
module Langchain
|
6
|
-
module Utils
|
7
|
-
class TokenLimitExceeded < StandardError; end
|
8
|
-
|
9
|
-
class TokenLengthValidator
|
10
|
-
#
|
11
|
-
# This class is meant to validate the length of the text passed in to OpenAI's API.
|
12
|
-
# It is used to validate the token length before the API call is made
|
13
|
-
#
|
14
|
-
TOKEN_LIMITS = {
|
15
|
-
# Source:
|
16
|
-
# https://platform.openai.com/docs/api-reference/embeddings
|
17
|
-
# https://platform.openai.com/docs/models/gpt-4
|
18
|
-
"text-embedding-ada-002" => 8191,
|
19
|
-
"gpt-3.5-turbo" => 4096,
|
20
|
-
"gpt-3.5-turbo-0301" => 4096,
|
21
|
-
"text-davinci-003" => 4097,
|
22
|
-
"text-davinci-002" => 4097,
|
23
|
-
"code-davinci-002" => 8001,
|
24
|
-
"gpt-4" => 8192,
|
25
|
-
"gpt-4-0314" => 8192,
|
26
|
-
"gpt-4-32k" => 32768,
|
27
|
-
"gpt-4-32k-0314" => 32768,
|
28
|
-
"text-curie-001" => 2049,
|
29
|
-
"text-babbage-001" => 2049,
|
30
|
-
"text-ada-001" => 2049,
|
31
|
-
"davinci" => 2049,
|
32
|
-
"curie" => 2049,
|
33
|
-
"babbage" => 2049,
|
34
|
-
"ada" => 2049
|
35
|
-
}.freeze
|
36
|
-
|
37
|
-
# GOOGLE_PALM_TOKEN_LIMITS = {
|
38
|
-
# "chat-bison-001" => {
|
39
|
-
# "inputTokenLimit"=>4096,
|
40
|
-
# "outputTokenLimit"=>1024
|
41
|
-
# },
|
42
|
-
# "text-bison-001" => {
|
43
|
-
# "inputTokenLimit"=>8196,
|
44
|
-
# "outputTokenLimit"=>1024
|
45
|
-
# },
|
46
|
-
# "embedding-gecko-001" => {
|
47
|
-
# "inputTokenLimit"=>1024
|
48
|
-
# }
|
49
|
-
# }.freeze
|
50
|
-
|
51
|
-
#
|
52
|
-
# Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
|
53
|
-
#
|
54
|
-
# @param content [String | Array<String>] The text or array of texts to validate
|
55
|
-
# @param model_name [String] The model name to validate against
|
56
|
-
# @return [Integer] Whether the text is valid or not
|
57
|
-
# @raise [TokenLimitExceeded] If the text is too long
|
58
|
-
#
|
59
|
-
def self.validate_max_tokens!(content, model_name)
|
60
|
-
text_token_length = if content.is_a?(Array)
|
61
|
-
content.sum { |item| token_length(item.to_json, model_name) }
|
62
|
-
else
|
63
|
-
token_length(content, model_name)
|
64
|
-
end
|
65
|
-
|
66
|
-
max_tokens = TOKEN_LIMITS[model_name] - text_token_length
|
67
|
-
|
68
|
-
# Raise an error even if whole prompt is equal to the model's token limit (max_tokens == 0) since not response will be returned
|
69
|
-
if max_tokens <= 0
|
70
|
-
raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS[model_name]} tokens, but the given text is #{text_token_length} tokens long."
|
71
|
-
end
|
72
|
-
|
73
|
-
max_tokens
|
74
|
-
end
|
75
|
-
|
76
|
-
#
|
77
|
-
# Calculate token length for a given text and model name
|
78
|
-
#
|
79
|
-
# @param text [String] The text to validate
|
80
|
-
# @param model_name [String] The model name to validate against
|
81
|
-
# @return [Integer] The token length of the text
|
82
|
-
#
|
83
|
-
def self.token_length(text, model_name)
|
84
|
-
encoder = Tiktoken.encoding_for_model(model_name)
|
85
|
-
encoder.encode(text).length
|
86
|
-
end
|
87
|
-
end
|
88
|
-
end
|
89
|
-
end
|