raif 1.4.0 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +2 -2
  3. data/app/assets/builds/raif_admin.css +40 -2
  4. data/app/assets/builds/raif_admin_sprockets.js +2709 -0
  5. data/app/assets/javascript/raif/admin/copy_to_clipboard_controller.js +132 -0
  6. data/app/assets/javascript/raif/admin/cost_estimate_controller.js +80 -0
  7. data/app/assets/javascript/raif/admin/judge_config_controller.js +23 -0
  8. data/app/assets/javascript/raif/admin/select_all_checkboxes_controller.js +33 -0
  9. data/app/assets/javascript/raif/admin/sortable_table_controller.js +51 -0
  10. data/app/assets/javascript/raif/admin/table_search_controller.js +15 -0
  11. data/app/assets/javascript/raif/admin/tom_select_controller.js +33 -0
  12. data/app/assets/javascript/raif_admin.js +23 -0
  13. data/app/assets/javascript/raif_admin_sprockets.js +24 -0
  14. data/app/assets/stylesheets/raif_admin.scss +50 -1
  15. data/app/controllers/raif/admin/agents_controller.rb +27 -1
  16. data/app/controllers/raif/admin/configs_controller.rb +1 -0
  17. data/app/controllers/raif/admin/llms_controller.rb +27 -0
  18. data/app/controllers/raif/admin/model_completions_controller.rb +6 -0
  19. data/app/controllers/raif/admin/prompt_studio/agents_controller.rb +25 -0
  20. data/app/controllers/raif/admin/prompt_studio/base_controller.rb +32 -0
  21. data/app/controllers/raif/admin/prompt_studio/batch_runs_controller.rb +102 -0
  22. data/app/controllers/raif/admin/prompt_studio/conversations_controller.rb +25 -0
  23. data/app/controllers/raif/admin/prompt_studio/tasks_controller.rb +64 -0
  24. data/app/controllers/raif/admin/tasks_controller.rb +5 -0
  25. data/app/helpers/raif/application_helper.rb +40 -0
  26. data/app/jobs/raif/prompt_studio_batch_run_item_job.rb +11 -0
  27. data/app/jobs/raif/prompt_studio_batch_run_job.rb +15 -0
  28. data/app/jobs/raif/prompt_studio_task_run_job.rb +36 -0
  29. data/app/models/raif/agent.rb +36 -5
  30. data/app/models/raif/agents/native_tool_calling_agent.rb +101 -19
  31. data/app/models/raif/concerns/has_prompt_templates.rb +88 -0
  32. data/app/models/raif/concerns/has_runtime_duration.rb +41 -0
  33. data/app/models/raif/concerns/json_schema_definition.rb +16 -3
  34. data/app/models/raif/concerns/llm_prompt_caching.rb +20 -0
  35. data/app/models/raif/concerns/llms/anthropic/message_formatting.rb +6 -0
  36. data/app/models/raif/concerns/llms/anthropic/tool_formatting.rb +5 -1
  37. data/app/models/raif/concerns/llms/bedrock/message_formatting.rb +7 -0
  38. data/app/models/raif/concerns/llms/bedrock/tool_formatting.rb +4 -0
  39. data/app/models/raif/concerns/llms/google/message_formatting.rb +5 -2
  40. data/app/models/raif/concerns/llms/google/tool_formatting.rb +4 -0
  41. data/app/models/raif/concerns/llms/message_formatting.rb +30 -0
  42. data/app/models/raif/concerns/llms/open_ai_completions/response_tool_calls.rb +1 -1
  43. data/app/models/raif/concerns/llms/open_ai_completions/tool_formatting.rb +4 -0
  44. data/app/models/raif/concerns/llms/open_ai_responses/tool_formatting.rb +4 -0
  45. data/app/models/raif/concerns/provider_managed_tool_calls.rb +162 -0
  46. data/app/models/raif/conversation.rb +24 -3
  47. data/app/models/raif/conversation_entry.rb +6 -3
  48. data/app/models/raif/embedding_models/bedrock.rb +10 -1
  49. data/app/models/raif/embedding_models/google.rb +37 -0
  50. data/app/models/raif/evals/llm_judge.rb +70 -0
  51. data/{lib → app/models}/raif/evals/llm_judges/binary.rb +38 -0
  52. data/{lib → app/models}/raif/evals/llm_judges/comparative.rb +38 -0
  53. data/{lib → app/models}/raif/evals/llm_judges/scored.rb +38 -0
  54. data/{lib → app/models}/raif/evals/llm_judges/summarization.rb +38 -0
  55. data/app/models/raif/llm.rb +82 -7
  56. data/app/models/raif/llms/anthropic.rb +26 -4
  57. data/app/models/raif/llms/bedrock.rb +59 -5
  58. data/app/models/raif/llms/google.rb +28 -2
  59. data/app/models/raif/llms/open_ai_base.rb +4 -0
  60. data/app/models/raif/llms/open_ai_completions.rb +9 -2
  61. data/app/models/raif/llms/open_ai_responses.rb +9 -2
  62. data/app/models/raif/llms/open_router.rb +10 -3
  63. data/app/models/raif/model_completion.rb +75 -34
  64. data/app/models/raif/model_tool.rb +45 -3
  65. data/app/models/raif/model_tool_invocation.rb +31 -1
  66. data/app/models/raif/prompt_studio_batch_run.rb +155 -0
  67. data/app/models/raif/prompt_studio_batch_run_item.rb +220 -0
  68. data/app/models/raif/streaming_responses/bedrock.rb +60 -1
  69. data/app/models/raif/task.rb +30 -6
  70. data/app/views/layouts/raif/admin.html.erb +31 -1
  71. data/app/views/raif/admin/agents/_agent.html.erb +1 -0
  72. data/app/views/raif/admin/agents/index.html.erb +48 -0
  73. data/app/views/raif/admin/agents/show.html.erb +4 -0
  74. data/app/views/raif/admin/llms/index.html.erb +110 -0
  75. data/app/views/raif/admin/model_completions/_model_completion.html.erb +3 -7
  76. data/app/views/raif/admin/model_completions/index.html.erb +14 -1
  77. data/app/views/raif/admin/model_completions/show.html.erb +164 -55
  78. data/app/views/raif/admin/model_tool_invocations/index.html.erb +1 -1
  79. data/app/views/raif/admin/model_tool_invocations/show.html.erb +18 -0
  80. data/app/views/raif/admin/prompt_studio/agents/index.html.erb +56 -0
  81. data/app/views/raif/admin/prompt_studio/agents/show.html.erb +57 -0
  82. data/app/views/raif/admin/prompt_studio/batch_runs/_batch_run_item.html.erb +54 -0
  83. data/app/views/raif/admin/prompt_studio/batch_runs/_judge_config_fields.html.erb +76 -0
  84. data/app/views/raif/admin/prompt_studio/batch_runs/_judge_detail_modal.html.erb +27 -0
  85. data/app/views/raif/admin/prompt_studio/batch_runs/_modal.html.erb +35 -0
  86. data/app/views/raif/admin/prompt_studio/batch_runs/_progress.html.erb +78 -0
  87. data/app/views/raif/admin/prompt_studio/batch_runs/show.html.erb +49 -0
  88. data/app/views/raif/admin/prompt_studio/conversations/index.html.erb +48 -0
  89. data/app/views/raif/admin/prompt_studio/conversations/show.html.erb +36 -0
  90. data/app/views/raif/admin/prompt_studio/shared/_nav_tabs.html.erb +17 -0
  91. data/app/views/raif/admin/prompt_studio/shared/_prompt_comparison.html.erb +87 -0
  92. data/app/views/raif/admin/prompt_studio/shared/_type_filter.html.erb +54 -0
  93. data/app/views/raif/admin/prompt_studio/tasks/_task_result.html.erb +145 -0
  94. data/app/views/raif/admin/prompt_studio/tasks/_task_row.html.erb +12 -0
  95. data/app/views/raif/admin/prompt_studio/tasks/_task_type_filter.html.erb +58 -0
  96. data/app/views/raif/admin/prompt_studio/tasks/_tasks_table.html.erb +22 -0
  97. data/app/views/raif/admin/prompt_studio/tasks/index.html.erb +35 -0
  98. data/app/views/raif/admin/prompt_studio/tasks/show.html.erb +19 -0
  99. data/app/views/raif/admin/tasks/_task.html.erb +1 -0
  100. data/app/views/raif/admin/tasks/index.html.erb +17 -5
  101. data/app/views/raif/admin/tasks/show.html.erb +20 -0
  102. data/app/views/raif/conversation_entries/_message.html.erb +10 -6
  103. data/config/importmap.rb +8 -0
  104. data/config/locales/admin.en.yml +128 -0
  105. data/config/locales/en.yml +36 -2
  106. data/config/routes.rb +8 -0
  107. data/db/migrate/20260307000000_add_prompt_studio_run_to_raif_tasks.rb +7 -0
  108. data/db/migrate/20260308000000_create_raif_prompt_studio_batch_runs.rb +27 -0
  109. data/db/migrate/20260308000001_create_raif_prompt_studio_batch_run_items.rb +24 -0
  110. data/db/migrate/20260407000000_add_cache_token_columns_to_raif_model_completions.rb +8 -0
  111. data/lib/generators/raif/agent/agent_generator.rb +18 -0
  112. data/lib/generators/raif/agent/templates/agent.rb.tt +7 -5
  113. data/lib/generators/raif/agent/templates/system_prompt.erb.tt +3 -0
  114. data/lib/generators/raif/conversation/conversation_generator.rb +19 -1
  115. data/lib/generators/raif/conversation/templates/system_prompt.erb.tt +4 -0
  116. data/lib/generators/raif/install/templates/initializer.rb +68 -27
  117. data/lib/generators/raif/task/task_generator.rb +18 -0
  118. data/lib/generators/raif/task/templates/prompt.erb.tt +4 -0
  119. data/lib/generators/raif/task/templates/task.rb.tt +9 -8
  120. data/lib/raif/configuration.rb +10 -0
  121. data/lib/raif/embedding_model_registry.rb +8 -0
  122. data/lib/raif/engine.rb +16 -1
  123. data/lib/raif/errors/blank_response_error.rb +8 -0
  124. data/lib/raif/errors/prompt_template_error.rb +15 -0
  125. data/lib/raif/errors.rb +2 -0
  126. data/lib/raif/evals.rb +0 -6
  127. data/lib/raif/llm_registry.rb +230 -9
  128. data/lib/raif/prompt_studio_comparison_builder.rb +138 -0
  129. data/lib/raif/token_estimator.rb +28 -0
  130. data/lib/raif/version.rb +1 -1
  131. data/lib/raif.rb +2 -0
  132. data/spec/support/rspec_helpers.rb +7 -1
  133. data/spec/support/test_task.rb +9 -0
  134. data/spec/support/test_template_task.rb +41 -0
  135. metadata +65 -7
  136. data/lib/raif/evals/llm_judge.rb +0 -32
  137. /data/{lib → app/models}/raif/evals/scoring_rubric.rb +0 -0
