langchainrb 0.18.0 → 0.19.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +30 -0
  3. data/README.md +4 -4
  4. data/lib/langchain/assistant/llm/adapter.rb +7 -6
  5. data/lib/langchain/assistant/llm/adapters/anthropic.rb +1 -3
  6. data/lib/langchain/assistant/llm/adapters/aws_bedrock_anthropic.rb +35 -0
  7. data/lib/langchain/assistant/llm/adapters/ollama.rb +1 -3
  8. data/lib/langchain/assistant/messages/anthropic_message.rb +89 -17
  9. data/lib/langchain/assistant/messages/base.rb +4 -0
  10. data/lib/langchain/assistant/messages/google_gemini_message.rb +62 -21
  11. data/lib/langchain/assistant/messages/mistral_ai_message.rb +69 -24
  12. data/lib/langchain/assistant/messages/ollama_message.rb +9 -5
  13. data/lib/langchain/assistant/messages/openai_message.rb +78 -26
  14. data/lib/langchain/assistant.rb +2 -1
  15. data/lib/langchain/llm/anthropic.rb +10 -10
  16. data/lib/langchain/llm/aws_bedrock.rb +75 -120
  17. data/lib/langchain/llm/azure.rb +1 -1
  18. data/lib/langchain/llm/base.rb +1 -1
  19. data/lib/langchain/llm/cohere.rb +8 -8
  20. data/lib/langchain/llm/google_gemini.rb +5 -6
  21. data/lib/langchain/llm/google_vertex_ai.rb +6 -5
  22. data/lib/langchain/llm/hugging_face.rb +4 -4
  23. data/lib/langchain/llm/mistral_ai.rb +4 -4
  24. data/lib/langchain/llm/ollama.rb +10 -8
  25. data/lib/langchain/llm/openai.rb +6 -5
  26. data/lib/langchain/llm/parameters/chat.rb +4 -1
  27. data/lib/langchain/llm/replicate.rb +6 -6
  28. data/lib/langchain/llm/response/ai21_response.rb +20 -0
  29. data/lib/langchain/tool_definition.rb +7 -0
  30. data/lib/langchain/utils/image_wrapper.rb +37 -0
  31. data/lib/langchain/version.rb +1 -1
  32. metadata +4 -2
@@ -50,32 +50,14 @@ module Langchain
50
50
  #
51
51
  # @return [Hash] The message as an OpenAI API-compatible hash
52
52
  def to_hash
53
- {}.tap do |h|
54
- h[:role] = role
55
-
56
- if tool_calls.any?
57
- h[:tool_calls] = tool_calls
58
- else
59
- h[:tool_call_id] = tool_call_id if tool_call_id
60
-
61
- h[:content] = []
62
-
63
- if content && !content.empty?
64
- h[:content] << {
65
- type: "text",
66
- text: content
67
- }
68
- end
69
-
70
- if image_url
71
- h[:content] << {
72
- type: "image_url",
73
- image_url: {
74
- url: image_url
75
- }
76
- }
77
- end
78
- end
53
+ if assistant?
54
+ assistant_hash
55
+ elsif system?
56
+ system_hash
57
+ elsif tool?
58
+ tool_hash
59
+ elsif user?
60
+ user_hash
79
61
  end
80
62
  end
81
63
 
@@ -99,6 +81,76 @@ module Langchain
99
81
  def tool?
100
82
  role == "tool"
101
83
  end
