ruby_llm-contract 0.4.2 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. checksums.yaml +4 -4
  2. data/.rubycritic.yml +8 -0
  3. data/.simplecov +22 -0
  4. data/CHANGELOG.md +59 -0
  5. data/Gemfile +2 -0
  6. data/Gemfile.lock +104 -2
  7. data/README.md +42 -2
  8. data/lib/ruby_llm/contract/concerns/context_helpers.rb +11 -10
  9. data/lib/ruby_llm/contract/concerns/deep_freeze.rb +13 -7
  10. data/lib/ruby_llm/contract/concerns/deep_symbolize.rb +15 -5
  11. data/lib/ruby_llm/contract/concerns/eval_host.rb +51 -7
  12. data/lib/ruby_llm/contract/contract/schema_validator/bound_rule.rb +85 -0
  13. data/lib/ruby_llm/contract/contract/schema_validator/enum_rule.rb +23 -0
  14. data/lib/ruby_llm/contract/contract/schema_validator/node.rb +70 -0
  15. data/lib/ruby_llm/contract/contract/schema_validator/object_rules.rb +66 -0
  16. data/lib/ruby_llm/contract/contract/schema_validator/scalar_rules.rb +22 -0
  17. data/lib/ruby_llm/contract/contract/schema_validator/schema_extractor.rb +23 -0
  18. data/lib/ruby_llm/contract/contract/schema_validator/type_rule.rb +30 -0
  19. data/lib/ruby_llm/contract/contract/schema_validator.rb +41 -266
  20. data/lib/ruby_llm/contract/contract/validator.rb +9 -0
  21. data/lib/ruby_llm/contract/cost_calculator.rb +41 -1
  22. data/lib/ruby_llm/contract/eval/case_executor.rb +52 -0
  23. data/lib/ruby_llm/contract/eval/case_result_builder.rb +35 -0
  24. data/lib/ruby_llm/contract/eval/case_scorer.rb +66 -0
  25. data/lib/ruby_llm/contract/eval/evaluator/exact.rb +8 -6
  26. data/lib/ruby_llm/contract/eval/evaluator/proc_evaluator.rb +22 -10
  27. data/lib/ruby_llm/contract/eval/evaluator/regex.rb +11 -8
  28. data/lib/ruby_llm/contract/eval/expectation_evaluator.rb +26 -0
  29. data/lib/ruby_llm/contract/eval/prompt_diff.rb +39 -0
  30. data/lib/ruby_llm/contract/eval/prompt_diff_comparator.rb +116 -0
  31. data/lib/ruby_llm/contract/eval/prompt_diff_presenter.rb +99 -0
  32. data/lib/ruby_llm/contract/eval/prompt_diff_serializer.rb +23 -0
  33. data/lib/ruby_llm/contract/eval/report.rb +19 -191
  34. data/lib/ruby_llm/contract/eval/report_presenter.rb +65 -0
  35. data/lib/ruby_llm/contract/eval/report_stats.rb +65 -0
  36. data/lib/ruby_llm/contract/eval/report_storage.rb +107 -0
  37. data/lib/ruby_llm/contract/eval/runner.rb +30 -207
  38. data/lib/ruby_llm/contract/eval/step_expectation_applier.rb +67 -0
  39. data/lib/ruby_llm/contract/eval/step_result_normalizer.rb +39 -0
  40. data/lib/ruby_llm/contract/eval.rb +13 -0
  41. data/lib/ruby_llm/contract/minitest.rb +116 -2
  42. data/lib/ruby_llm/contract/pipeline/base.rb +15 -2
  43. data/lib/ruby_llm/contract/rake_task.rb +20 -1
  44. data/lib/ruby_llm/contract/rspec/helpers.rb +91 -6
  45. data/lib/ruby_llm/contract/rspec/pass_eval.rb +84 -3
  46. data/lib/ruby_llm/contract/rspec.rb +18 -0
  47. data/lib/ruby_llm/contract/step/adapter_caller.rb +23 -0
  48. data/lib/ruby_llm/contract/step/base.rb +94 -37
  49. data/lib/ruby_llm/contract/step/dsl.rb +61 -16
  50. data/lib/ruby_llm/contract/step/input_validator.rb +34 -0
  51. data/lib/ruby_llm/contract/step/limit_checker.rb +28 -11
  52. data/lib/ruby_llm/contract/step/prompt_compiler.rb +33 -0
  53. data/lib/ruby_llm/contract/step/result.rb +3 -2
  54. data/lib/ruby_llm/contract/step/result_builder.rb +60 -0
  55. data/lib/ruby_llm/contract/step/retry_executor.rb +1 -0
  56. data/lib/ruby_llm/contract/step/runner.rb +47 -84
  57. data/lib/ruby_llm/contract/step/runner_config.rb +37 -0
  58. data/lib/ruby_llm/contract/step.rb +5 -0
  59. data/lib/ruby_llm/contract/version.rb +1 -1
  60. data/lib/ruby_llm/contract.rb +28 -0
  61. metadata +28 -1
