langchainrb 0.7.5 → 0.12.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (95) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +78 -0
  3. data/README.md +113 -56
  4. data/lib/langchain/assistants/assistant.rb +213 -0
  5. data/lib/langchain/assistants/message.rb +58 -0
  6. data/lib/langchain/assistants/thread.rb +34 -0
  7. data/lib/langchain/chunker/markdown.rb +37 -0
  8. data/lib/langchain/chunker/recursive_text.rb +0 -2
  9. data/lib/langchain/chunker/semantic.rb +1 -3
  10. data/lib/langchain/chunker/sentence.rb +0 -2
  11. data/lib/langchain/chunker/text.rb +0 -2
  12. data/lib/langchain/contextual_logger.rb +1 -1
  13. data/lib/langchain/data.rb +4 -3
  14. data/lib/langchain/llm/ai21.rb +1 -1
  15. data/lib/langchain/llm/anthropic.rb +86 -11
  16. data/lib/langchain/llm/aws_bedrock.rb +52 -0
  17. data/lib/langchain/llm/azure.rb +10 -97
  18. data/lib/langchain/llm/base.rb +3 -2
  19. data/lib/langchain/llm/cohere.rb +5 -7
  20. data/lib/langchain/llm/google_palm.rb +4 -2
  21. data/lib/langchain/llm/google_vertex_ai.rb +151 -0
  22. data/lib/langchain/llm/hugging_face.rb +1 -1
  23. data/lib/langchain/llm/llama_cpp.rb +18 -16
  24. data/lib/langchain/llm/mistral_ai.rb +68 -0
  25. data/lib/langchain/llm/ollama.rb +209 -27
  26. data/lib/langchain/llm/openai.rb +138 -170
  27. data/lib/langchain/llm/prompts/ollama/summarize_template.yaml +9 -0
  28. data/lib/langchain/llm/replicate.rb +1 -7
  29. data/lib/langchain/llm/response/anthropic_response.rb +20 -0
  30. data/lib/langchain/llm/response/base_response.rb +7 -0
  31. data/lib/langchain/llm/response/google_palm_response.rb +4 -0
  32. data/lib/langchain/llm/response/google_vertex_ai_response.rb +33 -0
  33. data/lib/langchain/llm/response/llama_cpp_response.rb +13 -0
  34. data/lib/langchain/llm/response/mistral_ai_response.rb +39 -0
  35. data/lib/langchain/llm/response/ollama_response.rb +27 -1
  36. data/lib/langchain/llm/response/openai_response.rb +8 -0
  37. data/lib/langchain/loader.rb +3 -2
  38. data/lib/langchain/output_parsers/base.rb +0 -4
  39. data/lib/langchain/output_parsers/output_fixing_parser.rb +7 -14
  40. data/lib/langchain/output_parsers/structured_output_parser.rb +0 -10
  41. data/lib/langchain/processors/csv.rb +37 -3
  42. data/lib/langchain/processors/eml.rb +64 -0
  43. data/lib/langchain/processors/markdown.rb +17 -0
  44. data/lib/langchain/processors/pptx.rb +29 -0
  45. data/lib/langchain/prompt/loading.rb +1 -1
  46. data/lib/langchain/tool/base.rb +21 -53
  47. data/lib/langchain/tool/calculator/calculator.json +19 -0
  48. data/lib/langchain/tool/{calculator.rb → calculator/calculator.rb} +8 -16
  49. data/lib/langchain/tool/database/database.json +46 -0
  50. data/lib/langchain/tool/database/database.rb +99 -0
  51. data/lib/langchain/tool/file_system/file_system.json +57 -0
  52. data/lib/langchain/tool/file_system/file_system.rb +32 -0
  53. data/lib/langchain/tool/google_search/google_search.json +19 -0
  54. data/lib/langchain/tool/{google_search.rb → google_search/google_search.rb} +5 -15
  55. data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +19 -0
  56. data/lib/langchain/tool/{ruby_code_interpreter.rb → ruby_code_interpreter/ruby_code_interpreter.rb} +8 -4
  57. data/lib/langchain/tool/vectorsearch/vectorsearch.json +24 -0
  58. data/lib/langchain/tool/vectorsearch/vectorsearch.rb +36 -0
  59. data/lib/langchain/tool/weather/weather.json +19 -0
  60. data/lib/langchain/tool/{weather.rb → weather/weather.rb} +3 -15
  61. data/lib/langchain/tool/wikipedia/wikipedia.json +19 -0
  62. data/lib/langchain/tool/{wikipedia.rb → wikipedia/wikipedia.rb} +9 -9
  63. data/lib/langchain/utils/token_length/ai21_validator.rb +6 -2
  64. data/lib/langchain/utils/token_length/base_validator.rb +1 -1
  65. data/lib/langchain/utils/token_length/cohere_validator.rb +6 -2
  66. data/lib/langchain/utils/token_length/google_palm_validator.rb +5 -1
  67. data/lib/langchain/utils/token_length/openai_validator.rb +55 -1
  68. data/lib/langchain/utils/token_length/token_limit_exceeded.rb +1 -1
  69. data/lib/langchain/vectorsearch/base.rb +11 -4
  70. data/lib/langchain/vectorsearch/chroma.rb +10 -1
  71. data/lib/langchain/vectorsearch/elasticsearch.rb +53 -4
  72. data/lib/langchain/vectorsearch/epsilla.rb +149 -0
  73. data/lib/langchain/vectorsearch/hnswlib.rb +5 -1
  74. data/lib/langchain/vectorsearch/milvus.rb +4 -2
  75. data/lib/langchain/vectorsearch/pgvector.rb +14 -4
  76. data/lib/langchain/vectorsearch/pinecone.rb +8 -5
  77. data/lib/langchain/vectorsearch/qdrant.rb +16 -4
  78. data/lib/langchain/vectorsearch/weaviate.rb +20 -2
  79. data/lib/langchain/version.rb +1 -1
  80. data/lib/langchain.rb +20 -5
  81. metadata +182 -45
  82. data/lib/langchain/agent/agents.md +0 -54
  83. data/lib/langchain/agent/base.rb +0 -20
  84. data/lib/langchain/agent/react_agent/react_agent_prompt.yaml +0 -26
  85. data/lib/langchain/agent/react_agent.rb +0 -131
  86. data/lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.yaml +0 -11
  87. data/lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.yaml +0 -21
  88. data/lib/langchain/agent/sql_query_agent.rb +0 -82
  89. data/lib/langchain/conversation/context.rb +0 -8
  90. data/lib/langchain/conversation/memory.rb +0 -86
  91. data/lib/langchain/conversation/message.rb +0 -48
  92. data/lib/langchain/conversation/prompt.rb +0 -8
  93. data/lib/langchain/conversation/response.rb +0 -8
  94. data/lib/langchain/conversation.rb +0 -93
  95. data/lib/langchain/tool/database.rb +0 -90
