lda-ruby 0.4.0-x86_64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +61 -0
- data/Gemfile +9 -0
- data/README.md +157 -0
- data/VERSION.yml +5 -0
- data/docs/modernization-handoff.md +190 -0
- data/docs/porting-strategy.md +127 -0
- data/docs/precompiled-platform-policy.md +68 -0
- data/docs/release-runbook.md +157 -0
- data/ext/lda-ruby/cokus.c +145 -0
- data/ext/lda-ruby/cokus.h +27 -0
- data/ext/lda-ruby/extconf.rb +13 -0
- data/ext/lda-ruby/lda-alpha.c +96 -0
- data/ext/lda-ruby/lda-alpha.h +21 -0
- data/ext/lda-ruby/lda-data.c +67 -0
- data/ext/lda-ruby/lda-data.h +14 -0
- data/ext/lda-ruby/lda-inference.c +1023 -0
- data/ext/lda-ruby/lda-inference.h +63 -0
- data/ext/lda-ruby/lda-model.c +345 -0
- data/ext/lda-ruby/lda-model.h +31 -0
- data/ext/lda-ruby/lda.h +54 -0
- data/ext/lda-ruby/utils.c +111 -0
- data/ext/lda-ruby/utils.h +18 -0
- data/ext/lda-ruby-rust/Cargo.toml +12 -0
- data/ext/lda-ruby-rust/README.md +48 -0
- data/ext/lda-ruby-rust/extconf.rb +123 -0
- data/ext/lda-ruby-rust/src/lib.rs +456 -0
- data/lda-ruby.gemspec +78 -0
- data/lib/lda-ruby/backends/base.rb +129 -0
- data/lib/lda-ruby/backends/native.rb +158 -0
- data/lib/lda-ruby/backends/pure_ruby.rb +613 -0
- data/lib/lda-ruby/backends/rust.rb +226 -0
- data/lib/lda-ruby/backends.rb +58 -0
- data/lib/lda-ruby/config/stopwords.yml +571 -0
- data/lib/lda-ruby/corpus/corpus.rb +45 -0
- data/lib/lda-ruby/corpus/data_corpus.rb +22 -0
- data/lib/lda-ruby/corpus/directory_corpus.rb +25 -0
- data/lib/lda-ruby/corpus/text_corpus.rb +27 -0
- data/lib/lda-ruby/document/data_document.rb +30 -0
- data/lib/lda-ruby/document/document.rb +40 -0
- data/lib/lda-ruby/document/text_document.rb +39 -0
- data/lib/lda-ruby/lda.so +0 -0
- data/lib/lda-ruby/rust_build_policy.rb +21 -0
- data/lib/lda-ruby/version.rb +5 -0
- data/lib/lda-ruby/vocabulary.rb +46 -0
- data/lib/lda-ruby.rb +413 -0
- data/lib/lda_ruby_rust.so +0 -0
- data/license.txt +504 -0
- data/test/backend_compatibility_test.rb +146 -0
- data/test/backends_selection_test.rb +100 -0
- data/test/data/docs.dat +46 -0
- data/test/data/sample.rb +20 -0
- data/test/data/wiki-test-docs.yml +123 -0
- data/test/gemspec_test.rb +27 -0
- data/test/lda_ruby_test.rb +319 -0
- data/test/packaged_gem_smoke_test.rb +33 -0
- data/test/release_scripts_test.rb +54 -0
- data/test/rust_build_policy_test.rb +23 -0
- data/test/simple_pipeline_test.rb +22 -0
- data/test/simple_yaml.rb +17 -0
- data/test/test_helper.rb +10 -0
- metadata +111 -0
|
@@ -0,0 +1,613 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Lda
|
|
4
|
+
module Backends
|
|
5
|
+
class PureRuby < Base
|
|
6
|
+
MIN_PROBABILITY = 1e-12
|
|
7
|
+
|
|
8
|
+
def initialize(random_seed: nil)
|
|
9
|
+
super(random_seed: random_seed)
|
|
10
|
+
@beta_probabilities = nil
|
|
11
|
+
@beta_log = nil
|
|
12
|
+
@gamma = nil
|
|
13
|
+
@phi = nil
|
|
14
|
+
@topic_weights_kernel = nil
|
|
15
|
+
@topic_term_accumulator_kernel = nil
|
|
16
|
+
@document_inference_kernel = nil
|
|
17
|
+
@corpus_iteration_kernel = nil
|
|
18
|
+
@topic_term_finalizer_kernel = nil
|
|
19
|
+
@gamma_shift_kernel = nil
|
|
20
|
+
@topic_document_probability_kernel = nil
|
|
21
|
+
@topic_term_seed_kernel = nil
|
|
22
|
+
@trusted_kernel_outputs = false
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
attr_writer :topic_weights_kernel,
|
|
26
|
+
:topic_term_accumulator_kernel,
|
|
27
|
+
:document_inference_kernel,
|
|
28
|
+
:corpus_iteration_kernel,
|
|
29
|
+
:topic_term_finalizer_kernel,
|
|
30
|
+
:gamma_shift_kernel,
|
|
31
|
+
:topic_document_probability_kernel,
|
|
32
|
+
:topic_term_seed_kernel,
|
|
33
|
+
:trusted_kernel_outputs
|
|
34
|
+
|
|
35
|
+
def name
|
|
36
|
+
"pure_ruby"
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
def corpus=(corpus)
|
|
40
|
+
super
|
|
41
|
+
@beta_probabilities = nil
|
|
42
|
+
@beta_log = nil
|
|
43
|
+
@gamma = nil
|
|
44
|
+
@phi = nil
|
|
45
|
+
true
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
def em(start)
|
|
49
|
+
return nil if @corpus.nil? || @corpus.num_docs.zero?
|
|
50
|
+
|
|
51
|
+
topics = Integer(num_topics)
|
|
52
|
+
raise ArgumentError, "num_topics must be greater than zero" if topics <= 0
|
|
53
|
+
|
|
54
|
+
terms = max_term_index + 1
|
|
55
|
+
raise ArgumentError, "corpus must contain terms" if terms <= 0
|
|
56
|
+
|
|
57
|
+
document_words = @corpus.documents.map { |document| document.words.map(&:to_i) }
|
|
58
|
+
document_counts = @corpus.documents.map { |document| document.counts.map(&:to_f) }
|
|
59
|
+
|
|
60
|
+
@beta_probabilities =
|
|
61
|
+
if start.to_s.strip.casecmp("seeded").zero? || start.to_s.strip.casecmp("deterministic").zero?
|
|
62
|
+
seeded_topic_term_probabilities(topics, terms, document_words, document_counts)
|
|
63
|
+
else
|
|
64
|
+
initial_topic_term_probabilities(topics, terms)
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
document_totals = document_counts.map { |counts| counts.sum.to_f }
|
|
68
|
+
document_lengths = document_words.map(&:length)
|
|
69
|
+
|
|
70
|
+
previous_gamma = nil
|
|
71
|
+
|
|
72
|
+
Integer(em_max_iter).times do
|
|
73
|
+
if @trusted_kernel_outputs && @corpus_iteration_kernel
|
|
74
|
+
current_gamma, current_phi, topic_term_counts = infer_corpus_iteration(
|
|
75
|
+
nil,
|
|
76
|
+
document_words,
|
|
77
|
+
document_counts,
|
|
78
|
+
document_totals,
|
|
79
|
+
document_lengths,
|
|
80
|
+
topics,
|
|
81
|
+
terms
|
|
82
|
+
)
|
|
83
|
+
else
|
|
84
|
+
topic_term_counts = Array.new(topics) { Array.new(terms, MIN_PROBABILITY) }
|
|
85
|
+
current_gamma, current_phi, topic_term_counts = infer_corpus_iteration(
|
|
86
|
+
topic_term_counts,
|
|
87
|
+
document_words,
|
|
88
|
+
document_counts,
|
|
89
|
+
document_totals,
|
|
90
|
+
document_lengths,
|
|
91
|
+
topics,
|
|
92
|
+
terms
|
|
93
|
+
)
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
@beta_probabilities, @beta_log = finalize_topic_term_counts(topic_term_counts)
|
|
97
|
+
@gamma = current_gamma
|
|
98
|
+
@phi = current_phi
|
|
99
|
+
|
|
100
|
+
break if previous_gamma && average_gamma_shift(previous_gamma, current_gamma) <= Float(em_convergence)
|
|
101
|
+
|
|
102
|
+
previous_gamma = current_gamma
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
nil
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
def beta
|
|
109
|
+
@beta_log || []
|
|
110
|
+
end
|
|
111
|
+
|
|
112
|
+
def gamma
|
|
113
|
+
@gamma || []
|
|
114
|
+
end
|
|
115
|
+
|
|
116
|
+
def compute_phi
|
|
117
|
+
clone_matrix(@phi || [])
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
def model
|
|
121
|
+
[Integer(num_topics), max_term_index + 1, Float(init_alpha)]
|
|
122
|
+
end
|
|
123
|
+
|
|
124
|
+
def topic_document_probability(phi_matrix, document_counts)
|
|
125
|
+
kernel_output = nil
|
|
126
|
+
if @topic_document_probability_kernel
|
|
127
|
+
kernel_output = @topic_document_probability_kernel.call(
|
|
128
|
+
phi_matrix,
|
|
129
|
+
document_counts,
|
|
130
|
+
Integer(num_topics),
|
|
131
|
+
MIN_PROBABILITY
|
|
132
|
+
)
|
|
133
|
+
end
|
|
134
|
+
|
|
135
|
+
if valid_topic_document_probability_output?(kernel_output, document_counts.size, Integer(num_topics))
|
|
136
|
+
if @trusted_kernel_outputs
|
|
137
|
+
kernel_output
|
|
138
|
+
else
|
|
139
|
+
kernel_output.map { |row| row.map(&:to_f) }
|
|
140
|
+
end
|
|
141
|
+
else
|
|
142
|
+
default_topic_document_probability(phi_matrix, document_counts)
|
|
143
|
+
end
|
|
144
|
+
rescue StandardError
|
|
145
|
+
default_topic_document_probability(phi_matrix, document_counts)
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
private
|
|
149
|
+
|
|
150
|
+
def max_term_index
|
|
151
|
+
return -1 if @corpus.nil? || @corpus.documents.empty?
|
|
152
|
+
|
|
153
|
+
@corpus.documents
|
|
154
|
+
.flat_map(&:words)
|
|
155
|
+
.max || -1
|
|
156
|
+
end
|
|
157
|
+
|
|
158
|
+
def initial_topic_term_probabilities(topics, terms)
|
|
159
|
+
Array.new(topics) do
|
|
160
|
+
weights = Array.new(terms) { @random.rand + MIN_PROBABILITY }
|
|
161
|
+
normalize!(weights)
|
|
162
|
+
end
|
|
163
|
+
end
|
|
164
|
+
|
|
165
|
+
def seeded_topic_term_probabilities(topics, terms, document_words, document_counts)
|
|
166
|
+
kernel_output = nil
|
|
167
|
+
if @topic_term_seed_kernel
|
|
168
|
+
kernel_output = @topic_term_seed_kernel.call(
|
|
169
|
+
document_words,
|
|
170
|
+
document_counts,
|
|
171
|
+
Integer(topics),
|
|
172
|
+
Integer(terms),
|
|
173
|
+
MIN_PROBABILITY
|
|
174
|
+
)
|
|
175
|
+
end
|
|
176
|
+
|
|
177
|
+
if valid_seeded_topic_term_probabilities?(kernel_output, topics, terms)
|
|
178
|
+
if @trusted_kernel_outputs
|
|
179
|
+
kernel_output
|
|
180
|
+
else
|
|
181
|
+
kernel_output.map { |weights| normalize!(weights.map(&:to_f)) }
|
|
182
|
+
end
|
|
183
|
+
else
|
|
184
|
+
default_seeded_topic_term_probabilities(topics, terms, document_words, document_counts)
|
|
185
|
+
end
|
|
186
|
+
rescue StandardError
|
|
187
|
+
default_seeded_topic_term_probabilities(topics, terms, document_words, document_counts)
|
|
188
|
+
end
|
|
189
|
+
|
|
190
|
+
def valid_seeded_topic_term_probabilities?(matrix, expected_topics, expected_terms)
|
|
191
|
+
return false unless matrix.is_a?(Array)
|
|
192
|
+
return false unless matrix.size == expected_topics
|
|
193
|
+
|
|
194
|
+
matrix.each do |row|
|
|
195
|
+
return false unless row.is_a?(Array)
|
|
196
|
+
return false unless row.size == expected_terms
|
|
197
|
+
row.each do |value|
|
|
198
|
+
return false unless value.is_a?(Numeric)
|
|
199
|
+
return false unless value.finite?
|
|
200
|
+
end
|
|
201
|
+
end
|
|
202
|
+
|
|
203
|
+
true
|
|
204
|
+
end
|
|
205
|
+
|
|
206
|
+
def default_seeded_topic_term_probabilities(topics, terms, document_words, document_counts)
|
|
207
|
+
topic_term_counts = Array.new(topics) { Array.new(terms, MIN_PROBABILITY) }
|
|
208
|
+
|
|
209
|
+
document_words.each_with_index do |words, document_index|
|
|
210
|
+
topic_index = document_index % topics
|
|
211
|
+
counts = document_counts[document_index] || []
|
|
212
|
+
|
|
213
|
+
words.each_with_index do |word_index, word_offset|
|
|
214
|
+
next if word_index >= terms
|
|
215
|
+
|
|
216
|
+
topic_term_counts[topic_index][word_index] += counts[word_offset].to_f
|
|
217
|
+
end
|
|
218
|
+
end
|
|
219
|
+
|
|
220
|
+
topic_term_counts.map { |weights| normalize!(weights) }
|
|
221
|
+
end
|
|
222
|
+
|
|
223
|
+
def topic_weights_for_word(word_index, gamma_d)
|
|
224
|
+
kernel_weights = nil
|
|
225
|
+
if @topic_weights_kernel
|
|
226
|
+
kernel_weights = @topic_weights_kernel.call(@beta_probabilities, gamma_d, Integer(word_index), MIN_PROBABILITY)
|
|
227
|
+
end
|
|
228
|
+
|
|
229
|
+
weights =
|
|
230
|
+
if valid_topic_weights?(kernel_weights, gamma_d.length)
|
|
231
|
+
kernel_weights.map(&:to_f)
|
|
232
|
+
else
|
|
233
|
+
default_topic_weights_for_word(word_index, gamma_d)
|
|
234
|
+
end
|
|
235
|
+
|
|
236
|
+
normalize!(weights)
|
|
237
|
+
rescue StandardError
|
|
238
|
+
normalize!(default_topic_weights_for_word(word_index, gamma_d))
|
|
239
|
+
end
|
|
240
|
+
|
|
241
|
+
def valid_topic_weights?(weights, expected_size)
|
|
242
|
+
weights.is_a?(Array) && weights.size == expected_size
|
|
243
|
+
end
|
|
244
|
+
|
|
245
|
+
def default_topic_weights_for_word(word_index, gamma_d)
|
|
246
|
+
topics = gamma_d.length
|
|
247
|
+
|
|
248
|
+
Array.new(topics) do |topic_index|
|
|
249
|
+
@beta_probabilities[topic_index][word_index] * [gamma_d[topic_index], MIN_PROBABILITY].max
|
|
250
|
+
end
|
|
251
|
+
end
|
|
252
|
+
|
|
253
|
+
def infer_document(gamma_initial, phi_initial, words, counts)
|
|
254
|
+
kernel_output = nil
|
|
255
|
+
|
|
256
|
+
if @document_inference_kernel
|
|
257
|
+
kernel_output = @document_inference_kernel.call(
|
|
258
|
+
@beta_probabilities,
|
|
259
|
+
gamma_initial,
|
|
260
|
+
words.map(&:to_i),
|
|
261
|
+
counts.map(&:to_f),
|
|
262
|
+
Integer(max_iter),
|
|
263
|
+
Float(convergence),
|
|
264
|
+
MIN_PROBABILITY,
|
|
265
|
+
Float(init_alpha)
|
|
266
|
+
)
|
|
267
|
+
end
|
|
268
|
+
|
|
269
|
+
if valid_document_inference_output?(kernel_output, gamma_initial.length, phi_initial.length)
|
|
270
|
+
if @trusted_kernel_outputs
|
|
271
|
+
[kernel_output[0], kernel_output[1]]
|
|
272
|
+
else
|
|
273
|
+
gamma_out = kernel_output[0].map(&:to_f)
|
|
274
|
+
phi_out = kernel_output[1].map { |row| normalize!(row.map(&:to_f)) }
|
|
275
|
+
[gamma_out, phi_out]
|
|
276
|
+
end
|
|
277
|
+
else
|
|
278
|
+
default_infer_document(gamma_initial, phi_initial, words, counts)
|
|
279
|
+
end
|
|
280
|
+
rescue StandardError
|
|
281
|
+
default_infer_document(gamma_initial, phi_initial, words, counts)
|
|
282
|
+
end
|
|
283
|
+
|
|
284
|
+
def valid_document_inference_output?(output, expected_topics, expected_length)
|
|
285
|
+
return false unless output.is_a?(Array)
|
|
286
|
+
return false unless output.size == 2
|
|
287
|
+
|
|
288
|
+
gamma_out = output[0]
|
|
289
|
+
phi_out = output[1]
|
|
290
|
+
|
|
291
|
+
return false unless gamma_out.is_a?(Array) && gamma_out.size == expected_topics
|
|
292
|
+
return false unless phi_out.is_a?(Array) && phi_out.size == expected_length
|
|
293
|
+
|
|
294
|
+
phi_out.all? { |row| row.is_a?(Array) && row.size == expected_topics }
|
|
295
|
+
end
|
|
296
|
+
|
|
297
|
+
def default_infer_document(gamma_initial, phi_initial, words, counts)
|
|
298
|
+
topics = gamma_initial.length
|
|
299
|
+
gamma_d = gamma_initial.dup
|
|
300
|
+
phi_d = phi_initial
|
|
301
|
+
|
|
302
|
+
Integer(max_iter).times do
|
|
303
|
+
gamma_next = Array.new(topics, Float(init_alpha))
|
|
304
|
+
|
|
305
|
+
words.each_with_index do |word_index, word_offset|
|
|
306
|
+
topic_weights = topic_weights_for_word(word_index, gamma_d)
|
|
307
|
+
phi_d[word_offset] = topic_weights
|
|
308
|
+
|
|
309
|
+
count = counts[word_offset].to_f
|
|
310
|
+
topics.times do |topic_index|
|
|
311
|
+
gamma_next[topic_index] += count * topic_weights[topic_index]
|
|
312
|
+
end
|
|
313
|
+
end
|
|
314
|
+
|
|
315
|
+
gamma_shift = max_absolute_distance(gamma_d, gamma_next)
|
|
316
|
+
gamma_d = gamma_next
|
|
317
|
+
break if gamma_shift <= Float(convergence)
|
|
318
|
+
end
|
|
319
|
+
|
|
320
|
+
[gamma_d, phi_d]
|
|
321
|
+
end
|
|
322
|
+
|
|
323
|
+
def infer_corpus_iteration(
|
|
324
|
+
topic_term_counts_initial,
|
|
325
|
+
document_words,
|
|
326
|
+
document_counts,
|
|
327
|
+
document_totals,
|
|
328
|
+
document_lengths,
|
|
329
|
+
topics,
|
|
330
|
+
terms
|
|
331
|
+
)
|
|
332
|
+
topic_term_counts_fallback =
|
|
333
|
+
topic_term_counts_initial || Array.new(topics) { Array.new(terms, MIN_PROBABILITY) }
|
|
334
|
+
kernel_output = nil
|
|
335
|
+
|
|
336
|
+
if @corpus_iteration_kernel
|
|
337
|
+
kernel_output = @corpus_iteration_kernel.call(
|
|
338
|
+
@beta_probabilities,
|
|
339
|
+
document_words,
|
|
340
|
+
document_counts,
|
|
341
|
+
Integer(max_iter),
|
|
342
|
+
Float(convergence),
|
|
343
|
+
MIN_PROBABILITY,
|
|
344
|
+
Float(init_alpha)
|
|
345
|
+
)
|
|
346
|
+
end
|
|
347
|
+
|
|
348
|
+
if valid_corpus_iteration_output?(kernel_output, document_words.size, document_lengths, topics, terms)
|
|
349
|
+
if @trusted_kernel_outputs
|
|
350
|
+
[kernel_output[0], kernel_output[1], kernel_output[2]]
|
|
351
|
+
else
|
|
352
|
+
current_gamma = kernel_output[0].map { |row| row.map(&:to_f) }
|
|
353
|
+
current_phi = kernel_output[1].map do |doc_phi|
|
|
354
|
+
doc_phi.map { |row| normalize!(row.map(&:to_f)) }
|
|
355
|
+
end
|
|
356
|
+
topic_term_counts = kernel_output[2].map { |row| row.map(&:to_f) }
|
|
357
|
+
|
|
358
|
+
[current_gamma, current_phi, topic_term_counts]
|
|
359
|
+
end
|
|
360
|
+
else
|
|
361
|
+
default_infer_corpus_iteration(
|
|
362
|
+
topic_term_counts_fallback,
|
|
363
|
+
document_words,
|
|
364
|
+
document_counts,
|
|
365
|
+
document_totals,
|
|
366
|
+
topics
|
|
367
|
+
)
|
|
368
|
+
end
|
|
369
|
+
rescue StandardError
|
|
370
|
+
default_infer_corpus_iteration(
|
|
371
|
+
topic_term_counts_fallback,
|
|
372
|
+
document_words,
|
|
373
|
+
document_counts,
|
|
374
|
+
document_totals,
|
|
375
|
+
topics
|
|
376
|
+
)
|
|
377
|
+
end
|
|
378
|
+
|
|
379
|
+
def valid_corpus_iteration_output?(output, expected_docs, expected_lengths, expected_topics, expected_terms)
|
|
380
|
+
return false unless output.is_a?(Array)
|
|
381
|
+
return false unless output.size == 3
|
|
382
|
+
|
|
383
|
+
gamma_matrix = output[0]
|
|
384
|
+
phi_tensor = output[1]
|
|
385
|
+
topic_term_counts = output[2]
|
|
386
|
+
|
|
387
|
+
return false unless gamma_matrix.is_a?(Array) && gamma_matrix.size == expected_docs
|
|
388
|
+
return false unless phi_tensor.is_a?(Array) && phi_tensor.size == expected_docs
|
|
389
|
+
return false unless topic_term_counts.is_a?(Array) && topic_term_counts.size == expected_topics
|
|
390
|
+
|
|
391
|
+
gamma_matrix.each do |row|
|
|
392
|
+
return false unless row.is_a?(Array) && row.size == expected_topics
|
|
393
|
+
end
|
|
394
|
+
|
|
395
|
+
phi_tensor.each_with_index do |doc_phi, index|
|
|
396
|
+
return false unless doc_phi.is_a?(Array) && doc_phi.size == expected_lengths[index]
|
|
397
|
+
doc_phi.each do |row|
|
|
398
|
+
return false unless row.is_a?(Array) && row.size == expected_topics
|
|
399
|
+
end
|
|
400
|
+
end
|
|
401
|
+
|
|
402
|
+
topic_term_counts.each do |row|
|
|
403
|
+
return false unless row.is_a?(Array) && row.size == expected_terms
|
|
404
|
+
end
|
|
405
|
+
|
|
406
|
+
true
|
|
407
|
+
end
|
|
408
|
+
|
|
409
|
+
def default_infer_corpus_iteration(
|
|
410
|
+
topic_term_counts_initial,
|
|
411
|
+
document_words,
|
|
412
|
+
document_counts,
|
|
413
|
+
document_totals,
|
|
414
|
+
topics
|
|
415
|
+
)
|
|
416
|
+
topic_term_counts = topic_term_counts_initial
|
|
417
|
+
current_gamma = Array.new(document_words.size) { Array.new(topics, Float(init_alpha)) }
|
|
418
|
+
current_phi = Array.new(document_words.size)
|
|
419
|
+
|
|
420
|
+
document_words.each_with_index do |words, document_index|
|
|
421
|
+
counts = document_counts[document_index]
|
|
422
|
+
total = document_totals[document_index].to_f
|
|
423
|
+
|
|
424
|
+
gamma_d = Array.new(topics, Float(init_alpha) + (total / topics))
|
|
425
|
+
phi_d = Array.new(words.length) { Array.new(topics, 1.0 / topics) }
|
|
426
|
+
|
|
427
|
+
gamma_d, phi_d = infer_document(gamma_d, phi_d, words, counts)
|
|
428
|
+
|
|
429
|
+
current_gamma[document_index] = gamma_d
|
|
430
|
+
current_phi[document_index] = phi_d
|
|
431
|
+
topic_term_counts = accumulate_topic_term_counts(topic_term_counts, phi_d, words, counts)
|
|
432
|
+
end
|
|
433
|
+
|
|
434
|
+
[current_gamma, current_phi, topic_term_counts]
|
|
435
|
+
end
|
|
436
|
+
|
|
437
|
+
def accumulate_topic_term_counts(topic_term_counts, phi_d, words, counts)
|
|
438
|
+
kernel_counts = nil
|
|
439
|
+
if @topic_term_accumulator_kernel
|
|
440
|
+
kernel_counts = @topic_term_accumulator_kernel.call(
|
|
441
|
+
topic_term_counts,
|
|
442
|
+
phi_d,
|
|
443
|
+
words.map(&:to_i),
|
|
444
|
+
counts.map(&:to_f)
|
|
445
|
+
)
|
|
446
|
+
end
|
|
447
|
+
|
|
448
|
+
if valid_topic_term_counts?(kernel_counts, topic_term_counts)
|
|
449
|
+
kernel_counts
|
|
450
|
+
else
|
|
451
|
+
default_accumulate_topic_term_counts(topic_term_counts, phi_d, words, counts)
|
|
452
|
+
end
|
|
453
|
+
rescue StandardError
|
|
454
|
+
default_accumulate_topic_term_counts(topic_term_counts, phi_d, words, counts)
|
|
455
|
+
end
|
|
456
|
+
|
|
457
|
+
def valid_topic_term_counts?(candidate, reference)
|
|
458
|
+
return false unless candidate.is_a?(Array)
|
|
459
|
+
return false unless candidate.size == reference.size
|
|
460
|
+
|
|
461
|
+
candidate.each_with_index do |row, index|
|
|
462
|
+
return false unless row.is_a?(Array)
|
|
463
|
+
return false unless row.size == reference[index].size
|
|
464
|
+
end
|
|
465
|
+
|
|
466
|
+
true
|
|
467
|
+
end
|
|
468
|
+
|
|
469
|
+
def default_accumulate_topic_term_counts(topic_term_counts, phi_d, words, counts)
|
|
470
|
+
topics = topic_term_counts.size
|
|
471
|
+
|
|
472
|
+
words.each_with_index do |word_index, word_offset|
|
|
473
|
+
count = counts[word_offset].to_f
|
|
474
|
+
next if count.zero?
|
|
475
|
+
|
|
476
|
+
topics.times do |topic_index|
|
|
477
|
+
topic_term_counts[topic_index][word_index] += count * phi_d[word_offset][topic_index]
|
|
478
|
+
end
|
|
479
|
+
end
|
|
480
|
+
|
|
481
|
+
topic_term_counts
|
|
482
|
+
end
|
|
483
|
+
|
|
484
|
+
def finalize_topic_term_counts(topic_term_counts)
|
|
485
|
+
kernel_output = nil
|
|
486
|
+
if @topic_term_finalizer_kernel
|
|
487
|
+
kernel_output = @topic_term_finalizer_kernel.call(topic_term_counts, MIN_PROBABILITY)
|
|
488
|
+
end
|
|
489
|
+
|
|
490
|
+
if valid_topic_term_finalization_output?(kernel_output, topic_term_counts)
|
|
491
|
+
if @trusted_kernel_outputs
|
|
492
|
+
[kernel_output[0], kernel_output[1]]
|
|
493
|
+
else
|
|
494
|
+
beta_probabilities = kernel_output[0].map { |row| row.map(&:to_f) }
|
|
495
|
+
beta_log = kernel_output[1].map { |row| row.map(&:to_f) }
|
|
496
|
+
[beta_probabilities, beta_log]
|
|
497
|
+
end
|
|
498
|
+
else
|
|
499
|
+
default_finalize_topic_term_counts(topic_term_counts)
|
|
500
|
+
end
|
|
501
|
+
rescue StandardError
|
|
502
|
+
default_finalize_topic_term_counts(topic_term_counts)
|
|
503
|
+
end
|
|
504
|
+
|
|
505
|
+
def valid_topic_term_finalization_output?(output, topic_term_counts)
|
|
506
|
+
return false unless output.is_a?(Array)
|
|
507
|
+
return false unless output.size == 2
|
|
508
|
+
|
|
509
|
+
beta_probabilities = output[0]
|
|
510
|
+
beta_log = output[1]
|
|
511
|
+
return false unless beta_probabilities.is_a?(Array) && beta_log.is_a?(Array)
|
|
512
|
+
return false unless beta_probabilities.size == topic_term_counts.size
|
|
513
|
+
return false unless beta_log.size == topic_term_counts.size
|
|
514
|
+
|
|
515
|
+
beta_probabilities.each_with_index do |row, index|
|
|
516
|
+
return false unless row.is_a?(Array)
|
|
517
|
+
return false unless row.size == topic_term_counts[index].size
|
|
518
|
+
end
|
|
519
|
+
|
|
520
|
+
beta_log.each_with_index do |row, index|
|
|
521
|
+
return false unless row.is_a?(Array)
|
|
522
|
+
return false unless row.size == topic_term_counts[index].size
|
|
523
|
+
end
|
|
524
|
+
|
|
525
|
+
true
|
|
526
|
+
end
|
|
527
|
+
|
|
528
|
+
def default_finalize_topic_term_counts(topic_term_counts)
|
|
529
|
+
beta_probabilities = topic_term_counts.map { |weights| normalize!(weights) }
|
|
530
|
+
beta_log = beta_probabilities.map do |topic_weights|
|
|
531
|
+
topic_weights.map { |probability| Math.log([probability, MIN_PROBABILITY].max) }
|
|
532
|
+
end
|
|
533
|
+
|
|
534
|
+
[beta_probabilities, beta_log]
|
|
535
|
+
end
|
|
536
|
+
|
|
537
|
+
def valid_topic_document_probability_output?(output, expected_docs, expected_topics)
|
|
538
|
+
return false unless output.is_a?(Array)
|
|
539
|
+
return false unless output.size == expected_docs
|
|
540
|
+
|
|
541
|
+
output.each do |row|
|
|
542
|
+
return false unless row.is_a?(Array)
|
|
543
|
+
return false unless row.size == expected_topics
|
|
544
|
+
row.each do |value|
|
|
545
|
+
return false unless value.is_a?(Numeric)
|
|
546
|
+
return false unless value.finite?
|
|
547
|
+
end
|
|
548
|
+
end
|
|
549
|
+
|
|
550
|
+
true
|
|
551
|
+
end
|
|
552
|
+
|
|
553
|
+
def default_topic_document_probability(phi_matrix, document_counts)
|
|
554
|
+
topics = Integer(num_topics)
|
|
555
|
+
output = []
|
|
556
|
+
|
|
557
|
+
document_counts.each_with_index do |counts, doc_index|
|
|
558
|
+
tops = Array.new(topics, 0.0)
|
|
559
|
+
ttl = counts.inject(0.0) { |sum, value| sum + value.to_f }
|
|
560
|
+
doc_phi = phi_matrix[doc_index] || []
|
|
561
|
+
|
|
562
|
+
doc_phi.each_with_index do |word_dist, word_idx|
|
|
563
|
+
count = counts[word_idx].to_f
|
|
564
|
+
next if count.zero?
|
|
565
|
+
|
|
566
|
+
topics.times do |topic_idx|
|
|
567
|
+
top_prob = word_dist[topic_idx].to_f
|
|
568
|
+
tops[topic_idx] += Math.log([top_prob, MIN_PROBABILITY].max) * count
|
|
569
|
+
end
|
|
570
|
+
end
|
|
571
|
+
|
|
572
|
+
tops = tops.map { |value| value / ttl } if ttl.positive?
|
|
573
|
+
output << tops
|
|
574
|
+
end
|
|
575
|
+
|
|
576
|
+
output
|
|
577
|
+
end
|
|
578
|
+
|
|
579
|
+
def max_absolute_distance(left, right)
|
|
580
|
+
left.zip(right).map { |a, b| (a - b).abs }.max.to_f
|
|
581
|
+
end
|
|
582
|
+
|
|
583
|
+
def average_gamma_shift(previous_gamma, current_gamma)
|
|
584
|
+
kernel_shift = nil
|
|
585
|
+
if @gamma_shift_kernel
|
|
586
|
+
kernel_shift = @gamma_shift_kernel.call(previous_gamma, current_gamma)
|
|
587
|
+
end
|
|
588
|
+
|
|
589
|
+
if kernel_shift.is_a?(Numeric) && kernel_shift.finite? && kernel_shift >= 0.0
|
|
590
|
+
kernel_shift.to_f
|
|
591
|
+
else
|
|
592
|
+
default_average_gamma_shift(previous_gamma, current_gamma)
|
|
593
|
+
end
|
|
594
|
+
rescue StandardError
|
|
595
|
+
default_average_gamma_shift(previous_gamma, current_gamma)
|
|
596
|
+
end
|
|
597
|
+
|
|
598
|
+
def default_average_gamma_shift(previous_gamma, current_gamma)
|
|
599
|
+
deltas = []
|
|
600
|
+
|
|
601
|
+
previous_gamma.each_with_index do |previous_row, row_index|
|
|
602
|
+
previous_row.each_with_index do |previous_value, col_index|
|
|
603
|
+
deltas << (previous_value - current_gamma[row_index][col_index]).abs
|
|
604
|
+
end
|
|
605
|
+
end
|
|
606
|
+
|
|
607
|
+
return 0.0 if deltas.empty?
|
|
608
|
+
|
|
609
|
+
deltas.sum / deltas.size.to_f
|
|
610
|
+
end
|
|
611
|
+
end
|
|
612
|
+
end
|
|
613
|
+
end
|