@@ -0,0 +1,102 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Raif
4
+ module Admin
5
+ module PromptStudio
6
+ class BatchRunsController < BaseController
7
+ def create
8
+ unless prompt_studio_runs_enabled?
9
+ redirect_to raif.admin_prompt_studio_tasks_path, alert: t("raif.admin.prompt_studio.common.runs_disabled")
10
+ return
11
+ end
12
+
13
+ source_tasks = resolve_source_tasks
14
+ if source_tasks.empty?
15
+ redirect_to raif.admin_prompt_studio_tasks_path(task_type: params[:task_type]),
16
+ alert: t("raif.admin.prompt_studio.batch_runs.create.no_tasks_selected")
17
+ return
18
+ end
19
+
20
+ available_keys = Raif.available_llm_keys.map(&:to_s)
21
+
22
+ unless params[:llm_model_key].present? && available_keys.include?(params[:llm_model_key])
23
+ redirect_to raif.admin_prompt_studio_tasks_path(task_type: params[:task_type]),
24
+ alert: t("raif.admin.prompt_studio.tasks.rerun.invalid_model")
25
+ return
26
+ end
27
+
28
+ if params[:judge_type].present? && params[:judge_llm_model_key].present? && !available_keys.include?(params[:judge_llm_model_key])
29
+ redirect_to raif.admin_prompt_studio_tasks_path(task_type: params[:task_type]),
30
+ alert: t("raif.admin.prompt_studio.tasks.rerun.invalid_model")
31
+ return
32
+ end
33
+
34
+ batch_run = Raif::PromptStudioBatchRun.new(
35
+ task_type: params[:task_type],
36
+ llm_model_key: params[:llm_model_key],
37
+ judge_type: params[:judge_type].presence,
38
+ judge_llm_model_key: params[:judge_llm_model_key].presence,
39
+ judge_config: build_judge_config,
40
+ total_count: source_tasks.size
41
+ )
42
+
43
+ batch_run.save!
44
+
45
+ source_tasks.each do |task|
46
+ batch_run.items.create!(source_task: task)
47
+ end
48
+
49
+ Raif::PromptStudioBatchRunJob.perform_later(batch_run: batch_run)
50
+
51
+ redirect_to raif.admin_prompt_studio_batch_run_path(batch_run)
52
+ rescue StandardError => e
53
+ redirect_to raif.admin_prompt_studio_tasks_path(task_type: params[:task_type]),
54
+ alert: t("raif.admin.prompt_studio.batch_runs.create.error", message: e.message)
55
+ end
56
+
57
+ def show
58
+ @batch_run = Raif::PromptStudioBatchRun.find(params[:id])
59
+ items = @batch_run.items.includes(:source_task, :result_task, :judge_task).order(:id)
60
+ @pagy, @items = pagy(items)
61
+ end
62
+
63
+ private
64
+
65
+ def resolve_source_tasks
66
+ ids = Array(params[:source_task_ids]).map(&:to_i).reject(&:zero?)
67
+ scope = Raif::Task.where(id: ids).completed
68
+ scope = scope.where(type: params[:task_type]) if params[:task_type].present?
69
+ scope
70
+ end
71
+
72
+ def build_judge_config
73
+ config = case params[:judge_type]
74
+ when "Raif::Evals::LlmJudges::Binary"
75
+ {
76
+ "criteria" => params[:judge_criteria].presence || "",
77
+ "strict_mode" => params[:judge_strict_mode] == "1"
78
+ }
79
+ when "Raif::Evals::LlmJudges::Scored"
80
+ {
81
+ "scoring_rubric" => params[:judge_scoring_rubric].presence || "accuracy"
82
+ }
83
+ when "Raif::Evals::LlmJudges::Comparative"
84
+ {
85
+ "comparison_criteria" => params[:judge_comparison_criteria].presence || ""
86
+ }
87
+ when "Raif::Evals::LlmJudges::Summarization"
88
+ {}
89
+ else
90
+ {}
91
+ end
92
+
93
+ if params[:judge_type].present?
94
+ config["include_original_prompt_as_context"] = params[:judge_include_original_prompt_as_context] == "1"
95
+ end
96
+
97
+ config
98
+ end
99
+ end
100
+ end
101
+ end
102
+ end
@@ -0,0 +1,25 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Raif
4
+ module Admin
5
+ module PromptStudio
6
+ class ConversationsController < BaseController
7
+ def index
8
+ @conversation_types = Raif::Conversation.distinct.pluck(:type).sort
9
+ @selected_type = params[:conversation_type] if params[:conversation_type].present?
10
+ @llm_model_keys = Raif::Conversation.where(type: @selected_type).distinct.pluck(:llm_model_key).compact.sort if @selected_type.present?
11
+
12
+ if @selected_type.present?
13
+ conversations = apply_filters(Raif::Conversation.where(type: @selected_type)).order(created_at: :desc)
14
+ @pagy, @conversations = pagy(conversations)
15
+ end
16
+ end
17
+
18
+ def show
19
+ @conversation = Raif::Conversation.find(params[:id])
20
+ @comparison = build_prompt_comparison(@conversation)
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,64 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Raif
4
+ module Admin
5
+ module PromptStudio
6
+ class TasksController < BaseController
7
+ def index
8
+ @task_types = Raif::Task.distinct.pluck(:type).sort
9
+ @selected_type = params[:task_type] if params[:task_type].present?
10
+ @llm_model_keys = Raif::Task.where(type: @selected_type).distinct.pluck(:llm_model_key).compact.sort if @selected_type.present?
11
+
12
+ if @selected_type.present?
13
+ tasks = apply_filters(Raif::Task.where(type: @selected_type).completed).includes(:raif_model_completion).order(created_at: :desc)
14
+ @pagy, @tasks = pagy(tasks)
15
+ end
16
+
17
+ @show_batch_runs = prompt_studio_runs_enabled? && @selected_type.present? && @tasks.present?
18
+ end
19
+
20
+ def show
21
+ @task = Raif::Task.find(params[:id])
22
+ @comparison = build_prompt_comparison(@task)
23
+ @original_task = @task.source if @task.prompt_studio_run? && @task.source.is_a?(Raif::Task)
24
+ @available_llm_keys = Raif.available_llm_keys.map(&:to_s).sort
25
+ end
26
+
27
+ def create
28
+ original_task = Raif::Task.find(params[:source_task_id])
29
+
30
+ unless prompt_studio_runs_enabled?
31
+ redirect_to raif.admin_prompt_studio_task_path(original_task), alert: t("raif.admin.prompt_studio.common.runs_disabled")
32
+ return
33
+ end
34
+
35
+ llm_model_key = params[:llm_model_key]
36
+
37
+ unless llm_model_key.present? && Raif.available_llm_keys.map(&:to_s).include?(llm_model_key)
38
+ redirect_to raif.admin_prompt_studio_task_path(original_task), alert: t("raif.admin.prompt_studio.tasks.rerun.invalid_model")
39
+ return
40
+ end
41
+
42
+ new_task = original_task.class.new(
43
+ creator: original_task.creator,
44
+ source: original_task,
45
+ llm_model_key: llm_model_key,
46
+ available_model_tools: original_task.available_model_tools,
47
+ run_with: original_task.run_with,
48
+ prompt_studio_run: true,
49
+ started_at: Time.current
50
+ )
51
+ new_task.assign_attributes(original_task.prompt_studio_task_attributes)
52
+ new_task.save!
53
+ Raif::PromptStudioTaskRunJob.perform_later(task: new_task)
54
+
55
+ redirect_to raif.admin_prompt_studio_task_path(new_task)
56
+ rescue StandardError => e
57
+ new_task&.update(failed_at: Time.current) unless new_task&.failed_at?
58
+ redirect_to raif.admin_prompt_studio_task_path(original_task || params[:source_task_id]),
59
+ alert: t("raif.admin.prompt_studio.tasks.rerun.error", message: e.message)
60
+ end
61
+ end
62
+ end
63
+ end
64
+ end
@@ -12,6 +12,9 @@ module Raif
12
12
  @task_statuses = [:all, :completed, :failed, :in_progress, :pending]
