langchainrb 0.18.0 → 0.19.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (32) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +30 -0
  3. data/README.md +4 -4
  4. data/lib/langchain/assistant/llm/adapter.rb +7 -6
  5. data/lib/langchain/assistant/llm/adapters/anthropic.rb +1 -3
  6. data/lib/langchain/assistant/llm/adapters/aws_bedrock_anthropic.rb +35 -0
  7. data/lib/langchain/assistant/llm/adapters/ollama.rb +1 -3
  8. data/lib/langchain/assistant/messages/anthropic_message.rb +89 -17
  9. data/lib/langchain/assistant/messages/base.rb +4 -0
  10. data/lib/langchain/assistant/messages/google_gemini_message.rb +62 -21
  11. data/lib/langchain/assistant/messages/mistral_ai_message.rb +69 -24
  12. data/lib/langchain/assistant/messages/ollama_message.rb +9 -5
  13. data/lib/langchain/assistant/messages/openai_message.rb +78 -26
  14. data/lib/langchain/assistant.rb +2 -1
  15. data/lib/langchain/llm/anthropic.rb +10 -10
  16. data/lib/langchain/llm/aws_bedrock.rb +75 -120
  17. data/lib/langchain/llm/azure.rb +1 -1
  18. data/lib/langchain/llm/base.rb +1 -1
  19. data/lib/langchain/llm/cohere.rb +8 -8
  20. data/lib/langchain/llm/google_gemini.rb +5 -6
  21. data/lib/langchain/llm/google_vertex_ai.rb +6 -5
  22. data/lib/langchain/llm/hugging_face.rb +4 -4
  23. data/lib/langchain/llm/mistral_ai.rb +4 -4
  24. data/lib/langchain/llm/ollama.rb +10 -8
  25. data/lib/langchain/llm/openai.rb +6 -5
  26. data/lib/langchain/llm/parameters/chat.rb +4 -1
  27. data/lib/langchain/llm/replicate.rb +6 -6
  28. data/lib/langchain/llm/response/ai21_response.rb +20 -0
  29. data/lib/langchain/tool_definition.rb +7 -0
  30. data/lib/langchain/utils/image_wrapper.rb +37 -0
  31. data/lib/langchain/version.rb +1 -1
  32. metadata +4 -2
@@ -50,32 +50,14 @@ module Langchain
50
50
  #
51
51
  # @return [Hash] The message as an OpenAI API-compatible hash
52
52
  def to_hash
53
- {}.tap do |h|
54
- h[:role] = role
55
-
56
- if tool_calls.any?
57
- h[:tool_calls] = tool_calls
58
- else
59
- h[:tool_call_id] = tool_call_id if tool_call_id
60
-
61
- h[:content] = []
62
-
63
- if content && !content.empty?
64
- h[:content] << {
65
- type: "text",
66
- text: content
67
- }
68
- end
69
-
70
- if image_url
71
- h[:content] << {
72
- type: "image_url",
73
- image_url: {
74
- url: image_url
75
- }
76
- }
77
- end
78
- end
53
+ if assistant?
54
+ assistant_hash
55
+ elsif system?
56
+ system_hash
57
+ elsif tool?
58
+ tool_hash
59
+ elsif user?
60
+ user_hash
79
61
  end
80
62
  end
81
63
 
@@ -99,6 +81,76 @@ module Langchain
99
81
  def tool?
100
82
  role == "tool"
101
83
  end
