langchainrb 0.12.0 → 0.13.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/README.md +3 -2
  4. data/lib/langchain/assistants/assistant.rb +75 -20
  5. data/lib/langchain/assistants/messages/base.rb +16 -0
  6. data/lib/langchain/assistants/messages/google_gemini_message.rb +90 -0
  7. data/lib/langchain/assistants/messages/openai_message.rb +74 -0
  8. data/lib/langchain/assistants/thread.rb +5 -5
  9. data/lib/langchain/llm/anthropic.rb +27 -49
  10. data/lib/langchain/llm/aws_bedrock.rb +30 -34
  11. data/lib/langchain/llm/azure.rb +6 -0
  12. data/lib/langchain/llm/base.rb +20 -1
  13. data/lib/langchain/llm/cohere.rb +38 -6
  14. data/lib/langchain/llm/google_gemini.rb +67 -0
  15. data/lib/langchain/llm/google_vertex_ai.rb +68 -112
  16. data/lib/langchain/llm/mistral_ai.rb +10 -19
  17. data/lib/langchain/llm/ollama.rb +23 -27
  18. data/lib/langchain/llm/openai.rb +20 -48
  19. data/lib/langchain/llm/parameters/chat.rb +51 -0
  20. data/lib/langchain/llm/response/base_response.rb +2 -2
  21. data/lib/langchain/llm/response/cohere_response.rb +16 -0
  22. data/lib/langchain/llm/response/google_gemini_response.rb +45 -0
  23. data/lib/langchain/llm/response/openai_response.rb +5 -1
  24. data/lib/langchain/llm/unified_parameters.rb +98 -0
  25. data/lib/langchain/loader.rb +6 -0
  26. data/lib/langchain/tool/base.rb +16 -6
  27. data/lib/langchain/tool/calculator/calculator.json +1 -1
  28. data/lib/langchain/tool/database/database.json +3 -3
  29. data/lib/langchain/tool/file_system/file_system.json +3 -3
  30. data/lib/langchain/tool/news_retriever/news_retriever.json +121 -0
  31. data/lib/langchain/tool/news_retriever/news_retriever.rb +132 -0
  32. data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +1 -1
  33. data/lib/langchain/tool/vectorsearch/vectorsearch.json +1 -1
  34. data/lib/langchain/tool/weather/weather.json +1 -1
  35. data/lib/langchain/tool/wikipedia/wikipedia.json +1 -1
  36. data/lib/langchain/tool/wikipedia/wikipedia.rb +2 -2
  37. data/lib/langchain/utils/token_length/openai_validator.rb +6 -1
  38. data/lib/langchain/version.rb +1 -1
  39. data/lib/langchain.rb +3 -0
  40. metadata +22 -15
  41. data/lib/langchain/assistants/message.rb +0 -58
  42. data/lib/langchain/llm/response/google_vertex_ai_response.rb +0 -33
@@ -11,7 +11,8 @@ module Langchain::LLM
11
11
  # - {Langchain::LLM::Azure}
12
12
  # - {Langchain::LLM::Cohere}
13
13
  # - {Langchain::LLM::GooglePalm}
14
- # - {Langchain::LLM::GoogleVertexAi}
14
+ # - {Langchain::LLM::GoogleVertexAI}
15
+ # - {Langchain::LLM::GoogleGemini}
15
16
  # - {Langchain::LLM::HuggingFace}
16
17
  # - {Langchain::LLM::LlamaCpp}
17
18
  # - {Langchain::LLM::OpenAI}
@@ -24,6 +25,15 @@ module Langchain::LLM
24
25
  # A client for communicating with the LLM
25
26
  attr_reader :client
26
27
 
28
+ # Ensuring backward compatibility after https://github.com/patterns-ai-core/langchainrb/pull/586
29
+ # TODO: Delete this method later
30
+ def default_dimension
31
+ default_dimensions
32
+ end
33
+
34
+ # Returns the number of vector dimensions used by DEFAULTS[:chat_completion_model_name]
35
+ #
36
+ # @return [Integer] Vector dimensions
27
37
  def default_dimensions