@@ -12,11 +12,77 @@ module RubyLLM
12
12
  #
13
13
  # Only affects the specified step — other steps are not affected.
14
14
  #
15
- def stub_step(step_class, response: nil, responses: nil)
15
+ # With a block, the stub is scoped — cleaned up after the block:
16
+ #
17
+ # stub_step(ClassifyTicket, response: data) do
18
+ # # only stubbed inside this block
19
+ # end
20
+ # # ClassifyTicket no longer stubbed
21
+ #
22
+ # Without a block, the stub lives until the RSpec example ends.
23
+ #
24
+ def stub_step(step_class, response: nil, responses: nil, &block)
16
25
  adapter = build_test_adapter(response: response, responses: responses)
17
- allow(step_class).to receive(:run).and_wrap_original do |original, input, **kwargs|
18
- context = (kwargs[:context] || {}).merge(adapter: adapter)
19
- original.call(input, context: context)
26
+
27
+ if block
28
+ # Block form: use thread-local overrides with save/restore for real scoping
29
+ overrides = RubyLLM::Contract.step_adapter_overrides
30
+ previous = overrides[step_class]
31
+ overrides[step_class] = adapter
32
+ begin
33
+ yield
34
+ ensure
35
+ if previous
36
+ overrides[step_class] = previous
37
+ else
38
+ overrides.delete(step_class)
39
+ end
40
+ end
41
+ else
42
+ # Non-block: use RSpec allow (auto-cleaned after example)
43
+ allow(step_class).to receive(:run).and_wrap_original do |original, input, **kwargs|
44
+ context = kwargs[:context] || {}
45
+ unless context.key?(:adapter) || context.key?("adapter")
46
+ context = context.merge(adapter: adapter)
47
+ end
48
+ original.call(input, context: context)
49
+ end
50
+ end
51
+ end
52
+
53
+ # Stub multiple steps at once with different responses.
54
+ # Takes a hash of step_class => options. Requires a block.
55
+ #
56
+ # stub_steps(
57
+ # ClassifyTicket => { response: { priority: "high" } },
58
+ # RouteToTeam => { response: { team: "billing" } }
59
+ # ) do
60
+ # result = TicketPipeline.run("test")
61
+ # end
62
+ #
63
+ def stub_steps(stubs, &block)
64
+ raise ArgumentError, "stub_steps requires a block" unless block
65
+
66
+ overrides = RubyLLM::Contract.step_adapter_overrides
67
+ previous = {}
68
+
69
+ stubs.each do |step_class, opts|
70
+ opts = opts.transform_keys(&:to_sym)
71
+ adapter = build_test_adapter(**opts)
72
+ previous[step_class] = overrides[step_class]
73
+ overrides[step_class] = adapter
74
+ end
75
+
76
+ begin
77
+ yield
78
+ ensure
79
+ stubs.each_key do |step_class|
80
+ if previous[step_class]
81
+ overrides[step_class] = previous[step_class]
82
+ else
83
+ overrides.delete(step_class)
84
+ end
85
+ end
20
86
  end
