langchainrb 0.7.5 → 0.12.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (95) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +78 -0
  3. data/README.md +113 -56
  4. data/lib/langchain/assistants/assistant.rb +213 -0
  5. data/lib/langchain/assistants/message.rb +58 -0
  6. data/lib/langchain/assistants/thread.rb +34 -0
  7. data/lib/langchain/chunker/markdown.rb +37 -0
  8. data/lib/langchain/chunker/recursive_text.rb +0 -2
  9. data/lib/langchain/chunker/semantic.rb +1 -3
  10. data/lib/langchain/chunker/sentence.rb +0 -2
  11. data/lib/langchain/chunker/text.rb +0 -2
  12. data/lib/langchain/contextual_logger.rb +1 -1
  13. data/lib/langchain/data.rb +4 -3
  14. data/lib/langchain/llm/ai21.rb +1 -1
  15. data/lib/langchain/llm/anthropic.rb +86 -11
  16. data/lib/langchain/llm/aws_bedrock.rb +52 -0
  17. data/lib/langchain/llm/azure.rb +10 -97
  18. data/lib/langchain/llm/base.rb +3 -2
  19. data/lib/langchain/llm/cohere.rb +5 -7
  20. data/lib/langchain/llm/google_palm.rb +4 -2
  21. data/lib/langchain/llm/google_vertex_ai.rb +151 -0
  22. data/lib/langchain/llm/hugging_face.rb +1 -1
  23. data/lib/langchain/llm/llama_cpp.rb +18 -16
  24. data/lib/langchain/llm/mistral_ai.rb +68 -0
  25. data/lib/langchain/llm/ollama.rb +209 -27
  26. data/lib/langchain/llm/openai.rb +138 -170
  27. data/lib/langchain/llm/prompts/ollama/summarize_template.yaml +9 -0
  28. data/lib/langchain/llm/replicate.rb +1 -7
  29. data/lib/langchain/llm/response/anthropic_response.rb +20 -0
  30. data/lib/langchain/llm/response/base_response.rb +7 -0
  31. data/lib/langchain/llm/response/google_palm_response.rb +4 -0
  32. data/lib/langchain/llm/response/google_vertex_ai_response.rb +33 -0
  33. data/lib/langchain/llm/response/llama_cpp_response.rb +13 -0
  34. data/lib/langchain/llm/response/mistral_ai_response.rb +39 -0
  35. data/lib/langchain/llm/response/ollama_response.rb +27 -1
  36. data/lib/langchain/llm/response/openai_response.rb +8 -0
  37. data/lib/langchain/loader.rb +3 -2
  38. data/lib/langchain/output_parsers/base.rb +0 -4
  39. data/lib/langchain/output_parsers/output_fixing_parser.rb +7 -14
  40. data/lib/langchain/output_parsers/structured_output_parser.rb +0 -10
  41. data/lib/langchain/processors/csv.rb +37 -3
  42. data/lib/langchain/processors/eml.rb +64 -0
  43. data/lib/langchain/processors/markdown.rb +17 -0
  44. data/lib/langchain/processors/pptx.rb +29 -0
  45. data/lib/langchain/prompt/loading.rb +1 -1
  46. data/lib/langchain/tool/base.rb +21 -53
  47. data/lib/langchain/tool/calculator/calculator.json +19 -0
  48. data/lib/langchain/tool/{calculator.rb → calculator/calculator.rb} +8 -16
  49. data/lib/langchain/tool/database/database.json +46 -0
  50. data/lib/langchain/tool/database/database.rb +99 -0
  51. data/lib/langchain/tool/file_system/file_system.json +57 -0
  52. data/lib/langchain/tool/file_system/file_system.rb +32 -0
  53. data/lib/langchain/tool/google_search/google_search.json +19 -0
  54. data/lib/langchain/tool/{google_search.rb → google_search/google_search.rb} +5 -15
  55. data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +19 -0
  56. data/lib/langchain/tool/{ruby_code_interpreter.rb → ruby_code_interpreter/ruby_code_interpreter.rb} +8 -4
  57. data/lib/langchain/tool/vectorsearch/vectorsearch.json +24 -0
  58. data/lib/langchain/tool/vectorsearch/vectorsearch.rb +36 -0
  59. data/lib/langchain/tool/weather/weather.json +19 -0
  60. data/lib/langchain/tool/{weather.rb → weather/weather.rb} +3 -15
  61. data/lib/langchain/tool/wikipedia/wikipedia.json +19 -0
  62. data/lib/langchain/tool/{wikipedia.rb → wikipedia/wikipedia.rb} +9 -9
  63. data/lib/langchain/utils/token_length/ai21_validator.rb +6 -2
  64. data/lib/langchain/utils/token_length/base_validator.rb +1 -1
  65. data/lib/langchain/utils/token_length/cohere_validator.rb +6 -2
  66. data/lib/langchain/utils/token_length/google_palm_validator.rb +5 -1
  67. data/lib/langchain/utils/token_length/openai_validator.rb +55 -1
  68. data/lib/langchain/utils/token_length/token_limit_exceeded.rb +1 -1
  69. data/lib/langchain/vectorsearch/base.rb +11 -4
  70. data/lib/langchain/vectorsearch/chroma.rb +10 -1
  71. data/lib/langchain/vectorsearch/elasticsearch.rb +53 -4
  72. data/lib/langchain/vectorsearch/epsilla.rb +149 -0
  73. data/lib/langchain/vectorsearch/hnswlib.rb +5 -1
  74. data/lib/langchain/vectorsearch/milvus.rb +4 -2
  75. data/lib/langchain/vectorsearch/pgvector.rb +14 -4
  76. data/lib/langchain/vectorsearch/pinecone.rb +8 -5
  77. data/lib/langchain/vectorsearch/qdrant.rb +16 -4
  78. data/lib/langchain/vectorsearch/weaviate.rb +20 -2
  79. data/lib/langchain/version.rb +1 -1
  80. data/lib/langchain.rb +20 -5
  81. metadata +182 -45
  82. data/lib/langchain/agent/agents.md +0 -54
  83. data/lib/langchain/agent/base.rb +0 -20
  84. data/lib/langchain/agent/react_agent/react_agent_prompt.yaml +0 -26
  85. data/lib/langchain/agent/react_agent.rb +0 -131
  86. data/lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.yaml +0 -11
  87. data/lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.yaml +0 -21
  88. data/lib/langchain/agent/sql_query_agent.rb +0 -82
  89. data/lib/langchain/conversation/context.rb +0 -8
  90. data/lib/langchain/conversation/memory.rb +0 -86
  91. data/lib/langchain/conversation/message.rb +0 -48
  92. data/lib/langchain/conversation/prompt.rb +0 -8
  93. data/lib/langchain/conversation/response.rb +0 -8
  94. data/lib/langchain/conversation.rb +0 -93
  95. data/lib/langchain/tool/database.rb +0 -90
