sequenzo 0.1.31__cp310-cp310-macosx_10_9_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. _sequenzo_fastcluster.cpython-310-darwin.so +0 -0
  2. sequenzo/__init__.py +349 -0
  3. sequenzo/big_data/__init__.py +12 -0
  4. sequenzo/big_data/clara/__init__.py +26 -0
  5. sequenzo/big_data/clara/clara.py +476 -0
  6. sequenzo/big_data/clara/utils/__init__.py +27 -0
  7. sequenzo/big_data/clara/utils/aggregatecases.py +92 -0
  8. sequenzo/big_data/clara/utils/davies_bouldin.py +91 -0
  9. sequenzo/big_data/clara/utils/get_weighted_diss.cpython-310-darwin.so +0 -0
  10. sequenzo/big_data/clara/utils/wfcmdd.py +205 -0
  11. sequenzo/big_data/clara/visualization.py +88 -0
  12. sequenzo/clustering/KMedoids.py +178 -0
  13. sequenzo/clustering/__init__.py +30 -0
  14. sequenzo/clustering/clustering_c_code.cpython-310-darwin.so +0 -0
  15. sequenzo/clustering/hierarchical_clustering.py +1256 -0
  16. sequenzo/clustering/sequenzo_fastcluster/fastcluster.py +495 -0
  17. sequenzo/clustering/sequenzo_fastcluster/src/fastcluster.cpp +1877 -0
  18. sequenzo/clustering/sequenzo_fastcluster/src/fastcluster_python.cpp +1264 -0
  19. sequenzo/clustering/src/KMedoid.cpp +263 -0
  20. sequenzo/clustering/src/PAM.cpp +237 -0
  21. sequenzo/clustering/src/PAMonce.cpp +265 -0
  22. sequenzo/clustering/src/cluster_quality.cpp +496 -0
  23. sequenzo/clustering/src/cluster_quality.h +128 -0
  24. sequenzo/clustering/src/cluster_quality_backup.cpp +570 -0
  25. sequenzo/clustering/src/module.cpp +228 -0
  26. sequenzo/clustering/src/weightedinertia.cpp +111 -0
  27. sequenzo/clustering/utils/__init__.py +27 -0
  28. sequenzo/clustering/utils/disscenter.py +122 -0
  29. sequenzo/data_preprocessing/__init__.py +22 -0
  30. sequenzo/data_preprocessing/helpers.py +303 -0
  31. sequenzo/datasets/__init__.py +41 -0
  32. sequenzo/datasets/biofam.csv +2001 -0
  33. sequenzo/datasets/biofam_child_domain.csv +2001 -0
  34. sequenzo/datasets/biofam_left_domain.csv +2001 -0
  35. sequenzo/datasets/biofam_married_domain.csv +2001 -0
  36. sequenzo/datasets/chinese_colonial_territories.csv +12 -0
  37. sequenzo/datasets/country_co2_emissions.csv +194 -0
  38. sequenzo/datasets/country_co2_emissions_global_deciles.csv +195 -0
  39. sequenzo/datasets/country_co2_emissions_global_quintiles.csv +195 -0
  40. sequenzo/datasets/country_co2_emissions_local_deciles.csv +195 -0
  41. sequenzo/datasets/country_co2_emissions_local_quintiles.csv +195 -0
  42. sequenzo/datasets/country_gdp_per_capita.csv +194 -0
  43. sequenzo/datasets/dyadic_children.csv +61 -0
  44. sequenzo/datasets/dyadic_parents.csv +61 -0
  45. sequenzo/datasets/mvad.csv +713 -0
  46. sequenzo/datasets/pairfam_activity_by_month.csv +1028 -0
  47. sequenzo/datasets/pairfam_activity_by_year.csv +1028 -0
  48. sequenzo/datasets/pairfam_family_by_month.csv +1028 -0
  49. sequenzo/datasets/pairfam_family_by_year.csv +1028 -0
  50. sequenzo/datasets/political_science_aid_shock.csv +166 -0
  51. sequenzo/datasets/political_science_donor_fragmentation.csv +157 -0
  52. sequenzo/define_sequence_data.py +1400 -0
  53. sequenzo/dissimilarity_measures/__init__.py +31 -0
  54. sequenzo/dissimilarity_measures/c_code.cpython-310-darwin.so +0 -0
  55. sequenzo/dissimilarity_measures/get_distance_matrix.py +762 -0
  56. sequenzo/dissimilarity_measures/get_substitution_cost_matrix.py +246 -0
  57. sequenzo/dissimilarity_measures/src/DHDdistance.cpp +148 -0
  58. sequenzo/dissimilarity_measures/src/LCPdistance.cpp +114 -0
  59. sequenzo/dissimilarity_measures/src/LCPspellDistance.cpp +215 -0
  60. sequenzo/dissimilarity_measures/src/OMdistance.cpp +247 -0
  61. sequenzo/dissimilarity_measures/src/OMspellDistance.cpp +281 -0
  62. sequenzo/dissimilarity_measures/src/__init__.py +0 -0
  63. sequenzo/dissimilarity_measures/src/dist2matrix.cpp +63 -0
  64. sequenzo/dissimilarity_measures/src/dp_utils.h +160 -0
  65. sequenzo/dissimilarity_measures/src/module.cpp +40 -0
  66. sequenzo/dissimilarity_measures/src/setup.py +30 -0
  67. sequenzo/dissimilarity_measures/src/utils.h +25 -0
  68. sequenzo/dissimilarity_measures/src/xsimd/.github/cmake-test/main.cpp +6 -0
  69. sequenzo/dissimilarity_measures/src/xsimd/benchmark/main.cpp +159 -0
  70. sequenzo/dissimilarity_measures/src/xsimd/benchmark/xsimd_benchmark.hpp +565 -0
  71. sequenzo/dissimilarity_measures/src/xsimd/docs/source/conf.py +37 -0
  72. sequenzo/dissimilarity_measures/src/xsimd/examples/mandelbrot.cpp +330 -0
  73. sequenzo/dissimilarity_measures/src/xsimd/examples/pico_bench.hpp +246 -0
  74. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_arithmetic.hpp +266 -0
  75. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_complex.hpp +112 -0
  76. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_details.hpp +323 -0
  77. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_logical.hpp +218 -0
  78. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_math.hpp +2583 -0
  79. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_memory.hpp +880 -0
  80. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_rounding.hpp +72 -0
  81. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_swizzle.hpp +174 -0
  82. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_trigo.hpp +978 -0
  83. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx.hpp +1924 -0
  84. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx2.hpp +1144 -0
  85. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512bw.hpp +656 -0
  86. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512cd.hpp +28 -0
  87. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512dq.hpp +244 -0
  88. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512er.hpp +20 -0
  89. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512f.hpp +2650 -0
  90. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512ifma.hpp +20 -0
  91. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512pf.hpp +20 -0
  92. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vbmi.hpp +77 -0
  93. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vbmi2.hpp +131 -0
  94. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vnni_avx512bw.hpp +20 -0
  95. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vnni_avx512vbmi2.hpp +20 -0
  96. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avxvnni.hpp +20 -0
  97. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_common.hpp +24 -0
  98. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_common_fwd.hpp +77 -0
  99. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_constants.hpp +393 -0
  100. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_emulated.hpp +788 -0
  101. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_avx.hpp +93 -0
  102. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_avx2.hpp +46 -0
  103. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_sse.hpp +97 -0
  104. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma4.hpp +92 -0
  105. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_i8mm_neon64.hpp +17 -0
  106. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_isa.hpp +142 -0
  107. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_neon.hpp +3142 -0
  108. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_neon64.hpp +1543 -0
  109. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_rvv.hpp +1513 -0
  110. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_scalar.hpp +1260 -0
  111. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse2.hpp +2024 -0
  112. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse3.hpp +67 -0
  113. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse4_1.hpp +339 -0
  114. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse4_2.hpp +44 -0
  115. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_ssse3.hpp +186 -0
  116. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sve.hpp +1155 -0
  117. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_vsx.hpp +892 -0
  118. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_wasm.hpp +1780 -0
  119. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_arch.hpp +240 -0
  120. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_config.hpp +484 -0
  121. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_cpuid.hpp +269 -0
  122. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_inline.hpp +27 -0
  123. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/math/xsimd_rem_pio2.hpp +719 -0
  124. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/memory/xsimd_aligned_allocator.hpp +349 -0
  125. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/memory/xsimd_alignment.hpp +91 -0
  126. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_all_registers.hpp +55 -0
  127. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_api.hpp +2765 -0
  128. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx2_register.hpp +44 -0
  129. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512bw_register.hpp +51 -0
  130. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512cd_register.hpp +51 -0
  131. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512dq_register.hpp +51 -0
  132. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512er_register.hpp +51 -0
  133. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512f_register.hpp +77 -0
  134. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512ifma_register.hpp +51 -0
  135. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512pf_register.hpp +51 -0
  136. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vbmi2_register.hpp +51 -0
  137. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vbmi_register.hpp +51 -0
  138. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vnni_avx512bw_register.hpp +54 -0
  139. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vnni_avx512vbmi2_register.hpp +53 -0
  140. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx_register.hpp +64 -0
  141. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avxvnni_register.hpp +44 -0
  142. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_batch.hpp +1524 -0
  143. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_batch_constant.hpp +300 -0
  144. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_common_arch.hpp +47 -0
  145. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_emulated_register.hpp +80 -0
  146. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma3_avx2_register.hpp +50 -0
  147. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma3_avx_register.hpp +50 -0
  148. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma3_sse_register.hpp +50 -0
  149. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma4_register.hpp +50 -0
  150. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_i8mm_neon64_register.hpp +55 -0
  151. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_neon64_register.hpp +55 -0
  152. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_neon_register.hpp +154 -0
  153. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_register.hpp +94 -0
  154. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_rvv_register.hpp +506 -0
  155. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse2_register.hpp +59 -0
  156. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse3_register.hpp +49 -0
  157. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse4_1_register.hpp +48 -0
  158. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse4_2_register.hpp +48 -0
  159. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_ssse3_register.hpp +48 -0
  160. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sve_register.hpp +156 -0
  161. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_traits.hpp +337 -0
  162. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_utils.hpp +536 -0
  163. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_vsx_register.hpp +77 -0
  164. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_wasm_register.hpp +59 -0
  165. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/xsimd.hpp +75 -0
  166. sequenzo/dissimilarity_measures/src/xsimd/test/architectures/dummy.cpp +7 -0
  167. sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set.cpp +13 -0
  168. sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean.cpp +24 -0
  169. sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean_aligned.cpp +25 -0
  170. sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean_arch_independent.cpp +28 -0
  171. sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean_tag_dispatch.cpp +25 -0
  172. sequenzo/dissimilarity_measures/src/xsimd/test/doc/manipulating_abstract_batches.cpp +7 -0
  173. sequenzo/dissimilarity_measures/src/xsimd/test/doc/manipulating_parametric_batches.cpp +8 -0
  174. sequenzo/dissimilarity_measures/src/xsimd/test/doc/sum.hpp +31 -0
  175. sequenzo/dissimilarity_measures/src/xsimd/test/doc/sum_avx2.cpp +3 -0
  176. sequenzo/dissimilarity_measures/src/xsimd/test/doc/sum_sse2.cpp +3 -0
  177. sequenzo/dissimilarity_measures/src/xsimd/test/doc/writing_vectorized_code.cpp +11 -0
  178. sequenzo/dissimilarity_measures/src/xsimd/test/main.cpp +31 -0
  179. sequenzo/dissimilarity_measures/src/xsimd/test/test_api.cpp +230 -0
  180. sequenzo/dissimilarity_measures/src/xsimd/test/test_arch.cpp +217 -0
  181. sequenzo/dissimilarity_measures/src/xsimd/test/test_basic_math.cpp +183 -0
  182. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch.cpp +1049 -0
  183. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_bool.cpp +508 -0
  184. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_cast.cpp +409 -0
  185. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_complex.cpp +712 -0
  186. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_constant.cpp +286 -0
  187. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_float.cpp +141 -0
  188. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_int.cpp +365 -0
  189. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_manip.cpp +308 -0
  190. sequenzo/dissimilarity_measures/src/xsimd/test/test_bitwise_cast.cpp +222 -0
  191. sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_exponential.cpp +226 -0
  192. sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_hyperbolic.cpp +183 -0
  193. sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_power.cpp +265 -0
  194. sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_trigonometric.cpp +236 -0
  195. sequenzo/dissimilarity_measures/src/xsimd/test/test_conversion.cpp +248 -0
  196. sequenzo/dissimilarity_measures/src/xsimd/test/test_custom_default_arch.cpp +28 -0
  197. sequenzo/dissimilarity_measures/src/xsimd/test/test_error_gamma.cpp +170 -0
  198. sequenzo/dissimilarity_measures/src/xsimd/test/test_explicit_batch_instantiation.cpp +32 -0
  199. sequenzo/dissimilarity_measures/src/xsimd/test/test_exponential.cpp +202 -0
  200. sequenzo/dissimilarity_measures/src/xsimd/test/test_extract_pair.cpp +92 -0
  201. sequenzo/dissimilarity_measures/src/xsimd/test/test_fp_manipulation.cpp +77 -0
  202. sequenzo/dissimilarity_measures/src/xsimd/test/test_gnu_source.cpp +30 -0
  203. sequenzo/dissimilarity_measures/src/xsimd/test/test_hyperbolic.cpp +167 -0
  204. sequenzo/dissimilarity_measures/src/xsimd/test/test_load_store.cpp +304 -0
  205. sequenzo/dissimilarity_measures/src/xsimd/test/test_memory.cpp +61 -0
  206. sequenzo/dissimilarity_measures/src/xsimd/test/test_poly_evaluation.cpp +64 -0
  207. sequenzo/dissimilarity_measures/src/xsimd/test/test_power.cpp +184 -0
  208. sequenzo/dissimilarity_measures/src/xsimd/test/test_rounding.cpp +199 -0
  209. sequenzo/dissimilarity_measures/src/xsimd/test/test_select.cpp +101 -0
  210. sequenzo/dissimilarity_measures/src/xsimd/test/test_shuffle.cpp +760 -0
  211. sequenzo/dissimilarity_measures/src/xsimd/test/test_sum.cpp +4 -0
  212. sequenzo/dissimilarity_measures/src/xsimd/test/test_sum.hpp +34 -0
  213. sequenzo/dissimilarity_measures/src/xsimd/test/test_traits.cpp +172 -0
  214. sequenzo/dissimilarity_measures/src/xsimd/test/test_trigonometric.cpp +208 -0
  215. sequenzo/dissimilarity_measures/src/xsimd/test/test_utils.hpp +611 -0
  216. sequenzo/dissimilarity_measures/src/xsimd/test/test_wasm/test_wasm_playwright.py +123 -0
  217. sequenzo/dissimilarity_measures/src/xsimd/test/test_xsimd_api.cpp +1460 -0
  218. sequenzo/dissimilarity_measures/utils/__init__.py +16 -0
  219. sequenzo/dissimilarity_measures/utils/get_LCP_length_for_2_seq.py +44 -0
  220. sequenzo/dissimilarity_measures/utils/get_sm_trate_substitution_cost_matrix.cpython-310-darwin.so +0 -0
  221. sequenzo/dissimilarity_measures/utils/seqconc.cpython-310-darwin.so +0 -0
  222. sequenzo/dissimilarity_measures/utils/seqdss.cpython-310-darwin.so +0 -0
  223. sequenzo/dissimilarity_measures/utils/seqdur.cpython-310-darwin.so +0 -0
  224. sequenzo/dissimilarity_measures/utils/seqlength.cpython-310-darwin.so +0 -0
  225. sequenzo/multidomain/__init__.py +23 -0
  226. sequenzo/multidomain/association_between_domains.py +311 -0
  227. sequenzo/multidomain/cat.py +597 -0
  228. sequenzo/multidomain/combt.py +519 -0
  229. sequenzo/multidomain/dat.py +81 -0
  230. sequenzo/multidomain/idcd.py +139 -0
  231. sequenzo/multidomain/linked_polyad.py +292 -0
  232. sequenzo/openmp_setup.py +233 -0
  233. sequenzo/prefix_tree/__init__.py +62 -0
  234. sequenzo/prefix_tree/hub.py +114 -0
  235. sequenzo/prefix_tree/individual_level_indicators.py +1321 -0
  236. sequenzo/prefix_tree/spell_individual_level_indicators.py +580 -0
  237. sequenzo/prefix_tree/spell_level_indicators.py +297 -0
  238. sequenzo/prefix_tree/system_level_indicators.py +544 -0
  239. sequenzo/prefix_tree/utils.py +54 -0
  240. sequenzo/seqhmm/__init__.py +95 -0
  241. sequenzo/seqhmm/advanced_optimization.py +305 -0
  242. sequenzo/seqhmm/bootstrap.py +411 -0
  243. sequenzo/seqhmm/build_hmm.py +142 -0
  244. sequenzo/seqhmm/build_mhmm.py +136 -0
  245. sequenzo/seqhmm/build_nhmm.py +121 -0
  246. sequenzo/seqhmm/fit_mhmm.py +62 -0
  247. sequenzo/seqhmm/fit_model.py +61 -0
  248. sequenzo/seqhmm/fit_nhmm.py +76 -0
  249. sequenzo/seqhmm/formulas.py +289 -0
  250. sequenzo/seqhmm/forward_backward_nhmm.py +276 -0
  251. sequenzo/seqhmm/gradients_nhmm.py +306 -0
  252. sequenzo/seqhmm/hmm.py +291 -0
  253. sequenzo/seqhmm/mhmm.py +314 -0
  254. sequenzo/seqhmm/model_comparison.py +238 -0
  255. sequenzo/seqhmm/multichannel_em.py +282 -0
  256. sequenzo/seqhmm/multichannel_utils.py +138 -0
  257. sequenzo/seqhmm/nhmm.py +270 -0
  258. sequenzo/seqhmm/nhmm_utils.py +191 -0
  259. sequenzo/seqhmm/predict.py +137 -0
  260. sequenzo/seqhmm/predict_mhmm.py +142 -0
  261. sequenzo/seqhmm/simulate.py +878 -0
  262. sequenzo/seqhmm/utils.py +218 -0
  263. sequenzo/seqhmm/visualization.py +910 -0
  264. sequenzo/sequence_characteristics/__init__.py +40 -0
  265. sequenzo/sequence_characteristics/complexity_index.py +49 -0
  266. sequenzo/sequence_characteristics/overall_cross_sectional_entropy.py +220 -0
  267. sequenzo/sequence_characteristics/plot_characteristics.py +593 -0
  268. sequenzo/sequence_characteristics/simple_characteristics.py +311 -0
  269. sequenzo/sequence_characteristics/state_frequencies_and_entropy_per_sequence.py +39 -0
  270. sequenzo/sequence_characteristics/turbulence.py +155 -0
  271. sequenzo/sequence_characteristics/variance_of_spell_durations.py +86 -0
  272. sequenzo/sequence_characteristics/within_sequence_entropy.py +43 -0
  273. sequenzo/suffix_tree/__init__.py +66 -0
  274. sequenzo/suffix_tree/hub.py +114 -0
  275. sequenzo/suffix_tree/individual_level_indicators.py +1679 -0
  276. sequenzo/suffix_tree/spell_individual_level_indicators.py +493 -0
  277. sequenzo/suffix_tree/spell_level_indicators.py +248 -0
  278. sequenzo/suffix_tree/system_level_indicators.py +535 -0
  279. sequenzo/suffix_tree/utils.py +56 -0
  280. sequenzo/version_check.py +283 -0
  281. sequenzo/visualization/__init__.py +29 -0
  282. sequenzo/visualization/plot_mean_time.py +222 -0
  283. sequenzo/visualization/plot_modal_state.py +276 -0
  284. sequenzo/visualization/plot_most_frequent_sequences.py +147 -0
  285. sequenzo/visualization/plot_relative_frequency.py +405 -0
  286. sequenzo/visualization/plot_sequence_index.py +1175 -0
  287. sequenzo/visualization/plot_single_medoid.py +153 -0
  288. sequenzo/visualization/plot_state_distribution.py +651 -0
  289. sequenzo/visualization/plot_transition_matrix.py +190 -0
  290. sequenzo/visualization/utils/__init__.py +23 -0
  291. sequenzo/visualization/utils/utils.py +310 -0
  292. sequenzo/with_event_history_analysis/__init__.py +35 -0
  293. sequenzo/with_event_history_analysis/sequence_analysis_multi_state_model.py +850 -0
  294. sequenzo/with_event_history_analysis/sequence_history_analysis.py +283 -0
  295. sequenzo-0.1.31.dist-info/METADATA +286 -0
  296. sequenzo-0.1.31.dist-info/RECORD +299 -0
  297. sequenzo-0.1.31.dist-info/WHEEL +5 -0
  298. sequenzo-0.1.31.dist-info/licenses/LICENSE +28 -0
  299. sequenzo-0.1.31.dist-info/top_level.txt +2 -0
