langchainrb 0.12.0 → 0.13.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (42) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/README.md +3 -2
  4. data/lib/langchain/assistants/assistant.rb +75 -20
  5. data/lib/langchain/assistants/messages/base.rb +16 -0
  6. data/lib/langchain/assistants/messages/google_gemini_message.rb +90 -0
  7. data/lib/langchain/assistants/messages/openai_message.rb +74 -0
  8. data/lib/langchain/assistants/thread.rb +5 -5
  9. data/lib/langchain/llm/anthropic.rb +27 -49
  10. data/lib/langchain/llm/aws_bedrock.rb +30 -34
  11. data/lib/langchain/llm/azure.rb +6 -0
  12. data/lib/langchain/llm/base.rb +20 -1
  13. data/lib/langchain/llm/cohere.rb +38 -6
  14. data/lib/langchain/llm/google_gemini.rb +67 -0
  15. data/lib/langchain/llm/google_vertex_ai.rb +68 -112
  16. data/lib/langchain/llm/mistral_ai.rb +10 -19
  17. data/lib/langchain/llm/ollama.rb +23 -27
  18. data/lib/langchain/llm/openai.rb +20 -48
  19. data/lib/langchain/llm/parameters/chat.rb +51 -0
  20. data/lib/langchain/llm/response/base_response.rb +2 -2
  21. data/lib/langchain/llm/response/cohere_response.rb +16 -0
  22. data/lib/langchain/llm/response/google_gemini_response.rb +45 -0
  23. data/lib/langchain/llm/response/openai_response.rb +5 -1
  24. data/lib/langchain/llm/unified_parameters.rb +98 -0
  25. data/lib/langchain/loader.rb +6 -0
  26. data/lib/langchain/tool/base.rb +16 -6
  27. data/lib/langchain/tool/calculator/calculator.json +1 -1
  28. data/lib/langchain/tool/database/database.json +3 -3
  29. data/lib/langchain/tool/file_system/file_system.json +3 -3
  30. data/lib/langchain/tool/news_retriever/news_retriever.json +121 -0
  31. data/lib/langchain/tool/news_retriever/news_retriever.rb +132 -0
  32. data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +1 -1
  33. data/lib/langchain/tool/vectorsearch/vectorsearch.json +1 -1
  34. data/lib/langchain/tool/weather/weather.json +1 -1
  35. data/lib/langchain/tool/wikipedia/wikipedia.json +1 -1
  36. data/lib/langchain/tool/wikipedia/wikipedia.rb +2 -2
  37. data/lib/langchain/utils/token_length/openai_validator.rb +6 -1
  38. data/lib/langchain/version.rb +1 -1
  39. data/lib/langchain.rb +3 -0
  40. metadata +22 -15
  41. data/lib/langchain/assistants/message.rb +0 -58
  42. data/lib/langchain/llm/response/google_vertex_ai_response.rb +0 -33
@@ -11,7 +11,8 @@ module Langchain::LLM
11
11
  # - {Langchain::LLM::Azure}
12
12
  # - {Langchain::LLM::Cohere}
13
13
  # - {Langchain::LLM::GooglePalm}
14
- # - {Langchain::LLM::GoogleVertexAi}
14
+ # - {Langchain::LLM::GoogleVertexAI}
15
+ # - {Langchain::LLM::GoogleGemini}
15
16
  # - {Langchain::LLM::HuggingFace}
16
17
  # - {Langchain::LLM::LlamaCpp}
17
18
  # - {Langchain::LLM::OpenAI}
@@ -24,6 +25,15 @@ module Langchain::LLM
24
25
  # A client for communicating with the LLM
25
26
  attr_reader :client
26
27
 