@@ -0,0 +1,34 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ # Langchain::Thread keeps track of messages in a conversation.
5
+ # TODO: Add functionality to persist to the thread to disk, DB, storage, etc.
6
+ class Thread
7
+ attr_accessor :messages
8
+
9
+ # @param messages [Array<Langchain::Message>]
10
+ def initialize(messages: [])
11
+ raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::Message) }
12
+
13
+ @messages = messages
14
+ end
15
+
16
+ # Convert the thread to an OpenAI API-compatible array of hashes
17
+ #
18
+ # @return [Array<Hash>] The thread as an OpenAI API-compatible array of hashes
19
+ def openai_messages
20
+ messages.map(&:to_openai_format)
21
+ end
22
+
23
+ # Add a message to the thread
24
+ #
25
+ # @param message [Langchain::Message] The message to add
26
+ # @return [Array<Langchain::Message>] The updated messages array
27
+ def add_message(message)
28
+ raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::Message)
29
+
30
+ # Prepend the message to the thread
31
+ messages << message
32
+ end
33
+ end
34
+ end
@@ -0,0 +1,37 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "baran"
4
+
5
+ module Langchain
6
+ module Chunker
7
+ # Simple text chunker
8
+ #
9
+ # Usage:
10
+ # Langchain::Chunker::Markdown.new(text).chunks
11
+ class Markdown < Base
12
+ attr_reader :text, :chunk_size, :chunk_overlap
13
+
14
+ # @param [String] text
15
+ # @param [Integer] chunk_size
16
+ # @param [Integer] chunk_overlap
17
+ # @param [String] separator
18
+ def initialize(text, chunk_size: 1000, chunk_overlap: 200)
19
+ @text = text
20
+ @chunk_size = chunk_size
21
+ @chunk_overlap = chunk_overlap
22
+ end
23
+
24
+ # @return [Array<Langchain::Chunk>]
25
+ def chunks
26
+ splitter = Baran::MarkdownSplitter.new(
27
+ chunk_size: chunk_size,
28
+ chunk_overlap: chunk_overlap
29
+ )
30
+
31
+ splitter.chunks(text).map do |chunk|
32
+ Langchain::Chunk.new(text: chunk[:text])
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -4,12 +4,10 @@ require "baran"
4
4
 
