sequenzo 0.1.17__cp39-cp39-win_amd64.whl → 0.1.18__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sequenzo might be problematic. Click here for more details.

Files changed (101) hide show
  1. sequenzo/__init__.py +25 -1
  2. sequenzo/big_data/clara/clara.py +1 -1
  3. sequenzo/big_data/clara/utils/get_weighted_diss.c +156 -156
  4. sequenzo/big_data/clara/utils/get_weighted_diss.cp39-win_amd64.pyd +0 -0
  5. sequenzo/clustering/clustering_c_code.cp39-win_amd64.pyd +0 -0
  6. sequenzo/clustering/hierarchical_clustering.py +202 -8
  7. sequenzo/define_sequence_data.py +34 -2
  8. sequenzo/dissimilarity_measures/c_code.cp39-win_amd64.pyd +0 -0
  9. sequenzo/dissimilarity_measures/get_substitution_cost_matrix.py +1 -1
  10. sequenzo/dissimilarity_measures/src/DHDdistance.cpp +13 -37
  11. sequenzo/dissimilarity_measures/src/LCPdistance.cpp +13 -37
  12. sequenzo/dissimilarity_measures/src/OMdistance.cpp +12 -47
  13. sequenzo/dissimilarity_measures/src/OMspellDistance.cpp +103 -67
  14. sequenzo/dissimilarity_measures/src/dp_utils.h +160 -0
  15. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_arithmetic.hpp +41 -16
  16. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_complex.hpp +4 -0
  17. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_details.hpp +7 -0
  18. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_logical.hpp +10 -0
  19. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_math.hpp +127 -43
  20. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_memory.hpp +30 -2
  21. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_swizzle.hpp +174 -0
  22. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/common/xsimd_common_trigo.hpp +14 -5
  23. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx.hpp +111 -54
  24. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx2.hpp +131 -9
  25. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512bw.hpp +11 -113
  26. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512dq.hpp +39 -7
  27. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512f.hpp +336 -30
  28. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vbmi.hpp +9 -37
  29. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_avx512vbmi2.hpp +58 -0
  30. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_common.hpp +1 -0
  31. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_common_fwd.hpp +35 -2
  32. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_constants.hpp +3 -1
  33. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_emulated.hpp +17 -0
  34. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_avx.hpp +13 -0
  35. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma3_sse.hpp +18 -0
  36. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_fma4.hpp +13 -0
  37. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_isa.hpp +8 -0
  38. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_neon.hpp +363 -34
  39. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_neon64.hpp +7 -0
  40. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_rvv.hpp +13 -0
  41. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_scalar.hpp +41 -4
  42. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse2.hpp +252 -16
  43. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sse3.hpp +9 -0
  44. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_ssse3.hpp +12 -1
  45. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_sve.hpp +7 -0
  46. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_vsx.hpp +892 -0
  47. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/arch/xsimd_wasm.hpp +78 -1
  48. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_arch.hpp +3 -1
  49. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_config.hpp +13 -2
  50. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_cpuid.hpp +5 -0
  51. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/config/xsimd_inline.hpp +5 -1
  52. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_all_registers.hpp +2 -0
  53. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_api.hpp +64 -1
  54. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_batch.hpp +36 -0
  55. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_rvv_register.hpp +40 -31
  56. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_traits.hpp +8 -0
  57. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/types/xsimd_vsx_register.hpp +77 -0
  58. sequenzo/dissimilarity_measures/src/xsimd/include/xsimd/xsimd.hpp +6 -0
  59. sequenzo/dissimilarity_measures/src/xsimd/test/test_basic_math.cpp +6 -0
  60. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch.cpp +54 -2
  61. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_bool.cpp +8 -0
  62. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_cast.cpp +11 -4
  63. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_complex.cpp +18 -0
  64. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_int.cpp +8 -14
  65. sequenzo/dissimilarity_measures/src/xsimd/test/test_batch_manip.cpp +216 -173
  66. sequenzo/dissimilarity_measures/src/xsimd/test/test_load_store.cpp +6 -0
  67. sequenzo/dissimilarity_measures/src/xsimd/test/test_memory.cpp +1 -1
  68. sequenzo/dissimilarity_measures/src/xsimd/test/test_power.cpp +7 -4
  69. sequenzo/dissimilarity_measures/src/xsimd/test/test_select.cpp +6 -2
  70. sequenzo/dissimilarity_measures/src/xsimd/test/test_shuffle.cpp +32 -18
  71. sequenzo/dissimilarity_measures/src/xsimd/test/test_utils.hpp +21 -24
  72. sequenzo/dissimilarity_measures/src/xsimd/test/test_xsimd_api.cpp +69 -9
  73. sequenzo/dissimilarity_measures/utils/get_sm_trate_substitution_cost_matrix.c +156 -156
  74. sequenzo/dissimilarity_measures/utils/get_sm_trate_substitution_cost_matrix.cp39-win_amd64.pyd +0 -0
  75. sequenzo/dissimilarity_measures/utils/seqconc.c +156 -156
  76. sequenzo/dissimilarity_measures/utils/seqconc.cp39-win_amd64.pyd +0 -0
  77. sequenzo/dissimilarity_measures/utils/seqdss.c +156 -156
  78. sequenzo/dissimilarity_measures/utils/seqdss.cp39-win_amd64.pyd +0 -0
  79. sequenzo/dissimilarity_measures/utils/seqdur.c +156 -156
  80. sequenzo/dissimilarity_measures/utils/seqdur.cp39-win_amd64.pyd +0 -0
  81. sequenzo/dissimilarity_measures/utils/seqlength.c +156 -156
  82. sequenzo/dissimilarity_measures/utils/seqlength.cp39-win_amd64.pyd +0 -0
  83. sequenzo/sequence_characteristics/__init__.py +4 -0
  84. sequenzo/sequence_characteristics/complexity_index.py +17 -57
  85. sequenzo/sequence_characteristics/overall_cross_sectional_entropy.py +177 -111
  86. sequenzo/sequence_characteristics/plot_characteristics.py +30 -11
  87. sequenzo/sequence_characteristics/simple_characteristics.py +1 -0
  88. sequenzo/sequence_characteristics/state_frequencies_and_entropy_per_sequence.py +9 -3
  89. sequenzo/sequence_characteristics/turbulence.py +47 -67
  90. sequenzo/sequence_characteristics/variance_of_spell_durations.py +19 -9
  91. sequenzo/sequence_characteristics/within_sequence_entropy.py +5 -58
  92. sequenzo/visualization/plot_sequence_index.py +58 -35
  93. sequenzo/visualization/plot_state_distribution.py +57 -36
  94. sequenzo/with_event_history_analysis/__init__.py +35 -0
  95. sequenzo/with_event_history_analysis/sequence_analysis_multi_state_model.py +850 -0
  96. sequenzo/with_event_history_analysis/sequence_history_analysis.py +283 -0
  97. {sequenzo-0.1.17.dist-info → sequenzo-0.1.18.dist-info}/METADATA +7 -6
  98. {sequenzo-0.1.17.dist-info → sequenzo-0.1.18.dist-info}/RECORD +101 -94
  99. {sequenzo-0.1.17.dist-info → sequenzo-0.1.18.dist-info}/WHEEL +0 -0
  100. {sequenzo-0.1.17.dist-info → sequenzo-0.1.18.dist-info}/licenses/LICENSE +0 -0
  101. {sequenzo-0.1.17.dist-info → sequenzo-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  """
2
- @Author : 李欣怡
3
- @File : seqST.py
2
+ @Author : Xinyi Li, Yuqi Liang
3
+ @File : turbulence.py
4
4
  @Time : 2025/9/24 14:09
5
5
  @Desc : Computes the sequence turbulence measure
6
6
 
@@ -27,7 +27,7 @@ def turb(x):
27
27
  Tux = np.log2(phi * ((s2max + 1) / (s2_tx + 1)))
28
28
  return Tux
29
29
 
30
- def get_turbulence(seqdata, norm=False, silent=True, type=1):
30
+ def get_turbulence(seqdata, norm=False, silent=True, type=1, id_as_column=True):
31
31
  """
32
32
  Computes the sequence turbulence measure
33
33
 
@@ -41,15 +41,18 @@ def get_turbulence(seqdata, norm=False, silent=True, type=1):
41
41
  If True, suppresses the output messages.
42
42
  type : int, default 1
43
43
  Type of spell duration variance to be used. Can be either 1 or 2.
44
+ id_as_column : bool, default True
45
+ If True, the ID will be included as a separate column instead of as the index.
44
46
 
45
47
  Returns
46
48
  -------
47
49
  pd.DataFrame
48
50
  A DataFrame with one column containing the turbulence measure for each sequence.
51
+ If id_as_column=True, also includes an ID column.
49
52
  """
50
53
 
51
54
  if not hasattr(seqdata, 'seqdata'):
52
- raise ValueError(" [!] data is NOT a sequence object, see SequenceData function to create one.")
55
+ raise ValueError("[!] data is NOT a sequence object, see SequenceData function to create one.")
53
56
 
54
57
  if not silent:
55
58
  print(f" - extracting symbols and durations ...")
@@ -70,7 +73,22 @@ def get_turbulence(seqdata, norm=False, silent=True, type=1):
70
73
  s2_tx_max = s2_tx['vmax']
71
74
  s2_tx = s2_tx['result']
72
75
 
73
- tmp = pd.DataFrame({'phi': phi.flatten(), 's2_tx': s2_tx, 's2max': s2_tx_max})
76
+ # Extract phi values and ensure 1D array
77
+ if hasattr(phi, 'iloc'):
78
+ phi_values = phi.iloc[:, 0].values
79
+ elif hasattr(phi, 'values'):
80
+ phi_values = phi.values
81
+ else:
82
+ phi_values = phi
83
+
84
+ # Ensure phi_values is 1D
85
+ phi_values = np.asarray(phi_values).flatten()
86
+
87
+ # Extract 1D arrays from s2_tx and s2_tx_max DataFrames
88
+ s2_tx_values = s2_tx.iloc[:, 1].values if hasattr(s2_tx, 'iloc') else np.asarray(s2_tx).flatten()
89
+ s2_tx_max_values = s2_tx_max.iloc[:, 1].values if hasattr(s2_tx_max, 'iloc') else np.asarray(s2_tx_max).flatten()
90
+
91
+ tmp = pd.DataFrame({'phi': phi_values, 's2_tx': s2_tx_values, 's2max': s2_tx_max_values})
74
92
  Tx = tmp.apply(lambda row: turb([row['phi'], row['s2_tx'], row['s2max']]), axis=1).to_numpy()
75
93
 
76
94
  if norm:
@@ -95,7 +113,7 @@ def get_turbulence(seqdata, norm=False, silent=True, type=1):
95
113
  else:
96
114
  turb_phi = 2
97
115
 
98
- if turb_phi.isna().any().any():
116
+ if hasattr(turb_phi, 'isna') and turb_phi.isna().any().any():
99
117
  turb_phi = 1e15 # 使用有限大数值避免转换警告
100
118
  print("[!] phi set as max float due to exceeding value when computing max turbulence.")
101
119
 
@@ -103,7 +121,19 @@ def get_turbulence(seqdata, norm=False, silent=True, type=1):
103
121
  turb_s2_max = turb_s2['vmax']
104
122
  turb_s2 = turb_s2['result']
105
123
 
106
- tmp = pd.DataFrame({'phi': turb_phi.iloc[:, 0], 's2_tx': turb_s2, 's2max': turb_s2_max})
124
+ # Extract turb_phi values and ensure 1D
125
+ if hasattr(turb_phi, 'iloc'):
126
+ phi_value = turb_phi.iloc[:, 0].values
127
+ else:
128
+ phi_value = [turb_phi]
129
+
130
+ phi_value = np.asarray(phi_value).flatten()
131
+
132
+ # Extract 1D arrays from turb_s2 and turb_s2_max DataFrames
133
+ turb_s2_values = turb_s2.iloc[:, 1].values if hasattr(turb_s2, 'iloc') else np.asarray(turb_s2).flatten()
134
+ turb_s2_max_values = turb_s2_max.iloc[:, 1].values if hasattr(turb_s2_max, 'iloc') else np.asarray(turb_s2_max).flatten()
135
+
136
+ tmp = pd.DataFrame({'phi': phi_value, 's2_tx': turb_s2_values, 's2max': turb_s2_max_values})
107
137
  maxT = tmp.apply(lambda row: turb([row['phi'], row['s2_tx'], row['s2max']]), axis=1).to_numpy()
108
138
 
109
139
  Tx_zero = np.where(Tx < 1)[0]
@@ -112,64 +142,14 @@ def get_turbulence(seqdata, norm=False, silent=True, type=1):
112
142
  Tx[Tx_zero, :] = 0
113
143
 
114
144
  Tx_df = pd.DataFrame(Tx, index=seqdata.seqdata.index, columns=['Turbulence'])
115
- return Tx_df
116
-
117
- if __name__ == "__main__":
118
-
119
- from sequenzo import *
120
145
 
121
- # ===============================
122
- # Sohee
123
- # ===============================
124
- # df = pd.read_csv('D:/college/research/QiQi/sequenzo/data_and_output/orignal data/sohee/sequence_data.csv')
125
- # # df = pd.read_csv('/Users/lei/Documents/Sequenzo_all_folders/sequence_data_sources/sohee/sequence_data.csv')
126
- # time_list = list(df.columns)[1:133]
127
- # states = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
128
- # # states = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
129
- # labels = ['FT+WC', 'FT+BC', 'PT+WC', 'PT+BC', 'U', 'OLF']
130
- # sequence_data = SequenceData(df, time=time_list, states=states, labels=labels, id_col="PID")
131
- # res = get_turbulence(sequence_data)
132
-
133
- # ===============================
134
- # kass
135
- # ===============================
136
- # df = pd.read_csv('D:/college/research/QiQi/sequenzo/files/orignal data/kass/wide_civil_final_df.csv')
137
- # time_list = list(df.columns)[1:]
138
- # states = ['Extensive Warfare', 'Limited Violence', 'No Violence', 'Pervasive Warfare', 'Prolonged Warfare',
139
- # 'Serious Violence', 'Serious Warfare', 'Sporadic Violence', 'Technological Warfare', 'Total Warfare']
140
- # sequence_data = SequenceData(df, time=time_list, states=states, id_col="COUNTRY")
141
- # res = seqST(sequence_data)
142
-
143
- # ===============================
144
- # CO2
145
- # ===============================
146
- # df = pd.read_csv("D:/country_co2_emissions_missing.csv")
147
- df = load_dataset('country_co2_emissions_local_deciles')
148
- df.to_csv("D:/country_co2_emissions_local_deciles.csv", index=False)
149
- _time = list(df.columns)[1:]
150
- # states = ['Very Low', 'Low', 'Middle', 'High', 'Very High']
151
- states = ['D1 (Very Low)', 'D10 (Very High)', 'D2', 'D3', 'D4', 'D5', 'D6', 'D7', 'D8', 'D9']
152
- sequence_data = SequenceData(df, time=_time, id_col="country", states=states)
153
- res = get_turbulence(sequence_data, norm=True, type=2)
154
-
155
- # ===============================
156
- # detailed
157
- # ===============================
158
- # df = pd.read_csv("D:/college/research/QiQi/sequenzo/data_and_output/sampled_data_sets/detailed_data/sampled_1000_data.csv")
159
- # _time = list(df.columns)[4:]
160
- # states = ['data', 'data & intensive math', 'hardware', 'research', 'software', 'software & hardware', 'support & test']
161
- # sequence_data = SequenceData(df[['worker_id', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C10']],
162
- # time=_time, id_col="worker_id", states=states)
163
- # res = seqST(sequence_data, norm=False, type=2)
164
-
165
- # ===============================
166
- # broad
167
- # ===============================
168
- # df = pd.read_csv("D:/college/research/QiQi/sequenzo/data_and_output/sampled_data_sets/broad_data/sampled_1000_data.csv")
169
- # _time = list(df.columns)[4:]
170
- # states = ['Non-computing', 'Non-technical computing', 'Technical computing']
171
- # sequence_data = SequenceData(df[['worker_id', 'C1', 'C2', 'C3', 'C4', 'C5']],
172
- # time=_time, id_col="worker_id", states=states)
173
- # res = seqST(sequence_data, norm=True, type=2)
174
-
175
- print(res)
146
+ # Handle ID display options
147
+ if id_as_column:
148
+ # Add ID as a separate column and reset index to numeric
149
+ Tx_df['ID'] = Tx_df.index
150
+ Tx_df = Tx_df[['ID', 'Turbulence']].reset_index(drop=True)
151
+ else:
152
+ # Always set index name to 'ID' for clarity
153
+ Tx_df.index.name = 'ID'
154
+
155
+ return Tx_df
@@ -1,5 +1,5 @@
1
1
  """
2
- @Author : 李欣怡
2
+ @Author : Xinyi Li, Yuqi Liang
3
3
  @File : variance_of_spell_durations.py
4
4
  @Time : 2025/9/24 14:22
5
5
  @Desc : Variance of spell durations of individual state sequences.
@@ -22,9 +22,9 @@ from .simple_characteristics import cut_prefix
22
22
 
23
23
  def get_spell_duration_variance(seqdata, type=1):
24
24
  if not hasattr(seqdata, 'seqdata'):
25
- raise ValueError(" [!] data is NOT a sequence object, see SequenceData function to create one.")
25
+ raise ValueError("[!] data is NOT a sequence object, see SequenceData function to create one.")
26
26
  if type not in [1, 2]:
27
- raise ValueError(" [!] type must be 1 or 2.")
27
+ raise ValueError("[!] type must be 1 or 2.")
28
28
 
29
29
  with open(os.devnull, 'w') as fnull:
30
30
  with redirect_stdout(fnull):
@@ -33,7 +33,7 @@ def get_spell_duration_variance(seqdata, type=1):
33
33
  lgth = seqlength(seqdata)
34
34
  dlgth = seqlength(dss)
35
35
  sdist = get_state_freq_and_entropy_per_seq(seqdata)
36
- nnvisit = (sdist==0).sum(axis=1)
36
+ nnvisit = (sdist.iloc[:, 1:]==0).sum(axis=1)
37
37
 
38
38
  def realvar(x):
39
39
  n = len(x)
@@ -57,7 +57,8 @@ def get_spell_duration_variance(seqdata, type=1):
57
57
  # ret = (np.nansum(ddur, axis=1) + nnvisit * (meand ** 2)) / (dlgth + nnvisit)
58
58
  ddur = pd.DataFrame(ddur.tolist())
59
59
  sum_sqdiff = np.nansum(ddur.to_numpy(), axis=1)
60
- ret = (sum_sqdiff + nnvisit.to_numpy() * (meand.to_numpy() ** 2)) / (dlgth + nnvisit.to_numpy())
60
+ ret_values = (sum_sqdiff + nnvisit.to_numpy() * (meand.to_numpy() ** 2)) / (dlgth + nnvisit.to_numpy())
61
+ ret = pd.Series(ret_values, index=meand.index)
61
62
 
62
63
  alph = seqdata.states.copy()
63
64
  alph_size = len(alph)
@@ -67,10 +68,19 @@ def get_spell_duration_variance(seqdata, type=1):
67
68
  maxnnv = np.where(dlgth == 1, alph_size - 1, alph_size - 2)
68
69
 
69
70
  meand_max = meand.to_numpy() * (dlgth + nnvisit.to_numpy()) / (dlgth + maxnnv)
70
- var_max = ((dlgth-1) * (1-meand_max)**2 + (lgth - dlgth + 1 - meand_max)**2 + maxnnv * meand_max**2) / (dlgth + maxnnv)
71
+ var_max_values = ((dlgth-1) * (1-meand_max)**2 + (lgth - dlgth + 1 - meand_max)**2 + maxnnv * meand_max**2) / (dlgth + maxnnv)
72
+ var_max = pd.Series(var_max_values, index=meand.index)
73
+
74
+ meand.index = seqdata.seqdata.index
75
+ ret.index = seqdata.seqdata.index
76
+ var_max.index = seqdata.seqdata.index
77
+
78
+ meand = meand.to_frame("meand")
79
+ ret = ret.to_frame("var_spell_dur")
80
+ var_max = var_max.to_frame("var_max")
71
81
 
72
82
  return {
73
- "meand": meand,
74
- "result": ret,
75
- "vmax": var_max
83
+ "meand": meand.reset_index().rename(columns={"index": "ID"}),
84
+ "result": ret.reset_index().rename(columns={"index": "ID"}),
85
+ "vmax": var_max.reset_index().rename(columns={"index": "ID"}),
76
86
  }
@@ -19,7 +19,7 @@ from .state_frequencies_and_entropy_per_sequence import get_state_freq_and_entro
19
19
 
20
20
  def get_within_sequence_entropy(seqdata, norm=True, base=np.e, silent=True):
21
21
  if not isinstance(seqdata, SequenceData):
22
- raise ValueError(" [!] data is NOT a sequence object, see SequenceData function to create one.")
22
+ raise ValueError("[!] data is NOT a sequence object, see SequenceData function to create one.")
23
23
 
24
24
  states = seqdata.states.copy()
25
25
 
@@ -29,68 +29,15 @@ def get_within_sequence_entropy(seqdata, norm=True, base=np.e, silent=True):
29
29
  with open(os.devnull, 'w') as fnull:
30
30
  with redirect_stdout(fnull):
31
31
  iseqtab = get_state_freq_and_entropy_per_seq(seqdata=seqdata)
32
+ iseqtab.index = seqdata.seqdata.index
32
33
 
33
- ient = iseqtab.apply(lambda row: entropy(row, base=base), axis=1)
34
+ ient = iseqtab.iloc[:, 1:].apply(lambda row: entropy(row, base=base), axis=1)
34
35
 
35
36
  if norm:
36
37
  maxent = np.log(len(states))
37
38
  ient = ient / maxent
38
39
 
39
- ient.columns = ['Entropy']
40
- ient.index = seqdata.seqdata.index
40
+ ient = pd.DataFrame(ient, index=seqdata.seqdata.index, columns=['Entropy'])
41
+ ient = ient.reset_index().rename(columns={'index': 'ID'})
41
42
 
42
43
  return ient
43
-
44
-
45
- if __name__ == "__main__":
46
- # ===============================
47
- # Sohee
48
- # ===============================
49
- # df = pd.read_csv('D:/college/research/QiQi/sequenzo/data_and_output/orignal data/sohee/sequence_data.csv')
50
- # time_list = list(df.columns)[1:133]
51
- # states = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
52
- # # states = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
53
- # labels = ['FT+WC', 'FT+BC', 'PT+WC', 'PT+BC', 'U', 'OLF']
54
- # sequence_data = SequenceData(df, time=time_list, states=states, labels=labels, id_col="PID")
55
- # res = seqient(sequence_data)
56
-
57
- # ===============================
58
- # kass
59
- # ===============================
60
- # df = pd.read_csv('D:/college/research/QiQi/sequenzo/files/orignal data/kass/wide_civil_final_df.csv')
61
- # time_list = list(df.columns)[1:]
62
- # states = ['Extensive Warfare', 'Limited Violence', 'No Violence', 'Pervasive Warfare', 'Prolonged Warfare',
63
- # 'Serious Violence', 'Serious Warfare', 'Sporadic Violence', 'Technological Warfare', 'Total Warfare']
64
- # sequence_data = SequenceData(df, time=time_list, states=states, id_col="COUNTRY")
65
- # res = seqient(sequence_data)
66
-
67
- # ===============================
68
- # CO2
69
- # ===============================
70
- # df = pd.read_csv("D:/country_co2_emissions_missing.csv")
71
- # _time = list(df.columns)[1:]
72
- # states = ['Very Low', 'Low', 'Middle', 'High', 'Very High']
73
- # sequence_data = SequenceData(df, time=_time, id_col="country", states=states)
74
- # res = seqient(sequence_data)
75
-
76
- # ===============================
77
- # detailed
78
- # ===============================
79
- # df = pd.read_csv("D:/college/research/QiQi/sequenzo/data_and_output/sampled_data_sets/detailed_data/sampled_1000_data.csv")
80
- # _time = list(df.columns)[4:]
81
- # states = ['data', 'data & intensive math', 'hardware', 'research', 'software', 'software & hardware', 'support & test']
82
- # sequence_data = SequenceData(df[['worker_id', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C10']],
83
- # time=_time, id_col="worker_id", states=states)
84
- # res = seqient(sequence_data)
85
-
86
- # ===============================
87
- # broad
88
- # ===============================
89
- df = pd.read_csv("D:/college/research/QiQi/sequenzo/data_and_output/sampled_data_sets/broad_data/sampled_1000_data.csv")
90
- _time = list(df.columns)[4:]
91
- states = ['Non-computing', 'Non-technical computing', 'Technical computing']
92
- sequence_data = SequenceData(df[['worker_id', 'C1', 'C2', 'C3', 'C4', 'C5']],
93
- time=_time, id_col="worker_id", states=states)
94
- res = get_within_sequence_entropy(sequence_data)
95
-
96
- print(res)
@@ -289,10 +289,12 @@ def sort_sequences_by_method(seqdata, method="unsorted", mask=None, distance_mat
289
289
 
290
290
 
291
291
  def plot_sequence_index(seqdata: SequenceData,
292
- show_by_category=None,
293
- category_labels=None,
294
- id_group_df=None,
295
- categories=None,
292
+ # Grouping parameters
293
+ group_by_column=None,
294
+ group_dataframe=None,
295
+ group_column_name=None,
296
+ group_labels=None,
297
+ # Other parameters
296
298
  sort_by="lexicographic",
297
299
  sort_by_weight=False,
298
300
  weights="auto",
@@ -320,13 +322,34 @@ def plot_sequence_index(seqdata: SequenceData,
320
322
  This function creates index plots that visualize sequences as horizontal lines,
321
323
  with different sorting options matching R's TraMineR functionality.
322
324
 
325
+ **Two API modes for grouping:**
326
+
327
+ 1. **Simplified API** (when grouping info is already in the data):
328
+ ```python
329
+ plot_sequence_index(seqdata, group_by_column="Cluster", group_labels=cluster_labels)
330
+ ```
331
+
332
+ 2. **Complete API** (when grouping info is in a separate dataframe):
333
+ ```python
334
+ plot_sequence_index(seqdata, group_dataframe=membership_df,
335
+ group_column_name="Cluster", group_labels=cluster_labels)
336
+ ```
337
+
323
338
  :param seqdata: SequenceData object containing sequence information
324
- :param show_by_category: (str, optional) Simple way to create grouped plots.
325
- Specify the column name from the original data (e.g., "sex", "education").
326
- This will automatically create separate plots for each category.
327
- :param category_labels: (dict, optional) Custom labels for category values.
328
- Example: {0: "Female", 1: "Male"} or {"low": "Low Education", "high": "High Education"}.
329
- If not provided, will use original values or auto-generate readable labels.
339
+
340
+ **New API parameters (recommended):**
341
+ :param group_by_column: (str, optional) Column name from seqdata.data to group by.
342
+ Use this when grouping information is already in your data.
343
+ Example: "Cluster", "sex", "education"
344
+ :param group_dataframe: (pd.DataFrame, optional) Separate dataframe containing grouping information.
345
+ Use this when grouping info is in a separate table (e.g., clustering results).
346
+ Must contain ID column and grouping column.
347
+ :param group_column_name: (str, optional) Name of the grouping column in group_dataframe.
348
+ Required when using group_dataframe.
349
+ :param group_labels: (dict, optional) Custom labels for group values.
350
+ Example: {1: "Late Family Formation", 2: "Early Partnership"}
351
+ Maps original values to display labels.
352
+
330
353
  :param sort_by: Sorting method for sequences within groups:
331
354
  - 'unsorted' or 'none': Keep original order (R TraMineR default)
332
355
  - 'lexicographic': Sort sequences lexicographically
@@ -392,45 +415,45 @@ def plot_sequence_index(seqdata: SequenceData,
392
415
 
393
416
  actual_figsize = style_sizes[plot_style]
394
417
 
395
- # Handle the new simplified API: show_by_category
396
- if show_by_category is not None:
418
+ # Handle the simplified API: group_by_column
419
+ if group_by_column is not None:
397
420
  # Validate that the column exists in the original data
398
- if show_by_category not in seqdata.data.columns:
421
+ if group_by_column not in seqdata.data.columns:
399
422
  available_cols = [col for col in seqdata.data.columns if col not in seqdata.time and col != seqdata.id_col]
400
423
  raise ValueError(
401
- f"Column '{show_by_category}' not found in the data. "
424
+ f"Column '{group_by_column}' not found in the data. "
402
425
  f"Available columns for grouping: {available_cols}"
403
426
  )
404
427
 
405
- # Automatically create id_group_df and categories from the simplified API
406
- id_group_df = seqdata.data[[seqdata.id_col, show_by_category]].copy()
407
- id_group_df.columns = ['Entity ID', 'Category']
408
- categories = 'Category'
428
+ # Automatically create group_dataframe and group_column_name from the simplified API
429
+ group_dataframe = seqdata.data[[seqdata.id_col, group_by_column]].copy()
430
+ group_dataframe.columns = ['Entity ID', 'Category']
431
+ group_column_name = 'Category'
409
432
 
410
- # Handle category labels - flexible and user-controllable
411
- unique_values = seqdata.data[show_by_category].unique()
433
+ # Handle group labels - flexible and user-controllable
434
+ unique_values = seqdata.data[group_by_column].unique()
412
435
 
413
- if category_labels is not None:
436
+ if group_labels is not None:
414
437
  # User provided custom labels - use them
415
- missing_keys = set(unique_values) - set(category_labels.keys())
438
+ missing_keys = set(unique_values) - set(group_labels.keys())
416
439
  if missing_keys:
417
440
  raise ValueError(
418
- f"category_labels missing mappings for values: {missing_keys}. "
419
- f"Please provide labels for all unique values in '{show_by_category}': {sorted(unique_values)}"
441
+ f"group_labels missing mappings for values: {missing_keys}. "
442
+ f"Please provide labels for all unique values in '{group_by_column}': {sorted(unique_values)}"
420
443
  )
421
- id_group_df['Category'] = id_group_df['Category'].map(category_labels)
444
+ group_dataframe['Category'] = group_dataframe['Category'].map(group_labels)
422
445
  else:
423
446
  # No custom labels provided - use smart defaults
424
447
  if all(isinstance(v, (int, float, np.integer, np.floating)) and not pd.isna(v) for v in unique_values):
425
- # Numeric values - keep as is (user can provide category_labels if they want custom names)
448
+ # Numeric values - keep as is (user can provide group_labels if they want custom names)
426
449
  pass
427
450
  # For string/categorical values, keep original values
428
451
  # This handles cases where users already have meaningful labels like "Male"/"Female"
429
452
 
430
- print(f"[>] Creating grouped plots by '{show_by_category}' with {len(unique_values)} categories")
453
+ print(f"[>] Creating grouped plots by '{group_by_column}' with {len(unique_values)} categories")
431
454
 
432
455
  # If no grouping information, create a single plot
433
- if id_group_df is None or categories is None:
456
+ if group_dataframe is None or group_column_name is None:
434
457
  return _sequence_index_plot_single(seqdata, sort_by, sort_by_weight, weights, actual_figsize, plot_style, title, xlabel, ylabel, save_as, dpi, fontsize, include_legend, sequence_selection, n_sequences, show_sequence_ids)
435
458
 
436
459
  # Process weights
@@ -443,21 +466,21 @@ def plot_sequence_index(seqdata: SequenceData,
443
466
  raise ValueError("Length of weights must equal number of sequences.")
444
467
 
445
468
  # Ensure ID columns match (convert if needed)
446
- id_col_name = "Entity ID" if "Entity ID" in id_group_df.columns else id_group_df.columns[0]
469
+ id_col_name = "Entity ID" if "Entity ID" in group_dataframe.columns else group_dataframe.columns[0]
447
470
 
448
471
  # Get unique groups and sort them based on user preference
449
472
  if group_order:
450
473
  # Use manually specified order, filter out non-existing groups
451
- groups = [g for g in group_order if g in id_group_df[categories].unique()]
452
- missing_groups = [g for g in id_group_df[categories].unique() if g not in group_order]
474
+ groups = [g for g in group_order if g in group_dataframe[group_column_name].unique()]
475
+ missing_groups = [g for g in group_dataframe[group_column_name].unique() if g not in group_order]
453
476
  if missing_groups:
454
477
  print(f"[Warning] Groups not in group_order will be excluded: {missing_groups}")
455
478
  elif sort_groups == 'numeric' or sort_groups == 'auto':
456
- groups = smart_sort_groups(id_group_df[categories].unique())
479
+ groups = smart_sort_groups(group_dataframe[group_column_name].unique())
457
480
  elif sort_groups == 'alpha':
458
- groups = sorted(id_group_df[categories].unique())
481
+ groups = sorted(group_dataframe[group_column_name].unique())
459
482
  elif sort_groups == 'none':
460
- groups = list(id_group_df[categories].unique())
483
+ groups = list(group_dataframe[group_column_name].unique())
461
484
  else:
462
485
  raise ValueError(f"Invalid sort_groups value: {sort_groups}. Use 'auto', 'numeric', 'alpha', or 'none'.")
463
486
 
@@ -477,7 +500,7 @@ def plot_sequence_index(seqdata: SequenceData,
477
500
  # Create a plot for each group
478
501
  for i, group in enumerate(groups):
479
502
  # Get IDs for this group
480
- group_ids = id_group_df[id_group_df[categories] == group][id_col_name].values
503
+ group_ids = group_dataframe[group_dataframe[group_column_name] == group][id_col_name].values
481
504
 
482
505
  # Match IDs with sequence data
483
506
  mask = np.isin(seqdata.ids, group_ids)
@@ -41,10 +41,12 @@ def smart_sort_groups(groups):
41
41
 
42
42
 
43
43
  def plot_state_distribution(seqdata: SequenceData,
44
- show_by_category=None,
45
- category_labels=None,
46
- id_group_df=None,
47
- categories=None,
44
+ # Grouping parameters
45
+ group_by_column=None,
46
+ group_dataframe=None,
47
+ group_column_name=None,
48
+ group_labels=None,
49
+ # Other parameters
48
50
  weights="auto",
49
51
  figsize=(12, 7),
50
52
  plot_style="standard",
@@ -67,13 +69,33 @@ def plot_state_distribution(seqdata: SequenceData,
67
69
  Creates state distribution plots for different groups, showing how state
68
70
  prevalence changes over time within each group.
69
71
 
72
+ **Two API modes for grouping:**
73
+
74
+ 1. **Simplified API** (when grouping info is already in the data):
75
+ ```python
76
+ plot_state_distribution(seqdata, group_by_column="Cluster", group_labels=cluster_labels)
77
+ ```
78
+
79
+ 2. **Complete API** (when grouping info is in a separate dataframe):
80
+ ```python
81
+ plot_state_distribution(seqdata, group_dataframe=membership_df,
82
+ group_column_name="Cluster", group_labels=cluster_labels)
83
+ ```
84
+
70
85
  :param seqdata: (SequenceData) A SequenceData object containing sequences
71
- :param show_by_category: (str, optional) Simple way to create grouped plots.
72
- Specify the column name from the original data (e.g., "sex", "education").
73
- This will automatically create separate plots for each category.
74
- :param category_labels: (dict, optional) Custom labels for category values.
75
- Example: {0: "Female", 1: "Male"} or {"low": "Low Education", "high": "High Education"}.
76
- If not provided, will use original values or auto-generate readable labels.
86
+
87
+ **Grouping parameters:**
88
+ :param group_by_column: (str, optional) Column name from seqdata.data to group by.
89
+ Use this when grouping information is already in your data.
90
+ Example: "Cluster", "sex", "education"
91
+ :param group_dataframe: (pd.DataFrame, optional) Separate dataframe containing grouping information.
92
+ Use this when grouping info is in a separate table (e.g., clustering results).
93
+ Must contain ID column and grouping column.
94
+ :param group_column_name: (str, optional) Name of the grouping column in group_dataframe.
95
+ Required when using group_dataframe.
96
+ :param group_labels: (dict, optional) Custom labels for group values.
97
+ Example: {1: "Late Family Formation", 2: "Early Partnership"}
98
+ Maps original values to display labels.
77
99
  :param weights: (np.ndarray or "auto") Weights for sequences. If "auto", uses seqdata.weights if available
78
100
  :param figsize: (tuple) Size of the figure (only used when plot_style="custom")
79
101
  :param plot_style: Plot aspect style:
@@ -122,46 +144,45 @@ def plot_state_distribution(seqdata: SequenceData,
122
144
 
123
145
  actual_figsize = style_sizes[plot_style]
124
146
 
125
- # Handle the new simplified API: show_by_category
126
- if show_by_category is not None:
127
-
147
+ # Handle the simplified API: group_by_column
148
+ if group_by_column is not None:
128
149
  # Validate that the column exists in the original data
129
- if show_by_category not in seqdata.data.columns:
150
+ if group_by_column not in seqdata.data.columns:
130
151
  available_cols = [col for col in seqdata.data.columns if col not in seqdata.time and col != seqdata.id_col]
131
152
  raise ValueError(
132
- f"Column '{show_by_category}' not found in the data. "
153
+ f"Column '{group_by_column}' not found in the data. "
133
154
  f"Available columns for grouping: {available_cols}"
134
155
  )
135
156
 
136
- # Automatically create id_group_df and categories from the simplified API
137
- id_group_df = seqdata.data[[seqdata.id_col, show_by_category]].copy()
138
- id_group_df.columns = ['Entity ID', 'Category']
139
- categories = 'Category'
157
+ # Automatically create group_dataframe and group_column_name from the simplified API
158
+ group_dataframe = seqdata.data[[seqdata.id_col, group_by_column]].copy()
159
+ group_dataframe.columns = ['Entity ID', 'Category']
160
+ group_column_name = 'Category'
140
161
 
141
- # Handle category labels - flexible and user-controllable
142
- unique_values = seqdata.data[show_by_category].unique()
162
+ # Handle group labels - flexible and user-controllable
163
+ unique_values = seqdata.data[group_by_column].unique()
143
164
 
144
- if category_labels is not None:
165
+ if group_labels is not None:
145
166
  # User provided custom labels - use them
146
- missing_keys = set(unique_values) - set(category_labels.keys())
167
+ missing_keys = set(unique_values) - set(group_labels.keys())
147
168
  if missing_keys:
148
169
  raise ValueError(
149
- f"category_labels missing mappings for values: {missing_keys}. "
150
- f"Please provide labels for all unique values in '{show_by_category}': {sorted(unique_values)}"
170
+ f"group_labels missing mappings for values: {missing_keys}. "
171
+ f"Please provide labels for all unique values in '{group_by_column}': {sorted(unique_values)}"
151
172
  )
152
- id_group_df['Category'] = id_group_df['Category'].map(category_labels)
173
+ group_dataframe['Category'] = group_dataframe['Category'].map(group_labels)
153
174
  else:
154
175
  # No custom labels provided - use smart defaults
155
176
  if all(isinstance(v, (int, float, np.integer, np.floating)) and not pd.isna(v) for v in unique_values):
156
- # Numeric values - keep as is (user can provide category_labels if they want custom names)
177
+ # Numeric values - keep as is (user can provide group_labels if they want custom names)
157
178
  pass
158
179
  # For string/categorical values, keep original values
159
180
  # This handles cases where users already have meaningful labels like "Male"/"Female"
160
181
 
161
- print(f"[>] Creating grouped plots by '{show_by_category}' with {len(unique_values)} categories")
182
+ print(f"[>] Creating grouped plots by '{group_by_column}' with {len(unique_values)} categories")
162
183
 
163
184
  # If no grouping information, create a single plot
164
- if id_group_df is None or categories is None:
185
+ if group_dataframe is None or group_column_name is None:
165
186
  return _plot_state_distribution_single(
166
187
  seqdata=seqdata, weights=weights, figsize=actual_figsize,
167
188
  plot_style=plot_style, title=title, xlabel=xlabel, ylabel=ylabel,
@@ -179,21 +200,21 @@ def plot_state_distribution(seqdata: SequenceData,
179
200
  raise ValueError("Length of weights must equal number of sequences.")
180
201
 
181
202
  # Ensure ID columns match (convert if needed)
182
- id_col_name = "Entity ID" if "Entity ID" in id_group_df.columns else id_group_df.columns[0]
203
+ id_col_name = "Entity ID" if "Entity ID" in group_dataframe.columns else group_dataframe.columns[0]
183
204
 
184
205
  # Get unique groups and sort them based on user preference
185
206
  if group_order:
186
207
  # Use manually specified order, filter out non-existing groups
187
- groups = [g for g in group_order if g in id_group_df[categories].unique()]
188
- missing_groups = [g for g in id_group_df[categories].unique() if g not in group_order]
208
+ groups = [g for g in group_order if g in group_dataframe[group_column_name].unique()]
209
+ missing_groups = [g for g in group_dataframe[group_column_name].unique() if g not in group_order]
189
210
  if missing_groups:
190
211
  print(f"[Warning] Groups not in group_order will be excluded: {missing_groups}")
191
212
  elif sort_groups == 'numeric' or sort_groups == 'auto':
192
- groups = smart_sort_groups(id_group_df[categories].unique())
213
+ groups = smart_sort_groups(group_dataframe[group_column_name].unique())
193
214
  elif sort_groups == 'alpha':
194
- groups = sorted(id_group_df[categories].unique())
215
+ groups = sorted(group_dataframe[group_column_name].unique())
195
216
  elif sort_groups == 'none':
196
- groups = list(id_group_df[categories].unique())
217
+ groups = list(group_dataframe[group_column_name].unique())
197
218
  else:
198
219
  raise ValueError(f"Invalid sort_groups value: {sort_groups}. Use 'auto', 'numeric', 'alpha', or 'none'.")
199
220
 
@@ -216,7 +237,7 @@ def plot_state_distribution(seqdata: SequenceData,
216
237
  # Process each group
217
238
  for i, group in enumerate(groups):
218
239
  # Get IDs for this group
219
- group_ids = id_group_df[id_group_df[categories] == group][id_col_name].values
240
+ group_ids = group_dataframe[group_dataframe[group_column_name] == group][id_col_name].values
220
241
 
221
242
  # Match IDs with sequence data
222
243
  mask = np.isin(seqdata.ids, group_ids)