84
+
85
+ def user?
86
+ role == "user"
87
+ end
88
+
89
+ # Convert the message to an OpenAI API-compatible hash
90
+ # @return [Hash] The message as an OpenAI API-compatible hash, with the role as "assistant"
91
+ def assistant_hash
92
+ if tool_calls.any?
93
+ {
94
+ role: "assistant",
95
+ tool_calls: tool_calls
96
+ }
97
+ else
98
+ {
99
+ role: "assistant",
100
+ content: build_content_array
101
+ }
102
+ end
103
+ end
104
+
105
+ # Convert the message to an OpenAI API-compatible hash
106
+ # @return [Hash] The message as an OpenAI API-compatible hash, with the role as "system"
107
+ def system_hash
108
+ {
109
+ role: "system",
110
+ content: build_content_array
111
+ }
112
+ end
113
+
114
+ # Convert the message to an OpenAI API-compatible hash
115
+ # @return [Hash] The message as an OpenAI API-compatible hash, with the role as "tool"
116
+ def tool_hash
117
+ {
118
+ role: "tool",
119
+ tool_call_id: tool_call_id,
120
+ content: build_content_array
121
+ }
122
+ end
123
+
124
+ # Convert the message to an OpenAI API-compatible hash
125
+ # @return [Hash] The message as an OpenAI API-compatible hash, with the role as "user"
126
+ def user_hash
127
+ {
128
+ role: "user",
129
+ content: build_content_array
130
+ }
131
+ end
132
+
133
+ # Builds the content value for the message hash
134
+ # @return [Array<Hash>] An array of content hashes, with keys :type and :text or :image_url.
135
+ def build_content_array
136
+ content_details = []
137
+ if content && !content.empty?
138
+ content_details << {
139
+ type: "text",
140
+ text: content
141
+ }
142
+ end
143
+
144
+ if image_url
145
+ content_details << {
146
+ type: "image_url",
147
+ image_url: {
148
+ url: image_url
149
+ }
150
+ }
151
+ end
152
+ content_details
153
+ end
102
154
  end
103
155
  end
104
156
  end
@@ -196,7 +196,7 @@ module Langchain
196
196
 
197
197
  if @llm_adapter.support_system_message?
198
198
  # TODO: Should we still set a system message even if @instructions is "" or nil?
199
- replace_system_message!(content: new_instructions) if @instructions
199
+ replace_system_message!(content: new_instructions)
200
200
  end
201
201
  end
202
202
 
@@ -217,6 +217,7 @@ module Langchain
217
217
  # @return [Array<Langchain::Message>] The messages
218
218
  def replace_system_message!(content:)
219
219
  messages.delete_if(&:system?)
220
+ return if content.nil?
220
221
 
221
222
  message = build_message(role: "system", content: content)
222
223
  messages.unshift(message)
@@ -13,16 +13,16 @@ module Langchain::LLM
13
13
  class Anthropic < Base
14
14
  DEFAULTS = {
15
15
  temperature: 0.0,
16
- completion_model_name: "claude-2.1",
17
- chat_completion_model_name: "claude-3-5-sonnet-20240620",
18
- max_tokens_to_sample: 256
16
+ completion_model: "claude-2.1",
17
+ chat_model: "claude-3-5-sonnet-20240620",
18
+ max_tokens: 256
19
19
  }.freeze
20
20
 
21
21
  # Initialize an Anthropic LLM instance
22
22
  #
23
23
  # @param api_key [String] The API key to use
24
24
  # @param llm_options [Hash] Options to pass to the Anthropic client
25
- # @param default_options [Hash] Default options to use on every call to LLM, e.g.: { temperature:, completion_model_name:, chat_completion_model_name:, max_tokens_to_sample: }
25
+ # @param default_options [Hash] Default options to use on every call to LLM, e.g.: { temperature:, completion_model:, chat_model:, max_tokens: }
26
26
  # @return [Langchain::LLM::Anthropic] Langchain::LLM::Anthropic instance
27
27
  def initialize(api_key:, llm_options: {}, default_options: {})
28
28
  depends_on "anthropic"
@@ -30,9 +30,9 @@ module Langchain::LLM
30
30
  @client = ::Anthropic::Client.new(access_token: api_key, **llm_options)
31
31
  @defaults = DEFAULTS.merge(default_options)
32
32
  chat_parameters.update(
33
- model: {default: @defaults[:chat_completion_model_name]},
33
+ model: {default: @defaults[:chat_model]},
34
34
  temperature: {default: @defaults[:temperature]},
35
- max_tokens: {default: @defaults[:max_tokens_to_sample]},
35
+ max_tokens: {default: @defaults[:max_tokens]},
36
36
  metadata: {},
37
37
  system: {}
38
38
  )
@@ -54,8 +54,8 @@ module Langchain::LLM
54
54
  # @return [Langchain::LLM::AnthropicResponse] The completion
55
55
  def complete(
56
56
  prompt:,
57
- model: @defaults[:completion_model_name],
58
- max_tokens_to_sample: @defaults[:max_tokens_to_sample],
57
+ model: @defaults[:completion_model],
58
+ max_tokens: @defaults[:max_tokens],
59
59
  stop_sequences: nil,
60
60
  temperature: @defaults[:temperature],
61
61
  top_p: nil,
@@ -64,12 +64,12 @@ module Langchain::LLM
64
64
  stream: nil
65
65
  )