21
87
  end
22
88
 
@@ -24,9 +90,28 @@ module RubyLLM
24
90
  #
25
91
  # stub_all_steps(response: { default: true })
26
92
  #
27
- def stub_all_steps(response: nil, responses: nil)
93
+ # Supports an optional block form — the previous adapter is restored
94
+ # after the block returns (even if it raises):
95
+ #
96
+ # stub_all_steps(response: { default: true }) do
97
+ # # all steps use test adapter
98
+ # end
99
+ # # original adapter restored
100
+ #
101
+ def stub_all_steps(response: nil, responses: nil, &block)
28
102
  adapter = build_test_adapter(response: response, responses: responses)
29
- RubyLLM::Contract.configure { |c| c.default_adapter = adapter }
103
+
104
+ if block
105
+ previous = RubyLLM::Contract.configuration.default_adapter
106
+ begin
107
+ RubyLLM::Contract.configuration.default_adapter = adapter
108
+ yield
109
+ ensure
110
+ RubyLLM::Contract.configuration.default_adapter = previous
111
+ end
112
+ else
113
+ RubyLLM::Contract.configure { |c| c.default_adapter = adapter }
114
+ end
30
115
  end
31
116
 
32
117
  private
@@ -68,15 +68,28 @@ RSpec::Matchers.define :pass_eval do |eval_name|
68
68
  @check_regressions = true
69
69
  end
70
70
 
71
+ chain :compared_with do |other_step|
72
+ @comparison_step = other_step
73
+ @check_regressions = true # compared_with implies regression check
74
+ end
75
+
71
76
  match do |step_or_pipeline|
72
77
  @eval_name = eval_name
73
78
  @context ||= {}
74
79
  @minimum_score ||= nil
75
80
  @maximum_cost ||= nil
76
81
  @check_regressions ||= false
82
+ @comparison_step ||= nil
77
83
  @error = nil
78
84
  @diff = nil
79
- @report = step_or_pipeline.run_eval(eval_name, context: @context)
85
+ @prompt_diff = nil
86
+
87
+ if @comparison_step && @check_regressions
88
+ @prompt_diff = step_or_pipeline.compare_with(@comparison_step, eval: eval_name, context: @context)
89
+ @report = @prompt_diff.candidate_report
90
+ else
91
+ @report = step_or_pipeline.run_eval(eval_name, context: @context)
92
+ end
80
93
 
81
94
  score_ok = if @minimum_score
82
95
  @report.score >= @minimum_score
@@ -86,7 +99,9 @@ RSpec::Matchers.define :pass_eval do |eval_name|
86
99
 
87
100
  cost_ok = @maximum_cost ? @report.total_cost <= @maximum_cost : true
88
101
 
89
- regression_ok = if @check_regressions && @report.baseline_exists?
102
+ regression_ok = if @prompt_diff
103
+ @prompt_diff.safe_to_switch?
104
+ elsif @check_regressions && @report.baseline_exists?
90
105
  @diff = @report.compare_with_baseline
91
106
  !@diff.regressed?
92
107
  else
@@ -100,11 +115,67 @@ RSpec::Matchers.define :pass_eval do |eval_name|
100
115
  end
101
116
 
102
117
  failure_message do
