langchainrb 0.8.2 → 0.9.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -11,6 +11,7 @@ module Langchain::LLM
11
11
  # - {Langchain::LLM::Azure}
12
12
  # - {Langchain::LLM::Cohere}
13
13
  # - {Langchain::LLM::GooglePalm}
14
+ # - {Langchain::LLM::GoogleVertexAi}
14
15
  # - {Langchain::LLM::HuggingFace}
15
16
  # - {Langchain::LLM::LlamaCpp}
16
17
  # - {Langchain::LLM::OpenAI}
@@ -62,17 +62,15 @@ module Langchain::LLM
62
62
 
63
63
  default_params.merge!(params)
64
64
 
65
- default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], client)
65
+ default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], llm: client)
66
66
 
67
67
  response = client.generate(**default_params)
68
68
  Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
69
69
  end
70
70
 
71
- # Cohere does not have a dedicated chat endpoint, so instead we call `complete()`
72
- def chat(...)
73
- response_text = complete(...)
74
- ::Langchain::Conversation::Response.new(response_text)
75
- end
71
+ # TODO: Implement chat method: https://github.com/andreibondarev/cohere-ruby/issues/11
72
+ # def chat
73
+ # end
76
74
 
77
75
  # Generate a summary in English for a given text
78
76
  #
@@ -23,6 +23,8 @@ module Langchain::LLM
23
23
  "assistant" => "ai"
24
24
  }
25
25
 
26
+ attr_reader :defaults
27
+
26
28
  def initialize(api_key:, default_options: {})
27
29
  depends_on "google_palm_api"
28
30
 
@@ -21,6 +21,9 @@ module Langchain::LLM
21
21
  embeddings_model_name: "textembedding-gecko"
22
22
  }.freeze
23
23
 
24
+ # TODO: Implement token length validation
25
+ # LENGTH_VALIDATOR = Langchain::Utils::TokenLength::...
26
+
24
27
  # Google Cloud has a project id and a specific region of deployment.
25
28
  # For GenAI-related things, a safe choice is us-central1.
26
29
  attr_reader :project_id, :client, :region
@@ -135,15 +138,14 @@ module Langchain::LLM
135
138
  )
136
139
  end
137
140
 
138
- def chat(...)
139
- # https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chathat
140
- # Chat params: https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chat
141
- # \"temperature\": 0.3,\n"
142
- # + " \"maxDecodeSteps\": 200,\n"
143
- # + " \"topP\": 0.8,\n"
144
- # + " \"topK\": 40\n"
145
- # + "}";
146
- raise NotImplementedError, "coming soon for Vertex AI.."
147
- end
141
+ # def chat(...)
142
+ # https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chathat
143
+ # Chat params: https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chat
144
+ # \"temperature\": 0.3,\n"
145
+ # + " \"maxDecodeSteps\": 200,\n"
146
+ # + " \"topP\": 0.8,\n"
147
+ # + " \"topK\": 40\n"
148
+ # + "}";
149
+ # end
148
150
  end
149
151
  end
@@ -4,156 +4,170 @@ module Langchain::LLM
4
4
  # LLM interface for OpenAI APIs: https://platform.openai.com/overview
5
5
  #
6
6
  # Gem requirements:
7
- # gem "ruby-openai", "~> 6.1.0"
7
+ # gem "ruby-openai", "~> 6.3.0"
8
8
  #
9
9
  # Usage:
10
- # openai = Langchain::LLM::OpenAI.new(api_key:, llm_options: {})
11
- #
10
+ # openai = Langchain::LLM::OpenAI.new(
11
+ # api_key: ENV["OPENAI_API_KEY"],
12
+ # llm_options: {},
13
+ # default_options: {}
14
+ # )
12
15
  class OpenAI < Base
13
16
  DEFAULTS = {
14
17
  n: 1,
15
18
  temperature: 0.0,
16
- completion_model_name: "gpt-3.5-turbo",
17
19
  chat_completion_model_name: "gpt-3.5-turbo",
18
20
  embeddings_model_name: "text-embedding-ada-002",
19
21
  dimension: 1536
20
22
  }.freeze