66
66
  raise ArgumentError.new("model argument is required") if model.empty?
67
- raise ArgumentError.new("max_tokens_to_sample argument is required") if max_tokens_to_sample.nil?
67
+ raise ArgumentError.new("max_tokens argument is required") if max_tokens.nil?
68
68
 
69
69
  parameters = {
70
70
  model: model,
71
71
  prompt: prompt,
72
- max_tokens_to_sample: max_tokens_to_sample,
72
+ max_tokens_to_sample: max_tokens,
73
73
  temperature: temperature
74
74
  }
75
75
  parameters[:stop_sequences] = stop_sequences if stop_sequences
@@ -7,51 +7,40 @@ module Langchain::LLM
7
7
  # gem 'aws-sdk-bedrockruntime', '~> 1.1'
8
8
  #
9
9
  # Usage:
10
- # llm = Langchain::LLM::AwsBedrock.new(llm_options: {})
10
+ # llm = Langchain::LLM::AwsBedrock.new(default_options: {})
11
11
  #
12
12
  class AwsBedrock < Base
13
13
  DEFAULTS = {
14
- chat_completion_model_name: "anthropic.claude-v2",
15
- completion_model_name: "anthropic.claude-v2",
16
- embeddings_model_name: "amazon.titan-embed-text-v1",
14
+ chat_model: "anthropic.claude-3-5-sonnet-20240620-v1:0",
15
+ completion_model: "anthropic.claude-v2:1",
16
+ embedding_model: "amazon.titan-embed-text-v1",
17
17
  max_tokens_to_sample: 300,
18
18
  temperature: 1,
19
19
  top_k: 250,
20
20
  top_p: 0.999,
21
21
  stop_sequences: ["\n\nHuman:"],
22
- anthropic_version: "bedrock-2023-05-31",
23
- return_likelihoods: "NONE",
24
- count_penalty: {
25
- scale: 0,
26
- apply_to_whitespaces: false,
27
- apply_to_punctuations: false,
28
- apply_to_numbers: false,
29
- apply_to_stopwords: false,
30
- apply_to_emojis: false
31
- },
32
- presence_penalty: {
33
- scale: 0,
34
- apply_to_whitespaces: false,
35
- apply_to_punctuations: false,
36
- apply_to_numbers: false,
37
- apply_to_stopwords: false,
38
- apply_to_emojis: false
39
- },
40
- frequency_penalty: {
41
- scale: 0,
42
- apply_to_whitespaces: false,
43
- apply_to_punctuations: false,
44
- apply_to_numbers: false,
45
- apply_to_stopwords: false,
46
- apply_to_emojis: false
47
- }
22
+ return_likelihoods: "NONE"
48
23
  }.freeze
49
24
 
50
25
  attr_reader :client, :defaults
51
26
 
52
- SUPPORTED_COMPLETION_PROVIDERS = %i[anthropic ai21 cohere meta].freeze
53
- SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[anthropic].freeze
54
- SUPPORTED_EMBEDDING_PROVIDERS = %i[amazon cohere].freeze
27
+ SUPPORTED_COMPLETION_PROVIDERS = %i[
28
+ anthropic
29
+ ai21
30
+ cohere
31
+ meta
32
+ ].freeze
33
+
34
+ SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[
35
+ anthropic
36
+ ai21
37
+ mistral
38
+ ].freeze
39
+
40
+ SUPPORTED_EMBEDDING_PROVIDERS = %i[
41
+ amazon
42
+ cohere
43
+ ].freeze
55
44
 
56
45
  def initialize(aws_client_options: {}, default_options: {})
57
46
  depends_on "aws-sdk-bedrockruntime", req: "aws-sdk-bedrockruntime"
@@ -60,12 +49,11 @@ module Langchain::LLM
60
49
  @defaults = DEFAULTS.merge(default_options)
61
50
 
62
51
  chat_parameters.update(
63
- model: {default: @defaults[:chat_completion_model_name]},
52
+ model: {default: @defaults[:chat_model]},
64
53
  temperature: {},
65
54
  max_tokens: {default: @defaults[:max_tokens_to_sample]},
66
55
  metadata: {},
67
- system: {},
68
- anthropic_version: {default: "bedrock-2023-05-31"}
56
+ system: {}
69
57
  )