13
13
  @selected_statuses = params[:task_statuses].present? ? params[:task_statuses].to_sym : :all
14
14
 
15
+ @selected_llm_model_key = params[:llm_model_key].presence
16
+ @llm_model_keys = Raif::Task.distinct.order(:llm_model_key).pluck(:llm_model_key)
17
+
15
18
  tasks = Raif::Task.order(created_at: :desc)
16
19
  tasks = tasks.where(type: @selected_type) if @selected_type.present? && @selected_type != "all"
17
20
 
@@ -28,6 +31,8 @@ module Raif
28
31
  end
29
32
  end
30
33
 
34
+ tasks = tasks.where(llm_model_key: @selected_llm_model_key) if @selected_llm_model_key.present?
35
+
31
36
  @pagy, @tasks = pagy(tasks)
32
37
  end
33
38
 
@@ -3,5 +3,45 @@
3
3
  module Raif
4
4
  module ApplicationHelper
5
5
  include Pagy::Frontend
6
+
7
+ def format_task_response(task)
8
+ if task.response_format_json? && task.raw_response.present?
9
+ JSON.pretty_generate(JSON.parse(task.raw_response))
10
+ else
11
+ task.raw_response
12
+ end
13
+ rescue JSON::ParserError
14
+ task.raw_response
15
+ end
16
+
17
+ def pretty_json(value)
18
+ JSON.pretty_generate(JSON.parse(value))
19
+ rescue StandardError
20
+ value
21
+ end
22
+
23
+ def llm_model_options(selected: nil)
24
+ options = Raif.available_llm_keys.map do |key|
25
+ label = I18n.t("raif.model_names.#{key}", default: key.to_s)
26
+ [label, key.to_s]
27
+ end.sort_by(&:first)
28
+
29
+ options_for_select(options, selected&.to_s)
30
+ end
31
+
32
+ def llm_pricing_json
33
+ pricing = {}
34
+ Raif.available_llm_keys.each do |key|
35
+ config = Raif.llm_config(key)
36
+ next unless config
37
+
38
+ pricing[key.to_s] = {
39
+ input: config[:input_token_cost] || 0,
40
+ output: config[:output_token_cost] || 0
41
+ }
42
+ end
43
+
44
+ pricing.to_json
45
+ end
6
46
  end
