psyke 0.8.9.dev48__py3-none-any.whl → 1.0.4.dev10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. psyke/__init__.py +112 -24
  2. psyke/clustering/__init__.py +4 -0
  3. psyke/clustering/cream/__init__.py +2 -6
  4. psyke/clustering/exact/__init__.py +10 -7
  5. psyke/clustering/utils.py +0 -1
  6. psyke/extraction/__init__.py +6 -2
  7. psyke/extraction/cart/{predictor.py → CartPredictor.py} +52 -7
  8. psyke/extraction/cart/FairTree.py +205 -0
  9. psyke/extraction/cart/FairTreePredictor.py +56 -0
  10. psyke/extraction/cart/__init__.py +27 -52
  11. psyke/extraction/hypercubic/__init__.py +58 -7
  12. psyke/extraction/hypercubic/creepy/__init__.py +14 -6
  13. psyke/extraction/hypercubic/ginger/__init__.py +100 -0
  14. psyke/extraction/hypercubic/gridex/__init__.py +6 -48
  15. psyke/extraction/hypercubic/gridrex/__init__.py +2 -2
  16. psyke/extraction/hypercubic/hypercube.py +33 -26
  17. psyke/extraction/hypercubic/iter/__init__.py +5 -0
  18. psyke/extraction/hypercubic/strategy.py +13 -9
  19. psyke/extraction/real/__init__.py +21 -22
  20. psyke/extraction/real/utils.py +2 -2
  21. psyke/extraction/trepan/__init__.py +19 -15
  22. psyke/genetic/__init__.py +0 -0
  23. psyke/genetic/fgin/__init__.py +74 -0
  24. psyke/genetic/gin/__init__.py +144 -0
  25. psyke/hypercubepredictor.py +4 -2
  26. psyke/tuning/pedro/__init__.py +4 -2
  27. psyke/utils/logic.py +4 -8
  28. {psyke-0.8.9.dev48.dist-info → psyke-1.0.4.dev10.dist-info}/METADATA +39 -19
  29. psyke-1.0.4.dev10.dist-info/RECORD +46 -0
  30. {psyke-0.8.9.dev48.dist-info → psyke-1.0.4.dev10.dist-info}/WHEEL +1 -1
  31. {psyke-0.8.9.dev48.dist-info → psyke-1.0.4.dev10.dist-info/licenses}/LICENSE +2 -1
  32. psyke-0.8.9.dev48.dist-info/RECORD +0 -40
  33. {psyke-0.8.9.dev48.dist-info → psyke-1.0.4.dev10.dist-info}/top_level.txt +0 -0
psyke/__init__.py CHANGED
@@ -5,16 +5,20 @@ from enum import Enum
5
5
 
6
6
  import numpy as np
7
7
  import pandas as pd
8
+ from matplotlib import pyplot as plt
8
9
  from sklearn.linear_model import LinearRegression
9
10
  from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, f1_score, accuracy_score, \
10
11
  adjusted_rand_score, adjusted_mutual_info_score, v_measure_score, fowlkes_mallows_score
12
+ from tuprolog.solve.prolog import prolog_solver
11
13
 
12
14
  from psyke.schema import DiscreteFeature
13
15
  from psyke.utils import get_default_random_seed, Target, get_int_precision
14
- from tuprolog.theory import Theory
16
+ from tuprolog.theory import Theory, mutable_theory
15
17
  from typing import Iterable
16
18
  import logging
17
19
 
20
+ from psyke.utils.logic import get_in_rule, data_to_struct, get_not_in_rule
21
+
18
22
  logging.basicConfig(level=logging.WARN)
19
23
  logger = logging.getLogger('psyke')
20
24
 
@@ -52,7 +56,7 @@ class EvaluableModel(object):
52
56
  """
53
57
  Predicts the output values of every sample in dataset.
54
58
 
55
- :param dataframe: is the set of instances to predict.
59
+ :param dataframe: the set of instances to predict.
56
60
  :return: a list of predictions.
57
61
  """
58
62
  return self.__convert(self._predict(dataframe))
@@ -61,7 +65,7 @@ class EvaluableModel(object):
61
65
  raise NotImplementedError('predict')
62
66
 
63
67
  def __convert(self, ys: Iterable) -> Iterable:
64
- if self.normalization is not None and not isinstance([p for p in ys if p is not None][0], str):
68
+ if self.normalization is not None and len(ys) > 0 and not isinstance([p for p in ys if p is not None][0], str):
65
69
  m, s = self.normalization[list(self.normalization.keys())[-1]]
66
70
  ys = [prediction if prediction is None else prediction * s + m for prediction in ys]
67
71
  return ys
