ai_guardrails 1.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/.rubocop.yml +8 -0
- data/CHANGELOG.md +149 -0
- data/CODE_OF_CONDUCT.md +132 -0
- data/LICENSE.txt +21 -0
- data/README.md +528 -0
- data/Rakefile +12 -0
- data/lib/ai_guardrails/auto_correction.rb +85 -0
- data/lib/ai_guardrails/auto_fix.rb +85 -0
- data/lib/ai_guardrails/background_job.rb +47 -0
- data/lib/ai_guardrails/cache.rb +50 -0
- data/lib/ai_guardrails/cli.rb +17 -0
- data/lib/ai_guardrails/config.rb +13 -0
- data/lib/ai_guardrails/dsl.rb +101 -0
- data/lib/ai_guardrails/json_repair.rb +234 -0
- data/lib/ai_guardrails/logger.rb +45 -0
- data/lib/ai_guardrails/mock_model_client.rb +34 -0
- data/lib/ai_guardrails/provider/base_client.rb +19 -0
- data/lib/ai_guardrails/provider/factory.rb +20 -0
- data/lib/ai_guardrails/provider/openai_client.rb +43 -0
- data/lib/ai_guardrails/runner.rb +40 -0
- data/lib/ai_guardrails/safety_filter.rb +33 -0
- data/lib/ai_guardrails/schema_validator.rb +57 -0
- data/lib/ai_guardrails/version.rb +5 -0
- data/lib/ai_guardrails.rb +40 -0
- data/sig/ai_guardrails.rbs +4 -0
- metadata +122 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module AiGuardrails
|
|
4
|
+
# Provides helper methods to run AiGuardrails in background jobs or CLI
|
|
5
|
+
module BackgroundJob
|
|
6
|
+
class << self
|
|
7
|
+
# Executes a task safely in background or CLI
|
|
8
|
+
#
|
|
9
|
+
# Example usage:
|
|
10
|
+
# AiGuardrails::BackgroundJob.perform do
|
|
11
|
+
# AiGuardrails::DSL.run(prompt: "...", schema: {...})
|
|
12
|
+
# end
|
|
13
|
+
#
|
|
14
|
+
# Optional parameters:
|
|
15
|
+
# logger: custom logger instance
|
|
16
|
+
# debug: true/false for debug mode
|
|
17
|
+
def perform(logger: nil, debug: false, &block)
|
|
18
|
+
with_temp_logger(logger, debug, &block)
|
|
19
|
+
rescue StandardError => e
|
|
20
|
+
Logger.logger&.error("Background job failed: #{e.class} - #{e.message}")
|
|
21
|
+
raise e
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def with_temp_logger(temp_logger, temp_debug, &block)
|
|
25
|
+
prev_logger = Logger.logger
|
|
26
|
+
prev_debug = Logger.debug_mode
|
|
27
|
+
|
|
28
|
+
Logger.logger = temp_logger if temp_logger
|
|
29
|
+
Logger.debug_mode = temp_debug
|
|
30
|
+
|
|
31
|
+
perform_with_error_logging(&block)
|
|
32
|
+
ensure
|
|
33
|
+
Logger.logger = prev_logger
|
|
34
|
+
Logger.debug_mode = prev_debug
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
private
|
|
38
|
+
|
|
39
|
+
def perform_with_error_logging(&block)
|
|
40
|
+
block.call
|
|
41
|
+
rescue StandardError => e
|
|
42
|
+
Logger.logger&.error("Background job failed: #{e.class} - #{e.message}")
|
|
43
|
+
raise e
|
|
44
|
+
end
|
|
45
|
+
end
|
|
46
|
+
end
|
|
47
|
+
end
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "digest"
|
|
4
|
+
|
|
5
|
+
module AiGuardrails
|
|
6
|
+
# Simple caching layer for AI responses
|
|
7
|
+
module Cache
|
|
8
|
+
class << self
|
|
9
|
+
attr_accessor :enabled, :store, :expires_in
|
|
10
|
+
|
|
11
|
+
# Setup cache
|
|
12
|
+
# store: any object responding to #read/#write (e.g., Rails.cache, ActiveSupport::Cache)
|
|
13
|
+
# expires_in: seconds
|
|
14
|
+
def configure(enabled: true, store: nil, expires_in: 300)
|
|
15
|
+
@enabled = enabled
|
|
16
|
+
@store = store || NullStore.new
|
|
17
|
+
@expires_in = expires_in
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# Accept default or block, works with caching disabled
|
|
21
|
+
def fetch(key, default = nil)
|
|
22
|
+
return (block_given? ? yield : default) unless enabled
|
|
23
|
+
|
|
24
|
+
cached = store.read(key, expires_in: expires_in)
|
|
25
|
+
return cached if cached
|
|
26
|
+
|
|
27
|
+
result = block_given? ? yield : default
|
|
28
|
+
store.write(key, result, expires_in: expires_in)
|
|
29
|
+
result
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
# Generate a cache key from prompt + schema
|
|
33
|
+
def key(prompt, schema)
|
|
34
|
+
digest_input = "#{prompt}-#{schema}"
|
|
35
|
+
Digest::SHA256.hexdigest(digest_input)
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
# Null object if no cache store is provided
|
|
39
|
+
class NullStore
|
|
40
|
+
def read(_key, **_options)
|
|
41
|
+
nil
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
def write(_key, value, **_options)
|
|
45
|
+
value
|
|
46
|
+
end
|
|
47
|
+
end
|
|
48
|
+
end
|
|
49
|
+
end
|
|
50
|
+
end
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module AiGuardrails
|
|
4
|
+
# Provides a CLI-friendly interface for running AiGuardrails safely
|
|
5
|
+
module CLI
|
|
6
|
+
# Runs AiGuardrails safely in CLI scripts
|
|
7
|
+
#
|
|
8
|
+
# Example:
|
|
9
|
+
# AiGuardrails::CLI.run do
|
|
10
|
+
# result = AiGuardrails::DSL.run(prompt: "...", schema: {...})
|
|
11
|
+
# puts result
|
|
12
|
+
# end
|
|
13
|
+
def self.run(debug: false, &block)
|
|
14
|
+
BackgroundJob.perform(logger: Logger.logger, debug: debug, &block)
|
|
15
|
+
end
|
|
16
|
+
end
|
|
17
|
+
end
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module AiGuardrails
|
|
4
|
+
# Provides a simple developer-friendly interface
|
|
5
|
+
module DSL
|
|
6
|
+
class << self
|
|
7
|
+
# Main entry point used by developers.
|
|
8
|
+
# Run AI model with validation, retries, and safety checks.
|
|
9
|
+
def run(prompt:, schema:, schema_hint: nil, **options)
|
|
10
|
+
Cache.fetch(Cache.key(prompt, schema)) do
|
|
11
|
+
result = fetch_with_retries_and_correction(prompt, schema, schema_hint, options)
|
|
12
|
+
|
|
13
|
+
puts "result in DSL: #{result}"
|
|
14
|
+
|
|
15
|
+
# Apply JSON + schema auto-fix when hooks are given.
|
|
16
|
+
hooks = options.fetch(:auto_fix_hooks, [])
|
|
17
|
+
fix_schema = schema_hint || schema
|
|
18
|
+
result = apply_auto_fix(result, fix_schema, hooks) unless hooks.empty?
|
|
19
|
+
|
|
20
|
+
check_safety(result, options.fetch(:blocklist, []))
|
|
21
|
+
result
|
|
22
|
+
end
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
private
|
|
26
|
+
|
|
27
|
+
# Extracted to reduce run method length
|
|
28
|
+
def fetch_with_retries_and_correction(prompt, schema, schema_hint, options)
|
|
29
|
+
client = build_client(options.fetch(:provider, :openai), options.fetch(:provider_config, {}))
|
|
30
|
+
max_retries = options.fetch(:max_retries, 3)
|
|
31
|
+
sleep_time = options.fetch(:sleep_time, 0)
|
|
32
|
+
run_with_retries_helper(
|
|
33
|
+
client: client, schema: schema, prompt: prompt,
|
|
34
|
+
max_retries: max_retries,
|
|
35
|
+
sleep_time: sleep_time,
|
|
36
|
+
schema_hint: schema_hint
|
|
37
|
+
)
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
# Builds the provider client
|
|
41
|
+
def build_client(provider, config)
|
|
42
|
+
Provider::Factory.build(provider: provider, config: config)
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
# Runs AutoCorrection wrapper (max 5 parameters)
|
|
46
|
+
def run_with_retries_helper(options = {})
|
|
47
|
+
client = options[:client]
|
|
48
|
+
schema = options[:schema]
|
|
49
|
+
prompt = options[:prompt]
|
|
50
|
+
max_retries = options[:max_retries] || 3
|
|
51
|
+
sleep_time = options[:sleep_time] || 0
|
|
52
|
+
schema_hint = options[:schema_hint]
|
|
53
|
+
|
|
54
|
+
auto = AutoCorrection.new(
|
|
55
|
+
provider: client, schema: schema, max_retries: max_retries, sleep_time: sleep_time
|
|
56
|
+
)
|
|
57
|
+
auto.call(prompt: prompt, schema_hint: schema_hint)
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
# Applies blocklist filtering when needed
|
|
61
|
+
def apply_auto_fix(result, schema, hooks)
|
|
62
|
+
AiGuardrails::AutoFix.fix(result, schema: schema, hooks: hooks)
|
|
63
|
+
end
|
|
64
|
+
|
|
65
|
+
# Runs safety filter when needed.
|
|
66
|
+
def check_safety(result, blocklist)
|
|
67
|
+
return if blocklist.empty?
|
|
68
|
+
|
|
69
|
+
content = normalize_result(result)
|
|
70
|
+
check_blocklist(content, blocklist)
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
# Normalizes result into a simple string for safety scanning.
|
|
74
|
+
def normalize_result(result)
|
|
75
|
+
case result
|
|
76
|
+
when Hash
|
|
77
|
+
result.values.join(" ")
|
|
78
|
+
when String
|
|
79
|
+
parse_json_string(result)
|
|
80
|
+
else
|
|
81
|
+
result.to_s
|
|
82
|
+
end
|
|
83
|
+
end
|
|
84
|
+
|
|
85
|
+
# Attempt to parse string as JSON; fallback to original string if parsing fails
|
|
86
|
+
def parse_json_string(str)
|
|
87
|
+
parsed = JSON.parse(str)
|
|
88
|
+
parsed.is_a?(Hash) ? parsed.values.join(" ") : str
|
|
89
|
+
rescue JSON::ParserError
|
|
90
|
+
str
|
|
91
|
+
end
|
|
92
|
+
|
|
93
|
+
# Perform case-insensitive safety check using SafetyFilter
|
|
94
|
+
def check_blocklist(content, blocklist)
|
|
95
|
+
content_down = content.downcase
|
|
96
|
+
blocklist_down = blocklist.map(&:downcase)
|
|
97
|
+
SafetyFilter.new(blocklist: blocklist_down).check!(content_down)
|
|
98
|
+
end
|
|
99
|
+
end
|
|
100
|
+
end
|
|
101
|
+
end
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
|
|
5
|
+
module AiGuardrails
|
|
6
|
+
# Repairs malformed JSON strings
|
|
7
|
+
# rubocop:disable Metrics/ClassLength
|
|
8
|
+
class JsonRepair
|
|
9
|
+
class RepairError < StandardError; end
|
|
10
|
+
|
|
11
|
+
# Class method entrypoint
|
|
12
|
+
def self.repair(raw)
|
|
13
|
+
new(raw).repair
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def initialize(raw)
|
|
17
|
+
@raw = raw.to_s.strip
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# Main repair
|
|
21
|
+
def repair
|
|
22
|
+
raw_sanitized = sanitize_llm_output(@raw) # move here
|
|
23
|
+
return JSON.parse(raw_sanitized) if valid_json?(raw_sanitized)
|
|
24
|
+
|
|
25
|
+
repaired = run_full_repair(raw_sanitized)
|
|
26
|
+
raise RepairError, "Unable to repair JSON" unless valid_json?(repaired)
|
|
27
|
+
|
|
28
|
+
JSON.parse(repaired)
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
private
|
|
32
|
+
|
|
33
|
+
# --------------------------
|
|
34
|
+
# Full repair workflow extracted from original repair
|
|
35
|
+
# --------------------------
|
|
36
|
+
def run_full_repair(str)
|
|
37
|
+
str = preprocess(str)
|
|
38
|
+
str = normalize_structure(str)
|
|
39
|
+
str = balance_braces(str)
|
|
40
|
+
str = run_recursive_fixes(str)
|
|
41
|
+
str = remove_trailing_commas(str)
|
|
42
|
+
str.gsub(/\s+/, " ").strip
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
# --------------------------
|
|
46
|
+
# JSON validation
|
|
47
|
+
# --------------------------
|
|
48
|
+
def valid_json?(str)
|
|
49
|
+
JSON.parse(str)
|
|
50
|
+
true
|
|
51
|
+
rescue JSON::ParserError
|
|
52
|
+
false
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
# --------------------------
|
|
56
|
+
# Preprocessing
|
|
57
|
+
# --------------------------
|
|
58
|
+
def preprocess(str)
|
|
59
|
+
str = str.strip
|
|
60
|
+
str.gsub!("'", '"')
|
|
61
|
+
str = quote_all_keys(str)
|
|
62
|
+
str = insert_missing_commas_regex(str)
|
|
63
|
+
remove_trailing_commas(str)
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
# --------------------------
|
|
67
|
+
# Quote keys
|
|
68
|
+
# --------------------------
|
|
69
|
+
def quote_all_keys(str)
|
|
70
|
+
prev = nil
|
|
71
|
+
current = str.dup
|
|
72
|
+
while current != prev
|
|
73
|
+
prev = current
|
|
74
|
+
current.gsub!(/([{\s,])([a-zA-Z0-9_-]+)\s*:/, '\1"\2":')
|
|
75
|
+
end
|
|
76
|
+
current
|
|
77
|
+
end
|
|
78
|
+
|
|
79
|
+
def insert_missing_commas_regex(str)
|
|
80
|
+
str.gsub(/([}\]"0-9a-zA-Z])\s+("?[\w-]+"?\s*:)/, '\1, \2')
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
# --------------------------
|
|
84
|
+
# Normalization
|
|
85
|
+
# --------------------------
|
|
86
|
+
def normalize_structure(input)
|
|
87
|
+
repaired = input.dup
|
|
88
|
+
repaired = fix_double_braces(repaired)
|
|
89
|
+
repaired = fix_object_brace_spacing(repaired)
|
|
90
|
+
repaired = insert_missing_commas_by_scanner(repaired)
|
|
91
|
+
repaired.gsub!(/([}\]])\s*(?=([A-Za-z0-9_"-]+\s*:))/, '\1, ')
|
|
92
|
+
repaired.gsub!(/([}\]])\s*(?=(\{|\[|"|\d|true|false|null))/i, '\1, ')
|
|
93
|
+
repaired.gsub!(/,+/, ",")
|
|
94
|
+
repaired.gsub!(/\s+/, " ")
|
|
95
|
+
repaired.strip
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
def fix_double_braces(str)
|
|
99
|
+
prev = nil
|
|
100
|
+
current = str.dup
|
|
101
|
+
while current != prev
|
|
102
|
+
prev = current
|
|
103
|
+
current.gsub!(/(\[|,)\s*\{\s*\{/, '\1 {')
|
|
104
|
+
end
|
|
105
|
+
current
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
def fix_object_brace_spacing(str)
|
|
109
|
+
str.gsub(/}\s*{/, "}, {")
|
|
110
|
+
.gsub(/]\s*{/, "], {")
|
|
111
|
+
.gsub(/}\s*\]\s*\{/, "}], {")
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
# --------------------------
|
|
115
|
+
# Recursive fixes runner
|
|
116
|
+
# --------------------------
|
|
117
|
+
def run_recursive_fixes(str)
|
|
118
|
+
str = quote_all_keys(str)
|
|
119
|
+
str = insert_commas_recursively(str)
|
|
120
|
+
str = insert_final_commas(str)
|
|
121
|
+
str = insert_commas_recursive_nested(str)
|
|
122
|
+
str = fix_consecutive_objects_in_arrays(str)
|
|
123
|
+
str = fix_double_object_braces(str)
|
|
124
|
+
fix_adjacent_arrays(str)
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
def fix_adjacent_arrays(str)
|
|
128
|
+
str.gsub(/\]\s*\[/, "], [")
|
|
129
|
+
end
|
|
130
|
+
|
|
131
|
+
# --------------------------
|
|
132
|
+
# Scanner-based comma insertion
|
|
133
|
+
# --------------------------
|
|
134
|
+
def insert_missing_commas_by_scanner(str)
|
|
135
|
+
s = str.dup
|
|
136
|
+
out_chars = []
|
|
137
|
+
i = 0
|
|
138
|
+
while i < s.length
|
|
139
|
+
char = s[i]
|
|
140
|
+
out_chars << char
|
|
141
|
+
insert_comma_after_close_brace?(char, s, i, out_chars)
|
|
142
|
+
i += 1
|
|
143
|
+
end
|
|
144
|
+
out_chars.join.gsub(/,+/, ",").gsub(/\s+/, " ").strip
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
# rubocop:disable Metrics/CyclomaticComplexity
|
|
148
|
+
def insert_comma_after_close_brace?(char, string, index, output_chars)
|
|
149
|
+
return unless ["}", "]"].include?(char)
|
|
150
|
+
|
|
151
|
+
j = index + 1
|
|
152
|
+
j += 1 while j < string.length && string[j] =~ /\s/
|
|
153
|
+
next_char = j < string.length ? string[j] : nil
|
|
154
|
+
return unless next_char && ![",", "}", "]", ":"].include?(next_char)
|
|
155
|
+
|
|
156
|
+
output_chars << "," if next_char =~ /[\[\]"0-9A-Za-z_-]/
|
|
157
|
+
end
|
|
158
|
+
# rubocop:enable Metrics/CyclomaticComplexity
|
|
159
|
+
|
|
160
|
+
# --------------------------
|
|
161
|
+
# Recursive comma insertion
|
|
162
|
+
# --------------------------
|
|
163
|
+
def insert_commas_recursively(str)
|
|
164
|
+
loop do
|
|
165
|
+
prev = str.dup
|
|
166
|
+
str.gsub!(/([}\]"0-9a-zA-Z])\s+(?=(\{|"[^"]*"|\d+|true|false|null|\[))/i, '\1, ')
|
|
167
|
+
str.gsub!(/(\})\s+(?=\{)/, '\1, ')
|
|
168
|
+
str.gsub!(/(\])\s+(?=\[)/, '\1, ')
|
|
169
|
+
break if str == prev
|
|
170
|
+
end
|
|
171
|
+
str
|
|
172
|
+
end
|
|
173
|
+
|
|
174
|
+
def insert_final_commas(str)
|
|
175
|
+
loop do
|
|
176
|
+
prev = str.dup
|
|
177
|
+
str.gsub!(/([}\]])\s+(?=("[a-zA-Z_][a-zA-Z0-9_]*"\s*:))/, '\1, ')
|
|
178
|
+
str.gsub!(/([}\]])\s+(?=[{\[])/, '\1, ')
|
|
179
|
+
break if str == prev
|
|
180
|
+
end
|
|
181
|
+
str
|
|
182
|
+
end
|
|
183
|
+
|
|
184
|
+
def insert_commas_recursive_nested(str)
|
|
185
|
+
loop do
|
|
186
|
+
prev = str.dup
|
|
187
|
+
str.gsub!(/}\s*(?=\{)/, "}, {")
|
|
188
|
+
str.gsub!(/]\s*(?=\[)/, "], [")
|
|
189
|
+
str.gsub!(/([}\]])\s+(?=("[^"]+"\s*:))/, '\1, ')
|
|
190
|
+
break if str == prev
|
|
191
|
+
end
|
|
192
|
+
str
|
|
193
|
+
end
|
|
194
|
+
|
|
195
|
+
def fix_consecutive_objects_in_arrays(str)
|
|
196
|
+
loop do
|
|
197
|
+
prev = str.dup
|
|
198
|
+
str.gsub!(/({[^{}]*})\s*(?=\{)/, '\1, ')
|
|
199
|
+
break if str == prev
|
|
200
|
+
end
|
|
201
|
+
str
|
|
202
|
+
end
|
|
203
|
+
|
|
204
|
+
def fix_double_object_braces(str)
|
|
205
|
+
fix_double_braces(str)
|
|
206
|
+
end
|
|
207
|
+
|
|
208
|
+
def remove_trailing_commas(str)
|
|
209
|
+
str.gsub(/,(\s*[}\]])/, '\1')
|
|
210
|
+
end
|
|
211
|
+
|
|
212
|
+
def balance_braces(str)
|
|
213
|
+
open_braces = str.count("{")
|
|
214
|
+
close_braces = str.count("}")
|
|
215
|
+
open_brackets = str.count("[")
|
|
216
|
+
close_brackets = str.count("]")
|
|
217
|
+
|
|
218
|
+
str + "}" * [open_braces - close_braces, 0].max + "]" * [open_brackets - close_brackets, 0].max
|
|
219
|
+
end
|
|
220
|
+
|
|
221
|
+
def sanitize_llm_output(str)
|
|
222
|
+
return str unless str.is_a?(String)
|
|
223
|
+
|
|
224
|
+
# Remove everything before the first ```json or ```
|
|
225
|
+
sanitized = str.sub(/\A.*?```(?:json)?\s*/m, "")
|
|
226
|
+
|
|
227
|
+
# Remove trailing ```
|
|
228
|
+
sanitized = sanitized.sub(/```\s*\z/, "")
|
|
229
|
+
|
|
230
|
+
sanitized.strip
|
|
231
|
+
end
|
|
232
|
+
end
|
|
233
|
+
# rubocop:enable Metrics/ClassLength
|
|
234
|
+
end
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module AiGuardrails
|
|
4
|
+
# Simple wrapper for logging inside the gem.
|
|
5
|
+
# Allows the user to pass any logger (Rails.logger, Logger.new, etc.)
|
|
6
|
+
module Logger
|
|
7
|
+
class << self
|
|
8
|
+
attr_accessor :logger, :debug_mode
|
|
9
|
+
|
|
10
|
+
# Logs normal information
|
|
11
|
+
def info(message)
|
|
12
|
+
safe_logger.info("[AiGuardrails] #{message}")
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
# Logs errors only
|
|
16
|
+
def error(message)
|
|
17
|
+
safe_logger.error("[AiGuardrails ERROR] #{message}")
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# Logs extra details when debug_mode is enabled
|
|
21
|
+
def debug(message)
|
|
22
|
+
return unless debug_mode
|
|
23
|
+
|
|
24
|
+
safe_logger.debug("[AiGuardrails DEBUG] #{message}")
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
private
|
|
28
|
+
|
|
29
|
+
# Uses null logger if no logger is configured
|
|
30
|
+
def safe_logger
|
|
31
|
+
logger || NullLogger.new
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
# Basic fallback logger that ignores messages.
|
|
36
|
+
# Prevents NoMethodError when users don't set a logger.
|
|
37
|
+
class NullLogger
|
|
38
|
+
def info(_msg); end
|
|
39
|
+
|
|
40
|
+
def error(_msg); end
|
|
41
|
+
|
|
42
|
+
def debug(_msg); end
|
|
43
|
+
end
|
|
44
|
+
end
|
|
45
|
+
end
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module AiGuardrails
|
|
4
|
+
# MockModelClient simulates AI LLM responses for tests
|
|
5
|
+
class MockModelClient
|
|
6
|
+
class MockError < StandardError; end
|
|
7
|
+
|
|
8
|
+
# Initialize with a hash of prompt => response
|
|
9
|
+
def initialize(responses = {})
|
|
10
|
+
@responses = responses.transform_keys(&:to_s)
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
# Simulates a call to the model
|
|
14
|
+
# Options can include:
|
|
15
|
+
# - prompt: string
|
|
16
|
+
# - raise_error: boolean to simulate API failure
|
|
17
|
+
def call(prompt:, raise_error: false, default_fallback: nil)
|
|
18
|
+
return default_fallback if raise_error == false && !@responses.key?(prompt.to_s)
|
|
19
|
+
|
|
20
|
+
raise MockError, "Simulated model error" if raise_error
|
|
21
|
+
|
|
22
|
+
response = @responses[prompt.to_s]
|
|
23
|
+
|
|
24
|
+
raise MockError, "No mock response defined for prompt: #{prompt}" unless response
|
|
25
|
+
|
|
26
|
+
response
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
# Add or update mock responses dynamically
|
|
30
|
+
def add_response(prompt, response)
|
|
31
|
+
@responses[prompt.to_s] = response
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
end
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module AiGuardrails
|
|
4
|
+
module Provider
|
|
5
|
+
# BaseClient defines a common interface for all providers
|
|
6
|
+
class BaseClient
|
|
7
|
+
# Initialize with optional config hash
|
|
8
|
+
def initialize(config = {})
|
|
9
|
+
@config = config
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
# Call AI model with a prompt
|
|
13
|
+
# Must be implemented by subclasses
|
|
14
|
+
def call_model(prompt:)
|
|
15
|
+
raise NotImplementedError, "Subclasses must implement call_model"
|
|
16
|
+
end
|
|
17
|
+
end
|
|
18
|
+
end
|
|
19
|
+
end
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module AiGuardrails
|
|
4
|
+
module Provider
|
|
5
|
+
# Factory returns the right provider client
|
|
6
|
+
class Factory
|
|
7
|
+
PROVIDERS = {
|
|
8
|
+
openai: OpenAIClient
|
|
9
|
+
# add :anthropic => AnthropicClient later
|
|
10
|
+
}.freeze
|
|
11
|
+
|
|
12
|
+
def self.build(provider:, config: {})
|
|
13
|
+
klass = PROVIDERS[provider.to_sym]
|
|
14
|
+
raise ArgumentError, "Unknown provider: #{provider}" unless klass
|
|
15
|
+
|
|
16
|
+
klass.new(config)
|
|
17
|
+
end
|
|
18
|
+
end
|
|
19
|
+
end
|
|
20
|
+
end
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module AiGuardrails
|
|
4
|
+
module Provider
|
|
5
|
+
# Handles actual OpenAI API calls.
|
|
6
|
+
# The ruby-openai gem is only loaded when call_model is used.
|
|
7
|
+
class OpenAIClient < BaseClient
|
|
8
|
+
def initialize(config = {})
|
|
9
|
+
super
|
|
10
|
+
@client = nil
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
# Actual API call method
|
|
14
|
+
def call_model(prompt:)
|
|
15
|
+
ensure_provider_loaded
|
|
16
|
+
|
|
17
|
+
@client ||= ::OpenAI::Client.new(access_token: @config[:api_key])
|
|
18
|
+
|
|
19
|
+
response = @client.chat(
|
|
20
|
+
parameters: {
|
|
21
|
+
model: @config[:model] || "gpt-4o-mini",
|
|
22
|
+
messages: [{ role: "user", content: prompt }],
|
|
23
|
+
temperature: @config[:temperature] || 0.7
|
|
24
|
+
}
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
response.dig("choices", 0, "message", "content")
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
private
|
|
31
|
+
|
|
32
|
+
# Load ruby-openai only when needed
|
|
33
|
+
def ensure_provider_loaded
|
|
34
|
+
require "ruby/openai"
|
|
35
|
+
rescue LoadError
|
|
36
|
+
raise LoadError,
|
|
37
|
+
"ruby-openai gem is not installed. Add:\n" \
|
|
38
|
+
" gem 'ruby-openai', require: false\n" \
|
|
39
|
+
"to your Gemfile if using OpenAI provider."
|
|
40
|
+
end
|
|
41
|
+
end
|
|
42
|
+
end
|
|
43
|
+
end
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module AiGuardrails
|
|
4
|
+
# Coordinates the full validation and repair flow.
|
|
5
|
+
class Runner
|
|
6
|
+
def initialize(prompt:, provider:, schema:, options: {})
|
|
7
|
+
@prompt = prompt
|
|
8
|
+
@provider = provider
|
|
9
|
+
@schema = schema
|
|
10
|
+
@options = options
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
# rubocop:disable Metrics/MethodLength
|
|
14
|
+
def run
|
|
15
|
+
Logger.info("Starting run")
|
|
16
|
+
Logger.debug("Prompt: #{@prompt}")
|
|
17
|
+
|
|
18
|
+
raw = @provider.call_model(prompt: @prompt)
|
|
19
|
+
|
|
20
|
+
Logger.debug("Raw model output: #{raw.inspect}")
|
|
21
|
+
|
|
22
|
+
repaired_json = JsonRepair.repair(raw)
|
|
23
|
+
Logger.debug("Repaired JSON: #{repaired_json.inspect}")
|
|
24
|
+
|
|
25
|
+
valid, result = SchemaValidator.new(@schema).validate(repaired_json)
|
|
26
|
+
|
|
27
|
+
unless valid
|
|
28
|
+
Logger.error("Schema validation failed: #{result}")
|
|
29
|
+
return { ok: false, errors: result }
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
Logger.info("Run completed successfully")
|
|
33
|
+
{ ok: true, result: result }
|
|
34
|
+
rescue StandardError => e
|
|
35
|
+
Logger.error("Unhandled exception: #{e.class} - #{e.message}")
|
|
36
|
+
raise e
|
|
37
|
+
end
|
|
38
|
+
# rubocop:enable Metrics/MethodLength
|
|
39
|
+
end
|
|
40
|
+
end
|