sequenzo 0.1.31__cp310-cp310-macosx_10_9_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _sequenzo_fastcluster.cpython-310-darwin.so +0 -0
- sequenzo/__init__.py +349 -0
- sequenzo/big_data/__init__.py +12 -0
- sequenzo/big_data/clara/__init__.py +26 -0
- sequenzo/big_data/clara/clara.py +476 -0
- sequenzo/big_data/clara/utils/__init__.py +27 -0
- sequenzo/big_data/clara/utils/aggregatecases.py +92 -0
- sequenzo/big_data/clara/utils/davies_bouldin.py +91 -0
- sequenzo/big_data/clara/utils/get_weighted_diss.cpython-310-darwin.so +0 -0
- sequenzo/big_data/clara/utils/wfcmdd.py +205 -0
- sequenzo/big_data/clara/visualization.py +88 -0
- sequenzo/clustering/KMedoids.py +178 -0
- sequenzo/clustering/__init__.py +30 -0
- sequenzo/clustering/clustering_c_code.cpython-310-darwin.so +0 -0
- sequenzo/clustering/hierarchical_clustering.py +1256 -0
- sequenzo/clustering/sequenzo_fastcluster/fastcluster.py +495 -0
- sequenzo/clustering/sequenzo_fastcluster/src/fastcluster.cpp +1877 -0
- sequenzo/clustering/sequenzo_fastcluster/src/fastcluster_python.cpp +1264 -0
- sequenzo/clustering/src/KMedoid.cpp +263 -0
- sequenzo/clustering/src/PAM.cpp +237 -0
- sequenzo/clustering/src/PAMonce.cpp +265 -0
- sequenzo/clustering/src/cluster_quality.cpp +496 -0
- sequenzo/clustering/src/cluster_quality.h +128 -0
- sequenzo/clustering/src/cluster_quality_backup.cpp +570 -0
- sequenzo/clustering/src/module.cpp +228 -0
- sequenzo/clustering/src/weightedinertia.cpp +111 -0
- sequenzo/clustering/utils/__init__.py +27 -0
- sequenzo/clustering/utils/disscenter.py +122 -0
- sequenzo/data_preprocessing/__init__.py +22 -0
- sequenzo/data_preprocessing/helpers.py +303 -0
- sequenzo/datasets/__init__.py +41 -0
- sequenzo/datasets/biofam.csv +2001 -0
- sequenzo/datasets/biofam_child_domain.csv +2001 -0
- sequenzo/datasets/biofam_left_domain.csv +2001 -0
- sequenzo/datasets/biofam_married_domain.csv +2001 -0
- sequenzo/datasets/chinese_colonial_territories.csv +12 -0
- sequenzo/datasets/country_co2_emissions.csv +194 -0
- sequenzo/datasets/country_co2_emissions_global_deciles.csv +195 -0
- sequenzo/datasets/country_co2_emissions_global_quintiles.csv +195 -0
- sequenzo/datasets/country_co2_emissions_local_deciles.csv +195 -0
- sequenzo/datasets/country_co2_emissions_local_quintiles.csv +195 -0
- sequenzo/datasets/country_gdp_per_capita.csv +194 -0
- sequenzo/datasets/dyadic_children.csv +61 -0
- sequenzo/datasets/dyadic_parents.csv +61 -0
- sequenzo/datasets/mvad.csv +713 -0
- sequenzo/datasets/pairfam_activity_by_month.csv +1028 -0
- sequenzo/datasets/pairfam_activity_by_year.csv +1028 -0
- sequenzo/datasets/pairfam_family_by_month.csv +1028 -0
- sequenzo/datasets/pairfam_family_by_year.csv +1028 -0
- sequenzo/datasets/political_science_aid_shock.csv +166 -0
- sequenzo/datasets/political_science_donor_fragmentation.csv +157 -0
- sequenzo/define_sequence_data.py +1400 -0
- sequenzo/dissimilarity_measures/__init__.py +31 -0
- sequenzo/dissimilarity_measures/c_code.cpython-310-darwin.so +0 -0
- sequenzo/dissimilarity_measures/get_distance_matrix.py +762 -0
- sequenzo/dissimilarity_measures/get_substitution_cost_matrix.py +246 -0
- sequenzo/dissimilarity_measures/src/DHDdistance.cpp +148 -0
- sequenzo/dissimilarity_measures/src/LCPdistance.cpp +114 -0
- sequenzo/dissimilarity_measures/src/LCPspellDistance.cpp +215 -0
- sequenzo/dissimilarity_measures/src/OMdistance.cpp +247 -0
- sequenzo/dissimilarity_measures/src/OMspellDistance.cpp +281 -0
- sequenzo/dissimilarity_measures/src/__init__.py +0 -0
- sequenzo/dissimilarity_measures/src/dist2matrix.cpp +63 -0
- sequenzo/dissimilarity_measures/src/dp_utils.h +160 -0
- sequenzo/dissimilarity_measures/src/module.cpp +40 -0
- sequenzo/dissimilarity_measures/src/setup.py +30 -0
- sequenzo/dissimilarity_measures/src/utils.h +25 -0
- sequenzo/dissimilarity_measures/src/xsimd/.github/cmake-test/main.cpp +6 -0
- sequenzo/dissimilarity_measures/src/xsimd/benchmark/main.cpp +159 -0
- sequenzo/dissimilarity_measures/src/xsimd/benchmark/xsimd_benchmark.hpp +565 -0
- sequenzo/dissimilarity_measures/src/xsimd/docs/source/conf.py +37 -0
- sequenzo/dissimilarity_measures/src/xsimd/examples/mandelbrot.cpp +330 -0
- sequenzo/dissimilarity_measures/src/xsimd/examples/pico_bench.hpp +246 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_arithmetic.hpp +266 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_complex.hpp +112 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_details.hpp +323 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_logical.hpp +218 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_math.hpp +2583 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_memory.hpp +880 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_rounding.hpp +72 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_swizzle.hpp +174 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_trigo.hpp +978 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx.hpp +1924 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx2.hpp +1144 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512bw.hpp +656 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512cd.hpp +28 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512dq.hpp +244 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512er.hpp +20 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512f.hpp +2650 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512ifma.hpp +20 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512pf.hpp +20 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vbmi.hpp +77 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vbmi2.hpp +131 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vnni_avx512bw.hpp +20 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vnni_avx512vbmi2.hpp +20 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avxvnni.hpp +20 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_common.hpp +24 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_common_fwd.hpp +77 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_constants.hpp +393 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_emulated.hpp +788 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_avx.hpp +93 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_avx2.hpp +46 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_sse.hpp +97 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma4.hpp +92 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_i8mm_neon64.hpp +17 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_isa.hpp +142 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_neon.hpp +3142 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_neon64.hpp +1543 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_rvv.hpp +1513 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_scalar.hpp +1260 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse2.hpp +2024 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse3.hpp +67 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse4_1.hpp +339 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse4_2.hpp +44 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_ssse3.hpp +186 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sve.hpp +1155 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_vsx.hpp +892 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_wasm.hpp +1780 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_arch.hpp +240 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_config.hpp +484 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_cpuid.hpp +269 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_inline.hpp +27 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/math/xsimd_rem_pio2.hpp +719 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/memory/xsimd_aligned_allocator.hpp +349 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/memory/xsimd_alignment.hpp +91 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_all_registers.hpp +55 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_api.hpp +2765 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx2_register.hpp +44 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512bw_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512cd_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512dq_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512er_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512f_register.hpp +77 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512ifma_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512pf_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vbmi2_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vbmi_register.hpp +51 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vnni_avx512bw_register.hpp +54 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx512vnni_avx512vbmi2_register.hpp +53 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avx_register.hpp +64 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_avxvnni_register.hpp +44 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_batch.hpp +1524 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_batch_constant.hpp +300 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_common_arch.hpp +47 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_emulated_register.hpp +80 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma3_avx2_register.hpp +50 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma3_avx_register.hpp +50 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma3_sse_register.hpp +50 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_fma4_register.hpp +50 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_i8mm_neon64_register.hpp +55 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_neon64_register.hpp +55 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_neon_register.hpp +154 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_register.hpp +94 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_rvv_register.hpp +506 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse2_register.hpp +59 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse3_register.hpp +49 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse4_1_register.hpp +48 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sse4_2_register.hpp +48 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_ssse3_register.hpp +48 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_sve_register.hpp +156 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_traits.hpp +337 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_utils.hpp +536 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_vsx_register.hpp +77 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_wasm_register.hpp +59 -0
- sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/xsimd.hpp +75 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/architectures/dummy.cpp +7 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set.cpp +13 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean.cpp +24 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean_aligned.cpp +25 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean_arch_independent.cpp +28 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/explicit_use_of_an_instruction_set_mean_tag_dispatch.cpp +25 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/manipulating_abstract_batches.cpp +7 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/manipulating_parametric_batches.cpp +8 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/sum.hpp +31 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/sum_avx2.cpp +3 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/sum_sse2.cpp +3 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/doc/writing_vectorized_code.cpp +11 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/main.cpp +31 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_api.cpp +230 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_arch.cpp +217 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_basic_math.cpp +183 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch.cpp +1049 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_bool.cpp +508 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_cast.cpp +409 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_complex.cpp +712 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_constant.cpp +286 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_float.cpp +141 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_int.cpp +365 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_manip.cpp +308 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_bitwise_cast.cpp +222 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_exponential.cpp +226 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_hyperbolic.cpp +183 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_power.cpp +265 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_complex_trigonometric.cpp +236 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_conversion.cpp +248 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_custom_default_arch.cpp +28 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_error_gamma.cpp +170 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_explicit_batch_instantiation.cpp +32 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_exponential.cpp +202 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_extract_pair.cpp +92 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_fp_manipulation.cpp +77 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_gnu_source.cpp +30 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_hyperbolic.cpp +167 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_load_store.cpp +304 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_memory.cpp +61 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_poly_evaluation.cpp +64 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_power.cpp +184 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_rounding.cpp +199 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_select.cpp +101 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_shuffle.cpp +760 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_sum.cpp +4 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_sum.hpp +34 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_traits.cpp +172 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_trigonometric.cpp +208 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_utils.hpp +611 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_wasm/test_wasm_playwright.py +123 -0
- sequenzo/dissimilarity_measures/src/xsimd/test/test_xsimd_api.cpp +1460 -0
- sequenzo/dissimilarity_measures/utils/__init__.py +16 -0
- sequenzo/dissimilarity_measures/utils/get_LCP_length_for_2_seq.py +44 -0
- sequenzo/dissimilarity_measures/utils/get_sm_trate_substitution_cost_matrix.cpython-310-darwin.so +0 -0
- sequenzo/dissimilarity_measures/utils/seqconc.cpython-310-darwin.so +0 -0
- sequenzo/dissimilarity_measures/utils/seqdss.cpython-310-darwin.so +0 -0
- sequenzo/dissimilarity_measures/utils/seqdur.cpython-310-darwin.so +0 -0
- sequenzo/dissimilarity_measures/utils/seqlength.cpython-310-darwin.so +0 -0
- sequenzo/multidomain/__init__.py +23 -0
- sequenzo/multidomain/association_between_domains.py +311 -0
- sequenzo/multidomain/cat.py +597 -0
- sequenzo/multidomain/combt.py +519 -0
- sequenzo/multidomain/dat.py +81 -0
- sequenzo/multidomain/idcd.py +139 -0
- sequenzo/multidomain/linked_polyad.py +292 -0
- sequenzo/openmp_setup.py +233 -0
- sequenzo/prefix_tree/__init__.py +62 -0
- sequenzo/prefix_tree/hub.py +114 -0
- sequenzo/prefix_tree/individual_level_indicators.py +1321 -0
- sequenzo/prefix_tree/spell_individual_level_indicators.py +580 -0
- sequenzo/prefix_tree/spell_level_indicators.py +297 -0
- sequenzo/prefix_tree/system_level_indicators.py +544 -0
- sequenzo/prefix_tree/utils.py +54 -0
- sequenzo/seqhmm/__init__.py +95 -0
- sequenzo/seqhmm/advanced_optimization.py +305 -0
- sequenzo/seqhmm/bootstrap.py +411 -0
- sequenzo/seqhmm/build_hmm.py +142 -0
- sequenzo/seqhmm/build_mhmm.py +136 -0
- sequenzo/seqhmm/build_nhmm.py +121 -0
- sequenzo/seqhmm/fit_mhmm.py +62 -0
- sequenzo/seqhmm/fit_model.py +61 -0
- sequenzo/seqhmm/fit_nhmm.py +76 -0
- sequenzo/seqhmm/formulas.py +289 -0
- sequenzo/seqhmm/forward_backward_nhmm.py +276 -0
- sequenzo/seqhmm/gradients_nhmm.py +306 -0
- sequenzo/seqhmm/hmm.py +291 -0
- sequenzo/seqhmm/mhmm.py +314 -0
- sequenzo/seqhmm/model_comparison.py +238 -0
- sequenzo/seqhmm/multichannel_em.py +282 -0
- sequenzo/seqhmm/multichannel_utils.py +138 -0
- sequenzo/seqhmm/nhmm.py +270 -0
- sequenzo/seqhmm/nhmm_utils.py +191 -0
- sequenzo/seqhmm/predict.py +137 -0
- sequenzo/seqhmm/predict_mhmm.py +142 -0
- sequenzo/seqhmm/simulate.py +878 -0
- sequenzo/seqhmm/utils.py +218 -0
- sequenzo/seqhmm/visualization.py +910 -0
- sequenzo/sequence_characteristics/__init__.py +40 -0
- sequenzo/sequence_characteristics/complexity_index.py +49 -0
- sequenzo/sequence_characteristics/overall_cross_sectional_entropy.py +220 -0
- sequenzo/sequence_characteristics/plot_characteristics.py +593 -0
- sequenzo/sequence_characteristics/simple_characteristics.py +311 -0
- sequenzo/sequence_characteristics/state_frequencies_and_entropy_per_sequence.py +39 -0
- sequenzo/sequence_characteristics/turbulence.py +155 -0
- sequenzo/sequence_characteristics/variance_of_spell_durations.py +86 -0
- sequenzo/sequence_characteristics/within_sequence_entropy.py +43 -0
- sequenzo/suffix_tree/__init__.py +66 -0
- sequenzo/suffix_tree/hub.py +114 -0
- sequenzo/suffix_tree/individual_level_indicators.py +1679 -0
- sequenzo/suffix_tree/spell_individual_level_indicators.py +493 -0
- sequenzo/suffix_tree/spell_level_indicators.py +248 -0
- sequenzo/suffix_tree/system_level_indicators.py +535 -0
- sequenzo/suffix_tree/utils.py +56 -0
- sequenzo/version_check.py +283 -0
- sequenzo/visualization/__init__.py +29 -0
- sequenzo/visualization/plot_mean_time.py +222 -0
- sequenzo/visualization/plot_modal_state.py +276 -0
- sequenzo/visualization/plot_most_frequent_sequences.py +147 -0
- sequenzo/visualization/plot_relative_frequency.py +405 -0
- sequenzo/visualization/plot_sequence_index.py +1175 -0
- sequenzo/visualization/plot_single_medoid.py +153 -0
- sequenzo/visualization/plot_state_distribution.py +651 -0
- sequenzo/visualization/plot_transition_matrix.py +190 -0
- sequenzo/visualization/utils/__init__.py +23 -0
- sequenzo/visualization/utils/utils.py +310 -0
- sequenzo/with_event_history_analysis/__init__.py +35 -0
- sequenzo/with_event_history_analysis/sequence_analysis_multi_state_model.py +850 -0
- sequenzo/with_event_history_analysis/sequence_history_analysis.py +283 -0
- sequenzo-0.1.31.dist-info/METADATA +286 -0
- sequenzo-0.1.31.dist-info/RECORD +299 -0
- sequenzo-0.1.31.dist-info/WHEEL +5 -0
- sequenzo-0.1.31.dist-info/licenses/LICENSE +28 -0
- sequenzo-0.1.31.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,1175 @@
|
|
|
1
|
+
"""
|
|
2
|
+
@Author : Yuqi Liang 梁彧祺
|
|
3
|
+
@File : plot_sequence_index.py
|
|
4
|
+
@Time : 29/12/2024 09:08
|
|
5
|
+
@Desc :
|
|
6
|
+
Generate sequence index plots.
|
|
7
|
+
"""
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import matplotlib.pyplot as plt
|
|
11
|
+
import matplotlib.gridspec as gridspec
|
|
12
|
+
|
|
13
|
+
# Use relative import to avoid circular import when top-level package imports visualization
|
|
14
|
+
from ..define_sequence_data import SequenceData
|
|
15
|
+
from sequenzo.visualization.utils import (
|
|
16
|
+
set_up_time_labels_for_x_axis,
|
|
17
|
+
save_figure_to_buffer,
|
|
18
|
+
create_standalone_legend,
|
|
19
|
+
combine_plot_with_legend,
|
|
20
|
+
save_and_show_results,
|
|
21
|
+
determine_layout,
|
|
22
|
+
show_plot_title
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def smart_sort_groups(groups):
|
|
27
|
+
"""
|
|
28
|
+
Smart sorting: prioritize numeric prefix, fallback to string sorting
|
|
29
|
+
|
|
30
|
+
:param groups: List of group names
|
|
31
|
+
:return: Sorted list of group names
|
|
32
|
+
"""
|
|
33
|
+
import re
|
|
34
|
+
|
|
35
|
+
# Compile regex once for better performance
|
|
36
|
+
numeric_pattern = re.compile(r'^(\d+)')
|
|
37
|
+
|
|
38
|
+
def sort_key(item):
|
|
39
|
+
match = numeric_pattern.match(str(item))
|
|
40
|
+
return (int(match.group(1)), str(item)) if match else (float('inf'), str(item))
|
|
41
|
+
|
|
42
|
+
return sorted(groups, key=sort_key)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _cmdscale(D):
|
|
46
|
+
"""
|
|
47
|
+
Classic Multidimensional Scaling (MDS), equivalent to R's cmdscale()
|
|
48
|
+
|
|
49
|
+
:param D: A NxN symmetric distance matrix
|
|
50
|
+
:return: Y, a Nxd coordinate matrix, where d is the largest positive eigenvalues' count
|
|
51
|
+
"""
|
|
52
|
+
n = len(D)
|
|
53
|
+
|
|
54
|
+
# Step 1: Compute the centering matrix
|
|
55
|
+
H = np.eye(n) - np.ones((n, n)) / n
|
|
56
|
+
|
|
57
|
+
# Step 2: Compute the double centered distance matrix
|
|
58
|
+
B = -0.5 * H @ (D ** 2) @ H
|
|
59
|
+
|
|
60
|
+
# Step 3: Compute eigenvalues and eigenvectors
|
|
61
|
+
eigvals, eigvecs = np.linalg.eigh(B)
|
|
62
|
+
|
|
63
|
+
# Step 4: Sort eigenvalues and eigenvectors in descending order
|
|
64
|
+
idx = np.argsort(eigvals)[::-1]
|
|
65
|
+
eigvals = eigvals[idx]
|
|
66
|
+
eigvecs = eigvecs[:, idx]
|
|
67
|
+
|
|
68
|
+
# Step 5: Select only positive eigenvalues
|
|
69
|
+
w, = np.where(eigvals > 0)
|
|
70
|
+
if len(w) > 0:
|
|
71
|
+
L = np.diag(np.sqrt(eigvals[w]))
|
|
72
|
+
V = eigvecs[:, w]
|
|
73
|
+
return V @ L # Return the MDS coordinates
|
|
74
|
+
else:
|
|
75
|
+
# Fallback if no positive eigenvalues
|
|
76
|
+
return np.zeros((n, 1))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _find_most_frequent_sequence(sequences):
|
|
80
|
+
"""
|
|
81
|
+
Find the most frequent sequence in the dataset.
|
|
82
|
+
|
|
83
|
+
:param sequences: numpy array of sequences
|
|
84
|
+
:return: index of the most frequent sequence
|
|
85
|
+
"""
|
|
86
|
+
from collections import Counter
|
|
87
|
+
|
|
88
|
+
# Convert sequences to tuples for hashing
|
|
89
|
+
seq_tuples = [tuple(seq) for seq in sequences]
|
|
90
|
+
|
|
91
|
+
# Count frequencies
|
|
92
|
+
counter = Counter(seq_tuples)
|
|
93
|
+
|
|
94
|
+
# Find the most frequent sequence
|
|
95
|
+
most_frequent = counter.most_common(1)[0][0]
|
|
96
|
+
|
|
97
|
+
# Find the index of this sequence in the original array
|
|
98
|
+
for i, seq in enumerate(seq_tuples):
|
|
99
|
+
if seq == most_frequent:
|
|
100
|
+
return i
|
|
101
|
+
|
|
102
|
+
return 0 # Fallback
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _select_sequences_subset(seqdata, sequence_selection, n_sequences, sort_by, sort_by_weight, weights, mask=None):
|
|
106
|
+
"""
|
|
107
|
+
Select a subset of sequences based on the selection method.
|
|
108
|
+
|
|
109
|
+
:param seqdata: SequenceData object
|
|
110
|
+
:param sequence_selection: Selection method ("all", "first_n", "last_n", or list of IDs)
|
|
111
|
+
:param n_sequences: Number of sequences for "first_n" or "last_n"
|
|
112
|
+
:param sort_by: Sorting method to use before selection
|
|
113
|
+
:param sort_by_weight: Whether to sort by weight
|
|
114
|
+
:param weights: Sequence weights
|
|
115
|
+
:param mask: Optional mask for pre-filtering sequences
|
|
116
|
+
:return: Boolean mask for selected sequences
|
|
117
|
+
"""
|
|
118
|
+
# Start with all sequences or pre-filtered mask
|
|
119
|
+
if mask is None:
|
|
120
|
+
mask = np.ones(len(seqdata.values), dtype=bool)
|
|
121
|
+
|
|
122
|
+
# If "all", return the current mask
|
|
123
|
+
if sequence_selection == "all":
|
|
124
|
+
return mask
|
|
125
|
+
|
|
126
|
+
# Get indices of sequences that pass the mask
|
|
127
|
+
valid_indices = np.where(mask)[0]
|
|
128
|
+
|
|
129
|
+
# Handle ID list selection
|
|
130
|
+
if isinstance(sequence_selection, list):
|
|
131
|
+
# Convert list to set for faster lookup
|
|
132
|
+
selected_ids = set(sequence_selection)
|
|
133
|
+
|
|
134
|
+
# Find indices of sequences with matching IDs
|
|
135
|
+
selected_mask = np.zeros(len(seqdata.values), dtype=bool)
|
|
136
|
+
if hasattr(seqdata, 'ids') and seqdata.ids is not None:
|
|
137
|
+
for i in valid_indices:
|
|
138
|
+
if seqdata.ids[i] in selected_ids:
|
|
139
|
+
selected_mask[i] = True
|
|
140
|
+
else:
|
|
141
|
+
print("Warning: sequence_selection provided as ID list but seqdata has no IDs. Using all sequences.")
|
|
142
|
+
return mask
|
|
143
|
+
|
|
144
|
+
return selected_mask
|
|
145
|
+
|
|
146
|
+
# For "first_n" or "last_n", we need to sort first
|
|
147
|
+
if sequence_selection in ["first_n", "last_n"]:
|
|
148
|
+
# Get the subset of data based on current mask
|
|
149
|
+
subset_seqdata = seqdata
|
|
150
|
+
subset_weights = weights
|
|
151
|
+
|
|
152
|
+
if not np.all(mask):
|
|
153
|
+
# Create subset if mask is not all True
|
|
154
|
+
subset_values = seqdata.values[mask]
|
|
155
|
+
subset_ids = seqdata.ids[mask] if hasattr(seqdata, 'ids') and seqdata.ids is not None else None
|
|
156
|
+
|
|
157
|
+
# Use original seqdata for structure, just work with filtered values
|
|
158
|
+
subset_seqdata = seqdata # Keep original structure
|
|
159
|
+
|
|
160
|
+
if weights is not None:
|
|
161
|
+
subset_weights = weights[mask]
|
|
162
|
+
|
|
163
|
+
# Apply sorting to get the order
|
|
164
|
+
distance_matrix = None
|
|
165
|
+
if sort_by in ["mds", "distance_to_most_frequent"]:
|
|
166
|
+
try:
|
|
167
|
+
from sequenzo.dissimilarity_measures.get_distance_matrix import get_distance_matrix
|
|
168
|
+
distance_matrix = get_distance_matrix(
|
|
169
|
+
seqdata=subset_seqdata,
|
|
170
|
+
method="OM",
|
|
171
|
+
sm="CONSTANT",
|
|
172
|
+
indel="auto"
|
|
173
|
+
)
|
|
174
|
+
if hasattr(distance_matrix, 'values'):
|
|
175
|
+
distance_matrix = distance_matrix.values
|
|
176
|
+
except ImportError:
|
|
177
|
+
print(f"Warning: Cannot compute distance matrix for '{sort_by}' sorting. Using unsorted order.")
|
|
178
|
+
sort_by = "unsorted"
|
|
179
|
+
|
|
180
|
+
# Apply sorting to the masked subset
|
|
181
|
+
if sort_by_weight and subset_weights is not None:
|
|
182
|
+
# Sort by weight on the subset
|
|
183
|
+
sorted_indices = np.argsort(-subset_weights)
|
|
184
|
+
else:
|
|
185
|
+
# Sort on the subset values
|
|
186
|
+
if sort_by == "unsorted" or sort_by == "none":
|
|
187
|
+
sorted_indices = np.arange(len(valid_indices))
|
|
188
|
+
elif sort_by == "lexicographic":
|
|
189
|
+
subset_values = seqdata.values[mask]
|
|
190
|
+
vals = subset_values.astype(float, copy=True)
|
|
191
|
+
vals = np.nan_to_num(vals, nan=np.inf)
|
|
192
|
+
sorted_indices = np.lexsort(vals.T[::-1])
|
|
193
|
+
elif sort_by in ["mds", "distance_to_most_frequent"]:
|
|
194
|
+
# For complex sorting that requires distance matrix,
|
|
195
|
+
# we'll fall back to simple lexicographic for now
|
|
196
|
+
subset_values = seqdata.values[mask]
|
|
197
|
+
vals = subset_values.astype(float, copy=True)
|
|
198
|
+
vals = np.nan_to_num(vals, nan=np.inf)
|
|
199
|
+
sorted_indices = np.lexsort(vals.T[::-1])
|
|
200
|
+
print(f"Warning: {sort_by} sorting simplified to lexicographic for sequence selection")
|
|
201
|
+
else:
|
|
202
|
+
sorted_indices = np.arange(len(valid_indices))
|
|
203
|
+
|
|
204
|
+
# Select first_n or last_n
|
|
205
|
+
n_available = len(sorted_indices)
|
|
206
|
+
n_to_select = min(n_sequences, n_available)
|
|
207
|
+
|
|
208
|
+
if sequence_selection == "first_n":
|
|
209
|
+
selected_subset_indices = sorted_indices[:n_to_select]
|
|
210
|
+
elif sequence_selection == "last_n":
|
|
211
|
+
selected_subset_indices = sorted_indices[-n_to_select:]
|
|
212
|
+
|
|
213
|
+
# Map back to original indices
|
|
214
|
+
original_indices = valid_indices[selected_subset_indices]
|
|
215
|
+
|
|
216
|
+
# Create final mask
|
|
217
|
+
final_mask = np.zeros(len(seqdata.values), dtype=bool)
|
|
218
|
+
final_mask[original_indices] = True
|
|
219
|
+
|
|
220
|
+
return final_mask
|
|
221
|
+
|
|
222
|
+
else:
|
|
223
|
+
raise ValueError(f"Unsupported sequence_selection: {sequence_selection}. "
|
|
224
|
+
f"Supported options: 'all', 'first_n', 'last_n', or list of IDs")
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def sort_sequences_by_method(seqdata, method="unsorted", mask=None, distance_matrix=None, weights=None):
|
|
228
|
+
"""
|
|
229
|
+
Sort sequences in SequenceData based on specified method.
|
|
230
|
+
|
|
231
|
+
:param seqdata: SequenceData object
|
|
232
|
+
:param method: str, sorting method - "unsorted", "lexicographic", "mds", "distance_to_most_frequent"
|
|
233
|
+
:param mask: np.array(bool), if provided, sort only this subset
|
|
234
|
+
:param distance_matrix: np.array, required for "mds" and "distance_to_most_frequent" methods
|
|
235
|
+
:param weights: np.array, optional weights for sequences
|
|
236
|
+
:return: np.array sorting indices (relative to original order)
|
|
237
|
+
"""
|
|
238
|
+
values = seqdata.values.copy()
|
|
239
|
+
|
|
240
|
+
n_sequences = len(values) if mask is None else int(np.sum(mask))
|
|
241
|
+
|
|
242
|
+
if mask is not None:
|
|
243
|
+
values = values[mask]
|
|
244
|
+
if distance_matrix is not None:
|
|
245
|
+
# Only slice if distance_matrix is for the full sample
|
|
246
|
+
if distance_matrix.shape[0] != n_sequences:
|
|
247
|
+
masked_indices = np.where(mask)[0]
|
|
248
|
+
distance_matrix = distance_matrix[np.ix_(masked_indices, masked_indices)]
|
|
249
|
+
|
|
250
|
+
if method == "unsorted" or method == "none":
|
|
251
|
+
# Keep original order (R default)
|
|
252
|
+
return np.arange(n_sequences)
|
|
253
|
+
|
|
254
|
+
elif method == "lexicographic":
|
|
255
|
+
# Lexicographic sorting (NaN-safe)
|
|
256
|
+
vals = values.astype(float, copy=True)
|
|
257
|
+
# Push NaNs to the end for sorting
|
|
258
|
+
vals = np.nan_to_num(vals, nan=np.inf)
|
|
259
|
+
return np.lexsort(vals.T[::-1])
|
|
260
|
+
|
|
261
|
+
elif method == "mds":
|
|
262
|
+
# MDS first dimension sorting
|
|
263
|
+
if distance_matrix is None:
|
|
264
|
+
raise ValueError("Distance matrix is required for MDS sorting")
|
|
265
|
+
|
|
266
|
+
# TODO: Support weighted MDS (TraMineR's wcmdscale analogue) when weights are provided.
|
|
267
|
+
# Compute MDS coordinates
|
|
268
|
+
mds_coords = _cmdscale(distance_matrix)
|
|
269
|
+
|
|
270
|
+
# Sort by first MDS dimension
|
|
271
|
+
return np.argsort(mds_coords[:, 0])
|
|
272
|
+
|
|
273
|
+
elif method == "distance_to_most_frequent":
|
|
274
|
+
# Sort by distance to most frequent sequence
|
|
275
|
+
if distance_matrix is None:
|
|
276
|
+
raise ValueError("Distance matrix is required for distance_to_most_frequent sorting")
|
|
277
|
+
|
|
278
|
+
# Find most frequent sequence
|
|
279
|
+
most_freq_idx = _find_most_frequent_sequence(values)
|
|
280
|
+
|
|
281
|
+
# Get distances to most frequent sequence
|
|
282
|
+
distances = distance_matrix[most_freq_idx, :]
|
|
283
|
+
|
|
284
|
+
# Sort by distance (ascending)
|
|
285
|
+
return np.argsort(distances)
|
|
286
|
+
|
|
287
|
+
else:
|
|
288
|
+
raise ValueError(f"Unsupported sorting method: {method}. "
|
|
289
|
+
f"Supported methods are: 'unsorted', 'lexicographic', 'mds', 'distance_to_most_frequent'")
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def plot_sequence_index(seqdata: SequenceData,
|
|
293
|
+
# Grouping parameters
|
|
294
|
+
group_by_column=None,
|
|
295
|
+
group_dataframe=None,
|
|
296
|
+
group_column_name=None,
|
|
297
|
+
group_labels=None,
|
|
298
|
+
# Other parameters
|
|
299
|
+
sort_by="lexicographic",
|
|
300
|
+
sort_by_weight=False,
|
|
301
|
+
weights="auto",
|
|
302
|
+
figsize=(10, 6),
|
|
303
|
+
plot_style="standard",
|
|
304
|
+
title=None,
|
|
305
|
+
xlabel="Time",
|
|
306
|
+
ylabel="Sequences",
|
|
307
|
+
save_as=None,
|
|
308
|
+
dpi=200,
|
|
309
|
+
layout='column',
|
|
310
|
+
nrows: int = None,
|
|
311
|
+
ncols: int = None,
|
|
312
|
+
group_order=None,
|
|
313
|
+
sort_groups='auto',
|
|
314
|
+
fontsize=12,
|
|
315
|
+
show_group_titles: bool = True,
|
|
316
|
+
include_legend: bool = True,
|
|
317
|
+
sequence_selection="all",
|
|
318
|
+
n_sequences=10,
|
|
319
|
+
show_sequence_ids=False,
|
|
320
|
+
sort_by_ids=None,
|
|
321
|
+
return_sorted_ids=False,
|
|
322
|
+
show_title=True,
|
|
323
|
+
proportional_scaling=False,
|
|
324
|
+
hide_y_axis=False
|
|
325
|
+
):
|
|
326
|
+
"""Creates sequence index plots, optionally grouped by categories.
|
|
327
|
+
|
|
328
|
+
This function creates index plots that visualize sequences as horizontal lines,
|
|
329
|
+
with different sorting options matching R's TraMineR functionality.
|
|
330
|
+
|
|
331
|
+
**Two API modes for grouping:**
|
|
332
|
+
|
|
333
|
+
1. **Simplified API** (when grouping info is already in the data):
|
|
334
|
+
```python
|
|
335
|
+
plot_sequence_index(seqdata, group_by_column="Cluster", group_labels=cluster_labels)
|
|
336
|
+
```
|
|
337
|
+
|
|
338
|
+
2. **Complete API** (when grouping info is in a separate dataframe):
|
|
339
|
+
```python
|
|
340
|
+
plot_sequence_index(seqdata, group_dataframe=membership_df,
|
|
341
|
+
group_column_name="Cluster", group_labels=cluster_labels)
|
|
342
|
+
```
|
|
343
|
+
|
|
344
|
+
:param seqdata: SequenceData object containing sequence information
|
|
345
|
+
|
|
346
|
+
**New API parameters (recommended):**
|
|
347
|
+
:param group_by_column: (str, optional) Column name from seqdata.data to group by.
|
|
348
|
+
Use this when grouping information is already in your data.
|
|
349
|
+
Example: "Cluster", "sex", "education"
|
|
350
|
+
:param group_dataframe: (pd.DataFrame, optional) Separate dataframe containing grouping information.
|
|
351
|
+
Use this when grouping info is in a separate table (e.g., clustering results).
|
|
352
|
+
Must contain ID column and grouping column.
|
|
353
|
+
:param group_column_name: (str, optional) Name of the grouping column in group_dataframe.
|
|
354
|
+
Required when using group_dataframe.
|
|
355
|
+
:param group_labels: (dict, optional) Custom labels for group values.
|
|
356
|
+
Example: {1: "Late Family Formation", 2: "Early Partnership"}
|
|
357
|
+
Maps original values to display labels.
|
|
358
|
+
|
|
359
|
+
:param sort_by: Sorting method for sequences within groups:
|
|
360
|
+
- 'unsorted' or 'none': Keep original order (R TraMineR default)
|
|
361
|
+
- 'lexicographic': Sort sequences lexicographically
|
|
362
|
+
- 'mds': Sort by first MDS dimension (requires distance computation)
|
|
363
|
+
- 'distance_to_most_frequent': Sort by distance to most frequent sequence
|
|
364
|
+
:param sort_by_weight: If True, sort sequences by weight (descending), overrides sort_by
|
|
365
|
+
:param weights: (np.ndarray or "auto") Weights for sequences. If "auto", uses seqdata.weights if available
|
|
366
|
+
:param figsize: Size of each subplot figure (only used when plot_style="custom")
|
|
367
|
+
:param plot_style: Plot aspect style:
|
|
368
|
+
- 'standard': Standard proportions (10, 6) - balanced view
|
|
369
|
+
- 'compact': Compact/vertical proportions (8, 8) - more vertical like R plots
|
|
370
|
+
- 'wide': Wide proportions (12, 4) - emphasizes time progression
|
|
371
|
+
- 'narrow': Narrow/tall proportions (8, 10) - moderately vertical
|
|
372
|
+
- 'custom': Use the provided figsize parameter
|
|
373
|
+
:param title: Title for the plot (if None, default titles will be used)
|
|
374
|
+
:param xlabel: Label for the x-axis
|
|
375
|
+
:param ylabel: Label for the y-axis
|
|
376
|
+
:param save_as: File path to save the plot (if None, plot will be shown)
|
|
377
|
+
:param dpi: DPI for saved image
|
|
378
|
+
:param layout: Layout style - 'column' (default, 3xn), 'grid' (nxn)
|
|
379
|
+
:param group_order: List, manually specify group order (overrides sort_groups)
|
|
380
|
+
:param sort_groups: String, sorting method: 'auto'(smart numeric), 'numeric'(numeric prefix), 'alpha'(alphabetical), 'none'(original order)
|
|
381
|
+
:param fontsize: Base font size for text elements (titles use fontsize+2, ticks use fontsize-2)
|
|
382
|
+
:param show_group_titles: Whether to show group titles
|
|
383
|
+
:param include_legend: Whether to include legend in the plot (True by default)
|
|
384
|
+
:param sequence_selection: Method for selecting sequences to visualize:
|
|
385
|
+
- "all": Show all sequences (default)
|
|
386
|
+
- "first_n": Show first n sequences from each group
|
|
387
|
+
- "last_n": Show last n sequences from each group
|
|
388
|
+
- list: List of specific sequence IDs to show
|
|
389
|
+
:param n_sequences: Number of sequences to show when using "first_n" or "last_n" (default: 10)
|
|
390
|
+
:param show_sequence_ids: If True, show actual sequence IDs on y-axis instead of sequence numbers.
|
|
391
|
+
Most useful when sequence_selection is a list of IDs (default: False)
|
|
392
|
+
:param sort_by_ids: (list or np.ndarray, optional) Custom ID order for sorting sequences.
|
|
393
|
+
When provided, sequences will be sorted to match this ID order, overriding
|
|
394
|
+
the sort_by parameter. This is useful for aligning multiple plots so that
|
|
395
|
+
the same IDs appear in the same row across different visualizations.
|
|
396
|
+
Example: sort_by_ids=[1, 3, 2, 5, 4] will sort sequences by this exact order.
|
|
397
|
+
:param return_sorted_ids: (bool, default: False) If True, returns the sorted ID order after plotting.
|
|
398
|
+
This is useful for multidomain analysis where you want to use the sorted
|
|
399
|
+
IDs from the first plot to align subsequent plots.
|
|
400
|
+
Returns a dictionary with group names as keys and sorted ID arrays as values
|
|
401
|
+
(for grouped plots), or a single array of sorted IDs (for single plots).
|
|
402
|
+
:param show_title: (bool, default: True) If False, suppresses the main title display even if title parameter is provided.
|
|
403
|
+
This allows you to control title visibility separately from providing a title string.
|
|
404
|
+
:param proportional_scaling: (bool, default: False) If True, scales subplot heights proportionally based on
|
|
405
|
+
the number of sequences in each group. Useful when groups have very different sizes.
|
|
406
|
+
Only applies to grouped plots with layout='column'.
|
|
407
|
+
:param hide_y_axis: (bool, default: False) If True, hides y-axis ticks, labels, and spine for all subplots.
|
|
408
|
+
Useful when using proportional_scaling to create cleaner visualizations.
|
|
409
|
+
|
|
410
|
+
Note: For 'mds' and 'distance_to_most_frequent' sorting, distance matrices are computed
|
|
411
|
+
automatically using Optimal Matching (OM) with constant substitution costs.
|
|
412
|
+
"""
|
|
413
|
+
# Determine figure size based on plot style
|
|
414
|
+
style_sizes = {
|
|
415
|
+
'standard': (10, 6), # Balanced view
|
|
416
|
+
'compact': (8, 8), # More square, like R plots
|
|
417
|
+
'wide': (12, 4), # Wide, emphasizes time
|
|
418
|
+
'narrow': (8, 10), # Moderately vertical
|
|
419
|
+
'custom': figsize # User-provided
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
if plot_style not in style_sizes:
|
|
423
|
+
raise ValueError(f"Invalid plot_style '{plot_style}'. "
|
|
424
|
+
f"Supported styles: {list(style_sizes.keys())}")
|
|
425
|
+
|
|
426
|
+
# Special validation for custom plot style
|
|
427
|
+
if plot_style == 'custom' and figsize == (10, 6):
|
|
428
|
+
raise ValueError(
|
|
429
|
+
"When using plot_style='custom', you must explicitly provide a figsize parameter "
|
|
430
|
+
"that differs from the default (10, 6). "
|
|
431
|
+
"Suggested custom sizes:\n"
|
|
432
|
+
" - For wide plots: figsize=(15, 5)\n"
|
|
433
|
+
" - For tall plots: figsize=(7, 12)\n"
|
|
434
|
+
" - For square plots: figsize=(9, 9)\n"
|
|
435
|
+
" - For small plots: figsize=(6, 4)\n"
|
|
436
|
+
"Example: plot_sequence_index(data, plot_style='custom', figsize=(12, 8))"
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
actual_figsize = style_sizes[plot_style]
|
|
440
|
+
|
|
441
|
+
# Handle the simplified API: group_by_column
|
|
442
|
+
if group_by_column is not None:
|
|
443
|
+
# Validate that the column exists in the original data
|
|
444
|
+
if group_by_column not in seqdata.data.columns:
|
|
445
|
+
available_cols = [col for col in seqdata.data.columns if col not in seqdata.time and col != seqdata.id_col]
|
|
446
|
+
raise ValueError(
|
|
447
|
+
f"Column '{group_by_column}' not found in the data. "
|
|
448
|
+
f"Available columns for grouping: {available_cols}"
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
# Automatically create group_dataframe and group_column_name from the simplified API
|
|
452
|
+
group_dataframe = seqdata.data[[seqdata.id_col, group_by_column]].copy()
|
|
453
|
+
group_dataframe.columns = ['Entity ID', 'Category']
|
|
454
|
+
group_column_name = 'Category'
|
|
455
|
+
|
|
456
|
+
# Handle group labels - flexible and user-controllable
|
|
457
|
+
unique_values = seqdata.data[group_by_column].unique()
|
|
458
|
+
|
|
459
|
+
if group_labels is not None:
|
|
460
|
+
# User provided custom labels - use them
|
|
461
|
+
missing_keys = set(unique_values) - set(group_labels.keys())
|
|
462
|
+
if missing_keys:
|
|
463
|
+
raise ValueError(
|
|
464
|
+
f"group_labels missing mappings for values: {missing_keys}. "
|
|
465
|
+
f"Please provide labels for all unique values in '{group_by_column}': {sorted(unique_values)}"
|
|
466
|
+
)
|
|
467
|
+
group_dataframe['Category'] = group_dataframe['Category'].map(group_labels)
|
|
468
|
+
else:
|
|
469
|
+
# No custom labels provided - use smart defaults
|
|
470
|
+
if all(isinstance(v, (int, float, np.integer, np.floating)) and not pd.isna(v) for v in unique_values):
|
|
471
|
+
# Numeric values - keep as is (user can provide group_labels if they want custom names)
|
|
472
|
+
pass
|
|
473
|
+
# For string/categorical values, keep original values
|
|
474
|
+
# This handles cases where users already have meaningful labels like "Male"/"Female"
|
|
475
|
+
|
|
476
|
+
print(f"[>] Creating grouped plots by '{group_by_column}' with {len(unique_values)} categories")
|
|
477
|
+
|
|
478
|
+
# If no grouping information, create a single plot
|
|
479
|
+
if group_dataframe is None or group_column_name is None:
|
|
480
|
+
return _sequence_index_plot_single(seqdata, sort_by, sort_by_weight, weights, actual_figsize, plot_style, title, xlabel, ylabel, save_as, dpi, fontsize, include_legend, sequence_selection, n_sequences, show_sequence_ids, sort_by_ids, return_sorted_ids, show_title)
|
|
481
|
+
|
|
482
|
+
# Process weights
|
|
483
|
+
if isinstance(weights, str) and weights == "auto":
|
|
484
|
+
weights = getattr(seqdata, "weights", None)
|
|
485
|
+
|
|
486
|
+
if weights is not None:
|
|
487
|
+
weights = np.asarray(weights, dtype=float).reshape(-1)
|
|
488
|
+
if len(weights) != len(seqdata.values):
|
|
489
|
+
raise ValueError("Length of weights must equal number of sequences.")
|
|
490
|
+
|
|
491
|
+
# Ensure ID columns match (convert if needed)
|
|
492
|
+
id_col_name = "Entity ID" if "Entity ID" in group_dataframe.columns else group_dataframe.columns[0]
|
|
493
|
+
|
|
494
|
+
# Apply group_labels if provided (for group_dataframe API)
|
|
495
|
+
if group_labels is not None and group_column_name in group_dataframe.columns:
|
|
496
|
+
# Validate that all values in the group column have labels
|
|
497
|
+
unique_values = group_dataframe[group_column_name].unique()
|
|
498
|
+
missing_keys = set(unique_values) - set(group_labels.keys())
|
|
499
|
+
|
|
500
|
+
# Track if we performed auto-remapping (to avoid double copying)
|
|
501
|
+
remapping_performed = False
|
|
502
|
+
|
|
503
|
+
# Auto-detect and fix: if missing_keys exist and they look like medoid indices
|
|
504
|
+
# (e.g., large integers that don't match group_labels keys), automatically remap them
|
|
505
|
+
if missing_keys:
|
|
506
|
+
# Check if missing_keys look like medoid indices (large integers > expected cluster count)
|
|
507
|
+
expected_cluster_count = len(group_labels)
|
|
508
|
+
missing_values_list = list(missing_keys)
|
|
509
|
+
|
|
510
|
+
# Check if all missing values are numeric and larger than expected cluster count
|
|
511
|
+
# This suggests they might be medoid indices from KMedoids
|
|
512
|
+
all_numeric = all(isinstance(v, (int, float, np.integer, np.floating)) and not pd.isna(v)
|
|
513
|
+
for v in missing_values_list)
|
|
514
|
+
all_large = all(isinstance(v, (int, float, np.integer, np.floating)) and not pd.isna(v)
|
|
515
|
+
and (v > expected_cluster_count or v < 1) for v in missing_values_list)
|
|
516
|
+
|
|
517
|
+
if all_numeric and all_large and len(missing_values_list) == expected_cluster_count:
|
|
518
|
+
# This looks like medoid indices - auto-remap to 1-k
|
|
519
|
+
print(f"[>] Detected medoid indices in '{group_column_name}'. "
|
|
520
|
+
f"Automatically remapping {sorted(missing_values_list)} to cluster labels 1-{expected_cluster_count}.")
|
|
521
|
+
|
|
522
|
+
# Create mapping from medoid indices to cluster labels 1-k
|
|
523
|
+
sorted_missing = sorted(missing_values_list)
|
|
524
|
+
medoid_to_cluster = {val: idx + 1 for idx, val in enumerate(sorted_missing)}
|
|
525
|
+
|
|
526
|
+
# Apply the remapping
|
|
527
|
+
group_dataframe = group_dataframe.copy() # Avoid modifying original
|
|
528
|
+
remapping_performed = True
|
|
529
|
+
group_dataframe[group_column_name] = group_dataframe[group_column_name].map(medoid_to_cluster)
|
|
530
|
+
|
|
531
|
+
# Now verify that all values match group_labels keys
|
|
532
|
+
unique_values_after_remap = group_dataframe[group_column_name].unique()
|
|
533
|
+
missing_keys_after = set(unique_values_after_remap) - set(group_labels.keys())
|
|
534
|
+
if missing_keys_after:
|
|
535
|
+
raise ValueError(
|
|
536
|
+
f"After auto-remapping, group_labels still missing mappings for values: {missing_keys_after}. "
|
|
537
|
+
f"Please provide labels for all unique values in '{group_column_name}': {sorted(unique_values_after_remap)}"
|
|
538
|
+
)
|
|
539
|
+
else:
|
|
540
|
+
# Not medoid indices - raise error as before
|
|
541
|
+
raise ValueError(
|
|
542
|
+
f"group_labels missing mappings for values: {missing_keys}. "
|
|
543
|
+
f"Please provide labels for all unique values in '{group_column_name}': {sorted(unique_values)}"
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Apply the labels mapping
|
|
547
|
+
# Only copy if we haven't already copied above (during medoid remapping)
|
|
548
|
+
if not remapping_performed:
|
|
549
|
+
group_dataframe = group_dataframe.copy() # Avoid modifying original
|
|
550
|
+
|
|
551
|
+
group_dataframe[group_column_name] = group_dataframe[group_column_name].map(group_labels)
|
|
552
|
+
|
|
553
|
+
# Get unique groups and sort them based on user preference
|
|
554
|
+
if group_order:
|
|
555
|
+
# Use manually specified order, filter out non-existing groups
|
|
556
|
+
groups = [g for g in group_order if g in group_dataframe[group_column_name].unique()]
|
|
557
|
+
missing_groups = [g for g in group_dataframe[group_column_name].unique() if g not in group_order]
|
|
558
|
+
if missing_groups:
|
|
559
|
+
print(f"[Warning] Groups not in group_order will be excluded: {missing_groups}")
|
|
560
|
+
elif group_labels is not None:
|
|
561
|
+
# If group_labels is provided, use its key order to determine groups order
|
|
562
|
+
# This ensures subplot order matches the order in group_labels dictionary
|
|
563
|
+
# Note: group_labels keys are original values, values are labels (which become groups)
|
|
564
|
+
mapped_labels = []
|
|
565
|
+
available_labels = set(group_dataframe[group_column_name].unique())
|
|
566
|
+
|
|
567
|
+
# Iterate through group_labels in order (Python 3.7+ dicts maintain insertion order)
|
|
568
|
+
for original_key, label_value in group_labels.items():
|
|
569
|
+
# Check if this label exists in the mapped dataframe
|
|
570
|
+
if label_value in available_labels:
|
|
571
|
+
mapped_labels.append(label_value)
|
|
572
|
+
|
|
573
|
+
# Also check for any labels in dataframe that weren't in group_labels
|
|
574
|
+
missing_in_labels = available_labels - set(mapped_labels)
|
|
575
|
+
if missing_in_labels:
|
|
576
|
+
print(f"[Warning] Some groups in data are not in group_labels and will be excluded: {missing_in_labels}")
|
|
577
|
+
|
|
578
|
+
groups = mapped_labels
|
|
579
|
+
elif sort_groups == 'numeric' or sort_groups == 'auto':
|
|
580
|
+
groups = smart_sort_groups(group_dataframe[group_column_name].unique())
|
|
581
|
+
elif sort_groups == 'alpha':
|
|
582
|
+
groups = sorted(group_dataframe[group_column_name].unique())
|
|
583
|
+
elif sort_groups == 'none':
|
|
584
|
+
groups = list(group_dataframe[group_column_name].unique())
|
|
585
|
+
else:
|
|
586
|
+
raise ValueError(f"Invalid sort_groups value: {sort_groups}. Use 'auto', 'numeric', 'alpha', or 'none'.")
|
|
587
|
+
|
|
588
|
+
num_groups = len(groups)
|
|
589
|
+
|
|
590
|
+
# Calculate figure size and layout based on number of groups and specified layout
|
|
591
|
+
nrows, ncols = determine_layout(num_groups, layout=layout, nrows=nrows, ncols=ncols)
|
|
592
|
+
|
|
593
|
+
# Calculate height ratios for proportional scaling if enabled
|
|
594
|
+
if proportional_scaling and layout == 'column':
|
|
595
|
+
# First pass: collect group sizes
|
|
596
|
+
group_sizes = []
|
|
597
|
+
for group in groups:
|
|
598
|
+
group_ids = group_dataframe[group_dataframe[group_column_name] == group][id_col_name].values
|
|
599
|
+
mask = np.isin(seqdata.ids, group_ids)
|
|
600
|
+
if np.any(mask):
|
|
601
|
+
mask = _select_sequences_subset(seqdata, sequence_selection, n_sequences, sort_by, sort_by_weight, weights, mask)
|
|
602
|
+
group_sizes.append(int(np.sum(mask)))
|
|
603
|
+
else:
|
|
604
|
+
group_sizes.append(1)
|
|
605
|
+
|
|
606
|
+
# Calculate height ratios (min 0.3 to avoid too small subplots)
|
|
607
|
+
if len(group_sizes) > 0:
|
|
608
|
+
max_size = max(group_sizes)
|
|
609
|
+
height_ratios = [max(0.3, size / max_size) for size in group_sizes]
|
|
610
|
+
max_ratio = max(height_ratios)
|
|
611
|
+
height_ratios = [h / max_ratio for h in height_ratios]
|
|
612
|
+
else:
|
|
613
|
+
height_ratios = [1.0] * num_groups
|
|
614
|
+
|
|
615
|
+
# Use gridspec for proportional heights
|
|
616
|
+
# Increase hspace to prevent x-axis labels from overlapping with subplot above
|
|
617
|
+
# Use larger hspace for column layout to accommodate group titles and x-axis labels
|
|
618
|
+
hspace_value = 0.4 if layout == 'column' else 0.25
|
|
619
|
+
fig = plt.figure(figsize=(actual_figsize[0], actual_figsize[1] * sum(height_ratios) / len(height_ratios) * num_groups))
|
|
620
|
+
gs = gridspec.GridSpec(nrows=num_groups, ncols=1, figure=fig,
|
|
621
|
+
height_ratios=height_ratios, hspace=hspace_value, wspace=0.15)
|
|
622
|
+
axes = [fig.add_subplot(gs[i]) for i in range(num_groups)]
|
|
623
|
+
else:
|
|
624
|
+
# Standard subplots with equal heights
|
|
625
|
+
# Increase hspace for column layout to prevent x-axis labels from overlapping
|
|
626
|
+
hspace_value = 0.4 if layout == 'column' else 0.25
|
|
627
|
+
fig, axes = plt.subplots(
|
|
628
|
+
nrows=nrows,
|
|
629
|
+
ncols=ncols,
|
|
630
|
+
figsize=(actual_figsize[0] * ncols, actual_figsize[1] * nrows),
|
|
631
|
+
gridspec_kw={'wspace': 0.15, 'hspace': hspace_value}
|
|
632
|
+
)
|
|
633
|
+
axes = axes.flatten()
|
|
634
|
+
|
|
635
|
+
# Dictionary to store sorted IDs for each group (if return_sorted_ids is True)
|
|
636
|
+
# Use OrderedDict or list to maintain group order
|
|
637
|
+
sorted_ids_by_group = {}
|
|
638
|
+
group_order_list = [] # Track the order of groups as they are processed
|
|
639
|
+
|
|
640
|
+
# Create a plot for each group
|
|
641
|
+
for i, group in enumerate(groups):
|
|
642
|
+
# Get IDs for this group
|
|
643
|
+
group_ids = group_dataframe[group_dataframe[group_column_name] == group][id_col_name].values
|
|
644
|
+
|
|
645
|
+
# Match IDs with sequence data
|
|
646
|
+
mask = np.isin(seqdata.ids, group_ids)
|
|
647
|
+
if not np.any(mask):
|
|
648
|
+
print(f"Warning: No matching sequences found for group '{group}'")
|
|
649
|
+
continue
|
|
650
|
+
|
|
651
|
+
# Apply sequence selection to this group
|
|
652
|
+
mask = _select_sequences_subset(seqdata, sequence_selection, n_sequences, sort_by, sort_by_weight, weights, mask)
|
|
653
|
+
|
|
654
|
+
# Extract sequences for this group
|
|
655
|
+
group_sequences = seqdata.values[mask]
|
|
656
|
+
|
|
657
|
+
# Track group IDs for y-axis labels
|
|
658
|
+
group_ids_for_labels = None
|
|
659
|
+
if hasattr(seqdata, 'ids') and seqdata.ids is not None and show_sequence_ids:
|
|
660
|
+
group_ids_for_labels = seqdata.ids[mask]
|
|
661
|
+
|
|
662
|
+
# Get weights for this group
|
|
663
|
+
if weights is not None:
|
|
664
|
+
group_weights = weights[mask]
|
|
665
|
+
else:
|
|
666
|
+
group_weights = None
|
|
667
|
+
|
|
668
|
+
# Handle NaN values for better visualization
|
|
669
|
+
if np.isnan(group_sequences).any():
|
|
670
|
+
# Map NaN to a dedicated state code with proper masking
|
|
671
|
+
group_sequences = group_sequences.astype(float)
|
|
672
|
+
group_sequences[np.isnan(group_sequences)] = np.nan
|
|
673
|
+
|
|
674
|
+
# Get the IDs for this group (after selection)
|
|
675
|
+
group_ids_after_selection = seqdata.ids[mask]
|
|
676
|
+
|
|
677
|
+
# Determine sorting method: sort_by_ids takes priority if provided
|
|
678
|
+
if sort_by_ids is not None:
|
|
679
|
+
# Sort by custom ID order
|
|
680
|
+
# Convert sort_by_ids to numpy array for easier handling
|
|
681
|
+
sort_by_ids_array = np.asarray(sort_by_ids)
|
|
682
|
+
|
|
683
|
+
# Create a mapping from ID to position in sort_by_ids
|
|
684
|
+
# IDs not in sort_by_ids will be placed at the end
|
|
685
|
+
id_to_position = {id_val: pos for pos, id_val in enumerate(sort_by_ids_array)}
|
|
686
|
+
|
|
687
|
+
# Get positions for each ID in the current group
|
|
688
|
+
# IDs not in sort_by_ids get a very large position value (placed at end)
|
|
689
|
+
max_position = len(sort_by_ids_array)
|
|
690
|
+
positions = np.array([id_to_position.get(id_val, max_position + i)
|
|
691
|
+
for i, id_val in enumerate(group_ids_after_selection)])
|
|
692
|
+
|
|
693
|
+
# Sort by position (ascending order)
|
|
694
|
+
sorted_indices = np.argsort(positions)
|
|
695
|
+
|
|
696
|
+
# Warn if some IDs in the group are not in sort_by_ids
|
|
697
|
+
missing_ids = set(group_ids_after_selection) - set(sort_by_ids_array)
|
|
698
|
+
if missing_ids:
|
|
699
|
+
print(f"[Warning] Group '{group}': {len(missing_ids)} IDs not found in sort_by_ids, "
|
|
700
|
+
f"they will be placed at the end: {list(missing_ids)[:5]}{'...' if len(missing_ids) > 5 else ''}")
|
|
701
|
+
|
|
702
|
+
elif sort_by_weight and group_weights is not None:
|
|
703
|
+
# Sort by weight (descending)
|
|
704
|
+
sorted_indices = np.argsort(-group_weights)
|
|
705
|
+
else:
|
|
706
|
+
# For group plots, we'll use simpler sorting to avoid complex object creation
|
|
707
|
+
if sort_by == "lexicographic":
|
|
708
|
+
vals = group_sequences.astype(float, copy=True)
|
|
709
|
+
vals = np.nan_to_num(vals, nan=np.inf)
|
|
710
|
+
sorted_indices = np.lexsort(vals.T[::-1])
|
|
711
|
+
elif sort_by in ["mds", "distance_to_most_frequent"]:
|
|
712
|
+
# Fallback to lexicographic for complex sorting methods
|
|
713
|
+
print(f"Warning: {sort_by} sorting simplified to lexicographic for grouped plots with sequence selection")
|
|
714
|
+
vals = group_sequences.astype(float, copy=True)
|
|
715
|
+
vals = np.nan_to_num(vals, nan=np.inf)
|
|
716
|
+
sorted_indices = np.lexsort(vals.T[::-1])
|
|
717
|
+
else:
|
|
718
|
+
# unsorted or other methods
|
|
719
|
+
sorted_indices = np.arange(len(group_sequences))
|
|
720
|
+
|
|
721
|
+
sorted_data = group_sequences[sorted_indices]
|
|
722
|
+
|
|
723
|
+
# Track sorted IDs for y-axis labels if needed
|
|
724
|
+
sorted_group_ids = None
|
|
725
|
+
if group_ids_for_labels is not None and show_sequence_ids:
|
|
726
|
+
sorted_group_ids = group_ids_for_labels[sorted_indices]
|
|
727
|
+
|
|
728
|
+
# Store sorted IDs for this group if return_sorted_ids is True
|
|
729
|
+
if return_sorted_ids:
|
|
730
|
+
sorted_ids_by_group[group] = group_ids_after_selection[sorted_indices]
|
|
731
|
+
# Track group order (only add once per group)
|
|
732
|
+
if group not in group_order_list:
|
|
733
|
+
group_order_list.append(group)
|
|
734
|
+
|
|
735
|
+
# Plot on the corresponding axis
|
|
736
|
+
ax = axes[i]
|
|
737
|
+
# Use masked array for better NaN handling
|
|
738
|
+
data = sorted_data.astype(float)
|
|
739
|
+
data[data < 1] = np.nan
|
|
740
|
+
|
|
741
|
+
# Check for all-missing or all-invalid data
|
|
742
|
+
if np.all(~np.isfinite(data)):
|
|
743
|
+
print(f"Warning: all values missing/invalid for group '{group}'")
|
|
744
|
+
ax.axis('off')
|
|
745
|
+
continue
|
|
746
|
+
|
|
747
|
+
im = ax.imshow(np.ma.masked_invalid(data), aspect='auto', cmap=seqdata.get_colormap(),
|
|
748
|
+
interpolation='nearest', vmin=1, vmax=len(seqdata.states))
|
|
749
|
+
|
|
750
|
+
# Remove grid lines
|
|
751
|
+
ax.grid(False)
|
|
752
|
+
|
|
753
|
+
# Set up time labels
|
|
754
|
+
set_up_time_labels_for_x_axis(seqdata, ax)
|
|
755
|
+
|
|
756
|
+
# Enhance y-axis aesthetics - evenly spaced ticks including the last sequence
|
|
757
|
+
num_sequences = sorted_data.shape[0]
|
|
758
|
+
|
|
759
|
+
# Determine tick positions and labels
|
|
760
|
+
if show_sequence_ids and sorted_group_ids is not None:
|
|
761
|
+
# Show sequence IDs instead of sequence numbers
|
|
762
|
+
# For large number of sequences, show fewer ticks to avoid overcrowding
|
|
763
|
+
if num_sequences <= 10:
|
|
764
|
+
ytick_positions = np.arange(num_sequences)
|
|
765
|
+
ytick_labels = [str(sid) for sid in sorted_group_ids]
|
|
766
|
+
else:
|
|
767
|
+
# Show subset of IDs for readability
|
|
768
|
+
if plot_style == "narrow":
|
|
769
|
+
num_ticks = min(8, num_sequences)
|
|
770
|
+
else:
|
|
771
|
+
num_ticks = min(11, num_sequences)
|
|
772
|
+
ytick_positions = np.linspace(0, num_sequences - 1, num=num_ticks, dtype=int)
|
|
773
|
+
ytick_positions = np.unique(ytick_positions)
|
|
774
|
+
ytick_labels = [str(sorted_group_ids[pos]) for pos in ytick_positions]
|
|
775
|
+
else:
|
|
776
|
+
# Default behavior: show sequence numbers
|
|
777
|
+
if plot_style == "narrow":
|
|
778
|
+
num_ticks = min(8, num_sequences) # Fewer ticks for narrow plots
|
|
779
|
+
else:
|
|
780
|
+
num_ticks = min(11, num_sequences)
|
|
781
|
+
ytick_positions = np.linspace(0, num_sequences - 1, num=num_ticks, dtype=int)
|
|
782
|
+
ytick_positions = np.unique(ytick_positions)
|
|
783
|
+
ytick_labels = (ytick_positions + 1).astype(int)
|
|
784
|
+
|
|
785
|
+
# Hide y-axis if requested
|
|
786
|
+
if hide_y_axis:
|
|
787
|
+
ax.set_yticks([])
|
|
788
|
+
ax.set_yticklabels([])
|
|
789
|
+
ax.spines['left'].set_visible(False)
|
|
790
|
+
else:
|
|
791
|
+
ax.set_yticks(ytick_positions)
|
|
792
|
+
ax.set_yticklabels(ytick_labels, fontsize=fontsize-2, color='black')
|
|
793
|
+
|
|
794
|
+
# Customize axis style
|
|
795
|
+
ax.spines['top'].set_visible(False)
|
|
796
|
+
ax.spines['right'].set_visible(False)
|
|
797
|
+
if not hide_y_axis:
|
|
798
|
+
ax.spines['left'].set_color('gray')
|
|
799
|
+
ax.spines['left'].set_linewidth(0.7)
|
|
800
|
+
ax.spines['left'].set_position(('outward', 5))
|
|
801
|
+
ax.spines['bottom'].set_color('gray')
|
|
802
|
+
ax.spines['bottom'].set_linewidth(0.7)
|
|
803
|
+
|
|
804
|
+
# Move spines slightly away from the plot area for better aesthetics
|
|
805
|
+
ax.spines['bottom'].set_position(('outward', 5))
|
|
806
|
+
|
|
807
|
+
# Ensure ticks are always visible regardless of plot style
|
|
808
|
+
ax.tick_params(axis='x', colors='gray', length=4, width=0.7, which='major')
|
|
809
|
+
if not hide_y_axis:
|
|
810
|
+
ax.tick_params(axis='y', colors='gray', length=4, width=0.7, which='major')
|
|
811
|
+
|
|
812
|
+
# Force tick visibility for narrow plot styles
|
|
813
|
+
ax.xaxis.set_ticks_position('bottom')
|
|
814
|
+
if not hide_y_axis:
|
|
815
|
+
ax.yaxis.set_ticks_position('left')
|
|
816
|
+
ax.tick_params(axis='both', which='major', direction='out')
|
|
817
|
+
|
|
818
|
+
# Add group title with weight information
|
|
819
|
+
# Check if we have effective weights (not all 1.0) and they were provided by user
|
|
820
|
+
original_weights = getattr(seqdata, "weights", None)
|
|
821
|
+
if original_weights is not None and not np.allclose(original_weights, 1.0) and group_weights is not None:
|
|
822
|
+
sum_w = float(group_weights.sum())
|
|
823
|
+
group_title = f"{group} (n = {num_sequences}, total weight = {sum_w:.1f})"
|
|
824
|
+
else:
|
|
825
|
+
group_title = f"{group} (n = {num_sequences})"
|
|
826
|
+
if show_group_titles:
|
|
827
|
+
show_plot_title(ax, group_title, show=True, fontsize=fontsize, loc='right')
|
|
828
|
+
|
|
829
|
+
# Add axis labels
|
|
830
|
+
if i % ncols == 0 and not hide_y_axis:
|
|
831
|
+
ax.set_ylabel(ylabel, fontsize=fontsize, labelpad=10, color='black')
|
|
832
|
+
|
|
833
|
+
# For column layout, only show x-axis label on the bottom subplot
|
|
834
|
+
# For grid layout, show x-axis label on bottom row subplots
|
|
835
|
+
if layout == 'column':
|
|
836
|
+
# Only show xlabel on the last (bottom) subplot
|
|
837
|
+
if i == num_groups - 1:
|
|
838
|
+
ax.set_xlabel(xlabel, fontsize=fontsize, labelpad=10, color='black')
|
|
839
|
+
else:
|
|
840
|
+
# For grid layout, show xlabel on bottom row
|
|
841
|
+
if i >= num_groups - ncols:
|
|
842
|
+
ax.set_xlabel(xlabel, fontsize=fontsize, labelpad=10, color='black')
|
|
843
|
+
|
|
844
|
+
# Hide unused subplots (not needed for proportional scaling with column layout)
|
|
845
|
+
if not (proportional_scaling and layout == 'column'):
|
|
846
|
+
for j in range(i + 1, len(axes)):
|
|
847
|
+
axes[j].set_visible(False)
|
|
848
|
+
|
|
849
|
+
# Add a common title if provided and show_title is True
|
|
850
|
+
if title and show_title:
|
|
851
|
+
fig.suptitle(title, fontsize=fontsize+2, y=1.02)
|
|
852
|
+
|
|
853
|
+
# Adjust layout to remove tight_layout warning and eliminate extra right space
|
|
854
|
+
# Increase hspace for column layout to prevent x-axis labels from overlapping with subplot above
|
|
855
|
+
if proportional_scaling and layout == 'column':
|
|
856
|
+
fig.subplots_adjust(left=0.08, right=0.98, bottom=0.1, top=0.9, wspace=0.15, hspace=0.4)
|
|
857
|
+
else:
|
|
858
|
+
hspace_value = 0.4 if layout == 'column' else 0.25
|
|
859
|
+
fig.subplots_adjust(wspace=0.15, hspace=hspace_value, bottom=0.1, top=0.9, right=0.98, left=0.08)
|
|
860
|
+
|
|
861
|
+
# Save main figure to memory
|
|
862
|
+
main_buffer = save_figure_to_buffer(fig, dpi=dpi)
|
|
863
|
+
|
|
864
|
+
if include_legend:
|
|
865
|
+
# Create standalone legend
|
|
866
|
+
colors = seqdata.color_map_by_label
|
|
867
|
+
legend_buffer = create_standalone_legend(
|
|
868
|
+
colors=colors,
|
|
869
|
+
labels=seqdata.labels,
|
|
870
|
+
ncol=min(5, len(seqdata.states)),
|
|
871
|
+
figsize=(actual_figsize[0] * ncols, 1),
|
|
872
|
+
fontsize=fontsize-2,
|
|
873
|
+
dpi=dpi
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
# Combine plot with legend
|
|
877
|
+
if save_as and not save_as.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf')):
|
|
878
|
+
save_as = save_as + '.png'
|
|
879
|
+
|
|
880
|
+
combined_img = combine_plot_with_legend(
|
|
881
|
+
main_buffer,
|
|
882
|
+
legend_buffer,
|
|
883
|
+
output_path=save_as,
|
|
884
|
+
dpi=dpi,
|
|
885
|
+
padding=20
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
# Display combined image
|
|
889
|
+
plt.figure(figsize=(actual_figsize[0] * ncols, actual_figsize[1] * nrows + 1))
|
|
890
|
+
plt.imshow(combined_img)
|
|
891
|
+
plt.axis('off')
|
|
892
|
+
plt.show()
|
|
893
|
+
plt.close()
|
|
894
|
+
else:
|
|
895
|
+
# Display plot without legend
|
|
896
|
+
if save_as and not save_as.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf')):
|
|
897
|
+
save_as = save_as + '.png'
|
|
898
|
+
|
|
899
|
+
# Save or show the main plot directly
|
|
900
|
+
plt.figure(figsize=(actual_figsize[0] * ncols, actual_figsize[1] * nrows))
|
|
901
|
+
plt.imshow(main_buffer)
|
|
902
|
+
plt.axis('off')
|
|
903
|
+
|
|
904
|
+
if save_as:
|
|
905
|
+
plt.savefig(save_as, dpi=dpi, bbox_inches='tight')
|
|
906
|
+
plt.show()
|
|
907
|
+
plt.close()
|
|
908
|
+
|
|
909
|
+
# Return sorted IDs if requested
|
|
910
|
+
# Return as dictionary with group order preserved (groups processed in plot order)
|
|
911
|
+
if return_sorted_ids:
|
|
912
|
+
# Create an ordered dictionary or return the dictionary with group_order_list for reference
|
|
913
|
+
# For simplicity, return the dictionary (Python 3.7+ maintains insertion order)
|
|
914
|
+
# But we'll also return group_order_list as metadata if needed
|
|
915
|
+
return sorted_ids_by_group
|
|
916
|
+
|
|
917
|
+
|
|
918
|
+
def _sequence_index_plot_single(seqdata: SequenceData,
|
|
919
|
+
sort_by="unsorted",
|
|
920
|
+
sort_by_weight=False,
|
|
921
|
+
weights="auto",
|
|
922
|
+
figsize=(10, 6),
|
|
923
|
+
plot_style="standard",
|
|
924
|
+
title=None,
|
|
925
|
+
xlabel="Time",
|
|
926
|
+
ylabel="Sequences",
|
|
927
|
+
save_as=None,
|
|
928
|
+
dpi=200,
|
|
929
|
+
fontsize=12,
|
|
930
|
+
include_legend=True,
|
|
931
|
+
sequence_selection="all",
|
|
932
|
+
n_sequences=10,
|
|
933
|
+
show_sequence_ids=False,
|
|
934
|
+
sort_by_ids=None,
|
|
935
|
+
return_sorted_ids=False,
|
|
936
|
+
show_title=True):
|
|
937
|
+
"""Efficiently creates a sequence index plot using `imshow` for faster rendering.
|
|
938
|
+
|
|
939
|
+
:param seqdata: SequenceData object containing sequence information
|
|
940
|
+
:param sort_by: Sorting method ('unsorted', 'lexicographic', 'mds', 'distance_to_most_frequent')
|
|
941
|
+
:param sort_by_weight: If True, sort sequences by weight (descending), overrides sort_by
|
|
942
|
+
:param weights: (np.ndarray or "auto") Weights for sequences. If "auto", uses seqdata.weights if available
|
|
943
|
+
:param figsize: (tuple): Size of the figure (only used when plot_style="custom").
|
|
944
|
+
:param plot_style: Plot aspect style ('standard', 'compact', 'wide', 'narrow', 'custom')
|
|
945
|
+
:param title: (str): Title for the plot.
|
|
946
|
+
:param xlabel: (str): Label for the x-axis.
|
|
947
|
+
:param ylabel: (str): Label for the y-axis.
|
|
948
|
+
:param save_as: File path to save the plot
|
|
949
|
+
:param dpi: DPI for saved image
|
|
950
|
+
:param include_legend: Whether to include legend in the plot (True by default)
|
|
951
|
+
:param sequence_selection: Method for selecting sequences ("all", "first_n", "last_n", or list of IDs)
|
|
952
|
+
:param n_sequences: Number of sequences for "first_n" or "last_n"
|
|
953
|
+
:param show_sequence_ids: If True, show actual sequence IDs on y-axis instead of sequence numbers
|
|
954
|
+
|
|
955
|
+
:return None.
|
|
956
|
+
"""
|
|
957
|
+
# Determine figure size based on plot style
|
|
958
|
+
style_sizes = {
|
|
959
|
+
'standard': (10, 6), # Balanced view
|
|
960
|
+
'compact': (8, 8), # More square, like R plots
|
|
961
|
+
'wide': (12, 4), # Wide, emphasizes time
|
|
962
|
+
'narrow': (8, 10), # Moderately vertical
|
|
963
|
+
'custom': figsize # User-provided
|
|
964
|
+
}
|
|
965
|
+
|
|
966
|
+
if plot_style not in style_sizes:
|
|
967
|
+
raise ValueError(f"Invalid plot_style '{plot_style}'. "
|
|
968
|
+
f"Supported styles: {list(style_sizes.keys())}")
|
|
969
|
+
|
|
970
|
+
# Special validation for custom plot style
|
|
971
|
+
if plot_style == 'custom' and figsize == (10, 6):
|
|
972
|
+
raise ValueError(
|
|
973
|
+
"When using plot_style='custom', you must explicitly provide a figsize parameter "
|
|
974
|
+
"that differs from the default (10, 6). "
|
|
975
|
+
"Suggested custom sizes:\n"
|
|
976
|
+
" - For wide plots: figsize=(15, 5)\n"
|
|
977
|
+
" - For tall plots: figsize=(7, 12)\n"
|
|
978
|
+
" - For square plots: figsize=(9, 9)\n"
|
|
979
|
+
" - For small plots: figsize=(6, 4)\n"
|
|
980
|
+
"Example: plot_sequence_index(data, plot_style='custom', figsize=(12, 8))"
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
actual_figsize = style_sizes[plot_style]
|
|
984
|
+
|
|
985
|
+
# Process weights
|
|
986
|
+
if isinstance(weights, str) and weights == "auto":
|
|
987
|
+
weights = getattr(seqdata, "weights", None)
|
|
988
|
+
|
|
989
|
+
if weights is not None:
|
|
990
|
+
weights = np.asarray(weights, dtype=float).reshape(-1)
|
|
991
|
+
if len(weights) != len(seqdata.values):
|
|
992
|
+
raise ValueError("Length of weights must equal number of sequences.")
|
|
993
|
+
|
|
994
|
+
# Apply sequence selection and get the filtered data directly
|
|
995
|
+
selection_mask = _select_sequences_subset(seqdata, sequence_selection, n_sequences, sort_by, sort_by_weight, weights)
|
|
996
|
+
|
|
997
|
+
# Get sequence values as NumPy array (apply selection if needed)
|
|
998
|
+
selected_ids = None # Track selected IDs for y-axis labels
|
|
999
|
+
if not np.all(selection_mask):
|
|
1000
|
+
sequence_values = seqdata.values[selection_mask].copy()
|
|
1001
|
+
# Track selected IDs for y-axis display
|
|
1002
|
+
if hasattr(seqdata, 'ids') and seqdata.ids is not None:
|
|
1003
|
+
selected_ids = seqdata.ids[selection_mask]
|
|
1004
|
+
# Update weights if provided
|
|
1005
|
+
if weights is not None:
|
|
1006
|
+
weights = weights[selection_mask]
|
|
1007
|
+
else:
|
|
1008
|
+
sequence_values = seqdata.values.copy()
|
|
1009
|
+
# All IDs are selected
|
|
1010
|
+
if hasattr(seqdata, 'ids') and seqdata.ids is not None:
|
|
1011
|
+
selected_ids = seqdata.ids
|
|
1012
|
+
|
|
1013
|
+
# Handle NaN values for better visualization
|
|
1014
|
+
if np.isnan(sequence_values).any():
|
|
1015
|
+
# Keep NaN as float for proper masking
|
|
1016
|
+
sequence_values = sequence_values.astype(float)
|
|
1017
|
+
|
|
1018
|
+
# Sort sequences based on specified method
|
|
1019
|
+
# sort_by_ids takes priority if provided
|
|
1020
|
+
if sort_by_ids is not None:
|
|
1021
|
+
# Sort by custom ID order
|
|
1022
|
+
# Convert sort_by_ids to numpy array for easier handling
|
|
1023
|
+
sort_by_ids_array = np.asarray(sort_by_ids)
|
|
1024
|
+
|
|
1025
|
+
# Create a mapping from ID to position in sort_by_ids
|
|
1026
|
+
# IDs not in sort_by_ids will be placed at the end
|
|
1027
|
+
id_to_position = {id_val: pos for pos, id_val in enumerate(sort_by_ids_array)}
|
|
1028
|
+
|
|
1029
|
+
# Get positions for each ID in the selected sequences
|
|
1030
|
+
# IDs not in sort_by_ids get a very large position value (placed at end)
|
|
1031
|
+
max_position = len(sort_by_ids_array)
|
|
1032
|
+
positions = np.array([id_to_position.get(id_val, max_position + i)
|
|
1033
|
+
for i, id_val in enumerate(selected_ids)])
|
|
1034
|
+
|
|
1035
|
+
# Sort by position (ascending order)
|
|
1036
|
+
sorted_indices = np.argsort(positions)
|
|
1037
|
+
|
|
1038
|
+
# Warn if some IDs are not in sort_by_ids
|
|
1039
|
+
missing_ids = set(selected_ids) - set(sort_by_ids_array)
|
|
1040
|
+
if missing_ids:
|
|
1041
|
+
print(f"[Warning] {len(missing_ids)} IDs not found in sort_by_ids, "
|
|
1042
|
+
f"they will be placed at the end: {list(missing_ids)[:5]}{'...' if len(missing_ids) > 5 else ''}")
|
|
1043
|
+
|
|
1044
|
+
elif sort_by_weight and weights is not None:
|
|
1045
|
+
# Sort by weight (descending)
|
|
1046
|
+
sorted_indices = np.argsort(-weights)
|
|
1047
|
+
else:
|
|
1048
|
+
# Use simpler sorting for the filtered data
|
|
1049
|
+
if sort_by == "lexicographic":
|
|
1050
|
+
vals = sequence_values.astype(float, copy=True)
|
|
1051
|
+
vals = np.nan_to_num(vals, nan=np.inf)
|
|
1052
|
+
sorted_indices = np.lexsort(vals.T[::-1])
|
|
1053
|
+
elif sort_by in ["mds", "distance_to_most_frequent"]:
|
|
1054
|
+
# Fallback to lexicographic for complex sorting methods
|
|
1055
|
+
print(f"Warning: {sort_by} sorting simplified to lexicographic for sequence selection")
|
|
1056
|
+
vals = sequence_values.astype(float, copy=True)
|
|
1057
|
+
vals = np.nan_to_num(vals, nan=np.inf)
|
|
1058
|
+
sorted_indices = np.lexsort(vals.T[::-1])
|
|
1059
|
+
else:
|
|
1060
|
+
# unsorted or other methods
|
|
1061
|
+
sorted_indices = np.arange(len(sequence_values))
|
|
1062
|
+
|
|
1063
|
+
sorted_data = sequence_values[sorted_indices]
|
|
1064
|
+
|
|
1065
|
+
# Track sorted IDs for y-axis labels if needed
|
|
1066
|
+
sorted_ids = None
|
|
1067
|
+
if selected_ids is not None and show_sequence_ids:
|
|
1068
|
+
sorted_ids = selected_ids[sorted_indices]
|
|
1069
|
+
|
|
1070
|
+
# Create the plot using imshow with proper NaN handling
|
|
1071
|
+
fig, ax = plt.subplots(figsize=actual_figsize)
|
|
1072
|
+
# Use masked array for better NaN handling
|
|
1073
|
+
data = sorted_data.astype(float)
|
|
1074
|
+
data[data < 1] = np.nan
|
|
1075
|
+
|
|
1076
|
+
# Check for all-missing or all-invalid data
|
|
1077
|
+
if np.all(~np.isfinite(data)):
|
|
1078
|
+
print(f"Warning: all values missing/invalid in sequence data")
|
|
1079
|
+
ax.axis('off')
|
|
1080
|
+
return
|
|
1081
|
+
|
|
1082
|
+
ax.imshow(np.ma.masked_invalid(data), aspect='auto', cmap=seqdata.get_colormap(),
|
|
1083
|
+
interpolation='nearest', vmin=1, vmax=len(seqdata.states))
|
|
1084
|
+
|
|
1085
|
+
# Disable background grid and all axis guide lines
|
|
1086
|
+
ax.grid(False)
|
|
1087
|
+
|
|
1088
|
+
# Optional: remove tick marks and tick labels to avoid visual grid effects
|
|
1089
|
+
# ax.set_xticks([])
|
|
1090
|
+
# ax.set_yticks([])
|
|
1091
|
+
|
|
1092
|
+
# x label
|
|
1093
|
+
set_up_time_labels_for_x_axis(seqdata, ax)
|
|
1094
|
+
|
|
1095
|
+
# Enhance y-axis aesthetics - evenly spaced ticks including the last sequence
|
|
1096
|
+
num_sequences = sorted_data.shape[0]
|
|
1097
|
+
|
|
1098
|
+
# Determine tick positions and labels
|
|
1099
|
+
if show_sequence_ids and sorted_ids is not None:
|
|
1100
|
+
# Show sequence IDs instead of sequence numbers
|
|
1101
|
+
# For large number of sequences, show fewer ticks to avoid overcrowding
|
|
1102
|
+
if num_sequences <= 10:
|
|
1103
|
+
ytick_positions = np.arange(num_sequences)
|
|
1104
|
+
ytick_labels = [str(sid) for sid in sorted_ids]
|
|
1105
|
+
else:
|
|
1106
|
+
# Show subset of IDs for readability
|
|
1107
|
+
if plot_style == "narrow":
|
|
1108
|
+
num_ticks = min(8, num_sequences)
|
|
1109
|
+
else:
|
|
1110
|
+
num_ticks = min(11, num_sequences)
|
|
1111
|
+
ytick_positions = np.linspace(0, num_sequences - 1, num=num_ticks, dtype=int)
|
|
1112
|
+
ytick_positions = np.unique(ytick_positions)
|
|
1113
|
+
ytick_labels = [str(sorted_ids[pos]) for pos in ytick_positions]
|
|
1114
|
+
else:
|
|
1115
|
+
# Default behavior: show sequence numbers
|
|
1116
|
+
if plot_style == "narrow":
|
|
1117
|
+
num_ticks = min(8, num_sequences) # Fewer ticks for narrow plots
|
|
1118
|
+
else:
|
|
1119
|
+
num_ticks = min(11, num_sequences)
|
|
1120
|
+
ytick_positions = np.linspace(0, num_sequences - 1, num=num_ticks, dtype=int)
|
|
1121
|
+
ytick_positions = np.unique(ytick_positions)
|
|
1122
|
+
ytick_labels = (ytick_positions + 1).astype(int)
|
|
1123
|
+
|
|
1124
|
+
ax.set_yticks(ytick_positions)
|
|
1125
|
+
ax.set_yticklabels(ytick_labels, fontsize=fontsize-2, color='black')
|
|
1126
|
+
|
|
1127
|
+
|
|
1128
|
+
# Customize axis line styles and ticks
|
|
1129
|
+
ax.spines['top'].set_visible(False)
|
|
1130
|
+
ax.spines['right'].set_visible(False)
|
|
1131
|
+
ax.spines['left'].set_color('gray')
|
|
1132
|
+
ax.spines['bottom'].set_color('gray')
|
|
1133
|
+
ax.spines['left'].set_linewidth(0.7)
|
|
1134
|
+
ax.spines['bottom'].set_linewidth(0.7)
|
|
1135
|
+
|
|
1136
|
+
# Move spines slightly away from the plot area for better aesthetics
|
|
1137
|
+
ax.spines['left'].set_position(('outward', 5))
|
|
1138
|
+
ax.spines['bottom'].set_position(('outward', 5))
|
|
1139
|
+
|
|
1140
|
+
# Ensure ticks are always visible regardless of plot style
|
|
1141
|
+
ax.tick_params(axis='x', colors='gray', length=4, width=0.7, which='major')
|
|
1142
|
+
ax.tick_params(axis='y', colors='gray', length=4, width=0.7, which='major')
|
|
1143
|
+
|
|
1144
|
+
# Force tick visibility for narrow plot styles
|
|
1145
|
+
ax.xaxis.set_ticks_position('bottom')
|
|
1146
|
+
ax.yaxis.set_ticks_position('left')
|
|
1147
|
+
ax.tick_params(axis='both', which='major', direction='out')
|
|
1148
|
+
|
|
1149
|
+
# Add labels and title
|
|
1150
|
+
ax.set_xlabel(xlabel, fontsize=fontsize, labelpad=10, color='black')
|
|
1151
|
+
ax.set_ylabel(ylabel, fontsize=fontsize, labelpad=10, color='black')
|
|
1152
|
+
|
|
1153
|
+
# Set title with weight information if available and show_title is True
|
|
1154
|
+
if title is not None and show_title:
|
|
1155
|
+
display_title = title
|
|
1156
|
+
|
|
1157
|
+
# Check if we have effective weights (not all 1.0) and they were provided by user
|
|
1158
|
+
original_weights = getattr(seqdata, "weights", None)
|
|
1159
|
+
if original_weights is not None and not np.allclose(original_weights, 1.0) and weights is not None:
|
|
1160
|
+
sum_w = float(weights.sum())
|
|
1161
|
+
display_title += f" (n = {num_sequences}, total weight = {sum_w:.1f})"
|
|
1162
|
+
else:
|
|
1163
|
+
display_title += f" (n = {num_sequences})"
|
|
1164
|
+
|
|
1165
|
+
ax.set_title(display_title, fontsize=fontsize+2, color='black')
|
|
1166
|
+
|
|
1167
|
+
# Use legend from SequenceData if requested
|
|
1168
|
+
if include_legend:
|
|
1169
|
+
ax.legend(*seqdata.get_legend(), bbox_to_anchor=(1.05, 1), loc='upper left')
|
|
1170
|
+
|
|
1171
|
+
save_and_show_results(save_as, dpi=dpi)
|
|
1172
|
+
|
|
1173
|
+
# Return sorted IDs if requested
|
|
1174
|
+
if return_sorted_ids:
|
|
1175
|
+
return selected_ids[sorted_indices] if selected_ids is not None else None
|