7
47
  end
@@ -0,0 +1,11 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Raif
4
+ class PromptStudioBatchRunItemJob < ApplicationJob
5
+
6
+ def perform(item:)
7
+ item.execute!
8
+ end
9
+
10
+ end
11
+ end
@@ -0,0 +1,15 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Raif
4
+ class PromptStudioBatchRunJob < ApplicationJob
5
+
6
+ def perform(batch_run:)
7
+ batch_run.update!(started_at: Time.current)
8
+
9
+ batch_run.items.where(status: "pending").find_each do |item|
10
+ Raif::PromptStudioBatchRunItemJob.perform_later(item: item)
11
+ end
12
+ end
13
+
14
+ end
15
+ end
@@ -0,0 +1,36 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Raif
4
+ class PromptStudioTaskRunJob < ApplicationJob
5
+
6
+ def perform(task:)
7
+ task.run
8
+ broadcast_task_result(task)
9
+ rescue StandardError => e
10
+ logger.error "Error running prompt studio task: #{e.message}"
11
+ logger.error e.backtrace&.join("\n")
12
+
13
+ task.update(failed_at: Time.current) unless task.failed_at?
14
+ broadcast_task_result(task)
15
+ end
16
+
17
+ private
18
+
19
+ def broadcast_task_result(task)
20
+ comparison = Raif::PromptStudioComparisonBuilder.build(task)
21
+ original_task = task.prompt_studio_run? && task.source.is_a?(Raif::Task) ? task.source : nil
22
+
23
+ html = Raif::Admin::PromptStudio::TasksController.render(
24
+ partial: "raif/admin/prompt_studio/tasks/task_result",
25
+ locals: { task: task, comparison: comparison, original_task: original_task }
26
+ )
27
+
28
+ Turbo::StreamsChannel.broadcast_replace_to(
29
+ task,
30
+ target: ActionView::RecordIdentifier.dom_id(task, :result),
31
+ html: html
32
+ )
33
+ end
34
+
35
+ end
36
+ end
@@ -35,11 +35,15 @@
35
35
  #
