tomoto 0.1.4 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/ext/tomoto/ct.cpp +8 -4
  4. data/ext/tomoto/dmr.cpp +10 -4
  5. data/ext/tomoto/dt.cpp +13 -4
  6. data/ext/tomoto/extconf.rb +1 -1
  7. data/ext/tomoto/gdmr.cpp +14 -6
  8. data/ext/tomoto/hdp.cpp +9 -4
  9. data/ext/tomoto/hlda.cpp +9 -4
  10. data/ext/tomoto/hpa.cpp +9 -4
  11. data/ext/tomoto/lda.cpp +8 -4
  12. data/ext/tomoto/llda.cpp +8 -4
  13. data/ext/tomoto/mglda.cpp +11 -1
  14. data/ext/tomoto/pa.cpp +9 -4
  15. data/ext/tomoto/plda.cpp +8 -4
  16. data/ext/tomoto/slda.cpp +13 -5
  17. data/lib/tomoto/gdmr.rb +2 -2
  18. data/lib/tomoto/version.rb +1 -1
  19. data/vendor/EigenRand/EigenRand/Core.h +6 -1107
  20. data/vendor/EigenRand/EigenRand/Dists/Basic.h +490 -43
  21. data/vendor/EigenRand/EigenRand/Dists/Discrete.h +916 -285
  22. data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +85 -36
  23. data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +1038 -290
  24. data/vendor/EigenRand/EigenRand/EigenRand +2 -2
  25. data/vendor/EigenRand/EigenRand/Macro.h +4 -4
  26. data/vendor/EigenRand/EigenRand/MorePacketMath.h +54 -22
  27. data/vendor/EigenRand/EigenRand/MvDists/Multinomial.h +222 -0
  28. data/vendor/EigenRand/EigenRand/MvDists/MvNormal.h +492 -0
  29. data/vendor/EigenRand/EigenRand/PacketFilter.h +2 -2
  30. data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +2 -2
  31. data/vendor/EigenRand/EigenRand/RandUtils.h +65 -11
  32. data/vendor/EigenRand/EigenRand/doc.h +142 -25
  33. data/vendor/EigenRand/LICENSE +1 -1
  34. data/vendor/EigenRand/README.md +109 -24
  35. data/vendor/tomotopy/README.kr.rst +27 -6
  36. data/vendor/tomotopy/README.rst +29 -8
  37. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +60 -12
  38. data/vendor/tomotopy/src/Labeling/FoRelevance.h +2 -2
  39. data/vendor/tomotopy/src/Labeling/Phraser.hpp +33 -21
  40. data/vendor/tomotopy/src/TopicModel/CT.h +8 -5
  41. data/vendor/tomotopy/src/TopicModel/CTModel.cpp +2 -6
  42. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +29 -23
  43. data/vendor/tomotopy/src/TopicModel/DMR.h +33 -4
  44. data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +2 -6
  45. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +231 -57
  46. data/vendor/tomotopy/src/TopicModel/DT.h +24 -5
  47. data/vendor/tomotopy/src/TopicModel/DTModel.cpp +2 -8
  48. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +41 -28
  49. data/vendor/tomotopy/src/TopicModel/GDMR.h +31 -5
  50. data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +2 -7
  51. data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +211 -104
  52. data/vendor/tomotopy/src/TopicModel/HDP.h +11 -2
  53. data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +2 -6
  54. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +52 -45
  55. data/vendor/tomotopy/src/TopicModel/HLDA.h +11 -2
  56. data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +2 -6
  57. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +13 -16
  58. data/vendor/tomotopy/src/TopicModel/HPA.h +5 -2
  59. data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +2 -6
  60. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +51 -21
  61. data/vendor/tomotopy/src/TopicModel/LDA.h +9 -2
  62. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +8 -8
  63. data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +2 -6
  64. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +70 -28
  65. data/vendor/tomotopy/src/TopicModel/LLDA.h +1 -2
  66. data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +2 -6
  67. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +22 -12
  68. data/vendor/tomotopy/src/TopicModel/MGLDA.h +12 -3
  69. data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +2 -10
  70. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +42 -19
  71. data/vendor/tomotopy/src/TopicModel/PA.h +9 -4
  72. data/vendor/tomotopy/src/TopicModel/PAModel.cpp +2 -6
  73. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +48 -25
  74. data/vendor/tomotopy/src/TopicModel/PLDA.h +13 -2
  75. data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +2 -6
  76. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +27 -19
  77. data/vendor/tomotopy/src/TopicModel/PT.h +12 -5
  78. data/vendor/tomotopy/src/TopicModel/PTModel.cpp +2 -3
  79. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +29 -14
  80. data/vendor/tomotopy/src/TopicModel/SLDA.h +18 -6
  81. data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +2 -10
  82. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +93 -43
  83. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +58 -23
  84. data/vendor/tomotopy/src/Utils/AliasMethod.hpp +6 -6
  85. data/vendor/tomotopy/src/Utils/Dictionary.h +11 -0
  86. data/vendor/tomotopy/src/Utils/SharedString.hpp +26 -1
  87. data/vendor/tomotopy/src/Utils/Trie.hpp +46 -21
  88. data/vendor/tomotopy/src/Utils/Utils.hpp +99 -14
  89. data/vendor/tomotopy/src/Utils/exception.h +1 -1
  90. data/vendor/tomotopy/src/Utils/math.h +5 -7
  91. data/vendor/tomotopy/src/Utils/serializer.hpp +329 -201
  92. data/vendor/tomotopy/src/Utils/text.hpp +8 -0
  93. data/vendor/tomotopy/src/Utils/tvector.hpp +49 -7
  94. metadata +9 -7