@@ -0,0 +1,151 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ #
5
+ # Wrapper around the Google Vertex AI APIs: https://cloud.google.com/vertex-ai?hl=en
6
+ #
7
+ # Gem requirements:
8
+ # gem "google-apis-aiplatform_v1", "~> 0.7"
9
+ #
10
+ # Usage:
11
+ # google_palm = Langchain::LLM::GoogleVertexAi.new(project_id: ENV["GOOGLE_VERTEX_AI_PROJECT_ID"])
12
+ #
13
+ class GoogleVertexAi < Base
14
+ DEFAULTS = {
15
+ temperature: 0.1, # 0.1 is the default in the API, quite low ("grounded")
16
+ max_output_tokens: 1000,
17
+ top_p: 0.8,
18
+ top_k: 40,
19
+ dimensions: 768,
20
+ completion_model_name: "text-bison", # Optional: tect-bison@001
21
+ embeddings_model_name: "textembedding-gecko"
22
+ }.freeze
23
+
24
+ # TODO: Implement token length validation
25
+ # LENGTH_VALIDATOR = Langchain::Utils::TokenLength::...
26
+
27
+ # Google Cloud has a project id and a specific region of deployment.
28
+ # For GenAI-related things, a safe choice is us-central1.
29
+ attr_reader :project_id, :client, :region
30
+
31
+ def initialize(project_id:, default_options: {})
32
+ depends_on "google-apis-aiplatform_v1"
33
+
34
+ @project_id = project_id
35
+ @region = default_options.fetch :region, "us-central1"
36
+
37
+ @client = Google::Apis::AiplatformV1::AiplatformService.new
38
+
39
+ # TODO: Adapt for other regions; Pass it in via the constructor
40
+ # For the moment only us-central1 available so no big deal.
41
+ @client.root_url = "https://#{@region}-aiplatform.googleapis.com/"
42
+ @client.authorization = Google::Auth.get_application_default
43
+
44
+ @defaults = DEFAULTS.merge(default_options)
45
+ end
46
+
47
+ #
48
+ # Generate an embedding for a given text
49
+ #
50
+ # @param text [String] The text to generate an embedding for
51
+ # @return [Langchain::LLM::GoogleVertexAiResponse] Response object
52
+ #
53
+ def embed(text:)
54
+ content = [{content: text}]
55
+ request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new(instances: content)
56
+
57
+ api_path = "projects/#{@project_id}/locations/us-central1/publishers/google/models/#{@defaults[:embeddings_model_name]}"
58
+
59
+ # puts("api_path: #{api_path}")
60
+
61
+ response = client.predict_project_location_publisher_model(api_path, request)
62
+
63
+ Langchain::LLM::GoogleVertexAiResponse.new(response.to_h, model: @defaults[:embeddings_model_name])
64
+ end
65
+
66
+ #
67
+ # Generate a completion for a given prompt
68
+ #
69
+ # @param prompt [String] The prompt to generate a completion for
70
+ # @param params extra parameters passed to GooglePalmAPI::Client#generate_text
71
+ # @return [Langchain::LLM::GooglePalmResponse] Response object
72
+ #
73
+ def complete(prompt:, **params)
74
+ default_params = {
75
+ prompt: prompt,
76
+ temperature: @defaults[:temperature],
77
+ top_k: @defaults[:top_k],
78
+ top_p: @defaults[:top_p],
79
+ max_output_tokens: @defaults[:max_output_tokens],
80
+ model: @defaults[:completion_model_name]
81
+ }
82
+
83
+ if params[:stop_sequences]
84
+ default_params[:stop_sequences] = params.delete(:stop_sequences)
85
+ end
86
+
87
+ if params[:max_output_tokens]
88
+ default_params[:max_output_tokens] = params.delete(:max_output_tokens)
89
+ end
90
+
91
+ # to be tested
92
+ temperature = params.delete(:temperature) || @defaults[:temperature]
93
+ max_output_tokens = default_params.fetch(:max_output_tokens, @defaults[:max_output_tokens])
94
+
95
+ default_params.merge!(params)
96
+
97
+ # response = client.generate_text(**default_params)
98
+ request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new \
99
+ instances: [{
100
+ prompt: prompt # key used to be :content, changed to :prompt
101
+ }],
102
+ parameters: {
103
+ temperature: temperature,
104
+ maxOutputTokens: max_output_tokens,
105
+ topP: 0.8,
106
+ topK: 40
107
+ }
108
+
109
+ response = client.predict_project_location_publisher_model \
110
+ "projects/#{project_id}/locations/us-central1/publishers/google/models/#{@defaults[:completion_model_name]}",
111
+ request
112
+
113
+ Langchain::LLM::GoogleVertexAiResponse.new(response, model: default_params[:model])
114
+ end
115
+
116
+ #
117
+ # Generate a summarization for a given text
118
+ #
119
+ # @param text [String] The text to generate a summarization for
120
+ # @return [String] The summarization
121
+ #
122
+ # TODO(ricc): add params for Temp, topP, topK, MaxTokens and have it default to these 4 values.
123
+ def summarize(text:)
124
+ prompt_template = Langchain::Prompt.load_from_path(
125
+ file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
126
+ )
127
+ prompt = prompt_template.format(text: text)
128
+
129
+ complete(
130
+ prompt: prompt,
131
+ # For best temperature, topP, topK, MaxTokens for summarization: see
132
+ # https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-summarization
133
+ temperature: 0.2,
134
+ top_p: 0.95,
135
+ top_k: 40,
136
+ # Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
137
+ max_output_tokens: 256
138
+ )
139
+ end
140
+
141
+ # def chat(...)
142
+ # https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chathat
143
+ # Chat params: https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chat
144
+ # \"temperature\": 0.3,\n"
145
+ # + " \"maxDecodeSteps\": 200,\n"
146
+ # + " \"topP\": 0.8,\n"
147
+ # + " \"topK\": 40\n"
148
+ # + "}";
149
+ # end
150
+ end
151
+ end
@@ -16,7 +16,7 @@ module Langchain::LLM
16
16
  DEFAULTS = {
17
17
  temperature: 0.0,
18
18
  embeddings_model_name: "sentence-transformers/all-MiniLM-L6-v2",
19
- dimension: 384 # Vector size generated by the above model
19
+ dimensions: 384 # Vector size generated by the above model
20
20
  }.freeze
