langchainrb 0.18.0 → 0.19.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +30 -0
- data/README.md +4 -4
- data/lib/langchain/assistant/llm/adapter.rb +7 -6
- data/lib/langchain/assistant/llm/adapters/anthropic.rb +1 -3
- data/lib/langchain/assistant/llm/adapters/aws_bedrock_anthropic.rb +35 -0
- data/lib/langchain/assistant/llm/adapters/ollama.rb +1 -3
- data/lib/langchain/assistant/messages/anthropic_message.rb +89 -17
- data/lib/langchain/assistant/messages/base.rb +4 -0
- data/lib/langchain/assistant/messages/google_gemini_message.rb +62 -21
- data/lib/langchain/assistant/messages/mistral_ai_message.rb +69 -24
- data/lib/langchain/assistant/messages/ollama_message.rb +9 -5
- data/lib/langchain/assistant/messages/openai_message.rb +78 -26
- data/lib/langchain/assistant.rb +2 -1
- data/lib/langchain/llm/anthropic.rb +10 -10
- data/lib/langchain/llm/aws_bedrock.rb +75 -120
- data/lib/langchain/llm/azure.rb +1 -1
- data/lib/langchain/llm/base.rb +1 -1
- data/lib/langchain/llm/cohere.rb +8 -8
- data/lib/langchain/llm/google_gemini.rb +5 -6
- data/lib/langchain/llm/google_vertex_ai.rb +6 -5
- data/lib/langchain/llm/hugging_face.rb +4 -4
- data/lib/langchain/llm/mistral_ai.rb +4 -4
- data/lib/langchain/llm/ollama.rb +10 -8
- data/lib/langchain/llm/openai.rb +6 -5
- data/lib/langchain/llm/parameters/chat.rb +4 -1
- data/lib/langchain/llm/replicate.rb +6 -6
- data/lib/langchain/llm/response/ai21_response.rb +20 -0
- data/lib/langchain/tool_definition.rb +7 -0
- data/lib/langchain/utils/image_wrapper.rb +37 -0
- data/lib/langchain/version.rb +1 -1
- metadata +4 -2
@@ -50,32 +50,14 @@ module Langchain
|
|
50
50
|
#
|
51
51
|
# @return [Hash] The message as an OpenAI API-compatible hash
|
52
52
|
def to_hash
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
h[:content] = []
|
62
|
-
|
63
|
-
if content && !content.empty?
|
64
|
-
h[:content] << {
|
65
|
-
type: "text",
|
66
|
-
text: content
|
67
|
-
}
|
68
|
-
end
|
69
|
-
|
70
|
-
if image_url
|
71
|
-
h[:content] << {
|
72
|
-
type: "image_url",
|
73
|
-
image_url: {
|
74
|
-
url: image_url
|
75
|
-
}
|
76
|
-
}
|
77
|
-
end
|
78
|
-
end
|
53
|
+
if assistant?
|
54
|
+
assistant_hash
|
55
|
+
elsif system?
|
56
|
+
system_hash
|
57
|
+
elsif tool?
|
58
|
+
tool_hash
|
59
|
+
elsif user?
|
60
|
+
user_hash
|
79
61
|
end
|
80
62
|
end
|
81
63
|
|
@@ -99,6 +81,76 @@ module Langchain
|
|
99
81
|
def tool?
|
100
82
|
role == "tool"
|
101
83
|
end
|
84
|
+
|
85
|
+
def user?
|
86
|
+
role == "user"
|
87
|
+
end
|
88
|
+
|
89
|
+
# Convert the message to an OpenAI API-compatible hash
|
90
|
+
# @return [Hash] The message as an OpenAI API-compatible hash, with the role as "assistant"
|
91
|
+
def assistant_hash
|
92
|
+
if tool_calls.any?
|
93
|
+
{
|
94
|
+
role: "assistant",
|
95
|
+
tool_calls: tool_calls
|
96
|
+
}
|
97
|
+
else
|
98
|
+
{
|
99
|
+
role: "assistant",
|
100
|
+
content: build_content_array
|
101
|
+
}
|
102
|
+
end
|
103
|
+
end
|
104
|
+
|
105
|
+
# Convert the message to an OpenAI API-compatible hash
|
106
|
+
# @return [Hash] The message as an OpenAI API-compatible hash, with the role as "system"
|
107
|
+
def system_hash
|
108
|
+
{
|
109
|
+
role: "system",
|
110
|
+
content: build_content_array
|
111
|
+
}
|
112
|
+
end
|
113
|
+
|
114
|
+
# Convert the message to an OpenAI API-compatible hash
|
115
|
+
# @return [Hash] The message as an OpenAI API-compatible hash, with the role as "tool"
|
116
|
+
def tool_hash
|
117
|
+
{
|
118
|
+
role: "tool",
|
119
|
+
tool_call_id: tool_call_id,
|
120
|
+
content: build_content_array
|
121
|
+
}
|
122
|
+
end
|
123
|
+
|
124
|
+
# Convert the message to an OpenAI API-compatible hash
|
125
|
+
# @return [Hash] The message as an OpenAI API-compatible hash, with the role as "user"
|
126
|
+
def user_hash
|
127
|
+
{
|
128
|
+
role: "user",
|
129
|
+
content: build_content_array
|
130
|
+
}
|
131
|
+
end
|
132
|
+
|
133
|
+
# Builds the content value for the message hash
|
134
|
+
# @return [Array<Hash>] An array of content hashes, with keys :type and :text or :image_url.
|
135
|
+
def build_content_array
|
136
|
+
content_details = []
|
137
|
+
if content && !content.empty?
|
138
|
+
content_details << {
|
139
|
+
type: "text",
|
140
|
+
text: content
|
141
|
+
}
|
142
|
+
end
|
143
|
+
|
144
|
+
if image_url
|
145
|
+
content_details << {
|
146
|
+
type: "image_url",
|
147
|
+
image_url: {
|
148
|
+
url: image_url
|
149
|
+
}
|
150
|
+
}
|
151
|
+
end
|
152
|
+
content_details
|
153
|
+
end
|
102
154
|
end
|
103
155
|
end
|
104
156
|
end
|
data/lib/langchain/assistant.rb
CHANGED
@@ -196,7 +196,7 @@ module Langchain
|
|
196
196
|
|
197
197
|
if @llm_adapter.support_system_message?
|
198
198
|
# TODO: Should we still set a system message even if @instructions is "" or nil?
|
199
|
-
replace_system_message!(content: new_instructions)
|
199
|
+
replace_system_message!(content: new_instructions)
|
200
200
|
end
|
201
201
|
end
|
202
202
|
|
@@ -217,6 +217,7 @@ module Langchain
|
|
217
217
|
# @return [Array<Langchain::Message>] The messages
|
218
218
|
def replace_system_message!(content:)
|
219
219
|
messages.delete_if(&:system?)
|
220
|
+
return if content.nil?
|
220
221
|
|
221
222
|
message = build_message(role: "system", content: content)
|
222
223
|
messages.unshift(message)
|
@@ -13,16 +13,16 @@ module Langchain::LLM
|
|
13
13
|
class Anthropic < Base
|
14
14
|
DEFAULTS = {
|
15
15
|
temperature: 0.0,
|
16
|
-
|
17
|
-
|
18
|
-
|
16
|
+
completion_model: "claude-2.1",
|
17
|
+
chat_model: "claude-3-5-sonnet-20240620",
|
18
|
+
max_tokens: 256
|
19
19
|
}.freeze
|
20
20
|
|
21
21
|
# Initialize an Anthropic LLM instance
|
22
22
|
#
|
23
23
|
# @param api_key [String] The API key to use
|
24
24
|
# @param llm_options [Hash] Options to pass to the Anthropic client
|
25
|
-
# @param default_options [Hash] Default options to use on every call to LLM, e.g.: { temperature:,
|
25
|
+
# @param default_options [Hash] Default options to use on every call to LLM, e.g.: { temperature:, completion_model:, chat_model:, max_tokens: }
|
26
26
|
# @return [Langchain::LLM::Anthropic] Langchain::LLM::Anthropic instance
|
27
27
|
def initialize(api_key:, llm_options: {}, default_options: {})
|
28
28
|
depends_on "anthropic"
|
@@ -30,9 +30,9 @@ module Langchain::LLM
|
|
30
30
|
@client = ::Anthropic::Client.new(access_token: api_key, **llm_options)
|
31
31
|
@defaults = DEFAULTS.merge(default_options)
|
32
32
|
chat_parameters.update(
|
33
|
-
model: {default: @defaults[:
|
33
|
+
model: {default: @defaults[:chat_model]},
|
34
34
|
temperature: {default: @defaults[:temperature]},
|
35
|
-
max_tokens: {default: @defaults[:
|
35
|
+
max_tokens: {default: @defaults[:max_tokens]},
|
36
36
|
metadata: {},
|
37
37
|
system: {}
|
38
38
|
)
|
@@ -54,8 +54,8 @@ module Langchain::LLM
|
|
54
54
|
# @return [Langchain::LLM::AnthropicResponse] The completion
|
55
55
|
def complete(
|
56
56
|
prompt:,
|
57
|
-
model: @defaults[:
|
58
|
-
|
57
|
+
model: @defaults[:completion_model],
|
58
|
+
max_tokens: @defaults[:max_tokens],
|
59
59
|
stop_sequences: nil,
|
60
60
|
temperature: @defaults[:temperature],
|
61
61
|
top_p: nil,
|
@@ -64,12 +64,12 @@ module Langchain::LLM
|
|
64
64
|
stream: nil
|
65
65
|
)
|
66
66
|
raise ArgumentError.new("model argument is required") if model.empty?
|
67
|
-
raise ArgumentError.new("
|
67
|
+
raise ArgumentError.new("max_tokens argument is required") if max_tokens.nil?
|
68
68
|
|
69
69
|
parameters = {
|
70
70
|
model: model,
|
71
71
|
prompt: prompt,
|
72
|
-
max_tokens_to_sample:
|
72
|
+
max_tokens_to_sample: max_tokens,
|
73
73
|
temperature: temperature
|
74
74
|
}
|
75
75
|
parameters[:stop_sequences] = stop_sequences if stop_sequences
|
@@ -7,51 +7,40 @@ module Langchain::LLM
|
|
7
7
|
# gem 'aws-sdk-bedrockruntime', '~> 1.1'
|
8
8
|
#
|
9
9
|
# Usage:
|
10
|
-
# llm = Langchain::LLM::AwsBedrock.new(
|
10
|
+
# llm = Langchain::LLM::AwsBedrock.new(default_options: {})
|
11
11
|
#
|
12
12
|
class AwsBedrock < Base
|
13
13
|
DEFAULTS = {
|
14
|
-
|
15
|
-
|
16
|
-
|
14
|
+
chat_model: "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
15
|
+
completion_model: "anthropic.claude-v2:1",
|
16
|
+
embedding_model: "amazon.titan-embed-text-v1",
|
17
17
|
max_tokens_to_sample: 300,
|
18
18
|
temperature: 1,
|
19
19
|
top_k: 250,
|
20
20
|
top_p: 0.999,
|
21
21
|
stop_sequences: ["\n\nHuman:"],
|
22
|
-
|
23
|
-
return_likelihoods: "NONE",
|
24
|
-
count_penalty: {
|
25
|
-
scale: 0,
|
26
|
-
apply_to_whitespaces: false,
|
27
|
-
apply_to_punctuations: false,
|
28
|
-
apply_to_numbers: false,
|
29
|
-
apply_to_stopwords: false,
|
30
|
-
apply_to_emojis: false
|
31
|
-
},
|
32
|
-
presence_penalty: {
|
33
|
-
scale: 0,
|
34
|
-
apply_to_whitespaces: false,
|
35
|
-
apply_to_punctuations: false,
|
36
|
-
apply_to_numbers: false,
|
37
|
-
apply_to_stopwords: false,
|
38
|
-
apply_to_emojis: false
|
39
|
-
},
|
40
|
-
frequency_penalty: {
|
41
|
-
scale: 0,
|
42
|
-
apply_to_whitespaces: false,
|
43
|
-
apply_to_punctuations: false,
|
44
|
-
apply_to_numbers: false,
|
45
|
-
apply_to_stopwords: false,
|
46
|
-
apply_to_emojis: false
|
47
|
-
}
|
22
|
+
return_likelihoods: "NONE"
|
48
23
|
}.freeze
|
49
24
|
|
50
25
|
attr_reader :client, :defaults
|
51
26
|
|
52
|
-
SUPPORTED_COMPLETION_PROVIDERS = %i[
|
53
|
-
|
54
|
-
|
27
|
+
SUPPORTED_COMPLETION_PROVIDERS = %i[
|
28
|
+
anthropic
|
29
|
+
ai21
|
30
|
+
cohere
|
31
|
+
meta
|
32
|
+
].freeze
|
33
|
+
|
34
|
+
SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[
|
35
|
+
anthropic
|
36
|
+
ai21
|
37
|
+
mistral
|
38
|
+
].freeze
|
39
|
+
|
40
|
+
SUPPORTED_EMBEDDING_PROVIDERS = %i[
|
41
|
+
amazon
|
42
|
+
cohere
|
43
|
+
].freeze
|
55
44
|
|
56
45
|
def initialize(aws_client_options: {}, default_options: {})
|
57
46
|
depends_on "aws-sdk-bedrockruntime", req: "aws-sdk-bedrockruntime"
|
@@ -60,12 +49,11 @@ module Langchain::LLM
|
|
60
49
|
@defaults = DEFAULTS.merge(default_options)
|
61
50
|
|
62
51
|
chat_parameters.update(
|
63
|
-
model: {default: @defaults[:
|
52
|
+
model: {default: @defaults[:chat_model]},
|
64
53
|
temperature: {},
|
65
54
|
max_tokens: {default: @defaults[:max_tokens_to_sample]},
|
66
55
|
metadata: {},
|
67
|
-
system: {}
|
68
|
-
anthropic_version: {default: "bedrock-2023-05-31"}
|
56
|
+
system: {}
|
69
57
|
)
|
70
58
|
chat_parameters.ignore(:n, :user)
|
71
59
|
chat_parameters.remap(stop: :stop_sequences)
|
@@ -84,7 +72,7 @@ module Langchain::LLM
|
|
84
72
|
parameters = compose_embedding_parameters params.merge(text:)
|
85
73
|
|
86
74
|
response = client.invoke_model({
|
87
|
-
model_id: @defaults[:
|
75
|
+
model_id: @defaults[:embedding_model],
|
88
76
|
body: parameters.to_json,
|
89
77
|
content_type: "application/json",
|
90
78
|
accept: "application/json"
|
@@ -100,23 +88,25 @@ module Langchain::LLM
|
|
100
88
|
# @param params extra parameters passed to Aws::BedrockRuntime::Client#invoke_model
|
101
89
|
# @return [Langchain::LLM::AnthropicResponse], [Langchain::LLM::CohereResponse] or [Langchain::LLM::AI21Response] Response object
|
102
90
|
#
|
103
|
-
def complete(
|
104
|
-
|
91
|
+
def complete(
|
92
|
+
prompt:,
|
93
|
+
model: @defaults[:completion_model],
|
94
|
+
**params
|
95
|
+
)
|
96
|
+
raise "Completion provider #{model} is not supported." unless SUPPORTED_COMPLETION_PROVIDERS.include?(provider_name(model))
|
105
97
|
|
106
|
-
|
107
|
-
|
108
|
-
parameters = compose_parameters params
|
98
|
+
parameters = compose_parameters(params, model)
|
109
99
|
|
110
100
|
parameters[:prompt] = wrap_prompt prompt
|
111
101
|
|
112
102
|
response = client.invoke_model({
|
113
|
-
model_id:
|
103
|
+
model_id: model,
|
114
104
|
body: parameters.to_json,
|
115
105
|
content_type: "application/json",
|
116
106
|
accept: "application/json"
|
117
107
|
})
|
118
108
|
|
119
|
-
parse_response
|
109
|
+
parse_response(response, model)
|
120
110
|
end
|
121
111
|
|
122
112
|
# Generate a chat completion for a given prompt
|
@@ -126,7 +116,7 @@ module Langchain::LLM
|
|
126
116
|
# @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
|
127
117
|
# @option params [Array<String>] :messages The messages to generate a completion for
|
128
118
|
# @option params [String] :system The system prompt to provide instructions
|
129
|
-
# @option params [String] :model The model to use for completion defaults to @defaults[:
|
119
|
+
# @option params [String] :model The model to use for completion defaults to @defaults[:chat_model]
|
130
120
|
# @option params [Integer] :max_tokens The maximum number of tokens to generate defaults to @defaults[:max_tokens_to_sample]
|
131
121
|
# @option params [Array<String>] :stop The stop sequences to use for completion
|
132
122
|
# @option params [Array<String>] :stop_sequences The stop sequences to use for completion
|
@@ -137,10 +127,11 @@ module Langchain::LLM
|
|
137
127
|
# @return [Langchain::LLM::AnthropicResponse] Response object
|
138
128
|
def chat(params = {}, &block)
|
139
129
|
parameters = chat_parameters.to_params(params)
|
130
|
+
parameters = compose_parameters(parameters, parameters[:model])
|
140
131
|
|
141
|
-
|
142
|
-
|
143
|
-
|
132
|
+
unless SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(provider_name(parameters[:model]))
|
133
|
+
raise "Chat provider #{parameters[:model]} is not supported."
|
134
|
+
end
|
144
135
|
|
145
136
|
if block
|
146
137
|
response_chunks = []
|
@@ -168,18 +159,32 @@ module Langchain::LLM
|
|
168
159
|
accept: "application/json"
|
169
160
|
})
|
170
161
|
|
171
|
-
parse_response
|
162
|
+
parse_response(response, parameters[:model])
|
172
163
|
end
|
173
164
|
end
|
174
165
|
|
175
166
|
private
|
176
167
|
|
168
|
+
def parse_model_id(model_id)
|
169
|
+
model_id
|
170
|
+
.gsub("us.", "") # Meta append "us." to their model ids
|
171
|
+
.split(".")
|
172
|
+
end
|
173
|
+
|
174
|
+
def provider_name(model_id)
|
175
|
+
parse_model_id(model_id).first.to_sym
|
176
|
+
end
|
177
|
+
|
178
|
+
def model_name(model_id)
|
179
|
+
parse_model_id(model_id).last
|
180
|
+
end
|
181
|
+
|
177
182
|
def completion_provider
|
178
|
-
@defaults[:
|
183
|
+
@defaults[:completion_model].split(".").first.to_sym
|
179
184
|
end
|
180
185
|
|
181
186
|
def embedding_provider
|
182
|
-
@defaults[:
|
187
|
+
@defaults[:embedding_model].split(".").first.to_sym
|
183
188
|
end
|
184
189
|
|
185
190
|
def wrap_prompt(prompt)
|
@@ -200,15 +205,17 @@ module Langchain::LLM
|
|
200
205
|
end
|
201
206
|
end
|
202
207
|
|
203
|
-
def compose_parameters(params)
|
204
|
-
if
|
205
|
-
compose_parameters_anthropic
|
206
|
-
elsif
|
207
|
-
compose_parameters_cohere
|
208
|
-
elsif
|
209
|
-
|
210
|
-
elsif
|
211
|
-
|
208
|
+
def compose_parameters(params, model_id)
|
209
|
+
if provider_name(model_id) == :anthropic
|
210
|
+
compose_parameters_anthropic(params)
|
211
|
+
elsif provider_name(model_id) == :cohere
|
212
|
+
compose_parameters_cohere(params)
|
213
|
+
elsif provider_name(model_id) == :ai21
|
214
|
+
params
|
215
|
+
elsif provider_name(model_id) == :meta
|
216
|
+
params
|
217
|
+
elsif provider_name(model_id) == :mistral
|
218
|
+
params
|
212
219
|
end
|
213
220
|
end
|
214
221
|
|
@@ -220,15 +227,17 @@ module Langchain::LLM
|
|
220
227
|
end
|
221
228
|
end
|
222
229
|
|
223
|
-
def parse_response(response)
|
224
|
-
if
|
230
|
+
def parse_response(response, model_id)
|
231
|
+
if provider_name(model_id) == :anthropic
|
225
232
|
Langchain::LLM::AnthropicResponse.new(JSON.parse(response.body.string))
|
226
|
-
elsif
|
233
|
+
elsif provider_name(model_id) == :cohere
|
227
234
|
Langchain::LLM::CohereResponse.new(JSON.parse(response.body.string))
|
228
|
-
elsif
|
235
|
+
elsif provider_name(model_id) == :ai21
|
229
236
|
Langchain::LLM::AI21Response.new(JSON.parse(response.body.string, symbolize_names: true))
|
230
|
-
elsif
|
237
|
+
elsif provider_name(model_id) == :meta
|
231
238
|
Langchain::LLM::AwsBedrockMetaResponse.new(JSON.parse(response.body.string))
|
239
|
+
elsif provider_name(model_id) == :mistral
|
240
|
+
Langchain::LLM::MistralAIResponse.new(JSON.parse(response.body.string))
|
232
241
|
end
|
233
242
|
end
|
234
243
|
|
@@ -276,61 +285,7 @@ module Langchain::LLM
|
|
276
285
|
end
|
277
286
|
|
278
287
|
def compose_parameters_anthropic(params)
|
279
|
-
|
280
|
-
|
281
|
-
{
|
282
|
-
max_tokens_to_sample: default_params[:max_tokens_to_sample],
|
283
|
-
temperature: default_params[:temperature],
|
284
|
-
top_k: default_params[:top_k],
|
285
|
-
top_p: default_params[:top_p],
|
286
|
-
stop_sequences: default_params[:stop_sequences],
|
287
|
-
anthropic_version: default_params[:anthropic_version]
|
288
|
-
}
|
289
|
-
end
|
290
|
-
|
291
|
-
def compose_parameters_ai21(params)
|
292
|
-
default_params = @defaults.merge(params)
|
293
|
-
|
294
|
-
{
|
295
|
-
maxTokens: default_params[:max_tokens_to_sample],
|
296
|
-
temperature: default_params[:temperature],
|
297
|
-
topP: default_params[:top_p],
|
298
|
-
stopSequences: default_params[:stop_sequences],
|
299
|
-
countPenalty: {
|
300
|
-
scale: default_params[:count_penalty][:scale],
|
301
|
-
applyToWhitespaces: default_params[:count_penalty][:apply_to_whitespaces],
|
302
|
-
applyToPunctuations: default_params[:count_penalty][:apply_to_punctuations],
|
303
|
-
applyToNumbers: default_params[:count_penalty][:apply_to_numbers],
|
304
|
-
applyToStopwords: default_params[:count_penalty][:apply_to_stopwords],
|
305
|
-
applyToEmojis: default_params[:count_penalty][:apply_to_emojis]
|
306
|
-
},
|
307
|
-
presencePenalty: {
|
308
|
-
scale: default_params[:presence_penalty][:scale],
|
309
|
-
applyToWhitespaces: default_params[:presence_penalty][:apply_to_whitespaces],
|
310
|
-
applyToPunctuations: default_params[:presence_penalty][:apply_to_punctuations],
|
311
|
-
applyToNumbers: default_params[:presence_penalty][:apply_to_numbers],
|
312
|
-
applyToStopwords: default_params[:presence_penalty][:apply_to_stopwords],
|
313
|
-
applyToEmojis: default_params[:presence_penalty][:apply_to_emojis]
|
314
|
-
},
|
315
|
-
frequencyPenalty: {
|
316
|
-
scale: default_params[:frequency_penalty][:scale],
|
317
|
-
applyToWhitespaces: default_params[:frequency_penalty][:apply_to_whitespaces],
|
318
|
-
applyToPunctuations: default_params[:frequency_penalty][:apply_to_punctuations],
|
319
|
-
applyToNumbers: default_params[:frequency_penalty][:apply_to_numbers],
|
320
|
-
applyToStopwords: default_params[:frequency_penalty][:apply_to_stopwords],
|
321
|
-
applyToEmojis: default_params[:frequency_penalty][:apply_to_emojis]
|
322
|
-
}
|
323
|
-
}
|
324
|
-
end
|
325
|
-
|
326
|
-
def compose_parameters_meta(params)
|
327
|
-
default_params = @defaults.merge(params)
|
328
|
-
|
329
|
-
{
|
330
|
-
temperature: default_params[:temperature],
|
331
|
-
top_p: default_params[:top_p],
|
332
|
-
max_gen_len: default_params[:max_tokens_to_sample]
|
333
|
-
}
|
288
|
+
params.merge(anthropic_version: "bedrock-2023-05-31")
|
334
289
|
end
|
335
290
|
|
336
291
|
def response_from_chunks(chunks)
|
data/lib/langchain/llm/azure.rb
CHANGED
@@ -33,7 +33,7 @@ module Langchain::LLM
|
|
33
33
|
)
|
34
34
|
@defaults = DEFAULTS.merge(default_options)
|
35
35
|
chat_parameters.update(
|
36
|
-
model: {default: @defaults[:
|
36
|
+
model: {default: @defaults[:chat_model]},
|
37
37
|
logprobs: {},
|
38
38
|
top_logprobs: {},
|
39
39
|
n: {default: @defaults[:n]},
|
data/lib/langchain/llm/base.rb
CHANGED
@@ -34,7 +34,7 @@ module Langchain::LLM
|
|
34
34
|
default_dimensions
|
35
35
|
end
|
36
36
|
|
37
|
-
# Returns the number of vector dimensions used by DEFAULTS[:
|
37
|
+
# Returns the number of vector dimensions used by DEFAULTS[:chat_model]
|
38
38
|
#
|
39
39
|
# @return [Integer] Vector dimensions
|
40
40
|
def default_dimensions
|
data/lib/langchain/llm/cohere.rb
CHANGED
@@ -13,9 +13,9 @@ module Langchain::LLM
|
|
13
13
|
class Cohere < Base
|
14
14
|
DEFAULTS = {
|
15
15
|
temperature: 0.0,
|
16
|
-
|
17
|
-
|
18
|
-
|
16
|
+
completion_model: "command",
|
17
|
+
chat_model: "command-r-plus",
|
18
|
+
embedding_model: "small",
|
19
19
|
dimensions: 1024,
|
20
20
|
truncate: "START"
|
21
21
|
}.freeze
|
@@ -26,7 +26,7 @@ module Langchain::LLM
|
|
26
26
|
@client = ::Cohere::Client.new(api_key: api_key)
|
27
27
|
@defaults = DEFAULTS.merge(default_options)
|
28
28
|
chat_parameters.update(
|
29
|
-
model: {default: @defaults[:
|
29
|
+
model: {default: @defaults[:chat_model]},
|
30
30
|
temperature: {default: @defaults[:temperature]},
|
31
31
|
response_format: {default: @defaults[:response_format]}
|
32
32
|
)
|
@@ -48,10 +48,10 @@ module Langchain::LLM
|
|
48
48
|
def embed(text:)
|
49
49
|
response = client.embed(
|
50
50
|
texts: [text],
|
51
|
-
model: @defaults[:
|
51
|
+
model: @defaults[:embedding_model]
|
52
52
|
)
|
53
53
|
|
54
|
-
Langchain::LLM::CohereResponse.new response, model: @defaults[:
|
54
|
+
Langchain::LLM::CohereResponse.new response, model: @defaults[:embedding_model]
|
55
55
|
end
|
56
56
|
|
57
57
|
#
|
@@ -65,7 +65,7 @@ module Langchain::LLM
|
|
65
65
|
default_params = {
|
66
66
|
prompt: prompt,
|
67
67
|
temperature: @defaults[:temperature],
|
68
|
-
model: @defaults[:
|
68
|
+
model: @defaults[:completion_model],
|
69
69
|
truncate: @defaults[:truncate]
|
70
70
|
}
|
71
71
|
|
@@ -76,7 +76,7 @@ module Langchain::LLM
|
|
76
76
|
default_params.merge!(params)
|
77
77
|
|
78
78
|
response = client.generate(**default_params)
|
79
|
-
Langchain::LLM::CohereResponse.new response, model: @defaults[:
|
79
|
+
Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model]
|
80
80
|
end
|
81
81
|
|
82
82
|
# Generate a chat completion for given messages
|
@@ -5,8 +5,8 @@ module Langchain::LLM
|
|
5
5
|
# llm = Langchain::LLM::GoogleGemini.new(api_key: ENV['GOOGLE_GEMINI_API_KEY'])
|
6
6
|
class GoogleGemini < Base
|
7
7
|
DEFAULTS = {
|
8
|
-
|
9
|
-
|
8
|
+
chat_model: "gemini-1.5-pro-latest",
|
9
|
+
embedding_model: "text-embedding-004",
|
10
10
|
temperature: 0.0
|
11
11
|
}
|
12
12
|
|
@@ -17,10 +17,10 @@ module Langchain::LLM
|
|
17
17
|
@defaults = DEFAULTS.merge(default_options)
|
18
18
|
|
19
19
|
chat_parameters.update(
|
20
|
-
model: {default: @defaults[:
|
20
|
+
model: {default: @defaults[:chat_model]},
|
21
21
|
temperature: {default: @defaults[:temperature]},
|
22
22
|
generation_config: {default: nil},
|
23
|
-
safety_settings: {default:
|
23
|
+
safety_settings: {default: @defaults[:safety_settings]}
|
24
24
|
)
|
25
25
|
chat_parameters.remap(
|
26
26
|
messages: :contents,
|
@@ -72,9 +72,8 @@ module Langchain::LLM
|
|
72
72
|
|
73
73
|
def embed(
|
74
74
|
text:,
|
75
|
-
model: @defaults[:
|
75
|
+
model: @defaults[:embedding_model]
|
76
76
|
)
|
77
|
-
|
78
77
|
params = {
|
79
78
|
content: {
|
80
79
|
parts: [
|
@@ -17,8 +17,8 @@ module Langchain::LLM
|
|
17
17
|
top_p: 0.8,
|
18
18
|
top_k: 40,
|
19
19
|
dimensions: 768,
|
20
|
-
|
21
|
-
|
20
|
+
embedding_model: "textembedding-gecko",
|
21
|
+
chat_model: "gemini-1.0-pro"
|
22
22
|
}.freeze
|
23
23
|
|
24
24
|
# Google Cloud has a project id and a specific region of deployment.
|
@@ -38,8 +38,9 @@ module Langchain::LLM
|
|
38
38
|
@defaults = DEFAULTS.merge(default_options)
|
39
39
|
|
40
40
|
chat_parameters.update(
|
41
|
-
model: {default: @defaults[:
|
42
|
-
temperature: {default: @defaults[:temperature]}
|
41
|
+
model: {default: @defaults[:chat_model]},
|
42
|
+
temperature: {default: @defaults[:temperature]},
|
43
|
+
safety_settings: {default: @defaults[:safety_settings]}
|
43
44
|
)
|
44
45
|
chat_parameters.remap(
|
45
46
|
messages: :contents,
|
@@ -57,7 +58,7 @@ module Langchain::LLM
|
|
57
58
|
#
|
58
59
|
def embed(
|
59
60
|
text:,
|
60
|
-
model: @defaults[:
|
61
|
+
model: @defaults[:embedding_model]
|
61
62
|
)
|
62
63
|
params = {instances: [{content: text}]}
|
63
64
|
|
@@ -12,7 +12,7 @@ module Langchain::LLM
|
|
12
12
|
#
|
13
13
|
class HuggingFace < Base
|
14
14
|
DEFAULTS = {
|
15
|
-
|
15
|
+
embedding_model: "sentence-transformers/all-MiniLM-L6-v2"
|
16
16
|
}.freeze
|
17
17
|
|
18
18
|
EMBEDDING_SIZES = {
|
@@ -36,7 +36,7 @@ module Langchain::LLM
|
|
36
36
|
def default_dimensions
|
37
37
|
# since Huggin Face can run multiple models, look it up or generate an embedding and return the size
|
38
38
|
@default_dimensions ||= @defaults[:dimensions] ||
|
39
|
-
EMBEDDING_SIZES.fetch(@defaults[:
|
39
|
+
EMBEDDING_SIZES.fetch(@defaults[:embedding_model].to_sym) do
|
40
40
|
embed(text: "test").embedding.size
|
41
41
|
end
|
42
42
|
end
|
@@ -50,9 +50,9 @@ module Langchain::LLM
|
|
50
50
|
def embed(text:)
|
51
51
|
response = client.embedding(
|
52
52
|
input: text,
|
53
|
-
model: @defaults[:
|
53
|
+
model: @defaults[:embedding_model]
|
54
54
|
)
|
55
|
-
Langchain::LLM::HuggingFaceResponse.new(response, model: @defaults[:
|
55
|
+
Langchain::LLM::HuggingFaceResponse.new(response, model: @defaults[:embedding_model])
|
56
56
|
end
|
57
57
|
end
|
58
58
|
end
|