36
36
  module Raif
37
37
  class Agent < ApplicationRecord
38
+ prepend Raif::Concerns::HasPromptTemplates
39
+
38
40
  include Raif::Concerns::HasLlm
39
41
  include Raif::Concerns::HasRequestedLanguage
40
42
  include Raif::Concerns::HasAvailableModelTools
43
+ include Raif::Concerns::HasRuntimeDuration
41
44
  include Raif::Concerns::InvokesModelTools
42
45
  include Raif::Concerns::AgentInferenceStats
46
+ include Raif::Concerns::LlmPromptCaching
43
47
  include Raif::Concerns::RunWith
44
48
 
45
49
  belongs_to :creator, polymorphic: true
@@ -122,7 +126,9 @@ module Raif
122
126
  source: self,
123
127
  system_prompt: system_prompt,
124
128
  available_model_tools: native_model_tools,
125
- tool_choice: tool_choice_for_iteration
129
+ tool_choice: tool_choice_for_iteration,
130
+ anthropic_prompt_caching_enabled: self.class.anthropic_prompt_caching_enabled,
131
+ bedrock_prompt_caching_enabled: self.class.bedrock_prompt_caching_enabled
126
132
  )
127
133
 
128
134
  logger.debug <<~DEBUG
@@ -137,14 +143,14 @@ module Raif
137
143
  DEBUG