28
38
  self.class.const_get(:DEFAULTS).dig(:dimensions)
29
39
  end
@@ -61,5 +71,14 @@ module Langchain::LLM
61
71
  def summarize(...)
62
72
  raise NotImplementedError, "#{self.class.name} does not support summarization"
63
73
  end
74
+
75
+ #
76
+ # Returns an instance of Langchain::LLM::Parameters::Chat
77
+ #
78
+ def chat_parameters(params = {})
79
+ @chat_parameters ||= Langchain::LLM::Parameters::Chat.new(
80
+ parameters: params
81
+ )
82
+ end
64
83
  end
65
84
  end
@@ -8,22 +8,34 @@ module Langchain::LLM
8
8
  # gem "cohere-ruby", "~> 0.9.6"
9
9
  #
10
10
  # Usage:
11
- # cohere = Langchain::LLM::Cohere.new(api_key: ENV["COHERE_API_KEY"])
11
+ # llm = Langchain::LLM::Cohere.new(api_key: ENV["COHERE_API_KEY"])
12
12
  #
13
13
  class Cohere < Base
14
14
  DEFAULTS = {
15
15
  temperature: 0.0,
16
16
  completion_model_name: "command",
17
+ chat_completion_model_name: "command-r-plus",
17
18
  embeddings_model_name: "small",
18
19
  dimensions: 1024,
19
20
  truncate: "START"
20
21
  }.freeze
21
22
 
22
- def initialize(api_key, default_options = {})
23
+ def initialize(api_key:, default_options: {})
23
24
  depends_on "cohere-ruby", req: "cohere"
24
25
 
25
- @client = ::Cohere::Client.new(api_key)
26
+ @client = ::Cohere::Client.new(api_key: api_key)
26
27
  @defaults = DEFAULTS.merge(default_options)
28
+ chat_parameters.update(
29
+ model: {default: @defaults[:chat_completion_model_name]},
30
+ temperature: {default: @defaults[:temperature]}
31
+ )
32
+ chat_parameters.remap(
33
+ system: :preamble,
34
+ messages: :chat_history,
35
+ stop: :stop_sequences,
36
+ top_k: :k,
37
+ top_p: :p
38
+ )
27
39
  end
28
40
 
29
41
  #
@@ -68,9 +80,29 @@ module Langchain::LLM
68
80
  Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
69
81
  end
70
82
 
71
- # TODO: Implement chat method: https://github.com/andreibondarev/cohere-ruby/issues/11
72
- # def chat
73
- # end
83
+ # Generate a chat completion for given messages
84
+ #
85
+ # @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
86
+ # @option params [Array<String>] :messages Input messages
87
+ # @option params [String] :model The model that will complete your prompt
88
+ # @option params [Integer] :max_tokens Maximum number of tokens to generate before stopping
89
+ # @option params [Array<String>] :stop Custom text sequences that will cause the model to stop generating
90
+ # @option params [Boolean] :stream Whether to incrementally stream the response using server-sent events
91
+ # @option params [String] :system System prompt
92
+ # @option params [Float] :temperature Amount of randomness injected into the response
93
+ # @option params [Array<String>] :tools Definitions of tools that the model may use
94
+ # @option params [Integer] :top_k Only sample from the top K options for each subsequent token
95
+ # @option params [Float] :top_p Use nucleus sampling.
96
+ # @return [Langchain::LLM::CohereResponse] The chat completion
97
+ def chat(params = {})
98
+ raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
99
+
100
+ parameters = chat_parameters.to_params(params)
101
+
102
+ response = client.chat(**parameters)
103
+
104
+ Langchain::LLM::CohereResponse.new(response)
105
+ end
74
106
 
