lda-ruby 0.3.9 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. checksums.yaml +5 -13
  2. data/CHANGELOG.md +16 -0
  3. data/Gemfile +9 -0
  4. data/README.md +126 -3
  5. data/VERSION.yml +3 -3
  6. data/docs/modernization-handoff.md +233 -0
  7. data/docs/porting-strategy.md +148 -0
  8. data/docs/precompiled-platform-policy.md +81 -0
  9. data/docs/precompiled-target-evaluation.md +67 -0
  10. data/docs/release-runbook.md +192 -0
  11. data/docs/rust-orchestration-guardrails.md +50 -0
  12. data/ext/lda-ruby/cokus.c +10 -11
  13. data/ext/lda-ruby/cokus.h +3 -3
  14. data/ext/lda-ruby/extconf.rb +10 -6
  15. data/ext/lda-ruby/lda-inference.c +23 -7
  16. data/ext/lda-ruby/utils.c +8 -0
  17. data/ext/lda-ruby-rust/Cargo.toml +12 -0
  18. data/ext/lda-ruby-rust/README.md +73 -0
  19. data/ext/lda-ruby-rust/extconf.rb +135 -0
  20. data/ext/lda-ruby-rust/include/strings.h +35 -0
  21. data/ext/lda-ruby-rust/src/lib.rs +1263 -0
  22. data/lda-ruby.gemspec +0 -0
  23. data/lib/lda-ruby/backends/base.rb +133 -0
  24. data/lib/lda-ruby/backends/native.rb +158 -0
  25. data/lib/lda-ruby/backends/pure_ruby.rb +675 -0
  26. data/lib/lda-ruby/backends/rust.rb +607 -0
  27. data/lib/lda-ruby/backends.rb +58 -0
  28. data/lib/lda-ruby/corpus/corpus.rb +17 -15
  29. data/lib/lda-ruby/corpus/data_corpus.rb +2 -2
  30. data/lib/lda-ruby/corpus/directory_corpus.rb +2 -2
  31. data/lib/lda-ruby/corpus/text_corpus.rb +2 -2
  32. data/lib/lda-ruby/document/document.rb +6 -6
  33. data/lib/lda-ruby/document/text_document.rb +5 -4
  34. data/lib/lda-ruby/rust_build_policy.rb +21 -0
  35. data/lib/lda-ruby/version.rb +5 -0
  36. data/lib/lda-ruby.rb +293 -48
  37. data/test/backend_compatibility_test.rb +146 -0
  38. data/test/backends_selection_test.rb +100 -0
  39. data/test/benchmark_scripts_test.rb +23 -0
  40. data/test/gemspec_test.rb +27 -0
  41. data/test/lda_ruby_test.rb +49 -11
  42. data/test/packaged_gem_smoke_test.rb +33 -0
  43. data/test/pure_ruby_orchestration_test.rb +109 -0
  44. data/test/release_scripts_test.rb +93 -0
  45. data/test/rust_build_policy_test.rb +23 -0
  46. data/test/rust_orchestration_test.rb +911 -0
  47. data/test/simple_pipeline_test.rb +22 -0
  48. data/test/simple_yaml.rb +1 -7
  49. data/test/test_helper.rb +5 -6
  50. metadata +54 -38
  51. data/Rakefile +0 -61
  52. data/ext/lda-ruby/Makefile +0 -181
  53. data/test/data/.gitignore +0 -2
  54. data/test/simple_test.rb +0 -26