@@ -0,0 +1,910 @@
1
+ """
2
+ @Author : Yuqi Liang 梁彧祺
3
+ @File : visualization.py
4
+ @Time : 2025-11-18 07:25
5
+ @Desc : Visualization functions for HMM models
6
+
7
+ This module provides visualization functions for HMM models, similar to
8
+ seqHMM's plot.hmm() and plot.mhmm() functions in R.
9
+ """
10
+
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.patches as mpatches
14
+ from matplotlib.patches import FancyArrowPatch, Circle, Wedge
15
+ from matplotlib.collections import PatchCollection
16
+ from typing import Optional, List, Union
17
+ from .hmm import HMM
18
+ from .mhmm import MHMM
19
+
20
+ # Try to import networkx for network layout, but make it optional
21
+ try:
22
+ import networkx as nx
23
+ HAS_NETWORKX = True
24
+ except ImportError:
25
+ HAS_NETWORKX = False
26
+
27
+
28
+ def plot_hmm(
29
+ model: HMM,
30
+ which: str = 'transition',
31
+ figsize: Optional[tuple] = None,
32
+ ax: Optional[plt.Axes] = None,
33
+ # Network plot parameters (similar to R's plot.hmm)
34
+ vertex_size: float = 50,
35
+ vertex_label_dist: float = 1.5,
36
+ edge_curved: Union[bool, float] = 0.5,
37
+ edge_label_cex: float = 0.8,
38
+ vertex_label: str = 'initial.probs',
39
+ loops: bool = False,
40
+ trim: float = 1e-15,
41
+ combine_slices: float = 0.05,
42
+ with_legend: Union[bool, str] = 'bottom',
43
+ layout: str = 'horizontal',
44
+ **kwargs
45
+ ) -> plt.Figure:
46
+ """
47
+ Plot HMM model parameters.
48
+
49
+ This function visualizes HMM model parameters, including:
50
+ - Transition probability matrix
51
+ - Emission probability matrix
52
+ - Initial state probabilities
53
+ - Network graph (similar to R's plot.hmm())
54
+
55
+ It is similar to seqHMM's plot.hmm() function in R.
56
+
57
+ Args:
58
+ model: Fitted HMM model object
59
+ which: What to plot. Options:
60
+ - 'transition': Transition probability matrix (default)
61
+ - 'emission': Emission probability matrix
62
+ - 'initial': Initial state probabilities
63
+ - 'network': Network graph with pie chart nodes (like R's plot.hmm)
64
+ - 'all': All three plots (transition, emission, initial)
65
+ figsize: Figure size tuple (width, height). If None, uses default.
66
+ ax: Optional matplotlib axes to plot on. If None, creates new figure.
67
+
68
+ # Network plot parameters (only used when which='network'):
69
+ vertex_size: Size of vertices (nodes). Default 50.
70
+ vertex_label_dist: Distance of vertex labels from center. Default 1.5.
71
+ edge_curved: Whether to plot curved edges. Can be bool or float (curvature). Default 0.5.
72
+ edge_label_cex: Character expansion factor for edge labels. Default 0.8.
73
+ vertex_label: Labels for vertices. Options: 'initial.probs', 'names', or custom list. Default 'initial.probs'.
74
+ loops: Whether to plot self-loops (transitions back to same state). Default False.
75
+ trim: Minimum transition probability to plot. Default 1e-15.
76
+ combine_slices: Emission probabilities below this are combined into 'others'. Default 0.05.
77
+ with_legend: Whether and where to plot legend. Options: True, False, 'bottom', 'top', 'left', 'right'. Default 'bottom'.
78
+ layout: Layout of vertices. Options: 'horizontal', 'vertical'. Default 'horizontal'.
79
+ **kwargs: Additional arguments passed to network plot.
80
+
81
+ Returns:
82
+ matplotlib Figure: The figure object
83
+
84
+ Examples:
85
+ >>> from sequenzo import SequenceData, load_dataset
86
+ >>> from sequenzo.seqhmm import build_hmm, fit_model, plot_hmm
87
+ >>>
88
+ >>> # Load and prepare data
89
+ >>> df = load_dataset('mvad')
90
+ >>> seq = SequenceData(df, time=range(15, 86), states=['EM', 'FE', 'HE', 'JL', 'SC', 'TR'])
91
+ >>>
92
+ >>> # Build and fit model
93
+ >>> hmm = build_hmm(seq, n_states=4, random_state=42)
94
+ >>> hmm = fit_model(hmm)
95
+ >>>
96
+ >>> # Plot network graph (like R's plot.hmm)
97
+ >>> plot_hmm(hmm, which='network', vertex_size=50, edge_curved=0.5)
98
+ >>> plt.show()
99
+ >>>
100
+ >>> # Plot transition matrix
101
+ >>> plot_hmm(hmm, which='transition')
102
+ >>> plt.show()
103
+ """
104
+ if model.log_likelihood is None:
105
+ raise ValueError("Model must be fitted before plotting. Use fit_model() first.")
106
+
107
+ if which == 'network':
108
+ return _plot_hmm_network(
109
+ model, figsize=figsize, ax=ax,
110
+ vertex_size=vertex_size,
111
+ vertex_label_dist=vertex_label_dist,
112
+ edge_curved=edge_curved,
113
+ edge_label_cex=edge_label_cex,
114
+ vertex_label=vertex_label,
115
+ loops=loops,
116
+ trim=trim,
117
+ combine_slices=combine_slices,
118
+ with_legend=with_legend,
119
+ layout=layout,
120
+ **kwargs
121
+ )
122
+ elif which == 'all':
123
+ # Create subplots for all three
124
+ if figsize is None:
125
+ figsize = (15, 5)
126
+ fig, axes = plt.subplots(1, 3, figsize=figsize)
127
+
128
+ # Remove outer borders for a cleaner look
129
+ for ax in axes:
130
+ ax.spines['top'].set_visible(False)
131
+ ax.spines['right'].set_visible(False)
132
+ ax.spines['bottom'].set_color('#cccccc')
133
+ ax.spines['left'].set_color('#cccccc')
134
+
135
+ # Plot each component
136
+ _plot_transition_matrix(model, ax=axes[0])
137
+ _plot_emission_matrix(model, ax=axes[1])
138
+ _plot_initial_probs(model, ax=axes[2])
139
+
140
+ plt.tight_layout()
141
+ return fig
142
+
143
+ elif which == 'transition':
144
+ return _plot_transition_matrix(model, figsize=figsize, ax=ax)
145
+ elif which == 'emission':
146
+ return _plot_emission_matrix(model, figsize=figsize, ax=ax)
147
+ elif which == 'initial':
148
+ return _plot_initial_probs(model, figsize=figsize, ax=ax)
149
+ else:
150
+ raise ValueError(f"Unknown 'which' option: {which}. Must be 'transition', 'emission', 'initial', 'network', or 'all'.")
151
+
152
+
153
+ def _plot_transition_matrix(
154
+ model: HMM,
155
+ figsize: Optional[tuple] = None,
156
+ ax: Optional[plt.Axes] = None
157
+ ) -> plt.Figure:
158
+ """Plot transition probability matrix as a heatmap."""
159
+ if ax is None:
160
+ if figsize is None:
161
+ figsize = (8, 6)
162
+ fig, ax = plt.subplots(figsize=figsize)
163
+ else:
164
+ fig = ax.figure
165
+
166
+ # Create heatmap with a more elegant colormap
167
+ im = ax.imshow(model.transition_probs, cmap='Blues', aspect='auto', vmin=0, vmax=1)
168
+
169
+ # Add colorbar with cleaner style
170
+ cbar = plt.colorbar(im, ax=ax)
171
+ cbar.set_label('Transition Probability', rotation=270, labelpad=20, fontsize=10)
172
+ cbar.outline.set_visible(False)
173
+
174
+ # Set ticks and labels
175
+ ax.set_xticks(range(model.n_states))
176
+ ax.set_yticks(range(model.n_states))
177
+ ax.set_xticklabels(model.state_names, rotation=45, ha='right', fontsize=9)
178
+ ax.set_yticklabels(model.state_names, fontsize=9)
179
+
180
+ # Add text annotations
181
+ for i in range(model.n_states):
182
+ for j in range(model.n_states):
183
+ text = ax.text(j, i, f'{model.transition_probs[i, j]:.2f}',
184
+ ha="center", va="center",
185
+ color="black" if model.transition_probs[i, j] < 0.5 else "white",
186
+ fontsize=9, weight='medium')
187
+
188
+ ax.set_xlabel('To State', fontsize=10)
189
+ ax.set_ylabel('From State', fontsize=10)
190
+ ax.set_title('Transition Probability Matrix', fontsize=11, pad=10, weight='medium')
191
+
192
+ # Remove top and right spines for cleaner look
193
+ ax.spines['top'].set_visible(False)
194
+ ax.spines['right'].set_visible(False)
195
+ ax.spines['bottom'].set_color('#cccccc')
196
+ ax.spines['left'].set_color('#cccccc')
197
+
198
+ return fig
199
+
200
+
201
+ def _plot_emission_matrix(
202
+ model: HMM,
203
+ figsize: Optional[tuple] = None,
204
+ ax: Optional[plt.Axes] = None
205
+ ) -> plt.Figure:
206
+ """Plot emission probability matrix as a heatmap."""
207
+ if ax is None:
208
+ if figsize is None:
209
+ figsize = (10, 6)
210
+ fig, ax = plt.subplots(figsize=figsize)
211
+ else:
212
+ fig = ax.figure
213
+
214
+ # Create heatmap with a more elegant colormap
215
+ im = ax.imshow(model.emission_probs, cmap='YlOrRd', aspect='auto', vmin=0, vmax=1)
216
+
217
+ # Add colorbar with cleaner style
218
+ cbar = plt.colorbar(im, ax=ax)
219
+ cbar.set_label('Emission Probability', rotation=270, labelpad=20, fontsize=10)
220
+ cbar.outline.set_visible(False)
221
+
222
+ # Set ticks and labels
223
+ ax.set_xticks(range(model.n_symbols))
224
+ ax.set_yticks(range(model.n_states))
225
+ ax.set_xticklabels(model.alphabet, rotation=45, ha='right', fontsize=9)
226
+ ax.set_yticklabels(model.state_names, fontsize=9)
227
+
228
+ # Add text annotations (only if matrix is not too large)
229
+ if model.n_states <= 10 and model.n_symbols <= 15:
230
+ for i in range(model.n_states):
231
+ for j in range(model.n_symbols):
232
+ text = ax.text(j, i, f'{model.emission_probs[i, j]:.2f}',
233
+ ha="center", va="center",
234
+ color="black" if model.emission_probs[i, j] < 0.5 else "white",
235
+ fontsize=8, weight='medium')
236
+
237
+ ax.set_xlabel('Observed Symbol', fontsize=10)
238
+ ax.set_ylabel('Hidden State', fontsize=10)
239
+ ax.set_title('Emission Probability Matrix', fontsize=11, pad=10, weight='medium')
240
+
241
+ # Remove top and right spines for cleaner look
242
+ ax.spines['top'].set_visible(False)
243
+ ax.spines['right'].set_visible(False)
244
+ ax.spines['bottom'].set_color('#cccccc')
245
+ ax.spines['left'].set_color('#cccccc')
246
+
247
+ return fig
248
+
249
+
250
+ def _plot_initial_probs(
251
+ model: HMM,
252
+ figsize: Optional[tuple] = None,
253
+ ax: Optional[plt.Axes] = None
254
+ ) -> plt.Figure:
255
+ """Plot initial state probabilities as a bar chart."""
256
+ if ax is None:
257
+ if figsize is None:
258
+ figsize = (8, 5)
259
+ fig, ax = plt.subplots(figsize=figsize)
260
+ else:
261
+ fig = ax.figure
262
+
263
+ # Create bar chart with a more elegant color
264
+ bars = ax.bar(range(model.n_states), model.initial_probs,
265
+ color='#4A90E2', alpha=0.8, edgecolor='white', linewidth=1.5)
266
+
267
+ # Add value labels on bars
268
+ for i, (bar, prob) in enumerate(zip(bars, model.initial_probs)):
269
+ height = bar.get_height()
270
+ ax.text(bar.get_x() + bar.get_width()/2., height,
271
+ f'{prob:.3f}',
272
+ ha='center', va='bottom', fontsize=9, weight='medium')
273
+
274
+ ax.set_xticks(range(model.n_states))
275
+ ax.set_xticklabels(model.state_names, rotation=45, ha='right', fontsize=9)
276
+ ax.set_ylabel('Probability', fontsize=10)
277
+ ax.set_title('Initial State Probabilities', fontsize=11, pad=10, weight='medium')
278
+ ax.set_ylim(0, max(model.initial_probs) * 1.2)
279
+ ax.grid(axis='y', alpha=0.2, linestyle='--', linewidth=0.5)
280
+
281
+ # Remove top and right spines for cleaner look
282
+ ax.spines['top'].set_visible(False)
283
+ ax.spines['right'].set_visible(False)
284
+ ax.spines['bottom'].set_color('#cccccc')
285
+ ax.spines['left'].set_color('#cccccc')
286
+
287
+ return fig
288
+
289
+
290
+ def plot_mhmm(
291
+ model: MHMM,
292
+ which: str = 'clusters',
293
+ figsize: Optional[tuple] = None,
294
+ ax: Optional[plt.Axes] = None
295
+ ) -> plt.Figure:
296
+ """
297
+ Plot Mixture HMM model parameters.
298
+
299
+ This function visualizes Mixture HMM model parameters, including:
300
+ - Cluster probabilities
301
+ - Transition matrices for each cluster
302
+ - Emission matrices for each cluster
303
+
304
+ It is similar to seqHMM's plot.mhmm() function in R.
305
+
306
+ Args:
307
+ model: Fitted MHMM model object
308
+ which: What to plot. Options:
309
+ - 'clusters': Cluster probabilities (default)
310
+ - 'transition': Transition matrices for all clusters
311
+ - 'emission': Emission matrices for all clusters
312
+ - 'all': All plots
313
+ figsize: Figure size tuple (width, height). If None, uses default.
314
+ ax: Optional matplotlib axes to plot on. If None, creates new figure.
315
+
316
+ Returns:
317
+ matplotlib Figure: The figure object
318
+ """
319
+ if model.log_likelihood is None:
320
+ raise ValueError("Model must be fitted before plotting. Use fit_mhmm() first.")
321
+
322
+ if which == 'all':
323
+ # Create subplots for all components
324
+ if figsize is None:
325
+ figsize = (18, 6)
326
+ fig, axes = plt.subplots(1, 3, figsize=figsize)
327
+
328
+ # Plot each component
329
+ _plot_cluster_probs(model, ax=axes[0])
330
+ _plot_mhmm_transitions(model, ax=axes[1])
331
+ _plot_mhmm_emissions(model, ax=axes[2])
332
+
333
+ plt.tight_layout()
334
+ return fig
335
+
336
+ elif which == 'clusters':
337
+ return _plot_cluster_probs(model, figsize=figsize, ax=ax)
338
+ elif which == 'transition':
339
+ return _plot_mhmm_transitions(model, figsize=figsize, ax=ax)
340
+ elif which == 'emission':
341
+ return _plot_mhmm_emissions(model, figsize=figsize, ax=ax)
342
+ else:
343
+ raise ValueError(
344
+ f"Unknown 'which' option: {which}. Must be 'clusters', 'transition', 'emission', or 'all'."
345
+ )
346
+
347
+
348
+ def _plot_cluster_probs(
349
+ model: MHMM,
350
+ figsize: Optional[tuple] = None,
351
+ ax: Optional[plt.Axes] = None
352
+ ) -> plt.Figure:
353
+ """Plot cluster probabilities as a bar chart."""
354
+ if ax is None:
355
+ if figsize is None:
356
+ figsize = (8, 5)
357
+ fig, ax = plt.subplots(figsize=figsize)
358
+ else:
359
+ fig = ax.figure
360
+
361
+ # Create bar chart
362
+ bars = ax.bar(range(model.n_clusters), model.cluster_probs,
363
+ color='steelblue', alpha=0.7)
364
+
365
+ # Add value labels on bars
366
+ for i, (bar, prob) in enumerate(zip(bars, model.cluster_probs)):
367
+ height = bar.get_height()
368
+ ax.text(bar.get_x() + bar.get_width()/2., height,
369
+ f'{prob:.3f}',
370
+ ha='center', va='bottom')
371
+
372
+ ax.set_xticks(range(model.n_clusters))
373
+ ax.set_xticklabels(model.cluster_names, rotation=45, ha='right')
374
+ ax.set_ylabel('Probability')
375
+ ax.set_title('Cluster Probabilities')
376
+ ax.set_ylim(0, max(model.cluster_probs) * 1.2)
377
+ ax.grid(axis='y', alpha=0.3)
378
+
379
+ return fig
380
+
381
+
382
+ def _plot_mhmm_transitions(
383
+ model: MHMM,
384
+ figsize: Optional[tuple] = None,
385
+ ax: Optional[plt.Axes] = None
386
+ ) -> plt.Figure:
387
+ """Plot transition matrices for all clusters."""
388
+ if ax is None:
389
+ if figsize is None:
390
+ figsize = (6 * model.n_clusters, 6)
391
+ fig, axes = plt.subplots(1, model.n_clusters, figsize=figsize)
392
+ if model.n_clusters == 1:
393
+ axes = [axes]
394
+ else:
395
+ fig = ax.figure
396
+ axes = [ax] * model.n_clusters
397
+
398
+ for k in range(model.n_clusters):
399
+ cluster = model.clusters[k]
400
+ trans_probs = cluster.transition_probs
401
+
402
+ # Create heatmap
403
+ im = axes[k].imshow(trans_probs, cmap='Blues', aspect='auto', vmin=0, vmax=1)
404
+
405
+ # Set ticks and labels
406
+ axes[k].set_xticks(range(cluster.n_states))
407
+ axes[k].set_yticks(range(cluster.n_states))
408
+ axes[k].set_xticklabels(cluster.state_names, rotation=45, ha='right', fontsize=8)
409
+ axes[k].set_yticklabels(cluster.state_names, fontsize=8)
410
+
411
+ # Add text annotations
412
+ for i in range(cluster.n_states):
413
+ for j in range(cluster.n_states):
414
+ text = axes[k].text(j, i, f'{trans_probs[i, j]:.2f}',
415
+ ha="center", va="center",
416
+ color="black" if trans_probs[i, j] < 0.5 else "white",
417
+ fontsize=7)
418
+
419
+ axes[k].set_xlabel('To State')
420
+ axes[k].set_ylabel('From State')
421
+ axes[k].set_title(f'{model.cluster_names[k]}\nTransition Matrix')
422
+
423
+ if model.n_clusters > 1:
424
+ plt.colorbar(im, ax=axes, orientation='horizontal', pad=0.1)
425
+
426
+ plt.tight_layout()
427
+ return fig
428
+
429
+
430
+ def _plot_mhmm_emissions(
431
+ model: MHMM,
432
+ figsize: Optional[tuple] = None,
433
+ ax: Optional[plt.Axes] = None
434
+ ) -> plt.Figure:
435
+ """Plot emission matrices for all clusters."""
436
+ if ax is None:
437
+ if figsize is None:
438
+ figsize = (6 * model.n_clusters, 6)
439
+ fig, axes = plt.subplots(1, model.n_clusters, figsize=figsize)
440
+ if model.n_clusters == 1:
441
+ axes = [axes]
442
+ else:
443
+ fig = ax.figure
444
+ axes = [ax] * model.n_clusters
445
+
446
+ for k in range(model.n_clusters):
447
+ cluster = model.clusters[k]
448
+ emission_probs = cluster.emission_probs
449
+
450
+ # Create heatmap
451
+ im = axes[k].imshow(emission_probs, cmap='YlOrRd', aspect='auto', vmin=0, vmax=1)
452
+
453
+ # Set ticks and labels
454
+ axes[k].set_xticks(range(cluster.n_symbols))
455
+ axes[k].set_yticks(range(cluster.n_states))
456
+ axes[k].set_xticklabels(cluster.alphabet, rotation=45, ha='right', fontsize=8)
457
+ axes[k].set_yticklabels(cluster.state_names, fontsize=8)
458
+
459
+ # Add text annotations (only if matrix is not too large)
460
+ if cluster.n_states <= 10 and cluster.n_symbols <= 15:
461
+ for i in range(cluster.n_states):
462
+ for j in range(cluster.n_symbols):
463
+ text = axes[k].text(j, i, f'{emission_probs[i, j]:.2f}',
464
+ ha="center", va="center",
465
+ color="black" if emission_probs[i, j] < 0.5 else "white",
466
+ fontsize=7)
467
+
468
+ axes[k].set_xlabel('Observed Symbol')
469
+ axes[k].set_ylabel('Hidden State')
470
+ axes[k].set_title(f'{model.cluster_names[k]}\nEmission Matrix')
471
+
472
+ if model.n_clusters > 1:
473
+ plt.colorbar(im, ax=axes, orientation='horizontal', pad=0.1)
474
+
475
+ plt.tight_layout()
476
+ return fig
477
+
478
+
479
+ def _plot_hmm_network(
480
+ model: HMM,
481
+ figsize: Optional[tuple] = None,
482
+ ax: Optional[plt.Axes] = None,
483
+ vertex_size: float = 50,
484
+ vertex_label_dist: float = 1.5,
485
+ edge_curved: Union[bool, float] = 0.5,
486
+ edge_label_cex: float = 0.8,
487
+ vertex_label: str = 'initial.probs',
488
+ loops: bool = False,
489
+ trim: float = 1e-15,
490
+ combine_slices: float = 0.05,
491
+ with_legend: Union[bool, str] = 'bottom',
492
+ layout: str = 'horizontal',
493
+ legend_prop: float = 0.5,
494
+ **kwargs
495
+ ) -> plt.Figure:
496
+ """
497
+ Plot HMM as a network graph with pie chart nodes (similar to R's plot.hmm).
498
+
499
+ This function creates a directed graph where:
500
+ - Nodes are pie charts showing emission probabilities for each hidden state
501
+ - Edges are arrows showing transition probabilities between states
502
+ - Node labels show initial probabilities or state names
503
+
504
+ Args:
505
+ model: Fitted HMM model object
506
+ figsize: Figure size tuple (width, height)
507
+ ax: Optional matplotlib axes
508
+ vertex_size: Size of vertices (nodes)
509
+ vertex_label_dist: Distance of vertex labels from center
510
+ edge_curved: Whether to plot curved edges (bool or float for curvature)
511
+ edge_label_cex: Character expansion factor for edge labels
512
+ vertex_label: Labels for vertices ('initial.probs', 'names', or custom list)
513
+ loops: Whether to plot self-loops
514
+ trim: Minimum transition probability to plot
515
+ combine_slices: Emission probabilities below this are combined into 'others'
516
+ with_legend: Whether and where to plot legend
517
+ layout: Layout of vertices ('horizontal' or 'vertical')
518
+ legend_prop: Proportion of figure used for legend (0-1). Default 0.5.
519
+ **kwargs: Additional arguments
520
+
521
+ Returns:
522
+ matplotlib Figure: The figure object
523
+ """
524
+ # Determine if we need separate subplots for legend (like R's layout)
525
+ use_separate_legend = (with_legend and with_legend != False and
526
+ with_legend in ['bottom', 'top', 'left', 'right'])
527
+
528
+ if ax is None:
529
+ if figsize is None:
530
+ # Adjust figsize based on legend position
531
+ if use_separate_legend and with_legend in ['bottom', 'top']:
532
+ figsize = (12, 8)
533
+ elif use_separate_legend and with_legend in ['left', 'right']:
534
+ figsize = (14, 6)
535
+ else:
536
+ figsize = (12, 6)
537
+
538
+ # Create figure with subplots if legend is needed
539
+ if use_separate_legend:
540
+ if with_legend == 'bottom':
541
+ fig = plt.figure(figsize=figsize)
542
+ gs = fig.add_gridspec(2, 1, height_ratios=[1 - legend_prop, legend_prop], hspace=0.3)
543
+ ax = fig.add_subplot(gs[0])
544
+ ax_legend = fig.add_subplot(gs[1])
545
+ elif with_legend == 'top':
546
+ fig = plt.figure(figsize=figsize)
547
+ gs = fig.add_gridspec(2, 1, height_ratios=[legend_prop, 1 - legend_prop], hspace=0.3)
548
+ ax_legend = fig.add_subplot(gs[0])
549
+ ax = fig.add_subplot(gs[1])
550
+ elif with_legend == 'right':
551
+ fig = plt.figure(figsize=figsize)
552
+ gs = fig.add_gridspec(1, 2, width_ratios=[1 - legend_prop, legend_prop], wspace=0.3)
553
+ ax = fig.add_subplot(gs[0])
554
+ ax_legend = fig.add_subplot(gs[1])
555
+ elif with_legend == 'left':
556
+ fig = plt.figure(figsize=figsize)
557
+ gs = fig.add_gridspec(1, 2, width_ratios=[legend_prop, 1 - legend_prop], wspace=0.3)
558
+ ax_legend = fig.add_subplot(gs[0])
559
+ ax = fig.add_subplot(gs[1])
560
+ else:
561
+ fig, ax = plt.subplots(figsize=figsize)
562
+ ax_legend = None
563
+ else:
564
+ fig = ax.figure
565
+ ax_legend = None
566
+ use_separate_legend = False
567
+
568
+ # Get model parameters
569
+ n_states = model.n_states
570
+ transition_probs = model.transition_probs.copy()
571
+ emission_probs = model.emission_probs.copy()
572
+ initial_probs = model.initial_probs.copy()
573
+
574
+ # Get colors for observed states
575
+ # Try to get colors from observations if available
576
+ if hasattr(model.observations, 'color_map') and model.observations.color_map:
577
+ colors = [model.observations.color_map.get(sym, '#808080') for sym in model.alphabet]
578
+ else:
579
+ # Default color palette (similar to TraMineR)
580
+ default_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b',
581
+ '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
582
+ colors = default_colors[:len(model.alphabet)]
583
+ # Extend if needed
584
+ while len(colors) < len(model.alphabet):
585
+ colors.append('#808080')
586
+
587
+ # Trim transitions (remove very small probabilities)
588
+ transition_probs[transition_probs < trim] = 0
589
+
590
+ # Remove self-loops if not requested
591
+ if not loops:
592
+ np.fill_diagonal(transition_probs, 0)
593
+
594
+ # Calculate node positions (similar to R's layout)
595
+ # First, determine coordinate limits (similar to R's xlim/ylim calculation)
596
+ if layout == 'horizontal':
597
+ x_min, x_max = -0.1, n_states - 1 + 0.1
598
+ y_min, y_max = -0.5, 0.5
599
+ # Horizontal layout: nodes in a line
600
+ positions = {i: (i, 0) for i in range(n_states)}
601
+ elif layout == 'vertical':
602
+ x_min, x_max = -0.5, 0.5
603
+ y_min, y_max = -0.1, n_states - 1 + 0.1
604
+ positions = {i: (0, -i) for i in range(n_states)}
605
+ else:
606
+ x_min, x_max = -0.1, n_states - 1 + 0.1
607
+ y_min, y_max = -0.5, 0.5
608
+ positions = {i: (i, 0) for i in range(n_states)}
609
+
610
+ # Scale positions to fit in plot area
611
+ if positions:
612
+ x_coords = [pos[0] for pos in positions.values()]
613
+ y_coords = [pos[1] for pos in positions.values()]
614
+
615
+ # Normalize and scale positions to fit within xlim/ylim
616
+ if max(x_coords) != min(x_coords):
617
+ x_range = max(x_coords) - min(x_coords)
618
+ positions = {i: ((pos[0] - min(x_coords)) / x_range * (x_max - x_min) + x_min,
619
+ pos[1]) for i, pos in positions.items()}
620
+ else:
621
+ positions = {i: (x_min + (x_max - x_min) / 2, pos[1])
622
+ for i, pos in positions.items()}
623
+
624
+ # Prepare emission probabilities for pie charts
625
+ pie_values = []
626
+ pie_colors_list = []
627
+ combined_slice_probs = []
628
+ legend_labels_list = []
629
+ legend_colors_list = []
630
+
631
+ for i in range(n_states):
632
+ emis = emission_probs[i, :].copy()
633
+ combined_prob = 0
634
+
635
+ # Combine small slices
636
+ if combine_slices > 0:
637
+ small_mask = emis < combine_slices
638
+ if np.any(small_mask):
639
+ combined_prob = np.sum(emis[small_mask])
640
+ emis[small_mask] = 0
641
+ if combined_prob > 0:
642
+ emis = np.append(emis, combined_prob)
643
+ pie_colors_list.append(colors + ['white'])
644
+ # Track which colors are used for legend
645
+ used_colors = [colors[j] for j in range(len(emis) - 1) if emis[j] > 0]
646
+ used_labels = [model.alphabet[j] for j in range(len(emis) - 1) if emis[j] > 0]
647
+ legend_labels_list.append(used_labels + ['others'])
648
+ legend_colors_list.append(used_colors + ['white'])
649
+ else:
650
+ pie_colors_list.append(colors)
651
+ legend_labels_list.append([model.alphabet[j] for j in range(len(emis)) if emis[j] > 0])
652
+ legend_colors_list.append([colors[j] for j in range(len(emis)) if emis[j] > 0])
653
+ else:
654
+ pie_colors_list.append(colors)
655
+ legend_labels_list.append([model.alphabet[j] for j in range(len(emis)) if emis[j] > 0])
656
+ legend_colors_list.append([colors[j] for j in range(len(emis)) if emis[j] > 0])
657
+ else:
658
+ pie_colors_list.append(colors)
659
+ legend_labels_list.append([model.alphabet[j] for j in range(len(emis)) if emis[j] > 0])
660
+ legend_colors_list.append([colors[j] for j in range(len(emis)) if emis[j] > 0])
661
+
662
+ # Remove zero probabilities
663
+ non_zero_mask = emis > 0
664
+ pie_values.append(emis[non_zero_mask])
665
+ combined_slice_probs.append(combined_prob)
666
+
667
+ # Collect unique legend items (by appearance order, like R)
668
+ if use_separate_legend:
669
+ unique_labels = []
670
+ unique_colors = []
671
+ seen = set()
672
+ for labels, cols in zip(legend_labels_list, legend_colors_list):
673
+ for label, col in zip(labels, cols):
674
+ if (label, col) not in seen:
675
+ unique_labels.append(label)
676
+ unique_colors.append(col)
677
+ seen.add((label, col))
678
+ # Add 'others' if needed
679
+ if combine_slices > 0 and any(combined_slice_probs):
680
+ if 'others' not in unique_labels:
681
+ unique_labels.append('others')
682
+ unique_colors.append('white')
683
+
684
+ # Calculate node radius in data coordinates
685
+ # Convert vertex_size (in points) to data coordinates
686
+ # R uses vertex.size directly in the plot coordinate system
687
+ # We'll use a reasonable scaling factor based on the coordinate range
688
+ if layout == 'horizontal':
689
+ # Estimate data coordinate range
690
+ data_range = max(x_max - x_min, 1.0)
691
+ # Convert vertex_size to data coordinates
692
+ # Scale factor: vertex_size of 50 should be about 0.2-0.25 of the spacing between nodes
693
+ # For horizontal layout, spacing is approximately (x_max - x_min) / max(n_states - 1, 1)
694
+ spacing = (x_max - x_min) / max(n_states - 1, 1) if n_states > 1 else (x_max - x_min)
695
+ node_radius = (vertex_size / 50.0) * spacing * 0.25
696
+ else:
697
+ data_range = max(y_max - y_min, 1.0)
698
+ spacing = (y_max - y_min) / max(n_states - 1, 1) if n_states > 1 else (y_max - y_min)
699
+ node_radius = (vertex_size / 50.0) * spacing * 0.25
700
+
701
+ # Draw edges (transitions) first (so they appear behind nodes)
702
+ edge_widths = []
703
+ edge_labels = {}
704
+ edges_to_draw = []
705
+
706
+ # Get all non-zero transitions
707
+ transitions = []
708
+ for i in range(n_states):
709
+ for j in range(n_states):
710
+ prob = transition_probs[i, j]
711
+ if prob > 0:
712
+ edges_to_draw.append((i, j))
713
+ transitions.append(prob)
714
+
715
+ # Calculate edge widths (similar to R: transitions * (7 / max(transitions)))
716
+ if transitions:
717
+ max_trans = max(transitions)
718
+ edge_widths = [t * (7.0 / max_trans) if max_trans > 0 else 1.0 for t in transitions]
719
+ # Format edge labels
720
+ for (i, j), prob in zip(edges_to_draw, transitions):
721
+ if prob >= 0.001 or prob == 0:
722
+ edge_labels[(i, j)] = f'{prob:.3f}'
723
+ else:
724
+ edge_labels[(i, j)] = f'{prob:.2e}'
725
+ else:
726
+ edge_widths = [1.0] * len(edges_to_draw)
727
+
728
+ # Draw edges
729
+ for (i, j), width in zip(edges_to_draw, edge_widths):
730
+ x1, y1 = positions[i]
731
+ x2, y2 = positions[j]
732
+
733
+ # Calculate arrow properties
734
+ dx = x2 - x1
735
+ dy = y2 - y1
736
+ dist = np.sqrt(dx**2 + dy**2)
737
+
738
+ if dist > 0:
739
+ # Normalize direction
740
+ dx_norm = dx / dist
741
+ dy_norm = dy / dist
742
+
743
+ # Adjust start/end to account for node radius
744
+ start_x = x1 + dx_norm * node_radius
745
+ start_y = y1 + dy_norm * node_radius
746
+ end_x = x2 - dx_norm * node_radius
747
+ end_y = y2 - dy_norm * node_radius
748
+
749
+ # Curved edge
750
+ if edge_curved and (isinstance(edge_curved, bool) or edge_curved != 0):
751
+ curvature = edge_curved if isinstance(edge_curved, (int, float)) else 0.5
752
+ # Create curved path
753
+ mid_x = (start_x + end_x) / 2
754
+ mid_y = (start_y + end_y) / 2
755
+ # Perpendicular direction for curve
756
+ perp_x = -dy_norm * curvature * dist * 0.3
757
+ perp_y = dx_norm * curvature * dist * 0.3
758
+ control_x = mid_x + perp_x
759
+ control_y = mid_y + perp_y
760
+
761
+ # Use quadratic bezier curve
762
+ from matplotlib.path import Path
763
+ path_data = [
764
+ (Path.MOVETO, (start_x, start_y)),
765
+ (Path.CURVE3, (control_x, control_y)),
766
+ (Path.CURVE3, (end_x, end_y)),
767
+ ]
768
+ codes, verts = zip(*path_data)
769
+ path = Path(verts, codes)
770
+
771
+ arrow = FancyArrowPatch(
772
+ path=path,
773
+ arrowstyle='->',
774
+ lw=max(width, 0.8),
775
+ color='#666666',
776
+ alpha=0.8,
777
+ zorder=1,
778
+ mutation_scale=15
779
+ )
780
+ else:
781
+ # Straight edge
782
+ arrow = FancyArrowPatch(
783
+ (start_x, start_y),
784
+ (end_x, end_y),
785
+ arrowstyle='->',
786
+ lw=max(width, 0.8),
787
+ color='#666666',
788
+ alpha=0.8,
789
+ zorder=1,
790
+ mutation_scale=15
791
+ )
792
+
793
+ ax.add_patch(arrow)
794
+
795
+ # Add edge label
796
+ if (i, j) in edge_labels:
797
+ label_x = (start_x + end_x) / 2
798
+ label_y = (start_y + end_y) / 2
799
+ if edge_curved and (isinstance(edge_curved, bool) or edge_curved != 0):
800
+ curvature = edge_curved if isinstance(edge_curved, (int, float)) else 0.5
801
+ perp_x = -dy_norm * curvature * dist * 0.3
802
+ perp_y = dx_norm * curvature * dist * 0.3
803
+ label_x += perp_x * 0.5
804
+ label_y += perp_y * 0.5
805
+
806
+ ax.text(label_x, label_y, edge_labels[(i, j)],
807
+ fontsize=9 * edge_label_cex,
808
+ ha='center', va='center',
809
+ bbox=dict(boxstyle='round,pad=0.2', facecolor='white',
810
+ alpha=0.9, edgecolor='gray', linewidth=0.5),
811
+ zorder=3)
812
+
813
+ # Draw nodes (pie charts)
814
+ for i in range(n_states):
815
+ x, y = positions[i]
816
+ emis = pie_values[i]
817
+ node_colors = pie_colors_list[i][:len(emis)]
818
+
819
+ # Draw pie chart
820
+ if len(emis) > 0 and np.sum(emis) > 0:
821
+ # Normalize to sum to 1
822
+ emis_norm = emis / np.sum(emis)
823
+ angles = np.cumsum(emis_norm * 2 * np.pi)
824
+ angles = np.insert(angles, 0, 0)
825
+
826
+ # Draw wedges
827
+ for j in range(len(emis_norm)):
828
+ if emis_norm[j] > 0:
829
+ theta1 = angles[j] * 180 / np.pi
830
+ theta2 = angles[j + 1] * 180 / np.pi
831
+ wedge = Wedge((x, y), node_radius, theta1, theta2,
832
+ facecolor=node_colors[j], edgecolor='black',
833
+ linewidth=2, zorder=2)
834
+ ax.add_patch(wedge)
835
+
836
+ # Add node label
837
+ if vertex_label == 'initial.probs':
838
+ label_text = f'{initial_probs[i]:.2f}'
839
+ elif vertex_label == 'names':
840
+ label_text = model.state_names[i] if i < len(model.state_names) else f'State {i+1}'
841
+ else:
842
+ if isinstance(vertex_label, list) and i < len(vertex_label):
843
+ label_text = str(vertex_label[i])
844
+ else:
845
+ label_text = f'{initial_probs[i]:.2f}'
846
+
847
+ # Position label (similar to R's vertex.label.dist)
848
+ if isinstance(vertex_label_dist, (int, float)) and vertex_label_dist != 'auto':
849
+ label_dist = vertex_label_dist
850
+ else:
851
+ # Auto: place outside vertex
852
+ label_dist = node_radius * 1.4
853
+
854
+ if layout == 'horizontal':
855
+ label_x = x
856
+ label_y = y - label_dist
857
+ else:
858
+ label_x = x - label_dist
859
+ label_y = y
860
+
861
+ ax.text(label_x, label_y, label_text,
862
+ fontsize=11, ha='center', va='top' if layout == 'horizontal' else 'center',
863
+ weight='bold', zorder=4)
864
+
865
+ # Set axis properties
866
+ ax.set_aspect('equal')
867
+ ax.axis('off')
868
+
869
+ # Set limits (matching R's xlim/ylim)
870
+ if layout == 'horizontal':
871
+ ax.set_xlim(x_min, x_max)
872
+ ax.set_ylim(y_min, y_max)
873
+ else:
874
+ ax.set_xlim(x_min, x_max)
875
+ ax.set_ylim(y_min, y_max)
876
+
877
+ # Add legend in separate subplot if requested
878
+ if use_separate_legend and ax_legend is not None:
879
+ ax_legend.axis('off')
880
+ # Create legend elements
881
+ legend_elements = [mpatches.Patch(facecolor=col, edgecolor='black', linewidth=1, label=label)
882
+ for col, label in zip(unique_colors, unique_labels)]
883
+
884
+ # Calculate number of columns
885
+ ncol = min(len(unique_labels), 6) if with_legend in ['bottom', 'top'] else 1
886
+
887
+ ax_legend.legend(handles=legend_elements, loc='center',
888
+ ncol=ncol, frameon=True, fontsize=10,
889
+ handlelength=1.5, handletextpad=0.5)
890
+ elif with_legend and with_legend != False and not use_separate_legend:
891
+ # Fallback: use regular legend (may overlap)
892
+ legend_labels = list(model.alphabet)
893
+ if combine_slices > 0 and any(combined_slice_probs):
894
+ legend_labels.append('others')
895
+ legend_colors = colors + ['white']
896
+ else:
897
+ legend_colors = colors
898
+
899
+ legend_elements = [mpatches.Patch(facecolor=col, edgecolor='black', label=label)
900
+ for col, label in zip(legend_colors[:len(legend_labels)], legend_labels)]
901
+
902
+ if with_legend == 'bottom' or with_legend == True:
903
+ ax.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, -0.1),
904
+ ncol=min(len(legend_labels), 6), frameon=True, fontsize=9)
905
+ elif with_legend == 'top':
906
+ ax.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, 1.1),
907
+ ncol=min(len(legend_labels), 6), frameon=True, fontsize=9)
908
+
909
+ plt.tight_layout()
910
+ return fig