raif 1.1.0 → 1.2.1.pre

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +150 -4
  3. data/app/assets/builds/raif.css +26 -1
  4. data/app/assets/stylesheets/raif/loader.scss +27 -1
  5. data/app/models/raif/concerns/llm_response_parsing.rb +22 -16
  6. data/app/models/raif/concerns/llms/anthropic/tool_formatting.rb +56 -0
  7. data/app/models/raif/concerns/llms/{bedrock_claude → bedrock}/message_formatting.rb +4 -4
  8. data/app/models/raif/concerns/llms/bedrock/tool_formatting.rb +37 -0
  9. data/app/models/raif/concerns/llms/message_formatting.rb +7 -6
  10. data/app/models/raif/concerns/llms/open_ai/json_schema_validation.rb +138 -0
  11. data/app/models/raif/concerns/llms/{open_ai → open_ai_completions}/message_formatting.rb +1 -1
  12. data/app/models/raif/concerns/llms/open_ai_completions/tool_formatting.rb +26 -0
  13. data/app/models/raif/concerns/llms/open_ai_responses/message_formatting.rb +43 -0
  14. data/app/models/raif/concerns/llms/open_ai_responses/tool_formatting.rb +42 -0
  15. data/app/models/raif/conversation.rb +17 -4
  16. data/app/models/raif/conversation_entry.rb +18 -2
  17. data/app/models/raif/embedding_models/{bedrock_titan.rb → bedrock.rb} +2 -2
  18. data/app/models/raif/llm.rb +73 -7
  19. data/app/models/raif/llms/anthropic.rb +56 -36
  20. data/app/models/raif/llms/{bedrock_claude.rb → bedrock.rb} +62 -45
  21. data/app/models/raif/llms/open_ai_base.rb +66 -0
  22. data/app/models/raif/llms/open_ai_completions.rb +100 -0
  23. data/app/models/raif/llms/open_ai_responses.rb +144 -0
  24. data/app/models/raif/llms/open_router.rb +44 -44
  25. data/app/models/raif/model_completion.rb +2 -0
  26. data/app/models/raif/model_tool.rb +4 -0
  27. data/app/models/raif/model_tools/provider_managed/base.rb +9 -0
  28. data/app/models/raif/model_tools/provider_managed/code_execution.rb +5 -0
  29. data/app/models/raif/model_tools/provider_managed/image_generation.rb +5 -0
  30. data/app/models/raif/model_tools/provider_managed/web_search.rb +5 -0
  31. data/app/models/raif/streaming_responses/anthropic.rb +63 -0
  32. data/app/models/raif/streaming_responses/bedrock.rb +89 -0
  33. data/app/models/raif/streaming_responses/open_ai_completions.rb +76 -0
  34. data/app/models/raif/streaming_responses/open_ai_responses.rb +54 -0
  35. data/app/views/raif/admin/conversations/_conversation_entry.html.erb +48 -0
  36. data/app/views/raif/admin/conversations/show.html.erb +1 -1
  37. data/app/views/raif/admin/model_completions/_model_completion.html.erb +7 -0
  38. data/app/views/raif/admin/model_completions/index.html.erb +1 -0
  39. data/app/views/raif/admin/model_completions/show.html.erb +28 -0
  40. data/app/views/raif/conversation_entries/_citations.html.erb +9 -0
  41. data/app/views/raif/conversation_entries/_conversation_entry.html.erb +5 -1
  42. data/app/views/raif/conversation_entries/_message.html.erb +4 -0
  43. data/config/locales/admin.en.yml +2 -0
  44. data/config/locales/en.yml +24 -0
  45. data/db/migrate/20250224234252_create_raif_tables.rb +1 -1
  46. data/db/migrate/20250421202149_add_response_format_to_raif_conversations.rb +1 -1
  47. data/db/migrate/20250424200755_add_cost_columns_to_raif_model_completions.rb +1 -1
  48. data/db/migrate/20250424232946_add_created_at_indexes.rb +1 -1
  49. data/db/migrate/20250502155330_add_status_indexes_to_raif_tasks.rb +1 -1
  50. data/db/migrate/20250527213016_add_response_id_and_response_array_to_model_completions.rb +14 -0
  51. data/db/migrate/20250603140622_add_citations_to_raif_model_completions.rb +13 -0
  52. data/db/migrate/20250603202013_add_stream_response_to_raif_model_completions.rb +7 -0
  53. data/lib/generators/raif/conversation/templates/conversation.rb.tt +3 -3
  54. data/lib/generators/raif/install/templates/initializer.rb +14 -2
  55. data/lib/raif/configuration.rb +27 -5
  56. data/lib/raif/embedding_model_registry.rb +1 -1
  57. data/lib/raif/engine.rb +25 -9
  58. data/lib/raif/errors/streaming_error.rb +18 -0
  59. data/lib/raif/errors.rb +1 -0
  60. data/lib/raif/llm_registry.rb +169 -47
  61. data/lib/raif/migration_checker.rb +74 -0
  62. data/lib/raif/utils/html_fragment_processor.rb +170 -0
  63. data/lib/raif/utils.rb +1 -0
  64. data/lib/raif/version.rb +1 -1
  65. data/lib/raif.rb +2 -0
  66. data/spec/support/complex_test_tool.rb +65 -0
  67. data/spec/support/rspec_helpers.rb +66 -0
  68. data/spec/support/test_conversation.rb +18 -0
  69. data/spec/support/test_embedding_model.rb +27 -0
  70. data/spec/support/test_llm.rb +22 -0
  71. data/spec/support/test_model_tool.rb +32 -0
  72. data/spec/support/test_task.rb +45 -0
  73. metadata +52 -8
  74. data/app/models/raif/llms/open_ai.rb +0 -256
