ruby_llm-contract 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/.rubocop.yml +55 -0
- data/CHANGELOG.md +76 -0
- data/Gemfile +11 -0
- data/Gemfile.lock +176 -0
- data/LICENSE +21 -0
- data/README.md +154 -0
- data/Rakefile +8 -0
- data/examples/00_basics.rb +500 -0
- data/examples/01_classify_threads.rb +220 -0
- data/examples/02_generate_comment.rb +203 -0
- data/examples/03_target_audience.rb +201 -0
- data/examples/04_real_llm.rb +410 -0
- data/examples/05_output_schema.rb +258 -0
- data/examples/07_keyword_extraction.rb +239 -0
- data/examples/08_translation.rb +353 -0
- data/examples/09_eval_dataset.rb +287 -0
- data/examples/10_reddit_full_showcase.rb +363 -0
- data/examples/README.md +140 -0
- data/lib/ruby_llm/contract/adapters/base.rb +13 -0
- data/lib/ruby_llm/contract/adapters/response.rb +17 -0
- data/lib/ruby_llm/contract/adapters/ruby_llm.rb +94 -0
- data/lib/ruby_llm/contract/adapters/test.rb +44 -0
- data/lib/ruby_llm/contract/adapters.rb +6 -0
- data/lib/ruby_llm/contract/concerns/deep_symbolize.rb +17 -0
- data/lib/ruby_llm/contract/concerns/eval_host.rb +109 -0
- data/lib/ruby_llm/contract/concerns/trace_equality.rb +15 -0
- data/lib/ruby_llm/contract/concerns/usage_aggregator.rb +43 -0
- data/lib/ruby_llm/contract/configuration.rb +21 -0
- data/lib/ruby_llm/contract/contract/definition.rb +39 -0
- data/lib/ruby_llm/contract/contract/invariant.rb +23 -0
- data/lib/ruby_llm/contract/contract/parser.rb +143 -0
- data/lib/ruby_llm/contract/contract/schema_validator.rb +239 -0
- data/lib/ruby_llm/contract/contract/validator.rb +104 -0
- data/lib/ruby_llm/contract/contract.rb +7 -0
- data/lib/ruby_llm/contract/cost_calculator.rb +38 -0
- data/lib/ruby_llm/contract/dsl.rb +13 -0
- data/lib/ruby_llm/contract/errors.rb +19 -0
- data/lib/ruby_llm/contract/eval/case_result.rb +76 -0
- data/lib/ruby_llm/contract/eval/contract_detail_builder.rb +47 -0
- data/lib/ruby_llm/contract/eval/dataset.rb +53 -0
- data/lib/ruby_llm/contract/eval/eval_definition.rb +112 -0
- data/lib/ruby_llm/contract/eval/evaluation_result.rb +27 -0
- data/lib/ruby_llm/contract/eval/evaluator/exact.rb +20 -0
- data/lib/ruby_llm/contract/eval/evaluator/json_includes.rb +58 -0
- data/lib/ruby_llm/contract/eval/evaluator/proc_evaluator.rb +40 -0
- data/lib/ruby_llm/contract/eval/evaluator/regex.rb +27 -0
- data/lib/ruby_llm/contract/eval/model_comparison.rb +80 -0
- data/lib/ruby_llm/contract/eval/pipeline_result_adapter.rb +15 -0
- data/lib/ruby_llm/contract/eval/report.rb +115 -0
- data/lib/ruby_llm/contract/eval/runner.rb +162 -0
- data/lib/ruby_llm/contract/eval/trait_evaluator.rb +75 -0
- data/lib/ruby_llm/contract/eval.rb +16 -0
- data/lib/ruby_llm/contract/pipeline/base.rb +62 -0
- data/lib/ruby_llm/contract/pipeline/result.rb +131 -0
- data/lib/ruby_llm/contract/pipeline/runner.rb +139 -0
- data/lib/ruby_llm/contract/pipeline/trace.rb +72 -0
- data/lib/ruby_llm/contract/pipeline.rb +6 -0
- data/lib/ruby_llm/contract/prompt/ast.rb +38 -0
- data/lib/ruby_llm/contract/prompt/builder.rb +47 -0
- data/lib/ruby_llm/contract/prompt/node.rb +25 -0
- data/lib/ruby_llm/contract/prompt/nodes/example_node.rb +27 -0
- data/lib/ruby_llm/contract/prompt/nodes/rule_node.rb +15 -0
- data/lib/ruby_llm/contract/prompt/nodes/section_node.rb +26 -0
- data/lib/ruby_llm/contract/prompt/nodes/system_node.rb +15 -0
- data/lib/ruby_llm/contract/prompt/nodes/user_node.rb +15 -0
- data/lib/ruby_llm/contract/prompt/nodes.rb +7 -0
- data/lib/ruby_llm/contract/prompt/renderer.rb +76 -0
- data/lib/ruby_llm/contract/railtie.rb +20 -0
- data/lib/ruby_llm/contract/rake_task.rb +78 -0
- data/lib/ruby_llm/contract/rspec/pass_eval.rb +96 -0
- data/lib/ruby_llm/contract/rspec/satisfy_contract.rb +31 -0
- data/lib/ruby_llm/contract/rspec.rb +6 -0
- data/lib/ruby_llm/contract/step/base.rb +138 -0
- data/lib/ruby_llm/contract/step/dsl.rb +144 -0
- data/lib/ruby_llm/contract/step/limit_checker.rb +64 -0
- data/lib/ruby_llm/contract/step/result.rb +38 -0
- data/lib/ruby_llm/contract/step/retry_executor.rb +90 -0
- data/lib/ruby_llm/contract/step/retry_policy.rb +76 -0
- data/lib/ruby_llm/contract/step/runner.rb +126 -0
- data/lib/ruby_llm/contract/step/trace.rb +70 -0
- data/lib/ruby_llm/contract/step.rb +10 -0
- data/lib/ruby_llm/contract/token_estimator.rb +19 -0
- data/lib/ruby_llm/contract/types.rb +11 -0
- data/lib/ruby_llm/contract/version.rb +7 -0
- data/lib/ruby_llm/contract.rb +108 -0
- data/ruby_llm-contract.gemspec +33 -0
- metadata +172 -0
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module RubyLLM
|
|
4
|
+
module Contract
|
|
5
|
+
module Step
|
|
6
|
+
# Extracted from Base to reduce class length.
|
|
7
|
+
# DSL accessor methods for step definition (input_type, output_type, prompt, etc.).
|
|
8
|
+
module Dsl # rubocop:disable Metrics/ModuleLength
|
|
9
|
+
def input_type(type = nil)
|
|
10
|
+
return @input_type = type if type
|
|
11
|
+
|
|
12
|
+
if defined?(@input_type)
|
|
13
|
+
@input_type
|
|
14
|
+
elsif superclass.respond_to?(:input_type)
|
|
15
|
+
superclass.input_type
|
|
16
|
+
else
|
|
17
|
+
String
|
|
18
|
+
end
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def output_type(type = nil)
|
|
22
|
+
return @output_type = type if type
|
|
23
|
+
|
|
24
|
+
if defined?(@output_type)
|
|
25
|
+
@output_type
|
|
26
|
+
elsif defined?(@output_schema) && @output_schema
|
|
27
|
+
RubyLLM::Contract::Types::Hash
|
|
28
|
+
elsif superclass.respond_to?(:output_type)
|
|
29
|
+
superclass.output_type
|
|
30
|
+
else
|
|
31
|
+
Hash
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
def output_schema(&block)
|
|
36
|
+
if block
|
|
37
|
+
require "ruby_llm/schema"
|
|
38
|
+
@output_schema = ::RubyLLM::Schema.create(&block)
|
|
39
|
+
elsif defined?(@output_schema)
|
|
40
|
+
@output_schema
|
|
41
|
+
elsif superclass.respond_to?(:output_schema)
|
|
42
|
+
superclass.output_schema
|
|
43
|
+
end
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
def prompt(text = nil, &block)
|
|
47
|
+
if text
|
|
48
|
+
@prompt_block = proc { user text }
|
|
49
|
+
elsif block
|
|
50
|
+
@prompt_block = block
|
|
51
|
+
elsif defined?(@prompt_block) && @prompt_block
|
|
52
|
+
@prompt_block
|
|
53
|
+
elsif superclass.respond_to?(:prompt)
|
|
54
|
+
superclass.prompt
|
|
55
|
+
else
|
|
56
|
+
raise(ArgumentError, "prompt has not been set")
|
|
57
|
+
end
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
def contract(&block)
|
|
61
|
+
return @contract_definition = Definition.new(&block) if block
|
|
62
|
+
|
|
63
|
+
if defined?(@contract_definition) && @contract_definition
|
|
64
|
+
@contract_definition
|
|
65
|
+
elsif superclass.respond_to?(:contract)
|
|
66
|
+
superclass.contract
|
|
67
|
+
else
|
|
68
|
+
Definition.new
|
|
69
|
+
end
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
def validate(description, &block)
|
|
73
|
+
(@class_validates ||= []) << Invariant.new(description, block)
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
def class_validates
|
|
77
|
+
own = defined?(@class_validates) ? @class_validates : []
|
|
78
|
+
inherited = superclass.respond_to?(:class_validates) ? superclass.class_validates : []
|
|
79
|
+
inherited + own
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
def max_output(tokens = nil)
|
|
83
|
+
if tokens
|
|
84
|
+
unless tokens.is_a?(Numeric) && tokens.positive?
|
|
85
|
+
raise ArgumentError, "max_output must be positive, got #{tokens}"
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
return @max_output = tokens
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
if defined?(@max_output)
|
|
92
|
+
@max_output
|
|
93
|
+
elsif superclass.respond_to?(:max_output)
|
|
94
|
+
superclass.max_output
|
|
95
|
+
end
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
def max_input(tokens = nil)
|
|
99
|
+
if tokens
|
|
100
|
+
unless tokens.is_a?(Numeric) && tokens.positive?
|
|
101
|
+
raise ArgumentError, "max_input must be positive, got #{tokens}"
|
|
102
|
+
end
|
|
103
|
+
|
|
104
|
+
return @max_input = tokens
|
|
105
|
+
end
|
|
106
|
+
|
|
107
|
+
if defined?(@max_input)
|
|
108
|
+
@max_input
|
|
109
|
+
elsif superclass.respond_to?(:max_input)
|
|
110
|
+
superclass.max_input
|
|
111
|
+
end
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
def max_cost(amount = nil)
|
|
115
|
+
if amount
|
|
116
|
+
unless amount.is_a?(Numeric) && amount.positive?
|
|
117
|
+
raise ArgumentError, "max_cost must be positive, got #{amount}"
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
return @max_cost = amount
|
|
121
|
+
end
|
|
122
|
+
|
|
123
|
+
if defined?(@max_cost)
|
|
124
|
+
@max_cost
|
|
125
|
+
elsif superclass.respond_to?(:max_cost)
|
|
126
|
+
superclass.max_cost
|
|
127
|
+
end
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
def retry_policy(models: nil, attempts: nil, retry_on: nil, &block)
|
|
131
|
+
if block || models || attempts
|
|
132
|
+
return @retry_policy = RetryPolicy.new(models: models, attempts: attempts, retry_on: retry_on, &block)
|
|
133
|
+
end
|
|
134
|
+
|
|
135
|
+
if defined?(@retry_policy) && @retry_policy
|
|
136
|
+
@retry_policy
|
|
137
|
+
elsif superclass.respond_to?(:retry_policy)
|
|
138
|
+
superclass.retry_policy
|
|
139
|
+
end
|
|
140
|
+
end
|
|
141
|
+
end
|
|
142
|
+
end
|
|
143
|
+
end
|
|
144
|
+
end
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module RubyLLM
|
|
4
|
+
module Contract
|
|
5
|
+
module Step
|
|
6
|
+
# Extracted from Runner to reduce class length.
|
|
7
|
+
# Handles input token limit and cost limit checks.
|
|
8
|
+
module LimitChecker
|
|
9
|
+
private
|
|
10
|
+
|
|
11
|
+
def check_limits(messages)
|
|
12
|
+
return nil unless @max_input || @max_cost
|
|
13
|
+
|
|
14
|
+
estimated = TokenEstimator.estimate(messages)
|
|
15
|
+
errors = collect_limit_errors(estimated)
|
|
16
|
+
|
|
17
|
+
return nil if errors.empty?
|
|
18
|
+
|
|
19
|
+
build_limit_result(messages, estimated, errors)
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def collect_limit_errors(estimated)
|
|
23
|
+
errors = []
|
|
24
|
+
if @max_input && estimated > @max_input
|
|
25
|
+
errors << "Input token limit exceeded: estimated #{estimated} tokens, max #{@max_input}"
|
|
26
|
+
end
|
|
27
|
+
append_cost_error(estimated, errors) if @max_cost
|
|
28
|
+
errors
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def append_cost_error(estimated, errors)
|
|
32
|
+
estimated_output = @max_output || 0
|
|
33
|
+
estimated_cost = CostCalculator.calculate(
|
|
34
|
+
model_name: @model,
|
|
35
|
+
usage: { input_tokens: estimated, output_tokens: estimated_output }
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
if estimated_cost.nil?
|
|
39
|
+
warn "[ruby_llm-contract] max_cost is configured but model '#{@model}' " \
|
|
40
|
+
"has no pricing data — cost limit not enforced"
|
|
41
|
+
elsif estimated_cost > @max_cost
|
|
42
|
+
errors << "Cost limit exceeded: estimated $#{format("%.6f", estimated_cost)} " \
|
|
43
|
+
"(#{estimated} input + #{estimated_output} output tokens), " \
|
|
44
|
+
"max $#{format("%.6f", @max_cost)}"
|
|
45
|
+
end
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
def build_limit_result(messages, estimated, errors)
|
|
49
|
+
Result.new(
|
|
50
|
+
status: :limit_exceeded,
|
|
51
|
+
raw_output: nil,
|
|
52
|
+
parsed_output: nil,
|
|
53
|
+
validation_errors: errors,
|
|
54
|
+
trace: Trace.new(
|
|
55
|
+
messages: messages, model: @model,
|
|
56
|
+
usage: { input_tokens: 0, output_tokens: 0, estimated_input_tokens: estimated,
|
|
57
|
+
estimate_method: :heuristic }
|
|
58
|
+
)
|
|
59
|
+
)
|
|
60
|
+
end
|
|
61
|
+
end
|
|
62
|
+
end
|
|
63
|
+
end
|
|
64
|
+
end
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module RubyLLM
|
|
4
|
+
module Contract
|
|
5
|
+
module Step
|
|
6
|
+
class Result
|
|
7
|
+
attr_reader :status, :raw_output, :parsed_output, :validation_errors, :trace
|
|
8
|
+
|
|
9
|
+
def initialize(status:, raw_output:, parsed_output:, validation_errors: [], trace: {})
|
|
10
|
+
@status = status
|
|
11
|
+
@raw_output = raw_output
|
|
12
|
+
@parsed_output = parsed_output
|
|
13
|
+
@validation_errors = validation_errors.freeze
|
|
14
|
+
@trace = trace.freeze
|
|
15
|
+
freeze
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def ok?
|
|
19
|
+
@status == :ok
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def failed?
|
|
23
|
+
@status != :ok
|
|
24
|
+
end
|
|
25
|
+
|
|
26
|
+
def to_s
|
|
27
|
+
if ok?
|
|
28
|
+
"#{@status} (#{@trace})"
|
|
29
|
+
else
|
|
30
|
+
errors = @validation_errors.first(3).join(", ")
|
|
31
|
+
errors += ", ..." if @validation_errors.size > 3
|
|
32
|
+
"#{@status}: #{errors}"
|
|
33
|
+
end
|
|
34
|
+
end
|
|
35
|
+
end
|
|
36
|
+
end
|
|
37
|
+
end
|
|
38
|
+
end
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module RubyLLM
|
|
4
|
+
module Contract
|
|
5
|
+
module Step
|
|
6
|
+
# Extracted from Base to reduce class length.
|
|
7
|
+
# Handles retry logic: run_with_retry, build_retry_result, aggregate usage, build attempt entries.
|
|
8
|
+
module RetryExecutor
|
|
9
|
+
private
|
|
10
|
+
|
|
11
|
+
def run_with_retry(input, adapter:, default_model:, policy:)
|
|
12
|
+
all_attempts = []
|
|
13
|
+
|
|
14
|
+
policy.max_attempts.times do |attempt_index|
|
|
15
|
+
model = policy.model_for_attempt(attempt_index, default_model)
|
|
16
|
+
result = run_once(input, adapter: adapter, model: model)
|
|
17
|
+
all_attempts << { attempt: attempt_index + 1, model: model, result: result }
|
|
18
|
+
break unless policy.retryable?(result)
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
build_retry_result(all_attempts)
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def build_retry_result(all_attempts)
|
|
25
|
+
last = all_attempts.last[:result]
|
|
26
|
+
attempt_log = all_attempts.map { |attempt| build_attempt_entry(attempt) }
|
|
27
|
+
aggregated_usage = aggregate_retry_usage(all_attempts)
|
|
28
|
+
total_cost = sum_attempt_costs(all_attempts)
|
|
29
|
+
total_latency = sum_attempt_latency(all_attempts)
|
|
30
|
+
|
|
31
|
+
Result.new(
|
|
32
|
+
status: last.status, raw_output: last.raw_output,
|
|
33
|
+
parsed_output: last.parsed_output, validation_errors: last.validation_errors,
|
|
34
|
+
trace: last.trace.merge(
|
|
35
|
+
attempts: attempt_log, usage: aggregated_usage,
|
|
36
|
+
cost: total_cost, latency_ms: total_latency
|
|
37
|
+
)
|
|
38
|
+
)
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def build_attempt_entry(attempt)
|
|
42
|
+
trace = attempt[:result].trace
|
|
43
|
+
entry = { attempt: attempt[:attempt], model: attempt[:model], status: attempt[:result].status }
|
|
44
|
+
append_trace_fields(entry, trace)
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
def append_trace_fields(entry, trace)
|
|
48
|
+
entry[:usage] = trace.usage if trace.respond_to?(:usage) && trace.usage
|
|
49
|
+
entry[:latency_ms] = trace.latency_ms if trace.respond_to?(:latency_ms) && trace.latency_ms
|
|
50
|
+
entry[:cost] = trace.cost if trace.respond_to?(:cost) && trace.cost
|
|
51
|
+
entry
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
def sum_attempt_costs(all_attempts)
|
|
55
|
+
costs = extract_trace_values(all_attempts, :cost)
|
|
56
|
+
return nil if costs.empty?
|
|
57
|
+
|
|
58
|
+
costs.sum.round(6)
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
def sum_attempt_latency(all_attempts)
|
|
62
|
+
latencies = extract_trace_values(all_attempts, :latency_ms)
|
|
63
|
+
return nil if latencies.empty?
|
|
64
|
+
|
|
65
|
+
latencies.sum
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
def extract_trace_values(all_attempts, method)
|
|
69
|
+
all_attempts.filter_map do |a|
|
|
70
|
+
trace = a[:result].trace
|
|
71
|
+
trace.respond_to?(method) && trace.public_send(method)
|
|
72
|
+
end
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
def aggregate_retry_usage(all_attempts)
|
|
76
|
+
totals = { input_tokens: 0, output_tokens: 0 }
|
|
77
|
+
all_attempts.each do |attempt|
|
|
78
|
+
usage = attempt[:result].trace
|
|
79
|
+
usage = usage.respond_to?(:usage) ? usage.usage : nil
|
|
80
|
+
next unless usage.is_a?(Hash)
|
|
81
|
+
|
|
82
|
+
totals[:input_tokens] += usage[:input_tokens] || 0
|
|
83
|
+
totals[:output_tokens] += usage[:output_tokens] || 0
|
|
84
|
+
end
|
|
85
|
+
totals
|
|
86
|
+
end
|
|
87
|
+
end
|
|
88
|
+
end
|
|
89
|
+
end
|
|
90
|
+
end
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module RubyLLM
|
|
4
|
+
module Contract
|
|
5
|
+
module Step
|
|
6
|
+
class RetryPolicy
|
|
7
|
+
attr_reader :max_attempts, :retryable_statuses
|
|
8
|
+
|
|
9
|
+
DEFAULT_RETRY_ON = %i[validation_failed parse_error adapter_error].freeze
|
|
10
|
+
|
|
11
|
+
def initialize(models: nil, attempts: nil, retry_on: nil, &block)
|
|
12
|
+
@models = []
|
|
13
|
+
@retryable_statuses = DEFAULT_RETRY_ON.dup
|
|
14
|
+
|
|
15
|
+
if block
|
|
16
|
+
@max_attempts = 1
|
|
17
|
+
instance_eval(&block)
|
|
18
|
+
else
|
|
19
|
+
apply_keywords(models: models, attempts: attempts, retry_on: retry_on)
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
validate_max_attempts!
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
def attempts(count)
|
|
26
|
+
@max_attempts = count
|
|
27
|
+
validate_max_attempts!
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
def escalate(*model_list)
|
|
31
|
+
@models = model_list.flatten
|
|
32
|
+
@max_attempts = @models.length if @max_attempts < @models.length
|
|
33
|
+
end
|
|
34
|
+
alias models escalate
|
|
35
|
+
|
|
36
|
+
def model_list
|
|
37
|
+
@models
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
def retry_on(*statuses)
|
|
41
|
+
@retryable_statuses = statuses
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
def retryable?(result)
|
|
45
|
+
retryable_statuses.include?(result.status)
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
def model_for_attempt(attempt, default_model)
|
|
49
|
+
if @models.any?
|
|
50
|
+
@models[attempt] || @models.last
|
|
51
|
+
else
|
|
52
|
+
default_model
|
|
53
|
+
end
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
private
|
|
57
|
+
|
|
58
|
+
def apply_keywords(models:, attempts:, retry_on:)
|
|
59
|
+
if models
|
|
60
|
+
@models = Array(models).dup.freeze
|
|
61
|
+
@max_attempts = @models.length
|
|
62
|
+
else
|
|
63
|
+
@max_attempts = attempts || 1
|
|
64
|
+
end
|
|
65
|
+
@retryable_statuses = Array(retry_on).dup if retry_on
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
def validate_max_attempts!
|
|
69
|
+
return if @max_attempts.is_a?(Integer) && @max_attempts >= 1
|
|
70
|
+
|
|
71
|
+
raise ArgumentError, "attempts must be at least 1, got #{@max_attempts.inspect}"
|
|
72
|
+
end
|
|
73
|
+
end
|
|
74
|
+
end
|
|
75
|
+
end
|
|
76
|
+
end
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module RubyLLM
|
|
4
|
+
module Contract
|
|
5
|
+
module Step
|
|
6
|
+
class Runner
|
|
7
|
+
include LimitChecker
|
|
8
|
+
|
|
9
|
+
def initialize(input_type:, output_type:, prompt_block:, contract_definition:,
|
|
10
|
+
adapter:, model:, output_schema: nil, max_output: nil,
|
|
11
|
+
max_input: nil, max_cost: nil)
|
|
12
|
+
@input_type = input_type
|
|
13
|
+
@output_type = output_type
|
|
14
|
+
@prompt_block = prompt_block
|
|
15
|
+
@contract_definition = contract_definition
|
|
16
|
+
@adapter = adapter
|
|
17
|
+
@model = model
|
|
18
|
+
@output_schema = output_schema
|
|
19
|
+
@max_output = max_output
|
|
20
|
+
@max_input = max_input
|
|
21
|
+
@max_cost = max_cost
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def call(input)
|
|
25
|
+
validated_input = validate_input(input)
|
|
26
|
+
return validated_input if validated_input.is_a?(Result)
|
|
27
|
+
|
|
28
|
+
messages = build_and_render_prompt(input)
|
|
29
|
+
rescue RubyLLM::Contract::Error => e
|
|
30
|
+
Result.new(status: :input_error, raw_output: nil, parsed_output: nil,
|
|
31
|
+
validation_errors: [e.message])
|
|
32
|
+
else
|
|
33
|
+
execute_pipeline(messages, input)
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
private
|
|
37
|
+
|
|
38
|
+
def execute_pipeline(messages, input)
|
|
39
|
+
limit_result = check_limits(messages)
|
|
40
|
+
return limit_result if limit_result
|
|
41
|
+
|
|
42
|
+
response, latency_ms = execute_adapter(messages)
|
|
43
|
+
return build_error_result(response, messages) if response.is_a?(Result)
|
|
44
|
+
|
|
45
|
+
build_result(response, messages, latency_ms, input)
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
def validate_input(input)
|
|
49
|
+
type = @input_type
|
|
50
|
+
if type.is_a?(Class) && !type.respond_to?(:[])
|
|
51
|
+
raise TypeError, "#{input.inspect} is not a #{type}" unless input.is_a?(type)
|
|
52
|
+
else
|
|
53
|
+
type[input]
|
|
54
|
+
end
|
|
55
|
+
nil
|
|
56
|
+
rescue Dry::Types::CoercionError, TypeError, ArgumentError => e
|
|
57
|
+
Result.new(status: :input_error, raw_output: nil, parsed_output: nil, validation_errors: [e.message])
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
def build_and_render_prompt(input)
|
|
61
|
+
dynamic = @prompt_block.arity >= 1
|
|
62
|
+
ast = Prompt::Builder.build(input: dynamic ? input : nil, &@prompt_block)
|
|
63
|
+
|
|
64
|
+
Prompt::Renderer.render(ast, variables: dynamic ? {} : template_variables_for(input))
|
|
65
|
+
rescue StandardError => e
|
|
66
|
+
raise RubyLLM::Contract::Error, "Prompt build failed: #{e.class}: #{e.message}"
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
def template_variables_for(input)
|
|
70
|
+
base = { input: input }
|
|
71
|
+
input.is_a?(Hash) ? base.merge(input.transform_keys(&:to_sym)) : base
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
def execute_adapter(messages)
|
|
75
|
+
start_time = Process.clock_gettime(Process::CLOCK_MONOTONIC)
|
|
76
|
+
response = @adapter.call(messages: messages, **build_adapter_options)
|
|
77
|
+
latency_ms = ((Process.clock_gettime(Process::CLOCK_MONOTONIC) - start_time) * 1000).round
|
|
78
|
+
[response, latency_ms]
|
|
79
|
+
rescue StandardError => e
|
|
80
|
+
[Result.new(status: :adapter_error, raw_output: nil, parsed_output: nil, validation_errors: [e.message]), 0]
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
def build_adapter_options
|
|
84
|
+
{ model: @model }.tap do |opts|
|
|
85
|
+
opts[:schema] = @output_schema if @output_schema
|
|
86
|
+
opts[:max_tokens] = @max_output if @max_output
|
|
87
|
+
end
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
def build_error_result(error_result, messages)
|
|
91
|
+
Result.new(
|
|
92
|
+
status: error_result.status,
|
|
93
|
+
raw_output: error_result.raw_output,
|
|
94
|
+
parsed_output: error_result.parsed_output,
|
|
95
|
+
validation_errors: error_result.validation_errors,
|
|
96
|
+
trace: Trace.new(messages: messages, model: @model)
|
|
97
|
+
)
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
def build_result(response, messages, latency_ms, input)
|
|
101
|
+
raw_output = response.content
|
|
102
|
+
validation_result = validate_output(raw_output, input)
|
|
103
|
+
trace = Trace.new(messages: messages, model: @model, latency_ms: latency_ms, usage: response.usage)
|
|
104
|
+
|
|
105
|
+
Result.new(
|
|
106
|
+
status: validation_result[:status],
|
|
107
|
+
raw_output: raw_output,
|
|
108
|
+
parsed_output: validation_result[:parsed_output],
|
|
109
|
+
validation_errors: validation_result[:errors],
|
|
110
|
+
trace: trace
|
|
111
|
+
)
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
def validate_output(raw_output, input)
|
|
115
|
+
Validator.validate(
|
|
116
|
+
raw_output: raw_output,
|
|
117
|
+
definition: @contract_definition,
|
|
118
|
+
output_type: @output_type,
|
|
119
|
+
input: input,
|
|
120
|
+
schema: @output_schema
|
|
121
|
+
)
|
|
122
|
+
end
|
|
123
|
+
end
|
|
124
|
+
end
|
|
125
|
+
end
|
|
126
|
+
end
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module RubyLLM
|
|
4
|
+
module Contract
|
|
5
|
+
module Step
|
|
6
|
+
class Trace
|
|
7
|
+
include Concerns::TraceEquality
|
|
8
|
+
|
|
9
|
+
attr_reader :messages, :model, :latency_ms, :usage, :attempts, :cost
|
|
10
|
+
|
|
11
|
+
def initialize(messages: nil, model: nil, latency_ms: nil, usage: nil, attempts: nil, cost: nil)
|
|
12
|
+
@messages = messages
|
|
13
|
+
@model = model
|
|
14
|
+
@latency_ms = latency_ms
|
|
15
|
+
@usage = usage
|
|
16
|
+
@attempts = attempts
|
|
17
|
+
@cost = cost || CostCalculator.calculate(model_name: model, usage: usage)
|
|
18
|
+
freeze
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
KNOWN_KEYS = %i[messages model latency_ms usage attempts cost].freeze
|
|
22
|
+
|
|
23
|
+
def [](key)
|
|
24
|
+
return nil unless KNOWN_KEYS.include?(key.to_sym)
|
|
25
|
+
|
|
26
|
+
public_send(key)
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
def key?(key)
|
|
30
|
+
KNOWN_KEYS.include?(key.to_sym) && !public_send(key).nil?
|
|
31
|
+
end
|
|
32
|
+
alias has_key? key?
|
|
33
|
+
|
|
34
|
+
def merge(**overrides)
|
|
35
|
+
self.class.new(
|
|
36
|
+
messages: overrides.fetch(:messages, @messages),
|
|
37
|
+
model: overrides.fetch(:model, @model),
|
|
38
|
+
latency_ms: overrides.fetch(:latency_ms, @latency_ms),
|
|
39
|
+
usage: overrides.fetch(:usage, @usage),
|
|
40
|
+
attempts: overrides.fetch(:attempts, @attempts),
|
|
41
|
+
cost: overrides.fetch(:cost, @cost)
|
|
42
|
+
)
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def to_h
|
|
46
|
+
{ messages: @messages, model: @model, latency_ms: @latency_ms,
|
|
47
|
+
usage: @usage, attempts: @attempts, cost: @cost }.compact
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
def to_s
|
|
51
|
+
build_summary_parts.join(" ")
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
private
|
|
55
|
+
|
|
56
|
+
def build_summary_parts
|
|
57
|
+
parts = [@model || "no-model"]
|
|
58
|
+
parts << "#{@latency_ms}ms" if @latency_ms
|
|
59
|
+
parts << format_token_usage if @usage.is_a?(Hash)
|
|
60
|
+
parts << "$#{format("%.6f", @cost)}" if @cost
|
|
61
|
+
parts
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
def format_token_usage
|
|
65
|
+
"#{@usage[:input_tokens] || 0}+#{@usage[:output_tokens] || 0} tokens"
|
|
66
|
+
end
|
|
67
|
+
end
|
|
68
|
+
end
|
|
69
|
+
end
|
|
70
|
+
end
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative "step/trace"
|
|
4
|
+
require_relative "step/result"
|
|
5
|
+
require_relative "step/limit_checker"
|
|
6
|
+
require_relative "step/runner"
|
|
7
|
+
require_relative "step/retry_policy"
|
|
8
|
+
require_relative "step/retry_executor"
|
|
9
|
+
require_relative "step/dsl"
|
|
10
|
+
require_relative "step/base"
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module RubyLLM
|
|
4
|
+
module Contract
|
|
5
|
+
module TokenEstimator
|
|
6
|
+
# Heuristic: ~4 characters per token for English text.
|
|
7
|
+
# This is a rough estimate — actual tokenization varies by model and content.
|
|
8
|
+
# Intentionally conservative (overestimates slightly) to avoid surprise costs.
|
|
9
|
+
CHARS_PER_TOKEN = 4
|
|
10
|
+
|
|
11
|
+
def self.estimate(messages)
|
|
12
|
+
return 0 unless messages.is_a?(Array)
|
|
13
|
+
|
|
14
|
+
total_chars = messages.sum { |m| m[:content].to_s.length }
|
|
15
|
+
(total_chars.to_f / CHARS_PER_TOKEN).ceil
|
|
16
|
+
end
|
|
17
|
+
end
|
|
18
|
+
end
|
|
19
|
+
end
|