@@ -85,7 +89,7 @@ class EvaluableModel(object):
85
89
  def score(self, dataframe: pd.DataFrame, predictor=None, fidelity: bool = False, completeness: bool = True,
86
90
  brute: bool = False, criterion: str = 'corners', n: int = 2,
87
91
  task: EvaluableModel.Task = Task.CLASSIFICATION,
88
- scoring_function: Iterable[EvaluableModel.Score] = [ClassificationScore.ACCURACY]):
92
+ scoring_function: Iterable[EvaluableModel.Score] = (ClassificationScore.ACCURACY, )):
89
93
  extracted = np.array(
90
94
  self.predict(dataframe.iloc[:, :-1]) if not brute else
91
95
  self.brute_predict(dataframe.iloc[:, :-1], criterion, n)
@@ -151,42 +155,113 @@ class Extractor(EvaluableModel, ABC):
151
155
  def __init__(self, predictor, discretization: Iterable[DiscreteFeature] = None, normalization=None):
152
156
  super().__init__(discretization, normalization)
153
157
  self.predictor = predictor
158
+ self.theory = None
154
159
 
155
160
  def extract(self, dataframe: pd.DataFrame) -> Theory:
156
161
  """
157
162
  Extracts rules from the underlying predictor.
158
163
 
159
- :param dataframe: is the set of instances to be used for the extraction.
164
+ :param dataframe: the set of instances to be used for the extraction.
160
165
  :return: the theory created from the extracted rules.
161
166
  """
162
167
  raise NotImplementedError('extract')
163
168
 
164
- def predict_why(self, data: dict[str, float], verbose=True):
169
+ def predict_why(self, data: dict[str, float], verbose: bool = True):
165
170
  """
166
171
  Provides a prediction and the corresponding explanation.
167
- :param data: is the instance to predict.
168
- :param verbose: if the explanation has to be printed.
172
+ :param data: the instance to predict.
173
+ :param verbose: if True the explanation is printed.
169
174
  """
170
175
  raise NotImplementedError('predict_why')
171
176
 
172
- def predict_counter(self, data: dict[str, float], verbose=True, only_first=True):
177
+ def predict_counter(self, data: dict[str, float], verbose: bool = True, only_first: bool = True):
173
178
  """
174
179
  Provides a prediction and counterfactual explanations.
175
- :param data: is the instance to predict.
176
- :param verbose: if the counterfactual explanation has to be printed.
177
- :param only_first: if only the closest counterfactual explanation is provided for each distinct class.
180
+ :param data: the instance to predict.
181
+ :param verbose: if True the counterfactual explanation is printed.
182
+ :param only_first: if True only the closest counterfactual explanation is provided for each distinct class.
178
183
  """
179
184
  raise NotImplementedError('predict_counter')
180
185
 
186
+ def plot_fairness(self, dataframe: pd.DataFrame, groups: dict[str, list], colormap='seismic_r', filename=None,
187
+ figsize=(5, 4)):
188
+ """
189
+ Provides a visual estimation of the fairness exhibited by an extractor with respect to the specified groups.
190
+ :param dataframe: the set of instances to be used for the estimation.
191
+ :param groups: the set of relevant groups to consider.
192
+ :param colormap: the colormap to use for the plot.
193
+ :param filename: if not None, name used to save the plot.
194
+ :param figsize: size of the plot.
195
+ """
196
+ counts = {group: len(dataframe[idx_g]) for group, idx_g in groups.items()}
197
+ output = {'labels': []}
198
+ for group in groups:
199
+ output[group] = []
200
+ for i, clause in enumerate(self.theory.clauses):
201
+ if len(dataframe) == 0:
202
+ break
203
+ solver = prolog_solver(static_kb=mutable_theory(clause).assertZ(get_in_rule()).assertZ(get_not_in_rule()))
204
+ idx = np.array([query.is_yes for query in
205
+ [solver.solveOnce(data_to_struct(data)) for _, data in dataframe.iterrows()]])
206
+ # print(f'Rule {i + 1}. Outcome {clause.head.args[-1]}. Affecting', end='')
207
+ output['labels'].append(str(clause.head.args[-1]))
208
+ for group, idx_g in groups.items():
209
+ # print(f' {len(dataframe[idx & idx_g]) / counts[group]:.2f}%{group}', end='')
210
+ output[group].append(len(dataframe[idx & idx_g]) / counts[group])
211
+ dataframe = dataframe[~idx]
212
+ groups = {group: indices[~idx] for group, indices in groups.items()}
213
+ # print(f'. Left {len(dataframe)} instances')
214
+
215
+ binary = len(set(output['labels'])) == 2
216
+ labels = sorted(set(output['labels']))
217
+ data = np.vstack([output[group] for group in groups]).T * 100
218
+ if binary:
219
+ data[np.array(output['labels']) == labels[0]] *= -1
220
+
221
+ plt.figure(figsize=figsize)
222
+ plt.imshow(data, cmap=colormap, vmin=-100 if binary else 0, vmax=100)
223
+
224
+ plt.gca().set_xticks(range(len(groups)), labels=groups.keys())
225
+ plt.gca().set_yticks(range(len(output['labels'])),
226
+ labels=[f'Rule {i + 1}\n{l}' for i, l in enumerate(output['labels'])])
227
+
228
+ plt.xlabel('Groups')
229
+ plt.ylabel('Rules')
230
+ plt.title("Rule set impact on groups")
231
+
232
+ for i in range(len(output['labels'])):
233
+ for j in range(len(groups)):
234
+ plt.gca().text(j, i, f'{abs(data[i, j]):.2f}%', ha="center", va="center", color="k")
235
+
236
+ plt.gca().set_xticks([i + .5 for i in range(len(groups))], minor=True)
237
+ plt.gca().set_yticks([i + .5 for i in range(len(output['labels']))], minor=True)
238
+ plt.gca().grid(which='minor', color='k', linestyle='-', linewidth=.8)
239
+ plt.gca().tick_params(which='minor', bottom=False, left=False)
240
+ cbarticks = np.linspace(-100 if binary else 0, 100, 9 if binary else 11, dtype=int)
241
+ cbar = plt.colorbar(fraction=0.046, label='Affected samples (%)', ticks=cbarticks)
242
+ if binary:
243
+ ticklabels = [str(-i) if i < 0 else str(i) for i in cbarticks]
244
+ ticklabels[0] += f' {labels[0]}'
245
+ ticklabels[-1] += f' {labels[-1]}'
246
+ cbar.ax.set_yticklabels(ticklabels)
247
+
248
+ plt.tight_layout()
249
+ if filename is not None:
250
+ plt.savefig(filename, dpi=500)
251
+ plt.show()
252
+
253
+ def make_fair(self, features: Iterable[str]):
254
+ raise NotImplementedError(f'Fairness for {type(self).__name__} is not supported at the moment')
255
+
181
256
  def mae(self, dataframe: pd.DataFrame, predictor=None, brute: bool = False, criterion: str = 'center',
182
257
  n: int = 3) -> float:
183
258
  """
184
259
  Calculates the predictions' MAE w.r.t. the instances given as input.
185
260
 
186
- :param dataframe: is the set of instances to be used to calculate the mean absolute error.
261
+ :param dataframe: the set of instances to be used to calculate the mean absolute error.
187
262
  :param predictor: if provided, its predictions on the dataframe are taken instead of the dataframe instances.
188
263
  :param brute: if True, a brute prediction is executed.
189
- :param criterion: creterion for brute prediction.
264
+ :param criterion: criterion for brute prediction.
190
265
  :param n: number of points for brute prediction with 'perimeter' criterion.
191
266
  :return: the mean absolute error (MAE) of the predictions.
192
267
  """
@@ -198,10 +273,10 @@ class Extractor(EvaluableModel, ABC):
198
273
  """
199
274
  Calculates the predictions' MSE w.r.t. the instances given as input.
200
275
 
201
- :param dataframe: is the set of instances to be used to calculate the mean squared error.
276
+ :param dataframe: the set of instances to be used to calculate the mean squared error.
202
277
  :param predictor: if provided, its predictions on the dataframe are taken instead of the dataframe instances.
203
278
  :param brute: if True, a brute prediction is executed.
204
- :param criterion: creterion for brute prediction.
279
+ :param criterion: criterion for brute prediction.
205
280
  :param n: number of points for brute prediction with 'perimeter' criterion.
206
281
  :return: the mean squared error (MSE) of the predictions.
207
282
  """
@@ -213,10 +288,10 @@ class Extractor(EvaluableModel, ABC):
213
288
  """
214
289
  Calculates the predictions' R2 score w.r.t. the instances given as input.
215
290
 
216
- :param dataframe: is the set of instances to be used to calculate the R2 score.
291
+ :param dataframe: the set of instances to be used to calculate the R2 score.
217
292
  :param predictor: if provided, its predictions on the dataframe are taken instead of the dataframe instances.
218
293
  :param brute: if True, a brute prediction is executed.
219
- :param criterion: creterion for brute prediction.
294
+ :param criterion: criterion for brute prediction.
220
295
  :param n: number of points for brute prediction with 'perimeter' criterion.
221
296
  :return: the R2 score of the predictions.
222
297
  """
@@ -224,14 +299,14 @@ class Extractor(EvaluableModel, ABC):
224
299
  Extractor.Task.REGRESSION, [Extractor.RegressionScore.R2])[Extractor.RegressionScore.R2][-1]