@@ -2,8 +2,8 @@
2
2
  * @file EigenRand
3
3
  * @author bab2min (bab2min@gmail.com)
4
4
  * @brief
5
- * @version 0.2.0
6
- * @date 2020-06-22
5
+ * @version 0.3.0
6
+ * @date 2020-10-07
7
7
  *
8
8
  * @copyright Copyright (c) 2020
9
9
  *
@@ -2,8 +2,8 @@
2
2
  * @file Macro.h
3
3
  * @author bab2min (bab2min@gmail.com)
4
4
  * @brief
5
- * @version 0.2.1
6
- * @date 2020-07-11
5
+ * @version 0.3.0
6
+ * @date 2020-10-07
7
7
  *
8
8
  * @copyright Copyright (c) 2020
9
9
  *
@@ -13,8 +13,8 @@
13
13
  #define EIGENRAND_MACRO_H
14
14
 
15
15
  #define EIGENRAND_WORLD_VERSION 0
16
- #define EIGENRAND_MAJOR_VERSION 2
17
- #define EIGENRAND_MINOR_VERSION 0
16
+ #define EIGENRAND_MAJOR_VERSION 3
17
+ #define EIGENRAND_MINOR_VERSION 2
18
18
 
19
19
  #if EIGEN_VERSION_AT_LEAST(3,3,7)
20
20
  #else
@@ -2,8 +2,8 @@
2
2
  * @file MorePacketMath.h
3
3
  * @author bab2min (bab2min@gmail.com)
4
4
  * @brief
5
- * @version 0.2.0
6
- * @date 2020-06-22
5
+ * @version 0.3.0
6
+ * @date 2020-10-07
7
7
  *
8
8
  * @copyright Copyright (c) 2020
9
9
  *
@@ -195,8 +195,8 @@ namespace Eigen
195
195
  template<typename Packet>
196
196
  EIGEN_STRONG_INLINE Packet pcmple(const Packet& a, const Packet& b);
197
197
 
198
- template<typename Packet>
199
- EIGEN_STRONG_INLINE Packet pblendv(const Packet& ifPacket, const Packet& thenPacket, const Packet& elsePacket);
198
+ template<typename PacketIf, typename Packet>
199
+ EIGEN_STRONG_INLINE Packet pblendv(const PacketIf& ifPacket, const Packet& thenPacket, const Packet& elsePacket);
200
200
 
201
201
  template<typename Packet>
202
202
  EIGEN_STRONG_INLINE Packet pgather(const int* addr, const Packet& index);
@@ -240,7 +240,7 @@ namespace Eigen
240
240
  return psub(reinterpret_to_double(por(pand(x, lower), upper)), one);
241
241
  }
242
242
 
243
- template<typename Scalar>
243
+ template<typename _Scalar>
244
244
  struct bit_scalar;
245
245
 
246
246
  template<>