28
+ # Ensuring backward compatibility after https://github.com/patterns-ai-core/langchainrb/pull/586
29
+ # TODO: Delete this method later
30
+ def default_dimension
31
+ default_dimensions
32
+ end
33
+
34
+ # Returns the number of vector dimensions used by DEFAULTS[:chat_completion_model_name]
35
+ #
36
+ # @return [Integer] Vector dimensions
27
37
  def default_dimensions
28
38
  self.class.const_get(:DEFAULTS).dig(:dimensions)
29
39
  end
@@ -61,5 +71,14 @@ module Langchain::LLM
61
71
  def summarize(...)
62
72
  raise NotImplementedError, "#{self.class.name} does not support summarization"
63
73
  end
74
+
75
+ #
76
+ # Returns an instance of Langchain::LLM::Parameters::Chat
77
+ #
78
+ def chat_parameters(params = {})
79
+ @chat_parameters ||= Langchain::LLM::Parameters::Chat.new(
80
+ parameters: params
81
+ )
82
+ end
64
83
  end
65
84
  end
@@ -8,22 +8,34 @@ module Langchain::LLM
8
8
  # gem "cohere-ruby", "~> 0.9.6"
9
9
  #
10
10
  # Usage:
11
- # cohere = Langchain::LLM::Cohere.new(api_key: ENV["COHERE_API_KEY"])
11
+ # llm = Langchain::LLM::Cohere.new(api_key: ENV["COHERE_API_KEY"])
12
12
  #
13
13
  class Cohere < Base
14
14
  DEFAULTS = {
15
15
  temperature: 0.0,
16
16
  completion_model_name: "command",
17
+ chat_completion_model_name: "command-r-plus",
17
18
  embeddings_model_name: "small",
18
19
  dimensions: 1024,
19
20
  truncate: "START"
20
21
  }.freeze
21
22
 
22
- def initialize(api_key, default_options = {})
23
+ def initialize(api_key:, default_options: {})
23
24
  depends_on "cohere-ruby", req: "cohere"
24
25
 
25
- @client = ::Cohere::Client.new(api_key)
26
+ @client = ::Cohere::Client.new(api_key: api_key)
26
27
  @defaults = DEFAULTS.merge(default_options)
28
+ chat_parameters.update(
29
+ model: {default: @defaults[:chat_completion_model_name]},
30
+ temperature: {default: @defaults[:temperature]}
31
+ )
32
+ chat_parameters.remap(
33
+ system: :preamble,
34
+ messages: :chat_history,
35
+ stop: :stop_sequences,
36
+ top_k: :k,
37
+ top_p: :p
38
+ )
27
39
  end
28
40
 
29
41
  #
@@ -68,9 +80,29 @@ module Langchain::LLM
68
80
  Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
69
81
  end
70
82
 
71
- # TODO: Implement chat method: https://github.com/andreibondarev/cohere-ruby/issues/11
72
- # def chat
73
- # end
83
+ # Generate a chat completion for given messages
84
+ #
85
+ # @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
86
+ # @option params [Array<String>] :messages Input messages
87
+ # @option params [String] :model The model that will complete your prompt
88
+ # @option params [Integer] :max_tokens Maximum number of tokens to generate before stopping
89
+ # @option params [Array<String>] :stop Custom text sequences that will cause the model to stop generating
90
+ # @option params [Boolean] :stream Whether to incrementally stream the response using server-sent events
91
+ # @option params [String] :system System prompt
92
+ # @option params [Float] :temperature Amount of randomness injected into the response
93
+ # @option params [Array<String>] :tools Definitions of tools that the model may use
94
+ # @option params [Integer] :top_k Only sample from the top K options for each subsequent token
95
+ # @option params [Float] :top_p Use nucleus sampling.
96
+ # @return [Langchain::LLM::CohereResponse] The chat completion
97
+ def chat(params = {})
98
+ raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
99
+
100
+ parameters = chat_parameters.to_params(params)
101
+
102
+ response = client.chat(**parameters)
103
+
104
+ Langchain::LLM::CohereResponse.new(response)
105
+ end
74
106
 