21
21
 
22
22
  #
@@ -22,7 +22,7 @@ module Langchain::LLM
22
22
  # @param n_ctx [Integer] The number of context tokens to use
23
23
  # @param n_threads [Integer] The CPU number of threads to use
24
24
  # @param seed [Integer] The seed to use
25
- def initialize(model_path:, n_gpu_layers: 1, n_ctx: 2048, n_threads: 1, seed: -1)
25
+ def initialize(model_path:, n_gpu_layers: 1, n_ctx: 2048, n_threads: 1, seed: 0)
26
26
  depends_on "llama_cpp"
27
27
 
28
28
  @model_path = model_path
@@ -33,30 +33,25 @@ module Langchain::LLM
33
33
  end
34
34
 
35
35
  # @param text [String] The text to embed
36
- # @param n_threads [Integer] The number of CPU threads to use
37
36
  # @return [Array<Float>] The embedding
38
- def embed(text:, n_threads: nil)
37
+ def embed(text:)
39
38
  # contexts are kinda stateful when it comes to embeddings, so allocate one each time
40
39
  context = embedding_context
41
40
 
42
- embedding_input = context.tokenize(text: text, add_bos: true)
41
+ embedding_input = @model.tokenize(text: text, add_bos: true)
43
42
  return unless embedding_input.size.positive?