5
5
  module Langchain
6
6
  module Chunker
7
- #
8
7
  # Recursive text chunker. Preferentially splits on separators.
9
8
  #
10
9
  # Usage:
11
10
  # Langchain::Chunker::RecursiveText.new(text).chunks
12
- #
13
11
  class RecursiveText < Base
14
12
  attr_reader :text, :chunk_size, :chunk_overlap, :separators
15
13
 
@@ -2,7 +2,6 @@
2
2
 
3
3
  module Langchain
4
4
  module Chunker
5
- #
6
5
  # LLM-powered semantic chunker.
7
6
  # Semantic chunking is a technique of splitting texts by their semantic meaning, e.g.: themes, topics, and ideas.
8
7
  # We use an LLM to accomplish this. The Anthropic LLM is highly recommended for this task as it has the longest context window (100k tokens).
@@ -12,7 +11,6 @@ module Langchain
12
11
  # text,
13
12
  # llm: Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
14
13
  # ).chunks
15
- #
16
14
  class Semantic < Base
17
15
  attr_reader :text, :llm, :prompt_template
18
16
  # @param [Langchain::LLM::Base] Langchain::LLM::* instance
@@ -28,7 +26,7 @@ module Langchain
28
26
  prompt = prompt_template.format(text: text)
29
27
 
30
28
  # Replace static 50k limit with dynamic limit based on text length (max_tokens_to_sample)
31
- completion = llm.complete(prompt: prompt, max_tokens_to_sample: 50000)
29
+ completion = llm.complete(prompt: prompt, max_tokens_to_sample: 50000).completion
32
30
  completion
33
31
  .gsub("Here are the paragraphs split by topic:\n\n", "")
34
32
  .split("---")
@@ -4,12 +4,10 @@ require "pragmatic_segmenter"
4
4
 
5
5
  module Langchain
6
6
  module Chunker
7
- #
8
7
  # This chunker splits text by sentences.
9
8
  #
10
9
  # Usage:
11
10
  # Langchain::Chunker::Sentence.new(text).chunks
12
- #
13
11
  class Sentence < Base
14
12
  attr_reader :text
15
13
 
@@ -4,12 +4,10 @@ require "baran"
4
4
 
5
5
  module Langchain
6
6
  module Chunker
7
- #
8
7
  # Simple text chunker
9
8
  #
10
9
  # Usage:
11
10
  # Langchain::Chunker::Text.new(text).chunks
12
- #
13
11
  class Text < Base