@@ -0,0 +1,138 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Raif::Concerns::Llms::OpenAi::JsonSchemaValidation
4
+ extend ActiveSupport::Concern
5
+
6
+ def validate_json_schema!(schema)
7
+ return if schema.blank?
8
+
9
+ errors = []
10
+
11
+ # Check if schema is present
12
+ if schema.blank?
13
+ errors << "JSON schema must include a 'schema' property"
14
+ else
15
+ # Check root object type
16
+ if schema[:type] != "object" && !schema.key?(:properties)
17
+ errors << "Root schema must be of type 'object' with 'properties'"
18
+ end
19
+
20
+ # Check all objects in the schema recursively
21
+ validate_object_properties(schema, errors)
22
+
23
+ # Check properties count (max 100 total)
24
+ validate_properties_count(schema, errors)
25
+
26
+ # Check nesting depth (max 5 levels)
27
+ validate_nesting_depth(schema, errors)
28
+
29
+ # Check for unsupported anyOf at root level
30
+ if schema[:anyOf].present? && schema[:properties].blank?
31
+ errors << "Root objects cannot be of type 'anyOf'"
32
+ end
33
+ end
34
+
35
+ # Raise error if any validation issues found
36
+ if errors.any?
37
+ error_message = "Invalid JSON schema for OpenAI structured outputs: #{errors.join("; ")}\nSchema was: #{schema.inspect}"
38
+ raise Raif::Errors::OpenAi::JsonSchemaError, error_message
39
+ else
40
+ true
41
+ end
42
+ end
43
+
44
+ private
45
+
46
+ def validate_object_properties(schema, errors)
47
+ return unless schema.is_a?(Hash)
48
+
49
+ # Check if the current schema is an object and validate additionalProperties and required fields
50
+ if schema[:type] == "object"
51
+ if schema[:additionalProperties] != false
52
+ errors << "All objects must have 'additionalProperties' set to false"
53
+ end
54
+
55
+ # Check that all properties are required
56
+ if schema[:properties].is_a?(Hash) && schema[:properties].any?
57
+ property_keys = schema[:properties].keys
58
+ required_fields = schema[:required] || []
59
+
60
+ if required_fields.sort != property_keys.map(&:to_s).sort
61
+ errors << "All object properties must be listed in the 'required' array"
62
+ end
63
+ end
64
+ end
65
+
66
+ # Check if the current schema is an object and validate additionalProperties
67
+ if schema[:type] == "object"
68
+ if schema[:additionalProperties] != false
69
+ errors << "All objects must have 'additionalProperties' set to false"
70
+ end
71
+
72
+ # Check properties of the object recursively
73
+ if schema[:properties].is_a?(Hash)
74
+ schema[:properties].each_value do |property|
75
+ validate_object_properties(property, errors)
76
+ end
77
+ end
78
+ end
79
+
80
+ # Check array items
81
+ if schema[:type] == "array" && schema[:items].is_a?(Hash)
82
+ validate_object_properties(schema[:items], errors)
83
+ end
84
+
85
+ # Check anyOf
86
+ if schema[:anyOf].is_a?(Array)
87
+ schema[:anyOf].each do |option|
88
+ validate_object_properties(option, errors)
89
+ end
90
+ end
91
+ end
92
+
93
+ def validate_properties_count(schema, errors, count = 0)
94
+ return count unless schema.is_a?(Hash)
95
+
96
+ if schema[:properties].is_a?(Hash)
97
+ count += schema[:properties].size
98
+
99
+ if count > 100
100
+ errors << "Schema exceeds maximum of 100 total object properties"
101
+ return count
102
+ end
103
+
104
+ # Check nested properties
105
+ schema[:properties].each_value do |property|
106
+ count = validate_properties_count(property, errors, count)
107
+ end
108
+ end
109
+
110
+ # Check array items
111
+ if schema[:type] == "array" && schema[:items].is_a?(Hash)
112
+ count = validate_properties_count(schema[:items], errors, count)
113
+ end
114
+
115
+ count
116
+ end
117
+
118
+ def validate_nesting_depth(schema, errors, depth = 1)
119
+ return unless schema.is_a?(Hash)
120
+
121
+ if depth > 5
122
+ errors << "Schema exceeds maximum nesting depth of 5 levels"
123
+ return
124
+ end
125
+
126
+ if schema[:properties].is_a?(Hash)
127
+ schema[:properties].each_value do |property|
128
+ validate_nesting_depth(property, errors, depth + 1)
129
+ end
130
+ end
131
+
132
+ # Check array items
133
+ if schema[:type] == "array" && schema[:items].is_a?(Hash)
134
+ validate_nesting_depth(schema[:items], errors, depth + 1)
135
+ end
136
+ end
137
+
138
+ end
@@ -1,6 +1,6 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- module Raif::Concerns::Llms::OpenAi::MessageFormatting
3
+ module Raif::Concerns::Llms::OpenAiCompletions::MessageFormatting
4
4
  extend ActiveSupport::Concern
