lda-ruby 0.4.0 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -109,6 +109,10 @@ module Lda
109
109
 
110
110
  private
111
111
 
112
+ def next_random_seed
113
+ @random.rand(0..9_223_372_036_854_775_807)
114
+ end
115
+
112
116
  def normalize!(weights)
113
117
  total = weights.sum.to_f
114
118
 
@@ -46,61 +46,48 @@ module Lda
46
46
  end
47
47
 
48
48
  def em(start)
49
- return nil if @corpus.nil? || @corpus.num_docs.zero?
50
-
51
- topics = Integer(num_topics)
52
- raise ArgumentError, "num_topics must be greater than zero" if topics <= 0
53
-
54
- terms = max_term_index + 1
55
- raise ArgumentError, "corpus must contain terms" if terms <= 0
49
+ em_input = build_em_input(start)
50
+ return nil if em_input.nil?
56
51
 
57
- document_words = @corpus.documents.map { |document| document.words.map(&:to_i) }
58
- document_counts = @corpus.documents.map { |document| document.counts.map(&:to_f) }
52
+ run_em_iterations(em_input)
53
+ nil
54
+ end
59
55
 
60
- @beta_probabilities =
61
- if start.to_s.strip.casecmp("seeded").zero? || start.to_s.strip.casecmp("deterministic").zero?
62
- seeded_topic_term_probabilities(topics, terms, document_words, document_counts)
63
- else
64
- initial_topic_term_probabilities(topics, terms)
65
- end
56
+ # Returns an EM input snapshot that can be reused by Rust orchestration
57
+ # and Ruby fallback paths without re-sampling random initialization.
58
+ def rust_em_input(start)
59
+ build_em_input(start)
60
+ end
66
61
 
67
- document_totals = document_counts.map { |counts| counts.sum.to_f }
68
- document_lengths = document_words.map(&:length)
62
+ # Returns only the initial beta matrix for Rust compatibility paths that
63
+ # already hold a cached corpus snapshot.
64
+ def rust_initial_beta_probabilities(start, document_words, document_counts, topics, terms)
65
+ start_mode = start.to_s
69
66
 
70
- previous_gamma = nil
67
+ if start_mode.strip.casecmp("seeded").zero? || start_mode.strip.casecmp("deterministic").zero?
68
+ seeded_topic_term_probabilities(
69
+ Integer(topics),
70
+ Integer(terms),
71
+ document_words,
72
+ document_counts
73
+ )
74
+ else
75
+ initial_topic_term_probabilities(Integer(topics), Integer(terms))
76
+ end
77
+ end
71
78
 
72
- Integer(em_max_iter).times do
73
- if @trusted_kernel_outputs && @corpus_iteration_kernel
74
- current_gamma, current_phi, topic_term_counts = infer_corpus_iteration(
75
- nil,
76
- document_words,
77
- document_counts,
78
- document_totals,
79
- document_lengths,
80
- topics,
81
- terms
82
- )
83
- else
84
- topic_term_counts = Array.new(topics) { Array.new(terms, MIN_PROBABILITY) }
85
- current_gamma, current_phi, topic_term_counts = infer_corpus_iteration(
86
- topic_term_counts,
87
- document_words,
88
- document_counts,
89
- document_totals,
90
- document_lengths,
91
- topics,
92
- terms
93
- )
94
- end
79
+ def em_from_input(em_input)
80
+ return nil if em_input.nil?
95
81
 
96
- @beta_probabilities, @beta_log = finalize_topic_term_counts(topic_term_counts)
97
- @gamma = current_gamma
98
- @phi = current_phi
82
+ run_em_iterations(em_input)
83
+ nil
84
+ end
99
85
 
100
- break if previous_gamma && average_gamma_shift(previous_gamma, current_gamma) <= Float(em_convergence)
101
-
102
- previous_gamma = current_gamma
103
- end
86
+ def apply_em_state(beta_probabilities:, beta_log:, gamma:, phi:)
87
+ @beta_probabilities = beta_probabilities
88
+ @beta_log = beta_log
89
+ @gamma = gamma
90
+ @phi = phi
104
91
 
105
92
  nil
106
93
  end
@@ -147,6 +134,81 @@ module Lda
147
134
 