14
12
  attr_reader :text, :chunk_size, :chunk_overlap, :separator
15
13
 
@@ -42,7 +42,7 @@ module Langchain
42
42
  for_class_name = for_class&.name
43
43
 
44
44
  log_line_parts = []
45
- log_line_parts << "[LangChain.rb]".colorize(color: :yellow)
45
+ log_line_parts << "[Langchain.rb]".colorize(color: :yellow)
46
46
  log_line_parts << if for_class.respond_to?(:logger_options)
47
47
  "[#{for_class_name}]".colorize(for_class.logger_options) + ":"
48
48
  elsif for_class_name
@@ -9,9 +9,10 @@ module Langchain
9
9
 
10
10
  # @param data [String] data that was loaded
11
11
  # @option options [String] :source URL or Path of the data source
12
- def initialize(data, options = {})
13
- @source = options[:source]
12
+ def initialize(data, source: nil, chunker: Langchain::Chunker::Text)
13
+ @source = source
14
14
  @data = data
15
+ @chunker_klass = chunker
15
16
  end
16
17
 
17
18
  # @return [String]
@@ -22,7 +23,7 @@ module Langchain
22
23
  # @param opts [Hash] options passed to the chunker
23
24
  # @return [Array<String>]
24
25
  def chunks(opts = {})
25
- Langchain::Chunker::Text.new(@data, **opts).chunks
26
+ @chunker_klass.new(@data, **opts).chunks
26
27
  end
27
28
  end
28
29
  end
@@ -35,7 +35,7 @@ module Langchain::LLM
35
35
  def complete(prompt:, **params)
36
36
  parameters = complete_parameters params
37
37
 
38
- parameters[:maxTokens] = LENGTH_VALIDATOR.validate_max_tokens!(prompt, parameters[:model], client)
38
+ parameters[:maxTokens] = LENGTH_VALIDATOR.validate_max_tokens!(prompt, parameters[:model], {llm: client})
39
39
 
40
40
  response = client.complete(prompt, parameters)
41
41
  Langchain::LLM::AI21Response.new response, model: parameters[:model]
@@ -14,12 +14,19 @@ module Langchain::LLM
14
14
  DEFAULTS = {
15
15
  temperature: 0.0,
16
16
  completion_model_name: "claude-2",
17
+ chat_completion_model_name: "claude-3-sonnet-20240229",
17
18
  max_tokens_to_sample: 256
18
19
  }.freeze
19
20
 
20
21
  # TODO: Implement token length validator for Anthropic
21
22
  # LENGTH_VALIDATOR = Langchain::Utils::TokenLength::AnthropicValidator
22
23
 
24
+ # Initialize an Anthropic LLM instance
25
+ #
26
+ # @param api_key [String] The API key to use
27
+ # @param llm_options [Hash] Options to pass to the Anthropic client
28
+ # @param default_options [Hash] Default options to use on every call to LLM, e.g.: { temperature:, completion_model_name:, chat_completion_model_name:, max_tokens_to_sample: }
29
+ # @return [Langchain::LLM::Anthropic] Langchain::LLM::Anthropic instance
23
30
  def initialize(api_key:, llm_options: {}, default_options: {})
24
31
  depends_on "anthropic"
25
32
 
@@ -27,17 +34,43 @@ module Langchain::LLM
27
34
  @defaults = DEFAULTS.merge(default_options)
28
35
  end
29
36
 
30
- #
31
37
  # Generate a completion for a given prompt
32
38
  #
33
- # @param prompt [String] The prompt to generate a completion for
34
- # @param params [Hash] extra parameters passed to Anthropic::Client#complete
39
+ # @param prompt [String] Prompt to generate a completion for
40
+ # @param model [String] The model to use
41
+ # @param max_tokens_to_sample [Integer] The maximum number of tokens to sample
42
+ # @param stop_sequences [Array<String>] The stop sequences to use
43
+ # @param temperature [Float] The temperature to use
44
+ # @param top_p [Float] The top p value to use
45
+ # @param top_k [Integer] The top k value to use
46
+ # @param metadata [Hash] The metadata to use
47
+ # @param stream [Boolean] Whether to stream the response
35
48
  # @return [Langchain::LLM::AnthropicResponse] The completion