70
58
  chat_parameters.ignore(:n, :user)
71
59
  chat_parameters.remap(stop: :stop_sequences)
@@ -84,7 +72,7 @@ module Langchain::LLM
84
72
  parameters = compose_embedding_parameters params.merge(text:)
85
73
 
86
74
  response = client.invoke_model({
87
- model_id: @defaults[:embeddings_model_name],
75
+ model_id: @defaults[:embedding_model],
88
76
  body: parameters.to_json,
89
77
  content_type: "application/json",
90
78
  accept: "application/json"
@@ -100,23 +88,25 @@ module Langchain::LLM
100
88
  # @param params extra parameters passed to Aws::BedrockRuntime::Client#invoke_model
101
89
  # @return [Langchain::LLM::AnthropicResponse], [Langchain::LLM::CohereResponse] or [Langchain::LLM::AI21Response] Response object
102
90
  #
103
- def complete(prompt:, **params)
104
- raise "Completion provider #{completion_provider} is not supported." unless SUPPORTED_COMPLETION_PROVIDERS.include?(completion_provider)
91
+ def complete(
92
+ prompt:,
93
+ model: @defaults[:completion_model],
94
+ **params
95
+ )
96
+ raise "Completion provider #{model} is not supported." unless SUPPORTED_COMPLETION_PROVIDERS.include?(provider_name(model))
105
97
 
106
- raise "Model #{@defaults[:completion_model_name]} only supports #chat." if @defaults[:completion_model_name].include?("claude-3")
107
-
108
- parameters = compose_parameters params
98
+ parameters = compose_parameters(params, model)
109
99
 
110
100
  parameters[:prompt] = wrap_prompt prompt
111
101
 
112
102
  response = client.invoke_model({
113
- model_id: @defaults[:completion_model_name],
103
+ model_id: model,
114
104
  body: parameters.to_json,
115
105
  content_type: "application/json",
116
106
  accept: "application/json"
117
107
  })
118
108
 
119
- parse_response response
109
+ parse_response(response, model)
120
110
  end
121
111
 
122
112
  # Generate a chat completion for a given prompt
@@ -126,7 +116,7 @@ module Langchain::LLM
126
116
  # @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
127
117
  # @option params [Array<String>] :messages The messages to generate a completion for
128
118
  # @option params [String] :system The system prompt to provide instructions
129
- # @option params [String] :model The model to use for completion defaults to @defaults[:chat_completion_model_name]
119
+ # @option params [String] :model The model to use for completion defaults to @defaults[:chat_model]
130
120
  # @option params [Integer] :max_tokens The maximum number of tokens to generate defaults to @defaults[:max_tokens_to_sample]
131
121
  # @option params [Array<String>] :stop The stop sequences to use for completion
132
122
  # @option params [Array<String>] :stop_sequences The stop sequences to use for completion
@@ -137,10 +127,11 @@ module Langchain::LLM
137
127
  # @return [Langchain::LLM::AnthropicResponse] Response object
138
128
  def chat(params = {}, &block)
139
129
  parameters = chat_parameters.to_params(params)
130
+ parameters = compose_parameters(parameters, parameters[:model])
140
131
 
141
- raise ArgumentError.new("messages argument is required") if Array(parameters[:messages]).empty?
142
-
143
- raise "Model #{parameters[:model]} does not support chat completions." unless Langchain::LLM::AwsBedrock::SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(completion_provider)
132
+ unless SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(provider_name(parameters[:model]))
133
+ raise "Chat provider #{parameters[:model]} is not supported."
134
+ end
144
135
 
145
136
  if block
146
137
  response_chunks = []
@@ -168,18 +159,32 @@ module Langchain::LLM
168
159
  accept: "application/json"
169
160
  })
170
161
 
171
- parse_response response
162
+ parse_response(response, parameters[:model])
172
163
  end
173
164
  end
174
165
 
175
166
  private
176
167
 
168
+ def parse_model_id(model_id)
169
+ model_id
170
+ .gsub("us.", "") # Meta append "us." to their model ids
171
+ .split(".")
172
+ end
173
+
174
+ def provider_name(model_id)
175
+ parse_model_id(model_id).first.to_sym
176
+ end
177
+
178
+ def model_name(model_id)
179
+ parse_model_id(model_id).last
180
+ end
181
+
177
182
  def completion_provider
