langchainrb 0.8.2 → 0.9.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (33) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/README.md +57 -27
  4. data/lib/langchain/assistants/assistant.rb +199 -0
  5. data/lib/langchain/assistants/message.rb +58 -0
  6. data/lib/langchain/assistants/thread.rb +34 -0
  7. data/lib/langchain/conversation/memory.rb +1 -6
  8. data/lib/langchain/conversation.rb +7 -18
  9. data/lib/langchain/llm/ai21.rb +1 -1
  10. data/lib/langchain/llm/azure.rb +10 -97
  11. data/lib/langchain/llm/base.rb +1 -0
  12. data/lib/langchain/llm/cohere.rb +4 -6
  13. data/lib/langchain/llm/google_palm.rb +2 -0
  14. data/lib/langchain/llm/google_vertex_ai.rb +12 -10
  15. data/lib/langchain/llm/ollama.rb +167 -27
  16. data/lib/langchain/llm/openai.rb +104 -160
  17. data/lib/langchain/llm/replicate.rb +0 -6
  18. data/lib/langchain/llm/response/anthropic_response.rb +4 -0
  19. data/lib/langchain/llm/response/base_response.rb +7 -0
  20. data/lib/langchain/llm/response/google_palm_response.rb +4 -0
  21. data/lib/langchain/llm/response/ollama_response.rb +22 -0
  22. data/lib/langchain/llm/response/openai_response.rb +8 -0
  23. data/lib/langchain/tool/base.rb +24 -0
  24. data/lib/langchain/tool/google_search.rb +1 -4
  25. data/lib/langchain/utils/token_length/ai21_validator.rb +2 -2
  26. data/lib/langchain/utils/token_length/cohere_validator.rb +2 -2
  27. data/lib/langchain/utils/token_length/google_palm_validator.rb +2 -2
  28. data/lib/langchain/utils/token_length/openai_validator.rb +13 -2
  29. data/lib/langchain/utils/token_length/token_limit_exceeded.rb +1 -1
  30. data/lib/langchain/vectorsearch/pinecone.rb +2 -1
  31. data/lib/langchain/version.rb +1 -1
  32. data/lib/langchain.rb +2 -1
  33. metadata +24 -7
@@ -4,7 +4,7 @@ module Langchain::LLM
4
4
  # LLM interface for Azure OpenAI Service APIs: https://learn.microsoft.com/en-us/azure/ai-services/openai/
5
5
  #
6
6
  # Gem requirements:
7
- # gem "ruby-openai", "~> 6.1.0"
7
+ # gem "ruby-openai", "~> 6.3.0"
8
8
  #
9
9
  # Usage:
10
10
  # openai = Langchain::LLM::Azure.new(api_key:, llm_options: {}, embedding_deployment_url: chat_deployment_url:)
@@ -34,106 +34,19 @@ module Langchain::LLM
34
34
  @defaults = DEFAULTS.merge(default_options)
35
35
  end
36
36
 
37
- #
38
- # Generate an embedding for a given text
39
- #
40
- # @param text [String] The text to generate an embedding for
41
- # @param params extra parameters passed to OpenAI::Client#embeddings
42
- # @return [Langchain::LLM::OpenAIResponse] Response object
43
- #
44
- def embed(text:, **params)
45
- parameters = {model: @defaults[:embeddings_model_name], input: text}
46
-
47
- validate_max_tokens(text, parameters[:model])
48
-
49
- response = with_api_error_handling do
50
- embed_client.embeddings(parameters: parameters.merge(params))
51
- end
52
-
53
- Langchain::LLM::OpenAIResponse.new(response)
37
+ def embed(...)
38
+ @client = @embed_client
39
+ super(...)
54
40
  end
55
41
 
56
- #
57
- # Generate a completion for a given prompt
58
- #
59
- # @param prompt [String] The prompt to generate a completion for
60
- # @param params extra parameters passed to OpenAI::Client#complete
61
- # @return [Langchain::LLM::Response::OpenaAI] Response object
62
- #
63
- def complete(prompt:, **params)
64
- parameters = compose_parameters @defaults[:completion_model_name], params
65
-
66
- parameters[:messages] = compose_chat_messages(prompt: prompt)
67
- parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
68
-
69
- response = with_api_error_handling do
70
- chat_client.chat(parameters: parameters)
71
- end
72
-
73
- Langchain::LLM::OpenAIResponse.new(response)
42
+ def complete(...)
43
+ @client = @chat_client
44
+ super(...)
74
45
  end
75
46
 