36
- #
37
- def complete(prompt:, **params)
38
- parameters = compose_parameters @defaults[:completion_model_name], params
49
+ def complete(
50
+ prompt:,
51
+ model: @defaults[:completion_model_name],
52
+ max_tokens_to_sample: @defaults[:max_tokens_to_sample],
53
+ stop_sequences: nil,
54
+ temperature: @defaults[:temperature],
55
+ top_p: nil,
56
+ top_k: nil,
57
+ metadata: nil,
58
+ stream: nil
59
+ )
60
+ raise ArgumentError.new("model argument is required") if model.empty?
61
+ raise ArgumentError.new("max_tokens_to_sample argument is required") if max_tokens_to_sample.nil?
39
62
 
40
- parameters[:prompt] = prompt
63
+ parameters = {
64
+ model: model,
65
+ prompt: prompt,
66
+ max_tokens_to_sample: max_tokens_to_sample,
67
+ temperature: temperature
68
+ }
69
+ parameters[:stop_sequences] = stop_sequences if stop_sequences
70
+ parameters[:top_p] = top_p if top_p
71
+ parameters[:top_k] = top_k if top_k
72
+ parameters[:metadata] = metadata if metadata
73
+ parameters[:stream] = stream if stream
41
74
 
42
75
  # TODO: Implement token length validator for Anthropic
43
76
  # parameters[:max_tokens_to_sample] = validate_max_tokens(prompt, parameters[:completion_model_name])
@@ -46,12 +79,54 @@ module Langchain::LLM
46
79
  Langchain::LLM::AnthropicResponse.new(response)
47
80
  end
48
81
 
49
- private
82
+ # Generate a chat completion for given messages
83
+ #
84
+ # @param messages [Array<String>] Input messages
85
+ # @param model [String] The model that will complete your prompt
86
+ # @param max_tokens [Integer] Maximum number of tokens to generate before stopping
87
+ # @param metadata [Hash] Object describing metadata about the request
88
+ # @param stop_sequences [Array<String>] Custom text sequences that will cause the model to stop generating
89
+ # @param stream [Boolean] Whether to incrementally stream the response using server-sent events
90
+ # @param system [String] System prompt
91
+ # @param temperature [Float] Amount of randomness injected into the response
92
+ # @param tools [Array<String>] Definitions of tools that the model may use
93
+ # @param top_k [Integer] Only sample from the top K options for each subsequent token
94
+ # @param top_p [Float] Use nucleus sampling.
95
+ # @return [Langchain::LLM::AnthropicResponse] The chat completion
96
+ def chat(
97
+ messages: [],
98
+ model: @defaults[:chat_completion_model_name],
99
+ max_tokens: @defaults[:max_tokens_to_sample],
100
+ metadata: nil,
101
+ stop_sequences: nil,
102
+ stream: nil,
103
+ system: nil,
104
+ temperature: @defaults[:temperature],
105
+ tools: [],
106
+ top_k: nil,
107
+ top_p: nil
108
+ )
109
+ raise ArgumentError.new("messages argument is required") if messages.empty?
110
+ raise ArgumentError.new("model argument is required") if model.empty?
111
+ raise ArgumentError.new("max_tokens argument is required") if max_tokens.nil?
112
+
113
+ parameters = {
114
+ messages: messages,
115
+ model: model,
116
+ max_tokens: max_tokens,
117
+ temperature: temperature
118
+ }
119
+ parameters[:metadata] = metadata if metadata
120
+ parameters[:stop_sequences] = stop_sequences if stop_sequences
121
+ parameters[:stream] = stream if stream
122
+ parameters[:system] = system if system
123
+ parameters[:tools] = tools if tools.any?
124
+ parameters[:top_k] = top_k if top_k
125
+ parameters[:top_p] = top_p if top_p
50
126
 