178
- @defaults[:completion_model_name].split(".").first.to_sym
183
+ @defaults[:completion_model].split(".").first.to_sym
179
184
  end
180
185
 
181
186
  def embedding_provider
182
- @defaults[:embeddings_model_name].split(".").first.to_sym
187
+ @defaults[:embedding_model].split(".").first.to_sym
183
188
  end
184
189
 
185
190
  def wrap_prompt(prompt)
@@ -200,15 +205,17 @@ module Langchain::LLM
200
205
  end
201
206
  end
202
207
 
203
- def compose_parameters(params)
204
- if completion_provider == :anthropic
205
- compose_parameters_anthropic params
206
- elsif completion_provider == :cohere
207
- compose_parameters_cohere params
208
- elsif completion_provider == :ai21
209
- compose_parameters_ai21 params
210
- elsif completion_provider == :meta
211
- compose_parameters_meta params
208
+ def compose_parameters(params, model_id)
209
+ if provider_name(model_id) == :anthropic
210
+ compose_parameters_anthropic(params)
211
+ elsif provider_name(model_id) == :cohere
212
+ compose_parameters_cohere(params)
213
+ elsif provider_name(model_id) == :ai21
214
+ params
215
+ elsif provider_name(model_id) == :meta
216
+ params
217
+ elsif provider_name(model_id) == :mistral
218
+ params
212
219
  end
213
220
  end
214
221
 
@@ -220,15 +227,17 @@ module Langchain::LLM
220
227
  end
221
228
  end
222
229
 
223
- def parse_response(response)
224
- if completion_provider == :anthropic
230
+ def parse_response(response, model_id)
231
+ if provider_name(model_id) == :anthropic
225
232
  Langchain::LLM::AnthropicResponse.new(JSON.parse(response.body.string))
226
- elsif completion_provider == :cohere
233
+ elsif provider_name(model_id) == :cohere
227
234
  Langchain::LLM::CohereResponse.new(JSON.parse(response.body.string))
228
- elsif completion_provider == :ai21
235
+ elsif provider_name(model_id) == :ai21
229
236
  Langchain::LLM::AI21Response.new(JSON.parse(response.body.string, symbolize_names: true))
230
- elsif completion_provider == :meta
237
+ elsif provider_name(model_id) == :meta
231
238
  Langchain::LLM::AwsBedrockMetaResponse.new(JSON.parse(response.body.string))
239
+ elsif provider_name(model_id) == :mistral
240
+ Langchain::LLM::MistralAIResponse.new(JSON.parse(response.body.string))
232
241
  end
233
242
  end
234
243
 
@@ -276,61 +285,7 @@ module Langchain::LLM
276
285
  end
277
286
 
278
287
  def compose_parameters_anthropic(params)
279
- default_params = @defaults.merge(params)
280
-
281
- {
282
- max_tokens_to_sample: default_params[:max_tokens_to_sample],
283
- temperature: default_params[:temperature],
284
- top_k: default_params[:top_k],
285
- top_p: default_params[:top_p],
286
- stop_sequences: default_params[:stop_sequences],
287
- anthropic_version: default_params[:anthropic_version]
288
- }
289
- end
290
-
291
- def compose_parameters_ai21(params)
292
- default_params = @defaults.merge(params)
293
-
294
- {
295
- maxTokens: default_params[:max_tokens_to_sample],
296
- temperature: default_params[:temperature],
297
- topP: default_params[:top_p],
298
- stopSequences: default_params[:stop_sequences],
299
- countPenalty: {
300
- scale: default_params[:count_penalty][:scale],
301
- applyToWhitespaces: default_params[:count_penalty][:apply_to_whitespaces],
302
- applyToPunctuations: default_params[:count_penalty][:apply_to_punctuations],
303
- applyToNumbers: default_params[:count_penalty][:apply_to_numbers],
304
- applyToStopwords: default_params[:count_penalty][:apply_to_stopwords],
305
- applyToEmojis: default_params[:count_penalty][:apply_to_emojis]
306
- },
307
- presencePenalty: {
308
- scale: default_params[:presence_penalty][:scale],
309
- applyToWhitespaces: default_params[:presence_penalty][:apply_to_whitespaces],
310
- applyToPunctuations: default_params[:presence_penalty][:apply_to_punctuations],
311
- applyToNumbers: default_params[:presence_penalty][:apply_to_numbers],
312
- applyToStopwords: default_params[:presence_penalty][:apply_to_stopwords],
313
- applyToEmojis: default_params[:presence_penalty][:apply_to_emojis]
314
- },
315
- frequencyPenalty: {
316
- scale: default_params[:frequency_penalty][:scale],
317
- applyToWhitespaces: default_params[:frequency_penalty][:apply_to_whitespaces],
318
- applyToPunctuations: default_params[:frequency_penalty][:apply_to_punctuations],
319
- applyToNumbers: default_params[:frequency_penalty][:apply_to_numbers],
320
- applyToStopwords: default_params[:frequency_penalty][:apply_to_stopwords],
321
- applyToEmojis: default_params[:frequency_penalty][:apply_to_emojis]
322
- }
323
- }
324
- end
325
-
326
- def compose_parameters_meta(params)
327
- default_params = @defaults.merge(params)
328
-
329
- {
330
- temperature: default_params[:temperature],
331
- top_p: default_params[:top_p],
332
- max_gen_len: default_params[:max_tokens_to_sample]
333
- }
288
+ params.merge(anthropic_version: "bedrock-2023-05-31")
334
289
  end