21
23
 
22
- LEGACY_COMPLETION_MODELS = %w[
23
- ada
24
- babbage
25
- curie
26
- davinci
27
- ].freeze
28
-
29
24
  LENGTH_VALIDATOR = Langchain::Utils::TokenLength::OpenAIValidator
30
25
 
31
- attr_accessor :functions
26
+ attr_reader :defaults
32
27
 
28
+ # Initialize an OpenAI LLM instance
29
+ #
30
+ # @param api_key [String] The API key to use
31
+ # @param client_options [Hash] Options to pass to the OpenAI::Client constructor
33
32
  def initialize(api_key:, llm_options: {}, default_options: {})
34
33
  depends_on "ruby-openai", req: "openai"
35
34
 
36
35
  @client = ::OpenAI::Client.new(access_token: api_key, **llm_options)
36
+
37
37
  @defaults = DEFAULTS.merge(default_options)
38
38
  end
39
39
 
40
- #
41
40
  # Generate an embedding for a given text
42
41
  #
43
42
  # @param text [String] The text to generate an embedding for
44
- # @param params extra parameters passed to OpenAI::Client#embeddings
43
+ # @param model [String] ID of the model to use
44
+ # @param encoding_format [String] The format to return the embeddings in. Can be either float or base64.
45
+ # @param user [String] A unique identifier representing your end-user
45
46
  # @return [Langchain::LLM::OpenAIResponse] Response object
46
- #
47
- def embed(text:, **params)
48
- parameters = {model: @defaults[:embeddings_model_name], input: text}
47
+ def embed(
48
+ text:,
49
+ model: defaults[:embeddings_model_name],
50
+ encoding_format: nil,
51
+ user: nil
52
+ )
53
+ raise ArgumentError.new("text argument is required") if text.empty?
54
+ raise ArgumentError.new("model argument is required") if model.empty?
55
+ raise ArgumentError.new("encoding_format must be either float or base64") if encoding_format && %w[float base64].include?(encoding_format)
56
+
57
+ parameters = {
58
+ input: text,
59
+ model: model
60
+ }
61
+ parameters[:encoding_format] = encoding_format if encoding_format
62
+ parameters[:user] = user if user
49
63
 
50
64
  validate_max_tokens(text, parameters[:model])
51
65
 
52
66
  response = with_api_error_handling do
53
- client.embeddings(parameters: parameters.merge(params))
67
+ client.embeddings(parameters: parameters)
54
68
  end
55
69
 
56
70
  Langchain::LLM::OpenAIResponse.new(response)
57
71
  end
58
72
 
59
- #
73
+ # rubocop:disable Style/ArgumentsForwarding
60
74
  # Generate a completion for a given prompt
61
75
  #
62
76
  # @param prompt [String] The prompt to generate a completion for
63
- # @param params extra parameters passed to OpenAI::Client#complete
64
- # @return [Langchain::LLM::Response::OpenaAI] Response object
65
- #
77
+ # @param params [Hash] The parameters to pass to the `chat()` method
78
+ # @return [Langchain::LLM::OpenAIResponse] Response object
66
79
  def complete(prompt:, **params)
67
- parameters = compose_parameters @defaults[:completion_model_name], params
68
-
69
- return legacy_complete(prompt, parameters) if is_legacy_model?(parameters[:model])
70
-
71
- parameters[:messages] = compose_chat_messages(prompt: prompt)
72
- parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model], parameters[:max_tokens])
73
-
74
- response = with_api_error_handling do
75
- client.chat(parameters: parameters)
80
+ if params[:stop_sequences]
81
+ params[:stop] = params.delete(:stop_sequences)
76
82
  end
77
-
78
- Langchain::LLM::OpenAIResponse.new(response)
83
+ # Should we still accept the `messages: []` parameter here?
84
+ messages = [{role: "user", content: prompt}]
85
+ chat(messages: messages, **params)
79
86
  end