148
135
  private
149
136
 
137
+ def build_em_input(start)
138
+ return nil if @corpus.nil? || @corpus.num_docs.zero?
139
+
140
+ topics = Integer(num_topics)
141
+ raise ArgumentError, "num_topics must be greater than zero" if topics <= 0
142
+
143
+ terms = max_term_index + 1
144
+ raise ArgumentError, "corpus must contain terms" if terms <= 0
145
+
146
+ document_words = @corpus.documents.map { |document| document.words.map(&:to_i) }
147
+ document_counts = @corpus.documents.map { |document| document.counts.map(&:to_f) }
148
+
149
+ {
150
+ topics: topics,
151
+ terms: terms,
152
+ document_words: document_words,
153
+ document_counts: document_counts,
154
+ document_totals: document_counts.map { |counts| counts.sum.to_f },
155
+ document_lengths: document_words.map(&:length),
156
+ initial_beta_probabilities: rust_initial_beta_probabilities(
157
+ start,
158
+ document_words,
159
+ document_counts,
160
+ topics,
161
+ terms
162
+ ),
163
+ min_probability: MIN_PROBABILITY
164
+ }
165
+ end
166
+
167
+ def run_em_iterations(em_input)
168
+ topics = em_input.fetch(:topics)
169
+ terms = em_input.fetch(:terms)
170
+ document_words = em_input.fetch(:document_words)
171
+ document_counts = em_input.fetch(:document_counts)
172
+ document_totals = em_input.fetch(:document_totals)
173
+ document_lengths = em_input.fetch(:document_lengths)
174
+
175
+ @beta_probabilities = em_input.fetch(:initial_beta_probabilities)
176
+ previous_gamma = nil
177
+
178
+ Integer(em_max_iter).times do
179
+ if @trusted_kernel_outputs && @corpus_iteration_kernel
180
+ current_gamma, current_phi, topic_term_counts = infer_corpus_iteration(
181
+ nil,
182
+ document_words,
183
+ document_counts,
184
+ document_totals,
185
+ document_lengths,
186
+ topics,
187
+ terms
188
+ )
189
+ else
190
+ topic_term_counts = Array.new(topics) { Array.new(terms, MIN_PROBABILITY) }
191
+ current_gamma, current_phi, topic_term_counts = infer_corpus_iteration(
192
+ topic_term_counts,
193
+ document_words,
194
+ document_counts,
195
+ document_totals,
196
+ document_lengths,
197
+ topics,
198
+ terms
199
+ )
200
+ end
201
+
202
+ @beta_probabilities, @beta_log = finalize_topic_term_counts(topic_term_counts)
203
+ @gamma = current_gamma
204
+ @phi = current_phi
205
+
206
+ break if previous_gamma && average_gamma_shift(previous_gamma, current_gamma) <= Float(em_convergence)
207
+
208
+ previous_gamma = current_gamma
209
+ end
210
+ end
211
+
150
212
  def max_term_index
151
213
  return -1 if @corpus.nil? || @corpus.documents.empty?
152
214
 
@@ -4,6 +4,7 @@ module Lda
4
4
  module Backends
5
5
  class Rust < Base
6
6
  SETTINGS = %i[max_iter convergence em_max_iter em_convergence num_topics init_alpha est_alpha verbose].freeze
7
+ MIN_PROBABILITY = 1e-12
7
8
 
8
9
  def self.available?
9
10
  return false unless defined?(::Lda::RUST_EXTENSION_LOADED) && ::Lda::RUST_EXTENSION_LOADED
@@ -32,6 +33,12 @@ module Lda
32
33
  super(random_seed: random_seed)
33
34
  raise LoadError, "Rust backend is unavailable for this environment" unless self.class.available?
34
35
 
36
+ @rust_corpus_session_id = nil
37
+ @rust_corpus_terms = nil
38
+ @rust_document_lengths = nil
39
+ @rust_document_words = nil
40
+ @rust_document_counts = nil
41
+
35
42
  @fallback = PureRuby.new(random_seed: random_seed)
36
43
  @fallback.topic_weights_kernel = method(:rust_topic_weights_for_word)
37
44
  @fallback.topic_term_accumulator_kernel = method(:rust_accumulate_topic_term_counts)
