langchainrb 0.5.3 → 0.5.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/Gemfile.lock +3 -1
- data/README.md +4 -2
- data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent.rb +1 -1
- data/lib/langchain/chat.rb +50 -0
- data/lib/langchain/llm/google_palm.rb +47 -10
- data/lib/langchain/llm/openai.rb +45 -10
- data/lib/langchain/tool/base.rb +9 -0
- data/lib/langchain/utils/token_length/google_palm_validator.rb +69 -0
- data/lib/langchain/utils/token_length/openai_validator.rb +75 -0
- data/lib/langchain/vectorsearch/chroma.rb +1 -1
- data/lib/langchain/vectorsearch/hnswlib.rb +122 -0
- data/lib/langchain/vectorsearch/milvus.rb +1 -14
- data/lib/langchain/vectorsearch/pgvector.rb +1 -5
- data/lib/langchain/vectorsearch/pinecone.rb +1 -4
- data/lib/langchain/vectorsearch/qdrant.rb +1 -4
- data/lib/langchain/vectorsearch/weaviate.rb +1 -4
- data/lib/langchain/version.rb +1 -1
- data/lib/langchain.rb +8 -1
- metadata +20 -3
- data/lib/langchain/utils/token_length_validator.rb +0 -89
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 87647b8a7e2dc49359f3f6d655eda501dcac26ebdd14247ad6c583be8dc1a71c
|
4
|
+
data.tar.gz: fb7b4321caa4ff026439158f5ccfc2ae9e7b515a69c35cba87385f2cb367fa85
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 4f80677c43c00e6d50e0494aa79cb7648b9f3878ed8d2a5f4f2dc90e308a3639589f8457a4615821b70b44c5a43ae4f26fcf00d7548684740e4c05dbcc165bf8
|
7
|
+
data.tar.gz: 4722233dbed83d21f2dadff19a9b79a30d8fd208d6e30bd057f018060c602b7f00f0526ee0364823597806388e2a8da48e883a5e6fa77b31490199685644b4d2
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,10 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [0.5.4] - 2023-06-10
|
4
|
+
- 🔍 Vectorsearch
|
5
|
+
- Introducing support for HNSWlib
|
6
|
+
- Improved and new `Langchain::Chat` interface that persists chat history in memory
|
7
|
+
|
3
8
|
## [0.5.3] - 2023-06-09
|
4
9
|
- 🗣️ LLMs
|
5
10
|
- Chat message history support for Langchain::LLM::GooglePalm and Langchain::LLM::OpenAI
|
data/Gemfile.lock
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
PATH
|
2
2
|
remote: .
|
3
3
|
specs:
|
4
|
-
langchainrb (0.5.
|
4
|
+
langchainrb (0.5.4)
|
5
5
|
colorize (~> 0.8.1)
|
6
6
|
tiktoken_ruby (~> 0.0.5)
|
7
7
|
|
@@ -135,6 +135,7 @@ GEM
|
|
135
135
|
activesupport (>= 3.0)
|
136
136
|
graphql
|
137
137
|
hashery (2.1.2)
|
138
|
+
hnswlib (0.8.1)
|
138
139
|
httparty (0.21.0)
|
139
140
|
mini_mime (>= 1.0.0)
|
140
141
|
multi_xml (>= 0.5.2)
|
@@ -312,6 +313,7 @@ DEPENDENCIES
|
|
312
313
|
eqn (~> 1.6.5)
|
313
314
|
google_palm_api (~> 0.1.1)
|
314
315
|
google_search_results (~> 2.0.0)
|
316
|
+
hnswlib (~> 0.8.1)
|
315
317
|
hugging-face (~> 0.3.4)
|
316
318
|
langchainrb!
|
317
319
|
milvus (~> 0.9.0)
|
data/README.md
CHANGED
@@ -34,6 +34,7 @@ require "langchain"
|
|
34
34
|
| Database | Querying | Storage | Schema Management | Backups | Rails Integration |
|
35
35
|
| -------- |:------------------:| -------:| -----------------:| -------:| -----------------:|
|
36
36
|
| [Chroma](https://trychroma.com/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
37
|
+
| [Hnswlib](https://github.com/nmslib/hnswlib/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
37
38
|
| [Milvus](https://milvus.io/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
38
39
|
| [Pinecone](https://www.pinecone.io/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
39
40
|
| [Pgvector](https://github.com/pgvector/pgvector) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
@@ -56,11 +57,12 @@ client = Langchain::Vectorsearch::Weaviate.new(
|
|
56
57
|
)
|
57
58
|
|
58
59
|
# You can instantiate any other supported vector search database:
|
60
|
+
client = Langchain::Vectorsearch::Chroma.new(...) # `gem "chroma-db", "~> 0.3.0"`
|
61
|
+
client = Langchain::Vectorsearch::Hnswlib.new(...) # `gem "hnswlib", "~> 0.8.1"`
|
59
62
|
client = Langchain::Vectorsearch::Milvus.new(...) # `gem "milvus", "~> 0.9.0"`
|
60
|
-
client = Langchain::Vectorsearch::Qdrant.new(...) # `gem"qdrant-ruby", "~> 0.9.0"`
|
61
63
|
client = Langchain::Vectorsearch::Pinecone.new(...) # `gem "pinecone", "~> 0.1.6"`
|
62
|
-
client = Langchain::Vectorsearch::Chroma.new(...) # `gem "chroma-db", "~> 0.3.0"`
|
63
64
|
client = Langchain::Vectorsearch::Pgvector.new(...) # `gem "pgvector", "~> 0.2"`
|
65
|
+
client = Langchain::Vectorsearch::Qdrant.new(...) # `gem"qdrant-ruby", "~> 0.9.0"`
|
64
66
|
```
|
65
67
|
|
66
68
|
```ruby
|
@@ -101,7 +101,7 @@ module Langchain::Agent
|
|
101
101
|
tool_names: "[#{tool_list.join(", ")}]",
|
102
102
|
tools: tools.map do |tool|
|
103
103
|
tool_name = tool.tool_name
|
104
|
-
tool_description = tool.
|
104
|
+
tool_description = tool.tool_description
|
105
105
|
"#{tool_name}: #{tool_description}"
|
106
106
|
end.join("\n")
|
107
107
|
)
|
@@ -0,0 +1,50 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
class Chat
|
5
|
+
attr_reader :context
|
6
|
+
|
7
|
+
def initialize(llm:, **options)
|
8
|
+
@llm = llm
|
9
|
+
@context = nil
|
10
|
+
@examples = []
|
11
|
+
@messages = []
|
12
|
+
end
|
13
|
+
|
14
|
+
# Set the context of the conversation. Usually used to set the model's persona.
|
15
|
+
# @param message [String] The context of the conversation
|
16
|
+
def set_context(message)
|
17
|
+
@context = message
|
18
|
+
end
|
19
|
+
|
20
|
+
# Add examples to the conversation. Used to give the model a sense of the conversation.
|
21
|
+
# @param examples [Array<Hash>] The examples to add to the conversation
|
22
|
+
def add_examples(examples)
|
23
|
+
@examples.concat examples
|
24
|
+
end
|
25
|
+
|
26
|
+
# Message the model with a prompt and return the response.
|
27
|
+
# @param message [String] The prompt to message the model with
|
28
|
+
# @return [String] The response from the model
|
29
|
+
def message(message)
|
30
|
+
append_user_message(message)
|
31
|
+
response = llm_response(message)
|
32
|
+
append_ai_message(response)
|
33
|
+
response
|
34
|
+
end
|
35
|
+
|
36
|
+
private
|
37
|
+
|
38
|
+
def llm_response(prompt)
|
39
|
+
@llm.chat(messages: @messages, context: @context, examples: @examples)
|
40
|
+
end
|
41
|
+
|
42
|
+
def append_ai_message(message)
|
43
|
+
@messages << {role: "ai", content: message}
|
44
|
+
end
|
45
|
+
|
46
|
+
def append_user_message(message)
|
47
|
+
@messages << {role: "user", content: message}
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
@@ -80,28 +80,31 @@ module Langchain::LLM
|
|
80
80
|
# @param params extra parameters passed to GooglePalmAPI::Client#generate_chat_message
|
81
81
|
# @return [String] The chat completion
|
82
82
|
#
|
83
|
-
def chat(prompt: "", messages: [], **
|
83
|
+
def chat(prompt: "", messages: [], context: "", examples: [], **options)
|
84
84
|
raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
|
85
85
|
|
86
|
-
messages << {author: "0", content: prompt} if !prompt.empty?
|
87
|
-
|
88
|
-
# TODO: Figure out how to introduce persisted conversations
|
89
86
|
default_params = {
|
90
87
|
temperature: DEFAULTS[:temperature],
|
91
|
-
|
88
|
+
context: context,
|
89
|
+
messages: compose_chat_messages(prompt: prompt, messages: messages),
|
90
|
+
examples: compose_examples(examples)
|
92
91
|
}
|
93
92
|
|
94
|
-
|
95
|
-
|
93
|
+
Langchain::Utils::TokenLength::GooglePalmValidator.validate_max_tokens!(self, default_params[:messages], "chat-bison-001")
|
94
|
+
|
95
|
+
if options[:stop_sequences]
|
96
|
+
default_params[:stop] = options.delete(:stop_sequences)
|
96
97
|
end
|
97
98
|
|
98
|
-
if
|
99
|
-
default_params[:max_output_tokens] =
|
99
|
+
if options[:max_tokens]
|
100
|
+
default_params[:max_output_tokens] = options.delete(:max_tokens)
|
100
101
|
end
|
101
102
|
|
102
|
-
default_params.merge!(
|
103
|
+
default_params.merge!(options)
|
103
104
|
|
104
105
|
response = client.generate_chat_message(**default_params)
|
106
|
+
raise "GooglePalm API returned an error: #{response}" if response.dig("error")
|
107
|
+
|
105
108
|
response.dig("candidates", 0, "content")
|
106
109
|
end
|
107
110
|
|
@@ -124,5 +127,39 @@ module Langchain::LLM
|
|
124
127
|
max_tokens: 2048
|
125
128
|
)
|
126
129
|
end
|
130
|
+
|
131
|
+
private
|
132
|
+
|
133
|
+
def compose_chat_messages(prompt:, messages:)
|
134
|
+
history = []
|
135
|
+
history.concat transform_messages(messages) unless messages.empty?
|
136
|
+
|
137
|
+
unless prompt.empty?
|
138
|
+
if history.last && history.last[:role] == "user"
|
139
|
+
history.last[:content] += "\n#{prompt}"
|
140
|
+
else
|
141
|
+
history.append({author: "user", content: prompt})
|
142
|
+
end
|
143
|
+
end
|
144
|
+
history
|
145
|
+
end
|
146
|
+
|
147
|
+
def compose_examples(examples)
|
148
|
+
examples.each_slice(2).map do |example|
|
149
|
+
{
|
150
|
+
input: {content: example.first[:content]},
|
151
|
+
output: {content: example.last[:content]}
|
152
|
+
}
|
153
|
+
end
|
154
|
+
end
|
155
|
+
|
156
|
+
def transform_messages(messages)
|
157
|
+
messages.map do |message|
|
158
|
+
{
|
159
|
+
author: message[:role],
|
160
|
+
content: message[:content]
|
161
|
+
}
|
162
|
+
end
|
163
|
+
end
|
127
164
|
end
|
128
165
|
end
|
data/lib/langchain/llm/openai.rb
CHANGED
@@ -35,7 +35,7 @@ module Langchain::LLM
|
|
35
35
|
def embed(text:, **params)
|
36
36
|
parameters = {model: DEFAULTS[:embeddings_model_name], input: text}
|
37
37
|
|
38
|
-
Langchain::Utils::
|
38
|
+
Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(text, parameters[:model])
|
39
39
|
|
40
40
|
response = client.embeddings(parameters: parameters.merge(params))
|
41
41
|
response.dig("data").first.dig("embedding")
|
@@ -52,7 +52,7 @@ module Langchain::LLM
|
|
52
52
|
parameters = compose_parameters DEFAULTS[:completion_model_name], params
|
53
53
|
|
54
54
|
parameters[:prompt] = prompt
|
55
|
-
parameters[:max_tokens] = Langchain::Utils::
|
55
|
+
parameters[:max_tokens] = Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(prompt, parameters[:model])
|
56
56
|
|
57
57
|
response = client.completions(parameters: parameters)
|
58
58
|
response.dig("choices", 0, "text")
|
@@ -63,19 +63,22 @@ module Langchain::LLM
|
|
63
63
|
#
|
64
64
|
# @param prompt [String] The prompt to generate a chat completion for
|
65
65
|
# @param messages [Array] The messages that have been sent in the conversation
|
66
|
-
# @param
|
66
|
+
# @param context [String] The context of the conversation
|
67
|
+
# @param examples [Array] Examples of messages provide model with
|
68
|
+
# @param options extra parameters passed to OpenAI::Client#chat
|
67
69
|
# @return [String] The chat completion
|
68
70
|
#
|
69
|
-
def chat(prompt: "", messages: [], **
|
71
|
+
def chat(prompt: "", messages: [], context: "", examples: [], **options)
|
70
72
|
raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
|
71
73
|
|
72
|
-
|
73
|
-
|
74
|
-
parameters =
|
75
|
-
parameters[:messages] = messages
|
76
|
-
parameters[:max_tokens] = validate_max_tokens(messages, parameters[:model])
|
74
|
+
parameters = compose_parameters DEFAULTS[:chat_completion_model_name], options
|
75
|
+
parameters[:messages] = compose_chat_messages(prompt: prompt, messages: messages, context: context, examples: examples)
|
76
|
+
parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
|
77
77
|
|
78
78
|
response = client.chat(parameters: parameters)
|
79
|
+
|
80
|
+
raise "Chat completion failed: #{response}" if response.dig("error")
|
81
|
+
|
79
82
|
response.dig("choices", 0, "message", "content")
|
80
83
|
end
|
81
84
|
|
@@ -104,8 +107,40 @@ module Langchain::LLM
|
|
104
107
|
default_params.merge(params)
|
105
108
|
end
|
106
109
|
|
110
|
+
def compose_chat_messages(prompt:, messages:, context:, examples:)
|
111
|
+
history = []
|
112
|
+
|
113
|
+
history.concat transform_messages(examples) unless examples.empty?
|
114
|
+
|
115
|
+
history.concat transform_messages(messages) unless messages.empty?
|
116
|
+
|
117
|
+
unless context.nil? || context.empty?
|
118
|
+
history.reject! { |message| message[:role] == "system" }
|
119
|
+
history.prepend({role: "system", content: context})
|
120
|
+
end
|
121
|
+
|
122
|
+
unless prompt.empty?
|
123
|
+
if history.last && history.last[:role] == "user"
|
124
|
+
history.last[:content] += "\n#{prompt}"
|
125
|
+
else
|
126
|
+
history.append({role: "user", content: prompt})
|
127
|
+
end
|
128
|
+
end
|
129
|
+
|
130
|
+
history
|
131
|
+
end
|
132
|
+
|
133
|
+
def transform_messages(messages)
|
134
|
+
messages.map do |message|
|
135
|
+
{
|
136
|
+
content: message[:content],
|
137
|
+
role: (message[:role] == "ai") ? "assistant" : message[:role]
|
138
|
+
}
|
139
|
+
end
|
140
|
+
end
|
141
|
+
|
107
142
|
def validate_max_tokens(messages, model)
|
108
|
-
Langchain::Utils::
|
143
|
+
Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(messages, model)
|
109
144
|
end
|
110
145
|
end
|
111
146
|
end
|
data/lib/langchain/tool/base.rb
CHANGED
@@ -57,6 +57,15 @@ module Langchain::Tool
|
|
57
57
|
self.class.const_get(:NAME)
|
58
58
|
end
|
59
59
|
|
60
|
+
#
|
61
|
+
# Returns the DESCRIPTION constant of the tool
|
62
|
+
#
|
63
|
+
# @return [String] tool description
|
64
|
+
#
|
65
|
+
def tool_description
|
66
|
+
self.class.const_get(:DESCRIPTION)
|
67
|
+
end
|
68
|
+
|
60
69
|
#
|
61
70
|
# Sets the DESCRIPTION constant of the tool
|
62
71
|
#
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
module Utils
|
5
|
+
module TokenLength
|
6
|
+
#
|
7
|
+
# This class is meant to validate the length of the text passed in to Google Palm's API.
|
8
|
+
# It is used to validate the token length before the API call is made
|
9
|
+
#
|
10
|
+
class GooglePalmValidator
|
11
|
+
TOKEN_LIMITS = {
|
12
|
+
# Source:
|
13
|
+
# This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage
|
14
|
+
|
15
|
+
# chat-bison-001 is the only model that currently supports countMessageTokens functions
|
16
|
+
"chat-bison-001" => {
|
17
|
+
"input_token_limit" => 4000, # 4096 is the limit but the countMessageTokens does not return anything higher than 4000
|
18
|
+
"output_token_limit" => 1024
|
19
|
+
}
|
20
|
+
# "text-bison-001" => {
|
21
|
+
# "input_token_limit" => 8196,
|
22
|
+
# "output_token_limit" => 1024
|
23
|
+
# },
|
24
|
+
# "embedding-gecko-001" => {
|
25
|
+
# "input_token_limit" => 1024
|
26
|
+
# }
|
27
|
+
}.freeze
|
28
|
+
|
29
|
+
#
|
30
|
+
# Validate the context length of the text
|
31
|
+
#
|
32
|
+
# @param content [String | Array<String>] The text or array of texts to validate
|
33
|
+
# @param model_name [String] The model name to validate against
|
34
|
+
# @return [Integer] Whether the text is valid or not
|
35
|
+
# @raise [TokenLimitExceeded] If the text is too long
|
36
|
+
#
|
37
|
+
def self.validate_max_tokens!(google_palm_llm, content, model_name)
|
38
|
+
text_token_length = if content.is_a?(Array)
|
39
|
+
content.sum { |item| token_length(google_palm_llm, item.to_json, model_name) }
|
40
|
+
else
|
41
|
+
token_length(google_palm_llm, content, model_name)
|
42
|
+
end
|
43
|
+
|
44
|
+
leftover_tokens = TOKEN_LIMITS.dig(model_name, "input_token_limit") - text_token_length
|
45
|
+
|
46
|
+
# Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
|
47
|
+
if leftover_tokens <= 0
|
48
|
+
raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS.dig(model_name, "input_token_limit")} tokens, but the given text is #{text_token_length} tokens long."
|
49
|
+
end
|
50
|
+
|
51
|
+
leftover_tokens
|
52
|
+
end
|
53
|
+
|
54
|
+
#
|
55
|
+
# Calculate token length for a given text and model name
|
56
|
+
#
|
57
|
+
# @param llm [Langchain::LLM:GooglePalm] The Langchain::LLM:GooglePalm instance
|
58
|
+
# @param text [String] The text to calculate the token length for
|
59
|
+
# @param model_name [String] The model name to validate against
|
60
|
+
# @return [Integer] The token length of the text
|
61
|
+
#
|
62
|
+
def self.token_length(llm, text, model_name = "chat-bison-001")
|
63
|
+
response = llm.client.count_message_tokens(model: model_name, prompt: text)
|
64
|
+
response.dig("tokenCount")
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
69
|
+
end
|
@@ -0,0 +1,75 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "tiktoken_ruby"
|
4
|
+
|
5
|
+
module Langchain
|
6
|
+
module Utils
|
7
|
+
module TokenLength
|
8
|
+
#
|
9
|
+
# This class is meant to validate the length of the text passed in to OpenAI's API.
|
10
|
+
# It is used to validate the token length before the API call is made
|
11
|
+
#
|
12
|
+
class OpenAIValidator
|
13
|
+
TOKEN_LIMITS = {
|
14
|
+
# Source:
|
15
|
+
# https://platform.openai.com/docs/api-reference/embeddings
|
16
|
+
# https://platform.openai.com/docs/models/gpt-4
|
17
|
+
"text-embedding-ada-002" => 8191,
|
18
|
+
"gpt-3.5-turbo" => 4096,
|
19
|
+
"gpt-3.5-turbo-0301" => 4096,
|
20
|
+
"text-davinci-003" => 4097,
|
21
|
+
"text-davinci-002" => 4097,
|
22
|
+
"code-davinci-002" => 8001,
|
23
|
+
"gpt-4" => 8192,
|
24
|
+
"gpt-4-0314" => 8192,
|
25
|
+
"gpt-4-32k" => 32768,
|
26
|
+
"gpt-4-32k-0314" => 32768,
|
27
|
+
"text-curie-001" => 2049,
|
28
|
+
"text-babbage-001" => 2049,
|
29
|
+
"text-ada-001" => 2049,
|
30
|
+
"davinci" => 2049,
|
31
|
+
"curie" => 2049,
|
32
|
+
"babbage" => 2049,
|
33
|
+
"ada" => 2049
|
34
|
+
}.freeze
|
35
|
+
|
36
|
+
#
|
37
|
+
# Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
|
38
|
+
#
|
39
|
+
# @param content [String | Array<String>] The text or array of texts to validate
|
40
|
+
# @param model_name [String] The model name to validate against
|
41
|
+
# @return [Integer] Whether the text is valid or not
|
42
|
+
# @raise [TokenLimitExceeded] If the text is too long
|
43
|
+
#
|
44
|
+
def self.validate_max_tokens!(content, model_name)
|
45
|
+
text_token_length = if content.is_a?(Array)
|
46
|
+
content.sum { |item| token_length(item.to_json, model_name) }
|
47
|
+
else
|
48
|
+
token_length(content, model_name)
|
49
|
+
end
|
50
|
+
|
51
|
+
max_tokens = TOKEN_LIMITS[model_name] - text_token_length
|
52
|
+
|
53
|
+
# Raise an error even if whole prompt is equal to the model's token limit (max_tokens == 0) since not response will be returned
|
54
|
+
if max_tokens <= 0
|
55
|
+
raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS[model_name]} tokens, but the given text is #{text_token_length} tokens long."
|
56
|
+
end
|
57
|
+
|
58
|
+
max_tokens
|
59
|
+
end
|
60
|
+
|
61
|
+
#
|
62
|
+
# Calculate token length for a given text and model name
|
63
|
+
#
|
64
|
+
# @param text [String] The text to calculate the token length for
|
65
|
+
# @param model_name [String] The model name to validate against
|
66
|
+
# @return [Integer] The token length of the text
|
67
|
+
#
|
68
|
+
def self.token_length(text, model_name)
|
69
|
+
encoder = Tiktoken.encoding_for_model(model_name)
|
70
|
+
encoder.encode(text).length
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
@@ -8,7 +8,7 @@ module Langchain::Vectorsearch
|
|
8
8
|
# Gem requirements: gem "chroma-db", "~> 0.3.0"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
# chroma = Langchain::Vectorsearch::Chroma.new(url:, index_name:, llm:, api_key: nil)
|
11
|
+
# chroma = Langchain::Vectorsearch::Chroma.new(url:, index_name:, llm:, llm_api_key:, api_key: nil)
|
12
12
|
#
|
13
13
|
|
14
14
|
# Initialize the Chroma client
|
@@ -0,0 +1,122 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain::Vectorsearch
|
4
|
+
class Hnswlib < Base
|
5
|
+
#
|
6
|
+
# Wrapper around HNSW (Hierarchical Navigable Small World) library.
|
7
|
+
# HNSWLib is an in-memory vectorstore that can be saved to a file on disk.
|
8
|
+
#
|
9
|
+
# Gem requirements:
|
10
|
+
# gem "hnswlib", "~> 0.8.1"
|
11
|
+
#
|
12
|
+
# Usage:
|
13
|
+
# hnsw = Langchain::Vectorsearch::Hnswlib.new(llm:, url:, index_name:)
|
14
|
+
#
|
15
|
+
|
16
|
+
attr_reader :client, :path_to_index
|
17
|
+
|
18
|
+
#
|
19
|
+
# Initialize the HNSW vector search
|
20
|
+
#
|
21
|
+
# @param llm [Object] The LLM client to use
|
22
|
+
# @param path_to_index [String] The local path to the index file, e.g.: "/storage/index.ann"
|
23
|
+
# @return [Langchain::Vectorsearch::Hnswlib] Class instance
|
24
|
+
#
|
25
|
+
def initialize(llm:, path_to_index:)
|
26
|
+
depends_on "hnswlib"
|
27
|
+
require "hnswlib"
|
28
|
+
|
29
|
+
super(llm: llm)
|
30
|
+
|
31
|
+
@client = ::Hnswlib::HierarchicalNSW.new(space: DEFAULT_METRIC, dim: llm.default_dimension)
|
32
|
+
@path_to_index = path_to_index
|
33
|
+
|
34
|
+
initialize_index
|
35
|
+
end
|
36
|
+
|
37
|
+
#
|
38
|
+
# Add a list of texts and corresponding IDs to the index
|
39
|
+
#
|
40
|
+
# @param texts [Array] The list of texts to add
|
41
|
+
# @param ids [Array] The list of corresponding IDs (integers) to the texts
|
42
|
+
# @return [Boolean] The response from the HNSW library
|
43
|
+
#
|
44
|
+
def add_texts(texts:, ids:)
|
45
|
+
resize_index(texts.size)
|
46
|
+
|
47
|
+
Array(texts).each_with_index do |text, i|
|
48
|
+
embedding = llm.embed(text: text)
|
49
|
+
|
50
|
+
client.add_point(embedding, ids[i])
|
51
|
+
end
|
52
|
+
|
53
|
+
client.save_index(path_to_index)
|
54
|
+
end
|
55
|
+
|
56
|
+
#
|
57
|
+
# Search for similar texts
|
58
|
+
#
|
59
|
+
# @param query [String] The text to search for
|
60
|
+
# @param k [Integer] The number of results to return
|
61
|
+
# @return [Array] Results in the format `[[id1, distance3], [id2, distance2]]`
|
62
|
+
#
|
63
|
+
def similarity_search(
|
64
|
+
query:,
|
65
|
+
k: 4
|
66
|
+
)
|
67
|
+
embedding = llm.embed(text: query)
|
68
|
+
|
69
|
+
similarity_search_by_vector(
|
70
|
+
embedding: embedding,
|
71
|
+
k: k
|
72
|
+
)
|
73
|
+
end
|
74
|
+
|
75
|
+
#
|
76
|
+
# Search for the K nearest neighbors of a given vector
|
77
|
+
#
|
78
|
+
# @param embedding [Array] The embedding to search for
|
79
|
+
# @param k [Integer] The number of results to return
|
80
|
+
# @return [Array] Results in the format `[[id1, distance3], [id2, distance2]]`
|
81
|
+
#
|
82
|
+
def similarity_search_by_vector(
|
83
|
+
embedding:,
|
84
|
+
k: 4
|
85
|
+
)
|
86
|
+
client.search_knn(embedding, k)
|
87
|
+
end
|
88
|
+
|
89
|
+
private
|
90
|
+
|
91
|
+
#
|
92
|
+
# Optionally resizes the index if there's no space for new data
|
93
|
+
#
|
94
|
+
# @param num_of_elements_to_add [Integer] The number of elements to add to the index
|
95
|
+
#
|
96
|
+
def resize_index(num_of_elements_to_add)
|
97
|
+
current_count = client.current_count
|
98
|
+
|
99
|
+
if (current_count + num_of_elements_to_add) > client.max_elements
|
100
|
+
new_size = current_count + num_of_elements_to_add
|
101
|
+
|
102
|
+
client.resize_index(new_size)
|
103
|
+
end
|
104
|
+
end
|
105
|
+
|
106
|
+
#
|
107
|
+
# Loads or initializes the new index
|
108
|
+
#
|
109
|
+
def initialize_index
|
110
|
+
if File.exist?(path_to_index)
|
111
|
+
client.load_index(path_to_index)
|
112
|
+
|
113
|
+
Langchain.logger.info("[#{self.class.name}]".blue + ": Successfully loaded the index at \"#{path_to_index}\"")
|
114
|
+
else
|
115
|
+
# Default max_elements: 100, but we constantly resize the index as new data is written to it
|
116
|
+
client.init_index(max_elements: 100)
|
117
|
+
|
118
|
+
Langchain.logger.info("[#{self.class.name}]".blue + ": Creating a new index at \"#{path_to_index}\"")
|
119
|
+
end
|
120
|
+
end
|
121
|
+
end
|
122
|
+
end
|
@@ -8,17 +8,9 @@ module Langchain::Vectorsearch
|
|
8
8
|
# Gem requirements: gem "milvus", "~> 0.9.0"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
# milvus = Langchain::Vectorsearch::Milvus.new(url:, index_name:, llm:)
|
11
|
+
# milvus = Langchain::Vectorsearch::Milvus.new(url:, index_name:, llm:, llm_api_key:)
|
12
12
|
#
|
13
13
|
|
14
|
-
#
|
15
|
-
# Initialize the Milvus client
|
16
|
-
#
|
17
|
-
# @param url [String] The URL of the Milvus server
|
18
|
-
# @param api_key [String] The API key to use
|
19
|
-
# @param index_name [String] The name of the index to use
|
20
|
-
# @param llm [Object] The LLM client to use
|
21
|
-
#
|
22
14
|
def initialize(url:, index_name:, llm:, api_key: nil)
|
23
15
|
depends_on "milvus"
|
24
16
|
require "milvus"
|
@@ -29,11 +21,6 @@ module Langchain::Vectorsearch
|
|
29
21
|
super(llm: llm)
|
30
22
|
end
|
31
23
|
|
32
|
-
#
|
33
|
-
# Add a list of texts to the index
|
34
|
-
#
|
35
|
-
# @param texts [Array] The list of texts to add
|
36
|
-
#
|
37
24
|
def add_texts(texts:)
|
38
25
|
client.entities.insert(
|
39
26
|
collection_name: index_name,
|
@@ -8,7 +8,7 @@ module Langchain::Vectorsearch
|
|
8
8
|
# Gem requirements: gem "pgvector", "~> 0.2"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
# pgvector = Langchain::Vectorsearch::Pgvector.new(url:, index_name:, llm:)
|
11
|
+
# pgvector = Langchain::Vectorsearch::Pgvector.new(url:, index_name:, llm:, llm_api_key:)
|
12
12
|
#
|
13
13
|
|
14
14
|
# The operators supported by the PostgreSQL vector search adapter
|
@@ -20,14 +20,10 @@ module Langchain::Vectorsearch
|
|
20
20
|
|
21
21
|
attr_reader :operator, :quoted_table_name
|
22
22
|
|
23
|
-
#
|
24
|
-
# Initialize the PostgreSQL client
|
25
|
-
#
|
26
23
|
# @param url [String] The URL of the PostgreSQL database
|
27
24
|
# @param index_name [String] The name of the table to use for the index
|
28
25
|
# @param llm [Object] The LLM client to use
|
29
26
|
# @param api_key [String] The API key for the Vectorsearch DB (not used for PostgreSQL)
|
30
|
-
#
|
31
27
|
def initialize(url:, index_name:, llm:, api_key: nil)
|
32
28
|
require "pg"
|
33
29
|
require "pgvector"
|
@@ -8,17 +8,14 @@ module Langchain::Vectorsearch
|
|
8
8
|
# Gem requirements: gem "pinecone", "~> 0.1.6"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
# pinecone = Langchain::Vectorsearch::Pinecone.new(environment:, api_key:, index_name:, llm:)
|
11
|
+
# pinecone = Langchain::Vectorsearch::Pinecone.new(environment:, api_key:, index_name:, llm:, llm_api_key:)
|
12
12
|
#
|
13
13
|
|
14
|
-
#
|
15
14
|
# Initialize the Pinecone client
|
16
|
-
#
|
17
15
|
# @param environment [String] The environment to use
|
18
16
|
# @param api_key [String] The API key to use
|
19
17
|
# @param index_name [String] The name of the index to use
|
20
18
|
# @param llm [Object] The LLM client to use
|
21
|
-
#
|
22
19
|
def initialize(environment:, api_key:, index_name:, llm:)
|
23
20
|
depends_on "pinecone"
|
24
21
|
require "pinecone"
|
@@ -8,17 +8,14 @@ module Langchain::Vectorsearch
|
|
8
8
|
# Gem requirements: gem "qdrant-ruby", "~> 0.9.0"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
# qdrant = Langchain::Vectorsearch::Qdrant.new(url:, api_key:, index_name:, llm:)
|
11
|
+
# qdrant = Langchain::Vectorsearch::Qdrant.new(url:, api_key:, index_name:, llm:, llm_api_key:)
|
12
12
|
#
|
13
13
|
|
14
|
-
#
|
15
14
|
# Initialize the Qdrant client
|
16
|
-
#
|
17
15
|
# @param url [String] The URL of the Qdrant server
|
18
16
|
# @param api_key [String] The API key to use
|
19
17
|
# @param index_name [String] The name of the index to use
|
20
18
|
# @param llm [Object] The LLM client to use
|
21
|
-
#
|
22
19
|
def initialize(url:, api_key:, index_name:, llm:)
|
23
20
|
depends_on "qdrant-ruby"
|
24
21
|
require "qdrant"
|
@@ -8,17 +8,14 @@ module Langchain::Vectorsearch
|
|
8
8
|
# Gem requirements: gem "weaviate-ruby", "~> 0.8.0"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
# weaviate = Langchain::Vectorsearch::Weaviate.new(url:, api_key:, index_name:, llm:)
|
11
|
+
# weaviate = Langchain::Vectorsearch::Weaviate.new(url:, api_key:, index_name:, llm:, llm_api_key:)
|
12
12
|
#
|
13
13
|
|
14
|
-
#
|
15
14
|
# Initialize the Weaviate adapter
|
16
|
-
#
|
17
15
|
# @param url [String] The URL of the Weaviate instance
|
18
16
|
# @param api_key [String] The API key to use
|
19
17
|
# @param index_name [String] The name of the index to use
|
20
18
|
# @param llm [Object] The LLM client to use
|
21
|
-
#
|
22
19
|
def initialize(url:, api_key:, index_name:, llm:)
|
23
20
|
depends_on "weaviate-ruby"
|
24
21
|
require "weaviate"
|
data/lib/langchain/version.rb
CHANGED
data/lib/langchain.rb
CHANGED
@@ -62,6 +62,7 @@ module Langchain
|
|
62
62
|
|
63
63
|
autoload :Loader, "langchain/loader"
|
64
64
|
autoload :Data, "langchain/data"
|
65
|
+
autoload :Chat, "langchain/chat"
|
65
66
|
autoload :DependencyHelper, "langchain/dependency_helper"
|
66
67
|
|
67
68
|
module Agent
|
@@ -92,12 +93,18 @@ module Langchain
|
|
92
93
|
end
|
93
94
|
|
94
95
|
module Utils
|
95
|
-
|
96
|
+
module TokenLength
|
97
|
+
class TokenLimitExceeded < StandardError; end
|
98
|
+
|
99
|
+
autoload :OpenAIValidator, "langchain/utils/token_length/openai_validator"
|
100
|
+
autoload :GooglePalmValidator, "langchain/utils/token_length/google_palm_validator"
|
101
|
+
end
|
96
102
|
end
|
97
103
|
|
98
104
|
module Vectorsearch
|
99
105
|
autoload :Base, "langchain/vectorsearch/base"
|
100
106
|
autoload :Chroma, "langchain/vectorsearch/chroma"
|
107
|
+
autoload :Hnswlib, "langchain/vectorsearch/hnswlib"
|
101
108
|
autoload :Milvus, "langchain/vectorsearch/milvus"
|
102
109
|
autoload :Pinecone, "langchain/vectorsearch/pinecone"
|
103
110
|
autoload :Pgvector, "langchain/vectorsearch/pgvector"
|
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: langchainrb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.5.
|
4
|
+
version: 0.5.4
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrei Bondarev
|
@@ -192,6 +192,20 @@ dependencies:
|
|
192
192
|
- - "~>"
|
193
193
|
- !ruby/object:Gem::Version
|
194
194
|
version: 2.0.0
|
195
|
+
- !ruby/object:Gem::Dependency
|
196
|
+
name: hnswlib
|
197
|
+
requirement: !ruby/object:Gem::Requirement
|
198
|
+
requirements:
|
199
|
+
- - "~>"
|
200
|
+
- !ruby/object:Gem::Version
|
201
|
+
version: 0.8.1
|
202
|
+
type: :development
|
203
|
+
prerelease: false
|
204
|
+
version_requirements: !ruby/object:Gem::Requirement
|
205
|
+
requirements:
|
206
|
+
- - "~>"
|
207
|
+
- !ruby/object:Gem::Version
|
208
|
+
version: 0.8.1
|
195
209
|
- !ruby/object:Gem::Dependency
|
196
210
|
name: hugging-face
|
197
211
|
requirement: !ruby/object:Gem::Requirement
|
@@ -432,6 +446,7 @@ files:
|
|
432
446
|
- lib/langchain/agent/sql_query_agent/sql_query_agent.rb
|
433
447
|
- lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.json
|
434
448
|
- lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.json
|
449
|
+
- lib/langchain/chat.rb
|
435
450
|
- lib/langchain/data.rb
|
436
451
|
- lib/langchain/dependency_helper.rb
|
437
452
|
- lib/langchain/llm/ai21.rb
|
@@ -462,9 +477,11 @@ files:
|
|
462
477
|
- lib/langchain/tool/ruby_code_interpreter.rb
|
463
478
|
- lib/langchain/tool/serp_api.rb
|
464
479
|
- lib/langchain/tool/wikipedia.rb
|
465
|
-
- lib/langchain/utils/
|
480
|
+
- lib/langchain/utils/token_length/google_palm_validator.rb
|
481
|
+
- lib/langchain/utils/token_length/openai_validator.rb
|
466
482
|
- lib/langchain/vectorsearch/base.rb
|
467
483
|
- lib/langchain/vectorsearch/chroma.rb
|
484
|
+
- lib/langchain/vectorsearch/hnswlib.rb
|
468
485
|
- lib/langchain/vectorsearch/milvus.rb
|
469
486
|
- lib/langchain/vectorsearch/pgvector.rb
|
470
487
|
- lib/langchain/vectorsearch/pinecone.rb
|
@@ -495,7 +512,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
495
512
|
- !ruby/object:Gem::Version
|
496
513
|
version: '0'
|
497
514
|
requirements: []
|
498
|
-
rubygems_version: 3.3
|
515
|
+
rubygems_version: 3.2.3
|
499
516
|
signing_key:
|
500
517
|
specification_version: 4
|
501
518
|
summary: Build LLM-backed Ruby applications with Ruby's LangChain
|
@@ -1,89 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require "tiktoken_ruby"
|
4
|
-
|
5
|
-
module Langchain
|
6
|
-
module Utils
|
7
|
-
class TokenLimitExceeded < StandardError; end
|
8
|
-
|
9
|
-
class TokenLengthValidator
|
10
|
-
#
|
11
|
-
# This class is meant to validate the length of the text passed in to OpenAI's API.
|
12
|
-
# It is used to validate the token length before the API call is made
|
13
|
-
#
|
14
|
-
TOKEN_LIMITS = {
|
15
|
-
# Source:
|
16
|
-
# https://platform.openai.com/docs/api-reference/embeddings
|
17
|
-
# https://platform.openai.com/docs/models/gpt-4
|
18
|
-
"text-embedding-ada-002" => 8191,
|
19
|
-
"gpt-3.5-turbo" => 4096,
|
20
|
-
"gpt-3.5-turbo-0301" => 4096,
|
21
|
-
"text-davinci-003" => 4097,
|
22
|
-
"text-davinci-002" => 4097,
|
23
|
-
"code-davinci-002" => 8001,
|
24
|
-
"gpt-4" => 8192,
|
25
|
-
"gpt-4-0314" => 8192,
|
26
|
-
"gpt-4-32k" => 32768,
|
27
|
-
"gpt-4-32k-0314" => 32768,
|
28
|
-
"text-curie-001" => 2049,
|
29
|
-
"text-babbage-001" => 2049,
|
30
|
-
"text-ada-001" => 2049,
|
31
|
-
"davinci" => 2049,
|
32
|
-
"curie" => 2049,
|
33
|
-
"babbage" => 2049,
|
34
|
-
"ada" => 2049
|
35
|
-
}.freeze
|
36
|
-
|
37
|
-
# GOOGLE_PALM_TOKEN_LIMITS = {
|
38
|
-
# "chat-bison-001" => {
|
39
|
-
# "inputTokenLimit"=>4096,
|
40
|
-
# "outputTokenLimit"=>1024
|
41
|
-
# },
|
42
|
-
# "text-bison-001" => {
|
43
|
-
# "inputTokenLimit"=>8196,
|
44
|
-
# "outputTokenLimit"=>1024
|
45
|
-
# },
|
46
|
-
# "embedding-gecko-001" => {
|
47
|
-
# "inputTokenLimit"=>1024
|
48
|
-
# }
|
49
|
-
# }.freeze
|
50
|
-
|
51
|
-
#
|
52
|
-
# Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
|
53
|
-
#
|
54
|
-
# @param content [String | Array<String>] The text or array of texts to validate
|
55
|
-
# @param model_name [String] The model name to validate against
|
56
|
-
# @return [Integer] Whether the text is valid or not
|
57
|
-
# @raise [TokenLimitExceeded] If the text is too long
|
58
|
-
#
|
59
|
-
def self.validate_max_tokens!(content, model_name)
|
60
|
-
text_token_length = if content.is_a?(Array)
|
61
|
-
content.sum { |item| token_length(item.to_json, model_name) }
|
62
|
-
else
|
63
|
-
token_length(content, model_name)
|
64
|
-
end
|
65
|
-
|
66
|
-
max_tokens = TOKEN_LIMITS[model_name] - text_token_length
|
67
|
-
|
68
|
-
# Raise an error even if whole prompt is equal to the model's token limit (max_tokens == 0) since not response will be returned
|
69
|
-
if max_tokens <= 0
|
70
|
-
raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS[model_name]} tokens, but the given text is #{text_token_length} tokens long."
|
71
|
-
end
|
72
|
-
|
73
|
-
max_tokens
|
74
|
-
end
|
75
|
-
|
76
|
-
#
|
77
|
-
# Calculate token length for a given text and model name
|
78
|
-
#
|
79
|
-
# @param text [String] The text to validate
|
80
|
-
# @param model_name [String] The model name to validate against
|
81
|
-
# @return [Integer] The token length of the text
|
82
|
-
#
|
83
|
-
def self.token_length(text, model_name)
|
84
|
-
encoder = Tiktoken.encoding_for_model(model_name)
|
85
|
-
encoder.encode(text).length
|
86
|
-
end
|
87
|
-
end
|
88
|
-
end
|
89
|
-
end
|