tomoto 0.1.1 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +23 -2
- data/ext/tomoto/ext.cpp +52 -25
- data/lib/tomoto.rb +1 -0
- data/lib/tomoto/ct.rb +1 -1
- data/lib/tomoto/dmr.rb +5 -1
- data/lib/tomoto/dt.rb +1 -1
- data/lib/tomoto/gdmr.rb +1 -1
- data/lib/tomoto/hdp.rb +1 -1
- data/lib/tomoto/hlda.rb +14 -1
- data/lib/tomoto/hpa.rb +1 -1
- data/lib/tomoto/lda.rb +95 -3
- data/lib/tomoto/llda.rb +1 -1
- data/lib/tomoto/mglda.rb +1 -1
- data/lib/tomoto/pa.rb +1 -1
- data/lib/tomoto/plda.rb +1 -1
- data/lib/tomoto/slda.rb +1 -1
- data/lib/tomoto/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: dd4c36ff621f73c38bb066694a932f0a682c18591ddf05a9a0764bea0b6e4430
|
4
|
+
data.tar.gz: 551e56c4bc17fb5a3a0aeac0db055960fcc5e45bf097bf88c7cbf9046f958e7d
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 565a91d0bb6d48142f38dc3d9e798ddb99bf41fda32762295362075fba972eea6b56b6bde126eab74677eba5fd525581b68c5efa73361a46fcb0b2796ab63684
|
7
|
+
data.tar.gz: 415193e4eb6adbe5dce05328aadf9acb91f4acc50951484183a956455d7336f93961fe145465b1eeffaae78dad37ee1452defe832514c72b3c032860ed433cc8
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -23,7 +23,13 @@ model = Tomoto::LDA.new(k: 3)
|
|
23
23
|
model.add_doc("text from document one")
|
24
24
|
model.add_doc("text from document two")
|
25
25
|
model.add_doc("text from document three")
|
26
|
-
model.train(100)
|
26
|
+
model.train(100) # iterations
|
27
|
+
```
|
28
|
+
|
29
|
+
Get the summary
|
30
|
+
|
31
|
+
```ruby
|
32
|
+
model.summary
|
27
33
|
```
|
28
34
|
|
29
35
|
Get topic words
|
@@ -89,6 +95,11 @@ This library follows the [tomotopy API](https://bab2min.github.io/tomotopy/v0.9.
|
|
89
95
|
|
90
96
|
If a method or option you need isn’t supported, feel free to open an issue.
|
91
97
|
|
98
|
+
## Examples
|
99
|
+
|
100
|
+
- [LDA](examples/lda_basic.rb)
|
101
|
+
- [HDP](examples/hdp.rb)
|
102
|
+
|
92
103
|
## Tokenization
|
93
104
|
|
94
105
|
Documents are tokenized by whitespace by default, or you can perform your own tokenization.
|
@@ -99,12 +110,22 @@ model.add_doc(["tokens", "from", "document", "one"])
|
|
99
110
|
|
100
111
|
## Performance
|
101
112
|
|
102
|
-
tomoto uses AVX2, AVX, or SSE2 instructions to increase performance on machines that support it. Check
|
113
|
+
tomoto uses AVX2, AVX, or SSE2 instructions to increase performance on machines that support it. Check which instruction set architecture it’s using with:
|
103
114
|
|
104
115
|
```ruby
|
105
116
|
Tomoto.isa
|
106
117
|
```
|
107
118
|
|
119
|
+
## Parallelism
|
120
|
+
|
121
|
+
Choose a [parallelism algorithm](https://bab2min.github.io/tomotopy/v0.9.0/en/#parallel-sampling-algorithms) with:
|
122
|
+
|
123
|
+
```ruby
|
124
|
+
model.train(parallel: :partition)
|
125
|
+
```
|
126
|
+
|
127
|
+
Supported values are `:default`, `:none`, `:copy_merge`, and `:partition`.
|
128
|
+
|
108
129
|
## History
|
109
130
|
|
110
131
|
View the [changelog](https://github.com/ankane/tomoto/blob/master/CHANGELOG.md)
|
data/ext/tomoto/ext.cpp
CHANGED
@@ -31,7 +31,7 @@ using Rice::define_class_under;
|
|
31
31
|
using Rice::define_module;
|
32
32
|
|
33
33
|
template<>
|
34
|
-
Object to_ruby<std::vector<
|
34
|
+
Object to_ruby<std::vector<tomoto::Float>>(std::vector<tomoto::Float> const & x)
|
35
35
|
{
|
36
36
|
Array res;
|
37
37
|
for (auto const& v : x) {
|
@@ -73,13 +73,13 @@ std::vector<std::string> from_ruby<std::vector<std::string>>(Object x)
|
|
73
73
|
}
|
74
74
|
|
75
75
|
template<>
|
76
|
-
std::vector<
|
76
|
+
std::vector<tomoto::Float> from_ruby<std::vector<tomoto::Float>>(Object x)
|
77
77
|
{
|
78
78
|
Array a = Array(x);
|
79
|
-
std::vector<
|
79
|
+
std::vector<tomoto::Float> res;
|
80
80
|
res.reserve(a.size());
|
81
81
|
for (auto const& v : a) {
|
82
|
-
res.push_back(from_ruby<
|
82
|
+
res.push_back(from_ruby<tomoto::Float>(v));
|
83
83
|
}
|
84
84
|
return res;
|
85
85
|
}
|
@@ -117,7 +117,7 @@ void Init_ext()
|
|
117
117
|
Class rb_cLDA = define_class_under<tomoto::ILDAModel>(rb_mTomoto, "LDA")
|
118
118
|
.define_singleton_method(
|
119
119
|
"_new",
|
120
|
-
*[](size_t tw, size_t k,
|
120
|
+
*[](size_t tw, size_t k, tomoto::Float alpha, tomoto::Float eta, int seed) {
|
121
121
|
if (seed < 0) {
|
122
122
|
seed = std::random_device{}();
|
123
123
|
}
|
@@ -131,7 +131,11 @@ void Init_ext()
|
|
131
131
|
.define_method(
|
132
132
|
"alpha",
|
133
133
|
*[](tomoto::ILDAModel& self) {
|
134
|
-
|
134
|
+
Array res;
|
135
|
+
for (size_t i = 0; i < self.getK(); i++) {
|
136
|
+
res.push(self.getAlpha(i));
|
137
|
+
}
|
138
|
+
return res;
|
135
139
|
})
|
136
140
|
.define_method(
|
137
141
|
"burn_in",
|
@@ -246,8 +250,7 @@ void Init_ext()
|
|
246
250
|
})
|
247
251
|
.define_method(
|
248
252
|
"_train",
|
249
|
-
*[](tomoto::ILDAModel& self, size_t iteration, size_t workers) {
|
250
|
-
size_t ps = 0;
|
253
|
+
*[](tomoto::ILDAModel& self, size_t iteration, size_t workers, size_t ps) {
|
251
254
|
self.train(iteration, workers, (tomoto::ParallelScheme)ps);
|
252
255
|
})
|
253
256
|
.define_method(
|
@@ -321,7 +324,7 @@ void Init_ext()
|
|
321
324
|
Class rb_cCT = define_class_under<tomoto::ICTModel, tomoto::ILDAModel>(rb_mTomoto, "CT")
|
322
325
|
.define_singleton_method(
|
323
326
|
"_new",
|
324
|
-
*[](size_t tw, size_t k,
|
327
|
+
*[](size_t tw, size_t k, tomoto::Float alpha, tomoto::Float eta, int seed) {
|
325
328
|
if (seed < 0) {
|
326
329
|
seed = std::random_device{}();
|
327
330
|
}
|
@@ -368,7 +371,7 @@ void Init_ext()
|
|
368
371
|
Class rb_cDMR = define_class_under<tomoto::IDMRModel, tomoto::ILDAModel>(rb_mTomoto, "DMR")
|
369
372
|
.define_singleton_method(
|
370
373
|
"_new",
|
371
|
-
*[](size_t tw, size_t k,
|
374
|
+
*[](size_t tw, size_t k, tomoto::Float alpha, tomoto::Float sigma, tomoto::Float eta, tomoto::Float alpha_epsilon, int seed) {
|
372
375
|
if (seed < 0) {
|
373
376
|
seed = std::random_device{}();
|
374
377
|
}
|
@@ -386,7 +389,7 @@ void Init_ext()
|
|
386
389
|
})
|
387
390
|
.define_method(
|
388
391
|
"alpha_epsilon=",
|
389
|
-
*[](tomoto::IDMRModel& self,
|
392
|
+
*[](tomoto::IDMRModel& self, tomoto::Float value) {
|
390
393
|
self.setAlphaEps(value);
|
391
394
|
return value;
|
392
395
|
})
|
@@ -420,7 +423,7 @@ void Init_ext()
|
|
420
423
|
Class rb_cDT = define_class_under<tomoto::IDTModel, tomoto::ILDAModel>(rb_mTomoto, "DT")
|
421
424
|
.define_singleton_method(
|
422
425
|
"_new",
|
423
|
-
*[](size_t tw, size_t k, size_t t,
|
426
|
+
*[](size_t tw, size_t k, size_t t, tomoto::Float alphaVar, tomoto::Float etaVar, tomoto::Float phiVar, tomoto::Float shapeA, tomoto::Float shapeB, tomoto::Float shapeC) {
|
424
427
|
// Rice only supports 10 arguments
|
425
428
|
int seed = -1;
|
426
429
|
if (seed < 0) {
|
@@ -440,7 +443,7 @@ void Init_ext()
|
|
440
443
|
})
|
441
444
|
.define_method(
|
442
445
|
"lr_a=",
|
443
|
-
*[](tomoto::IDTModel& self,
|
446
|
+
*[](tomoto::IDTModel& self, tomoto::Float value) {
|
444
447
|
self.setShapeA(value);
|
445
448
|
return value;
|
446
449
|
})
|
@@ -451,7 +454,7 @@ void Init_ext()
|
|
451
454
|
})
|
452
455
|
.define_method(
|
453
456
|
"lr_b=",
|
454
|
-
*[](tomoto::IDTModel& self,
|
457
|
+
*[](tomoto::IDTModel& self, tomoto::Float value) {
|
455
458
|
self.setShapeB(value);
|
456
459
|
return value;
|
457
460
|
})
|
@@ -462,7 +465,7 @@ void Init_ext()
|
|
462
465
|
})
|
463
466
|
.define_method(
|
464
467
|
"lr_c=",
|
465
|
-
*[](tomoto::IDTModel& self,
|
468
|
+
*[](tomoto::IDTModel& self, tomoto::Float value) {
|
466
469
|
self.setShapeC(value);
|
467
470
|
return value;
|
468
471
|
})
|
@@ -480,7 +483,7 @@ void Init_ext()
|
|
480
483
|
Class rb_cGDMR = define_class_under<tomoto::IGDMRModel, tomoto::IDMRModel>(rb_mTomoto, "GDMR")
|
481
484
|
.define_singleton_method(
|
482
485
|
"_new",
|
483
|
-
*[](size_t tw, size_t k, std::vector<uint64_t> degrees,
|
486
|
+
*[](size_t tw, size_t k, std::vector<uint64_t> degrees, tomoto::Float alpha, tomoto::Float sigma, tomoto::Float sigma0, tomoto::Float eta, tomoto::Float alpha_epsilon, int seed) {
|
484
487
|
if (seed < 0) {
|
485
488
|
seed = std::random_device{}();
|
486
489
|
}
|
@@ -500,12 +503,17 @@ void Init_ext()
|
|
500
503
|
Class rb_cHDP = define_class_under<tomoto::IHDPModel, tomoto::ILDAModel>(rb_mTomoto, "HDP")
|
501
504
|
.define_singleton_method(
|
502
505
|
"_new",
|
503
|
-
*[](size_t tw, size_t k,
|
506
|
+
*[](size_t tw, size_t k, tomoto::Float alpha, tomoto::Float eta, tomoto::Float gamma, int seed) {
|
504
507
|
if (seed < 0) {
|
505
508
|
seed = std::random_device{}();
|
506
509
|
}
|
507
510
|
return tomoto::IHDPModel::create((tomoto::TermWeight)tw, k, alpha, eta, gamma, seed);
|
508
511
|
})
|
512
|
+
.define_method(
|
513
|
+
"alpha",
|
514
|
+
*[](tomoto::IHDPModel& self) {
|
515
|
+
return self.getAlpha();
|
516
|
+
})
|
509
517
|
.define_method(
|
510
518
|
"gamma",
|
511
519
|
*[](tomoto::IHDPModel& self) {
|
@@ -530,12 +538,21 @@ void Init_ext()
|
|
530
538
|
Class rb_cHLDA = define_class_under<tomoto::IHLDAModel, tomoto::ILDAModel>(rb_mTomoto, "HLDA")
|
531
539
|
.define_singleton_method(
|
532
540
|
"_new",
|
533
|
-
*[](size_t tw, size_t levelDepth,
|
541
|
+
*[](size_t tw, size_t levelDepth, tomoto::Float alpha, tomoto::Float eta, tomoto::Float gamma, int seed) {
|
534
542
|
if (seed < 0) {
|
535
543
|
seed = std::random_device{}();
|
536
544
|
}
|
537
545
|
return tomoto::IHLDAModel::create((tomoto::TermWeight)tw, levelDepth, alpha, eta, gamma, seed);
|
538
546
|
})
|
547
|
+
.define_method(
|
548
|
+
"alpha",
|
549
|
+
*[](tomoto::IHLDAModel& self) {
|
550
|
+
Array res;
|
551
|
+
for (size_t i = 0; i < self.getLevelDepth(); i++) {
|
552
|
+
res.push(self.getAlpha(i));
|
553
|
+
}
|
554
|
+
return res;
|
555
|
+
})
|
539
556
|
.define_method(
|
540
557
|
"_children_topics",
|
541
558
|
*[](tomoto::IHLDAModel& self, tomoto::Tid topic_id) {
|
@@ -580,7 +597,7 @@ void Init_ext()
|
|
580
597
|
Class rb_cPA = define_class_under<tomoto::IPAModel, tomoto::ILDAModel>(rb_mTomoto, "PA")
|
581
598
|
.define_singleton_method(
|
582
599
|
"_new",
|
583
|
-
*[](size_t tw, size_t k1, size_t k2,
|
600
|
+
*[](size_t tw, size_t k1, size_t k2, tomoto::Float alpha, tomoto::Float eta, int seed) {
|
584
601
|
if (seed < 0) {
|
585
602
|
seed = std::random_device{}();
|
586
603
|
}
|
@@ -600,17 +617,27 @@ void Init_ext()
|
|
600
617
|
Class rb_cHPA = define_class_under<tomoto::IHPAModel, tomoto::IPAModel>(rb_mTomoto, "HPA")
|
601
618
|
.define_singleton_method(
|
602
619
|
"_new",
|
603
|
-
*[](size_t tw, size_t k1, size_t k2,
|
620
|
+
*[](size_t tw, size_t k1, size_t k2, tomoto::Float alpha, tomoto::Float eta, int seed) {
|
604
621
|
if (seed < 0) {
|
605
622
|
seed = std::random_device{}();
|
606
623
|
}
|
607
624
|
return tomoto::IHPAModel::create((tomoto::TermWeight)tw, false, k1, k2, alpha, eta, seed);
|
625
|
+
})
|
626
|
+
.define_method(
|
627
|
+
"alpha",
|
628
|
+
*[](tomoto::IHPAModel& self) {
|
629
|
+
Array res;
|
630
|
+
// use <= to return k+1 elements
|
631
|
+
for (size_t i = 0; i <= self.getK(); i++) {
|
632
|
+
res.push(self.getAlpha(i));
|
633
|
+
}
|
634
|
+
return res;
|
608
635
|
});
|
609
636
|
|
610
637
|
Class rb_cMGLDA = define_class_under<tomoto::IMGLDAModel, tomoto::ILDAModel>(rb_mTomoto, "MGLDA")
|
611
638
|
.define_singleton_method(
|
612
639
|
"_new",
|
613
|
-
*[](size_t tw, size_t k_g, size_t k_l, size_t t,
|
640
|
+
*[](size_t tw, size_t k_g, size_t k_l, size_t t, tomoto::Float alpha_g, tomoto::Float alpha_l, tomoto::Float alpha_mg, tomoto::Float alpha_ml, tomoto::Float eta_g) {
|
614
641
|
return tomoto::IMGLDAModel::create((tomoto::TermWeight)tw, k_g, k_l, t, alpha_g, alpha_l, alpha_mg, alpha_ml, eta_g);
|
615
642
|
})
|
616
643
|
.define_method(
|
@@ -672,7 +699,7 @@ void Init_ext()
|
|
672
699
|
Class rb_cLLDA = define_class_under<tomoto::ILLDAModel, tomoto::ILDAModel>(rb_mTomoto, "LLDA")
|
673
700
|
.define_singleton_method(
|
674
701
|
"_new",
|
675
|
-
*[](size_t tw, size_t k,
|
702
|
+
*[](size_t tw, size_t k, tomoto::Float alpha, tomoto::Float eta, int seed) {
|
676
703
|
if (seed < 0) {
|
677
704
|
seed = std::random_device{}();
|
678
705
|
}
|
@@ -692,7 +719,7 @@ void Init_ext()
|
|
692
719
|
Class rb_cPLDA = define_class_under<tomoto::IPLDAModel, tomoto::ILLDAModel>(rb_mTomoto, "PLDA")
|
693
720
|
.define_singleton_method(
|
694
721
|
"_new",
|
695
|
-
*[](size_t tw, size_t latent_topics,
|
722
|
+
*[](size_t tw, size_t latent_topics, tomoto::Float alpha, tomoto::Float eta, int seed) {
|
696
723
|
if (seed < 0) {
|
697
724
|
seed = std::random_device{}();
|
698
725
|
}
|
@@ -712,7 +739,7 @@ void Init_ext()
|
|
712
739
|
Class rb_cSLDA = define_class_under<tomoto::ISLDAModel, tomoto::ILDAModel>(rb_mTomoto, "SLDA")
|
713
740
|
.define_singleton_method(
|
714
741
|
"_new",
|
715
|
-
*[](size_t tw, size_t k, Array rb_vars,
|
742
|
+
*[](size_t tw, size_t k, Array rb_vars, tomoto::Float alpha, tomoto::Float eta, std::vector<tomoto::Float> mu, std::vector<tomoto::Float> nu_sq, std::vector<tomoto::Float> glm_param, int seed) {
|
716
743
|
if (seed < 0) {
|
717
744
|
seed = std::random_device{}();
|
718
745
|
}
|
@@ -725,7 +752,7 @@ void Init_ext()
|
|
725
752
|
})
|
726
753
|
.define_method(
|
727
754
|
"_add_doc",
|
728
|
-
*[](tomoto::ISLDAModel& self, std::vector<std::string> words, std::vector<
|
755
|
+
*[](tomoto::ISLDAModel& self, std::vector<std::string> words, std::vector<tomoto::Float> y) {
|
729
756
|
self.addDoc(words, y);
|
730
757
|
})
|
731
758
|
.define_method(
|
data/lib/tomoto.rb
CHANGED
data/lib/tomoto/ct.rb
CHANGED
data/lib/tomoto/dmr.rb
CHANGED
@@ -5,7 +5,7 @@ module Tomoto
|
|
5
5
|
model.instance_variable_set(:@min_cf, min_cf)
|
6
6
|
model.instance_variable_set(:@min_df, min_df)
|
7
7
|
model.instance_variable_set(:@rm_top, rm_top)
|
8
|
-
model
|
8
|
+
init_params(model, binding)
|
9
9
|
end
|
10
10
|
|
11
11
|
def add_doc(doc, metadata: "")
|
@@ -19,5 +19,9 @@ module Tomoto
|
|
19
19
|
k.times.map { |i| _lambdas(i) }
|
20
20
|
end
|
21
21
|
end
|
22
|
+
|
23
|
+
def alpha
|
24
|
+
lambdas.map { |v| v.map { |v2| Math.exp(v2) } }
|
25
|
+
end
|
22
26
|
end
|
23
27
|
end
|
data/lib/tomoto/dt.rb
CHANGED
data/lib/tomoto/gdmr.rb
CHANGED
data/lib/tomoto/hdp.rb
CHANGED
data/lib/tomoto/hlda.rb
CHANGED
@@ -5,7 +5,7 @@ module Tomoto
|
|
5
5
|
model.instance_variable_set(:@min_cf, min_cf)
|
6
6
|
model.instance_variable_set(:@min_df, min_df)
|
7
7
|
model.instance_variable_set(:@rm_top, rm_top)
|
8
|
-
model
|
8
|
+
init_params(model, binding)
|
9
9
|
end
|
10
10
|
|
11
11
|
def children_topics(topic_id)
|
@@ -39,5 +39,18 @@ module Tomoto
|
|
39
39
|
raise "topic_id must be < K" if topic_id >= k
|
40
40
|
raise "train() should be called first" unless @prepared
|
41
41
|
end
|
42
|
+
|
43
|
+
def topics_info(summary, topic_word_top_n:)
|
44
|
+
counts = count_by_topics
|
45
|
+
|
46
|
+
nested_info = lambda do |k = 0, level = 0|
|
47
|
+
words = topic_words(k, top_n: topic_word_top_n).keys.join(" ")
|
48
|
+
summary << "| #{" " * level}##{k} (#{counts[k]}) : #{words}"
|
49
|
+
children_topics(k).sort.each do |c|
|
50
|
+
nested_info.call(c, level + 1)
|
51
|
+
end
|
52
|
+
end
|
53
|
+
nested_info.call
|
54
|
+
end
|
42
55
|
end
|
43
56
|
end
|
data/lib/tomoto/hpa.rb
CHANGED
data/lib/tomoto/lda.rb
CHANGED
@@ -5,7 +5,7 @@ module Tomoto
|
|
5
5
|
model.instance_variable_set(:@min_cf, min_cf)
|
6
6
|
model.instance_variable_set(:@min_df, min_df)
|
7
7
|
model.instance_variable_set(:@rm_top, rm_top)
|
8
|
-
model
|
8
|
+
init_params(model, binding)
|
9
9
|
end
|
10
10
|
|
11
11
|
def self.load(filename)
|
@@ -32,6 +32,42 @@ module Tomoto
|
|
32
32
|
_save(filename, full)
|
33
33
|
end
|
34
34
|
|
35
|
+
# returns string instead of printing
|
36
|
+
def summary(initial_hp: true, params: true, topic_word_top_n: 5)
|
37
|
+
summary = []
|
38
|
+
|
39
|
+
summary << "<Basic Info>"
|
40
|
+
basic_info(summary)
|
41
|
+
summary << "|"
|
42
|
+
|
43
|
+
summary << "<Training Info>"
|
44
|
+
training_info(summary)
|
45
|
+
summary << "|"
|
46
|
+
|
47
|
+
if initial_hp
|
48
|
+
summary << "<Initial Parameters>"
|
49
|
+
initial_params_info(summary)
|
50
|
+
summary << "|"
|
51
|
+
end
|
52
|
+
|
53
|
+
if params
|
54
|
+
summary << "<Parameters>"
|
55
|
+
params_info(summary)
|
56
|
+
summary << "|"
|
57
|
+
end
|
58
|
+
|
59
|
+
if topic_word_top_n > 0
|
60
|
+
summary << "<Topics>"
|
61
|
+
topics_info(summary, topic_word_top_n: topic_word_top_n)
|
62
|
+
summary << "|"
|
63
|
+
end
|
64
|
+
|
65
|
+
# skip ending |
|
66
|
+
summary.pop
|
67
|
+
|
68
|
+
summary.join("\n")
|
69
|
+
end
|
70
|
+
|
35
71
|
def topic_words(topic_id = nil, top_n: 10)
|
36
72
|
if topic_id
|
37
73
|
_topic_words(topic_id, top_n)
|
@@ -40,9 +76,9 @@ module Tomoto
|
|
40
76
|
end
|
41
77
|
end
|
42
78
|
|
43
|
-
def train(iterations = 10, workers: 0)
|
79
|
+
def train(iterations = 10, workers: 0, parallel: :default)
|
44
80
|
prepare
|
45
|
-
_train(iterations, workers)
|
81
|
+
_train(iterations, workers, to_ps(parallel))
|
46
82
|
end
|
47
83
|
|
48
84
|
def tw
|
@@ -64,12 +100,68 @@ module Tomoto
|
|
64
100
|
doc
|
65
101
|
end
|
66
102
|
|
103
|
+
def basic_info(summary)
|
104
|
+
sum = used_vocab_freq.sum.to_f
|
105
|
+
mapped = used_vocab_freq.map { |v| v / sum }
|
106
|
+
entropy = mapped.map { |v| v * Math.log(v) }.sum
|
107
|
+
|
108
|
+
summary << "| #{self.class.name.sub("Tomoto::", "")} (current version: #{VERSION})"
|
109
|
+
summary << "| #{num_docs} docs, #{num_words} words"
|
110
|
+
summary << "| Total Vocabs: #{vocabs.size}, Used Vocabs: #{used_vocabs.size}"
|
111
|
+
summary << "| Entropy of words: %.5f" % entropy
|
112
|
+
summary << "| Removed Vocabs: #{removed_top_words.any? ? removed_top_words.join(" ") : "<NA>"}"
|
113
|
+
end
|
114
|
+
|
115
|
+
def training_info(summary)
|
116
|
+
summary << "| Iterations: #{global_step}, Burn-in steps: #{burn_in}"
|
117
|
+
summary << "| Optimization Interval: #{optim_interval}"
|
118
|
+
summary << "| Log-likelihood per word: %.5f" % ll_per_word
|
119
|
+
end
|
120
|
+
|
121
|
+
def initial_params_info(summary)
|
122
|
+
if defined?(@init_params)
|
123
|
+
@init_params.each do |k, v|
|
124
|
+
summary << "| #{k}: #{v}"
|
125
|
+
end
|
126
|
+
else
|
127
|
+
summary << "| Not Available"
|
128
|
+
end
|
129
|
+
end
|
130
|
+
|
131
|
+
def params_info(summary)
|
132
|
+
summary << "| alpha (Dirichlet prior on the per-document topic distributions)"
|
133
|
+
summary << "| #{alpha}"
|
134
|
+
summary << "| eta (Dirichlet prior on the per-topic word distribution)"
|
135
|
+
summary << "| %.5f" % eta
|
136
|
+
end
|
137
|
+
|
138
|
+
def topics_info(summary, topic_word_top_n:)
|
139
|
+
counts = count_by_topics
|
140
|
+
topic_words(top_n: topic_word_top_n).each_with_index do |words, i|
|
141
|
+
summary << "| ##{i} (#{counts[i]}) : #{words.keys.join(" ")}"
|
142
|
+
end
|
143
|
+
end
|
144
|
+
|
145
|
+
def to_ps(ps)
|
146
|
+
PARALLEL_SCHEME.index(ps) || (raise ArgumentError, "Invalid parallel scheme: #{ps}")
|
147
|
+
end
|
148
|
+
|
67
149
|
class << self
|
68
150
|
private
|
69
151
|
|
70
152
|
def to_tw(tw)
|
71
153
|
TERM_WEIGHT.index(tw) || (raise ArgumentError, "Invalid tw: #{tw}")
|
72
154
|
end
|
155
|
+
|
156
|
+
def init_params(model, binding)
|
157
|
+
init_params = {}
|
158
|
+
method(:new).parameters.each do |v|
|
159
|
+
next if v[0] != :key
|
160
|
+
init_params[v[1]] = binding.local_variable_get(v[1]).inspect
|
161
|
+
end
|
162
|
+
model.instance_variable_set(:@init_params, init_params)
|
163
|
+
model
|
164
|
+
end
|
73
165
|
end
|
74
166
|
end
|
75
167
|
end
|
data/lib/tomoto/llda.rb
CHANGED
data/lib/tomoto/mglda.rb
CHANGED
data/lib/tomoto/pa.rb
CHANGED
data/lib/tomoto/plda.rb
CHANGED
data/lib/tomoto/slda.rb
CHANGED
data/lib/tomoto/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: tomoto
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-10-
|
11
|
+
date: 2020-10-11 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|