335
290
 
336
291
  def response_from_chunks(chunks)
@@ -33,7 +33,7 @@ module Langchain::LLM
33
33
  )
34
34
  @defaults = DEFAULTS.merge(default_options)
35
35
  chat_parameters.update(
36
- model: {default: @defaults[:chat_completion_model_name]},
36
+ model: {default: @defaults[:chat_model]},
37
37
  logprobs: {},
38
38
  top_logprobs: {},
39
39
  n: {default: @defaults[:n]},
@@ -34,7 +34,7 @@ module Langchain::LLM
34
34
  default_dimensions
35
35
  end
36
36
 
37
- # Returns the number of vector dimensions used by DEFAULTS[:chat_completion_model_name]
37
+ # Returns the number of vector dimensions used by DEFAULTS[:chat_model]
38
38
  #
39
39
  # @return [Integer] Vector dimensions
40
40
  def default_dimensions
@@ -13,9 +13,9 @@ module Langchain::LLM
13
13
  class Cohere < Base
14
14
  DEFAULTS = {
15
15
  temperature: 0.0,
16
- completion_model_name: "command",
17
- chat_completion_model_name: "command-r-plus",
18
- embeddings_model_name: "small",
16
+ completion_model: "command",
17
+ chat_model: "command-r-plus",
18
+ embedding_model: "small",
19
19
  dimensions: 1024,
20
20
  truncate: "START"
21
21
  }.freeze
@@ -26,7 +26,7 @@ module Langchain::LLM
26
26
  @client = ::Cohere::Client.new(api_key: api_key)
27
27
  @defaults = DEFAULTS.merge(default_options)
28
28
  chat_parameters.update(
29
- model: {default: @defaults[:chat_completion_model_name]},
29
+ model: {default: @defaults[:chat_model]},
30
30
  temperature: {default: @defaults[:temperature]},
31
31
  response_format: {default: @defaults[:response_format]}
32
32
  )
@@ -48,10 +48,10 @@ module Langchain::LLM
48
48
  def embed(text:)
49
49
  response = client.embed(
50
50
  texts: [text],
51
- model: @defaults[:embeddings_model_name]
51
+ model: @defaults[:embedding_model]
52
52
  )
53
53
 
54
- Langchain::LLM::CohereResponse.new response, model: @defaults[:embeddings_model_name]
54
+ Langchain::LLM::CohereResponse.new response, model: @defaults[:embedding_model]
55
55
  end
56
56
 
57
57
  #
@@ -65,7 +65,7 @@ module Langchain::LLM
65
65
  default_params = {
66
66
  prompt: prompt,
67
67
  temperature: @defaults[:temperature],
68
- model: @defaults[:completion_model_name],
68
+ model: @defaults[:completion_model],
69
69
  truncate: @defaults[:truncate]
70
70
  }
71
71
 
@@ -76,7 +76,7 @@ module Langchain::LLM
76
76
  default_params.merge!(params)
77
77
 
78
78
  response = client.generate(**default_params)
79
- Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
79
+ Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model]
80
80
  end
81
81
 
82
82
  # Generate a chat completion for given messages
@@ -5,8 +5,8 @@ module Langchain::LLM
5
5
  # llm = Langchain::LLM::GoogleGemini.new(api_key: ENV['GOOGLE_GEMINI_API_KEY'])