@@ -412,7 +412,7 @@ namespace Eigen
412
412
  Packet4i a1, a2, b1, b2;
413
413
  split_two(a, a1, a2);
414
414
  split_two(b, b1, b2);
415
- return combine_two(_mm_cmpeq_epi32(a1, b1), _mm_cmpeq_epi32(a2, b2));
415
+ return combine_two((Packet4i)_mm_cmpeq_epi32(a1, b1), (Packet4i)_mm_cmpeq_epi32(a2, b2));
416
416
  #endif
417
417
  }
418
418
 
@@ -424,7 +424,7 @@ namespace Eigen
424
424
  #else
425
425
  Packet4i a1, a2;
426
426
  split_two(a, a1, a2);
427
- return combine_two(_mm_slli_epi32(a1, b), _mm_slli_epi32(a2, b));
427
+ return combine_two((Packet4i)_mm_slli_epi32(a1, b), (Packet4i)_mm_slli_epi32(a2, b));
428
428
  #endif
429
429
  }
430
430
 
@@ -436,7 +436,7 @@ namespace Eigen
436
436
  #else
437
437
  Packet4i a1, a2;
438
438
  split_two(a, a1, a2);
439
- return combine_two(_mm_srli_epi32(a1, b), _mm_srli_epi32(a2, b));
439
+ return combine_two((Packet4i)_mm_srli_epi32(a1, b), (Packet4i)_mm_srli_epi32(a2, b));
440
440
  #endif
441
441
  }
442
442
 
@@ -448,7 +448,7 @@ namespace Eigen
448
448
  #else
449
449
  Packet4i a1, a2;
450
450
  split_two(a, a1, a2);
451
- return combine_two(_mm_slli_epi64(a1, b), _mm_slli_epi64(a2, b));
451
+ return combine_two((Packet4i)_mm_slli_epi64(a1, b), (Packet4i)_mm_slli_epi64(a2, b));
452
452
  #endif
453
453
  }
454
454
 
@@ -460,7 +460,7 @@ namespace Eigen
460
460
  #else
461
461
  Packet4i a1, a2;
462
462
  split_two(a, a1, a2);
463
- return combine_two(_mm_srli_epi64(a1, b), _mm_srli_epi64(a2, b));
463
+ return combine_two((Packet4i)_mm_srli_epi64(a1, b), (Packet4i)_mm_srli_epi64(a2, b));
464
464
  #endif
465
465
  }
466
466
 
@@ -472,7 +472,7 @@ namespace Eigen
472
472
  Packet4i a1, a2, b1, b2;
473
473
  split_two(a, a1, a2);
474
474
  split_two(b, b1, b2);
475
- return combine_two(_mm_add_epi32(a1, b1), _mm_add_epi32(a2, b2));
475
+ return combine_two((Packet4i)_mm_add_epi32(a1, b1), (Packet4i)_mm_add_epi32(a2, b2));
476
476
  #endif
477
477
  }
478
478
 
@@ -484,7 +484,7 @@ namespace Eigen
484
484
  Packet4i a1, a2, b1, b2;
485
485
  split_two(a, a1, a2);
486
486
  split_two(b, b1, b2);
487
- return combine_two(_mm_sub_epi32(a1, b1), _mm_sub_epi32(a2, b2));
487
+ return combine_two((Packet4i)_mm_sub_epi32(a1, b1), (Packet4i)_mm_sub_epi32(a2, b2));
488
488
  #endif
489
489
  }
490
490
 
@@ -493,7 +493,7 @@ namespace Eigen
493
493
  #ifdef EIGEN_VECTORIZE_AVX2
494
494
  return _mm256_and_si256(a, b);
495
495
  #else