118
+ if @prompt_diff && !@prompt_diff.safe_to_switch?
119
+ msg = "expected #{@eval_name} eval to be safe to switch from baseline prompt\n"
120
+
121
+ # Check empty sides first — most fundamental problem
122
+ bl_empty = @prompt_diff.baseline_empty?
123
+ cd_empty = @prompt_diff.candidate_empty?
124
+ if bl_empty || cd_empty
125
+ msg += " One side has no evaluated cases (all skipped or no adapter?)\n"
126
+ if sample_response_only_compare?
127
+ msg += " compare_with ignores sample_response; pass model: or with_context(adapter: ...)\n"
128
+ end
129
+ msg += " Candidate score: #{@prompt_diff.candidate_score}, Baseline score: #{@prompt_diff.baseline_score}"
130
+ next msg
131
+ end
132
+
133
+ # Check dataset comparability — names, inputs, AND expected must match
134
+ unless @prompt_diff.cases_comparable?
135
+ unless @prompt_diff.case_names_match?
136
+ mm = @prompt_diff.mismatched_cases
137
+ msg += " Case set mismatch — candidate and baseline must have identical cases:\n"
138
+ mm[:only_in_baseline].each { |n| msg += " only in baseline: #{n}\n" }
139
+ mm[:only_in_candidate].each { |n| msg += " only in candidate: #{n}\n" }
140
+ end
141
+ @prompt_diff.input_mismatches.each do |m|
142
+ msg += " Input mismatch for '#{m[:case]}' — same name but different inputs\n"
143
+ end
144
+ @prompt_diff.expected_mismatches.each do |m|
145
+ msg += " Expected mismatch for '#{m[:case]}' — same name/input but different expected values\n"
146
+ end
147
+ next msg
148
+ end
149
+
150
+ # Check per-case score regressions (even if global average is flat)
151
+ if @prompt_diff.score_regressions.any?
152
+ msg += " Per-case score regressions (#{@prompt_diff.score_regressions.length}):\n"
153
+ @prompt_diff.score_regressions.each do |r|
154
+ msg += " #{r[:case]}: #{r[:baseline_score]} -> #{r[:candidate_score]} (#{r[:delta]})\n"
155
+ end
156
+ msg += " Score delta: #{@prompt_diff.score_delta}"
157
+ next msg
158
+ end
159
+
160
+ # Check pass/fail regressions and removed cases
161
+ removed = @prompt_diff.removed_passing_cases
162
+ reg_count = @prompt_diff.regressions.length + removed.length
163
+ msg += " Found #{reg_count} regression(s):\n"
164
+ @prompt_diff.regressions.each do |r|
165
+ msg += " #{r[:case]}: was PASS, now FAIL -- #{r[:detail]}\n"
166
+ end
167
+ removed.each do |name|
168
+ msg += " #{name}: REMOVED (was passing in baseline)\n"
169
+ end
170
+ msg += " Score delta: #{@prompt_diff.score_delta}"
171
+ next msg
172
+ end
173
+
103
174
  msg = format_failure_message(@eval_name, @error, @report, @minimum_score, @maximum_cost)
104
175
  if @diff&.regressed?
105
176
  msg += "\n\nRegressions from baseline:\n"
106
177
  @diff.regressions.each do |r|
107
- msg += " #{r[:case]}: was PASS, now FAIL #{r[:detail]}\n"
178
+ msg += " #{r[:case]}: was PASS, now FAIL -- #{r[:detail]}\n"
108
179
  end
109
180
  msg += " Score delta: #{@diff.score_delta}"
110
181
  end
@@ -114,4 +185,14 @@ RSpec::Matchers.define :pass_eval do |eval_name|
114
185
  failure_message_when_negated do
115
186
  "expected #{@eval_name} eval NOT to pass, but it passed with score: #{@report.score.round(2)}"
116
187
  end
188
+
189
+ def sample_response_only_compare?
190
+ return false unless @comparison_step
191
+ return false if @context[:adapter] || @context[:model]
192
+
193
+ defn = @comparison_step.send(:all_eval_definitions)[@eval_name.to_s]
194
+ defn&.build_adapter
195
+ rescue StandardError
196
+ false
197
+ end
117
198
  end
@@ -8,4 +8,22 @@ require_relative "rspec/helpers"
8
8
 
9
9
  RSpec.configure do |config|
10
10
  config.include RubyLLM::Contract::RSpec::Helpers