87
+ # rubocop:enable Style/ArgumentsForwarding
80
88
 
81
- #
82
89
  # Generate a chat completion for a given prompt or messages.
83
90
  #
84
- # == Examples
85
- #
86
- # # simplest case, just give a prompt
87
- # openai.chat prompt: "When was Ruby first released?"
88
- #
89
- # # prompt plus some context about how to respond
90
- # openai.chat context: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
91
- #
92
- # # full control over messages that get sent, equivilent to the above
93
- # openai.chat messages: [
94
- # {
95
- # role: "system",
96
- # content: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
97
- # },
98
- # {
99
- # role: "user",
100
- # content: "When was Ruby first released?"
101
- # }
102
- # ]
103
- #
104
- # # few-short prompting with examples
105
- # openai.chat prompt: "When was factory_bot released?",
106
- # examples: [
107
- # {
108
- # role: "user",
109
- # content: "When was Ruby on Rails released?"
110
- # }
111
- # {
112
- # role: "assistant",
113
- # content: "2004"
114
- # },
115
- # ]
116
- #
117
- # @param prompt [String] The prompt to generate a chat completion for
118
- # @param messages [Array<Hash>] The messages that have been sent in the conversation
119
- # @param context [String] An initial context to provide as a system message, ie "You are RubyGPT, a helpful chat bot for helping people learn Ruby"
120
- # @param examples [Array<Hash>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
121
- # @param options [Hash] extra parameters passed to OpenAI::Client#chat
122
- # @yield [Hash] Stream responses back one token at a time
123
- # @return [Langchain::LLM::OpenAIResponse] Response object
124
- #
125
- def chat(prompt: "", messages: [], context: "", examples: [], **options, &block)
126
- raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
127
-
128
- parameters = compose_parameters @defaults[:chat_completion_model_name], options, &block
129
- parameters[:messages] = compose_chat_messages(prompt: prompt, messages: messages, context: context, examples: examples)
91
+ # @param messages [Array<Hash>] List of messages comprising the conversation so far
92
+ # @param model [String] ID of the model to use
93
+ def chat(
94
+ messages: [],
95
+ model: defaults[:chat_completion_model_name],
96
+ frequency_penalty: nil,
97
+ logit_bias: nil,
98
+ logprobs: nil,
99
+ top_logprobs: nil,
100
+ max_tokens: nil,
101
+ n: defaults[:n],
102
+ presence_penalty: nil,
103
+ response_format: nil,
104
+ seed: nil,
105
+ stop: nil,
106
+ stream: nil,
107
+ temperature: defaults[:temperature],
108
+ top_p: nil,
109
+ tools: [],
110
+ tool_choice: nil,
111
+ user: nil,
112
+ &block
113
+ )
114
+ raise ArgumentError.new("messages argument is required") if messages.empty?
115
+ raise ArgumentError.new("model argument is required") if model.empty?
116
+ raise ArgumentError.new("'tool_choice' is only allowed when 'tools' are specified.") if tool_choice && tools.empty?
117
+
118
+ parameters = {
119
+ messages: messages,
120
+ model: model
121
+ }
122
+ parameters[:frequency_penalty] = frequency_penalty if frequency_penalty
123
+ parameters[:logit_bias] = logit_bias if logit_bias
124
+ parameters[:logprobs] = logprobs if logprobs
125
+ parameters[:top_logprobs] = top_logprobs if top_logprobs
126
+ # TODO: Fix max_tokens validation to account for tools/functions
127
+ parameters[:max_tokens] = max_tokens if max_tokens # || validate_max_tokens(parameters[:messages], parameters[:model])
128
+ parameters[:n] = n if n
129
+ parameters[:presence_penalty] = presence_penalty if presence_penalty
130
+ parameters[:response_format] = response_format if response_format
131
+ parameters[:seed] = seed if seed
132
+ parameters[:stop] = stop if stop
133
+ parameters[:stream] = stream if stream
134
+ parameters[:temperature] = temperature if temperature
135
+ parameters[:top_p] = top_p if top_p
136
+ parameters[:tools] = tools if tools.any?
137
+ parameters[:tool_choice] = tool_choice if tool_choice
138
+ parameters[:user] = user if user
139
+
140
+ # TODO: Clean this part up
141
+ if block
142
+ @response_chunks = []
143
+ parameters[:stream] = proc do |chunk, _bytesize|
144
+ chunk_content = chunk.dig("choices", 0)
145
+ @response_chunks << chunk
146
+ yield chunk_content
147
+ end
148
+ end
130
149
 
