lda-ruby 0.4.0 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,7 @@
1
1
  use magnus::{define_module, function, Error, Module, Object};
2
+ use std::collections::HashMap;
3
+ use std::sync::atomic::{AtomicU64, Ordering};
4
+ use std::sync::{Arc, Mutex, OnceLock};
2
5
 
3
6
  fn available() -> bool {
4
7
  true
@@ -40,6 +43,99 @@ fn normalize_in_place(weights: &mut [f64]) {
40
43
  }
41
44
  }
42
45
 
46
+ #[derive(Clone, PartialEq)]
47
+ struct SessionConfig {
48
+ topics: usize,
49
+ max_iter: i64,
50
+ convergence: f64,
51
+ em_max_iter: i64,
52
+ em_convergence: f64,
53
+ init_alpha: f64,
54
+ min_probability: f64,
55
+ }
56
+
57
+ struct CorpusSessionData {
58
+ document_words: Vec<Vec<usize>>,
59
+ document_counts: Vec<Vec<f64>>,
60
+ terms: usize,
61
+ }
62
+
63
+ struct CorpusSession {
64
+ data: Arc<CorpusSessionData>,
65
+ config: Option<SessionConfig>,
66
+ }
67
+
68
+ static CORPUS_SESSIONS: OnceLock<Mutex<HashMap<u64, CorpusSession>>> = OnceLock::new();
69
+ static NEXT_CORPUS_SESSION_ID: AtomicU64 = AtomicU64::new(1);
70
+
71
+ fn corpus_sessions() -> &'static Mutex<HashMap<u64, CorpusSession>> {
72
+ CORPUS_SESSIONS.get_or_init(|| Mutex::new(HashMap::new()))
73
+ }
74
+
75
+ fn corpus_session_count() -> i64 {
76
+ match corpus_sessions().lock() {
77
+ Ok(sessions) => sessions.len() as i64,
78
+ Err(_) => 0,
79
+ }
80
+ }
81
+
82
+ fn corpus_session_exists(session_id: i64) -> bool {
83
+ if session_id <= 0 {
84
+ return false;
85
+ }
86
+
87
+ let session_key = session_id as u64;
88
+ match corpus_sessions().lock() {
89
+ Ok(sessions) => sessions.contains_key(&session_key),
90
+ Err(_) => false,
91
+ }
92
+ }
93
+
94
+ fn empty_em_output() -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>) {
95
+ (Vec::new(), Vec::new(), Vec::new(), Vec::new())
96
+ }
97
+
98
+ fn empty_managed_session_em_output(
99
+ ) -> (
100
+ i64,
101
+ Vec<Vec<f64>>,
102
+ Vec<Vec<f64>>,
103
+ Vec<Vec<f64>>,
104
+ Vec<Vec<Vec<f64>>>,
105
+ ) {
106
+ (0, Vec::new(), Vec::new(), Vec::new(), Vec::new())
107
+ }
108
+
109
+ struct XorShift64 {
110
+ state: u64,
111
+ }
112
+
113
+ impl XorShift64 {
114
+ fn new(seed: i64) -> Self {
115
+ let mut state = seed as u64;
116
+ if state == 0 {
117
+ state = 0x9E37_79B9_7F4A_7C15;
118
+ }
119
+
120
+ Self { state }
121
+ }
122
+
123
+ fn next_u64(&mut self) -> u64 {
124
+ let mut x = self.state;
125
+ x ^= x >> 12;
126
+ x ^= x << 25;
127
+ x ^= x >> 27;
128
+ self.state = x;
129
+ x.wrapping_mul(0x2545_F491_4F6C_DD1D)
130
+ }
131
+
132
+ fn next_f64_unit(&mut self) -> f64 {
133
+ // Keep 53 random bits to map uniformly into [0, 1).
134
+ let value = self.next_u64() >> 11;
135
+ value as f64 / ((1_u64 << 53) as f64)
136
+ }
137
+ }
138
+
43
139
  fn compute_topic_weights(
44
140
  beta_probabilities: &[Vec<f64>],
45
141
  gamma: &[f64],
@@ -158,7 +254,7 @@ fn normalize_topic_term_counts(
158
254
  (beta_probabilities, beta_log)
159
255
  }
160
256
 
161
- fn average_gamma_shift(previous_gamma: Vec<Vec<f64>>, current_gamma: Vec<Vec<f64>>) -> f64 {
257
+ fn average_gamma_shift_internal(previous_gamma: &[Vec<f64>], current_gamma: &[Vec<f64>]) -> f64 {
162
258
  let mut sum = 0.0_f64;
163
259
  let mut count = 0_usize;
164
260
 
@@ -183,6 +279,10 @@ fn average_gamma_shift(previous_gamma: Vec<Vec<f64>>, current_gamma: Vec<Vec<f64
183
279
  }
184
280
  }
185
281
 
282
+ fn average_gamma_shift(previous_gamma: Vec<Vec<f64>>, current_gamma: Vec<Vec<f64>>) -> f64 {
283
+ average_gamma_shift_internal(previous_gamma.as_slice(), current_gamma.as_slice())
284
+ }
285
+
186
286
  fn topic_document_probability(
187
287
  phi_tensor: Vec<Vec<Vec<f64>>>,
188
288
  document_counts: Vec<Vec<f64>>,
@@ -222,9 +322,9 @@ fn topic_document_probability(
222
322
  output
223
323
  }
224
324
 
225
- fn seeded_topic_term_probabilities(
226
- document_words: Vec<Vec<usize>>,
227
- document_counts: Vec<Vec<f64>>,
325
+ fn seeded_topic_term_probabilities_internal(
326
+ document_words: &[Vec<usize>],
327
+ document_counts: &[Vec<f64>],
228
328
  topics: usize,
229
329
  terms: usize,
230
330
  min_probability: f64,
@@ -264,6 +364,427 @@ fn seeded_topic_term_probabilities(
264
364
  topic_term_counts
265
365
  }
266
366
 
367
+ fn seeded_topic_term_probabilities(
368
+ document_words: Vec<Vec<usize>>,
369
+ document_counts: Vec<Vec<f64>>,
370
+ topics: usize,
371
+ terms: usize,
372
+ min_probability: f64,
373
+ ) -> Vec<Vec<f64>> {
374
+ seeded_topic_term_probabilities_internal(
375
+ document_words.as_slice(),
376
+ document_counts.as_slice(),
377
+ topics,
378
+ terms,
379
+ min_probability,
380
+ )
381
+ }
382
+
383
+ fn random_topic_term_probabilities(
384
+ topics: usize,
385
+ terms: usize,
386
+ min_probability: f64,
387
+ random_seed: i64,
388
+ ) -> Vec<Vec<f64>> {
389
+ if topics == 0 || terms == 0 {
390
+ return Vec::new();
391
+ }
392
+
393
+ let floor = floor_value(min_probability);
394
+ let mut rng = XorShift64::new(random_seed);
395
+ let mut matrix = Vec::with_capacity(topics);
396
+
397
+ for _ in 0..topics {
398
+ let mut weights = Vec::with_capacity(terms);
399
+ for _ in 0..terms {
400
+ weights.push(rng.next_f64_unit() + floor);
401
+ }
402
+ normalize_in_place(&mut weights);
403
+ matrix.push(weights);
404
+ }
405
+
406
+ matrix
407
+ }
408
+
409
+ fn corpus_session_data(
410
+ document_words: &[Vec<usize>],
411
+ document_counts: &[Vec<f64>],
412
+ terms: usize,
413
+ ) -> Arc<CorpusSessionData> {
414
+ Arc::new(CorpusSessionData {
415
+ document_words: document_words.to_vec(),
416
+ document_counts: document_counts.to_vec(),
417
+ terms,
418
+ })
419
+ }
420
+
421
+ fn create_corpus_session_internal(
422
+ document_words: &[Vec<usize>],
423
+ document_counts: &[Vec<f64>],
424
+ terms: usize,
425
+ ) -> i64 {
426
+ let session_id = NEXT_CORPUS_SESSION_ID.fetch_add(1, Ordering::Relaxed);
427
+ let session = CorpusSession {
428
+ data: corpus_session_data(document_words, document_counts, terms),
429
+ config: None,
430
+ };
431
+
432
+ match corpus_sessions().lock() {
433
+ Ok(mut sessions) => {
434
+ sessions.insert(session_id, session);
435
+ session_id as i64
436
+ }
437
+ Err(_) => 0,
438
+ }
439
+ }
440
+
441
+ fn create_corpus_session(
442
+ document_words: Vec<Vec<usize>>,
443
+ document_counts: Vec<Vec<f64>>,
444
+ terms: usize,
445
+ ) -> i64 {
446
+ create_corpus_session_internal(document_words.as_slice(), document_counts.as_slice(), terms)
447
+ }
448
+
449
+ fn replace_corpus_session_internal(
450
+ session_id: i64,
451
+ document_words: &[Vec<usize>],
452
+ document_counts: &[Vec<f64>],
453
+ terms: usize,
454
+ ) -> i64 {
455
+ if terms == 0 {
456
+ return 0;
457
+ }
458
+
459
+ let replacement_data = corpus_session_data(document_words, document_counts, terms);
460
+ match corpus_sessions().lock() {
461
+ Ok(mut sessions) => {
462
+ if session_id > 0 {
463
+ let session_key = session_id as u64;
464
+ if let Some(session) = sessions.get_mut(&session_key) {
465
+ session.data = replacement_data;
466
+ session.config = None;
467
+ return session_id;
468
+ }
469
+ }
470
+
471
+ let new_session_id = NEXT_CORPUS_SESSION_ID.fetch_add(1, Ordering::Relaxed);
472
+ sessions.insert(
473
+ new_session_id,
474
+ CorpusSession {
475
+ data: replacement_data,
476
+ config: None,
477
+ },
478
+ );
479
+ new_session_id as i64
480
+ }
481
+ Err(_) => 0,
482
+ }
483
+ }
484
+
485
+ fn replace_corpus_session(
486
+ session_id: i64,
487
+ document_words: Vec<Vec<usize>>,
488
+ document_counts: Vec<Vec<f64>>,
489
+ terms: usize,
490
+ ) -> i64 {
491
+ replace_corpus_session_internal(
492
+ session_id,
493
+ document_words.as_slice(),
494
+ document_counts.as_slice(),
495
+ terms,
496
+ )
497
+ }
498
+
499
+ fn ensure_corpus_session(
500
+ session_id: i64,
501
+ document_words: &[Vec<usize>],
502
+ document_counts: &[Vec<f64>],
503
+ terms: usize,
504
+ ) -> i64 {
505
+ if terms == 0 {
506
+ return 0;
507
+ }
508
+
509
+ if session_id > 0 && corpus_session_exists(session_id) {
510
+ return session_id;
511
+ }
512
+
513
+ create_corpus_session_internal(document_words, document_counts, terms)
514
+ }
515
+
516
+ fn drop_corpus_session(session_id: i64) -> bool {
517
+ if session_id <= 0 {
518
+ return false;
519
+ }
520
+
521
+ let session_key = session_id as u64;
522
+ match corpus_sessions().lock() {
523
+ Ok(mut sessions) => sessions.remove(&session_key).is_some(),
524
+ Err(_) => false,
525
+ }
526
+ }
527
+
528
+ fn configure_corpus_session(
529
+ session_id: i64,
530
+ topics: usize,
531
+ max_iter: i64,
532
+ convergence: f64,
533
+ em_max_iter: i64,
534
+ em_convergence: f64,
535
+ init_alpha: f64,
536
+ min_probability: f64,
537
+ ) -> bool {
538
+ if session_id <= 0 || topics == 0 {
539
+ return false;
540
+ }
541
+
542
+ let session_key = session_id as u64;
543
+ match corpus_sessions().lock() {
544
+ Ok(mut sessions) => {
545
+ let Some(session) = sessions.get_mut(&session_key) else {
546
+ return false;
547
+ };
548
+
549
+ session.config = Some(SessionConfig {
550
+ topics,
551
+ max_iter,
552
+ convergence,
553
+ em_max_iter,
554
+ em_convergence,
555
+ init_alpha,
556
+ min_probability,
557
+ });
558
+
559
+ true
560
+ }
561
+ Err(_) => false,
562
+ }
563
+ }
564
+
565
+ fn run_em_on_session_with_start_seed(
566
+ session_id: i64,
567
+ start: String,
568
+ topics: usize,
569
+ max_iter: i64,
570
+ convergence: f64,
571
+ em_max_iter: i64,
572
+ em_convergence: f64,
573
+ init_alpha: f64,
574
+ min_probability: f64,
575
+ random_seed: i64,
576
+ ) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>) {
577
+ if session_id <= 0 {
578
+ return empty_em_output();
579
+ }
580
+
581
+ let session_key = session_id as u64;
582
+ let session_data = match corpus_sessions().lock() {
583
+ Ok(sessions) => sessions
584
+ .get(&session_key)
585
+ .map(|session| Arc::clone(&session.data)),
586
+ Err(_) => None,
587
+ };
588
+
589
+ let Some(session_data) = session_data else {
590
+ return empty_em_output();
591
+ };
592
+
593
+ run_em_with_start_seed_internal(
594
+ start.as_str(),
595
+ session_data.document_words.as_slice(),
596
+ session_data.document_counts.as_slice(),
597
+ topics,
598
+ session_data.terms,
599
+ max_iter,
600
+ convergence,
601
+ em_max_iter,
602
+ em_convergence,
603
+ init_alpha,
604
+ min_probability,
605
+ random_seed,
606
+ )
607
+ }
608
+
609
+ fn run_em_on_session(
610
+ session_id: i64,
611
+ start: String,
612
+ topics: usize,
613
+ max_iter: i64,
614
+ convergence: f64,
615
+ em_max_iter: i64,
616
+ em_convergence: f64,
617
+ init_alpha: f64,
618
+ min_probability: f64,
619
+ random_seed: i64,
620
+ ) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>) {
621
+ if session_id <= 0 || topics == 0 {
622
+ return empty_em_output();
623
+ }
624
+
625
+ let desired_config = SessionConfig {
626
+ topics,
627
+ max_iter,
628
+ convergence,
629
+ em_max_iter,
630
+ em_convergence,
631
+ init_alpha,
632
+ min_probability,
633
+ };
634
+
635
+ let session_key = session_id as u64;
636
+ let session_data = match corpus_sessions().lock() {
637
+ Ok(mut sessions) => {
638
+ let Some(session) = sessions.get_mut(&session_key) else {
639
+ return empty_em_output();
640
+ };
641
+
642
+ if session.config.as_ref() != Some(&desired_config) {
643
+ session.config = Some(desired_config.clone());
644
+ }
645
+
646
+ Arc::clone(&session.data)
647
+ }
648
+ Err(_) => return empty_em_output(),
649
+ };
650
+
651
+ run_em_with_start_seed_internal(
652
+ start.as_str(),
653
+ session_data.document_words.as_slice(),
654
+ session_data.document_counts.as_slice(),
655
+ desired_config.topics,
656
+ session_data.terms,
657
+ desired_config.max_iter,
658
+ desired_config.convergence,
659
+ desired_config.em_max_iter,
660
+ desired_config.em_convergence,
661
+ desired_config.init_alpha,
662
+ desired_config.min_probability,
663
+ random_seed,
664
+ )
665
+ }
666
+
667
+ fn run_em_on_session_with_corpus(
668
+ session_id: i64,
669
+ document_words: Vec<Vec<usize>>,
670
+ document_counts: Vec<Vec<f64>>,
671
+ terms: usize,
672
+ start: String,
673
+ topics: usize,
674
+ max_iter: i64,
675
+ convergence: f64,
676
+ em_max_iter: i64,
677
+ em_convergence: f64,
678
+ init_alpha: f64,
679
+ min_probability: f64,
680
+ random_seed: i64,
681
+ ) -> (
682
+ i64,
683
+ Vec<Vec<f64>>,
684
+ Vec<Vec<f64>>,
685
+ Vec<Vec<f64>>,
686
+ Vec<Vec<Vec<f64>>>,
687
+ ) {
688
+ if topics == 0 || terms == 0 {
689
+ return empty_managed_session_em_output();
690
+ }
691
+
692
+ let active_session_id = ensure_corpus_session(
693
+ session_id,
694
+ document_words.as_slice(),
695
+ document_counts.as_slice(),
696
+ terms,
697
+ );
698
+
699
+ if active_session_id > 0 {
700
+ let (beta_probabilities, beta_log, gamma, phi) = run_em_on_session(
701
+ active_session_id,
702
+ start.clone(),
703
+ topics,
704
+ max_iter,
705
+ convergence,
706
+ em_max_iter,
707
+ em_convergence,
708
+ init_alpha,
709
+ min_probability,
710
+ random_seed,
711
+ );
712
+
713
+ if !(beta_probabilities.is_empty()
714
+ && beta_log.is_empty()
715
+ && gamma.is_empty()
716
+ && phi.is_empty())
717
+ {
718
+ return (active_session_id, beta_probabilities, beta_log, gamma, phi);
719
+ }
720
+ }
721
+
722
+ let (beta_probabilities, beta_log, gamma, phi) = run_em_with_start_seed_internal(
723
+ start.as_str(),
724
+ document_words.as_slice(),
725
+ document_counts.as_slice(),
726
+ topics,
727
+ terms,
728
+ max_iter,
729
+ convergence,
730
+ em_max_iter,
731
+ em_convergence,
732
+ init_alpha,
733
+ min_probability,
734
+ random_seed,
735
+ );
736
+
737
+ if beta_probabilities.is_empty() && beta_log.is_empty() && gamma.is_empty() && phi.is_empty() {
738
+ return empty_managed_session_em_output();
739
+ }
740
+
741
+ (active_session_id, beta_probabilities, beta_log, gamma, phi)
742
+ }
743
+
744
+ fn run_em_on_session_start(
745
+ session_id: i64,
746
+ start: String,
747
+ random_seed: i64,
748
+ ) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>) {
749
+ if session_id <= 0 {
750
+ return empty_em_output();
751
+ }
752
+
753
+ let session_key = session_id as u64;
754
+ let session_data = match corpus_sessions().lock() {
755
+ Ok(sessions) => sessions.get(&session_key).map(|session| {
756
+ (
757
+ Arc::clone(&session.data),
758
+ session.config.clone(),
759
+ )
760
+ }),
761
+ Err(_) => None,
762
+ };
763
+
764
+ let Some((session_data, config)) = session_data else {
765
+ return empty_em_output();
766
+ };
767
+
768
+ let Some(config) = config else {
769
+ return empty_em_output();
770
+ };
771
+
772
+ run_em_with_start_seed_internal(
773
+ start.as_str(),
774
+ session_data.document_words.as_slice(),
775
+ session_data.document_counts.as_slice(),
776
+ config.topics,
777
+ session_data.terms,
778
+ config.max_iter,
779
+ config.convergence,
780
+ config.em_max_iter,
781
+ config.em_convergence,
782
+ config.init_alpha,
783
+ config.min_probability,
784
+ random_seed,
785
+ )
786
+ }
787
+
267
788
  fn infer_document_internal(
268
789
  beta_probabilities: &[Vec<f64>],
269
790
  gamma_initial: &[f64],
@@ -360,10 +881,10 @@ fn infer_document(
360
881
  output
361
882
  }
362
883
 
363
- fn infer_corpus_iteration(
364
- beta_probabilities: Vec<Vec<f64>>,
365
- document_words: Vec<Vec<usize>>,
366
- document_counts: Vec<Vec<f64>>,
884
+ fn infer_corpus_iteration_internal(
885
+ beta_probabilities: &[Vec<f64>],
886
+ document_words: &[Vec<usize>],
887
+ document_counts: &[Vec<f64>],
367
888
  max_iter: i64,
368
889
  convergence: f64,
369
890
  min_probability: f64,
@@ -392,7 +913,7 @@ fn infer_corpus_iteration(
392
913
  let gamma_initial = vec![init_alpha_value + (total / topics as f64); topics];
393
914
 
394
915
  let (gamma_d, phi_d) = infer_document_internal(
395
- beta_probabilities.as_slice(),
916
+ beta_probabilities,
396
917
  gamma_initial.as_slice(),
397
918
  words.as_slice(),
398
919
  counts.as_slice(),
@@ -416,6 +937,264 @@ fn infer_corpus_iteration(
416
937
  (gamma_matrix, phi_tensor, topic_term_counts)
417
938
  }
418
939
 
940
+ fn infer_corpus_iteration(
941
+ beta_probabilities: Vec<Vec<f64>>,
942
+ document_words: Vec<Vec<usize>>,
943
+ document_counts: Vec<Vec<f64>>,
944
+ max_iter: i64,
945
+ convergence: f64,
946
+ min_probability: f64,
947
+ init_alpha: f64,
948
+ ) -> (Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>, Vec<Vec<f64>>) {
949
+ infer_corpus_iteration_internal(
950
+ beta_probabilities.as_slice(),
951
+ document_words.as_slice(),
952
+ document_counts.as_slice(),
953
+ max_iter,
954
+ convergence,
955
+ min_probability,
956
+ init_alpha,
957
+ )
958
+ }
959
+
960
+ fn start_uses_seeded_initialization(start: &str) -> bool {
961
+ let normalized = start.trim().to_ascii_lowercase();
962
+ normalized == "seeded" || normalized == "deterministic"
963
+ }
964
+
965
+ fn start_uses_random_initialization(start: &str) -> bool {
966
+ start.trim().eq_ignore_ascii_case("random")
967
+ }
968
+
969
+ fn run_em_internal(
970
+ mut beta_probabilities: Vec<Vec<f64>>,
971
+ document_words: &[Vec<usize>],
972
+ document_counts: &[Vec<f64>],
973
+ max_iter: i64,
974
+ convergence: f64,
975
+ em_max_iter: i64,
976
+ em_convergence: f64,
977
+ init_alpha: f64,
978
+ min_probability: f64,
979
+ ) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>) {
980
+ let em_max_iter_value = if em_max_iter <= 0 { 0 } else { em_max_iter as usize };
981
+ let em_convergence_value = if em_convergence.is_finite() && em_convergence >= 0.0 {
982
+ em_convergence
983
+ } else {
984
+ 1.0e-4
985
+ };
986
+
987
+ let mut previous_gamma: Option<Vec<Vec<f64>>> = None;
988
+ let mut beta_log: Vec<Vec<f64>> = Vec::new();
989
+ let mut gamma: Vec<Vec<f64>> = Vec::new();
990
+ let mut phi: Vec<Vec<Vec<f64>>> = Vec::new();
991
+
992
+ for _ in 0..em_max_iter_value {
993
+ let (current_gamma, current_phi, topic_term_counts) = infer_corpus_iteration_internal(
994
+ beta_probabilities.as_slice(),
995
+ document_words,
996
+ document_counts,
997
+ max_iter,
998
+ convergence,
999
+ min_probability,
1000
+ init_alpha,
1001
+ );
1002
+
1003
+ let (next_beta_probabilities, next_beta_log) =
1004
+ normalize_topic_term_counts(topic_term_counts, min_probability);
1005
+ let should_stop = previous_gamma
1006
+ .as_ref()
1007
+ .map(|prev| {
1008
+ average_gamma_shift_internal(prev.as_slice(), current_gamma.as_slice())
1009
+ <= em_convergence_value
1010
+ })
1011
+ .unwrap_or(false);
1012
+
1013
+ beta_probabilities = next_beta_probabilities;
1014
+ beta_log = next_beta_log;
1015
+ gamma = current_gamma;
1016
+ phi = current_phi;
1017
+
1018
+ if should_stop {
1019
+ break;
1020
+ }
1021
+
1022
+ previous_gamma = Some(gamma.clone());
1023
+ }
1024
+
1025
+ (beta_probabilities, beta_log, gamma, phi)
1026
+ }
1027
+
1028
+ fn run_em(
1029
+ beta_probabilities: Vec<Vec<f64>>,
1030
+ document_words: Vec<Vec<usize>>,
1031
+ document_counts: Vec<Vec<f64>>,
1032
+ max_iter: i64,
1033
+ convergence: f64,
1034
+ em_max_iter: i64,
1035
+ em_convergence: f64,
1036
+ init_alpha: f64,
1037
+ min_probability: f64,
1038
+ ) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>) {
1039
+ run_em_internal(
1040
+ beta_probabilities,
1041
+ document_words.as_slice(),
1042
+ document_counts.as_slice(),
1043
+ max_iter,
1044
+ convergence,
1045
+ em_max_iter,
1046
+ em_convergence,
1047
+ init_alpha,
1048
+ min_probability,
1049
+ )
1050
+ }
1051
+
1052
+ fn run_em_with_start_internal(
1053
+ start: &str,
1054
+ document_words: &[Vec<usize>],
1055
+ document_counts: &[Vec<f64>],
1056
+ topics: usize,
1057
+ terms: usize,
1058
+ max_iter: i64,
1059
+ convergence: f64,
1060
+ em_max_iter: i64,
1061
+ em_convergence: f64,
1062
+ init_alpha: f64,
1063
+ min_probability: f64,
1064
+ ) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>) {
1065
+ let initial_beta =
1066
+ if start_uses_seeded_initialization(start) || start_uses_random_initialization(start) {
1067
+ seeded_topic_term_probabilities_internal(
1068
+ document_words,
1069
+ document_counts,
1070
+ topics,
1071
+ terms,
1072
+ min_probability,
1073
+ )
1074
+ } else {
1075
+ // Unknown start modes default to seeded initialization for a stable fallback.
1076
+ seeded_topic_term_probabilities_internal(
1077
+ document_words,
1078
+ document_counts,
1079
+ topics,
1080
+ terms,
1081
+ min_probability,
1082
+ )
1083
+ };
1084
+
1085
+ run_em_internal(
1086
+ initial_beta,
1087
+ document_words,
1088
+ document_counts,
1089
+ max_iter,
1090
+ convergence,
1091
+ em_max_iter,
1092
+ em_convergence,
1093
+ init_alpha,
1094
+ min_probability,
1095
+ )
1096
+ }
1097
+
1098
+ fn run_em_with_start(
1099
+ start: String,
1100
+ document_words: Vec<Vec<usize>>,
1101
+ document_counts: Vec<Vec<f64>>,
1102
+ topics: usize,
1103
+ terms: usize,
1104
+ max_iter: i64,
1105
+ convergence: f64,
1106
+ em_max_iter: i64,
1107
+ em_convergence: f64,
1108
+ init_alpha: f64,
1109
+ min_probability: f64,
1110
+ ) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>) {
1111
+ run_em_with_start_internal(
1112
+ start.as_str(),
1113
+ document_words.as_slice(),
1114
+ document_counts.as_slice(),
1115
+ topics,
1116
+ terms,
1117
+ max_iter,
1118
+ convergence,
1119
+ em_max_iter,
1120
+ em_convergence,
1121
+ init_alpha,
1122
+ min_probability,
1123
+ )
1124
+ }
1125
+
1126
+ fn run_em_with_start_seed_internal(
1127
+ start: &str,
1128
+ document_words: &[Vec<usize>],
1129
+ document_counts: &[Vec<f64>],
1130
+ topics: usize,
1131
+ terms: usize,
1132
+ max_iter: i64,
1133
+ convergence: f64,
1134
+ em_max_iter: i64,
1135
+ em_convergence: f64,
1136
+ init_alpha: f64,
1137
+ min_probability: f64,
1138
+ random_seed: i64,
1139
+ ) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>) {
1140
+ let initial_beta = if start_uses_seeded_initialization(start) {
1141
+ seeded_topic_term_probabilities_internal(
1142
+ document_words,
1143
+ document_counts,
1144
+ topics,
1145
+ terms,
1146
+ min_probability,
1147
+ )
1148
+ } else if start_uses_random_initialization(start) {
1149
+ random_topic_term_probabilities(topics, terms, min_probability, random_seed)
1150
+ } else {
1151
+ // Unknown start modes follow Ruby's non-seeded fallback behavior.
1152
+ random_topic_term_probabilities(topics, terms, min_probability, random_seed)
1153
+ };
1154
+
1155
+ run_em_internal(
1156
+ initial_beta,
1157
+ document_words,
1158
+ document_counts,
1159
+ max_iter,
1160
+ convergence,
1161
+ em_max_iter,
1162
+ em_convergence,
1163
+ init_alpha,
1164
+ min_probability,
1165
+ )
1166
+ }
1167
+
1168
+ fn run_em_with_start_seed(
1169
+ start: String,
1170
+ document_words: Vec<Vec<usize>>,
1171
+ document_counts: Vec<Vec<f64>>,
1172
+ topics: usize,
1173
+ terms: usize,
1174
+ max_iter: i64,
1175
+ convergence: f64,
1176
+ em_max_iter: i64,
1177
+ em_convergence: f64,
1178
+ init_alpha: f64,
1179
+ min_probability: f64,
1180
+ random_seed: i64,
1181
+ ) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>) {
1182
+ run_em_with_start_seed_internal(
1183
+ start.as_str(),
1184
+ document_words.as_slice(),
1185
+ document_counts.as_slice(),
1186
+ topics,
1187
+ terms,
1188
+ max_iter,
1189
+ convergence,
1190
+ em_max_iter,
1191
+ em_convergence,
1192
+ init_alpha,
1193
+ min_probability,
1194
+ random_seed,
1195
+ )
1196
+ }
1197
+
419
1198
  #[magnus::init]
420
1199
  fn init() -> Result<(), Error> {
421
1200
  let lda_module = define_module("Lda")?;
@@ -423,6 +1202,8 @@ fn init() -> Result<(), Error> {
423
1202
 
424
1203
  rust_backend_module.define_singleton_method("available?", function!(available, 0))?;
425
1204
  rust_backend_module.define_singleton_method("abi_version", function!(abi_version, 0))?;
1205
+ rust_backend_module.define_singleton_method("corpus_session_count", function!(corpus_session_count, 0))?;
1206
+ rust_backend_module.define_singleton_method("corpus_session_exists", function!(corpus_session_exists, 1))?;
426
1207
  rust_backend_module.define_singleton_method("before_em", function!(before_em, 3))?;
427
1208
  rust_backend_module.define_singleton_method(
428
1209
  "topic_weights_for_word",
@@ -451,6 +1232,32 @@ fn init() -> Result<(), Error> {
451
1232
  "seeded_topic_term_probabilities",
452
1233
  function!(seeded_topic_term_probabilities, 5),
453
1234
  )?;
1235
+ rust_backend_module.define_singleton_method(
1236
+ "random_topic_term_probabilities",
1237
+ function!(random_topic_term_probabilities, 4),
1238
+ )?;
1239
+ rust_backend_module
1240
+ .define_singleton_method("create_corpus_session", function!(create_corpus_session, 3))?;
1241
+ rust_backend_module
1242
+ .define_singleton_method("replace_corpus_session", function!(replace_corpus_session, 4))?;
1243
+ rust_backend_module
1244
+ .define_singleton_method("drop_corpus_session", function!(drop_corpus_session, 1))?;
1245
+ rust_backend_module
1246
+ .define_singleton_method("configure_corpus_session", function!(configure_corpus_session, 8))?;
1247
+ rust_backend_module.define_singleton_method("run_em", function!(run_em, 9))?;
1248
+ rust_backend_module
1249
+ .define_singleton_method("run_em_with_start", function!(run_em_with_start, 11))?;
1250
+ rust_backend_module
1251
+ .define_singleton_method("run_em_with_start_seed", function!(run_em_with_start_seed, 12))?;
1252
+ rust_backend_module.define_singleton_method(
1253
+ "run_em_on_session_with_start_seed",
1254
+ function!(run_em_on_session_with_start_seed, 10),
1255
+ )?;
1256
+ rust_backend_module.define_singleton_method("run_em_on_session", function!(run_em_on_session, 10))?;
1257
+ rust_backend_module
1258
+ .define_singleton_method("run_em_on_session_with_corpus", function!(run_em_on_session_with_corpus, 13))?;
1259
+ rust_backend_module
1260
+ .define_singleton_method("run_em_on_session_start", function!(run_em_on_session_start, 3))?;
454
1261
 
455
1262
  Ok(())
456
1263
  }