raif 1.0.0 → 1.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +346 -43
  3. data/app/assets/builds/raif.css +26 -1
  4. data/app/assets/stylesheets/raif/admin/stats.scss +12 -0
  5. data/app/assets/stylesheets/raif/loader.scss +27 -1
  6. data/app/controllers/raif/admin/application_controller.rb +14 -0
  7. data/app/controllers/raif/admin/stats/tasks_controller.rb +25 -0
  8. data/app/controllers/raif/admin/stats_controller.rb +19 -0
  9. data/app/controllers/raif/admin/tasks_controller.rb +18 -2
  10. data/app/controllers/raif/conversations_controller.rb +5 -1
  11. data/app/models/raif/agent.rb +11 -9
  12. data/app/models/raif/agents/native_tool_calling_agent.rb +11 -1
  13. data/app/models/raif/agents/re_act_agent.rb +6 -0
  14. data/app/models/raif/concerns/has_available_model_tools.rb +1 -1
  15. data/app/models/raif/concerns/json_schema_definition.rb +28 -0
  16. data/app/models/raif/concerns/llm_response_parsing.rb +42 -14
  17. data/app/models/raif/concerns/llm_temperature.rb +17 -0
  18. data/app/models/raif/concerns/llms/anthropic/message_formatting.rb +51 -0
  19. data/app/models/raif/concerns/llms/anthropic/tool_formatting.rb +56 -0
  20. data/app/models/raif/concerns/llms/bedrock/message_formatting.rb +70 -0
  21. data/app/models/raif/concerns/llms/bedrock/tool_formatting.rb +37 -0
  22. data/app/models/raif/concerns/llms/message_formatting.rb +42 -0
  23. data/app/models/raif/concerns/llms/open_ai/json_schema_validation.rb +138 -0
  24. data/app/models/raif/concerns/llms/open_ai_completions/message_formatting.rb +41 -0
  25. data/app/models/raif/concerns/llms/open_ai_completions/tool_formatting.rb +26 -0
  26. data/app/models/raif/concerns/llms/open_ai_responses/message_formatting.rb +43 -0
  27. data/app/models/raif/concerns/llms/open_ai_responses/tool_formatting.rb +42 -0
  28. data/app/models/raif/conversation.rb +28 -7
  29. data/app/models/raif/conversation_entry.rb +40 -8
  30. data/app/models/raif/embedding_model.rb +22 -0
  31. data/app/models/raif/embedding_models/bedrock.rb +34 -0
  32. data/app/models/raif/embedding_models/open_ai.rb +40 -0
  33. data/app/models/raif/llm.rb +108 -9
  34. data/app/models/raif/llms/anthropic.rb +72 -57
  35. data/app/models/raif/llms/bedrock.rb +165 -0
  36. data/app/models/raif/llms/open_ai_base.rb +66 -0
  37. data/app/models/raif/llms/open_ai_completions.rb +100 -0
  38. data/app/models/raif/llms/open_ai_responses.rb +144 -0
  39. data/app/models/raif/llms/open_router.rb +88 -0
  40. data/app/models/raif/model_completion.rb +23 -2
  41. data/app/models/raif/model_file_input.rb +113 -0
  42. data/app/models/raif/model_image_input.rb +4 -0
  43. data/app/models/raif/model_tool.rb +82 -52
  44. data/app/models/raif/model_tool_invocation.rb +8 -6
  45. data/app/models/raif/model_tools/agent_final_answer.rb +18 -27
  46. data/app/models/raif/model_tools/fetch_url.rb +27 -36
  47. data/app/models/raif/model_tools/provider_managed/base.rb +9 -0
  48. data/app/models/raif/model_tools/provider_managed/code_execution.rb +5 -0
  49. data/app/models/raif/model_tools/provider_managed/image_generation.rb +5 -0
  50. data/app/models/raif/model_tools/provider_managed/web_search.rb +5 -0
  51. data/app/models/raif/model_tools/wikipedia_search.rb +46 -55
  52. data/app/models/raif/streaming_responses/anthropic.rb +63 -0
  53. data/app/models/raif/streaming_responses/bedrock.rb +89 -0
  54. data/app/models/raif/streaming_responses/open_ai_completions.rb +76 -0
  55. data/app/models/raif/streaming_responses/open_ai_responses.rb +54 -0
  56. data/app/models/raif/task.rb +71 -16
  57. data/app/views/layouts/raif/admin.html.erb +10 -0
  58. data/app/views/raif/admin/agents/show.html.erb +3 -1
  59. data/app/views/raif/admin/conversations/_conversation.html.erb +1 -1
  60. data/app/views/raif/admin/conversations/_conversation_entry.html.erb +48 -0
  61. data/app/views/raif/admin/conversations/show.html.erb +4 -2
  62. data/app/views/raif/admin/model_completions/_model_completion.html.erb +8 -0
  63. data/app/views/raif/admin/model_completions/index.html.erb +2 -0
  64. data/app/views/raif/admin/model_completions/show.html.erb +58 -3
  65. data/app/views/raif/admin/stats/index.html.erb +128 -0
  66. data/app/views/raif/admin/stats/tasks/index.html.erb +45 -0
  67. data/app/views/raif/admin/tasks/_task.html.erb +5 -4
  68. data/app/views/raif/admin/tasks/index.html.erb +20 -2
  69. data/app/views/raif/admin/tasks/show.html.erb +3 -1
  70. data/app/views/raif/conversation_entries/_citations.html.erb +9 -0
  71. data/app/views/raif/conversation_entries/_conversation_entry.html.erb +22 -14
  72. data/app/views/raif/conversation_entries/_form.html.erb +1 -1
  73. data/app/views/raif/conversation_entries/_form_with_available_tools.html.erb +4 -4
  74. data/app/views/raif/conversation_entries/_message.html.erb +14 -3
  75. data/config/locales/admin.en.yml +16 -0
  76. data/config/locales/en.yml +47 -3
  77. data/config/routes.rb +6 -0
  78. data/db/migrate/20250224234252_create_raif_tables.rb +1 -1
  79. data/db/migrate/20250421202149_add_response_format_to_raif_conversations.rb +7 -0
  80. data/db/migrate/20250424200755_add_cost_columns_to_raif_model_completions.rb +14 -0
  81. data/db/migrate/20250424232946_add_created_at_indexes.rb +11 -0
  82. data/db/migrate/20250502155330_add_status_indexes_to_raif_tasks.rb +14 -0
  83. data/db/migrate/20250507155314_add_retry_count_to_raif_model_completions.rb +7 -0
  84. data/db/migrate/20250527213016_add_response_id_and_response_array_to_model_completions.rb +14 -0
  85. data/db/migrate/20250603140622_add_citations_to_raif_model_completions.rb +13 -0
  86. data/db/migrate/20250603202013_add_stream_response_to_raif_model_completions.rb +7 -0
  87. data/lib/generators/raif/agent/agent_generator.rb +22 -12
  88. data/lib/generators/raif/agent/templates/agent.rb.tt +3 -3
  89. data/lib/generators/raif/agent/templates/application_agent.rb.tt +7 -0
  90. data/lib/generators/raif/conversation/conversation_generator.rb +10 -0
  91. data/lib/generators/raif/conversation/templates/application_conversation.rb.tt +7 -0
  92. data/lib/generators/raif/conversation/templates/conversation.rb.tt +16 -14
  93. data/lib/generators/raif/install/templates/initializer.rb +62 -6
  94. data/lib/generators/raif/model_tool/model_tool_generator.rb +0 -5
  95. data/lib/generators/raif/model_tool/templates/model_tool.rb.tt +69 -56
  96. data/lib/generators/raif/task/templates/task.rb.tt +34 -23
  97. data/lib/raif/configuration.rb +63 -4
  98. data/lib/raif/embedding_model_registry.rb +83 -0
  99. data/lib/raif/engine.rb +56 -7
  100. data/lib/raif/errors/{open_ai/api_error.rb → invalid_model_file_input_error.rb} +1 -3
  101. data/lib/raif/errors/{anthropic/api_error.rb → invalid_model_image_input_error.rb} +1 -3
  102. data/lib/raif/errors/streaming_error.rb +18 -0
  103. data/lib/raif/errors/unsupported_feature_error.rb +8 -0
  104. data/lib/raif/errors.rb +4 -2
  105. data/lib/raif/json_schema_builder.rb +104 -0
  106. data/lib/raif/llm_registry.rb +315 -0
  107. data/lib/raif/migration_checker.rb +74 -0
  108. data/lib/raif/utils/html_fragment_processor.rb +169 -0
  109. data/lib/raif/utils.rb +1 -0
  110. data/lib/raif/version.rb +1 -1
  111. data/lib/raif.rb +7 -32
  112. data/lib/tasks/raif_tasks.rake +9 -4
  113. metadata +62 -12
  114. data/app/models/raif/llms/bedrock_claude.rb +0 -134
  115. data/app/models/raif/llms/open_ai.rb +0 -259
  116. data/lib/raif/default_llms.rb +0 -37
