sequenzo 0.1.17__cp39-cp39-macosx_10_9_universal2.whl → 0.1.18__cp39-cp39-macosx_10_9_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sequenzo might be problematic. Click here for more details.

Files changed (86) hide show
  1. sequenzo/__init__.py +25 -1
  2. sequenzo/big_data/clara/clara.py +1 -1
  3. sequenzo/big_data/clara/utils/get_weighted_diss.c +157 -157
  4. sequenzo/big_data/clara/utils/get_weighted_diss.cpython-39-darwin.so +0 -0
  5. sequenzo/clustering/hierarchical_clustering.py +202 -8
  6. sequenzo/define_sequence_data.py +34 -2
  7. sequenzo/dissimilarity_measures/c_code.cpython-39-darwin.so +0 -0
  8. sequenzo/dissimilarity_measures/get_substitution_cost_matrix.py +1 -1
  9. sequenzo/dissimilarity_measures/src/DHDdistance.cpp +13 -37
  10. sequenzo/dissimilarity_measures/src/LCPdistance.cpp +13 -37
  11. sequenzo/dissimilarity_measures/src/OMdistance.cpp +12 -47
  12. sequenzo/dissimilarity_measures/src/OMspellDistance.cpp +103 -67
  13. sequenzo/dissimilarity_measures/src/dp_utils.h +160 -0
  14. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_arithmetic.hpp +41 -16
  15. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_complex.hpp +4 -0
  16. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_details.hpp +7 -0
  17. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_logical.hpp +10 -0
  18. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_math.hpp +127 -43
  19. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_memory.hpp +30 -2
  20. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_swizzle.hpp +174 -0
  21. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_trigo.hpp +14 -5
  22. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx.hpp +111 -54
  23. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx2.hpp +131 -9
  24. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512bw.hpp +11 -113
  25. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512dq.hpp +39 -7
  26. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512f.hpp +336 -30
  27. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vbmi.hpp +9 -37
  28. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vbmi2.hpp +58 -0
  29. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_common.hpp +1 -0
  30. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_common_fwd.hpp +35 -2
  31. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_constants.hpp +3 -1
  32. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_emulated.hpp +17 -0
  33. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_avx.hpp +13 -0
  34. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_sse.hpp +18 -0
  35. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma4.hpp +13 -0
  36. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_isa.hpp +8 -0
  37. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_neon.hpp +363 -34
  38. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_neon64.hpp +7 -0
  39. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_rvv.hpp +13 -0
  40. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_scalar.hpp +41 -4
  41. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse2.hpp +252 -16
  42. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse3.hpp +9 -0
  43. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_ssse3.hpp +12 -1
  44. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sve.hpp +7 -0
  45. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_vsx.hpp +892 -0
  46. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_wasm.hpp +78 -1
  47. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_arch.hpp +3 -1
  48. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_config.hpp +13 -2
  49. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_cpuid.hpp +5 -0
  50. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_inline.hpp +5 -1
  51. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_all_registers.hpp +2 -0
  52. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_api.hpp +64 -1
  53. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_batch.hpp +36 -0
  54. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_rvv_register.hpp +40 -31
  55. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_traits.hpp +8 -0
  56. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_vsx_register.hpp +77 -0
  57. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/xsimd.hpp +6 -0
  58. sequenzo/dissimilarity_measures/utils/get_sm_trate_substitution_cost_matrix.c +157 -157
  59. sequenzo/dissimilarity_measures/utils/get_sm_trate_substitution_cost_matrix.cpython-39-darwin.so +0 -0
  60. sequenzo/dissimilarity_measures/utils/seqconc.c +157 -157
  61. sequenzo/dissimilarity_measures/utils/seqconc.cpython-39-darwin.so +0 -0
  62. sequenzo/dissimilarity_measures/utils/seqdss.c +157 -157
  63. sequenzo/dissimilarity_measures/utils/seqdss.cpython-39-darwin.so +0 -0
  64. sequenzo/dissimilarity_measures/utils/seqdur.c +157 -157
  65. sequenzo/dissimilarity_measures/utils/seqdur.cpython-39-darwin.so +0 -0
  66. sequenzo/dissimilarity_measures/utils/seqlength.c +157 -157
  67. sequenzo/dissimilarity_measures/utils/seqlength.cpython-39-darwin.so +0 -0
  68. sequenzo/sequence_characteristics/__init__.py +4 -0
  69. sequenzo/sequence_characteristics/complexity_index.py +17 -57
  70. sequenzo/sequence_characteristics/overall_cross_sectional_entropy.py +177 -111
  71. sequenzo/sequence_characteristics/plot_characteristics.py +30 -11
  72. sequenzo/sequence_characteristics/simple_characteristics.py +1 -0
  73. sequenzo/sequence_characteristics/state_frequencies_and_entropy_per_sequence.py +9 -3
  74. sequenzo/sequence_characteristics/turbulence.py +47 -67
  75. sequenzo/sequence_characteristics/variance_of_spell_durations.py +19 -9
  76. sequenzo/sequence_characteristics/within_sequence_entropy.py +5 -58
  77. sequenzo/visualization/plot_sequence_index.py +58 -35
  78. sequenzo/visualization/plot_state_distribution.py +57 -36
  79. sequenzo/with_event_history_analysis/__init__.py +35 -0
  80. sequenzo/with_event_history_analysis/sequence_analysis_multi_state_model.py +850 -0
  81. sequenzo/with_event_history_analysis/sequence_history_analysis.py +283 -0
  82. {sequenzo-0.1.17.dist-info → sequenzo-0.1.18.dist-info}/METADATA +7 -6
  83. {sequenzo-0.1.17.dist-info → sequenzo-0.1.18.dist-info}/RECORD +86 -79
  84. {sequenzo-0.1.17.dist-info → sequenzo-0.1.18.dist-info}/WHEEL +0 -0
  85. {sequenzo-0.1.17.dist-info → sequenzo-0.1.18.dist-info}/licenses/LICENSE +0 -0
  86. {sequenzo-0.1.17.dist-info → sequenzo-0.1.18.dist-info}/top_level.txt +0 -0