44
43
 
45
- n_threads ||= self.n_threads
46
-
47
- context.eval(tokens: embedding_input, n_past: 0, n_threads: n_threads)
48
- context.embeddings
44
+ context.eval(tokens: embedding_input, n_past: 0)
45
+ Langchain::LLM::LlamaCppResponse.new(context, model: context.model.desc)
49
46
  end
50
47
 
51
48
  # @param prompt [String] The prompt to complete
52
49
  # @param n_predict [Integer] The number of tokens to predict
53
- # @param n_threads [Integer] The number of CPU threads to use
54
50
  # @return [String] The completed prompt
55
- def complete(prompt:, n_predict: 128, n_threads: nil)
56
- n_threads ||= self.n_threads
51
+ def complete(prompt:, n_predict: 128)
57
52
  # contexts do not appear to be stateful when it comes to completion, so re-use the same one
58
53
  context = completion_context
59
- ::LLaMACpp.generate(context, prompt, n_threads: n_threads, n_predict: n_predict)
54
+ ::LLaMACpp.generate(context, prompt, n_predict: n_predict)
60
55
  end
61
56
 
62
57
  private
@@ -71,23 +66,30 @@ module Langchain::LLM
71
66
 
72
67
  context_params.seed = seed
73
68
  context_params.n_ctx = n_ctx
74
- context_params.n_gpu_layers = n_gpu_layers
69
+ context_params.n_threads = n_threads
75
70
  context_params.embedding = embeddings
76
71
 
77
72
  context_params
78
73
  end
79
74
 
75
+ def build_model_params
76
+ model_params = ::LLaMACpp::ModelParams.new
77
+ model_params.n_gpu_layers = n_gpu_layers
78
+
79
+ model_params
80
+ end
81
+
80
82
  def build_model(embeddings: false)
81
83
  return @model if defined?(@model)
82
- @model = ::LLaMACpp::Model.new(model_path: model_path, params: build_context_params(embeddings: embeddings))
84
+ @model = ::LLaMACpp::Model.new(model_path: model_path, params: build_model_params)
83
85
  end
84
86
 
85
87
  def build_completion_context
86
- ::LLaMACpp::Context.new(model: build_model)
88
+ ::LLaMACpp::Context.new(model: build_model, params: build_context_params(embeddings: false))
87
89
  end
88
90
 
89
91
  def build_embedding_context
