dspy 0.29.0 → 0.30.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +45 -0
- data/README.md +121 -101
- data/lib/dspy/callbacks.rb +74 -19
- data/lib/dspy/context.rb +49 -4
- data/lib/dspy/errors.rb +19 -1
- data/lib/dspy/{datasets.rb → evals/version.rb} +2 -3
- data/lib/dspy/{evaluate.rb → evals.rb} +373 -110
- data/lib/dspy/mixins/instruction_updatable.rb +22 -0
- data/lib/dspy/observability.rb +40 -182
- data/lib/dspy/predict.rb +10 -2
- data/lib/dspy/propose/dataset_summary_generator.rb +28 -18
- data/lib/dspy/re_act.rb +21 -0
- data/lib/dspy/schema/sorbet_json_schema.rb +302 -0
- data/lib/dspy/schema/version.rb +7 -0
- data/lib/dspy/schema.rb +4 -0
- data/lib/dspy/structured_outputs_prompt.rb +48 -0
- data/lib/dspy/support/warning_filters.rb +27 -0
- data/lib/dspy/teleprompt/gepa.rb +9 -588
- data/lib/dspy/teleprompt/instruction_updates.rb +94 -0
- data/lib/dspy/teleprompt/teleprompter.rb +6 -6
- data/lib/dspy/teleprompt/utils.rb +5 -65
- data/lib/dspy/type_system/sorbet_json_schema.rb +2 -299
- data/lib/dspy/version.rb +1 -1
- data/lib/dspy.rb +33 -7
- metadata +14 -60
- data/lib/dspy/code_act.rb +0 -477
- data/lib/dspy/datasets/ade.rb +0 -90
- data/lib/dspy/observability/async_span_processor.rb +0 -250
- data/lib/dspy/observability/observation_type.rb +0 -65
- data/lib/dspy/optimizers/gaussian_process.rb +0 -141
- data/lib/dspy/teleprompt/mipro_v2.rb +0 -1423
- data/lib/gepa/api.rb +0 -61
- data/lib/gepa/core/engine.rb +0 -226
- data/lib/gepa/core/evaluation_batch.rb +0 -26
- data/lib/gepa/core/result.rb +0 -92
- data/lib/gepa/core/state.rb +0 -231
- data/lib/gepa/logging/experiment_tracker.rb +0 -54
- data/lib/gepa/logging/logger.rb +0 -57
- data/lib/gepa/logging.rb +0 -9
- data/lib/gepa/proposer/base.rb +0 -27
- data/lib/gepa/proposer/merge_proposer.rb +0 -424
- data/lib/gepa/proposer/reflective_mutation/base.rb +0 -48
- data/lib/gepa/proposer/reflective_mutation/reflective_mutation.rb +0 -188
- data/lib/gepa/strategies/batch_sampler.rb +0 -91
- data/lib/gepa/strategies/candidate_selector.rb +0 -97
- data/lib/gepa/strategies/component_selector.rb +0 -57
- data/lib/gepa/strategies/instruction_proposal.rb +0 -120
- data/lib/gepa/telemetry.rb +0 -122
- data/lib/gepa/utils/pareto.rb +0 -119
- data/lib/gepa.rb +0 -21
data/lib/dspy/datasets/ade.rb
DELETED
|
@@ -1,90 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'json'
|
|
4
|
-
require 'net/http'
|
|
5
|
-
require 'uri'
|
|
6
|
-
require 'cgi'
|
|
7
|
-
require 'fileutils'
|
|
8
|
-
|
|
9
|
-
module DSPy
|
|
10
|
-
module Datasets
|
|
11
|
-
module ADE
|
|
12
|
-
extend self
|
|
13
|
-
|
|
14
|
-
DATASET = 'ade-benchmark-corpus/ade_corpus_v2'
|
|
15
|
-
CLASSIFICATION_CONFIG = 'Ade_corpus_v2_classification'
|
|
16
|
-
BASE_URL = 'https://datasets-server.huggingface.co'
|
|
17
|
-
|
|
18
|
-
DEFAULT_CACHE_DIR = File.expand_path('../../../tmp/dspy_datasets/ade', __dir__)
|
|
19
|
-
|
|
20
|
-
MAX_BATCH_SIZE = 100
|
|
21
|
-
|
|
22
|
-
def examples(split: 'train', limit: 200, offset: 0, cache_dir: default_cache_dir)
|
|
23
|
-
remaining = limit
|
|
24
|
-
current_offset = offset
|
|
25
|
-
collected = []
|
|
26
|
-
|
|
27
|
-
while remaining.positive?
|
|
28
|
-
batch_size = [remaining, MAX_BATCH_SIZE].min
|
|
29
|
-
rows = fetch_rows(
|
|
30
|
-
split: split,
|
|
31
|
-
limit: batch_size,
|
|
32
|
-
offset: current_offset,
|
|
33
|
-
cache_dir: cache_dir
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
break if rows.empty?
|
|
37
|
-
|
|
38
|
-
collected.concat(rows.map do |row|
|
|
39
|
-
{
|
|
40
|
-
'text' => row.fetch('text', ''),
|
|
41
|
-
'label' => row.fetch('label', 0).to_i
|
|
42
|
-
}
|
|
43
|
-
end)
|
|
44
|
-
|
|
45
|
-
current_offset += batch_size
|
|
46
|
-
remaining -= batch_size
|
|
47
|
-
end
|
|
48
|
-
|
|
49
|
-
collected
|
|
50
|
-
end
|
|
51
|
-
|
|
52
|
-
def fetch_rows(split:, limit:, offset:, cache_dir:)
|
|
53
|
-
FileUtils.mkdir_p(cache_dir)
|
|
54
|
-
cache_path = File.join(cache_dir, "#{CLASSIFICATION_CONFIG}_#{split}_#{offset}_#{limit}.json")
|
|
55
|
-
|
|
56
|
-
if File.exist?(cache_path)
|
|
57
|
-
return JSON.parse(File.read(cache_path))
|
|
58
|
-
end
|
|
59
|
-
|
|
60
|
-
rows = request_rows(split: split, limit: limit, offset: offset)
|
|
61
|
-
File.write(cache_path, JSON.pretty_generate(rows))
|
|
62
|
-
rows
|
|
63
|
-
end
|
|
64
|
-
|
|
65
|
-
private
|
|
66
|
-
|
|
67
|
-
def request_rows(split:, limit:, offset:)
|
|
68
|
-
uri = URI("#{BASE_URL}/rows")
|
|
69
|
-
params = {
|
|
70
|
-
dataset: DATASET,
|
|
71
|
-
config: CLASSIFICATION_CONFIG,
|
|
72
|
-
split: split,
|
|
73
|
-
offset: offset,
|
|
74
|
-
length: limit
|
|
75
|
-
}
|
|
76
|
-
uri.query = URI.encode_www_form(params)
|
|
77
|
-
|
|
78
|
-
response = Net::HTTP.get_response(uri)
|
|
79
|
-
raise "ADE dataset request failed: #{response.code}" unless response.is_a?(Net::HTTPSuccess)
|
|
80
|
-
|
|
81
|
-
body = JSON.parse(response.body)
|
|
82
|
-
body.fetch('rows', []).map { |row| row.fetch('row', {}) }
|
|
83
|
-
end
|
|
84
|
-
|
|
85
|
-
def default_cache_dir
|
|
86
|
-
ENV['DSPY_DATASETS_CACHE'] ? File.expand_path('ade', ENV['DSPY_DATASETS_CACHE']) : DEFAULT_CACHE_DIR
|
|
87
|
-
end
|
|
88
|
-
end
|
|
89
|
-
end
|
|
90
|
-
end
|
|
@@ -1,250 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'concurrent-ruby'
|
|
4
|
-
require 'thread'
|
|
5
|
-
require 'opentelemetry/sdk'
|
|
6
|
-
require 'opentelemetry/sdk/trace/export'
|
|
7
|
-
|
|
8
|
-
module DSPy
|
|
9
|
-
class Observability
|
|
10
|
-
# AsyncSpanProcessor provides non-blocking span export using concurrent-ruby.
|
|
11
|
-
# Spans are queued and exported on a dedicated single-thread executor to avoid blocking clients.
|
|
12
|
-
# Implements the same interface as OpenTelemetry::SDK::Trace::Export::BatchSpanProcessor
|
|
13
|
-
class AsyncSpanProcessor
|
|
14
|
-
# Default configuration values
|
|
15
|
-
DEFAULT_QUEUE_SIZE = 1000
|
|
16
|
-
DEFAULT_EXPORT_INTERVAL = 60.0 # seconds
|
|
17
|
-
DEFAULT_EXPORT_BATCH_SIZE = 100
|
|
18
|
-
DEFAULT_SHUTDOWN_TIMEOUT = 10.0 # seconds
|
|
19
|
-
DEFAULT_MAX_RETRIES = 3
|
|
20
|
-
|
|
21
|
-
def initialize(
|
|
22
|
-
exporter,
|
|
23
|
-
queue_size: DEFAULT_QUEUE_SIZE,
|
|
24
|
-
export_interval: DEFAULT_EXPORT_INTERVAL,
|
|
25
|
-
export_batch_size: DEFAULT_EXPORT_BATCH_SIZE,
|
|
26
|
-
shutdown_timeout: DEFAULT_SHUTDOWN_TIMEOUT,
|
|
27
|
-
max_retries: DEFAULT_MAX_RETRIES
|
|
28
|
-
)
|
|
29
|
-
@exporter = exporter
|
|
30
|
-
@queue_size = queue_size
|
|
31
|
-
@export_interval = export_interval
|
|
32
|
-
@export_batch_size = export_batch_size
|
|
33
|
-
@shutdown_timeout = shutdown_timeout
|
|
34
|
-
@max_retries = max_retries
|
|
35
|
-
@export_executor = Concurrent::SingleThreadExecutor.new
|
|
36
|
-
|
|
37
|
-
# Use thread-safe queue for cross-fiber communication
|
|
38
|
-
@queue = Thread::Queue.new
|
|
39
|
-
@shutdown_requested = false
|
|
40
|
-
@timer_thread = nil
|
|
41
|
-
|
|
42
|
-
start_export_task
|
|
43
|
-
end
|
|
44
|
-
|
|
45
|
-
def on_start(span, parent_context)
|
|
46
|
-
# Non-blocking - no operation needed on span start
|
|
47
|
-
end
|
|
48
|
-
|
|
49
|
-
def on_finish(span)
|
|
50
|
-
# Only process sampled spans to match BatchSpanProcessor behavior
|
|
51
|
-
return unless span.context.trace_flags.sampled?
|
|
52
|
-
|
|
53
|
-
# Non-blocking enqueue with overflow protection
|
|
54
|
-
# Note: on_finish is only called for already ended spans
|
|
55
|
-
begin
|
|
56
|
-
# Check queue size (non-blocking)
|
|
57
|
-
if @queue.size >= @queue_size
|
|
58
|
-
# Drop oldest span
|
|
59
|
-
begin
|
|
60
|
-
dropped_span = @queue.pop(true) # non-blocking pop
|
|
61
|
-
DSPy.log('observability.span_dropped',
|
|
62
|
-
reason: 'queue_full',
|
|
63
|
-
queue_size: @queue_size)
|
|
64
|
-
rescue ThreadError
|
|
65
|
-
# Queue was empty, continue
|
|
66
|
-
end
|
|
67
|
-
end
|
|
68
|
-
|
|
69
|
-
@queue.push(span)
|
|
70
|
-
|
|
71
|
-
# Log span queuing activity
|
|
72
|
-
DSPy.log('observability.span_queued', queue_size: @queue.size)
|
|
73
|
-
|
|
74
|
-
# Trigger immediate export if batch size reached
|
|
75
|
-
trigger_export_if_batch_full
|
|
76
|
-
rescue => e
|
|
77
|
-
DSPy.log('observability.enqueue_error', error: e.message)
|
|
78
|
-
end
|
|
79
|
-
end
|
|
80
|
-
|
|
81
|
-
def shutdown(timeout: nil)
|
|
82
|
-
timeout ||= @shutdown_timeout
|
|
83
|
-
@shutdown_requested = true
|
|
84
|
-
|
|
85
|
-
begin
|
|
86
|
-
# Export any remaining spans
|
|
87
|
-
result = export_remaining_spans(timeout: timeout, export_all: true)
|
|
88
|
-
|
|
89
|
-
future = Concurrent::Promises.future_on(@export_executor) do
|
|
90
|
-
@exporter.shutdown(timeout: timeout)
|
|
91
|
-
end
|
|
92
|
-
future.value!(timeout)
|
|
93
|
-
|
|
94
|
-
result
|
|
95
|
-
rescue => e
|
|
96
|
-
DSPy.log('observability.shutdown_error', error: e.message, class: e.class.name)
|
|
97
|
-
OpenTelemetry::SDK::Trace::Export::FAILURE
|
|
98
|
-
ensure
|
|
99
|
-
begin
|
|
100
|
-
@timer_thread&.join(timeout)
|
|
101
|
-
@timer_thread&.kill if @timer_thread&.alive?
|
|
102
|
-
rescue StandardError
|
|
103
|
-
# ignore timer shutdown issues
|
|
104
|
-
end
|
|
105
|
-
@export_executor.shutdown
|
|
106
|
-
unless @export_executor.wait_for_termination(timeout)
|
|
107
|
-
@export_executor.kill
|
|
108
|
-
end
|
|
109
|
-
end
|
|
110
|
-
end
|
|
111
|
-
|
|
112
|
-
def force_flush(timeout: nil)
|
|
113
|
-
return OpenTelemetry::SDK::Trace::Export::SUCCESS if @queue.empty?
|
|
114
|
-
|
|
115
|
-
export_remaining_spans(timeout: timeout, export_all: true)
|
|
116
|
-
end
|
|
117
|
-
|
|
118
|
-
private
|
|
119
|
-
|
|
120
|
-
def start_export_task
|
|
121
|
-
return if @export_interval <= 0 # Disable timer for testing
|
|
122
|
-
return if ENV['DSPY_DISABLE_OBSERVABILITY'] == 'true' # Skip in tests
|
|
123
|
-
|
|
124
|
-
@timer_thread = Thread.new do
|
|
125
|
-
loop do
|
|
126
|
-
break if @shutdown_requested
|
|
127
|
-
|
|
128
|
-
sleep(@export_interval)
|
|
129
|
-
break if @shutdown_requested
|
|
130
|
-
next if @queue.empty?
|
|
131
|
-
|
|
132
|
-
schedule_async_export(export_all: true)
|
|
133
|
-
end
|
|
134
|
-
rescue => e
|
|
135
|
-
DSPy.log('observability.export_task_error', error: e.message, class: e.class.name)
|
|
136
|
-
end
|
|
137
|
-
end
|
|
138
|
-
|
|
139
|
-
def trigger_export_if_batch_full
|
|
140
|
-
return if @queue.size < @export_batch_size
|
|
141
|
-
return if ENV['DSPY_DISABLE_OBSERVABILITY'] == 'true' # Skip in tests
|
|
142
|
-
schedule_async_export(export_all: false)
|
|
143
|
-
end
|
|
144
|
-
|
|
145
|
-
def export_remaining_spans(timeout: nil, export_all: true)
|
|
146
|
-
return OpenTelemetry::SDK::Trace::Export::SUCCESS if @queue.empty?
|
|
147
|
-
|
|
148
|
-
future = Concurrent::Promises.future_on(@export_executor) do
|
|
149
|
-
export_queued_spans_internal(export_all: export_all)
|
|
150
|
-
end
|
|
151
|
-
|
|
152
|
-
future.value!(timeout || @shutdown_timeout)
|
|
153
|
-
rescue => e
|
|
154
|
-
DSPy.log('observability.export_error', error: e.message, class: e.class.name)
|
|
155
|
-
OpenTelemetry::SDK::Trace::Export::FAILURE
|
|
156
|
-
end
|
|
157
|
-
|
|
158
|
-
def schedule_async_export(export_all: false)
|
|
159
|
-
return if @shutdown_requested
|
|
160
|
-
|
|
161
|
-
@export_executor.post do
|
|
162
|
-
export_queued_spans_internal(export_all: export_all)
|
|
163
|
-
rescue => e
|
|
164
|
-
DSPy.log('observability.batch_export_error', error: e.message, class: e.class.name)
|
|
165
|
-
end
|
|
166
|
-
end
|
|
167
|
-
|
|
168
|
-
def export_queued_spans
|
|
169
|
-
export_queued_spans_internal(export_all: false)
|
|
170
|
-
end
|
|
171
|
-
|
|
172
|
-
def export_queued_spans_internal(export_all: false)
|
|
173
|
-
result = OpenTelemetry::SDK::Trace::Export::SUCCESS
|
|
174
|
-
|
|
175
|
-
loop do
|
|
176
|
-
spans = dequeue_spans(export_all ? @queue_size : @export_batch_size)
|
|
177
|
-
break if spans.empty?
|
|
178
|
-
|
|
179
|
-
result = export_spans_with_retry(spans)
|
|
180
|
-
break if result == OpenTelemetry::SDK::Trace::Export::FAILURE
|
|
181
|
-
|
|
182
|
-
break unless export_all || @queue.size >= @export_batch_size
|
|
183
|
-
end
|
|
184
|
-
|
|
185
|
-
result
|
|
186
|
-
end
|
|
187
|
-
|
|
188
|
-
def dequeue_spans(limit)
|
|
189
|
-
spans = []
|
|
190
|
-
|
|
191
|
-
limit.times do
|
|
192
|
-
begin
|
|
193
|
-
spans << @queue.pop(true) # non-blocking pop
|
|
194
|
-
rescue ThreadError
|
|
195
|
-
break
|
|
196
|
-
end
|
|
197
|
-
end
|
|
198
|
-
|
|
199
|
-
spans
|
|
200
|
-
end
|
|
201
|
-
|
|
202
|
-
def export_spans_with_retry(spans)
|
|
203
|
-
retries = 0
|
|
204
|
-
|
|
205
|
-
# Convert spans to SpanData objects (required by OTLP exporter)
|
|
206
|
-
span_data_batch = spans.map(&:to_span_data)
|
|
207
|
-
|
|
208
|
-
# Log export attempt
|
|
209
|
-
DSPy.log('observability.export_attempt',
|
|
210
|
-
spans_count: span_data_batch.size,
|
|
211
|
-
batch_size: span_data_batch.size)
|
|
212
|
-
|
|
213
|
-
loop do
|
|
214
|
-
result = @exporter.export(span_data_batch, timeout: @shutdown_timeout)
|
|
215
|
-
|
|
216
|
-
case result
|
|
217
|
-
when OpenTelemetry::SDK::Trace::Export::SUCCESS
|
|
218
|
-
DSPy.log('observability.export_success',
|
|
219
|
-
spans_count: span_data_batch.size,
|
|
220
|
-
export_result: 'SUCCESS')
|
|
221
|
-
return result
|
|
222
|
-
when OpenTelemetry::SDK::Trace::Export::FAILURE
|
|
223
|
-
retries += 1
|
|
224
|
-
if retries <= @max_retries
|
|
225
|
-
backoff_seconds = 0.1 * (2 ** retries)
|
|
226
|
-
DSPy.log('observability.export_retry',
|
|
227
|
-
attempt: retries,
|
|
228
|
-
spans_count: span_data_batch.size,
|
|
229
|
-
backoff_seconds: backoff_seconds)
|
|
230
|
-
# Exponential backoff
|
|
231
|
-
sleep(backoff_seconds)
|
|
232
|
-
next
|
|
233
|
-
else
|
|
234
|
-
DSPy.log('observability.export_failed',
|
|
235
|
-
spans_count: span_data_batch.size,
|
|
236
|
-
retries: retries)
|
|
237
|
-
return result
|
|
238
|
-
end
|
|
239
|
-
else
|
|
240
|
-
return result
|
|
241
|
-
end
|
|
242
|
-
end
|
|
243
|
-
rescue => e
|
|
244
|
-
DSPy.log('observability.export_error', error: e.message, class: e.class.name)
|
|
245
|
-
OpenTelemetry::SDK::Trace::Export::FAILURE
|
|
246
|
-
end
|
|
247
|
-
|
|
248
|
-
end
|
|
249
|
-
end
|
|
250
|
-
end
|
|
@@ -1,65 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'sorbet-runtime'
|
|
4
|
-
|
|
5
|
-
module DSPy
|
|
6
|
-
# Langfuse observation types as a T::Enum for type safety
|
|
7
|
-
# Maps to the official Langfuse observation types: https://langfuse.com/docs/observability/features/observation-types
|
|
8
|
-
class ObservationType < T::Enum
|
|
9
|
-
enums do
|
|
10
|
-
# LLM generation calls - used for direct model inference
|
|
11
|
-
Generation = new('generation')
|
|
12
|
-
|
|
13
|
-
# Agent operations - decision-making processes using tools/LLM guidance
|
|
14
|
-
Agent = new('agent')
|
|
15
|
-
|
|
16
|
-
# External tool calls (APIs, functions, etc.)
|
|
17
|
-
Tool = new('tool')
|
|
18
|
-
|
|
19
|
-
# Chains linking different application steps/components
|
|
20
|
-
Chain = new('chain')
|
|
21
|
-
|
|
22
|
-
# Data retrieval operations (vector stores, databases, memory search)
|
|
23
|
-
Retriever = new('retriever')
|
|
24
|
-
|
|
25
|
-
# Embedding generation calls
|
|
26
|
-
Embedding = new('embedding')
|
|
27
|
-
|
|
28
|
-
# Functions that assess quality/relevance of outputs
|
|
29
|
-
Evaluator = new('evaluator')
|
|
30
|
-
|
|
31
|
-
# Generic spans for durations of work units
|
|
32
|
-
Span = new('span')
|
|
33
|
-
|
|
34
|
-
# Discrete events/moments in time
|
|
35
|
-
Event = new('event')
|
|
36
|
-
end
|
|
37
|
-
|
|
38
|
-
# Get the appropriate observation type for a DSPy module class
|
|
39
|
-
sig { params(module_class: T.class_of(DSPy::Module)).returns(ObservationType) }
|
|
40
|
-
def self.for_module_class(module_class)
|
|
41
|
-
case module_class.name
|
|
42
|
-
when /ReAct/, /CodeAct/
|
|
43
|
-
Agent
|
|
44
|
-
when /ChainOfThought/
|
|
45
|
-
Chain
|
|
46
|
-
when /Evaluator/
|
|
47
|
-
Evaluator
|
|
48
|
-
else
|
|
49
|
-
Span
|
|
50
|
-
end
|
|
51
|
-
end
|
|
52
|
-
|
|
53
|
-
# Returns the langfuse attribute key and value as an array
|
|
54
|
-
sig { returns([String, String]) }
|
|
55
|
-
def langfuse_attribute
|
|
56
|
-
['langfuse.observation.type', serialize]
|
|
57
|
-
end
|
|
58
|
-
|
|
59
|
-
# Returns a hash with the langfuse attribute for easy merging
|
|
60
|
-
sig { returns(T::Hash[String, String]) }
|
|
61
|
-
def langfuse_attributes
|
|
62
|
-
{ 'langfuse.observation.type' => serialize }
|
|
63
|
-
end
|
|
64
|
-
end
|
|
65
|
-
end
|
|
@@ -1,141 +0,0 @@
|
|
|
1
|
-
# typed: strict
|
|
2
|
-
# frozen_string_literal: true
|
|
3
|
-
|
|
4
|
-
require 'numo/narray'
|
|
5
|
-
require 'sorbet-runtime'
|
|
6
|
-
|
|
7
|
-
module DSPy
|
|
8
|
-
module Optimizers
|
|
9
|
-
# Pure Ruby Gaussian Process implementation for Bayesian optimization
|
|
10
|
-
# No external LAPACK/BLAS dependencies required
|
|
11
|
-
class GaussianProcess
|
|
12
|
-
extend T::Sig
|
|
13
|
-
|
|
14
|
-
sig { params(length_scale: Float, signal_variance: Float, noise_variance: Float).void }
|
|
15
|
-
def initialize(length_scale: 1.0, signal_variance: 1.0, noise_variance: 1e-6)
|
|
16
|
-
@length_scale = length_scale
|
|
17
|
-
@signal_variance = signal_variance
|
|
18
|
-
@noise_variance = noise_variance
|
|
19
|
-
@fitted = T.let(false, T::Boolean)
|
|
20
|
-
end
|
|
21
|
-
|
|
22
|
-
sig { params(x1: T::Array[T::Array[Float]], x2: T::Array[T::Array[Float]]).returns(Numo::DFloat) }
|
|
23
|
-
def rbf_kernel(x1, x2)
|
|
24
|
-
# Convert to Numo arrays
|
|
25
|
-
x1_array = Numo::DFloat[*x1]
|
|
26
|
-
x2_array = Numo::DFloat[*x2]
|
|
27
|
-
|
|
28
|
-
# Compute squared Euclidean distances manually
|
|
29
|
-
n1, n2 = x1_array.shape[0], x2_array.shape[0]
|
|
30
|
-
sqdist = Numo::DFloat.zeros(n1, n2)
|
|
31
|
-
|
|
32
|
-
(0...n1).each do |i|
|
|
33
|
-
(0...n2).each do |j|
|
|
34
|
-
diff = x1_array[i, true] - x2_array[j, true]
|
|
35
|
-
sqdist[i, j] = (diff ** 2).sum
|
|
36
|
-
end
|
|
37
|
-
end
|
|
38
|
-
|
|
39
|
-
# RBF kernel: σ² * exp(-0.5 * d² / ℓ²)
|
|
40
|
-
@signal_variance * Numo::NMath.exp(-0.5 * sqdist / (@length_scale ** 2))
|
|
41
|
-
end
|
|
42
|
-
|
|
43
|
-
sig { params(x_train: T::Array[T::Array[Float]], y_train: T::Array[Float]).void }
|
|
44
|
-
def fit(x_train, y_train)
|
|
45
|
-
@x_train = x_train
|
|
46
|
-
@y_train = Numo::DFloat[*y_train]
|
|
47
|
-
|
|
48
|
-
# Compute kernel matrix
|
|
49
|
-
k_matrix = rbf_kernel(x_train, x_train)
|
|
50
|
-
|
|
51
|
-
# Add noise to diagonal for numerical stability
|
|
52
|
-
n = k_matrix.shape[0]
|
|
53
|
-
(0...n).each { |i| k_matrix[i, i] += @noise_variance }
|
|
54
|
-
|
|
55
|
-
# Store inverted kernel matrix using simple LU decomposition
|
|
56
|
-
@k_inv = matrix_inverse(k_matrix)
|
|
57
|
-
@alpha = @k_inv.dot(@y_train)
|
|
58
|
-
|
|
59
|
-
@fitted = true
|
|
60
|
-
end
|
|
61
|
-
|
|
62
|
-
sig { params(x_test: T::Array[T::Array[Float]], return_std: T::Boolean).returns(T.any(Numo::DFloat, [Numo::DFloat, Numo::DFloat])) }
|
|
63
|
-
def predict(x_test, return_std: false)
|
|
64
|
-
raise "Gaussian Process not fitted" unless @fitted
|
|
65
|
-
|
|
66
|
-
# Kernel between training and test points
|
|
67
|
-
k_star = rbf_kernel(T.must(@x_train), x_test)
|
|
68
|
-
|
|
69
|
-
# Predictive mean
|
|
70
|
-
mean = k_star.transpose.dot(@alpha)
|
|
71
|
-
|
|
72
|
-
return mean unless return_std
|
|
73
|
-
|
|
74
|
-
# Predictive variance (simplified for small matrices)
|
|
75
|
-
k_star_star = rbf_kernel(x_test, x_test)
|
|
76
|
-
var_matrix = k_star_star - k_star.transpose.dot(@k_inv).dot(k_star)
|
|
77
|
-
var = var_matrix.diagonal
|
|
78
|
-
|
|
79
|
-
# Ensure positive variance (element-wise maximum)
|
|
80
|
-
var = var.map { |v| [v, 1e-12].max }
|
|
81
|
-
std = Numo::NMath.sqrt(var)
|
|
82
|
-
|
|
83
|
-
[mean, std]
|
|
84
|
-
end
|
|
85
|
-
|
|
86
|
-
private
|
|
87
|
-
|
|
88
|
-
sig { returns(T.nilable(T::Array[T::Array[Float]])) }
|
|
89
|
-
attr_reader :x_train
|
|
90
|
-
|
|
91
|
-
sig { returns(T.nilable(Numo::DFloat)) }
|
|
92
|
-
attr_reader :y_train, :k_inv, :alpha
|
|
93
|
-
|
|
94
|
-
# Simple matrix inversion using Gauss-Jordan elimination
|
|
95
|
-
# Only suitable for small matrices (< 100x100)
|
|
96
|
-
sig { params(matrix: Numo::DFloat).returns(Numo::DFloat) }
|
|
97
|
-
def matrix_inverse(matrix)
|
|
98
|
-
n = matrix.shape[0]
|
|
99
|
-
raise "Matrix must be square" unless matrix.shape[0] == matrix.shape[1]
|
|
100
|
-
|
|
101
|
-
# Create augmented matrix [A|I]
|
|
102
|
-
augmented = Numo::DFloat.zeros(n, 2*n)
|
|
103
|
-
augmented[true, 0...n] = matrix.copy
|
|
104
|
-
(0...n).each { |i| augmented[i, n+i] = 1.0 }
|
|
105
|
-
|
|
106
|
-
# Gauss-Jordan elimination
|
|
107
|
-
(0...n).each do |i|
|
|
108
|
-
# Find pivot
|
|
109
|
-
max_row = i
|
|
110
|
-
(i+1...n).each do |k|
|
|
111
|
-
if augmented[k, i].abs > augmented[max_row, i].abs
|
|
112
|
-
max_row = k
|
|
113
|
-
end
|
|
114
|
-
end
|
|
115
|
-
|
|
116
|
-
# Swap rows if needed
|
|
117
|
-
if max_row != i
|
|
118
|
-
temp = augmented[i, true].copy
|
|
119
|
-
augmented[i, true] = augmented[max_row, true]
|
|
120
|
-
augmented[max_row, true] = temp
|
|
121
|
-
end
|
|
122
|
-
|
|
123
|
-
# Make diagonal element 1
|
|
124
|
-
pivot = augmented[i, i]
|
|
125
|
-
raise "Matrix is singular" if pivot.abs < 1e-12
|
|
126
|
-
augmented[i, true] /= pivot
|
|
127
|
-
|
|
128
|
-
# Eliminate column
|
|
129
|
-
(0...n).each do |j|
|
|
130
|
-
next if i == j
|
|
131
|
-
factor = augmented[j, i]
|
|
132
|
-
augmented[j, true] -= factor * augmented[i, true]
|
|
133
|
-
end
|
|
134
|
-
end
|
|
135
|
-
|
|
136
|
-
# Extract inverse matrix
|
|
137
|
-
augmented[true, n...2*n]
|
|
138
|
-
end
|
|
139
|
-
end
|
|
140
|
-
end
|
|
141
|
-
end
|