75
107
  # Generate a summary in English for a given text
76
108
  #
@@ -0,0 +1,67 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ # Usage:
5
+ # llm = Langchain::LLM::GoogleGemini.new(api_key: ENV['GOOGLE_GEMINI_API_KEY'])
6
+ class GoogleGemini < Base
7
+ DEFAULTS = {
8
+ chat_completion_model_name: "gemini-1.5-pro-latest",
9
+ temperature: 0.0
10
+ }
11
+
12
+ attr_reader :defaults, :api_key
13
+
14
+ def initialize(api_key:, default_options: {})
15
+ @api_key = api_key
16
+ @defaults = DEFAULTS.merge(default_options)
17
+
18
+ chat_parameters.update(
19
+ model: {default: @defaults[:chat_completion_model_name]},
20
+ temperature: {default: @defaults[:temperature]}
21
+ )
22
+ chat_parameters.remap(
23
+ messages: :contents,
24
+ system: :system_instruction,
25
+ tool_choice: :tool_config
26
+ )
27
+ end
28
+
29
+ # Generate a chat completion for a given prompt
30
+ #
31
+ # @param messages [Array<Hash>] List of messages comprising the conversation so far
32
+ # @param model [String] The model to use
33
+ # @param tools [Array<Hash>] A list of Tools the model may use to generate the next response
34
+ # @param tool_choice [String] Specifies the mode in which function calling should execute. If unspecified, the default value will be set to AUTO. Possible values: AUTO, ANY, NONE
35
+ # @param system [String] Developer set system instruction
36
+ def chat(params = {})
37
+ params[:system] = {parts: [{text: params[:system]}]} if params[:system]
38
+ params[:tools] = {function_declarations: params[:tools]} if params[:tools]
39
+ params[:tool_choice] = {function_calling_config: {mode: params[:tool_choice].upcase}} if params[:tool_choice]
40
+
41
+ raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
42
+
43
+ parameters = chat_parameters.to_params(params)
44
+ parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature]
45
+
46
+ uri = URI("https://generativelanguage.googleapis.com/v1beta/models/#{parameters[:model]}:generateContent?key=#{api_key}")
47
+
48
+ request = Net::HTTP::Post.new(uri)
49
+ request.content_type = "application/json"
50
+ request.body = parameters.to_json
51
+
52
+ response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
53
+ http.request(request)
54
+ end
55
+
56
+ parsed_response = JSON.parse(response.body)
57
+
58
+ wrapped_response = Langchain::LLM::GoogleGeminiResponse.new(parsed_response, model: parameters[:model])
59
+
60
+ if wrapped_response.chat_completion || Array(wrapped_response.tool_calls).any?
61
+ wrapped_response
62
+ else
63
+ raise StandardError.new(response)
64
+ end
65
+ end
66
+ end
67
+ end
@@ -2,150 +2,106 @@
2
2
 
3
3
  module Langchain::LLM
4
4
  #
5
- # Wrapper around the Google Vertex AI APIs: https://cloud.google.com/vertex-ai?hl=en
5
+ # Wrapper around the Google Vertex AI APIs: https://cloud.google.com/vertex-ai
6
6
  #
7
7
  # Gem requirements:
8
- # gem "google-apis-aiplatform_v1", "~> 0.7"
8
+ # gem "googleauth"
9
9
  #
10
10
  # Usage:
11
- # google_palm = Langchain::LLM::GoogleVertexAi.new(project_id: ENV["GOOGLE_VERTEX_AI_PROJECT_ID"])
11
+ # llm = Langchain::LLM::GoogleVertexAI.new(project_id: ENV["GOOGLE_VERTEX_AI_PROJECT_ID"], region: "us-central1")
12
12
  #
