path-boost 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,596 @@
1
+ import ast
2
+ import numbers
3
+ from collections import defaultdict
4
+ from typing import Iterable
5
+
6
+ import networkx as nx
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+
11
+ class ExtendedBoostingMatrix:
12
+ frequency_column_name: str = "n_times_present"
13
+
14
+ def __init__(self):
15
+ pass
16
+
17
+ @staticmethod
18
+ def generate_new_columns_from_columns_names(
19
+ dataset: list[nx.Graph],
20
+ columns_names: Iterable[str],
21
+ main_label_name: str,
22
+ ebm_to_be_expanded: pd.DataFrame | None = None,
23
+ replace_nan_with=np.nan,
24
+ ) -> pd.DataFrame:
25
+ new_columns = None
26
+ expected_n_rows = len(dataset)
27
+
28
+ # find paths names in column_names
29
+ # for each path find which column in columns_names are related to it
30
+ columns_for_each_path = defaultdict(lambda: [])
31
+ for column in columns_names:
32
+ path_labels = ExtendedBoostingMatrix.get_path_from_column_name(column)
33
+ columns_for_each_path[path_labels].append(column)
34
+
35
+ for (
36
+ path_labels,
37
+ columns_names_referring_to_path,
38
+ ) in columns_for_each_path.items():
39
+ frequency_column_name = ExtendedBoostingMatrix.generate_name_of_columns_for(
40
+ path_labels[:-1], [ExtendedBoostingMatrix.frequency_column_name]
41
+ )[0]
42
+
43
+ if (ebm_to_be_expanded is not None) and (
44
+ frequency_column_name in ebm_to_be_expanded.columns
45
+ ):
46
+ frequency_column = ebm_to_be_expanded[frequency_column_name]
47
+ else:
48
+ frequency_column = [1] * len(dataset)
49
+
50
+ expanded_columns = (
51
+ ExtendedBoostingMatrix.generate_new_columns_from_path_labels(
52
+ dataset=dataset,
53
+ path_labels=path_labels,
54
+ main_label_name=main_label_name,
55
+ frequency_column=frequency_column,
56
+ replace_nan_with=replace_nan_with,
57
+ )
58
+ )
59
+
60
+ filtered_columns_names = [
61
+ col
62
+ for col in columns_names_referring_to_path
63
+ if col in expanded_columns.columns
64
+ ]
65
+ expanded_columns = expanded_columns[filtered_columns_names]
66
+
67
+ if new_columns is None:
68
+ new_columns = expanded_columns
69
+ else:
70
+ new_columns = pd.concat([new_columns, expanded_columns], axis=1)
71
+
72
+ if new_columns is None:
73
+ # Keep shape consistent with dataset even when no path yields columns.
74
+ new_columns = pd.DataFrame(index=range(expected_n_rows))
75
+
76
+ # Check if new_columns contains all columns_names and add missing columns
77
+ missing_columns = [
78
+ column for column in columns_names if column not in new_columns.columns
79
+ ]
80
+ missing_df = pd.DataFrame(
81
+ {col: [replace_nan_with] * expected_n_rows for col in missing_columns}
82
+ )
83
+ new_columns = pd.concat([new_columns, missing_df], axis=1)
84
+
85
+ return new_columns
86
+
87
+ @staticmethod
88
+ def generate_new_columns_from_path_labels(
89
+ dataset: list[nx.Graph],
90
+ path_labels: list,
91
+ main_label_name: str,
92
+ frequency_column,
93
+ replace_nan_with=np.nan,
94
+ ) -> pd.DataFrame:
95
+ # given a list that reppresent the labels of a path it returns the new columns for this path
96
+
97
+ # note: this frequency column is used only to indicate if a graph should be considered or not, a new frequency column with the number of times the path is present in the graph is created later
98
+ if frequency_column is None:
99
+ frequency_column = [1] * len(dataset)
100
+
101
+ columns_for_dataframe = defaultdict(lambda: [[] for _ in range(len(dataset))])
102
+
103
+ for graph_number, graph in enumerate(dataset):
104
+ if frequency_column[graph_number] > 0:
105
+ paths_found = ExtendedBoostingMatrix.find_labelled_path_in_nx_graph(
106
+ graph=graph,
107
+ path_labels=path_labels,
108
+ main_label_name=main_label_name,
109
+ )
110
+ paths_cumulative_attributes = defaultdict(lambda: [])
111
+ for path in paths_found:
112
+ attributes = (
113
+ ExtendedBoostingMatrix.get_attributes_of_last_part_of_the_path(
114
+ graph=graph, path=path
115
+ )
116
+ )
117
+ for key, value in attributes.items():
118
+ column_name = (
119
+ ExtendedBoostingMatrix.generate_name_of_columns_for(
120
+ path_label=path_labels, attributes=[key]
121
+ )[0]
122
+ )
123
+ paths_cumulative_attributes[column_name].append(value)
124
+ for key, values in paths_cumulative_attributes.items():
125
+ columns_for_dataframe[key][graph_number] = (
126
+ ExtendedBoostingMatrix.combine_attribute_value_of_multiple_paths_in_the_same_graph(
127
+ values
128
+ )
129
+ )
130
+ frequency_column_name = (
131
+ ExtendedBoostingMatrix.generate_name_of_columns_for(
132
+ path_label=path_labels,
133
+ attributes=[ExtendedBoostingMatrix.frequency_column_name],
134
+ )[0]
135
+ )
136
+ columns_for_dataframe[frequency_column_name][graph_number] = len(
137
+ paths_found
138
+ )
139
+
140
+ if len(columns_for_dataframe) == 0:
141
+ # If the path is absent in every graph, preserve row count for downstream concatenation.
142
+ new_df_columns = pd.DataFrame(index=range(len(dataset)))
143
+ else:
144
+ new_df_columns = pd.DataFrame(columns_for_dataframe)
145
+ new_df_columns = ExtendedBoostingMatrix._remove_empty_list_values_from_df(
146
+ new_df_columns, default_value=replace_nan_with
147
+ )
148
+
149
+ return new_df_columns
150
+
151
+ @staticmethod
152
+ def new_columns_to_expand_ebm_dataframe_with_path(
153
+ dataset: list[nx.Graph],
154
+ selected_path: list | tuple,
155
+ main_label_name: str,
156
+ df_to_be_expanded: pd.DataFrame,
157
+ replace_nan_with=np.nan,
158
+ ) -> pd.DataFrame:
159
+ # this function returns the new columns that should be added to the dataframe to expand it. The new columns come from the paths that expands selected path
160
+ # find frequency column relative to the selected path
161
+ path_frequency_column_name = (
162
+ ExtendedBoostingMatrix.generate_name_of_columns_for(
163
+ selected_path, [ExtendedBoostingMatrix.frequency_column_name]
164
+ )[0]
165
+ )
166
+
167
+ path_frequency_column = df_to_be_expanded[path_frequency_column_name]
168
+
169
+ # find all the efective paths in the dataset that corresponds to the selected path
170
+ effective_paths_in_graphs = ExtendedBoostingMatrix.find_paths_in_dataset(
171
+ dataset=dataset,
172
+ path_labels=selected_path,
173
+ id_label_name=main_label_name,
174
+ frequency_list=path_frequency_column,
175
+ )
176
+
177
+ graphs_new_attributes = [[] for _ in range(len(dataset))]
178
+ for graph_number, paths_in_graph in enumerate(effective_paths_in_graphs):
179
+ new_labels = []
180
+ new_attributes = []
181
+ for path in paths_in_graph:
182
+ # find all the possible extensions of the path
183
+ new_paths = ExtendedBoostingMatrix.extend_path(
184
+ nodes_path=path, graph=dataset[graph_number]
185
+ )
186
+ if new_paths is None:
187
+ continue
188
+ else:
189
+ for new_path in new_paths:
190
+ found_attributes = ExtendedBoostingMatrix.get_attributes_of_last_part_of_the_path(
191
+ graph=dataset[graph_number], path=new_path
192
+ )
193
+ new_attributes.append(found_attributes)
194
+ new_labels.append(
195
+ dataset[graph_number].nodes[new_path[-1]][main_label_name]
196
+ )
197
+ pass
198
+
199
+ graphs_new_attributes[graph_number] = new_attributes
200
+
201
+ columns_for_dataframe = defaultdict(lambda: [[] for _ in range(len(dataset))])
202
+
203
+ for graph_number in range(len(dataset)):
204
+ new_attributes_from_graph = defaultdict(lambda: [])
205
+ for i, new_attributes_from_one_path in enumerate(
206
+ graphs_new_attributes[graph_number]
207
+ ):
208
+ label = new_attributes_from_one_path[main_label_name]
209
+
210
+ for key, value in new_attributes_from_one_path.items():
211
+ columns_name = ExtendedBoostingMatrix.generate_name_of_columns_for(
212
+ selected_path + (label,), [key]
213
+ )[0]
214
+ new_attributes_from_graph[columns_name].append(value)
215
+
216
+ frequency_column_name = (
217
+ ExtendedBoostingMatrix.generate_name_of_columns_for(
218
+ selected_path + (label,),
219
+ [ExtendedBoostingMatrix.frequency_column_name],
220
+ )[0]
221
+ )
222
+ new_attributes_from_graph[frequency_column_name].append(1)
223
+
224
+ for key, values in new_attributes_from_graph.items():
225
+ if ExtendedBoostingMatrix.frequency_column_name in key:
226
+ value = len(values)
227
+ else:
228
+ value = ExtendedBoostingMatrix.combine_attribute_value_of_multiple_paths_in_the_same_graph(
229
+ values
230
+ )
231
+ columns_for_dataframe[key][graph_number] = value
232
+
233
+ new_df_columns = pd.DataFrame(columns_for_dataframe)
234
+ new_df_columns = ExtendedBoostingMatrix._remove_empty_list_values_from_df(
235
+ new_df_columns, default_value=replace_nan_with
236
+ )
237
+ return new_df_columns
238
+
239
+ @staticmethod
240
+ def extend_path(nodes_path: tuple, graph: nx.Graph):
241
+ """
242
+ Given a list of nodes, this function extends the path by adding the next node that is connected to the last node
243
+ in the path.
244
+ """
245
+ if len(nodes_path) == 0:
246
+ return None
247
+ last_node = nodes_path[-1]
248
+ neighbors = list(graph.neighbors(last_node))
249
+ if len(neighbors) == 0:
250
+ return None
251
+ new_paths = []
252
+ for neighbor in neighbors:
253
+ if neighbor not in nodes_path:
254
+ new_paths.append(nodes_path + (neighbor,))
255
+ if len(new_paths) == 0:
256
+ return None
257
+ return new_paths
258
+
259
+ @staticmethod
260
+ def find_paths_in_dataset(
261
+ dataset: list[nx.Graph],
262
+ path_labels: list | tuple,
263
+ id_label_name: str,
264
+ frequency_list=None,
265
+ ) -> list[list]:
266
+ """
267
+ Finds paths in a dataset of graphs based on provided path labels.
268
+
269
+ This method processes a list of NetworkX graphs and identifies, for each graph, all paths that match
270
+ the given sequence of labels (`path_labels`). Each returned path consists of nodes whose labels follow
271
+ the specified label sequence. Optionally, a `frequency_list` can be provided to control whether paths are
272
+ searched in each graph.
273
+
274
+ Parameters
275
+ ----------
276
+ dataset : list of nx.Graph
277
+ List of NetworkX graph objects representing the dataset.
278
+ path_labels : list or tuple
279
+ Sequence of labels defining the desired path to search for in each graph.
280
+ id_label_name : str
281
+ The node attribute name used to match labels in the graphs.
282
+ frequency_list : list of int or None, optional
283
+ List of integers indicating whether to search for paths in each graph (default is None,
284
+ which treats all graphs as active).
285
+
286
+ Returns
287
+ -------
288
+ list of list
289
+ A list where each element contains the identified paths for a graph in the dataset.
290
+ Each path is represented as a list of node IDs following `path_labels`.
291
+ """
292
+ # it returns a list of paths for each graph in the dataset each path is a list of nodes who follows the path_labels
293
+ if frequency_list is None:
294
+ frequency_list = [1] * len(dataset)
295
+ assert len(frequency_list) == len(dataset)
296
+ paths_in_graphs = [[] for _ in range(len(dataset))]
297
+ for i, graph in enumerate(dataset):
298
+ if frequency_list[i] > 0:
299
+ paths_in_graphs[i].extend(
300
+ ExtendedBoostingMatrix.find_labelled_path_in_nx_graph(
301
+ graph=graph,
302
+ path_labels=path_labels,
303
+ main_label_name=id_label_name,
304
+ )
305
+ )
306
+
307
+ return paths_in_graphs
308
+
309
+ @staticmethod
310
+ def find_labelled_path_in_nx_graph(
311
+ graph: nx.Graph, path_labels: list, main_label_name: str
312
+ ) -> list[tuple[int]]:
313
+ # find starting nodes
314
+ starting_nodes = []
315
+ for node, attributes in graph.nodes(data=True):
316
+ if attributes.get(main_label_name) == path_labels[0]:
317
+ starting_nodes.append(node)
318
+
319
+ found_paths = []
320
+ for starting_node in starting_nodes:
321
+ found_paths.extend(
322
+ ExtendedBoostingMatrix._find_labelled_path_in_nx_graph_from_starting_node(
323
+ graph=graph,
324
+ path_labels=path_labels,
325
+ main_label_name=main_label_name,
326
+ starting_node=starting_node,
327
+ )
328
+ )
329
+
330
+ return found_paths
331
+
332
+ @staticmethod
333
+ def _find_labelled_path_in_nx_graph_from_starting_node(
334
+ graph: nx.Graph,
335
+ path_labels: list,
336
+ main_label_name: str,
337
+ starting_node: int,
338
+ path=None,
339
+ visited_nodes: set | None = None,
340
+ ) -> list[tuple[int]]:
341
+ paths_found: list = []
342
+ if path is None:
343
+ path = []
344
+ if visited_nodes is None:
345
+ visited_nodes = set()
346
+
347
+ if starting_node not in visited_nodes:
348
+ label_of_node = graph.nodes[starting_node].get(main_label_name)
349
+ if label_of_node == path_labels[0]:
350
+ path = path + [starting_node]
351
+ else:
352
+ return []
353
+ visited_nodes.add(starting_node)
354
+ else:
355
+ return []
356
+ if len(path_labels) == 1:
357
+ return [tuple(path)]
358
+
359
+ # the next label we are looking for is always in the second position of the array "path_labels" since the first element is the element we just found
360
+ neighbors_with_right_label = [
361
+ neighbor
362
+ for neighbor in graph.neighbors(starting_node)
363
+ if graph.nodes[neighbor].get(main_label_name) == path_labels[1]
364
+ ]
365
+ for neighbour in neighbors_with_right_label:
366
+ if neighbour not in visited_nodes:
367
+ new_paths = ExtendedBoostingMatrix._find_labelled_path_in_nx_graph_from_starting_node(
368
+ graph=graph,
369
+ path_labels=path_labels[1:],
370
+ main_label_name=main_label_name,
371
+ starting_node=neighbour,
372
+ path=path,
373
+ visited_nodes=visited_nodes.copy(),
374
+ )
375
+
376
+ paths_found.extend(new_paths)
377
+ return paths_found
378
+
379
+ @staticmethod
380
+ def get_attributes_of_node(graph: nx.Graph, node_id: int) -> dict | None:
381
+ if graph.nodes.get(node_id) is not None:
382
+ nodes_attributes = graph.nodes.get(node_id)
383
+ numeric_nodes_attributes = {
384
+ k: v
385
+ for k, v in nodes_attributes.items()
386
+ if isinstance(v, numbers.Number)
387
+ }
388
+ return numeric_nodes_attributes
389
+ else:
390
+ return None
391
+
392
+ @staticmethod
393
+ def get_edge_attributes_of_nx_graph(
394
+ graph: nx.Graph, last_edge: tuple
395
+ ) -> dict | None:
396
+ if isinstance(graph, nx.MultiGraph) and len(last_edge) == 2:
397
+ last_edge = (last_edge[0], last_edge[1], 0)
398
+ if graph.edges.get(last_edge) is not None:
399
+ edge_attributes = graph.edges.get(last_edge)
400
+ numeric_edge_attributes = {
401
+ k: v
402
+ for k, v in edge_attributes.items()
403
+ if isinstance(v, numbers.Number)
404
+ }
405
+ return numeric_edge_attributes
406
+ else:
407
+ return None
408
+
409
+ @staticmethod
410
+ def get_attributes_of_last_part_of_the_path(
411
+ graph: nx.Graph, path: list | tuple
412
+ ) -> dict:
413
+ """
414
+ Retrieves and combines attributes of the LAST node in a given path, and optionally,
415
+ attributes of the edge connecting the last two nodes in the path. The method processes
416
+ both node and edge attributes in a NetworkX graph and merges them into a single
417
+ dictionary.
418
+
419
+ Parameters
420
+ graph: nx.Graph
421
+ The NetworkX graph from which attributes are retrieved.
422
+ path: list[int] | tuple[int]
423
+ A sequence of node identifiers representing a path in the graph.
424
+
425
+ Returns
426
+ dict
427
+ A dictionary containing the combined attributes of the last node and, if
428
+ applicable, the edge connecting the last two nodes in the specified path.
429
+ """
430
+ path_attributes = {}
431
+
432
+ node_attributes = ExtendedBoostingMatrix.get_attributes_of_node(graph, path[-1])
433
+ if node_attributes is not None:
434
+ path_attributes.update(node_attributes)
435
+ if len(path) > 1:
436
+ edge_attributes = ExtendedBoostingMatrix.get_edge_attributes_of_nx_graph(
437
+ graph, (path[-2], path[-1])
438
+ )
439
+ if edge_attributes is not None:
440
+ path_attributes.update(edge_attributes)
441
+ return path_attributes
442
+
443
+ @staticmethod
444
+ def generate_name_of_columns_for(
445
+ path_label: tuple | list, attributes: Iterable | None
446
+ ) -> list[str] | None:
447
+ if attributes is None:
448
+ return None
449
+ else:
450
+ # we parse the tuple to pytin numbers because we had some problems, sometimes it contains numpy integers that are not correctly translated when str() method is appled
451
+ path_in_python_numbers = ExtendedBoostingMatrix._parse_to_python_numbers(
452
+ path_label
453
+ )
454
+ return [
455
+ str(path_in_python_numbers) + "_" + str(attribute)
456
+ for attribute in attributes
457
+ ]
458
+
459
+ @staticmethod
460
+ def _parse_to_python_numbers(obj):
461
+ if isinstance(obj, tuple):
462
+ return tuple(
463
+ ExtendedBoostingMatrix._parse_to_python_numbers(x) for x in obj
464
+ )
465
+ elif isinstance(obj, list):
466
+ return [ExtendedBoostingMatrix._parse_to_python_numbers(x) for x in obj]
467
+ elif isinstance(obj, np.integer):
468
+ return int(obj)
469
+ elif isinstance(obj, np.floating):
470
+ return float(obj)
471
+ elif isinstance(obj, numbers.Number):
472
+ return obj
473
+ else:
474
+ return obj
475
+
476
+ @staticmethod
477
+ def combine_attribute_value_of_multiple_paths_in_the_same_graph(
478
+ values: list[numbers.Number],
479
+ ):
480
+ return np.mean(values)
481
+
482
+ @staticmethod
483
+ def initialize_boosting_matrix_with_anchor_nodes_attributes(
484
+ dataset: list[nx.Graph],
485
+ list_anchor_nodes_labels: list,
486
+ id_label_name: str,
487
+ replace_nan_with=np.nan,
488
+ ) -> pd.DataFrame:
489
+ extended_boosting_matrix_df = None
490
+ for anchor_node_label in list_anchor_nodes_labels:
491
+ if isinstance(anchor_node_label, str) or not hasattr(
492
+ anchor_node_label, "__iter__"
493
+ ):
494
+ anchor_node_label_as_tuple = tuple([anchor_node_label])
495
+ else:
496
+ anchor_node_label_as_tuple = anchor_node_label
497
+
498
+ columns_for_anchor_node = (
499
+ ExtendedBoostingMatrix.generate_new_columns_from_path_labels(
500
+ dataset=dataset,
501
+ path_labels=anchor_node_label_as_tuple,
502
+ frequency_column=None,
503
+ main_label_name=id_label_name,
504
+ replace_nan_with=replace_nan_with,
505
+ )
506
+ )
507
+ if extended_boosting_matrix_df is None:
508
+ extended_boosting_matrix_df = columns_for_anchor_node
509
+ else:
510
+ extended_boosting_matrix_df = pd.concat(
511
+ [extended_boosting_matrix_df, columns_for_anchor_node], axis=1
512
+ )
513
+
514
+ extended_boosting_matrix_df = (
515
+ ExtendedBoostingMatrix._remove_empty_list_values_from_df(
516
+ extended_boosting_matrix_df, default_value=replace_nan_with
517
+ )
518
+ )
519
+
520
+ return extended_boosting_matrix_df
521
+
522
+ @staticmethod
523
+ def _remove_empty_list_values_from_df(
524
+ df: pd.DataFrame, default_value=np.nan
525
+ ) -> pd.DataFrame:
526
+ modified_df = df.map(
527
+ lambda x: np.nan if isinstance(x, list) and len(x) == 0 else x
528
+ )
529
+ # the entries of the columns "(path)_n_times_present" are nan if the path is not present in the graph, we convert nan to 0
530
+ columns_to_replace = [
531
+ col
532
+ for col in modified_df.columns
533
+ if ExtendedBoostingMatrix.frequency_column_name in col
534
+ ]
535
+ modified_df[columns_to_replace] = (
536
+ modified_df[columns_to_replace].fillna(0).astype(int)
537
+ )
538
+ if default_value is not np.nan:
539
+ modified_df = modified_df.fillna(default_value)
540
+ return modified_df
541
+
542
+ @staticmethod
543
+ def get_attribute_name_from_column_name(column_name: str) -> str:
544
+ return column_name.split("_", 1)[1]
545
+
546
+ @staticmethod
547
+ def get_path_from_column_name(column_name: str) -> tuple:
548
+ string_path = column_name.split("_", 1)[0]
549
+ path = ast.literal_eval(string_path)
550
+ # this assert can be removed it is used during coding to make sure no error happens here
551
+ assert isinstance(path, tuple), "error in the columns name"
552
+ return path
553
+
554
+ @staticmethod
555
+ def get_frequency_boosting_matrix(
556
+ train_ebm_dataframe: pd.DataFrame,
557
+ ) -> pd.DataFrame:
558
+ selected_columns = [
559
+ column
560
+ for column in train_ebm_dataframe.columns
561
+ if ExtendedBoostingMatrix.frequency_column_name in column
562
+ ]
563
+ return train_ebm_dataframe[selected_columns]
564
+
565
+ @staticmethod
566
+ def get_columns_related_to_path(path: tuple, columns_names: list[str]) -> list[str]:
567
+ def _is_subtuple(main_tuple: tuple, sub_tuple: tuple) -> bool:
568
+ return sub_tuple == main_tuple[: len(column_path)]
569
+
570
+ columns_to_keep = []
571
+ for column in columns_names:
572
+ column_path = ExtendedBoostingMatrix.get_path_from_column_name(column)
573
+ if _is_subtuple(main_tuple=path, sub_tuple=column_path):
574
+ columns_to_keep.append(column)
575
+
576
+ return columns_to_keep
577
+
578
+ @staticmethod
579
+ def generate_frequency_column_name_for_path(path_label: tuple) -> str:
580
+ """
581
+ Generates a frequency column name for a given path.
582
+
583
+ Parameters
584
+ ----------
585
+ path : tuple
586
+ The path for which to generate the frequency column name.
587
+
588
+ Returns
589
+ -------
590
+ str
591
+ The generated frequency column name.
592
+ """
593
+ return ExtendedBoostingMatrix.generate_name_of_columns_for(
594
+ path_label=path_label,
595
+ attributes=[ExtendedBoostingMatrix.frequency_column_name],
596
+ )[0]
File without changes
@@ -0,0 +1,30 @@
1
+ import abc
2
+
3
+
4
+ class BaseLearnerClassInterface(metaclass=abc.ABCMeta):
5
+ @abc.abstractmethod
6
+ def fit(self, *args, **kwargs):
7
+ if len(args) < 2:
8
+ raise TypeError(
9
+ "fit() method requires at least two positional arguments (X and y)"
10
+ )
11
+
12
+ pass
13
+
14
+ @abc.abstractmethod
15
+ def predict(self, *args, **kwargs):
16
+ if len(args) < 1:
17
+ raise TypeError(
18
+ "fit() method requires at least one positional argument (X)"
19
+ )
20
+
21
+ pass
22
+
23
+ @classmethod
24
+ def __subclasshook__(cls, C):
25
+ if cls is BaseLearnerClassInterface:
26
+ if any("fit" in B.__dict__ for B in C.__mro__) and any(
27
+ "predict" in B.__dict__ for B in C.__mro__
28
+ ):
29
+ return True
30
+ return NotImplemented
@@ -0,0 +1,27 @@
1
+ import abc
2
+
3
+
4
+ class SelectorClassInterface(metaclass=abc.ABCMeta):
5
+ @abc.abstractmethod
6
+ def fit(self, X, y):
7
+ pass
8
+
9
+ @abc.abstractmethod
10
+ def predict(self, X):
11
+ pass
12
+
13
+ @property
14
+ @abc.abstractmethod
15
+ def feature_importances_(self):
16
+ pass
17
+
18
+ @classmethod
19
+ def __subclasshook__(cls, C):
20
+ if cls is SelectorClassInterface:
21
+ if (
22
+ any("fit" in B.__dict__ for B in C.__mro__)
23
+ and any("predict" in B.__dict__ for B in C.__mro__)
24
+ and any("feature_importances_" in B.__dict__ for B in C.__mro__)
25
+ ):
26
+ return True
27
+ return NotImplemented