11
+
12
+ # Auto-cleanup: snapshot adapter before each example, restore after.
13
+ # Prevents non-block stub_all_steps from leaking between examples.
14
+ config.around(:each) do |example|
15
+ original_adapter = RubyLLM::Contract.configuration.default_adapter
16
+ original_logger = RubyLLM::Contract.configuration.logger
17
+ original_eval_hosts = RubyLLM::Contract.eval_hosts.dup
18
+ original_overrides = RubyLLM::Contract.step_adapter_overrides.dup
19
+ begin
20
+ example.run
21
+ ensure
22
+ RubyLLM::Contract.configuration.default_adapter = original_adapter
23
+ RubyLLM::Contract.configuration.logger = original_logger
24
+ RubyLLM::Contract.reset_eval_hosts!
25
+ RubyLLM::Contract.eval_hosts.concat(original_eval_hosts)
26
+ RubyLLM::Contract.step_adapter_overrides.replace(original_overrides)
27
+ end
28
+ end
11
29
  end if defined?(::RSpec)
@@ -0,0 +1,23 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RubyLLM
4
+ module Contract
5
+ module Step
6
+ class AdapterCaller
7
+ def initialize(adapter:, adapter_options:)
8
+ @adapter = adapter
9
+ @adapter_options = adapter_options
10
+ end
11
+
12
+ def call(messages)
13
+ start_time = Process.clock_gettime(Process::CLOCK_MONOTONIC)
14
+ response = @adapter.call(messages: messages, **@adapter_options)
15
+ latency_ms = ((Process.clock_gettime(Process::CLOCK_MONOTONIC) - start_time) * 1000).round
16
+ [response, latency_ms]
17
+ rescue StandardError => error
18
+ [Result.new(status: :adapter_error, raw_output: nil, parsed_output: nil, validation_errors: [error.message]), 0]
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -4,6 +4,8 @@ module RubyLLM
4
4
  module Contract
5
5
  module Step
6
6
  class Base
7
+ DEFAULT_OUTPUT_TOKENS = 256
8
+
7
9
  def self.inherited(subclass)
8
10
  super
9
11
  Contract.register_eval_host(subclass) if respond_to?(:eval_defined?) && eval_defined?
@@ -15,30 +17,23 @@ module RubyLLM
15
17
  include Dsl
16
18
 
17
19
  def eval_case(input:, expected: nil, expected_traits: nil, evaluator: nil, context: {})
18
- dataset = Eval::Dataset.define("single_case") do
19
- add_case("inline", input: input, expected: expected,
20
- expected_traits: expected_traits, evaluator: evaluator)
21
- end
22
- report = Eval::Runner.run(step: self, dataset: dataset, context: context)
23
- report.results.first
20
+ Eval::Runner.run(step: self, dataset: inline_dataset(input, expected, expected_traits, evaluator),
21
+ context: context).results.first
24
22
  end
25
23
 
26
24
  def estimate_cost(input:, model: nil)
27
- model_name = model || RubyLLM::Contract.configuration.default_model
28
- messages = build_messages(input)
29
- input_tokens = TokenEstimator.estimate(messages)
30
- output_tokens = max_output || 256 # conservative default
31
-
25
+ model_name = estimated_model_name(model)
32
26
  model_info = CostCalculator.send(:find_model, model_name)
33
27
  return nil unless model_info
34
28
 
35
- estimated = CostCalculator.send(:compute_cost, model_info,
36
- { input_tokens: input_tokens, output_tokens: output_tokens })
29
+ input_tokens = TokenEstimator.estimate(build_messages(input))
30
+ output_tokens = max_output || DEFAULT_OUTPUT_TOKENS
31
+
37
32
  {
38
33
  model: model_name,
39
34
  input_tokens: input_tokens,
40
35
  output_tokens_estimate: output_tokens,
41
- estimated_cost: estimated
36
+ estimated_cost: estimated_cost_for(model_info, input_tokens, output_tokens)
42
37
  }
43
38
  end
44
39
 
@@ -46,15 +41,11 @@ module RubyLLM
46
41
  defn = send(:all_eval_definitions)[eval_name.to_s]
47
42
  raise ArgumentError, "No eval '#{eval_name}' defined" unless defn