75
107
  # Generate a summary in English for a given text
76
108
  #
@@ -0,0 +1,67 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ # Usage:
5
+ # llm = Langchain::LLM::GoogleGemini.new(api_key: ENV['GOOGLE_GEMINI_API_KEY'])
6
+ class GoogleGemini < Base
7
+ DEFAULTS = {
8
+ chat_completion_model_name: "gemini-1.5-pro-latest",
9
+ temperature: 0.0
10
+ }
11
+
12
+ attr_reader :defaults, :api_key
13
+
14
+ def initialize(api_key:, default_options: {})
15
+ @api_key = api_key
16
+ @defaults = DEFAULTS.merge(default_options)
17
+
18
+ chat_parameters.update(
19
+ model: {default: @defaults[:chat_completion_model_name]},
20
+ temperature: {default: @defaults[:temperature]}
21
+ )
22
+ chat_parameters.remap(
23
+ messages: :contents,
24
+ system: :system_instruction,
25
+ tool_choice: :tool_config
26
+ )
27
+ end
28
+
29
+ # Generate a chat completion for a given prompt
30
+ #
31
+ # @param messages [Array<Hash>] List of messages comprising the conversation so far
32
+ # @param model [String] The model to use
33
+ # @param tools [Array<Hash>] A list of Tools the model may use to generate the next response
34
+ # @param tool_choice [String] Specifies the mode in which function calling should execute. If unspecified, the default value will be set to AUTO. Possible values: AUTO, ANY, NONE
35
+ # @param system [String] Developer set system instruction
36
+ def chat(params = {})
37
+ params[:system] = {parts: [{text: params[:system]}]} if params[:system]
38
+ params[:tools] = {function_declarations: params[:tools]} if params[:tools]
39
+ params[:tool_choice] = {function_calling_config: {mode: params[:tool_choice].upcase}} if params[:tool_choice]
40
+
41
+ raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
42
+
43
+ parameters = chat_parameters.to_params(params)
44
+ parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature]
45
+
46
+ uri = URI("https://generativelanguage.googleapis.com/v1beta/models/#{parameters[:model]}:generateContent?key=#{api_key}")
47
+
48
+ request = Net::HTTP::Post.new(uri)
49
+ request.content_type = "application/json"
50
+ request.body = parameters.to_json
51
+
52
+ response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
53
+ http.request(request)
54
+ end
55
+
56
+ parsed_response = JSON.parse(response.body)
57
+
58
+ wrapped_response = Langchain::LLM::GoogleGeminiResponse.new(parsed_response, model: parameters[:model])
59
+
60
+ if wrapped_response.chat_completion || Array(wrapped_response.tool_calls).any?
61
+ wrapped_response
62
+ else
63
+ raise StandardError.new(response)
64
+ end
65
+ end
66
+ end
67
+ end
@@ -2,150 +2,106 @@
2
2
 
3
3
  module Langchain::LLM
4
4
  #
5
- # Wrapper around the Google Vertex AI APIs: https://cloud.google.com/vertex-ai?hl=en
5
+ # Wrapper around the Google Vertex AI APIs: https://cloud.google.com/vertex-ai
6
6
  #
7
7
  # Gem requirements:
8
- # gem "google-apis-aiplatform_v1", "~> 0.7"
8
+ # gem "googleauth"
9
9
  #
10
10
  # Usage:
11
- # google_palm = Langchain::LLM::GoogleVertexAi.new(project_id: ENV["GOOGLE_VERTEX_AI_PROJECT_ID"])
11
+ # llm = Langchain::LLM::GoogleVertexAI.new(project_id: ENV["GOOGLE_VERTEX_AI_PROJECT_ID"], region: "us-central1")
12
12
  #