13
- class GoogleVertexAi < Base
13
+ class GoogleVertexAI < Base
14
14
  DEFAULTS = {
15
- temperature: 0.1, # 0.1 is the default in the API, quite low ("grounded")
15
+ temperature: 0.1,
16
16
  max_output_tokens: 1000,
17
17
  top_p: 0.8,
18
18
  top_k: 40,
19
19
  dimensions: 768,
20
- completion_model_name: "text-bison", # Optional: tect-bison@001
21
- embeddings_model_name: "textembedding-gecko"
20
+ embeddings_model_name: "textembedding-gecko",
21
+ chat_completion_model_name: "gemini-1.0-pro"
22
22
  }.freeze
23
23
 
24
- # TODO: Implement token length validation
25
- # LENGTH_VALIDATOR = Langchain::Utils::TokenLength::...
26
-
27
24
  # Google Cloud has a project id and a specific region of deployment.
28
25
  # For GenAI-related things, a safe choice is us-central1.
29
- attr_reader :project_id, :client, :region
30
-
31
- def initialize(project_id:, default_options: {})
32
- depends_on "google-apis-aiplatform_v1"
26
+ attr_reader :defaults, :url, :authorizer
33
27
 
34
- @project_id = project_id
35
- @region = default_options.fetch :region, "us-central1"
28
+ def initialize(project_id:, region:, default_options: {})
29
+ depends_on "googleauth"
36
30
 
37
- @client = Google::Apis::AiplatformV1::AiplatformService.new
38
-
39
- # TODO: Adapt for other regions; Pass it in via the constructor
40
- # For the moment only us-central1 available so no big deal.
41
- @client.root_url = "https://#{@region}-aiplatform.googleapis.com/"
42
- @client.authorization = Google::Auth.get_application_default
31
+ @authorizer = ::Google::Auth.get_application_default
32
+ proj_id = project_id || @authorizer.project_id || @authorizer.quota_project_id
33
+ @url = "https://#{region}-aiplatform.googleapis.com/v1/projects/#{proj_id}/locations/#{region}/publishers/google/models/"
43
34
 
44
35
  @defaults = DEFAULTS.merge(default_options)
36
+
37
+ chat_parameters.update(
38
+ model: {default: @defaults[:chat_completion_model_name]},
39
+ temperature: {default: @defaults[:temperature]}
40
+ )
41
+ chat_parameters.remap(
42
+ messages: :contents,
43
+ system: :system_instruction,
44
+ tool_choice: :tool_config
45
+ )
45
46
  end
46
47
 
47
48
  #
48
49
  # Generate an embedding for a given text
49
50
  #
50
51
  # @param text [String] The text to generate an embedding for
51
- # @return [Langchain::LLM::GoogleVertexAiResponse] Response object
52
+ # @param model [String] ID of the model to use
53
+ # @return [Langchain::LLM::GoogleGeminiResponse] Response object
52
54
  #
53
- def embed(text:)
54
- content = [{content: text}]
55
- request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new(instances: content)
56
-
57
- api_path = "projects/#{@project_id}/locations/us-central1/publishers/google/models/#{@defaults[:embeddings_model_name]}"
58
-
59
- # puts("api_path: #{api_path}")
60
-
61
- response = client.predict_project_location_publisher_model(api_path, request)
55
+ def embed(
56
+ text:,
57
+ model: @defaults[:embeddings_model_name]
58
+ )
59
+ params = {instances: [{content: text}]}
60
+
61
+ response = HTTParty.post(
62
+ "#{url}#{model}:predict",
63
+ body: params.to_json,
64
+ headers: {
65
+ "Content-Type" => "application/json",
66
+ "Authorization" => "Bearer #{@authorizer.fetch_access_token!["access_token"]}"
67
+ }
68
+ )
62
69
 
63
- Langchain::LLM::GoogleVertexAiResponse.new(response.to_h, model: @defaults[:embeddings_model_name])
70
+ Langchain::LLM::GoogleGeminiResponse.new(response, model: model)
64
71
  end
65
72
 