90
- ::LLaMACpp::Context.new(model: build_model(embeddings: true))
92
+ ::LLaMACpp::Context.new(model: build_model, params: build_context_params(embeddings: true))
91
93
  end
92
94
 
93
95
  def completion_context
@@ -0,0 +1,68 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ # Gem requirements:
5
+ # gem "mistral-ai"
6
+ #
7
+ # Usage:
8
+ # llm = Langchain::LLM::MistralAI.new(api_key: ENV["MISTRAL_AI_API_KEY"])
9
+ class MistralAI < Base
10
+ DEFAULTS = {
11
+ chat_completion_model_name: "mistral-medium",
12
+ embeddings_model_name: "mistral-embed"
13
+ }.freeze
14
+
15
+ attr_reader :defaults
16
+
17
+ def initialize(api_key:, default_options: {})
18
+ depends_on "mistral-ai"
19
+
20
+ @client = Mistral.new(
21
+ credentials: {api_key: api_key},
22
+ options: {server_sent_events: true}
23
+ )
24
+
25
+ @defaults = DEFAULTS.merge(default_options)
26
+ end
27
+
28
+ def chat(
29
+ messages:,
30
+ model: defaults[:chat_completion_model_name],
31
+ temperature: nil,
32
+ top_p: nil,
33
+ max_tokens: nil,
34
+ safe_prompt: nil,
35
+ random_seed: nil
36
+ )
37
+ params = {
38
+ messages: messages,
39
+ model: model
40
+ }
41
+ params[:temperature] = temperature if temperature
42
+ params[:top_p] = top_p if top_p
43
+ params[:max_tokens] = max_tokens if max_tokens
44
+ params[:safe_prompt] = safe_prompt if safe_prompt
45
+ params[:random_seed] = random_seed if random_seed
46
+
47
+ response = client.chat_completions(params)
48
+
49
+ Langchain::LLM::MistralAIResponse.new(response.to_h)
50
+ end
51
+
52
+ def embed(
53
+ text:,
54
+ model: defaults[:embeddings_model_name],
55
+ encoding_format: nil
56
+ )
57
+ params = {
58
+ input: text,
59
+ model: model
60
+ }
61
+ params[:encoding_format] = encoding_format if encoding_format
62
+
63
+ response = client.embeddings(params)
64
+
65
+ Langchain::LLM::MistralAIResponse.new(response.to_h)
66
+ end
67
+ end
68
+ end
@@ -1,25 +1,52 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require "active_support/core_ext/hash"
4
+
3
5
  module Langchain::LLM
4
6
  # Interface to Ollama API.
5
7
  # Available models: https://ollama.ai/library
6
8
  #
7
9
  # Usage:
8
- # ollama = Langchain::LLM::Ollama.new(url: ENV["OLLAMA_URL"])
10
+ # ollama = Langchain::LLM::Ollama.new(url: ENV["OLLAMA_URL"], default_options: {})
9
11
  #
10
12
  class Ollama < Base
11
- attr_reader :url
13
+ attr_reader :url, :defaults
12
14
 
13
15
  DEFAULTS = {
14
- temperature: 0.0,
16
+ temperature: 0.8,
15
17
  completion_model_name: "llama2",
16
- embeddings_model_name: "llama2"
18
+ embeddings_model_name: "llama2",
19
+ chat_completion_model_name: "llama2"
20
+ }.freeze
21
+
22
+ EMBEDDING_SIZES = {
23
+ codellama: 4_096,
24
+ "dolphin-mixtral": 4_096,
25
+ llama2: 4_096,
26
+ llava: 4_096,
27
+ mistral: 4_096,
28
+ "mistral-openorca": 4_096,
29
+ mixtral: 4_096
17
30
  }.freeze
18
31
 
19
32
  # Initialize the Ollama client
20
33
  # @param url [String] The URL of the Ollama instance
21
- def initialize(url:)
34
+ # @param default_options [Hash] The default options to use
35
+ #
36
+ def initialize(url:, default_options: {})
37
+ depends_on "faraday"
22
38
  @url = url
39
+ @defaults = DEFAULTS.deep_merge(default_options)
40
+ end
41
+
42
+ # Returns the # of vector dimensions for the embeddings
43
+ # @return [Integer] The # of vector dimensions
44
+ def default_dimensions
45
+ # since Ollama can run multiple models, look it up or generate an embedding and return the size
46
+ @default_dimensions ||=
47
+ EMBEDDING_SIZES.fetch(defaults[:embeddings_model_name].to_sym) do
48
+ embed(text: "test").embedding.size
49
+ end
23
50
  end
