datasmryzr 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datasmryzr/__about__.py +4 -0
- datasmryzr/__init__.py +3 -0
- datasmryzr/annotate.py +270 -0
- datasmryzr/clusters.py +315 -0
- datasmryzr/core_genome.py +297 -0
- datasmryzr/datasmryzr.py +97 -0
- datasmryzr/distances.py +106 -0
- datasmryzr/pangenome.py +211 -0
- datasmryzr/smryz.py +472 -0
- datasmryzr/summary.py +110 -0
- datasmryzr/tables.py +243 -0
- datasmryzr/templates/base_config.json +28 -0
- datasmryzr/templates/report.html.j2 +899 -0
- datasmryzr/tree.py +27 -0
- datasmryzr/utils.py +94 -0
- datasmryzr-0.0.1.dist-info/METADATA +54 -0
- datasmryzr-0.0.1.dist-info/RECORD +20 -0
- datasmryzr-0.0.1.dist-info/WHEEL +4 -0
- datasmryzr-0.0.1.dist-info/entry_points.txt +2 -0
- datasmryzr-0.0.1.dist-info/licenses/LICENSE.txt +9 -0
datasmryzr/__about__.py
ADDED
datasmryzr/__init__.py
ADDED
datasmryzr/annotate.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides functions for generating metadata annotations
|
|
3
|
+
and legends for a DataFrame, mapping metadata variables to colors.
|
|
4
|
+
"""
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import json
|
|
7
|
+
from mycolorpy import colorlist as mcp
|
|
8
|
+
from datasmryzr import utils
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _open_file(file_path:str) -> pd.DataFrame:
|
|
13
|
+
"""
|
|
14
|
+
Open a file and return its contents.
|
|
15
|
+
Args:
|
|
16
|
+
file_path (str): Path to the file.
|
|
17
|
+
Returns:
|
|
18
|
+
str: File contents.
|
|
19
|
+
"""
|
|
20
|
+
df = pd.read_csv(file_path, sep = None, engine = 'python')
|
|
21
|
+
return df
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _check_vals(df:pd.DataFrame,
|
|
25
|
+
cols:list,
|
|
26
|
+
cfg:dict) -> list:
|
|
27
|
+
"""
|
|
28
|
+
Validates and filters the specified columns from a DataFrame based on
|
|
29
|
+
their data type and configuration settings.
|
|
30
|
+
Args:
|
|
31
|
+
df (pd.DataFrame): The input DataFrame to check.
|
|
32
|
+
cols (list): A list of column names to validate.
|
|
33
|
+
cfg (dict): A configuration dictionary containing the key
|
|
34
|
+
'categorical_columns', which specifies columns to treat as
|
|
35
|
+
categorical.
|
|
36
|
+
Returns:
|
|
37
|
+
list: A list of valid column names that are either non-numerical or
|
|
38
|
+
explicitly specified as categorical in the configuration.
|
|
39
|
+
Raises:
|
|
40
|
+
ValueError: If none of the specified columns are valid or if none of
|
|
41
|
+
the columns exist in the DataFrame.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
final_cols = []
|
|
45
|
+
_id_col = df.columns[0]
|
|
46
|
+
indf = False
|
|
47
|
+
for col in cols:
|
|
48
|
+
if col in df.columns:
|
|
49
|
+
indf = True
|
|
50
|
+
is_string = True
|
|
51
|
+
if col != _id_col:
|
|
52
|
+
|
|
53
|
+
for val in df[col].unique():
|
|
54
|
+
if isinstance(val, str):
|
|
55
|
+
is_string = True
|
|
56
|
+
else:
|
|
57
|
+
is_string = False
|
|
58
|
+
|
|
59
|
+
if is_string or col in cfg['categorical_columns']:
|
|
60
|
+
final_cols.append(col)
|
|
61
|
+
if not final_cols:
|
|
62
|
+
if indf:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"Columns {', '.join(cols)} do not contain any valid values.\
|
|
65
|
+
Please check the column names."
|
|
66
|
+
)
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"None of the columns {', '.join(cols)} are in the dataframe or \
|
|
69
|
+
in the correct format - only non-numerical data can be included. \
|
|
70
|
+
Please check the column names."
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
return final_cols
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _get_cols(cols: list, df: pd.DataFrame, cfg: dict) -> list:
|
|
77
|
+
"""
|
|
78
|
+
Retrieve and validate a list of columns from a DataFrame based on the
|
|
79
|
+
provided configuration.
|
|
80
|
+
Args:
|
|
81
|
+
cols (list): A list of column names to retrieve or the string "all"
|
|
82
|
+
to select all columns.
|
|
83
|
+
df (pd.DataFrame): The DataFrame from which columns will be retrieved.
|
|
84
|
+
cfg (dict): A configuration dictionary used for validation.
|
|
85
|
+
Returns:
|
|
86
|
+
list: A list of validated column names.
|
|
87
|
+
Notes:
|
|
88
|
+
- If `cols` is "all", all columns in the DataFrame will be selected
|
|
89
|
+
and validated.
|
|
90
|
+
- The `_check_vals` function is used to validate the selected columns
|
|
91
|
+
against the configuration.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
if cols == "all":
|
|
95
|
+
return _check_vals(df=df, cols=df.columns.tolist(), cfg=cfg)
|
|
96
|
+
return _check_vals(df=df, cols=cols, cfg=cfg)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _get_colors(df:pd.DataFrame,
|
|
100
|
+
cols:list) -> tuple:
|
|
101
|
+
"""
|
|
102
|
+
Generate a dictionary of CSS-compatible color mappings for unique values
|
|
103
|
+
in specified columns of a DataFrame.
|
|
104
|
+
Args:
|
|
105
|
+
df (pd.DataFrame): The input DataFrame containing the data.
|
|
106
|
+
cols (list): A list of column names in the DataFrame for which
|
|
107
|
+
unique values will be assigned colors.
|
|
108
|
+
Returns:
|
|
109
|
+
tuple: A dictionary where keys are modified color names
|
|
110
|
+
(CSS-compatible) and values are the corresponding color codes.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
colors_set: set = set()
|
|
114
|
+
colors_css: dict = {}
|
|
115
|
+
|
|
116
|
+
for col in cols:
|
|
117
|
+
unique_vals = list(df[col].unique())
|
|
118
|
+
length = len(unique_vals)
|
|
119
|
+
colors = mcp.gen_color(cmap="tab20b", n=length)
|
|
120
|
+
colors_set.update(colors)
|
|
121
|
+
|
|
122
|
+
for cl in colors_set:
|
|
123
|
+
nme = cl.replace("#", "a")
|
|
124
|
+
if nme not in colors_css:
|
|
125
|
+
colors_css[nme] = cl
|
|
126
|
+
return colors_css
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _make_legend(df:pd.DataFrame,
|
|
130
|
+
cols:list,
|
|
131
|
+
color_css:dict) -> dict:
|
|
132
|
+
"""
|
|
133
|
+
Generate a legend mapping unique values in specified columns of a
|
|
134
|
+
DataFrame to corresponding colors from a given CSS color dictionary.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
df (pd.DataFrame): The input DataFrame containing the data.
|
|
138
|
+
cols (list): A list of column names in the DataFrame to generate
|
|
139
|
+
legends for.
|
|
140
|
+
color_css (dict): A dictionary where keys are color names or codes,
|
|
141
|
+
and values are CSS color definitions.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
dict: A dictionary where each key is a column name from `cols`, and
|
|
145
|
+
the value is a list of dictionaries mapping unique column values
|
|
146
|
+
to colors.
|
|
147
|
+
|
|
148
|
+
Notes:
|
|
149
|
+
- Values equal to "NA" are excluded from the legend.
|
|
150
|
+
- If the number of unique values in a column exceeds the number of
|
|
151
|
+
available colors in `color_css`, only the first `len(color_css)`
|
|
152
|
+
unique values are mapped.
|
|
153
|
+
"""
|
|
154
|
+
legend: dict = {}
|
|
155
|
+
for col in cols:
|
|
156
|
+
unique_vals = [val for val in df[col].unique() if val != "NA"]
|
|
157
|
+
colors = list(color_css.keys())
|
|
158
|
+
cols_mapped = zip(unique_vals, colors)
|
|
159
|
+
legend[col] = [{val: color} for val, color in cols_mapped]
|
|
160
|
+
return legend
|
|
161
|
+
|
|
162
|
+
def _get_metadata_tree(df:pd.DataFrame,
|
|
163
|
+
cols:list,
|
|
164
|
+
legend: dict,
|
|
165
|
+
color_css:dict) -> dict:
|
|
166
|
+
"""
|
|
167
|
+
Generate a metadata structure from a DataFrame.
|
|
168
|
+
This function creates a nested dictionary (metadata tree) where each key corresponds to a unique value
|
|
169
|
+
in the first column of the DataFrame (`tiplabel`). For each row in the DataFrame, the metadata tree
|
|
170
|
+
includes additional metadata for specified columns (`cols`), with associated color and label information.
|
|
171
|
+
Args:
|
|
172
|
+
df (pd.DataFrame): The input DataFrame containing the data to process. The first column is used as
|
|
173
|
+
the primary key (`tiplabel`) for the metadata.
|
|
174
|
+
cols (list): A list of column names from the DataFrame to include in the metadata.
|
|
175
|
+
legend (dict): A dictionary mapping column names to dictionaries that map column values to colors.
|
|
176
|
+
Example: { "column_name": { "value1": "color1", "value2": "color2" } }.
|
|
177
|
+
color_css (dict): A dictionary mapping color names to CSS-compatible color codes.
|
|
178
|
+
Example: { "red": "#FF0000", "blue": "#0000FF" }.
|
|
179
|
+
Returns:
|
|
180
|
+
dict: A nested dictionary representing the metadata. Each key corresponds to a unique value
|
|
181
|
+
in the first column of the DataFrame, and each value is a dictionary containing metadata
|
|
182
|
+
for the specified columns, including color and label information.
|
|
183
|
+
Example:
|
|
184
|
+
Input DataFrame:
|
|
185
|
+
+---------+--------+--------+
|
|
186
|
+
| tiplabel| col1 | col2 |
|
|
187
|
+
+---------+--------+--------+
|
|
188
|
+
| A | value1 | value2 |
|
|
189
|
+
| B | value3 | value4 |
|
|
190
|
+
+---------+--------+--------+
|
|
191
|
+
cols = ["col1", "col2"]
|
|
192
|
+
legend = {
|
|
193
|
+
"col1": {"value1": "red", "value3": "blue"},
|
|
194
|
+
"col2": {"value2": "green", "value4": "yellow"}
|
|
195
|
+
color_css = {"red": "#FF0000", "blue": "#0000FF", "green": "#00FF00", "yellow": "#FFFF00"}
|
|
196
|
+
Output:
|
|
197
|
+
{
|
|
198
|
+
"A": {
|
|
199
|
+
"col1": {"colour": "#FF0000", "label": "value1"},
|
|
200
|
+
"col2": {"colour": "#00FF00", "label": "value2"}
|
|
201
|
+
},
|
|
202
|
+
"B": {
|
|
203
|
+
"col1": {"colour": "#0000FF", "label": "value3"},
|
|
204
|
+
"col2": {"colour": "#FFFF00", "label": "value4"}
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
metadata_tree = {}
|
|
208
|
+
tiplabel = df.columns[0]
|
|
209
|
+
|
|
210
|
+
for _, row in df.iterrows():
|
|
211
|
+
metadata_tree[row[tiplabel]] = {
|
|
212
|
+
col: {
|
|
213
|
+
"colour": color_css.get(
|
|
214
|
+
next(
|
|
215
|
+
(lg[row[col]] for lg in legend[col] if row[col] in lg),
|
|
216
|
+
"white",
|
|
217
|
+
),
|
|
218
|
+
"white",
|
|
219
|
+
),
|
|
220
|
+
"label": row[col],
|
|
221
|
+
}
|
|
222
|
+
for col in cols
|
|
223
|
+
if col != tiplabel
|
|
224
|
+
}
|
|
225
|
+
return metadata_tree
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def construct_annotations(path: str,
|
|
230
|
+
cols: list,
|
|
231
|
+
config:str) -> dict:
|
|
232
|
+
"""
|
|
233
|
+
Constructs annotations based on the provided file path and columns.
|
|
234
|
+
This function processes a file to generate metadata annotations, including
|
|
235
|
+
a metadata tree, metadata columns, CSS color mappings, and a legend. If no
|
|
236
|
+
file path is provided, it returns default empty structures.
|
|
237
|
+
Args:
|
|
238
|
+
path (str): The file path to the data source. If empty, default values
|
|
239
|
+
are returned.
|
|
240
|
+
cols (list): A list of column names to be used for generating metadata.
|
|
241
|
+
Returns:
|
|
242
|
+
dict: A dictionary containing the following keys:
|
|
243
|
+
- "metadata_tree" (dict): A hierarchical representation of metadata.
|
|
244
|
+
- "metadata_columns" (list): A list of metadata column names.
|
|
245
|
+
- "colors_css" (dict): A mapping of metadata values to CSS color codes.
|
|
246
|
+
- "legend" (list): A list of legend entries for the metadata.
|
|
247
|
+
"""
|
|
248
|
+
if not path:
|
|
249
|
+
return {
|
|
250
|
+
"metadata_tree": {},
|
|
251
|
+
"metadata_columns": [],
|
|
252
|
+
"colors_css": {},
|
|
253
|
+
"legend": [],
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
df = _open_file(path).fillna("NA")
|
|
257
|
+
cfg = utils.get_config(config)
|
|
258
|
+
metadata_columns = _get_cols(cols=cols, df=df, cfg=cfg)
|
|
259
|
+
colors_css = _get_colors(df=df, cols=metadata_columns)
|
|
260
|
+
legend = _make_legend(df=df, cols=metadata_columns, color_css=colors_css)
|
|
261
|
+
metadata_tree = _get_metadata_tree(
|
|
262
|
+
df=df, cols=metadata_columns, legend=legend, color_css=colors_css
|
|
263
|
+
)
|
|
264
|
+
return {
|
|
265
|
+
"metadata_tree": metadata_tree,
|
|
266
|
+
"metadata_columns": metadata_columns,
|
|
267
|
+
"colors_css": colors_css,
|
|
268
|
+
"legend": legend,
|
|
269
|
+
}
|
|
270
|
+
|
datasmryzr/clusters.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides functions for processing pairwise distances between isolates,
|
|
3
|
+
including generating histograms and heatmaps for visualization.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import pathlib
|
|
8
|
+
import json
|
|
9
|
+
import altair as alt
|
|
10
|
+
from datasmryzr.utils import check_file_exists
|
|
11
|
+
from datasmryzr.distances import _get_distances
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _get_cluster_table(
|
|
15
|
+
clusters: str
|
|
16
|
+
) -> pd.DataFrame:
|
|
17
|
+
try:
|
|
18
|
+
# if check_file_exists(clusters):
|
|
19
|
+
cluster_df = pd.read_csv(clusters, sep=None, engine='python', dtype=str)
|
|
20
|
+
return cluster_df
|
|
21
|
+
except Exception as e:
|
|
22
|
+
print(e)
|
|
23
|
+
return pd.DataFrame()
|
|
24
|
+
|
|
25
|
+
def _get_distance_data(
|
|
26
|
+
cluster_df: pd.DataFrame,
|
|
27
|
+
distances_df: pd.DataFrame
|
|
28
|
+
) -> pd.DataFrame:
|
|
29
|
+
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
def _combine_cluster_ids(
|
|
33
|
+
clusters:pd.DataFrame) -> pd.DataFrame:
|
|
34
|
+
|
|
35
|
+
thresholds = _get_thresholds(clusters)
|
|
36
|
+
while thresholds:
|
|
37
|
+
threshold = thresholds.pop(0)
|
|
38
|
+
if thresholds != []:
|
|
39
|
+
print(thresholds)
|
|
40
|
+
clusters[f"Tx:{thresholds[0]}"] = clusters[[f"Tx:{threshold}", f"Tx:{thresholds[0]}"]].apply(lambda x: ':'.join(x) if not "UC" in f"{x[0]}" else x[0], axis = 1)
|
|
41
|
+
|
|
42
|
+
return clusters
|
|
43
|
+
|
|
44
|
+
def _get_thresholds(clusters: pd.DataFrame) -> list:
|
|
45
|
+
|
|
46
|
+
thresholds = sorted([int(t.split(':')[1]) for t in list(clusters.columns) if "Tx" in t], reverse=True)
|
|
47
|
+
|
|
48
|
+
return thresholds
|
|
49
|
+
|
|
50
|
+
def _create_tree_for_traversal(
|
|
51
|
+
clusters: pd.DataFrame) -> dict:
|
|
52
|
+
thresholds = _get_thresholds(clusters)
|
|
53
|
+
tree = {'all': [c for c in clusters[f"Tx:{thresholds[0]}"].unique() if c != "UC"]}
|
|
54
|
+
while thresholds:
|
|
55
|
+
threshold = thresholds.pop(0)
|
|
56
|
+
clusters = clusters[~clusters[f"Tx:{threshold}"].str.contains("UC")]
|
|
57
|
+
for cl in clusters[f"Tx:{threshold}"].unique():
|
|
58
|
+
if cl not in tree:
|
|
59
|
+
tree[cl] = []
|
|
60
|
+
if thresholds != []:
|
|
61
|
+
tree[cl] = list(clusters[clusters[f"Tx:{threshold}"] == cl][f"Tx:{thresholds[0]}"].unique())
|
|
62
|
+
else:
|
|
63
|
+
tree[cl] = []
|
|
64
|
+
return tree
|
|
65
|
+
|
|
66
|
+
def _construct_table_dict(tree, node, clusters, visited=None):
|
|
67
|
+
# print(type(df))
|
|
68
|
+
# print(type(df))
|
|
69
|
+
# print(type(df))
|
|
70
|
+
cols = list(clusters.columns)
|
|
71
|
+
size = 0
|
|
72
|
+
print(node)
|
|
73
|
+
for col in cols:
|
|
74
|
+
print(col)
|
|
75
|
+
if node in clusters[col].unique():
|
|
76
|
+
tmp = clusters[clusters[col] == node]
|
|
77
|
+
size = tmp.shape[0]
|
|
78
|
+
isolates = list(tmp['ID'].unique())
|
|
79
|
+
# print(type(df))
|
|
80
|
+
if visited is None:
|
|
81
|
+
visited = set() # Initialize the visited set
|
|
82
|
+
visited.add(node) # Mark the node as visited
|
|
83
|
+
# print(node)
|
|
84
|
+
|
|
85
|
+
data = {'Cluster ID': node, 'Num seqs':size, '_children': []} # Store the children of the current node
|
|
86
|
+
if "UC" not in node:
|
|
87
|
+
for child in tree[node]: # Recursively visit children
|
|
88
|
+
if child not in visited:
|
|
89
|
+
if "UC" not in child:
|
|
90
|
+
data['_children'].append(_construct_table_dict(tree = tree, node = child, clusters = clusters, visited=visited))
|
|
91
|
+
# dfs_recursive(tree, child, clusters, visited)
|
|
92
|
+
# print(data)
|
|
93
|
+
return data
|
|
94
|
+
|
|
95
|
+
def get_cluster_distances(
|
|
96
|
+
clusters: str,
|
|
97
|
+
distances: str
|
|
98
|
+
) -> pd.DataFrame:
|
|
99
|
+
cluster_df = _get_cluster_table(clusters)
|
|
100
|
+
cluster_df = _combine_cluster_ids(cluster_df)
|
|
101
|
+
thresholds = _get_thresholds(cluster_df)
|
|
102
|
+
dists = {}
|
|
103
|
+
if check_file_exists(distances) and not cluster_df.empty:
|
|
104
|
+
distances_df = pd.read_csv(distances, sep = "\t")
|
|
105
|
+
tree = _create_tree_for_traversal(cluster_df)
|
|
106
|
+
id_col = distances_df.columns[0]
|
|
107
|
+
|
|
108
|
+
for cl in tree:
|
|
109
|
+
for th in thresholds:
|
|
110
|
+
if cl in cluster_df[f"Tx:{th}"].values:
|
|
111
|
+
|
|
112
|
+
isolates = list(cluster_df[cluster_df[f"Tx:{th}"] == cl]['ID'])
|
|
113
|
+
ccols = ["Isolate"]
|
|
114
|
+
ccols.extend(isolates)
|
|
115
|
+
dd = distances_df[distances_df["Isolate"].isin(isolates)][ccols]
|
|
116
|
+
tbl = dd.to_dict(orient='records')
|
|
117
|
+
col_dict = []
|
|
118
|
+
for col in ccols:
|
|
119
|
+
if col == id_col:
|
|
120
|
+
dct = {'field': col, 'title': col, 'type': 'string', 'headerFilter':'input',
|
|
121
|
+
'headerFilterPlaceholder':f'Search {col}',
|
|
122
|
+
'formatter':"textarea"}
|
|
123
|
+
col_dict.append(dct)
|
|
124
|
+
else:
|
|
125
|
+
dct = {'field': col, 'title': col, 'type': 'number', 'headerFilter':'number', 'headerFilterFunc':"<=",
|
|
126
|
+
'headerFilterPlaceholder':f'Less than ...',
|
|
127
|
+
'formatter':"number",}
|
|
128
|
+
col_dict.append(dct)
|
|
129
|
+
dists[cl] = {
|
|
130
|
+
'table': tbl,
|
|
131
|
+
'columns': col_dict
|
|
132
|
+
}
|
|
133
|
+
return dists
|
|
134
|
+
|
|
135
|
+
def _save_cluster_table(cluster_table: dict) -> str:
|
|
136
|
+
out_path = pathlib.Path.cwd() / "clusters.json"
|
|
137
|
+
with open(out_path, 'w') as f:
|
|
138
|
+
json.dump(cluster_table, f, indent=4)
|
|
139
|
+
return str(out_path)
|
|
140
|
+
|
|
141
|
+
def get_cluster_table(
|
|
142
|
+
clusters: str,
|
|
143
|
+
distances: str
|
|
144
|
+
) -> str:
|
|
145
|
+
|
|
146
|
+
distances_df = _get_distances(distances)
|
|
147
|
+
cluster_df = _get_cluster_table(clusters)
|
|
148
|
+
# print(cluster_df)
|
|
149
|
+
thresholds = _get_thresholds(cluster_df)
|
|
150
|
+
|
|
151
|
+
if cluster_df.empty or distances_df.empty:
|
|
152
|
+
return {}
|
|
153
|
+
else:
|
|
154
|
+
cluster_df = _combine_cluster_ids(cluster_df)
|
|
155
|
+
tree = _create_tree_for_traversal(cluster_df)
|
|
156
|
+
cluster_table = _construct_table_dict(tree = tree, node = 'all', clusters= cluster_df)
|
|
157
|
+
# print(raw_data)
|
|
158
|
+
# cluster_table = _polish_cluster_table(raw_data, 'all')
|
|
159
|
+
|
|
160
|
+
return _save_cluster_table(cluster_table['_children'])
|
|
161
|
+
|
|
162
|
+
def _get_clustered(clusters: str, threshold:int) -> pd.DataFrame:
|
|
163
|
+
clustered = clusters[~clusters[f"Tx:{threshold}"].str.contains("UC")][clusters.columns[0]].tolist()
|
|
164
|
+
return clustered
|
|
165
|
+
|
|
166
|
+
def _cluster_statistics(
|
|
167
|
+
cluster_df: str,
|
|
168
|
+
distances_df: str,
|
|
169
|
+
thresholds: list,
|
|
170
|
+
id_col: str = None
|
|
171
|
+
) -> pd.DataFrame:
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
intra_clusters = []
|
|
176
|
+
for th in thresholds:
|
|
177
|
+
|
|
178
|
+
clustered = _get_clustered(cluster_df, th)
|
|
179
|
+
|
|
180
|
+
cdf = distances_df[distances_df["Isolate1"].isin(clustered) & distances_df["Isolate2"].isin(clustered)]
|
|
181
|
+
# print(cdf)
|
|
182
|
+
for cl in cluster_df[f"Tx:{th}"].unique():
|
|
183
|
+
|
|
184
|
+
if "UC" not in cl:
|
|
185
|
+
|
|
186
|
+
cldf = cluster_df[cluster_df[f"Tx:{th}"] == cl]
|
|
187
|
+
|
|
188
|
+
if not cdf.empty:
|
|
189
|
+
|
|
190
|
+
tmp = cdf[cdf["Isolate1"].isin(cldf[id_col] )]
|
|
191
|
+
tmp = tmp[tmp["Isolate2"].isin(cldf[id_col])]
|
|
192
|
+
# print(tmp[["Isolate1", "Isolate2"]])
|
|
193
|
+
tmp["pair"] = tmp[["Isolate1", "Isolate2"]].apply(lambda x: "_".join(sorted(x)), axis=1)
|
|
194
|
+
tmp["Cluster ID"] = f"{cl}"
|
|
195
|
+
tmp["SNP Threshold"] = th
|
|
196
|
+
tmp["Measurement"] = "Intra-cluster distance"
|
|
197
|
+
intra_clusters.append(tmp)
|
|
198
|
+
inter = cluster_df[(cluster_df[f"Tx:{th}"] != cl) & (~cluster_df[f"Tx:{th}"].str.contains("UC"))]
|
|
199
|
+
# print(inter)
|
|
200
|
+
for cluster in inter[f"Tx:{th}"].unique():
|
|
201
|
+
intery = inter[inter[f"Tx:{th}"] == cluster]
|
|
202
|
+
for i in cldf[id_col].unique(): # get each isolate in the cluster
|
|
203
|
+
interx = pd.concat([intery, cldf[cldf[id_col] == i]])
|
|
204
|
+
# print(interx)
|
|
205
|
+
tmp2 = cdf[cdf["Isolate1"]== i]
|
|
206
|
+
tmp2 = tmp2[~tmp2["Isolate2"].isin(interx[id_col])]
|
|
207
|
+
tmp2["pair"] = tmp2[["Isolate1", "Isolate2"]].apply(lambda x: "_".join(sorted(x)), axis=1)
|
|
208
|
+
tmp2["Cluster ID"] = f"{cl}"
|
|
209
|
+
tmp2["SNP Threshold"] = th
|
|
210
|
+
tmp2["Measurement"] = "Inter-cluster distance"
|
|
211
|
+
# print(tmp2)
|
|
212
|
+
intra_clusters.append(tmp2)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
cdf_all = pd.concat(intra_clusters, ignore_index=True)
|
|
216
|
+
cdf_all.drop_duplicates(subset=["pair", "Cluster ID"], inplace=True)
|
|
217
|
+
return cdf_all
|
|
218
|
+
|
|
219
|
+
def _generate_cluster_graphs(
|
|
220
|
+
cdf_all: pd.DataFrame,
|
|
221
|
+
clusters: pd.DataFrame,
|
|
222
|
+
id_col: str,
|
|
223
|
+
thresholds: int
|
|
224
|
+
) -> dict:
|
|
225
|
+
|
|
226
|
+
charts = []
|
|
227
|
+
for th in thresholds:
|
|
228
|
+
clustered_list = _get_clustered(clusters, th)
|
|
229
|
+
tmp = clusters[clusters[id_col].isin(clustered_list)]
|
|
230
|
+
tmp = tmp.rename(columns={f"Tx:{th}": f"Tx_{th}"})
|
|
231
|
+
uc = clusters[~clusters[id_col].isin(clustered_list)]
|
|
232
|
+
uc = uc.rename(columns={f"Tx:{th}": f"Tx_{th}"})
|
|
233
|
+
uc[f"Tx_{th}"] = "UC"
|
|
234
|
+
num_cls = tmp.shape[0]
|
|
235
|
+
clustered_graph = alt.Chart(tmp).mark_bar().encode(
|
|
236
|
+
x=alt.X(f'Tx_{th}:N', title = None),
|
|
237
|
+
y=alt.Y('count():Q', title = None).scale(domain=[0, clusters.shape[0]]),
|
|
238
|
+
# column = "Clustered:N",
|
|
239
|
+
color=alt.Color('Tx_{th}:N', scale=alt.Scale(scheme='viridis')).legend(None),
|
|
240
|
+
tooltip=[f'Tx_{th}:N', 'count():Q']
|
|
241
|
+
).properties(
|
|
242
|
+
width=200,
|
|
243
|
+
title = "Number sequences per cluster"
|
|
244
|
+
# height=300
|
|
245
|
+
)
|
|
246
|
+
unclustered_graph = alt.Chart(uc).mark_bar(color="grey").encode(
|
|
247
|
+
x=alt.X(f'Tx_{th}:N', title = None),
|
|
248
|
+
y=alt.Y('count():Q', title = "Number of Isolates").scale(domain=[0, clusters.shape[0]]),
|
|
249
|
+
# column = "Clustered:N",
|
|
250
|
+
# color=alt.Color('threshold_9:N'),
|
|
251
|
+
tooltip=[f'Tx_{th}:N', 'count():Q']
|
|
252
|
+
).properties(
|
|
253
|
+
width=300/(num_cls + 1),
|
|
254
|
+
# height=300
|
|
255
|
+
)
|
|
256
|
+
alt.hconcat(unclustered_graph,clustered_graph).configure_axis(
|
|
257
|
+
grid=False
|
|
258
|
+
).configure_view(
|
|
259
|
+
stroke=None
|
|
260
|
+
)
|
|
261
|
+
graphs = [unclustered_graph,clustered_graph]
|
|
262
|
+
for m in ["Intra-cluster distance", "Inter-cluster distance"]:
|
|
263
|
+
box = alt.Chart(cdf_all[(cdf_all["Measurement"] == m) & (cdf_all["SNP Threshold"] == th)]).mark_boxplot(extent='min-max', opacity=.3).encode(
|
|
264
|
+
y=alt.Y(f"Distance:Q", sort=None, title = f"Pairwise Distance (threshold: {th})"),
|
|
265
|
+
x=alt.X('Cluster ID:N'),
|
|
266
|
+
|
|
267
|
+
# tooltip=['pair', 'Distance:Q'],
|
|
268
|
+
color=alt.Color('Cluster ID').scale(scheme='viridis').legend(None)
|
|
269
|
+
|
|
270
|
+
)
|
|
271
|
+
scatter = alt.Chart(cdf_all[(cdf_all["Measurement"] == m) & (cdf_all["SNP Threshold"] == th)]).mark_circle(size=80).encode(
|
|
272
|
+
y=alt.Y(f"Distance:Q", sort=None),
|
|
273
|
+
x=alt.X('Cluster ID:N'),
|
|
274
|
+
color=alt.Color('Cluster ID').scale(scheme='viridis').legend(None),
|
|
275
|
+
|
|
276
|
+
# Add jitter if desired (e.g., using a calculated jitter column or a transform)
|
|
277
|
+
# yOffset='jitter_x:Q',
|
|
278
|
+
tooltip=['pair', f'Distance'],
|
|
279
|
+
)
|
|
280
|
+
chart = box + scatter
|
|
281
|
+
chart = chart.properties(title=f"{m}", width = 500)
|
|
282
|
+
graphs.append(chart)
|
|
283
|
+
graph = alt.hconcat(*graphs).resolve_scale(
|
|
284
|
+
y='independent').properties(
|
|
285
|
+
title = alt.Title(f"SNP threshold {th}", anchor='start', fontSize=20, dy=-10, baseline='middle')
|
|
286
|
+
)
|
|
287
|
+
charts.append(graph)
|
|
288
|
+
|
|
289
|
+
final_chart = alt.vconcat(*charts).configure_axis(
|
|
290
|
+
grid=False
|
|
291
|
+
).configure_view(
|
|
292
|
+
stroke=None
|
|
293
|
+
)
|
|
294
|
+
return final_chart.to_json()
|
|
295
|
+
|
|
296
|
+
def get_cluster_graphs(
|
|
297
|
+
clusters: str,
|
|
298
|
+
distances: str
|
|
299
|
+
) -> dict:
|
|
300
|
+
|
|
301
|
+
distances_df = _get_distances(distances)
|
|
302
|
+
cluster_df = _get_cluster_table(clusters)
|
|
303
|
+
thresholds = _get_thresholds(cluster_df)
|
|
304
|
+
cluster_df = _combine_cluster_ids(cluster_df)
|
|
305
|
+
id_col = cluster_df.columns[0]
|
|
306
|
+
try:
|
|
307
|
+
cdf_all = _cluster_statistics(cluster_df = cluster_df, distances_df = distances_df, thresholds= thresholds, id_col = id_col)
|
|
308
|
+
graph = _generate_cluster_graphs(cdf_all = cdf_all, clusters= cluster_df, id_col = id_col, thresholds= thresholds)
|
|
309
|
+
return graph
|
|
310
|
+
except Exception as e:
|
|
311
|
+
print(e)
|
|
312
|
+
return {}
|
|
313
|
+
|
|
314
|
+
# <button class="btn btn-sm btn-outline-secondary" style= "margin:2px;" id="information-button" data-bs-toggle="modal" data-bs-target="#myModal"><i class="bi bi-info-circle" style = "font-size: 1.2rem;"></i> Info</button>
|
|
315
|
+
|