73
+ # Generate a chat completion for given messages
66
74
  #
67
- # Generate a completion for a given prompt
68
- #
69
- # @param prompt [String] The prompt to generate a completion for
70
- # @param params extra parameters passed to GooglePalmAPI::Client#generate_text
71
- # @return [Langchain::LLM::GooglePalmResponse] Response object
72
- #
73
- def complete(prompt:, **params)
74
- default_params = {
75
- prompt: prompt,
76
- temperature: @defaults[:temperature],
77
- top_k: @defaults[:top_k],
78
- top_p: @defaults[:top_p],
79
- max_output_tokens: @defaults[:max_output_tokens],
80
- model: @defaults[:completion_model_name]
81
- }
82
-
83
- if params[:stop_sequences]
84
- default_params[:stop_sequences] = params.delete(:stop_sequences)
75
+ # @param messages [Array<Hash>] Input messages
76
+ # @param model [String] The model that will complete your prompt
77
+ # @param tools [Array<Hash>] The tools to use
78
+ # @param tool_choice [String] The tool choice to use
79
+ # @param system [String] The system instruction to use
80
+ # @return [Langchain::LLM::GoogleGeminiResponse] Response object
81
+ def chat(params = {})
82
+ params[:system] = {parts: [{text: params[:system]}]} if params[:system]
83
+ params[:tools] = {function_declarations: params[:tools]} if params[:tools]
84
+ params[:tool_choice] = {function_calling_config: {mode: params[:tool_choice].upcase}} if params[:tool_choice]
85
+
86
+ raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
87
+
88
+ parameters = chat_parameters.to_params(params)
89
+ parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature]
90
+
91
+ uri = URI("#{url}#{parameters[:model]}:generateContent")
92
+
93
+ request = Net::HTTP::Post.new(uri)
94
+ request.content_type = "application/json"
95
+ request["Authorization"] = "Bearer #{@authorizer.fetch_access_token!["access_token"]}"
96
+ request.body = parameters.to_json
97
+
98
+ response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
99
+ http.request(request)
85
100
  end
86
101
 
87
- if params[:max_output_tokens]
88
- default_params[:max_output_tokens] = params.delete(:max_output_tokens)
89
- end
90
-
91
- # to be tested
92
- temperature = params.delete(:temperature) || @defaults[:temperature]
93
- max_output_tokens = default_params.fetch(:max_output_tokens, @defaults[:max_output_tokens])
94
-
95
- default_params.merge!(params)
96
-
97
- # response = client.generate_text(**default_params)
98
- request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new \
99
- instances: [{
100
- prompt: prompt # key used to be :content, changed to :prompt
101
- }],
102
- parameters: {
103
- temperature: temperature,
104
- maxOutputTokens: max_output_tokens,
105
- topP: 0.8,
106
- topK: 40
107
- }
108
-
109
- response = client.predict_project_location_publisher_model \
110
- "projects/#{project_id}/locations/us-central1/publishers/google/models/#{@defaults[:completion_model_name]}",
111
- request
112
-
113
- Langchain::LLM::GoogleVertexAiResponse.new(response, model: default_params[:model])
114
- end
102
+ parsed_response = JSON.parse(response.body)
115
103
 
116
- #
117
- # Generate a summarization for a given text
118
- #
119
- # @param text [String] The text to generate a summarization for
120
- # @return [String] The summarization
121
- #
122
- # TODO(ricc): add params for Temp, topP, topK, MaxTokens and have it default to these 4 values.
123
- def summarize(text:)
124
- prompt_template = Langchain::Prompt.load_from_path(
125
- file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
126
- )
127
- prompt = prompt_template.format(text: text)
128
-
129
- complete(
130
- prompt: prompt,
131
- # For best temperature, topP, topK, MaxTokens for summarization: see
132
- # https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-summarization
133
- temperature: 0.2,
134
- top_p: 0.95,
135
- top_k: 40,
136
- # Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
137
- max_output_tokens: 256
138
- )
104
+ Langchain::LLM::GoogleGeminiResponse.new(parsed_response, model: parameters[:model])
139
105
  end