@@ -0,0 +1,607 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Lda
4
+ module Backends
5
+ class Rust < Base
6
+ SETTINGS = %i[max_iter convergence em_max_iter em_convergence num_topics init_alpha est_alpha verbose].freeze
7
+ MIN_PROBABILITY = 1e-12
8
+
9
+ def self.available?
10
+ return false unless defined?(::Lda::RUST_EXTENSION_LOADED) && ::Lda::RUST_EXTENSION_LOADED
11
+ return false unless defined?(::Lda::RustBackend)
12
+
13
+ if ::Lda::RustBackend.respond_to?(:available?)
14
+ ::Lda::RustBackend.available?
15
+ else
16
+ true
17
+ end
18
+ rescue StandardError
19
+ false
20
+ end
21
+
22
+ SETTINGS.each do |setting_name|
23
+ define_method(setting_name) do
24
+ @fallback.public_send(setting_name)
25
+ end
26
+
27
+ define_method("#{setting_name}=") do |value|
28
+ @fallback.public_send("#{setting_name}=", value)
29
+ end
30
+ end
31
+
32
+ def initialize(random_seed: nil)
33
+ super(random_seed: random_seed)
34
+ raise LoadError, "Rust backend is unavailable for this environment" unless self.class.available?
35
+
36
+ @rust_corpus_session_id = nil
37
+ @rust_corpus_terms = nil
38
+ @rust_document_lengths = nil
39
+ @rust_document_words = nil
40
+ @rust_document_counts = nil
41
+
42
+ @fallback = PureRuby.new(random_seed: random_seed)
43
+ @fallback.topic_weights_kernel = method(:rust_topic_weights_for_word)
44
+ @fallback.topic_term_accumulator_kernel = method(:rust_accumulate_topic_term_counts)
45
+ @fallback.document_inference_kernel = method(:rust_infer_document)
46
+ @fallback.corpus_iteration_kernel = method(:rust_infer_corpus_iteration)
47
+ @fallback.topic_term_finalizer_kernel = method(:rust_finalize_topic_term_counts)
48
+ @fallback.gamma_shift_kernel = method(:rust_average_gamma_shift)
49
+ @fallback.topic_document_probability_kernel = method(:rust_topic_document_probability)
50
+ @fallback.topic_term_seed_kernel = method(:rust_seeded_topic_term_probabilities)
51
+ @fallback.trusted_kernel_outputs = true
52
+ end
53
+
54
+ def name
55
+ "rust"
56
+ end
57
+
58
+ def corpus=(corpus)
59
+ previous_session_id = @rust_corpus_session_id
60
+ @corpus = corpus
61
+ @fallback.corpus = corpus
62
+ register_rust_corpus_session(previous_session_id)
63
+ true
64
+ end
65
+
66
+ def fast_load_corpus_from_file(filename)
67
+ loaded = @fallback.fast_load_corpus_from_file(filename)
68
+ self.corpus = @fallback.corpus
69
+ loaded
70
+ end
71
+
72
+ def load_settings(settings_file)
73
+ loaded = @fallback.load_settings(settings_file)
74
+ self.corpus = @fallback.corpus
75
+ loaded
76
+ end
77
+
78
+ def set_config(init_alpha, num_topics, max_iter, convergence, em_max_iter, em_convergence, est_alpha)
79
+ @fallback.set_config(init_alpha, num_topics, max_iter, convergence, em_max_iter, em_convergence, est_alpha)
80
+ end
81
+
82
+ def em(start)
83
+ start_mode = start.to_s
84
+ rust_before_em(start_mode)
85
+ return nil if rust_orchestrated_em(start_mode)
86
+
87
+ @fallback.em(start)
88
+ end
89
+
90
+ def beta
91
+ @fallback.beta
92
+ end
93
+
94
+ def gamma
95
+ @fallback.gamma
96
+ end
97
+
98
+ def compute_phi
99
+ @fallback.compute_phi
100
+ end
101
+
102
+ def model
103
+ @fallback.model
104
+ end
105
+
106
+ private
107
+
108
+ def rust_orchestrated_em(start)
109
+ managed_orchestrated = rust_orchestrated_em_with_managed_corpus(start)
110
+ return true if managed_orchestrated
111
+
112
+ direct_orchestrated = rust_orchestrated_em_with_start_seed(start)
113
+ return true if direct_orchestrated
114
+
115
+ rust_orchestrated_em_with_beta(start)
116
+ end
117
+
118
+ def rust_orchestrated_em_with_managed_corpus(start)
119
+ return false unless defined?(::Lda::RustBackend)
120
+ return false unless ensure_rust_corpus_snapshot
121
+
122
+ random_seed = Integer(next_random_seed)
123
+ if ::Lda::RustBackend.respond_to?(:run_em_on_session_with_corpus)
124
+ managed_output = ::Lda::RustBackend.run_em_on_session_with_corpus(
125
+ Integer(@rust_corpus_session_id || 0),
126
+ @rust_document_words,
127
+ @rust_document_counts,
128
+ Integer(@rust_corpus_terms),
129
+ start.to_s,
130
+ *current_rust_session_config_signature,
131
+ random_seed
132
+ )
133
+
134
+ return false unless managed_output.is_a?(Array) && managed_output.size == 5
135
+
136
+ session_id, beta_probabilities, beta_log, gamma, phi = managed_output
137
+ output = [beta_probabilities, beta_log, gamma, phi]
138
+ return false unless valid_rust_em_output?(output, @rust_document_lengths, Integer(num_topics), Integer(@rust_corpus_terms))
139
+
140
+ @rust_corpus_session_id =
141
+ if session_id.is_a?(Numeric) && session_id.positive?
142
+ Integer(session_id)
143
+ end
144
+ @fallback.apply_em_state(
145
+ beta_probabilities: beta_probabilities,
146
+ beta_log: beta_log,
147
+ gamma: gamma,
148
+ phi: phi
149
+ )
150
+ return true
151
+ end
152
+
153
+ return false unless ::Lda::RustBackend.respond_to?(:run_em_on_session)
154
+ return false unless @rust_corpus_session_id
155
+
156
+ output = ::Lda::RustBackend.run_em_on_session(
157
+ Integer(@rust_corpus_session_id),
158
+ start.to_s,
159
+ *current_rust_session_config_signature,
160
+ random_seed
161
+ )
162
+
163
+ return false unless valid_rust_em_output?(output, @rust_document_lengths, Integer(num_topics), Integer(@rust_corpus_terms))
164
+
165
+ beta_probabilities, beta_log, gamma, phi = output
166
+ @fallback.apply_em_state(
167
+ beta_probabilities: beta_probabilities,
168
+ beta_log: beta_log,
169
+ gamma: gamma,
170
+ phi: phi
171
+ )
172
+ true
173
+ rescue StandardError
174
+ false
175
+ end
176
+
177
+ def rust_orchestrated_em_with_start_seed(start)
178
+ return false unless defined?(::Lda::RustBackend)
179
+ return false unless ::Lda::RustBackend.respond_to?(:run_em_with_start_seed)
180
+ return false unless ensure_rust_corpus_snapshot
181
+
182
+ topics = Integer(num_topics)
183
+ return false unless topics.positive?
184
+
185
+ random_seed = Integer(next_random_seed)
186
+ output = ::Lda::RustBackend.run_em_with_start_seed(
187
+ start.to_s,
188
+ @rust_document_words,
189
+ @rust_document_counts,
190
+ topics,
191
+ Integer(@rust_corpus_terms),
192
+ Integer(max_iter),
193
+ Float(convergence),
194
+ Integer(em_max_iter),
195
+ Float(em_convergence),
196
+ Float(init_alpha),
197
+ MIN_PROBABILITY,
198
+ random_seed
199
+ )
200
+
201
+ return false unless valid_rust_em_output?(
202
+ output,
203
+ @rust_document_lengths,
204
+ topics,
205
+ Integer(@rust_corpus_terms)
206
+ )
207
+
208
+ beta_probabilities, beta_log, gamma, phi = output
209
+ @fallback.apply_em_state(
210
+ beta_probabilities: beta_probabilities,
211
+ beta_log: beta_log,
212
+ gamma: gamma,
213
+ phi: phi
214
+ )
215
+ true
216
+ rescue StandardError
217
+ false
218
+ end
219
+
220
+ def rust_orchestrated_em_with_beta(start)
221
+ return false unless defined?(::Lda::RustBackend)
222
+ return false unless ::Lda::RustBackend.respond_to?(:run_em)
223
+
224
+ em_input =
225
+ if ensure_rust_corpus_snapshot && @fallback.respond_to?(:rust_initial_beta_probabilities)
226
+ topics = Integer(num_topics)
227
+ terms = Integer(@rust_corpus_terms)
228
+ initial_beta_probabilities = @fallback.rust_initial_beta_probabilities(
229
+ start,
230
+ @rust_document_words,
231
+ @rust_document_counts,
232
+ topics,
233
+ terms
234
+ )
235
+
236
+ {
237
+ topics: topics,
238
+ terms: terms,
239
+ document_words: @rust_document_words,
240
+ document_counts: @rust_document_counts,
241
+ document_totals: @rust_document_counts.map { |counts| counts.sum.to_f },
242
+ document_lengths: @rust_document_lengths,
243
+ initial_beta_probabilities: initial_beta_probabilities,
244
+ min_probability: MIN_PROBABILITY
245
+ }
246
+ else
247
+ @fallback.rust_em_input(start)
248
+ end
249
+
250
+ return true if em_input.nil?
251
+
252
+ output = ::Lda::RustBackend.run_em(
253
+ em_input.fetch(:initial_beta_probabilities),
254
+ em_input.fetch(:document_words),
255
+ em_input.fetch(:document_counts),
256
+ Integer(max_iter),
257
+ Float(convergence),
258
+ Integer(em_max_iter),
259
+ Float(em_convergence),
260
+ Float(init_alpha),
261
+ Float(em_input.fetch(:min_probability))
262
+ )
263
+
264
+ unless valid_rust_em_output?(
265
+ output,
266
+ em_input.fetch(:document_lengths),
267
+ em_input.fetch(:topics),
268
+ em_input.fetch(:terms)
269
+ )
270
+ @fallback.em_from_input(em_input)
271
+ return true
272
+ end
273
+
274
+ beta_probabilities, beta_log, gamma, phi = output
275
+ @fallback.apply_em_state(
276
+ beta_probabilities: beta_probabilities,
277
+ beta_log: beta_log,
278
+ gamma: gamma,
279
+ phi: phi
280
+ )
281
+ true
282
+ rescue StandardError
283
+ if defined?(em_input) && em_input
284
+ @fallback.em_from_input(em_input)
285
+ return true
286
+ end
287
+
288
+ false
289
+ end
290
+
291
+ def rust_em_corpus_input
292
+ return nil if @corpus.nil? || @corpus.num_docs.zero?
293
+
294
+ topics = Integer(num_topics)
295
+ raise ArgumentError, "num_topics must be greater than zero" if topics <= 0
296
+
297
+ terms = max_term_index + 1
298
+ raise ArgumentError, "corpus must contain terms" if terms <= 0
299
+
300
+ document_words = @corpus.documents.map { |document| document.words.map(&:to_i) }
301
+ document_counts = @corpus.documents.map { |document| document.counts.map(&:to_f) }
302
+
303
+ {
304
+ topics: topics,
305
+ terms: terms,
306
+ document_words: document_words,
307
+ document_counts: document_counts,
308
+ document_lengths: document_words.map(&:length),
309
+ min_probability: MIN_PROBABILITY
310
+ }
311
+ end
312
+
313
+ def max_term_index
314
+ return -1 if @corpus.nil? || @corpus.documents.empty?
315
+
316
+ @corpus.documents
317
+ .flat_map(&:words)
318
+ .max || -1
319
+ end
320
+
321
+ def register_rust_corpus_session(previous_session_id = nil)
322
+ @rust_corpus_session_id = nil
323
+ @rust_corpus_terms = nil
324
+ @rust_document_lengths = nil
325
+ @rust_document_words = nil
326
+ @rust_document_counts = nil
327
+
328
+ if @corpus.nil?
329
+ drop_rust_corpus_session_by_id(previous_session_id)
330
+ return
331
+ end
332
+
333
+ return unless defined?(::Lda::RustBackend)
334
+
335
+ em_input = rust_em_corpus_input
336
+ if em_input.nil?
337
+ drop_rust_corpus_session_by_id(previous_session_id)
338
+ return
339
+ end
340
+
341
+ @rust_corpus_terms = Integer(em_input.fetch(:terms))
342
+ @rust_document_lengths = em_input.fetch(:document_lengths)
343
+ @rust_document_words = em_input.fetch(:document_words)
344
+ @rust_document_counts = em_input.fetch(:document_counts)
345
+
346
+ session_id =
347
+ if ::Lda::RustBackend.respond_to?(:replace_corpus_session)
348
+ ::Lda::RustBackend.replace_corpus_session(
349
+ Integer(previous_session_id || 0),
350
+ @rust_document_words,
351
+ @rust_document_counts,
352
+ Integer(@rust_corpus_terms)
353
+ )
354
+ elsif ::Lda::RustBackend.respond_to?(:create_corpus_session)
355
+ drop_rust_corpus_session_by_id(previous_session_id)
356
+ ::Lda::RustBackend.create_corpus_session(
357
+ @rust_document_words,
358
+ @rust_document_counts,
359
+ Integer(@rust_corpus_terms)
360
+ )
361
+ end
362
+
363
+ unless session_id.is_a?(Numeric) && session_id.positive?
364
+ drop_rust_corpus_session_by_id(previous_session_id)
365
+ return
366
+ end
367
+
368
+ @rust_corpus_session_id = Integer(session_id)
369
+ rescue StandardError
370
+ @rust_corpus_session_id = nil
371
+ @rust_corpus_terms = nil
372
+ @rust_document_lengths = nil
373
+ @rust_document_words = nil
374
+ @rust_document_counts = nil
375
+ drop_rust_corpus_session_by_id(previous_session_id)
376
+ end
377
+
378
+ def ensure_rust_corpus_snapshot
379
+ has_session_data = @rust_corpus_terms && @rust_document_lengths && @rust_document_words && @rust_document_counts
380
+ return true if has_session_data
381
+
382
+ register_rust_corpus_session(@rust_corpus_session_id)
383
+ @rust_corpus_terms && @rust_document_lengths && @rust_document_words && @rust_document_counts
384
+ rescue StandardError
385
+ false
386
+ end
387
+
388
+ def release_rust_corpus_session
389
+ session_id = @rust_corpus_session_id
390
+
391
+ @rust_corpus_session_id = nil
392
+ @rust_corpus_terms = nil
393
+ @rust_document_lengths = nil
394
+ @rust_document_words = nil
395
+ @rust_document_counts = nil
396
+
397
+ drop_rust_corpus_session_by_id(session_id)
398
+ rescue StandardError
399
+ nil
400
+ end
401
+
402
+ def drop_rust_corpus_session_by_id(session_id)
403
+ return unless session_id
404
+ return unless defined?(::Lda::RustBackend)
405
+ return unless ::Lda::RustBackend.respond_to?(:drop_corpus_session)
406
+
407
+ ::Lda::RustBackend.drop_corpus_session(Integer(session_id))
408
+ rescue StandardError
409
+ nil
410
+ end
411
+
412
+ def current_rust_session_config_signature
413
+ [
414
+ Integer(num_topics),
415
+ Integer(max_iter),
416
+ Float(convergence),
417
+ Integer(em_max_iter),
418
+ Float(em_convergence),
419
+ Float(init_alpha),
420
+ MIN_PROBABILITY
421
+ ]
422
+ end
423
+
424
+ def valid_rust_em_output?(output, document_lengths, topics, terms)
425
+ return false unless output.is_a?(Array)
426
+ return false unless output.size == 4
427
+
428
+ beta_probabilities, beta_log, gamma, phi = output
429
+
430
+ valid_topic_term_matrix?(beta_probabilities, topics, terms) &&
431
+ valid_topic_term_matrix?(beta_log, topics, terms) &&
432
+ valid_gamma_matrix?(gamma, document_lengths.size, topics) &&
433
+ valid_phi_tensor?(phi, document_lengths, topics)
434
+ end
435
+
436
+ def valid_topic_term_matrix?(matrix, topics, terms)
437
+ return false unless matrix.is_a?(Array)
438
+ return false unless matrix.size == topics
439
+
440
+ matrix.all? do |row|
441
+ row.is_a?(Array) &&
442
+ row.size == terms &&
443
+ row.all? { |value| finite_numeric?(value) }
444
+ end
445
+ end
446
+
447
+ def valid_gamma_matrix?(gamma, expected_docs, topics)
448
+ return false unless gamma.is_a?(Array)
449
+ return false unless gamma.size == expected_docs
450
+
451
+ gamma.all? do |row|
452
+ row.is_a?(Array) &&
453
+ row.size == topics &&
454
+ row.all? { |value| finite_numeric?(value) && value.positive? }
455
+ end
456
+ end
457
+
458
+ def valid_phi_tensor?(phi, document_lengths, topics)
459
+ return false unless phi.is_a?(Array)
460
+ return false unless phi.size == document_lengths.size
461
+
462
+ phi.each_with_index.all? do |doc_phi, doc_index|
463
+ doc_phi.is_a?(Array) &&
464
+ doc_phi.size == document_lengths[doc_index] &&
465
+ doc_phi.all? do |row|
466
+ row.is_a?(Array) &&
467
+ row.size == topics &&
468
+ row.all? { |value| finite_numeric?(value) }
469
+ end
470
+ end
471
+ end
472
+
473
+ def finite_numeric?(value)
474
+ value.is_a?(Numeric) && value.finite?
475
+ end
476
+
477
+ def rust_before_em(start)
478
+ return unless defined?(::Lda::RustBackend)
479
+ return unless ::Lda::RustBackend.respond_to?(:before_em)
480
+
481
+ ::Lda::RustBackend.before_em(start.to_s, @corpus&.num_docs.to_i, @corpus&.num_terms.to_i)
482
+ rescue StandardError
483
+ nil
484
+ end
485
+
486
+ def rust_topic_weights_for_word(beta_probabilities, gamma, word_index, min_probability)
487
+ return nil unless defined?(::Lda::RustBackend)
488
+ return nil unless ::Lda::RustBackend.respond_to?(:topic_weights_for_word)
489
+
490
+ ::Lda::RustBackend.topic_weights_for_word(
491
+ beta_probabilities,
492
+ gamma,
493
+ Integer(word_index),
494
+ Float(min_probability)
495
+ )
496
+ rescue StandardError
497
+ nil
498
+ end
499
+
500
+ def rust_accumulate_topic_term_counts(topic_term_counts, phi_d, words, counts)
501
+ return nil unless defined?(::Lda::RustBackend)
502
+ return nil unless ::Lda::RustBackend.respond_to?(:accumulate_topic_term_counts)
503
+
504
+ ::Lda::RustBackend.accumulate_topic_term_counts(
505
+ topic_term_counts,
506
+ phi_d,
507
+ words,
508
+ counts
509
+ )
510
+ rescue StandardError
511
+ nil
512
+ end
513
+
514
+ def rust_infer_document(beta_probabilities, gamma_initial, words, counts, max_iter, convergence, min_probability, init_alpha)
515
+ return nil unless defined?(::Lda::RustBackend)
516
+ return nil unless ::Lda::RustBackend.respond_to?(:infer_document)
517
+
518
+ output = ::Lda::RustBackend.infer_document(
519
+ beta_probabilities,
520
+ gamma_initial,
521
+ words,
522
+ counts,
523
+ Integer(max_iter),
524
+ Float(convergence),
525
+ Float(min_probability),
526
+ Float(init_alpha)
527
+ )
528
+
529
+ return nil unless output.is_a?(Array)
530
+ return nil if output.empty?
531
+
532
+ gamma = output.first
533
+ phi_rows = output[1..] || []
534
+ [gamma, phi_rows]
535
+ rescue StandardError
536
+ nil
537
+ end
538
+
539
+ def rust_infer_corpus_iteration(beta_probabilities, document_words, document_counts, max_iter, convergence, min_probability, init_alpha)
540
+ return nil unless defined?(::Lda::RustBackend)
541
+ return nil unless ::Lda::RustBackend.respond_to?(:infer_corpus_iteration)
542
+
543
+ ::Lda::RustBackend.infer_corpus_iteration(
544
+ beta_probabilities,
545
+ document_words,
546
+ document_counts,
547
+ Integer(max_iter),
548
+ Float(convergence),
549
+ Float(min_probability),
550
+ Float(init_alpha)
551
+ )
552
+ rescue StandardError
553
+ nil
554
+ end
555
+
556
+ def rust_finalize_topic_term_counts(topic_term_counts, min_probability)
557
+ return nil unless defined?(::Lda::RustBackend)
558
+ return nil unless ::Lda::RustBackend.respond_to?(:normalize_topic_term_counts)
559
+
560
+ ::Lda::RustBackend.normalize_topic_term_counts(
561
+ topic_term_counts,
562
+ Float(min_probability)
563
+ )
564
+ rescue StandardError
565
+ nil
566
+ end
567
+
568
+ def rust_average_gamma_shift(previous_gamma, current_gamma)
569
+ return nil unless defined?(::Lda::RustBackend)
570
+ return nil unless ::Lda::RustBackend.respond_to?(:average_gamma_shift)
571
+
572
+ ::Lda::RustBackend.average_gamma_shift(previous_gamma, current_gamma)
573
+ rescue StandardError
574
+ nil
575
+ end
576
+
577
+ def rust_topic_document_probability(phi_matrix, document_counts, num_topics, min_probability)
578
+ return nil unless defined?(::Lda::RustBackend)
579
+ return nil unless ::Lda::RustBackend.respond_to?(:topic_document_probability)
580
+
581
+ ::Lda::RustBackend.topic_document_probability(
582
+ phi_matrix,
583
+ document_counts,
584
+ Integer(num_topics),
585
+ Float(min_probability)
586
+ )
587
+ rescue StandardError
588
+ nil
589
+ end
590
+
591
+ def rust_seeded_topic_term_probabilities(document_words, document_counts, topics, terms, min_probability)
592
+ return nil unless defined?(::Lda::RustBackend)
593
+ return nil unless ::Lda::RustBackend.respond_to?(:seeded_topic_term_probabilities)
594
+
595
+ ::Lda::RustBackend.seeded_topic_term_probabilities(
596
+ document_words,
597
+ document_counts,
598
+ Integer(topics),
599
+ Integer(terms),
600
+ Float(min_probability)
601
+ )
602
+ rescue StandardError
603
+ nil
604
+ end
605
+ end
606
+ end
607
+ end
@@ -0,0 +1,58 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "lda-ruby/backends/base"
4
+ require "lda-ruby/backends/rust"
5
+ require "lda-ruby/backends/native"
6
+ require "lda-ruby/backends/pure_ruby"
7
+
8
+ module Lda
9
+ module Backends
10
+ class << self
11
+ def build(host:, requested: nil, random_seed: nil)
12
+ mode = normalize_mode(requested)
13
+
14
+ case mode
15
+ when :auto
16
+ if Rust.available?
17
+ Rust.new(random_seed: random_seed)
18
+ elsif Native.available?(host)
19
+ Native.new(host, random_seed: random_seed)
20
+ else
21
+ PureRuby.new(random_seed: random_seed)
22
+ end
23
+ when :rust
24
+ raise LoadError, "Rust backend is unavailable for this environment" unless Rust.available?
25
+
26
+ Rust.new(random_seed: random_seed)
27
+ when :native
28
+ raise LoadError, "Native backend is unavailable for this environment" unless Native.available?(host)
29
+
30
+ Native.new(host, random_seed: random_seed)
31
+ when :pure
32
+ PureRuby.new(random_seed: random_seed)
33
+ else
34
+ raise ArgumentError, "Unknown backend mode: #{requested.inspect}"
35
+ end
36
+ end
37
+
38
+ private
39
+
40
+ def normalize_mode(requested)
41
+ raw_mode = requested || ENV.fetch("LDA_RUBY_BACKEND", "auto")
42
+
43
+ case raw_mode.to_s.strip.downcase
44
+ when "", "auto"
45
+ :auto
46
+ when "native", "c"
47
+ :native
48
+ when "rust", "rust_native"
49
+ :rust
50
+ when "pure", "ruby", "pure_ruby"
51
+ :pure
52
+ else
53
+ raw_mode
54
+ end
55
+ end
56
+ end
57
+ end
58
+ end