lda-ruby 0.4.0 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/README.md +4 -1
- data/VERSION.yml +1 -1
- data/docs/modernization-handoff.md +68 -25
- data/docs/porting-strategy.md +23 -2
- data/docs/precompiled-platform-policy.md +15 -2
- data/docs/precompiled-target-evaluation.md +67 -0
- data/docs/release-runbook.md +41 -6
- data/docs/rust-orchestration-guardrails.md +50 -0
- data/ext/lda-ruby/cokus.c +10 -11
- data/ext/lda-ruby/cokus.h +3 -3
- data/ext/lda-ruby/lda-inference.c +2 -2
- data/ext/lda-ruby/utils.c +8 -0
- data/ext/lda-ruby-rust/README.md +25 -0
- data/ext/lda-ruby-rust/extconf.rb +25 -13
- data/ext/lda-ruby-rust/include/strings.h +35 -0
- data/ext/lda-ruby-rust/src/lib.rs +816 -9
- data/lib/lda-ruby/backends/base.rb +4 -0
- data/lib/lda-ruby/backends/pure_ruby.rb +110 -48
- data/lib/lda-ruby/backends/rust.rb +384 -3
- data/lib/lda-ruby/version.rb +1 -1
- data/test/benchmark_scripts_test.rb +23 -0
- data/test/pure_ruby_orchestration_test.rb +109 -0
- data/test/release_scripts_test.rb +39 -0
- data/test/rust_orchestration_test.rb +911 -0
- metadata +8 -2
|
@@ -46,61 +46,48 @@ module Lda
|
|
|
46
46
|
end
|
|
47
47
|
|
|
48
48
|
def em(start)
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
topics = Integer(num_topics)
|
|
52
|
-
raise ArgumentError, "num_topics must be greater than zero" if topics <= 0
|
|
53
|
-
|
|
54
|
-
terms = max_term_index + 1
|
|
55
|
-
raise ArgumentError, "corpus must contain terms" if terms <= 0
|
|
49
|
+
em_input = build_em_input(start)
|
|
50
|
+
return nil if em_input.nil?
|
|
56
51
|
|
|
57
|
-
|
|
58
|
-
|
|
52
|
+
run_em_iterations(em_input)
|
|
53
|
+
nil
|
|
54
|
+
end
|
|
59
55
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
end
|
|
56
|
+
# Returns an EM input snapshot that can be reused by Rust orchestration
|
|
57
|
+
# and Ruby fallback paths without re-sampling random initialization.
|
|
58
|
+
def rust_em_input(start)
|
|
59
|
+
build_em_input(start)
|
|
60
|
+
end
|
|
66
61
|
|
|
67
|
-
|
|
68
|
-
|
|
62
|
+
# Returns only the initial beta matrix for Rust compatibility paths that
|
|
63
|
+
# already hold a cached corpus snapshot.
|
|
64
|
+
def rust_initial_beta_probabilities(start, document_words, document_counts, topics, terms)
|
|
65
|
+
start_mode = start.to_s
|
|
69
66
|
|
|
70
|
-
|
|
67
|
+
if start_mode.strip.casecmp("seeded").zero? || start_mode.strip.casecmp("deterministic").zero?
|
|
68
|
+
seeded_topic_term_probabilities(
|
|
69
|
+
Integer(topics),
|
|
70
|
+
Integer(terms),
|
|
71
|
+
document_words,
|
|
72
|
+
document_counts
|
|
73
|
+
)
|
|
74
|
+
else
|
|
75
|
+
initial_topic_term_probabilities(Integer(topics), Integer(terms))
|
|
76
|
+
end
|
|
77
|
+
end
|
|
71
78
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
current_gamma, current_phi, topic_term_counts = infer_corpus_iteration(
|
|
75
|
-
nil,
|
|
76
|
-
document_words,
|
|
77
|
-
document_counts,
|
|
78
|
-
document_totals,
|
|
79
|
-
document_lengths,
|
|
80
|
-
topics,
|
|
81
|
-
terms
|
|
82
|
-
)
|
|
83
|
-
else
|
|
84
|
-
topic_term_counts = Array.new(topics) { Array.new(terms, MIN_PROBABILITY) }
|
|
85
|
-
current_gamma, current_phi, topic_term_counts = infer_corpus_iteration(
|
|
86
|
-
topic_term_counts,
|
|
87
|
-
document_words,
|
|
88
|
-
document_counts,
|
|
89
|
-
document_totals,
|
|
90
|
-
document_lengths,
|
|
91
|
-
topics,
|
|
92
|
-
terms
|
|
93
|
-
)
|
|
94
|
-
end
|
|
79
|
+
def em_from_input(em_input)
|
|
80
|
+
return nil if em_input.nil?
|
|
95
81
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
82
|
+
run_em_iterations(em_input)
|
|
83
|
+
nil
|
|
84
|
+
end
|
|
99
85
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
86
|
+
def apply_em_state(beta_probabilities:, beta_log:, gamma:, phi:)
|
|
87
|
+
@beta_probabilities = beta_probabilities
|
|
88
|
+
@beta_log = beta_log
|
|
89
|
+
@gamma = gamma
|
|
90
|
+
@phi = phi
|
|
104
91
|
|
|
105
92
|
nil
|
|
106
93
|
end
|
|
@@ -147,6 +134,81 @@ module Lda
|
|
|
147
134
|
|
|
148
135
|
private
|
|
149
136
|
|
|
137
|
+
def build_em_input(start)
|
|
138
|
+
return nil if @corpus.nil? || @corpus.num_docs.zero?
|
|
139
|
+
|
|
140
|
+
topics = Integer(num_topics)
|
|
141
|
+
raise ArgumentError, "num_topics must be greater than zero" if topics <= 0
|
|
142
|
+
|
|
143
|
+
terms = max_term_index + 1
|
|
144
|
+
raise ArgumentError, "corpus must contain terms" if terms <= 0
|
|
145
|
+
|
|
146
|
+
document_words = @corpus.documents.map { |document| document.words.map(&:to_i) }
|
|
147
|
+
document_counts = @corpus.documents.map { |document| document.counts.map(&:to_f) }
|
|
148
|
+
|
|
149
|
+
{
|
|
150
|
+
topics: topics,
|
|
151
|
+
terms: terms,
|
|
152
|
+
document_words: document_words,
|
|
153
|
+
document_counts: document_counts,
|
|
154
|
+
document_totals: document_counts.map { |counts| counts.sum.to_f },
|
|
155
|
+
document_lengths: document_words.map(&:length),
|
|
156
|
+
initial_beta_probabilities: rust_initial_beta_probabilities(
|
|
157
|
+
start,
|
|
158
|
+
document_words,
|
|
159
|
+
document_counts,
|
|
160
|
+
topics,
|
|
161
|
+
terms
|
|
162
|
+
),
|
|
163
|
+
min_probability: MIN_PROBABILITY
|
|
164
|
+
}
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
def run_em_iterations(em_input)
|
|
168
|
+
topics = em_input.fetch(:topics)
|
|
169
|
+
terms = em_input.fetch(:terms)
|
|
170
|
+
document_words = em_input.fetch(:document_words)
|
|
171
|
+
document_counts = em_input.fetch(:document_counts)
|
|
172
|
+
document_totals = em_input.fetch(:document_totals)
|
|
173
|
+
document_lengths = em_input.fetch(:document_lengths)
|
|
174
|
+
|
|
175
|
+
@beta_probabilities = em_input.fetch(:initial_beta_probabilities)
|
|
176
|
+
previous_gamma = nil
|
|
177
|
+
|
|
178
|
+
Integer(em_max_iter).times do
|
|
179
|
+
if @trusted_kernel_outputs && @corpus_iteration_kernel
|
|
180
|
+
current_gamma, current_phi, topic_term_counts = infer_corpus_iteration(
|
|
181
|
+
nil,
|
|
182
|
+
document_words,
|
|
183
|
+
document_counts,
|
|
184
|
+
document_totals,
|
|
185
|
+
document_lengths,
|
|
186
|
+
topics,
|
|
187
|
+
terms
|
|
188
|
+
)
|
|
189
|
+
else
|
|
190
|
+
topic_term_counts = Array.new(topics) { Array.new(terms, MIN_PROBABILITY) }
|
|
191
|
+
current_gamma, current_phi, topic_term_counts = infer_corpus_iteration(
|
|
192
|
+
topic_term_counts,
|
|
193
|
+
document_words,
|
|
194
|
+
document_counts,
|
|
195
|
+
document_totals,
|
|
196
|
+
document_lengths,
|
|
197
|
+
topics,
|
|
198
|
+
terms
|
|
199
|
+
)
|
|
200
|
+
end
|
|
201
|
+
|
|
202
|
+
@beta_probabilities, @beta_log = finalize_topic_term_counts(topic_term_counts)
|
|
203
|
+
@gamma = current_gamma
|
|
204
|
+
@phi = current_phi
|
|
205
|
+
|
|
206
|
+
break if previous_gamma && average_gamma_shift(previous_gamma, current_gamma) <= Float(em_convergence)
|
|
207
|
+
|
|
208
|
+
previous_gamma = current_gamma
|
|
209
|
+
end
|
|
210
|
+
end
|
|
211
|
+
|
|
150
212
|
def max_term_index
|
|
151
213
|
return -1 if @corpus.nil? || @corpus.documents.empty?
|
|
152
214
|
|
|
@@ -4,6 +4,7 @@ module Lda
|
|
|
4
4
|
module Backends
|
|
5
5
|
class Rust < Base
|
|
6
6
|
SETTINGS = %i[max_iter convergence em_max_iter em_convergence num_topics init_alpha est_alpha verbose].freeze
|
|
7
|
+
MIN_PROBABILITY = 1e-12
|
|
7
8
|
|
|
8
9
|
def self.available?
|
|
9
10
|
return false unless defined?(::Lda::RUST_EXTENSION_LOADED) && ::Lda::RUST_EXTENSION_LOADED
|
|
@@ -32,6 +33,12 @@ module Lda
|
|
|
32
33
|
super(random_seed: random_seed)
|
|
33
34
|
raise LoadError, "Rust backend is unavailable for this environment" unless self.class.available?
|
|
34
35
|
|
|
36
|
+
@rust_corpus_session_id = nil
|
|
37
|
+
@rust_corpus_terms = nil
|
|
38
|
+
@rust_document_lengths = nil
|
|
39
|
+
@rust_document_words = nil
|
|
40
|
+
@rust_document_counts = nil
|
|
41
|
+
|
|
35
42
|
@fallback = PureRuby.new(random_seed: random_seed)
|
|
36
43
|
@fallback.topic_weights_kernel = method(:rust_topic_weights_for_word)
|
|
37
44
|
@fallback.topic_term_accumulator_kernel = method(:rust_accumulate_topic_term_counts)
|
|
@@ -49,20 +56,22 @@ module Lda
|
|
|
49
56
|
end
|
|
50
57
|
|
|
51
58
|
def corpus=(corpus)
|
|
59
|
+
previous_session_id = @rust_corpus_session_id
|
|
52
60
|
@corpus = corpus
|
|
53
61
|
@fallback.corpus = corpus
|
|
62
|
+
register_rust_corpus_session(previous_session_id)
|
|
54
63
|
true
|
|
55
64
|
end
|
|
56
65
|
|
|
57
66
|
def fast_load_corpus_from_file(filename)
|
|
58
67
|
loaded = @fallback.fast_load_corpus_from_file(filename)
|
|
59
|
-
|
|
68
|
+
self.corpus = @fallback.corpus
|
|
60
69
|
loaded
|
|
61
70
|
end
|
|
62
71
|
|
|
63
72
|
def load_settings(settings_file)
|
|
64
73
|
loaded = @fallback.load_settings(settings_file)
|
|
65
|
-
|
|
74
|
+
self.corpus = @fallback.corpus
|
|
66
75
|
loaded
|
|
67
76
|
end
|
|
68
77
|
|
|
@@ -71,7 +80,10 @@ module Lda
|
|
|
71
80
|
end
|
|
72
81
|
|
|
73
82
|
def em(start)
|
|
74
|
-
|
|
83
|
+
start_mode = start.to_s
|
|
84
|
+
rust_before_em(start_mode)
|
|
85
|
+
return nil if rust_orchestrated_em(start_mode)
|
|
86
|
+
|
|
75
87
|
@fallback.em(start)
|
|
76
88
|
end
|
|
77
89
|
|
|
@@ -93,6 +105,375 @@ module Lda
|
|
|
93
105
|
|
|
94
106
|
private
|
|
95
107
|
|
|
108
|
+
def rust_orchestrated_em(start)
|
|
109
|
+
managed_orchestrated = rust_orchestrated_em_with_managed_corpus(start)
|
|
110
|
+
return true if managed_orchestrated
|
|
111
|
+
|
|
112
|
+
direct_orchestrated = rust_orchestrated_em_with_start_seed(start)
|
|
113
|
+
return true if direct_orchestrated
|
|
114
|
+
|
|
115
|
+
rust_orchestrated_em_with_beta(start)
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
def rust_orchestrated_em_with_managed_corpus(start)
|
|
119
|
+
return false unless defined?(::Lda::RustBackend)
|
|
120
|
+
return false unless ensure_rust_corpus_snapshot
|
|
121
|
+
|
|
122
|
+
random_seed = Integer(next_random_seed)
|
|
123
|
+
if ::Lda::RustBackend.respond_to?(:run_em_on_session_with_corpus)
|
|
124
|
+
managed_output = ::Lda::RustBackend.run_em_on_session_with_corpus(
|
|
125
|
+
Integer(@rust_corpus_session_id || 0),
|
|
126
|
+
@rust_document_words,
|
|
127
|
+
@rust_document_counts,
|
|
128
|
+
Integer(@rust_corpus_terms),
|
|
129
|
+
start.to_s,
|
|
130
|
+
*current_rust_session_config_signature,
|
|
131
|
+
random_seed
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return false unless managed_output.is_a?(Array) && managed_output.size == 5
|
|
135
|
+
|
|
136
|
+
session_id, beta_probabilities, beta_log, gamma, phi = managed_output
|
|
137
|
+
output = [beta_probabilities, beta_log, gamma, phi]
|
|
138
|
+
return false unless valid_rust_em_output?(output, @rust_document_lengths, Integer(num_topics), Integer(@rust_corpus_terms))
|
|
139
|
+
|
|
140
|
+
@rust_corpus_session_id =
|
|
141
|
+
if session_id.is_a?(Numeric) && session_id.positive?
|
|
142
|
+
Integer(session_id)
|
|
143
|
+
end
|
|
144
|
+
@fallback.apply_em_state(
|
|
145
|
+
beta_probabilities: beta_probabilities,
|
|
146
|
+
beta_log: beta_log,
|
|
147
|
+
gamma: gamma,
|
|
148
|
+
phi: phi
|
|
149
|
+
)
|
|
150
|
+
return true
|
|
151
|
+
end
|
|
152
|
+
|
|
153
|
+
return false unless ::Lda::RustBackend.respond_to?(:run_em_on_session)
|
|
154
|
+
return false unless @rust_corpus_session_id
|
|
155
|
+
|
|
156
|
+
output = ::Lda::RustBackend.run_em_on_session(
|
|
157
|
+
Integer(@rust_corpus_session_id),
|
|
158
|
+
start.to_s,
|
|
159
|
+
*current_rust_session_config_signature,
|
|
160
|
+
random_seed
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return false unless valid_rust_em_output?(output, @rust_document_lengths, Integer(num_topics), Integer(@rust_corpus_terms))
|
|
164
|
+
|
|
165
|
+
beta_probabilities, beta_log, gamma, phi = output
|
|
166
|
+
@fallback.apply_em_state(
|
|
167
|
+
beta_probabilities: beta_probabilities,
|
|
168
|
+
beta_log: beta_log,
|
|
169
|
+
gamma: gamma,
|
|
170
|
+
phi: phi
|
|
171
|
+
)
|
|
172
|
+
true
|
|
173
|
+
rescue StandardError
|
|
174
|
+
false
|
|
175
|
+
end
|
|
176
|
+
|
|
177
|
+
def rust_orchestrated_em_with_start_seed(start)
|
|
178
|
+
return false unless defined?(::Lda::RustBackend)
|
|
179
|
+
return false unless ::Lda::RustBackend.respond_to?(:run_em_with_start_seed)
|
|
180
|
+
return false unless ensure_rust_corpus_snapshot
|
|
181
|
+
|
|
182
|
+
topics = Integer(num_topics)
|
|
183
|
+
return false unless topics.positive?
|
|
184
|
+
|
|
185
|
+
random_seed = Integer(next_random_seed)
|
|
186
|
+
output = ::Lda::RustBackend.run_em_with_start_seed(
|
|
187
|
+
start.to_s,
|
|
188
|
+
@rust_document_words,
|
|
189
|
+
@rust_document_counts,
|
|
190
|
+
topics,
|
|
191
|
+
Integer(@rust_corpus_terms),
|
|
192
|
+
Integer(max_iter),
|
|
193
|
+
Float(convergence),
|
|
194
|
+
Integer(em_max_iter),
|
|
195
|
+
Float(em_convergence),
|
|
196
|
+
Float(init_alpha),
|
|
197
|
+
MIN_PROBABILITY,
|
|
198
|
+
random_seed
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return false unless valid_rust_em_output?(
|
|
202
|
+
output,
|
|
203
|
+
@rust_document_lengths,
|
|
204
|
+
topics,
|
|
205
|
+
Integer(@rust_corpus_terms)
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
beta_probabilities, beta_log, gamma, phi = output
|
|
209
|
+
@fallback.apply_em_state(
|
|
210
|
+
beta_probabilities: beta_probabilities,
|
|
211
|
+
beta_log: beta_log,
|
|
212
|
+
gamma: gamma,
|
|
213
|
+
phi: phi
|
|
214
|
+
)
|
|
215
|
+
true
|
|
216
|
+
rescue StandardError
|
|
217
|
+
false
|
|
218
|
+
end
|
|
219
|
+
|
|
220
|
+
def rust_orchestrated_em_with_beta(start)
|
|
221
|
+
return false unless defined?(::Lda::RustBackend)
|
|
222
|
+
return false unless ::Lda::RustBackend.respond_to?(:run_em)
|
|
223
|
+
|
|
224
|
+
em_input =
|
|
225
|
+
if ensure_rust_corpus_snapshot && @fallback.respond_to?(:rust_initial_beta_probabilities)
|
|
226
|
+
topics = Integer(num_topics)
|
|
227
|
+
terms = Integer(@rust_corpus_terms)
|
|
228
|
+
initial_beta_probabilities = @fallback.rust_initial_beta_probabilities(
|
|
229
|
+
start,
|
|
230
|
+
@rust_document_words,
|
|
231
|
+
@rust_document_counts,
|
|
232
|
+
topics,
|
|
233
|
+
terms
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
{
|
|
237
|
+
topics: topics,
|
|
238
|
+
terms: terms,
|
|
239
|
+
document_words: @rust_document_words,
|
|
240
|
+
document_counts: @rust_document_counts,
|
|
241
|
+
document_totals: @rust_document_counts.map { |counts| counts.sum.to_f },
|
|
242
|
+
document_lengths: @rust_document_lengths,
|
|
243
|
+
initial_beta_probabilities: initial_beta_probabilities,
|
|
244
|
+
min_probability: MIN_PROBABILITY
|
|
245
|
+
}
|
|
246
|
+
else
|
|
247
|
+
@fallback.rust_em_input(start)
|
|
248
|
+
end
|
|
249
|
+
|
|
250
|
+
return true if em_input.nil?
|
|
251
|
+
|
|
252
|
+
output = ::Lda::RustBackend.run_em(
|
|
253
|
+
em_input.fetch(:initial_beta_probabilities),
|
|
254
|
+
em_input.fetch(:document_words),
|
|
255
|
+
em_input.fetch(:document_counts),
|
|
256
|
+
Integer(max_iter),
|
|
257
|
+
Float(convergence),
|
|
258
|
+
Integer(em_max_iter),
|
|
259
|
+
Float(em_convergence),
|
|
260
|
+
Float(init_alpha),
|
|
261
|
+
Float(em_input.fetch(:min_probability))
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
unless valid_rust_em_output?(
|
|
265
|
+
output,
|
|
266
|
+
em_input.fetch(:document_lengths),
|
|
267
|
+
em_input.fetch(:topics),
|
|
268
|
+
em_input.fetch(:terms)
|
|
269
|
+
)
|
|
270
|
+
@fallback.em_from_input(em_input)
|
|
271
|
+
return true
|
|
272
|
+
end
|
|
273
|
+
|
|
274
|
+
beta_probabilities, beta_log, gamma, phi = output
|
|
275
|
+
@fallback.apply_em_state(
|
|
276
|
+
beta_probabilities: beta_probabilities,
|
|
277
|
+
beta_log: beta_log,
|
|
278
|
+
gamma: gamma,
|
|
279
|
+
phi: phi
|
|
280
|
+
)
|
|
281
|
+
true
|
|
282
|
+
rescue StandardError
|
|
283
|
+
if defined?(em_input) && em_input
|
|
284
|
+
@fallback.em_from_input(em_input)
|
|
285
|
+
return true
|
|
286
|
+
end
|
|
287
|
+
|
|
288
|
+
false
|
|
289
|
+
end
|
|
290
|
+
|
|
291
|
+
def rust_em_corpus_input
|
|
292
|
+
return nil if @corpus.nil? || @corpus.num_docs.zero?
|
|
293
|
+
|
|
294
|
+
topics = Integer(num_topics)
|
|
295
|
+
raise ArgumentError, "num_topics must be greater than zero" if topics <= 0
|
|
296
|
+
|
|
297
|
+
terms = max_term_index + 1
|
|
298
|
+
raise ArgumentError, "corpus must contain terms" if terms <= 0
|
|
299
|
+
|
|
300
|
+
document_words = @corpus.documents.map { |document| document.words.map(&:to_i) }
|
|
301
|
+
document_counts = @corpus.documents.map { |document| document.counts.map(&:to_f) }
|
|
302
|
+
|
|
303
|
+
{
|
|
304
|
+
topics: topics,
|
|
305
|
+
terms: terms,
|
|
306
|
+
document_words: document_words,
|
|
307
|
+
document_counts: document_counts,
|
|
308
|
+
document_lengths: document_words.map(&:length),
|
|
309
|
+
min_probability: MIN_PROBABILITY
|
|
310
|
+
}
|
|
311
|
+
end
|
|
312
|
+
|
|
313
|
+
def max_term_index
|
|
314
|
+
return -1 if @corpus.nil? || @corpus.documents.empty?
|
|
315
|
+
|
|
316
|
+
@corpus.documents
|
|
317
|
+
.flat_map(&:words)
|
|
318
|
+
.max || -1
|
|
319
|
+
end
|
|
320
|
+
|
|
321
|
+
def register_rust_corpus_session(previous_session_id = nil)
|
|
322
|
+
@rust_corpus_session_id = nil
|
|
323
|
+
@rust_corpus_terms = nil
|
|
324
|
+
@rust_document_lengths = nil
|
|
325
|
+
@rust_document_words = nil
|
|
326
|
+
@rust_document_counts = nil
|
|
327
|
+
|
|
328
|
+
if @corpus.nil?
|
|
329
|
+
drop_rust_corpus_session_by_id(previous_session_id)
|
|
330
|
+
return
|
|
331
|
+
end
|
|
332
|
+
|
|
333
|
+
return unless defined?(::Lda::RustBackend)
|
|
334
|
+
|
|
335
|
+
em_input = rust_em_corpus_input
|
|
336
|
+
if em_input.nil?
|
|
337
|
+
drop_rust_corpus_session_by_id(previous_session_id)
|
|
338
|
+
return
|
|
339
|
+
end
|
|
340
|
+
|
|
341
|
+
@rust_corpus_terms = Integer(em_input.fetch(:terms))
|
|
342
|
+
@rust_document_lengths = em_input.fetch(:document_lengths)
|
|
343
|
+
@rust_document_words = em_input.fetch(:document_words)
|
|
344
|
+
@rust_document_counts = em_input.fetch(:document_counts)
|
|
345
|
+
|
|
346
|
+
session_id =
|
|
347
|
+
if ::Lda::RustBackend.respond_to?(:replace_corpus_session)
|
|
348
|
+
::Lda::RustBackend.replace_corpus_session(
|
|
349
|
+
Integer(previous_session_id || 0),
|
|
350
|
+
@rust_document_words,
|
|
351
|
+
@rust_document_counts,
|
|
352
|
+
Integer(@rust_corpus_terms)
|
|
353
|
+
)
|
|
354
|
+
elsif ::Lda::RustBackend.respond_to?(:create_corpus_session)
|
|
355
|
+
drop_rust_corpus_session_by_id(previous_session_id)
|
|
356
|
+
::Lda::RustBackend.create_corpus_session(
|
|
357
|
+
@rust_document_words,
|
|
358
|
+
@rust_document_counts,
|
|
359
|
+
Integer(@rust_corpus_terms)
|
|
360
|
+
)
|
|
361
|
+
end
|
|
362
|
+
|
|
363
|
+
unless session_id.is_a?(Numeric) && session_id.positive?
|
|
364
|
+
drop_rust_corpus_session_by_id(previous_session_id)
|
|
365
|
+
return
|
|
366
|
+
end
|
|
367
|
+
|
|
368
|
+
@rust_corpus_session_id = Integer(session_id)
|
|
369
|
+
rescue StandardError
|
|
370
|
+
@rust_corpus_session_id = nil
|
|
371
|
+
@rust_corpus_terms = nil
|
|
372
|
+
@rust_document_lengths = nil
|
|
373
|
+
@rust_document_words = nil
|
|
374
|
+
@rust_document_counts = nil
|
|
375
|
+
drop_rust_corpus_session_by_id(previous_session_id)
|
|
376
|
+
end
|
|
377
|
+
|
|
378
|
+
def ensure_rust_corpus_snapshot
|
|
379
|
+
has_session_data = @rust_corpus_terms && @rust_document_lengths && @rust_document_words && @rust_document_counts
|
|
380
|
+
return true if has_session_data
|
|
381
|
+
|
|
382
|
+
register_rust_corpus_session(@rust_corpus_session_id)
|
|
383
|
+
@rust_corpus_terms && @rust_document_lengths && @rust_document_words && @rust_document_counts
|
|
384
|
+
rescue StandardError
|
|
385
|
+
false
|
|
386
|
+
end
|
|
387
|
+
|
|
388
|
+
def release_rust_corpus_session
|
|
389
|
+
session_id = @rust_corpus_session_id
|
|
390
|
+
|
|
391
|
+
@rust_corpus_session_id = nil
|
|
392
|
+
@rust_corpus_terms = nil
|
|
393
|
+
@rust_document_lengths = nil
|
|
394
|
+
@rust_document_words = nil
|
|
395
|
+
@rust_document_counts = nil
|
|
396
|
+
|
|
397
|
+
drop_rust_corpus_session_by_id(session_id)
|
|
398
|
+
rescue StandardError
|
|
399
|
+
nil
|
|
400
|
+
end
|
|
401
|
+
|
|
402
|
+
def drop_rust_corpus_session_by_id(session_id)
|
|
403
|
+
return unless session_id
|
|
404
|
+
return unless defined?(::Lda::RustBackend)
|
|
405
|
+
return unless ::Lda::RustBackend.respond_to?(:drop_corpus_session)
|
|
406
|
+
|
|
407
|
+
::Lda::RustBackend.drop_corpus_session(Integer(session_id))
|
|
408
|
+
rescue StandardError
|
|
409
|
+
nil
|
|
410
|
+
end
|
|
411
|
+
|
|
412
|
+
def current_rust_session_config_signature
|
|
413
|
+
[
|
|
414
|
+
Integer(num_topics),
|
|
415
|
+
Integer(max_iter),
|
|
416
|
+
Float(convergence),
|
|
417
|
+
Integer(em_max_iter),
|
|
418
|
+
Float(em_convergence),
|
|
419
|
+
Float(init_alpha),
|
|
420
|
+
MIN_PROBABILITY
|
|
421
|
+
]
|
|
422
|
+
end
|
|
423
|
+
|
|
424
|
+
def valid_rust_em_output?(output, document_lengths, topics, terms)
|
|
425
|
+
return false unless output.is_a?(Array)
|
|
426
|
+
return false unless output.size == 4
|
|
427
|
+
|
|
428
|
+
beta_probabilities, beta_log, gamma, phi = output
|
|
429
|
+
|
|
430
|
+
valid_topic_term_matrix?(beta_probabilities, topics, terms) &&
|
|
431
|
+
valid_topic_term_matrix?(beta_log, topics, terms) &&
|
|
432
|
+
valid_gamma_matrix?(gamma, document_lengths.size, topics) &&
|
|
433
|
+
valid_phi_tensor?(phi, document_lengths, topics)
|
|
434
|
+
end
|
|
435
|
+
|
|
436
|
+
def valid_topic_term_matrix?(matrix, topics, terms)
|
|
437
|
+
return false unless matrix.is_a?(Array)
|
|
438
|
+
return false unless matrix.size == topics
|
|
439
|
+
|
|
440
|
+
matrix.all? do |row|
|
|
441
|
+
row.is_a?(Array) &&
|
|
442
|
+
row.size == terms &&
|
|
443
|
+
row.all? { |value| finite_numeric?(value) }
|
|
444
|
+
end
|
|
445
|
+
end
|
|
446
|
+
|
|
447
|
+
def valid_gamma_matrix?(gamma, expected_docs, topics)
|
|
448
|
+
return false unless gamma.is_a?(Array)
|
|
449
|
+
return false unless gamma.size == expected_docs
|
|
450
|
+
|
|
451
|
+
gamma.all? do |row|
|
|
452
|
+
row.is_a?(Array) &&
|
|
453
|
+
row.size == topics &&
|
|
454
|
+
row.all? { |value| finite_numeric?(value) && value.positive? }
|
|
455
|
+
end
|
|
456
|
+
end
|
|
457
|
+
|
|
458
|
+
def valid_phi_tensor?(phi, document_lengths, topics)
|
|
459
|
+
return false unless phi.is_a?(Array)
|
|
460
|
+
return false unless phi.size == document_lengths.size
|
|
461
|
+
|
|
462
|
+
phi.each_with_index.all? do |doc_phi, doc_index|
|
|
463
|
+
doc_phi.is_a?(Array) &&
|
|
464
|
+
doc_phi.size == document_lengths[doc_index] &&
|
|
465
|
+
doc_phi.all? do |row|
|
|
466
|
+
row.is_a?(Array) &&
|
|
467
|
+
row.size == topics &&
|
|
468
|
+
row.all? { |value| finite_numeric?(value) }
|
|
469
|
+
end
|
|
470
|
+
end
|
|
471
|
+
end
|
|
472
|
+
|
|
473
|
+
def finite_numeric?(value)
|
|
474
|
+
value.is_a?(Numeric) && value.finite?
|
|
475
|
+
end
|
|
476
|
+
|
|
96
477
|
def rust_before_em(start)
|
|
97
478
|
return unless defined?(::Lda::RustBackend)
|
|
98
479
|
return unless ::Lda::RustBackend.respond_to?(:before_em)
|
data/lib/lda-ruby/version.rb
CHANGED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative "test_helper"
|
|
4
|
+
require "open3"
|
|
5
|
+
|
|
6
|
+
class BenchmarkScriptsTest < Test::Unit::TestCase
|
|
7
|
+
def setup
|
|
8
|
+
@repo_root = File.expand_path("..", __dir__)
|
|
9
|
+
@check_rust_benchmark = File.join(@repo_root, "bin", "check-rust-benchmark")
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
def test_check_rust_benchmark_help
|
|
13
|
+
stdout, stderr, status = Open3.capture3(@check_rust_benchmark, "--help", chdir: @repo_root)
|
|
14
|
+
assert(status.success?, stderr)
|
|
15
|
+
assert_match(/Usage: \.\/bin\/check-rust-benchmark/, stdout)
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def test_check_rust_benchmark_rejects_unknown_argument
|
|
19
|
+
_stdout, stderr, status = Open3.capture3(@check_rust_benchmark, "--unknown", chdir: @repo_root)
|
|
20
|
+
assert(!status.success?, "expected check-rust-benchmark to fail for unknown args")
|
|
21
|
+
assert_match(/Unknown argument/, stderr)
|
|
22
|
+
end
|
|
23
|
+
end
|