76
- #
77
- # Generate a chat completion for a given prompt or messages.
78
- #
79
- # == Examples
80
- #
81
- # # simplest case, just give a prompt
82
- # openai.chat prompt: "When was Ruby first released?"
83
- #
84
- # # prompt plus some context about how to respond
85
- # openai.chat context: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
86
- #
87
- # # full control over messages that get sent, equivilent to the above
88
- # openai.chat messages: [
89
- # {
90
- # role: "system",
91
- # content: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
92
- # },
93
- # {
94
- # role: "user",
95
- # content: "When was Ruby first released?"
96
- # }
97
- # ]
98
- #
99
- # # few-short prompting with examples
100
- # openai.chat prompt: "When was factory_bot released?",
101
- # examples: [
102
- # {
103
- # role: "user",
104
- # content: "When was Ruby on Rails released?"
105
- # }
106
- # {
107
- # role: "assistant",
108
- # content: "2004"
109
- # },
110
- # ]
111
- #
112
- # @param prompt [String] The prompt to generate a chat completion for
113
- # @param messages [Array<Hash>] The messages that have been sent in the conversation
114
- # @param context [String] An initial context to provide as a system message, ie "You are RubyGPT, a helpful chat bot for helping people learn Ruby"
115
- # @param examples [Array<Hash>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
116
- # @param options [Hash] extra parameters passed to OpenAI::Client#chat
117
- # @yield [Hash] Stream responses back one token at a time
118
- # @return [Langchain::LLM::OpenAIResponse] Response object
119
- #
120
- def chat(prompt: "", messages: [], context: "", examples: [], **options, &block)
121
- raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
122
-
123
- parameters = compose_parameters @defaults[:chat_completion_model_name], options, &block
124
- parameters[:messages] = compose_chat_messages(prompt: prompt, messages: messages, context: context, examples: examples)
125
-
126
- if functions
127
- parameters[:functions] = functions
128
- else
129
- parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
130
- end
131
-
132
- response = with_api_error_handling { chat_client.chat(parameters: parameters) }
133
-
134
- return if block
135
-
136
- Langchain::LLM::OpenAIResponse.new(response)
47
+ def chat(...)
48
+ @client = @chat_client
49
+ super(...)
137
50
  end
138
51
  end
139
52
  end
@@ -11,6 +11,7 @@ module Langchain::LLM
11
11
  # - {Langchain::LLM::Azure}
12
12
  # - {Langchain::LLM::Cohere}
13
13
  # - {Langchain::LLM::GooglePalm}
14
+ # - {Langchain::LLM::GoogleVertexAi}
14
15
  # - {Langchain::LLM::HuggingFace}
15
16
  # - {Langchain::LLM::LlamaCpp}
16
17
  # - {Langchain::LLM::OpenAI}
@@ -62,17 +62,15 @@ module Langchain::LLM
62
62
 
63
63
  default_params.merge!(params)
64
64
 
65
- default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], client)
65
+ default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], llm: client)
66
66
 
67
67
  response = client.generate(**default_params)
68
68
  Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
69
69
  end
70
70
 
71
- # Cohere does not have a dedicated chat endpoint, so instead we call `complete()`
72
- def chat(...)
73
- response_text = complete(...)
74
- ::Langchain::Conversation::Response.new(response_text)
75
- end
71
+ # TODO: Implement chat method: https://github.com/andreibondarev/cohere-ruby/issues/11
72
+ # def chat
73
+ # end
76
74
 
77
75
  # Generate a summary in English for a given text
78
76
  #
@@ -23,6 +23,8 @@ module Langchain::LLM
23
23
  "assistant" => "ai"
24
24
  }
25
25
 
26
+ attr_reader :defaults
27
+
26
28
  def initialize(api_key:, default_options: {})
27
29
  depends_on "google_palm_api"
28
30
 
@@ -21,6 +21,9 @@ module Langchain::LLM
21
21
  embeddings_model_name: "textembedding-gecko"
22
22
  }.freeze
23
23
 
24
+ # TODO: Implement token length validation
25
+ # LENGTH_VALIDATOR = Langchain::Utils::TokenLength::...
26
+
24
27
  # Google Cloud has a project id and a specific region of deployment.
25
28
  # For GenAI-related things, a safe choice is us-central1.
26
29
  attr_reader :project_id, :client, :region
@@ -135,15 +138,14 @@ module Langchain::LLM
135
138
  )
136
139
  end
137
140
 