84
+
85
+ def user?
86
+ role == "user"
87
+ end
88
+
89
+ # Convert the message to an OpenAI API-compatible hash
90
+ # @return [Hash] The message as an OpenAI API-compatible hash, with the role as "assistant"
91
+ def assistant_hash
92
+ if tool_calls.any?
93
+ {
94
+ role: "assistant",
95
+ tool_calls: tool_calls
96
+ }
97
+ else
98
+ {
99
+ role: "assistant",
100
+ content: build_content_array
101
+ }
102
+ end
103
+ end
104
+
105
+ # Convert the message to an OpenAI API-compatible hash
106
+ # @return [Hash] The message as an OpenAI API-compatible hash, with the role as "system"
107
+ def system_hash
108
+ {
109
+ role: "system",
110
+ content: build_content_array
111
+ }
112
+ end
113
+
114
+ # Convert the message to an OpenAI API-compatible hash
115
+ # @return [Hash] The message as an OpenAI API-compatible hash, with the role as "tool"
116
+ def tool_hash
117
+ {
118
+ role: "tool",
119
+ tool_call_id: tool_call_id,
120
+ content: build_content_array
121
+ }
122
+ end
123
+
124
+ # Convert the message to an OpenAI API-compatible hash
125
+ # @return [Hash] The message as an OpenAI API-compatible hash, with the role as "user"
126
+ def user_hash
127
+ {
128
+ role: "user",
129
+ content: build_content_array
130
+ }
131
+ end
132
+
133
+ # Builds the content value for the message hash
134
+ # @return [Array<Hash>] An array of content hashes, with keys :type and :text or :image_url.
135
+ def build_content_array
136
+ content_details = []
137
+ if content && !content.empty?
138
+ content_details << {
139
+ type: "text",
140
+ text: content
141
+ }
142
+ end
143
+
144
+ if image_url
145
+ content_details << {
146
+ type: "image_url",
147
+ image_url: {
148
+ url: image_url
149
+ }
150
+ }
151
+ end
152
+ content_details
153
+ end
102
154
  end
103
155
  end
104
156
  end
@@ -196,7 +196,7 @@ module Langchain
196
196
 
197
197
  if @llm_adapter.support_system_message?
198
198
  # TODO: Should we still set a system message even if @instructions is "" or nil?
199
- replace_system_message!(content: new_instructions) if @instructions
199
+ replace_system_message!(content: new_instructions)
200
200
  end
201
201
  end
202
202
 
@@ -217,6 +217,7 @@ module Langchain
217
217
  # @return [Array<Langchain::Message>] The messages
218
218
  def replace_system_message!(content:)
219
219
  messages.delete_if(&:system?)
220
+ return if content.nil?
220
221
 
221
222
  message = build_message(role: "system", content: content)
222
223
  messages.unshift(message)
@@ -13,16 +13,16 @@ module Langchain::LLM
13
13
  class Anthropic < Base
14
14
  DEFAULTS = {
15
15
  temperature: 0.0,
16
- completion_model_name: "claude-2.1",
17
- chat_completion_model_name: "claude-3-5-sonnet-20240620",
18
- max_tokens_to_sample: 256
16
+ completion_model: "claude-2.1",
17
+ chat_model: "claude-3-5-sonnet-20240620",
18
+ max_tokens: 256
19
19
  }.freeze
20
20
 
21
21
  # Initialize an Anthropic LLM instance
22
22
  #
23
23
  # @param api_key [String] The API key to use
24
24
  # @param llm_options [Hash] Options to pass to the Anthropic client
25
- # @param default_options [Hash] Default options to use on every call to LLM, e.g.: { temperature:, completion_model_name:, chat_completion_model_name:, max_tokens_to_sample: }
25
+ # @param default_options [Hash] Default options to use on every call to LLM, e.g.: { temperature:, completion_model:, chat_model:, max_tokens: }
26
26
  # @return [Langchain::LLM::Anthropic] Langchain::LLM::Anthropic instance
27
27
  def initialize(api_key:, llm_options: {}, default_options: {})
28
28
  depends_on "anthropic"
@@ -30,9 +30,9 @@ module Langchain::LLM
30
30
  @client = ::Anthropic::Client.new(access_token: api_key, **llm_options)
31
31
  @defaults = DEFAULTS.merge(default_options)
32
32
  chat_parameters.update(
33
- model: {default: @defaults[:chat_completion_model_name]},
33
+ model: {default: @defaults[:chat_model]},
34
34
  temperature: {default: @defaults[:temperature]},
35
- max_tokens: {default: @defaults[:max_tokens_to_sample]},
35
+ max_tokens: {default: @defaults[:max_tokens]},
36
36
  metadata: {},
37
37
  system: {}
38
38
  )
@@ -54,8 +54,8 @@ module Langchain::LLM
54
54
  # @return [Langchain::LLM::AnthropicResponse] The completion
55
55
  def complete(
56
56
  prompt:,
57
- model: @defaults[:completion_model_name],
58
- max_tokens_to_sample: @defaults[:max_tokens_to_sample],
57
+ model: @defaults[:completion_model],
58
+ max_tokens: @defaults[:max_tokens],
59
59
  stop_sequences: nil,
60
60
  temperature: @defaults[:temperature],
61
61
  top_p: nil,
@@ -64,12 +64,12 @@ module Langchain::LLM
64
64
  stream: nil
65
65
  )
