langchainrb 0.7.5 → 0.12.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +78 -0
- data/README.md +113 -56
- data/lib/langchain/assistants/assistant.rb +213 -0
- data/lib/langchain/assistants/message.rb +58 -0
- data/lib/langchain/assistants/thread.rb +34 -0
- data/lib/langchain/chunker/markdown.rb +37 -0
- data/lib/langchain/chunker/recursive_text.rb +0 -2
- data/lib/langchain/chunker/semantic.rb +1 -3
- data/lib/langchain/chunker/sentence.rb +0 -2
- data/lib/langchain/chunker/text.rb +0 -2
- data/lib/langchain/contextual_logger.rb +1 -1
- data/lib/langchain/data.rb +4 -3
- data/lib/langchain/llm/ai21.rb +1 -1
- data/lib/langchain/llm/anthropic.rb +86 -11
- data/lib/langchain/llm/aws_bedrock.rb +52 -0
- data/lib/langchain/llm/azure.rb +10 -97
- data/lib/langchain/llm/base.rb +3 -2
- data/lib/langchain/llm/cohere.rb +5 -7
- data/lib/langchain/llm/google_palm.rb +4 -2
- data/lib/langchain/llm/google_vertex_ai.rb +151 -0
- data/lib/langchain/llm/hugging_face.rb +1 -1
- data/lib/langchain/llm/llama_cpp.rb +18 -16
- data/lib/langchain/llm/mistral_ai.rb +68 -0
- data/lib/langchain/llm/ollama.rb +209 -27
- data/lib/langchain/llm/openai.rb +138 -170
- data/lib/langchain/llm/prompts/ollama/summarize_template.yaml +9 -0
- data/lib/langchain/llm/replicate.rb +1 -7
- data/lib/langchain/llm/response/anthropic_response.rb +20 -0
- data/lib/langchain/llm/response/base_response.rb +7 -0
- data/lib/langchain/llm/response/google_palm_response.rb +4 -0
- data/lib/langchain/llm/response/google_vertex_ai_response.rb +33 -0
- data/lib/langchain/llm/response/llama_cpp_response.rb +13 -0
- data/lib/langchain/llm/response/mistral_ai_response.rb +39 -0
- data/lib/langchain/llm/response/ollama_response.rb +27 -1
- data/lib/langchain/llm/response/openai_response.rb +8 -0
- data/lib/langchain/loader.rb +3 -2
- data/lib/langchain/output_parsers/base.rb +0 -4
- data/lib/langchain/output_parsers/output_fixing_parser.rb +7 -14
- data/lib/langchain/output_parsers/structured_output_parser.rb +0 -10
- data/lib/langchain/processors/csv.rb +37 -3
- data/lib/langchain/processors/eml.rb +64 -0
- data/lib/langchain/processors/markdown.rb +17 -0
- data/lib/langchain/processors/pptx.rb +29 -0
- data/lib/langchain/prompt/loading.rb +1 -1
- data/lib/langchain/tool/base.rb +21 -53
- data/lib/langchain/tool/calculator/calculator.json +19 -0
- data/lib/langchain/tool/{calculator.rb → calculator/calculator.rb} +8 -16
- data/lib/langchain/tool/database/database.json +46 -0
- data/lib/langchain/tool/database/database.rb +99 -0
- data/lib/langchain/tool/file_system/file_system.json +57 -0
- data/lib/langchain/tool/file_system/file_system.rb +32 -0
- data/lib/langchain/tool/google_search/google_search.json +19 -0
- data/lib/langchain/tool/{google_search.rb → google_search/google_search.rb} +5 -15
- data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +19 -0
- data/lib/langchain/tool/{ruby_code_interpreter.rb → ruby_code_interpreter/ruby_code_interpreter.rb} +8 -4
- data/lib/langchain/tool/vectorsearch/vectorsearch.json +24 -0
- data/lib/langchain/tool/vectorsearch/vectorsearch.rb +36 -0
- data/lib/langchain/tool/weather/weather.json +19 -0
- data/lib/langchain/tool/{weather.rb → weather/weather.rb} +3 -15
- data/lib/langchain/tool/wikipedia/wikipedia.json +19 -0
- data/lib/langchain/tool/{wikipedia.rb → wikipedia/wikipedia.rb} +9 -9
- data/lib/langchain/utils/token_length/ai21_validator.rb +6 -2
- data/lib/langchain/utils/token_length/base_validator.rb +1 -1
- data/lib/langchain/utils/token_length/cohere_validator.rb +6 -2
- data/lib/langchain/utils/token_length/google_palm_validator.rb +5 -1
- data/lib/langchain/utils/token_length/openai_validator.rb +55 -1
- data/lib/langchain/utils/token_length/token_limit_exceeded.rb +1 -1
- data/lib/langchain/vectorsearch/base.rb +11 -4
- data/lib/langchain/vectorsearch/chroma.rb +10 -1
- data/lib/langchain/vectorsearch/elasticsearch.rb +53 -4
- data/lib/langchain/vectorsearch/epsilla.rb +149 -0
- data/lib/langchain/vectorsearch/hnswlib.rb +5 -1
- data/lib/langchain/vectorsearch/milvus.rb +4 -2
- data/lib/langchain/vectorsearch/pgvector.rb +14 -4
- data/lib/langchain/vectorsearch/pinecone.rb +8 -5
- data/lib/langchain/vectorsearch/qdrant.rb +16 -4
- data/lib/langchain/vectorsearch/weaviate.rb +20 -2
- data/lib/langchain/version.rb +1 -1
- data/lib/langchain.rb +20 -5
- metadata +182 -45
- data/lib/langchain/agent/agents.md +0 -54
- data/lib/langchain/agent/base.rb +0 -20
- data/lib/langchain/agent/react_agent/react_agent_prompt.yaml +0 -26
- data/lib/langchain/agent/react_agent.rb +0 -131
- data/lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.yaml +0 -11
- data/lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.yaml +0 -21
- data/lib/langchain/agent/sql_query_agent.rb +0 -82
- data/lib/langchain/conversation/context.rb +0 -8
- data/lib/langchain/conversation/memory.rb +0 -86
- data/lib/langchain/conversation/message.rb +0 -48
- data/lib/langchain/conversation/prompt.rb +0 -8
- data/lib/langchain/conversation/response.rb +0 -8
- data/lib/langchain/conversation.rb +0 -93
- data/lib/langchain/tool/database.rb +0 -90
@@ -0,0 +1,151 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain::LLM
|
4
|
+
#
|
5
|
+
# Wrapper around the Google Vertex AI APIs: https://cloud.google.com/vertex-ai?hl=en
|
6
|
+
#
|
7
|
+
# Gem requirements:
|
8
|
+
# gem "google-apis-aiplatform_v1", "~> 0.7"
|
9
|
+
#
|
10
|
+
# Usage:
|
11
|
+
# google_palm = Langchain::LLM::GoogleVertexAi.new(project_id: ENV["GOOGLE_VERTEX_AI_PROJECT_ID"])
|
12
|
+
#
|
13
|
+
class GoogleVertexAi < Base
|
14
|
+
DEFAULTS = {
|
15
|
+
temperature: 0.1, # 0.1 is the default in the API, quite low ("grounded")
|
16
|
+
max_output_tokens: 1000,
|
17
|
+
top_p: 0.8,
|
18
|
+
top_k: 40,
|
19
|
+
dimensions: 768,
|
20
|
+
completion_model_name: "text-bison", # Optional: tect-bison@001
|
21
|
+
embeddings_model_name: "textembedding-gecko"
|
22
|
+
}.freeze
|
23
|
+
|
24
|
+
# TODO: Implement token length validation
|
25
|
+
# LENGTH_VALIDATOR = Langchain::Utils::TokenLength::...
|
26
|
+
|
27
|
+
# Google Cloud has a project id and a specific region of deployment.
|
28
|
+
# For GenAI-related things, a safe choice is us-central1.
|
29
|
+
attr_reader :project_id, :client, :region
|
30
|
+
|
31
|
+
def initialize(project_id:, default_options: {})
|
32
|
+
depends_on "google-apis-aiplatform_v1"
|
33
|
+
|
34
|
+
@project_id = project_id
|
35
|
+
@region = default_options.fetch :region, "us-central1"
|
36
|
+
|
37
|
+
@client = Google::Apis::AiplatformV1::AiplatformService.new
|
38
|
+
|
39
|
+
# TODO: Adapt for other regions; Pass it in via the constructor
|
40
|
+
# For the moment only us-central1 available so no big deal.
|
41
|
+
@client.root_url = "https://#{@region}-aiplatform.googleapis.com/"
|
42
|
+
@client.authorization = Google::Auth.get_application_default
|
43
|
+
|
44
|
+
@defaults = DEFAULTS.merge(default_options)
|
45
|
+
end
|
46
|
+
|
47
|
+
#
|
48
|
+
# Generate an embedding for a given text
|
49
|
+
#
|
50
|
+
# @param text [String] The text to generate an embedding for
|
51
|
+
# @return [Langchain::LLM::GoogleVertexAiResponse] Response object
|
52
|
+
#
|
53
|
+
def embed(text:)
|
54
|
+
content = [{content: text}]
|
55
|
+
request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new(instances: content)
|
56
|
+
|
57
|
+
api_path = "projects/#{@project_id}/locations/us-central1/publishers/google/models/#{@defaults[:embeddings_model_name]}"
|
58
|
+
|
59
|
+
# puts("api_path: #{api_path}")
|
60
|
+
|
61
|
+
response = client.predict_project_location_publisher_model(api_path, request)
|
62
|
+
|
63
|
+
Langchain::LLM::GoogleVertexAiResponse.new(response.to_h, model: @defaults[:embeddings_model_name])
|
64
|
+
end
|
65
|
+
|
66
|
+
#
|
67
|
+
# Generate a completion for a given prompt
|
68
|
+
#
|
69
|
+
# @param prompt [String] The prompt to generate a completion for
|
70
|
+
# @param params extra parameters passed to GooglePalmAPI::Client#generate_text
|
71
|
+
# @return [Langchain::LLM::GooglePalmResponse] Response object
|
72
|
+
#
|
73
|
+
def complete(prompt:, **params)
|
74
|
+
default_params = {
|
75
|
+
prompt: prompt,
|
76
|
+
temperature: @defaults[:temperature],
|
77
|
+
top_k: @defaults[:top_k],
|
78
|
+
top_p: @defaults[:top_p],
|
79
|
+
max_output_tokens: @defaults[:max_output_tokens],
|
80
|
+
model: @defaults[:completion_model_name]
|
81
|
+
}
|
82
|
+
|
83
|
+
if params[:stop_sequences]
|
84
|
+
default_params[:stop_sequences] = params.delete(:stop_sequences)
|
85
|
+
end
|
86
|
+
|
87
|
+
if params[:max_output_tokens]
|
88
|
+
default_params[:max_output_tokens] = params.delete(:max_output_tokens)
|
89
|
+
end
|
90
|
+
|
91
|
+
# to be tested
|
92
|
+
temperature = params.delete(:temperature) || @defaults[:temperature]
|
93
|
+
max_output_tokens = default_params.fetch(:max_output_tokens, @defaults[:max_output_tokens])
|
94
|
+
|
95
|
+
default_params.merge!(params)
|
96
|
+
|
97
|
+
# response = client.generate_text(**default_params)
|
98
|
+
request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new \
|
99
|
+
instances: [{
|
100
|
+
prompt: prompt # key used to be :content, changed to :prompt
|
101
|
+
}],
|
102
|
+
parameters: {
|
103
|
+
temperature: temperature,
|
104
|
+
maxOutputTokens: max_output_tokens,
|
105
|
+
topP: 0.8,
|
106
|
+
topK: 40
|
107
|
+
}
|
108
|
+
|
109
|
+
response = client.predict_project_location_publisher_model \
|
110
|
+
"projects/#{project_id}/locations/us-central1/publishers/google/models/#{@defaults[:completion_model_name]}",
|
111
|
+
request
|
112
|
+
|
113
|
+
Langchain::LLM::GoogleVertexAiResponse.new(response, model: default_params[:model])
|
114
|
+
end
|
115
|
+
|
116
|
+
#
|
117
|
+
# Generate a summarization for a given text
|
118
|
+
#
|
119
|
+
# @param text [String] The text to generate a summarization for
|
120
|
+
# @return [String] The summarization
|
121
|
+
#
|
122
|
+
# TODO(ricc): add params for Temp, topP, topK, MaxTokens and have it default to these 4 values.
|
123
|
+
def summarize(text:)
|
124
|
+
prompt_template = Langchain::Prompt.load_from_path(
|
125
|
+
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
|
126
|
+
)
|
127
|
+
prompt = prompt_template.format(text: text)
|
128
|
+
|
129
|
+
complete(
|
130
|
+
prompt: prompt,
|
131
|
+
# For best temperature, topP, topK, MaxTokens for summarization: see
|
132
|
+
# https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-summarization
|
133
|
+
temperature: 0.2,
|
134
|
+
top_p: 0.95,
|
135
|
+
top_k: 40,
|
136
|
+
# Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
|
137
|
+
max_output_tokens: 256
|
138
|
+
)
|
139
|
+
end
|
140
|
+
|
141
|
+
# def chat(...)
|
142
|
+
# https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chathat
|
143
|
+
# Chat params: https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chat
|
144
|
+
# \"temperature\": 0.3,\n"
|
145
|
+
# + " \"maxDecodeSteps\": 200,\n"
|
146
|
+
# + " \"topP\": 0.8,\n"
|
147
|
+
# + " \"topK\": 40\n"
|
148
|
+
# + "}";
|
149
|
+
# end
|
150
|
+
end
|
151
|
+
end
|
@@ -16,7 +16,7 @@ module Langchain::LLM
|
|
16
16
|
DEFAULTS = {
|
17
17
|
temperature: 0.0,
|
18
18
|
embeddings_model_name: "sentence-transformers/all-MiniLM-L6-v2",
|
19
|
-
|
19
|
+
dimensions: 384 # Vector size generated by the above model
|
20
20
|
}.freeze
|
21
21
|
|
22
22
|
#
|
@@ -22,7 +22,7 @@ module Langchain::LLM
|
|
22
22
|
# @param n_ctx [Integer] The number of context tokens to use
|
23
23
|
# @param n_threads [Integer] The CPU number of threads to use
|
24
24
|
# @param seed [Integer] The seed to use
|
25
|
-
def initialize(model_path:, n_gpu_layers: 1, n_ctx: 2048, n_threads: 1, seed:
|
25
|
+
def initialize(model_path:, n_gpu_layers: 1, n_ctx: 2048, n_threads: 1, seed: 0)
|
26
26
|
depends_on "llama_cpp"
|
27
27
|
|
28
28
|
@model_path = model_path
|
@@ -33,30 +33,25 @@ module Langchain::LLM
|
|
33
33
|
end
|
34
34
|
|
35
35
|
# @param text [String] The text to embed
|
36
|
-
# @param n_threads [Integer] The number of CPU threads to use
|
37
36
|
# @return [Array<Float>] The embedding
|
38
|
-
def embed(text
|
37
|
+
def embed(text:)
|
39
38
|
# contexts are kinda stateful when it comes to embeddings, so allocate one each time
|
40
39
|
context = embedding_context
|
41
40
|
|
42
|
-
embedding_input =
|
41
|
+
embedding_input = @model.tokenize(text: text, add_bos: true)
|
43
42
|
return unless embedding_input.size.positive?
|
44
43
|
|
45
|
-
|
46
|
-
|
47
|
-
context.eval(tokens: embedding_input, n_past: 0, n_threads: n_threads)
|
48
|
-
context.embeddings
|
44
|
+
context.eval(tokens: embedding_input, n_past: 0)
|
45
|
+
Langchain::LLM::LlamaCppResponse.new(context, model: context.model.desc)
|
49
46
|
end
|
50
47
|
|
51
48
|
# @param prompt [String] The prompt to complete
|
52
49
|
# @param n_predict [Integer] The number of tokens to predict
|
53
|
-
# @param n_threads [Integer] The number of CPU threads to use
|
54
50
|
# @return [String] The completed prompt
|
55
|
-
def complete(prompt:, n_predict: 128
|
56
|
-
n_threads ||= self.n_threads
|
51
|
+
def complete(prompt:, n_predict: 128)
|
57
52
|
# contexts do not appear to be stateful when it comes to completion, so re-use the same one
|
58
53
|
context = completion_context
|
59
|
-
::LLaMACpp.generate(context, prompt,
|
54
|
+
::LLaMACpp.generate(context, prompt, n_predict: n_predict)
|
60
55
|
end
|
61
56
|
|
62
57
|
private
|
@@ -71,23 +66,30 @@ module Langchain::LLM
|
|
71
66
|
|
72
67
|
context_params.seed = seed
|
73
68
|
context_params.n_ctx = n_ctx
|
74
|
-
context_params.
|
69
|
+
context_params.n_threads = n_threads
|
75
70
|
context_params.embedding = embeddings
|
76
71
|
|
77
72
|
context_params
|
78
73
|
end
|
79
74
|
|
75
|
+
def build_model_params
|
76
|
+
model_params = ::LLaMACpp::ModelParams.new
|
77
|
+
model_params.n_gpu_layers = n_gpu_layers
|
78
|
+
|
79
|
+
model_params
|
80
|
+
end
|
81
|
+
|
80
82
|
def build_model(embeddings: false)
|
81
83
|
return @model if defined?(@model)
|
82
|
-
@model = ::LLaMACpp::Model.new(model_path: model_path, params:
|
84
|
+
@model = ::LLaMACpp::Model.new(model_path: model_path, params: build_model_params)
|
83
85
|
end
|
84
86
|
|
85
87
|
def build_completion_context
|
86
|
-
::LLaMACpp::Context.new(model: build_model)
|
88
|
+
::LLaMACpp::Context.new(model: build_model, params: build_context_params(embeddings: false))
|
87
89
|
end
|
88
90
|
|
89
91
|
def build_embedding_context
|
90
|
-
::LLaMACpp::Context.new(model: build_model(embeddings: true))
|
92
|
+
::LLaMACpp::Context.new(model: build_model, params: build_context_params(embeddings: true))
|
91
93
|
end
|
92
94
|
|
93
95
|
def completion_context
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain::LLM
|
4
|
+
# Gem requirements:
|
5
|
+
# gem "mistral-ai"
|
6
|
+
#
|
7
|
+
# Usage:
|
8
|
+
# llm = Langchain::LLM::MistralAI.new(api_key: ENV["MISTRAL_AI_API_KEY"])
|
9
|
+
class MistralAI < Base
|
10
|
+
DEFAULTS = {
|
11
|
+
chat_completion_model_name: "mistral-medium",
|
12
|
+
embeddings_model_name: "mistral-embed"
|
13
|
+
}.freeze
|
14
|
+
|
15
|
+
attr_reader :defaults
|
16
|
+
|
17
|
+
def initialize(api_key:, default_options: {})
|
18
|
+
depends_on "mistral-ai"
|
19
|
+
|
20
|
+
@client = Mistral.new(
|
21
|
+
credentials: {api_key: api_key},
|
22
|
+
options: {server_sent_events: true}
|
23
|
+
)
|
24
|
+
|
25
|
+
@defaults = DEFAULTS.merge(default_options)
|
26
|
+
end
|
27
|
+
|
28
|
+
def chat(
|
29
|
+
messages:,
|
30
|
+
model: defaults[:chat_completion_model_name],
|
31
|
+
temperature: nil,
|
32
|
+
top_p: nil,
|
33
|
+
max_tokens: nil,
|
34
|
+
safe_prompt: nil,
|
35
|
+
random_seed: nil
|
36
|
+
)
|
37
|
+
params = {
|
38
|
+
messages: messages,
|
39
|
+
model: model
|
40
|
+
}
|
41
|
+
params[:temperature] = temperature if temperature
|
42
|
+
params[:top_p] = top_p if top_p
|
43
|
+
params[:max_tokens] = max_tokens if max_tokens
|
44
|
+
params[:safe_prompt] = safe_prompt if safe_prompt
|
45
|
+
params[:random_seed] = random_seed if random_seed
|
46
|
+
|
47
|
+
response = client.chat_completions(params)
|
48
|
+
|
49
|
+
Langchain::LLM::MistralAIResponse.new(response.to_h)
|
50
|
+
end
|
51
|
+
|
52
|
+
def embed(
|
53
|
+
text:,
|
54
|
+
model: defaults[:embeddings_model_name],
|
55
|
+
encoding_format: nil
|
56
|
+
)
|
57
|
+
params = {
|
58
|
+
input: text,
|
59
|
+
model: model
|
60
|
+
}
|
61
|
+
params[:encoding_format] = encoding_format if encoding_format
|
62
|
+
|
63
|
+
response = client.embeddings(params)
|
64
|
+
|
65
|
+
Langchain::LLM::MistralAIResponse.new(response.to_h)
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
data/lib/langchain/llm/ollama.rb
CHANGED
@@ -1,25 +1,52 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
+
require "active_support/core_ext/hash"
|
4
|
+
|
3
5
|
module Langchain::LLM
|
4
6
|
# Interface to Ollama API.
|
5
7
|
# Available models: https://ollama.ai/library
|
6
8
|
#
|
7
9
|
# Usage:
|
8
|
-
# ollama = Langchain::LLM::Ollama.new(url: ENV["OLLAMA_URL"])
|
10
|
+
# ollama = Langchain::LLM::Ollama.new(url: ENV["OLLAMA_URL"], default_options: {})
|
9
11
|
#
|
10
12
|
class Ollama < Base
|
11
|
-
attr_reader :url
|
13
|
+
attr_reader :url, :defaults
|
12
14
|
|
13
15
|
DEFAULTS = {
|
14
|
-
temperature: 0.
|
16
|
+
temperature: 0.8,
|
15
17
|
completion_model_name: "llama2",
|
16
|
-
embeddings_model_name: "llama2"
|
18
|
+
embeddings_model_name: "llama2",
|
19
|
+
chat_completion_model_name: "llama2"
|
20
|
+
}.freeze
|
21
|
+
|
22
|
+
EMBEDDING_SIZES = {
|
23
|
+
codellama: 4_096,
|
24
|
+
"dolphin-mixtral": 4_096,
|
25
|
+
llama2: 4_096,
|
26
|
+
llava: 4_096,
|
27
|
+
mistral: 4_096,
|
28
|
+
"mistral-openorca": 4_096,
|
29
|
+
mixtral: 4_096
|
17
30
|
}.freeze
|
18
31
|
|
19
32
|
# Initialize the Ollama client
|
20
33
|
# @param url [String] The URL of the Ollama instance
|
21
|
-
|
34
|
+
# @param default_options [Hash] The default options to use
|
35
|
+
#
|
36
|
+
def initialize(url:, default_options: {})
|
37
|
+
depends_on "faraday"
|
22
38
|
@url = url
|
39
|
+
@defaults = DEFAULTS.deep_merge(default_options)
|
40
|
+
end
|
41
|
+
|
42
|
+
# Returns the # of vector dimensions for the embeddings
|
43
|
+
# @return [Integer] The # of vector dimensions
|
44
|
+
def default_dimensions
|
45
|
+
# since Ollama can run multiple models, look it up or generate an embedding and return the size
|
46
|
+
@default_dimensions ||=
|
47
|
+
EMBEDDING_SIZES.fetch(defaults[:embeddings_model_name].to_sym) do
|
48
|
+
embed(text: "test").embedding.size
|
49
|
+
end
|
23
50
|
end
|
24
51
|
|
25
52
|
#
|
@@ -27,32 +54,135 @@ module Langchain::LLM
|
|
27
54
|
#
|
28
55
|
# @param prompt [String] The prompt to complete
|
29
56
|
# @param model [String] The model to use
|
30
|
-
#
|
57
|
+
# For a list of valid parameters and values, see:
|
58
|
+
# https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
|
31
59
|
# @return [Langchain::LLM::OllamaResponse] Response object
|
32
60
|
#
|
33
|
-
def complete(
|
34
|
-
|
61
|
+
def complete(
|
62
|
+
prompt:,
|
63
|
+
model: defaults[:completion_model_name],
|
64
|
+
images: nil,
|
65
|
+
format: nil,
|
66
|
+
system: nil,
|
67
|
+
template: nil,
|
68
|
+
context: nil,
|
69
|
+
stream: nil,
|
70
|
+
raw: nil,
|
71
|
+
mirostat: nil,
|
72
|
+
mirostat_eta: nil,
|
73
|
+
mirostat_tau: nil,
|
74
|
+
num_ctx: nil,
|
75
|
+
num_gqa: nil,
|
76
|
+
num_gpu: nil,
|
77
|
+
num_thread: nil,
|
78
|
+
repeat_last_n: nil,
|
79
|
+
repeat_penalty: nil,
|
80
|
+
temperature: defaults[:temperature],
|
81
|
+
seed: nil,
|
82
|
+
stop: nil,
|
83
|
+
tfs_z: nil,
|
84
|
+
num_predict: nil,
|
85
|
+
top_k: nil,
|
86
|
+
top_p: nil,
|
87
|
+
stop_sequences: nil,
|
88
|
+
&block
|
89
|
+
)
|
90
|
+
if stop_sequences
|
91
|
+
stop = stop_sequences
|
92
|
+
end
|
35
93
|
|
36
|
-
|
94
|
+
parameters = {
|
95
|
+
prompt: prompt,
|
96
|
+
model: model,
|
97
|
+
images: images,
|
98
|
+
format: format,
|
99
|
+
system: system,
|
100
|
+
template: template,
|
101
|
+
context: context,
|
102
|
+
stream: stream,
|
103
|
+
raw: raw
|
104
|
+
}.compact
|
37
105
|
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
106
|
+
llm_parameters = {
|
107
|
+
mirostat: mirostat,
|
108
|
+
mirostat_eta: mirostat_eta,
|
109
|
+
mirostat_tau: mirostat_tau,
|
110
|
+
num_ctx: num_ctx,
|
111
|
+
num_gqa: num_gqa,
|
112
|
+
num_gpu: num_gpu,
|
113
|
+
num_thread: num_thread,
|
114
|
+
repeat_last_n: repeat_last_n,
|
115
|
+
repeat_penalty: repeat_penalty,
|
116
|
+
temperature: temperature,
|
117
|
+
seed: seed,
|
118
|
+
stop: stop,
|
119
|
+
tfs_z: tfs_z,
|
120
|
+
num_predict: num_predict,
|
121
|
+
top_k: top_k,
|
122
|
+
top_p: top_p
|
123
|
+
}
|
42
124
|
|
43
|
-
|
125
|
+
parameters[:options] = llm_parameters.compact
|
126
|
+
|
127
|
+
response = ""
|
128
|
+
|
129
|
+
client.post("api/generate") do |req|
|
130
|
+
req.body = parameters
|
44
131
|
|
45
|
-
# TODO: Implement streaming support when a &block is passed in
|
46
132
|
req.options.on_data = proc do |chunk, size|
|
47
|
-
|
133
|
+
chunk.split("\n").each do |line_chunk|
|
134
|
+
json_chunk = begin
|
135
|
+
JSON.parse(line_chunk)
|
136
|
+
# In some instance the chunk exceeds the buffer size and the JSON parser fails
|
137
|
+
rescue JSON::ParserError
|
138
|
+
nil
|
139
|
+
end
|
48
140
|
|
49
|
-
|
50
|
-
response.to_s << JSON.parse(chunk).dig("response")
|
141
|
+
response += json_chunk.dig("response") unless json_chunk.blank?
|
51
142
|
end
|
143
|
+
|
144
|
+
yield json_chunk, size if block
|
52
145
|
end
|
53
146
|
end
|
54
147
|
|
55
|
-
Langchain::LLM::OllamaResponse.new(response, model:
|
148
|
+
Langchain::LLM::OllamaResponse.new(response, model: parameters[:model])
|
149
|
+
end
|
150
|
+
|
151
|
+
# Generate a chat completion
|
152
|
+
#
|
153
|
+
# @param model [String] Model name
|
154
|
+
# @param messages [Array<Hash>] Array of messages
|
155
|
+
# @param format [String] Format to return a response in. Currently the only accepted value is `json`
|
156
|
+
# @param temperature [Float] The temperature to use
|
157
|
+
# @param template [String] The prompt template to use (overrides what is defined in the `Modelfile`)
|
158
|
+
# @param stream [Boolean] Streaming the response. If false the response will be returned as a single response object, rather than a stream of objects
|
159
|
+
#
|
160
|
+
# The message object has the following fields:
|
161
|
+
# role: the role of the message, either system, user or assistant
|
162
|
+
# content: the content of the message
|
163
|
+
# images (optional): a list of images to include in the message (for multimodal models such as llava)
|
164
|
+
def chat(
|
165
|
+
model: defaults[:chat_completion_model_name],
|
166
|
+
messages: [],
|
167
|
+
format: nil,
|
168
|
+
temperature: defaults[:temperature],
|
169
|
+
template: nil,
|
170
|
+
stream: false # TODO: Fix streaming.
|
171
|
+
)
|
172
|
+
parameters = {
|
173
|
+
model: model,
|
174
|
+
messages: messages,
|
175
|
+
format: format,
|
176
|
+
temperature: temperature,
|
177
|
+
template: template,
|
178
|
+
stream: stream
|
179
|
+
}.compact
|
180
|
+
|
181
|
+
response = client.post("api/chat") do |req|
|
182
|
+
req.body = parameters
|
183
|
+
end
|
184
|
+
|
185
|
+
Langchain::LLM::OllamaResponse.new(response.body, model: parameters[:model])
|
56
186
|
end
|
57
187
|
|
58
188
|
#
|
@@ -63,18 +193,70 @@ module Langchain::LLM
|
|
63
193
|
# @param options [Hash] The options to use
|
64
194
|
# @return [Langchain::LLM::OllamaResponse] Response object
|
65
195
|
#
|
66
|
-
def embed(
|
67
|
-
|
196
|
+
def embed(
|
197
|
+
text:,
|
198
|
+
model: defaults[:embeddings_model_name],
|
199
|
+
mirostat: nil,
|
200
|
+
mirostat_eta: nil,
|
201
|
+
mirostat_tau: nil,
|
202
|
+
num_ctx: nil,
|
203
|
+
num_gqa: nil,
|
204
|
+
num_gpu: nil,
|
205
|
+
num_thread: nil,
|
206
|
+
repeat_last_n: nil,
|
207
|
+
repeat_penalty: nil,
|
208
|
+
temperature: defaults[:temperature],
|
209
|
+
seed: nil,
|
210
|
+
stop: nil,
|
211
|
+
tfs_z: nil,
|
212
|
+
num_predict: nil,
|
213
|
+
top_k: nil,
|
214
|
+
top_p: nil
|
215
|
+
)
|
216
|
+
parameters = {
|
217
|
+
prompt: text,
|
218
|
+
model: model
|
219
|
+
}.compact
|
68
220
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
221
|
+
llm_parameters = {
|
222
|
+
mirostat: mirostat,
|
223
|
+
mirostat_eta: mirostat_eta,
|
224
|
+
mirostat_tau: mirostat_tau,
|
225
|
+
num_ctx: num_ctx,
|
226
|
+
num_gqa: num_gqa,
|
227
|
+
num_gpu: num_gpu,
|
228
|
+
num_thread: num_thread,
|
229
|
+
repeat_last_n: repeat_last_n,
|
230
|
+
repeat_penalty: repeat_penalty,
|
231
|
+
temperature: temperature,
|
232
|
+
seed: seed,
|
233
|
+
stop: stop,
|
234
|
+
tfs_z: tfs_z,
|
235
|
+
num_predict: num_predict,
|
236
|
+
top_k: top_k,
|
237
|
+
top_p: top_p
|
238
|
+
}
|
73
239
|
|
74
|
-
|
240
|
+
parameters[:options] = llm_parameters.compact
|
241
|
+
|
242
|
+
response = client.post("api/embeddings") do |req|
|
243
|
+
req.body = parameters
|
75
244
|
end
|
76
245
|
|
77
|
-
Langchain::LLM::OllamaResponse.new(response.body, model:
|
246
|
+
Langchain::LLM::OllamaResponse.new(response.body, model: parameters[:model])
|
247
|
+
end
|
248
|
+
|
249
|
+
# Generate a summary for a given text
|
250
|
+
#
|
251
|
+
# @param text [String] The text to generate a summary for
|
252
|
+
# @return [String] The summary
|
253
|
+
def summarize(text:)
|
254
|
+
prompt_template = Langchain::Prompt.load_from_path(
|
255
|
+
file_path: Langchain.root.join("langchain/llm/prompts/ollama/summarize_template.yaml")
|
256
|
+
)
|
257
|
+
prompt = prompt_template.format(text: text)
|
258
|
+
|
259
|
+
complete(prompt: prompt)
|
78
260
|
end
|
79
261
|
|
80
262
|
private
|