138
144
 
139
145
  process_iteration_model_completion(model_completion)
140
- break if final_answer.present?
146
+ break if final_answer.present? || failed?
141
147
  end
142
148
 
143
- completed!
149
+ finalize_run!
144
150
  final_answer
145
151
  rescue StandardError => e
146
- self.failed_at = Time.current
147
- self.failure_reason = e.message
152
+ self.failed_at ||= Time.current
153
+ self.failure_reason ||= e.message
148
154
  save!
149
155
 
150
156
  raise
@@ -160,6 +166,17 @@ module Raif
160
166
  # no-op by default. Can be overridden by subclasses to add default model tools
161
167
  end
162
168
 
169
+ def finalize_run!
170
+ validate_successful_completion
171
+ return if failed?
172
+
173
+ completed!
174
+ end
175
+
176
+ def validate_successful_completion
177
+ # no-op by default. Can be overridden by subclasses to enforce success criteria.
178
+ end
179
+
163
180
  def process_iteration_model_completion(model_completion)
164
181
  raise NotImplementedError, "#{self.class.name} must implement process_iteration_model_completion"
165
182
  end
@@ -181,6 +198,14 @@ module Raif
181
198
  nil
182
199
  end
183
200
 
201
+ # Hook for subclasses to require a specific tool on the current iteration.
202
+ # Override to align prompt warnings and provider-level tool_choice.
203
+ # Overrides should be deterministic and side-effect free for a given iteration.
204
+ # @return [Class, nil] A model tool class, or nil if no specific tool is required.
205
+ def required_tool_for_iteration
206
+ nil
207
+ end
208
+
184
209
  def add_conversation_history_entry(entry)
185
210
  entry_stringified = entry.stringify_keys
186
211
  conversation_history << entry_stringified
@@ -188,6 +213,12 @@ module Raif
188
213
  on_conversation_history_entry.call(entry_stringified) if on_conversation_history_entry.present?
189
214
  end
190
215
 
216
+ def fail_run!(reason)
217
+ self.failed_at ||= Time.current
218
+ self.failure_reason ||= reason
219
+ save!
220
+ end
221
+
191
222
  def build_system_prompt
192
223
  raise NotImplementedError, "Subclasses of Raif::Agent must implement build_system_prompt"
193
224
  end
@@ -88,24 +88,32 @@ module Raif
88
88
  available_model_tools_map["agent_final_answer"]
89
89
  end
90
90
 
91
- # Warn the agent that it must provide a final answer on the next iteration
91
+ def required_tool_for_iteration
92
+ return final_answer_tool if final_iteration?
93
+
94
+ nil
95
+ end
96
+
92
97
  def before_iteration_llm_chat
93
- return unless final_iteration?
98
+ required_tool = current_iteration_required_tool
99
+ return if required_tool.blank?
94
100
 
95
101
  warning_message = Raif::Messages::UserMessage.new(
96
- content: I18n.t("raif.agents.native_tool_calling_agent.final_answer_warning")
102
+ content: required_tool_warning_message(required_tool)
97
103
  )
98
104
  add_conversation_history_entry(warning_message.to_h)
99
105
  end
100
106
 
101
- # On the final iteration, force the agent to use the agent_final_answer tool
102
107
  def tool_choice_for_iteration
103
- return unless final_iteration?
108
+ return current_iteration_required_tool if current_iteration_required_tool.present?
109
+ return :required if llm.supports_faithful_required_tool_choice?(native_model_tools)
104
110
 
105
- final_answer_tool
111
+ log_required_tool_choice_fallback_once!
112
+ nil
106
113
  end
107
114
 
108
115
  def process_iteration_model_completion(model_completion)
116
+ required_tool = current_iteration_required_tool
109
117
  assistant_response_message = model_completion.parsed_response if model_completion.parsed_response.present?