66
66
  raise ArgumentError.new("model argument is required") if model.empty?
67
- raise ArgumentError.new("max_tokens_to_sample argument is required") if max_tokens_to_sample.nil?
67
+ raise ArgumentError.new("max_tokens argument is required") if max_tokens.nil?
68
68
 
69
69
  parameters = {
70
70
  model: model,
71
71
  prompt: prompt,
72
- max_tokens_to_sample: max_tokens_to_sample,
72
+ max_tokens_to_sample: max_tokens,
73
73
  temperature: temperature
74
74
  }
75
75
  parameters[:stop_sequences] = stop_sequences if stop_sequences
@@ -7,51 +7,40 @@ module Langchain::LLM
7
7
  # gem 'aws-sdk-bedrockruntime', '~> 1.1'
8
8
  #
9
9
  # Usage:
10
- # llm = Langchain::LLM::AwsBedrock.new(llm_options: {})
10
+ # llm = Langchain::LLM::AwsBedrock.new(default_options: {})
11
11
  #
12
12
  class AwsBedrock < Base
13
13
  DEFAULTS = {
14
- chat_completion_model_name: "anthropic.claude-v2",
15
- completion_model_name: "anthropic.claude-v2",
16
- embeddings_model_name: "amazon.titan-embed-text-v1",
14
+ chat_model: "anthropic.claude-3-5-sonnet-20240620-v1:0",
15
+ completion_model: "anthropic.claude-v2:1",
16
+ embedding_model: "amazon.titan-embed-text-v1",
17
17
  max_tokens_to_sample: 300,
18
18
  temperature: 1,
19
19
  top_k: 250,
20
20
  top_p: 0.999,
21
21
  stop_sequences: ["\n\nHuman:"],
22
- anthropic_version: "bedrock-2023-05-31",
23
- return_likelihoods: "NONE",
24
- count_penalty: {
25
- scale: 0,
26
- apply_to_whitespaces: false,
27
- apply_to_punctuations: false,
28
- apply_to_numbers: false,
29
- apply_to_stopwords: false,
30
- apply_to_emojis: false
31
- },
32
- presence_penalty: {
33
- scale: 0,
34
- apply_to_whitespaces: false,
35
- apply_to_punctuations: false,
36
- apply_to_numbers: false,
37
- apply_to_stopwords: false,
38
- apply_to_emojis: false
39
- },
40
- frequency_penalty: {
41
- scale: 0,
42
- apply_to_whitespaces: false,
43
- apply_to_punctuations: false,
44
- apply_to_numbers: false,
45
- apply_to_stopwords: false,
46
- apply_to_emojis: false
47
- }
22
+ return_likelihoods: "NONE"
48
23
  }.freeze
49
24
 
50
25
  attr_reader :client, :defaults
51
26
 
52
- SUPPORTED_COMPLETION_PROVIDERS = %i[anthropic ai21 cohere meta].freeze
53
- SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[anthropic].freeze
54
- SUPPORTED_EMBEDDING_PROVIDERS = %i[amazon cohere].freeze
27
+ SUPPORTED_COMPLETION_PROVIDERS = %i[
28
+ anthropic
29
+ ai21
30
+ cohere
31
+ meta
32
+ ].freeze
33
+
34
+ SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[
35
+ anthropic
36
+ ai21
37
+ mistral
38
+ ].freeze
39
+
40
+ SUPPORTED_EMBEDDING_PROVIDERS = %i[
41
+ amazon
42
+ cohere
43
+ ].freeze
55
44
 
56
45
  def initialize(aws_client_options: {}, default_options: {})
57
46
  depends_on "aws-sdk-bedrockruntime", req: "aws-sdk-bedrockruntime"
@@ -60,12 +49,11 @@ module Langchain::LLM
60
49
  @defaults = DEFAULTS.merge(default_options)
61
50
 
62
51
  chat_parameters.update(
63
- model: {default: @defaults[:chat_completion_model_name]},
52
+ model: {default: @defaults[:chat_model]},
64
53
  temperature: {},
65
54
  max_tokens: {default: @defaults[:max_tokens_to_sample]},
66
55
  metadata: {},
67
- system: {},
68
- anthropic_version: {default: "bedrock-2023-05-31"}
56
+ system: {}
69
57
  )
