sequenzo 0.1.31__cp310-cp310-macosx_10_9_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. _sequenzo_fastcluster.cpython-310-darwin.so +0 -0
  2. sequenzo/__init__.py +349 -0
  3. sequenzo/big_data/__init__.py +12 -0
  4. sequenzo/big_data/clara/__init__.py +26 -0
  5. sequenzo/big_data/clara/clara.py +476 -0
  6. sequenzo/big_data/clara/utils/__init__.py +27 -0
  7. sequenzo/big_data/clara/utils/aggregatecases.py +92 -0
  8. sequenzo/big_data/clara/utils/davies_bouldin.py +91 -0
  9. sequenzo/big_data/clara/utils/get_weighted_diss.cpython-310-darwin.so +0 -0
  10. sequenzo/big_data/clara/utils/wfcmdd.py +205 -0
  11. sequenzo/big_data/clara/visualization.py +88 -0
  12. sequenzo/clustering/KMedoids.py +178 -0
  13. sequenzo/clustering/__init__.py +30 -0
  14. sequenzo/clustering/clustering_c_code.cpython-310-darwin.so +0 -0
  15. sequenzo/clustering/hierarchical_clustering.py +1256 -0
  16. sequenzo/clustering/sequenzo_fastcluster/fastcluster.py +495 -0
  17. sequenzo/clustering/sequenzo_fastcluster/src/fastcluster.cpp +1877 -0
  18. sequenzo/clustering/sequenzo_fastcluster/src/fastcluster_python.cpp +1264 -0
  19. sequenzo/clustering/src/KMedoid.cpp +263 -0
  20. sequenzo/clustering/src/PAM.cpp +237 -0
  21. sequenzo/clustering/src/PAMonce.cpp +265 -0
  22. sequenzo/clustering/src/cluster_quality.cpp +496 -0
  23. sequenzo/clustering/src/cluster_quality.h +128 -0
  24. sequenzo/clustering/src/cluster_quality_backup.cpp +570 -0
  25. sequenzo/clustering/src/module.cpp +228 -0
  26. sequenzo/clustering/src/weightedinertia.cpp +111 -0
  27. sequenzo/clustering/utils/__init__.py +27 -0
  28. sequenzo/clustering/utils/disscenter.py +122 -0
  29. sequenzo/data_preprocessing/__init__.py +22 -0
  30. sequenzo/data_preprocessing/helpers.py +303 -0
  31. sequenzo/datasets/__init__.py +41 -0
  32. sequenzo/datasets/biofam.csv +2001 -0
  33. sequenzo/datasets/biofam_child_domain.csv +2001 -0
  34. sequenzo/datasets/biofam_left_domain.csv +2001 -0
  35. sequenzo/datasets/biofam_married_domain.csv +2001 -0
  36. sequenzo/datasets/chinese_colonial_territories.csv +12 -0
  37. sequenzo/datasets/country_co2_emissions.csv +194 -0
  38. sequenzo/datasets/country_co2_emissions_global_deciles.csv +195 -0
  39. sequenzo/datasets/country_co2_emissions_global_quintiles.csv +195 -0
  40. sequenzo/datasets/country_co2_emissions_local_deciles.csv +195 -0
  41. sequenzo/datasets/country_co2_emissions_local_quintiles.csv +195 -0
  42. sequenzo/datasets/country_gdp_per_capita.csv +194 -0
  43. sequenzo/datasets/dyadic_children.csv +61 -0
  44. sequenzo/datasets/dyadic_parents.csv +61 -0
  45. sequenzo/datasets/mvad.csv +713 -0
  46. sequenzo/datasets/pairfam_activity_by_month.csv +1028 -0
  47. sequenzo/datasets/pairfam_activity_by_year.csv +1028 -0
  48. sequenzo/datasets/pairfam_family_by_month.csv +1028 -0
  49. sequenzo/datasets/pairfam_family_by_year.csv +1028 -0
  50. sequenzo/datasets/political_science_aid_shock.csv +166 -0
  51. sequenzo/datasets/political_science_donor_fragmentation.csv +157 -0
  52. sequenzo/define_sequence_data.py +1400 -0
  53. sequenzo/dissimilarity_measures/__init__.py +31 -0
  54. sequenzo/dissimilarity_measures/c_code.cpython-310-darwin.so +0 -0
  55. sequenzo/dissimilarity_measures/get_distance_matrix.py +762 -0
  56. sequenzo/dissimilarity_measures/get_substitution_cost_matrix.py +246 -0
  57. sequenzo/dissimilarity_measures/src/DHDdistance.cpp +148 -0
  58. sequenzo/dissimilarity_measures/src/LCPdistance.cpp +114 -0
  59. sequenzo/dissimilarity_measures/src/LCPspellDistance.cpp +215 -0
  60. sequenzo/dissimilarity_measures/src/OMdistance.cpp +247 -0
  61. sequenzo/dissimilarity_measures/src/OMspellDistance.cpp +281 -0
  62. sequenzo/dissimilarity_measures/src/__init__.py +0 -0
  63. sequenzo/dissimilarity_measures/src/dist2matrix.cpp +63 -0
  64. sequenzo/dissimilarity_measures/src/dp_utils.h +160 -0
  65. sequenzo/dissimilarity_measures/src/module.cpp +40 -0
  66. sequenzo/dissimilarity_measures/src/setup.py +30 -0
  67. sequenzo/dissimilarity_measures/src/utils.h +25 -0
  68. sequenzo/dissimilarity_measures/src/xsimd/.github/cmake-test/main.cpp +6 -0
  69. sequenzo/dissimilarity_measures/src/xsimd/benchmark/main.cpp +159 -0
  70. sequenzo/dissimilarity_measures/src/xsimd/benchmark/xsimd_benchmark.hpp +565 -0
  71. sequenzo/dissimilarity_measures/src/xsimd/docs/source/conf.py +37 -0
  72. sequenzo/dissimilarity_measures/src/xsimd/examples/mandelbrot.cpp +330 -0
  73. sequenzo/dissimilarity_measures/src/xsimd/examples/pico_bench.hpp +246 -0
  74. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_arithmetic.hpp +266 -0
  75. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_complex.hpp +112 -0
  76. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_details.hpp +323 -0
  77. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_logical.hpp +218 -0
  78. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_math.hpp +2583 -0
  79. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_memory.hpp +880 -0
  80. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_rounding.hpp +72 -0
  81. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_swizzle.hpp +174 -0
  82. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_trigo.hpp +978 -0
  83. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx.hpp +1924 -0
  84. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx2.hpp +1144 -0
  85. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512bw.hpp +656 -0
  86. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512cd.hpp +28 -0
  87. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512dq.hpp +244 -0
  88. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512er.hpp +20 -0
  89. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512f.hpp +2650 -0
  90. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512ifma.hpp +20 -0
  91. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512pf.hpp +20 -0
  92. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vbmi.hpp +77 -0
  93. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vbmi2.hpp +131 -0
  94. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vnni_avx512bw.hpp +20 -0
  95. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vnni_avx512vbmi2.hpp +20 -0
  96. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avxvnni.hpp +20 -0
  97. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_common.hpp +24 -0
  98. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_common_fwd.hpp +77 -0
  99. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_constants.hpp +393 -0
  100. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_emulated.hpp +788 -0
  101. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_avx.hpp +93 -0
  102. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_avx2.hpp +46 -0
  103. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_sse.hpp +97 -0
  104. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma4.hpp +92 -0
  105. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_i8mm_neon64.hpp +17 -0
  106. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_isa.hpp +142 -0
  107. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_neon.hpp +3142 -0
  108. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_neon64.hpp +1543 -0
  109. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_rvv.hpp +1513 -0
  110. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_scalar.hpp +1260 -0
  111. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse2.hpp +2024 -0
  112. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse3.hpp +67 -0
  113. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse4_1.hpp +339 -0
  114. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse4_2.hpp +44 -0
  115. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_ssse3.hpp +186 -0
  116. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sve.hpp +1155 -0
  117. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_vsx.hpp +892 -0
  118. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_wasm.hpp +1780 -0
  119. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_arch.hpp +240 -0
  120. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_config.hpp +484 -0
  121. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_cpuid.hpp +269 -0
  122. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_inline.hpp +27 -0
  123. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/math/xsimd_rem_pio2.hpp +719 -0
  124. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/memory/xsimd_aligned_allocator.hpp +349 -0
  125. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/memory/xsimd_alignment.hpp +91 -0
  126. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_all_registers.hpp +55 -0
  127. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_api.hpp +2765 -0
  128. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx2_register.hpp +44 -0
  129. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512bw_register.hpp +51 -0
  130. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512cd_register.hpp +51 -0
  131. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512dq_register.hpp +51 -0
  132. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512er_register.hpp +51 -0
  133. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512f_register.hpp +77 -0
  134. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512ifma_register.hpp +51 -0
  135. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512pf_register.hpp +51 -0
  136. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vbmi2_register.hpp +51 -0
  137. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vbmi_register.hpp +51 -0
  138. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vnni_avx512bw_register.hpp +54 -0
  139. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vnni_avx512vbmi2_register.hpp +53 -0
  140. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx_register.hpp +64 -0
  141. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avxvnni_register.hpp +44 -0
  142. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_batch.hpp +1524 -0
  143. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_batch_constant.hpp +300 -0
  144. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_common_arch.hpp +47 -0
  145. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_emulated_register.hpp +80 -0
  146. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma3_avx2_register.hpp +50 -0
  147. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma3_avx_register.hpp +50 -0
  148. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma3_sse_register.hpp +50 -0
  149. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma4_register.hpp +50 -0
  150. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_i8mm_neon64_register.hpp +55 -0
  151. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_neon64_register.hpp +55 -0
  152. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_neon_register.hpp +154 -0
  153. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_register.hpp +94 -0
  154. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_rvv_register.hpp +506 -0
  155. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse2_register.hpp +59 -0
  156. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse3_register.hpp +49 -0
  157. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse4_1_register.hpp +48 -0
  158. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse4_2_register.hpp +48 -0
  159. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_ssse3_register.hpp +48 -0
  160. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sve_register.hpp +156 -0
  161. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_traits.hpp +337 -0
  162. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_utils.hpp +536 -0
  163. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_vsx_register.hpp +77 -0
  164. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_wasm_register.hpp +59 -0
  165. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/xsimd.hpp +75 -0
  166. sequenzo/dissimilarity_measures/src/xsimd/test/architectures/dummy.cpp +7 -0
  167. sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set.cpp +13 -0
  168. sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean.cpp +24 -0
  169. sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean_aligned.cpp +25 -0
  170. sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean_arch_independent.cpp +28 -0
  171. sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean_tag_dispatch.cpp +25 -0
  172. sequenzo/dissimilarity_measures/src/xsimd/test/doc/manipulating_abstract_batches.cpp +7 -0
  173. sequenzo/dissimilarity_measures/src/xsimd/test/doc/manipulating_parametric_batches.cpp +8 -0
  174. sequenzo/dissimilarity_measures/src/xsimd/test/doc/sum.hpp +31 -0
  175. sequenzo/dissimilarity_measures/src/xsimd/test/doc/sum_avx2.cpp +3 -0
  176. sequenzo/dissimilarity_measures/src/xsimd/test/doc/sum_sse2.cpp +3 -0
  177. sequenzo/dissimilarity_measures/src/xsimd/test/doc/writing_vectorized_code.cpp +11 -0
  178. sequenzo/dissimilarity_measures/src/xsimd/test/main.cpp +31 -0
  179. sequenzo/dissimilarity_measures/src/xsimd/test/test_api.cpp +230 -0
  180. sequenzo/dissimilarity_measures/src/xsimd/test/test_arch.cpp +217 -0
  181. sequenzo/dissimilarity_measures/src/xsimd/test/test_basic_math.cpp +183 -0
  182. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch.cpp +1049 -0
  183. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_bool.cpp +508 -0
  184. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_cast.cpp +409 -0
  185. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_complex.cpp +712 -0
  186. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_constant.cpp +286 -0
  187. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_float.cpp +141 -0
  188. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_int.cpp +365 -0
  189. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_manip.cpp +308 -0
  190. sequenzo/dissimilarity_measures/src/xsimd/test/test_bitwise_cast.cpp +222 -0
  191. sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_exponential.cpp +226 -0
  192. sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_hyperbolic.cpp +183 -0
  193. sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_power.cpp +265 -0
  194. sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_trigonometric.cpp +236 -0
  195. sequenzo/dissimilarity_measures/src/xsimd/test/test_conversion.cpp +248 -0
  196. sequenzo/dissimilarity_measures/src/xsimd/test/test_custom_default_arch.cpp +28 -0
  197. sequenzo/dissimilarity_measures/src/xsimd/test/test_error_gamma.cpp +170 -0
  198. sequenzo/dissimilarity_measures/src/xsimd/test/test_explicit_batch_instantiation.cpp +32 -0
  199. sequenzo/dissimilarity_measures/src/xsimd/test/test_exponential.cpp +202 -0
  200. sequenzo/dissimilarity_measures/src/xsimd/test/test_extract_pair.cpp +92 -0
  201. sequenzo/dissimilarity_measures/src/xsimd/test/test_fp_manipulation.cpp +77 -0
  202. sequenzo/dissimilarity_measures/src/xsimd/test/test_gnu_source.cpp +30 -0
  203. sequenzo/dissimilarity_measures/src/xsimd/test/test_hyperbolic.cpp +167 -0
  204. sequenzo/dissimilarity_measures/src/xsimd/test/test_load_store.cpp +304 -0
  205. sequenzo/dissimilarity_measures/src/xsimd/test/test_memory.cpp +61 -0
  206. sequenzo/dissimilarity_measures/src/xsimd/test/test_poly_evaluation.cpp +64 -0
  207. sequenzo/dissimilarity_measures/src/xsimd/test/test_power.cpp +184 -0
  208. sequenzo/dissimilarity_measures/src/xsimd/test/test_rounding.cpp +199 -0
  209. sequenzo/dissimilarity_measures/src/xsimd/test/test_select.cpp +101 -0
  210. sequenzo/dissimilarity_measures/src/xsimd/test/test_shuffle.cpp +760 -0
  211. sequenzo/dissimilarity_measures/src/xsimd/test/test_sum.cpp +4 -0
  212. sequenzo/dissimilarity_measures/src/xsimd/test/test_sum.hpp +34 -0
  213. sequenzo/dissimilarity_measures/src/xsimd/test/test_traits.cpp +172 -0
  214. sequenzo/dissimilarity_measures/src/xsimd/test/test_trigonometric.cpp +208 -0
  215. sequenzo/dissimilarity_measures/src/xsimd/test/test_utils.hpp +611 -0
  216. sequenzo/dissimilarity_measures/src/xsimd/test/test_wasm/test_wasm_playwright.py +123 -0
  217. sequenzo/dissimilarity_measures/src/xsimd/test/test_xsimd_api.cpp +1460 -0
  218. sequenzo/dissimilarity_measures/utils/__init__.py +16 -0
  219. sequenzo/dissimilarity_measures/utils/get_LCP_length_for_2_seq.py +44 -0
  220. sequenzo/dissimilarity_measures/utils/get_sm_trate_substitution_cost_matrix.cpython-310-darwin.so +0 -0
  221. sequenzo/dissimilarity_measures/utils/seqconc.cpython-310-darwin.so +0 -0
  222. sequenzo/dissimilarity_measures/utils/seqdss.cpython-310-darwin.so +0 -0
  223. sequenzo/dissimilarity_measures/utils/seqdur.cpython-310-darwin.so +0 -0
  224. sequenzo/dissimilarity_measures/utils/seqlength.cpython-310-darwin.so +0 -0
  225. sequenzo/multidomain/__init__.py +23 -0
  226. sequenzo/multidomain/association_between_domains.py +311 -0
  227. sequenzo/multidomain/cat.py +597 -0
  228. sequenzo/multidomain/combt.py +519 -0
  229. sequenzo/multidomain/dat.py +81 -0
  230. sequenzo/multidomain/idcd.py +139 -0
  231. sequenzo/multidomain/linked_polyad.py +292 -0
  232. sequenzo/openmp_setup.py +233 -0
  233. sequenzo/prefix_tree/__init__.py +62 -0
  234. sequenzo/prefix_tree/hub.py +114 -0
  235. sequenzo/prefix_tree/individual_level_indicators.py +1321 -0
  236. sequenzo/prefix_tree/spell_individual_level_indicators.py +580 -0
  237. sequenzo/prefix_tree/spell_level_indicators.py +297 -0
  238. sequenzo/prefix_tree/system_level_indicators.py +544 -0
  239. sequenzo/prefix_tree/utils.py +54 -0
  240. sequenzo/seqhmm/__init__.py +95 -0
  241. sequenzo/seqhmm/advanced_optimization.py +305 -0
  242. sequenzo/seqhmm/bootstrap.py +411 -0
  243. sequenzo/seqhmm/build_hmm.py +142 -0
  244. sequenzo/seqhmm/build_mhmm.py +136 -0
  245. sequenzo/seqhmm/build_nhmm.py +121 -0
  246. sequenzo/seqhmm/fit_mhmm.py +62 -0
  247. sequenzo/seqhmm/fit_model.py +61 -0
  248. sequenzo/seqhmm/fit_nhmm.py +76 -0
  249. sequenzo/seqhmm/formulas.py +289 -0
  250. sequenzo/seqhmm/forward_backward_nhmm.py +276 -0
  251. sequenzo/seqhmm/gradients_nhmm.py +306 -0
  252. sequenzo/seqhmm/hmm.py +291 -0
  253. sequenzo/seqhmm/mhmm.py +314 -0
  254. sequenzo/seqhmm/model_comparison.py +238 -0
  255. sequenzo/seqhmm/multichannel_em.py +282 -0
  256. sequenzo/seqhmm/multichannel_utils.py +138 -0
  257. sequenzo/seqhmm/nhmm.py +270 -0
  258. sequenzo/seqhmm/nhmm_utils.py +191 -0
  259. sequenzo/seqhmm/predict.py +137 -0
  260. sequenzo/seqhmm/predict_mhmm.py +142 -0
  261. sequenzo/seqhmm/simulate.py +878 -0
  262. sequenzo/seqhmm/utils.py +218 -0
  263. sequenzo/seqhmm/visualization.py +910 -0
  264. sequenzo/sequence_characteristics/__init__.py +40 -0
  265. sequenzo/sequence_characteristics/complexity_index.py +49 -0
  266. sequenzo/sequence_characteristics/overall_cross_sectional_entropy.py +220 -0
  267. sequenzo/sequence_characteristics/plot_characteristics.py +593 -0
  268. sequenzo/sequence_characteristics/simple_characteristics.py +311 -0
  269. sequenzo/sequence_characteristics/state_frequencies_and_entropy_per_sequence.py +39 -0
  270. sequenzo/sequence_characteristics/turbulence.py +155 -0
  271. sequenzo/sequence_characteristics/variance_of_spell_durations.py +86 -0
  272. sequenzo/sequence_characteristics/within_sequence_entropy.py +43 -0
  273. sequenzo/suffix_tree/__init__.py +66 -0
  274. sequenzo/suffix_tree/hub.py +114 -0
  275. sequenzo/suffix_tree/individual_level_indicators.py +1679 -0
  276. sequenzo/suffix_tree/spell_individual_level_indicators.py +493 -0
  277. sequenzo/suffix_tree/spell_level_indicators.py +248 -0
  278. sequenzo/suffix_tree/system_level_indicators.py +535 -0
  279. sequenzo/suffix_tree/utils.py +56 -0
  280. sequenzo/version_check.py +283 -0
  281. sequenzo/visualization/__init__.py +29 -0
  282. sequenzo/visualization/plot_mean_time.py +222 -0
  283. sequenzo/visualization/plot_modal_state.py +276 -0
  284. sequenzo/visualization/plot_most_frequent_sequences.py +147 -0
  285. sequenzo/visualization/plot_relative_frequency.py +405 -0
  286. sequenzo/visualization/plot_sequence_index.py +1175 -0
  287. sequenzo/visualization/plot_single_medoid.py +153 -0
  288. sequenzo/visualization/plot_state_distribution.py +651 -0
  289. sequenzo/visualization/plot_transition_matrix.py +190 -0
  290. sequenzo/visualization/utils/__init__.py +23 -0
  291. sequenzo/visualization/utils/utils.py +310 -0
  292. sequenzo/with_event_history_analysis/__init__.py +35 -0
  293. sequenzo/with_event_history_analysis/sequence_analysis_multi_state_model.py +850 -0
  294. sequenzo/with_event_history_analysis/sequence_history_analysis.py +283 -0
  295. sequenzo-0.1.31.dist-info/METADATA +286 -0
  296. sequenzo-0.1.31.dist-info/RECORD +299 -0
  297. sequenzo-0.1.31.dist-info/WHEEL +5 -0
  298. sequenzo-0.1.31.dist-info/licenses/LICENSE +28 -0
  299. sequenzo-0.1.31.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1175 @@