@@ -49,20 +56,22 @@ module Lda
49
56
  end
50
57
 
51
58
  def corpus=(corpus)
59
+ previous_session_id = @rust_corpus_session_id
52
60
  @corpus = corpus
53
61
  @fallback.corpus = corpus
62
+ register_rust_corpus_session(previous_session_id)
54
63
  true
55
64
  end
56
65
 
57
66
  def fast_load_corpus_from_file(filename)
58
67
  loaded = @fallback.fast_load_corpus_from_file(filename)
59
- @corpus = @fallback.corpus
68
+ self.corpus = @fallback.corpus
60
69
  loaded
61
70
  end
62
71
 
63
72
  def load_settings(settings_file)
64
73
  loaded = @fallback.load_settings(settings_file)
65
- @corpus = @fallback.corpus
74
+ self.corpus = @fallback.corpus
66
75
  loaded
67
76
  end
68
77
 
@@ -71,7 +80,10 @@ module Lda
71
80
  end
72
81
 
73
82
  def em(start)
74
- rust_before_em(start)
83
+ start_mode = start.to_s
84
+ rust_before_em(start_mode)
85
+ return nil if rust_orchestrated_em(start_mode)
86
+
75
87
  @fallback.em(start)
76
88
  end
77
89
 
@@ -93,6 +105,375 @@ module Lda
93
105
 
94
106
  private
95
107
 