5
5
 
6
6
  def format_model_image_input_message(image_input)
@@ -0,0 +1,26 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Raif::Concerns::Llms::OpenAiCompletions::ToolFormatting
4
+ extend ActiveSupport::Concern
5
+
6
+ def build_tools_parameter(model_completion)
7
+ model_completion.available_model_tools_map.map do |_tool_name, tool|
8
+ if tool.provider_managed?
9
+ raise Raif::Errors::UnsupportedFeatureError,
10
+ "Raif doesn't yet support provider-managed tools for the OpenAI Completions API. Consider using the OpenAI Responses API instead."
11
+ else
12
+ # It's a developer-managed tool
13
+ validate_json_schema!(tool.tool_arguments_schema)
14
+
15
+ {
16
+ type: "function",
17
+ function: {
18
+ name: tool.tool_name,
19
+ description: tool.tool_description,
20
+ parameters: tool.tool_arguments_schema
21
+ }
22
+ }
23
+ end
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,43 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Raif::Concerns::Llms::OpenAiResponses::MessageFormatting
4
+ extend ActiveSupport::Concern
5
+
6
+ def format_string_message(content, role: nil)
7
+ if role == "assistant"
8
+ { "type" => "output_text", "text" => content }
9
+ else
10
+ { "type" => "input_text", "text" => content }
11
+ end
12
+ end
13
+
14
+ def format_model_image_input_message(image_input)
15
+ if image_input.source_type == :url
16
+ {
17
+ "type" => "input_image",
18
+ "image_url" => image_input.url
19
+ }
20
+ elsif image_input.source_type == :file_content
21
+ {
22
+ "type" => "input_image",
23
+ "image_url" => "data:#{image_input.content_type};base64,#{image_input.base64_data}"
24
+ }
25
+ else
26
+ raise Raif::Errors::InvalidModelImageInputError, "Invalid model image input source type: #{image_input.source_type}"
27
+ end
28
+ end
29
+
30
+ def format_model_file_input_message(file_input)
31
+ if file_input.source_type == :url
32
+ raise Raif::Errors::UnsupportedFeatureError, "#{self.class.name} does not support providing a file by URL"
33
+ elsif file_input.source_type == :file_content
34
+ {
35
+ "type" => "input_file",
36
+ "filename" => file_input.filename,
37
+ "file_data" => "data:#{file_input.content_type};base64,#{file_input.base64_data}"
38
+ }
39
+ else
40
+ raise Raif::Errors::InvalidModelFileInputError, "Invalid model image input source type: #{file_input.source_type}"
41
+ end
42
+ end
43
+ end
@@ -0,0 +1,42 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Raif::Concerns::Llms::OpenAiResponses::ToolFormatting
4
+ extend ActiveSupport::Concern
5
+
6
+ def build_tools_parameter(model_completion)
7
+ model_completion.available_model_tools_map.map do |_tool_name, tool|
8
+ if tool.provider_managed?
9
+ format_provider_managed_tool(tool)
10
+ else
11
+ # It's a developer-managed tool
12
+ validate_json_schema!(tool.tool_arguments_schema)
13
+
14
+ {
15
+ type: "function",
16
+ name: tool.tool_name,
17
+ description: tool.tool_description,
18
+ parameters: tool.tool_arguments_schema
19
+ }
20
+ end
21
+ end
22
+ end
23
+
24
+ def format_provider_managed_tool(tool)
25
+ validate_provider_managed_tool_support!(tool)
26
+
27
+ case tool.name
28
+ when "Raif::ModelTools::ProviderManaged::WebSearch"
29
+ { type: "web_search_preview" }
30
+ when "Raif::ModelTools::ProviderManaged::CodeExecution"
31
+ {
32
+ type: "code_interpreter",
33
+ container: { "type": "auto" }
34
+ }
35
+ when "Raif::ModelTools::ProviderManaged::ImageGeneration"
36
+ { type: "image_generation" }
37
+ else
38
+ raise Raif::Errors::UnsupportedFeatureError,
39
+ "Invalid provider-managed tool: #{tool.name} for #{key}"
40
+ end
41
+ end
42
+ end
@@ -16,10 +16,9 @@ class Raif::Conversation < Raif::ApplicationRecord
16
16
  after_initialize -> { self.available_user_tools ||= [] }