496
- return reinterpret_to_int(_mm256_and_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
496
+ return reinterpret_to_int((Packet8f)_mm256_and_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
497
497
  #endif
498
498
  }
499
499
 
@@ -502,7 +502,7 @@ namespace Eigen
502
502
  #ifdef EIGEN_VECTORIZE_AVX2
503
503
  return _mm256_andnot_si256(a, b);
504
504
  #else
505
- return reinterpret_to_int(_mm256_andnot_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
505
+ return reinterpret_to_int((Packet8f)_mm256_andnot_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
506
506
  #endif
507
507
  }
508
508
 
@@ -511,7 +511,7 @@ namespace Eigen
511
511
  #ifdef EIGEN_VECTORIZE_AVX2
512
512
  return _mm256_or_si256(a, b);
513
513
  #else
514
- return reinterpret_to_int(_mm256_or_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
514
+ return reinterpret_to_int((Packet8f)_mm256_or_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
515
515
  #endif
516
516
  }
517
517
 
@@ -520,14 +520,21 @@ namespace Eigen
520
520
  #ifdef EIGEN_VECTORIZE_AVX2
521
521
  return _mm256_xor_si256(a, b);
522
522
  #else
523
- return reinterpret_to_int(_mm256_xor_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
523
+ return reinterpret_to_int((Packet8f)_mm256_xor_ps(reinterpret_to_float(a), reinterpret_to_float(b)));
524
524
  #endif
525
525
  }
526
526
 
527
527
  template<>
528
528
  EIGEN_STRONG_INLINE Packet8i pcmplt<Packet8i>(const Packet8i& a, const Packet8i& b)
529
529
  {
530
+ #ifdef EIGEN_VECTORIZE_AVX2
530
531
  return _mm256_cmpgt_epi32(b, a);
532
+ #else
533
+ Packet4i a1, a2, b1, b2;
534
+ split_two(a, a1, a2);
535
+ split_two(b, b1, b2);
536
+ return combine_two((Packet4i)_mm_cmpgt_epi32(b1, a1), (Packet4i)_mm_cmpgt_epi32(b2, a2));
537
+ #endif
531
538
  }
532
539
 
533
540
  template<>
@@ -560,6 +567,12 @@ namespace Eigen
560
567
  return _mm256_blendv_ps(elsePacket, thenPacket, ifPacket);
561
568
  }
562
569
 
570
+ template<>
571
+ EIGEN_STRONG_INLINE Packet8f pblendv(const Packet8i& ifPacket, const Packet8f& thenPacket, const Packet8f& elsePacket)
572
+ {
573
+ return pblendv(_mm256_castsi256_ps(ifPacket), thenPacket, elsePacket);
574
+ }
575
+
563
576
  template<>
564
577
  EIGEN_STRONG_INLINE Packet8i pblendv(const Packet8i& ifPacket, const Packet8i& thenPacket, const Packet8i& elsePacket)
565
578
  {
@@ -576,6 +589,12 @@ namespace Eigen
576
589
  return _mm256_blendv_pd(elsePacket, thenPacket, ifPacket);
577
590
  }
578
591
 
592
+ template<>
593
+ EIGEN_STRONG_INLINE Packet4d pblendv(const Packet8i& ifPacket, const Packet4d& thenPacket, const Packet4d& elsePacket)
594
+ {
595
+ return pblendv(_mm256_castsi256_pd(ifPacket), thenPacket, elsePacket);
596
+ }
597
+
579
598
  template<>
580
599
  EIGEN_STRONG_INLINE Packet8i pgather<Packet8i>(const int* addr, const Packet8i& index)
581
600
  {
@@ -660,7 +679,7 @@ namespace Eigen
660
679
  Packet4i a1, a2, b1, b2;
661
680
  split_two(a, a1, a2);
662
681
  split_two(b, b1, b2);
663
- return combine_two(_mm_cmpeq_epi64(a1, b1), _mm_cmpeq_epi64(a2, b2));
682
+ return combine_two((Packet4i)_mm_cmpeq_epi64(a1, b1), (Packet4i)_mm_cmpeq_epi64(a2, b2));
664
683
  #endif
665
684
  }
666
685
 
@@ -842,6 +861,12 @@ namespace Eigen
842
861
  #endif
843
862
  }
844
863
 
864
+ template<>
865
+ EIGEN_STRONG_INLINE Packet4f pblendv(const Packet4i& ifPacket, const Packet4f& thenPacket, const Packet4f& elsePacket)
866
+ {
867
+ return pblendv(_mm_castsi128_ps(ifPacket), thenPacket, elsePacket);
868
+ }
869
+
845
870
  template<>
846
871
  EIGEN_STRONG_INLINE Packet4i pblendv(const Packet4i& ifPacket, const Packet4i& thenPacket, const Packet4i& elsePacket)
847
872
  {
@@ -862,6 +887,13 @@ namespace Eigen
862
887
  #endif
863
888
  }
864
889
 
890
+
891
+ template<>
892
+ EIGEN_STRONG_INLINE Packet2d pblendv(const Packet4i& ifPacket, const Packet2d& thenPacket, const Packet2d& elsePacket)
893
+ {
894
+ return pblendv(_mm_castsi128_pd(ifPacket), thenPacket, elsePacket);
895
+ }
896
+
865
897
  template<>
866
898
  EIGEN_STRONG_INLINE Packet4i pgather<Packet4i>(const int* addr, const Packet4i& index)
867
899
  {
@@ -869,7 +901,7 @@ namespace Eigen
869
901
  return _mm_i32gather_epi32(addr, index, 4);
870
902
  #else
871
903
  uint32_t u[4];
872
- _mm_storeu_si128((Packet4i*)u, index);
904
+ _mm_storeu_si128((__m128i*)u, index);
873
905
  return _mm_setr_epi32(addr[u[0]], addr[u[1]], addr[u[2]], addr[u[3]]);
874
906
  #endif
875
907
  }
@@ -881,7 +913,7 @@ namespace Eigen
881
913
  return _mm_i32gather_ps(addr, index, 4);
882
914
  #else
883
915
  uint32_t u[4];
884
- _mm_storeu_si128((Packet4i*)u, index);
916
+ _mm_storeu_si128((__m128i*)u, index);
885
917
  return _mm_setr_ps(addr[u[0]], addr[u[1]], addr[u[2]], addr[u[3]]);
886
918
  #endif
887
919
  }