131
- if functions
132
- parameters[:functions] = functions
133
- else
134
- parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model], parameters[:max_tokens])
150
+ response = with_api_error_handling do
151
+ client.chat(parameters: parameters)
135
152
  end
136
153
 
137
- response = with_api_error_handling { client.chat(parameters: parameters) }
138
154
  response = response_from_chunks if block
139
155
  reset_response_chunks
156
+
140
157
  Langchain::LLM::OpenAIResponse.new(response)
141
158
  end
142
159
 
143
- #
144
160
  # Generate a summary for a given text
145
161
  #
146
162
  # @param text [String] The text to generate a summary for
147
163
  # @return [String] The summary
148
- #
149
164
  def summarize(text:)
150
165
  prompt_template = Langchain::Prompt.load_from_path(
151
166
  file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
152
167
  )
153
168
  prompt = prompt_template.format(text: text)
154
169
 
155
- complete(prompt: prompt, temperature: @defaults[:temperature])
156
- # Should this return a Langchain::LLM::OpenAIResponse as well?
170
+ complete(prompt: prompt)
157
171
  end
158
172
 
159
173
  private
@@ -164,71 +178,6 @@ module Langchain::LLM
164
178
  @response_chunks = []
165
179
  end
166
180
 
167
- def is_legacy_model?(model)
168
- LEGACY_COMPLETION_MODELS.any? { |legacy_model| model.include?(legacy_model) }
169
- end
170
-
171
- def legacy_complete(prompt, parameters)
172
- Langchain.logger.warn "DEPRECATION WARNING: The model #{parameters[:model]} is deprecated. Please use gpt-3.5-turbo instead. Details: https://platform.openai.com/docs/deprecations/2023-07-06-gpt-and-embeddings"
173
-
174
- parameters[:prompt] = prompt
175
- parameters[:max_tokens] = validate_max_tokens(prompt, parameters[:model])
176
-
177
- response = with_api_error_handling do
178
- client.completions(parameters: parameters)
179
- end
180
- response.dig("choices", 0, "text")
181
- end
182
-
183
- def compose_parameters(model, params, &block)
184
- default_params = {model: model, temperature: @defaults[:temperature], n: @defaults[:n]}
185
- default_params[:stop] = params.delete(:stop_sequences) if params[:stop_sequences]
186
- parameters = default_params.merge(params)
187
-
188
- if block
189
- @response_chunks = []
190
- parameters[:stream] = proc do |chunk, _bytesize|
191
- chunk_content = chunk.dig("choices", 0)
192
- @response_chunks << chunk
193
- yield chunk_content
194
- end
195
- end
196
-
197
- parameters
198
- end
199
-
200
- def compose_chat_messages(prompt:, messages: [], context: "", examples: [])
201
- history = []
202
-
203
- history.concat transform_messages(examples) unless examples.empty?
204
-
205
- history.concat transform_messages(messages) unless messages.empty?
206
-
207
- unless context.nil? || context.empty?
208
- history.reject! { |message| message[:role] == "system" }
209
- history.prepend({role: "system", content: context})
210
- end
211
-
212
- unless prompt.empty?
213
- if history.last && history.last[:role] == "user"
214
- history.last[:content] += "\n#{prompt}"
215
- else
216
- history.append({role: "user", content: prompt})
217
- end
218
- end
219
-
220
- history
221
- end
222
-
223
- def transform_messages(messages)
224
- messages.map do |message|
225
- {
226
- role: message[:role],
227
- content: message[:content]
228
- }
229
- end
230
- end
231
-
232
181
  def with_api_error_handling
