lda-ruby 0.3.9 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +5 -13
- data/CHANGELOG.md +8 -0
- data/Gemfile +9 -0
- data/README.md +123 -3
- data/VERSION.yml +3 -3
- data/docs/modernization-handoff.md +190 -0
- data/docs/porting-strategy.md +127 -0
- data/docs/precompiled-platform-policy.md +68 -0
- data/docs/release-runbook.md +157 -0
- data/ext/lda-ruby/extconf.rb +10 -6
- data/ext/lda-ruby/lda-inference.c +21 -5
- data/ext/lda-ruby-rust/Cargo.toml +12 -0
- data/ext/lda-ruby-rust/README.md +48 -0
- data/ext/lda-ruby-rust/extconf.rb +123 -0
- data/ext/lda-ruby-rust/src/lib.rs +456 -0
- data/lda-ruby.gemspec +0 -0
- data/lib/lda-ruby/backends/base.rb +129 -0
- data/lib/lda-ruby/backends/native.rb +158 -0
- data/lib/lda-ruby/backends/pure_ruby.rb +613 -0
- data/lib/lda-ruby/backends/rust.rb +226 -0
- data/lib/lda-ruby/backends.rb +58 -0
- data/lib/lda-ruby/corpus/corpus.rb +17 -15
- data/lib/lda-ruby/corpus/data_corpus.rb +2 -2
- data/lib/lda-ruby/corpus/directory_corpus.rb +2 -2
- data/lib/lda-ruby/corpus/text_corpus.rb +2 -2
- data/lib/lda-ruby/document/document.rb +6 -6
- data/lib/lda-ruby/document/text_document.rb +5 -4
- data/lib/lda-ruby/rust_build_policy.rb +21 -0
- data/lib/lda-ruby/version.rb +5 -0
- data/lib/lda-ruby.rb +293 -48
- data/test/backend_compatibility_test.rb +146 -0
- data/test/backends_selection_test.rb +100 -0
- data/test/gemspec_test.rb +27 -0
- data/test/lda_ruby_test.rb +49 -11
- data/test/packaged_gem_smoke_test.rb +33 -0
- data/test/release_scripts_test.rb +54 -0
- data/test/rust_build_policy_test.rb +23 -0
- data/test/simple_pipeline_test.rb +22 -0
- data/test/simple_yaml.rb +1 -7
- data/test/test_helper.rb +5 -6
- metadata +48 -38
- data/Rakefile +0 -61
- data/ext/lda-ruby/Makefile +0 -181
- data/test/data/.gitignore +0 -2
- data/test/simple_test.rb +0 -26
data/lib/lda-ruby.rb
CHANGED
|
@@ -1,29 +1,125 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
require
|
|
4
|
-
require
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "lda-ruby/version"
|
|
4
|
+
require "rbconfig"
|
|
5
|
+
|
|
6
|
+
rust_extension_loaded = false
|
|
7
|
+
rust_dlext = RbConfig::CONFIG.fetch("DLEXT")
|
|
8
|
+
|
|
9
|
+
[
|
|
10
|
+
"lda_ruby_rust",
|
|
11
|
+
"../ext/lda-ruby-rust/target/release/lda_ruby_rust",
|
|
12
|
+
"../ext/lda-ruby-rust/target/release/lda_ruby_rust.#{rust_dlext}",
|
|
13
|
+
"../ext/lda-ruby-rust/target/debug/lda_ruby_rust",
|
|
14
|
+
"../ext/lda-ruby-rust/target/debug/lda_ruby_rust.#{rust_dlext}"
|
|
15
|
+
].each do |rust_extension_candidate|
|
|
16
|
+
begin
|
|
17
|
+
if rust_extension_candidate.start_with?("../")
|
|
18
|
+
require_relative rust_extension_candidate
|
|
19
|
+
else
|
|
20
|
+
require rust_extension_candidate
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
rust_extension_loaded = true
|
|
24
|
+
break
|
|
25
|
+
rescue LoadError
|
|
26
|
+
next
|
|
27
|
+
end
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
native_extension_loaded = false
|
|
31
|
+
|
|
32
|
+
begin
|
|
33
|
+
require "lda-ruby/lda"
|
|
34
|
+
native_extension_loaded = true
|
|
35
|
+
rescue LoadError
|
|
36
|
+
begin
|
|
37
|
+
require_relative "../ext/lda-ruby/lda"
|
|
38
|
+
native_extension_loaded = true
|
|
39
|
+
rescue LoadError
|
|
40
|
+
native_extension_loaded = false
|
|
41
|
+
end
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
LDA_RUBY_NATIVE_EXTENSION_LOADED = native_extension_loaded unless defined?(LDA_RUBY_NATIVE_EXTENSION_LOADED)
|
|
45
|
+
LDA_RUBY_RUST_EXTENSION_LOADED = rust_extension_loaded unless defined?(LDA_RUBY_RUST_EXTENSION_LOADED)
|
|
46
|
+
|
|
47
|
+
require "lda-ruby/document/document"
|
|
48
|
+
require "lda-ruby/document/data_document"
|
|
49
|
+
require "lda-ruby/document/text_document"
|
|
50
|
+
require "lda-ruby/corpus/corpus"
|
|
51
|
+
require "lda-ruby/corpus/data_corpus"
|
|
52
|
+
require "lda-ruby/corpus/text_corpus"
|
|
53
|
+
require "lda-ruby/corpus/directory_corpus"
|
|
54
|
+
require "lda-ruby/vocabulary"
|
|
55
|
+
require "lda-ruby/backends"
|
|
12
56
|
|
|
13
57
|
module Lda
|
|
58
|
+
RUST_EXTENSION_LOADED = LDA_RUBY_RUST_EXTENSION_LOADED unless const_defined?(:RUST_EXTENSION_LOADED)
|
|
59
|
+
NATIVE_EXTENSION_LOADED = LDA_RUBY_NATIVE_EXTENSION_LOADED unless const_defined?(:NATIVE_EXTENSION_LOADED)
|
|
60
|
+
|
|
14
61
|
class Lda
|
|
15
|
-
|
|
62
|
+
NATIVE_ALIAS_MAP = {
|
|
63
|
+
fast_load_corpus_from_file: :__native_fast_load_corpus_from_file,
|
|
64
|
+
"corpus=": :__native_set_corpus,
|
|
65
|
+
em: :__native_em,
|
|
66
|
+
load_settings: :__native_load_settings,
|
|
67
|
+
set_config: :__native_set_config,
|
|
68
|
+
max_iter: :__native_max_iter,
|
|
69
|
+
"max_iter=": :__native_set_max_iter,
|
|
70
|
+
convergence: :__native_convergence,
|
|
71
|
+
"convergence=": :__native_set_convergence,
|
|
72
|
+
em_max_iter: :__native_em_max_iter,
|
|
73
|
+
"em_max_iter=": :__native_set_em_max_iter,
|
|
74
|
+
em_convergence: :__native_em_convergence,
|
|
75
|
+
"em_convergence=": :__native_set_em_convergence,
|
|
76
|
+
init_alpha: :__native_init_alpha,
|
|
77
|
+
"init_alpha=": :__native_set_init_alpha,
|
|
78
|
+
est_alpha: :__native_est_alpha,
|
|
79
|
+
"est_alpha=": :__native_set_est_alpha,
|
|
80
|
+
num_topics: :__native_num_topics,
|
|
81
|
+
"num_topics=": :__native_set_num_topics,
|
|
82
|
+
verbose: :__native_verbose,
|
|
83
|
+
"verbose=": :__native_set_verbose,
|
|
84
|
+
beta: :__native_beta,
|
|
85
|
+
gamma: :__native_gamma,
|
|
86
|
+
compute_phi: :__native_compute_phi,
|
|
87
|
+
model: :__native_model
|
|
88
|
+
}.freeze
|
|
89
|
+
|
|
90
|
+
NATIVE_ALIAS_MAP.each do |native_name, alias_name|
|
|
91
|
+
next unless method_defined?(native_name)
|
|
92
|
+
|
|
93
|
+
alias_method alias_name, native_name
|
|
94
|
+
private alias_name
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
attr_reader :vocab, :corpus, :backend
|
|
98
|
+
|
|
99
|
+
def initialize(corpus, backend: nil, random_seed: nil)
|
|
100
|
+
@backend = Backends.build(host: self, requested: backend, random_seed: random_seed)
|
|
16
101
|
|
|
17
|
-
def initialize(corpus)
|
|
18
102
|
load_default_settings
|
|
19
103
|
|
|
20
104
|
@vocab = nil
|
|
21
105
|
self.corpus = corpus
|
|
22
|
-
@vocab = corpus.vocabulary.to_a if corpus.vocabulary
|
|
106
|
+
@vocab = corpus.vocabulary.to_a if corpus.respond_to?(:vocabulary) && corpus.vocabulary
|
|
23
107
|
|
|
24
108
|
@phi = nil
|
|
25
109
|
end
|
|
26
110
|
|
|
111
|
+
def backend_name
|
|
112
|
+
@backend.name
|
|
113
|
+
end
|
|
114
|
+
|
|
115
|
+
def native_backend?
|
|
116
|
+
backend_name == "native"
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
def rust_backend?
|
|
120
|
+
backend_name == "rust"
|
|
121
|
+
end
|
|
122
|
+
|
|
27
123
|
def load_default_settings
|
|
28
124
|
self.max_iter = 20
|
|
29
125
|
self.convergence = 1e-6
|
|
@@ -36,25 +132,138 @@ module Lda
|
|
|
36
132
|
[20, 1e-6, 100, 1e-4, 20, 0.3, 1]
|
|
37
133
|
end
|
|
38
134
|
|
|
39
|
-
def
|
|
40
|
-
@
|
|
41
|
-
|
|
135
|
+
def set_config(init_alpha, num_topics, max_iter, convergence, em_max_iter, em_convergence = self.em_convergence, est_alpha = self.est_alpha)
|
|
136
|
+
@backend.set_config(
|
|
137
|
+
Float(init_alpha),
|
|
138
|
+
Integer(num_topics),
|
|
139
|
+
Integer(max_iter),
|
|
140
|
+
Float(convergence),
|
|
141
|
+
Integer(em_max_iter),
|
|
142
|
+
Float(em_convergence),
|
|
143
|
+
Integer(est_alpha)
|
|
144
|
+
)
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
def max_iter
|
|
148
|
+
@backend.max_iter
|
|
149
|
+
end
|
|
150
|
+
|
|
151
|
+
def max_iter=(value)
|
|
152
|
+
@backend.max_iter = Integer(value)
|
|
153
|
+
end
|
|
154
|
+
|
|
155
|
+
def convergence
|
|
156
|
+
@backend.convergence
|
|
157
|
+
end
|
|
158
|
+
|
|
159
|
+
def convergence=(value)
|
|
160
|
+
@backend.convergence = Float(value)
|
|
161
|
+
end
|
|
162
|
+
|
|
163
|
+
def em_max_iter
|
|
164
|
+
@backend.em_max_iter
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
def em_max_iter=(value)
|
|
168
|
+
@backend.em_max_iter = Integer(value)
|
|
169
|
+
end
|
|
170
|
+
|
|
171
|
+
def em_convergence
|
|
172
|
+
@backend.em_convergence
|
|
173
|
+
end
|
|
174
|
+
|
|
175
|
+
def em_convergence=(value)
|
|
176
|
+
@backend.em_convergence = Float(value)
|
|
177
|
+
end
|
|
178
|
+
|
|
179
|
+
def num_topics
|
|
180
|
+
@backend.num_topics
|
|
181
|
+
end
|
|
182
|
+
|
|
183
|
+
def num_topics=(value)
|
|
184
|
+
@backend.num_topics = Integer(value)
|
|
185
|
+
end
|
|
42
186
|
|
|
187
|
+
def init_alpha
|
|
188
|
+
@backend.init_alpha
|
|
189
|
+
end
|
|
190
|
+
|
|
191
|
+
def init_alpha=(value)
|
|
192
|
+
@backend.init_alpha = Float(value)
|
|
193
|
+
end
|
|
194
|
+
|
|
195
|
+
def est_alpha
|
|
196
|
+
@backend.est_alpha
|
|
197
|
+
end
|
|
198
|
+
|
|
199
|
+
def est_alpha=(value)
|
|
200
|
+
@backend.est_alpha = Integer(value)
|
|
201
|
+
end
|
|
202
|
+
|
|
203
|
+
def verbose
|
|
204
|
+
@backend.verbose
|
|
205
|
+
end
|
|
206
|
+
|
|
207
|
+
def verbose=(value)
|
|
208
|
+
@backend.verbose = !!value
|
|
209
|
+
end
|
|
210
|
+
|
|
211
|
+
def corpus=(corpus)
|
|
212
|
+
@corpus = corpus
|
|
213
|
+
@backend.corpus = corpus
|
|
43
214
|
true
|
|
44
215
|
end
|
|
45
216
|
|
|
217
|
+
def load_corpus(filename)
|
|
218
|
+
fast_load_corpus_from_file(filename)
|
|
219
|
+
end
|
|
220
|
+
|
|
221
|
+
def fast_load_corpus_from_file(filename)
|
|
222
|
+
loaded = @backend.fast_load_corpus_from_file(filename)
|
|
223
|
+
|
|
224
|
+
if @backend.corpus
|
|
225
|
+
@corpus = @backend.corpus
|
|
226
|
+
@vocab = @corpus.vocabulary.to_a if @corpus.respond_to?(:vocabulary) && @corpus.vocabulary
|
|
227
|
+
elsif @corpus.nil?
|
|
228
|
+
@corpus = DataCorpus.new(filename)
|
|
229
|
+
end
|
|
230
|
+
|
|
231
|
+
!!loaded
|
|
232
|
+
end
|
|
233
|
+
|
|
234
|
+
def load_settings(settings_file)
|
|
235
|
+
@backend.load_settings(settings_file)
|
|
236
|
+
end
|
|
237
|
+
|
|
46
238
|
def load_vocabulary(vocab)
|
|
47
239
|
if vocab.is_a?(Array)
|
|
48
|
-
@vocab = Marshal
|
|
240
|
+
@vocab = Marshal.load(Marshal.dump(vocab)) # deep clone array
|
|
49
241
|
elsif vocab.is_a?(Vocabulary)
|
|
50
242
|
@vocab = vocab.to_a
|
|
51
243
|
else
|
|
52
|
-
@vocab = File.
|
|
244
|
+
@vocab = File.read(vocab).split(/\s+/)
|
|
53
245
|
end
|
|
54
246
|
|
|
55
247
|
true
|
|
56
248
|
end
|
|
57
249
|
|
|
250
|
+
def em(start = "random")
|
|
251
|
+
@phi = nil
|
|
252
|
+
@backend.em(start.to_s)
|
|
253
|
+
end
|
|
254
|
+
|
|
255
|
+
def beta
|
|
256
|
+
@backend.beta
|
|
257
|
+
end
|
|
258
|
+
|
|
259
|
+
def gamma
|
|
260
|
+
@backend.gamma
|
|
261
|
+
end
|
|
262
|
+
|
|
263
|
+
def model
|
|
264
|
+
@backend.model
|
|
265
|
+
end
|
|
266
|
+
|
|
58
267
|
#
|
|
59
268
|
# Visualization method for printing out the top +words_per_topic+ words
|
|
60
269
|
# for each topic.
|
|
@@ -62,14 +271,18 @@ module Lda
|
|
|
62
271
|
# See also +top_words+.
|
|
63
272
|
#
|
|
64
273
|
def print_topics(words_per_topic = 10)
|
|
65
|
-
raise
|
|
274
|
+
raise "No vocabulary loaded." unless @vocab
|
|
66
275
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
276
|
+
beta.each_with_index do |topic, topic_num|
|
|
277
|
+
indices = topic
|
|
278
|
+
.each_with_index
|
|
279
|
+
.sort_by { |score, _index| score }
|
|
280
|
+
.reverse
|
|
281
|
+
.first(words_per_topic)
|
|
282
|
+
.map { |_score, index| index }
|
|
70
283
|
|
|
71
284
|
puts "Topic #{topic_num}"
|
|
72
|
-
puts "\t#{indices.map {|i| @vocab[i]}.join("\n\t")}"
|
|
285
|
+
puts "\t#{indices.map { |i| @vocab[i] }.join("\n\t")}"
|
|
73
286
|
puts ""
|
|
74
287
|
end
|
|
75
288
|
|
|
@@ -87,21 +300,24 @@ module Lda
|
|
|
87
300
|
# See also +print_topics+.
|
|
88
301
|
#
|
|
89
302
|
def top_word_indices(words_per_topic = 10)
|
|
90
|
-
raise
|
|
303
|
+
raise "No vocabulary loaded." unless @vocab
|
|
91
304
|
|
|
92
|
-
|
|
93
|
-
topics = Hash.new
|
|
94
|
-
indices = (0...@vocab.size).to_a
|
|
305
|
+
topics = {}
|
|
95
306
|
|
|
96
|
-
|
|
97
|
-
topics[topic_num] =
|
|
307
|
+
beta.each_with_index do |topic, topic_num|
|
|
308
|
+
topics[topic_num] = topic
|
|
309
|
+
.each_with_index
|
|
310
|
+
.sort_by { |score, _index| score }
|
|
311
|
+
.reverse
|
|
312
|
+
.first(words_per_topic)
|
|
313
|
+
.map { |_score, index| index }
|
|
98
314
|
end
|
|
99
315
|
|
|
100
316
|
topics
|
|
101
317
|
end
|
|
102
318
|
|
|
103
319
|
def top_words(words_per_topic = 10)
|
|
104
|
-
output =
|
|
320
|
+
output = {}
|
|
105
321
|
|
|
106
322
|
topics = top_word_indices(words_per_topic)
|
|
107
323
|
topics.each_pair do |topic_num, words|
|
|
@@ -118,49 +334,78 @@ module Lda
|
|
|
118
334
|
# after the first call, so if it needs to be recomputed, set the +recompute+
|
|
119
335
|
# value to true.
|
|
120
336
|
#
|
|
121
|
-
def phi(recompute=false)
|
|
122
|
-
if @phi.nil? || recompute
|
|
123
|
-
@phi = self.compute_phi
|
|
124
|
-
end
|
|
337
|
+
def phi(recompute = false)
|
|
338
|
+
@phi = compute_phi if @phi.nil? || recompute
|
|
125
339
|
|
|
126
340
|
@phi
|
|
127
341
|
end
|
|
128
342
|
|
|
343
|
+
def compute_phi
|
|
344
|
+
@backend.compute_phi
|
|
345
|
+
end
|
|
346
|
+
|
|
129
347
|
#
|
|
130
348
|
# Compute the average log probability for each topic for each document in the corpus.
|
|
131
349
|
# This method returns a matrix: num_docs x num_topics with the average log probability
|
|
132
350
|
# for the topic in the document.
|
|
133
351
|
#
|
|
134
352
|
def compute_topic_document_probability
|
|
135
|
-
|
|
353
|
+
phi_matrix = phi
|
|
354
|
+
document_counts = @corpus.documents.map(&:counts)
|
|
355
|
+
|
|
356
|
+
backend_output = @backend.topic_document_probability(phi_matrix, document_counts)
|
|
357
|
+
if valid_topic_document_probability_output?(backend_output, document_counts.size, num_topics)
|
|
358
|
+
return backend_output
|
|
359
|
+
end
|
|
360
|
+
|
|
361
|
+
outp = []
|
|
136
362
|
|
|
137
363
|
@corpus.documents.each_with_index do |doc, idx|
|
|
138
|
-
tops = [0.0] *
|
|
139
|
-
ttl
|
|
140
|
-
|
|
364
|
+
tops = [0.0] * num_topics
|
|
365
|
+
ttl = doc.counts.inject(0.0) { |sum, i| sum + i }
|
|
366
|
+
|
|
367
|
+
phi_matrix[idx].each_with_index do |word_dist, word_idx|
|
|
141
368
|
word_dist.each_with_index do |top_prob, top_idx|
|
|
142
|
-
tops[top_idx] += Math.log(top_prob) * doc.counts[word_idx]
|
|
369
|
+
tops[top_idx] += Math.log([top_prob, 1e-300].max) * doc.counts[word_idx]
|
|
143
370
|
end
|
|
144
371
|
end
|
|
145
|
-
|
|
372
|
+
|
|
373
|
+
tops = tops.map { |i| i / ttl }
|
|
146
374
|
outp << tops
|
|
147
375
|
end
|
|
148
376
|
|
|
149
377
|
outp
|
|
150
378
|
end
|
|
151
379
|
|
|
380
|
+
def valid_topic_document_probability_output?(output, expected_docs, expected_topics)
|
|
381
|
+
return false unless output.is_a?(Array)
|
|
382
|
+
return false unless output.size == expected_docs
|
|
383
|
+
|
|
384
|
+
output.each do |row|
|
|
385
|
+
return false unless row.is_a?(Array)
|
|
386
|
+
return false unless row.size == expected_topics
|
|
387
|
+
row.each do |value|
|
|
388
|
+
return false unless value.is_a?(Numeric)
|
|
389
|
+
return false unless value.finite?
|
|
390
|
+
end
|
|
391
|
+
end
|
|
392
|
+
|
|
393
|
+
true
|
|
394
|
+
end
|
|
395
|
+
|
|
152
396
|
#
|
|
153
397
|
# String representation displaying current settings.
|
|
154
398
|
#
|
|
155
399
|
def to_s
|
|
156
400
|
outp = ["LDA Settings:"]
|
|
157
|
-
outp << " Initial alpha: %0.6f"
|
|
158
|
-
outp << " # of topics: %d"
|
|
159
|
-
outp << " Max iterations: %d"
|
|
160
|
-
outp << " Convergence: %0.6f"
|
|
161
|
-
outp << "EM max iterations: %d"
|
|
162
|
-
outp << " EM convergence: %0.6f"
|
|
163
|
-
outp << " Estimate alpha: %d"
|
|
401
|
+
outp << format(" Initial alpha: %0.6f", init_alpha)
|
|
402
|
+
outp << format(" # of topics: %d", num_topics)
|
|
403
|
+
outp << format(" Max iterations: %d", max_iter)
|
|
404
|
+
outp << format(" Convergence: %0.6f", convergence)
|
|
405
|
+
outp << format("EM max iterations: %d", em_max_iter)
|
|
406
|
+
outp << format(" EM convergence: %0.6f", em_convergence)
|
|
407
|
+
outp << format(" Estimate alpha: %d", est_alpha)
|
|
408
|
+
outp << format(" Backend: %s", backend_name)
|
|
164
409
|
|
|
165
410
|
outp.join("\n")
|
|
166
411
|
end
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
require_relative "test_helper"
|
|
2
|
+
|
|
3
|
+
class BackendCompatibilityTest < Test::Unit::TestCase
|
|
4
|
+
FIXTURE_DOCUMENTS = [
|
|
5
|
+
"apple banana apple banana fruit sweet fruit",
|
|
6
|
+
"truck wheel truck road engine metal road",
|
|
7
|
+
"ruby code gem ruby class module test",
|
|
8
|
+
"banana fruit apple orchard fresh sweet",
|
|
9
|
+
"engine road truck wheel fuel highway",
|
|
10
|
+
"module ruby class object gem code"
|
|
11
|
+
].freeze
|
|
12
|
+
|
|
13
|
+
def setup
|
|
14
|
+
@corpus = Lda::TextCorpus.new(FIXTURE_DOCUMENTS)
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def test_pure_backend_seeded_fixture
|
|
18
|
+
lda = build_and_train(:pure)
|
|
19
|
+
|
|
20
|
+
assert_equal "pure_ruby", lda.backend_name
|
|
21
|
+
assert_backend_output_valid(lda)
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def test_native_backend_seeded_fixture
|
|
25
|
+
return unless Lda::NATIVE_EXTENSION_LOADED
|
|
26
|
+
|
|
27
|
+
lda = build_and_train(:native)
|
|
28
|
+
|
|
29
|
+
assert_equal "native", lda.backend_name
|
|
30
|
+
assert_backend_output_valid(lda)
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
def test_native_and_pure_backend_agree_on_shapes
|
|
34
|
+
return unless Lda::NATIVE_EXTENSION_LOADED
|
|
35
|
+
|
|
36
|
+
native = build_and_train(:native)
|
|
37
|
+
pure = build_and_train(:pure)
|
|
38
|
+
|
|
39
|
+
assert_equal native.model[0], pure.model[0]
|
|
40
|
+
assert_equal native.model[1], pure.model[1]
|
|
41
|
+
assert_equal native.beta.size, pure.beta.size
|
|
42
|
+
assert_equal native.gamma.size, pure.gamma.size
|
|
43
|
+
assert_equal native.phi.size, pure.phi.size
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
def test_rust_backend_seeded_fixture
|
|
47
|
+
return unless Lda::RUST_EXTENSION_LOADED
|
|
48
|
+
|
|
49
|
+
rust = build_and_train(:rust)
|
|
50
|
+
|
|
51
|
+
assert_equal "rust", rust.backend_name
|
|
52
|
+
assert_backend_output_valid(rust)
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
def test_rust_and_pure_backend_numeric_parity
|
|
56
|
+
return unless Lda::RUST_EXTENSION_LOADED
|
|
57
|
+
|
|
58
|
+
pure = build_and_train(:pure)
|
|
59
|
+
rust = build_and_train(:rust)
|
|
60
|
+
|
|
61
|
+
assert_nested_close(pure.gamma, rust.gamma, 1e-9)
|
|
62
|
+
assert_nested_close(pure.beta, rust.beta, 1e-9)
|
|
63
|
+
assert_nested_close(pure.phi, rust.phi, 1e-9)
|
|
64
|
+
assert_nested_close(
|
|
65
|
+
exponentiate_nested(pure.compute_topic_document_probability),
|
|
66
|
+
exponentiate_nested(rust.compute_topic_document_probability),
|
|
67
|
+
1e-6
|
|
68
|
+
)
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
private
|
|
72
|
+
|
|
73
|
+
def build_and_train(backend)
|
|
74
|
+
lda = Lda::Lda.new(@corpus, backend: backend, random_seed: 1234)
|
|
75
|
+
lda.verbose = false
|
|
76
|
+
lda.num_topics = 3
|
|
77
|
+
lda.max_iter = 25
|
|
78
|
+
lda.em_max_iter = 40
|
|
79
|
+
lda.convergence = 1e-5
|
|
80
|
+
lda.em_convergence = 1e-4
|
|
81
|
+
lda.em("seeded")
|
|
82
|
+
lda
|
|
83
|
+
end
|
|
84
|
+
|
|
85
|
+
def assert_backend_output_valid(lda)
|
|
86
|
+
assert_equal 3, lda.model[0]
|
|
87
|
+
assert lda.model[1] > 0
|
|
88
|
+
|
|
89
|
+
assert_equal @corpus.num_docs, lda.gamma.size
|
|
90
|
+
lda.gamma.each do |topic_weights|
|
|
91
|
+
assert_equal 3, topic_weights.size
|
|
92
|
+
topic_weights.each do |weight|
|
|
93
|
+
assert weight.is_a?(Numeric)
|
|
94
|
+
assert weight.finite?
|
|
95
|
+
assert weight.positive?
|
|
96
|
+
end
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
assert_equal 3, lda.beta.size
|
|
100
|
+
lda.beta.each do |topic_log_probs|
|
|
101
|
+
assert topic_log_probs.size > 0
|
|
102
|
+
probabilities = topic_log_probs.map { |log_prob| Math.exp(log_prob) }
|
|
103
|
+
assert_in_delta 1.0, probabilities.sum, 1e-3
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
phi = lda.phi
|
|
107
|
+
assert_equal @corpus.num_docs, phi.size
|
|
108
|
+
phi.each_with_index do |doc_phi, doc_index|
|
|
109
|
+
assert_equal @corpus.documents[doc_index].length, doc_phi.size
|
|
110
|
+
doc_phi.each do |word_topic_distribution|
|
|
111
|
+
assert_equal 3, word_topic_distribution.size
|
|
112
|
+
assert_in_delta 1.0, word_topic_distribution.sum, 1e-3
|
|
113
|
+
end
|
|
114
|
+
end
|
|
115
|
+
|
|
116
|
+
probabilities = lda.compute_topic_document_probability
|
|
117
|
+
assert_equal @corpus.num_docs, probabilities.size
|
|
118
|
+
probabilities.each do |row|
|
|
119
|
+
assert_equal 3, row.size
|
|
120
|
+
row.each { |value| assert value.finite? }
|
|
121
|
+
end
|
|
122
|
+
|
|
123
|
+
top_words = lda.top_words(4)
|
|
124
|
+
assert_equal 3, top_words.size
|
|
125
|
+
top_words.each_value { |words| assert_equal 4, words.size }
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
def assert_nested_close(left, right, tolerance)
|
|
129
|
+
assert_equal left.class, right.class
|
|
130
|
+
|
|
131
|
+
if left.is_a?(Array)
|
|
132
|
+
assert_equal left.size, right.size
|
|
133
|
+
left.each_with_index do |left_item, index|
|
|
134
|
+
assert_nested_close(left_item, right[index], tolerance)
|
|
135
|
+
end
|
|
136
|
+
else
|
|
137
|
+
assert_in_delta left.to_f, right.to_f, tolerance
|
|
138
|
+
end
|
|
139
|
+
end
|
|
140
|
+
|
|
141
|
+
def exponentiate_nested(value)
|
|
142
|
+
return Math.exp(value.to_f) unless value.is_a?(Array)
|
|
143
|
+
|
|
144
|
+
value.map { |item| exponentiate_nested(item) }
|
|
145
|
+
end
|
|
146
|
+
end
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
require_relative "test_helper"
|
|
2
|
+
|
|
3
|
+
class BackendsSelectionTest < Test::Unit::TestCase
|
|
4
|
+
RUST_ALIAS = :__test_original_rust_available__
|
|
5
|
+
NATIVE_ALIAS = :__test_original_native_available__
|
|
6
|
+
|
|
7
|
+
setup do
|
|
8
|
+
@host = Object.new
|
|
9
|
+
@rust_singleton = Lda::Backends::Rust.singleton_class
|
|
10
|
+
@native_singleton = Lda::Backends::Native.singleton_class
|
|
11
|
+
|
|
12
|
+
@rust_singleton.send(:alias_method, RUST_ALIAS, :available?)
|
|
13
|
+
@native_singleton.send(:alias_method, NATIVE_ALIAS, :available?)
|
|
14
|
+
@previous_env_backend = ENV["LDA_RUBY_BACKEND"]
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
teardown do
|
|
18
|
+
restore_availability_stubs
|
|
19
|
+
ENV["LDA_RUBY_BACKEND"] = @previous_env_backend
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
should "prefer rust over native in auto mode when both are available" do
|
|
23
|
+
stub_rust_available(true)
|
|
24
|
+
stub_native_available(true)
|
|
25
|
+
|
|
26
|
+
backend = Lda::Backends.build(host: @host, requested: :auto)
|
|
27
|
+
assert_instance_of Lda::Backends::Rust, backend
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
should "fall back to native in auto mode when rust is unavailable" do
|
|
31
|
+
stub_rust_available(false)
|
|
32
|
+
stub_native_available(true)
|
|
33
|
+
|
|
34
|
+
backend = Lda::Backends.build(host: @host, requested: :auto)
|
|
35
|
+
assert_instance_of Lda::Backends::Native, backend
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
should "fall back to pure in auto mode when rust and native are unavailable" do
|
|
39
|
+
stub_rust_available(false)
|
|
40
|
+
stub_native_available(false)
|
|
41
|
+
|
|
42
|
+
backend = Lda::Backends.build(host: @host, requested: :auto)
|
|
43
|
+
assert_instance_of Lda::Backends::PureRuby, backend
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
should "respect LDA_RUBY_BACKEND env override when requested mode is nil" do
|
|
47
|
+
stub_rust_available(true)
|
|
48
|
+
stub_native_available(true)
|
|
49
|
+
ENV["LDA_RUBY_BACKEND"] = "pure_ruby"
|
|
50
|
+
|
|
51
|
+
backend = Lda::Backends.build(host: @host, requested: nil)
|
|
52
|
+
assert_instance_of Lda::Backends::PureRuby, backend
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
should "raise for unknown backend mode" do
|
|
56
|
+
stub_rust_available(false)
|
|
57
|
+
stub_native_available(false)
|
|
58
|
+
|
|
59
|
+
error = assert_raise(ArgumentError) do
|
|
60
|
+
Lda::Backends.build(host: @host, requested: :unknown_backend)
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
assert_match(/Unknown backend mode/i, error.message)
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
private
|
|
67
|
+
|
|
68
|
+
def stub_rust_available(value)
|
|
69
|
+
silence_redefinition_warnings do
|
|
70
|
+
@rust_singleton.send(:define_method, :available?) do
|
|
71
|
+
value
|
|
72
|
+
end
|
|
73
|
+
end
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
def stub_native_available(value)
|
|
77
|
+
silence_redefinition_warnings do
|
|
78
|
+
@native_singleton.send(:define_method, :available?) do |_host|
|
|
79
|
+
value
|
|
80
|
+
end
|
|
81
|
+
end
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
def restore_availability_stubs
|
|
85
|
+
silence_redefinition_warnings do
|
|
86
|
+
@rust_singleton.send(:alias_method, :available?, RUST_ALIAS)
|
|
87
|
+
@native_singleton.send(:alias_method, :available?, NATIVE_ALIAS)
|
|
88
|
+
end
|
|
89
|
+
@rust_singleton.send(:remove_method, RUST_ALIAS)
|
|
90
|
+
@native_singleton.send(:remove_method, NATIVE_ALIAS)
|
|
91
|
+
end
|
|
92
|
+
|
|
93
|
+
def silence_redefinition_warnings
|
|
94
|
+
previous_verbose = $VERBOSE
|
|
95
|
+
$VERBOSE = nil
|
|
96
|
+
yield
|
|
97
|
+
ensure
|
|
98
|
+
$VERBOSE = previous_verbose
|
|
99
|
+
end
|
|
100
|
+
end
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
require_relative "test_helper"
|
|
2
|
+
|
|
3
|
+
class GemspecTest < Test::Unit::TestCase
|
|
4
|
+
def test_gemspec_excludes_local_rust_build_artifacts
|
|
5
|
+
spec = Gem::Specification.load(File.expand_path("../lda-ruby.gemspec", __dir__))
|
|
6
|
+
assert_not_nil spec
|
|
7
|
+
|
|
8
|
+
rust_target_files = spec.files.grep(%r{\Aext/lda-ruby-rust/target/})
|
|
9
|
+
assert_equal [], rust_target_files
|
|
10
|
+
assert(!spec.files.include?("ext/lda-ruby-rust/Cargo.lock"))
|
|
11
|
+
assert(!spec.files.include?("ext/lda-ruby-rust/Makefile"))
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def test_gemspec_declares_rust_extconf
|
|
15
|
+
spec = Gem::Specification.load(File.expand_path("../lda-ruby.gemspec", __dir__))
|
|
16
|
+
assert_not_nil spec
|
|
17
|
+
|
|
18
|
+
assert(spec.extensions.include?("ext/lda-ruby-rust/extconf.rb"))
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def test_gemspec_includes_release_runbook
|
|
22
|
+
spec = Gem::Specification.load(File.expand_path("../lda-ruby.gemspec", __dir__))
|
|
23
|
+
assert_not_nil spec
|
|
24
|
+
|
|
25
|
+
assert(spec.files.include?("docs/release-runbook.md"))
|
|
26
|
+
end
|
|
27
|
+
end
|