@@ -3,13 +3,17 @@
3
3
  module Raif
4
4
  class Llm
5
5
  include ActiveModel::Model
6
+ include Raif::Concerns::Llms::MessageFormatting
6
7
 
7
8
  attr_accessor :key,
8
9
  :api_name,
9
10
  :default_temperature,
10
11
  :default_max_completion_tokens,
11
12
  :supports_native_tool_use,
12
- :provider_settings
13
+ :provider_settings,
14
+ :input_token_cost,
15
+ :output_token_cost,
16
+ :supported_provider_managed_tools
13
17
 
14
18
  validates :key, presence: true
15
19
  validates :api_name, presence: true
@@ -18,13 +22,26 @@ module Raif
18
22
 
19
23
  alias_method :supports_native_tool_use?, :supports_native_tool_use
20
24
 
21
- def initialize(key:, api_name:, model_provider_settings: {}, supports_native_tool_use: true, temperature: nil, max_completion_tokens: nil)
25
+ def initialize(
26
+ key:,
27
+ api_name:,
28
+ model_provider_settings: {},
29
+ supported_provider_managed_tools: [],
30
+ supports_native_tool_use: true,
31
+ temperature: nil,
32
+ max_completion_tokens: nil,
33
+ input_token_cost: nil,
34
+ output_token_cost: nil
35
+ )
22
36
  @key = key