108
+ def rust_orchestrated_em(start)
109
+ managed_orchestrated = rust_orchestrated_em_with_managed_corpus(start)
110
+ return true if managed_orchestrated
111
+
112
+ direct_orchestrated = rust_orchestrated_em_with_start_seed(start)
113
+ return true if direct_orchestrated
114
+
115
+ rust_orchestrated_em_with_beta(start)
116
+ end
117
+
118
+ def rust_orchestrated_em_with_managed_corpus(start)
119
+ return false unless defined?(::Lda::RustBackend)
120
+ return false unless ensure_rust_corpus_snapshot
121
+
122
+ random_seed = Integer(next_random_seed)
123
+ if ::Lda::RustBackend.respond_to?(:run_em_on_session_with_corpus)
124
+ managed_output = ::Lda::RustBackend.run_em_on_session_with_corpus(
125
+ Integer(@rust_corpus_session_id || 0),
126
+ @rust_document_words,
127
+ @rust_document_counts,
128
+ Integer(@rust_corpus_terms),
129
+ start.to_s,
130
+ *current_rust_session_config_signature,
131
+ random_seed
132
+ )
133
+
134
+ return false unless managed_output.is_a?(Array) && managed_output.size == 5
135
+
136
+ session_id, beta_probabilities, beta_log, gamma, phi = managed_output
137
+ output = [beta_probabilities, beta_log, gamma, phi]
138
+ return false unless valid_rust_em_output?(output, @rust_document_lengths, Integer(num_topics), Integer(@rust_corpus_terms))
139
+
140
+ @rust_corpus_session_id =
141
+ if session_id.is_a?(Numeric) && session_id.positive?
142
+ Integer(session_id)
143
+ end
144
+ @fallback.apply_em_state(
145
+ beta_probabilities: beta_probabilities,
146
+ beta_log: beta_log,
147
+ gamma: gamma,
148
+ phi: phi
149
+ )
150
+ return true
151
+ end
152
+
153
+ return false unless ::Lda::RustBackend.respond_to?(:run_em_on_session)
154
+ return false unless @rust_corpus_session_id
155
+
156
+ output = ::Lda::RustBackend.run_em_on_session(
157
+ Integer(@rust_corpus_session_id),
158
+ start.to_s,
159
+ *current_rust_session_config_signature,
160
+ random_seed
161
+ )
162
+
163
+ return false unless valid_rust_em_output?(output, @rust_document_lengths, Integer(num_topics), Integer(@rust_corpus_terms))
164
+
165
+ beta_probabilities, beta_log, gamma, phi = output
166
+ @fallback.apply_em_state(
167
+ beta_probabilities: beta_probabilities,
168
+ beta_log: beta_log,
169
+ gamma: gamma,
170
+ phi: phi
171
+ )
172
+ true
173
+ rescue StandardError
174
+ false
175
+ end
176
+
177
+ def rust_orchestrated_em_with_start_seed(start)
178
+ return false unless defined?(::Lda::RustBackend)
179
+ return false unless ::Lda::RustBackend.respond_to?(:run_em_with_start_seed)
180
+ return false unless ensure_rust_corpus_snapshot
181
+
182
+ topics = Integer(num_topics)
183
+ return false unless topics.positive?
184
+
185
+ random_seed = Integer(next_random_seed)
186
+ output = ::Lda::RustBackend.run_em_with_start_seed(
187
+ start.to_s,
188
+ @rust_document_words,
189
+ @rust_document_counts,
190
+ topics,
191
+ Integer(@rust_corpus_terms),
192
+ Integer(max_iter),
193
+ Float(convergence),
194
+ Integer(em_max_iter),
195
+ Float(em_convergence),
196
+ Float(init_alpha),
197
+ MIN_PROBABILITY,
198
+ random_seed
199
+ )
200
+
201
+ return false unless valid_rust_em_output?(
202
+ output,
203
+ @rust_document_lengths,
204
+ topics,
205
+ Integer(@rust_corpus_terms)
206
+ )
207
+
208
+ beta_probabilities, beta_log, gamma, phi = output
209
+ @fallback.apply_em_state(
210
+ beta_probabilities: beta_probabilities,
211
+ beta_log: beta_log,
212
+ gamma: gamma,
213
+ phi: phi
214
+ )
215
+ true
216
+ rescue StandardError
217
+ false
218
+ end
219
+
220
+ def rust_orchestrated_em_with_beta(start)
221
+ return false unless defined?(::Lda::RustBackend)
222
+ return false unless ::Lda::RustBackend.respond_to?(:run_em)
223
+
224
+ em_input =
225
+ if ensure_rust_corpus_snapshot && @fallback.respond_to?(:rust_initial_beta_probabilities)
226
+ topics = Integer(num_topics)
227
+ terms = Integer(@rust_corpus_terms)
228
+ initial_beta_probabilities = @fallback.rust_initial_beta_probabilities(
229
+ start,
230
+ @rust_document_words,
231
+ @rust_document_counts,
232
+ topics,
233
+ terms
234
+ )
235
+
236
+ {
237
+ topics: topics,
238
+ terms: terms,
239
+ document_words: @rust_document_words,
240
+ document_counts: @rust_document_counts,
241
+ document_totals: @rust_document_counts.map { |counts| counts.sum.to_f },
242
+ document_lengths: @rust_document_lengths,
243
+ initial_beta_probabilities: initial_beta_probabilities,
244
+ min_probability: MIN_PROBABILITY
245
+ }
246
+ else
247
+ @fallback.rust_em_input(start)
248
+ end
249
+
250
+ return true if em_input.nil?
251
+
252
+ output = ::Lda::RustBackend.run_em(
253
+ em_input.fetch(:initial_beta_probabilities),
254
+ em_input.fetch(:document_words),
255
+ em_input.fetch(:document_counts),
256
+ Integer(max_iter),
257
+ Float(convergence),
258
+ Integer(em_max_iter),
259
+ Float(em_convergence),
260
+ Float(init_alpha),
261
+ Float(em_input.fetch(:min_probability))
262
+ )
263
+
264
+ unless valid_rust_em_output?(
265
+ output,
266
+ em_input.fetch(:document_lengths),
267
+ em_input.fetch(:topics),
268
+ em_input.fetch(:terms)
269
+ )
270
+ @fallback.em_from_input(em_input)
271
+ return true
272
+ end
273
+
274
+ beta_probabilities, beta_log, gamma, phi = output
275
+ @fallback.apply_em_state(
276
+ beta_probabilities: beta_probabilities,
277
+ beta_log: beta_log,
278
+ gamma: gamma,
279
+ phi: phi
280
+ )
281
+ true
282
+ rescue StandardError
283
+ if defined?(em_input) && em_input
284
+ @fallback.em_from_input(em_input)
285
+ return true
286
+ end
287
+
288
+ false
289
+ end
290
+
291
+ def rust_em_corpus_input
292
+ return nil if @corpus.nil? || @corpus.num_docs.zero?
293
+
294
+ topics = Integer(num_topics)
295
+ raise ArgumentError, "num_topics must be greater than zero" if topics <= 0
296
+
297
+ terms = max_term_index + 1
298
+ raise ArgumentError, "corpus must contain terms" if terms <= 0
299
+
300
+ document_words = @corpus.documents.map { |document| document.words.map(&:to_i) }
301
+ document_counts = @corpus.documents.map { |document| document.counts.map(&:to_f) }
302
+
303
+ {
304
+ topics: topics,
305
+ terms: terms,
306
+ document_words: document_words,
307
+ document_counts: document_counts,
308
+ document_lengths: document_words.map(&:length),
309
+ min_probability: MIN_PROBABILITY
310
+ }
311
+ end
312
+
313
+ def max_term_index
314
+ return -1 if @corpus.nil? || @corpus.documents.empty?
315
+
316
+ @corpus.documents
317
+ .flat_map(&:words)
318
+ .max || -1
319
+ end
320
+
321
+ def register_rust_corpus_session(previous_session_id = nil)
322
+ @rust_corpus_session_id = nil
323
+ @rust_corpus_terms = nil
324
+ @rust_document_lengths = nil
325
+ @rust_document_words = nil
326
+ @rust_document_counts = nil
327
+
328
+ if @corpus.nil?
329
+ drop_rust_corpus_session_by_id(previous_session_id)
330
+ return
331
+ end
332
+
333
+ return unless defined?(::Lda::RustBackend)
334
+
335
+ em_input = rust_em_corpus_input
336
+ if em_input.nil?
337
+ drop_rust_corpus_session_by_id(previous_session_id)
338
+ return
339
+ end
340
+
341
+ @rust_corpus_terms = Integer(em_input.fetch(:terms))
342
+ @rust_document_lengths = em_input.fetch(:document_lengths)
343
+ @rust_document_words = em_input.fetch(:document_words)
344
+ @rust_document_counts = em_input.fetch(:document_counts)
345
+
346
+ session_id =
347
+ if ::Lda::RustBackend.respond_to?(:replace_corpus_session)
348
+ ::Lda::RustBackend.replace_corpus_session(
349
+ Integer(previous_session_id || 0),
350
+ @rust_document_words,
351
+ @rust_document_counts,
352
+ Integer(@rust_corpus_terms)
353
+ )
354
+ elsif ::Lda::RustBackend.respond_to?(:create_corpus_session)
355
+ drop_rust_corpus_session_by_id(previous_session_id)
356
+ ::Lda::RustBackend.create_corpus_session(
357
+ @rust_document_words,
358
+ @rust_document_counts,
359
+ Integer(@rust_corpus_terms)
360
+ )
361
+ end
362
+
363
+ unless session_id.is_a?(Numeric) && session_id.positive?
364
+ drop_rust_corpus_session_by_id(previous_session_id)
365
+ return
366
+ end
367
+
368
+ @rust_corpus_session_id = Integer(session_id)
369
+ rescue StandardError
370
+ @rust_corpus_session_id = nil
371
+ @rust_corpus_terms = nil
372
+ @rust_document_lengths = nil
373
+ @rust_document_words = nil
374
+ @rust_document_counts = nil
375
+ drop_rust_corpus_session_by_id(previous_session_id)
376
+ end
377
+
378
+ def ensure_rust_corpus_snapshot
379
+ has_session_data = @rust_corpus_terms && @rust_document_lengths && @rust_document_words && @rust_document_counts
380
+ return true if has_session_data
381
+
382
+ register_rust_corpus_session(@rust_corpus_session_id)
383
+ @rust_corpus_terms && @rust_document_lengths && @rust_document_words && @rust_document_counts
384
+ rescue StandardError
385
+ false
386
+ end
387
+
388
+ def release_rust_corpus_session
389
+ session_id = @rust_corpus_session_id
390
+
391
+ @rust_corpus_session_id = nil
392
+ @rust_corpus_terms = nil
393
+ @rust_document_lengths = nil
394
+ @rust_document_words = nil
395
+ @rust_document_counts = nil
396
+
397
+ drop_rust_corpus_session_by_id(session_id)
398
+ rescue StandardError
399
+ nil
400
+ end
401
+
402
+ def drop_rust_corpus_session_by_id(session_id)
403
+ return unless session_id
404
+ return unless defined?(::Lda::RustBackend)
405
+ return unless ::Lda::RustBackend.respond_to?(:drop_corpus_session)
406
+
407
+ ::Lda::RustBackend.drop_corpus_session(Integer(session_id))
408
+ rescue StandardError
409
+ nil
410
+ end
411
+
412
+ def current_rust_session_config_signature
413
+ [
414
+ Integer(num_topics),
415
+ Integer(max_iter),
416
+ Float(convergence),
417
+ Integer(em_max_iter),
418
+ Float(em_convergence),
419
+ Float(init_alpha),
420
+ MIN_PROBABILITY
421
+ ]
422
+ end
423
+
424
+ def valid_rust_em_output?(output, document_lengths, topics, terms)
425
+ return false unless output.is_a?(Array)
426
+ return false unless output.size == 4
427
+
428
+ beta_probabilities, beta_log, gamma, phi = output
429
+
430
+ valid_topic_term_matrix?(beta_probabilities, topics, terms) &&
431
+ valid_topic_term_matrix?(beta_log, topics, terms) &&
432
+ valid_gamma_matrix?(gamma, document_lengths.size, topics) &&
433
+ valid_phi_tensor?(phi, document_lengths, topics)
434
+ end
435
+
436
+ def valid_topic_term_matrix?(matrix, topics, terms)
437
+ return false unless matrix.is_a?(Array)
438
+ return false unless matrix.size == topics
439
+
440
+ matrix.all? do |row|
441
+ row.is_a?(Array) &&
442
+ row.size == terms &&
443
+ row.all? { |value| finite_numeric?(value) }
444
+ end
445
+ end
446
+
447
+ def valid_gamma_matrix?(gamma, expected_docs, topics)
448
+ return false unless gamma.is_a?(Array)
449
+ return false unless gamma.size == expected_docs
450
+
451
+ gamma.all? do |row|
452
+ row.is_a?(Array) &&
453
+ row.size == topics &&
454
+ row.all? { |value| finite_numeric?(value) && value.positive? }
455
+ end
456
+ end
457
+
458
+ def valid_phi_tensor?(phi, document_lengths, topics)
459
+ return false unless phi.is_a?(Array)
460
+ return false unless phi.size == document_lengths.size
461
+
462
+ phi.each_with_index.all? do |doc_phi, doc_index|
463
+ doc_phi.is_a?(Array) &&
464
+ doc_phi.size == document_lengths[doc_index] &&
465
+ doc_phi.all? do |row|
466
+ row.is_a?(Array) &&
467
+ row.size == topics &&
468
+ row.all? { |value| finite_numeric?(value) }
469
+ end
470
+ end
471
+ end
472
+
473
+ def finite_numeric?(value)
474
+ value.is_a?(Numeric) && value.finite?
475
+ end
476
+
96
477
  def rust_before_em(start)
