langchainrb 0.5.3 → 0.5.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/Gemfile.lock +3 -1
- data/README.md +4 -2
- data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent.rb +1 -1
- data/lib/langchain/chat.rb +50 -0
- data/lib/langchain/llm/google_palm.rb +47 -10
- data/lib/langchain/llm/openai.rb +45 -10
- data/lib/langchain/tool/base.rb +9 -0
- data/lib/langchain/utils/token_length/google_palm_validator.rb +69 -0
- data/lib/langchain/utils/token_length/openai_validator.rb +75 -0
- data/lib/langchain/vectorsearch/chroma.rb +1 -1
- data/lib/langchain/vectorsearch/hnswlib.rb +122 -0
- data/lib/langchain/vectorsearch/milvus.rb +1 -14
- data/lib/langchain/vectorsearch/pgvector.rb +1 -5
- data/lib/langchain/vectorsearch/pinecone.rb +1 -4
- data/lib/langchain/vectorsearch/qdrant.rb +1 -4
- data/lib/langchain/vectorsearch/weaviate.rb +1 -4
- data/lib/langchain/version.rb +1 -1
- data/lib/langchain.rb +8 -1
- metadata +20 -3
- data/lib/langchain/utils/token_length_validator.rb +0 -89
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 87647b8a7e2dc49359f3f6d655eda501dcac26ebdd14247ad6c583be8dc1a71c
|
4
|
+
data.tar.gz: fb7b4321caa4ff026439158f5ccfc2ae9e7b515a69c35cba87385f2cb367fa85
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 4f80677c43c00e6d50e0494aa79cb7648b9f3878ed8d2a5f4f2dc90e308a3639589f8457a4615821b70b44c5a43ae4f26fcf00d7548684740e4c05dbcc165bf8
|
7
|
+
data.tar.gz: 4722233dbed83d21f2dadff19a9b79a30d8fd208d6e30bd057f018060c602b7f00f0526ee0364823597806388e2a8da48e883a5e6fa77b31490199685644b4d2
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,10 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [0.5.4] - 2023-06-10
|
4
|
+
- 🔍 Vectorsearch
|
5
|
+
- Introducing support for HNSWlib
|
6
|
+
- Improved and new `Langchain::Chat` interface that persists chat history in memory
|
7
|
+
|
3
8
|
## [0.5.3] - 2023-06-09
|
4
9
|
- 🗣️ LLMs
|
5
10
|
- Chat message history support for Langchain::LLM::GooglePalm and Langchain::LLM::OpenAI
|
data/Gemfile.lock
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
PATH
|
2
2
|
remote: .
|
3
3
|
specs:
|
4
|
-
langchainrb (0.5.
|
4
|
+
langchainrb (0.5.4)
|
5
5
|
colorize (~> 0.8.1)
|
6
6
|
tiktoken_ruby (~> 0.0.5)
|
7
7
|
|
@@ -135,6 +135,7 @@ GEM
|
|
135
135
|
activesupport (>= 3.0)
|
136
136
|
graphql
|
137
137
|
hashery (2.1.2)
|
138
|
+
hnswlib (0.8.1)
|
138
139
|
httparty (0.21.0)
|
139
140
|
mini_mime (>= 1.0.0)
|
140
141
|
multi_xml (>= 0.5.2)
|
@@ -312,6 +313,7 @@ DEPENDENCIES
|
|
312
313
|
eqn (~> 1.6.5)
|
313
314
|
google_palm_api (~> 0.1.1)
|
314
315
|
google_search_results (~> 2.0.0)
|
316
|
+
hnswlib (~> 0.8.1)
|
315
317
|
hugging-face (~> 0.3.4)
|
316
318
|
langchainrb!
|
317
319
|
milvus (~> 0.9.0)
|
data/README.md
CHANGED
@@ -34,6 +34,7 @@ require "langchain"
|
|
34
34
|
| Database | Querying | Storage | Schema Management | Backups | Rails Integration |
|
35
35
|
| -------- |:------------------:| -------:| -----------------:| -------:| -----------------:|
|
36
36
|
| [Chroma](https://trychroma.com/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
37
|
+
| [Hnswlib](https://github.com/nmslib/hnswlib/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
37
38
|
| [Milvus](https://milvus.io/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
38
39
|
| [Pinecone](https://www.pinecone.io/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
39
40
|
| [Pgvector](https://github.com/pgvector/pgvector) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
|
@@ -56,11 +57,12 @@ client = Langchain::Vectorsearch::Weaviate.new(
|
|
56
57
|
)
|
57
58
|
|
58
59
|
# You can instantiate any other supported vector search database:
|
60
|
+
client = Langchain::Vectorsearch::Chroma.new(...) # `gem "chroma-db", "~> 0.3.0"`
|
61
|
+
client = Langchain::Vectorsearch::Hnswlib.new(...) # `gem "hnswlib", "~> 0.8.1"`
|
59
62
|
client = Langchain::Vectorsearch::Milvus.new(...) # `gem "milvus", "~> 0.9.0"`
|
60
|
-
client = Langchain::Vectorsearch::Qdrant.new(...) # `gem"qdrant-ruby", "~> 0.9.0"`
|
61
63
|
client = Langchain::Vectorsearch::Pinecone.new(...) # `gem "pinecone", "~> 0.1.6"`
|
62
|
-
client = Langchain::Vectorsearch::Chroma.new(...) # `gem "chroma-db", "~> 0.3.0"`
|
63
64
|
client = Langchain::Vectorsearch::Pgvector.new(...) # `gem "pgvector", "~> 0.2"`
|
65
|
+
client = Langchain::Vectorsearch::Qdrant.new(...) # `gem"qdrant-ruby", "~> 0.9.0"`
|
64
66
|
```
|
65
67
|
|
66
68
|
```ruby
|
@@ -101,7 +101,7 @@ module Langchain::Agent
|
|
101
101
|
tool_names: "[#{tool_list.join(", ")}]",
|
102
102
|
tools: tools.map do |tool|
|
103
103
|
tool_name = tool.tool_name
|
104
|
-
tool_description = tool.
|
104
|
+
tool_description = tool.tool_description
|
105
105
|
"#{tool_name}: #{tool_description}"
|
106
106
|
end.join("\n")
|
107
107
|
)
|
@@ -0,0 +1,50 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
class Chat
|
5
|
+
attr_reader :context
|
6
|
+
|
7
|
+
def initialize(llm:, **options)
|
8
|
+
@llm = llm
|
9
|
+
@context = nil
|
10
|
+
@examples = []
|
11
|
+
@messages = []
|
12
|
+
end
|
13
|
+
|
14
|
+
# Set the context of the conversation. Usually used to set the model's persona.
|
15
|
+
# @param message [String] The context of the conversation
|
16
|
+
def set_context(message)
|
17
|
+
@context = message
|
18
|
+
end
|
19
|
+
|
20
|
+
# Add examples to the conversation. Used to give the model a sense of the conversation.
|
21
|
+
# @param examples [Array<Hash>] The examples to add to the conversation
|
22
|
+
def add_examples(examples)
|
23
|
+
@examples.concat examples
|
24
|
+
end
|
25
|
+
|
26
|
+
# Message the model with a prompt and return the response.
|
27
|
+
# @param message [String] The prompt to message the model with
|
28
|
+
# @return [String] The response from the model
|
29
|
+
def message(message)
|
30
|
+
append_user_message(message)
|
31
|
+
response = llm_response(message)
|
32
|
+
append_ai_message(response)
|
33
|
+
response
|
34
|
+
end
|
35
|
+
|
36
|
+
private
|
37
|
+
|
38
|
+
def llm_response(prompt)
|
39
|
+
@llm.chat(messages: @messages, context: @context, examples: @examples)
|
40
|
+
end
|
41
|
+
|
42
|
+
def append_ai_message(message)
|
43
|
+
@messages << {role: "ai", content: message}
|
44
|
+
end
|
45
|
+
|
46
|
+
def append_user_message(message)
|
47
|
+
@messages << {role: "user", content: message}
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
@@ -80,28 +80,31 @@ module Langchain::LLM
|
|
80
80
|
# @param params extra parameters passed to GooglePalmAPI::Client#generate_chat_message
|
81
81
|
# @return [String] The chat completion
|
82
82
|
#
|
83
|
-
def chat(prompt: "", messages: [], **
|
83
|
+
def chat(prompt: "", messages: [], context: "", examples: [], **options)
|
84
84
|
raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
|
85
85
|
|
86
|
-
messages << {author: "0", content: prompt} if !prompt.empty?
|
87
|
-
|
88
|
-
# TODO: Figure out how to introduce persisted conversations
|
89
86
|
default_params = {
|
90
87
|
temperature: DEFAULTS[:temperature],
|
91
|
-
|
88
|
+
context: context,
|
89
|
+
messages: compose_chat_messages(prompt: prompt, messages: messages),
|
90
|
+
examples: compose_examples(examples)
|
92
91
|
}
|
93
92
|
|
94
|
-
|
95
|
-
|
93
|
+
Langchain::Utils::TokenLength::GooglePalmValidator.validate_max_tokens!(self, default_params[:messages], "chat-bison-001")
|
94
|
+
|
95
|
+
if options[:stop_sequences]
|
96
|
+
default_params[:stop] = options.delete(:stop_sequences)
|
96
97
|
end
|
97
98
|
|
98
|
-
if
|
99
|
-
default_params[:max_output_tokens] =
|
99
|
+
if options[:max_tokens]
|
100
|
+
default_params[:max_output_tokens] = options.delete(:max_tokens)
|
100
101
|
end
|
101
102
|
|
102
|
-
default_params.merge!(
|
103
|
+
default_params.merge!(options)
|
103
104
|
|
104
105
|
response = client.generate_chat_message(**default_params)
|
106
|
+
raise "GooglePalm API returned an error: #{response}" if response.dig("error")
|
107
|
+
|
105
108
|
response.dig("candidates", 0, "content")
|
106
109
|
end
|
107
110
|
|
@@ -124,5 +127,39 @@ module Langchain::LLM
|
|
124
127
|
max_tokens: 2048
|
125
128
|
)
|
126
129
|
end
|
130
|
+
|
131
|
+
private
|
132
|
+
|
133
|
+
def compose_chat_messages(prompt:, messages:)
|
134
|
+
history = []
|
135
|
+
history.concat transform_messages(messages) unless messages.empty?
|
136
|
+
|
137
|
+
unless prompt.empty?
|
138
|
+
if history.last && history.last[:role] == "user"
|
139
|
+
history.last[:content] += "\n#{prompt}"
|
140
|
+
else
|
141
|
+
history.append({author: "user", content: prompt})
|
142
|
+
end
|
143
|
+
end
|
144
|
+
history
|
145
|
+
end
|
146
|
+
|
147
|
+
def compose_examples(examples)
|
148
|
+
examples.each_slice(2).map do |example|
|
149
|
+
{
|
150
|
+
input: {content: example.first[:content]},
|
151
|
+
output: {content: example.last[:content]}
|
152
|
+
}
|
153
|
+
end
|
154
|
+
end
|
155
|
+
|
156
|
+
def transform_messages(messages)
|
157
|
+
messages.map do |message|
|
158
|
+
{
|
159
|
+
author: message[:role],
|
160
|
+
content: message[:content]
|
161
|
+
}
|
162
|
+
end
|
163
|
+
end
|
127
164
|
end
|
128
165
|
end
|
data/lib/langchain/llm/openai.rb
CHANGED
@@ -35,7 +35,7 @@ module Langchain::LLM
|
|
35
35
|
def embed(text:, **params)
|
36
36
|
parameters = {model: DEFAULTS[:embeddings_model_name], input: text}
|
37
37
|
|
38
|
-
Langchain::Utils::
|
38
|
+
Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(text, parameters[:model])
|
39
39
|
|
40
40
|
response = client.embeddings(parameters: parameters.merge(params))
|
41
41
|
response.dig("data").first.dig("embedding")
|
@@ -52,7 +52,7 @@ module Langchain::LLM
|
|
52
52
|
parameters = compose_parameters DEFAULTS[:completion_model_name], params
|
53
53
|
|
54
54
|
parameters[:prompt] = prompt
|
55
|
-
parameters[:max_tokens] = Langchain::Utils::
|
55
|
+
parameters[:max_tokens] = Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(prompt, parameters[:model])
|
56
56
|
|
57
57
|
response = client.completions(parameters: parameters)
|
58
58
|
response.dig("choices", 0, "text")
|
@@ -63,19 +63,22 @@ module Langchain::LLM
|
|
63
63
|
#
|
64
64
|
# @param prompt [String] The prompt to generate a chat completion for
|
65
65
|
# @param messages [Array] The messages that have been sent in the conversation
|
66
|
-
# @param
|
66
|
+
# @param context [String] The context of the conversation
|
67
|
+
# @param examples [Array] Examples of messages provide model with
|
68
|
+
# @param options extra parameters passed to OpenAI::Client#chat
|
67
69
|
# @return [String] The chat completion
|
68
70
|
#
|
69
|
-
def chat(prompt: "", messages: [], **
|
71
|
+
def chat(prompt: "", messages: [], context: "", examples: [], **options)
|
70
72
|
raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
|
71
73
|
|
72
|
-
|
73
|
-
|
74
|
-
parameters =
|
75
|
-
parameters[:messages] = messages
|
76
|
-
parameters[:max_tokens] = validate_max_tokens(messages, parameters[:model])
|
74
|
+
parameters = compose_parameters DEFAULTS[:chat_completion_model_name], options
|
75
|
+
parameters[:messages] = compose_chat_messages(prompt: prompt, messages: messages, context: context, examples: examples)
|
76
|
+
parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
|
77
77
|
|
78
78
|
response = client.chat(parameters: parameters)
|
79
|
+
|
80
|
+
raise "Chat completion failed: #{response}" if response.dig("error")
|
81
|
+
|
79
82
|
response.dig("choices", 0, "message", "content")
|
80
83
|
end
|
81
84
|
|
@@ -104,8 +107,40 @@ module Langchain::LLM
|
|
104
107
|
default_params.merge(params)
|
105
108
|
end
|
106
109
|
|
110
|
+
def compose_chat_messages(prompt:, messages:, context:, examples:)
|
111
|
+
history = []
|
112
|
+
|
113
|
+
history.concat transform_messages(examples) unless examples.empty?
|
114
|
+
|
115
|
+
history.concat transform_messages(messages) unless messages.empty?
|
116
|
+
|
117
|
+
unless context.nil? || context.empty?
|
118
|
+
history.reject! { |message| message[:role] == "system" }
|
119
|
+
history.prepend({role: "system", content: context})
|
120
|
+
end
|
121
|
+
|
122
|
+
unless prompt.empty?
|
123
|
+
if history.last && history.last[:role] == "user"
|
124
|
+
history.last[:content] += "\n#{prompt}"
|
125
|
+
else
|
126
|
+
history.append({role: "user", content: prompt})
|
127
|
+
end
|
128
|
+
end
|
129
|
+
|
130
|
+
history
|
131
|
+
end
|
132
|
+
|
133
|
+
def transform_messages(messages)
|
134
|
+
messages.map do |message|
|
135
|
+
{
|
136
|
+
content: message[:content],
|
137
|
+
role: (message[:role] == "ai") ? "assistant" : message[:role]
|
138
|
+
}
|
139
|
+
end
|
140
|
+
end
|
141
|
+
|
107
142
|
def validate_max_tokens(messages, model)
|
108
|
-
Langchain::Utils::
|
143
|
+
Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(messages, model)
|
109
144
|
end
|
110
145
|
end
|
111
146
|
end
|
data/lib/langchain/tool/base.rb
CHANGED
@@ -57,6 +57,15 @@ module Langchain::Tool
|
|
57
57
|
self.class.const_get(:NAME)
|
58
58
|
end
|
59
59
|
|
60
|
+
#
|
61
|
+
# Returns the DESCRIPTION constant of the tool
|
62
|
+
#
|
63
|
+
# @return [String] tool description
|
64
|
+
#
|
65
|
+
def tool_description
|
66
|
+
self.class.const_get(:DESCRIPTION)
|
67
|
+
end
|
68
|
+
|
60
69
|
#
|
61
70
|
# Sets the DESCRIPTION constant of the tool
|
62
71
|
#
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
module Utils
|
5
|
+
module TokenLength
|
6
|
+
#
|
7
|
+
# This class is meant to validate the length of the text passed in to Google Palm's API.
|
8
|
+
# It is used to validate the token length before the API call is made
|
9
|
+
#
|
10
|
+
class GooglePalmValidator
|
11
|
+
TOKEN_LIMITS = {
|
12
|
+
# Source:
|
13
|
+
# This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage
|
14
|
+
|
15
|
+
# chat-bison-001 is the only model that currently supports countMessageTokens functions
|
16
|
+
"chat-bison-001" => {
|
17
|
+
"input_token_limit" => 4000, # 4096 is the limit but the countMessageTokens does not return anything higher than 4000
|
18
|
+
"output_token_limit" => 1024
|
19
|
+
}
|
20
|
+
# "text-bison-001" => {
|
21
|
+
# "input_token_limit" => 8196,
|
22
|
+
# "output_token_limit" => 1024
|
23
|
+
# },
|
24
|
+
# "embedding-gecko-001" => {
|
25
|
+
# "input_token_limit" => 1024
|
26
|
+
# }
|
27
|
+
}.freeze
|
28
|
+
|
29
|
+
#
|
30
|
+
# Validate the context length of the text
|
31
|
+
#
|
32
|
+
# @param content [String | Array<String>] The text or array of texts to validate
|
33
|
+
# @param model_name [String] The model name to validate against
|
34
|
+
# @return [Integer] Whether the text is valid or not
|
35
|
+
# @raise [TokenLimitExceeded] If the text is too long
|
36
|
+
#
|
37
|
+
def self.validate_max_tokens!(google_palm_llm, content, model_name)
|
38
|
+
text_token_length = if content.is_a?(Array)
|
39
|
+
content.sum { |item| token_length(google_palm_llm, item.to_json, model_name) }
|
40
|
+
else
|
41
|
+
token_length(google_palm_llm, content, model_name)
|
42
|
+
end
|
43
|
+
|
44
|
+
leftover_tokens = TOKEN_LIMITS.dig(model_name, "input_token_limit") - text_token_length
|
45
|
+
|
46
|
+
# Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
|
47
|
+
if leftover_tokens <= 0
|
48
|
+
raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS.dig(model_name, "input_token_limit")} tokens, but the given text is #{text_token_length} tokens long."
|
49
|
+
end
|
50
|
+
|
51
|
+
leftover_tokens
|
52
|
+
end
|
53
|
+
|
54
|
+
#
|
55
|
+
# Calculate token length for a given text and model name
|
56
|
+
#
|
57
|
+
# @param llm [Langchain::LLM:GooglePalm] The Langchain::LLM:GooglePalm instance
|
58
|
+
# @param text [String] The text to calculate the token length for
|
59
|
+
# @param model_name [String] The model name to validate against
|
60
|
+
# @return [Integer] The token length of the text
|
61
|
+
#
|
62
|
+
def self.token_length(llm, text, model_name = "chat-bison-001")
|
63
|
+
response = llm.client.count_message_tokens(model: model_name, prompt: text)
|
64
|
+
response.dig("tokenCount")
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
69
|
+
end
|
@@ -0,0 +1,75 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "tiktoken_ruby"
|
4
|
+
|
5
|
+
module Langchain
|
6
|
+
module Utils
|
7
|
+
module TokenLength
|
8
|
+
#
|
9
|
+
# This class is meant to validate the length of the text passed in to OpenAI's API.
|
10
|
+
# It is used to validate the token length before the API call is made
|
11
|
+
#
|
12
|
+
class OpenAIValidator
|
13
|
+
TOKEN_LIMITS = {
|
14
|
+
# Source:
|
15
|
+
# https://platform.openai.com/docs/api-reference/embeddings
|
16
|
+
# https://platform.openai.com/docs/models/gpt-4
|
17
|
+
"text-embedding-ada-002" => 8191,
|
18
|
+
"gpt-3.5-turbo" => 4096,
|
19
|
+
"gpt-3.5-turbo-0301" => 4096,
|
20
|
+
"text-davinci-003" => 4097,
|
21
|
+
"text-davinci-002" => 4097,
|
22
|
+
"code-davinci-002" => 8001,
|
23
|
+
"gpt-4" => 8192,
|
24
|
+
"gpt-4-0314" => 8192,
|
25
|
+
"gpt-4-32k" => 32768,
|
26
|
+
"gpt-4-32k-0314" => 32768,
|
27
|
+
"text-curie-001" => 2049,
|
28
|
+
"text-babbage-001" => 2049,
|
29
|
+
"text-ada-001" => 2049,
|
30
|
+
"davinci" => 2049,
|
31
|
+
"curie" => 2049,
|
32
|
+
"babbage" => 2049,
|
33
|
+
"ada" => 2049
|
34
|
+
}.freeze
|
35
|
+
|
36
|
+
#
|
37
|
+
# Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
|
38
|
+
#
|
39
|
+
# @param content [String | Array<String>] The text or array of texts to validate
|
40
|
+
# @param model_name [String] The model name to validate against
|
41
|
+
# @return [Integer] Whether the text is valid or not
|
42
|
+
# @raise [TokenLimitExceeded] If the text is too long
|
43
|
+
#
|
44
|
+
def self.validate_max_tokens!(content, model_name)
|
45
|
+
text_token_length = if content.is_a?(Array)
|
46
|
+
content.sum { |item| token_length(item.to_json, model_name) }
|
47
|
+
else
|
48
|
+
token_length(content, model_name)
|
49
|
+
end
|
50
|
+
|
51
|
+
max_tokens = TOKEN_LIMITS[model_name] - text_token_length
|
52
|
+
|
53
|
+
# Raise an error even if whole prompt is equal to the model's token limit (max_tokens == 0) since not response will be returned
|
54
|
+
if max_tokens <= 0
|
55
|
+
raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS[model_name]} tokens, but the given text is #{text_token_length} tokens long."
|
56
|
+
end
|
57
|
+
|
58
|
+
max_tokens
|
59
|
+
end
|
60
|
+
|
61
|
+
#
|
62
|
+
# Calculate token length for a given text and model name
|
63
|
+
#
|
64
|
+
# @param text [String] The text to calculate the token length for
|
65
|
+
# @param model_name [String] The model name to validate against
|
66
|
+
# @return [Integer] The token length of the text
|
67
|
+
#
|
68
|
+
def self.token_length(text, model_name)
|
69
|
+
encoder = Tiktoken.encoding_for_model(model_name)
|
70
|
+
encoder.encode(text).length
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
@@ -8,7 +8,7 @@ module Langchain::Vectorsearch
|
|
8
8
|
# Gem requirements: gem "chroma-db", "~> 0.3.0"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
# chroma = Langchain::Vectorsearch::Chroma.new(url:, index_name:, llm:, api_key: nil)
|
11
|
+
# chroma = Langchain::Vectorsearch::Chroma.new(url:, index_name:, llm:, llm_api_key:, api_key: nil)
|
12
12
|
#
|
13
13
|
|
14
14
|
# Initialize the Chroma client
|
@@ -0,0 +1,122 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain::Vectorsearch
|
4
|
+
class Hnswlib < Base
|
5
|
+
#
|
6
|
+
# Wrapper around HNSW (Hierarchical Navigable Small World) library.
|
7
|
+
# HNSWLib is an in-memory vectorstore that can be saved to a file on disk.
|
8
|
+
#
|
9
|
+
# Gem requirements:
|
10
|
+
# gem "hnswlib", "~> 0.8.1"
|
11
|
+
#
|
12
|
+
# Usage:
|
13
|
+
# hnsw = Langchain::Vectorsearch::Hnswlib.new(llm:, url:, index_name:)
|
14
|
+
#
|
15
|
+
|
16
|
+
attr_reader :client, :path_to_index
|
17
|
+
|
18
|
+
#
|
19
|
+
# Initialize the HNSW vector search
|
20
|
+
#
|
21
|
+
# @param llm [Object] The LLM client to use
|
22
|
+
# @param path_to_index [String] The local path to the index file, e.g.: "/storage/index.ann"
|
23
|
+
# @return [Langchain::Vectorsearch::Hnswlib] Class instance
|
24
|
+
#
|
25
|
+
def initialize(llm:, path_to_index:)
|
26
|
+
depends_on "hnswlib"
|
27
|
+
require "hnswlib"
|
28
|
+
|
29
|
+
super(llm: llm)
|
30
|
+
|
31
|
+
@client = ::Hnswlib::HierarchicalNSW.new(space: DEFAULT_METRIC, dim: llm.default_dimension)
|
32
|
+
@path_to_index = path_to_index
|
33
|
+
|
34
|
+
initialize_index
|
35
|
+
end
|
36
|
+
|
37
|
+
#
|
38
|
+
# Add a list of texts and corresponding IDs to the index
|
39
|
+
#
|
40
|
+
# @param texts [Array] The list of texts to add
|
41
|
+
# @param ids [Array] The list of corresponding IDs (integers) to the texts
|
42
|
+
# @return [Boolean] The response from the HNSW library
|
43
|
+
#
|
44
|
+
def add_texts(texts:, ids:)
|
45
|
+
resize_index(texts.size)
|
46
|
+
|
47
|
+
Array(texts).each_with_index do |text, i|
|
48
|
+
embedding = llm.embed(text: text)
|
49
|
+
|
50
|
+
client.add_point(embedding, ids[i])
|
51
|
+
end
|
52
|
+
|
53
|
+
client.save_index(path_to_index)
|
54
|
+
end
|
55
|
+
|
56
|
+
#
|
57
|
+
# Search for similar texts
|
58
|
+
#
|
59
|
+
# @param query [String] The text to search for
|
60
|
+
# @param k [Integer] The number of results to return
|
61
|
+
# @return [Array] Results in the format `[[id1, distance3], [id2, distance2]]`
|
62
|
+
#
|
63
|
+
def similarity_search(
|
64
|
+
query:,
|
65
|
+
k: 4
|
66
|
+
)
|
67
|
+
embedding = llm.embed(text: query)
|
68
|
+
|
69
|
+
similarity_search_by_vector(
|
70
|
+
embedding: embedding,
|
71
|
+
k: k
|
72
|
+
)
|
73
|
+
end
|
74
|
+
|
75
|
+
#
|
76
|
+
# Search for the K nearest neighbors of a given vector
|
77
|
+
#
|
78
|
+
# @param embedding [Array] The embedding to search for
|
79
|
+
# @param k [Integer] The number of results to return
|
80
|
+
# @return [Array] Results in the format `[[id1, distance3], [id2, distance2]]`
|
81
|
+
#
|
82
|
+
def similarity_search_by_vector(
|
83
|
+
embedding:,
|
84
|
+
k: 4
|
85
|
+
)
|
86
|
+
client.search_knn(embedding, k)
|
87
|
+
end
|
88
|
+
|
89
|
+
private
|
90
|
+
|
91
|
+
#
|
92
|
+
# Optionally resizes the index if there's no space for new data
|
93
|
+
#
|
94
|
+
# @param num_of_elements_to_add [Integer] The number of elements to add to the index
|
95
|
+
#
|
96
|
+
def resize_index(num_of_elements_to_add)
|
97
|
+
current_count = client.current_count
|
98
|
+
|
99
|
+
if (current_count + num_of_elements_to_add) > client.max_elements
|
100
|
+
new_size = current_count + num_of_elements_to_add
|
101
|
+
|
102
|
+
client.resize_index(new_size)
|
103
|
+
end
|
104
|
+
end
|
105
|
+
|
106
|
+
#
|
107
|
+
# Loads or initializes the new index
|
108
|
+
#
|
109
|
+
def initialize_index
|
110
|
+
if File.exist?(path_to_index)
|
111
|
+
client.load_index(path_to_index)
|
112
|
+
|
113
|
+
Langchain.logger.info("[#{self.class.name}]".blue + ": Successfully loaded the index at \"#{path_to_index}\"")
|
114
|
+
else
|
115
|
+
# Default max_elements: 100, but we constantly resize the index as new data is written to it
|
116
|
+
client.init_index(max_elements: 100)
|
117
|
+
|
118
|
+
Langchain.logger.info("[#{self.class.name}]".blue + ": Creating a new index at \"#{path_to_index}\"")
|
119
|
+
end
|
120
|
+
end
|
121
|
+
end
|
122
|
+
end
|
@@ -8,17 +8,9 @@ module Langchain::Vectorsearch
|
|
8
8
|
# Gem requirements: gem "milvus", "~> 0.9.0"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
# milvus = Langchain::Vectorsearch::Milvus.new(url:, index_name:, llm:)
|
11
|
+
# milvus = Langchain::Vectorsearch::Milvus.new(url:, index_name:, llm:, llm_api_key:)
|
12
12
|
#
|
13
13
|
|
14
|
-
#
|
15
|
-
# Initialize the Milvus client
|
16
|
-
#
|
17
|
-
# @param url [String] The URL of the Milvus server
|
18
|
-
# @param api_key [String] The API key to use
|
19
|
-
# @param index_name [String] The name of the index to use
|
20
|
-
# @param llm [Object] The LLM client to use
|
21
|
-
#
|
22
14
|
def initialize(url:, index_name:, llm:, api_key: nil)
|
23
15
|
depends_on "milvus"
|
24
16
|
require "milvus"
|
@@ -29,11 +21,6 @@ module Langchain::Vectorsearch
|
|
29
21
|
super(llm: llm)
|
30
22
|
end
|
31
23
|
|
32
|
-
#
|
33
|
-
# Add a list of texts to the index
|
34
|
-
#
|
35
|
-
# @param texts [Array] The list of texts to add
|
36
|
-
#
|
37
24
|
def add_texts(texts:)
|
38
25
|
client.entities.insert(
|
39
26
|
collection_name: index_name,
|
@@ -8,7 +8,7 @@ module Langchain::Vectorsearch
|
|
8
8
|
# Gem requirements: gem "pgvector", "~> 0.2"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
# pgvector = Langchain::Vectorsearch::Pgvector.new(url:, index_name:, llm:)
|
11
|
+
# pgvector = Langchain::Vectorsearch::Pgvector.new(url:, index_name:, llm:, llm_api_key:)
|
12
12
|
#
|
13
13
|
|
14
14
|
# The operators supported by the PostgreSQL vector search adapter
|
@@ -20,14 +20,10 @@ module Langchain::Vectorsearch
|
|
20
20
|
|
21
21
|
attr_reader :operator, :quoted_table_name
|
22
22
|
|
23
|
-
#
|
24
|
-
# Initialize the PostgreSQL client
|
25
|
-
#
|
26
23
|
# @param url [String] The URL of the PostgreSQL database
|
27
24
|
# @param index_name [String] The name of the table to use for the index
|
28
25
|
# @param llm [Object] The LLM client to use
|
29
26
|
# @param api_key [String] The API key for the Vectorsearch DB (not used for PostgreSQL)
|
30
|
-
#
|
31
27
|
def initialize(url:, index_name:, llm:, api_key: nil)
|
32
28
|
require "pg"
|
33
29
|
require "pgvector"
|
@@ -8,17 +8,14 @@ module Langchain::Vectorsearch
|
|
8
8
|
# Gem requirements: gem "pinecone", "~> 0.1.6"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
# pinecone = Langchain::Vectorsearch::Pinecone.new(environment:, api_key:, index_name:, llm:)
|
11
|
+
# pinecone = Langchain::Vectorsearch::Pinecone.new(environment:, api_key:, index_name:, llm:, llm_api_key:)
|
12
12
|
#
|
13
13
|
|
14
|
-
#
|
15
14
|
# Initialize the Pinecone client
|
16
|
-
#
|
17
15
|
# @param environment [String] The environment to use
|
18
16
|
# @param api_key [String] The API key to use
|
19
17
|
# @param index_name [String] The name of the index to use
|
20
18
|
# @param llm [Object] The LLM client to use
|
21
|
-
#
|
22
19
|
def initialize(environment:, api_key:, index_name:, llm:)
|
23
20
|
depends_on "pinecone"
|
24
21
|
require "pinecone"
|
@@ -8,17 +8,14 @@ module Langchain::Vectorsearch
|
|
8
8
|
# Gem requirements: gem "qdrant-ruby", "~> 0.9.0"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
# qdrant = Langchain::Vectorsearch::Qdrant.new(url:, api_key:, index_name:, llm:)
|
11
|
+
# qdrant = Langchain::Vectorsearch::Qdrant.new(url:, api_key:, index_name:, llm:, llm_api_key:)
|
12
12
|
#
|
13
13
|
|
14
|
-
#
|
15
14
|
# Initialize the Qdrant client
|
16
|
-
#
|
17
15
|
# @param url [String] The URL of the Qdrant server
|
18
16
|
# @param api_key [String] The API key to use
|
19
17
|
# @param index_name [String] The name of the index to use
|
20
18
|
# @param llm [Object] The LLM client to use
|
21
|
-
#
|
22
19
|
def initialize(url:, api_key:, index_name:, llm:)
|
23
20
|
depends_on "qdrant-ruby"
|
24
21
|
require "qdrant"
|
@@ -8,17 +8,14 @@ module Langchain::Vectorsearch
|
|
8
8
|
# Gem requirements: gem "weaviate-ruby", "~> 0.8.0"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
# weaviate = Langchain::Vectorsearch::Weaviate.new(url:, api_key:, index_name:, llm:)
|
11
|
+
# weaviate = Langchain::Vectorsearch::Weaviate.new(url:, api_key:, index_name:, llm:, llm_api_key:)
|
12
12
|
#
|
13
13
|
|
14
|
-
#
|
15
14
|
# Initialize the Weaviate adapter
|
16
|
-
#
|
17
15
|
# @param url [String] The URL of the Weaviate instance
|
18
16
|
# @param api_key [String] The API key to use
|
19
17
|
# @param index_name [String] The name of the index to use
|
20
18
|
# @param llm [Object] The LLM client to use
|
21
|
-
#
|
22
19
|
def initialize(url:, api_key:, index_name:, llm:)
|
23
20
|
depends_on "weaviate-ruby"
|
24
21
|
require "weaviate"
|
data/lib/langchain/version.rb
CHANGED
data/lib/langchain.rb
CHANGED
@@ -62,6 +62,7 @@ module Langchain
|
|
62
62
|
|
63
63
|
autoload :Loader, "langchain/loader"
|
64
64
|
autoload :Data, "langchain/data"
|
65
|
+
autoload :Chat, "langchain/chat"
|
65
66
|
autoload :DependencyHelper, "langchain/dependency_helper"
|
66
67
|
|
67
68
|
module Agent
|
@@ -92,12 +93,18 @@ module Langchain
|
|
92
93
|
end
|
93
94
|
|
94
95
|
module Utils
|
95
|
-
|
96
|
+
module TokenLength
|
97
|
+
class TokenLimitExceeded < StandardError; end
|
98
|
+
|
99
|
+
autoload :OpenAIValidator, "langchain/utils/token_length/openai_validator"
|
100
|
+
autoload :GooglePalmValidator, "langchain/utils/token_length/google_palm_validator"
|
101
|
+
end
|
96
102
|
end
|
97
103
|
|
98
104
|
module Vectorsearch
|
99
105
|
autoload :Base, "langchain/vectorsearch/base"
|
100
106
|
autoload :Chroma, "langchain/vectorsearch/chroma"
|
107
|
+
autoload :Hnswlib, "langchain/vectorsearch/hnswlib"
|
101
108
|
autoload :Milvus, "langchain/vectorsearch/milvus"
|
102
109
|
autoload :Pinecone, "langchain/vectorsearch/pinecone"
|
103
110
|
autoload :Pgvector, "langchain/vectorsearch/pgvector"
|
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: langchainrb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.5.
|
4
|
+
version: 0.5.4
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrei Bondarev
|
@@ -192,6 +192,20 @@ dependencies:
|
|
192
192
|
- - "~>"
|
193
193
|
- !ruby/object:Gem::Version
|
194
194
|
version: 2.0.0
|
195
|
+
- !ruby/object:Gem::Dependency
|
196
|
+
name: hnswlib
|
197
|
+
requirement: !ruby/object:Gem::Requirement
|
198
|
+
requirements:
|
199
|
+
- - "~>"
|
200
|
+
- !ruby/object:Gem::Version
|
201
|
+
version: 0.8.1
|
202
|
+
type: :development
|
203
|
+
prerelease: false
|
204
|
+
version_requirements: !ruby/object:Gem::Requirement
|
205
|
+
requirements:
|
206
|
+
- - "~>"
|
207
|
+
- !ruby/object:Gem::Version
|
208
|
+
version: 0.8.1
|
195
209
|
- !ruby/object:Gem::Dependency
|
196
210
|
name: hugging-face
|
197
211
|
requirement: !ruby/object:Gem::Requirement
|
@@ -432,6 +446,7 @@ files:
|
|
432
446
|
- lib/langchain/agent/sql_query_agent/sql_query_agent.rb
|
433
447
|
- lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.json
|
434
448
|
- lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.json
|
449
|
+
- lib/langchain/chat.rb
|
435
450
|
- lib/langchain/data.rb
|
436
451
|
- lib/langchain/dependency_helper.rb
|
437
452
|
- lib/langchain/llm/ai21.rb
|
@@ -462,9 +477,11 @@ files:
|
|
462
477
|
- lib/langchain/tool/ruby_code_interpreter.rb
|
463
478
|
- lib/langchain/tool/serp_api.rb
|
464
479
|
- lib/langchain/tool/wikipedia.rb
|
465
|
-
- lib/langchain/utils/
|
480
|
+
- lib/langchain/utils/token_length/google_palm_validator.rb
|
481
|
+
- lib/langchain/utils/token_length/openai_validator.rb
|
466
482
|
- lib/langchain/vectorsearch/base.rb
|
467
483
|
- lib/langchain/vectorsearch/chroma.rb
|
484
|
+
- lib/langchain/vectorsearch/hnswlib.rb
|
468
485
|
- lib/langchain/vectorsearch/milvus.rb
|
469
486
|
- lib/langchain/vectorsearch/pgvector.rb
|
470
487
|
- lib/langchain/vectorsearch/pinecone.rb
|
@@ -495,7 +512,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
495
512
|
- !ruby/object:Gem::Version
|
496
513
|
version: '0'
|
497
514
|
requirements: []
|
498
|
-
rubygems_version: 3.3
|
515
|
+
rubygems_version: 3.2.3
|
499
516
|
signing_key:
|
500
517
|
specification_version: 4
|
501
518
|
summary: Build LLM-backed Ruby applications with Ruby's LangChain
|
@@ -1,89 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require "tiktoken_ruby"
|
4
|
-
|
5
|
-
module Langchain
|
6
|
-
module Utils
|
7
|
-
class TokenLimitExceeded < StandardError; end
|
8
|
-
|
9
|
-
class TokenLengthValidator
|
10
|
-
#
|
11
|
-
# This class is meant to validate the length of the text passed in to OpenAI's API.
|
12
|
-
# It is used to validate the token length before the API call is made
|
13
|
-
#
|
14
|
-
TOKEN_LIMITS = {
|
15
|
-
# Source:
|
16
|
-
# https://platform.openai.com/docs/api-reference/embeddings
|
17
|
-
# https://platform.openai.com/docs/models/gpt-4
|
18
|
-
"text-embedding-ada-002" => 8191,
|
19
|
-
"gpt-3.5-turbo" => 4096,
|
20
|
-
"gpt-3.5-turbo-0301" => 4096,
|
21
|
-
"text-davinci-003" => 4097,
|
22
|
-
"text-davinci-002" => 4097,
|
23
|
-
"code-davinci-002" => 8001,
|
24
|
-
"gpt-4" => 8192,
|
25
|
-
"gpt-4-0314" => 8192,
|
26
|
-
"gpt-4-32k" => 32768,
|
27
|
-
"gpt-4-32k-0314" => 32768,
|
28
|
-
"text-curie-001" => 2049,
|
29
|
-
"text-babbage-001" => 2049,
|
30
|
-
"text-ada-001" => 2049,
|
31
|
-
"davinci" => 2049,
|
32
|
-
"curie" => 2049,
|
33
|
-
"babbage" => 2049,
|
34
|
-
"ada" => 2049
|
35
|
-
}.freeze
|
36
|
-
|
37
|
-
# GOOGLE_PALM_TOKEN_LIMITS = {
|
38
|
-
# "chat-bison-001" => {
|
39
|
-
# "inputTokenLimit"=>4096,
|
40
|
-
# "outputTokenLimit"=>1024
|
41
|
-
# },
|
42
|
-
# "text-bison-001" => {
|
43
|
-
# "inputTokenLimit"=>8196,
|
44
|
-
# "outputTokenLimit"=>1024
|
45
|
-
# },
|
46
|
-
# "embedding-gecko-001" => {
|
47
|
-
# "inputTokenLimit"=>1024
|
48
|
-
# }
|
49
|
-
# }.freeze
|
50
|
-
|
51
|
-
#
|
52
|
-
# Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
|
53
|
-
#
|
54
|
-
# @param content [String | Array<String>] The text or array of texts to validate
|
55
|
-
# @param model_name [String] The model name to validate against
|
56
|
-
# @return [Integer] Whether the text is valid or not
|
57
|
-
# @raise [TokenLimitExceeded] If the text is too long
|
58
|
-
#
|
59
|
-
def self.validate_max_tokens!(content, model_name)
|
60
|
-
text_token_length = if content.is_a?(Array)
|
61
|
-
content.sum { |item| token_length(item.to_json, model_name) }
|
62
|
-
else
|
63
|
-
token_length(content, model_name)
|
64
|
-
end
|
65
|
-
|
66
|
-
max_tokens = TOKEN_LIMITS[model_name] - text_token_length
|
67
|
-
|
68
|
-
# Raise an error even if whole prompt is equal to the model's token limit (max_tokens == 0) since not response will be returned
|
69
|
-
if max_tokens <= 0
|
70
|
-
raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS[model_name]} tokens, but the given text is #{text_token_length} tokens long."
|
71
|
-
end
|
72
|
-
|
73
|
-
max_tokens
|
74
|
-
end
|
75
|
-
|
76
|
-
#
|
77
|
-
# Calculate token length for a given text and model name
|
78
|
-
#
|
79
|
-
# @param text [String] The text to validate
|
80
|
-
# @param model_name [String] The model name to validate against
|
81
|
-
# @return [Integer] The token length of the text
|
82
|
-
#
|
83
|
-
def self.token_length(text, model_name)
|
84
|
-
encoder = Tiktoken.encoding_for_model(model_name)
|
85
|
-
encoder.encode(text).length
|
86
|
-
end
|
87
|
-
end
|
88
|
-
end
|
89
|
-
end
|