51
- def compose_parameters(model, params)
52
- default_params = {model: model}.merge(@defaults.except(:completion_model_name))
127
+ response = client.messages(parameters: parameters)
53
128
 
54
- default_params.merge(params)
129
+ Langchain::LLM::AnthropicResponse.new(response)
55
130
  end
56
131
 
57
132
  # TODO: Implement token length validator for Anthropic
@@ -46,7 +46,10 @@ module Langchain::LLM
46
46
  }
47
47
  }.freeze
48
48
 
49
+ attr_reader :client, :defaults
50
+
49
51
  SUPPORTED_COMPLETION_PROVIDERS = %i[anthropic cohere ai21].freeze
52
+ SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[anthropic].freeze
50
53
  SUPPORTED_EMBEDDING_PROVIDERS = %i[amazon].freeze
51
54
 
52
55
  def initialize(completion_model: DEFAULTS[:completion_model_name], embedding_model: DEFAULTS[:embedding_model_name], aws_client_options: {}, default_options: {})
@@ -91,6 +94,8 @@ module Langchain::LLM
91
94
  def complete(prompt:, **params)
92
95
  raise "Completion provider #{completion_provider} is not supported." unless SUPPORTED_COMPLETION_PROVIDERS.include?(completion_provider)
93
96
 
97
+ raise "Model #{@defaults[:completion_model_name]} only supports #chat." if @defaults[:completion_model_name].include?("claude-3")
98
+
94
99
  parameters = compose_parameters params
95
100
 
96
101
  parameters[:prompt] = wrap_prompt prompt
@@ -105,6 +110,53 @@ module Langchain::LLM
105
110
  parse_response response
106
111
  end
107
112
 
113
+ # Generate a chat completion for a given prompt
114
+ # Currently only configured to work with the Anthropic provider and
115
+ # the claude-3 model family
116
+ # @param messages [Array] The messages to generate a completion for
117
+ # @param system [String] The system prompt to provide instructions
118
+ # @param model [String] The model to use for completion defaults to @defaults[:chat_completion_model_name]
119
+ # @param max_tokens [Integer] The maximum number of tokens to generate
120
+ # @param stop_sequences [Array] The stop sequences to use for completion
121
+ # @param temperature [Float] The temperature to use for completion
122
+ # @param top_p [Float] The top p to use for completion
123
+ # @param top_k [Integer] The top k to use for completion
124
+ # @return [Langchain::LLM::AnthropicMessagesResponse] Response object
125
+ def chat(
126
+ messages: [],
127
+ system: nil,
128
+ model: defaults[:completion_model_name],
129
+ max_tokens: defaults[:max_tokens_to_sample],
130
+ stop_sequences: nil,
131
+ temperature: nil,
132
+ top_p: nil,
133
+ top_k: nil
134
+ )
135
+ raise ArgumentError.new("messages argument is required") if messages.empty?
136
+
137
+ raise "Model #{model} does not support chat completions." unless Langchain::LLM::AwsBedrock::SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(completion_provider)
138
+
139
+ inference_parameters = {
140
+ messages: messages,
141
+ max_tokens: max_tokens,
142
+ anthropic_version: @defaults[:anthropic_version]
143
+ }
144
+ inference_parameters[:system] = system if system
145
+ inference_parameters[:stop_sequences] = stop_sequences if stop_sequences
146
+ inference_parameters[:temperature] = temperature if temperature
147
+ inference_parameters[:top_p] = top_p if top_p
148
+ inference_parameters[:top_k] = top_k if top_k
149
+
150
+ response = client.invoke_model({
151
+ model_id: model,
152
+ body: inference_parameters.to_json,
153
+ content_type: "application/json",
154
+ accept: "application/json"
155
+ })
156
+
157
+ parse_response response
158
+ end
159
+
108
160
  private