@@ -75,6 +75,12 @@ from scipy.spatial.distance import squareform
75
75
  # sklearn metrics no longer needed - using C++ implementation
76
76
  from fastcluster import linkage
77
77
 
78
+ import rpy2.robjects as ro
79
+ from rpy2.robjects import numpy2ri
80
+ from rpy2.robjects.packages import importr
81
+ from rpy2.robjects.conversion import localconverter
82
+ from rpy2.robjects import FloatVector
83
+
78
84
  # Import C++ cluster quality functions
79
85
  try:
80
86
  from . import clustering_c_code
@@ -84,7 +90,7 @@ except ImportError:
84
90
  print("[!] Warning: C++ cluster quality functions not available. Using Python fallback.")
85
91
 
86
92
  # Corrected imports: Use relative imports *within* the package.
87
- from ..visualization.utils import save_and_show_results
93
+ from sequenzo.visualization.utils import save_and_show_results
88
94
 
89
95
  # Global flag to ensure Ward warning is only shown once per session
90
96
  _WARD_WARNING_SHOWN = False
@@ -259,6 +265,79 @@ def _clean_distance_matrix(matrix):
259
265
  return matrix
260
266
 
261
267
 
268
+ def _hclust_to_linkage_matrix(linkage_matrix):
269
+ """
270
+ Convert an R `hclust` object to a SciPy-compatible linkage matrix.
271
+
272
+ This function takes an `hclust` object returned by R (e.g., from
273
+ `fastcluster::hclust`) and converts it into the standard linkage matrix
274
+ format used by SciPy (`scipy.cluster.hierarchy.linkage`), which can be
275
+ used for dendrogram plotting or further clustering analysis in Python.
276
+
277
+ Parameters
278
+ ----------
279
+ linkage_matrix : rpy2.robjects.ListVector
280
+ An R `hclust` object. Expected to contain at least the following fields:
281
+ - 'merge': ndarray of shape (n-1, 2), indicating which clusters are merged
282
+ at each step (negative indices for original observations,
283
+ positive indices for previously merged clusters).
284
+ - 'height': ndarray of shape (n-1,), distances at which merges occur.
285
+ - 'order': ordering of the leaves.
286
+
287
+ Returns
288
+ -------
289
+ Z : numpy.ndarray, shape (n-1, 4), dtype=float
290
+ A SciPy-compatible linkage matrix where each row represents a merge:
291
+ - Z[i, 0] : index of the first cluster (0-based)
292
+ - Z[i, 1] : index of the second cluster (0-based)
293
+ - Z[i, 2] : distance between the merged clusters
294
+ - Z[i, 3] : total number of original samples in the newly formed cluster
295
+
296
+ Notes
297
+ -----
298
+ - The conversion handles the difference in indexing:
299
+ - In R's `hclust`, negative numbers in 'merge' indicate original samples
300
+ and positive numbers indicate previously merged clusters (1-based).
301
+ - In the returned SciPy linkage matrix, all indices are converted to 0-based.
302
+ - The function iteratively tracks cluster sizes to populate the fourth column
303
+ (sample counts) required by SciPy.
304
+ """
305
+
306
+ n = len(linkage_matrix.rx2("order")) # 样本数
307
+ merge = np.array(linkage_matrix.rx2("merge"), dtype=int) # (n-1, 2)
308
+ height = np.array(linkage_matrix.rx2("height"), dtype=float)
309
+
310
+ cluster_sizes = np.ones(n, dtype=int) # 单个样本初始大小 = 1
311
+ Z = np.zeros((n - 1, 4), dtype=float)
312
+
313
+ for i in range(n - 1):
314
+ a, b = merge[i]
315
+
316
+ # R hclust 编号负数表示原始样本
317
+ if a < 0:
318
+ idx1 = -a - 1 # 转成 0-based
319
+ size1 = 1
320
+ else:
321
+ idx1 = n + a - 1 # 已合并簇,0-based
322
+ size1 = cluster_sizes[idx1]
323
+
324
+ if b < 0:
325
+ idx2 = -b - 1
326
+ size2 = 1
327
+ else:
328
+ idx2 = n + b - 1
329
+ size2 = cluster_sizes[idx2]
330
+
331
+ Z[i, 0] = idx1
332
+ Z[i, 1] = idx2
333
+ Z[i, 2] = height[i]
334
+ Z[i, 3] = size1 + size2
335
+
336
+ # 更新 cluster_sizes,用于后续簇
337
+ cluster_sizes = np.append(cluster_sizes, size1 + size2)
338
+
339
+ return Z
340
+
262
341
  class Cluster:
