linearrf 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
File without changes
@@ -0,0 +1,16 @@
1
+ Metadata-Version: 2.1
2
+ Name: linearrf
3
+ Version: 1.0.0
4
+ Summary: A python libary to build Random Forests with Linear Models at the leaves.
5
+ Author-email: Marian Biermann <marianbiermann@gmx.de>
6
+ Project-URL: homepage, https://github.com/marianbiermann/lrf
7
+ Keywords: ml,rf,linear model,tree,dart,model tree,linear tree
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Programming Language :: Python
10
+ Classifier: Programming Language :: Python :: 3
11
+ Requires-Python: >=3.7
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE.md
14
+ Requires-Dist: numpy>=1.20.3
15
+ Provides-Extra: dev
16
+ Requires-Dist: sklearn; extra == "dev"
File without changes
@@ -0,0 +1,27 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "linearrf"
7
+ version = "1.0.0"
8
+ description = "A python libary to build Random Forests with Linear Models at the leaves."
9
+ readme = "README.md"
10
+ authors = [{ name = "Marian Biermann", email = "marianbiermann@gmx.de" }]
11
+ license = { file = "LICENSE" }
12
+ classifiers = [
13
+ "License :: OSI Approved :: MIT License",
14
+ "Programming Language :: Python",
15
+ "Programming Language :: Python :: 3",
16
+ ]
17
+ keywords = ["ml", "rf", "linear model", "tree", "dart", "model tree", "linear tree"]
18
+ dependencies = [
19
+ "numpy >= 1.20.3",
20
+ ]
21
+ requires-python = ">=3.7"
22
+
23
+ [project.optional-dependencies]
24
+ dev = ["sklearn"]
25
+
26
+ [project.urls]
27
+ homepage = "https://github.com/marianbiermann/lrf"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,16 @@
1
+ Metadata-Version: 2.1
2
+ Name: linearrf
3
+ Version: 1.0.0
4
+ Summary: A python libary to build Random Forests with Linear Models at the leaves.
5
+ Author-email: Marian Biermann <marianbiermann@gmx.de>
6
+ Project-URL: homepage, https://github.com/marianbiermann/lrf
7
+ Keywords: ml,rf,linear model,tree,dart,model tree,linear tree
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Programming Language :: Python
10
+ Classifier: Programming Language :: Python :: 3
11
+ Requires-Python: >=3.7
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE.md
14
+ Requires-Dist: numpy>=1.20.3
15
+ Provides-Extra: dev
16
+ Requires-Dist: sklearn; extra == "dev"
@@ -0,0 +1,17 @@
1
+ LICENSE.md
2
+ MANIFEST.in
3
+ README.md
4
+ pyproject.toml
5
+ src/linearrf.egg-info/PKG-INFO
6
+ src/linearrf.egg-info/SOURCES.txt
7
+ src/linearrf.egg-info/dependency_links.txt
8
+ src/linearrf.egg-info/requires.txt
9
+ src/linearrf.egg-info/top_level.txt
10
+ src/lrf/__init__.py
11
+ src/lrf/_base_lrf.py
12
+ src/lrf/_bfgs.py
13
+ src/lrf/_criterion.py
14
+ src/lrf/_linear_models.py
15
+ src/lrf/_node.py
16
+ src/lrf/_preprocessor.py
17
+ src/lrf/lrf.py
@@ -0,0 +1,4 @@
1
+ numpy>=1.20.3
2
+
3
+ [dev]
4
+ sklearn
@@ -0,0 +1 @@
1
+ from .lrf import LRFClassifier, LRFRegressor
@@ -0,0 +1,360 @@
1
+ import copy
2
+ import datetime
3
+ import time
4
+ from concurrent.futures import ProcessPoolExecutor, as_completed
5
+ from multiprocessing import cpu_count
6
+ from typing import List, Union
7
+
8
+ import numpy as np
9
+
10
+ from lrf._linear_models import Regressor, Classifier
11
+ from lrf._node import Node
12
+
13
+
14
+ class _LinearRandomForest:
15
+ def __init__(self, linear_model: Union[Regressor, Classifier] = None, n_estimators: int = 100, max_depth: int = 5,
16
+ criterion: str = None, n_splits: int = 15, split_samples_to_features_ratio: float = 4.5,
17
+ leaf_samples_to_features_ratio: float = 2.0, min_abs_improvement: float = 5*10**(-4),
18
+ warm_start: bool = True, n_jobs: int = -1, random_state: int = None, verbose: bool = False,
19
+ classification: bool = False):
20
+ self.linear_model = linear_model
21
+ self.n_estimators = n_estimators
22
+ self.max_depth = max_depth
23
+ self.criterion = criterion
24
+ self.n_splits = n_splits
25
+ self.split_samples_to_features_ratio = split_samples_to_features_ratio
26
+ self.leaf_samples_to_features_ratio = leaf_samples_to_features_ratio
27
+ self.min_abs_improvement = min_abs_improvement
28
+ self.warm_start = warm_start
29
+ self.n_jobs = n_jobs
30
+ self.random_state = random_state
31
+ self.verbose = verbose
32
+ self.classification = classification
33
+
34
+ def _init_more_attributes(self, y):
35
+ if self.classification:
36
+ self.classes_ = None
37
+
38
+ self.forest = None
39
+ self.min_samples_split = None
40
+ self.min_samples_leaf = None
41
+
42
+ if self.max_depth is None:
43
+ self.max_depth = 10 ** 32
44
+
45
+ if self.n_jobs == -1 or self.n_jobs == 0:
46
+ self.n_jobs = cpu_count()
47
+ else:
48
+ self.n_jobs = min(self.n_jobs, cpu_count())
49
+
50
+ if self.split_samples_to_features_ratio < self.leaf_samples_to_features_ratio * 2:
51
+ self.split_samples_to_features_ratio = self.leaf_samples_to_features_ratio * 2
52
+
53
+ self.min_samples_split = None
54
+ self.min_samples_leaf = None
55
+
56
+ self.total_data_points = y.shape[0]
57
+
58
+ def fit(self, x: np.ndarray, y: np.ndarray):
59
+
60
+ assert y.ndim == 1
61
+ assert not np.all(y == y[0])
62
+
63
+ self._init_more_attributes(y=y)
64
+
65
+ if self.classification:
66
+ self._check_targets_classification(y)
67
+
68
+ random_state_list = np.random.default_rng(self.random_state).integers(2**63, size=self.n_estimators)
69
+
70
+ self.min_samples_split = self.split_samples_to_features_ratio * x.shape[1]
71
+ self.min_samples_leaf = self.leaf_samples_to_features_ratio * x.shape[1]
72
+
73
+ forest = []
74
+
75
+ # add intercept here and not inside linear model for performance reasons
76
+ x = np.insert(x, 0, 1, axis=1)
77
+
78
+ # parallel process combinations of chunks of the data
79
+ with ProcessPoolExecutor(max_workers=self.n_jobs) as executor:
80
+ if self.verbose:
81
+ print('\nStart growing trees...')
82
+ finished_tasks = 0
83
+ start_time = time.time()
84
+
85
+ results = [executor.submit(self._grow_tree, x=x, y=y, random_state=i) for i in random_state_list]
86
+
87
+ # collect the results and print the progress
88
+ for r in as_completed(results):
89
+ # collecting results
90
+ grown_tree = r.result()
91
+ forest.append(grown_tree)
92
+
93
+ # printing progress
94
+ if self.verbose:
95
+ finished_tasks += 1
96
+ self._print_progress(frac=finished_tasks/self.n_estimators, start_time=start_time)
97
+
98
+ self.forest = forest
99
+
100
+ if self.verbose:
101
+ elapsed_seconds = round(time.time() - start_time)
102
+ print('Finished planting the forest in {} '.format(str(datetime.timedelta(seconds=elapsed_seconds))))
103
+
104
+ @staticmethod
105
+ def _print_progress(frac: float, start_time: float):
106
+ """
107
+ Prints the progress of the parallel multiprocessing
108
+ Args:
109
+ frac (int): Fraction of tasks which are already finished
110
+
111
+ """
112
+ elapsed_seconds = round(time.time() - start_time)
113
+ remaining_seconds = round(elapsed_seconds / frac - elapsed_seconds)
114
+ print('LRF - Progress: {}%, [{}<{}]'.format(
115
+ round(100 * frac, 2),
116
+ str(datetime.timedelta(seconds=elapsed_seconds)),
117
+ str(datetime.timedelta(seconds=remaining_seconds))
118
+ ), end='\r')
119
+
120
+ def _grow_tree(self, x: np.ndarray, y: np.ndarray, random_state: int):
121
+
122
+ rng = np.random.default_rng(random_state)
123
+
124
+ idx = rng.choice(np.arange(x.shape[0]), x.shape[0])
125
+ x = x[idx]
126
+ y = y[idx]
127
+
128
+ tree = self._root_node(x=x, y=y)
129
+
130
+ # split
131
+ tree = self._split(node=tree, x=x, y=y, depth=0, rng=rng)
132
+
133
+ return tree
134
+
135
+ def _root_node(self, x: np.ndarray, y: np.ndarray):
136
+ # initial linear model
137
+ root_model = copy.deepcopy(self.linear_model)
138
+
139
+ if isinstance(root_model, (Regressor, Classifier)):
140
+ root_model.fit(x, y, None)
141
+ else:
142
+ root_model.fit(x, y)
143
+
144
+ if self.criterion == 'cross_entropy':
145
+ y_pred = root_model.predict_proba(x)
146
+ elif (self.criterion == 'neg_roc_auc') or (self.criterion == 'neg_pr_auc'):
147
+ y_pred = root_model.predict_proba(x)[:, 1]
148
+ else:
149
+ y_pred = root_model.predict(x)
150
+
151
+ metric = self._calculate_metric(y_true=y, y_pred=y_pred)
152
+
153
+ # create node object
154
+ tree = Node(depth=0, metric=metric, model=root_model)
155
+
156
+ return tree
157
+
158
+ def _split(self, node: Node, x: np.ndarray, y: np.ndarray, depth: int, rng: np.random.Generator):
159
+ if (depth == self.max_depth) or np.all(np.all(x == x[0, :], axis=1)) or (x.shape[0] < self.min_samples_split):
160
+ return node
161
+ else:
162
+ split = self._find_best_split(x=x, y=y, last_metric=node.metric, old_coefs=node.model.coef_, rng=rng)
163
+
164
+ if split.get('threshold') is not None:
165
+ node.threshold = split['threshold']
166
+ node.split_col_idx = split['column']
167
+
168
+ left_node = Node(depth=depth + 1, model=split['model_left'], metric=split['metric_left'])
169
+ left_node = self._split(node=left_node, x=split['x_left'], y=split['y_left'],
170
+ depth=depth + 1, rng=rng)
171
+
172
+ right_node = Node(depth=depth + 1, model=split['model_right'], metric=split['metric_right'])
173
+ right_node = self._split(node=right_node, x=split['x_right'], y=split['y_right'],
174
+ depth=depth + 1, rng=rng)
175
+
176
+ node.left_node = left_node
177
+ node.right_node = right_node
178
+ node.model = None
179
+
180
+ return node
181
+
182
+ def _find_best_split(self, x: np.ndarray, y: np.ndarray,
183
+ last_metric: float, old_coefs: np.ndarray, rng: np.random.Generator):
184
+ split = {}
185
+
186
+ random_col_ids = rng.choice(np.arange(1, (x.shape[1])), int(round(np.sqrt(x.shape[1] - 1))), replace=False)
187
+
188
+ for col in random_col_ids:
189
+ split_candidates = self._split_values(x[:, col], rng=rng)
190
+
191
+ for thresh in split_candidates:
192
+ left_idx = x[:, col] <= thresh
193
+ left_idx, right_idx = left_idx.nonzero()[0], (~left_idx).nonzero()[0]
194
+
195
+ if x[:, col].max() == thresh:
196
+ continue
197
+
198
+ x_left, y_left = x.take(left_idx, axis=0), y.take(left_idx, axis=0)
199
+ x_right, y_right = x.take(right_idx, axis=0), y.take(right_idx, axis=0)
200
+
201
+ if np.all(y_left == y_left[0]) or np.all(y_right == y_right[0]):
202
+ continue
203
+
204
+ observations_left, observations_right = y_left.shape[0], y_right.shape[0]
205
+
206
+ if (
207
+ observations_left < self.min_samples_leaf
208
+ ) or (
209
+ observations_right < self.min_samples_leaf
210
+ ):
211
+ continue
212
+
213
+ # initialize models
214
+ model_left, model_right = copy.deepcopy(self.linear_model), copy.deepcopy(self.linear_model)
215
+
216
+ # fit models
217
+ if self.warm_start and isinstance(model_left, (Regressor, Classifier)) and isinstance(
218
+ model_right, (Regressor, Classifier)):
219
+ model_left.fit(x_left, y_left, initial_coefs=old_coefs)
220
+ model_right.fit(x_right, y_right, initial_coefs=old_coefs)
221
+ else:
222
+ model_left.fit(x_left, y_left, None)
223
+ model_right.fit(x_right, y_right, None)
224
+
225
+ # get prediction for these nodes
226
+ if self.criterion == 'cross_entropy':
227
+ y_pred_left = model_left.predict_proba(x_left)
228
+ y_pred_right = model_right.predict_proba(x_right)
229
+ elif (self.criterion == 'neg_roc_auc') or (self.criterion == 'neg_pr_auc'):
230
+ y_pred_left = model_left.predict_proba(x_left)[:, 1]
231
+ y_pred_right = model_right.predict_proba(x_right)[:, 1]
232
+ else:
233
+ y_pred_left = model_left.predict(x_left)
234
+ y_pred_right = model_right.predict(x_right)
235
+
236
+ # get metrics for these nodes
237
+ metric_left = self._calculate_metric(y_true=y_left, y_pred=y_pred_left)
238
+ metric_right = self._calculate_metric(y_true=y_right, y_pred=y_pred_right)
239
+
240
+ new_metric = ((metric_left * observations_left + metric_right * observations_right)
241
+ / (observations_left + observations_right))
242
+ better_split = new_metric < (last_metric - self.min_abs_improvement)
243
+
244
+ if better_split:
245
+ last_metric = new_metric
246
+
247
+ split = {'column': col,
248
+ 'threshold': thresh,
249
+ 'model_left': model_left,
250
+ 'model_right': model_right,
251
+ 'x_right': x_right,
252
+ 'y_right': y_right,
253
+ 'x_left': x_left,
254
+ 'y_left': y_left,
255
+ 'metric_left': metric_left,
256
+ 'metric_right': metric_right}
257
+
258
+ return split
259
+
260
+ def _split_values(self, values: np.ndarray, rng: np.random.Generator) -> List:
261
+ unique_values = np.unique(values)
262
+ if unique_values.shape[0] <= self.n_splits:
263
+ split_values = unique_values.tolist()
264
+ else:
265
+ perc_splits = np.ceil(2 * self.n_splits / 3)
266
+ perc_splits = np.unique(np.percentile(values, np.arange(50 / perc_splits, 100, 100 / perc_splits),
267
+ method='closest_observation'))
268
+
269
+ n_smart_splits = self.n_splits - perc_splits.shape[0]
270
+
271
+ diff = np.diff(unique_values, prepend=unique_values[0])
272
+ std = diff[1:].std()
273
+ length = unique_values.shape[0]
274
+
275
+ mask = np.array([False] * length)
276
+ k = np.arange(0, 15, 0.02)
277
+ for j in k:
278
+ mask += diff > (0.5 - j / 200) * length ** (1 / (2 + j)) * std
279
+ if np.count_nonzero(mask) >= n_smart_splits:
280
+ break
281
+
282
+ mask = mask.nonzero()[0]
283
+ smart_splits = unique_values.take(mask, axis=0)
284
+ if smart_splits.shape[0] > n_smart_splits:
285
+ smart_splits = rng.choice(smart_splits, n_smart_splits, replace=False)
286
+
287
+ split_values = perc_splits.tolist() + smart_splits.tolist()
288
+
289
+ max_value = unique_values[-1]
290
+ split_values = [val for val in split_values if val != max_value]
291
+ return split_values
292
+
293
+ def predict(self, x: np.ndarray):
294
+ NotImplementedError()
295
+
296
+ def _calculate_metric(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
297
+ NotImplementedError()
298
+
299
+ def export_text(self, tree: int = None, column_names: List[str] = None, ndigits: int = 5):
300
+ txt = ''
301
+ if tree is None:
302
+ for i, node in enumerate(self.forest):
303
+ txt += 'Tree {}:\n'.format(i)
304
+ txt += self._node_to_text(node=node, column_names=column_names, ndigits=ndigits)
305
+ txt += '\n' + '\n'
306
+ else:
307
+ txt += 'Tree {}:\n'.format(tree)
308
+ node = self.forest[tree]
309
+ txt += self._node_to_text(node=node, column_names=column_names, ndigits=ndigits)
310
+
311
+ return txt
312
+
313
+ def _node_to_text(self, node: Node, column_names: List[str] = None, ndigits: int = 3):
314
+
315
+ txt = ''.join(['| ']*node.depth)
316
+ txt += '|---'
317
+
318
+ if node.model is None:
319
+ if column_names is None:
320
+ col = 'col_{}'.format(node.split_col_idx - 1)
321
+ else:
322
+ col = column_names[node.split_col_idx - 1]
323
+
324
+ txt += ' '.join([col, '<', str(round(node.threshold, ndigits))])
325
+ txt += '\n'
326
+
327
+ txt += self._node_to_text(node=node.left_node, column_names=column_names, ndigits=ndigits)
328
+
329
+ txt += ''.join(['| '] * node.depth)
330
+ txt += '|---'
331
+ txt += ' '.join([col, '>=', str(round(node.threshold, ndigits))])
332
+ txt += '\n'
333
+
334
+ txt += self._node_to_text(node=node.right_node, column_names=column_names, ndigits=ndigits)
335
+
336
+ else:
337
+ intercept = node.model.coef_[0]
338
+ weights = node.model.coef_[1:]
339
+ weights = ['+' + str(round(w, ndigits)) if w > 0 else str(round(w, ndigits)) for w in weights]
340
+ if column_names is None:
341
+ cols = ['col_{}'.format(i) for i in range(len(weights))]
342
+ else:
343
+ cols = column_names
344
+
345
+ weights_and_cols = ' '.join(['*'.join(p) for p in list(zip(weights, cols))])
346
+
347
+ txt += ' '.join(['model: y =', str(round(intercept, ndigits)), weights_and_cols])
348
+
349
+ txt += '\n'
350
+
351
+ return txt
352
+
353
+ def set_params(self, **parameters):
354
+ for parameter, value in parameters.items():
355
+ setattr(self, parameter, value)
356
+ return self
357
+
358
+ def _check_targets_classification(self, y: np.ndarray):
359
+ self.classes_ = np.unique(y)
360
+ assert issubclass(self.classes_.dtype.type, np.integer), 'Please convert targets to integer values'
@@ -0,0 +1,176 @@
1
+ import numpy as np
2
+
3
+ from lrf._criterion import cross_entropy
4
+
5
+
6
+ class BFGS:
7
+ def __init__(self, n_iter: int = 100, tol: float = 10**(-4), intercept: bool = True):
8
+ self.n_iter = n_iter
9
+ self.tol = tol
10
+ self.intercept = intercept
11
+
12
+ def classification(self, x: np.ndarray, y_true: np.ndarray, coef_: np.ndarray, C: float = 1.0):
13
+ y_true = y_true[:, np.newaxis]
14
+
15
+ coef_ = coef_[:, np.newaxis]
16
+ new_grad = self._grad_cross_entropy_logistic(y_true=y_true, x=x, coef_=coef_, C=C,
17
+ y_pred=self._sigmoid(x@coef_))
18
+
19
+ H_inv = np.eye(coef_.shape[0]) / 0.2
20
+
21
+ alpha = 1
22
+ for _ in range(self.n_iter):
23
+
24
+ grad = new_grad
25
+
26
+ direction = -H_inv @ grad
27
+
28
+ alpha, new_grad = self._line_search(x=x, y=y_true, coef_=coef_, direction=direction,
29
+ grad=grad, C=C, alpha=alpha)
30
+
31
+ if alpha is None:
32
+ break
33
+
34
+ s = alpha * direction
35
+
36
+ change_mask = coef_ != 0
37
+ change = np.abs(s[change_mask] / coef_[change_mask]) if np.count_nonzero(change_mask) > 0 else 1
38
+
39
+ coef_ += s
40
+
41
+ if (np.max(change) <= self.tol) or np.all(new_grad == 0):
42
+ break
43
+ else:
44
+ grad_diff = new_grad - grad
45
+
46
+ st_grad_diff = s.T @ grad_diff
47
+
48
+ A = ((st_grad_diff + grad_diff.T @ H_inv @ grad_diff) * (s @ s.T)) / (st_grad_diff**2)
49
+
50
+ B = (H_inv @ grad_diff @ s.T + s @ grad_diff.T @ H_inv) / st_grad_diff
51
+
52
+ H_inv += A - B
53
+
54
+ return coef_.flatten()
55
+
56
+ @staticmethod
57
+ def _sigmoid(y: np.ndarray):
58
+ """
59
+ Sigmoid function to map input to values between 0 and 1 on the characteristic s-shaped curve (sigmoid curve).
60
+ This is the probability for the positive class.
61
+
62
+ Args:
63
+ y: np.ndarray
64
+ Input values, which will be mapped to values between 0 and 1.
65
+
66
+ Returns:
67
+ np.ndarray: Returns the probability for the positive class.
68
+ """
69
+ return np.exp(-np.logaddexp(0, -y))
70
+
71
+ def _grad_cross_entropy_logistic(self, y_true: np.ndarray, x: np.ndarray, y_pred: np.ndarray,
72
+ coef_: np.ndarray, C: float):
73
+
74
+ weights = coef_.copy()
75
+ if self.intercept:
76
+ weights[0] = 0
77
+
78
+ norm = np.linalg.norm(weights)
79
+ if norm != 0.0:
80
+ penalty = np.einsum('ij->', weights) / (C * norm)
81
+ else:
82
+ penalty = 0
83
+
84
+ grad = x.T @ (y_pred - y_true) + penalty
85
+ norm = np.linalg.norm(grad)
86
+ if norm == 0:
87
+ grad = np.zeros(grad.shape)
88
+ else:
89
+ grad /= norm
90
+ return grad
91
+
92
+ def _armijo(self, y: np.ndarray, y_pred: np.ndarray, coef_: np.ndarray, alpha: float, C: float,
93
+ c1: float, grad_dir: float, cross_entropy_value: float):
94
+
95
+ penalty = self.get_penalty(coef=coef_, C=C)
96
+
97
+ left_armijo = cross_entropy(y_true=y, y_pred=y_pred, penalty=penalty)
98
+ right_armijo = cross_entropy_value + c1 * alpha * grad_dir
99
+
100
+ armijo = left_armijo <= right_armijo
101
+
102
+ return armijo
103
+
104
+ def _wolfe(self, x: np.ndarray, y: np.ndarray, coef_: np.ndarray, alpha: float, C: float,
105
+ direction: np.ndarray, c1: float, c2: float, grad_dir: float, cross_entropy_value: float,
106
+ x_coef: np.ndarray, x_direction: np.ndarray):
107
+
108
+ y_pred = self._sigmoid(x_coef + alpha*x_direction)
109
+
110
+ armijo = self._armijo(y=y, coef_=coef_, alpha=alpha, c1=c1, grad_dir=grad_dir,
111
+ cross_entropy_value=cross_entropy_value, C=C, y_pred=y_pred)
112
+
113
+ if armijo:
114
+ grad = self._grad_cross_entropy_logistic(y_true=y, x=x, coef_=coef_ + alpha * direction, C=C, y_pred=y_pred)
115
+ left_curvature = (direction.T @ grad).item()
116
+ right_curvature = c2 * grad_dir
117
+
118
+ # since wolfe conditions are armijo and weak/strong curvature, the curvature directly implies weak or
119
+ # strong wolfe since armijo is given to be True at this point
120
+ weak_wolfe = left_curvature >= right_curvature
121
+ strong_wolfe = np.abs(left_curvature) <= np.abs(right_curvature)
122
+ else:
123
+ weak_wolfe, strong_wolfe = False, False
124
+ grad = None
125
+
126
+ return weak_wolfe, strong_wolfe, grad
127
+
128
+ def _line_search(self, x: np.ndarray, y: np.ndarray, coef_: np.ndarray,
129
+ direction: np.ndarray, grad: np.ndarray, C: float,
130
+ c1: float = 10 ** (-4), c2: float = 0.9,
131
+ alpha_upper: float = 2.0, alpha_lower: float = 10**-10, alpha: float = 1.0,
132
+ n_iter: int = 10):
133
+
134
+ grad_dir = (direction.T @ grad).item()
135
+ x_coef = x @ coef_
136
+ x_direction = x @ direction
137
+
138
+ penalty = self.get_penalty(coef=coef_, C=C)
139
+
140
+ cross_entropy_value = cross_entropy(y_true=y, y_pred=self._sigmoid(x_coef), penalty=penalty)
141
+
142
+ weak_wolfe_value, grad_value = 0, 0
143
+ for _ in range(n_iter):
144
+ weak_wolfe, strong_wolfe, grad = self._wolfe(x=x, y=y, coef_=coef_, alpha=alpha, direction=direction, c1=c1,
145
+ c2=c2, grad_dir=grad_dir,
146
+ cross_entropy_value=cross_entropy_value,
147
+ C=C, x_coef=x_coef, x_direction=x_direction)
148
+
149
+ if strong_wolfe:
150
+ break
151
+ else:
152
+ if weak_wolfe and alpha > weak_wolfe_value:
153
+ weak_wolfe_value = alpha
154
+ grad_value = grad
155
+
156
+ alpha_lower = alpha
157
+ else:
158
+ alpha_upper = alpha
159
+
160
+ alpha = (alpha_lower + alpha_upper) / 2
161
+ else:
162
+ if weak_wolfe_value != 0:
163
+ alpha = weak_wolfe_value
164
+ grad = grad_value
165
+ else:
166
+ alpha, grad = None, None
167
+
168
+ return alpha, grad
169
+
170
+ def get_penalty(self, coef: np.ndarray, C: float):
171
+ if self.intercept:
172
+ penalty = np.linalg.norm(coef[1:]) / C
173
+ else:
174
+ penalty = np.linalg.norm(coef) / C
175
+
176
+ return penalty