23
37
  @api_name = api_name
24
38
  @provider_settings = model_provider_settings
25
39
  @supports_native_tool_use = supports_native_tool_use
26
40
  @default_temperature = temperature || 0.7
27
41
  @default_max_completion_tokens = max_completion_tokens
42
+ @input_token_cost = input_token_cost
43
+ @output_token_cost = output_token_cost
44
+ @supported_provider_managed_tools = supported_provider_managed_tools.map(&:to_s)
28
45
  end
29
46
 
30
47
  def name
@@ -32,7 +49,7 @@ module Raif
32
49
  end
33
50
 
34
51
  def chat(message: nil, messages: nil, response_format: :text, available_model_tools: [], source: nil, system_prompt: nil, temperature: nil,
35
- max_completion_tokens: nil)
52
+ max_completion_tokens: nil, &block)
36
53
  unless response_format.is_a?(Symbol)
37
54
  raise ArgumentError,
38
55
  "Raif::Llm#chat - Invalid response format: #{response_format}. Must be a symbol (you passed #{response_format.class}) and be one of: #{VALID_RESPONSE_FORMATS.join(", ")}" # rubocop:disable Layout/LineLength
@@ -55,13 +72,13 @@ module Raif
55
72
  return
56
73
  end
57
74
 
58
- messages = [{ role: "user", content: message }] if message.present?
75
+ messages = [{ "role" => "user", "content" => message }] if message.present?
59
76
 
60
77
  temperature ||= default_temperature
61
78
  max_completion_tokens ||= default_max_completion_tokens
62
79
 
63
80
  model_completion = Raif::ModelCompletion.new(
64
- messages: messages,
81
+ messages: format_messages(messages),
65
82
  system_prompt: system_prompt,
66
83
  response_format: response_format,
67
84
  source: source,
@@ -69,20 +86,102 @@ module Raif
69
86
  model_api_name: api_name,
70
87
  temperature: temperature,
71
88
  max_completion_tokens: max_completion_tokens,
72
- available_model_tools: available_model_tools
89
+ available_model_tools: available_model_tools,
90
+ stream_response: block_given?
73
91
  )
74
92
 
75
- perform_model_completion!(model_completion)
93
+ retry_with_backoff(model_completion) do
94
+ perform_model_completion!(model_completion, &block)
95
+ end
96
+
76
97
  model_completion
98
+ rescue Raif::Errors::StreamingError => e
99
+ Rails.logger.error("Raif streaming error -- code: #{e.code} -- type: #{e.type} -- message: #{e.message} -- event: #{e.event}")
100
+ raise e
101
+ rescue Faraday::Error => e
102
+ Raif.logger.error("LLM API request failed (status: #{e.response_status}): #{e.message}")
103
+ Raif.logger.error(e.response_body)
104
+ raise e
77
105
  end
78
106
 
79
- def perform_model_completion!(model_completion)
80
- raise NotImplementedError, "Raif::Llm subclasses must implement #perform_model_completion!"
107
+ def perform_model_completion!(model_completion, &block)
108
+ raise NotImplementedError, "#{self.class.name} must implement #perform_model_completion!"
81
109
  end
82
110
 
83
111
  def self.valid_response_formats