109
161
 
110
162
  def completion_provider
@@ -4,7 +4,7 @@ module Langchain::LLM
4
4
  # LLM interface for Azure OpenAI Service APIs: https://learn.microsoft.com/en-us/azure/ai-services/openai/
5
5
  #
6
6
  # Gem requirements:
7
- # gem "ruby-openai", "~> 5.2.0"
7
+ # gem "ruby-openai", "~> 6.3.0"
8
8
  #
9
9
  # Usage:
10
10
  # openai = Langchain::LLM::Azure.new(api_key:, llm_options: {}, embedding_deployment_url: chat_deployment_url:)
@@ -34,106 +34,19 @@ module Langchain::LLM
34
34
  @defaults = DEFAULTS.merge(default_options)
35
35
  end
36
36
 
37
- #
38
- # Generate an embedding for a given text
39
- #
40
- # @param text [String] The text to generate an embedding for
41
- # @param params extra parameters passed to OpenAI::Client#embeddings
42
- # @return [Langchain::LLM::OpenAIResponse] Response object
43
- #
44
- def embed(text:, **params)
45
- parameters = {model: @defaults[:embeddings_model_name], input: text}
46
-
47
- validate_max_tokens(text, parameters[:model])
48
-
49
- response = with_api_error_handling do
50
- embed_client.embeddings(parameters: parameters.merge(params))
51
- end
52
-
53
- Langchain::LLM::OpenAIResponse.new(response)
37
+ def embed(...)
38
+ @client = @embed_client
39
+ super(...)
54
40
  end
55
41
 
56
- #
57
- # Generate a completion for a given prompt
58
- #
59
- # @param prompt [String] The prompt to generate a completion for
60
- # @param params extra parameters passed to OpenAI::Client#complete
61
- # @return [Langchain::LLM::Response::OpenaAI] Response object
62
- #
63
- def complete(prompt:, **params)
64
- parameters = compose_parameters @defaults[:completion_model_name], params
65
-
66
- parameters[:messages] = compose_chat_messages(prompt: prompt)
67
- parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
68
-
69
- response = with_api_error_handling do
70
- chat_client.chat(parameters: parameters)
71
- end
72
-
73
- Langchain::LLM::OpenAIResponse.new(response)
42
+ def complete(...)
43
+ @client = @chat_client
44
+ super(...)
74
45
  end
75
46
 
76
- #
77
- # Generate a chat completion for a given prompt or messages.
78
- #
79
- # == Examples
80
- #
81
- # # simplest case, just give a prompt
82
- # openai.chat prompt: "When was Ruby first released?"
83
- #
84
- # # prompt plus some context about how to respond
85
- # openai.chat context: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
86
- #
87
- # # full control over messages that get sent, equivilent to the above
88
- # openai.chat messages: [
89
- # {
90
- # role: "system",
91
- # content: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
92
- # },
93
- # {
94
- # role: "user",
95
- # content: "When was Ruby first released?"
96
- # }
97
- # ]
98
- #
99
- # # few-short prompting with examples
100
- # openai.chat prompt: "When was factory_bot released?",
101
- # examples: [
102
- # {
103
- # role: "user",
104
- # content: "When was Ruby on Rails released?"
105
- # }
106
- # {
107
- # role: "assistant",
108
- # content: "2004"
109
- # },
110
- # ]
111
- #
112
- # @param prompt [String] The prompt to generate a chat completion for
113
- # @param messages [Array<Hash>] The messages that have been sent in the conversation
114
- # @param context [String] An initial context to provide as a system message, ie "You are RubyGPT, a helpful chat bot for helping people learn Ruby"
115
- # @param examples [Array<Hash>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
116
- # @param options [Hash] extra parameters passed to OpenAI::Client#chat
117
- # @yield [Hash] Stream responses back one token at a time
118
- # @return [Langchain::LLM::OpenAIResponse] Response object
119
- #
120
- def chat(prompt: "", messages: [], context: "", examples: [], **options, &block)
121
- raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
122
-
123
- parameters = compose_parameters @defaults[:chat_completion_model_name], options, &block
124
- parameters[:messages] = compose_chat_messages(prompt: prompt, messages: messages, context: context, examples: examples)
125
-
126
- if functions
127
- parameters[:functions] = functions
128
- else
129
- parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
130
- end
131
-
132
- response = with_api_error_handling { chat_client.chat(parameters: parameters) }
133
-
134
- return if block
135
-
136
- Langchain::LLM::OpenAIResponse.new(response)
47
+ def chat(...)
48
+ @client = @chat_client
49
+ super(...)
137
50
  end
