lda-ruby 0.3.9 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +5 -13
- data/CHANGELOG.md +16 -0
- data/Gemfile +9 -0
- data/README.md +126 -3
- data/VERSION.yml +3 -3
- data/docs/modernization-handoff.md +233 -0
- data/docs/porting-strategy.md +148 -0
- data/docs/precompiled-platform-policy.md +81 -0
- data/docs/precompiled-target-evaluation.md +67 -0
- data/docs/release-runbook.md +192 -0
- data/docs/rust-orchestration-guardrails.md +50 -0
- data/ext/lda-ruby/cokus.c +10 -11
- data/ext/lda-ruby/cokus.h +3 -3
- data/ext/lda-ruby/extconf.rb +10 -6
- data/ext/lda-ruby/lda-inference.c +23 -7
- data/ext/lda-ruby/utils.c +8 -0
- data/ext/lda-ruby-rust/Cargo.toml +12 -0
- data/ext/lda-ruby-rust/README.md +73 -0
- data/ext/lda-ruby-rust/extconf.rb +135 -0
- data/ext/lda-ruby-rust/include/strings.h +35 -0
- data/ext/lda-ruby-rust/src/lib.rs +1263 -0
- data/lda-ruby.gemspec +0 -0
- data/lib/lda-ruby/backends/base.rb +133 -0
- data/lib/lda-ruby/backends/native.rb +158 -0
- data/lib/lda-ruby/backends/pure_ruby.rb +675 -0
- data/lib/lda-ruby/backends/rust.rb +607 -0
- data/lib/lda-ruby/backends.rb +58 -0
- data/lib/lda-ruby/corpus/corpus.rb +17 -15
- data/lib/lda-ruby/corpus/data_corpus.rb +2 -2
- data/lib/lda-ruby/corpus/directory_corpus.rb +2 -2
- data/lib/lda-ruby/corpus/text_corpus.rb +2 -2
- data/lib/lda-ruby/document/document.rb +6 -6
- data/lib/lda-ruby/document/text_document.rb +5 -4
- data/lib/lda-ruby/rust_build_policy.rb +21 -0
- data/lib/lda-ruby/version.rb +5 -0
- data/lib/lda-ruby.rb +293 -48
- data/test/backend_compatibility_test.rb +146 -0
- data/test/backends_selection_test.rb +100 -0
- data/test/benchmark_scripts_test.rb +23 -0
- data/test/gemspec_test.rb +27 -0
- data/test/lda_ruby_test.rb +49 -11
- data/test/packaged_gem_smoke_test.rb +33 -0
- data/test/pure_ruby_orchestration_test.rb +109 -0
- data/test/release_scripts_test.rb +93 -0
- data/test/rust_build_policy_test.rb +23 -0
- data/test/rust_orchestration_test.rb +911 -0
- data/test/simple_pipeline_test.rb +22 -0
- data/test/simple_yaml.rb +1 -7
- data/test/test_helper.rb +5 -6
- metadata +54 -38
- data/Rakefile +0 -61
- data/ext/lda-ruby/Makefile +0 -181
- data/test/data/.gitignore +0 -2
- data/test/simple_test.rb +0 -26
|
@@ -0,0 +1,607 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Lda
|
|
4
|
+
module Backends
|
|
5
|
+
class Rust < Base
|
|
6
|
+
SETTINGS = %i[max_iter convergence em_max_iter em_convergence num_topics init_alpha est_alpha verbose].freeze
|
|
7
|
+
MIN_PROBABILITY = 1e-12
|
|
8
|
+
|
|
9
|
+
def self.available?
|
|
10
|
+
return false unless defined?(::Lda::RUST_EXTENSION_LOADED) && ::Lda::RUST_EXTENSION_LOADED
|
|
11
|
+
return false unless defined?(::Lda::RustBackend)
|
|
12
|
+
|
|
13
|
+
if ::Lda::RustBackend.respond_to?(:available?)
|
|
14
|
+
::Lda::RustBackend.available?
|
|
15
|
+
else
|
|
16
|
+
true
|
|
17
|
+
end
|
|
18
|
+
rescue StandardError
|
|
19
|
+
false
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
SETTINGS.each do |setting_name|
|
|
23
|
+
define_method(setting_name) do
|
|
24
|
+
@fallback.public_send(setting_name)
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
define_method("#{setting_name}=") do |value|
|
|
28
|
+
@fallback.public_send("#{setting_name}=", value)
|
|
29
|
+
end
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
def initialize(random_seed: nil)
|
|
33
|
+
super(random_seed: random_seed)
|
|
34
|
+
raise LoadError, "Rust backend is unavailable for this environment" unless self.class.available?
|
|
35
|
+
|
|
36
|
+
@rust_corpus_session_id = nil
|
|
37
|
+
@rust_corpus_terms = nil
|
|
38
|
+
@rust_document_lengths = nil
|
|
39
|
+
@rust_document_words = nil
|
|
40
|
+
@rust_document_counts = nil
|
|
41
|
+
|
|
42
|
+
@fallback = PureRuby.new(random_seed: random_seed)
|
|
43
|
+
@fallback.topic_weights_kernel = method(:rust_topic_weights_for_word)
|
|
44
|
+
@fallback.topic_term_accumulator_kernel = method(:rust_accumulate_topic_term_counts)
|
|
45
|
+
@fallback.document_inference_kernel = method(:rust_infer_document)
|
|
46
|
+
@fallback.corpus_iteration_kernel = method(:rust_infer_corpus_iteration)
|
|
47
|
+
@fallback.topic_term_finalizer_kernel = method(:rust_finalize_topic_term_counts)
|
|
48
|
+
@fallback.gamma_shift_kernel = method(:rust_average_gamma_shift)
|
|
49
|
+
@fallback.topic_document_probability_kernel = method(:rust_topic_document_probability)
|
|
50
|
+
@fallback.topic_term_seed_kernel = method(:rust_seeded_topic_term_probabilities)
|
|
51
|
+
@fallback.trusted_kernel_outputs = true
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
def name
|
|
55
|
+
"rust"
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
def corpus=(corpus)
|
|
59
|
+
previous_session_id = @rust_corpus_session_id
|
|
60
|
+
@corpus = corpus
|
|
61
|
+
@fallback.corpus = corpus
|
|
62
|
+
register_rust_corpus_session(previous_session_id)
|
|
63
|
+
true
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
def fast_load_corpus_from_file(filename)
|
|
67
|
+
loaded = @fallback.fast_load_corpus_from_file(filename)
|
|
68
|
+
self.corpus = @fallback.corpus
|
|
69
|
+
loaded
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
def load_settings(settings_file)
|
|
73
|
+
loaded = @fallback.load_settings(settings_file)
|
|
74
|
+
self.corpus = @fallback.corpus
|
|
75
|
+
loaded
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
def set_config(init_alpha, num_topics, max_iter, convergence, em_max_iter, em_convergence, est_alpha)
|
|
79
|
+
@fallback.set_config(init_alpha, num_topics, max_iter, convergence, em_max_iter, em_convergence, est_alpha)
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
def em(start)
|
|
83
|
+
start_mode = start.to_s
|
|
84
|
+
rust_before_em(start_mode)
|
|
85
|
+
return nil if rust_orchestrated_em(start_mode)
|
|
86
|
+
|
|
87
|
+
@fallback.em(start)
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
def beta
|
|
91
|
+
@fallback.beta
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
def gamma
|
|
95
|
+
@fallback.gamma
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
def compute_phi
|
|
99
|
+
@fallback.compute_phi
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
def model
|
|
103
|
+
@fallback.model
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
private
|
|
107
|
+
|
|
108
|
+
def rust_orchestrated_em(start)
|
|
109
|
+
managed_orchestrated = rust_orchestrated_em_with_managed_corpus(start)
|
|
110
|
+
return true if managed_orchestrated
|
|
111
|
+
|
|
112
|
+
direct_orchestrated = rust_orchestrated_em_with_start_seed(start)
|
|
113
|
+
return true if direct_orchestrated
|
|
114
|
+
|
|
115
|
+
rust_orchestrated_em_with_beta(start)
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
def rust_orchestrated_em_with_managed_corpus(start)
|
|
119
|
+
return false unless defined?(::Lda::RustBackend)
|
|
120
|
+
return false unless ensure_rust_corpus_snapshot
|
|
121
|
+
|
|
122
|
+
random_seed = Integer(next_random_seed)
|
|
123
|
+
if ::Lda::RustBackend.respond_to?(:run_em_on_session_with_corpus)
|
|
124
|
+
managed_output = ::Lda::RustBackend.run_em_on_session_with_corpus(
|
|
125
|
+
Integer(@rust_corpus_session_id || 0),
|
|
126
|
+
@rust_document_words,
|
|
127
|
+
@rust_document_counts,
|
|
128
|
+
Integer(@rust_corpus_terms),
|
|
129
|
+
start.to_s,
|
|
130
|
+
*current_rust_session_config_signature,
|
|
131
|
+
random_seed
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return false unless managed_output.is_a?(Array) && managed_output.size == 5
|
|
135
|
+
|
|
136
|
+
session_id, beta_probabilities, beta_log, gamma, phi = managed_output
|
|
137
|
+
output = [beta_probabilities, beta_log, gamma, phi]
|
|
138
|
+
return false unless valid_rust_em_output?(output, @rust_document_lengths, Integer(num_topics), Integer(@rust_corpus_terms))
|
|
139
|
+
|
|
140
|
+
@rust_corpus_session_id =
|
|
141
|
+
if session_id.is_a?(Numeric) && session_id.positive?
|
|
142
|
+
Integer(session_id)
|
|
143
|
+
end
|
|
144
|
+
@fallback.apply_em_state(
|
|
145
|
+
beta_probabilities: beta_probabilities,
|
|
146
|
+
beta_log: beta_log,
|
|
147
|
+
gamma: gamma,
|
|
148
|
+
phi: phi
|
|
149
|
+
)
|
|
150
|
+
return true
|
|
151
|
+
end
|
|
152
|
+
|
|
153
|
+
return false unless ::Lda::RustBackend.respond_to?(:run_em_on_session)
|
|
154
|
+
return false unless @rust_corpus_session_id
|
|
155
|
+
|
|
156
|
+
output = ::Lda::RustBackend.run_em_on_session(
|
|
157
|
+
Integer(@rust_corpus_session_id),
|
|
158
|
+
start.to_s,
|
|
159
|
+
*current_rust_session_config_signature,
|
|
160
|
+
random_seed
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return false unless valid_rust_em_output?(output, @rust_document_lengths, Integer(num_topics), Integer(@rust_corpus_terms))
|
|
164
|
+
|
|
165
|
+
beta_probabilities, beta_log, gamma, phi = output
|
|
166
|
+
@fallback.apply_em_state(
|
|
167
|
+
beta_probabilities: beta_probabilities,
|
|
168
|
+
beta_log: beta_log,
|
|
169
|
+
gamma: gamma,
|
|
170
|
+
phi: phi
|
|
171
|
+
)
|
|
172
|
+
true
|
|
173
|
+
rescue StandardError
|
|
174
|
+
false
|
|
175
|
+
end
|
|
176
|
+
|
|
177
|
+
def rust_orchestrated_em_with_start_seed(start)
|
|
178
|
+
return false unless defined?(::Lda::RustBackend)
|
|
179
|
+
return false unless ::Lda::RustBackend.respond_to?(:run_em_with_start_seed)
|
|
180
|
+
return false unless ensure_rust_corpus_snapshot
|
|
181
|
+
|
|
182
|
+
topics = Integer(num_topics)
|
|
183
|
+
return false unless topics.positive?
|
|
184
|
+
|
|
185
|
+
random_seed = Integer(next_random_seed)
|
|
186
|
+
output = ::Lda::RustBackend.run_em_with_start_seed(
|
|
187
|
+
start.to_s,
|
|
188
|
+
@rust_document_words,
|
|
189
|
+
@rust_document_counts,
|
|
190
|
+
topics,
|
|
191
|
+
Integer(@rust_corpus_terms),
|
|
192
|
+
Integer(max_iter),
|
|
193
|
+
Float(convergence),
|
|
194
|
+
Integer(em_max_iter),
|
|
195
|
+
Float(em_convergence),
|
|
196
|
+
Float(init_alpha),
|
|
197
|
+
MIN_PROBABILITY,
|
|
198
|
+
random_seed
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return false unless valid_rust_em_output?(
|
|
202
|
+
output,
|
|
203
|
+
@rust_document_lengths,
|
|
204
|
+
topics,
|
|
205
|
+
Integer(@rust_corpus_terms)
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
beta_probabilities, beta_log, gamma, phi = output
|
|
209
|
+
@fallback.apply_em_state(
|
|
210
|
+
beta_probabilities: beta_probabilities,
|
|
211
|
+
beta_log: beta_log,
|
|
212
|
+
gamma: gamma,
|
|
213
|
+
phi: phi
|
|
214
|
+
)
|
|
215
|
+
true
|
|
216
|
+
rescue StandardError
|
|
217
|
+
false
|
|
218
|
+
end
|
|
219
|
+
|
|
220
|
+
def rust_orchestrated_em_with_beta(start)
|
|
221
|
+
return false unless defined?(::Lda::RustBackend)
|
|
222
|
+
return false unless ::Lda::RustBackend.respond_to?(:run_em)
|
|
223
|
+
|
|
224
|
+
em_input =
|
|
225
|
+
if ensure_rust_corpus_snapshot && @fallback.respond_to?(:rust_initial_beta_probabilities)
|
|
226
|
+
topics = Integer(num_topics)
|
|
227
|
+
terms = Integer(@rust_corpus_terms)
|
|
228
|
+
initial_beta_probabilities = @fallback.rust_initial_beta_probabilities(
|
|
229
|
+
start,
|
|
230
|
+
@rust_document_words,
|
|
231
|
+
@rust_document_counts,
|
|
232
|
+
topics,
|
|
233
|
+
terms
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
{
|
|
237
|
+
topics: topics,
|
|
238
|
+
terms: terms,
|
|
239
|
+
document_words: @rust_document_words,
|
|
240
|
+
document_counts: @rust_document_counts,
|
|
241
|
+
document_totals: @rust_document_counts.map { |counts| counts.sum.to_f },
|
|
242
|
+
document_lengths: @rust_document_lengths,
|
|
243
|
+
initial_beta_probabilities: initial_beta_probabilities,
|
|
244
|
+
min_probability: MIN_PROBABILITY
|
|
245
|
+
}
|
|
246
|
+
else
|
|
247
|
+
@fallback.rust_em_input(start)
|
|
248
|
+
end
|
|
249
|
+
|
|
250
|
+
return true if em_input.nil?
|
|
251
|
+
|
|
252
|
+
output = ::Lda::RustBackend.run_em(
|
|
253
|
+
em_input.fetch(:initial_beta_probabilities),
|
|
254
|
+
em_input.fetch(:document_words),
|
|
255
|
+
em_input.fetch(:document_counts),
|
|
256
|
+
Integer(max_iter),
|
|
257
|
+
Float(convergence),
|
|
258
|
+
Integer(em_max_iter),
|
|
259
|
+
Float(em_convergence),
|
|
260
|
+
Float(init_alpha),
|
|
261
|
+
Float(em_input.fetch(:min_probability))
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
unless valid_rust_em_output?(
|
|
265
|
+
output,
|
|
266
|
+
em_input.fetch(:document_lengths),
|
|
267
|
+
em_input.fetch(:topics),
|
|
268
|
+
em_input.fetch(:terms)
|
|
269
|
+
)
|
|
270
|
+
@fallback.em_from_input(em_input)
|
|
271
|
+
return true
|
|
272
|
+
end
|
|
273
|
+
|
|
274
|
+
beta_probabilities, beta_log, gamma, phi = output
|
|
275
|
+
@fallback.apply_em_state(
|
|
276
|
+
beta_probabilities: beta_probabilities,
|
|
277
|
+
beta_log: beta_log,
|
|
278
|
+
gamma: gamma,
|
|
279
|
+
phi: phi
|
|
280
|
+
)
|
|
281
|
+
true
|
|
282
|
+
rescue StandardError
|
|
283
|
+
if defined?(em_input) && em_input
|
|
284
|
+
@fallback.em_from_input(em_input)
|
|
285
|
+
return true
|
|
286
|
+
end
|
|
287
|
+
|
|
288
|
+
false
|
|
289
|
+
end
|
|
290
|
+
|
|
291
|
+
def rust_em_corpus_input
|
|
292
|
+
return nil if @corpus.nil? || @corpus.num_docs.zero?
|
|
293
|
+
|
|
294
|
+
topics = Integer(num_topics)
|
|
295
|
+
raise ArgumentError, "num_topics must be greater than zero" if topics <= 0
|
|
296
|
+
|
|
297
|
+
terms = max_term_index + 1
|
|
298
|
+
raise ArgumentError, "corpus must contain terms" if terms <= 0
|
|
299
|
+
|
|
300
|
+
document_words = @corpus.documents.map { |document| document.words.map(&:to_i) }
|
|
301
|
+
document_counts = @corpus.documents.map { |document| document.counts.map(&:to_f) }
|
|
302
|
+
|
|
303
|
+
{
|
|
304
|
+
topics: topics,
|
|
305
|
+
terms: terms,
|
|
306
|
+
document_words: document_words,
|
|
307
|
+
document_counts: document_counts,
|
|
308
|
+
document_lengths: document_words.map(&:length),
|
|
309
|
+
min_probability: MIN_PROBABILITY
|
|
310
|
+
}
|
|
311
|
+
end
|
|
312
|
+
|
|
313
|
+
def max_term_index
|
|
314
|
+
return -1 if @corpus.nil? || @corpus.documents.empty?
|
|
315
|
+
|
|
316
|
+
@corpus.documents
|
|
317
|
+
.flat_map(&:words)
|
|
318
|
+
.max || -1
|
|
319
|
+
end
|
|
320
|
+
|
|
321
|
+
def register_rust_corpus_session(previous_session_id = nil)
|
|
322
|
+
@rust_corpus_session_id = nil
|
|
323
|
+
@rust_corpus_terms = nil
|
|
324
|
+
@rust_document_lengths = nil
|
|
325
|
+
@rust_document_words = nil
|
|
326
|
+
@rust_document_counts = nil
|
|
327
|
+
|
|
328
|
+
if @corpus.nil?
|
|
329
|
+
drop_rust_corpus_session_by_id(previous_session_id)
|
|
330
|
+
return
|
|
331
|
+
end
|
|
332
|
+
|
|
333
|
+
return unless defined?(::Lda::RustBackend)
|
|
334
|
+
|
|
335
|
+
em_input = rust_em_corpus_input
|
|
336
|
+
if em_input.nil?
|
|
337
|
+
drop_rust_corpus_session_by_id(previous_session_id)
|
|
338
|
+
return
|
|
339
|
+
end
|
|
340
|
+
|
|
341
|
+
@rust_corpus_terms = Integer(em_input.fetch(:terms))
|
|
342
|
+
@rust_document_lengths = em_input.fetch(:document_lengths)
|
|
343
|
+
@rust_document_words = em_input.fetch(:document_words)
|
|
344
|
+
@rust_document_counts = em_input.fetch(:document_counts)
|
|
345
|
+
|
|
346
|
+
session_id =
|
|
347
|
+
if ::Lda::RustBackend.respond_to?(:replace_corpus_session)
|
|
348
|
+
::Lda::RustBackend.replace_corpus_session(
|
|
349
|
+
Integer(previous_session_id || 0),
|
|
350
|
+
@rust_document_words,
|
|
351
|
+
@rust_document_counts,
|
|
352
|
+
Integer(@rust_corpus_terms)
|
|
353
|
+
)
|
|
354
|
+
elsif ::Lda::RustBackend.respond_to?(:create_corpus_session)
|
|
355
|
+
drop_rust_corpus_session_by_id(previous_session_id)
|
|
356
|
+
::Lda::RustBackend.create_corpus_session(
|
|
357
|
+
@rust_document_words,
|
|
358
|
+
@rust_document_counts,
|
|
359
|
+
Integer(@rust_corpus_terms)
|
|
360
|
+
)
|
|
361
|
+
end
|
|
362
|
+
|
|
363
|
+
unless session_id.is_a?(Numeric) && session_id.positive?
|
|
364
|
+
drop_rust_corpus_session_by_id(previous_session_id)
|
|
365
|
+
return
|
|
366
|
+
end
|
|
367
|
+
|
|
368
|
+
@rust_corpus_session_id = Integer(session_id)
|
|
369
|
+
rescue StandardError
|
|
370
|
+
@rust_corpus_session_id = nil
|
|
371
|
+
@rust_corpus_terms = nil
|
|
372
|
+
@rust_document_lengths = nil
|
|
373
|
+
@rust_document_words = nil
|
|
374
|
+
@rust_document_counts = nil
|
|
375
|
+
drop_rust_corpus_session_by_id(previous_session_id)
|
|
376
|
+
end
|
|
377
|
+
|
|
378
|
+
def ensure_rust_corpus_snapshot
|
|
379
|
+
has_session_data = @rust_corpus_terms && @rust_document_lengths && @rust_document_words && @rust_document_counts
|
|
380
|
+
return true if has_session_data
|
|
381
|
+
|
|
382
|
+
register_rust_corpus_session(@rust_corpus_session_id)
|
|
383
|
+
@rust_corpus_terms && @rust_document_lengths && @rust_document_words && @rust_document_counts
|
|
384
|
+
rescue StandardError
|
|
385
|
+
false
|
|
386
|
+
end
|
|
387
|
+
|
|
388
|
+
def release_rust_corpus_session
|
|
389
|
+
session_id = @rust_corpus_session_id
|
|
390
|
+
|
|
391
|
+
@rust_corpus_session_id = nil
|
|
392
|
+
@rust_corpus_terms = nil
|
|
393
|
+
@rust_document_lengths = nil
|
|
394
|
+
@rust_document_words = nil
|
|
395
|
+
@rust_document_counts = nil
|
|
396
|
+
|
|
397
|
+
drop_rust_corpus_session_by_id(session_id)
|
|
398
|
+
rescue StandardError
|
|
399
|
+
nil
|
|
400
|
+
end
|
|
401
|
+
|
|
402
|
+
def drop_rust_corpus_session_by_id(session_id)
|
|
403
|
+
return unless session_id
|
|
404
|
+
return unless defined?(::Lda::RustBackend)
|
|
405
|
+
return unless ::Lda::RustBackend.respond_to?(:drop_corpus_session)
|
|
406
|
+
|
|
407
|
+
::Lda::RustBackend.drop_corpus_session(Integer(session_id))
|
|
408
|
+
rescue StandardError
|
|
409
|
+
nil
|
|
410
|
+
end
|
|
411
|
+
|
|
412
|
+
def current_rust_session_config_signature
|
|
413
|
+
[
|
|
414
|
+
Integer(num_topics),
|
|
415
|
+
Integer(max_iter),
|
|
416
|
+
Float(convergence),
|
|
417
|
+
Integer(em_max_iter),
|
|
418
|
+
Float(em_convergence),
|
|
419
|
+
Float(init_alpha),
|
|
420
|
+
MIN_PROBABILITY
|
|
421
|
+
]
|
|
422
|
+
end
|
|
423
|
+
|
|
424
|
+
def valid_rust_em_output?(output, document_lengths, topics, terms)
|
|
425
|
+
return false unless output.is_a?(Array)
|
|
426
|
+
return false unless output.size == 4
|
|
427
|
+
|
|
428
|
+
beta_probabilities, beta_log, gamma, phi = output
|
|
429
|
+
|
|
430
|
+
valid_topic_term_matrix?(beta_probabilities, topics, terms) &&
|
|
431
|
+
valid_topic_term_matrix?(beta_log, topics, terms) &&
|
|
432
|
+
valid_gamma_matrix?(gamma, document_lengths.size, topics) &&
|
|
433
|
+
valid_phi_tensor?(phi, document_lengths, topics)
|
|
434
|
+
end
|
|
435
|
+
|
|
436
|
+
def valid_topic_term_matrix?(matrix, topics, terms)
|
|
437
|
+
return false unless matrix.is_a?(Array)
|
|
438
|
+
return false unless matrix.size == topics
|
|
439
|
+
|
|
440
|
+
matrix.all? do |row|
|
|
441
|
+
row.is_a?(Array) &&
|
|
442
|
+
row.size == terms &&
|
|
443
|
+
row.all? { |value| finite_numeric?(value) }
|
|
444
|
+
end
|
|
445
|
+
end
|
|
446
|
+
|
|
447
|
+
def valid_gamma_matrix?(gamma, expected_docs, topics)
|
|
448
|
+
return false unless gamma.is_a?(Array)
|
|
449
|
+
return false unless gamma.size == expected_docs
|
|
450
|
+
|
|
451
|
+
gamma.all? do |row|
|
|
452
|
+
row.is_a?(Array) &&
|
|
453
|
+
row.size == topics &&
|
|
454
|
+
row.all? { |value| finite_numeric?(value) && value.positive? }
|
|
455
|
+
end
|
|
456
|
+
end
|
|
457
|
+
|
|
458
|
+
def valid_phi_tensor?(phi, document_lengths, topics)
|
|
459
|
+
return false unless phi.is_a?(Array)
|
|
460
|
+
return false unless phi.size == document_lengths.size
|
|
461
|
+
|
|
462
|
+
phi.each_with_index.all? do |doc_phi, doc_index|
|
|
463
|
+
doc_phi.is_a?(Array) &&
|
|
464
|
+
doc_phi.size == document_lengths[doc_index] &&
|
|
465
|
+
doc_phi.all? do |row|
|
|
466
|
+
row.is_a?(Array) &&
|
|
467
|
+
row.size == topics &&
|
|
468
|
+
row.all? { |value| finite_numeric?(value) }
|
|
469
|
+
end
|
|
470
|
+
end
|
|
471
|
+
end
|
|
472
|
+
|
|
473
|
+
def finite_numeric?(value)
|
|
474
|
+
value.is_a?(Numeric) && value.finite?
|
|
475
|
+
end
|
|
476
|
+
|
|
477
|
+
def rust_before_em(start)
|
|
478
|
+
return unless defined?(::Lda::RustBackend)
|
|
479
|
+
return unless ::Lda::RustBackend.respond_to?(:before_em)
|
|
480
|
+
|
|
481
|
+
::Lda::RustBackend.before_em(start.to_s, @corpus&.num_docs.to_i, @corpus&.num_terms.to_i)
|
|
482
|
+
rescue StandardError
|
|
483
|
+
nil
|
|
484
|
+
end
|
|
485
|
+
|
|
486
|
+
def rust_topic_weights_for_word(beta_probabilities, gamma, word_index, min_probability)
|
|
487
|
+
return nil unless defined?(::Lda::RustBackend)
|
|
488
|
+
return nil unless ::Lda::RustBackend.respond_to?(:topic_weights_for_word)
|
|
489
|
+
|
|
490
|
+
::Lda::RustBackend.topic_weights_for_word(
|
|
491
|
+
beta_probabilities,
|
|
492
|
+
gamma,
|
|
493
|
+
Integer(word_index),
|
|
494
|
+
Float(min_probability)
|
|
495
|
+
)
|
|
496
|
+
rescue StandardError
|
|
497
|
+
nil
|
|
498
|
+
end
|
|
499
|
+
|
|
500
|
+
def rust_accumulate_topic_term_counts(topic_term_counts, phi_d, words, counts)
|
|
501
|
+
return nil unless defined?(::Lda::RustBackend)
|
|
502
|
+
return nil unless ::Lda::RustBackend.respond_to?(:accumulate_topic_term_counts)
|
|
503
|
+
|
|
504
|
+
::Lda::RustBackend.accumulate_topic_term_counts(
|
|
505
|
+
topic_term_counts,
|
|
506
|
+
phi_d,
|
|
507
|
+
words,
|
|
508
|
+
counts
|
|
509
|
+
)
|
|
510
|
+
rescue StandardError
|
|
511
|
+
nil
|
|
512
|
+
end
|
|
513
|
+
|
|
514
|
+
def rust_infer_document(beta_probabilities, gamma_initial, words, counts, max_iter, convergence, min_probability, init_alpha)
|
|
515
|
+
return nil unless defined?(::Lda::RustBackend)
|
|
516
|
+
return nil unless ::Lda::RustBackend.respond_to?(:infer_document)
|
|
517
|
+
|
|
518
|
+
output = ::Lda::RustBackend.infer_document(
|
|
519
|
+
beta_probabilities,
|
|
520
|
+
gamma_initial,
|
|
521
|
+
words,
|
|
522
|
+
counts,
|
|
523
|
+
Integer(max_iter),
|
|
524
|
+
Float(convergence),
|
|
525
|
+
Float(min_probability),
|
|
526
|
+
Float(init_alpha)
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
return nil unless output.is_a?(Array)
|
|
530
|
+
return nil if output.empty?
|
|
531
|
+
|
|
532
|
+
gamma = output.first
|
|
533
|
+
phi_rows = output[1..] || []
|
|
534
|
+
[gamma, phi_rows]
|
|
535
|
+
rescue StandardError
|
|
536
|
+
nil
|
|
537
|
+
end
|
|
538
|
+
|
|
539
|
+
def rust_infer_corpus_iteration(beta_probabilities, document_words, document_counts, max_iter, convergence, min_probability, init_alpha)
|
|
540
|
+
return nil unless defined?(::Lda::RustBackend)
|
|
541
|
+
return nil unless ::Lda::RustBackend.respond_to?(:infer_corpus_iteration)
|
|
542
|
+
|
|
543
|
+
::Lda::RustBackend.infer_corpus_iteration(
|
|
544
|
+
beta_probabilities,
|
|
545
|
+
document_words,
|
|
546
|
+
document_counts,
|
|
547
|
+
Integer(max_iter),
|
|
548
|
+
Float(convergence),
|
|
549
|
+
Float(min_probability),
|
|
550
|
+
Float(init_alpha)
|
|
551
|
+
)
|
|
552
|
+
rescue StandardError
|
|
553
|
+
nil
|
|
554
|
+
end
|
|
555
|
+
|
|
556
|
+
def rust_finalize_topic_term_counts(topic_term_counts, min_probability)
|
|
557
|
+
return nil unless defined?(::Lda::RustBackend)
|
|
558
|
+
return nil unless ::Lda::RustBackend.respond_to?(:normalize_topic_term_counts)
|
|
559
|
+
|
|
560
|
+
::Lda::RustBackend.normalize_topic_term_counts(
|
|
561
|
+
topic_term_counts,
|
|
562
|
+
Float(min_probability)
|
|
563
|
+
)
|
|
564
|
+
rescue StandardError
|
|
565
|
+
nil
|
|
566
|
+
end
|
|
567
|
+
|
|
568
|
+
def rust_average_gamma_shift(previous_gamma, current_gamma)
|
|
569
|
+
return nil unless defined?(::Lda::RustBackend)
|
|
570
|
+
return nil unless ::Lda::RustBackend.respond_to?(:average_gamma_shift)
|
|
571
|
+
|
|
572
|
+
::Lda::RustBackend.average_gamma_shift(previous_gamma, current_gamma)
|
|
573
|
+
rescue StandardError
|
|
574
|
+
nil
|
|
575
|
+
end
|
|
576
|
+
|
|
577
|
+
def rust_topic_document_probability(phi_matrix, document_counts, num_topics, min_probability)
|
|
578
|
+
return nil unless defined?(::Lda::RustBackend)
|
|
579
|
+
return nil unless ::Lda::RustBackend.respond_to?(:topic_document_probability)
|
|
580
|
+
|
|
581
|
+
::Lda::RustBackend.topic_document_probability(
|
|
582
|
+
phi_matrix,
|
|
583
|
+
document_counts,
|
|
584
|
+
Integer(num_topics),
|
|
585
|
+
Float(min_probability)
|
|
586
|
+
)
|
|
587
|
+
rescue StandardError
|
|
588
|
+
nil
|
|
589
|
+
end
|
|
590
|
+
|
|
591
|
+
def rust_seeded_topic_term_probabilities(document_words, document_counts, topics, terms, min_probability)
|
|
592
|
+
return nil unless defined?(::Lda::RustBackend)
|
|
593
|
+
return nil unless ::Lda::RustBackend.respond_to?(:seeded_topic_term_probabilities)
|
|
594
|
+
|
|
595
|
+
::Lda::RustBackend.seeded_topic_term_probabilities(
|
|
596
|
+
document_words,
|
|
597
|
+
document_counts,
|
|
598
|
+
Integer(topics),
|
|
599
|
+
Integer(terms),
|
|
600
|
+
Float(min_probability)
|
|
601
|
+
)
|
|
602
|
+
rescue StandardError
|
|
603
|
+
nil
|
|
604
|
+
end
|
|
605
|
+
end
|
|
606
|
+
end
|
|
607
|
+
end
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "lda-ruby/backends/base"
|
|
4
|
+
require "lda-ruby/backends/rust"
|
|
5
|
+
require "lda-ruby/backends/native"
|
|
6
|
+
require "lda-ruby/backends/pure_ruby"
|
|
7
|
+
|
|
8
|
+
module Lda
|
|
9
|
+
module Backends
|
|
10
|
+
class << self
|
|
11
|
+
def build(host:, requested: nil, random_seed: nil)
|
|
12
|
+
mode = normalize_mode(requested)
|
|
13
|
+
|
|
14
|
+
case mode
|
|
15
|
+
when :auto
|
|
16
|
+
if Rust.available?
|
|
17
|
+
Rust.new(random_seed: random_seed)
|
|
18
|
+
elsif Native.available?(host)
|
|
19
|
+
Native.new(host, random_seed: random_seed)
|
|
20
|
+
else
|
|
21
|
+
PureRuby.new(random_seed: random_seed)
|
|
22
|
+
end
|
|
23
|
+
when :rust
|
|
24
|
+
raise LoadError, "Rust backend is unavailable for this environment" unless Rust.available?
|
|
25
|
+
|
|
26
|
+
Rust.new(random_seed: random_seed)
|
|
27
|
+
when :native
|
|
28
|
+
raise LoadError, "Native backend is unavailable for this environment" unless Native.available?(host)
|
|
29
|
+
|
|
30
|
+
Native.new(host, random_seed: random_seed)
|
|
31
|
+
when :pure
|
|
32
|
+
PureRuby.new(random_seed: random_seed)
|
|
33
|
+
else
|
|
34
|
+
raise ArgumentError, "Unknown backend mode: #{requested.inspect}"
|
|
35
|
+
end
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
private
|
|
39
|
+
|
|
40
|
+
def normalize_mode(requested)
|
|
41
|
+
raw_mode = requested || ENV.fetch("LDA_RUBY_BACKEND", "auto")
|
|
42
|
+
|
|
43
|
+
case raw_mode.to_s.strip.downcase
|
|
44
|
+
when "", "auto"
|
|
45
|
+
:auto
|
|
46
|
+
when "native", "c"
|
|
47
|
+
:native
|
|
48
|
+
when "rust", "rust_native"
|
|
49
|
+
:rust
|
|
50
|
+
when "pure", "ruby", "pure_ruby"
|
|
51
|
+
:pure
|
|
52
|
+
else
|
|
53
|
+
raw_mode
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
end
|
|
57
|
+
end
|
|
58
|
+
end
|