48
43
 
49
- model_list = models || [RubyLLM::Contract.configuration.default_model].compact
44
+ model_list = models || [estimated_model_name].compact
50
45
  cases = defn.build_dataset.cases
51
46
 
52
47
  model_list.each_with_object({}) do |model_name, result|
53
- per_case = cases.sum do |c|
54
- est = estimate_cost(input: c.input, model: model_name)
55
- est ? est[:estimated_cost] : 0.0
56
- end
57
- result[model_name] = per_case.round(6)
48
+ result[model_name] = estimate_eval_cost_for_model(cases, model_name)
58
49
  end
59
50
  end
60
51
 
@@ -65,20 +56,8 @@ module RubyLLM
65
56
  def run(input, context: {})
66
57
  context = safe_context(context)
67
58
  warn_unknown_context_keys(context)
68
- adapter = resolve_adapter(context)
69
- default_model = context[:model] || model || RubyLLM::Contract.configuration.default_model
70
- policy = retry_policy
71
-
72
- ctx_temp = context[:temperature]
73
- extra = context.slice(:provider, :assume_model_exists, :max_tokens)
74
- result = if policy
75
- run_with_retry(input, adapter: adapter, default_model: default_model,
76
- policy: policy, context_temperature: ctx_temp, extra_options: extra)
77
- else
78
- run_once(input, adapter: adapter, model: default_model,
79
- context_temperature: ctx_temp, extra_options: extra)
80
- end
81
59
 
60
+ result = dispatch_run(input, context)
82
61
  log_result(result)
83
62
  invoke_around_call(input, result)
84
63
  end
@@ -87,13 +66,43 @@ module RubyLLM
87
66
  dynamic = prompt.arity >= 1
88
67
  builder_input = dynamic ? input : Prompt::Builder::NOT_PROVIDED
89
68
  ast = Prompt::Builder.build(input: builder_input, &prompt)
90
- variables = dynamic ? {} : { input: input }
91
- variables.merge!(input.transform_keys(&:to_sym)) if !dynamic && input.is_a?(Hash)
92
- Prompt::Renderer.render(ast, variables: variables)
69
+ Prompt::Renderer.render(ast, variables: prompt_variables(input, dynamic))
93
70
  end
94
71
 
95
72
  private
96
73
 
74
+ def inline_dataset(input, expected, expected_traits, evaluator)
75
+ Eval::Dataset.define("single_case") do
76
+ add_case("inline", input: input, expected: expected,
77
+ expected_traits: expected_traits, evaluator: evaluator)
78
+ end
79
+ end
80
+
81
+ def estimated_model_name(model = nil)
82
+ model || (self.model if respond_to?(:model)) || RubyLLM::Contract.configuration.default_model
83
+ end
84
+
85
+ def estimated_cost_for(model_info, input_tokens, output_tokens)
86
+ CostCalculator.send(
87
+ :compute_cost,
88
+ model_info,
89
+ { input_tokens: input_tokens, output_tokens: output_tokens }
90
+ )
91
+ end
92
+
93
+ def estimate_eval_cost_for_model(cases, model_name)
94
+ cases.sum do |test_case|
95
+ estimate = estimate_cost(input: test_case.input, model: model_name)
96
+ estimate ? estimate[:estimated_cost] : 0.0
97
+ end.round(6)
98
+ end
99
+
100
+ def prompt_variables(input, dynamic)
101
+ variables = dynamic ? {} : { input: input }
102
+ variables.merge!(input.transform_keys(&:to_sym)) if !dynamic && input.is_a?(Hash)
103
+ variables
104
+ end
105
+
97
106
  def warn_unknown_context_keys(context)
98
107
  unknown = context.keys - KNOWN_CONTEXT_KEYS
99
108
  return if unknown.empty?
@@ -102,6 +111,39 @@ module RubyLLM
102
111
  "Known keys: #{KNOWN_CONTEXT_KEYS.inspect}"
103
112
  end
104
113
 