17
17
 
18
18
  before_validation ->{ self.type ||= "Raif::Conversation" }, on: :create
19
- before_validation -> { self.system_prompt ||= build_system_prompt }, on: :create
20
19
 
21
20
  def build_system_prompt
22
- <<~PROMPT
21
+ <<~PROMPT.strip
23
22
  #{system_prompt_intro}
24
23
  #{system_prompt_language_preference}
25
24
  PROMPT
@@ -35,14 +34,28 @@ class Raif::Conversation < Raif::ApplicationRecord
35
34
  I18n.t("#{self.class.name.underscore.gsub("/", ".")}.initial_chat_message")
36
35
  end
37
36
 
38
- def prompt_model_for_entry_response(entry:)
37
+ def prompt_model_for_entry_response(entry:, &block)
38
+ update(system_prompt: build_system_prompt)
39
+
39
40
  llm.chat(
40
41
  messages: llm_messages,
41
42
  source: entry,
42
43
  response_format: response_format.to_sym,
43
44
  system_prompt: system_prompt,
44
- available_model_tools: available_model_tools
45
+ available_model_tools: available_model_tools,
46
+ &block
45
47
  )
48
+ rescue StandardError => e
49
+ Rails.logger.error("Error processing conversation entry ##{entry.id}. #{e.message}")
50
+ entry.failed!
51
+
52
+ if defined?(Airbrake)
53
+ notice = Airbrake.build_notice(e)
54
+ notice[:context][:component] = "raif_conversation"
55
+ notice[:context][:action] = "prompt_model_for_entry_response"
56
+
57
+ Airbrake.notify(notice)
58
+ end
46
59
  end
47
60
 
48
61
  def process_model_response_message(message:, entry:)
@@ -16,7 +16,7 @@ class Raif::ConversationEntry < Raif::ApplicationRecord
16
16
  has_one :raif_model_completion, as: :source, dependent: :destroy, class_name: "Raif::ModelCompletion"
17
17
 
18
18
  delegate :available_model_tools, to: :raif_conversation
19
- delegate :system_prompt, :llm_model_key, to: :raif_model_completion, allow_nil: true
19
+ delegate :system_prompt, :llm_model_key, :citations, to: :raif_model_completion, allow_nil: true
20
20
  delegate :json_response_schema, to: :class