225
300
 
226
301
  def accuracy(self, dataframe: pd.DataFrame, predictor=None, brute: bool = False, criterion: str = 'center',
227
- n: int = 3) -> float:
302
+ n: int = 3) -> float:
228
303
  """
229
304
  Calculates the predictions' accuracy classification score w.r.t. the instances given as input.
230
305
 
231
- :param dataframe: is the set of instances to be used to calculate the accuracy classification score.
306
+ :param dataframe: the set of instances to be used to calculate the accuracy classification score.
232
307
  :param predictor: if provided, its predictions on the dataframe are taken instead of the dataframe instances.
233
308
  :param brute: if True, a brute prediction is executed.
234
- :param criterion: creterion for brute prediction.
309
+ :param criterion: criterion for brute prediction.
235
310
  :param n: number of points for brute prediction with 'perimeter' criterion.
236
311
  :return: the accuracy classification score of the predictions.
237
312
  """
@@ -244,10 +319,10 @@ class Extractor(EvaluableModel, ABC):
244
319
  """
245
320
  Calculates the predictions' F1 score w.r.t. the instances given as input.
246
321
 
247
- :param dataframe: is the set of instances to be used to calculate the F1 score.
322
+ :param dataframe: the set of instances to be used to calculate the F1 score.
248
323
  :param predictor: if provided, its predictions on the dataframe are taken instead of the dataframe instances.
249
324
  :param brute: if True, a brute prediction is executed.
250
- :param criterion: creterion for brute prediction.
325
+ :param criterion: criterion for brute prediction.
251
326
  :param n: number of points for brute prediction with 'perimeter' criterion.
252
327
  :return: the F1 score of the predictions.
253
328
  """