233
182
  response = yield
234
183
  return if response.empty?
@@ -239,12 +188,7 @@ module Langchain::LLM
239
188
  end
240
189
 
241
190
  def validate_max_tokens(messages, model, max_tokens = nil)
242
- LENGTH_VALIDATOR.validate_max_tokens!(messages, model, max_tokens: max_tokens)
243
- end
244
-
245
- def extract_response(response)
246
- results = response.dig("choices").map { |choice| choice.dig("message", "content") }
247
- (results.size == 1) ? results.first : results
191
+ LENGTH_VALIDATOR.validate_max_tokens!(messages, model, max_tokens: max_tokens, llm: self)
248
192
  end
249
193
 
250
194
  def response_from_chunks
@@ -77,12 +77,6 @@ module Langchain::LLM
77
77
  Langchain::LLM::ReplicateResponse.new(response, model: @defaults[:completion_model_name])
78
78
  end
79
79
 
80
- # Cohere does not have a dedicated chat endpoint, so instead we call `complete()`
81
- def chat(...)
82
- response_text = complete(...)
83
- ::Langchain::Conversation::Response.new(response_text)
84
- end
85
-
86
80
  #
87
81
  # Generate a summary for a given text
88
82
  #
@@ -25,5 +25,9 @@ module Langchain::LLM
25
25
  def log_id
26
26
  raw_response.dig("log_id")
27
27
  end
28
+
29
+ def role
30
+ "assistant"
31
+ end
28
32
  end
29
33
  end
@@ -32,5 +32,9 @@ module Langchain::LLM
32
32
  def embeddings
33
33
  [raw_response.dig("embedding", "value")]
34
34
  end
35
+
36
+ def role
37
+ "assistant"
38
+ end
35
39
  end
36
40
  end
@@ -22,5 +22,9 @@ module Langchain::LLM
22
22
  def embeddings
23
23
  [raw_response&.dig("embedding")]
24
24
  end
25
+
26
+ def role
27
+ "assistant"
28
+ end
25
29
  end
26
30
  end
@@ -16,10 +16,18 @@ module Langchain::LLM
16
16
  completions&.dig(0, "message", "content")
17
17
  end
18
18
 
19
+ def role
20
+ completions&.dig(0, "message", "role")
21
+ end
22
+
19
23
  def chat_completion
20
24
  completion
21
25
  end
22
26
 
27
+ def tool_calls
28
+ chat_completions&.dig(0, "message", "tool_calls")
29
+ end
30
+
23
31
  def embedding
24
32
  embeddings&.first
25
33
  end
@@ -91,6 +91,30 @@ module Langchain::Tool
91
91
  new.execute(input: input)
92
92
  end
93
93
 
94
+ # Returns the tool as an OpenAI tool
95
+ #
96
+ # @return [Hash] tool as an OpenAI tool
97
+ def to_openai_tool
98
+ # TODO: This is hardcoded to def execute(input:) found in each tool, needs to be dynamic.
99
+ {
100
+ type: "function",
101
+ function: {
102
+ name: name,
103
+ description: description,
104
+ parameters: {
105
+ type: "object",
106
+ properties: {
107
+ input: {
108
+ type: "string",
109
+ description: "Input to the tool"
110
+ }
111
+ },
112
+ required: ["input"]
113
+ }
114
+ }
115
+ }
116
+ end
117
+
94
118
  #
95
119
  # Executes the tool and returns the answer
96
120
  #
@@ -17,10 +17,7 @@ module Langchain::Tool
17
17
  description <<~DESC
18
18
  A wrapper around SerpApi's Google Search API.
