tomoto 0.1.4 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/ext/tomoto/ct.cpp +8 -4
- data/ext/tomoto/dmr.cpp +10 -4
- data/ext/tomoto/dt.cpp +13 -4
- data/ext/tomoto/extconf.rb +1 -1
- data/ext/tomoto/gdmr.cpp +14 -6
- data/ext/tomoto/hdp.cpp +9 -4
- data/ext/tomoto/hlda.cpp +9 -4
- data/ext/tomoto/hpa.cpp +9 -4
- data/ext/tomoto/lda.cpp +8 -4
- data/ext/tomoto/llda.cpp +8 -4
- data/ext/tomoto/mglda.cpp +11 -1
- data/ext/tomoto/pa.cpp +9 -4
- data/ext/tomoto/plda.cpp +8 -4
- data/ext/tomoto/slda.cpp +13 -5
- data/lib/tomoto/gdmr.rb +2 -2
- data/lib/tomoto/version.rb +1 -1
- data/vendor/EigenRand/EigenRand/Core.h +6 -1107
- data/vendor/EigenRand/EigenRand/Dists/Basic.h +490 -43
- data/vendor/EigenRand/EigenRand/Dists/Discrete.h +916 -285
- data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +85 -36
- data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +1038 -290
- data/vendor/EigenRand/EigenRand/EigenRand +2 -2
- data/vendor/EigenRand/EigenRand/Macro.h +4 -4
- data/vendor/EigenRand/EigenRand/MorePacketMath.h +54 -22
- data/vendor/EigenRand/EigenRand/MvDists/Multinomial.h +222 -0
- data/vendor/EigenRand/EigenRand/MvDists/MvNormal.h +492 -0
- data/vendor/EigenRand/EigenRand/PacketFilter.h +2 -2
- data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +2 -2
- data/vendor/EigenRand/EigenRand/RandUtils.h +65 -11
- data/vendor/EigenRand/EigenRand/doc.h +142 -25
- data/vendor/EigenRand/LICENSE +1 -1
- data/vendor/EigenRand/README.md +109 -24
- data/vendor/tomotopy/README.kr.rst +27 -6
- data/vendor/tomotopy/README.rst +29 -8
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +60 -12
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +2 -2
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +33 -21
- data/vendor/tomotopy/src/TopicModel/CT.h +8 -5
- data/vendor/tomotopy/src/TopicModel/CTModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +29 -23
- data/vendor/tomotopy/src/TopicModel/DMR.h +33 -4
- data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +231 -57
- data/vendor/tomotopy/src/TopicModel/DT.h +24 -5
- data/vendor/tomotopy/src/TopicModel/DTModel.cpp +2 -8
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +41 -28
- data/vendor/tomotopy/src/TopicModel/GDMR.h +31 -5
- data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +2 -7
- data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +211 -104
- data/vendor/tomotopy/src/TopicModel/HDP.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +52 -45
- data/vendor/tomotopy/src/TopicModel/HLDA.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +13 -16
- data/vendor/tomotopy/src/TopicModel/HPA.h +5 -2
- data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +51 -21
- data/vendor/tomotopy/src/TopicModel/LDA.h +9 -2
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +8 -8
- data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +70 -28
- data/vendor/tomotopy/src/TopicModel/LLDA.h +1 -2
- data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +22 -12
- data/vendor/tomotopy/src/TopicModel/MGLDA.h +12 -3
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +42 -19
- data/vendor/tomotopy/src/TopicModel/PA.h +9 -4
- data/vendor/tomotopy/src/TopicModel/PAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +48 -25
- data/vendor/tomotopy/src/TopicModel/PLDA.h +13 -2
- data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +27 -19
- data/vendor/tomotopy/src/TopicModel/PT.h +12 -5
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +2 -3
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +29 -14
- data/vendor/tomotopy/src/TopicModel/SLDA.h +18 -6
- data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +93 -43
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +58 -23
- data/vendor/tomotopy/src/Utils/AliasMethod.hpp +6 -6
- data/vendor/tomotopy/src/Utils/Dictionary.h +11 -0
- data/vendor/tomotopy/src/Utils/SharedString.hpp +26 -1
- data/vendor/tomotopy/src/Utils/Trie.hpp +46 -21
- data/vendor/tomotopy/src/Utils/Utils.hpp +99 -14
- data/vendor/tomotopy/src/Utils/exception.h +1 -1
- data/vendor/tomotopy/src/Utils/math.h +5 -7
- data/vendor/tomotopy/src/Utils/serializer.hpp +329 -201
- data/vendor/tomotopy/src/Utils/text.hpp +8 -0
- data/vendor/tomotopy/src/Utils/tvector.hpp +49 -7
- metadata +9 -7
|
@@ -2,8 +2,8 @@
|
|
|
2
2
|
* @file Macro.h
|
|
3
3
|
* @author bab2min (bab2min@gmail.com)
|
|
4
4
|
* @brief
|
|
5
|
-
* @version 0.
|
|
6
|
-
* @date 2020-07
|
|
5
|
+
* @version 0.3.0
|
|
6
|
+
* @date 2020-10-07
|
|
7
7
|
*
|
|
8
8
|
* @copyright Copyright (c) 2020
|
|
9
9
|
*
|
|
@@ -13,8 +13,8 @@
|
|
|
13
13
|
#define EIGENRAND_MACRO_H
|
|
14
14
|
|
|
15
15
|
#define EIGENRAND_WORLD_VERSION 0
|
|
16
|
-
#define EIGENRAND_MAJOR_VERSION
|
|
17
|
-
#define EIGENRAND_MINOR_VERSION
|
|
16
|
+
#define EIGENRAND_MAJOR_VERSION 3
|
|
17
|
+
#define EIGENRAND_MINOR_VERSION 2
|
|
18
18
|
|
|
19
19
|
#if EIGEN_VERSION_AT_LEAST(3,3,7)
|
|
20
20
|
#else
|
|
@@ -2,8 +2,8 @@
|
|
|
2
2
|
* @file MorePacketMath.h
|
|
3
3
|
* @author bab2min (bab2min@gmail.com)
|
|
4
4
|
* @brief
|
|
5
|
-
* @version 0.
|
|
6
|
-
* @date 2020-
|
|
5
|
+
* @version 0.3.0
|
|
6
|
+
* @date 2020-10-07
|
|
7
7
|
*
|
|
8
8
|
* @copyright Copyright (c) 2020
|
|
9
9
|
*
|
|
@@ -195,8 +195,8 @@ namespace Eigen
|
|
|
195
195
|
template<typename Packet>
|
|
196
196
|
EIGEN_STRONG_INLINE Packet pcmple(const Packet& a, const Packet& b);
|
|
197
197
|
|
|
198
|
-
template<typename Packet>
|
|
199
|
-
EIGEN_STRONG_INLINE Packet pblendv(const
|
|
198
|
+
template<typename PacketIf, typename Packet>
|
|
199
|
+
EIGEN_STRONG_INLINE Packet pblendv(const PacketIf& ifPacket, const Packet& thenPacket, const Packet& elsePacket);
|
|
200
200
|
|
|
201
201
|
template<typename Packet>
|
|
202
202
|
EIGEN_STRONG_INLINE Packet pgather(const int* addr, const Packet& index);
|
|
@@ -240,7 +240,7 @@ namespace Eigen
|
|
|
240
240
|
return psub(reinterpret_to_double(por(pand(x, lower), upper)), one);
|
|
241
241
|
}
|
|
242
242
|
|
|
243
|
-
template<typename
|
|
243
|
+
template<typename _Scalar>
|
|
244
244
|
struct bit_scalar;
|
|
245
245
|
|
|
246
246
|
template<>
|
|
@@ -412,7 +412,7 @@ namespace Eigen
|
|
|
412
412
|
Packet4i a1, a2, b1, b2;
|
|
413
413
|
split_two(a, a1, a2);
|
|
414
414
|
split_two(b, b1, b2);
|
|
415
|
-
return combine_two(_mm_cmpeq_epi32(a1, b1), _mm_cmpeq_epi32(a2, b2));
|
|
415
|
+
return combine_two((Packet4i)_mm_cmpeq_epi32(a1, b1), (Packet4i)_mm_cmpeq_epi32(a2, b2));
|
|
416
416
|
#endif
|
|
417
417
|
}
|
|
418
418
|
|
|
@@ -424,7 +424,7 @@ namespace Eigen
|
|
|
424
424
|
#else
|
|
425
425
|
Packet4i a1, a2;
|
|
426
426
|
split_two(a, a1, a2);
|
|
427
|
-
return combine_two(_mm_slli_epi32(a1, b), _mm_slli_epi32(a2, b));
|
|
427
|
+
return combine_two((Packet4i)_mm_slli_epi32(a1, b), (Packet4i)_mm_slli_epi32(a2, b));
|
|
428
428
|
#endif
|
|
429
429
|
}
|
|
430
430
|
|
|
@@ -436,7 +436,7 @@ namespace Eigen
|
|
|
436
436
|
#else
|
|
437
437
|
Packet4i a1, a2;
|
|
438
438
|
split_two(a, a1, a2);
|
|
439
|
-
return combine_two(_mm_srli_epi32(a1, b), _mm_srli_epi32(a2, b));
|
|
439
|
+
return combine_two((Packet4i)_mm_srli_epi32(a1, b), (Packet4i)_mm_srli_epi32(a2, b));
|
|
440
440
|
#endif
|
|
441
441
|
}
|
|
442
442
|
|
|
@@ -448,7 +448,7 @@ namespace Eigen
|
|
|
448
448
|
#else
|
|
449
449
|
Packet4i a1, a2;
|
|
450
450
|
split_two(a, a1, a2);
|
|
451
|
-
return combine_two(_mm_slli_epi64(a1, b), _mm_slli_epi64(a2, b));
|
|
451
|
+
return combine_two((Packet4i)_mm_slli_epi64(a1, b), (Packet4i)_mm_slli_epi64(a2, b));
|
|
452
452
|
#endif
|
|
453
453
|
}
|
|
454
454
|
|
|
@@ -460,7 +460,7 @@ namespace Eigen
|
|
|
460
460
|
#else
|
|
461
461
|
Packet4i a1, a2;
|
|
462
462
|
split_two(a, a1, a2);
|
|
463
|
-
return combine_two(_mm_srli_epi64(a1, b), _mm_srli_epi64(a2, b));
|
|
463
|
+
return combine_two((Packet4i)_mm_srli_epi64(a1, b), (Packet4i)_mm_srli_epi64(a2, b));
|
|
464
464
|
#endif
|
|
465
465
|
}
|
|
466
466
|
|
|
@@ -472,7 +472,7 @@ namespace Eigen
|
|
|
472
472
|
Packet4i a1, a2, b1, b2;
|
|
473
473
|
split_two(a, a1, a2);
|
|
474
474
|
split_two(b, b1, b2);
|
|
475
|
-
return combine_two(_mm_add_epi32(a1, b1), _mm_add_epi32(a2, b2));
|
|
475
|
+
return combine_two((Packet4i)_mm_add_epi32(a1, b1), (Packet4i)_mm_add_epi32(a2, b2));
|
|
476
476
|
#endif
|
|
477
477
|
}
|
|
478
478
|
|
|
@@ -484,7 +484,7 @@ namespace Eigen
|
|
|
484
484
|
Packet4i a1, a2, b1, b2;
|
|
485
485
|
split_two(a, a1, a2);
|
|
486
486
|
split_two(b, b1, b2);
|
|
487
|
-
return combine_two(_mm_sub_epi32(a1, b1), _mm_sub_epi32(a2, b2));
|
|
487
|
+
return combine_two((Packet4i)_mm_sub_epi32(a1, b1), (Packet4i)_mm_sub_epi32(a2, b2));
|
|
488
488
|
#endif
|
|
489
489
|
}
|
|
490
490
|
|
|
@@ -493,7 +493,7 @@ namespace Eigen
|
|
|
493
493
|
#ifdef EIGEN_VECTORIZE_AVX2
|
|
494
494
|
return _mm256_and_si256(a, b);
|
|
495
495
|
#else
|
|
496
|
-
return reinterpret_to_int(_mm256_and_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
|
|
496
|
+
return reinterpret_to_int((Packet8f)_mm256_and_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
|
|
497
497
|
#endif
|
|
498
498
|
}
|
|
499
499
|
|
|
@@ -502,7 +502,7 @@ namespace Eigen
|
|
|
502
502
|
#ifdef EIGEN_VECTORIZE_AVX2
|
|
503
503
|
return _mm256_andnot_si256(a, b);
|
|
504
504
|
#else
|
|
505
|
-
return reinterpret_to_int(_mm256_andnot_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
|
|
505
|
+
return reinterpret_to_int((Packet8f)_mm256_andnot_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
|
|
506
506
|
#endif
|
|
507
507
|
}
|
|
508
508
|
|
|
@@ -511,7 +511,7 @@ namespace Eigen
|
|
|
511
511
|
#ifdef EIGEN_VECTORIZE_AVX2
|
|
512
512
|
return _mm256_or_si256(a, b);
|
|
513
513
|
#else
|
|
514
|
-
return reinterpret_to_int(_mm256_or_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
|
|
514
|
+
return reinterpret_to_int((Packet8f)_mm256_or_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
|
|
515
515
|
#endif
|
|
516
516
|
}
|
|
517
517
|
|
|
@@ -520,14 +520,21 @@ namespace Eigen
|
|
|
520
520
|
#ifdef EIGEN_VECTORIZE_AVX2
|
|
521
521
|
return _mm256_xor_si256(a, b);
|
|
522
522
|
#else
|
|
523
|
-
return reinterpret_to_int(_mm256_xor_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
|
|
523
|
+
return reinterpret_to_int((Packet8f)_mm256_xor_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
|
|
524
524
|
#endif
|
|
525
525
|
}
|
|
526
526
|
|
|
527
527
|
template<>
|
|
528
528
|
EIGEN_STRONG_INLINE Packet8i pcmplt<Packet8i>(const Packet8i& a, const Packet8i& b)
|
|
529
529
|
{
|
|
530
|
+
#ifdef EIGEN_VECTORIZE_AVX2
|
|
530
531
|
return _mm256_cmpgt_epi32(b, a);
|
|
532
|
+
#else
|
|
533
|
+
Packet4i a1, a2, b1, b2;
|
|
534
|
+
split_two(a, a1, a2);
|
|
535
|
+
split_two(b, b1, b2);
|
|
536
|
+
return combine_two((Packet4i)_mm_cmpgt_epi32(b1, a1), (Packet4i)_mm_cmpgt_epi32(b2, a2));
|
|
537
|
+
#endif
|
|
531
538
|
}
|
|
532
539
|
|
|
533
540
|
template<>
|
|
@@ -560,6 +567,12 @@ namespace Eigen
|
|
|
560
567
|
return _mm256_blendv_ps(elsePacket, thenPacket, ifPacket);
|
|
561
568
|
}
|
|
562
569
|
|
|
570
|
+
template<>
|
|
571
|
+
EIGEN_STRONG_INLINE Packet8f pblendv(const Packet8i& ifPacket, const Packet8f& thenPacket, const Packet8f& elsePacket)
|
|
572
|
+
{
|
|
573
|
+
return pblendv(_mm256_castsi256_ps(ifPacket), thenPacket, elsePacket);
|
|
574
|
+
}
|
|
575
|
+
|
|
563
576
|
template<>
|
|
564
577
|
EIGEN_STRONG_INLINE Packet8i pblendv(const Packet8i& ifPacket, const Packet8i& thenPacket, const Packet8i& elsePacket)
|
|
565
578
|
{
|
|
@@ -576,6 +589,12 @@ namespace Eigen
|
|
|
576
589
|
return _mm256_blendv_pd(elsePacket, thenPacket, ifPacket);
|
|
577
590
|
}
|
|
578
591
|
|
|
592
|
+
template<>
|
|
593
|
+
EIGEN_STRONG_INLINE Packet4d pblendv(const Packet8i& ifPacket, const Packet4d& thenPacket, const Packet4d& elsePacket)
|
|
594
|
+
{
|
|
595
|
+
return pblendv(_mm256_castsi256_pd(ifPacket), thenPacket, elsePacket);
|
|
596
|
+
}
|
|
597
|
+
|
|
579
598
|
template<>
|
|
580
599
|
EIGEN_STRONG_INLINE Packet8i pgather<Packet8i>(const int* addr, const Packet8i& index)
|
|
581
600
|
{
|
|
@@ -660,7 +679,7 @@ namespace Eigen
|
|
|
660
679
|
Packet4i a1, a2, b1, b2;
|
|
661
680
|
split_two(a, a1, a2);
|
|
662
681
|
split_two(b, b1, b2);
|
|
663
|
-
return combine_two(_mm_cmpeq_epi64(a1, b1), _mm_cmpeq_epi64(a2, b2));
|
|
682
|
+
return combine_two((Packet4i)_mm_cmpeq_epi64(a1, b1), (Packet4i)_mm_cmpeq_epi64(a2, b2));
|
|
664
683
|
#endif
|
|
665
684
|
}
|
|
666
685
|
|
|
@@ -842,6 +861,12 @@ namespace Eigen
|
|
|
842
861
|
#endif
|
|
843
862
|
}
|
|
844
863
|
|
|
864
|
+
template<>
|
|
865
|
+
EIGEN_STRONG_INLINE Packet4f pblendv(const Packet4i& ifPacket, const Packet4f& thenPacket, const Packet4f& elsePacket)
|
|
866
|
+
{
|
|
867
|
+
return pblendv(_mm_castsi128_ps(ifPacket), thenPacket, elsePacket);
|
|
868
|
+
}
|
|
869
|
+
|
|
845
870
|
template<>
|
|
846
871
|
EIGEN_STRONG_INLINE Packet4i pblendv(const Packet4i& ifPacket, const Packet4i& thenPacket, const Packet4i& elsePacket)
|
|
847
872
|
{
|
|
@@ -862,6 +887,13 @@ namespace Eigen
|
|
|
862
887
|
#endif
|
|
863
888
|
}
|
|
864
889
|
|
|
890
|
+
|
|
891
|
+
template<>
|
|
892
|
+
EIGEN_STRONG_INLINE Packet2d pblendv(const Packet4i& ifPacket, const Packet2d& thenPacket, const Packet2d& elsePacket)
|
|
893
|
+
{
|
|
894
|
+
return pblendv(_mm_castsi128_pd(ifPacket), thenPacket, elsePacket);
|
|
895
|
+
}
|
|
896
|
+
|
|
865
897
|
template<>
|
|
866
898
|
EIGEN_STRONG_INLINE Packet4i pgather<Packet4i>(const int* addr, const Packet4i& index)
|
|
867
899
|
{
|
|
@@ -869,7 +901,7 @@ namespace Eigen
|
|
|
869
901
|
return _mm_i32gather_epi32(addr, index, 4);
|
|
870
902
|
#else
|
|
871
903
|
uint32_t u[4];
|
|
872
|
-
_mm_storeu_si128((
|
|
904
|
+
_mm_storeu_si128((__m128i*)u, index);
|
|
873
905
|
return _mm_setr_epi32(addr[u[0]], addr[u[1]], addr[u[2]], addr[u[3]]);
|
|
874
906
|
#endif
|
|
875
907
|
}
|
|
@@ -881,7 +913,7 @@ namespace Eigen
|
|
|
881
913
|
return _mm_i32gather_ps(addr, index, 4);
|
|
882
914
|
#else
|
|
883
915
|
uint32_t u[4];
|
|
884
|
-
_mm_storeu_si128((
|
|
916
|
+
_mm_storeu_si128((__m128i*)u, index);
|
|
885
917
|
return _mm_setr_ps(addr[u[0]], addr[u[1]], addr[u[2]], addr[u[3]]);
|
|
886
918
|
#endif
|
|
887
919
|
}
|
|
@@ -893,7 +925,7 @@ namespace Eigen
|
|
|
893
925
|
return _mm_i32gather_pd(addr, index, 8);
|
|
894
926
|
#else
|
|
895
927
|
uint32_t u[4];
|
|
896
|
-
_mm_storeu_si128((
|
|
928
|
+
_mm_storeu_si128((__m128i*)u, index);
|
|
897
929
|
if (upperhalf)
|
|
898
930
|
{
|
|
899
931
|
return _mm_setr_pd(addr[u[2]], addr[u[3]]);
|
|
@@ -920,7 +952,7 @@ namespace Eigen
|
|
|
920
952
|
template<>
|
|
921
953
|
EIGEN_STRONG_INLINE int pmovemask<Packet4i>(const Packet4i& a)
|
|
922
954
|
{
|
|
923
|
-
return pmovemask(_mm_castsi128_ps(a));
|
|
955
|
+
return pmovemask((Packet4f)_mm_castsi128_ps(a));
|
|
924
956
|
}
|
|
925
957
|
|
|
926
958
|
template<>
|
|
@@ -958,7 +990,7 @@ namespace Eigen
|
|
|
958
990
|
return _mm_cmpeq_epi64(a, b);
|
|
959
991
|
#else
|
|
960
992
|
Packet4i c = _mm_cmpeq_epi32(a, b);
|
|
961
|
-
return pand(c, _mm_shuffle_epi32(c, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
993
|
+
return pand(c, (Packet4i)_mm_shuffle_epi32(c, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
962
994
|
#endif
|
|
963
995
|
}
|
|
964
996
|
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file Multinomial.h
|
|
3
|
+
* @author bab2min (bab2min@gmail.com)
|
|
4
|
+
* @brief
|
|
5
|
+
* @version 0.3.0
|
|
6
|
+
* @date 2020-10-07
|
|
7
|
+
*
|
|
8
|
+
* @copyright Copyright (c) 2020
|
|
9
|
+
*
|
|
10
|
+
*/
|
|
11
|
+
|
|
12
|
+
#ifndef EIGENRAND_MVDISTS_MULTINOMIAL_H
|
|
13
|
+
#define EIGENRAND_MVDISTS_MULTINOMIAL_H
|
|
14
|
+
|
|
15
|
+
namespace Eigen
|
|
16
|
+
{
|
|
17
|
+
namespace Rand
|
|
18
|
+
{
|
|
19
|
+
/**
|
|
20
|
+
* @brief Generator of real vectors on a multinomial distribution
|
|
21
|
+
*
|
|
22
|
+
* @tparam _Scalar
|
|
23
|
+
* @tparam Dim number of dimensions, or `Eigen::Dynamic`
|
|
24
|
+
*/
|
|
25
|
+
template<typename _Scalar = int32_t, Index Dim = -1>
|
|
26
|
+
class MultinomialGen : public MvVecGenBase<MultinomialGen<_Scalar, Dim>, _Scalar, Dim>
|
|
27
|
+
{
|
|
28
|
+
static_assert(std::is_same<_Scalar, int32_t>::value, "`MultinomialGen` needs integral types.");
|
|
29
|
+
_Scalar trials;
|
|
30
|
+
Matrix<double, Dim, 1> probs;
|
|
31
|
+
DiscreteGen<_Scalar> discrete;
|
|
32
|
+
public:
|
|
33
|
+
/**
|
|
34
|
+
* @brief Construct a new multinomial generator
|
|
35
|
+
*
|
|
36
|
+
* @tparam WeightTy
|
|
37
|
+
* @param _trials the number of trials
|
|
38
|
+
* @param _weights the weights of each category, `(Dim, 1)` shape matrix or vector
|
|
39
|
+
*/
|
|
40
|
+
template<typename WeightTy>
|
|
41
|
+
MultinomialGen(_Scalar _trials, const MatrixBase<WeightTy>& _weights)
|
|
42
|
+
: trials{ _trials }, probs{ _weights.template cast<double>() }, discrete(probs.data(), probs.data() + probs.size())
|
|
43
|
+
{
|
|
44
|
+
eigen_assert(_weights.cols() == 1);
|
|
45
|
+
for (Index i = 0; i < probs.size(); ++i)
|
|
46
|
+
{
|
|
47
|
+
eigen_assert(probs[i] >= 0);
|
|
48
|
+
}
|
|
49
|
+
probs /= probs.sum();
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
MultinomialGen(const MultinomialGen&) = default;
|
|
53
|
+
MultinomialGen(MultinomialGen&&) = default;
|
|
54
|
+
|
|
55
|
+
MultinomialGen& operator=(const MultinomialGen&) = default;
|
|
56
|
+
MultinomialGen& operator=(MultinomialGen&&) = default;
|
|
57
|
+
|
|
58
|
+
Index dims() const { return probs.rows(); }
|
|
59
|
+
|
|
60
|
+
template<typename Urng>
|
|
61
|
+
inline Matrix<_Scalar, Dim, -1> generate(Urng&& urng, Index samples)
|
|
62
|
+
{
|
|
63
|
+
const Index dim = probs.size();
|
|
64
|
+
Matrix<_Scalar, Dim, -1> ret(dim, samples);
|
|
65
|
+
//if (trials < 2500)
|
|
66
|
+
{
|
|
67
|
+
for (Index j = 0; j < samples; ++j)
|
|
68
|
+
{
|
|
69
|
+
ret.col(j) = generate(urng);
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
/*else
|
|
73
|
+
{
|
|
74
|
+
ret.row(0) = binomial<Matrix<_Scalar, -1, 1>>(samples, 1, urng, trials, probs[0]).eval().transpose();
|
|
75
|
+
for (Index j = 0; j < samples; ++j)
|
|
76
|
+
{
|
|
77
|
+
double rest_p = 1 - probs[0];
|
|
78
|
+
_Scalar t = trials - ret(0, j);
|
|
79
|
+
for (Index i = 1; i < dim - 1; ++i)
|
|
80
|
+
{
|
|
81
|
+
ret(i, j) = binomial<Matrix<_Scalar, 1, 1>>(1, 1, urng, t, probs[i] / rest_p)(0);
|
|
82
|
+
t -= ret(i, j);
|
|
83
|
+
rest_p -= probs[i];
|
|
84
|
+
}
|
|
85
|
+
ret(dim - 1, j) = 0;
|
|
86
|
+
}
|
|
87
|
+
ret.row(dim - 1).setZero();
|
|
88
|
+
ret.row(dim - 1).array() = trials - ret.colwise().sum().array();
|
|
89
|
+
}*/
|
|
90
|
+
return ret;
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
template<typename Urng>
|
|
94
|
+
inline Matrix<_Scalar, Dim, 1> generate(Urng&& urng)
|
|
95
|
+
{
|
|
96
|
+
const Index dim = probs.size();
|
|
97
|
+
Matrix<_Scalar, Dim, 1> ret(dim);
|
|
98
|
+
//if (trials < 2500)
|
|
99
|
+
{
|
|
100
|
+
ret.setZero();
|
|
101
|
+
auto d = discrete.template generate<Matrix<_Scalar, -1, 1>>(trials, 1, urng).eval();
|
|
102
|
+
for (Index i = 0; i < trials; ++i)
|
|
103
|
+
{
|
|
104
|
+
ret[d[i]] += 1;
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
/*else
|
|
108
|
+
{
|
|
109
|
+
double rest_p = 1;
|
|
110
|
+
_Scalar t = trials;
|
|
111
|
+
for (Index i = 0; i < dim - 1; ++i)
|
|
112
|
+
{
|
|
113
|
+
ret[i] = binomial<Matrix<_Scalar, 1, 1>>(1, 1, urng, t, probs[i] / rest_p)(0);
|
|
114
|
+
t -= ret[i];
|
|
115
|
+
rest_p -= probs[i];
|
|
116
|
+
}
|
|
117
|
+
ret[dim - 1] = 0;
|
|
118
|
+
ret[dim - 1] = trials - ret.sum();
|
|
119
|
+
}*/
|
|
120
|
+
return ret;
|
|
121
|
+
}
|
|
122
|
+
};
|
|
123
|
+
|
|
124
|
+
/**
|
|
125
|
+
* @brief helper function constructing Eigen::Rand::MultinomialGen
|
|
126
|
+
*
|
|
127
|
+
* @tparam IntTy
|
|
128
|
+
* @tparam WeightTy
|
|
129
|
+
* @param trials the number of trials
|
|
130
|
+
* @param probs The weights of each category with shape `(Dim, 1)` of matrix or vector.
|
|
131
|
+
* The number of entries determines the dimensionality of the distribution
|
|
132
|
+
* @return an instance of MultinomialGen in the appropriate type
|
|
133
|
+
*/
|
|
134
|
+
template<typename IntTy, typename WeightTy>
|
|
135
|
+
inline auto makeMultinomialGen(IntTy trials, const MatrixBase<WeightTy>& probs)
|
|
136
|
+
-> MultinomialGen<IntTy, MatrixBase<WeightTy>::RowsAtCompileTime>
|
|
137
|
+
{
|
|
138
|
+
return MultinomialGen<IntTy, MatrixBase<WeightTy>::RowsAtCompileTime>{ trials, probs };
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
/**
|
|
142
|
+
* @brief Generator of reals on a Dirichlet distribution
|
|
143
|
+
*
|
|
144
|
+
* @tparam _Scalar
|
|
145
|
+
* @tparam Dim number of dimensions, or `Eigen::Dynamic`
|
|
146
|
+
*/
|
|
147
|
+
template<typename _Scalar, Index Dim = -1>
|
|
148
|
+
class DirichletGen : public MvVecGenBase<DirichletGen<_Scalar, Dim>, _Scalar, Dim>
|
|
149
|
+
{
|
|
150
|
+
Matrix<_Scalar, Dim, 1> alpha;
|
|
151
|
+
std::vector<GammaGen<_Scalar>> gammas;
|
|
152
|
+
public:
|
|
153
|
+
/**
|
|
154
|
+
* @brief Construct a new Dirichlet generator
|
|
155
|
+
*
|
|
156
|
+
* @tparam AlphaTy
|
|
157
|
+
* @param _alpha the concentration parameters with shape `(Dim, 1)` matrix or vector
|
|
158
|
+
*/
|
|
159
|
+
template<typename AlphaTy>
|
|
160
|
+
DirichletGen(const MatrixBase<AlphaTy>& _alpha)
|
|
161
|
+
: alpha{ _alpha }
|
|
162
|
+
{
|
|
163
|
+
eigen_assert(_alpha.cols() == 1);
|
|
164
|
+
for (Index i = 0; i < alpha.size(); ++i)
|
|
165
|
+
{
|
|
166
|
+
eigen_assert(alpha[i] > 0);
|
|
167
|
+
gammas.emplace_back(alpha[i]);
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
DirichletGen(const DirichletGen&) = default;
|
|
172
|
+
DirichletGen(DirichletGen&&) = default;
|
|
173
|
+
|
|
174
|
+
Index dims() const { return alpha.rows(); }
|
|
175
|
+
|
|
176
|
+
template<typename Urng>
|
|
177
|
+
inline Matrix<_Scalar, Dim, -1> generate(Urng&& urng, Index samples)
|
|
178
|
+
{
|
|
179
|
+
const Index dim = alpha.size();
|
|
180
|
+
Matrix<_Scalar, Dim, -1> ret(dim, samples);
|
|
181
|
+
Matrix<_Scalar, -1, 1> tmp(samples);
|
|
182
|
+
for (Index i = 0; i < dim; ++i)
|
|
183
|
+
{
|
|
184
|
+
tmp = gammas[i].generateLike(tmp, urng);
|
|
185
|
+
ret.row(i) = tmp.transpose();
|
|
186
|
+
}
|
|
187
|
+
ret.array().rowwise() /= ret.array().colwise().sum();
|
|
188
|
+
return ret;
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
template<typename Urng>
|
|
192
|
+
inline Matrix<_Scalar, Dim, 1> generate(Urng&& urng)
|
|
193
|
+
{
|
|
194
|
+
const Index dim = alpha.size();
|
|
195
|
+
Matrix<_Scalar, Dim, 1> ret(dim);
|
|
196
|
+
for (Index i = 0; i < dim; ++i)
|
|
197
|
+
{
|
|
198
|
+
ret[i] = gammas[i].template generate<Matrix<_Scalar, 1, 1>>(1, 1, urng)(0);
|
|
199
|
+
}
|
|
200
|
+
ret /= ret.sum();
|
|
201
|
+
return ret;
|
|
202
|
+
}
|
|
203
|
+
};
|
|
204
|
+
|
|
205
|
+
/**
|
|
206
|
+
* @brief helper function constructing Eigen::Rand::DirichletGen
|
|
207
|
+
*
|
|
208
|
+
* @tparam AlphaTy
|
|
209
|
+
* @param alpha The concentration parameters with shape `(Dim, 1)` of matrix or vector.
|
|
210
|
+
* The number of entries determines the dimensionality of the distribution.
|
|
211
|
+
* @return an instance of MultinomialGen in the appropriate type
|
|
212
|
+
*/
|
|
213
|
+
template<typename AlphaTy>
|
|
214
|
+
inline auto makeDirichletGen(const MatrixBase<AlphaTy>& alpha)
|
|
215
|
+
-> DirichletGen<typename MatrixBase<AlphaTy>::Scalar, MatrixBase<AlphaTy>::RowsAtCompileTime>
|
|
216
|
+
{
|
|
217
|
+
return DirichletGen<typename MatrixBase<AlphaTy>::Scalar, MatrixBase<AlphaTy>::RowsAtCompileTime>{ alpha };
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
#endif
|