110
118
 
111
119
  # The model made no tool call in this completion. Tell it to make a tool call.
@@ -115,36 +123,61 @@ module Raif
115
123
  add_conversation_history_entry(assistant_message.to_h)
116
124
  end
117
125
 
118
- error_message = Raif::Messages::UserMessage.new(
119
- content: "Error: Previous message contained no tool call. Make a tool call at each step. Available tools: #{available_model_tools_map.keys.join(", ")}" # rubocop:disable Layout/LineLength
120
- )
121
- add_conversation_history_entry(error_message.to_h)
126
+ error_content = if required_tool.present?
127
+ "Error: This iteration required the tool '#{required_tool.tool_name}', but the model response contained no tool call. Available tools: #{available_model_tools_map.keys.join(", ")}" # rubocop:disable Layout/LineLength
128
+ else
129
+ "Error: Previous message contained no tool call. Make a tool call at each step. Available tools: #{available_model_tools_map.keys.join(", ")}" # rubocop:disable Layout/LineLength
130
+ end
131
+ handle_iteration_error(error_content, required_tool:)
132
+
133
+ return
134
+ end
135
+
136
+ # The model returned multiple tool calls. We only allow one per step.
137
+ if model_completion.response_tool_calls.length > 1
138
+ if assistant_response_message.present?
139
+ assistant_message = Raif::Messages::AssistantMessage.new(content: assistant_response_message)
140
+ add_conversation_history_entry(assistant_message.to_h)
141
+ end
142
+
143
+ error_content = "Error: Multiple tool calls received. Only one tool call is allowed per step. " \
144
+ "Please call exactly one tool at a time."
145
+ handle_iteration_error(error_content, required_tool:)
122
146
 
123
147
  return
124
148
  end
125
149
 
126
150
  tool_call = model_completion.response_tool_calls.first
127
151
 
128
- # Add the tool call to history
152
+ tool_name = tool_call["name"]
153
+ tool_arguments = tool_call["arguments"]
154
+ tool_klass = available_model_tools_map[tool_name]
155
+
156
+ # Prepare tool arguments before recording to history so the history
157
+ # accurately reflects what was actually invoked
158
+ tool_arguments = tool_klass.prepare_tool_arguments(tool_arguments) if tool_klass.present?
159
+
160
+ # Add the tool call to history (with prepared arguments if tool is known)
129
161
  tool_call_message = Raif::Messages::ToolCall.new(
130
162
  provider_tool_call_id: tool_call["provider_tool_call_id"],
131
163
  name: tool_call["name"],
132
- arguments: tool_call["arguments"],
164
+ arguments: tool_arguments,
133
165
  assistant_message: assistant_response_message,
134
166
  provider_metadata: tool_call["provider_metadata"]
135
167
  )
136
168
  add_conversation_history_entry(tool_call_message.to_h)
137
169
 
138
- tool_name = tool_call["name"]
139
- tool_arguments = tool_call["arguments"]
140
- tool_klass = available_model_tools_map[tool_name]
170
+ if required_tool.present? && tool_name != required_tool.tool_name
171
+ error_content = "Error: This iteration required the tool '#{required_tool.tool_name}', but the model called '#{tool_name}' instead."
172
+ handle_iteration_error(error_content, required_tool:)
173
+ return
174
+ end
141
175
 
142
176
  # The model tried to use a tool that doesn't exist
143
177
  if tool_klass.blank?
144
178
  error_content = "Error: Tool '#{tool_name}' is not a valid tool. " \
145
179
  "Available tools: #{available_model_tools_map.keys.join(", ")}"
146
- error_message = Raif::Messages::UserMessage.new(content: error_content)
147
- add_conversation_history_entry(error_message.to_h)
180
+ handle_iteration_error(error_content, required_tool:)
148
181
  return
149
182
  end
150
183
 
@@ -152,8 +185,7 @@ module Raif
152
185
  unless JSON::Validator.validate(tool_klass.tool_arguments_schema, tool_arguments)
153
186
  error_content = "Error: Invalid tool arguments for the tool '#{tool_name}'. " \
154
187
  "Tool arguments schema: #{tool_klass.tool_arguments_schema.to_json}"
155
- error_message = Raif::Messages::UserMessage.new(content: error_content)
156
- add_conversation_history_entry(error_message.to_h)
188
+ handle_iteration_error(error_content, required_tool:)
157
189
  return
158
190
  end
159
191
 
@@ -171,6 +203,56 @@ module Raif
171
203
  end
172
204
  end
173
205
 
