path-boost 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,153 @@
1
+ import logging
2
+ import os
3
+ from datetime import datetime
4
+
5
+ from matplotlib import pyplot as plt
6
+ from matplotlib.ticker import MaxNLocator
7
+
8
+ logger = logging.getLogger("path_boost")
9
+
10
+
11
+ def plot_training_and_eval_errors(
12
+ learning_rate: float,
13
+ train_mse: list,
14
+ mse_eval_set: list | None = None,
15
+ skip_first_n_iterations: int | bool = False,
16
+ show=True,
17
+ save=False,
18
+ save_path: str | None = None,
19
+ ):
20
+ """
21
+ Plots the training Mean Squared Error (MSE) and, if given, the MSE for multiple
22
+ evaluation sets over the boosting iterations.
23
+
24
+ Parameters
25
+ ----------
26
+ learning_rate : float
27
+ The learning rate used during training, used to adjust the x-axis.
28
+ train_mse : list[float]
29
+ A list of training MSE values, where each element corresponds to an iteration.
30
+ mse_eval_set : list[list[float]] | None, default=None
31
+ A list of lists, where each inner list contains MSE values for an evaluation
32
+ set over iterations. If None, only training MSE is plotted.
33
+ skip_first_n_iterations : int | bool, default=True
34
+ If True, the first iteration's errors are skipped in the plot (often an outlier).
35
+ If an integer, that many initial iterations are skipped.
36
+ If False or 0, all iterations are plotted.
37
+ show : bool, default=True
38
+ If True, the plot is displayed.
39
+ save : bool, default=False
40
+ If True, the plot is saved to a file.
41
+ save_path : str | None, default=None
42
+ The directory where the plot will be saved. If None, the current
43
+ working directory is used.
44
+ """
45
+ # skip_the_first n iterations
46
+ if isinstance(skip_first_n_iterations, bool):
47
+ if skip_first_n_iterations:
48
+ n = int(2 / learning_rate)
49
+ else:
50
+ n = 0
51
+ else:
52
+ n = skip_first_n_iterations
53
+
54
+ if len(train_mse) > n:
55
+ train_mse = train_mse[n:]
56
+ else:
57
+ train_mse = train_mse
58
+
59
+ plt.figure(figsize=(12, 6))
60
+
61
+ # Plot training errors
62
+ plt.plot(range(n, len(train_mse) + n), train_mse, label="Training Error", marker="")
63
+
64
+ # Plot evaluation set errors if available
65
+ if mse_eval_set is not None:
66
+ if len(mse_eval_set[0]) > n:
67
+ eval_set_mse = [mse_eval_set[i][n:] for i in range(len(mse_eval_set))]
68
+ else:
69
+ eval_set_mse = mse_eval_set
70
+
71
+ num_iterations = len(eval_set_mse[0])
72
+ num_eval_sets = len(eval_set_mse)
73
+ for eval_set_index in range(num_eval_sets):
74
+ if eval_set_mse[eval_set_index][0] is not None:
75
+ plt.plot(
76
+ range(n, num_iterations + n),
77
+ eval_set_mse[eval_set_index],
78
+ label=f"Evaluation Set {eval_set_index + 1}",
79
+ marker="",
80
+ )
81
+
82
+ plt.xlabel("Iteration")
83
+ plt.ylabel("Mean Squared Error")
84
+ plt.title("Training and Evaluation Set Errors Over Iterations")
85
+ plt.legend()
86
+ plt.grid(True)
87
+
88
+ # Ensure x-axis only shows integers
89
+ plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
90
+
91
+ if save:
92
+ now = datetime.now()
93
+ timestamp = now.strftime("%Y-%m-%d_%H-%M-%S")
94
+ filename = f"training_and_eval_errors_{timestamp}.png"
95
+ if save_path:
96
+ if not os.path.exists(save_path):
97
+ os.makedirs(save_path)
98
+ filename = os.path.join(save_path, filename)
99
+ plt.savefig(filename)
100
+ logger.info(f"Plot saved to {filename}")
101
+
102
+ if show:
103
+ plt.show()
104
+
105
+
106
+ def plot_variable_importance_utils(
107
+ variable_importance: dict,
108
+ parameters_variable_importance: dict,
109
+ top_n: int | None = None,
110
+ show: bool = True,
111
+ ):
112
+ """
113
+ Plots the variable importance scores.
114
+
115
+ This function visualizes the importance of features
116
+ as computed by the PathBoost or SequentialPathBoost models. The appearance and
117
+ details of the plot can be influenced by the parameters used to compute
118
+ the variable importance.
119
+
120
+ Parameters
121
+ ----------
122
+ variable_importance : dict
123
+ A dictionary where keys are feature identifiers and values are their corresponding importance scores.
124
+ parameters_variable_importance : dict
125
+ A dictionary containing parameters that were used for computing
126
+ variable importance (e.g., 'criterion', 'normalize') and potentially
127
+ other parameters to guide the plotting.
128
+ show : bool, default=True
129
+ If True, the plot is displayed.
130
+ """
131
+
132
+ assert isinstance(
133
+ variable_importance, dict
134
+ ), "Variable importance should be a dictionary."
135
+ sorted_items = sorted(variable_importance.items(), key=lambda x: x[1], reverse=True)
136
+ if top_n is not None:
137
+ sorted_items = sorted_items[:top_n]
138
+ labels, values = zip(*sorted_items)
139
+
140
+ # Convert tuples in labels to strings
141
+ labels = [
142
+ ",".join(map(str, label)) if isinstance(label, tuple) else str(label)
143
+ for label in labels
144
+ ]
145
+
146
+ plt.figure(figsize=(10, 6))
147
+ plt.barh(labels, values, color="skyblue")
148
+ plt.xlabel("Importance Score")
149
+ plt.title(parameters_variable_importance["criterion"] + " Variable Importance")
150
+ plt.gca().invert_yaxis()
151
+ plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
152
+ if show:
153
+ plt.show()
@@ -0,0 +1,223 @@
1
+ import inspect
2
+ import logging
3
+ import numbers
4
+ import warnings
5
+ from typing import Iterable
6
+
7
+ import networkx as nx
8
+ import numpy as np
9
+ from sklearn.tree import DecisionTreeRegressor
10
+ from sklearn.utils.validation import validate_data
11
+
12
+ from .classes.interfaces.interface_base_learner import BaseLearnerClassInterface
13
+ from .classes.interfaces.interface_selector import SelectorClassInterface
14
+
15
+ logger = logging.getLogger("path_boost")
16
+
17
+
18
+ def check_interface(class_to_be_checked, interface_class):
19
+ if not issubclass(class_to_be_checked, interface_class):
20
+ missing_methods = []
21
+ interface_methods = inspect.getmembers(
22
+ interface_class, predicate=inspect.isfunction
23
+ )
24
+ class_methods = inspect.getmembers(
25
+ class_to_be_checked, predicate=inspect.isfunction
26
+ )
27
+ interface_method_names = [name for name, _ in interface_methods]
28
+ class_method_names = [name for name, _ in class_methods]
29
+
30
+ for method in interface_method_names:
31
+ if method not in class_method_names:
32
+ missing_methods.append(method)
33
+
34
+ missing_attributes = []
35
+ interface_attributes = inspect.getmembers(
36
+ interface_class, predicate=lambda x: not (inspect.isroutine(x))
37
+ )
38
+ class_attributes = inspect.getmembers(
39
+ class_to_be_checked, predicate=lambda x: not (inspect.isroutine(x))
40
+ )
41
+ interface_attribute_names = [name for name, _ in interface_attributes]
42
+ class_attribute_names = [name for name, _ in class_attributes]
43
+
44
+ for attribute in interface_attribute_names:
45
+ if attribute not in class_attribute_names:
46
+ missing_attributes.append(attribute)
47
+
48
+ missing_items = missing_methods + missing_attributes
49
+ raise TypeError(
50
+ f"{class_to_be_checked.__name__} must implement {interface_class.__name__}."
51
+ f" Missing items: {missing_items}"
52
+ )
53
+
54
+
55
+ def util_validate_data(
56
+ model,
57
+ X="no_validation",
58
+ y="no_validation",
59
+ reset=True,
60
+ validate_separately=False,
61
+ **check_params,
62
+ ):
63
+ # We use the `_validate_data` method to validate the input data.
64
+ # This method is defined in the `BaseEstimator` class.
65
+ # It allows to:
66
+ # - run different checks on the input data
67
+ if not np.array_equal(X, "no_validation"):
68
+ assert isinstance(X, list) and all(isinstance(item, nx.Graph) for item in X)
69
+ if not np.array_equal(y, "no_validation"):
70
+ assert isinstance(y, Iterable) and all(
71
+ isinstance(item, numbers.Number) for item in y
72
+ )
73
+
74
+ # check BaseLearnerClass and SelectorClass respects the respective interfaces
75
+ check_interface(model.BaseLearnerClass, BaseLearnerClassInterface)
76
+
77
+ check_interface(model.SelectorClass, SelectorClassInterface)
78
+
79
+ # ------------------------------------------------------------------------------------------------------
80
+ # the following is just to set the default parameters for the selector class ant the base learner class
81
+ if issubclass(model.BaseLearnerClass, DecisionTreeRegressor):
82
+ if model.kwargs_for_base_learner is None:
83
+ model.kwargs_for_base_learner = model._default_kwargs_for_base_learner
84
+ else:
85
+ for key in model._default_kwargs_for_base_learner:
86
+ if key not in model.kwargs_for_base_learner:
87
+ model.kwargs_for_base_learner[key] = (
88
+ model._default_kwargs_for_base_learner[key]
89
+ )
90
+
91
+ if issubclass(model.SelectorClass, DecisionTreeRegressor):
92
+ if model.kwargs_for_selector is None:
93
+ model.kwargs_for_selector = model._default_kwargs_for_selector
94
+ else:
95
+ for key in model._default_kwargs_for_selector:
96
+ if key not in model.kwargs_for_selector:
97
+ model.kwargs_for_selector[key] = model._default_kwargs_for_selector[
98
+ key
99
+ ]
100
+
101
+ # ------------------------------------------------------------------------------------------------------
102
+
103
+ if model.kwargs_for_selector is None:
104
+ model.kwargs_for_selector = {}
105
+ if model.kwargs_for_base_learner is None:
106
+ model.kwargs_for_base_learner = {}
107
+
108
+ # check parameters for variable importance
109
+ parameters_variable_importance = check_params.get(
110
+ "parameters_variable_importance", None
111
+ )
112
+ if parameters_variable_importance is not None:
113
+ assert isinstance(parameters_variable_importance, dict)
114
+ for key, value in parameters_variable_importance.items():
115
+ assert isinstance(key, str)
116
+ if key == "criterion":
117
+ assert value in ["absolute", "relative"]
118
+ elif key == "error_used":
119
+ assert value in ["mse", "mae"]
120
+ elif key == "use_correlation":
121
+ assert isinstance(value, bool)
122
+ elif key == "normalize":
123
+ assert isinstance(value, bool)
124
+ elif key == "normalization_value":
125
+ assert value is None or isinstance(value, float)
126
+
127
+ else:
128
+ raise ValueError(f"Unknown parameter {key} for variable importance")
129
+
130
+ # check list_anchor_nodes_labels
131
+ list_anchor_nodes_labels = check_params.get("list_anchor_nodes_labels", None)
132
+ if list_anchor_nodes_labels is not None:
133
+ # Ensure each element in list_anchor_nodes_labels is a tuple
134
+ len_list_anchor_nodes_labels = len(list_anchor_nodes_labels)
135
+ for i in range(len_list_anchor_nodes_labels):
136
+ if not isinstance(list_anchor_nodes_labels[i], tuple):
137
+ if hasattr(list_anchor_nodes_labels[i], "__iter__") and not isinstance(
138
+ list_anchor_nodes_labels[i], str
139
+ ):
140
+ list_anchor_nodes_labels[i] = tuple(list_anchor_nodes_labels[i])
141
+ else:
142
+ list_anchor_nodes_labels[i] = tuple([list_anchor_nodes_labels[i]])
143
+
144
+ elif (
145
+ "list_anchor_nodes_labels" in check_params and list_anchor_nodes_labels is None
146
+ ):
147
+ # Warn if X is provided but name_of_label_attribute is missing
148
+ if not (
149
+ not np.array_equal(X, "no_validation")
150
+ and "name_of_label_attribute" in check_params
151
+ ):
152
+ warnings.warn(
153
+ "list_anchor_nodes_labels can not be validated because it is None."
154
+ )
155
+ else:
156
+ anchor_attribute_values = []
157
+ if (
158
+ not np.array_equal(X, "no_validation")
159
+ and "name_of_label_attribute" in check_params
160
+ ):
161
+ label_attr = check_params["name_of_label_attribute"]
162
+ for graph in X:
163
+ for _, node_data in graph.nodes(data=True):
164
+ if label_attr in node_data:
165
+ anchor_attribute_values.append(node_data[label_attr])
166
+ model.list_anchor_nodes_labels = anchor_attribute_values
167
+
168
+ model.list_anchor_nodes_labels = list_anchor_nodes_labels
169
+
170
+ # check m_stops
171
+ # only for cyclic path boost
172
+ m_stops = check_params.get("m_stops", None)
173
+ if m_stops is not None:
174
+ assert isinstance(m_stops, list) and all(
175
+ isinstance(item, (int, type(None))) for item in m_stops
176
+ )
177
+ assert len(m_stops) == len(model.list_anchor_nodes_labels)
178
+
179
+ # check eval sets
180
+ eval_set = check_params.get("eval_set", None)
181
+ if eval_set is not None:
182
+ assert isinstance(eval_set, Iterable) and all(
183
+ isinstance(eval_tuple, tuple) and len(eval_tuple) == 2
184
+ for eval_tuple in eval_set
185
+ )
186
+ for eval_tuple in eval_set:
187
+ assert isinstance(eval_tuple[0], list) and all(
188
+ isinstance(item, nx.Graph) for item in eval_tuple[0]
189
+ )
190
+ assert isinstance(eval_tuple[1], Iterable) and all(
191
+ isinstance(item, numbers.Number) for item in eval_tuple[1]
192
+ )
193
+
194
+ # check patience
195
+ patience = check_params.get("patience", None)
196
+ if patience is not None:
197
+ assert (
198
+ isinstance(patience, int) and patience >= 0
199
+ ), "patience must be a non-negative integer"
200
+ if check_params.get("eval_set", None) is None:
201
+ warnings.warn(
202
+ "patience is set to None because there is no eval_set provided"
203
+ )
204
+ patience = None
205
+ model.patience = patience
206
+
207
+ if not np.array_equal(y, "no_validation"):
208
+ validate_data(
209
+ model,
210
+ X="no_validation",
211
+ y=y,
212
+ reset=reset,
213
+ validate_separately=validate_separately,
214
+ )
215
+
216
+ if not np.array_equal(X, "no_validation") and not np.array_equal(
217
+ y, "no_validation"
218
+ ):
219
+ return X, y
220
+ elif not np.array_equal(X, "no_validation"):
221
+ return X
222
+ elif not np.array_equal(y, "no_validation"):
223
+ return y