70
58
  chat_parameters.ignore(:n, :user)
71
59
  chat_parameters.remap(stop: :stop_sequences)
@@ -84,7 +72,7 @@ module Langchain::LLM
84
72
  parameters = compose_embedding_parameters params.merge(text:)
85
73
 
86
74
  response = client.invoke_model({
87
- model_id: @defaults[:embeddings_model_name],
75
+ model_id: @defaults[:embedding_model],
88
76
  body: parameters.to_json,
89
77
  content_type: "application/json",
90
78
  accept: "application/json"
@@ -100,23 +88,25 @@ module Langchain::LLM
100
88
  # @param params extra parameters passed to Aws::BedrockRuntime::Client#invoke_model
101
89
  # @return [Langchain::LLM::AnthropicResponse], [Langchain::LLM::CohereResponse] or [Langchain::LLM::AI21Response] Response object
102
90
  #
103
- def complete(prompt:, **params)
104
- raise "Completion provider #{completion_provider} is not supported." unless SUPPORTED_COMPLETION_PROVIDERS.include?(completion_provider)
91
+ def complete(
92
+ prompt:,
93
+ model: @defaults[:completion_model],
94
+ **params
95
+ )
96
+ raise "Completion provider #{model} is not supported." unless SUPPORTED_COMPLETION_PROVIDERS.include?(provider_name(model))
105
97
 
106
- raise "Model #{@defaults[:completion_model_name]} only supports #chat." if @defaults[:completion_model_name].include?("claude-3")
107
-
108
- parameters = compose_parameters params
98
+ parameters = compose_parameters(params, model)
109
99
 
110
100
  parameters[:prompt] = wrap_prompt prompt
111
101
 
112
102
  response = client.invoke_model({
113
- model_id: @defaults[:completion_model_name],
103
+ model_id: model,
114
104
  body: parameters.to_json,
115
105
  content_type: "application/json",
116
106
  accept: "application/json"
117
107
  })
118
108
 
119
- parse_response response
109
+ parse_response(response, model)
120
110
  end
121
111
 
122
112
  # Generate a chat completion for a given prompt
@@ -126,7 +116,7 @@ module Langchain::LLM
126
116
  # @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
127
117
  # @option params [Array<String>] :messages The messages to generate a completion for
128
118
  # @option params [String] :system The system prompt to provide instructions
129
- # @option params [String] :model The model to use for completion defaults to @defaults[:chat_completion_model_name]
119
+ # @option params [String] :model The model to use for completion defaults to @defaults[:chat_model]
130
120
  # @option params [Integer] :max_tokens The maximum number of tokens to generate defaults to @defaults[:max_tokens_to_sample]
131
121
  # @option params [Array<String>] :stop The stop sequences to use for completion
132
122
  # @option params [Array<String>] :stop_sequences The stop sequences to use for completion
@@ -137,10 +127,11 @@ module Langchain::LLM
137
127
  # @return [Langchain::LLM::AnthropicResponse] Response object
138
128
  def chat(params = {}, &block)
139
129
  parameters = chat_parameters.to_params(params)
130
+ parameters = compose_parameters(parameters, parameters[:model])
140
131
 
141
- raise ArgumentError.new("messages argument is required") if Array(parameters[:messages]).empty?
142
-
143
- raise "Model #{parameters[:model]} does not support chat completions." unless Langchain::LLM::AwsBedrock::SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(completion_provider)
132
+ unless SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(provider_name(parameters[:model]))
133
+ raise "Chat provider #{parameters[:model]} is not supported."
134
+ end
144
135
 
145
136
  if block
146
137
  response_chunks = []
@@ -168,18 +159,32 @@ module Langchain::LLM
168
159
  accept: "application/json"
169
160
  })
170
161
 
171
- parse_response response
162
+ parse_response(response, parameters[:model])
172
163
  end
173
164
  end
174
165
 
175
166
  private
176
167
 
168
+ def parse_model_id(model_id)
169
+ model_id
170
+ .gsub("us.", "") # Meta append "us." to their model ids
171
+ .split(".")
172
+ end
173
+
174
+ def provider_name(model_id)
175
+ parse_model_id(model_id).first.to_sym
176
+ end
177
+
178
+ def model_name(model_id)
179
+ parse_model_id(model_id).last
180
+ end
181
+
177
182
  def completion_provider