@@ -319,6 +394,19 @@ class Extractor(EvaluableModel, ABC):
319
394
  from psyke.extraction.hypercubic.hex import HEx
320
395
  return HEx(predictor, grid, min_examples, threshold, output, discretization, normalization, seed)
321
396
 
397
+ @staticmethod
398
+ def ginger(predictor, features: Iterable[str], sigmas: Iterable[float], max_slices: int, min_rules: int = 1,
399
+ max_poly: int = 1, alpha: float = 0.5, indpb: float = 0.5, tournsize: int = 3, metric: str = 'R2',
400
+ n_gen: int = 50, n_pop: int = 50, threshold=None, valid=None, output=Target.REGRESSION,
401
+ normalization: dict[str, tuple[float, float]] = None,
402
+ seed: int = get_default_random_seed()) -> Extractor:
403
+ """
404
+ Creates a new GInGER extractor.
405
+ """
406
+ from psyke.extraction.hypercubic.ginger import GInGER
407
+ return GInGER(predictor, features, sigmas, max_slices, min_rules, max_poly, alpha, indpb, tournsize, metric,
408
+ n_gen, n_pop, threshold, valid, output, normalization, seed)
409
+
322
410
  @staticmethod
323
411
  def gridrex(predictor, grid, min_examples: int = 250, threshold: float = 0.1,
324
412
  normalization: dict[str, tuple[float, float]] = None,
@@ -331,7 +419,7 @@ class Extractor(EvaluableModel, ABC):
331
419
 
332
420
  @staticmethod
333
421
  def creepy(predictor, clustering, depth: int, error_threshold: float, output: Target = Target.CONSTANT,
334
- gauss_components: int = 2, ranks: [(str, float)] = [], ignore_threshold: float = 0.0,
422
+ gauss_components: int = 2, ranks: Iterable[(str, float)] = tuple(), ignore_threshold: float = 0.0,
335
423
  discretization=None, normalization: dict[str, tuple[float, float]] = None,
336
424
  seed: int = get_default_random_seed()) -> Extractor:
337
425
  """
@@ -10,6 +10,10 @@ class HyperCubeClustering(HyperCubePredictor, Clustering, ABC):
10
10
 
11
11
  def __init__(self, output: Target = Target.CONSTANT, discretization=None, normalization=None):
12
12
  HyperCubePredictor.__init__(self, output=output, discretization=discretization, normalization=normalization)
13
+ self._protected_features = []
13
14
 
14
15
  def get_hypercubes(self) -> Iterable[HyperCube]:
15
16
  raise NotImplementedError('get_hypercubes')
17
+
18
+ def make_fair(self, features: Iterable[str]):
19
+ self._protected_features = features
@@ -46,11 +46,7 @@ class CREAM(ExACT):
46
46
  def _iterate(self, surrounding: Node) -> Iterable[HyperCube]:
47
47
  to_split = [(self.error_threshold * 10, 1, 1, surrounding)]
48
48
  while len(to_split) > 0:
49
- to_split.sort(reverse=True)
50
- (_, depth, _, node) = to_split.pop()
51
- data = ExACT._remove_string_label(node.dataframe)
52
- gauss_params = select_gaussian_mixture(data, self.gauss_components)
53
- gauss_pred = gauss_params[2].predict(data)
49
+ node, depth, gauss_pred, gauss_params = self._get_gauss_predictions(to_split)
54
50
  cubes = self.__eligible_cubes(gauss_pred, node, gauss_params[1])
55
51
  if len(cubes) < 1:
56
52
  continue
@@ -65,4 +61,4 @@ class CREAM(ExACT):
65
61
  (error, depth + 1, np.random.uniform(), n) for (n, error) in
66
62
  zip(node.children, [right[0].diversity, left[0].diversity]) if error > self.error_threshold
67
63
  ]
68
- return self._node_to_cubes(surrounding)
64
+ return self._node_to_cubes(surrounding)
@@ -54,13 +54,13 @@ class ExACT(HyperCubeClustering, ABC):
54
54
  dbscan_pred = DBSCAN(eps=select_dbscan_epsilon(data, clusters)).fit_predict(data.iloc[:, :-1])
55
55
  return HyperCube.create_surrounding_cube(
56
56
  dataframe.iloc[np.where(dbscan_pred == Counter(dbscan_pred).most_common(1)[0][0])],
57
- True, self._output
57
+ True, self._output, self._protected_features
58
58
  )
59
59
 
60
60
  def fit(self, dataframe: pd.DataFrame):
61
61
  np.random.seed(self.seed)
62
62
  self._predictor.fit(dataframe.iloc[:, :-1], dataframe.iloc[:, -1])
63
- self._surrounding = HyperCube.create_surrounding_cube(dataframe, True, self._output)
63
+ self._surrounding = HyperCube.create_surrounding_cube(dataframe, True, self._output, self._protected_features)
64
64
  self._hypercubes = self._iterate(Node(dataframe, self._surrounding))
65
65
 
66
66
  def get_hypercubes(self) -> Iterable[HyperCube]:
@@ -79,14 +79,17 @@ class ExACT(HyperCubeClustering, ABC):
79
79
  enumerate(dataframe.iloc[:, -1].unique())
80
80
  ).items()}}) if isinstance(dataframe.iloc[0, -1], str) else dataframe
81
81
 
82
+ def _get_gauss_predictions(self, to_split):
83
+ to_split.sort(reverse=True)
84
+ (_, depth, _, node) = to_split.pop()
85
+ data = ExACT._remove_string_label(node.dataframe)
86
+ gauss_params = select_gaussian_mixture(data.drop(self._protected_features, axis=1), self.gauss_components)
87
+ return node, depth, gauss_params[2].predict(data.drop(self._protected_features, axis=1)), gauss_params
88
+
82
89
  def _iterate(self, surrounding: Node) -> Iterable[HyperCube]:
83
90
  to_split = [(self.error_threshold * 10, 1, 1, surrounding)]
84
91
  while len(to_split) > 0:
85
- to_split.sort(reverse=True)
86
- (_, depth, _, node) = to_split.pop()
87
- data = ExACT._remove_string_label(node.dataframe)
88
- gauss_params = select_gaussian_mixture(data, self.gauss_components)
89
- gauss_pred = gauss_params[2].predict(data)
92
+ node, depth, gauss_pred, gauss_params = self._get_gauss_predictions(to_split)
90
93
  cubes, indices = self.__eligible_cubes(gauss_pred, node, gauss_params[1])
91
94
  cubes = [(c.volume(), len(idx), i, idx, c) for i, (c, idx) in enumerate(zip(cubes, indices))
92
95
  if (idx is not None) and (not node.cube.equal(c))]
psyke/clustering/utils.py CHANGED
@@ -11,7 +11,6 @@ def select_gaussian_mixture(data: pd.DataFrame, max_components) -> tuple[float,
11
11
  try:
12
12
  models = [GaussianMixture(n_components=n).fit(data) for n in components if n <= len(data)]
13
13
  except ValueError:
14
- print(data)
15
14
  print(len(data))
16
15
  return min([(m.bic(data) / (i + 2), (i + 2), m) for i, m in enumerate(models)])
17
16
 
@@ -11,11 +11,15 @@ class PedagogicalExtractor(Extractor, ABC):
11
11
  def __init__(self, predictor, discretization=None, normalization=None):
12
12
  Extractor.__init__(self, predictor=predictor, discretization=discretization, normalization=normalization)
13
13
 
14
- def extract(self, dataframe: pd.DataFrame) -> Theory:
14
+ def _substitute_output(self, dataframe: pd.DataFrame) -> pd.DataFrame:
15
15
  new_y = pd.DataFrame(self.predictor.predict(dataframe.iloc[:, :-1])).set_index(dataframe.index)
16
16
  data = dataframe.iloc[:, :-1].copy().join(new_y)
17
17
  data.columns = dataframe.columns
18
- return self._extract(data)
18
+ return data
19
+
20
+ def extract(self, dataframe: pd.DataFrame) -> Theory:
21
+ self.theory = self._extract(self._substitute_output(dataframe))
22
+ return self.theory
19
23
 
20
24
  def _extract(self, dataframe: pd.DataFrame) -> Theory:
21
25
  raise NotImplementedError('extract')
@@ -1,11 +1,14 @@
1
- from collections import Iterable
1
+ from collections.abc import Iterable
2
2
  from typing import Union, Any
3
3
  import numpy as np
4
+ import pandas as pd
4
5
  from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
5
- from psyke.schema import Value, LessThan, GreaterThan, SchemaException
6
+ from tuprolog.core import clause, Var, Struct
7
+ from tuprolog.theory import Theory, mutable_theory
6
8
 
7
- LeafConstraints = dict[str, list[Value]]
8
- LeafSequence = Iterable[tuple[LeafConstraints, Any]]
9
+ from psyke.extraction.cart import LeafConstraints, LeafSequence
10
+ from psyke.schema import LessThan, GreaterThan, SchemaException, DiscreteFeature
11
+ from psyke.utils.logic import create_variable_list, create_head, create_term
9
12
 
10
13
 
11
14
  class CartPredictor:
@@ -14,11 +17,12 @@ class CartPredictor:
14
17
  """
15
18
 
16
19
  def __init__(self, predictor: Union[DecisionTreeClassifier, DecisionTreeRegressor] = DecisionTreeClassifier(),
17
- normalization=None):
20
+ discretization=None, normalization=None):
18
21
  self._predictor = predictor
