sequenzo 0.1.31__cp310-cp310-macosx_10_9_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _sequenzo_fastcluster.cpython-310-darwin.so +0 -0
- sequenzo/__init__.py +349 -0
- sequenzo/big_data/__init__.py +12 -0
- sequenzo/big_data/clara/__init__.py +26 -0
- sequenzo/big_data/clara/clara.py +476 -0
- sequenzo/big_data/clara/utils/__init__.py +27 -0
- sequenzo/big_data/clara/utils/aggregatecases.py +92 -0
- sequenzo/big_data/clara/utils/davies_bouldin.py +91 -0
- sequenzo/big_data/clara/utils/get_weighted_diss.cpython-310-darwin.so +0 -0
- sequenzo/big_data/clara/utils/wfcmdd.py +205 -0
- sequenzo/big_data/clara/visualization.py +88 -0
- sequenzo/clustering/KMedoids.py +178 -0
- sequenzo/clustering/__init__.py +30 -0
- sequenzo/clustering/clustering_c_code.cpython-310-darwin.so +0 -0
- sequenzo/clustering/hierarchical_clustering.py +1256 -0
- sequenzo/clustering/sequenzo_fastcluster/fastcluster.py +495 -0
- sequenzo/clustering/sequenzo_fastcluster/src/fastcluster.cpp +1877 -0
- sequenzo/clustering/sequenzo_fastcluster/src/fastcluster_python.cpp +1264 -0
- sequenzo/clustering/src/KMedoid.cpp +263 -0
- sequenzo/clustering/src/PAM.cpp +237 -0
- sequenzo/clustering/src/PAMonce.cpp +265 -0
- sequenzo/clustering/src/cluster_quality.cpp +496 -0
- sequenzo/clustering/src/cluster_quality.h +128 -0
- sequenzo/clustering/src/cluster_quality_backup.cpp +570 -0
- sequenzo/clustering/src/module.cpp +228 -0
- sequenzo/clustering/src/weightedinertia.cpp +111 -0
- sequenzo/clustering/utils/__init__.py +27 -0
- sequenzo/clustering/utils/disscenter.py +122 -0
- sequenzo/data_preprocessing/__init__.py +22 -0
- sequenzo/data_preprocessing/helpers.py +303 -0
- sequenzo/datasets/__init__.py +41 -0
- sequenzo/datasets/biofam.csv +2001 -0
- sequenzo/datasets/biofam_child_domain.csv +2001 -0
- sequenzo/datasets/biofam_left_domain.csv +2001 -0
- sequenzo/datasets/biofam_married_domain.csv +2001 -0
- sequenzo/datasets/chinese_colonial_territories.csv +12 -0
- sequenzo/datasets/country_co2_emissions.csv +194 -0
- sequenzo/datasets/country_co2_emissions_global_deciles.csv +195 -0
- sequenzo/datasets/country_co2_emissions_global_quintiles.csv +195 -0
- sequenzo/datasets/country_co2_emissions_local_deciles.csv +195 -0
- sequenzo/datasets/country_co2_emissions_local_quintiles.csv +195 -0
- sequenzo/datasets/country_gdp_per_capita.csv +194 -0
- sequenzo/datasets/dyadic_children.csv +61 -0
- sequenzo/datasets/dyadic_parents.csv +61 -0
- sequenzo/datasets/mvad.csv +713 -0
- sequenzo/datasets/pairfam_activity_by_month.csv +1028 -0
- sequenzo/datasets/pairfam_activity_by_year.csv +1028 -0
- sequenzo/datasets/pairfam_family_by_month.csv +1028 -0
- sequenzo/datasets/pairfam_family_by_year.csv +1028 -0
- sequenzo/datasets/political_science_aid_shock.csv +166 -0
- sequenzo/datasets/political_science_donor_fragmentation.csv +157 -0
- sequenzo/define_sequence_data.py +1400 -0
- sequenzo/dissimilarity_measures/__init__.py +31 -0
- sequenzo/dissimilarity_measures/c_code.cpython-310-darwin.so +0 -0
- sequenzo/dissimilarity_measures/get_distance_matrix.py +762 -0
- sequenzo/dissimilarity_measures/get_substitution_cost_matrix.py +246 -0
- sequenzo/dissimilarity_measures/src/DHDdistance.cpp +148 -0
- sequenzo/dissimilarity_measures/src/LCPdistance.cpp +114 -0
- sequenzo/dissimilarity_measures/src/LCPspellDistance.cpp +215 -0
- sequenzo/dissimilarity_measures/src/OMdistance.cpp +247 -0
- sequenzo/dissimilarity_measures/src/OMspellDistance.cpp +281 -0
- sequenzo/dissimilarity_measures/src/__init__.py +0 -0
- sequenzo/dissimilarity_measures/src/dist2matrix.cpp +63 -0
- sequenzo/dissimilarity_measures/src/dp_utils.h +160 -0
- sequenzo/dissimilarity_measures/src/module.cpp +40 -0
- sequenzo/dissimilarity_measures/src/setup.py +30 -0
- sequenzo/dissimilarity_measures/src/utils.h +25 -0
- sequenzo/dissimilarity_measures/src/xsimd/.github/cmake-test/main.cpp +6 -0
- sequenzo/dissimilarity_measures/src/xsimd/benchmark/main.cpp +159 -0
- sequenzo/dissimilarity_measures/src/xsimd/benchmark/xsimd_benchmark.hpp +565 -0
- sequenzo/dissimilarity_measures/src/xsimd/docs/source/conf.py +37 -0
- sequenzo/dissimilarity_measures/src/xsimd/examples/mandelbrot.cpp +330 -0
- sequenzo/dissimilarity_measures/src/xsimd/examples/pico_bench.hpp +246 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_arithmetic.hpp +266 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_complex.hpp +112 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_details.hpp +323 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_logical.hpp +218 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_math.hpp +2583 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_memory.hpp +880 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_rounding.hpp +72 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_swizzle.hpp +174 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_trigo.hpp +978 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx.hpp +1924 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx2.hpp +1144 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512bw.hpp +656 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512cd.hpp +28 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512dq.hpp +244 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512er.hpp +20 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512f.hpp +2650 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512ifma.hpp +20 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512pf.hpp +20 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vbmi.hpp +77 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vbmi2.hpp +131 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vnni_avx512bw.hpp +20 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vnni_avx512vbmi2.hpp +20 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avxvnni.hpp +20 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_common.hpp +24 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_common_fwd.hpp +77 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_constants.hpp +393 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_emulated.hpp +788 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_avx.hpp +93 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_avx2.hpp +46 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_sse.hpp +97 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma4.hpp +92 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_i8mm_neon64.hpp +17 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_isa.hpp +142 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_neon.hpp +3142 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_neon64.hpp +1543 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_rvv.hpp +1513 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_scalar.hpp +1260 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse2.hpp +2024 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse3.hpp +67 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse4_1.hpp +339 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse4_2.hpp +44 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_ssse3.hpp +186 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sve.hpp +1155 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_vsx.hpp +892 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_wasm.hpp +1780 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_arch.hpp +240 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_config.hpp +484 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_cpuid.hpp +269 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_inline.hpp +27 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/math/xsimd_rem_pio2.hpp +719 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/memory/xsimd_aligned_allocator.hpp +349 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/memory/xsimd_alignment.hpp +91 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_all_registers.hpp +55 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_api.hpp +2765 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx2_register.hpp +44 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512bw_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512cd_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512dq_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512er_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512f_register.hpp +77 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512ifma_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512pf_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vbmi2_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vbmi_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vnni_avx512bw_register.hpp +54 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vnni_avx512vbmi2_register.hpp +53 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx_register.hpp +64 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avxvnni_register.hpp +44 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_batch.hpp +1524 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_batch_constant.hpp +300 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_common_arch.hpp +47 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_emulated_register.hpp +80 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma3_avx2_register.hpp +50 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma3_avx_register.hpp +50 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma3_sse_register.hpp +50 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma4_register.hpp +50 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_i8mm_neon64_register.hpp +55 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_neon64_register.hpp +55 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_neon_register.hpp +154 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_register.hpp +94 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_rvv_register.hpp +506 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse2_register.hpp +59 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse3_register.hpp +49 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse4_1_register.hpp +48 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse4_2_register.hpp +48 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_ssse3_register.hpp +48 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sve_register.hpp +156 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_traits.hpp +337 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_utils.hpp +536 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_vsx_register.hpp +77 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_wasm_register.hpp +59 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/xsimd.hpp +75 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/architectures/dummy.cpp +7 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set.cpp +13 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean.cpp +24 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean_aligned.cpp +25 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean_arch_independent.cpp +28 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean_tag_dispatch.cpp +25 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/manipulating_abstract_batches.cpp +7 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/manipulating_parametric_batches.cpp +8 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/sum.hpp +31 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/sum_avx2.cpp +3 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/sum_sse2.cpp +3 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/writing_vectorized_code.cpp +11 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/main.cpp +31 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_api.cpp +230 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_arch.cpp +217 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_basic_math.cpp +183 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch.cpp +1049 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_bool.cpp +508 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_cast.cpp +409 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_complex.cpp +712 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_constant.cpp +286 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_float.cpp +141 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_int.cpp +365 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_manip.cpp +308 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_bitwise_cast.cpp +222 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_exponential.cpp +226 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_hyperbolic.cpp +183 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_power.cpp +265 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_trigonometric.cpp +236 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_conversion.cpp +248 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_custom_default_arch.cpp +28 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_error_gamma.cpp +170 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_explicit_batch_instantiation.cpp +32 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_exponential.cpp +202 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_extract_pair.cpp +92 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_fp_manipulation.cpp +77 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_gnu_source.cpp +30 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_hyperbolic.cpp +167 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_load_store.cpp +304 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_memory.cpp +61 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_poly_evaluation.cpp +64 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_power.cpp +184 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_rounding.cpp +199 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_select.cpp +101 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_shuffle.cpp +760 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_sum.cpp +4 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_sum.hpp +34 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_traits.cpp +172 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_trigonometric.cpp +208 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_utils.hpp +611 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_wasm/test_wasm_playwright.py +123 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_xsimd_api.cpp +1460 -0
- sequenzo/dissimilarity_measures/utils/__init__.py +16 -0
- sequenzo/dissimilarity_measures/utils/get_LCP_length_for_2_seq.py +44 -0
- sequenzo/dissimilarity_measures/utils/get_sm_trate_substitution_cost_matrix.cpython-310-darwin.so +0 -0
- sequenzo/dissimilarity_measures/utils/seqconc.cpython-310-darwin.so +0 -0
- sequenzo/dissimilarity_measures/utils/seqdss.cpython-310-darwin.so +0 -0
- sequenzo/dissimilarity_measures/utils/seqdur.cpython-310-darwin.so +0 -0
- sequenzo/dissimilarity_measures/utils/seqlength.cpython-310-darwin.so +0 -0
- sequenzo/multidomain/__init__.py +23 -0
- sequenzo/multidomain/association_between_domains.py +311 -0
- sequenzo/multidomain/cat.py +597 -0
- sequenzo/multidomain/combt.py +519 -0
- sequenzo/multidomain/dat.py +81 -0
- sequenzo/multidomain/idcd.py +139 -0
- sequenzo/multidomain/linked_polyad.py +292 -0
- sequenzo/openmp_setup.py +233 -0
- sequenzo/prefix_tree/__init__.py +62 -0
- sequenzo/prefix_tree/hub.py +114 -0
- sequenzo/prefix_tree/individual_level_indicators.py +1321 -0
- sequenzo/prefix_tree/spell_individual_level_indicators.py +580 -0
- sequenzo/prefix_tree/spell_level_indicators.py +297 -0
- sequenzo/prefix_tree/system_level_indicators.py +544 -0
- sequenzo/prefix_tree/utils.py +54 -0
- sequenzo/seqhmm/__init__.py +95 -0
- sequenzo/seqhmm/advanced_optimization.py +305 -0
- sequenzo/seqhmm/bootstrap.py +411 -0
- sequenzo/seqhmm/build_hmm.py +142 -0
- sequenzo/seqhmm/build_mhmm.py +136 -0
- sequenzo/seqhmm/build_nhmm.py +121 -0
- sequenzo/seqhmm/fit_mhmm.py +62 -0
- sequenzo/seqhmm/fit_model.py +61 -0
- sequenzo/seqhmm/fit_nhmm.py +76 -0
- sequenzo/seqhmm/formulas.py +289 -0
- sequenzo/seqhmm/forward_backward_nhmm.py +276 -0
- sequenzo/seqhmm/gradients_nhmm.py +306 -0
- sequenzo/seqhmm/hmm.py +291 -0
- sequenzo/seqhmm/mhmm.py +314 -0
- sequenzo/seqhmm/model_comparison.py +238 -0
- sequenzo/seqhmm/multichannel_em.py +282 -0
- sequenzo/seqhmm/multichannel_utils.py +138 -0
- sequenzo/seqhmm/nhmm.py +270 -0
- sequenzo/seqhmm/nhmm_utils.py +191 -0
- sequenzo/seqhmm/predict.py +137 -0
- sequenzo/seqhmm/predict_mhmm.py +142 -0
- sequenzo/seqhmm/simulate.py +878 -0
- sequenzo/seqhmm/utils.py +218 -0
- sequenzo/seqhmm/visualization.py +910 -0
- sequenzo/sequence_characteristics/__init__.py +40 -0
- sequenzo/sequence_characteristics/complexity_index.py +49 -0
- sequenzo/sequence_characteristics/overall_cross_sectional_entropy.py +220 -0
- sequenzo/sequence_characteristics/plot_characteristics.py +593 -0
- sequenzo/sequence_characteristics/simple_characteristics.py +311 -0
- sequenzo/sequence_characteristics/state_frequencies_and_entropy_per_sequence.py +39 -0
- sequenzo/sequence_characteristics/turbulence.py +155 -0
- sequenzo/sequence_characteristics/variance_of_spell_durations.py +86 -0
- sequenzo/sequence_characteristics/within_sequence_entropy.py +43 -0
- sequenzo/suffix_tree/__init__.py +66 -0
- sequenzo/suffix_tree/hub.py +114 -0
- sequenzo/suffix_tree/individual_level_indicators.py +1679 -0
- sequenzo/suffix_tree/spell_individual_level_indicators.py +493 -0
- sequenzo/suffix_tree/spell_level_indicators.py +248 -0
- sequenzo/suffix_tree/system_level_indicators.py +535 -0
- sequenzo/suffix_tree/utils.py +56 -0
- sequenzo/version_check.py +283 -0
- sequenzo/visualization/__init__.py +29 -0
- sequenzo/visualization/plot_mean_time.py +222 -0
- sequenzo/visualization/plot_modal_state.py +276 -0
- sequenzo/visualization/plot_most_frequent_sequences.py +147 -0
- sequenzo/visualization/plot_relative_frequency.py +405 -0
- sequenzo/visualization/plot_sequence_index.py +1175 -0
- sequenzo/visualization/plot_single_medoid.py +153 -0
- sequenzo/visualization/plot_state_distribution.py +651 -0
- sequenzo/visualization/plot_transition_matrix.py +190 -0
- sequenzo/visualization/utils/__init__.py +23 -0
- sequenzo/visualization/utils/utils.py +310 -0
- sequenzo/with_event_history_analysis/__init__.py +35 -0
- sequenzo/with_event_history_analysis/sequence_analysis_multi_state_model.py +850 -0
- sequenzo/with_event_history_analysis/sequence_history_analysis.py +283 -0
- sequenzo-0.1.31.dist-info/METADATA +286 -0
- sequenzo-0.1.31.dist-info/RECORD +299 -0
- sequenzo-0.1.31.dist-info/WHEEL +5 -0
- sequenzo-0.1.31.dist-info/licenses/LICENSE +28 -0
- sequenzo-0.1.31.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,910 @@
|
|
|
1
|
+
"""
|
|
2
|
+
@Author : Yuqi Liang 梁彧祺
|
|
3
|
+
@File : visualization.py
|
|
4
|
+
@Time : 2025-11-18 07:25
|
|
5
|
+
@Desc : Visualization functions for HMM models
|
|
6
|
+
|
|
7
|
+
This module provides visualization functions for HMM models, similar to
|
|
8
|
+
seqHMM's plot.hmm() and plot.mhmm() functions in R.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import matplotlib.pyplot as plt
|
|
13
|
+
import matplotlib.patches as mpatches
|
|
14
|
+
from matplotlib.patches import FancyArrowPatch, Circle, Wedge
|
|
15
|
+
from matplotlib.collections import PatchCollection
|
|
16
|
+
from typing import Optional, List, Union
|
|
17
|
+
from .hmm import HMM
|
|
18
|
+
from .mhmm import MHMM
|
|
19
|
+
|
|
20
|
+
# Try to import networkx for network layout, but make it optional
|
|
21
|
+
try:
|
|
22
|
+
import networkx as nx
|
|
23
|
+
HAS_NETWORKX = True
|
|
24
|
+
except ImportError:
|
|
25
|
+
HAS_NETWORKX = False
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def plot_hmm(
|
|
29
|
+
model: HMM,
|
|
30
|
+
which: str = 'transition',
|
|
31
|
+
figsize: Optional[tuple] = None,
|
|
32
|
+
ax: Optional[plt.Axes] = None,
|
|
33
|
+
# Network plot parameters (similar to R's plot.hmm)
|
|
34
|
+
vertex_size: float = 50,
|
|
35
|
+
vertex_label_dist: float = 1.5,
|
|
36
|
+
edge_curved: Union[bool, float] = 0.5,
|
|
37
|
+
edge_label_cex: float = 0.8,
|
|
38
|
+
vertex_label: str = 'initial.probs',
|
|
39
|
+
loops: bool = False,
|
|
40
|
+
trim: float = 1e-15,
|
|
41
|
+
combine_slices: float = 0.05,
|
|
42
|
+
with_legend: Union[bool, str] = 'bottom',
|
|
43
|
+
layout: str = 'horizontal',
|
|
44
|
+
**kwargs
|
|
45
|
+
) -> plt.Figure:
|
|
46
|
+
"""
|
|
47
|
+
Plot HMM model parameters.
|
|
48
|
+
|
|
49
|
+
This function visualizes HMM model parameters, including:
|
|
50
|
+
- Transition probability matrix
|
|
51
|
+
- Emission probability matrix
|
|
52
|
+
- Initial state probabilities
|
|
53
|
+
- Network graph (similar to R's plot.hmm())
|
|
54
|
+
|
|
55
|
+
It is similar to seqHMM's plot.hmm() function in R.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
model: Fitted HMM model object
|
|
59
|
+
which: What to plot. Options:
|
|
60
|
+
- 'transition': Transition probability matrix (default)
|
|
61
|
+
- 'emission': Emission probability matrix
|
|
62
|
+
- 'initial': Initial state probabilities
|
|
63
|
+
- 'network': Network graph with pie chart nodes (like R's plot.hmm)
|
|
64
|
+
- 'all': All three plots (transition, emission, initial)
|
|
65
|
+
figsize: Figure size tuple (width, height). If None, uses default.
|
|
66
|
+
ax: Optional matplotlib axes to plot on. If None, creates new figure.
|
|
67
|
+
|
|
68
|
+
# Network plot parameters (only used when which='network'):
|
|
69
|
+
vertex_size: Size of vertices (nodes). Default 50.
|
|
70
|
+
vertex_label_dist: Distance of vertex labels from center. Default 1.5.
|
|
71
|
+
edge_curved: Whether to plot curved edges. Can be bool or float (curvature). Default 0.5.
|
|
72
|
+
edge_label_cex: Character expansion factor for edge labels. Default 0.8.
|
|
73
|
+
vertex_label: Labels for vertices. Options: 'initial.probs', 'names', or custom list. Default 'initial.probs'.
|
|
74
|
+
loops: Whether to plot self-loops (transitions back to same state). Default False.
|
|
75
|
+
trim: Minimum transition probability to plot. Default 1e-15.
|
|
76
|
+
combine_slices: Emission probabilities below this are combined into 'others'. Default 0.05.
|
|
77
|
+
with_legend: Whether and where to plot legend. Options: True, False, 'bottom', 'top', 'left', 'right'. Default 'bottom'.
|
|
78
|
+
layout: Layout of vertices. Options: 'horizontal', 'vertical'. Default 'horizontal'.
|
|
79
|
+
**kwargs: Additional arguments passed to network plot.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
matplotlib Figure: The figure object
|
|
83
|
+
|
|
84
|
+
Examples:
|
|
85
|
+
>>> from sequenzo import SequenceData, load_dataset
|
|
86
|
+
>>> from sequenzo.seqhmm import build_hmm, fit_model, plot_hmm
|
|
87
|
+
>>>
|
|
88
|
+
>>> # Load and prepare data
|
|
89
|
+
>>> df = load_dataset('mvad')
|
|
90
|
+
>>> seq = SequenceData(df, time=range(15, 86), states=['EM', 'FE', 'HE', 'JL', 'SC', 'TR'])
|
|
91
|
+
>>>
|
|
92
|
+
>>> # Build and fit model
|
|
93
|
+
>>> hmm = build_hmm(seq, n_states=4, random_state=42)
|
|
94
|
+
>>> hmm = fit_model(hmm)
|
|
95
|
+
>>>
|
|
96
|
+
>>> # Plot network graph (like R's plot.hmm)
|
|
97
|
+
>>> plot_hmm(hmm, which='network', vertex_size=50, edge_curved=0.5)
|
|
98
|
+
>>> plt.show()
|
|
99
|
+
>>>
|
|
100
|
+
>>> # Plot transition matrix
|
|
101
|
+
>>> plot_hmm(hmm, which='transition')
|
|
102
|
+
>>> plt.show()
|
|
103
|
+
"""
|
|
104
|
+
if model.log_likelihood is None:
|
|
105
|
+
raise ValueError("Model must be fitted before plotting. Use fit_model() first.")
|
|
106
|
+
|
|
107
|
+
if which == 'network':
|
|
108
|
+
return _plot_hmm_network(
|
|
109
|
+
model, figsize=figsize, ax=ax,
|
|
110
|
+
vertex_size=vertex_size,
|
|
111
|
+
vertex_label_dist=vertex_label_dist,
|
|
112
|
+
edge_curved=edge_curved,
|
|
113
|
+
edge_label_cex=edge_label_cex,
|
|
114
|
+
vertex_label=vertex_label,
|
|
115
|
+
loops=loops,
|
|
116
|
+
trim=trim,
|
|
117
|
+
combine_slices=combine_slices,
|
|
118
|
+
with_legend=with_legend,
|
|
119
|
+
layout=layout,
|
|
120
|
+
**kwargs
|
|
121
|
+
)
|
|
122
|
+
elif which == 'all':
|
|
123
|
+
# Create subplots for all three
|
|
124
|
+
if figsize is None:
|
|
125
|
+
figsize = (15, 5)
|
|
126
|
+
fig, axes = plt.subplots(1, 3, figsize=figsize)
|
|
127
|
+
|
|
128
|
+
# Remove outer borders for a cleaner look
|
|
129
|
+
for ax in axes:
|
|
130
|
+
ax.spines['top'].set_visible(False)
|
|
131
|
+
ax.spines['right'].set_visible(False)
|
|
132
|
+
ax.spines['bottom'].set_color('#cccccc')
|
|
133
|
+
ax.spines['left'].set_color('#cccccc')
|
|
134
|
+
|
|
135
|
+
# Plot each component
|
|
136
|
+
_plot_transition_matrix(model, ax=axes[0])
|
|
137
|
+
_plot_emission_matrix(model, ax=axes[1])
|
|
138
|
+
_plot_initial_probs(model, ax=axes[2])
|
|
139
|
+
|
|
140
|
+
plt.tight_layout()
|
|
141
|
+
return fig
|
|
142
|
+
|
|
143
|
+
elif which == 'transition':
|
|
144
|
+
return _plot_transition_matrix(model, figsize=figsize, ax=ax)
|
|
145
|
+
elif which == 'emission':
|
|
146
|
+
return _plot_emission_matrix(model, figsize=figsize, ax=ax)
|
|
147
|
+
elif which == 'initial':
|
|
148
|
+
return _plot_initial_probs(model, figsize=figsize, ax=ax)
|
|
149
|
+
else:
|
|
150
|
+
raise ValueError(f"Unknown 'which' option: {which}. Must be 'transition', 'emission', 'initial', 'network', or 'all'.")
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _plot_transition_matrix(
|
|
154
|
+
model: HMM,
|
|
155
|
+
figsize: Optional[tuple] = None,
|
|
156
|
+
ax: Optional[plt.Axes] = None
|
|
157
|
+
) -> plt.Figure:
|
|
158
|
+
"""Plot transition probability matrix as a heatmap."""
|
|
159
|
+
if ax is None:
|
|
160
|
+
if figsize is None:
|
|
161
|
+
figsize = (8, 6)
|
|
162
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
163
|
+
else:
|
|
164
|
+
fig = ax.figure
|
|
165
|
+
|
|
166
|
+
# Create heatmap with a more elegant colormap
|
|
167
|
+
im = ax.imshow(model.transition_probs, cmap='Blues', aspect='auto', vmin=0, vmax=1)
|
|
168
|
+
|
|
169
|
+
# Add colorbar with cleaner style
|
|
170
|
+
cbar = plt.colorbar(im, ax=ax)
|
|
171
|
+
cbar.set_label('Transition Probability', rotation=270, labelpad=20, fontsize=10)
|
|
172
|
+
cbar.outline.set_visible(False)
|
|
173
|
+
|
|
174
|
+
# Set ticks and labels
|
|
175
|
+
ax.set_xticks(range(model.n_states))
|
|
176
|
+
ax.set_yticks(range(model.n_states))
|
|
177
|
+
ax.set_xticklabels(model.state_names, rotation=45, ha='right', fontsize=9)
|
|
178
|
+
ax.set_yticklabels(model.state_names, fontsize=9)
|
|
179
|
+
|
|
180
|
+
# Add text annotations
|
|
181
|
+
for i in range(model.n_states):
|
|
182
|
+
for j in range(model.n_states):
|
|
183
|
+
text = ax.text(j, i, f'{model.transition_probs[i, j]:.2f}',
|
|
184
|
+
ha="center", va="center",
|
|
185
|
+
color="black" if model.transition_probs[i, j] < 0.5 else "white",
|
|
186
|
+
fontsize=9, weight='medium')
|
|
187
|
+
|
|
188
|
+
ax.set_xlabel('To State', fontsize=10)
|
|
189
|
+
ax.set_ylabel('From State', fontsize=10)
|
|
190
|
+
ax.set_title('Transition Probability Matrix', fontsize=11, pad=10, weight='medium')
|
|
191
|
+
|
|
192
|
+
# Remove top and right spines for cleaner look
|
|
193
|
+
ax.spines['top'].set_visible(False)
|
|
194
|
+
ax.spines['right'].set_visible(False)
|
|
195
|
+
ax.spines['bottom'].set_color('#cccccc')
|
|
196
|
+
ax.spines['left'].set_color('#cccccc')
|
|
197
|
+
|
|
198
|
+
return fig
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _plot_emission_matrix(
|
|
202
|
+
model: HMM,
|
|
203
|
+
figsize: Optional[tuple] = None,
|
|
204
|
+
ax: Optional[plt.Axes] = None
|
|
205
|
+
) -> plt.Figure:
|
|
206
|
+
"""Plot emission probability matrix as a heatmap."""
|
|
207
|
+
if ax is None:
|
|
208
|
+
if figsize is None:
|
|
209
|
+
figsize = (10, 6)
|
|
210
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
211
|
+
else:
|
|
212
|
+
fig = ax.figure
|
|
213
|
+
|
|
214
|
+
# Create heatmap with a more elegant colormap
|
|
215
|
+
im = ax.imshow(model.emission_probs, cmap='YlOrRd', aspect='auto', vmin=0, vmax=1)
|
|
216
|
+
|
|
217
|
+
# Add colorbar with cleaner style
|
|
218
|
+
cbar = plt.colorbar(im, ax=ax)
|
|
219
|
+
cbar.set_label('Emission Probability', rotation=270, labelpad=20, fontsize=10)
|
|
220
|
+
cbar.outline.set_visible(False)
|
|
221
|
+
|
|
222
|
+
# Set ticks and labels
|
|
223
|
+
ax.set_xticks(range(model.n_symbols))
|
|
224
|
+
ax.set_yticks(range(model.n_states))
|
|
225
|
+
ax.set_xticklabels(model.alphabet, rotation=45, ha='right', fontsize=9)
|
|
226
|
+
ax.set_yticklabels(model.state_names, fontsize=9)
|
|
227
|
+
|
|
228
|
+
# Add text annotations (only if matrix is not too large)
|
|
229
|
+
if model.n_states <= 10 and model.n_symbols <= 15:
|
|
230
|
+
for i in range(model.n_states):
|
|
231
|
+
for j in range(model.n_symbols):
|
|
232
|
+
text = ax.text(j, i, f'{model.emission_probs[i, j]:.2f}',
|
|
233
|
+
ha="center", va="center",
|
|
234
|
+
color="black" if model.emission_probs[i, j] < 0.5 else "white",
|
|
235
|
+
fontsize=8, weight='medium')
|
|
236
|
+
|
|
237
|
+
ax.set_xlabel('Observed Symbol', fontsize=10)
|
|
238
|
+
ax.set_ylabel('Hidden State', fontsize=10)
|
|
239
|
+
ax.set_title('Emission Probability Matrix', fontsize=11, pad=10, weight='medium')
|
|
240
|
+
|
|
241
|
+
# Remove top and right spines for cleaner look
|
|
242
|
+
ax.spines['top'].set_visible(False)
|
|
243
|
+
ax.spines['right'].set_visible(False)
|
|
244
|
+
ax.spines['bottom'].set_color('#cccccc')
|
|
245
|
+
ax.spines['left'].set_color('#cccccc')
|
|
246
|
+
|
|
247
|
+
return fig
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _plot_initial_probs(
|
|
251
|
+
model: HMM,
|
|
252
|
+
figsize: Optional[tuple] = None,
|
|
253
|
+
ax: Optional[plt.Axes] = None
|
|
254
|
+
) -> plt.Figure:
|
|
255
|
+
"""Plot initial state probabilities as a bar chart."""
|
|
256
|
+
if ax is None:
|
|
257
|
+
if figsize is None:
|
|
258
|
+
figsize = (8, 5)
|
|
259
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
260
|
+
else:
|
|
261
|
+
fig = ax.figure
|
|
262
|
+
|
|
263
|
+
# Create bar chart with a more elegant color
|
|
264
|
+
bars = ax.bar(range(model.n_states), model.initial_probs,
|
|
265
|
+
color='#4A90E2', alpha=0.8, edgecolor='white', linewidth=1.5)
|
|
266
|
+
|
|
267
|
+
# Add value labels on bars
|
|
268
|
+
for i, (bar, prob) in enumerate(zip(bars, model.initial_probs)):
|
|
269
|
+
height = bar.get_height()
|
|
270
|
+
ax.text(bar.get_x() + bar.get_width()/2., height,
|
|
271
|
+
f'{prob:.3f}',
|
|
272
|
+
ha='center', va='bottom', fontsize=9, weight='medium')
|
|
273
|
+
|
|
274
|
+
ax.set_xticks(range(model.n_states))
|
|
275
|
+
ax.set_xticklabels(model.state_names, rotation=45, ha='right', fontsize=9)
|
|
276
|
+
ax.set_ylabel('Probability', fontsize=10)
|
|
277
|
+
ax.set_title('Initial State Probabilities', fontsize=11, pad=10, weight='medium')
|
|
278
|
+
ax.set_ylim(0, max(model.initial_probs) * 1.2)
|
|
279
|
+
ax.grid(axis='y', alpha=0.2, linestyle='--', linewidth=0.5)
|
|
280
|
+
|
|
281
|
+
# Remove top and right spines for cleaner look
|
|
282
|
+
ax.spines['top'].set_visible(False)
|
|
283
|
+
ax.spines['right'].set_visible(False)
|
|
284
|
+
ax.spines['bottom'].set_color('#cccccc')
|
|
285
|
+
ax.spines['left'].set_color('#cccccc')
|
|
286
|
+
|
|
287
|
+
return fig
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def plot_mhmm(
|
|
291
|
+
model: MHMM,
|
|
292
|
+
which: str = 'clusters',
|
|
293
|
+
figsize: Optional[tuple] = None,
|
|
294
|
+
ax: Optional[plt.Axes] = None
|
|
295
|
+
) -> plt.Figure:
|
|
296
|
+
"""
|
|
297
|
+
Plot Mixture HMM model parameters.
|
|
298
|
+
|
|
299
|
+
This function visualizes Mixture HMM model parameters, including:
|
|
300
|
+
- Cluster probabilities
|
|
301
|
+
- Transition matrices for each cluster
|
|
302
|
+
- Emission matrices for each cluster
|
|
303
|
+
|
|
304
|
+
It is similar to seqHMM's plot.mhmm() function in R.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
model: Fitted MHMM model object
|
|
308
|
+
which: What to plot. Options:
|
|
309
|
+
- 'clusters': Cluster probabilities (default)
|
|
310
|
+
- 'transition': Transition matrices for all clusters
|
|
311
|
+
- 'emission': Emission matrices for all clusters
|
|
312
|
+
- 'all': All plots
|
|
313
|
+
figsize: Figure size tuple (width, height). If None, uses default.
|
|
314
|
+
ax: Optional matplotlib axes to plot on. If None, creates new figure.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
matplotlib Figure: The figure object
|
|
318
|
+
"""
|
|
319
|
+
if model.log_likelihood is None:
|
|
320
|
+
raise ValueError("Model must be fitted before plotting. Use fit_mhmm() first.")
|
|
321
|
+
|
|
322
|
+
if which == 'all':
|
|
323
|
+
# Create subplots for all components
|
|
324
|
+
if figsize is None:
|
|
325
|
+
figsize = (18, 6)
|
|
326
|
+
fig, axes = plt.subplots(1, 3, figsize=figsize)
|
|
327
|
+
|
|
328
|
+
# Plot each component
|
|
329
|
+
_plot_cluster_probs(model, ax=axes[0])
|
|
330
|
+
_plot_mhmm_transitions(model, ax=axes[1])
|
|
331
|
+
_plot_mhmm_emissions(model, ax=axes[2])
|
|
332
|
+
|
|
333
|
+
plt.tight_layout()
|
|
334
|
+
return fig
|
|
335
|
+
|
|
336
|
+
elif which == 'clusters':
|
|
337
|
+
return _plot_cluster_probs(model, figsize=figsize, ax=ax)
|
|
338
|
+
elif which == 'transition':
|
|
339
|
+
return _plot_mhmm_transitions(model, figsize=figsize, ax=ax)
|
|
340
|
+
elif which == 'emission':
|
|
341
|
+
return _plot_mhmm_emissions(model, figsize=figsize, ax=ax)
|
|
342
|
+
else:
|
|
343
|
+
raise ValueError(
|
|
344
|
+
f"Unknown 'which' option: {which}. Must be 'clusters', 'transition', 'emission', or 'all'."
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def _plot_cluster_probs(
|
|
349
|
+
model: MHMM,
|
|
350
|
+
figsize: Optional[tuple] = None,
|
|
351
|
+
ax: Optional[plt.Axes] = None
|
|
352
|
+
) -> plt.Figure:
|
|
353
|
+
"""Plot cluster probabilities as a bar chart."""
|
|
354
|
+
if ax is None:
|
|
355
|
+
if figsize is None:
|
|
356
|
+
figsize = (8, 5)
|
|
357
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
358
|
+
else:
|
|
359
|
+
fig = ax.figure
|
|
360
|
+
|
|
361
|
+
# Create bar chart
|
|
362
|
+
bars = ax.bar(range(model.n_clusters), model.cluster_probs,
|
|
363
|
+
color='steelblue', alpha=0.7)
|
|
364
|
+
|
|
365
|
+
# Add value labels on bars
|
|
366
|
+
for i, (bar, prob) in enumerate(zip(bars, model.cluster_probs)):
|
|
367
|
+
height = bar.get_height()
|
|
368
|
+
ax.text(bar.get_x() + bar.get_width()/2., height,
|
|
369
|
+
f'{prob:.3f}',
|
|
370
|
+
ha='center', va='bottom')
|
|
371
|
+
|
|
372
|
+
ax.set_xticks(range(model.n_clusters))
|
|
373
|
+
ax.set_xticklabels(model.cluster_names, rotation=45, ha='right')
|
|
374
|
+
ax.set_ylabel('Probability')
|
|
375
|
+
ax.set_title('Cluster Probabilities')
|
|
376
|
+
ax.set_ylim(0, max(model.cluster_probs) * 1.2)
|
|
377
|
+
ax.grid(axis='y', alpha=0.3)
|
|
378
|
+
|
|
379
|
+
return fig
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def _plot_mhmm_transitions(
|
|
383
|
+
model: MHMM,
|
|
384
|
+
figsize: Optional[tuple] = None,
|
|
385
|
+
ax: Optional[plt.Axes] = None
|
|
386
|
+
) -> plt.Figure:
|
|
387
|
+
"""Plot transition matrices for all clusters."""
|
|
388
|
+
if ax is None:
|
|
389
|
+
if figsize is None:
|
|
390
|
+
figsize = (6 * model.n_clusters, 6)
|
|
391
|
+
fig, axes = plt.subplots(1, model.n_clusters, figsize=figsize)
|
|
392
|
+
if model.n_clusters == 1:
|
|
393
|
+
axes = [axes]
|
|
394
|
+
else:
|
|
395
|
+
fig = ax.figure
|
|
396
|
+
axes = [ax] * model.n_clusters
|
|
397
|
+
|
|
398
|
+
for k in range(model.n_clusters):
|
|
399
|
+
cluster = model.clusters[k]
|
|
400
|
+
trans_probs = cluster.transition_probs
|
|
401
|
+
|
|
402
|
+
# Create heatmap
|
|
403
|
+
im = axes[k].imshow(trans_probs, cmap='Blues', aspect='auto', vmin=0, vmax=1)
|
|
404
|
+
|
|
405
|
+
# Set ticks and labels
|
|
406
|
+
axes[k].set_xticks(range(cluster.n_states))
|
|
407
|
+
axes[k].set_yticks(range(cluster.n_states))
|
|
408
|
+
axes[k].set_xticklabels(cluster.state_names, rotation=45, ha='right', fontsize=8)
|
|
409
|
+
axes[k].set_yticklabels(cluster.state_names, fontsize=8)
|
|
410
|
+
|
|
411
|
+
# Add text annotations
|
|
412
|
+
for i in range(cluster.n_states):
|
|
413
|
+
for j in range(cluster.n_states):
|
|
414
|
+
text = axes[k].text(j, i, f'{trans_probs[i, j]:.2f}',
|
|
415
|
+
ha="center", va="center",
|
|
416
|
+
color="black" if trans_probs[i, j] < 0.5 else "white",
|
|
417
|
+
fontsize=7)
|
|
418
|
+
|
|
419
|
+
axes[k].set_xlabel('To State')
|
|
420
|
+
axes[k].set_ylabel('From State')
|
|
421
|
+
axes[k].set_title(f'{model.cluster_names[k]}\nTransition Matrix')
|
|
422
|
+
|
|
423
|
+
if model.n_clusters > 1:
|
|
424
|
+
plt.colorbar(im, ax=axes, orientation='horizontal', pad=0.1)
|
|
425
|
+
|
|
426
|
+
plt.tight_layout()
|
|
427
|
+
return fig
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def _plot_mhmm_emissions(
|
|
431
|
+
model: MHMM,
|
|
432
|
+
figsize: Optional[tuple] = None,
|
|
433
|
+
ax: Optional[plt.Axes] = None
|
|
434
|
+
) -> plt.Figure:
|
|
435
|
+
"""Plot emission matrices for all clusters."""
|
|
436
|
+
if ax is None:
|
|
437
|
+
if figsize is None:
|
|
438
|
+
figsize = (6 * model.n_clusters, 6)
|
|
439
|
+
fig, axes = plt.subplots(1, model.n_clusters, figsize=figsize)
|
|
440
|
+
if model.n_clusters == 1:
|
|
441
|
+
axes = [axes]
|
|
442
|
+
else:
|
|
443
|
+
fig = ax.figure
|
|
444
|
+
axes = [ax] * model.n_clusters
|
|
445
|
+
|
|
446
|
+
for k in range(model.n_clusters):
|
|
447
|
+
cluster = model.clusters[k]
|
|
448
|
+
emission_probs = cluster.emission_probs
|
|
449
|
+
|
|
450
|
+
# Create heatmap
|
|
451
|
+
im = axes[k].imshow(emission_probs, cmap='YlOrRd', aspect='auto', vmin=0, vmax=1)
|
|
452
|
+
|
|
453
|
+
# Set ticks and labels
|
|
454
|
+
axes[k].set_xticks(range(cluster.n_symbols))
|
|
455
|
+
axes[k].set_yticks(range(cluster.n_states))
|
|
456
|
+
axes[k].set_xticklabels(cluster.alphabet, rotation=45, ha='right', fontsize=8)
|
|
457
|
+
axes[k].set_yticklabels(cluster.state_names, fontsize=8)
|
|
458
|
+
|
|
459
|
+
# Add text annotations (only if matrix is not too large)
|
|
460
|
+
if cluster.n_states <= 10 and cluster.n_symbols <= 15:
|
|
461
|
+
for i in range(cluster.n_states):
|
|
462
|
+
for j in range(cluster.n_symbols):
|
|
463
|
+
text = axes[k].text(j, i, f'{emission_probs[i, j]:.2f}',
|
|
464
|
+
ha="center", va="center",
|
|
465
|
+
color="black" if emission_probs[i, j] < 0.5 else "white",
|
|
466
|
+
fontsize=7)
|
|
467
|
+
|
|
468
|
+
axes[k].set_xlabel('Observed Symbol')
|
|
469
|
+
axes[k].set_ylabel('Hidden State')
|
|
470
|
+
axes[k].set_title(f'{model.cluster_names[k]}\nEmission Matrix')
|
|
471
|
+
|
|
472
|
+
if model.n_clusters > 1:
|
|
473
|
+
plt.colorbar(im, ax=axes, orientation='horizontal', pad=0.1)
|
|
474
|
+
|
|
475
|
+
plt.tight_layout()
|
|
476
|
+
return fig
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def _plot_hmm_network(
|
|
480
|
+
model: HMM,
|
|
481
|
+
figsize: Optional[tuple] = None,
|
|
482
|
+
ax: Optional[plt.Axes] = None,
|
|
483
|
+
vertex_size: float = 50,
|
|
484
|
+
vertex_label_dist: float = 1.5,
|
|
485
|
+
edge_curved: Union[bool, float] = 0.5,
|
|
486
|
+
edge_label_cex: float = 0.8,
|
|
487
|
+
vertex_label: str = 'initial.probs',
|
|
488
|
+
loops: bool = False,
|
|
489
|
+
trim: float = 1e-15,
|
|
490
|
+
combine_slices: float = 0.05,
|
|
491
|
+
with_legend: Union[bool, str] = 'bottom',
|
|
492
|
+
layout: str = 'horizontal',
|
|
493
|
+
legend_prop: float = 0.5,
|
|
494
|
+
**kwargs
|
|
495
|
+
) -> plt.Figure:
|
|
496
|
+
"""
|
|
497
|
+
Plot HMM as a network graph with pie chart nodes (similar to R's plot.hmm).
|
|
498
|
+
|
|
499
|
+
This function creates a directed graph where:
|
|
500
|
+
- Nodes are pie charts showing emission probabilities for each hidden state
|
|
501
|
+
- Edges are arrows showing transition probabilities between states
|
|
502
|
+
- Node labels show initial probabilities or state names
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
model: Fitted HMM model object
|
|
506
|
+
figsize: Figure size tuple (width, height)
|
|
507
|
+
ax: Optional matplotlib axes
|
|
508
|
+
vertex_size: Size of vertices (nodes)
|
|
509
|
+
vertex_label_dist: Distance of vertex labels from center
|
|
510
|
+
edge_curved: Whether to plot curved edges (bool or float for curvature)
|
|
511
|
+
edge_label_cex: Character expansion factor for edge labels
|
|
512
|
+
vertex_label: Labels for vertices ('initial.probs', 'names', or custom list)
|
|
513
|
+
loops: Whether to plot self-loops
|
|
514
|
+
trim: Minimum transition probability to plot
|
|
515
|
+
combine_slices: Emission probabilities below this are combined into 'others'
|
|
516
|
+
with_legend: Whether and where to plot legend
|
|
517
|
+
layout: Layout of vertices ('horizontal' or 'vertical')
|
|
518
|
+
legend_prop: Proportion of figure used for legend (0-1). Default 0.5.
|
|
519
|
+
**kwargs: Additional arguments
|
|
520
|
+
|
|
521
|
+
Returns:
|
|
522
|
+
matplotlib Figure: The figure object
|
|
523
|
+
"""
|
|
524
|
+
# Determine if we need separate subplots for legend (like R's layout)
|
|
525
|
+
use_separate_legend = (with_legend and with_legend != False and
|
|
526
|
+
with_legend in ['bottom', 'top', 'left', 'right'])
|
|
527
|
+
|
|
528
|
+
if ax is None:
|
|
529
|
+
if figsize is None:
|
|
530
|
+
# Adjust figsize based on legend position
|
|
531
|
+
if use_separate_legend and with_legend in ['bottom', 'top']:
|
|
532
|
+
figsize = (12, 8)
|
|
533
|
+
elif use_separate_legend and with_legend in ['left', 'right']:
|
|
534
|
+
figsize = (14, 6)
|
|
535
|
+
else:
|
|
536
|
+
figsize = (12, 6)
|
|
537
|
+
|
|
538
|
+
# Create figure with subplots if legend is needed
|
|
539
|
+
if use_separate_legend:
|
|
540
|
+
if with_legend == 'bottom':
|
|
541
|
+
fig = plt.figure(figsize=figsize)
|
|
542
|
+
gs = fig.add_gridspec(2, 1, height_ratios=[1 - legend_prop, legend_prop], hspace=0.3)
|
|
543
|
+
ax = fig.add_subplot(gs[0])
|
|
544
|
+
ax_legend = fig.add_subplot(gs[1])
|
|
545
|
+
elif with_legend == 'top':
|
|
546
|
+
fig = plt.figure(figsize=figsize)
|
|
547
|
+
gs = fig.add_gridspec(2, 1, height_ratios=[legend_prop, 1 - legend_prop], hspace=0.3)
|
|
548
|
+
ax_legend = fig.add_subplot(gs[0])
|
|
549
|
+
ax = fig.add_subplot(gs[1])
|
|
550
|
+
elif with_legend == 'right':
|
|
551
|
+
fig = plt.figure(figsize=figsize)
|
|
552
|
+
gs = fig.add_gridspec(1, 2, width_ratios=[1 - legend_prop, legend_prop], wspace=0.3)
|
|
553
|
+
ax = fig.add_subplot(gs[0])
|
|
554
|
+
ax_legend = fig.add_subplot(gs[1])
|
|
555
|
+
elif with_legend == 'left':
|
|
556
|
+
fig = plt.figure(figsize=figsize)
|
|
557
|
+
gs = fig.add_gridspec(1, 2, width_ratios=[legend_prop, 1 - legend_prop], wspace=0.3)
|
|
558
|
+
ax_legend = fig.add_subplot(gs[0])
|
|
559
|
+
ax = fig.add_subplot(gs[1])
|
|
560
|
+
else:
|
|
561
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
562
|
+
ax_legend = None
|
|
563
|
+
else:
|
|
564
|
+
fig = ax.figure
|
|
565
|
+
ax_legend = None
|
|
566
|
+
use_separate_legend = False
|
|
567
|
+
|
|
568
|
+
# Get model parameters
|
|
569
|
+
n_states = model.n_states
|
|
570
|
+
transition_probs = model.transition_probs.copy()
|
|
571
|
+
emission_probs = model.emission_probs.copy()
|
|
572
|
+
initial_probs = model.initial_probs.copy()
|
|
573
|
+
|
|
574
|
+
# Get colors for observed states
|
|
575
|
+
# Try to get colors from observations if available
|
|
576
|
+
if hasattr(model.observations, 'color_map') and model.observations.color_map:
|
|
577
|
+
colors = [model.observations.color_map.get(sym, '#808080') for sym in model.alphabet]
|
|
578
|
+
else:
|
|
579
|
+
# Default color palette (similar to TraMineR)
|
|
580
|
+
default_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b',
|
|
581
|
+
'#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
|
|
582
|
+
colors = default_colors[:len(model.alphabet)]
|
|
583
|
+
# Extend if needed
|
|
584
|
+
while len(colors) < len(model.alphabet):
|
|
585
|
+
colors.append('#808080')
|
|
586
|
+
|
|
587
|
+
# Trim transitions (remove very small probabilities)
|
|
588
|
+
transition_probs[transition_probs < trim] = 0
|
|
589
|
+
|
|
590
|
+
# Remove self-loops if not requested
|
|
591
|
+
if not loops:
|
|
592
|
+
np.fill_diagonal(transition_probs, 0)
|
|
593
|
+
|
|
594
|
+
# Calculate node positions (similar to R's layout)
|
|
595
|
+
# First, determine coordinate limits (similar to R's xlim/ylim calculation)
|
|
596
|
+
if layout == 'horizontal':
|
|
597
|
+
x_min, x_max = -0.1, n_states - 1 + 0.1
|
|
598
|
+
y_min, y_max = -0.5, 0.5
|
|
599
|
+
# Horizontal layout: nodes in a line
|
|
600
|
+
positions = {i: (i, 0) for i in range(n_states)}
|
|
601
|
+
elif layout == 'vertical':
|
|
602
|
+
x_min, x_max = -0.5, 0.5
|
|
603
|
+
y_min, y_max = -0.1, n_states - 1 + 0.1
|
|
604
|
+
positions = {i: (0, -i) for i in range(n_states)}
|
|
605
|
+
else:
|
|
606
|
+
x_min, x_max = -0.1, n_states - 1 + 0.1
|
|
607
|
+
y_min, y_max = -0.5, 0.5
|
|
608
|
+
positions = {i: (i, 0) for i in range(n_states)}
|
|
609
|
+
|
|
610
|
+
# Scale positions to fit in plot area
|
|
611
|
+
if positions:
|
|
612
|
+
x_coords = [pos[0] for pos in positions.values()]
|
|
613
|
+
y_coords = [pos[1] for pos in positions.values()]
|
|
614
|
+
|
|
615
|
+
# Normalize and scale positions to fit within xlim/ylim
|
|
616
|
+
if max(x_coords) != min(x_coords):
|
|
617
|
+
x_range = max(x_coords) - min(x_coords)
|
|
618
|
+
positions = {i: ((pos[0] - min(x_coords)) / x_range * (x_max - x_min) + x_min,
|
|
619
|
+
pos[1]) for i, pos in positions.items()}
|
|
620
|
+
else:
|
|
621
|
+
positions = {i: (x_min + (x_max - x_min) / 2, pos[1])
|
|
622
|
+
for i, pos in positions.items()}
|
|
623
|
+
|
|
624
|
+
# Prepare emission probabilities for pie charts
|
|
625
|
+
pie_values = []
|
|
626
|
+
pie_colors_list = []
|
|
627
|
+
combined_slice_probs = []
|
|
628
|
+
legend_labels_list = []
|
|
629
|
+
legend_colors_list = []
|
|
630
|
+
|
|
631
|
+
for i in range(n_states):
|
|
632
|
+
emis = emission_probs[i, :].copy()
|
|
633
|
+
combined_prob = 0
|
|
634
|
+
|
|
635
|
+
# Combine small slices
|
|
636
|
+
if combine_slices > 0:
|
|
637
|
+
small_mask = emis < combine_slices
|
|
638
|
+
if np.any(small_mask):
|
|
639
|
+
combined_prob = np.sum(emis[small_mask])
|
|
640
|
+
emis[small_mask] = 0
|
|
641
|
+
if combined_prob > 0:
|
|
642
|
+
emis = np.append(emis, combined_prob)
|
|
643
|
+
pie_colors_list.append(colors + ['white'])
|
|
644
|
+
# Track which colors are used for legend
|
|
645
|
+
used_colors = [colors[j] for j in range(len(emis) - 1) if emis[j] > 0]
|
|
646
|
+
used_labels = [model.alphabet[j] for j in range(len(emis) - 1) if emis[j] > 0]
|
|
647
|
+
legend_labels_list.append(used_labels + ['others'])
|
|
648
|
+
legend_colors_list.append(used_colors + ['white'])
|
|
649
|
+
else:
|
|
650
|
+
pie_colors_list.append(colors)
|
|
651
|
+
legend_labels_list.append([model.alphabet[j] for j in range(len(emis)) if emis[j] > 0])
|
|
652
|
+
legend_colors_list.append([colors[j] for j in range(len(emis)) if emis[j] > 0])
|
|
653
|
+
else:
|
|
654
|
+
pie_colors_list.append(colors)
|
|
655
|
+
legend_labels_list.append([model.alphabet[j] for j in range(len(emis)) if emis[j] > 0])
|
|
656
|
+
legend_colors_list.append([colors[j] for j in range(len(emis)) if emis[j] > 0])
|
|
657
|
+
else:
|
|
658
|
+
pie_colors_list.append(colors)
|
|
659
|
+
legend_labels_list.append([model.alphabet[j] for j in range(len(emis)) if emis[j] > 0])
|
|
660
|
+
legend_colors_list.append([colors[j] for j in range(len(emis)) if emis[j] > 0])
|
|
661
|
+
|
|
662
|
+
# Remove zero probabilities
|
|
663
|
+
non_zero_mask = emis > 0
|
|
664
|
+
pie_values.append(emis[non_zero_mask])
|
|
665
|
+
combined_slice_probs.append(combined_prob)
|
|
666
|
+
|
|
667
|
+
# Collect unique legend items (by appearance order, like R)
|
|
668
|
+
if use_separate_legend:
|
|
669
|
+
unique_labels = []
|
|
670
|
+
unique_colors = []
|
|
671
|
+
seen = set()
|
|
672
|
+
for labels, cols in zip(legend_labels_list, legend_colors_list):
|
|
673
|
+
for label, col in zip(labels, cols):
|
|
674
|
+
if (label, col) not in seen:
|
|
675
|
+
unique_labels.append(label)
|
|
676
|
+
unique_colors.append(col)
|
|
677
|
+
seen.add((label, col))
|
|
678
|
+
# Add 'others' if needed
|
|
679
|
+
if combine_slices > 0 and any(combined_slice_probs):
|
|
680
|
+
if 'others' not in unique_labels:
|
|
681
|
+
unique_labels.append('others')
|
|
682
|
+
unique_colors.append('white')
|
|
683
|
+
|
|
684
|
+
# Calculate node radius in data coordinates
|
|
685
|
+
# Convert vertex_size (in points) to data coordinates
|
|
686
|
+
# R uses vertex.size directly in the plot coordinate system
|
|
687
|
+
# We'll use a reasonable scaling factor based on the coordinate range
|
|
688
|
+
if layout == 'horizontal':
|
|
689
|
+
# Estimate data coordinate range
|
|
690
|
+
data_range = max(x_max - x_min, 1.0)
|
|
691
|
+
# Convert vertex_size to data coordinates
|
|
692
|
+
# Scale factor: vertex_size of 50 should be about 0.2-0.25 of the spacing between nodes
|
|
693
|
+
# For horizontal layout, spacing is approximately (x_max - x_min) / max(n_states - 1, 1)
|
|
694
|
+
spacing = (x_max - x_min) / max(n_states - 1, 1) if n_states > 1 else (x_max - x_min)
|
|
695
|
+
node_radius = (vertex_size / 50.0) * spacing * 0.25
|
|
696
|
+
else:
|
|
697
|
+
data_range = max(y_max - y_min, 1.0)
|
|
698
|
+
spacing = (y_max - y_min) / max(n_states - 1, 1) if n_states > 1 else (y_max - y_min)
|
|
699
|
+
node_radius = (vertex_size / 50.0) * spacing * 0.25
|
|
700
|
+
|
|
701
|
+
# Draw edges (transitions) first (so they appear behind nodes)
|
|
702
|
+
edge_widths = []
|
|
703
|
+
edge_labels = {}
|
|
704
|
+
edges_to_draw = []
|
|
705
|
+
|
|
706
|
+
# Get all non-zero transitions
|
|
707
|
+
transitions = []
|
|
708
|
+
for i in range(n_states):
|
|
709
|
+
for j in range(n_states):
|
|
710
|
+
prob = transition_probs[i, j]
|
|
711
|
+
if prob > 0:
|
|
712
|
+
edges_to_draw.append((i, j))
|
|
713
|
+
transitions.append(prob)
|
|
714
|
+
|
|
715
|
+
# Calculate edge widths (similar to R: transitions * (7 / max(transitions)))
|
|
716
|
+
if transitions:
|
|
717
|
+
max_trans = max(transitions)
|
|
718
|
+
edge_widths = [t * (7.0 / max_trans) if max_trans > 0 else 1.0 for t in transitions]
|
|
719
|
+
# Format edge labels
|
|
720
|
+
for (i, j), prob in zip(edges_to_draw, transitions):
|
|
721
|
+
if prob >= 0.001 or prob == 0:
|
|
722
|
+
edge_labels[(i, j)] = f'{prob:.3f}'
|
|
723
|
+
else:
|
|
724
|
+
edge_labels[(i, j)] = f'{prob:.2e}'
|
|
725
|
+
else:
|
|
726
|
+
edge_widths = [1.0] * len(edges_to_draw)
|
|
727
|
+
|
|
728
|
+
# Draw edges
|
|
729
|
+
for (i, j), width in zip(edges_to_draw, edge_widths):
|
|
730
|
+
x1, y1 = positions[i]
|
|
731
|
+
x2, y2 = positions[j]
|
|
732
|
+
|
|
733
|
+
# Calculate arrow properties
|
|
734
|
+
dx = x2 - x1
|
|
735
|
+
dy = y2 - y1
|
|
736
|
+
dist = np.sqrt(dx**2 + dy**2)
|
|
737
|
+
|
|
738
|
+
if dist > 0:
|
|
739
|
+
# Normalize direction
|
|
740
|
+
dx_norm = dx / dist
|
|
741
|
+
dy_norm = dy / dist
|
|
742
|
+
|
|
743
|
+
# Adjust start/end to account for node radius
|
|
744
|
+
start_x = x1 + dx_norm * node_radius
|
|
745
|
+
start_y = y1 + dy_norm * node_radius
|
|
746
|
+
end_x = x2 - dx_norm * node_radius
|
|
747
|
+
end_y = y2 - dy_norm * node_radius
|
|
748
|
+
|
|
749
|
+
# Curved edge
|
|
750
|
+
if edge_curved and (isinstance(edge_curved, bool) or edge_curved != 0):
|
|
751
|
+
curvature = edge_curved if isinstance(edge_curved, (int, float)) else 0.5
|
|
752
|
+
# Create curved path
|
|
753
|
+
mid_x = (start_x + end_x) / 2
|
|
754
|
+
mid_y = (start_y + end_y) / 2
|
|
755
|
+
# Perpendicular direction for curve
|
|
756
|
+
perp_x = -dy_norm * curvature * dist * 0.3
|
|
757
|
+
perp_y = dx_norm * curvature * dist * 0.3
|
|
758
|
+
control_x = mid_x + perp_x
|
|
759
|
+
control_y = mid_y + perp_y
|
|
760
|
+
|
|
761
|
+
# Use quadratic bezier curve
|
|
762
|
+
from matplotlib.path import Path
|
|
763
|
+
path_data = [
|
|
764
|
+
(Path.MOVETO, (start_x, start_y)),
|
|
765
|
+
(Path.CURVE3, (control_x, control_y)),
|
|
766
|
+
(Path.CURVE3, (end_x, end_y)),
|
|
767
|
+
]
|
|
768
|
+
codes, verts = zip(*path_data)
|
|
769
|
+
path = Path(verts, codes)
|
|
770
|
+
|
|
771
|
+
arrow = FancyArrowPatch(
|
|
772
|
+
path=path,
|
|
773
|
+
arrowstyle='->',
|
|
774
|
+
lw=max(width, 0.8),
|
|
775
|
+
color='#666666',
|
|
776
|
+
alpha=0.8,
|
|
777
|
+
zorder=1,
|
|
778
|
+
mutation_scale=15
|
|
779
|
+
)
|
|
780
|
+
else:
|
|
781
|
+
# Straight edge
|
|
782
|
+
arrow = FancyArrowPatch(
|
|
783
|
+
(start_x, start_y),
|
|
784
|
+
(end_x, end_y),
|
|
785
|
+
arrowstyle='->',
|
|
786
|
+
lw=max(width, 0.8),
|
|
787
|
+
color='#666666',
|
|
788
|
+
alpha=0.8,
|
|
789
|
+
zorder=1,
|
|
790
|
+
mutation_scale=15
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
ax.add_patch(arrow)
|
|
794
|
+
|
|
795
|
+
# Add edge label
|
|
796
|
+
if (i, j) in edge_labels:
|
|
797
|
+
label_x = (start_x + end_x) / 2
|
|
798
|
+
label_y = (start_y + end_y) / 2
|
|
799
|
+
if edge_curved and (isinstance(edge_curved, bool) or edge_curved != 0):
|
|
800
|
+
curvature = edge_curved if isinstance(edge_curved, (int, float)) else 0.5
|
|
801
|
+
perp_x = -dy_norm * curvature * dist * 0.3
|
|
802
|
+
perp_y = dx_norm * curvature * dist * 0.3
|
|
803
|
+
label_x += perp_x * 0.5
|
|
804
|
+
label_y += perp_y * 0.5
|
|
805
|
+
|
|
806
|
+
ax.text(label_x, label_y, edge_labels[(i, j)],
|
|
807
|
+
fontsize=9 * edge_label_cex,
|
|
808
|
+
ha='center', va='center',
|
|
809
|
+
bbox=dict(boxstyle='round,pad=0.2', facecolor='white',
|
|
810
|
+
alpha=0.9, edgecolor='gray', linewidth=0.5),
|
|
811
|
+
zorder=3)
|
|
812
|
+
|
|
813
|
+
# Draw nodes (pie charts)
|
|
814
|
+
for i in range(n_states):
|
|
815
|
+
x, y = positions[i]
|
|
816
|
+
emis = pie_values[i]
|
|
817
|
+
node_colors = pie_colors_list[i][:len(emis)]
|
|
818
|
+
|
|
819
|
+
# Draw pie chart
|
|
820
|
+
if len(emis) > 0 and np.sum(emis) > 0:
|
|
821
|
+
# Normalize to sum to 1
|
|
822
|
+
emis_norm = emis / np.sum(emis)
|
|
823
|
+
angles = np.cumsum(emis_norm * 2 * np.pi)
|
|
824
|
+
angles = np.insert(angles, 0, 0)
|
|
825
|
+
|
|
826
|
+
# Draw wedges
|
|
827
|
+
for j in range(len(emis_norm)):
|
|
828
|
+
if emis_norm[j] > 0:
|
|
829
|
+
theta1 = angles[j] * 180 / np.pi
|
|
830
|
+
theta2 = angles[j + 1] * 180 / np.pi
|
|
831
|
+
wedge = Wedge((x, y), node_radius, theta1, theta2,
|
|
832
|
+
facecolor=node_colors[j], edgecolor='black',
|
|
833
|
+
linewidth=2, zorder=2)
|
|
834
|
+
ax.add_patch(wedge)
|
|
835
|
+
|
|
836
|
+
# Add node label
|
|
837
|
+
if vertex_label == 'initial.probs':
|
|
838
|
+
label_text = f'{initial_probs[i]:.2f}'
|
|
839
|
+
elif vertex_label == 'names':
|
|
840
|
+
label_text = model.state_names[i] if i < len(model.state_names) else f'State {i+1}'
|
|
841
|
+
else:
|
|
842
|
+
if isinstance(vertex_label, list) and i < len(vertex_label):
|
|
843
|
+
label_text = str(vertex_label[i])
|
|
844
|
+
else:
|
|
845
|
+
label_text = f'{initial_probs[i]:.2f}'
|
|
846
|
+
|
|
847
|
+
# Position label (similar to R's vertex.label.dist)
|
|
848
|
+
if isinstance(vertex_label_dist, (int, float)) and vertex_label_dist != 'auto':
|
|
849
|
+
label_dist = vertex_label_dist
|
|
850
|
+
else:
|
|
851
|
+
# Auto: place outside vertex
|
|
852
|
+
label_dist = node_radius * 1.4
|
|
853
|
+
|
|
854
|
+
if layout == 'horizontal':
|
|
855
|
+
label_x = x
|
|
856
|
+
label_y = y - label_dist
|
|
857
|
+
else:
|
|
858
|
+
label_x = x - label_dist
|
|
859
|
+
label_y = y
|
|
860
|
+
|
|
861
|
+
ax.text(label_x, label_y, label_text,
|
|
862
|
+
fontsize=11, ha='center', va='top' if layout == 'horizontal' else 'center',
|
|
863
|
+
weight='bold', zorder=4)
|
|
864
|
+
|
|
865
|
+
# Set axis properties
|
|
866
|
+
ax.set_aspect('equal')
|
|
867
|
+
ax.axis('off')
|
|
868
|
+
|
|
869
|
+
# Set limits (matching R's xlim/ylim)
|
|
870
|
+
if layout == 'horizontal':
|
|
871
|
+
ax.set_xlim(x_min, x_max)
|
|
872
|
+
ax.set_ylim(y_min, y_max)
|
|
873
|
+
else:
|
|
874
|
+
ax.set_xlim(x_min, x_max)
|
|
875
|
+
ax.set_ylim(y_min, y_max)
|
|
876
|
+
|
|
877
|
+
# Add legend in separate subplot if requested
|
|
878
|
+
if use_separate_legend and ax_legend is not None:
|
|
879
|
+
ax_legend.axis('off')
|
|
880
|
+
# Create legend elements
|
|
881
|
+
legend_elements = [mpatches.Patch(facecolor=col, edgecolor='black', linewidth=1, label=label)
|
|
882
|
+
for col, label in zip(unique_colors, unique_labels)]
|
|
883
|
+
|
|
884
|
+
# Calculate number of columns
|
|
885
|
+
ncol = min(len(unique_labels), 6) if with_legend in ['bottom', 'top'] else 1
|
|
886
|
+
|
|
887
|
+
ax_legend.legend(handles=legend_elements, loc='center',
|
|
888
|
+
ncol=ncol, frameon=True, fontsize=10,
|
|
889
|
+
handlelength=1.5, handletextpad=0.5)
|
|
890
|
+
elif with_legend and with_legend != False and not use_separate_legend:
|
|
891
|
+
# Fallback: use regular legend (may overlap)
|
|
892
|
+
legend_labels = list(model.alphabet)
|
|
893
|
+
if combine_slices > 0 and any(combined_slice_probs):
|
|
894
|
+
legend_labels.append('others')
|
|
895
|
+
legend_colors = colors + ['white']
|
|
896
|
+
else:
|
|
897
|
+
legend_colors = colors
|
|
898
|
+
|
|
899
|
+
legend_elements = [mpatches.Patch(facecolor=col, edgecolor='black', label=label)
|
|
900
|
+
for col, label in zip(legend_colors[:len(legend_labels)], legend_labels)]
|
|
901
|
+
|
|
902
|
+
if with_legend == 'bottom' or with_legend == True:
|
|
903
|
+
ax.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, -0.1),
|
|
904
|
+
ncol=min(len(legend_labels), 6), frameon=True, fontsize=9)
|
|
905
|
+
elif with_legend == 'top':
|
|
906
|
+
ax.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, 1.1),
|
|
907
|
+
ncol=min(len(legend_labels), 6), frameon=True, fontsize=9)
|
|
908
|
+
|
|
909
|
+
plt.tight_layout()
|
|
910
|
+
return fig
|