24
51
 
25
52
  #
@@ -27,32 +54,135 @@ module Langchain::LLM
27
54
  #
28
55
  # @param prompt [String] The prompt to complete
29
56
  # @param model [String] The model to use
30
- # @param options [Hash] The options to use (https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values)
57
+ # For a list of valid parameters and values, see:
58
+ # https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
31
59
  # @return [Langchain::LLM::OllamaResponse] Response object
32
60
  #
33
- def complete(prompt:, model: nil, **options)
34
- response = +""
61
+ def complete(
62
+ prompt:,
63
+ model: defaults[:completion_model_name],
64
+ images: nil,
65
+ format: nil,
66
+ system: nil,
67
+ template: nil,
68
+ context: nil,
69
+ stream: nil,
70
+ raw: nil,
71
+ mirostat: nil,
72
+ mirostat_eta: nil,
73
+ mirostat_tau: nil,
74
+ num_ctx: nil,
75
+ num_gqa: nil,
76
+ num_gpu: nil,
77
+ num_thread: nil,
78
+ repeat_last_n: nil,
79
+ repeat_penalty: nil,
80
+ temperature: defaults[:temperature],
81
+ seed: nil,
82
+ stop: nil,
83
+ tfs_z: nil,
84
+ num_predict: nil,
85
+ top_k: nil,
86
+ top_p: nil,
87
+ stop_sequences: nil,
88
+ &block
89
+ )
90
+ if stop_sequences
91
+ stop = stop_sequences
92
+ end
35
93
 
36
- model_name = model || DEFAULTS[:completion_model_name]
94
+ parameters = {
95
+ prompt: prompt,
96
+ model: model,
97
+ images: images,
98
+ format: format,
99
+ system: system,
100
+ template: template,
101
+ context: context,
102
+ stream: stream,
103
+ raw: raw
104
+ }.compact
37
105
 
38
- client.post("api/generate") do |req|
39
- req.body = {}
40
- req.body["prompt"] = prompt
41
- req.body["model"] = model_name
106
+ llm_parameters = {
107
+ mirostat: mirostat,
108
+ mirostat_eta: mirostat_eta,
109
+ mirostat_tau: mirostat_tau,
110
+ num_ctx: num_ctx,
111
+ num_gqa: num_gqa,
112
+ num_gpu: num_gpu,
113
+ num_thread: num_thread,
114
+ repeat_last_n: repeat_last_n,
115
+ repeat_penalty: repeat_penalty,
116
+ temperature: temperature,
117
+ seed: seed,
118
+ stop: stop,
119
+ tfs_z: tfs_z,
120
+ num_predict: num_predict,
121
+ top_k: top_k,
122
+ top_p: top_p
123
+ }
42
124
 
43
- req.body["options"] = options if options.any?
125
+ parameters[:options] = llm_parameters.compact
126
+
127
+ response = ""
128
+
129
+ client.post("api/generate") do |req|
130
+ req.body = parameters
44
131
 
45
- # TODO: Implement streaming support when a &block is passed in
46
132
  req.options.on_data = proc do |chunk, size|
47
- json_chunk = JSON.parse(chunk)
133
+ chunk.split("\n").each do |line_chunk|
134
+ json_chunk = begin
135
+ JSON.parse(line_chunk)
136
+ # In some instance the chunk exceeds the buffer size and the JSON parser fails
137
+ rescue JSON::ParserError
138
+ nil
139
+ end
48
140
 
49
- unless json_chunk.dig("done")
50
- response.to_s << JSON.parse(chunk).dig("response")
141
+ response += json_chunk.dig("response") unless json_chunk.blank?
51
142
  end
143
+
144
+ yield json_chunk, size if block
52
145
  end
53
146
  end
54
147
 