13
- class GoogleVertexAi < Base
13
+ class GoogleVertexAI < Base
14
14
  DEFAULTS = {
15
- temperature: 0.1, # 0.1 is the default in the API, quite low ("grounded")
15
+ temperature: 0.1,
16
16
  max_output_tokens: 1000,
17
17
  top_p: 0.8,
18
18
  top_k: 40,
19
19
  dimensions: 768,
20
- completion_model_name: "text-bison", # Optional: tect-bison@001
21
- embeddings_model_name: "textembedding-gecko"
20
+ embeddings_model_name: "textembedding-gecko",
21
+ chat_completion_model_name: "gemini-1.0-pro"
22
22
  }.freeze
23
23
 
24
- # TODO: Implement token length validation
25
- # LENGTH_VALIDATOR = Langchain::Utils::TokenLength::...
26
-
27
24
  # Google Cloud has a project id and a specific region of deployment.
28
25
  # For GenAI-related things, a safe choice is us-central1.
29
- attr_reader :project_id, :client, :region
30
-
31
- def initialize(project_id:, default_options: {})
32
- depends_on "google-apis-aiplatform_v1"
26
+ attr_reader :defaults, :url, :authorizer
33
27
 
34
- @project_id = project_id
35
- @region = default_options.fetch :region, "us-central1"
28
+ def initialize(project_id:, region:, default_options: {})
29
+ depends_on "googleauth"
36
30
 
37
- @client = Google::Apis::AiplatformV1::AiplatformService.new
38
-
39
- # TODO: Adapt for other regions; Pass it in via the constructor
40
- # For the moment only us-central1 available so no big deal.
41
- @client.root_url = "https://#{@region}-aiplatform.googleapis.com/"
42
- @client.authorization = Google::Auth.get_application_default
31
+ @authorizer = ::Google::Auth.get_application_default
32
+ proj_id = project_id || @authorizer.project_id || @authorizer.quota_project_id
33
+ @url = "https://#{region}-aiplatform.googleapis.com/v1/projects/#{proj_id}/locations/#{region}/publishers/google/models/"
43
34
 
44
35
  @defaults = DEFAULTS.merge(default_options)
36
+
37
+ chat_parameters.update(
38
+ model: {default: @defaults[:chat_completion_model_name]},
39
+ temperature: {default: @defaults[:temperature]}
40
+ )
41
+ chat_parameters.remap(
42
+ messages: :contents,
43
+ system: :system_instruction,
44
+ tool_choice: :tool_config
45
+ )
45
46
  end
46
47
 
47
48
  #
48
49
  # Generate an embedding for a given text
49
50
  #
50
51
  # @param text [String] The text to generate an embedding for
51
- # @return [Langchain::LLM::GoogleVertexAiResponse] Response object
52
+ # @param model [String] ID of the model to use
53
+ # @return [Langchain::LLM::GoogleGeminiResponse] Response object
52
54
  #
53
- def embed(text:)
54
- content = [{content: text}]
55
- request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new(instances: content)
56
-
57
- api_path = "projects/#{@project_id}/locations/us-central1/publishers/google/models/#{@defaults[:embeddings_model_name]}"
58
-
59
- # puts("api_path: #{api_path}")
60
-
61
- response = client.predict_project_location_publisher_model(api_path, request)
55
+ def embed(
56
+ text:,
57
+ model: @defaults[:embeddings_model_name]
58
+ )
59
+ params = {instances: [{content: text}]}
60
+
61
+ response = HTTParty.post(
62
+ "#{url}#{model}:predict",
63
+ body: params.to_json,
64
+ headers: {
65
+ "Content-Type" => "application/json",
66
+ "Authorization" => "Bearer #{@authorizer.fetch_access_token!["access_token"]}"
67
+ }
68
+ )
62
69
 
63
- Langchain::LLM::GoogleVertexAiResponse.new(response.to_h, model: @defaults[:embeddings_model_name])
70
+ Langchain::LLM::GoogleGeminiResponse.new(response, model: model)
64
71
  end
65
72
 
