risk-network 0.0.7b11__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +1 -1
- risk/annotations/__init__.py +1 -1
- risk/annotations/annotations.py +86 -54
- risk/annotations/io.py +14 -14
- risk/log/__init__.py +1 -1
- risk/log/console.py +139 -0
- risk/log/params.py +6 -6
- risk/neighborhoods/community.py +68 -61
- risk/neighborhoods/domains.py +43 -20
- risk/neighborhoods/neighborhoods.py +136 -71
- risk/network/geometry.py +5 -2
- risk/network/graph.py +69 -235
- risk/network/io.py +56 -18
- risk/network/plot/__init__.py +6 -0
- risk/network/plot/canvas.py +290 -0
- risk/network/plot/contour.py +327 -0
- risk/network/plot/labels.py +929 -0
- risk/network/plot/network.py +288 -0
- risk/network/plot/plotter.py +137 -0
- risk/network/plot/utils/color.py +424 -0
- risk/network/plot/utils/layout.py +91 -0
- risk/risk.py +84 -58
- risk/stats/hypergeom.py +1 -1
- risk/stats/permutation/permutation.py +21 -8
- risk/stats/poisson.py +2 -2
- risk/stats/stats.py +12 -10
- {risk_network-0.0.7b11.dist-info → risk_network-0.0.8.dist-info}/METADATA +84 -21
- risk_network-0.0.8.dist-info/RECORD +37 -0
- {risk_network-0.0.7b11.dist-info → risk_network-0.0.8.dist-info}/WHEEL +1 -1
- risk/log/config.py +0 -48
- risk/network/plot.py +0 -1343
- risk_network-0.0.7b11.dist-info/RECORD +0 -30
- {risk_network-0.0.7b11.dist-info → risk_network-0.0.8.dist-info}/LICENSE +0 -0
- {risk_network-0.0.7b11.dist-info → risk_network-0.0.8.dist-info}/top_level.txt +0 -0
risk/__init__.py
CHANGED
risk/annotations/__init__.py
CHANGED
risk/annotations/annotations.py
CHANGED
@@ -3,8 +3,9 @@ risk/annotations/annotations
|
|
3
3
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
|
+
import re
|
6
7
|
from collections import Counter
|
7
|
-
from itertools import compress
|
8
|
+
from itertools import compress
|
8
9
|
from typing import Any, Dict, List, Set
|
9
10
|
|
10
11
|
import networkx as nx
|
@@ -30,27 +31,30 @@ def _setup_nltk():
|
|
30
31
|
|
31
32
|
# Ensure you have the necessary NLTK data
|
32
33
|
_setup_nltk()
|
34
|
+
# Initialize English stopwords
|
35
|
+
stop_words = set(stopwords.words("english"))
|
33
36
|
|
34
37
|
|
35
38
|
def load_annotations(network: nx.Graph, annotations_input: Dict[str, Any]) -> Dict[str, Any]:
|
36
39
|
"""Convert annotations input to a DataFrame and reindex based on the network's node labels.
|
37
40
|
|
38
41
|
Args:
|
39
|
-
|
42
|
+
network (nx.Graph): The network graph.
|
43
|
+
annotations_input (Dict[str, Any]): A dictionary with annotations.
|
40
44
|
|
41
45
|
Returns:
|
42
|
-
|
46
|
+
Dict[str, Any]: A dictionary containing ordered nodes, ordered annotations, and the binary annotations matrix.
|
43
47
|
"""
|
44
48
|
# Flatten the dictionary to a list of tuples for easier DataFrame creation
|
45
49
|
flattened_annotations = [
|
46
50
|
(node, annotation) for annotation, nodes in annotations_input.items() for node in nodes
|
47
51
|
]
|
48
52
|
# Create a DataFrame from the flattened list
|
49
|
-
annotations = pd.DataFrame(flattened_annotations, columns=["
|
50
|
-
annotations["
|
53
|
+
annotations = pd.DataFrame(flattened_annotations, columns=["node", "annotations"])
|
54
|
+
annotations["is_member"] = 1
|
51
55
|
# Pivot to create a binary matrix with nodes as rows and annotations as columns
|
52
56
|
annotations_pivot = annotations.pivot_table(
|
53
|
-
index="
|
57
|
+
index="node", columns="annotations", values="is_member", fill_value=0, dropna=False
|
54
58
|
)
|
55
59
|
# Reindex the annotations matrix based on the node labels from the network
|
56
60
|
node_label_order = list(nx.get_node_attributes(network, "label").values())
|
@@ -80,7 +84,8 @@ def define_top_annotations(
|
|
80
84
|
network: nx.Graph,
|
81
85
|
ordered_annotation_labels: List[str],
|
82
86
|
neighborhood_enrichment_sums: List[int],
|
83
|
-
|
87
|
+
significant_enrichment_matrix: np.ndarray,
|
88
|
+
significant_binary_enrichment_matrix: np.ndarray,
|
84
89
|
min_cluster_size: int = 5,
|
85
90
|
max_cluster_size: int = 1000,
|
86
91
|
) -> pd.DataFrame:
|
@@ -90,42 +95,52 @@ def define_top_annotations(
|
|
90
95
|
network (NetworkX graph): The network graph.
|
91
96
|
ordered_annotation_labels (list of str): List of ordered annotation labels.
|
92
97
|
neighborhood_enrichment_sums (list of int): List of neighborhood enrichment sums.
|
93
|
-
|
98
|
+
significant_enrichment_matrix (np.ndarray): Enrichment matrix below alpha threshold.
|
99
|
+
significant_binary_enrichment_matrix (np.ndarray): Binary enrichment matrix below alpha threshold.
|
94
100
|
min_cluster_size (int, optional): Minimum cluster size. Defaults to 5.
|
95
101
|
max_cluster_size (int, optional): Maximum cluster size. Defaults to 1000.
|
96
102
|
|
97
103
|
Returns:
|
98
104
|
pd.DataFrame: DataFrame with top annotations and their properties.
|
99
105
|
"""
|
100
|
-
#
|
106
|
+
# Sum the columns of the significant enrichment matrix (positive floating point values)
|
107
|
+
significant_enrichment_scores = significant_enrichment_matrix.sum(axis=0)
|
108
|
+
# Create DataFrame to store annotations, their neighborhood enrichment sums, and enrichment scores
|
101
109
|
annotations_enrichment_matrix = pd.DataFrame(
|
102
110
|
{
|
103
111
|
"id": range(len(ordered_annotation_labels)),
|
104
|
-
"
|
105
|
-
"
|
112
|
+
"full_terms": ordered_annotation_labels,
|
113
|
+
"significant_neighborhood_enrichment_sums": neighborhood_enrichment_sums,
|
114
|
+
"significant_enrichment_score": significant_enrichment_scores,
|
106
115
|
}
|
107
116
|
)
|
108
|
-
annotations_enrichment_matrix["
|
109
|
-
# Apply size constraints to identify potential
|
117
|
+
annotations_enrichment_matrix["significant_annotations"] = False
|
118
|
+
# Apply size constraints to identify potential significant annotations
|
110
119
|
annotations_enrichment_matrix.loc[
|
111
|
-
(
|
112
|
-
|
113
|
-
|
120
|
+
(
|
121
|
+
annotations_enrichment_matrix["significant_neighborhood_enrichment_sums"]
|
122
|
+
>= min_cluster_size
|
123
|
+
)
|
124
|
+
& (
|
125
|
+
annotations_enrichment_matrix["significant_neighborhood_enrichment_sums"]
|
126
|
+
<= max_cluster_size
|
127
|
+
),
|
128
|
+
"significant_annotations",
|
114
129
|
] = True
|
115
130
|
# Initialize columns for connected components analysis
|
116
|
-
annotations_enrichment_matrix["
|
117
|
-
annotations_enrichment_matrix["
|
118
|
-
annotations_enrichment_matrix["
|
119
|
-
"
|
131
|
+
annotations_enrichment_matrix["num_connected_components"] = 0
|
132
|
+
annotations_enrichment_matrix["size_connected_components"] = None
|
133
|
+
annotations_enrichment_matrix["size_connected_components"] = annotations_enrichment_matrix[
|
134
|
+
"size_connected_components"
|
120
135
|
].astype(object)
|
121
|
-
annotations_enrichment_matrix["
|
136
|
+
annotations_enrichment_matrix["num_large_connected_components"] = 0
|
122
137
|
|
123
138
|
for attribute in annotations_enrichment_matrix.index.values[
|
124
|
-
annotations_enrichment_matrix["
|
139
|
+
annotations_enrichment_matrix["significant_annotations"]
|
125
140
|
]:
|
126
141
|
# Identify enriched neighborhoods based on the binary enrichment matrix
|
127
142
|
enriched_neighborhoods = list(
|
128
|
-
compress(list(network),
|
143
|
+
compress(list(network), significant_binary_enrichment_matrix[:, attribute])
|
129
144
|
)
|
130
145
|
enriched_network = nx.subgraph(network, enriched_neighborhoods)
|
131
146
|
# Analyze connected components within the enriched subnetwork
|
@@ -144,57 +159,74 @@ def define_top_annotations(
|
|
144
159
|
num_large_connected_components = len(filtered_size_connected_components)
|
145
160
|
|
146
161
|
# Assign the number of connected components
|
147
|
-
annotations_enrichment_matrix.loc[attribute, "
|
162
|
+
annotations_enrichment_matrix.loc[attribute, "num_connected_components"] = (
|
148
163
|
num_connected_components
|
149
164
|
)
|
150
165
|
# Filter out attributes with more than one connected component
|
151
166
|
annotations_enrichment_matrix.loc[
|
152
|
-
annotations_enrichment_matrix["
|
167
|
+
annotations_enrichment_matrix["num_connected_components"] > 1, "significant_annotations"
|
153
168
|
] = False
|
154
169
|
# Assign the number of large connected components
|
155
|
-
annotations_enrichment_matrix.loc[attribute, "
|
170
|
+
annotations_enrichment_matrix.loc[attribute, "num_large_connected_components"] = (
|
156
171
|
num_large_connected_components
|
157
172
|
)
|
158
173
|
# Assign the size of connected components, ensuring it is always a list
|
159
|
-
annotations_enrichment_matrix.at[attribute, "
|
174
|
+
annotations_enrichment_matrix.at[attribute, "size_connected_components"] = (
|
160
175
|
filtered_size_connected_components.tolist()
|
161
176
|
)
|
162
177
|
|
163
178
|
return annotations_enrichment_matrix
|
164
179
|
|
165
180
|
|
166
|
-
def
|
167
|
-
"""
|
168
|
-
|
181
|
+
def get_weighted_description(words_column: pd.Series, scores_column: pd.Series) -> str:
|
182
|
+
"""Generate a weighted description from words and their corresponding scores,
|
183
|
+
with support for stopwords filtering and improved weighting logic.
|
169
184
|
|
170
185
|
Args:
|
171
186
|
words_column (pd.Series): A pandas Series containing strings to process.
|
187
|
+
scores_column (pd.Series): A pandas Series containing enrichment scores to weigh the terms.
|
172
188
|
|
173
189
|
Returns:
|
174
|
-
str: A coherent description formed from the most frequent and significant words.
|
190
|
+
str: A coherent description formed from the most frequent and significant words, weighed by enrichment scores.
|
175
191
|
"""
|
176
|
-
#
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
192
|
+
# Handle case where all scores are the same
|
193
|
+
if scores_column.max() == scores_column.min():
|
194
|
+
normalized_scores = pd.Series([1] * len(scores_column))
|
195
|
+
else:
|
196
|
+
# Normalize the enrichment scores to be between 0 and 1
|
197
|
+
normalized_scores = (scores_column - scores_column.min()) / (
|
198
|
+
scores_column.max() - scores_column.min()
|
199
|
+
)
|
200
|
+
|
201
|
+
# Combine words and normalized scores to create weighted words
|
202
|
+
weighted_words = []
|
203
|
+
for word, score in zip(words_column, normalized_scores):
|
204
|
+
word = str(word)
|
205
|
+
if word not in stop_words: # Skip stopwords
|
206
|
+
weight = max(1, int((0 if pd.isna(score) else score) * 10))
|
207
|
+
weighted_words.extend([word] * weight)
|
208
|
+
|
209
|
+
# Tokenize the weighted words, but preserve number-word patterns like '4-alpha'
|
210
|
+
tokens = word_tokenize(" ".join(weighted_words))
|
211
|
+
# Ensure we treat "4-alpha" or other "number-word" patterns as single tokens
|
212
|
+
combined_tokens = []
|
213
|
+
for token in tokens:
|
214
|
+
# Match patterns like '4-alpha' or '5-hydroxy' and keep them together
|
215
|
+
if re.match(r"^\d+-\w+", token):
|
216
|
+
combined_tokens.append(token)
|
217
|
+
elif token.replace(".", "", 1).isdigit(): # Handle pure numeric tokens
|
218
|
+
# Ignore pure numbers as descriptions unless necessary
|
219
|
+
continue
|
220
|
+
else:
|
221
|
+
combined_tokens.append(token)
|
222
|
+
|
223
|
+
# Prevent descriptions like just '4' from being selected
|
224
|
+
if len(combined_tokens) == 1 and combined_tokens[0].isdigit():
|
225
|
+
return "N/A" # Return "N/A" for cases where it's just a number
|
226
|
+
|
227
|
+
# Simplify the word list and generate the description
|
228
|
+
simplified_words = _simplify_word_list(combined_tokens)
|
229
|
+
description = _generate_coherent_description(simplified_words)
|
198
230
|
|
199
231
|
return description
|
200
232
|
|
@@ -257,7 +289,7 @@ def _generate_coherent_description(words: List[str]) -> str:
|
|
257
289
|
If there is only one unique entry, return it directly.
|
258
290
|
|
259
291
|
Args:
|
260
|
-
words (
|
292
|
+
words (List): A list of words or numerical string values.
|
261
293
|
|
262
294
|
Returns:
|
263
295
|
str: A coherent description formed by arranging the words in a logical sequence.
|
risk/annotations/io.py
CHANGED
@@ -25,15 +25,15 @@ class AnnotationsIO:
|
|
25
25
|
def __init__(self):
|
26
26
|
pass
|
27
27
|
|
28
|
-
def load_json_annotation(self,
|
28
|
+
def load_json_annotation(self, network: nx.Graph, filepath: str) -> Dict[str, Any]:
|
29
29
|
"""Load annotations from a JSON file and convert them to a DataFrame.
|
30
30
|
|
31
31
|
Args:
|
32
|
-
filepath (str): Path to the JSON annotations file.
|
33
32
|
network (NetworkX graph): The network to which the annotations are related.
|
33
|
+
filepath (str): Path to the JSON annotations file.
|
34
34
|
|
35
35
|
Returns:
|
36
|
-
|
36
|
+
Dict[str, Any]: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
|
37
37
|
"""
|
38
38
|
filetype = "JSON"
|
39
39
|
# Log the loading of the JSON file
|
@@ -49,8 +49,8 @@ class AnnotationsIO:
|
|
49
49
|
|
50
50
|
def load_excel_annotation(
|
51
51
|
self,
|
52
|
-
filepath: str,
|
53
52
|
network: nx.Graph,
|
53
|
+
filepath: str,
|
54
54
|
label_colname: str = "label",
|
55
55
|
nodes_colname: str = "nodes",
|
56
56
|
sheet_name: str = "Sheet1",
|
@@ -59,8 +59,8 @@ class AnnotationsIO:
|
|
59
59
|
"""Load annotations from an Excel file and associate them with the network.
|
60
60
|
|
61
61
|
Args:
|
62
|
-
filepath (str): Path to the Excel annotations file.
|
63
62
|
network (nx.Graph): The NetworkX graph to which the annotations are related.
|
63
|
+
filepath (str): Path to the Excel annotations file.
|
64
64
|
label_colname (str): Name of the column containing the labels (e.g., GO terms).
|
65
65
|
nodes_colname (str): Name of the column containing the nodes associated with each label.
|
66
66
|
sheet_name (str, optional): The name of the Excel sheet to load (default is 'Sheet1').
|
@@ -87,8 +87,8 @@ class AnnotationsIO:
|
|
87
87
|
|
88
88
|
def load_csv_annotation(
|
89
89
|
self,
|
90
|
-
filepath: str,
|
91
90
|
network: nx.Graph,
|
91
|
+
filepath: str,
|
92
92
|
label_colname: str = "label",
|
93
93
|
nodes_colname: str = "nodes",
|
94
94
|
nodes_delimiter: str = ";",
|
@@ -96,8 +96,8 @@ class AnnotationsIO:
|
|
96
96
|
"""Load annotations from a CSV file and associate them with the network.
|
97
97
|
|
98
98
|
Args:
|
99
|
-
filepath (str): Path to the CSV annotations file.
|
100
99
|
network (nx.Graph): The NetworkX graph to which the annotations are related.
|
100
|
+
filepath (str): Path to the CSV annotations file.
|
101
101
|
label_colname (str): Name of the column containing the labels (e.g., GO terms).
|
102
102
|
nodes_colname (str): Name of the column containing the nodes associated with each label.
|
103
103
|
nodes_delimiter (str, optional): Delimiter used to separate multiple nodes within the nodes column (default is ';').
|
@@ -121,8 +121,8 @@ class AnnotationsIO:
|
|
121
121
|
|
122
122
|
def load_tsv_annotation(
|
123
123
|
self,
|
124
|
-
filepath: str,
|
125
124
|
network: nx.Graph,
|
125
|
+
filepath: str,
|
126
126
|
label_colname: str = "label",
|
127
127
|
nodes_colname: str = "nodes",
|
128
128
|
nodes_delimiter: str = ";",
|
@@ -130,8 +130,8 @@ class AnnotationsIO:
|
|
130
130
|
"""Load annotations from a TSV file and associate them with the network.
|
131
131
|
|
132
132
|
Args:
|
133
|
-
filepath (str): Path to the TSV annotations file.
|
134
133
|
network (nx.Graph): The NetworkX graph to which the annotations are related.
|
134
|
+
filepath (str): Path to the TSV annotations file.
|
135
135
|
label_colname (str): Name of the column containing the labels (e.g., GO terms).
|
136
136
|
nodes_colname (str): Name of the column containing the nodes associated with each label.
|
137
137
|
nodes_delimiter (str, optional): Delimiter used to separate multiple nodes within the nodes column (default is ';').
|
@@ -153,15 +153,15 @@ class AnnotationsIO:
|
|
153
153
|
# Load the annotations into the provided network
|
154
154
|
return load_annotations(network, annotations_input)
|
155
155
|
|
156
|
-
def load_dict_annotation(self, content: Dict[str, Any]
|
156
|
+
def load_dict_annotation(self, network: nx.Graph, content: Dict[str, Any]) -> Dict[str, Any]:
|
157
157
|
"""Load annotations from a provided dictionary and convert them to a dictionary annotation.
|
158
158
|
|
159
159
|
Args:
|
160
|
-
content (dict): The annotations dictionary to load.
|
161
160
|
network (NetworkX graph): The network to which the annotations are related.
|
161
|
+
content (Dict[str, Any]): The annotations dictionary to load.
|
162
162
|
|
163
163
|
Returns:
|
164
|
-
|
164
|
+
Dict[str, Any]: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
|
165
165
|
"""
|
166
166
|
# Ensure the input content is a dictionary
|
167
167
|
if not isinstance(content, dict):
|
@@ -219,6 +219,6 @@ def _log_loading(filetype: str, filepath: str = "") -> None:
|
|
219
219
|
filepath (str, optional): The path to the file being loaded.
|
220
220
|
"""
|
221
221
|
log_header("Loading annotations")
|
222
|
-
logger.
|
222
|
+
logger.debug(f"Filetype: {filetype}")
|
223
223
|
if filepath:
|
224
|
-
logger.
|
224
|
+
logger.debug(f"Filepath: {filepath}")
|
risk/log/__init__.py
CHANGED
risk/log/console.py
ADDED
@@ -0,0 +1,139 @@
|
|
1
|
+
"""
|
2
|
+
risk/log/console
|
3
|
+
~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
import logging
|
7
|
+
|
8
|
+
|
9
|
+
def in_jupyter():
|
10
|
+
"""Check if the code is running in a Jupyter notebook environment.
|
11
|
+
|
12
|
+
Returns:
|
13
|
+
bool: True if running in a Jupyter notebook or QtConsole, False otherwise.
|
14
|
+
"""
|
15
|
+
try:
|
16
|
+
shell = get_ipython().__class__.__name__
|
17
|
+
if shell == "ZMQInteractiveShell": # Jupyter Notebook or QtConsole
|
18
|
+
return True
|
19
|
+
elif shell == "TerminalInteractiveShell": # Terminal running IPython
|
20
|
+
return False
|
21
|
+
except NameError:
|
22
|
+
return False # Not in Jupyter
|
23
|
+
|
24
|
+
|
25
|
+
# Define the MockLogger class to replicate logging behavior with print statements in Jupyter
|
26
|
+
class MockLogger:
|
27
|
+
"""MockLogger: A lightweight logger replacement using print statements in Jupyter.
|
28
|
+
|
29
|
+
The MockLogger class replicates the behavior of a standard logger using print statements
|
30
|
+
to display messages. This is primarily used in a Jupyter environment to show outputs
|
31
|
+
directly in the notebook. The class supports logging levels such as `info`, `debug`,
|
32
|
+
`warning`, and `error`, while the `verbose` attribute controls whether to display non-error messages.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self, verbose: bool = True):
|
36
|
+
"""Initialize the MockLogger with verbosity settings.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
verbose (bool): If True, display all log messages (info, debug, warning).
|
40
|
+
If False, only display error messages. Defaults to True.
|
41
|
+
"""
|
42
|
+
self.verbose = verbose
|
43
|
+
|
44
|
+
def info(self, message: str) -> None:
|
45
|
+
"""Display an informational message.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
message (str): The informational message to be printed.
|
49
|
+
"""
|
50
|
+
if self.verbose:
|
51
|
+
print(message)
|
52
|
+
|
53
|
+
def debug(self, message: str) -> None:
|
54
|
+
"""Display a debug message.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
message (str): The debug message to be printed.
|
58
|
+
"""
|
59
|
+
if self.verbose:
|
60
|
+
print(message)
|
61
|
+
|
62
|
+
def warning(self, message: str) -> None:
|
63
|
+
"""Display a warning message.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
message (str): The warning message to be printed.
|
67
|
+
"""
|
68
|
+
print(message)
|
69
|
+
|
70
|
+
def error(self, message: str) -> None:
|
71
|
+
"""Display an error message.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
message (str): The error message to be printed.
|
75
|
+
"""
|
76
|
+
print(message)
|
77
|
+
|
78
|
+
def setLevel(self, level: int) -> None:
|
79
|
+
"""Adjust verbosity based on the logging level.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
level (int): Logging level to control message display.
|
83
|
+
- logging.DEBUG sets verbose to True (show all messages).
|
84
|
+
- logging.WARNING sets verbose to False (show only warning, error, and critical messages).
|
85
|
+
"""
|
86
|
+
if level == logging.DEBUG:
|
87
|
+
self.verbose = True # Show all messages
|
88
|
+
elif level == logging.WARNING:
|
89
|
+
self.verbose = False # Suppress all except warning, error, and critical messages
|
90
|
+
|
91
|
+
|
92
|
+
# Set up logger based on environment
|
93
|
+
if not in_jupyter():
|
94
|
+
# Set up logger normally for .py files or terminal environments
|
95
|
+
logger = logging.getLogger("risk_logger")
|
96
|
+
logger.setLevel(logging.DEBUG)
|
97
|
+
console_handler = logging.StreamHandler()
|
98
|
+
console_handler.setLevel(logging.DEBUG)
|
99
|
+
console_handler.setFormatter(logging.Formatter("%(message)s"))
|
100
|
+
|
101
|
+
if not logger.hasHandlers():
|
102
|
+
logger.addHandler(console_handler)
|
103
|
+
else:
|
104
|
+
# If in Jupyter, use the MockLogger
|
105
|
+
logger = MockLogger()
|
106
|
+
|
107
|
+
|
108
|
+
def set_global_verbosity(verbose):
|
109
|
+
"""Set the global verbosity level for the logger.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
verbose (bool): Whether to display all log messages (True) or only error messages (False).
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
None
|
116
|
+
"""
|
117
|
+
if not isinstance(logger, MockLogger):
|
118
|
+
# For the regular logger, adjust logging levels
|
119
|
+
if verbose:
|
120
|
+
logger.setLevel(logging.DEBUG) # Show all messages
|
121
|
+
console_handler.setLevel(logging.DEBUG)
|
122
|
+
else:
|
123
|
+
logger.setLevel(logging.WARNING) # Show only warning, error, and critical messages
|
124
|
+
console_handler.setLevel(logging.WARNING)
|
125
|
+
else:
|
126
|
+
# For the MockLogger, set verbosity directly
|
127
|
+
logger.setLevel(logging.DEBUG if verbose else logging.WARNING)
|
128
|
+
|
129
|
+
|
130
|
+
def log_header(input_string: str) -> None:
|
131
|
+
"""Log the input string as a header with a line of dashes above and below it.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
input_string (str): The string to be printed as a header.
|
135
|
+
"""
|
136
|
+
border = "-" * len(input_string)
|
137
|
+
logger.info(border)
|
138
|
+
logger.info(input_string)
|
139
|
+
logger.info(border)
|
risk/log/params.py
CHANGED
@@ -12,7 +12,7 @@ from typing import Any, Dict
|
|
12
12
|
|
13
13
|
import numpy as np
|
14
14
|
|
15
|
-
from .
|
15
|
+
from .console import logger, log_header
|
16
16
|
|
17
17
|
# Suppress all warnings - this is to resolve warnings from multiprocessing
|
18
18
|
warnings.filterwarnings("ignore")
|
@@ -159,7 +159,7 @@ class Params:
|
|
159
159
|
"""Load and process various parameters, converting any np.ndarray values to lists.
|
160
160
|
|
161
161
|
Returns:
|
162
|
-
|
162
|
+
Dict[str, Any]: A dictionary containing the processed parameters.
|
163
163
|
"""
|
164
164
|
log_header("Loading parameters")
|
165
165
|
return _convert_ndarray_to_list(
|
@@ -174,14 +174,14 @@ class Params:
|
|
174
174
|
)
|
175
175
|
|
176
176
|
|
177
|
-
def _convert_ndarray_to_list(d: Any) -> Any:
|
177
|
+
def _convert_ndarray_to_list(d: Dict[str, Any]) -> Dict[str, Any]:
|
178
178
|
"""Recursively convert all np.ndarray values in the dictionary to lists.
|
179
179
|
|
180
180
|
Args:
|
181
|
-
d (
|
181
|
+
d (Dict[str, Any]): The dictionary to process.
|
182
182
|
|
183
183
|
Returns:
|
184
|
-
|
184
|
+
Dict[str, Any]: The processed dictionary with np.ndarray values converted to lists.
|
185
185
|
"""
|
186
186
|
if isinstance(d, dict):
|
187
187
|
# Recursively process each value in the dictionary
|
@@ -193,5 +193,5 @@ def _convert_ndarray_to_list(d: Any) -> Any:
|
|
193
193
|
# Convert numpy arrays to lists
|
194
194
|
return d.tolist()
|
195
195
|
else:
|
196
|
-
# Return the value unchanged if it's not a dict,
|
196
|
+
# Return the value unchanged if it's not a dict, List, or ndarray
|
197
197
|
return d
|