114
+ def dispatch_run(input, context)
115
+ adapter = resolve_adapter(context)
116
+ runtime = runtime_settings(context)
117
+
118
+ if runtime[:policy]
119
+ run_with_retry(
120
+ input,
121
+ adapter: adapter,
122
+ default_model: runtime[:model],
123
+ policy: runtime[:policy],
124
+ context_temperature: runtime[:temperature],
125
+ extra_options: runtime[:extra_options]
126
+ )
127
+ else
128
+ run_once(
129
+ input,
130
+ adapter: adapter,
131
+ model: runtime[:model],
132
+ context_temperature: runtime[:temperature],
133
+ extra_options: runtime[:extra_options]
134
+ )
135
+ end
136
+ end
137
+
138
+ def runtime_settings(context)
139
+ {
140
+ model: context[:model] || model || RubyLLM::Contract.configuration.default_model,
141
+ temperature: context[:temperature],
142
+ extra_options: context.slice(:provider, :assume_model_exists, :max_tokens),
143
+ policy: retry_policy
144
+ }
145
+ end
146
+
105
147
  def resolve_adapter(context)
106
148
  adapter = context[:adapter] || RubyLLM::Contract.configuration.default_adapter
107
149
  return adapter if adapter
@@ -117,7 +159,9 @@ module RubyLLM
117
159
  prompt_block: prompt, contract_definition: effective_contract,
118
160
  adapter: adapter, model: model, output_schema: output_schema,
119
161
  max_output: max_output, max_input: max_input, max_cost: max_cost,
120
- temperature: effective_temp, extra_options: extra_options
162
+ on_unknown_pricing: on_unknown_pricing,
163
+ temperature: effective_temp, extra_options: extra_options,
164
+ observers: class_observers
121
165
  ).call(input)
122
166
  rescue ArgumentError => e