73
+ # Generate a chat completion for given messages
66
74
  #
67
- # Generate a completion for a given prompt
68
- #
69
- # @param prompt [String] The prompt to generate a completion for
70
- # @param params extra parameters passed to GooglePalmAPI::Client#generate_text
71
- # @return [Langchain::LLM::GooglePalmResponse] Response object
72
- #
73
- def complete(prompt:, **params)
74
- default_params = {
75
- prompt: prompt,
76
- temperature: @defaults[:temperature],
77
- top_k: @defaults[:top_k],
78
- top_p: @defaults[:top_p],
79
- max_output_tokens: @defaults[:max_output_tokens],
80
- model: @defaults[:completion_model_name]
81
- }
82
-
83
- if params[:stop_sequences]
84
- default_params[:stop_sequences] = params.delete(:stop_sequences)
75
+ # @param messages [Array<Hash>] Input messages
76
+ # @param model [String] The model that will complete your prompt
77
+ # @param tools [Array<Hash>] The tools to use
78
+ # @param tool_choice [String] The tool choice to use
79
+ # @param system [String] The system instruction to use
80
+ # @return [Langchain::LLM::GoogleGeminiResponse] Response object
81
+ def chat(params = {})
82
+ params[:system] = {parts: [{text: params[:system]}]} if params[:system]
83
+ params[:tools] = {function_declarations: params[:tools]} if params[:tools]
84
+ params[:tool_choice] = {function_calling_config: {mode: params[:tool_choice].upcase}} if params[:tool_choice]
85
+
86
+ raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
87
+
88
+ parameters = chat_parameters.to_params(params)
89
+ parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature]
90
+
91
+ uri = URI("#{url}#{parameters[:model]}:generateContent")
92
+
93
+ request = Net::HTTP::Post.new(uri)
94
+ request.content_type = "application/json"
95
+ request["Authorization"] = "Bearer #{@authorizer.fetch_access_token!["access_token"]}"
96
+ request.body = parameters.to_json
97
+
98
+ response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
99
+ http.request(request)
85
100
  end
86
101
 
87
- if params[:max_output_tokens]
88
- default_params[:max_output_tokens] = params.delete(:max_output_tokens)
89
- end
90
-
91
- # to be tested
92
- temperature = params.delete(:temperature) || @defaults[:temperature]
93
- max_output_tokens = default_params.fetch(:max_output_tokens, @defaults[:max_output_tokens])
94
-
95
- default_params.merge!(params)
96
-
97
- # response = client.generate_text(**default_params)
98
- request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new \
99
- instances: [{
100
- prompt: prompt # key used to be :content, changed to :prompt
101
- }],
102
- parameters: {
103
- temperature: temperature,
104
- maxOutputTokens: max_output_tokens,
105
- topP: 0.8,
106
- topK: 40
107
- }
108
-
109
- response = client.predict_project_location_publisher_model \
110
- "projects/#{project_id}/locations/us-central1/publishers/google/models/#{@defaults[:completion_model_name]}",
111
- request
112
-
113
- Langchain::LLM::GoogleVertexAiResponse.new(response, model: default_params[:model])
114
- end
102
+ parsed_response = JSON.parse(response.body)
115
103
 
116
- #
117
- # Generate a summarization for a given text
118
- #
119
- # @param text [String] The text to generate a summarization for
120
- # @return [String] The summarization
121
- #
122
- # TODO(ricc): add params for Temp, topP, topK, MaxTokens and have it default to these 4 values.
123
- def summarize(text:)
124
- prompt_template = Langchain::Prompt.load_from_path(
125
- file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
126
- )
127
- prompt = prompt_template.format(text: text)
128
-
129
- complete(
130
- prompt: prompt,
131
- # For best temperature, topP, topK, MaxTokens for summarization: see
132
- # https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-summarization
133
- temperature: 0.2,
134
- top_p: 0.95,
135
- top_k: 40,
136
- # Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
137
- max_output_tokens: 256
138
- )
104
+ Langchain::LLM::GoogleGeminiResponse.new(parsed_response, model: parameters[:model])
139
105
  end