140
-
141
- # def chat(...)
142
- # https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chathat
143
- # Chat params: https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chat
144
- # \"temperature\": 0.3,\n"
145
- # + " \"maxDecodeSteps\": 200,\n"
146
- # + " \"topP\": 0.8,\n"
147
- # + " \"topK\": 40\n"
148
- # + "}";
149
- # end
150
106
  end
151
107
  end
@@ -23,28 +23,19 @@ module Langchain::LLM
23
23
  )
24
24
 
25
25
  @defaults = DEFAULTS.merge(default_options)
26
+ chat_parameters.update(
27
+ model: {default: @defaults[:chat_completion_model_name]},
28
+ n: {default: @defaults[:n]},
29
+ safe_prompt: {}
30
+ )
31
+ chat_parameters.remap(seed: :random_seed)
32
+ chat_parameters.ignore(:n, :top_k)
26
33
  end
27
34
 
28
- def chat(
29
- messages:,
30
- model: defaults[:chat_completion_model_name],
31
- temperature: nil,
32
- top_p: nil,
33
- max_tokens: nil,
34
- safe_prompt: nil,
35
- random_seed: nil
36
- )
37
- params = {
38
- messages: messages,
39
- model: model
40
- }
41
- params[:temperature] = temperature if temperature
42
- params[:top_p] = top_p if top_p
43
- params[:max_tokens] = max_tokens if max_tokens
44
- params[:safe_prompt] = safe_prompt if safe_prompt
45
- params[:random_seed] = random_seed if random_seed
35
+ def chat(params = {})
36
+ parameters = chat_parameters.to_params(params)
46
37
 
47
- response = client.chat_completions(params)
38
+ response = client.chat_completions(parameters)
48
39
 
49
40
  Langchain::LLM::MistralAIResponse.new(response.to_h)
50
41
  end
@@ -7,22 +7,24 @@ module Langchain::LLM
7
7
  # Available models: https://ollama.ai/library
8
8
  #
9
9
  # Usage:
10
- # ollama = Langchain::LLM::Ollama.new(url: ENV["OLLAMA_URL"], default_options: {})
10
+ # llm = Langchain::LLM::Ollama.new
11
+ # llm = Langchain::LLM::Ollama.new(url: ENV["OLLAMA_URL"], default_options: {})
11
12
  #
12
13
  class Ollama < Base
13
14
  attr_reader :url, :defaults
14
15
 
15
16
  DEFAULTS = {
16
17
  temperature: 0.8,
17
- completion_model_name: "llama2",
18
- embeddings_model_name: "llama2",
19
- chat_completion_model_name: "llama2"
18
+ completion_model_name: "llama3",
19
+ embeddings_model_name: "llama3",
20
+ chat_completion_model_name: "llama3"
20
21
  }.freeze
21
22
 
