langchainrb 0.7.5 → 0.12.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +78 -0
- data/README.md +113 -56
- data/lib/langchain/assistants/assistant.rb +213 -0
- data/lib/langchain/assistants/message.rb +58 -0
- data/lib/langchain/assistants/thread.rb +34 -0
- data/lib/langchain/chunker/markdown.rb +37 -0
- data/lib/langchain/chunker/recursive_text.rb +0 -2
- data/lib/langchain/chunker/semantic.rb +1 -3
- data/lib/langchain/chunker/sentence.rb +0 -2
- data/lib/langchain/chunker/text.rb +0 -2
- data/lib/langchain/contextual_logger.rb +1 -1
- data/lib/langchain/data.rb +4 -3
- data/lib/langchain/llm/ai21.rb +1 -1
- data/lib/langchain/llm/anthropic.rb +86 -11
- data/lib/langchain/llm/aws_bedrock.rb +52 -0
- data/lib/langchain/llm/azure.rb +10 -97
- data/lib/langchain/llm/base.rb +3 -2
- data/lib/langchain/llm/cohere.rb +5 -7
- data/lib/langchain/llm/google_palm.rb +4 -2
- data/lib/langchain/llm/google_vertex_ai.rb +151 -0
- data/lib/langchain/llm/hugging_face.rb +1 -1
- data/lib/langchain/llm/llama_cpp.rb +18 -16
- data/lib/langchain/llm/mistral_ai.rb +68 -0
- data/lib/langchain/llm/ollama.rb +209 -27
- data/lib/langchain/llm/openai.rb +138 -170
- data/lib/langchain/llm/prompts/ollama/summarize_template.yaml +9 -0
- data/lib/langchain/llm/replicate.rb +1 -7
- data/lib/langchain/llm/response/anthropic_response.rb +20 -0
- data/lib/langchain/llm/response/base_response.rb +7 -0
- data/lib/langchain/llm/response/google_palm_response.rb +4 -0
- data/lib/langchain/llm/response/google_vertex_ai_response.rb +33 -0
- data/lib/langchain/llm/response/llama_cpp_response.rb +13 -0
- data/lib/langchain/llm/response/mistral_ai_response.rb +39 -0
- data/lib/langchain/llm/response/ollama_response.rb +27 -1
- data/lib/langchain/llm/response/openai_response.rb +8 -0
- data/lib/langchain/loader.rb +3 -2
- data/lib/langchain/output_parsers/base.rb +0 -4
- data/lib/langchain/output_parsers/output_fixing_parser.rb +7 -14
- data/lib/langchain/output_parsers/structured_output_parser.rb +0 -10
- data/lib/langchain/processors/csv.rb +37 -3
- data/lib/langchain/processors/eml.rb +64 -0
- data/lib/langchain/processors/markdown.rb +17 -0
- data/lib/langchain/processors/pptx.rb +29 -0
- data/lib/langchain/prompt/loading.rb +1 -1
- data/lib/langchain/tool/base.rb +21 -53
- data/lib/langchain/tool/calculator/calculator.json +19 -0
- data/lib/langchain/tool/{calculator.rb → calculator/calculator.rb} +8 -16
- data/lib/langchain/tool/database/database.json +46 -0
- data/lib/langchain/tool/database/database.rb +99 -0
- data/lib/langchain/tool/file_system/file_system.json +57 -0
- data/lib/langchain/tool/file_system/file_system.rb +32 -0
- data/lib/langchain/tool/google_search/google_search.json +19 -0
- data/lib/langchain/tool/{google_search.rb → google_search/google_search.rb} +5 -15
- data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +19 -0
- data/lib/langchain/tool/{ruby_code_interpreter.rb → ruby_code_interpreter/ruby_code_interpreter.rb} +8 -4
- data/lib/langchain/tool/vectorsearch/vectorsearch.json +24 -0
- data/lib/langchain/tool/vectorsearch/vectorsearch.rb +36 -0
- data/lib/langchain/tool/weather/weather.json +19 -0
- data/lib/langchain/tool/{weather.rb → weather/weather.rb} +3 -15
- data/lib/langchain/tool/wikipedia/wikipedia.json +19 -0
- data/lib/langchain/tool/{wikipedia.rb → wikipedia/wikipedia.rb} +9 -9
- data/lib/langchain/utils/token_length/ai21_validator.rb +6 -2
- data/lib/langchain/utils/token_length/base_validator.rb +1 -1
- data/lib/langchain/utils/token_length/cohere_validator.rb +6 -2
- data/lib/langchain/utils/token_length/google_palm_validator.rb +5 -1
- data/lib/langchain/utils/token_length/openai_validator.rb +55 -1
- data/lib/langchain/utils/token_length/token_limit_exceeded.rb +1 -1
- data/lib/langchain/vectorsearch/base.rb +11 -4
- data/lib/langchain/vectorsearch/chroma.rb +10 -1
- data/lib/langchain/vectorsearch/elasticsearch.rb +53 -4
- data/lib/langchain/vectorsearch/epsilla.rb +149 -0
- data/lib/langchain/vectorsearch/hnswlib.rb +5 -1
- data/lib/langchain/vectorsearch/milvus.rb +4 -2
- data/lib/langchain/vectorsearch/pgvector.rb +14 -4
- data/lib/langchain/vectorsearch/pinecone.rb +8 -5
- data/lib/langchain/vectorsearch/qdrant.rb +16 -4
- data/lib/langchain/vectorsearch/weaviate.rb +20 -2
- data/lib/langchain/version.rb +1 -1
- data/lib/langchain.rb +20 -5
- metadata +182 -45
- data/lib/langchain/agent/agents.md +0 -54
- data/lib/langchain/agent/base.rb +0 -20
- data/lib/langchain/agent/react_agent/react_agent_prompt.yaml +0 -26
- data/lib/langchain/agent/react_agent.rb +0 -131
- data/lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.yaml +0 -11
- data/lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.yaml +0 -21
- data/lib/langchain/agent/sql_query_agent.rb +0 -82
- data/lib/langchain/conversation/context.rb +0 -8
- data/lib/langchain/conversation/memory.rb +0 -86
- data/lib/langchain/conversation/message.rb +0 -48
- data/lib/langchain/conversation/prompt.rb +0 -8
- data/lib/langchain/conversation/response.rb +0 -8
- data/lib/langchain/conversation.rb +0 -93
- data/lib/langchain/tool/database.rb +0 -90
@@ -0,0 +1,34 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
# Langchain::Thread keeps track of messages in a conversation.
|
5
|
+
# TODO: Add functionality to persist to the thread to disk, DB, storage, etc.
|
6
|
+
class Thread
|
7
|
+
attr_accessor :messages
|
8
|
+
|
9
|
+
# @param messages [Array<Langchain::Message>]
|
10
|
+
def initialize(messages: [])
|
11
|
+
raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::Message) }
|
12
|
+
|
13
|
+
@messages = messages
|
14
|
+
end
|
15
|
+
|
16
|
+
# Convert the thread to an OpenAI API-compatible array of hashes
|
17
|
+
#
|
18
|
+
# @return [Array<Hash>] The thread as an OpenAI API-compatible array of hashes
|
19
|
+
def openai_messages
|
20
|
+
messages.map(&:to_openai_format)
|
21
|
+
end
|
22
|
+
|
23
|
+
# Add a message to the thread
|
24
|
+
#
|
25
|
+
# @param message [Langchain::Message] The message to add
|
26
|
+
# @return [Array<Langchain::Message>] The updated messages array
|
27
|
+
def add_message(message)
|
28
|
+
raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::Message)
|
29
|
+
|
30
|
+
# Prepend the message to the thread
|
31
|
+
messages << message
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "baran"
|
4
|
+
|
5
|
+
module Langchain
|
6
|
+
module Chunker
|
7
|
+
# Simple text chunker
|
8
|
+
#
|
9
|
+
# Usage:
|
10
|
+
# Langchain::Chunker::Markdown.new(text).chunks
|
11
|
+
class Markdown < Base
|
12
|
+
attr_reader :text, :chunk_size, :chunk_overlap
|
13
|
+
|
14
|
+
# @param [String] text
|
15
|
+
# @param [Integer] chunk_size
|
16
|
+
# @param [Integer] chunk_overlap
|
17
|
+
# @param [String] separator
|
18
|
+
def initialize(text, chunk_size: 1000, chunk_overlap: 200)
|
19
|
+
@text = text
|
20
|
+
@chunk_size = chunk_size
|
21
|
+
@chunk_overlap = chunk_overlap
|
22
|
+
end
|
23
|
+
|
24
|
+
# @return [Array<Langchain::Chunk>]
|
25
|
+
def chunks
|
26
|
+
splitter = Baran::MarkdownSplitter.new(
|
27
|
+
chunk_size: chunk_size,
|
28
|
+
chunk_overlap: chunk_overlap
|
29
|
+
)
|
30
|
+
|
31
|
+
splitter.chunks(text).map do |chunk|
|
32
|
+
Langchain::Chunk.new(text: chunk[:text])
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -4,12 +4,10 @@ require "baran"
|
|
4
4
|
|
5
5
|
module Langchain
|
6
6
|
module Chunker
|
7
|
-
#
|
8
7
|
# Recursive text chunker. Preferentially splits on separators.
|
9
8
|
#
|
10
9
|
# Usage:
|
11
10
|
# Langchain::Chunker::RecursiveText.new(text).chunks
|
12
|
-
#
|
13
11
|
class RecursiveText < Base
|
14
12
|
attr_reader :text, :chunk_size, :chunk_overlap, :separators
|
15
13
|
|
@@ -2,7 +2,6 @@
|
|
2
2
|
|
3
3
|
module Langchain
|
4
4
|
module Chunker
|
5
|
-
#
|
6
5
|
# LLM-powered semantic chunker.
|
7
6
|
# Semantic chunking is a technique of splitting texts by their semantic meaning, e.g.: themes, topics, and ideas.
|
8
7
|
# We use an LLM to accomplish this. The Anthropic LLM is highly recommended for this task as it has the longest context window (100k tokens).
|
@@ -12,7 +11,6 @@ module Langchain
|
|
12
11
|
# text,
|
13
12
|
# llm: Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
|
14
13
|
# ).chunks
|
15
|
-
#
|
16
14
|
class Semantic < Base
|
17
15
|
attr_reader :text, :llm, :prompt_template
|
18
16
|
# @param [Langchain::LLM::Base] Langchain::LLM::* instance
|
@@ -28,7 +26,7 @@ module Langchain
|
|
28
26
|
prompt = prompt_template.format(text: text)
|
29
27
|
|
30
28
|
# Replace static 50k limit with dynamic limit based on text length (max_tokens_to_sample)
|
31
|
-
completion = llm.complete(prompt: prompt, max_tokens_to_sample: 50000)
|
29
|
+
completion = llm.complete(prompt: prompt, max_tokens_to_sample: 50000).completion
|
32
30
|
completion
|
33
31
|
.gsub("Here are the paragraphs split by topic:\n\n", "")
|
34
32
|
.split("---")
|
@@ -42,7 +42,7 @@ module Langchain
|
|
42
42
|
for_class_name = for_class&.name
|
43
43
|
|
44
44
|
log_line_parts = []
|
45
|
-
log_line_parts << "[
|
45
|
+
log_line_parts << "[Langchain.rb]".colorize(color: :yellow)
|
46
46
|
log_line_parts << if for_class.respond_to?(:logger_options)
|
47
47
|
"[#{for_class_name}]".colorize(for_class.logger_options) + ":"
|
48
48
|
elsif for_class_name
|
data/lib/langchain/data.rb
CHANGED
@@ -9,9 +9,10 @@ module Langchain
|
|
9
9
|
|
10
10
|
# @param data [String] data that was loaded
|
11
11
|
# @option options [String] :source URL or Path of the data source
|
12
|
-
def initialize(data,
|
13
|
-
@source =
|
12
|
+
def initialize(data, source: nil, chunker: Langchain::Chunker::Text)
|
13
|
+
@source = source
|
14
14
|
@data = data
|
15
|
+
@chunker_klass = chunker
|
15
16
|
end
|
16
17
|
|
17
18
|
# @return [String]
|
@@ -22,7 +23,7 @@ module Langchain
|
|
22
23
|
# @param opts [Hash] options passed to the chunker
|
23
24
|
# @return [Array<String>]
|
24
25
|
def chunks(opts = {})
|
25
|
-
|
26
|
+
@chunker_klass.new(@data, **opts).chunks
|
26
27
|
end
|
27
28
|
end
|
28
29
|
end
|
data/lib/langchain/llm/ai21.rb
CHANGED
@@ -35,7 +35,7 @@ module Langchain::LLM
|
|
35
35
|
def complete(prompt:, **params)
|
36
36
|
parameters = complete_parameters params
|
37
37
|
|
38
|
-
parameters[:maxTokens] = LENGTH_VALIDATOR.validate_max_tokens!(prompt, parameters[:model], client)
|
38
|
+
parameters[:maxTokens] = LENGTH_VALIDATOR.validate_max_tokens!(prompt, parameters[:model], {llm: client})
|
39
39
|
|
40
40
|
response = client.complete(prompt, parameters)
|
41
41
|
Langchain::LLM::AI21Response.new response, model: parameters[:model]
|
@@ -14,12 +14,19 @@ module Langchain::LLM
|
|
14
14
|
DEFAULTS = {
|
15
15
|
temperature: 0.0,
|
16
16
|
completion_model_name: "claude-2",
|
17
|
+
chat_completion_model_name: "claude-3-sonnet-20240229",
|
17
18
|
max_tokens_to_sample: 256
|
18
19
|
}.freeze
|
19
20
|
|
20
21
|
# TODO: Implement token length validator for Anthropic
|
21
22
|
# LENGTH_VALIDATOR = Langchain::Utils::TokenLength::AnthropicValidator
|
22
23
|
|
24
|
+
# Initialize an Anthropic LLM instance
|
25
|
+
#
|
26
|
+
# @param api_key [String] The API key to use
|
27
|
+
# @param llm_options [Hash] Options to pass to the Anthropic client
|
28
|
+
# @param default_options [Hash] Default options to use on every call to LLM, e.g.: { temperature:, completion_model_name:, chat_completion_model_name:, max_tokens_to_sample: }
|
29
|
+
# @return [Langchain::LLM::Anthropic] Langchain::LLM::Anthropic instance
|
23
30
|
def initialize(api_key:, llm_options: {}, default_options: {})
|
24
31
|
depends_on "anthropic"
|
25
32
|
|
@@ -27,17 +34,43 @@ module Langchain::LLM
|
|
27
34
|
@defaults = DEFAULTS.merge(default_options)
|
28
35
|
end
|
29
36
|
|
30
|
-
#
|
31
37
|
# Generate a completion for a given prompt
|
32
38
|
#
|
33
|
-
# @param prompt [String]
|
34
|
-
# @param
|
39
|
+
# @param prompt [String] Prompt to generate a completion for
|
40
|
+
# @param model [String] The model to use
|
41
|
+
# @param max_tokens_to_sample [Integer] The maximum number of tokens to sample
|
42
|
+
# @param stop_sequences [Array<String>] The stop sequences to use
|
43
|
+
# @param temperature [Float] The temperature to use
|
44
|
+
# @param top_p [Float] The top p value to use
|
45
|
+
# @param top_k [Integer] The top k value to use
|
46
|
+
# @param metadata [Hash] The metadata to use
|
47
|
+
# @param stream [Boolean] Whether to stream the response
|
35
48
|
# @return [Langchain::LLM::AnthropicResponse] The completion
|
36
|
-
|
37
|
-
|
38
|
-
|
49
|
+
def complete(
|
50
|
+
prompt:,
|
51
|
+
model: @defaults[:completion_model_name],
|
52
|
+
max_tokens_to_sample: @defaults[:max_tokens_to_sample],
|
53
|
+
stop_sequences: nil,
|
54
|
+
temperature: @defaults[:temperature],
|
55
|
+
top_p: nil,
|
56
|
+
top_k: nil,
|
57
|
+
metadata: nil,
|
58
|
+
stream: nil
|
59
|
+
)
|
60
|
+
raise ArgumentError.new("model argument is required") if model.empty?
|
61
|
+
raise ArgumentError.new("max_tokens_to_sample argument is required") if max_tokens_to_sample.nil?
|
39
62
|
|
40
|
-
parameters
|
63
|
+
parameters = {
|
64
|
+
model: model,
|
65
|
+
prompt: prompt,
|
66
|
+
max_tokens_to_sample: max_tokens_to_sample,
|
67
|
+
temperature: temperature
|
68
|
+
}
|
69
|
+
parameters[:stop_sequences] = stop_sequences if stop_sequences
|
70
|
+
parameters[:top_p] = top_p if top_p
|
71
|
+
parameters[:top_k] = top_k if top_k
|
72
|
+
parameters[:metadata] = metadata if metadata
|
73
|
+
parameters[:stream] = stream if stream
|
41
74
|
|
42
75
|
# TODO: Implement token length validator for Anthropic
|
43
76
|
# parameters[:max_tokens_to_sample] = validate_max_tokens(prompt, parameters[:completion_model_name])
|
@@ -46,12 +79,54 @@ module Langchain::LLM
|
|
46
79
|
Langchain::LLM::AnthropicResponse.new(response)
|
47
80
|
end
|
48
81
|
|
49
|
-
|
82
|
+
# Generate a chat completion for given messages
|
83
|
+
#
|
84
|
+
# @param messages [Array<String>] Input messages
|
85
|
+
# @param model [String] The model that will complete your prompt
|
86
|
+
# @param max_tokens [Integer] Maximum number of tokens to generate before stopping
|
87
|
+
# @param metadata [Hash] Object describing metadata about the request
|
88
|
+
# @param stop_sequences [Array<String>] Custom text sequences that will cause the model to stop generating
|
89
|
+
# @param stream [Boolean] Whether to incrementally stream the response using server-sent events
|
90
|
+
# @param system [String] System prompt
|
91
|
+
# @param temperature [Float] Amount of randomness injected into the response
|
92
|
+
# @param tools [Array<String>] Definitions of tools that the model may use
|
93
|
+
# @param top_k [Integer] Only sample from the top K options for each subsequent token
|
94
|
+
# @param top_p [Float] Use nucleus sampling.
|
95
|
+
# @return [Langchain::LLM::AnthropicResponse] The chat completion
|
96
|
+
def chat(
|
97
|
+
messages: [],
|
98
|
+
model: @defaults[:chat_completion_model_name],
|
99
|
+
max_tokens: @defaults[:max_tokens_to_sample],
|
100
|
+
metadata: nil,
|
101
|
+
stop_sequences: nil,
|
102
|
+
stream: nil,
|
103
|
+
system: nil,
|
104
|
+
temperature: @defaults[:temperature],
|
105
|
+
tools: [],
|
106
|
+
top_k: nil,
|
107
|
+
top_p: nil
|
108
|
+
)
|
109
|
+
raise ArgumentError.new("messages argument is required") if messages.empty?
|
110
|
+
raise ArgumentError.new("model argument is required") if model.empty?
|
111
|
+
raise ArgumentError.new("max_tokens argument is required") if max_tokens.nil?
|
112
|
+
|
113
|
+
parameters = {
|
114
|
+
messages: messages,
|
115
|
+
model: model,
|
116
|
+
max_tokens: max_tokens,
|
117
|
+
temperature: temperature
|
118
|
+
}
|
119
|
+
parameters[:metadata] = metadata if metadata
|
120
|
+
parameters[:stop_sequences] = stop_sequences if stop_sequences
|
121
|
+
parameters[:stream] = stream if stream
|
122
|
+
parameters[:system] = system if system
|
123
|
+
parameters[:tools] = tools if tools.any?
|
124
|
+
parameters[:top_k] = top_k if top_k
|
125
|
+
parameters[:top_p] = top_p if top_p
|
50
126
|
|
51
|
-
|
52
|
-
default_params = {model: model}.merge(@defaults.except(:completion_model_name))
|
127
|
+
response = client.messages(parameters: parameters)
|
53
128
|
|
54
|
-
|
129
|
+
Langchain::LLM::AnthropicResponse.new(response)
|
55
130
|
end
|
56
131
|
|
57
132
|
# TODO: Implement token length validator for Anthropic
|
@@ -46,7 +46,10 @@ module Langchain::LLM
|
|
46
46
|
}
|
47
47
|
}.freeze
|
48
48
|
|
49
|
+
attr_reader :client, :defaults
|
50
|
+
|
49
51
|
SUPPORTED_COMPLETION_PROVIDERS = %i[anthropic cohere ai21].freeze
|
52
|
+
SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[anthropic].freeze
|
50
53
|
SUPPORTED_EMBEDDING_PROVIDERS = %i[amazon].freeze
|
51
54
|
|
52
55
|
def initialize(completion_model: DEFAULTS[:completion_model_name], embedding_model: DEFAULTS[:embedding_model_name], aws_client_options: {}, default_options: {})
|
@@ -91,6 +94,8 @@ module Langchain::LLM
|
|
91
94
|
def complete(prompt:, **params)
|
92
95
|
raise "Completion provider #{completion_provider} is not supported." unless SUPPORTED_COMPLETION_PROVIDERS.include?(completion_provider)
|
93
96
|
|
97
|
+
raise "Model #{@defaults[:completion_model_name]} only supports #chat." if @defaults[:completion_model_name].include?("claude-3")
|
98
|
+
|
94
99
|
parameters = compose_parameters params
|
95
100
|
|
96
101
|
parameters[:prompt] = wrap_prompt prompt
|
@@ -105,6 +110,53 @@ module Langchain::LLM
|
|
105
110
|
parse_response response
|
106
111
|
end
|
107
112
|
|
113
|
+
# Generate a chat completion for a given prompt
|
114
|
+
# Currently only configured to work with the Anthropic provider and
|
115
|
+
# the claude-3 model family
|
116
|
+
# @param messages [Array] The messages to generate a completion for
|
117
|
+
# @param system [String] The system prompt to provide instructions
|
118
|
+
# @param model [String] The model to use for completion defaults to @defaults[:chat_completion_model_name]
|
119
|
+
# @param max_tokens [Integer] The maximum number of tokens to generate
|
120
|
+
# @param stop_sequences [Array] The stop sequences to use for completion
|
121
|
+
# @param temperature [Float] The temperature to use for completion
|
122
|
+
# @param top_p [Float] The top p to use for completion
|
123
|
+
# @param top_k [Integer] The top k to use for completion
|
124
|
+
# @return [Langchain::LLM::AnthropicMessagesResponse] Response object
|
125
|
+
def chat(
|
126
|
+
messages: [],
|
127
|
+
system: nil,
|
128
|
+
model: defaults[:completion_model_name],
|
129
|
+
max_tokens: defaults[:max_tokens_to_sample],
|
130
|
+
stop_sequences: nil,
|
131
|
+
temperature: nil,
|
132
|
+
top_p: nil,
|
133
|
+
top_k: nil
|
134
|
+
)
|
135
|
+
raise ArgumentError.new("messages argument is required") if messages.empty?
|
136
|
+
|
137
|
+
raise "Model #{model} does not support chat completions." unless Langchain::LLM::AwsBedrock::SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(completion_provider)
|
138
|
+
|
139
|
+
inference_parameters = {
|
140
|
+
messages: messages,
|
141
|
+
max_tokens: max_tokens,
|
142
|
+
anthropic_version: @defaults[:anthropic_version]
|
143
|
+
}
|
144
|
+
inference_parameters[:system] = system if system
|
145
|
+
inference_parameters[:stop_sequences] = stop_sequences if stop_sequences
|
146
|
+
inference_parameters[:temperature] = temperature if temperature
|
147
|
+
inference_parameters[:top_p] = top_p if top_p
|
148
|
+
inference_parameters[:top_k] = top_k if top_k
|
149
|
+
|
150
|
+
response = client.invoke_model({
|
151
|
+
model_id: model,
|
152
|
+
body: inference_parameters.to_json,
|
153
|
+
content_type: "application/json",
|
154
|
+
accept: "application/json"
|
155
|
+
})
|
156
|
+
|
157
|
+
parse_response response
|
158
|
+
end
|
159
|
+
|
108
160
|
private
|
109
161
|
|
110
162
|
def completion_provider
|
data/lib/langchain/llm/azure.rb
CHANGED
@@ -4,7 +4,7 @@ module Langchain::LLM
|
|
4
4
|
# LLM interface for Azure OpenAI Service APIs: https://learn.microsoft.com/en-us/azure/ai-services/openai/
|
5
5
|
#
|
6
6
|
# Gem requirements:
|
7
|
-
# gem "ruby-openai", "~>
|
7
|
+
# gem "ruby-openai", "~> 6.3.0"
|
8
8
|
#
|
9
9
|
# Usage:
|
10
10
|
# openai = Langchain::LLM::Azure.new(api_key:, llm_options: {}, embedding_deployment_url: chat_deployment_url:)
|
@@ -34,106 +34,19 @@ module Langchain::LLM
|
|
34
34
|
@defaults = DEFAULTS.merge(default_options)
|
35
35
|
end
|
36
36
|
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
# @param text [String] The text to generate an embedding for
|
41
|
-
# @param params extra parameters passed to OpenAI::Client#embeddings
|
42
|
-
# @return [Langchain::LLM::OpenAIResponse] Response object
|
43
|
-
#
|
44
|
-
def embed(text:, **params)
|
45
|
-
parameters = {model: @defaults[:embeddings_model_name], input: text}
|
46
|
-
|
47
|
-
validate_max_tokens(text, parameters[:model])
|
48
|
-
|
49
|
-
response = with_api_error_handling do
|
50
|
-
embed_client.embeddings(parameters: parameters.merge(params))
|
51
|
-
end
|
52
|
-
|
53
|
-
Langchain::LLM::OpenAIResponse.new(response)
|
37
|
+
def embed(...)
|
38
|
+
@client = @embed_client
|
39
|
+
super(...)
|
54
40
|
end
|
55
41
|
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
# @param prompt [String] The prompt to generate a completion for
|
60
|
-
# @param params extra parameters passed to OpenAI::Client#complete
|
61
|
-
# @return [Langchain::LLM::Response::OpenaAI] Response object
|
62
|
-
#
|
63
|
-
def complete(prompt:, **params)
|
64
|
-
parameters = compose_parameters @defaults[:completion_model_name], params
|
65
|
-
|
66
|
-
parameters[:messages] = compose_chat_messages(prompt: prompt)
|
67
|
-
parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
|
68
|
-
|
69
|
-
response = with_api_error_handling do
|
70
|
-
chat_client.chat(parameters: parameters)
|
71
|
-
end
|
72
|
-
|
73
|
-
Langchain::LLM::OpenAIResponse.new(response)
|
42
|
+
def complete(...)
|
43
|
+
@client = @chat_client
|
44
|
+
super(...)
|
74
45
|
end
|
75
46
|
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
# == Examples
|
80
|
-
#
|
81
|
-
# # simplest case, just give a prompt
|
82
|
-
# openai.chat prompt: "When was Ruby first released?"
|
83
|
-
#
|
84
|
-
# # prompt plus some context about how to respond
|
85
|
-
# openai.chat context: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
|
86
|
-
#
|
87
|
-
# # full control over messages that get sent, equivilent to the above
|
88
|
-
# openai.chat messages: [
|
89
|
-
# {
|
90
|
-
# role: "system",
|
91
|
-
# content: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
|
92
|
-
# },
|
93
|
-
# {
|
94
|
-
# role: "user",
|
95
|
-
# content: "When was Ruby first released?"
|
96
|
-
# }
|
97
|
-
# ]
|
98
|
-
#
|
99
|
-
# # few-short prompting with examples
|
100
|
-
# openai.chat prompt: "When was factory_bot released?",
|
101
|
-
# examples: [
|
102
|
-
# {
|
103
|
-
# role: "user",
|
104
|
-
# content: "When was Ruby on Rails released?"
|
105
|
-
# }
|
106
|
-
# {
|
107
|
-
# role: "assistant",
|
108
|
-
# content: "2004"
|
109
|
-
# },
|
110
|
-
# ]
|
111
|
-
#
|
112
|
-
# @param prompt [String] The prompt to generate a chat completion for
|
113
|
-
# @param messages [Array<Hash>] The messages that have been sent in the conversation
|
114
|
-
# @param context [String] An initial context to provide as a system message, ie "You are RubyGPT, a helpful chat bot for helping people learn Ruby"
|
115
|
-
# @param examples [Array<Hash>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
|
116
|
-
# @param options [Hash] extra parameters passed to OpenAI::Client#chat
|
117
|
-
# @yield [Hash] Stream responses back one token at a time
|
118
|
-
# @return [Langchain::LLM::OpenAIResponse] Response object
|
119
|
-
#
|
120
|
-
def chat(prompt: "", messages: [], context: "", examples: [], **options, &block)
|
121
|
-
raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
|
122
|
-
|
123
|
-
parameters = compose_parameters @defaults[:chat_completion_model_name], options, &block
|
124
|
-
parameters[:messages] = compose_chat_messages(prompt: prompt, messages: messages, context: context, examples: examples)
|
125
|
-
|
126
|
-
if functions
|
127
|
-
parameters[:functions] = functions
|
128
|
-
else
|
129
|
-
parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
|
130
|
-
end
|
131
|
-
|
132
|
-
response = with_api_error_handling { chat_client.chat(parameters: parameters) }
|
133
|
-
|
134
|
-
return if block
|
135
|
-
|
136
|
-
Langchain::LLM::OpenAIResponse.new(response)
|
47
|
+
def chat(...)
|
48
|
+
@client = @chat_client
|
49
|
+
super(...)
|
137
50
|
end
|
138
51
|
end
|
139
52
|
end
|
data/lib/langchain/llm/base.rb
CHANGED
@@ -11,6 +11,7 @@ module Langchain::LLM
|
|
11
11
|
# - {Langchain::LLM::Azure}
|
12
12
|
# - {Langchain::LLM::Cohere}
|
13
13
|
# - {Langchain::LLM::GooglePalm}
|
14
|
+
# - {Langchain::LLM::GoogleVertexAi}
|
14
15
|
# - {Langchain::LLM::HuggingFace}
|
15
16
|
# - {Langchain::LLM::LlamaCpp}
|
16
17
|
# - {Langchain::LLM::OpenAI}
|
@@ -23,8 +24,8 @@ module Langchain::LLM
|
|
23
24
|
# A client for communicating with the LLM
|
24
25
|
attr_reader :client
|
25
26
|
|
26
|
-
def
|
27
|
-
self.class.const_get(:DEFAULTS).dig(:
|
27
|
+
def default_dimensions
|
28
|
+
self.class.const_get(:DEFAULTS).dig(:dimensions)
|
28
29
|
end
|
29
30
|
|
30
31
|
#
|
data/lib/langchain/llm/cohere.rb
CHANGED
@@ -15,7 +15,7 @@ module Langchain::LLM
|
|
15
15
|
temperature: 0.0,
|
16
16
|
completion_model_name: "command",
|
17
17
|
embeddings_model_name: "small",
|
18
|
-
|
18
|
+
dimensions: 1024,
|
19
19
|
truncate: "START"
|
20
20
|
}.freeze
|
21
21
|
|
@@ -62,17 +62,15 @@ module Langchain::LLM
|
|
62
62
|
|
63
63
|
default_params.merge!(params)
|
64
64
|
|
65
|
-
default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], client)
|
65
|
+
default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], llm: client)
|
66
66
|
|
67
67
|
response = client.generate(**default_params)
|
68
68
|
Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
|
69
69
|
end
|
70
70
|
|
71
|
-
#
|
72
|
-
def chat
|
73
|
-
|
74
|
-
::Langchain::Conversation::Response.new(response_text)
|
75
|
-
end
|
71
|
+
# TODO: Implement chat method: https://github.com/andreibondarev/cohere-ruby/issues/11
|
72
|
+
# def chat
|
73
|
+
# end
|
76
74
|
|
77
75
|
# Generate a summary in English for a given text
|
78
76
|
#
|
@@ -13,7 +13,7 @@ module Langchain::LLM
|
|
13
13
|
class GooglePalm < Base
|
14
14
|
DEFAULTS = {
|
15
15
|
temperature: 0.0,
|
16
|
-
|
16
|
+
dimensions: 768, # This is what the `embedding-gecko-001` model generates
|
17
17
|
completion_model_name: "text-bison-001",
|
18
18
|
chat_completion_model_name: "chat-bison-001",
|
19
19
|
embeddings_model_name: "embedding-gecko-001"
|
@@ -23,6 +23,8 @@ module Langchain::LLM
|
|
23
23
|
"assistant" => "ai"
|
24
24
|
}
|
25
25
|
|
26
|
+
attr_reader :defaults
|
27
|
+
|
26
28
|
def initialize(api_key:, default_options: {})
|
27
29
|
depends_on "google_palm_api"
|
28
30
|
|
@@ -131,7 +133,7 @@ module Langchain::LLM
|
|
131
133
|
prompt: prompt,
|
132
134
|
temperature: @defaults[:temperature],
|
133
135
|
# Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
|
134
|
-
max_tokens:
|
136
|
+
max_tokens: 256
|
135
137
|
)
|
136
138
|
end
|
137
139
|
|