140
-
141
- # def chat(...)
142
- # https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chathat
143
- # Chat params: https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chat
144
- # \"temperature\": 0.3,\n"
145
- # + " \"maxDecodeSteps\": 200,\n"
146
- # + " \"topP\": 0.8,\n"
147
- # + " \"topK\": 40\n"
148
- # + "}";
149
- # end
150
106
  end
151
107
  end
@@ -23,28 +23,19 @@ module Langchain::LLM
23
23
  )
24
24
 
25
25
  @defaults = DEFAULTS.merge(default_options)
26
+ chat_parameters.update(
27
+ model: {default: @defaults[:chat_completion_model_name]},
28
+ n: {default: @defaults[:n]},
29
+ safe_prompt: {}
30
+ )
31
+ chat_parameters.remap(seed: :random_seed)
32
+ chat_parameters.ignore(:n, :top_k)
26
33
  end
27
34
 
28
- def chat(
29
- messages:,
30
- model: defaults[:chat_completion_model_name],
31
- temperature: nil,
32
- top_p: nil,
33
- max_tokens: nil,
34
- safe_prompt: nil,
35
- random_seed: nil
36
- )
37
- params = {
38
- messages: messages,
39
- model: model
40
- }
41
- params[:temperature] = temperature if temperature
42
- params[:top_p] = top_p if top_p
43
- params[:max_tokens] = max_tokens if max_tokens
44
- params[:safe_prompt] = safe_prompt if safe_prompt
45
- params[:random_seed] = random_seed if random_seed
35
+ def chat(params = {})
36
+ parameters = chat_parameters.to_params(params)
46
37
 
47
- response = client.chat_completions(params)
38
+ response = client.chat_completions(parameters)
48
39
 
49
40
  Langchain::LLM::MistralAIResponse.new(response.to_h)
50
41
  end
@@ -7,22 +7,24 @@ module Langchain::LLM
7
7
  # Available models: https://ollama.ai/library
8
8
  #
9
9
  # Usage:
10
- # ollama = Langchain::LLM::Ollama.new(url: ENV["OLLAMA_URL"], default_options: {})
10
+ # llm = Langchain::LLM::Ollama.new
11
+ # llm = Langchain::LLM::Ollama.new(url: ENV["OLLAMA_URL"], default_options: {})
11
12
  #
12
13
  class Ollama < Base
13
14
  attr_reader :url, :defaults
14
15
 
15
16
  DEFAULTS = {
16
17
  temperature: 0.8,
17
- completion_model_name: "llama2",
18
- embeddings_model_name: "llama2",
19
- chat_completion_model_name: "llama2"
18
+ completion_model_name: "llama3",
19
+ embeddings_model_name: "llama3",
20
+ chat_completion_model_name: "llama3"
20
21
  }.freeze
21
22
 