97
478
  return unless defined?(::Lda::RustBackend)
98
479
  return unless ::Lda::RustBackend.respond_to?(:before_em)
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Lda
4
- VERSION = "0.4.0"
4
+ VERSION = "0.5.0"
5
5
  end
@@ -0,0 +1,23 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "test_helper"
4
+ require "open3"
5
+
6
+ class BenchmarkScriptsTest < Test::Unit::TestCase
7
+ def setup
8
+ @repo_root = File.expand_path("..", __dir__)
9
+ @check_rust_benchmark = File.join(@repo_root, "bin", "check-rust-benchmark")
10
+ end
11
+
12
+ def test_check_rust_benchmark_help
13
+ stdout, stderr, status = Open3.capture3(@check_rust_benchmark, "--help", chdir: @repo_root)
14
+ assert(status.success?, stderr)
15
+ assert_match(/Usage: \.\/bin\/check-rust-benchmark/, stdout)
16
+ end
17
+
18
+ def test_check_rust_benchmark_rejects_unknown_argument
19
+ _stdout, stderr, status = Open3.capture3(@check_rust_benchmark, "--unknown", chdir: @repo_root)
20
+ assert(!status.success?, "expected check-rust-benchmark to fail for unknown args")
21
+ assert_match(/Unknown argument/, stderr)
22
+ end
23
+ end