22
+ self.discretization = discretization
19
23
  self.normalization = normalization
20
24
 
21
- def __get_constraints(self, nodes: Iterable[(int, bool)]) -> LeafConstraints:
25
+ def __get_constraints(self, nodes: Iterable[tuple[int, bool]]) -> LeafConstraints:
22
26
  thresholds = [self._predictor.tree_.threshold[i[0]] for i in nodes]
23
27
  features = [self._predictor.feature_names_in_[self._predictor.tree_.feature[node[0]]] for node in nodes]
24
28
  conditions = [node[1] for node in nodes]
@@ -48,7 +52,7 @@ class CartPredictor:
48
52
  else:
49
53
  return self._predictor.tree_.value[node]
50
54
 
51
- def __path(self, node: int, path=None) -> Iterable[(int, bool)]:
55
+ def __path(self, node: int, path=None) -> Iterable[tuple[int, bool]]:
52
56
  path = [] if path is None else path
53
57
  if node == 0:
54
58
  return path
@@ -62,6 +66,47 @@ class CartPredictor:
62
66
  def predict(self, data) -> Iterable:
63
67
  return self._predictor.predict(data)
64
68
 
69
+ @staticmethod
70
+ def _simplify_nodes(nodes: list) -> Iterable:
71
+ simplified = [nodes.pop(0)]
72
+ while len(nodes) > 0:
73
+ first_node = nodes[0][0]
74
+ for k, conditions in first_node.items():
75
+ for condition in conditions:
76
+ if all(k in node[0] and condition in node[0][k] for node in nodes):
77
+ [node[0][k].remove(condition) for node in nodes]
78
+ simplified.append(nodes.pop(0))
79
+ return [({k: v for k, v in rule.items() if v != []}, prediction) for rule, prediction in simplified]
80
+
81
+ def _create_body(self, variables: dict[str, Var], conditions: LeafConstraints) -> Iterable[Struct]:
82
+ results = []
83
+ for feature_name, cond_list in conditions.items():
84
+ for condition in cond_list:
85
+ feature: DiscreteFeature = [d for d in self.discretization if feature_name in d.admissible_values][0] \
86
+ if self.discretization else None
87
+ results.append(create_term(variables[feature_name], condition) if feature is None else
88
+ create_term(variables[feature.name],
89
+ feature.admissible_values[feature_name],
90
+ isinstance(condition, GreaterThan)))
91
+ return results
92
+
93
+ def create_theory(self, data: pd.DataFrame, simplify: bool = True) -> Theory:
94
+ new_theory = mutable_theory()
95
+ nodes = [node for node in self]
96
+ nodes = self._simplify_nodes(nodes) if simplify else nodes
97
+ for (constraints, prediction) in nodes:
98
+ if self.normalization is not None and data.columns[-1] in self.normalization:
99
+ m, s = self.normalization[data.columns[-1]]
100
+ prediction = prediction * s + m
101
+ variables = create_variable_list(self.discretization, data)
102
+ new_theory.assertZ(
103
+ clause(
104
+ create_head(data.columns[-1], list(variables.values()), prediction),
105
+ self._create_body(variables, constraints)
106
+ )
107
+ )
108
+ return new_theory
109
+
65
110
  @property