22
23
  EMBEDDING_SIZES = {
23
24
  codellama: 4_096,
24
25
  "dolphin-mixtral": 4_096,
25
26
  llama2: 4_096,
27
+ llama3: 4_096,
26
28
  llava: 4_096,
27
29
  mistral: 4_096,
28
30
  "mistral-openorca": 4_096,
@@ -33,10 +35,17 @@ module Langchain::LLM
33
35
  # @param url [String] The URL of the Ollama instance
34
36
  # @param default_options [Hash] The default options to use
35
37
  #
36
- def initialize(url:, default_options: {})
38
+ def initialize(url: "http://localhost:11434", default_options: {})
37
39
  depends_on "faraday"
38
40
  @url = url
39
41
  @defaults = DEFAULTS.deep_merge(default_options)
42
+ chat_parameters.update(
43
+ model: {default: @defaults[:chat_completion_model_name]},
44
+ temperature: {default: @defaults[:temperature]},
45
+ template: {},
46
+ stream: {default: false}
47
+ )
48
+ chat_parameters.remap(response_format: :format)
40
49
  end
41
50
 
42
51
  # Returns the # of vector dimensions for the embeddings
@@ -150,33 +159,20 @@ module Langchain::LLM
150
159
 
151
160
  # Generate a chat completion
152
161
  #
153
- # @param model [String] Model name
154
- # @param messages [Array<Hash>] Array of messages
155
- # @param format [String] Format to return a response in. Currently the only accepted value is `json`
156
- # @param temperature [Float] The temperature to use
157
- # @param template [String] The prompt template to use (overrides what is defined in the `Modelfile`)
158
- # @param stream [Boolean] Streaming the response. If false the response will be returned as a single response object, rather than a stream of objects
162
+ # @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
163
+ # @option params [String] :model Model name
164
+ # @option params [Array<Hash>] :messages Array of messages
165
+ # @option params [String] :format Format to return a response in. Currently the only accepted value is `json`
166
+ # @option params [Float] :temperature The temperature to use
167
+ # @option params [String] :template The prompt template to use (overrides what is defined in the `Modelfile`)
168
+ # @option params [Boolean] :stream Streaming the response. If false the response will be returned as a single response object, rather than a stream of objects
159
169
  #
160
170
  # The message object has the following fields:
161
171
  # role: the role of the message, either system, user or assistant
162
172
  # content: the content of the message
163
173
  # images (optional): a list of images to include in the message (for multimodal models such as llava)
164
- def chat(
165
- model: defaults[:chat_completion_model_name],
166
- messages: [],
167
- format: nil,
168
- temperature: defaults[:temperature],
169
- template: nil,
170
- stream: false # TODO: Fix streaming.
171
- )
172
- parameters = {
173
- model: model,
174
- messages: messages,
175
- format: format,
176
- temperature: temperature,
177
- template: template,
178
- stream: stream
179
- }.compact
174
+ def chat(params = {})
175
+ parameters = chat_parameters.to_params(params)
180
176
 
181
177
  response = client.post("api/chat") do |req|
182
178
  req.body = parameters
@@ -40,6 +40,15 @@ module Langchain::LLM
40
40
  @client = ::OpenAI::Client.new(access_token: api_key, **llm_options)
41
41
 
42
42
  @defaults = DEFAULTS.merge(default_options)
43
+ chat_parameters.update(
44
+ model: {default: @defaults[:chat_completion_model_name]},
45
+ logprobs: {},
46
+ top_logprobs: {},
47
+ n: {default: @defaults[:n]},
48
+ temperature: {default: @defaults[:temperature]},
49
+ user: {}
50
+ )
51
+ chat_parameters.ignore(:top_k)
43
52
  end
44
53
 
45
54
  # Generate an embedding for a given text
@@ -102,54 +111,17 @@ module Langchain::LLM
102
111
 
103
112
  # Generate a chat completion for given messages.
104
113
  #
105
- # @param messages [Array<Hash>] List of messages comprising the conversation so far
106
- # @param model [String] ID of the model to use
107
- def chat(
108
- messages: [],
109
- model: defaults[:chat_completion_model_name],
110
- frequency_penalty: nil,
111
- logit_bias: nil,
112
- logprobs: nil,
113
- top_logprobs: nil,
114
- max_tokens: nil,
115
- n: defaults[:n],
116
- presence_penalty: nil,
117
- response_format: nil,
118
- seed: nil,
119
- stop: nil,
120
- stream: nil,
121
- temperature: defaults[:temperature],
122
- top_p: nil,
123
- tools: [],
124
- tool_choice: nil,
125
- user: nil,
126
- &block
127
- )
128
- raise ArgumentError.new("messages argument is required") if messages.empty?
129
- raise ArgumentError.new("model argument is required") if model.empty?
130
- raise ArgumentError.new("'tool_choice' is only allowed when 'tools' are specified.") if tool_choice && tools.empty?
131
-
132
- parameters = {
133
- messages: messages,
134
- model: model
135
- }
136
- parameters[:frequency_penalty] = frequency_penalty if frequency_penalty
137
- parameters[:logit_bias] = logit_bias if logit_bias
138
- parameters[:logprobs] = logprobs if logprobs
139
- parameters[:top_logprobs] = top_logprobs if top_logprobs
140
- # TODO: Fix max_tokens validation to account for tools/functions
141
- parameters[:max_tokens] = max_tokens if max_tokens # || validate_max_tokens(parameters[:messages], parameters[:model])
142
- parameters[:n] = n if n
143
- parameters[:presence_penalty] = presence_penalty if presence_penalty
144
- parameters[:response_format] = response_format if response_format
145
- parameters[:seed] = seed if seed
146
- parameters[:stop] = stop if stop
147
- parameters[:stream] = stream if stream
148
- parameters[:temperature] = temperature if temperature
149
- parameters[:top_p] = top_p if top_p
150
- parameters[:tools] = tools if tools.any?
151
- parameters[:tool_choice] = tool_choice if tool_choice
152
- parameters[:user] = user if user
114
+ # @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
115
+ # @option params [Array<Hash>] :messages List of messages comprising the conversation so far
116
+ # @option params [String] :model ID of the model to use
117
+ def chat(params = {}, &block)
118
+ parameters = chat_parameters.to_params(params)
119
+
120
+ raise ArgumentError.new("messages argument is required") if Array(parameters[:messages]).empty?
121
+ raise ArgumentError.new("model argument is required") if parameters[:model].to_s.empty?
122
+ if parameters[:tool_choice] && Array(parameters[:tools]).empty?
123
+ raise ArgumentError.new("'tool_choice' is only allowed when 'tools' are specified.")
124
+ end
153
125
 
154
126
  # TODO: Clean this part up
155
127
  if block
@@ -0,0 +1,51 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "delegate"
4
+
5
+ module Langchain::LLM::Parameters
6
+ class Chat < SimpleDelegator
7
+ # TODO: At the moment, the UnifiedParamters only considers keys. In the
8
+ # future, we may consider ActiveModel-style validations and further typed
9
+ # options here.
10
+ SCHEMA = {
11
+ # Either "messages" or "prompt" is required
12
+ messages: {},
13
+ model: {},
14
+ prompt: {},
15
+
16
+ # System instructions. Used by Cohere, Anthropic and Google Gemini.
17
+ system: {},
18
+
19
+ # Allows to force the model to produce specific output format.
20
+ response_format: {},
21
+
22
+ stop: {}, # multiple types (e.g. OpenAI also allows Array, null)
23
+ stream: {}, # Enable streaming
24
+
25
+ max_tokens: {}, # Range: [1, context_length)
26
+ temperature: {}, # Range: [0, 2]
27
+ top_p: {}, # Range: (0, 1]
28
+ top_k: {}, # Range: [1, Infinity) Not available for OpenAI models
29
+ frequency_penalty: {}, # Range: [-2, 2]
30
+ presence_penalty: {}, # Range: [-2, 2]
31
+ repetition_penalty: {}, # Range: (0, 2]
32
+ seed: {}, # OpenAI only
33
+
34
+ # Function-calling
35
+ tools: {default: []},
36
+ tool_choice: {},
37
+
38
+ # Additional optional parameters
39
+ logit_bias: {}
40
+ }
41
+
42
+ def initialize(parameters: {})
43
+ super(
44
+ ::Langchain::LLM::UnifiedParameters.new(
45
+ schema: SCHEMA.dup,
46
+ parameters: parameters
47
+ )
48
+ )
49
+ end
50
+ end
51
+ end