206
+ def validate_successful_completion
207
+ return if failed? || final_answer.present?
208
+
209
+ fail_run!("Agent completed without calling agent_final_answer")
210
+ end
211
+
212
+ def required_tool_warning_message(required_tool)
213
+ if required_tool == final_answer_tool
214
+ if final_iteration?
215
+ I18n.t("raif.agents.native_tool_calling_agent.final_answer_warning")
216
+ else
217
+ "Warning: This iteration requires the agent_final_answer tool. If you do not use it now, the next iteration will be your final chance."
218
+ end
219
+ else
220
+ "Warning: This iteration requires the #{required_tool.tool_name} tool."
221
+ end
222
+ end
223
+
224
+ def current_iteration_required_tool
225
+ if @required_tool_iteration_count != iteration_count
226
+ @required_tool_iteration_count = iteration_count
227
+ @current_iteration_required_tool = required_tool_for_iteration
228
+ end
229
+
230
+ @current_iteration_required_tool
231
+ end
232
+
233
+ def handle_iteration_error(error_content, required_tool: nil)
234
+ error_message = Raif::Messages::UserMessage.new(content: error_content)
235
+ add_conversation_history_entry(error_message.to_h)
236
+
237
+ return if required_tool.blank? || retry_iteration_available?
238
+
239
+ fail_run!(error_content)
240
+ end
241
+
242
+ def retry_iteration_available?
243
+ iteration_count < max_iterations
244
+ end
245
+
246
+ def log_required_tool_choice_fallback_once!
247
+ return if @logged_required_tool_choice_fallback
248
+
249
+ @logged_required_tool_choice_fallback = true
250
+ Raif.logger.warn(
251
+ "NativeToolCallingAgent is falling back to runtime tool-call validation because #{llm.key} " \
252
+ "cannot faithfully enforce tool_choice: :required for tools: #{available_model_tools_map.keys.join(", ")}"
253
+ )
254
+ end
255
+
174
256
  def ensure_llm_supports_native_tool_use
175
257
  unless llm.supports_native_tool_use?
176
258
  errors.add(:base, "Raif::Agent#llm_model_key must use an LLM that supports native tool use")
@@ -0,0 +1,88 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Raif
4
+ module Concerns
5
+ module HasPromptTemplates
6
+ extend ActiveSupport::Concern
7
+
8
+ class TemplateContext < ActionView::Base.with_empty_template_cache
9
+ def initialize(lookup_context, instance)
10
+ super(lookup_context, {}, nil)
11
+ @_instance = instance
12
+ end
13
+
14
+ def method_missing(method_name, ...)
15
+ if @_instance.respond_to?(method_name)
16
+ @_instance.public_send(method_name, ...)
17
+ else
18
+ super
19
+ end
20
+ end
21
+
22
+ def respond_to_missing?(method_name, include_private = false)
23
+ @_instance.respond_to?(method_name, include_private) || super
24
+ end
25
+ end
26
+
27
+ class_methods do
28
+ # Returns the template prefix path derived from the class name.
29
+ # e.g. Raif::Tasks::SummarizeDocument -> "raif/tasks/summarize_document"
30
+ # e.g. Raif::Tasks::Docs::Summarize -> "raif/tasks/docs/summarize"
31
+ def prompt_template_prefix
32
+ name.underscore
33
+ end
34
+
35
+ def prompt_template_view_paths
36
+ ActionController::Base.view_paths
37
+ end
38
+ end
39
+
40
+ def build_prompt
41
+ if prompt_template_exists?(:prompt)
42
+ render_prompt_template(:prompt)
43
+ else
44
+ super
45
+ end
46
+ end
47
+
48
+ def build_system_prompt
49
+ if prompt_template_exists?(:system_prompt)
50
+ render_prompt_template(:system_prompt)
51
+ else
52
+ super
53
+ end
54
+ end
55
+
56
+ private
57
+
58
+ def prompt_template_name
59
+ self.class.prompt_template_prefix.split("/").last
60
+ end
61
+
62
+ def prompt_template_dir
63
+ File.dirname(self.class.prompt_template_prefix)
64
+ end
65
+
66
+ def prompt_template_exists?(template_type)
67
+ prompt_lookup_context_for(template_type).exists?(prompt_template_name, prompt_template_dir)
68
+ end
69
+
70
+ def prompt_lookup_context_for(template_type)
71
+ lookup = ActionView::LookupContext.new(self.class.prompt_template_view_paths)
72
+ lookup.formats = [template_type]
73
+ lookup
74
+ end
75
+
76
+ def render_prompt_template(template_type)
77
+ lookup = prompt_lookup_context_for(template_type)
78
+ context = TemplateContext.new(lookup, self)
79
+ context.render(template: "#{prompt_template_dir}/#{prompt_template_name}").strip
80
+ rescue ActionView::Template::Error, ActionView::MissingTemplate => e
81
+ raise Raif::Errors::PromptTemplateError.new(
82
+ template_path: "#{self.class.prompt_template_prefix}.#{template_type}.erb",
83
+ original_error: e
84
+ )
85
+ end
86
+ end
87
+ end
88
+ end