138
51
  end
139
52
  end
@@ -11,6 +11,7 @@ module Langchain::LLM
11
11
  # - {Langchain::LLM::Azure}
12
12
  # - {Langchain::LLM::Cohere}
13
13
  # - {Langchain::LLM::GooglePalm}
14
+ # - {Langchain::LLM::GoogleVertexAi}
14
15
  # - {Langchain::LLM::HuggingFace}
15
16
  # - {Langchain::LLM::LlamaCpp}
16
17
  # - {Langchain::LLM::OpenAI}
@@ -23,8 +24,8 @@ module Langchain::LLM
23
24
  # A client for communicating with the LLM
24
25
  attr_reader :client
25
26
 
26
- def default_dimension
27
- self.class.const_get(:DEFAULTS).dig(:dimension)
27
+ def default_dimensions
28
+ self.class.const_get(:DEFAULTS).dig(:dimensions)
28
29
  end
29
30
 
30
31
  #
@@ -15,7 +15,7 @@ module Langchain::LLM
15
15
  temperature: 0.0,
16
16
  completion_model_name: "command",
17
17
  embeddings_model_name: "small",
18
- dimension: 1024,
18
+ dimensions: 1024,
19
19
  truncate: "START"
20
20
  }.freeze
21
21
 
@@ -62,17 +62,15 @@ module Langchain::LLM
62
62
 
63
63
  default_params.merge!(params)
64
64
 
65
- default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], client)
65
+ default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], llm: client)
66
66
 
67
67
  response = client.generate(**default_params)
68
68
  Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
69
69
  end
70
70
 
71
- # Cohere does not have a dedicated chat endpoint, so instead we call `complete()`
72
- def chat(...)
73
- response_text = complete(...)
74
- ::Langchain::Conversation::Response.new(response_text)
75
- end
71
+ # TODO: Implement chat method: https://github.com/andreibondarev/cohere-ruby/issues/11
72
+ # def chat
73
+ # end
76
74
 
77
75
  # Generate a summary in English for a given text
78
76
  #
@@ -13,7 +13,7 @@ module Langchain::LLM
13
13
  class GooglePalm < Base
14
14
  DEFAULTS = {
15
15
  temperature: 0.0,
16
- dimension: 768, # This is what the `embedding-gecko-001` model generates
16
+ dimensions: 768, # This is what the `embedding-gecko-001` model generates
17
17
  completion_model_name: "text-bison-001",
18
18
  chat_completion_model_name: "chat-bison-001",
19
19
  embeddings_model_name: "embedding-gecko-001"
@@ -23,6 +23,8 @@ module Langchain::LLM
23
23
  "assistant" => "ai"
24
24
  }
25
25
 
26
+ attr_reader :defaults
27
+
26
28
  def initialize(api_key:, default_options: {})
27
29
  depends_on "google_palm_api"
28
30
 
@@ -131,7 +133,7 @@ module Langchain::LLM
131
133
  prompt: prompt,
132
134
  temperature: @defaults[:temperature],
133
135
  # Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
134
- max_tokens: 2048
136
+ max_tokens: 256
135
137
  )
136
138
  end
137
139