ai_guardrails 1.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,47 @@
1
+ # frozen_string_literal: true
2
+
3
+ module AiGuardrails
4
+ # Provides helper methods to run AiGuardrails in background jobs or CLI
5
+ module BackgroundJob
6
+ class << self
7
+ # Executes a task safely in background or CLI
8
+ #
9
+ # Example usage:
10
+ # AiGuardrails::BackgroundJob.perform do
11
+ # AiGuardrails::DSL.run(prompt: "...", schema: {...})
12
+ # end
13
+ #
14
+ # Optional parameters:
15
+ # logger: custom logger instance
16
+ # debug: true/false for debug mode
17
+ def perform(logger: nil, debug: false, &block)
18
+ with_temp_logger(logger, debug, &block)
19
+ rescue StandardError => e
20
+ Logger.logger&.error("Background job failed: #{e.class} - #{e.message}")
21
+ raise e
22
+ end
23
+
24
+ def with_temp_logger(temp_logger, temp_debug, &block)
25
+ prev_logger = Logger.logger
26
+ prev_debug = Logger.debug_mode
27
+
28
+ Logger.logger = temp_logger if temp_logger
29
+ Logger.debug_mode = temp_debug
30
+
31
+ perform_with_error_logging(&block)
32
+ ensure
33
+ Logger.logger = prev_logger
34
+ Logger.debug_mode = prev_debug
35
+ end
36
+
37
+ private
38
+
39
+ def perform_with_error_logging(&block)
40
+ block.call
41
+ rescue StandardError => e
42
+ Logger.logger&.error("Background job failed: #{e.class} - #{e.message}")
43
+ raise e
44
+ end
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,50 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "digest"
4
+
5
+ module AiGuardrails
6
+ # Simple caching layer for AI responses
7
+ module Cache
8
+ class << self
9
+ attr_accessor :enabled, :store, :expires_in
10
+
11
+ # Setup cache
12
+ # store: any object responding to #read/#write (e.g., Rails.cache, ActiveSupport::Cache)
13
+ # expires_in: seconds
14
+ def configure(enabled: true, store: nil, expires_in: 300)
15
+ @enabled = enabled
16
+ @store = store || NullStore.new
17
+ @expires_in = expires_in
18
+ end
19
+
20
+ # Accept default or block, works with caching disabled
21
+ def fetch(key, default = nil)
22
+ return (block_given? ? yield : default) unless enabled
23
+
24
+ cached = store.read(key, expires_in: expires_in)
25
+ return cached if cached
26
+
27
+ result = block_given? ? yield : default
28
+ store.write(key, result, expires_in: expires_in)
29
+ result
30
+ end
31
+
32
+ # Generate a cache key from prompt + schema
33
+ def key(prompt, schema)
34
+ digest_input = "#{prompt}-#{schema}"
35
+ Digest::SHA256.hexdigest(digest_input)
36
+ end
37
+
38
+ # Null object if no cache store is provided
39
+ class NullStore
40
+ def read(_key, **_options)
41
+ nil
42
+ end
43
+
44
+ def write(_key, value, **_options)
45
+ value
46
+ end
47
+ end
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,17 @@
1
+ # frozen_string_literal: true
2
+
3
+ module AiGuardrails
4
+ # Provides a CLI-friendly interface for running AiGuardrails safely
5
+ module CLI
6
+ # Runs AiGuardrails safely in CLI scripts
7
+ #
8
+ # Example:
9
+ # AiGuardrails::CLI.run do
10
+ # result = AiGuardrails::DSL.run(prompt: "...", schema: {...})
11
+ # puts result
12
+ # end
13
+ def self.run(debug: false, &block)
14
+ BackgroundJob.perform(logger: Logger.logger, debug: debug, &block)
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,13 @@
1
+ # frozen_string_literal: true
2
+
3
+ module AiGuardrails
4
+ # Holds configuration options for AiGuardrails.
5
+ class Config
6
+ attr_accessor :logger, :debug
7
+
8
+ def initialize
9
+ @logger = nil
10
+ @debug = false
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,101 @@
1
+ # frozen_string_literal: true
2
+
3
+ module AiGuardrails
4
+ # Provides a simple developer-friendly interface
5
+ module DSL
6
+ class << self
7
+ # Main entry point used by developers.
8
+ # Run AI model with validation, retries, and safety checks.
9
+ def run(prompt:, schema:, schema_hint: nil, **options)
10
+ Cache.fetch(Cache.key(prompt, schema)) do
11
+ result = fetch_with_retries_and_correction(prompt, schema, schema_hint, options)
12
+
13
+ puts "result in DSL: #{result}"
14
+
15
+ # Apply JSON + schema auto-fix when hooks are given.
16
+ hooks = options.fetch(:auto_fix_hooks, [])
17
+ fix_schema = schema_hint || schema
18
+ result = apply_auto_fix(result, fix_schema, hooks) unless hooks.empty?
19
+
20
+ check_safety(result, options.fetch(:blocklist, []))
21
+ result
22
+ end
23
+ end
24
+
25
+ private
26
+
27
+ # Extracted to reduce run method length
28
+ def fetch_with_retries_and_correction(prompt, schema, schema_hint, options)
29
+ client = build_client(options.fetch(:provider, :openai), options.fetch(:provider_config, {}))
30
+ max_retries = options.fetch(:max_retries, 3)
31
+ sleep_time = options.fetch(:sleep_time, 0)
32
+ run_with_retries_helper(
33
+ client: client, schema: schema, prompt: prompt,
34
+ max_retries: max_retries,
35
+ sleep_time: sleep_time,
36
+ schema_hint: schema_hint
37
+ )
38
+ end
39
+
40
+ # Builds the provider client
41
+ def build_client(provider, config)
42
+ Provider::Factory.build(provider: provider, config: config)
43
+ end
44
+
45
+ # Runs AutoCorrection wrapper (max 5 parameters)
46
+ def run_with_retries_helper(options = {})
47
+ client = options[:client]
48
+ schema = options[:schema]
49
+ prompt = options[:prompt]
50
+ max_retries = options[:max_retries] || 3
51
+ sleep_time = options[:sleep_time] || 0
52
+ schema_hint = options[:schema_hint]
53
+
54
+ auto = AutoCorrection.new(
55
+ provider: client, schema: schema, max_retries: max_retries, sleep_time: sleep_time
56
+ )
57
+ auto.call(prompt: prompt, schema_hint: schema_hint)
58
+ end
59
+
60
+ # Applies blocklist filtering when needed
61
+ def apply_auto_fix(result, schema, hooks)
62
+ AiGuardrails::AutoFix.fix(result, schema: schema, hooks: hooks)
63
+ end
64
+
65
+ # Runs safety filter when needed.
66
+ def check_safety(result, blocklist)
67
+ return if blocklist.empty?
68
+
69
+ content = normalize_result(result)
70
+ check_blocklist(content, blocklist)
71
+ end
72
+
73
+ # Normalizes result into a simple string for safety scanning.
74
+ def normalize_result(result)
75
+ case result
76
+ when Hash
77
+ result.values.join(" ")
78
+ when String
79
+ parse_json_string(result)
80
+ else
81
+ result.to_s
82
+ end
83
+ end
84
+
85
+ # Attempt to parse string as JSON; fallback to original string if parsing fails
86
+ def parse_json_string(str)
87
+ parsed = JSON.parse(str)
88
+ parsed.is_a?(Hash) ? parsed.values.join(" ") : str
89
+ rescue JSON::ParserError
90
+ str
91
+ end
92
+
93
+ # Perform case-insensitive safety check using SafetyFilter
94
+ def check_blocklist(content, blocklist)
95
+ content_down = content.downcase
96
+ blocklist_down = blocklist.map(&:downcase)
97
+ SafetyFilter.new(blocklist: blocklist_down).check!(content_down)
98
+ end
99
+ end
100
+ end
101
+ end
@@ -0,0 +1,234 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+
5
+ module AiGuardrails
6
+ # Repairs malformed JSON strings
7
+ # rubocop:disable Metrics/ClassLength
8
+ class JsonRepair
9
+ class RepairError < StandardError; end
10
+
11
+ # Class method entrypoint
12
+ def self.repair(raw)
13
+ new(raw).repair
14
+ end
15
+
16
+ def initialize(raw)
17
+ @raw = raw.to_s.strip
18
+ end
19
+
20
+ # Main repair
21
+ def repair
22
+ raw_sanitized = sanitize_llm_output(@raw) # move here
23
+ return JSON.parse(raw_sanitized) if valid_json?(raw_sanitized)
24
+
25
+ repaired = run_full_repair(raw_sanitized)
26
+ raise RepairError, "Unable to repair JSON" unless valid_json?(repaired)
27
+
28
+ JSON.parse(repaired)
29
+ end
30
+
31
+ private
32
+
33
+ # --------------------------
34
+ # Full repair workflow extracted from original repair
35
+ # --------------------------
36
+ def run_full_repair(str)
37
+ str = preprocess(str)
38
+ str = normalize_structure(str)
39
+ str = balance_braces(str)
40
+ str = run_recursive_fixes(str)
41
+ str = remove_trailing_commas(str)
42
+ str.gsub(/\s+/, " ").strip
43
+ end
44
+
45
+ # --------------------------
46
+ # JSON validation
47
+ # --------------------------
48
+ def valid_json?(str)
49
+ JSON.parse(str)
50
+ true
51
+ rescue JSON::ParserError
52
+ false
53
+ end
54
+
55
+ # --------------------------
56
+ # Preprocessing
57
+ # --------------------------
58
+ def preprocess(str)
59
+ str = str.strip
60
+ str.gsub!("'", '"')
61
+ str = quote_all_keys(str)
62
+ str = insert_missing_commas_regex(str)
63
+ remove_trailing_commas(str)
64
+ end
65
+
66
+ # --------------------------
67
+ # Quote keys
68
+ # --------------------------
69
+ def quote_all_keys(str)
70
+ prev = nil
71
+ current = str.dup
72
+ while current != prev
73
+ prev = current
74
+ current.gsub!(/([{\s,])([a-zA-Z0-9_-]+)\s*:/, '\1"\2":')
75
+ end
76
+ current
77
+ end
78
+
79
+ def insert_missing_commas_regex(str)
80
+ str.gsub(/([}\]"0-9a-zA-Z])\s+("?[\w-]+"?\s*:)/, '\1, \2')
81
+ end
82
+
83
+ # --------------------------
84
+ # Normalization
85
+ # --------------------------
86
+ def normalize_structure(input)
87
+ repaired = input.dup
88
+ repaired = fix_double_braces(repaired)
89
+ repaired = fix_object_brace_spacing(repaired)
90
+ repaired = insert_missing_commas_by_scanner(repaired)
91
+ repaired.gsub!(/([}\]])\s*(?=([A-Za-z0-9_"-]+\s*:))/, '\1, ')
92
+ repaired.gsub!(/([}\]])\s*(?=(\{|\[|"|\d|true|false|null))/i, '\1, ')
93
+ repaired.gsub!(/,+/, ",")
94
+ repaired.gsub!(/\s+/, " ")
95
+ repaired.strip
96
+ end
97
+
98
+ def fix_double_braces(str)
99
+ prev = nil
100
+ current = str.dup
101
+ while current != prev
102
+ prev = current
103
+ current.gsub!(/(\[|,)\s*\{\s*\{/, '\1 {')
104
+ end
105
+ current
106
+ end
107
+
108
+ def fix_object_brace_spacing(str)
109
+ str.gsub(/}\s*{/, "}, {")
110
+ .gsub(/]\s*{/, "], {")
111
+ .gsub(/}\s*\]\s*\{/, "}], {")
112
+ end
113
+
114
+ # --------------------------
115
+ # Recursive fixes runner
116
+ # --------------------------
117
+ def run_recursive_fixes(str)
118
+ str = quote_all_keys(str)
119
+ str = insert_commas_recursively(str)
120
+ str = insert_final_commas(str)
121
+ str = insert_commas_recursive_nested(str)
122
+ str = fix_consecutive_objects_in_arrays(str)
123
+ str = fix_double_object_braces(str)
124
+ fix_adjacent_arrays(str)
125
+ end
126
+
127
+ def fix_adjacent_arrays(str)
128
+ str.gsub(/\]\s*\[/, "], [")
129
+ end
130
+
131
+ # --------------------------
132
+ # Scanner-based comma insertion
133
+ # --------------------------
134
+ def insert_missing_commas_by_scanner(str)
135
+ s = str.dup
136
+ out_chars = []
137
+ i = 0
138
+ while i < s.length
139
+ char = s[i]
140
+ out_chars << char
141
+ insert_comma_after_close_brace?(char, s, i, out_chars)
142
+ i += 1
143
+ end
144
+ out_chars.join.gsub(/,+/, ",").gsub(/\s+/, " ").strip
145
+ end
146
+
147
+ # rubocop:disable Metrics/CyclomaticComplexity
148
+ def insert_comma_after_close_brace?(char, string, index, output_chars)
149
+ return unless ["}", "]"].include?(char)
150
+
151
+ j = index + 1
152
+ j += 1 while j < string.length && string[j] =~ /\s/
153
+ next_char = j < string.length ? string[j] : nil
154
+ return unless next_char && ![",", "}", "]", ":"].include?(next_char)
155
+
156
+ output_chars << "," if next_char =~ /[\[\]"0-9A-Za-z_-]/
157
+ end
158
+ # rubocop:enable Metrics/CyclomaticComplexity
159
+
160
+ # --------------------------
161
+ # Recursive comma insertion
162
+ # --------------------------
163
+ def insert_commas_recursively(str)
164
+ loop do
165
+ prev = str.dup
166
+ str.gsub!(/([}\]"0-9a-zA-Z])\s+(?=(\{|"[^"]*"|\d+|true|false|null|\[))/i, '\1, ')
167
+ str.gsub!(/(\})\s+(?=\{)/, '\1, ')
168
+ str.gsub!(/(\])\s+(?=\[)/, '\1, ')
169
+ break if str == prev
170
+ end
171
+ str
172
+ end
173
+
174
+ def insert_final_commas(str)
175
+ loop do
176
+ prev = str.dup
177
+ str.gsub!(/([}\]])\s+(?=("[a-zA-Z_][a-zA-Z0-9_]*"\s*:))/, '\1, ')
178
+ str.gsub!(/([}\]])\s+(?=[{\[])/, '\1, ')
179
+ break if str == prev
180
+ end
181
+ str
182
+ end
183
+
184
+ def insert_commas_recursive_nested(str)
185
+ loop do
186
+ prev = str.dup
187
+ str.gsub!(/}\s*(?=\{)/, "}, {")
188
+ str.gsub!(/]\s*(?=\[)/, "], [")
189
+ str.gsub!(/([}\]])\s+(?=("[^"]+"\s*:))/, '\1, ')
190
+ break if str == prev
191
+ end
192
+ str
193
+ end
194
+
195
+ def fix_consecutive_objects_in_arrays(str)
196
+ loop do
197
+ prev = str.dup
198
+ str.gsub!(/({[^{}]*})\s*(?=\{)/, '\1, ')
199
+ break if str == prev
200
+ end
201
+ str
202
+ end
203
+
204
+ def fix_double_object_braces(str)
205
+ fix_double_braces(str)
206
+ end
207
+
208
+ def remove_trailing_commas(str)
209
+ str.gsub(/,(\s*[}\]])/, '\1')
210
+ end
211
+
212
+ def balance_braces(str)
213
+ open_braces = str.count("{")
214
+ close_braces = str.count("}")
215
+ open_brackets = str.count("[")
216
+ close_brackets = str.count("]")
217
+
218
+ str + "}" * [open_braces - close_braces, 0].max + "]" * [open_brackets - close_brackets, 0].max
219
+ end
220
+
221
+ def sanitize_llm_output(str)
222
+ return str unless str.is_a?(String)
223
+
224
+ # Remove everything before the first ```json or ```
225
+ sanitized = str.sub(/\A.*?```(?:json)?\s*/m, "")
226
+
227
+ # Remove trailing ```
228
+ sanitized = sanitized.sub(/```\s*\z/, "")
229
+
230
+ sanitized.strip
231
+ end
232
+ end
233
+ # rubocop:enable Metrics/ClassLength
234
+ end
@@ -0,0 +1,45 @@
1
+ # frozen_string_literal: true
2
+
3
+ module AiGuardrails
4
+ # Simple wrapper for logging inside the gem.
5
+ # Allows the user to pass any logger (Rails.logger, Logger.new, etc.)
6
+ module Logger
7
+ class << self
8
+ attr_accessor :logger, :debug_mode
9
+
10
+ # Logs normal information
11
+ def info(message)
12
+ safe_logger.info("[AiGuardrails] #{message}")
13
+ end
14
+
15
+ # Logs errors only
16
+ def error(message)
17
+ safe_logger.error("[AiGuardrails ERROR] #{message}")
18
+ end
19
+
20
+ # Logs extra details when debug_mode is enabled
21
+ def debug(message)
22
+ return unless debug_mode
23
+
24
+ safe_logger.debug("[AiGuardrails DEBUG] #{message}")
25
+ end
26
+
27
+ private
28
+
29
+ # Uses null logger if no logger is configured
30
+ def safe_logger
31
+ logger || NullLogger.new
32
+ end
33
+ end
34
+
35
+ # Basic fallback logger that ignores messages.
36
+ # Prevents NoMethodError when users don't set a logger.
37
+ class NullLogger
38
+ def info(_msg); end
39
+
40
+ def error(_msg); end
41
+
42
+ def debug(_msg); end
43
+ end
44
+ end
45
+ end
@@ -0,0 +1,34 @@
1
+ # frozen_string_literal: true
2
+
3
+ module AiGuardrails
4
+ # MockModelClient simulates AI LLM responses for tests
5
+ class MockModelClient
6
+ class MockError < StandardError; end
7
+
8
+ # Initialize with a hash of prompt => response
9
+ def initialize(responses = {})
10
+ @responses = responses.transform_keys(&:to_s)
11
+ end
12
+
13
+ # Simulates a call to the model
14
+ # Options can include:
15
+ # - prompt: string
16
+ # - raise_error: boolean to simulate API failure
17
+ def call(prompt:, raise_error: false, default_fallback: nil)
18
+ return default_fallback if raise_error == false && !@responses.key?(prompt.to_s)
19
+
20
+ raise MockError, "Simulated model error" if raise_error
21
+
22
+ response = @responses[prompt.to_s]
23
+
24
+ raise MockError, "No mock response defined for prompt: #{prompt}" unless response
25
+
26
+ response
27
+ end
28
+
29
+ # Add or update mock responses dynamically
30
+ def add_response(prompt, response)
31
+ @responses[prompt.to_s] = response
32
+ end
33
+ end
34
+ end
@@ -0,0 +1,19 @@
1
+ # frozen_string_literal: true
2
+
3
+ module AiGuardrails
4
+ module Provider
5
+ # BaseClient defines a common interface for all providers
6
+ class BaseClient
7
+ # Initialize with optional config hash
8
+ def initialize(config = {})
9
+ @config = config
10
+ end
11
+
12
+ # Call AI model with a prompt
13
+ # Must be implemented by subclasses
14
+ def call_model(prompt:)
15
+ raise NotImplementedError, "Subclasses must implement call_model"
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,20 @@
1
+ # frozen_string_literal: true
2
+
3
+ module AiGuardrails
4
+ module Provider
5
+ # Factory returns the right provider client
6
+ class Factory
7
+ PROVIDERS = {
8
+ openai: OpenAIClient
9
+ # add :anthropic => AnthropicClient later
10
+ }.freeze
11
+
12
+ def self.build(provider:, config: {})
13
+ klass = PROVIDERS[provider.to_sym]
14
+ raise ArgumentError, "Unknown provider: #{provider}" unless klass
15
+
16
+ klass.new(config)
17
+ end
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,43 @@
1
+ # frozen_string_literal: true
2
+
3
+ module AiGuardrails
4
+ module Provider
5
+ # Handles actual OpenAI API calls.
6
+ # The ruby-openai gem is only loaded when call_model is used.
7
+ class OpenAIClient < BaseClient
8
+ def initialize(config = {})
9
+ super
10
+ @client = nil
11
+ end
12
+
13
+ # Actual API call method
14
+ def call_model(prompt:)
15
+ ensure_provider_loaded
16
+
17
+ @client ||= ::OpenAI::Client.new(access_token: @config[:api_key])
18
+
19
+ response = @client.chat(
20
+ parameters: {
21
+ model: @config[:model] || "gpt-4o-mini",
22
+ messages: [{ role: "user", content: prompt }],
23
+ temperature: @config[:temperature] || 0.7
24
+ }
25
+ )
26
+
27
+ response.dig("choices", 0, "message", "content")
28
+ end
29
+
30
+ private
31
+
32
+ # Load ruby-openai only when needed
33
+ def ensure_provider_loaded
34
+ require "ruby/openai"
35
+ rescue LoadError
36
+ raise LoadError,
37
+ "ruby-openai gem is not installed. Add:\n" \
38
+ " gem 'ruby-openai', require: false\n" \
39
+ "to your Gemfile if using OpenAI provider."
40
+ end
41
+ end
42
+ end
43
+ end
@@ -0,0 +1,40 @@
1
+ # frozen_string_literal: true
2
+
3
+ module AiGuardrails
4
+ # Coordinates the full validation and repair flow.
5
+ class Runner
6
+ def initialize(prompt:, provider:, schema:, options: {})
7
+ @prompt = prompt
8
+ @provider = provider
9
+ @schema = schema
10
+ @options = options
11
+ end
12
+
13
+ # rubocop:disable Metrics/MethodLength
14
+ def run
15
+ Logger.info("Starting run")
16
+ Logger.debug("Prompt: #{@prompt}")
17
+
18
+ raw = @provider.call_model(prompt: @prompt)
19
+
20
+ Logger.debug("Raw model output: #{raw.inspect}")
21
+
22
+ repaired_json = JsonRepair.repair(raw)
23
+ Logger.debug("Repaired JSON: #{repaired_json.inspect}")
24
+
25
+ valid, result = SchemaValidator.new(@schema).validate(repaired_json)
26
+
27
+ unless valid
28
+ Logger.error("Schema validation failed: #{result}")
29
+ return { ok: false, errors: result }
30
+ end
31
+
32
+ Logger.info("Run completed successfully")
33
+ { ok: true, result: result }
34
+ rescue StandardError => e
35
+ Logger.error("Unhandled exception: #{e.class} - #{e.message}")
36
+ raise e
37
+ end
38
+ # rubocop:enable Metrics/MethodLength
39
+ end
40
+ end