19
19
 
20
- Useful for when you need to answer questions about current events.
21
- Always one of the first options when you need to find information on internet.
22
-
23
- Input should be a search query.
20
+ Useful for when you need to answer questions about current events. Always one of the first options when you need to find information on internet. Input should be a search query.
24
21
  DESC
25
22
 
26
23
  attr_reader :api_key
@@ -22,8 +22,8 @@ module Langchain
22
22
  # @param model_name [String] The model name to validate against
23
23
  # @return [Integer] The token length of the text
24
24
  #
25
- def self.token_length(text, model_name, client)
26
- res = client.tokenize(text)
25
+ def self.token_length(text, model_name, options = {})
26
+ res = options[:llm].tokenize(text)
27
27
  res.dig(:tokens).length
28
28
  end
29
29
 
@@ -30,8 +30,8 @@ module Langchain
30
30
  # @param model_name [String] The model name to validate against
31
31
  # @return [Integer] The token length of the text
32
32
  #
33
- def self.token_length(text, model_name, client)
34
- res = client.tokenize(text: text)
33
+ def self.token_length(text, model_name, options = {})
34
+ res = options[:llm].tokenize(text: text)
35
35
  res["tokens"].length
36
36
  end
37
37
 
@@ -35,7 +35,7 @@ module Langchain
35
35
  # @option options [Langchain::LLM:GooglePalm] :llm The Langchain::LLM:GooglePalm instance
36
36
  # @return [Integer] The token length of the text
37
37
  #
38
- def self.token_length(text, model_name = "chat-bison-001", options)
38
+ def self.token_length(text, model_name = "chat-bison-001", options = {})
39
39
  response = options[:llm].client.count_message_tokens(model: model_name, prompt: text)
40
40
 
41
41
  raise Langchain::LLM::ApiError.new(response["error"]["message"]) unless response["error"].nil?
@@ -43,7 +43,7 @@ module Langchain
43
43
  response.dig("tokenCount")
44
44
  end
45
45
 
46
- def self.token_length_from_messages(messages, model_name, options)
46
+ def self.token_length_from_messages(messages, model_name, options = {})
47
47
  messages.sum { |message| token_length(message.to_json, model_name, options) }
48
48
  end
49
49
 
@@ -93,10 +93,10 @@ module Langchain
93
93
  tokens_per_message = 4 # every message follows {role/name}\n{content}\n
94
94
  tokens_per_name = -1 # if there's a name, the role is omitted
95
95
  elsif model_name.include?("gpt-3.5-turbo")
96
- puts "Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613."
96
+ # puts "Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613."
97
97
  return token_length_from_messages(messages, "gpt-3.5-turbo-0613", options)
98
98
  elsif model_name.include?("gpt-4")
99
- puts "Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
99
+ # puts "Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
100
100
  return token_length_from_messages(messages, "gpt-4-0613", options)
101
101
  else
102
102
  raise NotImplementedError.new(
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain
4
- VERSION = "0.8.2"
4
+ VERSION = "0.9.0"
5
5
  end
data/lib/langchain.rb CHANGED
@@ -24,6 +24,7 @@ loader.inflector.inflect(
24
24
  "sql_query_agent" => "SQLQueryAgent"
25
25
  )
26
26
  loader.collapse("#{__dir__}/langchain/llm/response")
27
+ loader.collapse("#{__dir__}/langchain/assistants")
27
28
  loader.setup
28
29
 
29
30
  # Langchain.rb a is library for building LLM-backed Ruby applications. It is an abstraction layer that sits on top of the emerging AI-related tools that makes it easy for developers to consume and string those services together.
@@ -82,7 +83,7 @@ module Langchain
82
83
  attr_reader :root
83
84
  end
84
85
 
85
- self.logger ||= ::Logger.new($stdout, level: :warn)
86
+ self.logger ||= ::Logger.new($stdout, level: :debug)
86
87
 
87
88
  @root = Pathname.new(__dir__)
88
89