178
- @defaults[:completion_model_name].split(".").first.to_sym
183
+ @defaults[:completion_model].split(".").first.to_sym
179
184
  end
180
185
 
181
186
  def embedding_provider
182
- @defaults[:embeddings_model_name].split(".").first.to_sym
187
+ @defaults[:embedding_model].split(".").first.to_sym
183
188
  end
184
189
 
185
190
  def wrap_prompt(prompt)
@@ -200,15 +205,17 @@ module Langchain::LLM
200
205
  end
201
206
  end
202
207
 
203
- def compose_parameters(params)
204
- if completion_provider == :anthropic
205
- compose_parameters_anthropic params
206
- elsif completion_provider == :cohere
207
- compose_parameters_cohere params
208
- elsif completion_provider == :ai21
209
- compose_parameters_ai21 params
210
- elsif completion_provider == :meta
211
- compose_parameters_meta params
208
+ def compose_parameters(params, model_id)
209
+ if provider_name(model_id) == :anthropic
210
+ compose_parameters_anthropic(params)
211
+ elsif provider_name(model_id) == :cohere
212
+ compose_parameters_cohere(params)
213
+ elsif provider_name(model_id) == :ai21
214
+ params
215
+ elsif provider_name(model_id) == :meta
216
+ params
217
+ elsif provider_name(model_id) == :mistral
218
+ params
212
219
  end
213
220
  end
214
221
 
@@ -220,15 +227,17 @@ module Langchain::LLM
220
227
  end
221
228
  end
222
229
 
223
- def parse_response(response)
224
- if completion_provider == :anthropic
230
+ def parse_response(response, model_id)
231
+ if provider_name(model_id) == :anthropic
225
232
  Langchain::LLM::AnthropicResponse.new(JSON.parse(response.body.string))
226
- elsif completion_provider == :cohere
233
+ elsif provider_name(model_id) == :cohere
227
234
  Langchain::LLM::CohereResponse.new(JSON.parse(response.body.string))
228
- elsif completion_provider == :ai21
235
+ elsif provider_name(model_id) == :ai21
229
236
  Langchain::LLM::AI21Response.new(JSON.parse(response.body.string, symbolize_names: true))
230
- elsif completion_provider == :meta
237
+ elsif provider_name(model_id) == :meta
231
238
  Langchain::LLM::AwsBedrockMetaResponse.new(JSON.parse(response.body.string))
239
+ elsif provider_name(model_id) == :mistral
240
+ Langchain::LLM::MistralAIResponse.new(JSON.parse(response.body.string))
232
241
  end
233
242
  end
234
243
 
@@ -276,61 +285,7 @@ module Langchain::LLM
276
285
  end
277
286
 
278
287
  def compose_parameters_anthropic(params)
279
- default_params = @defaults.merge(params)
280
-
281
- {
282
- max_tokens_to_sample: default_params[:max_tokens_to_sample],
283
- temperature: default_params[:temperature],
284
- top_k: default_params[:top_k],
285
- top_p: default_params[:top_p],
286
- stop_sequences: default_params[:stop_sequences],
287
- anthropic_version: default_params[:anthropic_version]
288
- }
289
- end
290
-
291
- def compose_parameters_ai21(params)
292
- default_params = @defaults.merge(params)
293
-
294
- {
295
- maxTokens: default_params[:max_tokens_to_sample],
296
- temperature: default_params[:temperature],
297
- topP: default_params[:top_p],
298
- stopSequences: default_params[:stop_sequences],
299
- countPenalty: {
300
- scale: default_params[:count_penalty][:scale],
301
- applyToWhitespaces: default_params[:count_penalty][:apply_to_whitespaces],
302
- applyToPunctuations: default_params[:count_penalty][:apply_to_punctuations],
303
- applyToNumbers: default_params[:count_penalty][:apply_to_numbers],
304
- applyToStopwords: default_params[:count_penalty][:apply_to_stopwords],
305
- applyToEmojis: default_params[:count_penalty][:apply_to_emojis]
306
- },
307
- presencePenalty: {
308
- scale: default_params[:presence_penalty][:scale],
309
- applyToWhitespaces: default_params[:presence_penalty][:apply_to_whitespaces],
310
- applyToPunctuations: default_params[:presence_penalty][:apply_to_punctuations],
311
- applyToNumbers: default_params[:presence_penalty][:apply_to_numbers],
312
- applyToStopwords: default_params[:presence_penalty][:apply_to_stopwords],
313
- applyToEmojis: default_params[:presence_penalty][:apply_to_emojis]
314
- },
315
- frequencyPenalty: {
316
- scale: default_params[:frequency_penalty][:scale],
317
- applyToWhitespaces: default_params[:frequency_penalty][:apply_to_whitespaces],
318
- applyToPunctuations: default_params[:frequency_penalty][:apply_to_punctuations],
319
- applyToNumbers: default_params[:frequency_penalty][:apply_to_numbers],
320
- applyToStopwords: default_params[:frequency_penalty][:apply_to_stopwords],
321
- applyToEmojis: default_params[:frequency_penalty][:apply_to_emojis]
322
- }
323
- }
324
- end
325
-
326
- def compose_parameters_meta(params)
327
- default_params = @defaults.merge(params)
328
-
329
- {
330
- temperature: default_params[:temperature],
331
- top_p: default_params[:top_p],
332
- max_gen_len: default_params[:max_tokens_to_sample]
333
- }
288
+ params.merge(anthropic_version: "bedrock-2023-05-31")
334
289
  end