21
21
 
22
22
  accepts_nested_attributes_for :raif_user_tool_invocation
@@ -46,7 +46,23 @@ class Raif::ConversationEntry < Raif::ApplicationRecord
46
46
  end
47
47
 
48
48
  def process_entry!
49
- self.raif_model_completion = raif_conversation.prompt_model_for_entry_response(entry: self)
49
+ self.model_response_message = ""
50
+
51
+ self.raif_model_completion = raif_conversation.prompt_model_for_entry_response(entry: self) do |model_completion, _delta, _sse_event|
52
+ self.raw_response = model_completion.raw_response
53
+ self.model_response_message = raif_conversation.process_model_response_message(
54
+ message: model_completion.parsed_response(force_reparse: true),
55
+ entry: self
56
+ )
57
+
58
+ update_columns(
59
+ model_response_message: model_response_message,
60
+ raw_response: raw_response,
61
+ updated_at: Time.current
62
+ )
63
+
64
+ broadcast_replace_to raif_conversation
65
+ end
50
66
 
51
67
  if raif_model_completion.parsed_response.present? || raif_model_completion.response_tool_calls.present?
52
68
  extract_message_and_invoke_tools!
@@ -1,10 +1,10 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- class Raif::EmbeddingModels::BedrockTitan < Raif::EmbeddingModel
3
+ class Raif::EmbeddingModels::Bedrock < Raif::EmbeddingModel
4
4
 
5
5
  def generate_embedding!(input, dimensions: nil)
6
6
  unless input.is_a?(String)
7
- raise ArgumentError, "Raif::EmbeddingModels::BedrockTitan#generate_embedding! input must be a string"
7
+ raise ArgumentError, "Raif::EmbeddingModels::Bedrock#generate_embedding! input must be a string"
8
8
  end
9
9
 
10
10
  params = build_request_parameters(input, dimensions:)
@@ -12,7 +12,8 @@ module Raif
12
12
  :supports_native_tool_use,
13
13
  :provider_settings,
14
14
  :input_token_cost,
15
- :output_token_cost
15
+ :output_token_cost,
16
+ :supported_provider_managed_tools
16
17
 
17
18
  validates :key, presence: true
18
19
  validates :api_name, presence: true
@@ -21,8 +22,17 @@ module Raif
21
22
 
22
23
  alias_method :supports_native_tool_use?, :supports_native_tool_use
23
24
 
24
- def initialize(key:, api_name:, model_provider_settings: {}, supports_native_tool_use: true, temperature: nil, max_completion_tokens: nil,
25
- input_token_cost: nil, output_token_cost: nil)
25
+ def initialize(
26
+ key:,
27
+ api_name:,
28
+ model_provider_settings: {},
29
+ supported_provider_managed_tools: [],
30
+ supports_native_tool_use: true,
31
+ temperature: nil,
32
+ max_completion_tokens: nil,
33
+ input_token_cost: nil,
34
+ output_token_cost: nil
35
+ )
26
36
  @key = key
27
37
  @api_name = api_name
28
38
  @provider_settings = model_provider_settings
@@ -31,6 +41,7 @@ module Raif
31
41
  @default_max_completion_tokens = max_completion_tokens
32
42
  @input_token_cost = input_token_cost
33
43
  @output_token_cost = output_token_cost
44
+ @supported_provider_managed_tools = supported_provider_managed_tools.map(&:to_s)
34
45
  end
35
46
 
36
47
  def name
@@ -38,7 +49,7 @@ module Raif
38
49
  end
39
50
 
40
51
  def chat(message: nil, messages: nil, response_format: :text, available_model_tools: [], source: nil, system_prompt: nil, temperature: nil,
41
- max_completion_tokens: nil)
52
+ max_completion_tokens: nil, &block)
42
53
  unless response_format.is_a?(Symbol)
43
54
  raise ArgumentError,
44
55
  "Raif::Llm#chat - Invalid response format: #{response_format}. Must be a symbol (you passed #{response_format.class}) and be one of: #{VALID_RESPONSE_FORMATS.join(", ")}" # rubocop:disable Layout/LineLength
@@ -75,17 +86,25 @@ module Raif
75
86
  model_api_name: api_name,
76
87
  temperature: temperature,
77
88
  max_completion_tokens: max_completion_tokens,
