langchainrb 0.12.0 → 0.13.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/README.md +3 -2
- data/lib/langchain/assistants/assistant.rb +75 -20
- data/lib/langchain/assistants/messages/base.rb +16 -0
- data/lib/langchain/assistants/messages/google_gemini_message.rb +90 -0
- data/lib/langchain/assistants/messages/openai_message.rb +74 -0
- data/lib/langchain/assistants/thread.rb +5 -5
- data/lib/langchain/llm/anthropic.rb +27 -49
- data/lib/langchain/llm/aws_bedrock.rb +30 -34
- data/lib/langchain/llm/azure.rb +6 -0
- data/lib/langchain/llm/base.rb +20 -1
- data/lib/langchain/llm/cohere.rb +38 -6
- data/lib/langchain/llm/google_gemini.rb +67 -0
- data/lib/langchain/llm/google_vertex_ai.rb +68 -112
- data/lib/langchain/llm/mistral_ai.rb +10 -19
- data/lib/langchain/llm/ollama.rb +23 -27
- data/lib/langchain/llm/openai.rb +20 -48
- data/lib/langchain/llm/parameters/chat.rb +51 -0
- data/lib/langchain/llm/response/base_response.rb +2 -2
- data/lib/langchain/llm/response/cohere_response.rb +16 -0
- data/lib/langchain/llm/response/google_gemini_response.rb +45 -0
- data/lib/langchain/llm/response/openai_response.rb +5 -1
- data/lib/langchain/llm/unified_parameters.rb +98 -0
- data/lib/langchain/loader.rb +6 -0
- data/lib/langchain/tool/base.rb +16 -6
- data/lib/langchain/tool/calculator/calculator.json +1 -1
- data/lib/langchain/tool/database/database.json +3 -3
- data/lib/langchain/tool/file_system/file_system.json +3 -3
- data/lib/langchain/tool/news_retriever/news_retriever.json +121 -0
- data/lib/langchain/tool/news_retriever/news_retriever.rb +132 -0
- data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +1 -1
- data/lib/langchain/tool/vectorsearch/vectorsearch.json +1 -1
- data/lib/langchain/tool/weather/weather.json +1 -1
- data/lib/langchain/tool/wikipedia/wikipedia.json +1 -1
- data/lib/langchain/tool/wikipedia/wikipedia.rb +2 -2
- data/lib/langchain/utils/token_length/openai_validator.rb +6 -1
- data/lib/langchain/version.rb +1 -1
- data/lib/langchain.rb +3 -0
- metadata +22 -15
- data/lib/langchain/assistants/message.rb +0 -58
- data/lib/langchain/llm/response/google_vertex_ai_response.rb +0 -33
data/lib/langchain/llm/base.rb
CHANGED
@@ -11,7 +11,8 @@ module Langchain::LLM
|
|
11
11
|
# - {Langchain::LLM::Azure}
|
12
12
|
# - {Langchain::LLM::Cohere}
|
13
13
|
# - {Langchain::LLM::GooglePalm}
|
14
|
-
# - {Langchain::LLM::
|
14
|
+
# - {Langchain::LLM::GoogleVertexAI}
|
15
|
+
# - {Langchain::LLM::GoogleGemini}
|
15
16
|
# - {Langchain::LLM::HuggingFace}
|
16
17
|
# - {Langchain::LLM::LlamaCpp}
|
17
18
|
# - {Langchain::LLM::OpenAI}
|
@@ -24,6 +25,15 @@ module Langchain::LLM
|
|
24
25
|
# A client for communicating with the LLM
|
25
26
|
attr_reader :client
|
26
27
|
|
28
|
+
# Ensuring backward compatibility after https://github.com/patterns-ai-core/langchainrb/pull/586
|
29
|
+
# TODO: Delete this method later
|
30
|
+
def default_dimension
|
31
|
+
default_dimensions
|
32
|
+
end
|
33
|
+
|
34
|
+
# Returns the number of vector dimensions used by DEFAULTS[:chat_completion_model_name]
|
35
|
+
#
|
36
|
+
# @return [Integer] Vector dimensions
|
27
37
|
def default_dimensions
|
28
38
|
self.class.const_get(:DEFAULTS).dig(:dimensions)
|
29
39
|
end
|
@@ -61,5 +71,14 @@ module Langchain::LLM
|
|
61
71
|
def summarize(...)
|
62
72
|
raise NotImplementedError, "#{self.class.name} does not support summarization"
|
63
73
|
end
|
74
|
+
|
75
|
+
#
|
76
|
+
# Returns an instance of Langchain::LLM::Parameters::Chat
|
77
|
+
#
|
78
|
+
def chat_parameters(params = {})
|
79
|
+
@chat_parameters ||= Langchain::LLM::Parameters::Chat.new(
|
80
|
+
parameters: params
|
81
|
+
)
|
82
|
+
end
|
64
83
|
end
|
65
84
|
end
|
data/lib/langchain/llm/cohere.rb
CHANGED
@@ -8,22 +8,34 @@ module Langchain::LLM
|
|
8
8
|
# gem "cohere-ruby", "~> 0.9.6"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
#
|
11
|
+
# llm = Langchain::LLM::Cohere.new(api_key: ENV["COHERE_API_KEY"])
|
12
12
|
#
|
13
13
|
class Cohere < Base
|
14
14
|
DEFAULTS = {
|
15
15
|
temperature: 0.0,
|
16
16
|
completion_model_name: "command",
|
17
|
+
chat_completion_model_name: "command-r-plus",
|
17
18
|
embeddings_model_name: "small",
|
18
19
|
dimensions: 1024,
|
19
20
|
truncate: "START"
|
20
21
|
}.freeze
|
21
22
|
|
22
|
-
def initialize(api_key
|
23
|
+
def initialize(api_key:, default_options: {})
|
23
24
|
depends_on "cohere-ruby", req: "cohere"
|
24
25
|
|
25
|
-
@client = ::Cohere::Client.new(api_key)
|
26
|
+
@client = ::Cohere::Client.new(api_key: api_key)
|
26
27
|
@defaults = DEFAULTS.merge(default_options)
|
28
|
+
chat_parameters.update(
|
29
|
+
model: {default: @defaults[:chat_completion_model_name]},
|
30
|
+
temperature: {default: @defaults[:temperature]}
|
31
|
+
)
|
32
|
+
chat_parameters.remap(
|
33
|
+
system: :preamble,
|
34
|
+
messages: :chat_history,
|
35
|
+
stop: :stop_sequences,
|
36
|
+
top_k: :k,
|
37
|
+
top_p: :p
|
38
|
+
)
|
27
39
|
end
|
28
40
|
|
29
41
|
#
|
@@ -68,9 +80,29 @@ module Langchain::LLM
|
|
68
80
|
Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
|
69
81
|
end
|
70
82
|
|
71
|
-
#
|
72
|
-
#
|
73
|
-
#
|
83
|
+
# Generate a chat completion for given messages
|
84
|
+
#
|
85
|
+
# @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
|
86
|
+
# @option params [Array<String>] :messages Input messages
|
87
|
+
# @option params [String] :model The model that will complete your prompt
|
88
|
+
# @option params [Integer] :max_tokens Maximum number of tokens to generate before stopping
|
89
|
+
# @option params [Array<String>] :stop Custom text sequences that will cause the model to stop generating
|
90
|
+
# @option params [Boolean] :stream Whether to incrementally stream the response using server-sent events
|
91
|
+
# @option params [String] :system System prompt
|
92
|
+
# @option params [Float] :temperature Amount of randomness injected into the response
|
93
|
+
# @option params [Array<String>] :tools Definitions of tools that the model may use
|
94
|
+
# @option params [Integer] :top_k Only sample from the top K options for each subsequent token
|
95
|
+
# @option params [Float] :top_p Use nucleus sampling.
|
96
|
+
# @return [Langchain::LLM::CohereResponse] The chat completion
|
97
|
+
def chat(params = {})
|
98
|
+
raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
|
99
|
+
|
100
|
+
parameters = chat_parameters.to_params(params)
|
101
|
+
|
102
|
+
response = client.chat(**parameters)
|
103
|
+
|
104
|
+
Langchain::LLM::CohereResponse.new(response)
|
105
|
+
end
|
74
106
|
|
75
107
|
# Generate a summary in English for a given text
|
76
108
|
#
|
@@ -0,0 +1,67 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain::LLM
|
4
|
+
# Usage:
|
5
|
+
# llm = Langchain::LLM::GoogleGemini.new(api_key: ENV['GOOGLE_GEMINI_API_KEY'])
|
6
|
+
class GoogleGemini < Base
|
7
|
+
DEFAULTS = {
|
8
|
+
chat_completion_model_name: "gemini-1.5-pro-latest",
|
9
|
+
temperature: 0.0
|
10
|
+
}
|
11
|
+
|
12
|
+
attr_reader :defaults, :api_key
|
13
|
+
|
14
|
+
def initialize(api_key:, default_options: {})
|
15
|
+
@api_key = api_key
|
16
|
+
@defaults = DEFAULTS.merge(default_options)
|
17
|
+
|
18
|
+
chat_parameters.update(
|
19
|
+
model: {default: @defaults[:chat_completion_model_name]},
|
20
|
+
temperature: {default: @defaults[:temperature]}
|
21
|
+
)
|
22
|
+
chat_parameters.remap(
|
23
|
+
messages: :contents,
|
24
|
+
system: :system_instruction,
|
25
|
+
tool_choice: :tool_config
|
26
|
+
)
|
27
|
+
end
|
28
|
+
|
29
|
+
# Generate a chat completion for a given prompt
|
30
|
+
#
|
31
|
+
# @param messages [Array<Hash>] List of messages comprising the conversation so far
|
32
|
+
# @param model [String] The model to use
|
33
|
+
# @param tools [Array<Hash>] A list of Tools the model may use to generate the next response
|
34
|
+
# @param tool_choice [String] Specifies the mode in which function calling should execute. If unspecified, the default value will be set to AUTO. Possible values: AUTO, ANY, NONE
|
35
|
+
# @param system [String] Developer set system instruction
|
36
|
+
def chat(params = {})
|
37
|
+
params[:system] = {parts: [{text: params[:system]}]} if params[:system]
|
38
|
+
params[:tools] = {function_declarations: params[:tools]} if params[:tools]
|
39
|
+
params[:tool_choice] = {function_calling_config: {mode: params[:tool_choice].upcase}} if params[:tool_choice]
|
40
|
+
|
41
|
+
raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
|
42
|
+
|
43
|
+
parameters = chat_parameters.to_params(params)
|
44
|
+
parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature]
|
45
|
+
|
46
|
+
uri = URI("https://generativelanguage.googleapis.com/v1beta/models/#{parameters[:model]}:generateContent?key=#{api_key}")
|
47
|
+
|
48
|
+
request = Net::HTTP::Post.new(uri)
|
49
|
+
request.content_type = "application/json"
|
50
|
+
request.body = parameters.to_json
|
51
|
+
|
52
|
+
response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
|
53
|
+
http.request(request)
|
54
|
+
end
|
55
|
+
|
56
|
+
parsed_response = JSON.parse(response.body)
|
57
|
+
|
58
|
+
wrapped_response = Langchain::LLM::GoogleGeminiResponse.new(parsed_response, model: parameters[:model])
|
59
|
+
|
60
|
+
if wrapped_response.chat_completion || Array(wrapped_response.tool_calls).any?
|
61
|
+
wrapped_response
|
62
|
+
else
|
63
|
+
raise StandardError.new(response)
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
@@ -2,150 +2,106 @@
|
|
2
2
|
|
3
3
|
module Langchain::LLM
|
4
4
|
#
|
5
|
-
# Wrapper around the Google Vertex AI APIs: https://cloud.google.com/vertex-ai
|
5
|
+
# Wrapper around the Google Vertex AI APIs: https://cloud.google.com/vertex-ai
|
6
6
|
#
|
7
7
|
# Gem requirements:
|
8
|
-
# gem "
|
8
|
+
# gem "googleauth"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
#
|
11
|
+
# llm = Langchain::LLM::GoogleVertexAI.new(project_id: ENV["GOOGLE_VERTEX_AI_PROJECT_ID"], region: "us-central1")
|
12
12
|
#
|
13
|
-
class
|
13
|
+
class GoogleVertexAI < Base
|
14
14
|
DEFAULTS = {
|
15
|
-
temperature: 0.1,
|
15
|
+
temperature: 0.1,
|
16
16
|
max_output_tokens: 1000,
|
17
17
|
top_p: 0.8,
|
18
18
|
top_k: 40,
|
19
19
|
dimensions: 768,
|
20
|
-
|
21
|
-
|
20
|
+
embeddings_model_name: "textembedding-gecko",
|
21
|
+
chat_completion_model_name: "gemini-1.0-pro"
|
22
22
|
}.freeze
|
23
23
|
|
24
|
-
# TODO: Implement token length validation
|
25
|
-
# LENGTH_VALIDATOR = Langchain::Utils::TokenLength::...
|
26
|
-
|
27
24
|
# Google Cloud has a project id and a specific region of deployment.
|
28
25
|
# For GenAI-related things, a safe choice is us-central1.
|
29
|
-
attr_reader :
|
30
|
-
|
31
|
-
def initialize(project_id:, default_options: {})
|
32
|
-
depends_on "google-apis-aiplatform_v1"
|
26
|
+
attr_reader :defaults, :url, :authorizer
|
33
27
|
|
34
|
-
|
35
|
-
|
28
|
+
def initialize(project_id:, region:, default_options: {})
|
29
|
+
depends_on "googleauth"
|
36
30
|
|
37
|
-
@
|
38
|
-
|
39
|
-
|
40
|
-
# For the moment only us-central1 available so no big deal.
|
41
|
-
@client.root_url = "https://#{@region}-aiplatform.googleapis.com/"
|
42
|
-
@client.authorization = Google::Auth.get_application_default
|
31
|
+
@authorizer = ::Google::Auth.get_application_default
|
32
|
+
proj_id = project_id || @authorizer.project_id || @authorizer.quota_project_id
|
33
|
+
@url = "https://#{region}-aiplatform.googleapis.com/v1/projects/#{proj_id}/locations/#{region}/publishers/google/models/"
|
43
34
|
|
44
35
|
@defaults = DEFAULTS.merge(default_options)
|
36
|
+
|
37
|
+
chat_parameters.update(
|
38
|
+
model: {default: @defaults[:chat_completion_model_name]},
|
39
|
+
temperature: {default: @defaults[:temperature]}
|
40
|
+
)
|
41
|
+
chat_parameters.remap(
|
42
|
+
messages: :contents,
|
43
|
+
system: :system_instruction,
|
44
|
+
tool_choice: :tool_config
|
45
|
+
)
|
45
46
|
end
|
46
47
|
|
47
48
|
#
|
48
49
|
# Generate an embedding for a given text
|
49
50
|
#
|
50
51
|
# @param text [String] The text to generate an embedding for
|
51
|
-
# @
|
52
|
+
# @param model [String] ID of the model to use
|
53
|
+
# @return [Langchain::LLM::GoogleGeminiResponse] Response object
|
52
54
|
#
|
53
|
-
def embed(
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
55
|
+
def embed(
|
56
|
+
text:,
|
57
|
+
model: @defaults[:embeddings_model_name]
|
58
|
+
)
|
59
|
+
params = {instances: [{content: text}]}
|
60
|
+
|
61
|
+
response = HTTParty.post(
|
62
|
+
"#{url}#{model}:predict",
|
63
|
+
body: params.to_json,
|
64
|
+
headers: {
|
65
|
+
"Content-Type" => "application/json",
|
66
|
+
"Authorization" => "Bearer #{@authorizer.fetch_access_token!["access_token"]}"
|
67
|
+
}
|
68
|
+
)
|
62
69
|
|
63
|
-
Langchain::LLM::
|
70
|
+
Langchain::LLM::GoogleGeminiResponse.new(response, model: model)
|
64
71
|
end
|
65
72
|
|
73
|
+
# Generate a chat completion for given messages
|
66
74
|
#
|
67
|
-
#
|
68
|
-
#
|
69
|
-
# @param
|
70
|
-
# @param
|
71
|
-
# @
|
72
|
-
#
|
73
|
-
def
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
}
|
82
|
-
|
83
|
-
|
84
|
-
|
75
|
+
# @param messages [Array<Hash>] Input messages
|
76
|
+
# @param model [String] The model that will complete your prompt
|
77
|
+
# @param tools [Array<Hash>] The tools to use
|
78
|
+
# @param tool_choice [String] The tool choice to use
|
79
|
+
# @param system [String] The system instruction to use
|
80
|
+
# @return [Langchain::LLM::GoogleGeminiResponse] Response object
|
81
|
+
def chat(params = {})
|
82
|
+
params[:system] = {parts: [{text: params[:system]}]} if params[:system]
|
83
|
+
params[:tools] = {function_declarations: params[:tools]} if params[:tools]
|
84
|
+
params[:tool_choice] = {function_calling_config: {mode: params[:tool_choice].upcase}} if params[:tool_choice]
|
85
|
+
|
86
|
+
raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
|
87
|
+
|
88
|
+
parameters = chat_parameters.to_params(params)
|
89
|
+
parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature]
|
90
|
+
|
91
|
+
uri = URI("#{url}#{parameters[:model]}:generateContent")
|
92
|
+
|
93
|
+
request = Net::HTTP::Post.new(uri)
|
94
|
+
request.content_type = "application/json"
|
95
|
+
request["Authorization"] = "Bearer #{@authorizer.fetch_access_token!["access_token"]}"
|
96
|
+
request.body = parameters.to_json
|
97
|
+
|
98
|
+
response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
|
99
|
+
http.request(request)
|
85
100
|
end
|
86
101
|
|
87
|
-
|
88
|
-
default_params[:max_output_tokens] = params.delete(:max_output_tokens)
|
89
|
-
end
|
90
|
-
|
91
|
-
# to be tested
|
92
|
-
temperature = params.delete(:temperature) || @defaults[:temperature]
|
93
|
-
max_output_tokens = default_params.fetch(:max_output_tokens, @defaults[:max_output_tokens])
|
94
|
-
|
95
|
-
default_params.merge!(params)
|
96
|
-
|
97
|
-
# response = client.generate_text(**default_params)
|
98
|
-
request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new \
|
99
|
-
instances: [{
|
100
|
-
prompt: prompt # key used to be :content, changed to :prompt
|
101
|
-
}],
|
102
|
-
parameters: {
|
103
|
-
temperature: temperature,
|
104
|
-
maxOutputTokens: max_output_tokens,
|
105
|
-
topP: 0.8,
|
106
|
-
topK: 40
|
107
|
-
}
|
108
|
-
|
109
|
-
response = client.predict_project_location_publisher_model \
|
110
|
-
"projects/#{project_id}/locations/us-central1/publishers/google/models/#{@defaults[:completion_model_name]}",
|
111
|
-
request
|
112
|
-
|
113
|
-
Langchain::LLM::GoogleVertexAiResponse.new(response, model: default_params[:model])
|
114
|
-
end
|
102
|
+
parsed_response = JSON.parse(response.body)
|
115
103
|
|
116
|
-
|
117
|
-
# Generate a summarization for a given text
|
118
|
-
#
|
119
|
-
# @param text [String] The text to generate a summarization for
|
120
|
-
# @return [String] The summarization
|
121
|
-
#
|
122
|
-
# TODO(ricc): add params for Temp, topP, topK, MaxTokens and have it default to these 4 values.
|
123
|
-
def summarize(text:)
|
124
|
-
prompt_template = Langchain::Prompt.load_from_path(
|
125
|
-
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
|
126
|
-
)
|
127
|
-
prompt = prompt_template.format(text: text)
|
128
|
-
|
129
|
-
complete(
|
130
|
-
prompt: prompt,
|
131
|
-
# For best temperature, topP, topK, MaxTokens for summarization: see
|
132
|
-
# https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-summarization
|
133
|
-
temperature: 0.2,
|
134
|
-
top_p: 0.95,
|
135
|
-
top_k: 40,
|
136
|
-
# Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
|
137
|
-
max_output_tokens: 256
|
138
|
-
)
|
104
|
+
Langchain::LLM::GoogleGeminiResponse.new(parsed_response, model: parameters[:model])
|
139
105
|
end
|
140
|
-
|
141
|
-
# def chat(...)
|
142
|
-
# https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chathat
|
143
|
-
# Chat params: https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chat
|
144
|
-
# \"temperature\": 0.3,\n"
|
145
|
-
# + " \"maxDecodeSteps\": 200,\n"
|
146
|
-
# + " \"topP\": 0.8,\n"
|
147
|
-
# + " \"topK\": 40\n"
|
148
|
-
# + "}";
|
149
|
-
# end
|
150
106
|
end
|
151
107
|
end
|
@@ -23,28 +23,19 @@ module Langchain::LLM
|
|
23
23
|
)
|
24
24
|
|
25
25
|
@defaults = DEFAULTS.merge(default_options)
|
26
|
+
chat_parameters.update(
|
27
|
+
model: {default: @defaults[:chat_completion_model_name]},
|
28
|
+
n: {default: @defaults[:n]},
|
29
|
+
safe_prompt: {}
|
30
|
+
)
|
31
|
+
chat_parameters.remap(seed: :random_seed)
|
32
|
+
chat_parameters.ignore(:n, :top_k)
|
26
33
|
end
|
27
34
|
|
28
|
-
def chat(
|
29
|
-
|
30
|
-
model: defaults[:chat_completion_model_name],
|
31
|
-
temperature: nil,
|
32
|
-
top_p: nil,
|
33
|
-
max_tokens: nil,
|
34
|
-
safe_prompt: nil,
|
35
|
-
random_seed: nil
|
36
|
-
)
|
37
|
-
params = {
|
38
|
-
messages: messages,
|
39
|
-
model: model
|
40
|
-
}
|
41
|
-
params[:temperature] = temperature if temperature
|
42
|
-
params[:top_p] = top_p if top_p
|
43
|
-
params[:max_tokens] = max_tokens if max_tokens
|
44
|
-
params[:safe_prompt] = safe_prompt if safe_prompt
|
45
|
-
params[:random_seed] = random_seed if random_seed
|
35
|
+
def chat(params = {})
|
36
|
+
parameters = chat_parameters.to_params(params)
|
46
37
|
|
47
|
-
response = client.chat_completions(
|
38
|
+
response = client.chat_completions(parameters)
|
48
39
|
|
49
40
|
Langchain::LLM::MistralAIResponse.new(response.to_h)
|
50
41
|
end
|
data/lib/langchain/llm/ollama.rb
CHANGED
@@ -7,22 +7,24 @@ module Langchain::LLM
|
|
7
7
|
# Available models: https://ollama.ai/library
|
8
8
|
#
|
9
9
|
# Usage:
|
10
|
-
#
|
10
|
+
# llm = Langchain::LLM::Ollama.new
|
11
|
+
# llm = Langchain::LLM::Ollama.new(url: ENV["OLLAMA_URL"], default_options: {})
|
11
12
|
#
|
12
13
|
class Ollama < Base
|
13
14
|
attr_reader :url, :defaults
|
14
15
|
|
15
16
|
DEFAULTS = {
|
16
17
|
temperature: 0.8,
|
17
|
-
completion_model_name: "
|
18
|
-
embeddings_model_name: "
|
19
|
-
chat_completion_model_name: "
|
18
|
+
completion_model_name: "llama3",
|
19
|
+
embeddings_model_name: "llama3",
|
20
|
+
chat_completion_model_name: "llama3"
|
20
21
|
}.freeze
|
21
22
|
|
22
23
|
EMBEDDING_SIZES = {
|
23
24
|
codellama: 4_096,
|
24
25
|
"dolphin-mixtral": 4_096,
|
25
26
|
llama2: 4_096,
|
27
|
+
llama3: 4_096,
|
26
28
|
llava: 4_096,
|
27
29
|
mistral: 4_096,
|
28
30
|
"mistral-openorca": 4_096,
|
@@ -33,10 +35,17 @@ module Langchain::LLM
|
|
33
35
|
# @param url [String] The URL of the Ollama instance
|
34
36
|
# @param default_options [Hash] The default options to use
|
35
37
|
#
|
36
|
-
def initialize(url
|
38
|
+
def initialize(url: "http://localhost:11434", default_options: {})
|
37
39
|
depends_on "faraday"
|
38
40
|
@url = url
|
39
41
|
@defaults = DEFAULTS.deep_merge(default_options)
|
42
|
+
chat_parameters.update(
|
43
|
+
model: {default: @defaults[:chat_completion_model_name]},
|
44
|
+
temperature: {default: @defaults[:temperature]},
|
45
|
+
template: {},
|
46
|
+
stream: {default: false}
|
47
|
+
)
|
48
|
+
chat_parameters.remap(response_format: :format)
|
40
49
|
end
|
41
50
|
|
42
51
|
# Returns the # of vector dimensions for the embeddings
|
@@ -150,33 +159,20 @@ module Langchain::LLM
|
|
150
159
|
|
151
160
|
# Generate a chat completion
|
152
161
|
#
|
153
|
-
# @param
|
154
|
-
# @
|
155
|
-
# @
|
156
|
-
# @
|
157
|
-
# @
|
158
|
-
# @
|
162
|
+
# @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
|
163
|
+
# @option params [String] :model Model name
|
164
|
+
# @option params [Array<Hash>] :messages Array of messages
|
165
|
+
# @option params [String] :format Format to return a response in. Currently the only accepted value is `json`
|
166
|
+
# @option params [Float] :temperature The temperature to use
|
167
|
+
# @option params [String] :template The prompt template to use (overrides what is defined in the `Modelfile`)
|
168
|
+
# @option params [Boolean] :stream Streaming the response. If false the response will be returned as a single response object, rather than a stream of objects
|
159
169
|
#
|
160
170
|
# The message object has the following fields:
|
161
171
|
# role: the role of the message, either system, user or assistant
|
162
172
|
# content: the content of the message
|
163
173
|
# images (optional): a list of images to include in the message (for multimodal models such as llava)
|
164
|
-
def chat(
|
165
|
-
|
166
|
-
messages: [],
|
167
|
-
format: nil,
|
168
|
-
temperature: defaults[:temperature],
|
169
|
-
template: nil,
|
170
|
-
stream: false # TODO: Fix streaming.
|
171
|
-
)
|
172
|
-
parameters = {
|
173
|
-
model: model,
|
174
|
-
messages: messages,
|
175
|
-
format: format,
|
176
|
-
temperature: temperature,
|
177
|
-
template: template,
|
178
|
-
stream: stream
|
179
|
-
}.compact
|
174
|
+
def chat(params = {})
|
175
|
+
parameters = chat_parameters.to_params(params)
|
180
176
|
|
181
177
|
response = client.post("api/chat") do |req|
|
182
178
|
req.body = parameters
|
data/lib/langchain/llm/openai.rb
CHANGED
@@ -40,6 +40,15 @@ module Langchain::LLM
|
|
40
40
|
@client = ::OpenAI::Client.new(access_token: api_key, **llm_options)
|
41
41
|
|
42
42
|
@defaults = DEFAULTS.merge(default_options)
|
43
|
+
chat_parameters.update(
|
44
|
+
model: {default: @defaults[:chat_completion_model_name]},
|
45
|
+
logprobs: {},
|
46
|
+
top_logprobs: {},
|
47
|
+
n: {default: @defaults[:n]},
|
48
|
+
temperature: {default: @defaults[:temperature]},
|
49
|
+
user: {}
|
50
|
+
)
|
51
|
+
chat_parameters.ignore(:top_k)
|
43
52
|
end
|
44
53
|
|
45
54
|
# Generate an embedding for a given text
|
@@ -102,54 +111,17 @@ module Langchain::LLM
|
|
102
111
|
|
103
112
|
# Generate a chat completion for given messages.
|
104
113
|
#
|
105
|
-
# @param
|
106
|
-
# @
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
presence_penalty: nil,
|
117
|
-
response_format: nil,
|
118
|
-
seed: nil,
|
119
|
-
stop: nil,
|
120
|
-
stream: nil,
|
121
|
-
temperature: defaults[:temperature],
|
122
|
-
top_p: nil,
|
123
|
-
tools: [],
|
124
|
-
tool_choice: nil,
|
125
|
-
user: nil,
|
126
|
-
&block
|
127
|
-
)
|
128
|
-
raise ArgumentError.new("messages argument is required") if messages.empty?
|
129
|
-
raise ArgumentError.new("model argument is required") if model.empty?
|
130
|
-
raise ArgumentError.new("'tool_choice' is only allowed when 'tools' are specified.") if tool_choice && tools.empty?
|
131
|
-
|
132
|
-
parameters = {
|
133
|
-
messages: messages,
|
134
|
-
model: model
|
135
|
-
}
|
136
|
-
parameters[:frequency_penalty] = frequency_penalty if frequency_penalty
|
137
|
-
parameters[:logit_bias] = logit_bias if logit_bias
|
138
|
-
parameters[:logprobs] = logprobs if logprobs
|
139
|
-
parameters[:top_logprobs] = top_logprobs if top_logprobs
|
140
|
-
# TODO: Fix max_tokens validation to account for tools/functions
|
141
|
-
parameters[:max_tokens] = max_tokens if max_tokens # || validate_max_tokens(parameters[:messages], parameters[:model])
|
142
|
-
parameters[:n] = n if n
|
143
|
-
parameters[:presence_penalty] = presence_penalty if presence_penalty
|
144
|
-
parameters[:response_format] = response_format if response_format
|
145
|
-
parameters[:seed] = seed if seed
|
146
|
-
parameters[:stop] = stop if stop
|
147
|
-
parameters[:stream] = stream if stream
|
148
|
-
parameters[:temperature] = temperature if temperature
|
149
|
-
parameters[:top_p] = top_p if top_p
|
150
|
-
parameters[:tools] = tools if tools.any?
|
151
|
-
parameters[:tool_choice] = tool_choice if tool_choice
|
152
|
-
parameters[:user] = user if user
|
114
|
+
# @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
|
115
|
+
# @option params [Array<Hash>] :messages List of messages comprising the conversation so far
|
116
|
+
# @option params [String] :model ID of the model to use
|
117
|
+
def chat(params = {}, &block)
|
118
|
+
parameters = chat_parameters.to_params(params)
|
119
|
+
|
120
|
+
raise ArgumentError.new("messages argument is required") if Array(parameters[:messages]).empty?
|
121
|
+
raise ArgumentError.new("model argument is required") if parameters[:model].to_s.empty?
|
122
|
+
if parameters[:tool_choice] && Array(parameters[:tools]).empty?
|
123
|
+
raise ArgumentError.new("'tool_choice' is only allowed when 'tools' are specified.")
|
124
|
+
end
|
153
125
|
|
154
126
|
# TODO: Clean this part up
|
155
127
|
if block
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "delegate"
|
4
|
+
|
5
|
+
module Langchain::LLM::Parameters
|
6
|
+
class Chat < SimpleDelegator
|
7
|
+
# TODO: At the moment, the UnifiedParamters only considers keys. In the
|
8
|
+
# future, we may consider ActiveModel-style validations and further typed
|
9
|
+
# options here.
|
10
|
+
SCHEMA = {
|
11
|
+
# Either "messages" or "prompt" is required
|
12
|
+
messages: {},
|
13
|
+
model: {},
|
14
|
+
prompt: {},
|
15
|
+
|
16
|
+
# System instructions. Used by Cohere, Anthropic and Google Gemini.
|
17
|
+
system: {},
|
18
|
+
|
19
|
+
# Allows to force the model to produce specific output format.
|
20
|
+
response_format: {},
|
21
|
+
|
22
|
+
stop: {}, # multiple types (e.g. OpenAI also allows Array, null)
|
23
|
+
stream: {}, # Enable streaming
|
24
|
+
|
25
|
+
max_tokens: {}, # Range: [1, context_length)
|
26
|
+
temperature: {}, # Range: [0, 2]
|
27
|
+
top_p: {}, # Range: (0, 1]
|
28
|
+
top_k: {}, # Range: [1, Infinity) Not available for OpenAI models
|
29
|
+
frequency_penalty: {}, # Range: [-2, 2]
|
30
|
+
presence_penalty: {}, # Range: [-2, 2]
|
31
|
+
repetition_penalty: {}, # Range: (0, 2]
|
32
|
+
seed: {}, # OpenAI only
|
33
|
+
|
34
|
+
# Function-calling
|
35
|
+
tools: {default: []},
|
36
|
+
tool_choice: {},
|
37
|
+
|
38
|
+
# Additional optional parameters
|
39
|
+
logit_bias: {}
|
40
|
+
}
|
41
|
+
|
42
|
+
def initialize(parameters: {})
|
43
|
+
super(
|
44
|
+
::Langchain::LLM::UnifiedParameters.new(
|
45
|
+
schema: SCHEMA.dup,
|
46
|
+
parameters: parameters
|
47
|
+
)
|
48
|
+
)
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|