138
- def chat(...)
139
- # https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chathat
140
- # Chat params: https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chat
141
- # \"temperature\": 0.3,\n"
142
- # + " \"maxDecodeSteps\": 200,\n"
143
- # + " \"topP\": 0.8,\n"
144
- # + " \"topK\": 40\n"
145
- # + "}";
146
- raise NotImplementedError, "coming soon for Vertex AI.."
147
- end
141
+ # def chat(...)
142
+ # https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chathat
143
+ # Chat params: https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chat
144
+ # \"temperature\": 0.3,\n"
145
+ # + " \"maxDecodeSteps\": 200,\n"
146
+ # + " \"topP\": 0.8,\n"
147
+ # + " \"topK\": 40\n"
148
+ # + "}";
149
+ # end
148
150
  end
149
151
  end
@@ -5,21 +5,26 @@ module Langchain::LLM
5
5
  # Available models: https://ollama.ai/library
6
6
  #
7
7
  # Usage:
8
- # ollama = Langchain::LLM::Ollama.new(url: ENV["OLLAMA_URL"])
8
+ # ollama = Langchain::LLM::Ollama.new(url: ENV["OLLAMA_URL"], default_options: {})
9
9
  #
10
10
  class Ollama < Base
11
- attr_reader :url
11
+ attr_reader :url, :defaults
12
12
 
13
13
  DEFAULTS = {
14
- temperature: 0.0,
14
+ temperature: 0.8,
15
15
  completion_model_name: "llama2",
16
- embeddings_model_name: "llama2"
16
+ embeddings_model_name: "llama2",
17
+ chat_completion_model_name: "llama2"
17
18
  }.freeze
18
19
 
19
20
  # Initialize the Ollama client
20
21
  # @param url [String] The URL of the Ollama instance
21
- def initialize(url:)
22
+ # @param default_options [Hash] The default options to use
23
+ #
24
+ def initialize(url:, default_options: {})
25
+ depends_on "faraday"
22
26
  @url = url
27
+ @defaults = DEFAULTS.merge(default_options)
23
28
  end
24
29
 
25
30
  #
@@ -27,32 +32,128 @@ module Langchain::LLM
27
32
  #
28
33
  # @param prompt [String] The prompt to complete
29
34
  # @param model [String] The model to use
30
- # @param options [Hash] The options to use (https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values)
35
+ # For a list of valid parameters and values, see:
36
+ # https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
31
37
  # @return [Langchain::LLM::OllamaResponse] Response object
32
38
  #
33
- def complete(prompt:, model: nil, **options)
34
- response = +""
39
+ def complete(
40
+ prompt:,
41
+ model: defaults[:completion_model_name],
42
+ images: nil,
43
+ format: nil,
44
+ system: nil,
45
+ template: nil,
46
+ context: nil,
47
+ stream: nil,
48
+ raw: nil,
49
+ mirostat: nil,
50
+ mirostat_eta: nil,
51
+ mirostat_tau: nil,
52
+ num_ctx: nil,
53
+ num_gqa: nil,
54
+ num_gpu: nil,
55
+ num_thread: nil,
56
+ repeat_last_n: nil,
57
+ repeat_penalty: nil,
58
+ temperature: defaults[:temperature],
59
+ seed: nil,
60
+ stop: nil,
61
+ tfs_z: nil,
62
+ num_predict: nil,
63
+ top_k: nil,
64
+ top_p: nil,
65
+ stop_sequences: nil,
66
+ &block
67
+ )
68
+ if stop_sequences
69
+ stop = stop_sequences
70
+ end
35
71
 
36
- model_name = model || DEFAULTS[:completion_model_name]
72
+ parameters = {
73
+ prompt: prompt,
74
+ model: model,
75
+ images: images,
76
+ format: format,
77
+ system: system,
78
+ template: template,
79
+ context: context,
80
+ stream: stream,
81
+ raw: raw
82
+ }.compact
83
+
84
+ llm_parameters = {
85
+ mirostat: mirostat,
86
+ mirostat_eta: mirostat_eta,
87
+ mirostat_tau: mirostat_tau,
88
+ num_ctx: num_ctx,
89
+ num_gqa: num_gqa,
90
+ num_gpu: num_gpu,
91
+ num_thread: num_thread,
92
+ repeat_last_n: repeat_last_n,
93
+ repeat_penalty: repeat_penalty,
94
+ temperature: temperature,
95
+ seed: seed,
96
+ stop: stop,
97
+ tfs_z: tfs_z,
98
+ num_predict: num_predict,
99
+ top_k: top_k,
100
+ top_p: top_p
101
+ }
102
+
103
+ parameters[:options] = llm_parameters.compact
104
+
105
+ response = ""
37
106
 
38
107
  client.post("api/generate") do |req|