84
112
  VALID_RESPONSE_FORMATS
85
113
  end
86
114
 
115
+ def supports_provider_managed_tool?(tool_klass)
116
+ supported_provider_managed_tools&.include?(tool_klass.to_s)
117
+ end
118
+
119
+ def validate_provider_managed_tool_support!(tool)
120
+ unless supports_provider_managed_tool?(tool)
121
+ raise Raif::Errors::UnsupportedFeatureError,
122
+ "Invalid provider-managed tool: #{tool.name} for #{key}"
123
+ end
124
+ end
125
+
126
+ private
127
+
128
+ def retry_with_backoff(model_completion)
129
+ retries = 0
130
+ max_retries = Raif.config.llm_request_max_retries
131
+ base_delay = 3
132
+ max_delay = 30
133
+
134
+ begin
135
+ yield
136
+ rescue *Raif.config.llm_request_retriable_exceptions => e
137
+ retries += 1
138
+ if retries <= max_retries
139
+ delay = [base_delay * (2**(retries - 1)), max_delay].min
140
+ Raif.logger.warn("Retrying LLM API request after error: #{e.message}. Attempt #{retries}/#{max_retries}. Waiting #{delay} seconds...")
141
+ model_completion.increment!(:retry_count)
142
+ sleep delay
143
+ retry
144
+ else
145
+ Raif.logger.error("LLM API request failed after #{max_retries} retries. Last error: #{e.message}")
146
+ raise
147
+ end
148
+ end
149
+ end
150
+
151
+ def streaming_response_type
152
+ raise NotImplementedError, "#{self.class.name} must implement #streaming_response_type"
153
+ end
154
+
155
+ def streaming_chunk_handler(model_completion, &block)
156
+ return unless model_completion.stream_response?
157
+
158
+ streaming_response = streaming_response_type.new
159
+ event_parser = EventStreamParser::Parser.new
160
+ accumulated_delta = ""
161
+
162
+ proc do |chunk, _size, _env|
163
+ event_parser.feed(chunk) do |event_type, data, _id, _reconnect_time|
164
+ if data.blank? || data == "[DONE]"
165
+ update_model_completion(model_completion, streaming_response.current_response_json)
166
+ next
167
+ end
168
+
169
+ event_data = JSON.parse(data)
170
+ delta, finish_reason = streaming_response.process_streaming_event(event_type, event_data)
171
+
172
+ accumulated_delta += delta if delta.present?
173
+
174
+ if accumulated_delta.length >= Raif.config.streaming_update_chunk_size_threshold || finish_reason.present?
175
+ update_model_completion(model_completion, streaming_response.current_response_json)
176
+
177
+ if accumulated_delta.present?
178
+ block.call(model_completion, accumulated_delta, event_data)
179
+ accumulated_delta = ""
180
+ end
181
+ end
182
+ end
183
+ end
184
+ end
185
+
87
186
  end
88
187
  end
@@ -1,47 +1,57 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  class Raif::Llms::Anthropic < Raif::Llm
4
+ include Raif::Concerns::Llms::Anthropic::MessageFormatting
5
+ include Raif::Concerns::Llms::Anthropic::ToolFormatting
4
6
 
5
- def perform_model_completion!(model_completion)
6
- params = build_api_parameters(model_completion)
7
-
7
+ def perform_model_completion!(model_completion, &block)
8
+ params = build_request_parameters(model_completion)
8
9
  response = connection.post("messages") do |req|
9
- req.body = params.to_json
10
+ req.body = params
11
+ req.options.on_data = streaming_chunk_handler(model_completion, &block) if model_completion.stream_response?
10
12
  end
11
13
 
12
- resp = JSON.parse(response.body, symbolize_names: true)
13
-
14
- # Handle API errors
15
- unless response.success?
16
- error_message = resp[:error]&.dig(:message) || "Anthropic API error: #{response.status}"
17
- raise Raif::Errors::Anthropic::ApiError, error_message
14
+ unless model_completion.stream_response?
15
+ update_model_completion(model_completion, response.body)
18
16
  end
19
17
 
20
- model_completion.raw_response = if model_completion.response_format_json?
21
- extract_json_response(resp)
22
- else
23
- extract_text_response(resp)
24
- end
25
-
26
- model_completion.response_tool_calls = extract_response_tool_calls(resp)
27
- model_completion.completion_tokens = resp&.dig(:usage, :output_tokens)
28
- model_completion.prompt_tokens = resp&.dig(:usage, :input_tokens)
29
- model_completion.save!
30
-
31
18
  model_completion
32
19
  end
