raif 1.4.0 → 1.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +2 -2
- data/app/assets/builds/raif_admin.css +40 -2
- data/app/assets/builds/raif_admin_sprockets.js +2709 -0
- data/app/assets/javascript/raif/admin/copy_to_clipboard_controller.js +132 -0
- data/app/assets/javascript/raif/admin/cost_estimate_controller.js +80 -0
- data/app/assets/javascript/raif/admin/judge_config_controller.js +23 -0
- data/app/assets/javascript/raif/admin/select_all_checkboxes_controller.js +33 -0
- data/app/assets/javascript/raif/admin/sortable_table_controller.js +51 -0
- data/app/assets/javascript/raif/admin/table_search_controller.js +15 -0
- data/app/assets/javascript/raif/admin/tom_select_controller.js +33 -0
- data/app/assets/javascript/raif_admin.js +23 -0
- data/app/assets/javascript/raif_admin_sprockets.js +24 -0
- data/app/assets/stylesheets/raif_admin.scss +50 -1
- data/app/controllers/raif/admin/agents_controller.rb +27 -1
- data/app/controllers/raif/admin/configs_controller.rb +1 -0
- data/app/controllers/raif/admin/llms_controller.rb +27 -0
- data/app/controllers/raif/admin/model_completions_controller.rb +6 -0
- data/app/controllers/raif/admin/prompt_studio/agents_controller.rb +25 -0
- data/app/controllers/raif/admin/prompt_studio/base_controller.rb +32 -0
- data/app/controllers/raif/admin/prompt_studio/batch_runs_controller.rb +102 -0
- data/app/controllers/raif/admin/prompt_studio/conversations_controller.rb +25 -0
- data/app/controllers/raif/admin/prompt_studio/tasks_controller.rb +64 -0
- data/app/controllers/raif/admin/tasks_controller.rb +5 -0
- data/app/helpers/raif/application_helper.rb +40 -0
- data/app/jobs/raif/prompt_studio_batch_run_item_job.rb +11 -0
- data/app/jobs/raif/prompt_studio_batch_run_job.rb +15 -0
- data/app/jobs/raif/prompt_studio_task_run_job.rb +36 -0
- data/app/models/raif/agent.rb +36 -5
- data/app/models/raif/agents/native_tool_calling_agent.rb +101 -19
- data/app/models/raif/concerns/has_prompt_templates.rb +88 -0
- data/app/models/raif/concerns/has_runtime_duration.rb +41 -0
- data/app/models/raif/concerns/json_schema_definition.rb +16 -3
- data/app/models/raif/concerns/llm_prompt_caching.rb +20 -0
- data/app/models/raif/concerns/llms/anthropic/message_formatting.rb +6 -0
- data/app/models/raif/concerns/llms/anthropic/tool_formatting.rb +5 -1
- data/app/models/raif/concerns/llms/bedrock/message_formatting.rb +7 -0
- data/app/models/raif/concerns/llms/bedrock/tool_formatting.rb +4 -0
- data/app/models/raif/concerns/llms/google/message_formatting.rb +5 -2
- data/app/models/raif/concerns/llms/google/tool_formatting.rb +4 -0
- data/app/models/raif/concerns/llms/message_formatting.rb +30 -0
- data/app/models/raif/concerns/llms/open_ai_completions/response_tool_calls.rb +1 -1
- data/app/models/raif/concerns/llms/open_ai_completions/tool_formatting.rb +4 -0
- data/app/models/raif/concerns/llms/open_ai_responses/tool_formatting.rb +4 -0
- data/app/models/raif/concerns/provider_managed_tool_calls.rb +162 -0
- data/app/models/raif/conversation.rb +24 -3
- data/app/models/raif/conversation_entry.rb +6 -3
- data/app/models/raif/embedding_models/bedrock.rb +10 -1
- data/app/models/raif/embedding_models/google.rb +37 -0
- data/app/models/raif/evals/llm_judge.rb +70 -0
- data/{lib → app/models}/raif/evals/llm_judges/binary.rb +38 -0
- data/{lib → app/models}/raif/evals/llm_judges/comparative.rb +38 -0
- data/{lib → app/models}/raif/evals/llm_judges/scored.rb +38 -0
- data/{lib → app/models}/raif/evals/llm_judges/summarization.rb +38 -0
- data/app/models/raif/llm.rb +82 -7
- data/app/models/raif/llms/anthropic.rb +26 -4
- data/app/models/raif/llms/bedrock.rb +59 -5
- data/app/models/raif/llms/google.rb +28 -2
- data/app/models/raif/llms/open_ai_base.rb +4 -0
- data/app/models/raif/llms/open_ai_completions.rb +9 -2
- data/app/models/raif/llms/open_ai_responses.rb +9 -2
- data/app/models/raif/llms/open_router.rb +10 -3
- data/app/models/raif/model_completion.rb +75 -34
- data/app/models/raif/model_tool.rb +45 -3
- data/app/models/raif/model_tool_invocation.rb +31 -1
- data/app/models/raif/prompt_studio_batch_run.rb +155 -0
- data/app/models/raif/prompt_studio_batch_run_item.rb +220 -0
- data/app/models/raif/streaming_responses/bedrock.rb +60 -1
- data/app/models/raif/task.rb +30 -6
- data/app/views/layouts/raif/admin.html.erb +31 -1
- data/app/views/raif/admin/agents/_agent.html.erb +1 -0
- data/app/views/raif/admin/agents/index.html.erb +48 -0
- data/app/views/raif/admin/agents/show.html.erb +4 -0
- data/app/views/raif/admin/llms/index.html.erb +110 -0
- data/app/views/raif/admin/model_completions/_model_completion.html.erb +3 -7
- data/app/views/raif/admin/model_completions/index.html.erb +14 -1
- data/app/views/raif/admin/model_completions/show.html.erb +164 -55
- data/app/views/raif/admin/model_tool_invocations/index.html.erb +1 -1
- data/app/views/raif/admin/model_tool_invocations/show.html.erb +18 -0
- data/app/views/raif/admin/prompt_studio/agents/index.html.erb +56 -0
- data/app/views/raif/admin/prompt_studio/agents/show.html.erb +57 -0
- data/app/views/raif/admin/prompt_studio/batch_runs/_batch_run_item.html.erb +54 -0
- data/app/views/raif/admin/prompt_studio/batch_runs/_judge_config_fields.html.erb +76 -0
- data/app/views/raif/admin/prompt_studio/batch_runs/_judge_detail_modal.html.erb +27 -0
- data/app/views/raif/admin/prompt_studio/batch_runs/_modal.html.erb +35 -0
- data/app/views/raif/admin/prompt_studio/batch_runs/_progress.html.erb +78 -0
- data/app/views/raif/admin/prompt_studio/batch_runs/show.html.erb +49 -0
- data/app/views/raif/admin/prompt_studio/conversations/index.html.erb +48 -0
- data/app/views/raif/admin/prompt_studio/conversations/show.html.erb +36 -0
- data/app/views/raif/admin/prompt_studio/shared/_nav_tabs.html.erb +17 -0
- data/app/views/raif/admin/prompt_studio/shared/_prompt_comparison.html.erb +87 -0
- data/app/views/raif/admin/prompt_studio/shared/_type_filter.html.erb +54 -0
- data/app/views/raif/admin/prompt_studio/tasks/_task_result.html.erb +145 -0
- data/app/views/raif/admin/prompt_studio/tasks/_task_row.html.erb +12 -0
- data/app/views/raif/admin/prompt_studio/tasks/_task_type_filter.html.erb +58 -0
- data/app/views/raif/admin/prompt_studio/tasks/_tasks_table.html.erb +22 -0
- data/app/views/raif/admin/prompt_studio/tasks/index.html.erb +35 -0
- data/app/views/raif/admin/prompt_studio/tasks/show.html.erb +19 -0
- data/app/views/raif/admin/tasks/_task.html.erb +1 -0
- data/app/views/raif/admin/tasks/index.html.erb +17 -5
- data/app/views/raif/admin/tasks/show.html.erb +20 -0
- data/app/views/raif/conversation_entries/_message.html.erb +10 -6
- data/config/importmap.rb +8 -0
- data/config/locales/admin.en.yml +128 -0
- data/config/locales/en.yml +36 -2
- data/config/routes.rb +8 -0
- data/db/migrate/20260307000000_add_prompt_studio_run_to_raif_tasks.rb +7 -0
- data/db/migrate/20260308000000_create_raif_prompt_studio_batch_runs.rb +27 -0
- data/db/migrate/20260308000001_create_raif_prompt_studio_batch_run_items.rb +24 -0
- data/db/migrate/20260407000000_add_cache_token_columns_to_raif_model_completions.rb +8 -0
- data/lib/generators/raif/agent/agent_generator.rb +18 -0
- data/lib/generators/raif/agent/templates/agent.rb.tt +7 -5
- data/lib/generators/raif/agent/templates/system_prompt.erb.tt +3 -0
- data/lib/generators/raif/conversation/conversation_generator.rb +19 -1
- data/lib/generators/raif/conversation/templates/system_prompt.erb.tt +4 -0
- data/lib/generators/raif/install/templates/initializer.rb +68 -27
- data/lib/generators/raif/task/task_generator.rb +18 -0
- data/lib/generators/raif/task/templates/prompt.erb.tt +4 -0
- data/lib/generators/raif/task/templates/task.rb.tt +9 -8
- data/lib/raif/configuration.rb +10 -0
- data/lib/raif/embedding_model_registry.rb +8 -0
- data/lib/raif/engine.rb +16 -1
- data/lib/raif/errors/blank_response_error.rb +8 -0
- data/lib/raif/errors/prompt_template_error.rb +15 -0
- data/lib/raif/errors.rb +2 -0
- data/lib/raif/evals.rb +0 -6
- data/lib/raif/llm_registry.rb +230 -9
- data/lib/raif/prompt_studio_comparison_builder.rb +138 -0
- data/lib/raif/token_estimator.rb +28 -0
- data/lib/raif/version.rb +1 -1
- data/lib/raif.rb +2 -0
- data/spec/support/rspec_helpers.rb +7 -1
- data/spec/support/test_task.rb +9 -0
- data/spec/support/test_template_task.rb +41 -0
- metadata +65 -7
- data/lib/raif/evals/llm_judge.rb +0 -32
- /data/{lib → app/models}/raif/evals/scoring_rubric.rb +0 -0
|
@@ -39,6 +39,8 @@ private
|
|
|
39
39
|
end
|
|
40
40
|
|
|
41
41
|
def update_model_completion(model_completion, response_json)
|
|
42
|
+
return if response_json.nil?
|
|
43
|
+
|
|
42
44
|
raw_response = if model_completion.response_format_json?
|
|
43
45
|
extract_json_response(response_json)
|
|
44
46
|
else
|
|
@@ -52,7 +54,8 @@ private
|
|
|
52
54
|
response_array: response_json["choices"],
|
|
53
55
|
completion_tokens: response_json.dig("usage", "completion_tokens"),
|
|
54
56
|
prompt_tokens: response_json.dig("usage", "prompt_tokens"),
|
|
55
|
-
total_tokens: response_json.dig("usage", "total_tokens")
|
|
57
|
+
total_tokens: response_json.dig("usage", "total_tokens"),
|
|
58
|
+
cache_read_input_tokens: response_json.dig("usage", "prompt_tokens_details", "cached_tokens")
|
|
56
59
|
)
|
|
57
60
|
end
|
|
58
61
|
|
|
@@ -87,9 +90,13 @@ private
|
|
|
87
90
|
|
|
88
91
|
params[:tools] = tools unless tools.blank?
|
|
89
92
|
|
|
90
|
-
if model_completion.tool_choice
|
|
93
|
+
if model_completion.tool_choice == "required"
|
|
94
|
+
params[:tool_choice] = build_required_tool_choice
|
|
95
|
+
params[:parallel_tool_calls] = false unless tools.blank?
|
|
96
|
+
elsif model_completion.tool_choice.present?
|
|
91
97
|
tool_klass = model_completion.tool_choice.constantize
|
|
92
98
|
params[:tool_choice] = build_forced_tool_choice(tool_klass.tool_name)
|
|
99
|
+
params[:parallel_tool_calls] = false unless tools.blank?
|
|
93
100
|
end
|
|
94
101
|
end
|
|
95
102
|
|
|
@@ -114,7 +121,7 @@ private
|
|
|
114
121
|
end
|
|
115
122
|
|
|
116
123
|
def extract_json_response(resp)
|
|
117
|
-
tool_calls = resp
|
|
124
|
+
tool_calls = resp&.dig("choices", 0, "message", "tool_calls")
|
|
118
125
|
return extract_text_response(resp) if tool_calls.blank?
|
|
119
126
|
|
|
120
127
|
tool_response = tool_calls.find do |tool_call|
|
|
@@ -4,39 +4,41 @@
|
|
|
4
4
|
#
|
|
5
5
|
# Table name: raif_model_completions
|
|
6
6
|
#
|
|
7
|
-
# id
|
|
8
|
-
# available_model_tools
|
|
9
|
-
#
|
|
10
|
-
#
|
|
11
|
-
#
|
|
12
|
-
#
|
|
13
|
-
#
|
|
14
|
-
#
|
|
15
|
-
#
|
|
16
|
-
#
|
|
17
|
-
#
|
|
18
|
-
#
|
|
19
|
-
#
|
|
20
|
-
#
|
|
21
|
-
#
|
|
22
|
-
#
|
|
23
|
-
#
|
|
24
|
-
#
|
|
25
|
-
#
|
|
26
|
-
#
|
|
27
|
-
#
|
|
28
|
-
#
|
|
29
|
-
#
|
|
30
|
-
#
|
|
31
|
-
#
|
|
32
|
-
#
|
|
33
|
-
#
|
|
34
|
-
#
|
|
35
|
-
#
|
|
36
|
-
#
|
|
37
|
-
#
|
|
38
|
-
#
|
|
39
|
-
#
|
|
7
|
+
# id :bigint not null, primary key
|
|
8
|
+
# available_model_tools :jsonb not null
|
|
9
|
+
# cache_creation_input_tokens :integer
|
|
10
|
+
# cache_read_input_tokens :integer
|
|
11
|
+
# citations :jsonb
|
|
12
|
+
# completed_at :datetime
|
|
13
|
+
# completion_tokens :integer
|
|
14
|
+
# failed_at :datetime
|
|
15
|
+
# failure_error :string
|
|
16
|
+
# failure_reason :text
|
|
17
|
+
# llm_model_key :string not null
|
|
18
|
+
# max_completion_tokens :integer
|
|
19
|
+
# messages :jsonb not null
|
|
20
|
+
# model_api_name :string not null
|
|
21
|
+
# output_token_cost :decimal(10, 6)
|
|
22
|
+
# prompt_token_cost :decimal(10, 6)
|
|
23
|
+
# prompt_tokens :integer
|
|
24
|
+
# raw_response :text
|
|
25
|
+
# response_array :jsonb
|
|
26
|
+
# response_format :integer default("text"), not null
|
|
27
|
+
# response_format_parameter :string
|
|
28
|
+
# response_tool_calls :jsonb
|
|
29
|
+
# retry_count :integer default(0), not null
|
|
30
|
+
# source_type :string
|
|
31
|
+
# started_at :datetime
|
|
32
|
+
# stream_response :boolean default(FALSE), not null
|
|
33
|
+
# system_prompt :text
|
|
34
|
+
# temperature :decimal(5, 3)
|
|
35
|
+
# tool_choice :string
|
|
36
|
+
# total_cost :decimal(10, 6)
|
|
37
|
+
# total_tokens :integer
|
|
38
|
+
# created_at :datetime not null
|
|
39
|
+
# updated_at :datetime not null
|
|
40
|
+
# response_id :string
|
|
41
|
+
# source_id :bigint
|
|
40
42
|
#
|
|
41
43
|
# Indexes
|
|
42
44
|
#
|
|
@@ -49,8 +51,12 @@
|
|
|
49
51
|
class Raif::ModelCompletion < Raif::ApplicationRecord
|
|
50
52
|
include Raif::Concerns::LlmResponseParsing
|
|
51
53
|
include Raif::Concerns::HasAvailableModelTools
|
|
54
|
+
include Raif::Concerns::HasRuntimeDuration
|
|
55
|
+
include Raif::Concerns::ProviderManagedToolCalls
|
|
52
56
|
include Raif::Concerns::BooleanTimestamp
|
|
53
57
|
|
|
58
|
+
attr_accessor :anthropic_prompt_caching_enabled, :bedrock_prompt_caching_enabled
|
|
59
|
+
|
|
54
60
|
boolean_timestamp :started_at
|
|
55
61
|
boolean_timestamp :completed_at
|
|
56
62
|
boolean_timestamp :failed_at
|
|
@@ -82,8 +88,12 @@ class Raif::ModelCompletion < Raif::ApplicationRecord
|
|
|
82
88
|
end
|
|
83
89
|
|
|
84
90
|
def calculate_costs
|
|
91
|
+
# Each retry resends the same prompt, so the provider charges input tokens
|
|
92
|
+
# for every attempt. Factor in retry_count to reflect actual billing.
|
|
93
|
+
total_attempts = (retry_count || 0) + 1
|
|
94
|
+
|
|
85
95
|
if prompt_tokens.present? && llm_config[:input_token_cost].present?
|
|
86
|
-
self.prompt_token_cost =
|
|
96
|
+
self.prompt_token_cost = calculate_prompt_token_cost(total_attempts)
|
|
87
97
|
end
|
|
88
98
|
|
|
89
99
|
if completion_tokens.present? && llm_config[:output_token_cost].present?
|
|
@@ -104,6 +114,37 @@ class Raif::ModelCompletion < Raif::ApplicationRecord
|
|
|
104
114
|
|
|
105
115
|
private
|
|
106
116
|
|
|
117
|
+
def calculate_prompt_token_cost(total_attempts)
|
|
118
|
+
input_cost = llm_config[:input_token_cost]
|
|
119
|
+
llm_class = llm_config[:llm_class]
|
|
120
|
+
cache_read_multiplier = llm_class&.cache_read_input_token_cost_multiplier
|
|
121
|
+
cache_creation_multiplier = llm_class&.cache_creation_input_token_cost_multiplier
|
|
122
|
+
cached_reads = cache_read_input_tokens.to_i
|
|
123
|
+
cached_writes = cache_creation_input_tokens.to_i
|
|
124
|
+
|
|
125
|
+
if cached_reads > 0 && cache_read_multiplier.present?
|
|
126
|
+
cache_read_cost = input_cost * cache_read_multiplier
|
|
127
|
+
|
|
128
|
+
if llm_class.prompt_tokens_include_cached_tokens?
|
|
129
|
+
# OpenAI / Google / OpenRouter: cached tokens are a subset of prompt_tokens
|
|
130
|
+
non_cached = prompt_tokens - cached_reads
|
|
131
|
+
cost = (non_cached * input_cost) + (cached_reads * cache_read_cost)
|
|
132
|
+
else
|
|
133
|
+
# Anthropic / Bedrock: cached tokens are separate from prompt_tokens
|
|
134
|
+
cost = (prompt_tokens * input_cost) + (cached_reads * cache_read_cost)
|
|
135
|
+
end
|
|
136
|
+
else
|
|
137
|
+
cost = prompt_tokens * input_cost
|
|
138
|
+
end
|
|
139
|
+
|
|
140
|
+
# Cache creation surcharge (Anthropic / Bedrock)
|
|
141
|
+
if cached_writes > 0 && cache_creation_multiplier.present?
|
|
142
|
+
cost += cached_writes * input_cost * cache_creation_multiplier
|
|
143
|
+
end
|
|
144
|
+
|
|
145
|
+
cost * total_attempts
|
|
146
|
+
end
|
|
147
|
+
|
|
107
148
|
def llm_config
|
|
108
149
|
@llm_config ||= Raif.llm_config(llm_model_key.to_sym)
|
|
109
150
|
end
|
|
@@ -53,9 +53,9 @@ class Raif::ModelTool
|
|
|
53
53
|
name.gsub("Raif::ModelTools::", "").underscore
|
|
54
54
|
end
|
|
55
55
|
|
|
56
|
-
def tool_arguments_schema(&block)
|
|
56
|
+
def tool_arguments_schema(dynamic: false, &block)
|
|
57
57
|
if block_given?
|
|
58
|
-
json_schema_definition(:tool_arguments, &block)
|
|
58
|
+
json_schema_definition(:tool_arguments, dynamic: dynamic, &block)
|
|
59
59
|
elsif schema_defined?(:tool_arguments)
|
|
60
60
|
schema_for(:tool_arguments)
|
|
61
61
|
else
|
|
@@ -77,11 +77,13 @@ class Raif::ModelTool
|
|
|
77
77
|
end
|
|
78
78
|
|
|
79
79
|
def invoke_tool(provider_tool_call_id:, tool_arguments:, source:)
|
|
80
|
+
prepared_arguments = prepare_tool_arguments(tool_arguments)
|
|
81
|
+
|
|
80
82
|
tool_invocation = Raif::ModelToolInvocation.new(
|
|
81
83
|
provider_tool_call_id: provider_tool_call_id,
|
|
82
84
|
source: source,
|
|
83
85
|
tool_type: name,
|
|
84
|
-
tool_arguments:
|
|
86
|
+
tool_arguments: prepared_arguments
|
|
85
87
|
)
|
|
86
88
|
|
|
87
89
|
ActiveRecord::Base.transaction do
|
|
@@ -95,6 +97,46 @@ class Raif::ModelTool
|
|
|
95
97
|
tool_invocation.failed!
|
|
96
98
|
raise e
|
|
97
99
|
end
|
|
100
|
+
|
|
101
|
+
# Prepares tool arguments before validation and invocation. Override in subclasses
|
|
102
|
+
# to add tool-specific argument processing (e.g. type coercion, default injection).
|
|
103
|
+
# The base implementation strips keys not declared in the tool's argument schema,
|
|
104
|
+
# which handles LLMs that hallucinate extra parameters.
|
|
105
|
+
#
|
|
106
|
+
# @param arguments [Hash] The raw tool arguments from the LLM response
|
|
107
|
+
# @return [Hash] The prepared arguments ready for validation and processing
|
|
108
|
+
def prepare_tool_arguments(arguments)
|
|
109
|
+
strip_unknown_tool_arguments(arguments)
|
|
110
|
+
end
|
|
111
|
+
|
|
112
|
+
private
|
|
113
|
+
|
|
114
|
+
# Removes keys from the arguments hash that are not declared in the tool's
|
|
115
|
+
# argument schema. Logs a warning when keys are stripped so hallucination
|
|
116
|
+
# patterns can be monitored. Normalizes all keys to strings for consistent
|
|
117
|
+
# comparison since the schema builder uses symbol keys and LLM responses
|
|
118
|
+
# use string keys.
|
|
119
|
+
#
|
|
120
|
+
# @param arguments [Hash] The raw tool arguments
|
|
121
|
+
# @return [Hash] The arguments with only schema-declared keys
|
|
122
|
+
def strip_unknown_tool_arguments(arguments)
|
|
123
|
+
return arguments unless arguments.is_a?(Hash)
|
|
124
|
+
|
|
125
|
+
schema_properties = tool_arguments_schema[:properties] || tool_arguments_schema["properties"]
|
|
126
|
+
return arguments if schema_properties.blank?
|
|
127
|
+
|
|
128
|
+
normalized_arguments = arguments.deep_stringify_keys
|
|
129
|
+
allowed_keys = schema_properties.keys.map(&:to_s)
|
|
130
|
+
dropped_keys = normalized_arguments.keys - allowed_keys
|
|
131
|
+
|
|
132
|
+
if dropped_keys.any?
|
|
133
|
+
Rails.logger.warn(
|
|
134
|
+
"[Raif::ModelTool] Stripped unexpected tool arguments for #{name}: #{dropped_keys.join(", ")}"
|
|
135
|
+
)
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
normalized_arguments.slice(*allowed_keys)
|
|
139
|
+
end
|
|
98
140
|
end
|
|
99
141
|
|
|
100
142
|
# Instance method to get the tool arguments schema
|
|
@@ -56,7 +56,7 @@ class Raif::ModelToolInvocation < Raif::ApplicationRecord
|
|
|
56
56
|
|
|
57
57
|
# Returns tool result in the format expected by LLM message formatting
|
|
58
58
|
# @return [Hash] Hash representation for JSONB storage and LLM APIs
|
|
59
|
-
def as_tool_call_result_message
|
|
59
|
+
def as_tool_call_result_message(result: self.result)
|
|
60
60
|
Raif::Messages::ToolCallResult.new(
|
|
61
61
|
provider_tool_call_id: provider_tool_call_id,
|
|
62
62
|
name: tool_name,
|
|
@@ -68,10 +68,40 @@ class Raif::ModelToolInvocation < Raif::ApplicationRecord
|
|
|
68
68
|
"raif/model_tool_invocations/#{tool.invocation_partial_name}"
|
|
69
69
|
end
|
|
70
70
|
|
|
71
|
+
def admin_observation
|
|
72
|
+
admin_observation_result[:observation]
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
def admin_observation_error
|
|
76
|
+
admin_observation_result[:error]
|
|
77
|
+
end
|
|
78
|
+
|
|
79
|
+
def admin_observation_available?
|
|
80
|
+
admin_observation.present? || admin_observation_error.present?
|
|
81
|
+
end
|
|
82
|
+
|
|
71
83
|
def ensure_valid_tool_argument_schema
|
|
72
84
|
unless JSON::Validator.validate(tool_arguments_schema, tool_arguments)
|
|
73
85
|
errors.add(:tool_arguments, "does not match schema")
|
|
74
86
|
end
|
|
75
87
|
end
|
|
76
88
|
|
|
89
|
+
private
|
|
90
|
+
|
|
91
|
+
# Best-effort reconstruction of the observation shown in admin. This uses the
|
|
92
|
+
# current formatter code against persisted invocation data, so failures are
|
|
93
|
+
# captured for display instead of breaking the page render.
|
|
94
|
+
def admin_observation_result
|
|
95
|
+
@admin_observation_result ||= if completed? && triggers_observation_to_model?
|
|
96
|
+
begin
|
|
97
|
+
observation = tool.observation_for_invocation(self)
|
|
98
|
+
{ observation: observation.presence, error: nil }
|
|
99
|
+
rescue StandardError => e
|
|
100
|
+
{ observation: nil, error: e.message }
|
|
101
|
+
end
|
|
102
|
+
else
|
|
103
|
+
{ observation: nil, error: nil }
|
|
104
|
+
end
|
|
105
|
+
end
|
|
106
|
+
|
|
77
107
|
end
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
# == Schema Information
|
|
4
|
+
#
|
|
5
|
+
# Table name: raif_prompt_studio_batch_runs
|
|
6
|
+
#
|
|
7
|
+
# id :bigint not null, primary key
|
|
8
|
+
# completed_at :datetime
|
|
9
|
+
# completed_count :integer default(0), not null
|
|
10
|
+
# failed_at :datetime
|
|
11
|
+
# failed_count :integer default(0), not null
|
|
12
|
+
# judge_config :jsonb not null
|
|
13
|
+
# judge_llm_model_key :string
|
|
14
|
+
# judge_type :string
|
|
15
|
+
# llm_model_key :string not null
|
|
16
|
+
# started_at :datetime
|
|
17
|
+
# task_type :string not null
|
|
18
|
+
# total_count :integer default(0), not null
|
|
19
|
+
# created_at :datetime not null
|
|
20
|
+
# updated_at :datetime not null
|
|
21
|
+
#
|
|
22
|
+
|
|
23
|
+
module Raif
|
|
24
|
+
class PromptStudioBatchRun < Raif::ApplicationRecord
|
|
25
|
+
ALLOWED_JUDGE_TYPES = [
|
|
26
|
+
"Raif::Evals::LlmJudges::Binary",
|
|
27
|
+
"Raif::Evals::LlmJudges::Scored",
|
|
28
|
+
"Raif::Evals::LlmJudges::Comparative",
|
|
29
|
+
"Raif::Evals::LlmJudges::Summarization"
|
|
30
|
+
].freeze
|
|
31
|
+
|
|
32
|
+
after_initialize -> { self.judge_config ||= {} }
|
|
33
|
+
|
|
34
|
+
has_many :items,
|
|
35
|
+
class_name: "Raif::PromptStudioBatchRunItem",
|
|
36
|
+
foreign_key: :batch_run_id,
|
|
37
|
+
dependent: :destroy,
|
|
38
|
+
inverse_of: :batch_run
|
|
39
|
+
|
|
40
|
+
boolean_timestamp :started_at
|
|
41
|
+
boolean_timestamp :completed_at
|
|
42
|
+
boolean_timestamp :failed_at
|
|
43
|
+
|
|
44
|
+
validates :task_type, presence: true
|
|
45
|
+
validates :llm_model_key, presence: true
|
|
46
|
+
validates :judge_type, inclusion: { in: ALLOWED_JUDGE_TYPES }, allow_nil: true
|
|
47
|
+
|
|
48
|
+
def status
|
|
49
|
+
if completed_at?
|
|
50
|
+
:completed
|
|
51
|
+
elsif failed_at?
|
|
52
|
+
:failed
|
|
53
|
+
elsif started_at?
|
|
54
|
+
:in_progress
|
|
55
|
+
else
|
|
56
|
+
:pending
|
|
57
|
+
end
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
def progress_percentage
|
|
61
|
+
return 0 if total_count.zero?
|
|
62
|
+
|
|
63
|
+
((completed_count + failed_count).to_f / total_count * 100).round
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
def has_judge?
|
|
67
|
+
judge_type.present?
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
def judge_class
|
|
71
|
+
judge_type&.safe_constantize
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
def judge_pass_rate
|
|
75
|
+
judge_tasks = completed_judge_tasks
|
|
76
|
+
return if judge_tasks.empty?
|
|
77
|
+
|
|
78
|
+
pass_count = judge_tasks.count(&:passes?)
|
|
79
|
+
percentage = ((pass_count.to_f / judge_tasks.size) * 100).round
|
|
80
|
+
"#{percentage}% (#{pass_count}/#{judge_tasks.size})"
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
def judge_average_score
|
|
84
|
+
scores = completed_judge_tasks.filter_map(&:judgment_score)
|
|
85
|
+
return if scores.empty?
|
|
86
|
+
|
|
87
|
+
(scores.sum.to_f / scores.size).round(1)
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
def judge_comparative_summary
|
|
91
|
+
completed_items = items.where.not(judge_task_id: nil).includes(:judge_task)
|
|
92
|
+
return if completed_items.empty?
|
|
93
|
+
|
|
94
|
+
new_wins = 0
|
|
95
|
+
original_wins = 0
|
|
96
|
+
ties = 0
|
|
97
|
+
|
|
98
|
+
completed_items.each do |item|
|
|
99
|
+
next unless item.judge_task&.completed?
|
|
100
|
+
|
|
101
|
+
parsed = item.judge_task.parsed_response
|
|
102
|
+
next unless parsed.is_a?(Hash)
|
|
103
|
+
|
|
104
|
+
winner = parsed["winner"]
|
|
105
|
+
if winner == "tie"
|
|
106
|
+
ties += 1
|
|
107
|
+
elsif winner == item.metadata&.dig("new_response_letter")
|
|
108
|
+
new_wins += 1
|
|
109
|
+
else
|
|
110
|
+
original_wins += 1
|
|
111
|
+
end
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
total = new_wins + original_wins + ties
|
|
115
|
+
return if total.zero?
|
|
116
|
+
|
|
117
|
+
{
|
|
118
|
+
new_wins: new_wins,
|
|
119
|
+
original_wins: original_wins,
|
|
120
|
+
ties: ties,
|
|
121
|
+
total: total,
|
|
122
|
+
new_win_pct: ((new_wins.to_f / total) * 100).round,
|
|
123
|
+
original_win_pct: ((original_wins.to_f / total) * 100).round,
|
|
124
|
+
tie_pct: ((ties.to_f / total) * 100).round
|
|
125
|
+
}
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
private
|
|
129
|
+
|
|
130
|
+
def completed_judge_tasks
|
|
131
|
+
Raif::Task.where(
|
|
132
|
+
id: items.where.not(judge_task_id: nil).select(:judge_task_id)
|
|
133
|
+
).where.not(completed_at: nil)
|
|
134
|
+
end
|
|
135
|
+
|
|
136
|
+
public
|
|
137
|
+
|
|
138
|
+
def check_completion!
|
|
139
|
+
reload
|
|
140
|
+
remaining = items.where(status: %w[pending running judging]).count
|
|
141
|
+
self.completed_count = items.where(status: "completed").count
|
|
142
|
+
self.failed_count = items.where(status: "failed").count
|
|
143
|
+
|
|
144
|
+
if remaining.zero?
|
|
145
|
+
if failed_count > 0 && completed_count == 0
|
|
146
|
+
self.failed_at = Time.current
|
|
147
|
+
else
|
|
148
|
+
self.completed_at = Time.current
|
|
149
|
+
end
|
|
150
|
+
end
|
|
151
|
+
|
|
152
|
+
save!
|
|
153
|
+
end
|
|
154
|
+
end
|
|
155
|
+
end
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
# == Schema Information
|
|
4
|
+
#
|
|
5
|
+
# Table name: raif_prompt_studio_batch_run_items
|
|
6
|
+
#
|
|
7
|
+
# id :bigint not null, primary key
|
|
8
|
+
# metadata :jsonb
|
|
9
|
+
# status :string default("pending"), not null
|
|
10
|
+
# created_at :datetime not null
|
|
11
|
+
# updated_at :datetime not null
|
|
12
|
+
# batch_run_id :bigint not null
|
|
13
|
+
# judge_task_id :bigint
|
|
14
|
+
# result_task_id :bigint
|
|
15
|
+
# source_task_id :bigint not null
|
|
16
|
+
#
|
|
17
|
+
# Indexes
|
|
18
|
+
#
|
|
19
|
+
# index_raif_prompt_studio_batch_run_items_on_batch_run_id (batch_run_id)
|
|
20
|
+
# index_raif_prompt_studio_batch_run_items_on_judge_task_id (judge_task_id)
|
|
21
|
+
# index_raif_prompt_studio_batch_run_items_on_result_task_id (result_task_id)
|
|
22
|
+
# index_raif_prompt_studio_batch_run_items_on_source_task_id (source_task_id)
|
|
23
|
+
# index_raif_prompt_studio_batch_run_items_on_status (status)
|
|
24
|
+
#
|
|
25
|
+
# Foreign Keys
|
|
26
|
+
#
|
|
27
|
+
# fk_rails_... (batch_run_id => raif_prompt_studio_batch_runs.id)
|
|
28
|
+
# fk_rails_... (judge_task_id => raif_tasks.id)
|
|
29
|
+
# fk_rails_... (result_task_id => raif_tasks.id)
|
|
30
|
+
# fk_rails_... (source_task_id => raif_tasks.id)
|
|
31
|
+
#
|
|
32
|
+
|
|
33
|
+
module Raif
|
|
34
|
+
class PromptStudioBatchRunItem < Raif::ApplicationRecord
|
|
35
|
+
include ActionView::RecordIdentifier
|
|
36
|
+
|
|
37
|
+
STATUSES = %w[pending running judging completed failed].freeze
|
|
38
|
+
|
|
39
|
+
after_initialize -> { self.metadata ||= {} }
|
|
40
|
+
|
|
41
|
+
belongs_to :batch_run,
|
|
42
|
+
class_name: "Raif::PromptStudioBatchRun",
|
|
43
|
+
inverse_of: :items
|
|
44
|
+
|
|
45
|
+
belongs_to :source_task,
|
|
46
|
+
class_name: "Raif::Task"
|
|
47
|
+
|
|
48
|
+
belongs_to :result_task,
|
|
49
|
+
class_name: "Raif::Task",
|
|
50
|
+
optional: true
|
|
51
|
+
|
|
52
|
+
belongs_to :judge_task,
|
|
53
|
+
class_name: "Raif::Task",
|
|
54
|
+
optional: true
|
|
55
|
+
|
|
56
|
+
validates :status, inclusion: { in: STATUSES }
|
|
57
|
+
|
|
58
|
+
def execute!
|
|
59
|
+
update!(status: "running")
|
|
60
|
+
broadcast_item
|
|
61
|
+
|
|
62
|
+
new_task = create_and_run_task
|
|
63
|
+
run_judge_if_configured(new_task)
|
|
64
|
+
|
|
65
|
+
update!(status: "completed")
|
|
66
|
+
rescue StandardError => e
|
|
67
|
+
Rails.logger.error "Error running batch run item ##{id}: #{e.message}"
|
|
68
|
+
Rails.logger.error e.backtrace&.join("\n")
|
|
69
|
+
|
|
70
|
+
update!(status: "failed")
|
|
71
|
+
ensure
|
|
72
|
+
broadcast_item
|
|
73
|
+
batch_run.check_completion!
|
|
74
|
+
broadcast_progress
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
def judge_summary
|
|
78
|
+
return unless judge_task&.completed?
|
|
79
|
+
|
|
80
|
+
parsed = judge_task.parsed_response
|
|
81
|
+
return unless parsed.is_a?(Hash)
|
|
82
|
+
|
|
83
|
+
case batch_run.judge_type
|
|
84
|
+
when "Raif::Evals::LlmJudges::Binary"
|
|
85
|
+
parsed["passes"] ? "PASS" : "FAIL"
|
|
86
|
+
when "Raif::Evals::LlmJudges::Scored"
|
|
87
|
+
"Score: #{parsed["score"]}"
|
|
88
|
+
when "Raif::Evals::LlmJudges::Comparative"
|
|
89
|
+
if parsed["winner"] == "tie"
|
|
90
|
+
I18n.t("raif.admin.prompt_studio.batch_runs.judge.tie")
|
|
91
|
+
else
|
|
92
|
+
winner_label = comparative_winner_label(parsed["winner"])
|
|
93
|
+
I18n.t("raif.admin.prompt_studio.batch_runs.judge.winner", name: winner_label)
|
|
94
|
+
end
|
|
95
|
+
when "Raif::Evals::LlmJudges::Summarization"
|
|
96
|
+
"Overall: #{parsed.dig("overall", "score")}/5"
|
|
97
|
+
end
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
def judge_reasoning
|
|
101
|
+
return unless judge_task&.completed?
|
|
102
|
+
|
|
103
|
+
parsed = judge_task.parsed_response
|
|
104
|
+
return unless parsed.is_a?(Hash)
|
|
105
|
+
|
|
106
|
+
parsed["reasoning"]
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
def comparative_winner_label(winner_letter)
|
|
110
|
+
new_response_letter = metadata&.dig("new_response_letter")
|
|
111
|
+
return winner_letter unless new_response_letter
|
|
112
|
+
|
|
113
|
+
if winner_letter == new_response_letter
|
|
114
|
+
I18n.t("raif.admin.prompt_studio.batch_runs.judge.new_response")
|
|
115
|
+
else
|
|
116
|
+
I18n.t("raif.admin.prompt_studio.batch_runs.judge.original_response")
|
|
117
|
+
end
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
private
|
|
121
|
+
|
|
122
|
+
def create_and_run_task
|
|
123
|
+
new_task = source_task.class.new(
|
|
124
|
+
creator: source_task.creator,
|
|
125
|
+
source: source_task,
|
|
126
|
+
llm_model_key: batch_run.llm_model_key,
|
|
127
|
+
available_model_tools: source_task.available_model_tools,
|
|
128
|
+
run_with: source_task.run_with,
|
|
129
|
+
prompt_studio_run: true,
|
|
130
|
+
started_at: Time.current
|
|
131
|
+
)
|
|
132
|
+
new_task.assign_attributes(source_task.prompt_studio_task_attributes)
|
|
133
|
+
new_task.save!
|
|
134
|
+
|
|
135
|
+
update!(result_task_id: new_task.id)
|
|
136
|
+
new_task.run
|
|
137
|
+
new_task
|
|
138
|
+
end
|
|
139
|
+
|
|
140
|
+
def run_judge_if_configured(new_task)
|
|
141
|
+
return unless batch_run.has_judge? && new_task.completed?
|
|
142
|
+
|
|
143
|
+
update!(status: "judging")
|
|
144
|
+
broadcast_item
|
|
145
|
+
|
|
146
|
+
judge_result = invoke_judge(new_task)
|
|
147
|
+
update!(judge_task_id: judge_result.id)
|
|
148
|
+
end
|
|
149
|
+
|
|
150
|
+
def invoke_judge(new_task)
|
|
151
|
+
judge_class = batch_run.judge_class
|
|
152
|
+
config = batch_run.judge_config
|
|
153
|
+
judge_args = {
|
|
154
|
+
creator: source_task.creator,
|
|
155
|
+
prompt_studio_run: true,
|
|
156
|
+
llm_model_key: batch_run.judge_llm_model_key
|
|
157
|
+
}
|
|
158
|
+
judge_args.merge!(source_task.prompt_studio_task_attributes)
|
|
159
|
+
|
|
160
|
+
if config["include_original_prompt_as_context"]
|
|
161
|
+
judge_args[:additional_context] =
|
|
162
|
+
"The content being evaluated was generated in response to the following prompt:\n\n#{source_task.prompt}"
|
|
163
|
+
end
|
|
164
|
+
|
|
165
|
+
case batch_run.judge_type
|
|
166
|
+
when "Raif::Evals::LlmJudges::Binary"
|
|
167
|
+
judge_class.run(
|
|
168
|
+
content_to_judge: new_task.raw_response,
|
|
169
|
+
criteria: config["criteria"],
|
|
170
|
+
strict_mode: config["strict_mode"],
|
|
171
|
+
**judge_args
|
|
172
|
+
)
|
|
173
|
+
when "Raif::Evals::LlmJudges::Scored"
|
|
174
|
+
rubric = Raif::Evals::ScoringRubric.send(config["scoring_rubric"])
|
|
175
|
+
judge_class.run(
|
|
176
|
+
content_to_judge: new_task.raw_response,
|
|
177
|
+
scoring_rubric: rubric,
|
|
178
|
+
**judge_args
|
|
179
|
+
)
|
|
180
|
+
when "Raif::Evals::LlmJudges::Comparative"
|
|
181
|
+
result = judge_class.run(
|
|
182
|
+
content_to_judge: new_task.raw_response,
|
|
183
|
+
over_content: source_task.raw_response,
|
|
184
|
+
comparison_criteria: config["comparison_criteria"],
|
|
185
|
+
**judge_args
|
|
186
|
+
)
|
|
187
|
+
# Store which letter was assigned to the new response so we can display
|
|
188
|
+
# "Winner: New Response" / "Winner: Original Response" instead of "A"/"B"
|
|
189
|
+
update!(metadata: metadata.merge("new_response_letter" => result.expected_winner))
|
|
190
|
+
result
|
|
191
|
+
when "Raif::Evals::LlmJudges::Summarization"
|
|
192
|
+
judge_class.run(
|
|
193
|
+
original_content: source_task.prompt,
|
|
194
|
+
summary: new_task.raw_response,
|
|
195
|
+
**judge_args
|
|
196
|
+
)
|
|
197
|
+
end
|
|
198
|
+
end
|
|
199
|
+
|
|
200
|
+
def broadcast_item
|
|
201
|
+
Turbo::StreamsChannel.broadcast_replace_to(
|
|
202
|
+
batch_run,
|
|
203
|
+
target: dom_id(self),
|
|
204
|
+
partial: "raif/admin/prompt_studio/batch_runs/batch_run_item",
|
|
205
|
+
locals: { item: self }
|
|
206
|
+
)
|
|
207
|
+
end
|
|
208
|
+
|
|
209
|
+
def broadcast_progress
|
|
210
|
+
batch_run.reload
|
|
211
|
+
Turbo::StreamsChannel.broadcast_replace_to(
|
|
212
|
+
batch_run,
|
|
213
|
+
target: dom_id(batch_run, :progress),
|
|
214
|
+
partial: "raif/admin/prompt_studio/batch_runs/progress",
|
|
215
|
+
locals: { batch_run: batch_run }
|
|
216
|
+
)
|
|
217
|
+
end
|
|
218
|
+
|
|
219
|
+
end
|
|
220
|
+
end
|