55
- Langchain::LLM::OllamaResponse.new(response, model: model_name)
148
+ Langchain::LLM::OllamaResponse.new(response, model: parameters[:model])
149
+ end
150
+
151
+ # Generate a chat completion
152
+ #
153
+ # @param model [String] Model name
154
+ # @param messages [Array<Hash>] Array of messages
155
+ # @param format [String] Format to return a response in. Currently the only accepted value is `json`
156
+ # @param temperature [Float] The temperature to use
157
+ # @param template [String] The prompt template to use (overrides what is defined in the `Modelfile`)
158
+ # @param stream [Boolean] Streaming the response. If false the response will be returned as a single response object, rather than a stream of objects
159
+ #
160
+ # The message object has the following fields:
161
+ # role: the role of the message, either system, user or assistant
162
+ # content: the content of the message
163
+ # images (optional): a list of images to include in the message (for multimodal models such as llava)
164
+ def chat(
165
+ model: defaults[:chat_completion_model_name],
166
+ messages: [],
167
+ format: nil,
168
+ temperature: defaults[:temperature],
169
+ template: nil,
170
+ stream: false # TODO: Fix streaming.
171
+ )
172
+ parameters = {
173
+ model: model,
174
+ messages: messages,
175
+ format: format,
176
+ temperature: temperature,
177
+ template: template,
178
+ stream: stream
179
+ }.compact
180
+
181
+ response = client.post("api/chat") do |req|
182
+ req.body = parameters
183
+ end
184
+
185
+ Langchain::LLM::OllamaResponse.new(response.body, model: parameters[:model])
56
186
  end
57
187
 
58
188
  #
@@ -63,18 +193,70 @@ module Langchain::LLM
63
193
  # @param options [Hash] The options to use
64
194
  # @return [Langchain::LLM::OllamaResponse] Response object
65
195
  #
66
- def embed(text:, model: nil, **options)
67
- model_name = model || DEFAULTS[:embeddings_model_name]
196
+ def embed(
197
+ text:,
198
+ model: defaults[:embeddings_model_name],
199
+ mirostat: nil,
200
+ mirostat_eta: nil,
201
+ mirostat_tau: nil,
202
+ num_ctx: nil,
203
+ num_gqa: nil,
204
+ num_gpu: nil,
205
+ num_thread: nil,
206
+ repeat_last_n: nil,
207
+ repeat_penalty: nil,
208
+ temperature: defaults[:temperature],
209
+ seed: nil,
210
+ stop: nil,
211
+ tfs_z: nil,
212
+ num_predict: nil,
213
+ top_k: nil,
214
+ top_p: nil
215
+ )
216
+ parameters = {
217
+ prompt: text,
218
+ model: model
219
+ }.compact
68
220
 
69
- response = client.post("api/embeddings") do |req|
70
- req.body = {}
71
- req.body["prompt"] = text
72
- req.body["model"] = model_name
221
+ llm_parameters = {
222
+ mirostat: mirostat,
223
+ mirostat_eta: mirostat_eta,
224
+ mirostat_tau: mirostat_tau,
225
+ num_ctx: num_ctx,
226
+ num_gqa: num_gqa,
227
+ num_gpu: num_gpu,
228
+ num_thread: num_thread,
229
+ repeat_last_n: repeat_last_n,
230
+ repeat_penalty: repeat_penalty,
231
+ temperature: temperature,
232
+ seed: seed,
233
+ stop: stop,
234
+ tfs_z: tfs_z,
235
+ num_predict: num_predict,
236
+ top_k: top_k,
237
+ top_p: top_p
238
+ }
73
239
 
74
- req.body["options"] = options if options.any?
240
+ parameters[:options] = llm_parameters.compact
241
+
242
+ response = client.post("api/embeddings") do |req|
243
+ req.body = parameters
75
244
  end
76
245
 
77
- Langchain::LLM::OllamaResponse.new(response.body, model: model_name)
246
+ Langchain::LLM::OllamaResponse.new(response.body, model: parameters[:model])
247
+ end
248
+
249
+ # Generate a summary for a given text
250
+ #
251
+ # @param text [String] The text to generate a summary for
252
+ # @return [String] The summary
253
+ def summarize(text:)
254
+ prompt_template = Langchain::Prompt.load_from_path(
255
+ file_path: Langchain.root.join("langchain/llm/prompts/ollama/summarize_template.yaml")
256
+ )
257
+ prompt = prompt_template.format(text: text)
258
+
259
+ complete(prompt: prompt)
78
260
  end
79
261
 
80
262
  private