263
342
  def __init__(self,
264
343
  matrix,
@@ -358,12 +437,27 @@ class Cluster:
358
437
  try:
359
438
  # Map our method names to fastcluster's expected method names
360
439
  fastcluster_method = self._map_method_name(self.clustering_method)
361
- linkage_matrix = linkage(self.condensed_matrix, method=fastcluster_method)
362
-
440
+
441
+ if self.clustering_method == "ward_d" or self.clustering_method == "ward":
442
+ fastcluster_r = importr("fastcluster")
443
+
444
+ # 将 full_matrix 转换为 R 矩阵(直接从 Python 数组创建),避免 rpy2 对大向量长度出错
445
+ # 用‘F’强制按列展开,符合 R 的内存布局(列优先)
446
+ full_matrix_r = ro.r.matrix(ro.FloatVector(self.full_matrix.flatten('F')),
447
+ nrow=self.full_matrix.shape[0], ncol=self.full_matrix.shape[1])
448
+ r_om = ro.r['as.dist'](full_matrix_r)
449
+
450
+ linkage_matrix = fastcluster_r.hclust(r_om, method="ward.D")
451
+
452
+ linkage_matrix = _hclust_to_linkage_matrix(linkage_matrix)
453
+
454
+ else:
455
+ linkage_matrix = linkage(self.condensed_matrix, method=fastcluster_method)
456
+
363
457
  # Apply Ward D correction if needed (divide distances by 2 for classic Ward)
364
- if self.clustering_method == "ward_d":
365
- linkage_matrix = self._apply_ward_d_correction(linkage_matrix)
366
-
458
+ # if self.clustering_method == "ward_d":
459
+ # linkage_matrix = self._apply_ward_d_correction(linkage_matrix)
460
+
367
461
  except Exception as e:
368
462
  raise RuntimeError(
369
463
  f"Failed to compute linkage with method '{self.clustering_method}'. "
@@ -1080,5 +1174,105 @@ class ClusterResults:
1080
1174
  save_and_show_results(save_as, dpi)
1081
1175
 
1082
1176
 
1083
-
1084
-
1177
+ # For xinyi's test, because she can't debug in Jupyter :
1178
+ # Traceback (most recent call last):
1179
+ # File "/Applications/PyCharm.app/Contents/plugins/python-ce/helpers/pydev/_pydevd_bundle/pydevd_comm.py", line 736, in make_thread_stack_str
1180
+ # append('file="%s" line="%s">' % (make_valid_xml_value(my_file), lineno))
1181
+ # File "/Applications/PyCharm.app/Contents/plugins/python-ce/helpers/pydev/_pydevd_bundle/pydevd_xml.py", line 36, in make_valid_xml_value
1182
+ # return s.replace("&", "&amp;").replace('<', '&lt;').replace('>', '&gt;').replace('"', '&quot;')
1183
+ # AttributeError: 'tuple' object has no attribute 'replace'
1184
+
1185
+ if __name__ == '__main__':
1186
+ # Import necessary libraries
1187
+ # Your calling code (e.g., in a script or notebook)
1188
+
1189
+ from sequenzo import * # Import the package, give it a short alias
1190
+ import pandas as pd # Data manipulation
1191
+ import numpy as np
1192
+
1193
+ # List all the available datasets in Sequenzo
1194
+ # Now access functions using the alias:
1195
+ print('Available datasets in Sequenzo: ', list_datasets())
1196
+
1197
+ # Load the data that we would like to explore in this tutorial
1198
+ # `df` is the short for `dataframe`, which is a common variable name for a dataset
1199
+ # df = load_dataset('country_co2_emissions')
1200
+ df = load_dataset('mvad')
1201
+
1202
+ # 时间列表
1203
+ time_list = ['Jul.93', 'Aug.93', 'Sep.93', 'Oct.93', 'Nov.93', 'Dec.93',
1204
+ 'Jan.94', 'Feb.94', 'Mar.94', 'Apr.94', 'May.94', 'Jun.94', 'Jul.94',
1205
+ 'Aug.94', 'Sep.94', 'Oct.94', 'Nov.94', 'Dec.94', 'Jan.95', 'Feb.95',
1206
+ 'Mar.95', 'Apr.95', 'May.95', 'Jun.95', 'Jul.95', 'Aug.95', 'Sep.95',
1207
+ 'Oct.95', 'Nov.95', 'Dec.95', 'Jan.96', 'Feb.96', 'Mar.96', 'Apr.96',
1208
+ 'May.96', 'Jun.96', 'Jul.96', 'Aug.96', 'Sep.96', 'Oct.96', 'Nov.96',
1209
+ 'Dec.96', 'Jan.97', 'Feb.97', 'Mar.97', 'Apr.97', 'May.97', 'Jun.97',
1210
+ 'Jul.97', 'Aug.97', 'Sep.97', 'Oct.97', 'Nov.97', 'Dec.97', 'Jan.98',
1211
+ 'Feb.98', 'Mar.98', 'Apr.98', 'May.98', 'Jun.98', 'Jul.98', 'Aug.98',
1212
+ 'Sep.98', 'Oct.98', 'Nov.98', 'Dec.98', 'Jan.99', 'Feb.99', 'Mar.99',
1213
+ 'Apr.99', 'May.99', 'Jun.99']
1214
+
1215
+ # 方法1: 使用pandas获取所有唯一值
1216
+ time_states_df = df[time_list]
1217
+ all_unique_states = set()
1218
+
1219
+ for col in time_list:
1220
+ unique_vals = df[col].dropna().unique() # Remove NaN values
1221
+ all_unique_states.update(unique_vals)
1222
+
1223
+ # 转换为排序的列表
1224
+ states = sorted(list(all_unique_states))
1225
+ print("All unique states:")
1226
+ for i, state in enumerate(states, 1):
1227
+ print(f"{i:2d}. {state}")
1228
+
1229
+ print(f"\nstates list:")
1230
+ print(f"states = {states}")
1231
+
1232
+ # Create a SequenceData object
1233
+
1234
+ # Define the time-span variable
1235
+ time_list = ['Jul.93', 'Aug.93', 'Sep.93', 'Oct.93', 'Nov.93', 'Dec.93',
1236
+ 'Jan.94', 'Feb.94', 'Mar.94', 'Apr.94', 'May.94', 'Jun.94', 'Jul.94',
1237
+ 'Aug.94', 'Sep.94', 'Oct.94', 'Nov.94', 'Dec.94', 'Jan.95', 'Feb.95',
1238
+ 'Mar.95', 'Apr.95', 'May.95', 'Jun.95', 'Jul.95', 'Aug.95', 'Sep.95',
1239
+ 'Oct.95', 'Nov.95', 'Dec.95', 'Jan.96', 'Feb.96', 'Mar.96', 'Apr.96',
1240
+ 'May.96', 'Jun.96', 'Jul.96', 'Aug.96', 'Sep.96', 'Oct.96', 'Nov.96',
1241
+ 'Dec.96', 'Jan.97', 'Feb.97', 'Mar.97', 'Apr.97', 'May.97', 'Jun.97',
1242
+ 'Jul.97', 'Aug.97', 'Sep.97', 'Oct.97', 'Nov.97', 'Dec.97', 'Jan.98',
1243
+ 'Feb.98', 'Mar.98', 'Apr.98', 'May.98', 'Jun.98', 'Jul.98', 'Aug.98',
1244
+ 'Sep.98', 'Oct.98', 'Nov.98', 'Dec.98', 'Jan.99', 'Feb.99', 'Mar.99',
1245
+ 'Apr.99', 'May.99', 'Jun.99']
1246
+
1247
+ states = ['FE', 'HE', 'employment', 'joblessness', 'school', 'training']
1248
+ labels = ['further education', 'higher education', 'employment', 'joblessness', 'school', 'training']
1249
+
1250
+ # TODO: write a try and error: if no such a parameter, then ask to pass the right ones
1251
+ # sequence_data = SequenceData(df, time=time, time_type="year", id_col="country", ids=df['country'].values, states=states)
1252
+
1253
+ sequence_data = SequenceData(df,
1254
+ time=time_list,
1255
+ id_col="id",
1256
+ states=states,
1257
+ labels=labels,
1258
+ )
1259
+
1260
+ om = get_distance_matrix(sequence_data,
1261
+ method="OM",
1262
+ sm="CONSTANT",
1263
+ indel=1)
1264
+
1265
+ cluster = Cluster(om, sequence_data.ids, clustering_method='ward_d')
1266
+ cluster.plot_dendrogram(xlabel="Individuals", ylabel="Distance")
1267
+
1268
+ # Create a ClusterQuality object to evaluate clustering quality
1269
+ cluster_quality = ClusterQuality(cluster)
1270
+ cluster_quality.compute_cluster_quality_scores()
1271
+ cluster_quality.plot_cqi_scores(norm='zscore')
1272
+ summary_table = cluster_quality.get_cqi_table()
1273
+ print(summary_table)
1274
+
1275
+ table = cluster_quality.get_cluster_range_table()
1276
+ # table.to_csv("cluster_quality_table.csv")
1277
+
1278
+ print(table)
@@ -325,7 +325,23 @@ class SequenceData:
325
325
  if non_missing_states <= 20:
326
326
  non_missing_color_list = sns.color_palette("Spectral", non_missing_states)
327
327
  else:
328
- non_missing_color_list = sns.color_palette("cubehelix", non_missing_states)
328
+ # Use a more elegant color palette for many states - combination of viridis and pastel colors
329
+ if non_missing_states <= 40:
330
+ # Use viridis for up to 40 states (more colorful than cubehelix)
331
+ non_missing_color_list = sns.color_palette("viridis", non_missing_states)
332
+ else:
333
+ # For very large state counts, use a custom palette combining multiple schemes
334
+ viridis_colors = sns.color_palette("viridis", min(non_missing_states // 2, 20))
335
+ pastel_colors = sns.color_palette("Set3", min(non_missing_states // 2, 12))
336
+ tab20_colors = sns.color_palette("tab20", min(non_missing_states // 3, 20))
337
+
338
+ # Combine and extend the palette
339
+ combined_colors = viridis_colors + pastel_colors + tab20_colors
340
+ # If we need more colors, cycle through the combined palette
341
+ while len(combined_colors) < non_missing_states:
342
+ combined_colors.extend(combined_colors[:min(len(combined_colors), non_missing_states - len(combined_colors))])
343
+
344
+ non_missing_color_list = combined_colors[:non_missing_states]
329
345
 
330
346
  if reverse_colors:
331
347
  non_missing_color_list = list(reversed(non_missing_color_list))
@@ -342,7 +358,23 @@ class SequenceData:
342
358
  if num_states <= 20:
343
359
  color_list = sns.color_palette("Spectral", num_states)
344
360
  else:
345
- color_list = sns.color_palette("cubehelix", num_states)
361
+ # Use a more elegant color palette for many states - combination of viridis and pastel colors
362
+ if num_states <= 40:
363
+ # Use viridis for up to 40 states (more colorful than cubehelix)
364
+ color_list = sns.color_palette("viridis", num_states)
365
+ else:
366
+ # For very large state counts, use a custom palette combining multiple schemes
367
+ viridis_colors = sns.color_palette("viridis", min(num_states // 2, 20))
368
+ pastel_colors = sns.color_palette("Set3", min(num_states // 2, 12))
369
+ tab20_colors = sns.color_palette("tab20", min(num_states // 3, 20))
370
+
371
+ # Combine and extend the palette
372
+ combined_colors = viridis_colors + pastel_colors + tab20_colors
373
+ # If we need more colors, cycle through the combined palette
374
+ while len(combined_colors) < num_states:
375
+ combined_colors.extend(combined_colors[:min(len(combined_colors), num_states - len(combined_colors))])
376
+
377
+ color_list = combined_colors[:num_states]
346
378
 
347
379
  if reverse_colors:
348
380
  color_list = list(reversed(color_list))
@@ -142,7 +142,7 @@ def get_substitution_cost_matrix(seqdata, method, cval=None, miss_cost=None, tim
142
142
  # ================================
143
143
  if method in ["INDELS", "INDELSLOG"]:
144
144
  if time_varying:
145
- indels = get_cross_sectional_entropy(seqdata)['Frequencies']
145
+ indels = get_cross_sectional_entropy(seqdata, return_format="dict")['Frequencies']
146
146
  else:
147
147
  ww = seqdata.weights
148
148
  if ww is None:
@@ -4,6 +4,7 @@
4
4
  #include <cmath>
5
5
  #include <iostream>
6
6
  #include "utils.h"
7
+ #include "dp_utils.h"
7
8
  #ifdef _OPENMP
8
9
  #include <omp.h>
9
10
  #endif
@@ -104,26 +105,11 @@ public:
104
105
 
105
106
  py::array_t<double> compute_all_distances() {
106
107
  try {
107
- auto buffer = dist_matrix.mutable_unchecked<2>();
108
-
109
- #pragma omp parallel
110
- {
111
- #pragma omp for schedule(guided)
112
- for (int i = 0; i < nseq; i++) {
113
- for (int j = i; j < nseq; j++) {
114
- buffer(i, j) = compute_distance(i, j);
115
- }
116
- }
117
- }
118
-
119
- #pragma omp for schedule(static)
120
- for (int i = 0; i < nseq; ++i) {
121
- for (int j = i + 1; j < nseq; ++j) {
122
- buffer(j, i) = buffer(i, j);
123
- }
124
- }
125
-
126
- return dist_matrix;
108
+ return dp_utils::compute_all_distances_simple(
109
+ nseq,
110
+ dist_matrix,
111
+ [this](int i, int j){ return this->compute_distance(i, j); }
112
+ );
127
113
  } catch (const std::exception& e) {
128
114
  py::print("Error in compute_all_distances: ", e.what());
129
115
  throw;
@@ -132,23 +118,13 @@ public:
132
118
 
133
119
  py::array_t<double> compute_refseq_distances() {
134
120
  try {
135
- auto buffer = refdist_matrix.mutable_unchecked<2>();
136
-
137
- #pragma omp parallel
138
- {
139
- #pragma omp for schedule(guided)
140
- for (int rseq = rseq1; rseq < rseq2; rseq ++) {
141
- for (int is = 0; is < nseq; is ++) {
142
- if(is == rseq){
143
- buffer(is, rseq-rseq1) = 0;
144
- }else{
145
- buffer(is, rseq-rseq1) = compute_distance(is, rseq);
146
- }
147
- }
148
- }
149
- }
150
-
151
- return refdist_matrix;
121
+ return dp_utils::compute_refseq_distances_simple(
122
+ nseq,
123
+ rseq1,
124
+ rseq2,
125
+ refdist_matrix,
126
+ [this](int is, int rseq){ return this->compute_distance(is, rseq); }
127
+ );
152
128
  } catch (const std::exception& e) {
153
129
  py::print("Error in compute_all_distances: ", e.what());
154
130
  throw;
@@ -3,6 +3,7 @@
3
3
  #include <vector>
4
4
  #include <iostream>
5
5
  #include "utils.h"
6
+ #include "dp_utils.h"
6
7
 
7
8
  namespace py = pybind11;
8
9
 
@@ -71,26 +72,11 @@ public:
71
72
 
72
73
  py::array_t<double> compute_all_distances() {
73
74
  try {
74
- auto buffer = dist_matrix.mutable_unchecked<2>();
75
-
76
- #pragma omp parallel
77
- {
78
- #pragma omp for schedule(static)
79
- for (int i = 0; i < nseq; i++) {
80
- for (int j = i; j < nseq; j++) {
81
- buffer(i, j) = compute_distance(i, j);
82
- }
83
- }
84
- }
85
-
86
- #pragma omp for schedule(static)
87
- for (int i = 0; i < nseq; ++i) {
88
- for (int j = i + 1; j < nseq; ++j) {
89
- buffer(j, i) = buffer(i, j);
90
- }
91
- }
92
-
93
- return dist_matrix;
75
+ return dp_utils::compute_all_distances_simple(
76
+ nseq,
77
+ dist_matrix,
78
+ [this](int i, int j){ return this->compute_distance(i, j); }
79
+ );
94
80
  } catch (const std::exception& e) {
95
81
  py::print("Error in compute_all_distances: ", e.what());
96
82
  throw;
@@ -99,23 +85,13 @@ public:
99
85
 
100
86
  py::array_t<double> compute_refseq_distances() {
101
87
  try {
102
- auto buffer = refdist_matrix.mutable_unchecked<2>();
103
-
104
- #pragma omp parallel
105
- {
106
- #pragma omp for schedule(guided)
107
- for (int rseq = rseq1; rseq < rseq2; rseq ++) {
108
- for (int is = 0; is < nseq; is ++) {
109
- if(is == rseq){
110
- buffer(is, rseq-rseq1) = 0;
111
- }else{
112
- buffer(is, rseq-rseq1) = compute_distance(is, rseq);
113
- }
114
- }
115
- }
116
- }
117
-
118
- return refdist_matrix;
88
+ return dp_utils::compute_refseq_distances_simple(
89
+ nseq,
90
+ rseq1,
91
+ rseq2,
92
+ refdist_matrix,
93
+ [this](int is, int rseq){ return this->compute_distance(is, rseq); }
94
+ );
119
95
  } catch (const std::exception& e) {
120
96
  py::print("Error in compute_all_distances: ", e.what());
121
97
  throw;
@@ -5,6 +5,7 @@
5
5
  #include <cmath>
6
6
  #include <iostream>
7
7
  #include "utils.h"
8
+ #include "dp_utils.h"
8
9
  #ifdef _OPENMP
9
10
  #include <omp.h>
10
11
  #endif
@@ -71,24 +72,6 @@ public:
71
72
  }
72
73
  }
73
74
 
74
- // 对齐分配函数
75
- #ifdef _WIN32
76
- inline double* aligned_alloc_double(size_t size, size_t align=64) {
77
- return reinterpret_cast<double*>(_aligned_malloc(size * sizeof(double), align));
78
- }
79
- inline void aligned_free_double(double* ptr) {
80
- _aligned_free(ptr);
81
- }
82
- #else
83
- inline double* aligned_alloc_double(size_t size, size_t align=64) {
84
- void* ptr = nullptr;
85
- if(posix_memalign(&ptr, align, size*sizeof(double)) != 0) throw std::bad_alloc();
86
- return reinterpret_cast<double*>(ptr);
87
- }
88
- inline void aligned_free_double(double* ptr) { free(ptr); }
89
- #endif
90
-
91
-
92
75
  double compute_distance(int is, int js, double* prev, double* curr) {
93
76
  try {
94
77
  auto ptr_len = seqlength.unchecked<1>();
@@ -198,34 +181,14 @@ public:
198
181
 
199
182
  py::array_t<double> compute_all_distances() {
200
183
  try {
201
- auto buffer = dist_matrix.mutable_unchecked<2>();
202
-
203
- #pragma omp parallel
204
- {
205
- // 每线程独立分配 prev/curr
206
- double* prev = aligned_alloc_double(fmatsize);
207
- double* curr = aligned_alloc_double(fmatsize);
208
-
209
- #pragma omp for schedule(static)
210
- for (int i = 0; i < nseq; i++) {
211
- for (int j = i; j < nseq; j++) {
212
- buffer(i, j) = compute_distance(i, j, prev, curr);
213
- }
214
- }
215
-
216
- aligned_free_double(prev);
217
- aligned_free_double(curr);
218
- }
219
-
220
- // 对称填充
221
- #pragma omp parallel for schedule(static)
222
- for(int i = 0; i < nseq; i++) {
223
- for(int j = i+1; j < nseq; j++) {
224
- buffer(j, i) = buffer(i, j);
184
+ return dp_utils::compute_all_distances(
185
+ nseq,
186
+ fmatsize,
187
+ dist_matrix,
188
+ [this](int i, int j, double* prev, double* curr) {
189
+ return this->compute_distance(i, j, prev, curr);
225
190
  }
226
- }
227
-
228
- return dist_matrix;
191
+ );
229
192
  } catch (const std::exception& e) {
230
193
  py::print("Error in compute_all_distances: ", e.what());
231
194
  throw;
@@ -238,8 +201,8 @@ public:
238
201
 
239
202
  #pragma omp parallel
240
203
  {
241
- double* prev = aligned_alloc_double(2 * seqlen + 1);
242
- double* curr = aligned_alloc_double(2 * seqlen + 1);
204
+ double* prev = dp_utils::aligned_alloc_double(static_cast<size_t>(fmatsize));
205
+ double* curr = dp_utils::aligned_alloc_double(static_cast<size_t>(fmatsize));
243
206
 
244
207
  #pragma omp for schedule(static)
245
208
  for (int rseq = rseq1; rseq < rseq2; rseq ++) {
@@ -252,6 +215,8 @@ public:
252
215
  buffer(is, rseq - rseq1) = cmpres;
253
216
  }
254
217
  }
218
+ dp_utils::aligned_free_double(prev);
219
+ dp_utils::aligned_free_double(curr);
255
220
  }
256
221
 
257
222
  return refdist_matrix;