risk-network 0.0.8b27__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +2 -2
- risk/annotations/__init__.py +2 -2
- risk/annotations/annotations.py +195 -118
- risk/annotations/io.py +47 -31
- risk/log/__init__.py +4 -2
- risk/log/console.py +3 -1
- risk/log/{params.py → parameters.py} +17 -42
- risk/neighborhoods/__init__.py +3 -5
- risk/neighborhoods/api.py +442 -0
- risk/neighborhoods/community.py +324 -101
- risk/neighborhoods/domains.py +125 -52
- risk/neighborhoods/neighborhoods.py +177 -165
- risk/network/__init__.py +1 -3
- risk/network/geometry.py +71 -89
- risk/network/graph/__init__.py +6 -0
- risk/network/graph/api.py +200 -0
- risk/network/{graph.py → graph/graph.py} +90 -40
- risk/network/graph/summary.py +254 -0
- risk/network/io.py +103 -114
- risk/network/plotter/__init__.py +6 -0
- risk/network/plotter/api.py +54 -0
- risk/network/{plot → plotter}/canvas.py +9 -8
- risk/network/{plot → plotter}/contour.py +27 -24
- risk/network/{plot → plotter}/labels.py +73 -78
- risk/network/{plot → plotter}/network.py +45 -39
- risk/network/{plot → plotter}/plotter.py +23 -17
- risk/network/{plot/utils/color.py → plotter/utils/colors.py} +114 -122
- risk/network/{plot → plotter}/utils/layout.py +10 -7
- risk/risk.py +11 -500
- risk/stats/__init__.py +10 -4
- risk/stats/permutation/__init__.py +1 -1
- risk/stats/permutation/permutation.py +44 -38
- risk/stats/permutation/test_functions.py +26 -18
- risk/stats/{stats.py → significance.py} +17 -15
- risk/stats/stat_tests.py +267 -0
- {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/METADATA +31 -46
- risk_network-0.0.9.dist-info/RECORD +40 -0
- {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/WHEEL +1 -1
- risk/constants.py +0 -31
- risk/network/plot/__init__.py +0 -6
- risk/stats/hypergeom.py +0 -54
- risk/stats/poisson.py +0 -44
- risk_network-0.0.8b27.dist-info/RECORD +0 -37
- {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/LICENSE +0 -0
- {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/top_level.txt +0 -0
risk/__init__.py
CHANGED
risk/annotations/__init__.py
CHANGED
@@ -3,5 +3,5 @@ risk/annotations
|
|
3
3
|
~~~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
|
-
from .annotations import define_top_annotations, get_weighted_description
|
7
|
-
from .io import AnnotationsIO
|
6
|
+
from risk.annotations.annotations import define_top_annotations, get_weighted_description
|
7
|
+
from risk.annotations.io import AnnotationsIO
|
risk/annotations/annotations.py
CHANGED
@@ -3,7 +3,9 @@ risk/annotations/annotations
|
|
3
3
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
|
+
import os
|
6
7
|
import re
|
8
|
+
import zipfile
|
7
9
|
from collections import Counter
|
8
10
|
from itertools import compress
|
9
11
|
from typing import Any, Dict, List, Set
|
@@ -12,140 +14,197 @@ import networkx as nx
|
|
12
14
|
import nltk
|
13
15
|
import numpy as np
|
14
16
|
import pandas as pd
|
15
|
-
from nltk.tokenize import word_tokenize
|
16
17
|
from nltk.corpus import stopwords
|
18
|
+
from nltk.stem import WordNetLemmatizer
|
19
|
+
from nltk.tokenize import word_tokenize
|
17
20
|
|
21
|
+
from risk.log import logger
|
22
|
+
from scipy.sparse import coo_matrix
|
18
23
|
|
19
|
-
|
20
|
-
|
24
|
+
|
25
|
+
def ensure_nltk_resource(resource: str) -> None:
|
26
|
+
"""Ensure the specified NLTK resource is available."""
|
27
|
+
# Define the path to the resource within the NLTK data directory
|
28
|
+
resource_path = f"corpora/{resource}"
|
29
|
+
# Check if the resource is already available.
|
21
30
|
try:
|
22
|
-
nltk.data.find(
|
31
|
+
nltk.data.find(resource_path)
|
32
|
+
return
|
23
33
|
except LookupError:
|
24
|
-
|
34
|
+
print(f"Resource '{resource}' not found. Attempting to download...")
|
25
35
|
|
36
|
+
# Download the resource.
|
37
|
+
nltk.download(resource)
|
38
|
+
# Check again after downloading.
|
26
39
|
try:
|
27
|
-
nltk.data.find(
|
40
|
+
nltk.data.find(resource_path)
|
41
|
+
return
|
28
42
|
except LookupError:
|
29
|
-
|
43
|
+
print(f"Resource '{resource}' still not found after download. Checking for a ZIP file...")
|
44
|
+
|
45
|
+
# Look for a ZIP file in all known NLTK data directories.
|
46
|
+
for data_path in nltk.data.path:
|
47
|
+
zip_path = os.path.join(data_path, "corpora", f"{resource}.zip")
|
48
|
+
if os.path.isfile(zip_path):
|
49
|
+
print(f"Found ZIP file for '{resource}' at: {zip_path}")
|
50
|
+
target_dir = os.path.join(data_path, "corpora")
|
51
|
+
with zipfile.ZipFile(zip_path, "r") as z:
|
52
|
+
z.extractall(path=target_dir)
|
53
|
+
print(f"Unzipped '{resource}' successfully.")
|
54
|
+
break # Stop after unzipping the first found ZIP.
|
55
|
+
|
56
|
+
# Final check: Try to load the resource one last time.
|
57
|
+
try:
|
58
|
+
nltk.data.find(resource_path)
|
59
|
+
print(f"Resource '{resource}' is now available.")
|
60
|
+
except LookupError:
|
61
|
+
raise LookupError(f"Resource '{resource}' could not be found, downloaded, or unzipped.")
|
30
62
|
|
31
63
|
|
32
|
-
# Ensure
|
33
|
-
|
34
|
-
|
35
|
-
|
64
|
+
# Ensure the NLTK stopwords and WordNet resources are available
|
65
|
+
ensure_nltk_resource("stopwords")
|
66
|
+
ensure_nltk_resource("wordnet")
|
67
|
+
# Use NLTK's stopwords - load all languages
|
68
|
+
STOP_WORDS = set(word for lang in stopwords.fileids() for word in stopwords.words(lang))
|
69
|
+
# Initialize the WordNet lemmatizer, which is used for normalizing words
|
70
|
+
LEMMATIZER = WordNetLemmatizer()
|
36
71
|
|
37
72
|
|
38
|
-
def load_annotations(
|
39
|
-
|
73
|
+
def load_annotations(
|
74
|
+
network: nx.Graph, annotations_input: Dict[str, Any], min_nodes_per_term: int = 2
|
75
|
+
) -> Dict[str, Any]:
|
76
|
+
"""Convert annotations input to a sparse matrix and reindex based on the network's node labels.
|
40
77
|
|
41
78
|
Args:
|
42
79
|
network (nx.Graph): The network graph.
|
43
80
|
annotations_input (Dict[str, Any]): A dictionary with annotations.
|
81
|
+
min_nodes_per_term (int, optional): The minimum number of network nodes required for each annotation
|
82
|
+
term to be included. Defaults to 2.
|
44
83
|
|
45
84
|
Returns:
|
46
|
-
Dict[str, Any]: A dictionary containing ordered nodes, ordered annotations, and the binary annotations
|
85
|
+
Dict[str, Any]: A dictionary containing ordered nodes, ordered annotations, and the sparse binary annotations
|
86
|
+
matrix.
|
87
|
+
|
88
|
+
Raises:
|
89
|
+
ValueError: If no annotations are found for the nodes in the network.
|
90
|
+
ValueError: If no annotations have at least min_nodes_per_term nodes in the network.
|
47
91
|
"""
|
48
|
-
#
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
#
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
92
|
+
# Step 1: Map nodes and annotations to indices
|
93
|
+
node_label_order = [attr["label"] for _, attr in network.nodes(data=True) if "label" in attr]
|
94
|
+
node_to_idx = {node: i for i, node in enumerate(node_label_order)}
|
95
|
+
annotation_to_idx = {annotation: i for i, annotation in enumerate(annotations_input)}
|
96
|
+
# Step 2: Construct a sparse binary matrix directly
|
97
|
+
row = []
|
98
|
+
col = []
|
99
|
+
data = []
|
100
|
+
for annotation, nodes in annotations_input.items():
|
101
|
+
for node in nodes:
|
102
|
+
if node in node_to_idx and annotation in annotation_to_idx:
|
103
|
+
row.append(node_to_idx[node])
|
104
|
+
col.append(annotation_to_idx[annotation])
|
105
|
+
data.append(1)
|
106
|
+
|
107
|
+
# Create a sparse binary matrix
|
108
|
+
num_nodes = len(node_to_idx)
|
109
|
+
num_annotations = len(annotation_to_idx)
|
110
|
+
annotations_pivot = coo_matrix((data, (row, col)), shape=(num_nodes, num_annotations)).tocsr()
|
111
|
+
# Step 3: Filter out annotations with fewer than min_nodes_per_term occurrences
|
112
|
+
valid_annotations = annotations_pivot.sum(axis=0).A1 >= min_nodes_per_term
|
113
|
+
annotations_pivot = annotations_pivot[:, valid_annotations]
|
114
|
+
# Step 4: Raise errors for empty matrices
|
115
|
+
if annotations_pivot.nnz == 0:
|
116
|
+
raise ValueError("No terms found in the annotation file for the nodes in the network.")
|
117
|
+
|
118
|
+
num_remaining_annotations = annotations_pivot.shape[1]
|
119
|
+
if num_remaining_annotations == 0:
|
64
120
|
raise ValueError(
|
65
|
-
"No
|
121
|
+
f"No annotation terms found with at least {min_nodes_per_term} nodes in the network."
|
66
122
|
)
|
67
123
|
|
68
|
-
#
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
124
|
+
# Step 5: Extract ordered nodes and annotations
|
125
|
+
ordered_nodes = tuple(node_label_order)
|
126
|
+
ordered_annotations = tuple(
|
127
|
+
annotation for annotation, is_valid in zip(annotation_to_idx, valid_annotations) if is_valid
|
128
|
+
)
|
129
|
+
|
130
|
+
# Log the filtering details
|
131
|
+
logger.info(f"Minimum number of nodes per annotation term: {min_nodes_per_term}")
|
132
|
+
logger.info(f"Number of input annotation terms: {num_annotations}")
|
133
|
+
logger.info(f"Number of remaining annotation terms: {num_remaining_annotations}")
|
75
134
|
|
76
135
|
return {
|
77
136
|
"ordered_nodes": ordered_nodes,
|
78
137
|
"ordered_annotations": ordered_annotations,
|
79
|
-
"matrix":
|
138
|
+
"matrix": annotations_pivot,
|
80
139
|
}
|
81
140
|
|
82
141
|
|
83
142
|
def define_top_annotations(
|
84
143
|
network: nx.Graph,
|
85
144
|
ordered_annotation_labels: List[str],
|
86
|
-
|
87
|
-
|
88
|
-
|
145
|
+
neighborhood_significance_sums: List[int],
|
146
|
+
significant_significance_matrix: np.ndarray,
|
147
|
+
significant_binary_significance_matrix: np.ndarray,
|
89
148
|
min_cluster_size: int = 5,
|
90
149
|
max_cluster_size: int = 1000,
|
91
150
|
) -> pd.DataFrame:
|
92
|
-
"""Define top annotations based on neighborhood
|
151
|
+
"""Define top annotations based on neighborhood significance sums and binary significance matrix.
|
93
152
|
|
94
153
|
Args:
|
95
154
|
network (NetworkX graph): The network graph.
|
96
155
|
ordered_annotation_labels (list of str): List of ordered annotation labels.
|
97
|
-
|
98
|
-
|
99
|
-
|
156
|
+
neighborhood_significance_sums (list of int): List of neighborhood significance sums.
|
157
|
+
significant_significance_matrix (np.ndarray): Enrichment matrix below alpha threshold.
|
158
|
+
significant_binary_significance_matrix (np.ndarray): Binary significance matrix below alpha threshold.
|
100
159
|
min_cluster_size (int, optional): Minimum cluster size. Defaults to 5.
|
101
160
|
max_cluster_size (int, optional): Maximum cluster size. Defaults to 1000.
|
102
161
|
|
103
162
|
Returns:
|
104
163
|
pd.DataFrame: DataFrame with top annotations and their properties.
|
105
164
|
"""
|
106
|
-
# Sum the columns of the significant
|
107
|
-
|
108
|
-
# Create DataFrame to store annotations, their neighborhood
|
109
|
-
|
165
|
+
# Sum the columns of the significant significance matrix (positive floating point values)
|
166
|
+
significant_significance_scores = significant_significance_matrix.sum(axis=0)
|
167
|
+
# Create DataFrame to store annotations, their neighborhood significance sums, and significance scores
|
168
|
+
annotations_significance_matrix = pd.DataFrame(
|
110
169
|
{
|
111
170
|
"id": range(len(ordered_annotation_labels)),
|
112
171
|
"full_terms": ordered_annotation_labels,
|
113
|
-
"
|
114
|
-
"
|
172
|
+
"significant_neighborhood_significance_sums": neighborhood_significance_sums,
|
173
|
+
"significant_significance_score": significant_significance_scores,
|
115
174
|
}
|
116
175
|
)
|
117
|
-
|
176
|
+
annotations_significance_matrix["significant_annotations"] = False
|
118
177
|
# Apply size constraints to identify potential significant annotations
|
119
|
-
|
178
|
+
annotations_significance_matrix.loc[
|
120
179
|
(
|
121
|
-
|
180
|
+
annotations_significance_matrix["significant_neighborhood_significance_sums"]
|
122
181
|
>= min_cluster_size
|
123
182
|
)
|
124
183
|
& (
|
125
|
-
|
184
|
+
annotations_significance_matrix["significant_neighborhood_significance_sums"]
|
126
185
|
<= max_cluster_size
|
127
186
|
),
|
128
187
|
"significant_annotations",
|
129
188
|
] = True
|
130
189
|
# Initialize columns for connected components analysis
|
131
|
-
|
132
|
-
|
133
|
-
|
190
|
+
annotations_significance_matrix["num_connected_components"] = 0
|
191
|
+
annotations_significance_matrix["size_connected_components"] = None
|
192
|
+
annotations_significance_matrix["size_connected_components"] = annotations_significance_matrix[
|
134
193
|
"size_connected_components"
|
135
194
|
].astype(object)
|
136
|
-
|
195
|
+
annotations_significance_matrix["num_large_connected_components"] = 0
|
137
196
|
|
138
|
-
for attribute in
|
139
|
-
|
197
|
+
for attribute in annotations_significance_matrix.index.values[
|
198
|
+
annotations_significance_matrix["significant_annotations"]
|
140
199
|
]:
|
141
|
-
# Identify
|
142
|
-
|
143
|
-
compress(list(network),
|
200
|
+
# Identify significant neighborhoods based on the binary significance matrix
|
201
|
+
significant_neighborhoods = list(
|
202
|
+
compress(list(network), significant_binary_significance_matrix[:, attribute])
|
144
203
|
)
|
145
|
-
|
146
|
-
# Analyze connected components within the
|
204
|
+
significant_network = nx.subgraph(network, significant_neighborhoods)
|
205
|
+
# Analyze connected components within the significant subnetwork
|
147
206
|
connected_components = sorted(
|
148
|
-
nx.connected_components(
|
207
|
+
nx.connected_components(significant_network), key=len, reverse=True
|
149
208
|
)
|
150
209
|
size_connected_components = np.array([len(c) for c in connected_components])
|
151
210
|
|
@@ -159,125 +218,143 @@ def define_top_annotations(
|
|
159
218
|
num_large_connected_components = len(filtered_size_connected_components)
|
160
219
|
|
161
220
|
# Assign the number of connected components
|
162
|
-
|
221
|
+
annotations_significance_matrix.loc[attribute, "num_connected_components"] = (
|
163
222
|
num_connected_components
|
164
223
|
)
|
165
224
|
# Filter out attributes with more than one connected component
|
166
|
-
|
167
|
-
|
225
|
+
annotations_significance_matrix.loc[
|
226
|
+
annotations_significance_matrix["num_connected_components"] > 1,
|
227
|
+
"significant_annotations",
|
168
228
|
] = False
|
169
229
|
# Assign the number of large connected components
|
170
|
-
|
230
|
+
annotations_significance_matrix.loc[attribute, "num_large_connected_components"] = (
|
171
231
|
num_large_connected_components
|
172
232
|
)
|
173
233
|
# Assign the size of connected components, ensuring it is always a list
|
174
|
-
|
234
|
+
annotations_significance_matrix.at[attribute, "size_connected_components"] = (
|
175
235
|
filtered_size_connected_components.tolist()
|
176
236
|
)
|
177
237
|
|
178
|
-
return
|
238
|
+
return annotations_significance_matrix
|
179
239
|
|
180
240
|
|
181
241
|
def get_weighted_description(words_column: pd.Series, scores_column: pd.Series) -> str:
|
182
242
|
"""Generate a weighted description from words and their corresponding scores,
|
183
|
-
|
243
|
+
using improved weighting logic with normalization, lemmatization, and aggregation.
|
184
244
|
|
185
245
|
Args:
|
186
|
-
words_column (pd.Series): A pandas Series containing strings to process.
|
187
|
-
scores_column (pd.Series): A pandas Series containing
|
246
|
+
words_column (pd.Series): A pandas Series containing strings (phrases) to process.
|
247
|
+
scores_column (pd.Series): A pandas Series containing significance scores to weigh the terms.
|
188
248
|
|
189
249
|
Returns:
|
190
|
-
str: A coherent description formed from the most frequent and significant words
|
250
|
+
str: A coherent description formed from the most frequent and significant words.
|
191
251
|
"""
|
192
|
-
#
|
252
|
+
# Normalize significance scores to [0,1]. If all scores are identical, use 1.
|
193
253
|
if scores_column.max() == scores_column.min():
|
194
|
-
normalized_scores = pd.Series([1] * len(scores_column))
|
254
|
+
normalized_scores = pd.Series([1] * len(scores_column), index=scores_column.index)
|
195
255
|
else:
|
196
|
-
# Normalize the enrichment scores to be between 0 and 1
|
197
256
|
normalized_scores = (scores_column - scores_column.min()) / (
|
198
257
|
scores_column.max() - scores_column.min()
|
199
258
|
)
|
200
259
|
|
201
|
-
#
|
260
|
+
# Accumulate weighted counts for each token (after cleaning and lemmatization)
|
261
|
+
weighted_counts = {}
|
262
|
+
for phrase, score in zip(words_column, normalized_scores):
|
263
|
+
# Tokenize the phrase
|
264
|
+
tokens = word_tokenize(str(phrase))
|
265
|
+
# Determine the weight (scale factor; here multiplying normalized score by 10)
|
266
|
+
weight = max(1, int((0 if pd.isna(score) else score) * 10))
|
267
|
+
for token in tokens:
|
268
|
+
# Clean token: lowercase and remove extraneous punctuation (but preserve intra-word hyphens)
|
269
|
+
token_clean = re.sub(r"[^\w\-]", "", token).strip()
|
270
|
+
if not token_clean:
|
271
|
+
continue
|
272
|
+
# Skip tokens that are pure numbers
|
273
|
+
if token_clean.isdigit():
|
274
|
+
continue
|
275
|
+
# Skip stopwords
|
276
|
+
if token_clean in STOP_WORDS:
|
277
|
+
continue
|
278
|
+
# Lemmatize the token to merge similar forms
|
279
|
+
token_norm = LEMMATIZER.lemmatize(token_clean)
|
280
|
+
weighted_counts[token_norm] = weighted_counts.get(token_norm, 0) + weight
|
281
|
+
|
282
|
+
# Reconstruct a weighted token list by repeating each token by its aggregated count.
|
202
283
|
weighted_words = []
|
203
|
-
for
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
weighted_words.extend([word] * weight)
|
208
|
-
|
209
|
-
# Tokenize the weighted words, but preserve number-word patterns like '4-alpha'
|
210
|
-
tokens = word_tokenize(" ".join(weighted_words))
|
211
|
-
# Ensure we treat "4-alpha" or other "number-word" patterns as single tokens
|
284
|
+
for token, count in weighted_counts.items():
|
285
|
+
weighted_words.extend([token] * count)
|
286
|
+
|
287
|
+
# Combine tokens that match number-word patterns (e.g. "4-alpha") and remove pure numeric tokens.
|
212
288
|
combined_tokens = []
|
213
|
-
for token in
|
214
|
-
# Match patterns like '4-alpha' or '5-hydroxy' and keep them together
|
289
|
+
for token in weighted_words:
|
215
290
|
if re.match(r"^\d+-\w+", token):
|
216
291
|
combined_tokens.append(token)
|
217
|
-
elif token.replace(".", "", 1).isdigit():
|
218
|
-
# Ignore pure numbers as descriptions unless necessary
|
292
|
+
elif token.replace(".", "", 1).isdigit():
|
219
293
|
continue
|
220
294
|
else:
|
221
295
|
combined_tokens.append(token)
|
222
296
|
|
223
|
-
#
|
297
|
+
# If the only token is numeric, return a default value.
|
224
298
|
if len(combined_tokens) == 1 and combined_tokens[0].isdigit():
|
225
|
-
return "N/A"
|
299
|
+
return "N/A"
|
226
300
|
|
227
|
-
# Simplify the
|
301
|
+
# Simplify the token list to remove near-duplicates based on the Jaccard index.
|
228
302
|
simplified_words = _simplify_word_list(combined_tokens)
|
303
|
+
# Generate a coherent description from the simplified words.
|
229
304
|
description = _generate_coherent_description(simplified_words)
|
230
305
|
|
231
306
|
return description
|
232
307
|
|
233
308
|
|
234
309
|
def _simplify_word_list(words: List[str], threshold: float = 0.80) -> List[str]:
|
235
|
-
"""Filter out words that are too similar based on the Jaccard index,
|
310
|
+
"""Filter out words that are too similar based on the Jaccard index,
|
311
|
+
keeping the word with the higher aggregated count.
|
236
312
|
|
237
313
|
Args:
|
238
|
-
words (
|
314
|
+
words (List[str]): The list of tokens to be filtered.
|
239
315
|
threshold (float, optional): The similarity threshold for the Jaccard index. Defaults to 0.80.
|
240
316
|
|
241
317
|
Returns:
|
242
|
-
|
318
|
+
List[str]: A list of filtered words, where similar words are reduced to the most frequent one.
|
243
319
|
"""
|
244
|
-
# Count the occurrences
|
320
|
+
# Count the occurrences (which reflect the weighted importance)
|
245
321
|
word_counts = Counter(words)
|
246
322
|
filtered_words = []
|
247
323
|
used_words = set()
|
248
|
-
|
249
|
-
|
324
|
+
|
325
|
+
# Iterate through words sorted by descending weighted frequency
|
326
|
+
for word in sorted(word_counts, key=lambda w: word_counts[w], reverse=True):
|
250
327
|
if word in used_words:
|
251
328
|
continue
|
252
329
|
|
253
330
|
word_set = set(word)
|
254
|
-
# Find similar words based on the Jaccard index
|
331
|
+
# Find similar words (including the current word) based on the Jaccard index
|
255
332
|
similar_words = [
|
256
333
|
other_word
|
257
334
|
for other_word in word_counts
|
258
335
|
if _calculate_jaccard_index(word_set, set(other_word)) >= threshold
|
259
336
|
]
|
260
|
-
#
|
337
|
+
# Choose the word with the highest weighted count among the similar group
|
261
338
|
similar_words.sort(key=lambda w: word_counts[w], reverse=True)
|
262
339
|
best_word = similar_words[0]
|
263
340
|
filtered_words.append(best_word)
|
264
341
|
used_words.update(similar_words)
|
265
342
|
|
343
|
+
# Preserve the original order (by frequency) from the filtered set
|
266
344
|
final_words = [word for word in words if word in filtered_words]
|
267
345
|
|
268
346
|
return final_words
|
269
347
|
|
270
348
|
|
271
349
|
def _calculate_jaccard_index(set1: Set[Any], set2: Set[Any]) -> float:
|
272
|
-
"""Calculate the Jaccard
|
350
|
+
"""Calculate the Jaccard index between two sets.
|
273
351
|
|
274
352
|
Args:
|
275
|
-
set1 (
|
276
|
-
set2 (
|
353
|
+
set1 (Set[Any]): The first set.
|
354
|
+
set2 (Set[Any]): The second set.
|
277
355
|
|
278
356
|
Returns:
|
279
|
-
float: The Jaccard
|
280
|
-
Returns 0 if the union of the sets is empty.
|
357
|
+
float: The Jaccard index (intersection over union). Returns 0 if the union is empty.
|
281
358
|
"""
|
282
359
|
intersection = len(set1.intersection(set2))
|
283
360
|
union = len(set1.union(set2))
|
@@ -285,28 +362,28 @@ def _calculate_jaccard_index(set1: Set[Any], set2: Set[Any]) -> float:
|
|
285
362
|
|
286
363
|
|
287
364
|
def _generate_coherent_description(words: List[str]) -> str:
|
288
|
-
"""Generate a coherent description from a list of words
|
365
|
+
"""Generate a coherent description from a list of words.
|
366
|
+
|
289
367
|
If there is only one unique entry, return it directly.
|
368
|
+
Otherwise, order the words by frequency and join them into a single string.
|
290
369
|
|
291
370
|
Args:
|
292
|
-
words (List): A list of
|
371
|
+
words (List[str]): A list of tokens.
|
293
372
|
|
294
373
|
Returns:
|
295
|
-
str: A coherent description
|
374
|
+
str: A coherent, space-separated description.
|
296
375
|
"""
|
297
|
-
# If there are no words, return a keyword indicating no data is available
|
298
376
|
if not words:
|
299
377
|
return "N/A"
|
300
378
|
|
301
|
-
# If there
|
379
|
+
# If there is only one unique word, return it directly
|
302
380
|
unique_words = set(words)
|
303
381
|
if len(unique_words) == 1:
|
304
382
|
return list(unique_words)[0]
|
305
383
|
|
306
|
-
# Count
|
384
|
+
# Count weighted occurrences and sort in descending order.
|
307
385
|
word_counts = Counter(words)
|
308
386
|
most_common_words = [word for word, _ in word_counts.most_common()]
|
309
|
-
# Join the most common words to form a coherent description based on frequency
|
310
387
|
description = " ".join(most_common_words)
|
311
388
|
|
312
389
|
return description
|