335
290
 
336
291
  def response_from_chunks(chunks)
@@ -33,7 +33,7 @@ module Langchain::LLM
33
33
  )
34
34
  @defaults = DEFAULTS.merge(default_options)
35
35
  chat_parameters.update(
36
- model: {default: @defaults[:chat_completion_model_name]},
36
+ model: {default: @defaults[:chat_model]},
37
37
  logprobs: {},
38
38
  top_logprobs: {},
39
39
  n: {default: @defaults[:n]},
@@ -34,7 +34,7 @@ module Langchain::LLM
34
34
  default_dimensions
35
35
  end
36
36
 
37
- # Returns the number of vector dimensions used by DEFAULTS[:chat_completion_model_name]
37
+ # Returns the number of vector dimensions used by DEFAULTS[:chat_model]
38
38
  #
39
39
  # @return [Integer] Vector dimensions
40
40
  def default_dimensions
@@ -13,9 +13,9 @@ module Langchain::LLM
13
13
  class Cohere < Base
14
14
  DEFAULTS = {
15
15
  temperature: 0.0,
16
- completion_model_name: "command",
17
- chat_completion_model_name: "command-r-plus",
18
- embeddings_model_name: "small",
16
+ completion_model: "command",
17
+ chat_model: "command-r-plus",
18
+ embedding_model: "small",
19
19
  dimensions: 1024,
20
20
  truncate: "START"
21
21
  }.freeze
@@ -26,7 +26,7 @@ module Langchain::LLM
26
26
  @client = ::Cohere::Client.new(api_key: api_key)
27
27
  @defaults = DEFAULTS.merge(default_options)
28
28
  chat_parameters.update(
29
- model: {default: @defaults[:chat_completion_model_name]},
29
+ model: {default: @defaults[:chat_model]},
30
30
  temperature: {default: @defaults[:temperature]},
31
31
  response_format: {default: @defaults[:response_format]}
32
32
  )
@@ -48,10 +48,10 @@ module Langchain::LLM
48
48
  def embed(text:)
49
49
  response = client.embed(
50
50
  texts: [text],
51
- model: @defaults[:embeddings_model_name]
51
+ model: @defaults[:embedding_model]
52
52
  )
53
53
 
54
- Langchain::LLM::CohereResponse.new response, model: @defaults[:embeddings_model_name]
54
+ Langchain::LLM::CohereResponse.new response, model: @defaults[:embedding_model]
55
55
  end
56
56
 
57
57
  #
@@ -65,7 +65,7 @@ module Langchain::LLM
65
65
  default_params = {
66
66
  prompt: prompt,
67
67
  temperature: @defaults[:temperature],
68
- model: @defaults[:completion_model_name],
68
+ model: @defaults[:completion_model],
69
69
  truncate: @defaults[:truncate]
70
70
  }
71
71
 
@@ -76,7 +76,7 @@ module Langchain::LLM
76
76
  default_params.merge!(params)
77
77
 
78
78
  response = client.generate(**default_params)
79
- Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
79
+ Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model]
80
80
  end
81
81
 
82
82
  # Generate a chat completion for given messages
@@ -5,8 +5,8 @@ module Langchain::LLM
5
5
  # llm = Langchain::LLM::GoogleGemini.new(api_key: ENV['GOOGLE_GEMINI_API_KEY'])