78
- available_model_tools: available_model_tools
89
+ available_model_tools: available_model_tools,
90
+ stream_response: block_given?
79
91
  )
80
92
 
81
93
  retry_with_backoff(model_completion) do
82
- perform_model_completion!(model_completion)
94
+ perform_model_completion!(model_completion, &block)
83
95
  end
84
96
 
85
97
  model_completion
98
+ rescue Raif::Errors::StreamingError => e
99
+ Rails.logger.error("Raif streaming error -- code: #{e.code} -- type: #{e.type} -- message: #{e.message} -- event: #{e.event}")
100
+ raise e
101
+ rescue Faraday::Error => e
102
+ Raif.logger.error("LLM API request failed (status: #{e.response_status}): #{e.message}")
103
+ Raif.logger.error(e.response_body)
104
+ raise e
86
105
  end
87
106
 
88
- def perform_model_completion!(model_completion)
107
+ def perform_model_completion!(model_completion, &block)
89
108
  raise NotImplementedError, "#{self.class.name} must implement #perform_model_completion!"
90
109
  end
91
110
 
@@ -93,6 +112,17 @@ module Raif
93
112
  VALID_RESPONSE_FORMATS
94
113
  end
95
114
 
115
+ def supports_provider_managed_tool?(tool_klass)
116
+ supported_provider_managed_tools&.include?(tool_klass.to_s)
117
+ end
118
+
119
+ def validate_provider_managed_tool_support!(tool)
120
+ unless supports_provider_managed_tool?(tool)
121
+ raise Raif::Errors::UnsupportedFeatureError,
122
+ "Invalid provider-managed tool: #{tool.name} for #{key}"
123
+ end
124
+ end
125
+
96
126
  private
97
127
 
98
128
  def retry_with_backoff(model_completion)
@@ -117,5 +147,41 @@ module Raif
117
147
  end
118
148
  end
119
149
  end
150
+
151
+ def streaming_response_type
152
+ raise NotImplementedError, "#{self.class.name} must implement #streaming_response_type"
153
+ end
154
+
155
+ def streaming_chunk_handler(model_completion, &block)
156
+ return unless model_completion.stream_response?
157
+
158
+ streaming_response = streaming_response_type.new
159
+ event_parser = EventStreamParser::Parser.new
160
+ accumulated_delta = ""
161
+
162
+ proc do |chunk, _size, _env|
163
+ event_parser.feed(chunk) do |event_type, data, _id, _reconnect_time|
164
+ if data.blank? || data == "[DONE]"
165
+ update_model_completion(model_completion, streaming_response.current_response_json)
166
+ next
167
+ end
168
+
169
+ event_data = JSON.parse(data)
170
+ delta, finish_reason = streaming_response.process_streaming_event(event_type, event_data)
171
+
172
+ accumulated_delta += delta if delta.present?
173
+
174
+ if accumulated_delta.length >= Raif.config.streaming_update_chunk_size_threshold || finish_reason.present?
175
+ update_model_completion(model_completion, streaming_response.current_response_json)
176
+
177
+ if accumulated_delta.present?
178
+ block.call(model_completion, accumulated_delta, event_data)
179
+ accumulated_delta = ""
180
+ end
181
+ end
182
+ end
183
+ end
184
+ end
185
+
120
186
  end
121
187
  end
@@ -2,29 +2,24 @@
2
2
 
3
3
  class Raif::Llms::Anthropic < Raif::Llm
4
4
  include Raif::Concerns::Llms::Anthropic::MessageFormatting
5
+ include Raif::Concerns::Llms::Anthropic::ToolFormatting
5
6
 
6
- def perform_model_completion!(model_completion)
7
+ def perform_model_completion!(model_completion, &block)
7
8
  params = build_request_parameters(model_completion)
8
9
  response = connection.post("messages") do |req|
9
10
  req.body = params
11
+ req.options.on_data = streaming_chunk_handler(model_completion, &block) if model_completion.stream_response?
10
12
  end
11
13
 
12
- response_json = response.body
13
-
14
- model_completion.raw_response = if model_completion.response_format_json?
15
- extract_json_response(response_json)
16
- else
17
- extract_text_response(response_json)
14
+ unless model_completion.stream_response?
15
+ update_model_completion(model_completion, response.body)
18
16
  end
19
17
 