1
+ """
2
+ @Author : Yuqi Liang 梁彧祺
3
+ @File : plot_sequence_index.py
4
+ @Time : 29/12/2024 09:08
5
+ @Desc :
6
+ Generate sequence index plots.
7
+ """
8
+ import numpy as np
9
+ import pandas as pd
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.gridspec as gridspec
12
+
13
+ # Use relative import to avoid circular import when top-level package imports visualization
14
+ from ..define_sequence_data import SequenceData
15
+ from sequenzo.visualization.utils import (
16
+ set_up_time_labels_for_x_axis,
17
+ save_figure_to_buffer,
18
+ create_standalone_legend,
19
+ combine_plot_with_legend,
20
+ save_and_show_results,
21
+ determine_layout,
22
+ show_plot_title
23
+ )
24
+
25
+
26
+ def smart_sort_groups(groups):
27
+ """
28
+ Smart sorting: prioritize numeric prefix, fallback to string sorting
29
+
30
+ :param groups: List of group names
31
+ :return: Sorted list of group names
32
+ """
33
+ import re
34
+
35
+ # Compile regex once for better performance
36
+ numeric_pattern = re.compile(r'^(\d+)')
37
+
38
+ def sort_key(item):
39
+ match = numeric_pattern.match(str(item))
40
+ return (int(match.group(1)), str(item)) if match else (float('inf'), str(item))
41
+
42
+ return sorted(groups, key=sort_key)
43
+
44
+
45
+ def _cmdscale(D):
46
+ """
47
+ Classic Multidimensional Scaling (MDS), equivalent to R's cmdscale()
48
+
49
+ :param D: A NxN symmetric distance matrix
50
+ :return: Y, a Nxd coordinate matrix, where d is the largest positive eigenvalues' count
51
+ """
52
+ n = len(D)
53
+
54
+ # Step 1: Compute the centering matrix
55
+ H = np.eye(n) - np.ones((n, n)) / n
56
+
57
+ # Step 2: Compute the double centered distance matrix
58
+ B = -0.5 * H @ (D ** 2) @ H
59
+
60
+ # Step 3: Compute eigenvalues and eigenvectors
61
+ eigvals, eigvecs = np.linalg.eigh(B)
62
+
63
+ # Step 4: Sort eigenvalues and eigenvectors in descending order
64
+ idx = np.argsort(eigvals)[::-1]
65
+ eigvals = eigvals[idx]
66
+ eigvecs = eigvecs[:, idx]
67
+
68
+ # Step 5: Select only positive eigenvalues
69
+ w, = np.where(eigvals > 0)
70
+ if len(w) > 0:
71
+ L = np.diag(np.sqrt(eigvals[w]))
72
+ V = eigvecs[:, w]
73
+ return V @ L # Return the MDS coordinates
74
+ else:
75
+ # Fallback if no positive eigenvalues
76
+ return np.zeros((n, 1))
77
+
78
+
79
+ def _find_most_frequent_sequence(sequences):
80
+ """
81
+ Find the most frequent sequence in the dataset.
82
+
83
+ :param sequences: numpy array of sequences
84
+ :return: index of the most frequent sequence
85
+ """
86
+ from collections import Counter
87
+
88
+ # Convert sequences to tuples for hashing
89
+ seq_tuples = [tuple(seq) for seq in sequences]
90
+
91
+ # Count frequencies
92
+ counter = Counter(seq_tuples)
93
+
94
+ # Find the most frequent sequence
95
+ most_frequent = counter.most_common(1)[0][0]
96
+
97
+ # Find the index of this sequence in the original array
98
+ for i, seq in enumerate(seq_tuples):
99
+ if seq == most_frequent:
100
+ return i
101
+
102
+ return 0 # Fallback
103
+
104
+
105
+ def _select_sequences_subset(seqdata, sequence_selection, n_sequences, sort_by, sort_by_weight, weights, mask=None):
106
+ """
107
+ Select a subset of sequences based on the selection method.
108
+
109
+ :param seqdata: SequenceData object
110
+ :param sequence_selection: Selection method ("all", "first_n", "last_n", or list of IDs)
111
+ :param n_sequences: Number of sequences for "first_n" or "last_n"
112
+ :param sort_by: Sorting method to use before selection
113
+ :param sort_by_weight: Whether to sort by weight
114
+ :param weights: Sequence weights
115
+ :param mask: Optional mask for pre-filtering sequences
116
+ :return: Boolean mask for selected sequences
117
+ """
118
+ # Start with all sequences or pre-filtered mask
119
+ if mask is None:
120
+ mask = np.ones(len(seqdata.values), dtype=bool)
121
+
122
+ # If "all", return the current mask
123
+ if sequence_selection == "all":
124
+ return mask
125
+
126
+ # Get indices of sequences that pass the mask
127
+ valid_indices = np.where(mask)[0]
128
+
129
+ # Handle ID list selection
130
+ if isinstance(sequence_selection, list):
131
+ # Convert list to set for faster lookup
132
+ selected_ids = set(sequence_selection)
133
+
134
+ # Find indices of sequences with matching IDs
135
+ selected_mask = np.zeros(len(seqdata.values), dtype=bool)
136
+ if hasattr(seqdata, 'ids') and seqdata.ids is not None:
137
+ for i in valid_indices:
138
+ if seqdata.ids[i] in selected_ids:
139
+ selected_mask[i] = True
140
+ else:
141
+ print("Warning: sequence_selection provided as ID list but seqdata has no IDs. Using all sequences.")
142
+ return mask
143
+
144
+ return selected_mask
145
+
146
+ # For "first_n" or "last_n", we need to sort first
147
+ if sequence_selection in ["first_n", "last_n"]:
148
+ # Get the subset of data based on current mask
149
+ subset_seqdata = seqdata
150
+ subset_weights = weights
151
+
152
+ if not np.all(mask):
153
+ # Create subset if mask is not all True
154
+ subset_values = seqdata.values[mask]
155
+ subset_ids = seqdata.ids[mask] if hasattr(seqdata, 'ids') and seqdata.ids is not None else None
156
+
157
+ # Use original seqdata for structure, just work with filtered values
158
+ subset_seqdata = seqdata # Keep original structure
159
+
160
+ if weights is not None:
161
+ subset_weights = weights[mask]
162
+
163
+ # Apply sorting to get the order
164
+ distance_matrix = None
165
+ if sort_by in ["mds", "distance_to_most_frequent"]:
166
+ try:
167
+ from sequenzo.dissimilarity_measures.get_distance_matrix import get_distance_matrix
168
+ distance_matrix = get_distance_matrix(
169
+ seqdata=subset_seqdata,
170
+ method="OM",
171
+ sm="CONSTANT",
172
+ indel="auto"
173
+ )
174
+ if hasattr(distance_matrix, 'values'):
175
+ distance_matrix = distance_matrix.values
176
+ except ImportError:
177
+ print(f"Warning: Cannot compute distance matrix for '{sort_by}' sorting. Using unsorted order.")
178
+ sort_by = "unsorted"
179
+
180
+ # Apply sorting to the masked subset
181
+ if sort_by_weight and subset_weights is not None:
182
+ # Sort by weight on the subset
183
+ sorted_indices = np.argsort(-subset_weights)
184
+ else:
185
+ # Sort on the subset values
186
+ if sort_by == "unsorted" or sort_by == "none":
187
+ sorted_indices = np.arange(len(valid_indices))
188
+ elif sort_by == "lexicographic":
189
+ subset_values = seqdata.values[mask]
190
+ vals = subset_values.astype(float, copy=True)
191
+ vals = np.nan_to_num(vals, nan=np.inf)
192
+ sorted_indices = np.lexsort(vals.T[::-1])
193
+ elif sort_by in ["mds", "distance_to_most_frequent"]:
194
+ # For complex sorting that requires distance matrix,
195
+ # we'll fall back to simple lexicographic for now
196
+ subset_values = seqdata.values[mask]
197
+ vals = subset_values.astype(float, copy=True)
198
+ vals = np.nan_to_num(vals, nan=np.inf)
199
+ sorted_indices = np.lexsort(vals.T[::-1])
200
+ print(f"Warning: {sort_by} sorting simplified to lexicographic for sequence selection")
201
+ else:
202
+ sorted_indices = np.arange(len(valid_indices))
203
+
204
+ # Select first_n or last_n
205
+ n_available = len(sorted_indices)
206
+ n_to_select = min(n_sequences, n_available)
207
+
208
+ if sequence_selection == "first_n":
209
+ selected_subset_indices = sorted_indices[:n_to_select]
210
+ elif sequence_selection == "last_n":
211
+ selected_subset_indices = sorted_indices[-n_to_select:]
212
+
213
+ # Map back to original indices
214
+ original_indices = valid_indices[selected_subset_indices]
215
+
216
+ # Create final mask
217
+ final_mask = np.zeros(len(seqdata.values), dtype=bool)
218
+ final_mask[original_indices] = True
219
+
220
+ return final_mask
221
+
222
+ else:
223
+ raise ValueError(f"Unsupported sequence_selection: {sequence_selection}. "
224
+ f"Supported options: 'all', 'first_n', 'last_n', or list of IDs")
225
+
226
+
227
+ def sort_sequences_by_method(seqdata, method="unsorted", mask=None, distance_matrix=None, weights=None):
228
+ """
229
+ Sort sequences in SequenceData based on specified method.
230
+
231
+ :param seqdata: SequenceData object
232
+ :param method: str, sorting method - "unsorted", "lexicographic", "mds", "distance_to_most_frequent"
233
+ :param mask: np.array(bool), if provided, sort only this subset
234
+ :param distance_matrix: np.array, required for "mds" and "distance_to_most_frequent" methods
235
+ :param weights: np.array, optional weights for sequences
236
+ :return: np.array sorting indices (relative to original order)
237
+ """
238
+ values = seqdata.values.copy()
239
+
240
+ n_sequences = len(values) if mask is None else int(np.sum(mask))
241
+
242
+ if mask is not None:
243
+ values = values[mask]
244
+ if distance_matrix is not None:
245
+ # Only slice if distance_matrix is for the full sample
246
+ if distance_matrix.shape[0] != n_sequences:
247
+ masked_indices = np.where(mask)[0]
248
+ distance_matrix = distance_matrix[np.ix_(masked_indices, masked_indices)]
249
+
250
+ if method == "unsorted" or method == "none":
251
+ # Keep original order (R default)
252
+ return np.arange(n_sequences)
253
+
254
+ elif method == "lexicographic":
255
+ # Lexicographic sorting (NaN-safe)
256
+ vals = values.astype(float, copy=True)
257
+ # Push NaNs to the end for sorting
258
+ vals = np.nan_to_num(vals, nan=np.inf)
259
+ return np.lexsort(vals.T[::-1])
260
+
261
+ elif method == "mds":
262
+ # MDS first dimension sorting
263
+ if distance_matrix is None:
264
+ raise ValueError("Distance matrix is required for MDS sorting")
265
+
266
+ # TODO: Support weighted MDS (TraMineR's wcmdscale analogue) when weights are provided.
267
+ # Compute MDS coordinates
268
+ mds_coords = _cmdscale(distance_matrix)
269
+
270
+ # Sort by first MDS dimension
271
+ return np.argsort(mds_coords[:, 0])
272
+
273
+ elif method == "distance_to_most_frequent":
274
+ # Sort by distance to most frequent sequence
275
+ if distance_matrix is None:
276
+ raise ValueError("Distance matrix is required for distance_to_most_frequent sorting")
277
+
278
+ # Find most frequent sequence
279
+ most_freq_idx = _find_most_frequent_sequence(values)
280
+
281
+ # Get distances to most frequent sequence
282
+ distances = distance_matrix[most_freq_idx, :]
283
+
284
+ # Sort by distance (ascending)
285
+ return np.argsort(distances)
286
+
287
+ else:
288
+ raise ValueError(f"Unsupported sorting method: {method}. "
289
+ f"Supported methods are: 'unsorted', 'lexicographic', 'mds', 'distance_to_most_frequent'")
290
+
291
+
292
+ def plot_sequence_index(seqdata: SequenceData,
293
+ # Grouping parameters
294
+ group_by_column=None,
295
+ group_dataframe=None,
296
+ group_column_name=None,
297
+ group_labels=None,
298
+ # Other parameters
299
+ sort_by="lexicographic",
300
+ sort_by_weight=False,
301
+ weights="auto",
302
+ figsize=(10, 6),
303
+ plot_style="standard",
304
+ title=None,
305
+ xlabel="Time",
306
+ ylabel="Sequences",
307
+ save_as=None,
308
+ dpi=200,
309
+ layout='column',
310
+ nrows: int = None,
311
+ ncols: int = None,
312
+ group_order=None,
313
+ sort_groups='auto',
314
+ fontsize=12,
315
+ show_group_titles: bool = True,
316
+ include_legend: bool = True,
317
+ sequence_selection="all",
318
+ n_sequences=10,
319
+ show_sequence_ids=False,
320
+ sort_by_ids=None,
321
+ return_sorted_ids=False,
322
+ show_title=True,
323
+ proportional_scaling=False,
324
+ hide_y_axis=False
325
+ ):
326
+ """Creates sequence index plots, optionally grouped by categories.
327
+
328
+ This function creates index plots that visualize sequences as horizontal lines,
329
+ with different sorting options matching R's TraMineR functionality.
330
+
331
+ **Two API modes for grouping:**
332
+
333
+ 1. **Simplified API** (when grouping info is already in the data):
334
+ ```python
335
+ plot_sequence_index(seqdata, group_by_column="Cluster", group_labels=cluster_labels)
336
+ ```
337
+
338
+ 2. **Complete API** (when grouping info is in a separate dataframe):
339
+ ```python
340
+ plot_sequence_index(seqdata, group_dataframe=membership_df,
341
+ group_column_name="Cluster", group_labels=cluster_labels)
342
+ ```
343
+
344
+ :param seqdata: SequenceData object containing sequence information
345
+
346
+ **New API parameters (recommended):**
347
+ :param group_by_column: (str, optional) Column name from seqdata.data to group by.
348
+ Use this when grouping information is already in your data.
349
+ Example: "Cluster", "sex", "education"
350
+ :param group_dataframe: (pd.DataFrame, optional) Separate dataframe containing grouping information.
351
+ Use this when grouping info is in a separate table (e.g., clustering results).
352
+ Must contain ID column and grouping column.
353
+ :param group_column_name: (str, optional) Name of the grouping column in group_dataframe.
354
+ Required when using group_dataframe.
355
+ :param group_labels: (dict, optional) Custom labels for group values.
356
+ Example: {1: "Late Family Formation", 2: "Early Partnership"}
357
+ Maps original values to display labels.
358
+
359
+ :param sort_by: Sorting method for sequences within groups:
360
+ - 'unsorted' or 'none': Keep original order (R TraMineR default)
361
+ - 'lexicographic': Sort sequences lexicographically
362
+ - 'mds': Sort by first MDS dimension (requires distance computation)
363
+ - 'distance_to_most_frequent': Sort by distance to most frequent sequence
364
+ :param sort_by_weight: If True, sort sequences by weight (descending), overrides sort_by
365
+ :param weights: (np.ndarray or "auto") Weights for sequences. If "auto", uses seqdata.weights if available
366
+ :param figsize: Size of each subplot figure (only used when plot_style="custom")
367
+ :param plot_style: Plot aspect style:
368
+ - 'standard': Standard proportions (10, 6) - balanced view
369
+ - 'compact': Compact/vertical proportions (8, 8) - more vertical like R plots
370
+ - 'wide': Wide proportions (12, 4) - emphasizes time progression
371
+ - 'narrow': Narrow/tall proportions (8, 10) - moderately vertical
372
+ - 'custom': Use the provided figsize parameter
373
+ :param title: Title for the plot (if None, default titles will be used)
374
+ :param xlabel: Label for the x-axis
375
+ :param ylabel: Label for the y-axis
376
+ :param save_as: File path to save the plot (if None, plot will be shown)
377
+ :param dpi: DPI for saved image
378
+ :param layout: Layout style - 'column' (default, 3xn), 'grid' (nxn)
379
+ :param group_order: List, manually specify group order (overrides sort_groups)
380
+ :param sort_groups: String, sorting method: 'auto'(smart numeric), 'numeric'(numeric prefix), 'alpha'(alphabetical), 'none'(original order)
381
+ :param fontsize: Base font size for text elements (titles use fontsize+2, ticks use fontsize-2)
382
+ :param show_group_titles: Whether to show group titles
383
+ :param include_legend: Whether to include legend in the plot (True by default)
384
+ :param sequence_selection: Method for selecting sequences to visualize:
385
+ - "all": Show all sequences (default)
386
+ - "first_n": Show first n sequences from each group
387
+ - "last_n": Show last n sequences from each group
388
+ - list: List of specific sequence IDs to show
389
+ :param n_sequences: Number of sequences to show when using "first_n" or "last_n" (default: 10)
390
+ :param show_sequence_ids: If True, show actual sequence IDs on y-axis instead of sequence numbers.
391
+ Most useful when sequence_selection is a list of IDs (default: False)
392
+ :param sort_by_ids: (list or np.ndarray, optional) Custom ID order for sorting sequences.
393
+ When provided, sequences will be sorted to match this ID order, overriding
394
+ the sort_by parameter. This is useful for aligning multiple plots so that
395
+ the same IDs appear in the same row across different visualizations.
396
+ Example: sort_by_ids=[1, 3, 2, 5, 4] will sort sequences by this exact order.
397
+ :param return_sorted_ids: (bool, default: False) If True, returns the sorted ID order after plotting.
398
+ This is useful for multidomain analysis where you want to use the sorted
399
+ IDs from the first plot to align subsequent plots.
400
+ Returns a dictionary with group names as keys and sorted ID arrays as values
401
+ (for grouped plots), or a single array of sorted IDs (for single plots).
402
+ :param show_title: (bool, default: True) If False, suppresses the main title display even if title parameter is provided.
403
+ This allows you to control title visibility separately from providing a title string.
404
+ :param proportional_scaling: (bool, default: False) If True, scales subplot heights proportionally based on
405
+ the number of sequences in each group. Useful when groups have very different sizes.
406
+ Only applies to grouped plots with layout='column'.
407
+ :param hide_y_axis: (bool, default: False) If True, hides y-axis ticks, labels, and spine for all subplots.
408
+ Useful when using proportional_scaling to create cleaner visualizations.
409
+
410
+ Note: For 'mds' and 'distance_to_most_frequent' sorting, distance matrices are computed
411
+ automatically using Optimal Matching (OM) with constant substitution costs.
412
+ """
413
+ # Determine figure size based on plot style
414
+ style_sizes = {
415
+ 'standard': (10, 6), # Balanced view
416
+ 'compact': (8, 8), # More square, like R plots
417
+ 'wide': (12, 4), # Wide, emphasizes time
418
+ 'narrow': (8, 10), # Moderately vertical
419
+ 'custom': figsize # User-provided
420
+ }
421
+
422
+ if plot_style not in style_sizes:
423
+ raise ValueError(f"Invalid plot_style '{plot_style}'. "
424
+ f"Supported styles: {list(style_sizes.keys())}")
425
+
426
+ # Special validation for custom plot style
427
+ if plot_style == 'custom' and figsize == (10, 6):
428
+ raise ValueError(
429
+ "When using plot_style='custom', you must explicitly provide a figsize parameter "
430
+ "that differs from the default (10, 6). "
431
+ "Suggested custom sizes:\n"
432
+ " - For wide plots: figsize=(15, 5)\n"
433
+ " - For tall plots: figsize=(7, 12)\n"
434
+ " - For square plots: figsize=(9, 9)\n"
435
+ " - For small plots: figsize=(6, 4)\n"
436
+ "Example: plot_sequence_index(data, plot_style='custom', figsize=(12, 8))"
437
+ )
438
+
439
+ actual_figsize = style_sizes[plot_style]
440
+
441
+ # Handle the simplified API: group_by_column
442
+ if group_by_column is not None:
443
+ # Validate that the column exists in the original data
444
+ if group_by_column not in seqdata.data.columns:
445
+ available_cols = [col for col in seqdata.data.columns if col not in seqdata.time and col != seqdata.id_col]
446
+ raise ValueError(
447
+ f"Column '{group_by_column}' not found in the data. "
448
+ f"Available columns for grouping: {available_cols}"
449
+ )
450
+
451
+ # Automatically create group_dataframe and group_column_name from the simplified API
452
+ group_dataframe = seqdata.data[[seqdata.id_col, group_by_column]].copy()
453
+ group_dataframe.columns = ['Entity ID', 'Category']
454
+ group_column_name = 'Category'
455
+
456
+ # Handle group labels - flexible and user-controllable
457
+ unique_values = seqdata.data[group_by_column].unique()
458
+
459
+ if group_labels is not None:
460
+ # User provided custom labels - use them
461
+ missing_keys = set(unique_values) - set(group_labels.keys())
462
+ if missing_keys:
463
+ raise ValueError(
464
+ f"group_labels missing mappings for values: {missing_keys}. "
465
+ f"Please provide labels for all unique values in '{group_by_column}': {sorted(unique_values)}"
466
+ )
467
+ group_dataframe['Category'] = group_dataframe['Category'].map(group_labels)
468
+ else:
469
+ # No custom labels provided - use smart defaults
470
+ if all(isinstance(v, (int, float, np.integer, np.floating)) and not pd.isna(v) for v in unique_values):
471
+ # Numeric values - keep as is (user can provide group_labels if they want custom names)
472
+ pass
473
+ # For string/categorical values, keep original values
474
+ # This handles cases where users already have meaningful labels like "Male"/"Female"
475
+
476
+ print(f"[>] Creating grouped plots by '{group_by_column}' with {len(unique_values)} categories")
477
+
478
+ # If no grouping information, create a single plot
479
+ if group_dataframe is None or group_column_name is None:
480
+ return _sequence_index_plot_single(seqdata, sort_by, sort_by_weight, weights, actual_figsize, plot_style, title, xlabel, ylabel, save_as, dpi, fontsize, include_legend, sequence_selection, n_sequences, show_sequence_ids, sort_by_ids, return_sorted_ids, show_title)
481
+
482
+ # Process weights
483
+ if isinstance(weights, str) and weights == "auto":
484
+ weights = getattr(seqdata, "weights", None)
485
+
486
+ if weights is not None:
487
+ weights = np.asarray(weights, dtype=float).reshape(-1)
488
+ if len(weights) != len(seqdata.values):
489
+ raise ValueError("Length of weights must equal number of sequences.")
490
+
491
+ # Ensure ID columns match (convert if needed)
492
+ id_col_name = "Entity ID" if "Entity ID" in group_dataframe.columns else group_dataframe.columns[0]
493
+
494
+ # Apply group_labels if provided (for group_dataframe API)
495
+ if group_labels is not None and group_column_name in group_dataframe.columns:
496
+ # Validate that all values in the group column have labels
497
+ unique_values = group_dataframe[group_column_name].unique()
498
+ missing_keys = set(unique_values) - set(group_labels.keys())
499
+
500
+ # Track if we performed auto-remapping (to avoid double copying)
501
+ remapping_performed = False
502
+
503
+ # Auto-detect and fix: if missing_keys exist and they look like medoid indices
504
+ # (e.g., large integers that don't match group_labels keys), automatically remap them
505
+ if missing_keys:
506
+ # Check if missing_keys look like medoid indices (large integers > expected cluster count)
507
+ expected_cluster_count = len(group_labels)
508
+ missing_values_list = list(missing_keys)
509
+
510
+ # Check if all missing values are numeric and larger than expected cluster count
511
+ # This suggests they might be medoid indices from KMedoids
512
+ all_numeric = all(isinstance(v, (int, float, np.integer, np.floating)) and not pd.isna(v)
513
+ for v in missing_values_list)
514
+ all_large = all(isinstance(v, (int, float, np.integer, np.floating)) and not pd.isna(v)
515
+ and (v > expected_cluster_count or v < 1) for v in missing_values_list)
516
+
517
+ if all_numeric and all_large and len(missing_values_list) == expected_cluster_count:
518
+ # This looks like medoid indices - auto-remap to 1-k
519
+ print(f"[>] Detected medoid indices in '{group_column_name}'. "
520
+ f"Automatically remapping {sorted(missing_values_list)} to cluster labels 1-{expected_cluster_count}.")
521
+
522
+ # Create mapping from medoid indices to cluster labels 1-k
523
+ sorted_missing = sorted(missing_values_list)
524
+ medoid_to_cluster = {val: idx + 1 for idx, val in enumerate(sorted_missing)}
525
+
526
+ # Apply the remapping
527
+ group_dataframe = group_dataframe.copy() # Avoid modifying original
528
+ remapping_performed = True
529
+ group_dataframe[group_column_name] = group_dataframe[group_column_name].map(medoid_to_cluster)
530
+
531
+ # Now verify that all values match group_labels keys
532
+ unique_values_after_remap = group_dataframe[group_column_name].unique()
533
+ missing_keys_after = set(unique_values_after_remap) - set(group_labels.keys())
534
+ if missing_keys_after:
535
+ raise ValueError(
536
+ f"After auto-remapping, group_labels still missing mappings for values: {missing_keys_after}. "
537
+ f"Please provide labels for all unique values in '{group_column_name}': {sorted(unique_values_after_remap)}"
538
+ )
539
+ else:
540
+ # Not medoid indices - raise error as before
541
+ raise ValueError(
542
+ f"group_labels missing mappings for values: {missing_keys}. "
543
+ f"Please provide labels for all unique values in '{group_column_name}': {sorted(unique_values)}"
544
+ )
545
+
546
+ # Apply the labels mapping
547
+ # Only copy if we haven't already copied above (during medoid remapping)
548
+ if not remapping_performed:
549
+ group_dataframe = group_dataframe.copy() # Avoid modifying original
550
+
551
+ group_dataframe[group_column_name] = group_dataframe[group_column_name].map(group_labels)
552
+
553
+ # Get unique groups and sort them based on user preference
554
+ if group_order:
555
+ # Use manually specified order, filter out non-existing groups
556
+ groups = [g for g in group_order if g in group_dataframe[group_column_name].unique()]
557
+ missing_groups = [g for g in group_dataframe[group_column_name].unique() if g not in group_order]
558
+ if missing_groups:
559
+ print(f"[Warning] Groups not in group_order will be excluded: {missing_groups}")
560
+ elif group_labels is not None:
561
+ # If group_labels is provided, use its key order to determine groups order
562
+ # This ensures subplot order matches the order in group_labels dictionary
563
+ # Note: group_labels keys are original values, values are labels (which become groups)
564
+ mapped_labels = []
565
+ available_labels = set(group_dataframe[group_column_name].unique())
566
+
567
+ # Iterate through group_labels in order (Python 3.7+ dicts maintain insertion order)
568
+ for original_key, label_value in group_labels.items():
569
+ # Check if this label exists in the mapped dataframe
570
+ if label_value in available_labels:
571
+ mapped_labels.append(label_value)
572
+
573
+ # Also check for any labels in dataframe that weren't in group_labels
574
+ missing_in_labels = available_labels - set(mapped_labels)
575
+ if missing_in_labels:
576
+ print(f"[Warning] Some groups in data are not in group_labels and will be excluded: {missing_in_labels}")
577
+
578
+ groups = mapped_labels
579
+ elif sort_groups == 'numeric' or sort_groups == 'auto':
580
+ groups = smart_sort_groups(group_dataframe[group_column_name].unique())
581
+ elif sort_groups == 'alpha':
582
+ groups = sorted(group_dataframe[group_column_name].unique())
583
+ elif sort_groups == 'none':
584
+ groups = list(group_dataframe[group_column_name].unique())
585
+ else:
586
+ raise ValueError(f"Invalid sort_groups value: {sort_groups}. Use 'auto', 'numeric', 'alpha', or 'none'.")
587
+
588
+ num_groups = len(groups)
589
+
590
+ # Calculate figure size and layout based on number of groups and specified layout
591
+ nrows, ncols = determine_layout(num_groups, layout=layout, nrows=nrows, ncols=ncols)
592
+
593
+ # Calculate height ratios for proportional scaling if enabled
594
+ if proportional_scaling and layout == 'column':
595
+ # First pass: collect group sizes
596
+ group_sizes = []
597
+ for group in groups:
598
+ group_ids = group_dataframe[group_dataframe[group_column_name] == group][id_col_name].values
599
+ mask = np.isin(seqdata.ids, group_ids)
600
+ if np.any(mask):
601
+ mask = _select_sequences_subset(seqdata, sequence_selection, n_sequences, sort_by, sort_by_weight, weights, mask)
602
+ group_sizes.append(int(np.sum(mask)))
603
+ else:
604
+ group_sizes.append(1)
605
+
606
+ # Calculate height ratios (min 0.3 to avoid too small subplots)
607
+ if len(group_sizes) > 0:
608
+ max_size = max(group_sizes)
609
+ height_ratios = [max(0.3, size / max_size) for size in group_sizes]
610
+ max_ratio = max(height_ratios)
611
+ height_ratios = [h / max_ratio for h in height_ratios]
612
+ else:
613
+ height_ratios = [1.0] * num_groups
614
+
615
+ # Use gridspec for proportional heights
616
+ # Increase hspace to prevent x-axis labels from overlapping with subplot above
617
+ # Use larger hspace for column layout to accommodate group titles and x-axis labels
618
+ hspace_value = 0.4 if layout == 'column' else 0.25
619
+ fig = plt.figure(figsize=(actual_figsize[0], actual_figsize[1] * sum(height_ratios) / len(height_ratios) * num_groups))
620
+ gs = gridspec.GridSpec(nrows=num_groups, ncols=1, figure=fig,
621
+ height_ratios=height_ratios, hspace=hspace_value, wspace=0.15)
622
+ axes = [fig.add_subplot(gs[i]) for i in range(num_groups)]
623
+ else:
624
+ # Standard subplots with equal heights
625
+ # Increase hspace for column layout to prevent x-axis labels from overlapping
626
+ hspace_value = 0.4 if layout == 'column' else 0.25
627
+ fig, axes = plt.subplots(
628
+ nrows=nrows,
629
+ ncols=ncols,
630
+ figsize=(actual_figsize[0] * ncols, actual_figsize[1] * nrows),
631
+ gridspec_kw={'wspace': 0.15, 'hspace': hspace_value}
632
+ )
633
+ axes = axes.flatten()
634
+
635
+ # Dictionary to store sorted IDs for each group (if return_sorted_ids is True)
636
+ # Use OrderedDict or list to maintain group order
637
+ sorted_ids_by_group = {}
638
+ group_order_list = [] # Track the order of groups as they are processed
639
+
640
+ # Create a plot for each group
641
+ for i, group in enumerate(groups):
642
+ # Get IDs for this group
643
+ group_ids = group_dataframe[group_dataframe[group_column_name] == group][id_col_name].values
644
+
645
+ # Match IDs with sequence data
646
+ mask = np.isin(seqdata.ids, group_ids)
647
+ if not np.any(mask):
648
+ print(f"Warning: No matching sequences found for group '{group}'")
649
+ continue
650
+
651
+ # Apply sequence selection to this group
652
+ mask = _select_sequences_subset(seqdata, sequence_selection, n_sequences, sort_by, sort_by_weight, weights, mask)
653
+
654
+ # Extract sequences for this group
655
+ group_sequences = seqdata.values[mask]
656
+
657
+ # Track group IDs for y-axis labels
658
+ group_ids_for_labels = None
659
+ if hasattr(seqdata, 'ids') and seqdata.ids is not None and show_sequence_ids:
660
+ group_ids_for_labels = seqdata.ids[mask]
661
+
662
+ # Get weights for this group
663
+ if weights is not None:
664
+ group_weights = weights[mask]
665
+ else:
666
+ group_weights = None
667
+
668
+ # Handle NaN values for better visualization
669
+ if np.isnan(group_sequences).any():
670
+ # Map NaN to a dedicated state code with proper masking
671
+ group_sequences = group_sequences.astype(float)
672
+ group_sequences[np.isnan(group_sequences)] = np.nan
673
+
674
+ # Get the IDs for this group (after selection)
675
+ group_ids_after_selection = seqdata.ids[mask]
676
+
677
+ # Determine sorting method: sort_by_ids takes priority if provided
678
+ if sort_by_ids is not None:
679
+ # Sort by custom ID order
680
+ # Convert sort_by_ids to numpy array for easier handling
681
+ sort_by_ids_array = np.asarray(sort_by_ids)
682
+
683
+ # Create a mapping from ID to position in sort_by_ids
684
+ # IDs not in sort_by_ids will be placed at the end
685
+ id_to_position = {id_val: pos for pos, id_val in enumerate(sort_by_ids_array)}
686
+
687
+ # Get positions for each ID in the current group
688
+ # IDs not in sort_by_ids get a very large position value (placed at end)
689
+ max_position = len(sort_by_ids_array)
690
+ positions = np.array([id_to_position.get(id_val, max_position + i)
691
+ for i, id_val in enumerate(group_ids_after_selection)])
692
+
693
+ # Sort by position (ascending order)
694
+ sorted_indices = np.argsort(positions)
695
+
696
+ # Warn if some IDs in the group are not in sort_by_ids
697
+ missing_ids = set(group_ids_after_selection) - set(sort_by_ids_array)
698
+ if missing_ids:
699
+ print(f"[Warning] Group '{group}': {len(missing_ids)} IDs not found in sort_by_ids, "
700
+ f"they will be placed at the end: {list(missing_ids)[:5]}{'...' if len(missing_ids) > 5 else ''}")
701
+
702
+ elif sort_by_weight and group_weights is not None:
703
+ # Sort by weight (descending)
704
+ sorted_indices = np.argsort(-group_weights)
705
+ else:
706
+ # For group plots, we'll use simpler sorting to avoid complex object creation
707
+ if sort_by == "lexicographic":
708
+ vals = group_sequences.astype(float, copy=True)
709
+ vals = np.nan_to_num(vals, nan=np.inf)
710
+ sorted_indices = np.lexsort(vals.T[::-1])
711
+ elif sort_by in ["mds", "distance_to_most_frequent"]:
712
+ # Fallback to lexicographic for complex sorting methods
713
+ print(f"Warning: {sort_by} sorting simplified to lexicographic for grouped plots with sequence selection")
714
+ vals = group_sequences.astype(float, copy=True)
715
+ vals = np.nan_to_num(vals, nan=np.inf)
716
+ sorted_indices = np.lexsort(vals.T[::-1])
717
+ else:
718
+ # unsorted or other methods
719
+ sorted_indices = np.arange(len(group_sequences))
720
+
721
+ sorted_data = group_sequences[sorted_indices]
722
+
723
+ # Track sorted IDs for y-axis labels if needed
724
+ sorted_group_ids = None
725
+ if group_ids_for_labels is not None and show_sequence_ids:
726
+ sorted_group_ids = group_ids_for_labels[sorted_indices]
727
+
728
+ # Store sorted IDs for this group if return_sorted_ids is True
729
+ if return_sorted_ids:
730
+ sorted_ids_by_group[group] = group_ids_after_selection[sorted_indices]
731
+ # Track group order (only add once per group)
732
+ if group not in group_order_list:
733
+ group_order_list.append(group)
734
+
735
+ # Plot on the corresponding axis
736
+ ax = axes[i]
737
+ # Use masked array for better NaN handling
738
+ data = sorted_data.astype(float)
739
+ data[data < 1] = np.nan
740
+
741
+ # Check for all-missing or all-invalid data
742
+ if np.all(~np.isfinite(data)):
743
+ print(f"Warning: all values missing/invalid for group '{group}'")
744
+ ax.axis('off')
745
+ continue
746
+
747
+ im = ax.imshow(np.ma.masked_invalid(data), aspect='auto', cmap=seqdata.get_colormap(),
748
+ interpolation='nearest', vmin=1, vmax=len(seqdata.states))
749
+
750
+ # Remove grid lines
751
+ ax.grid(False)
752
+
753
+ # Set up time labels
754
+ set_up_time_labels_for_x_axis(seqdata, ax)
755
+
756
+ # Enhance y-axis aesthetics - evenly spaced ticks including the last sequence
757
+ num_sequences = sorted_data.shape[0]
758
+
759
+ # Determine tick positions and labels
760
+ if show_sequence_ids and sorted_group_ids is not None:
761
+ # Show sequence IDs instead of sequence numbers
762
+ # For large number of sequences, show fewer ticks to avoid overcrowding
763
+ if num_sequences <= 10:
764
+ ytick_positions = np.arange(num_sequences)
765
+ ytick_labels = [str(sid) for sid in sorted_group_ids]
766
+ else:
767
+ # Show subset of IDs for readability
768
+ if plot_style == "narrow":
769
+ num_ticks = min(8, num_sequences)
770
+ else:
771
+ num_ticks = min(11, num_sequences)
772
+ ytick_positions = np.linspace(0, num_sequences - 1, num=num_ticks, dtype=int)
773
+ ytick_positions = np.unique(ytick_positions)
774
+ ytick_labels = [str(sorted_group_ids[pos]) for pos in ytick_positions]
775
+ else:
776
+ # Default behavior: show sequence numbers
777
+ if plot_style == "narrow":
778
+ num_ticks = min(8, num_sequences) # Fewer ticks for narrow plots
779
+ else:
780
+ num_ticks = min(11, num_sequences)
781
+ ytick_positions = np.linspace(0, num_sequences - 1, num=num_ticks, dtype=int)
782
+ ytick_positions = np.unique(ytick_positions)
783
+ ytick_labels = (ytick_positions + 1).astype(int)
784
+
785
+ # Hide y-axis if requested
786
+ if hide_y_axis:
787
+ ax.set_yticks([])
788
+ ax.set_yticklabels([])
789
+ ax.spines['left'].set_visible(False)
790
+ else:
791
+ ax.set_yticks(ytick_positions)
792
+ ax.set_yticklabels(ytick_labels, fontsize=fontsize-2, color='black')
793
+
794
+ # Customize axis style
795
+ ax.spines['top'].set_visible(False)
796
+ ax.spines['right'].set_visible(False)
797
+ if not hide_y_axis:
798
+ ax.spines['left'].set_color('gray')
799
+ ax.spines['left'].set_linewidth(0.7)
800
+ ax.spines['left'].set_position(('outward', 5))
801
+ ax.spines['bottom'].set_color('gray')
802
+ ax.spines['bottom'].set_linewidth(0.7)
803
+
804
+ # Move spines slightly away from the plot area for better aesthetics
805
+ ax.spines['bottom'].set_position(('outward', 5))
806
+
807
+ # Ensure ticks are always visible regardless of plot style
808
+ ax.tick_params(axis='x', colors='gray', length=4, width=0.7, which='major')
809
+ if not hide_y_axis:
810
+ ax.tick_params(axis='y', colors='gray', length=4, width=0.7, which='major')
811
+
812
+ # Force tick visibility for narrow plot styles
813
+ ax.xaxis.set_ticks_position('bottom')
814
+ if not hide_y_axis:
815
+ ax.yaxis.set_ticks_position('left')
816
+ ax.tick_params(axis='both', which='major', direction='out')
817
+
818
+ # Add group title with weight information
819
+ # Check if we have effective weights (not all 1.0) and they were provided by user
820
+ original_weights = getattr(seqdata, "weights", None)
821
+ if original_weights is not None and not np.allclose(original_weights, 1.0) and group_weights is not None:
822
+ sum_w = float(group_weights.sum())
823
+ group_title = f"{group} (n = {num_sequences}, total weight = {sum_w:.1f})"
824
+ else:
825
+ group_title = f"{group} (n = {num_sequences})"
826
+ if show_group_titles:
827
+ show_plot_title(ax, group_title, show=True, fontsize=fontsize, loc='right')
828
+
829
+ # Add axis labels
830
+ if i % ncols == 0 and not hide_y_axis:
831
+ ax.set_ylabel(ylabel, fontsize=fontsize, labelpad=10, color='black')
832
+
833
+ # For column layout, only show x-axis label on the bottom subplot
834
+ # For grid layout, show x-axis label on bottom row subplots
835
+ if layout == 'column':
836
+ # Only show xlabel on the last (bottom) subplot
837
+ if i == num_groups - 1:
838
+ ax.set_xlabel(xlabel, fontsize=fontsize, labelpad=10, color='black')
839
+ else:
840
+ # For grid layout, show xlabel on bottom row
841
+ if i >= num_groups - ncols:
842
+ ax.set_xlabel(xlabel, fontsize=fontsize, labelpad=10, color='black')
843
+
844
+ # Hide unused subplots (not needed for proportional scaling with column layout)
845
+ if not (proportional_scaling and layout == 'column'):
846
+ for j in range(i + 1, len(axes)):
847
+ axes[j].set_visible(False)
848
+
849
+ # Add a common title if provided and show_title is True
850
+ if title and show_title:
851
+ fig.suptitle(title, fontsize=fontsize+2, y=1.02)
852
+
853
+ # Adjust layout to remove tight_layout warning and eliminate extra right space
854
+ # Increase hspace for column layout to prevent x-axis labels from overlapping with subplot above
855
+ if proportional_scaling and layout == 'column':
856
+ fig.subplots_adjust(left=0.08, right=0.98, bottom=0.1, top=0.9, wspace=0.15, hspace=0.4)
857
+ else:
858
+ hspace_value = 0.4 if layout == 'column' else 0.25
859
+ fig.subplots_adjust(wspace=0.15, hspace=hspace_value, bottom=0.1, top=0.9, right=0.98, left=0.08)
860
+
861
+ # Save main figure to memory
862
+ main_buffer = save_figure_to_buffer(fig, dpi=dpi)
863
+
864
+ if include_legend:
865
+ # Create standalone legend
866
+ colors = seqdata.color_map_by_label
867
+ legend_buffer = create_standalone_legend(
868
+ colors=colors,
869
+ labels=seqdata.labels,
870
+ ncol=min(5, len(seqdata.states)),
871
+ figsize=(actual_figsize[0] * ncols, 1),
872
+ fontsize=fontsize-2,
873
+ dpi=dpi
874
+ )
875
+
876
+ # Combine plot with legend
877
+ if save_as and not save_as.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf')):
878
+ save_as = save_as + '.png'
879
+
880
+ combined_img = combine_plot_with_legend(
881
+ main_buffer,
882
+ legend_buffer,
883
+ output_path=save_as,
884
+ dpi=dpi,
885
+ padding=20
886
+ )
887
+
888
+ # Display combined image
889
+ plt.figure(figsize=(actual_figsize[0] * ncols, actual_figsize[1] * nrows + 1))
890
+ plt.imshow(combined_img)
891
+ plt.axis('off')
892
+ plt.show()
893
+ plt.close()
894
+ else:
895
+ # Display plot without legend
896
+ if save_as and not save_as.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf')):
897
+ save_as = save_as + '.png'
898
+
899
+ # Save or show the main plot directly
900
+ plt.figure(figsize=(actual_figsize[0] * ncols, actual_figsize[1] * nrows))
901
+ plt.imshow(main_buffer)
902
+ plt.axis('off')
903
+
904
+ if save_as:
905
+ plt.savefig(save_as, dpi=dpi, bbox_inches='tight')
906
+ plt.show()
907
+ plt.close()
908
+
909
+ # Return sorted IDs if requested
910
+ # Return as dictionary with group order preserved (groups processed in plot order)
911
+ if return_sorted_ids:
912
+ # Create an ordered dictionary or return the dictionary with group_order_list for reference
913
+ # For simplicity, return the dictionary (Python 3.7+ maintains insertion order)
914
+ # But we'll also return group_order_list as metadata if needed
915
+ return sorted_ids_by_group
916
+
917
+
918
+ def _sequence_index_plot_single(seqdata: SequenceData,
919
+ sort_by="unsorted",
920
+ sort_by_weight=False,
921
+ weights="auto",
922
+ figsize=(10, 6),
923
+ plot_style="standard",
924
+ title=None,
925
+ xlabel="Time",
926
+ ylabel="Sequences",
927
+ save_as=None,
928
+ dpi=200,
929
+ fontsize=12,
930
+ include_legend=True,
931
+ sequence_selection="all",
932
+ n_sequences=10,
933
+ show_sequence_ids=False,
934
+ sort_by_ids=None,
935
+ return_sorted_ids=False,
936
+ show_title=True):
937
+ """Efficiently creates a sequence index plot using `imshow` for faster rendering.
938
+
939
+ :param seqdata: SequenceData object containing sequence information
940
+ :param sort_by: Sorting method ('unsorted', 'lexicographic', 'mds', 'distance_to_most_frequent')
941
+ :param sort_by_weight: If True, sort sequences by weight (descending), overrides sort_by
942
+ :param weights: (np.ndarray or "auto") Weights for sequences. If "auto", uses seqdata.weights if available
943
+ :param figsize: (tuple): Size of the figure (only used when plot_style="custom").
944
+ :param plot_style: Plot aspect style ('standard', 'compact', 'wide', 'narrow', 'custom')
945
+ :param title: (str): Title for the plot.
946
+ :param xlabel: (str): Label for the x-axis.
947
+ :param ylabel: (str): Label for the y-axis.
948
+ :param save_as: File path to save the plot
949
+ :param dpi: DPI for saved image
950
+ :param include_legend: Whether to include legend in the plot (True by default)
951
+ :param sequence_selection: Method for selecting sequences ("all", "first_n", "last_n", or list of IDs)
952
+ :param n_sequences: Number of sequences for "first_n" or "last_n"
953
+ :param show_sequence_ids: If True, show actual sequence IDs on y-axis instead of sequence numbers
954
+
955
+ :return None.
956
+ """
957
+ # Determine figure size based on plot style
958
+ style_sizes = {
959
+ 'standard': (10, 6), # Balanced view
960
+ 'compact': (8, 8), # More square, like R plots
961
+ 'wide': (12, 4), # Wide, emphasizes time
962
+ 'narrow': (8, 10), # Moderately vertical
963
+ 'custom': figsize # User-provided
964
+ }
965
+
966
+ if plot_style not in style_sizes:
967
+ raise ValueError(f"Invalid plot_style '{plot_style}'. "
968
+ f"Supported styles: {list(style_sizes.keys())}")
969
+
970
+ # Special validation for custom plot style
971
+ if plot_style == 'custom' and figsize == (10, 6):
972
+ raise ValueError(
973
+ "When using plot_style='custom', you must explicitly provide a figsize parameter "
974
+ "that differs from the default (10, 6). "
975
+ "Suggested custom sizes:\n"
976
+ " - For wide plots: figsize=(15, 5)\n"
977
+ " - For tall plots: figsize=(7, 12)\n"
978
+ " - For square plots: figsize=(9, 9)\n"
979
+ " - For small plots: figsize=(6, 4)\n"
980
+ "Example: plot_sequence_index(data, plot_style='custom', figsize=(12, 8))"
981
+ )
982
+
983
+ actual_figsize = style_sizes[plot_style]
984
+
985
+ # Process weights
986
+ if isinstance(weights, str) and weights == "auto":
987
+ weights = getattr(seqdata, "weights", None)
988
+
989
+ if weights is not None:
990
+ weights = np.asarray(weights, dtype=float).reshape(-1)
991
+ if len(weights) != len(seqdata.values):
992
+ raise ValueError("Length of weights must equal number of sequences.")
993
+
994
+ # Apply sequence selection and get the filtered data directly
995
+ selection_mask = _select_sequences_subset(seqdata, sequence_selection, n_sequences, sort_by, sort_by_weight, weights)
996
+
997
+ # Get sequence values as NumPy array (apply selection if needed)
998
+ selected_ids = None # Track selected IDs for y-axis labels
999
+ if not np.all(selection_mask):
1000
+ sequence_values = seqdata.values[selection_mask].copy()
1001
+ # Track selected IDs for y-axis display
1002
+ if hasattr(seqdata, 'ids') and seqdata.ids is not None:
1003
+ selected_ids = seqdata.ids[selection_mask]
1004
+ # Update weights if provided
1005
+ if weights is not None:
1006
+ weights = weights[selection_mask]
1007
+ else:
1008
+ sequence_values = seqdata.values.copy()
1009
+ # All IDs are selected
1010
+ if hasattr(seqdata, 'ids') and seqdata.ids is not None:
1011
+ selected_ids = seqdata.ids
1012
+
1013
+ # Handle NaN values for better visualization
1014
+ if np.isnan(sequence_values).any():
1015
+ # Keep NaN as float for proper masking
1016
+ sequence_values = sequence_values.astype(float)
1017
+
1018
+ # Sort sequences based on specified method
1019
+ # sort_by_ids takes priority if provided
1020
+ if sort_by_ids is not None:
1021
+ # Sort by custom ID order
1022
+ # Convert sort_by_ids to numpy array for easier handling
1023
+ sort_by_ids_array = np.asarray(sort_by_ids)
1024
+
1025
+ # Create a mapping from ID to position in sort_by_ids
1026
+ # IDs not in sort_by_ids will be placed at the end
1027
+ id_to_position = {id_val: pos for pos, id_val in enumerate(sort_by_ids_array)}
1028
+
1029
+ # Get positions for each ID in the selected sequences
1030
+ # IDs not in sort_by_ids get a very large position value (placed at end)
1031
+ max_position = len(sort_by_ids_array)
1032
+ positions = np.array([id_to_position.get(id_val, max_position + i)
1033
+ for i, id_val in enumerate(selected_ids)])
1034
+
1035
+ # Sort by position (ascending order)
1036
+ sorted_indices = np.argsort(positions)
1037
+
1038
+ # Warn if some IDs are not in sort_by_ids
1039
+ missing_ids = set(selected_ids) - set(sort_by_ids_array)
1040
+ if missing_ids:
1041
+ print(f"[Warning] {len(missing_ids)} IDs not found in sort_by_ids, "
1042
+ f"they will be placed at the end: {list(missing_ids)[:5]}{'...' if len(missing_ids) > 5 else ''}")
1043
+
1044
+ elif sort_by_weight and weights is not None:
1045
+ # Sort by weight (descending)
1046
+ sorted_indices = np.argsort(-weights)
1047
+ else:
1048
+ # Use simpler sorting for the filtered data
1049
+ if sort_by == "lexicographic":
1050
+ vals = sequence_values.astype(float, copy=True)
1051
+ vals = np.nan_to_num(vals, nan=np.inf)
1052
+ sorted_indices = np.lexsort(vals.T[::-1])
1053
+ elif sort_by in ["mds", "distance_to_most_frequent"]:
1054
+ # Fallback to lexicographic for complex sorting methods
1055
+ print(f"Warning: {sort_by} sorting simplified to lexicographic for sequence selection")
1056
+ vals = sequence_values.astype(float, copy=True)
1057
+ vals = np.nan_to_num(vals, nan=np.inf)
1058
+ sorted_indices = np.lexsort(vals.T[::-1])
1059
+ else:
1060
+ # unsorted or other methods
1061
+ sorted_indices = np.arange(len(sequence_values))
1062
+
1063
+ sorted_data = sequence_values[sorted_indices]
1064
+
1065
+ # Track sorted IDs for y-axis labels if needed
1066
+ sorted_ids = None
1067
+ if selected_ids is not None and show_sequence_ids:
1068
+ sorted_ids = selected_ids[sorted_indices]
1069
+
1070
+ # Create the plot using imshow with proper NaN handling
1071
+ fig, ax = plt.subplots(figsize=actual_figsize)
1072
+ # Use masked array for better NaN handling
1073
+ data = sorted_data.astype(float)
1074
+ data[data < 1] = np.nan
1075
+
1076
+ # Check for all-missing or all-invalid data
1077
+ if np.all(~np.isfinite(data)):
1078
+ print(f"Warning: all values missing/invalid in sequence data")
1079
+ ax.axis('off')
1080
+ return
1081
+
1082
+ ax.imshow(np.ma.masked_invalid(data), aspect='auto', cmap=seqdata.get_colormap(),
1083
+ interpolation='nearest', vmin=1, vmax=len(seqdata.states))
1084
+
1085
+ # Disable background grid and all axis guide lines
1086
+ ax.grid(False)
1087
+
1088
+ # Optional: remove tick marks and tick labels to avoid visual grid effects
1089
+ # ax.set_xticks([])
1090
+ # ax.set_yticks([])
1091
+
1092
+ # x label
1093
+ set_up_time_labels_for_x_axis(seqdata, ax)
1094
+
1095
+ # Enhance y-axis aesthetics - evenly spaced ticks including the last sequence
1096
+ num_sequences = sorted_data.shape[0]
1097
+
1098
+ # Determine tick positions and labels
1099
+ if show_sequence_ids and sorted_ids is not None:
1100
+ # Show sequence IDs instead of sequence numbers
1101
+ # For large number of sequences, show fewer ticks to avoid overcrowding
1102
+ if num_sequences <= 10:
1103
+ ytick_positions = np.arange(num_sequences)
1104
+ ytick_labels = [str(sid) for sid in sorted_ids]
1105
+ else:
1106
+ # Show subset of IDs for readability
1107
+ if plot_style == "narrow":
1108
+ num_ticks = min(8, num_sequences)
1109
+ else:
1110
+ num_ticks = min(11, num_sequences)
1111
+ ytick_positions = np.linspace(0, num_sequences - 1, num=num_ticks, dtype=int)
1112
+ ytick_positions = np.unique(ytick_positions)
1113
+ ytick_labels = [str(sorted_ids[pos]) for pos in ytick_positions]
1114
+ else:
1115
+ # Default behavior: show sequence numbers
1116
+ if plot_style == "narrow":
1117
+ num_ticks = min(8, num_sequences) # Fewer ticks for narrow plots
1118
+ else:
1119
+ num_ticks = min(11, num_sequences)
1120
+ ytick_positions = np.linspace(0, num_sequences - 1, num=num_ticks, dtype=int)
1121
+ ytick_positions = np.unique(ytick_positions)
1122
+ ytick_labels = (ytick_positions + 1).astype(int)
1123
+
1124
+ ax.set_yticks(ytick_positions)
1125
+ ax.set_yticklabels(ytick_labels, fontsize=fontsize-2, color='black')
1126
+
1127
+
1128
+ # Customize axis line styles and ticks
1129
+ ax.spines['top'].set_visible(False)
1130
+ ax.spines['right'].set_visible(False)
1131
+ ax.spines['left'].set_color('gray')
1132
+ ax.spines['bottom'].set_color('gray')
1133
+ ax.spines['left'].set_linewidth(0.7)
1134
+ ax.spines['bottom'].set_linewidth(0.7)
1135
+
1136
+ # Move spines slightly away from the plot area for better aesthetics
1137
+ ax.spines['left'].set_position(('outward', 5))
1138
+ ax.spines['bottom'].set_position(('outward', 5))
1139
+
1140
+ # Ensure ticks are always visible regardless of plot style
1141
+ ax.tick_params(axis='x', colors='gray', length=4, width=0.7, which='major')
1142
+ ax.tick_params(axis='y', colors='gray', length=4, width=0.7, which='major')
1143
+
1144
+ # Force tick visibility for narrow plot styles
1145
+ ax.xaxis.set_ticks_position('bottom')
1146
+ ax.yaxis.set_ticks_position('left')
1147
+ ax.tick_params(axis='both', which='major', direction='out')
1148
+
1149
+ # Add labels and title
1150
+ ax.set_xlabel(xlabel, fontsize=fontsize, labelpad=10, color='black')
1151
+ ax.set_ylabel(ylabel, fontsize=fontsize, labelpad=10, color='black')
1152
+
1153
+ # Set title with weight information if available and show_title is True
1154
+ if title is not None and show_title:
1155
+ display_title = title
1156
+
1157
+ # Check if we have effective weights (not all 1.0) and they were provided by user
1158
+ original_weights = getattr(seqdata, "weights", None)
1159
+ if original_weights is not None and not np.allclose(original_weights, 1.0) and weights is not None:
1160
+ sum_w = float(weights.sum())
1161
+ display_title += f" (n = {num_sequences}, total weight = {sum_w:.1f})"
1162
+ else:
1163
+ display_title += f" (n = {num_sequences})"
1164
+
1165
+ ax.set_title(display_title, fontsize=fontsize+2, color='black')
1166
+
1167
+ # Use legend from SequenceData if requested
1168
+ if include_legend:
1169
+ ax.legend(*seqdata.get_legend(), bbox_to_anchor=(1.05, 1), loc='upper left')
1170
+
1171
+ save_and_show_results(save_as, dpi=dpi)
1172
+
1173
+ # Return sorted IDs if requested
1174
+ if return_sorted_ids:
1175
+ return selected_ids[sorted_indices] if selected_ids is not None else None