22
23
  EMBEDDING_SIZES = {
23
24
  codellama: 4_096,
24
25
  "dolphin-mixtral": 4_096,
25
26
  llama2: 4_096,
27
+ llama3: 4_096,
26
28
  llava: 4_096,
27
29
  mistral: 4_096,
28
30
  "mistral-openorca": 4_096,
@@ -33,10 +35,17 @@ module Langchain::LLM
33
35
  # @param url [String] The URL of the Ollama instance
34
36
  # @param default_options [Hash] The default options to use
35
37
  #
36
- def initialize(url:, default_options: {})
38
+ def initialize(url: "http://localhost:11434", default_options: {})
37
39
  depends_on "faraday"
38
40
  @url = url
39
41
  @defaults = DEFAULTS.deep_merge(default_options)
42
+ chat_parameters.update(
43
+ model: {default: @defaults[:chat_completion_model_name]},
44
+ temperature: {default: @defaults[:temperature]},
45
+ template: {},
46
+ stream: {default: false}
47
+ )
48
+ chat_parameters.remap(response_format: :format)
40
49
  end
41
50
 
42
51
  # Returns the # of vector dimensions for the embeddings
@@ -150,33 +159,20 @@ module Langchain::LLM
150
159
 
151
160
  # Generate a chat completion
152
161
  #
153
- # @param model [String] Model name
154
- # @param messages [Array<Hash>] Array of messages
155
- # @param format [String] Format to return a response in. Currently the only accepted value is `json`
156
- # @param temperature [Float] The temperature to use
157
- # @param template [String] The prompt template to use (overrides what is defined in the `Modelfile`)
158
- # @param stream [Boolean] Streaming the response. If false the response will be returned as a single response object, rather than a stream of objects
162
+ # @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
163
+ # @option params [String] :model Model name
164
+ # @option params [Array<Hash>] :messages Array of messages
165
+ # @option params [String] :format Format to return a response in. Currently the only accepted value is `json`
166
+ # @option params [Float] :temperature The temperature to use
167
+ # @option params [String] :template The prompt template to use (overrides what is defined in the `Modelfile`)
168
+ # @option params [Boolean] :stream Streaming the response. If false the response will be returned as a single response object, rather than a stream of objects
159
169
  #
160
170
  # The message object has the following fields:
161
171
  # role: the role of the message, either system, user or assistant
162
172
  # content: the content of the message
163
173
  # images (optional): a list of images to include in the message (for multimodal models such as llava)
164
- def chat(
165
- model: defaults[:chat_completion_model_name],
166
- messages: [],
167
- format: nil,
168
- temperature: defaults[:temperature],
169
- template: nil,
170
- stream: false # TODO: Fix streaming.
171
- )
172
- parameters = {
173
- model: model,
174
- messages: messages,
175
- format: format,
176
- temperature: temperature,
177
- template: template,
178
- stream: stream
179
- }.compact
174
+ def chat(params = {})
175
+ parameters = chat_parameters.to_params(params)
180
176
 
181
177
  response = client.post("api/chat") do |req|
182
178
  req.body = parameters
@@ -40,6 +40,15 @@ module Langchain::LLM
40
40
  @client = ::OpenAI::Client.new(access_token: api_key, **llm_options)
41
41
 
42
42
  @defaults = DEFAULTS.merge(default_options)
43
+ chat_parameters.update(
44
+ model: {default: @defaults[:chat_completion_model_name]},
45
+ logprobs: {},
46
+ top_logprobs: {},
47
+ n: {default: @defaults[:n]},
48
+ temperature: {default: @defaults[:temperature]},
49
+ user: {}
50
+ )
51
+ chat_parameters.ignore(:top_k)
43
52
  end
44
53
 
45
54
  # Generate an embedding for a given text
@@ -102,54 +111,17 @@ module Langchain::LLM
102
111
 
103
112
  # Generate a chat completion for given messages.
104
113
  #
105
- # @param messages [Array<Hash>] List of messages comprising the conversation so far
106
- # @param model [String] ID of the model to use
107
- def chat(
108
- messages: [],
109
- model: defaults[:chat_completion_model_name],
110
- frequency_penalty: nil,
111
- logit_bias: nil,
112
- logprobs: nil,
113
- top_logprobs: nil,
114
- max_tokens: nil,
115
- n: defaults[:n],
116
- presence_penalty: nil,
117
- response_format: nil,
118
- seed: nil,
119
- stop: nil,
120
- stream: nil,
121
- temperature: defaults[:temperature],
122
- top_p: nil,
123
- tools: [],
124
- tool_choice: nil,
125
- user: nil,
126
- &block
127
- )
128
- raise ArgumentError.new("messages argument is required") if messages.empty?
129
- raise ArgumentError.new("model argument is required") if model.empty?
130
- raise ArgumentError.new("'tool_choice' is only allowed when 'tools' are specified.") if tool_choice && tools.empty?
131
-
132
- parameters = {
133
- messages: messages,
134
- model: model
135
- }
136
- parameters[:frequency_penalty] = frequency_penalty if frequency_penalty
137
- parameters[:logit_bias] = logit_bias if logit_bias
138
- parameters[:logprobs] = logprobs if logprobs
139
- parameters[:top_logprobs] = top_logprobs if top_logprobs
140
- # TODO: Fix max_tokens validation to account for tools/functions
141
- parameters[:max_tokens] = max_tokens if max_tokens # || validate_max_tokens(parameters[:messages], parameters[:model])
142
- parameters[:n] = n if n
143
- parameters[:presence_penalty] = presence_penalty if presence_penalty
144
- parameters[:response_format] = response_format if response_format
145
- parameters[:seed] = seed if seed
146
- parameters[:stop] = stop if stop
147
- parameters[:stream] = stream if stream
148
- parameters[:temperature] = temperature if temperature
149
- parameters[:top_p] = top_p if top_p
150
- parameters[:tools] = tools if tools.any?
151
- parameters[:tool_choice] = tool_choice if tool_choice
152
- parameters[:user] = user if user
114
+ # @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
115
+ # @option params [Array<Hash>] :messages List of messages comprising the conversation so far
116
+ # @option params [String] :model ID of the model to use
117
+ def chat(params = {}, &block)
118
+ parameters = chat_parameters.to_params(params)
119
+
120
+ raise ArgumentError.new("messages argument is required") if Array(parameters[:messages]).empty?
121
+ raise ArgumentError.new("model argument is required") if parameters[:model].to_s.empty?
122
+ if parameters[:tool_choice] && Array(parameters[:tools]).empty?
123
+ raise ArgumentError.new("'tool_choice' is only allowed when 'tools' are specified.")
124
+ end
153
125
 
154
126
  # TODO: Clean this part up
155
127
  if block
@@ -0,0 +1,51 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "delegate"
4
+
5
+ module Langchain::LLM::Parameters
6
+ class Chat < SimpleDelegator
7
+ # TODO: At the moment, the UnifiedParamters only considers keys. In the
8
+ # future, we may consider ActiveModel-style validations and further typed
9
+ # options here.
10
+ SCHEMA = {
11
+ # Either "messages" or "prompt" is required
12
+ messages: {},
13
+ model: {},
14
+ prompt: {},
15
+
16
+ # System instructions. Used by Cohere, Anthropic and Google Gemini.
17
+ system: {},
18
+
19
+ # Allows to force the model to produce specific output format.
20
+ response_format: {},
21
+
22
+ stop: {}, # multiple types (e.g. OpenAI also allows Array, null)
23
+ stream: {}, # Enable streaming
24
+
25
+ max_tokens: {}, # Range: [1, context_length)
26
+ temperature: {}, # Range: [0, 2]
27
+ top_p: {}, # Range: (0, 1]
28
+ top_k: {}, # Range: [1, Infinity) Not available for OpenAI models
29
+ frequency_penalty: {}, # Range: [-2, 2]
30
+ presence_penalty: {}, # Range: [-2, 2]
31
+ repetition_penalty: {}, # Range: (0, 2]
32
+ seed: {}, # OpenAI only
33
+
34
+ # Function-calling
35
+ tools: {default: []},
36
+ tool_choice: {},
37
+
38
+ # Additional optional parameters
39
+ logit_bias: {}
40
+ }
41
+
42
+ def initialize(parameters: {})
43
+ super(
44
+ ::Langchain::LLM::UnifiedParameters.new(
45
+ schema: SCHEMA.dup,
46
+ parameters: parameters
47
+ )
48
+ )
49
+ end
50
+ end
51
+ end