@@ -893,7 +925,7 @@ namespace Eigen
893
925
  return _mm_i32gather_pd(addr, index, 8);
894
926
  #else
895
927
  uint32_t u[4];
896
- _mm_storeu_si128((Packet4i*)u, index);
928
+ _mm_storeu_si128((__m128i*)u, index);
897
929
  if (upperhalf)
898
930
  {
899
931
  return _mm_setr_pd(addr[u[2]], addr[u[3]]);
@@ -920,7 +952,7 @@ namespace Eigen
920
952
  template<>
921
953
  EIGEN_STRONG_INLINE int pmovemask<Packet4i>(const Packet4i& a)
922
954
  {
923
- return pmovemask(_mm_castsi128_ps(a));
955
+ return pmovemask((Packet4f)_mm_castsi128_ps(a));
924
956
  }
925
957
 
926
958
  template<>
@@ -958,7 +990,7 @@ namespace Eigen
958
990
  return _mm_cmpeq_epi64(a, b);
959
991
  #else
960
992
  Packet4i c = _mm_cmpeq_epi32(a, b);
961
- return pand(c, _mm_shuffle_epi32(c, _MM_SHUFFLE(2, 3, 0, 1)));
993
+ return pand(c, (Packet4i)_mm_shuffle_epi32(c, _MM_SHUFFLE(2, 3, 0, 1)));
962
994
  #endif
963
995
  }
964
996
 