20
- model_completion.response_tool_calls = extract_response_tool_calls(response_json)
21
- model_completion.completion_tokens = response_json&.dig("usage", "output_tokens")
22
- model_completion.prompt_tokens = response_json&.dig("usage", "input_tokens")
23
- model_completion.save!
24
-
25
18
  model_completion
26
19
  end
27
20
 
21
+ private
22
+
28
23
  def connection
29
24
  @connection ||= Faraday.new(url: "https://api.anthropic.com/v1") do |f|
30
25
  f.headers["x-api-key"] = Raif.config.anthropic_api_key
@@ -35,7 +30,26 @@ class Raif::Llms::Anthropic < Raif::Llm
35
30
  end
36
31
  end
37
32
 
38
- protected
33
+ def streaming_response_type
34
+ Raif::StreamingResponses::Anthropic
35
+ end
36
+
37
+ def update_model_completion(model_completion, response_json)
38
+ model_completion.raw_response = if model_completion.response_format_json?
39
+ extract_json_response(response_json)
40
+ else
41
+ extract_text_response(response_json)
42
+ end
43
+
44
+ model_completion.response_id = response_json&.dig("id")
45
+ model_completion.response_array = response_json&.dig("content")
46
+ model_completion.response_tool_calls = extract_response_tool_calls(response_json)
47
+ model_completion.citations = extract_citations(response_json)
48
+ model_completion.completion_tokens = response_json&.dig("usage", "output_tokens")
49
+ model_completion.prompt_tokens = response_json&.dig("usage", "input_tokens")
50
+ model_completion.total_tokens = model_completion.completion_tokens.to_i + model_completion.prompt_tokens.to_i
51
+ model_completion.save!
52
+ end
39
53
 
40
54
  def build_request_parameters(model_completion)
41
55
  params = {
@@ -47,36 +61,20 @@ protected
47
61
 
48
62
  params[:system] = model_completion.system_prompt if model_completion.system_prompt.present?
49
63
 
50
- # Add tools to the request if needed
51
- tools = []
52
-
53
- # If we're looking for a JSON response, add a tool to the request that the model can use to provide a JSON response
54
- if model_completion.response_format_json? && model_completion.json_response_schema.present?
55
- tools << {
56
- name: "json_response",
57
- description: "Generate a structured JSON response based on the provided schema.",
58
- input_schema: model_completion.json_response_schema
59
- }
60
- end
61
-
62
- # If we support native tool use and have tools available, add them to the request
63
- if supports_native_tool_use? && model_completion.available_model_tools.any?
64
- model_completion.available_model_tools_map.each do |_tool_name, tool|
65
- tools << {
66
- name: tool.tool_name,
67
- description: tool.tool_description,
68
- input_schema: tool.tool_arguments_schema
69
- }
70
- end
64
+ if supports_native_tool_use?
65
+ tools = build_tools_parameter(model_completion)
66
+ params[:tools] = tools unless tools.blank?
71
67
  end
72
68
 
73
- params[:tools] = tools if tools.any?
69
+ params[:stream] = true if model_completion.stream_response?
74
70
 
75
71
  params
76
72
  end
77
73
 
78
74
  def extract_text_response(resp)
79
- resp&.dig("content")&.first&.dig("text")
75
+ return if resp&.dig("content").blank?
76
+
77
+ resp.dig("content").select{|v| v["type"] == "text" }.map{|v| v["text"] }.join("\n")
80
78
  end
81
79
 
82
80
  def extract_json_response(resp)
@@ -112,4 +110,26 @@ protected
112
110
  end
113
111
  end
114
112
 
113
+ def extract_citations(resp)
114
+ return [] if resp&.dig("content").nil?
115
+
116
+ citations = []
117
+
118
+ # Look through content blocks for citations
119
+ resp.dig("content").each do |content|
120
+ next unless content["type"] == "text" && content["citations"].present?
121
+
122
+ content["citations"].each do |citation|
123
+ next unless citation["type"] == "web_search_result_location"
124
+
125
+ citations << {
126
+ "url" => Raif::Utils::HtmlFragmentProcessor.strip_tracking_parameters(citation["url"]),
127
+ "title" => citation["title"]
128
+ }
129
+ end
130
+ end
131
+
132
+ citations.uniq{|citation| citation["url"] }
133
+ end
134
+
115
135
  end