39
- req.body = {}
40
- req.body["prompt"] = prompt
41
- req.body["model"] = model_name
42
-
43
- req.body["options"] = options if options.any?
108
+ req.body = parameters
44
109
 
45
- # TODO: Implement streaming support when a &block is passed in
46
110
  req.options.on_data = proc do |chunk, size|
47
111
  json_chunk = JSON.parse(chunk)
48
112
 
49
- unless json_chunk.dig("done")
50
- response.to_s << JSON.parse(chunk).dig("response")
51
- end
113
+ response += json_chunk.dig("response")
114
+
115
+ yield json_chunk, size if block
52
116
  end
53
117
  end
54
118
 
55
- Langchain::LLM::OllamaResponse.new(response, model: model_name)
119
+ Langchain::LLM::OllamaResponse.new(response, model: parameters[:model])
120
+ end
121
+
122
+ # Generate a chat completion
123
+ #
124
+ # @param model [String] Model name
125
+ # @param messages [Array<Hash>] Array of messages
126
+ # @param format [String] Format to return a response in. Currently the only accepted value is `json`
127
+ # @param temperature [Float] The temperature to use
128
+ # @param template [String] The prompt template to use (overrides what is defined in the `Modelfile`)
129
+ # @param stream [Boolean] Streaming the response. If false the response will be returned as a single response object, rather than a stream of objects
130
+ #
131
+ # The message object has the following fields:
132
+ # role: the role of the message, either system, user or assistant
133
+ # content: the content of the message
134
+ # images (optional): a list of images to include in the message (for multimodal models such as llava)
135
+ def chat(
136
+ model: defaults[:chat_completion_model_name],
137
+ messages: [],
138
+ format: nil,
139
+ temperature: defaults[:temperature],
140
+ template: nil,
141
+ stream: false # TODO: Fix streaming.
142
+ )
143
+ parameters = {
144
+ model: model,
145
+ messages: messages,
146
+ format: format,
147
+ temperature: temperature,
148
+ template: template,
149
+ stream: stream
150
+ }.compact
151
+
152
+ response = client.post("api/chat") do |req|
153
+ req.body = parameters
154
+ end
155
+
156
+ Langchain::LLM::OllamaResponse.new(response.body, model: parameters[:model])
56
157
  end
57
158
 
58
159
  #
@@ -63,18 +164,57 @@ module Langchain::LLM
63
164
  # @param options [Hash] The options to use
64
165
  # @return [Langchain::LLM::OllamaResponse] Response object
65
166
  #
66
- def embed(text:, model: nil, **options)
67
- model_name = model || DEFAULTS[:embeddings_model_name]
167
+ def embed(
168
+ text:,
169
+ model: defaults[:embeddings_model_name],
170
+ mirostat: nil,
171
+ mirostat_eta: nil,
172
+ mirostat_tau: nil,
173
+ num_ctx: nil,
174
+ num_gqa: nil,
175
+ num_gpu: nil,
176
+ num_thread: nil,
177
+ repeat_last_n: nil,
178
+ repeat_penalty: nil,
179
+ temperature: defaults[:temperature],
180
+ seed: nil,
181
+ stop: nil,
182
+ tfs_z: nil,
183
+ num_predict: nil,
184
+ top_k: nil,
185
+ top_p: nil
186
+ )
187
+ parameters = {
188
+ prompt: text,
189
+ model: model
190
+ }.compact
191
+
192
+ llm_parameters = {
193
+ mirostat: mirostat,
194
+ mirostat_eta: mirostat_eta,
195
+ mirostat_tau: mirostat_tau,
196
+ num_ctx: num_ctx,
197
+ num_gqa: num_gqa,
198
+ num_gpu: num_gpu,
199
+ num_thread: num_thread,
200
+ repeat_last_n: repeat_last_n,
201
+ repeat_penalty: repeat_penalty,
202
+ temperature: temperature,
203
+ seed: seed,
204
+ stop: stop,
205
+ tfs_z: tfs_z,
206
+ num_predict: num_predict,
207
+ top_k: top_k,
208
+ top_p: top_p
209
+ }
210
+
211
+ parameters[:options] = llm_parameters.compact
68
212
 
69
213
  response = client.post("api/embeddings") do |req|
70
- req.body = {}
71
- req.body["prompt"] = text
72
- req.body["model"] = model_name
73
-
74
- req.body["options"] = options if options.any?
214
+ req.body = parameters
75
215
  end
76
216
 
77
- Langchain::LLM::OllamaResponse.new(response.body, model: model_name)
217
+ Langchain::LLM::OllamaResponse.new(response.body, model: parameters[:model])
78
218
  end
79
219
 
80
220
  private