33
20
 
21
+ private
22
+
34
23
  def connection
35
24
  @connection ||= Faraday.new(url: "https://api.anthropic.com/v1") do |f|
36
- f.headers["Content-Type"] = "application/json"
37
25
  f.headers["x-api-key"] = Raif.config.anthropic_api_key
38
26
  f.headers["anthropic-version"] = "2023-06-01"
27
+ f.request :json
28
+ f.response :json
29
+ f.response :raise_error
39
30
  end
40
31
  end
41
32
 
42
- protected
33
+ def streaming_response_type
34
+ Raif::StreamingResponses::Anthropic
35
+ end
36
+
37
+ def update_model_completion(model_completion, response_json)
38
+ model_completion.raw_response = if model_completion.response_format_json?
39
+ extract_json_response(response_json)
40
+ else
41
+ extract_text_response(response_json)
42
+ end
43
+
44
+ model_completion.response_id = response_json&.dig("id")
45
+ model_completion.response_array = response_json&.dig("content")
46
+ model_completion.response_tool_calls = extract_response_tool_calls(response_json)
47
+ model_completion.citations = extract_citations(response_json)
48
+ model_completion.completion_tokens = response_json&.dig("usage", "output_tokens")
49
+ model_completion.prompt_tokens = response_json&.dig("usage", "input_tokens")
50
+ model_completion.total_tokens = model_completion.completion_tokens.to_i + model_completion.prompt_tokens.to_i
51
+ model_completion.save!
52
+ end
43
53
 