@@ -0,0 +1,222 @@
1
+ /**
2
+ * @file Multinomial.h
3
+ * @author bab2min (bab2min@gmail.com)
4
+ * @brief
5
+ * @version 0.3.0
6
+ * @date 2020-10-07
7
+ *
8
+ * @copyright Copyright (c) 2020
9
+ *
10
+ */
11
+
12
+ #ifndef EIGENRAND_MVDISTS_MULTINOMIAL_H
13
+ #define EIGENRAND_MVDISTS_MULTINOMIAL_H
14
+
15
+ namespace Eigen
16
+ {
17
+ namespace Rand
18
+ {
19
+ /**
20
+ * @brief Generator of real vectors on a multinomial distribution
21
+ *
22
+ * @tparam _Scalar
23
+ * @tparam Dim number of dimensions, or `Eigen::Dynamic`
24
+ */
25
+ template<typename _Scalar = int32_t, Index Dim = -1>
26
+ class MultinomialGen : public MvVecGenBase<MultinomialGen<_Scalar, Dim>, _Scalar, Dim>
27
+ {
28
+ static_assert(std::is_same<_Scalar, int32_t>::value, "`MultinomialGen` needs integral types.");
29
+ _Scalar trials;
30
+ Matrix<double, Dim, 1> probs;
31
+ DiscreteGen<_Scalar> discrete;
32
+ public:
33
+ /**
34
+ * @brief Construct a new multinomial generator
35
+ *
36
+ * @tparam WeightTy
37
+ * @param _trials the number of trials
38
+ * @param _weights the weights of each category, `(Dim, 1)` shape matrix or vector
39
+ */
40
+ template<typename WeightTy>
41
+ MultinomialGen(_Scalar _trials, const MatrixBase<WeightTy>& _weights)
42
+ : trials{ _trials }, probs{ _weights.template cast<double>() }, discrete(probs.data(), probs.data() + probs.size())
43
+ {
44
+ eigen_assert(_weights.cols() == 1);
45
+ for (Index i = 0; i < probs.size(); ++i)
46
+ {
47
+ eigen_assert(probs[i] >= 0);
48
+ }
49
+ probs /= probs.sum();
50
+ }
51
+
52
+ MultinomialGen(const MultinomialGen&) = default;
53
+ MultinomialGen(MultinomialGen&&) = default;
54
+
55
+ MultinomialGen& operator=(const MultinomialGen&) = default;
56
+ MultinomialGen& operator=(MultinomialGen&&) = default;
57
+
58
+ Index dims() const { return probs.rows(); }
59
+
60
+ template<typename Urng>
61
+ inline Matrix<_Scalar, Dim, -1> generate(Urng&& urng, Index samples)
62
+ {
63
+ const Index dim = probs.size();
64
+ Matrix<_Scalar, Dim, -1> ret(dim, samples);
65
+ //if (trials < 2500)
66
+ {
67
+ for (Index j = 0; j < samples; ++j)
68
+ {
69
+ ret.col(j) = generate(urng);
70
+ }
71
+ }
72
+ /*else
73
+ {
74
+ ret.row(0) = binomial<Matrix<_Scalar, -1, 1>>(samples, 1, urng, trials, probs[0]).eval().transpose();
75
+ for (Index j = 0; j < samples; ++j)
76
+ {
77
+ double rest_p = 1 - probs[0];
78
+ _Scalar t = trials - ret(0, j);
79
+ for (Index i = 1; i < dim - 1; ++i)
80
+ {
81
+ ret(i, j) = binomial<Matrix<_Scalar, 1, 1>>(1, 1, urng, t, probs[i] / rest_p)(0);
82
+ t -= ret(i, j);
83
+ rest_p -= probs[i];
84
+ }
85
+ ret(dim - 1, j) = 0;
86
+ }
87
+ ret.row(dim - 1).setZero();
88
+ ret.row(dim - 1).array() = trials - ret.colwise().sum().array();
89
+ }*/
90
+ return ret;
91
+ }
92
+
93
+ template<typename Urng>
94
+ inline Matrix<_Scalar, Dim, 1> generate(Urng&& urng)
95
+ {
96
+ const Index dim = probs.size();
97
+ Matrix<_Scalar, Dim, 1> ret(dim);
98
+ //if (trials < 2500)
99
+ {
100
+ ret.setZero();
101
+ auto d = discrete.template generate<Matrix<_Scalar, -1, 1>>(trials, 1, urng).eval();
102
+ for (Index i = 0; i < trials; ++i)
103
+ {
104
+ ret[d[i]] += 1;
105
+ }
106
+ }
107
+ /*else
108
+ {
109
+ double rest_p = 1;
110
+ _Scalar t = trials;
111
+ for (Index i = 0; i < dim - 1; ++i)
112
+ {
113
+ ret[i] = binomial<Matrix<_Scalar, 1, 1>>(1, 1, urng, t, probs[i] / rest_p)(0);
114
+ t -= ret[i];
115
+ rest_p -= probs[i];
116
+ }
117
+ ret[dim - 1] = 0;
118
+ ret[dim - 1] = trials - ret.sum();
119
+ }*/
120
+ return ret;
121
+ }
122
+ };
123
+
124
+ /**
125
+ * @brief helper function constructing Eigen::Rand::MultinomialGen
126
+ *
127
+ * @tparam IntTy
128
+ * @tparam WeightTy
129
+ * @param trials the number of trials
130
+ * @param probs The weights of each category with shape `(Dim, 1)` of matrix or vector.
131
+ * The number of entries determines the dimensionality of the distribution
132
+ * @return an instance of MultinomialGen in the appropriate type
133
+ */
134
+ template<typename IntTy, typename WeightTy>
135
+ inline auto makeMultinomialGen(IntTy trials, const MatrixBase<WeightTy>& probs)
136
+ -> MultinomialGen<IntTy, MatrixBase<WeightTy>::RowsAtCompileTime>
137
+ {
138
+ return MultinomialGen<IntTy, MatrixBase<WeightTy>::RowsAtCompileTime>{ trials, probs };
139
+ }
140
+
141
+ /**
142
+ * @brief Generator of reals on a Dirichlet distribution
143
+ *
144
+ * @tparam _Scalar
145
+ * @tparam Dim number of dimensions, or `Eigen::Dynamic`
146
+ */
147
+ template<typename _Scalar, Index Dim = -1>
148
+ class DirichletGen : public MvVecGenBase<DirichletGen<_Scalar, Dim>, _Scalar, Dim>
149
+ {
150
+ Matrix<_Scalar, Dim, 1> alpha;
151
+ std::vector<GammaGen<_Scalar>> gammas;
152
+ public:
153
+ /**
154
+ * @brief Construct a new Dirichlet generator
155
+ *
156
+ * @tparam AlphaTy
157
+ * @param _alpha the concentration parameters with shape `(Dim, 1)` matrix or vector
158
+ */
159
+ template<typename AlphaTy>
160
+ DirichletGen(const MatrixBase<AlphaTy>& _alpha)
161
+ : alpha{ _alpha }
162
+ {
163
+ eigen_assert(_alpha.cols() == 1);
164
+ for (Index i = 0; i < alpha.size(); ++i)
165
+ {
166
+ eigen_assert(alpha[i] > 0);
167
+ gammas.emplace_back(alpha[i]);
168
+ }
169
+ }
170
+
171
+ DirichletGen(const DirichletGen&) = default;
172
+ DirichletGen(DirichletGen&&) = default;
173
+
174
+ Index dims() const { return alpha.rows(); }
175
+
176
+ template<typename Urng>
177
+ inline Matrix<_Scalar, Dim, -1> generate(Urng&& urng, Index samples)
178
+ {
179
+ const Index dim = alpha.size();
180
+ Matrix<_Scalar, Dim, -1> ret(dim, samples);
181
+ Matrix<_Scalar, -1, 1> tmp(samples);
182
+ for (Index i = 0; i < dim; ++i)
183
+ {
184
+ tmp = gammas[i].generateLike(tmp, urng);
185
+ ret.row(i) = tmp.transpose();
186
+ }
187
+ ret.array().rowwise() /= ret.array().colwise().sum();
188
+ return ret;
189
+ }
190
+
191
+ template<typename Urng>
192
+ inline Matrix<_Scalar, Dim, 1> generate(Urng&& urng)
193
+ {
194
+ const Index dim = alpha.size();
195
+ Matrix<_Scalar, Dim, 1> ret(dim);
196
+ for (Index i = 0; i < dim; ++i)
197
+ {
198
+ ret[i] = gammas[i].template generate<Matrix<_Scalar, 1, 1>>(1, 1, urng)(0);
199
+ }
200
+ ret /= ret.sum();
201
+ return ret;
202
+ }
203
+ };
204
+
205
+ /**
206
+ * @brief helper function constructing Eigen::Rand::DirichletGen
207
+ *
208
+ * @tparam AlphaTy
209
+ * @param alpha The concentration parameters with shape `(Dim, 1)` of matrix or vector.
210
+ * The number of entries determines the dimensionality of the distribution.
211
+ * @return an instance of MultinomialGen in the appropriate type
212
+ */
213
+ template<typename AlphaTy>
214
+ inline auto makeDirichletGen(const MatrixBase<AlphaTy>& alpha)
215
+ -> DirichletGen<typename MatrixBase<AlphaTy>::Scalar, MatrixBase<AlphaTy>::RowsAtCompileTime>
216
+ {
217
+ return DirichletGen<typename MatrixBase<AlphaTy>::Scalar, MatrixBase<AlphaTy>::RowsAtCompileTime>{ alpha };
218
+ }
219
+ }
220
+ }
221
+
222
+ #endif