ruby_llm-contract 0.2.3 → 0.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +64 -0
- data/Gemfile.lock +2 -2
- data/README.md +27 -2
- data/lib/ruby_llm/contract/adapters/response.rb +4 -2
- data/lib/ruby_llm/contract/adapters/ruby_llm.rb +3 -3
- data/lib/ruby_llm/contract/adapters/test.rb +3 -2
- data/lib/ruby_llm/contract/concerns/deep_freeze.rb +23 -0
- data/lib/ruby_llm/contract/concerns/eval_host.rb +11 -2
- data/lib/ruby_llm/contract/contract/schema_validator.rb +70 -3
- data/lib/ruby_llm/contract/eval/baseline_diff.rb +92 -0
- data/lib/ruby_llm/contract/eval/dataset.rb +11 -4
- data/lib/ruby_llm/contract/eval/eval_definition.rb +36 -14
- data/lib/ruby_llm/contract/eval/model_comparison.rb +1 -1
- data/lib/ruby_llm/contract/eval/report.rb +71 -2
- data/lib/ruby_llm/contract/eval/runner.rb +5 -3
- data/lib/ruby_llm/contract/eval/trait_evaluator.rb +6 -0
- data/lib/ruby_llm/contract/eval.rb +1 -0
- data/lib/ruby_llm/contract/pipeline/base.rb +1 -1
- data/lib/ruby_llm/contract/pipeline/result.rb +1 -1
- data/lib/ruby_llm/contract/pipeline/runner.rb +1 -1
- data/lib/ruby_llm/contract/pipeline/trace.rb +3 -2
- data/lib/ruby_llm/contract/prompt/builder.rb +2 -1
- data/lib/ruby_llm/contract/prompt/node.rb +2 -2
- data/lib/ruby_llm/contract/prompt/nodes/example_node.rb +2 -2
- data/lib/ruby_llm/contract/rake_task.rb +31 -4
- data/lib/ruby_llm/contract/rspec/helpers.rb +28 -8
- data/lib/ruby_llm/contract/rspec/pass_eval.rb +23 -2
- data/lib/ruby_llm/contract/step/base.rb +10 -5
- data/lib/ruby_llm/contract/step/dsl.rb +1 -1
- data/lib/ruby_llm/contract/step/limit_checker.rb +1 -1
- data/lib/ruby_llm/contract/step/retry_executor.rb +3 -2
- data/lib/ruby_llm/contract/step/retry_policy.rb +7 -1
- data/lib/ruby_llm/contract/step/runner.rb +10 -2
- data/lib/ruby_llm/contract/step/trace.rb +5 -4
- data/lib/ruby_llm/contract/version.rb +1 -1
- data/lib/ruby_llm/contract.rb +36 -17
- metadata +3 -1
|
@@ -1,14 +1,18 @@
|
|
|
1
1
|
# frozen_string_literal: true
|
|
2
2
|
|
|
3
|
+
require "json"
|
|
4
|
+
require "fileutils"
|
|
5
|
+
|
|
3
6
|
module RubyLLM
|
|
4
7
|
module Contract
|
|
5
8
|
module Eval
|
|
6
9
|
class Report
|
|
7
10
|
attr_reader :dataset_name, :results
|
|
8
11
|
|
|
9
|
-
def initialize(dataset_name:, results:)
|
|
12
|
+
def initialize(dataset_name:, results:, step_name: nil)
|
|
10
13
|
@dataset_name = dataset_name
|
|
11
|
-
@
|
|
14
|
+
@step_name = step_name
|
|
15
|
+
@results = results.dup.freeze
|
|
12
16
|
freeze
|
|
13
17
|
end
|
|
14
18
|
|
|
@@ -78,6 +82,29 @@ module RubyLLM
|
|
|
78
82
|
lines.join("\n")
|
|
79
83
|
end
|
|
80
84
|
|
|
85
|
+
def save_baseline!(path: nil, model: nil)
|
|
86
|
+
file = path || default_baseline_path(model: model)
|
|
87
|
+
FileUtils.mkdir_p(File.dirname(file))
|
|
88
|
+
File.write(file, JSON.pretty_generate(serialize_for_baseline))
|
|
89
|
+
file
|
|
90
|
+
end
|
|
91
|
+
|
|
92
|
+
def compare_with_baseline(path: nil, model: nil)
|
|
93
|
+
file = path || default_baseline_path(model: model)
|
|
94
|
+
raise ArgumentError, "No baseline found at #{file}" unless File.exist?(file)
|
|
95
|
+
|
|
96
|
+
baseline_data = JSON.parse(File.read(file), symbolize_names: true)
|
|
97
|
+
validate_baseline!(baseline_data)
|
|
98
|
+
BaselineDiff.new(
|
|
99
|
+
baseline_cases: baseline_data[:cases],
|
|
100
|
+
current_cases: results.map { |r| serialize_case(r) }
|
|
101
|
+
)
|
|
102
|
+
end
|
|
103
|
+
|
|
104
|
+
def baseline_exists?(path: nil, model: nil)
|
|
105
|
+
File.exist?(path || default_baseline_path(model: model))
|
|
106
|
+
end
|
|
107
|
+
|
|
81
108
|
def print_summary(io = $stdout)
|
|
82
109
|
io.puts summary
|
|
83
110
|
io.puts
|
|
@@ -106,6 +133,48 @@ module RubyLLM
|
|
|
106
133
|
results.reject { |r| r.step_status == :skipped }
|
|
107
134
|
end
|
|
108
135
|
|
|
136
|
+
def default_baseline_path(model: nil)
|
|
137
|
+
parts = [".eval_baselines"]
|
|
138
|
+
parts << sanitize_name(@step_name) if @step_name
|
|
139
|
+
name = sanitize_name(dataset_name)
|
|
140
|
+
name = "#{name}_#{sanitize_name(model)}" if model
|
|
141
|
+
parts << "#{name}.json"
|
|
142
|
+
File.join(*parts)
|
|
143
|
+
end
|
|
144
|
+
|
|
145
|
+
def validate_baseline!(data)
|
|
146
|
+
if data[:dataset_name] && data[:dataset_name] != dataset_name
|
|
147
|
+
raise ArgumentError, "Baseline eval '#{data[:dataset_name]}' does not match '#{dataset_name}'"
|
|
148
|
+
end
|
|
149
|
+
if data[:step_name] && @step_name && data[:step_name] != @step_name
|
|
150
|
+
raise ArgumentError, "Baseline step '#{data[:step_name]}' does not match '#{@step_name}'"
|
|
151
|
+
end
|
|
152
|
+
end
|
|
153
|
+
|
|
154
|
+
def sanitize_name(name)
|
|
155
|
+
name.to_s.gsub(/[^a-zA-Z0-9_-]/, "_")
|
|
156
|
+
end
|
|
157
|
+
|
|
158
|
+
def serialize_for_baseline
|
|
159
|
+
{
|
|
160
|
+
dataset_name: dataset_name,
|
|
161
|
+
step_name: @step_name,
|
|
162
|
+
score: score,
|
|
163
|
+
total_cost: total_cost,
|
|
164
|
+
cases: evaluated_results.map { |r| serialize_case(r) }
|
|
165
|
+
}
|
|
166
|
+
end
|
|
167
|
+
|
|
168
|
+
def serialize_case(result)
|
|
169
|
+
{
|
|
170
|
+
name: result.name,
|
|
171
|
+
passed: result.passed?,
|
|
172
|
+
score: result.score,
|
|
173
|
+
details: result.details,
|
|
174
|
+
cost: result.cost
|
|
175
|
+
}
|
|
176
|
+
end
|
|
177
|
+
|
|
109
178
|
def format_cost(cost)
|
|
110
179
|
"$#{format("%.6f", cost)}"
|
|
111
180
|
end
|
|
@@ -19,7 +19,8 @@ module RubyLLM
|
|
|
19
19
|
|
|
20
20
|
def run
|
|
21
21
|
results = @dataset.cases.map { |test_case| evaluate_case(test_case) }
|
|
22
|
-
|
|
22
|
+
step_name = @step.respond_to?(:name) ? @step.name : @step.to_s
|
|
23
|
+
Report.new(dataset_name: @dataset.name, results: results, step_name: step_name)
|
|
23
24
|
end
|
|
24
25
|
|
|
25
26
|
private
|
|
@@ -31,7 +32,8 @@ module RubyLLM
|
|
|
31
32
|
|
|
32
33
|
build_case_result(test_case, step_result, eval_result)
|
|
33
34
|
rescue RubyLLM::Contract::Error => e
|
|
34
|
-
|
|
35
|
+
raise unless e.message.include?("No adapter configured")
|
|
36
|
+
|
|
35
37
|
skipped_result(test_case, e.message)
|
|
36
38
|
end
|
|
37
39
|
|
|
@@ -81,7 +83,7 @@ module RubyLLM
|
|
|
81
83
|
evaluate_with_custom(step_result, test_case)
|
|
82
84
|
elsif test_case.expected_traits
|
|
83
85
|
evaluate_traits(step_result, test_case)
|
|
84
|
-
elsif test_case.expected
|
|
86
|
+
elsif !test_case.expected.nil?
|
|
85
87
|
evaluate_expected(step_result, test_case)
|
|
86
88
|
else
|
|
87
89
|
evaluate_contract_only
|
|
@@ -26,6 +26,8 @@ module RubyLLM
|
|
|
26
26
|
|
|
27
27
|
def trait_error(key, value, expectation)
|
|
28
28
|
case expectation
|
|
29
|
+
when ::Proc
|
|
30
|
+
trait_proc_error(key, value, expectation)
|
|
29
31
|
when ::Regexp
|
|
30
32
|
trait_regexp_error(key, value, expectation)
|
|
31
33
|
when Range
|
|
@@ -56,6 +58,10 @@ module RubyLLM
|
|
|
56
58
|
"#{key}: expected falsy, got #{value.inspect}" if value
|
|
57
59
|
end
|
|
58
60
|
|
|
61
|
+
def trait_proc_error(key, value, expectation)
|
|
62
|
+
"#{key}: trait check failed, got #{value.inspect}" unless expectation.call(value)
|
|
63
|
+
end
|
|
64
|
+
|
|
59
65
|
def trait_equality_error(key, value, expectation)
|
|
60
66
|
"#{key}: expected #{expectation.inspect}, got #{value.inspect}" unless value == expectation
|
|
61
67
|
end
|
|
@@ -5,14 +5,15 @@ module RubyLLM
|
|
|
5
5
|
module Pipeline
|
|
6
6
|
class Trace
|
|
7
7
|
include Concerns::TraceEquality
|
|
8
|
+
include Concerns::DeepFreeze
|
|
8
9
|
|
|
9
10
|
attr_reader :trace_id, :total_latency_ms, :total_usage, :step_traces, :total_cost
|
|
10
11
|
|
|
11
12
|
def initialize(trace_id: nil, total_latency_ms: nil, total_usage: nil, step_traces: nil)
|
|
12
13
|
@trace_id = trace_id
|
|
13
14
|
@total_latency_ms = total_latency_ms
|
|
14
|
-
@total_usage = total_usage
|
|
15
|
-
@step_traces = step_traces
|
|
15
|
+
@total_usage = deep_dup_freeze(total_usage)
|
|
16
|
+
@step_traces = step_traces&.dup&.freeze
|
|
16
17
|
@total_cost = calculate_total_cost
|
|
17
18
|
freeze
|
|
18
19
|
end
|
|
@@ -8,8 +8,8 @@ module RubyLLM
|
|
|
8
8
|
attr_reader :input, :output
|
|
9
9
|
|
|
10
10
|
def initialize(input:, output:)
|
|
11
|
-
@input = input.freeze
|
|
12
|
-
@output = output.freeze
|
|
11
|
+
@input = input.frozen? ? input : input.dup.freeze
|
|
12
|
+
@output = output.frozen? ? output : output.dup.freeze
|
|
13
13
|
super(type: :example, content: nil)
|
|
14
14
|
end
|
|
15
15
|
|
|
@@ -6,7 +6,8 @@ require "rake/tasklib"
|
|
|
6
6
|
module RubyLLM
|
|
7
7
|
module Contract
|
|
8
8
|
class RakeTask < ::Rake::TaskLib
|
|
9
|
-
attr_accessor :name, :context, :fail_on_empty, :minimum_score, :maximum_cost,
|
|
9
|
+
attr_accessor :name, :context, :fail_on_empty, :minimum_score, :maximum_cost,
|
|
10
|
+
:eval_dirs, :save_baseline, :fail_on_regression
|
|
10
11
|
|
|
11
12
|
def initialize(name = :"ruby_llm_contract:eval", &block)
|
|
12
13
|
super()
|
|
@@ -16,6 +17,8 @@ module RubyLLM
|
|
|
16
17
|
@minimum_score = nil # nil = require 100%; float = threshold
|
|
17
18
|
@maximum_cost = nil # nil = no cost limit; float = budget cap (suite-level)
|
|
18
19
|
@eval_dirs = [] # directories to load eval files from (non-Rails)
|
|
20
|
+
@save_baseline = false
|
|
21
|
+
@fail_on_regression = false
|
|
19
22
|
block&.call(self)
|
|
20
23
|
define_task
|
|
21
24
|
end
|
|
@@ -26,8 +29,7 @@ module RubyLLM
|
|
|
26
29
|
desc "Run all ruby_llm-contract evals"
|
|
27
30
|
task(@name => task_prerequisites) do
|
|
28
31
|
require "ruby_llm/contract"
|
|
29
|
-
|
|
30
|
-
RubyLLM::Contract.load_evals!
|
|
32
|
+
RubyLLM::Contract.load_evals!(*@eval_dirs)
|
|
31
33
|
|
|
32
34
|
results = RubyLLM::Contract.run_all_evals(context: @context)
|
|
33
35
|
|
|
@@ -43,12 +45,16 @@ module RubyLLM
|
|
|
43
45
|
gate_passed = true
|
|
44
46
|
suite_cost = 0.0
|
|
45
47
|
|
|
48
|
+
passed_reports = []
|
|
49
|
+
|
|
46
50
|
results.each do |host, reports|
|
|
47
51
|
puts "\n#{host.name || host.to_s}"
|
|
48
52
|
reports.each_value do |report|
|
|
49
53
|
report.print_summary
|
|
50
54
|
suite_cost += report.total_cost
|
|
51
|
-
|
|
55
|
+
report_ok = report_meets_score?(report) && !check_regression(report)
|
|
56
|
+
gate_passed = false unless report_ok
|
|
57
|
+
passed_reports << report if report_ok
|
|
52
58
|
end
|
|
53
59
|
end
|
|
54
60
|
|
|
@@ -58,6 +64,9 @@ module RubyLLM
|
|
|
58
64
|
end
|
|
59
65
|
|
|
60
66
|
abort "\nEval suite FAILED" unless gate_passed
|
|
67
|
+
|
|
68
|
+
# Save baselines only after ALL gates pass
|
|
69
|
+
passed_reports.each { |r| save_baseline!(r) } if @save_baseline
|
|
61
70
|
puts "\nAll evals passed."
|
|
62
71
|
end
|
|
63
72
|
end
|
|
@@ -70,6 +79,24 @@ module RubyLLM
|
|
|
70
79
|
end
|
|
71
80
|
end
|
|
72
81
|
|
|
82
|
+
def check_regression(report)
|
|
83
|
+
return false unless @fail_on_regression && report.baseline_exists?
|
|
84
|
+
|
|
85
|
+
diff = report.compare_with_baseline
|
|
86
|
+
if diff.regressed?
|
|
87
|
+
puts "\n REGRESSIONS DETECTED:"
|
|
88
|
+
puts " #{diff}"
|
|
89
|
+
true
|
|
90
|
+
else
|
|
91
|
+
false
|
|
92
|
+
end
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
def save_baseline!(report)
|
|
96
|
+
path = report.save_baseline!
|
|
97
|
+
puts " Baseline saved: #{path}"
|
|
98
|
+
end
|
|
99
|
+
|
|
73
100
|
def task_prerequisites
|
|
74
101
|
Rake::Task.task_defined?(:environment) ? [:environment] : []
|
|
75
102
|
end
|
|
@@ -10,18 +10,38 @@ module RubyLLM
|
|
|
10
10
|
# result = ClassifyTicket.run("test")
|
|
11
11
|
# result.parsed_output # => {priority: "high"}
|
|
12
12
|
#
|
|
13
|
-
#
|
|
14
|
-
# stub_step(ClassifyTicket, responses: [{ a: 1 }, { a: 2 }])
|
|
13
|
+
# Only affects the specified step — other steps are not affected.
|
|
15
14
|
#
|
|
16
15
|
def stub_step(step_class, response: nil, responses: nil)
|
|
17
|
-
adapter =
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
16
|
+
adapter = build_test_adapter(response: response, responses: responses)
|
|
17
|
+
allow(step_class).to receive(:run).and_wrap_original do |original, input, **kwargs|
|
|
18
|
+
context = (kwargs[:context] || {}).merge(adapter: adapter)
|
|
19
|
+
original.call(input, context: context)
|
|
20
|
+
end
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
# Set a global test adapter for ALL steps.
|
|
24
|
+
#
|
|
25
|
+
# stub_all_steps(response: { default: true })
|
|
26
|
+
#
|
|
27
|
+
def stub_all_steps(response: nil, responses: nil)
|
|
28
|
+
adapter = build_test_adapter(response: response, responses: responses)
|
|
23
29
|
RubyLLM::Contract.configure { |c| c.default_adapter = adapter }
|
|
24
30
|
end
|
|
31
|
+
|
|
32
|
+
private
|
|
33
|
+
|
|
34
|
+
def build_test_adapter(response: nil, responses: nil)
|
|
35
|
+
if responses
|
|
36
|
+
Adapters::Test.new(responses: responses.map { |r| normalize_test_response(r) })
|
|
37
|
+
else
|
|
38
|
+
Adapters::Test.new(response: normalize_test_response(response))
|
|
39
|
+
end
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
def normalize_test_response(value)
|
|
43
|
+
value
|
|
44
|
+
end
|
|
25
45
|
end
|
|
26
46
|
end
|
|
27
47
|
end
|
|
@@ -64,12 +64,18 @@ RSpec::Matchers.define :pass_eval do |eval_name|
|
|
|
64
64
|
@maximum_cost = cost
|
|
65
65
|
end
|
|
66
66
|
|
|
67
|
+
chain :without_regressions do
|
|
68
|
+
@check_regressions = true
|
|
69
|
+
end
|
|
70
|
+
|
|
67
71
|
match do |step_or_pipeline|
|
|
68
72
|
@eval_name = eval_name
|
|
69
73
|
@context ||= {}
|
|
70
74
|
@minimum_score ||= nil
|
|
71
75
|
@maximum_cost ||= nil
|
|
76
|
+
@check_regressions ||= false
|
|
72
77
|
@error = nil
|
|
78
|
+
@diff = nil
|
|
73
79
|
@report = step_or_pipeline.run_eval(eval_name, context: @context)
|
|
74
80
|
|
|
75
81
|
score_ok = if @minimum_score
|
|
@@ -80,14 +86,29 @@ RSpec::Matchers.define :pass_eval do |eval_name|
|
|
|
80
86
|
|
|
81
87
|
cost_ok = @maximum_cost ? @report.total_cost <= @maximum_cost : true
|
|
82
88
|
|
|
83
|
-
|
|
89
|
+
regression_ok = if @check_regressions && @report.baseline_exists?
|
|
90
|
+
@diff = @report.compare_with_baseline
|
|
91
|
+
!@diff.regressed?
|
|
92
|
+
else
|
|
93
|
+
true
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
score_ok && cost_ok && regression_ok
|
|
84
97
|
rescue StandardError => e
|
|
85
98
|
@error = e
|
|
86
99
|
false
|
|
87
100
|
end
|
|
88
101
|
|
|
89
102
|
failure_message do
|
|
90
|
-
format_failure_message(@eval_name, @error, @report, @minimum_score, @maximum_cost)
|
|
103
|
+
msg = format_failure_message(@eval_name, @error, @report, @minimum_score, @maximum_cost)
|
|
104
|
+
if @diff&.regressed?
|
|
105
|
+
msg += "\n\nRegressions from baseline:\n"
|
|
106
|
+
@diff.regressions.each do |r|
|
|
107
|
+
msg += " #{r[:case]}: was PASS, now FAIL — #{r[:detail]}\n"
|
|
108
|
+
end
|
|
109
|
+
msg += " Score delta: #{@diff.score_delta}"
|
|
110
|
+
end
|
|
111
|
+
msg
|
|
91
112
|
end
|
|
92
113
|
|
|
93
114
|
failure_message_when_negated do
|
|
@@ -58,18 +58,23 @@ module RubyLLM
|
|
|
58
58
|
end
|
|
59
59
|
end
|
|
60
60
|
|
|
61
|
-
KNOWN_CONTEXT_KEYS = %i[adapter model temperature max_tokens
|
|
61
|
+
KNOWN_CONTEXT_KEYS = %i[adapter model temperature max_tokens provider assume_model_exists].freeze
|
|
62
62
|
|
|
63
63
|
def run(input, context: {})
|
|
64
|
+
context = (context || {}).transform_keys { |k| k.respond_to?(:to_sym) ? k.to_sym : k }
|
|
64
65
|
warn_unknown_context_keys(context)
|
|
65
66
|
adapter = resolve_adapter(context)
|
|
66
67
|
default_model = context[:model] || model || RubyLLM::Contract.configuration.default_model
|
|
67
68
|
policy = retry_policy
|
|
68
69
|
|
|
70
|
+
ctx_temp = context[:temperature]
|
|
71
|
+
extra = context.slice(:provider, :assume_model_exists, :max_tokens)
|
|
69
72
|
result = if policy
|
|
70
|
-
run_with_retry(input, adapter: adapter, default_model: default_model,
|
|
73
|
+
run_with_retry(input, adapter: adapter, default_model: default_model,
|
|
74
|
+
policy: policy, context_temperature: ctx_temp, extra_options: extra)
|
|
71
75
|
else
|
|
72
|
-
run_once(input, adapter: adapter, model: default_model,
|
|
76
|
+
run_once(input, adapter: adapter, model: default_model,
|
|
77
|
+
context_temperature: ctx_temp, extra_options: extra)
|
|
73
78
|
end
|
|
74
79
|
|
|
75
80
|
invoke_around_call(input, result)
|
|
@@ -101,14 +106,14 @@ module RubyLLM
|
|
|
101
106
|
"{ |c| c.default_adapter = ... } or pass context: { adapter: ... }"
|
|
102
107
|
end
|
|
103
108
|
|
|
104
|
-
def run_once(input, adapter:, model:, context_temperature: nil)
|
|
109
|
+
def run_once(input, adapter:, model:, context_temperature: nil, extra_options: {})
|
|
105
110
|
effective_temp = context_temperature || temperature
|
|
106
111
|
Runner.new(
|
|
107
112
|
input_type: input_type, output_type: output_type,
|
|
108
113
|
prompt_block: prompt, contract_definition: effective_contract,
|
|
109
114
|
adapter: adapter, model: model, output_schema: output_schema,
|
|
110
115
|
max_output: max_output, max_input: max_input, max_cost: max_cost,
|
|
111
|
-
temperature: effective_temp
|
|
116
|
+
temperature: effective_temp, extra_options: extra_options
|
|
112
117
|
).call(input)
|
|
113
118
|
rescue ArgumentError => e
|
|
114
119
|
Result.new(status: :input_error, raw_output: nil, parsed_output: nil,
|
|
@@ -168,7 +168,7 @@ module RubyLLM
|
|
|
168
168
|
end
|
|
169
169
|
|
|
170
170
|
def retry_policy(models: nil, attempts: nil, retry_on: nil, &block)
|
|
171
|
-
if block || models || attempts
|
|
171
|
+
if block || models || attempts || retry_on
|
|
172
172
|
return @retry_policy = RetryPolicy.new(models: models, attempts: attempts, retry_on: retry_on, &block)
|
|
173
173
|
end
|
|
174
174
|
|
|
@@ -29,7 +29,7 @@ module RubyLLM
|
|
|
29
29
|
end
|
|
30
30
|
|
|
31
31
|
def append_cost_error(estimated, errors)
|
|
32
|
-
estimated_output =
|
|
32
|
+
estimated_output = effective_max_output || 0
|
|
33
33
|
estimated_cost = CostCalculator.calculate(
|
|
34
34
|
model_name: @model,
|
|
35
35
|
usage: { input_tokens: estimated, output_tokens: estimated_output }
|
|
@@ -8,12 +8,13 @@ module RubyLLM
|
|
|
8
8
|
module RetryExecutor
|
|
9
9
|
private
|
|
10
10
|
|
|
11
|
-
def run_with_retry(input, adapter:, default_model:, policy:)
|
|
11
|
+
def run_with_retry(input, adapter:, default_model:, policy:, context_temperature: nil, extra_options: {})
|
|
12
12
|
all_attempts = []
|
|
13
13
|
|
|
14
14
|
policy.max_attempts.times do |attempt_index|
|
|
15
15
|
model = policy.model_for_attempt(attempt_index, default_model)
|
|
16
|
-
result = run_once(input, adapter: adapter, model: model
|
|
16
|
+
result = run_once(input, adapter: adapter, model: model,
|
|
17
|
+
context_temperature: context_temperature, extra_options: extra_options)
|
|
17
18
|
all_attempts << { attempt: attempt_index + 1, model: model, result: result }
|
|
18
19
|
break unless policy.retryable?(result)
|
|
19
20
|
end
|
|
@@ -15,6 +15,7 @@ module RubyLLM
|
|
|
15
15
|
if block
|
|
16
16
|
@max_attempts = 1
|
|
17
17
|
instance_eval(&block)
|
|
18
|
+
warn_no_retry! if @max_attempts == 1 && @models.empty?
|
|
18
19
|
else
|
|
19
20
|
apply_keywords(models: models, attempts: attempts, retry_on: retry_on)
|
|
20
21
|
end
|
|
@@ -38,7 +39,7 @@ module RubyLLM
|
|
|
38
39
|
end
|
|
39
40
|
|
|
40
41
|
def retry_on(*statuses)
|
|
41
|
-
@retryable_statuses = statuses
|
|
42
|
+
@retryable_statuses = statuses.flatten
|
|
42
43
|
end
|
|
43
44
|
|
|
44
45
|
def retryable?(result)
|
|
@@ -65,6 +66,11 @@ module RubyLLM
|
|
|
65
66
|
@retryable_statuses = Array(retry_on).dup if retry_on
|
|
66
67
|
end
|
|
67
68
|
|
|
69
|
+
def warn_no_retry!
|
|
70
|
+
warn "[ruby_llm-contract] retry_policy has max_attempts=1 with no models. " \
|
|
71
|
+
"This means no actual retry will happen. Add `attempts 2` or `escalate %w[model1 model2]`."
|
|
72
|
+
end
|
|
73
|
+
|
|
68
74
|
def validate_max_attempts!
|
|
69
75
|
return if @max_attempts.is_a?(Integer) && @max_attempts >= 1
|
|
70
76
|
|
|
@@ -8,7 +8,7 @@ module RubyLLM
|
|
|
8
8
|
|
|
9
9
|
def initialize(input_type:, output_type:, prompt_block:, contract_definition:,
|
|
10
10
|
adapter:, model:, output_schema: nil, max_output: nil,
|
|
11
|
-
max_input: nil, max_cost: nil, temperature: nil)
|
|
11
|
+
max_input: nil, max_cost: nil, temperature: nil, extra_options: {})
|
|
12
12
|
@input_type = input_type
|
|
13
13
|
@output_type = output_type
|
|
14
14
|
@prompt_block = prompt_block
|
|
@@ -20,6 +20,7 @@ module RubyLLM
|
|
|
20
20
|
@max_input = max_input
|
|
21
21
|
@max_cost = max_cost
|
|
22
22
|
@temperature = temperature
|
|
23
|
+
@extra_options = extra_options
|
|
23
24
|
end
|
|
24
25
|
|
|
25
26
|
def call(input)
|
|
@@ -82,13 +83,20 @@ module RubyLLM
|
|
|
82
83
|
end
|
|
83
84
|
|
|
84
85
|
def build_adapter_options
|
|
86
|
+
effective_max_tokens = @extra_options[:max_tokens] || @max_output
|
|
87
|
+
|
|
85
88
|
{ model: @model }.tap do |opts|
|
|
86
89
|
opts[:schema] = @output_schema if @output_schema
|
|
87
|
-
opts[:max_tokens] =
|
|
90
|
+
opts[:max_tokens] = effective_max_tokens if effective_max_tokens
|
|
88
91
|
opts[:temperature] = @temperature if @temperature
|
|
92
|
+
@extra_options.each { |k, v| opts[k] = v unless opts.key?(k) }
|
|
89
93
|
end
|
|
90
94
|
end
|
|
91
95
|
|
|
96
|
+
def effective_max_output
|
|
97
|
+
@extra_options[:max_tokens] || @max_output
|
|
98
|
+
end
|
|
99
|
+
|
|
92
100
|
def build_error_result(error_result, messages)
|
|
93
101
|
Result.new(
|
|
94
102
|
status: error_result.status,
|
|
@@ -5,15 +5,16 @@ module RubyLLM
|
|
|
5
5
|
module Step
|
|
6
6
|
class Trace
|
|
7
7
|
include Concerns::TraceEquality
|
|
8
|
+
include Concerns::DeepFreeze
|
|
8
9
|
|
|
9
10
|
attr_reader :messages, :model, :latency_ms, :usage, :attempts, :cost
|
|
10
11
|
|
|
11
12
|
def initialize(messages: nil, model: nil, latency_ms: nil, usage: nil, attempts: nil, cost: nil)
|
|
12
|
-
@messages = messages
|
|
13
|
-
@model = model
|
|
13
|
+
@messages = deep_dup_freeze(messages)
|
|
14
|
+
@model = model.frozen? ? model : model&.dup&.freeze
|
|
14
15
|
@latency_ms = latency_ms
|
|
15
|
-
@usage = usage
|
|
16
|
-
@attempts = attempts
|
|
16
|
+
@usage = deep_dup_freeze(usage)
|
|
17
|
+
@attempts = deep_dup_freeze(attempts)
|
|
17
18
|
@cost = cost || CostCalculator.calculate(model_name: model, usage: usage)
|
|
18
19
|
freeze
|
|
19
20
|
end
|