66
111
  def predictor(self) -> Union[DecisionTreeClassifier, DecisionTreeRegressor]:
67
112
  return self._predictor
@@ -0,0 +1,205 @@
1
+ import numpy as np
2
+ from collections import Counter
3
+
4
+ from sklearn.metrics import accuracy_score, r2_score
5
+
6
+
7
+ class Node:
8
+ def __init__(self, feature=None, threshold=None, left=None, right=None, *, value=None):
9
+ self.feature = feature
10
+ self.threshold = threshold
11
+ self.left = left
12
+ self.right = right
13
+ self.value = value
14
+
15
+ def is_leaf_node(self):
16
+ return self.value is not None
17
+
18
+
19
+ class FairTree:
20
+ def __init__(self, max_depth=3, max_leaves=None, criterion=None, min_samples_split=2, lambda_penalty=0.0,
21
+ protected_attr=None):
22
+ self.max_depth = max_depth
23
+ self.max_leaves = max_leaves
24
+ self.min_samples_split = min_samples_split
25
+ self.lambda_penalty = lambda_penalty
26
+ self.protected_attr = protected_attr
27
+ self.criterion = criterion
28
+ self.root = None
29
+ self.n_leaves = 0
30
+ self.quality_function = None
31
+
32
+ def fit(self, X, y):
33
+ self.n_leaves = 0
34
+ self.root = self._grow_tree(X, y, depth=0)
35
+ while self.n_leaves > self.max_leaves:
36
+ self.prune_least_important_leaf(X, y)
37
+ self.n_leaves -= 1
38
+ return self
39
+
40
+ @staticmethod
41
+ def _estimate_output(y):
42
+ raise NotImplementedError
43
+
44
+ def score(self, X, y):
45
+ raise NotImplementedError
46
+
47
+ def predict(self, X):
48
+ return np.array([self._traverse_tree(x, self.root) for _, x in X.iterrows()])
49
+
50
+ def _traverse_tree(self, x, node):
51
+ if node.is_leaf_node():
52
+ return node.value
53
+ if x[node.feature] <= node.threshold:
54
+ return self._traverse_tree(x, node.left)
55
+ return self._traverse_tree(x, node.right)
56
+
57
+ def _grow_tree(self, X, y, depth):
58
+ if depth >= self.max_depth or X.shape[0] < self.min_samples_split or len(set(y.values.flatten())) == 1 or \
59
+ (self.max_leaves is not None and self.n_leaves >= self.max_leaves):
60
+ self.n_leaves += 1
61
+ return Node(value=self._estimate_output(y))
62
+
63
+ best_feature, best_threshold = self._best_split(X, y)
64
+ if best_feature is None:
65
+ self.n_leaves += 1
66
+ return Node(value=self._estimate_output(y))
67
+
68
+ left_idxs = X[best_feature] <= best_threshold
69
+ right_idxs = X[best_feature] > best_threshold
70
+
71
+ left = self._grow_tree(X[left_idxs], y[left_idxs], depth + 1)
72
+ right = self._grow_tree(X[right_idxs], y[right_idxs], depth + 1)
73
+ return Node(best_feature, best_threshold, left, right)
74
+
75
+ @staticmethod
76
+ def generate_thresholds(X, y):
77
+ sorted_indices = np.argsort(X)
78
+ X = np.array(X)[sorted_indices]
79
+ y = np.array(y)[sorted_indices]
80
+ # X = np.array(np.unique(np.unique(list(zip(X, y)), axis=0)[:, 0]), dtype=float)
81
+ return np.array([(X[:-1][i] + X[1:][i]) / 2.0 for i in range(len(X) - 1) if y[i] != y[i + 1]])
82
+
83
+ def _best_split(self, X, y):
84
+ best_gain = -float('inf')
85
+ split_idx, split_threshold = None, None
86
+
87
+ for feature in [feature for feature in X.columns if feature not in self.protected_attr]:
88
+ # for threshold in self.generate_thresholds(X[feature], y):
89
+ for threshold in np.unique(np.quantile(X[feature], np.linspace(0, 1, num=25))):
90
+ left_idxs = X[feature] <= threshold
91
+ right_idxs = X[feature] > threshold
92
+
93
+ if left_idxs.sum() == 0 or right_idxs.sum() == 0:
94
+ continue
95
+
96
+ gain = self._fair_gain(y, left_idxs, right_idxs, X[self.protected_attr])
97
+
98
+ if gain > best_gain:
99
+ best_gain = gain
100
+ split_idx = feature
101
+ split_threshold = threshold
102
+ return split_idx, split_threshold
103
+
104
+ @staticmethod
105
+ def _disparity(group):
106
+ counts = Counter(group)
107
+ if len(counts) <= 1:
108
+ return 0.0
109
+ values = np.array(list(counts.values())) / len(group)
110
+ return np.abs(values[0] - values[1])
111
+
112
+ def _fair_gain(self, y, left_idx, right_idx, protected):
113
+ child = len(y[left_idx]) / len(y) * self.quality_function(y[left_idx]) + \
114
+ len(y[right_idx]) / len(y) * self.quality_function(y[right_idx])
115
+ info_gain = self.quality_function(y) - child
116
+ penalty = self._disparity(protected[left_idx]) + self._disparity(protected[right_idx])
117
+ return info_gain - self.lambda_penalty * penalty
118
+
119
+ @staticmethod
120
+ def _match_path(x, path):
121
+ for node, left in path:
122
+ if left and x[node.feature] > node.threshold:
123
+ return False
124
+ if not left and x[node.feature] <= node.threshold:
125
+ return False
126
+ return True
127
+
128
+ @staticmethod
129
+ def candidates(node, parent=None, is_left=None, path=[]):
130
+ if node is None or node.is_leaf_node():
131
+ return []
132
+ leaves = []
133
+ if node.left.is_leaf_node() and node.right.is_leaf_node():
134
+ leaves.append((node, parent, is_left, path))
135
+ leaves += FairTreeClassifier.candidates(node.left, node, True, path + [(node, True)])
136
+ leaves += FairTreeClassifier.candidates(node.right, node, False, path + [(node, False)])
137
+ return leaves
138
+
139
+ def prune_least_important_leaf(self, X, y):
140
+ best_score = -np.inf
141
+ best_prune = None
142
+
143
+ for node, parent, is_left, path in self.candidates(self.root):
144
+ original_left = node.left
145
+ original_right = node.right
146
+
147
+ merged_y = y[(X.apply(lambda x: self._match_path(x, path), axis=1))]
148
+ if len(merged_y) == 0:
149
+ continue
150
+ new_value = self._estimate_output(merged_y)
151
+ node.left = node.right = None
152
+ node.value = new_value
153
+
154
+ score = self.score(X, y)
155
+ if score >= best_score:
156
+ best_score = score
157
+ best_prune = (node, new_value)
158
+
159
+ node.left, node.right, node.value = original_left, original_right, None
160
+
161
+ if best_prune:
162
+ best_prune[0].left = best_prune[0].right = None
163
+ best_prune[0].value = best_prune[1]
164
+
165
+
166
+ class FairTreeClassifier(FairTree):
167
+ def __init__(self, max_depth=3, max_leaves=None, criterion='entropy', min_samples_split=2, lambda_penalty=0.0,
168
+ protected_attr=None):
169
+ super().__init__(max_depth, max_leaves, criterion, min_samples_split, lambda_penalty, protected_attr)
170
+ self.quality_function = self._gini if self.criterion == 'gini' else self._entropy
171
+
172
+ @staticmethod
173
+ def _estimate_output(y):
174
+ return Counter(y.values.flatten()).most_common(1)[0][0]
175
+
176
+ def score(self, X, y):
177
+ return accuracy_score(y.values.flatten(), self.predict(X))
178
+
179
+ @staticmethod
180
+ def _entropy(y):
181
+ ps = np.unique(y, return_counts=True)[1] / len(y)
182
+ return -np.sum([p * np.log2(p) for p in ps if p > 0])
183
+
184
+ @staticmethod
185
+ def _gini(y):
186
+ return 1.0 - np.sum(np.unique(y, return_counts=True)[1] / len(y)**2)
187
+
188
+
189
+ class FairTreeRegressor(FairTree):
190
+ def __init__(self, max_depth=3, max_leaves=None, criterion='mse', min_samples_split=2, lambda_penalty=0.0,
191
+ protected_attr=None):
192
+ super().__init__(max_depth, max_leaves, criterion, min_samples_split, lambda_penalty, protected_attr)
193
+ self.quality_function = self._mse
194
+
195
+ @staticmethod
196
+ def _estimate_output(y):
197
+ return np.mean(y.values.flatten())
198
+
199
+ def score(self, X, y):
200
+ return r2_score(y.values.flatten(), self.predict(X))
201
+
202
+ @staticmethod
203
+ def _mse(y):
204
+ y = y.values.flatten().astype(float)
205
+ return np.mean((y - np.mean(y))**2)