6
6
  class GoogleGemini < Base
7
7
  DEFAULTS = {
8
- chat_completion_model_name: "gemini-1.5-pro-latest",
9
- embeddings_model_name: "text-embedding-004",
8
+ chat_model: "gemini-1.5-pro-latest",
9
+ embedding_model: "text-embedding-004",
10
10
  temperature: 0.0
11
11
  }
12
12
 
@@ -17,10 +17,10 @@ module Langchain::LLM
17
17
  @defaults = DEFAULTS.merge(default_options)
18
18
 
19
19
  chat_parameters.update(
20
- model: {default: @defaults[:chat_completion_model_name]},
20
+ model: {default: @defaults[:chat_model]},
21
21
  temperature: {default: @defaults[:temperature]},
22
22
  generation_config: {default: nil},
23
- safety_settings: {default: nil}
23
+ safety_settings: {default: @defaults[:safety_settings]}
24
24
  )
25
25
  chat_parameters.remap(
26
26
  messages: :contents,
@@ -72,9 +72,8 @@ module Langchain::LLM
72
72
 
73
73
  def embed(
74
74
  text:,
75
- model: @defaults[:embeddings_model_name]
75
+ model: @defaults[:embedding_model]
76
76
  )
77
-
78
77
  params = {
79
78
  content: {
80
79
  parts: [
@@ -17,8 +17,8 @@ module Langchain::LLM
17
17
  top_p: 0.8,
18
18
  top_k: 40,
19
19
  dimensions: 768,
20
- embeddings_model_name: "textembedding-gecko",
21
- chat_completion_model_name: "gemini-1.0-pro"
20
+ embedding_model: "textembedding-gecko",
21
+ chat_model: "gemini-1.0-pro"
22
22
  }.freeze
23
23
 
24
24
  # Google Cloud has a project id and a specific region of deployment.
@@ -38,8 +38,9 @@ module Langchain::LLM
38
38
  @defaults = DEFAULTS.merge(default_options)
39
39
 
40
40
  chat_parameters.update(
41
- model: {default: @defaults[:chat_completion_model_name]},
42
- temperature: {default: @defaults[:temperature]}
41
+ model: {default: @defaults[:chat_model]},
42
+ temperature: {default: @defaults[:temperature]},
43
+ safety_settings: {default: @defaults[:safety_settings]}
43
44
  )
44
45
  chat_parameters.remap(
45
46
  messages: :contents,
@@ -57,7 +58,7 @@ module Langchain::LLM
57
58
  #
58
59
  def embed(
59
60
  text:,
60
- model: @defaults[:embeddings_model_name]
61
+ model: @defaults[:embedding_model]
61
62
  )
62
63
  params = {instances: [{content: text}]}
63
64
 
@@ -12,7 +12,7 @@ module Langchain::LLM
12
12
  #
13
13
  class HuggingFace < Base
14
14
  DEFAULTS = {
15
- embeddings_model_name: "sentence-transformers/all-MiniLM-L6-v2"
15
+ embedding_model: "sentence-transformers/all-MiniLM-L6-v2"
16
16
  }.freeze
17
17
 
18
18
  EMBEDDING_SIZES = {
@@ -36,7 +36,7 @@ module Langchain::LLM
36
36
  def default_dimensions
37
37
  # since Huggin Face can run multiple models, look it up or generate an embedding and return the size
38
38
  @default_dimensions ||= @defaults[:dimensions] ||
39
- EMBEDDING_SIZES.fetch(@defaults[:embeddings_model_name].to_sym) do
39
+ EMBEDDING_SIZES.fetch(@defaults[:embedding_model].to_sym) do
40
40
  embed(text: "test").embedding.size
41
41
  end
42
42
  end
@@ -50,9 +50,9 @@ module Langchain::LLM
50
50
  def embed(text:)
51
51
  response = client.embedding(
52
52
  input: text,
53
- model: @defaults[:embeddings_model_name]
53
+ model: @defaults[:embedding_model]
54
54
  )
55
- Langchain::LLM::HuggingFaceResponse.new(response, model: @defaults[:embeddings_model_name])
55
+ Langchain::LLM::HuggingFaceResponse.new(response, model: @defaults[:embedding_model])
56
56
  end
57
57
  end
58
58
  end