123
167
  Result.new(status: :input_error, raw_output: nil, parsed_output: nil,
@@ -135,6 +179,19 @@ module RubyLLM
135
179
  "tokens=#{trace.usage&.dig(:input_tokens) || 0}+#{trace.usage&.dig(:output_tokens) || 0} " \
136
180
  "cost=$#{format("%.6f", trace.cost || 0)}"
137
181
  logger.info(msg)
182
+
183
+ log_failed_observations(result, logger)
184
+ end
185
+
186
+ def log_failed_observations(result, logger)
187
+ failed = result.observations.select { |o| !o[:passed] }
188
+ return if failed.empty?
189
+
190
+ failed.each do |obs|
191
+ msg = "[ruby_llm-contract] #{name || self} observation failed: #{obs[:description]}"
192
+ msg += " (#{obs[:error]})" if obs[:error]
193
+ logger.warn(msg)
194
+ end
138
195
  end
139
196
 
140
197
  def invoke_around_call(input, result)
@@ -79,6 +79,16 @@ module RubyLLM
79
79
  inherited + own
80
80
  end
81
81
 
82
+ def observe(description, &block)
83
+ (@class_observers ||= []) << Invariant.new(description, block)
84
+ end
85
+
86
+ def class_observers
87
+ own = defined?(@class_observers) ? @class_observers : []
88
+ inherited = superclass.respond_to?(:class_observers) ? superclass.class_observers : []
89
+ inherited + own
90
+ end
91
+
82
92
  def max_output(tokens = nil)
83
93
  if tokens
84
94
  unless tokens.is_a?(Numeric) && tokens.positive?
@@ -111,48 +121,83 @@ module RubyLLM
111
121
  end
112
122
  end
113
123
 
114
- def max_cost(amount = nil)
124
+ def max_cost(amount = nil, on_unknown_pricing: nil)
125
+ if amount == :default
126
+ @max_cost = nil
127
+ @max_cost_explicitly_unset = true
128
+ @on_unknown_pricing = nil
129
+ return nil
130
+ end
131
+
115
132
  if amount
116
133
  unless amount.is_a?(Numeric) && amount.positive?
117
134
  raise ArgumentError, "max_cost must be positive, got #{amount}"
118
135
  end
119
136
 
120
- return @max_cost = amount
137
+ if on_unknown_pricing && !%i[refuse warn].include?(on_unknown_pricing)
138
+ raise ArgumentError, "on_unknown_pricing must be :refuse or :warn, got #{on_unknown_pricing.inspect}"
139
+ end
140
+
141
+ @max_cost_explicitly_unset = false
142
+ @max_cost = amount
143
+ @on_unknown_pricing = on_unknown_pricing || :refuse
144
+ return @max_cost
121
145
  end
122
146
 
123
- if defined?(@max_cost)
124
- @max_cost
125
- elsif superclass.respond_to?(:max_cost)
126
- superclass.max_cost
147
+ return @max_cost if defined?(@max_cost) && !@max_cost_explicitly_unset
148
+ return nil if @max_cost_explicitly_unset
149
+
150
+ superclass.max_cost if superclass.respond_to?(:max_cost)
151
+ end
152
+
153
+ def on_unknown_pricing
154
+ if defined?(@on_unknown_pricing)
155
+ @on_unknown_pricing
156
+ elsif superclass.respond_to?(:on_unknown_pricing)
157
+ superclass.on_unknown_pricing
158
+ else
159
+ :refuse
127
160
  end
128
161
  end
129
162
 
130
163
  def model(name = nil)
164
+ if name == :default
165
+ @model = nil
166
+ @model_explicitly_unset = true
167
+ return nil
168
+ end
169
+
131
170
  if name
171
+ @model_explicitly_unset = false
132
172
  return @model = name
133
173
  end
134
174
 
135
- if defined?(@model)
136
- @model
137
- elsif superclass.respond_to?(:model)
138
- superclass.model
139
- end
175
+ return @model if defined?(@model) && !@model_explicitly_unset
176
+ return nil if @model_explicitly_unset
177
+
178
+ superclass.model if superclass.respond_to?(:model)
140
179
  end
141
180
 
142
181
  def temperature(value = nil)
182
+ if value == :default
183
+ @temperature = nil
184
+ @temperature_explicitly_unset = true
185
+ return nil
186
+ end
187
+
143
188
  if value
144
189
  unless value.is_a?(Numeric) && value >= 0 && value <= 2
145
190
  raise ArgumentError, "temperature must be 0.0-2.0, got #{value}"
146
191
  end
147
192
 
193
+ @temperature_explicitly_unset = false
148
194
  return @temperature = value
149
195
  end
150
196
 
151
- if defined?(@temperature)
152
- @temperature
153
- elsif superclass.respond_to?(:temperature)
154
- superclass.temperature
155
- end
197
+ return @temperature if defined?(@temperature) && !@temperature_explicitly_unset
198
+ return nil if @temperature_explicitly_unset
199
+
200
+ superclass.temperature if superclass.respond_to?(:temperature)
156
201
  end
157
202
 
158
203
  def around_call(&block)
@@ -0,0 +1,34 @@
1
+ # frozen_string_literal: true
2
+
3
+ module RubyLLM
4
+ module Contract
5
+ module Step
6
+ class InputValidator
7
+ def initialize(input_type:)
8
+ @input_type = input_type
9
+ end
10
+
11
+ def call(input)
12
+ validate(input)
13
+ nil
14
+ rescue Dry::Types::CoercionError, TypeError, ArgumentError => error
15
+ Result.new(status: :input_error, raw_output: nil, parsed_output: nil, validation_errors: [error.message])
16
+ end
17
+
18
+ private
19
+
20
+ def validate(input)
21
+ if ruby_class_input?
22
+ raise TypeError, "#{input.inspect} is not a #{@input_type}" unless input.is_a?(@input_type)
23
+ else
24
+ @input_type[input]
25
+ end
26
+ end
27
+
28
+ def ruby_class_input?
29
+ @input_type.is_a?(Class) && !@input_type.respond_to?(:[])
30
+ end
31
+ end
32
+ end
33
+ end
34
+ end