44
- def build_api_parameters(model_completion)
54
+ def build_request_parameters(model_completion)
45
55
  params = {
46
56
  model: model_completion.model_api_name,
47
57
  messages: model_completion.messages,
@@ -51,70 +61,75 @@ protected
51
61
 
52
62
  params[:system] = model_completion.system_prompt if model_completion.system_prompt.present?
53
63
 
54
- # Add tools to the request if needed
55
- tools = []
56
-
57
- # If we're looking for a JSON response, add a tool to the request that the model can use to provide a JSON response
58
- if model_completion.response_format_json? && model_completion.json_response_schema.present?
59
- tools << {
60
- name: "json_response",
61
- description: "Generate a structured JSON response based on the provided schema.",
62
- input_schema: model_completion.json_response_schema
63
- }
64
+ if supports_native_tool_use?
65
+ tools = build_tools_parameter(model_completion)
66
+ params[:tools] = tools unless tools.blank?
64
67
  end
65
68
 
66
- # If we support native tool use and have tools available, add them to the request
67
- if supports_native_tool_use? && model_completion.available_model_tools.any?
68
- model_completion.available_model_tools_map.each do |_tool_name, tool|
69
- tools << {
70
- name: tool.tool_name,
71
- description: tool.tool_description,
72
- input_schema: tool.tool_arguments_schema
73
- }
74
- end
75
- end
76
-
77
- params[:tools] = tools if tools.any?
69
+ params[:stream] = true if model_completion.stream_response?
78
70
 
79
71
  params
80
72
  end
81
73
 
82
74
  def extract_text_response(resp)
83
- resp&.dig(:content)&.first&.dig(:text)
75
+ return if resp&.dig("content").blank?
76
+
77
+ resp.dig("content").select{|v| v["type"] == "text" }.map{|v| v["text"] }.join("\n")
84
78
  end
85
79
 
86
80
  def extract_json_response(resp)
87
- return extract_text_response(resp) if resp&.dig(:content).nil?
81
+ return extract_text_response(resp) if resp&.dig("content").nil?
88
82
 
89
83
  # Look for tool_use blocks in the content array
90
- tool_name = "json_response"
91
- tool_response = resp&.dig(:content)&.find do |content|
92
- content[:type] == "tool_use" && content[:name] == tool_name
84
+ tool_response = resp&.dig("content")&.find do |content|
85
+ content["type"] == "tool_use" && content["name"] == "json_response"
93
86
  end
94
87
 
95
88
  if tool_response
96
- JSON.generate(tool_response[:input])
89
+ JSON.generate(tool_response["input"])
97
90
  else
98
91
  extract_text_response(resp)
99
92
  end
100
93
  end
101
94
 
102
95
  def extract_response_tool_calls(resp)
103
- return if resp&.dig(:content).nil?
96
+ return if resp&.dig("content").nil?
104
97
 
105
98
  # Find any tool_use content blocks
106
- tool_uses = resp&.dig(:content)&.select do |content|
107
- content[:type] == "tool_use"
99
+ tool_uses = resp&.dig("content")&.select do |content|
100
+ content["type"] == "tool_use"
108
101
  end
109
102
 
110
103
  return if tool_uses.blank?
111
104
 
112
105
  tool_uses.map do |tool_use|
113
106
  {
114
- "name" => tool_use[:name],
115
- "arguments" => tool_use[:input]
107
+ "name" => tool_use["name"],
108
+ "arguments" => tool_use["input"]
116
109
  }
117
110
  end
118
111
  end
119
112
 
113
+ def extract_citations(resp)
114
+ return [] if resp&.dig("content").nil?
115
+
116
+ citations = []
117
+
118
+ # Look through content blocks for citations
119
+ resp.dig("content").each do |content|
120
+ next unless content["type"] == "text" && content["citations"].present?
121
+
122
+ content["citations"].each do |citation|
123
+ next unless citation["type"] == "web_search_result_location"
124
+
125
+ citations << {
126
+ "url" => Raif::Utils::HtmlFragmentProcessor.strip_tracking_parameters(citation["url"]),
127
+ "title" => citation["title"]
128
+ }
129
+ end
130
+ end
131
+
132
+ citations.uniq{|citation| citation["url"] }
133
+ end
134
+
120
135
  end
@@ -0,0 +1,165 @@
1
+ # frozen_string_literal: true
2
+
3
+ class Raif::Llms::Bedrock < Raif::Llm
4
+ include Raif::Concerns::Llms::Bedrock::MessageFormatting
5
+ include Raif::Concerns::Llms::Bedrock::ToolFormatting
6
+
7
+ def perform_model_completion!(model_completion, &block)
8
+ if Raif.config.aws_bedrock_model_name_prefix.present?
9
+ model_completion.model_api_name = "#{Raif.config.aws_bedrock_model_name_prefix}.#{model_completion.model_api_name}"
10
+ end
11
+
12
+ params = build_request_parameters(model_completion)
13
+
14
+ if model_completion.stream_response?
15
+ bedrock_client.converse_stream(params) do |stream|
16
+ stream.on_error_event do |event|
17
+ raise Raif::Errors::StreamingError.new(
18
+ message: event.error_message,
19
+ type: event.event_type,
20
+ code: event.error_code,
21
+ event: event
22
+ )
23
+ end
24
+
25
+ handler = streaming_chunk_handler(model_completion, &block)
26
+ stream.on_event do |event|
27
+ handler.call(event)
28
+ end
29
+ end
30
+ else
31
+ response = bedrock_client.converse(params)
32
+ update_model_completion(model_completion, response)
33
+ end
34
+
35
+ model_completion
36
+ end
37
+
38
+ private
39
+
40
+ def bedrock_client
41
+ @bedrock_client ||= Aws::BedrockRuntime::Client.new(region: Raif.config.aws_bedrock_region)
42
+ end
43
+
44
+ def update_model_completion(model_completion, resp)
45
+ model_completion.raw_response = if model_completion.response_format_json?
46
+ extract_json_response(resp)
47
+ else
48
+ extract_text_response(resp)
49
+ end
50
+
51
+ model_completion.response_array = resp.output.message.content
52
+ model_completion.response_tool_calls = extract_response_tool_calls(resp)
53
+ model_completion.completion_tokens = resp.usage.output_tokens
54
+ model_completion.prompt_tokens = resp.usage.input_tokens
55
+ model_completion.total_tokens = resp.usage.total_tokens
56
+ model_completion.save!
57
+ end
58
+
59
+ def build_request_parameters(model_completion)
60
+ # The AWS Bedrock SDK requires symbols for keys
61
+ messages_param = model_completion.messages.map(&:deep_symbolize_keys)
62
+ replace_tmp_base64_data_with_bytes(messages_param)
63
+
64
+ params = {
65
+ model_id: model_completion.model_api_name,
66
+ inference_config: { max_tokens: model_completion.max_completion_tokens || 8192 },
67
+ messages: messages_param
68
+ }
69
+
70
+ params[:system] = [{ text: model_completion.system_prompt }] if model_completion.system_prompt.present?
71
+
72
+ if supports_native_tool_use?
73
+ tools = build_tools_parameter(model_completion)
74
+ params[:tool_config] = tools unless tools.blank?
75
+ end
76
+
77
+ params
78
+ end
79
+
80
+ def replace_tmp_base64_data_with_bytes(messages)
81
+ # The AWS Bedrock SDK requires data sent as bytes (and doesn't support base64 like everyone else)
82
+ # The ModelCompletion stores the messages as JSON though, so it can't be raw bytes.
83
+ # We store the image data as base64, so we need to convert that to bytes before sending to AWS.
84
+ messages.each do |message|
85
+ message[:content].each do |content|
86
+ next unless content[:image] || content[:document]
87
+
88
+ type_key = content[:image] ? :image : :document
89
+ base64_data = content[type_key][:source].delete(:tmp_base64_data)
90
+ content[type_key][:source][:bytes] = Base64.strict_decode64(base64_data)
91
+ end
92
+ end
93
+ end
94
+
95
+ def extract_text_response(resp)
96
+ message = resp.output.message
97
+
98
+ # Find the first text content block
99
+ text_block = message.content&.find do |content|
100
+ content.respond_to?(:text) && content.text.present?
101
+ end
102
+
103
+ text_block&.text
104
+ end
105
+
106
+ def extract_json_response(resp)
107
+ # Get the message from the response object
108
+ message = resp.output.message
109
+
110
+ return extract_text_response(resp) if message.content.nil?
111
+
112
+ # Look for tool_use blocks in the content array
113
+ tool_response = message.content.find do |content|
114
+ content.respond_to?(:tool_use) && content.tool_use.present? && content.tool_use.name == "json_response"
115
+ end
116
+
117
+ if tool_response&.tool_use
118
+ JSON.generate(tool_response.tool_use.input)
119
+ else
120
+ extract_text_response(resp)
121
+ end
122
+ end
123
+
124
+ def extract_response_tool_calls(resp)
125
+ # Get the message from the response object
126
+ message = resp.output.message
127
+ return if message.content.nil?
128
+
129
+ # Find any tool_use blocks in the content array
130
+ tool_uses = message.content.select do |content|
131
+ content.respond_to?(:tool_use) && content.tool_use.present?
132
+ end
133
+
134
+ return if tool_uses.blank?
135
+
136
+ tool_uses.map do |content|
137
+ {
138
+ "name" => content.tool_use.name,
139
+ "arguments" => content.tool_use.input
140
+ }
141
+ end
142
+ end
143
+
144
+ def streaming_chunk_handler(model_completion, &block)
145
+ return unless model_completion.stream_response?
146
+
147
+ streaming_response = Raif::StreamingResponses::Bedrock.new
148
+ accumulated_delta = ""
149
+
150
+ proc do |event|
151
+ delta, finish_reason = streaming_response.process_streaming_event(event.class, event)
152
+ accumulated_delta += delta if delta.present?
153
+
154
+ if accumulated_delta.length >= Raif.config.streaming_update_chunk_size_threshold || finish_reason.present?
155
+ update_model_completion(model_completion, streaming_response.current_response)
156
+
157
+ if accumulated_delta.present?
158
+ block.call(model_completion, accumulated_delta, event)
159
+ accumulated_delta = ""
160
+ end
161
+ end
162
+ end
163
+ end
164
+
165
+ end
@@ -0,0 +1,66 @@
1
+ # frozen_string_literal: true
2
+
3
+ class Raif::Llms::OpenAiBase < Raif::Llm
4
+ include Raif::Concerns::Llms::OpenAi::JsonSchemaValidation
5
+
6
+ def perform_model_completion!(model_completion, &block)
7
+ if supports_temperature?
8
+ model_completion.temperature ||= default_temperature
9
+ else
10
+ Raif.logger.warn "Temperature is not supported for #{api_name}. Ignoring temperature parameter."
11
+ model_completion.temperature = nil
12
+ end
13
+
14
+ parameters = build_request_parameters(model_completion)
15
+
16
+ response = connection.post(api_path) do |req|
17
+ req.body = parameters
18
+ req.options.on_data = streaming_chunk_handler(model_completion, &block) if model_completion.stream_response?
19
+ end
20
+
21
+ unless model_completion.stream_response?
22
+ update_model_completion(model_completion, response.body)
23
+ end
24
+
25
+ model_completion
26
+ end
27
+
28
+ private
29
+
30
+ def connection
31
+ @connection ||= Faraday.new(url: "https://api.openai.com/v1") do |f|
32
+ f.headers["Authorization"] = "Bearer #{Raif.config.open_ai_api_key}"
33
+ f.request :json
34
+ f.response :json
35
+ f.response :raise_error
36
+ end
37
+ end
38
+
39
+ def format_system_prompt(model_completion)
40
+ formatted_system_prompt = model_completion.system_prompt.to_s.strip
41
+
42
+ # If the response format is JSON, we need to include "as json" in the system prompt.
43
+ # OpenAI requires this and will throw an error if it's not included.
44
+ if model_completion.response_format_json?
45
+ # Ensure system prompt ends with a period if not empty
46
+ if formatted_system_prompt.present? && !formatted_system_prompt.end_with?(".", "?", "!")
47
+ formatted_system_prompt += "."
48
+ end
49
+ formatted_system_prompt += " Return your response as JSON."
50
+ formatted_system_prompt.strip!
51
+ end
52
+
53
+ formatted_system_prompt
54
+ end
55
+
56
+ def supports_structured_outputs?
57
+ # Not all OpenAI models support structured outputs:
58
+ # https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat#supported-models
59
+ provider_settings.key?(:supports_structured_outputs) ? provider_settings[:supports_structured_outputs] : true
60
+ end
61
+
62
+ def supports_temperature?
63
+ provider_settings.key?(:supports_temperature) ? provider_settings[:supports_temperature] : true
64
+ end
65
+
66
+ end
@@ -0,0 +1,100 @@
1
+ # frozen_string_literal: true
2
+
3
+ class Raif::Llms::OpenAiCompletions < Raif::Llms::OpenAiBase
4
+ include Raif::Concerns::Llms::OpenAiCompletions::MessageFormatting
5
+ include Raif::Concerns::Llms::OpenAiCompletions::ToolFormatting
6
+
7
+ private
8
+
9
+ def api_path
10
+ "chat/completions"
11
+ end
12
+
13
+ def streaming_response_type
14
+ Raif::StreamingResponses::OpenAiCompletions
15
+ end
16
+
17
+ def update_model_completion(model_completion, response_json)
18
+ model_completion.update!(
19
+ response_id: response_json["id"],
20
+ response_tool_calls: extract_response_tool_calls(response_json),
21
+ raw_response: response_json.dig("choices", 0, "message", "content"),
22
+ response_array: response_json["choices"],
23
+ completion_tokens: response_json.dig("usage", "completion_tokens"),
24
+ prompt_tokens: response_json.dig("usage", "prompt_tokens"),
25
+ total_tokens: response_json.dig("usage", "total_tokens")
26
+ )
27
+ end
28
+
29
+ def extract_response_tool_calls(resp)
30
+ return if resp.dig("choices", 0, "message", "tool_calls").blank?
31
+
32
+ resp.dig("choices", 0, "message", "tool_calls").map do |tool_call|
33
+ {
34
+ "name" => tool_call["function"]["name"],
35
+ "arguments" => JSON.parse(tool_call["function"]["arguments"])
36
+ }
37
+ end
38
+ end
39
+
40
+ def build_request_parameters(model_completion)
41
+ formatted_system_prompt = format_system_prompt(model_completion)
42
+
43
+ messages = model_completion.messages
44
+ messages_with_system = if formatted_system_prompt.blank?
45
+ messages
46
+ else
47
+ [{ "role" => "system", "content" => formatted_system_prompt }] + messages
48
+ end
49
+
50
+ parameters = {
51
+ model: api_name,
52
+ messages: messages_with_system
53
+ }
54
+
55
+ if supports_temperature?
56
+ parameters[:temperature] = model_completion.temperature.to_f
57
+ end
58
+
59
+ # If the LLM supports native tool use and there are available tools, add them to the parameters
60
+ if supports_native_tool_use?
61
+ tools = build_tools_parameter(model_completion)
62
+ parameters[:tools] = tools unless tools.blank?
63
+ end
64
+
65
+ if model_completion.stream_response?
66
+ parameters[:stream] = true
67
+ # Ask for usage stats in the last chunk
68
+ parameters[:stream_options] = { include_usage: true }
69
+ end
70
+
71
+ # Add response format if needed
72
+ response_format = determine_response_format(model_completion)
73
+ parameters[:response_format] = response_format if response_format
74
+ model_completion.response_format_parameter = response_format[:type] if response_format
75
+
76
+ parameters
77
+ end
78
+
79
+ def determine_response_format(model_completion)
80
+ # Only configure response format for JSON outputs
81
+ return unless model_completion.response_format_json?
82
+
83
+ if model_completion.json_response_schema.present? && supports_structured_outputs?
84
+ validate_json_schema!(model_completion.json_response_schema)
85
+
86
+ {
87
+ type: "json_schema",
88
+ json_schema: {
89
+ name: "json_response_schema",
90
+ strict: true,
91
+ schema: model_completion.json_response_schema
92
+ }
93
+ }
94
+ else
95
+ # Default JSON mode for OpenAI models that don't support structured outputs or no schema is provided
96
+ { type: "json_object" }
97
+ end
98
+ end
99
+
100
+ end