6
6
  class GoogleGemini < Base
7
7
  DEFAULTS = {
8
- chat_completion_model_name: "gemini-1.5-pro-latest",
9
- embeddings_model_name: "text-embedding-004",
8
+ chat_model: "gemini-1.5-pro-latest",
9
+ embedding_model: "text-embedding-004",
10
10
  temperature: 0.0
11
11
  }
12
12
 
@@ -17,10 +17,10 @@ module Langchain::LLM
17
17
  @defaults = DEFAULTS.merge(default_options)
18
18
 
19
19
  chat_parameters.update(
20
- model: {default: @defaults[:chat_completion_model_name]},
20
+ model: {default: @defaults[:chat_model]},
21
21
  temperature: {default: @defaults[:temperature]},
22
22
  generation_config: {default: nil},
23
- safety_settings: {default: nil}
23
+ safety_settings: {default: @defaults[:safety_settings]}
24
24
  )
25
25
  chat_parameters.remap(
26
26
  messages: :contents,
@@ -72,9 +72,8 @@ module Langchain::LLM
72
72
 
73
73
  def embed(
74
74
  text:,
75
- model: @defaults[:embeddings_model_name]
75
+ model: @defaults[:embedding_model]
76
76
  )
77
-
78
77
  params = {
79
78
  content: {
80
79
  parts: [
@@ -17,8 +17,8 @@ module Langchain::LLM
17
17
  top_p: 0.8,
18
18
  top_k: 40,
19
19
  dimensions: 768,
20
- embeddings_model_name: "textembedding-gecko",
21
- chat_completion_model_name: "gemini-1.0-pro"
20
+ embedding_model: "textembedding-gecko",
21
+ chat_model: "gemini-1.0-pro"
22
22
  }.freeze
23
23
 
24
24
  # Google Cloud has a project id and a specific region of deployment.
@@ -38,8 +38,9 @@ module Langchain::LLM
38
38
  @defaults = DEFAULTS.merge(default_options)
39
39
 
40
40
  chat_parameters.update(
41
- model: {default: @defaults[:chat_completion_model_name]},
42
- temperature: {default: @defaults[:temperature]}
41
+ model: {default: @defaults[:chat_model]},
42
+ temperature: {default: @defaults[:temperature]},
43
+ safety_settings: {default: @defaults[:safety_settings]}
43
44
  )
44
45
  chat_parameters.remap(
45
46
  messages: :contents,
@@ -57,7 +58,7 @@ module Langchain::LLM
57
58
  #
58
59
  def embed(
59
60
  text:,
60
- model: @defaults[:embeddings_model_name]
61
+ model: @defaults[:embedding_model]
61
62
  )
62
63
  params = {instances: [{content: text}]}
63
64
 
@@ -12,7 +12,7 @@ module Langchain::LLM
12
12
  #
13
13
  class HuggingFace < Base
14
14
  DEFAULTS = {
15
- embeddings_model_name: "sentence-transformers/all-MiniLM-L6-v2"
15
+ embedding_model: "sentence-transformers/all-MiniLM-L6-v2"
16
16
  }.freeze
17
17
 
18
18
  EMBEDDING_SIZES = {
@@ -36,7 +36,7 @@ module Langchain::LLM
36
36
  def default_dimensions
37
37
  # since Huggin Face can run multiple models, look it up or generate an embedding and return the size
38
38
  @default_dimensions ||= @defaults[:dimensions] ||
39
- EMBEDDING_SIZES.fetch(@defaults[:embeddings_model_name].to_sym) do
39
+ EMBEDDING_SIZES.fetch(@defaults[:embedding_model].to_sym) do
40
40
  embed(text: "test").embedding.size
41
41
  end
42
42
  end
@@ -50,9 +50,9 @@ module Langchain::LLM
50
50
  def embed(text:)
51
51
  response = client.embedding(
52
52
  input: text,
53
- model: @defaults[:embeddings_model_name]
53
+ model: @defaults[:embedding_model]
54
54
  )
55
- Langchain::LLM::HuggingFaceResponse.new(response, model: @defaults[:embeddings_model_name])
55
+ Langchain::LLM::HuggingFaceResponse.new(response, model: @defaults[:embedding_model])
56
56
  end
57
57
  end
58
58
  end