Moral88 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Moral88/regression.py CHANGED
@@ -31,6 +31,7 @@ class DataValidator:
31
31
  """
32
32
  Returns the number of samples in the array.
33
33
  """
34
+ array = np.asarray(array)
34
35
  if hasattr(array, 'shape') and len(array.shape) > 0:
35
36
  return array.shape[0]
36
37
  else:
@@ -101,12 +102,16 @@ class DataValidator:
101
102
 
102
103
 
103
104
  class Metrics:
104
- def mean_bias_deviation(self, y_true, y_pred, library=None):
105
+ def mean_bias_deviation(self, y_true, y_pred, library=None, flatten=True):
105
106
  """
106
107
  Computes Mean Bias Deviation (MBD).
107
108
  """
108
109
  y_true, y_pred = self.validator.validate_mae_mse_inputs(y_true, y_pred, library)
109
110
 
111
+ if flatten and y_true.ndim > 1:
112
+ y_true = y_true.flatten()
113
+ y_pred = y_pred.flatten()
114
+
110
115
  if library == 'sklearn':
111
116
  # Sklearn does not have a direct implementation for MBD
112
117
  raise NotImplementedError("Mean Bias Deviation is not implemented in sklearn.")
@@ -130,12 +135,16 @@ class Metrics:
130
135
  def __init__(self):
131
136
  self.validator = DataValidator()
132
137
 
133
- def r2_score(self, y_true, y_pred, sample_weight=None, library=None):
138
+ def r2_score(self, y_true, y_pred, sample_weight=None, library=None, flatten=True):
134
139
  """
135
140
  Computes R2 score.
136
141
  """
137
142
  y_true, y_pred, sample_weight = self.validator.validate_r2_score_inputs(y_true, y_pred, sample_weight)
138
143
 
144
+ if flatten and y_true.ndim > 1:
145
+ y_true = y_true.flatten()
146
+ y_pred = y_pred.flatten()
147
+
139
148
  if library == 'sklearn':
140
149
  from sklearn.metrics import r2_score as sklearn_r2
141
150
  return sklearn_r2(y_true, y_pred, sample_weight=sample_weight)
@@ -166,6 +175,9 @@ class Metrics:
166
175
  if threshold is not None:
167
176
  y_pred = np.clip(y_pred, threshold[0], threshold[1])
168
177
 
178
+ if y_true.ndim > 1 and flatten:
179
+ y_true = y_true.flatten()
180
+ y_pred = y_pred.flatten()
169
181
  absolute_errors = np.abs(y_true - y_pred)
170
182
 
171
183
  if method == 'mean':
@@ -177,9 +189,9 @@ class Metrics:
177
189
  else:
178
190
  raise ValueError("Invalid method. Choose from {'mean', 'sum', 'none'}.")
179
191
 
180
- if normalize and method != 'none':
181
- range_y = np.ptp(y_true)
182
- result = result / max(abs(range_y), 1)
192
+ # if normalize and method != 'none':
193
+ # range_y = np.ptp(y_true)
194
+ # result = result / max(abs(range_y), 1)
183
195
 
184
196
  return result
185
197
 
@@ -213,6 +225,9 @@ class Metrics:
213
225
  if threshold is not None:
214
226
  y_pred = np.clip(y_pred, threshold[0], threshold[1])
215
227
 
228
+ if y_true.ndim > 1 and flatten:
229
+ y_true = y_true.flatten()
230
+ y_pred = y_pred.flatten()
216
231
  squared_errors = (y_true - y_pred) ** 2
217
232
 
218
233
  if method == 'mean':
@@ -224,9 +239,9 @@ class Metrics:
224
239
  else:
225
240
  raise ValueError("Invalid method. Choose from {'mean', 'sum', 'none'}.")
226
241
 
227
- if normalize and method != 'none':
228
- range_y = np.ptp(y_true)
229
- result = result / max(abs(range_y), 1)
242
+ # if normalize and method != 'none':
243
+ # range_y = np.ptp(y_true)
244
+ # result = result / max(abs(range_y), 1)
230
245
 
231
246
  return result
232
247
 
@@ -296,7 +311,7 @@ class Metrics:
296
311
 
297
312
  return np.mean(np.abs((y_true - y_pred) / np.clip(y_true, 1e-8, None))) * 100
298
313
 
299
- def explained_variance_score(self, y_true, y_pred, library=None):
314
+ def explained_variance_score(self, y_true, y_pred, library=None, flatten=True):
300
315
  """
301
316
  Computes Explained Variance Score.
302
317
  """
@@ -326,6 +341,58 @@ class Metrics:
326
341
  denominator = np.var(y_true)
327
342
  return 1 - numerator / denominator if denominator != 0 else 0
328
343
 
344
+ def adjusted_r2_score(self, y_true, y_pred, n_features, library=None, flatten=True):
345
+ """
346
+ Computes Adjusted R-Squared Score.
347
+
348
+ Parameters:
349
+ y_true: array-like of shape (n_samples,)
350
+ Ground truth (correct) target values.
351
+
352
+ y_pred: array-like of shape (n_samples,)
353
+ Estimated target values.
354
+
355
+ n_features: int
356
+ Number of independent features in the model.
357
+
358
+ library: str, optional (default=None)
359
+ Library to use for computation. Supports {'sklearn', 'statsmodels', None}.
360
+
361
+ flatten: bool, optional (default=True)
362
+ If True, flattens multidimensional arrays before computation.
363
+ """
364
+ # Validate inputs
365
+ y_true, y_pred, _ = self.validator.validate_r2_score_inputs(y_true, y_pred)
366
+
367
+ # Ensure inputs are 1D arrays
368
+ if y_true.ndim == 0 or y_pred.ndim == 0:
369
+ y_true = np.array([y_true])
370
+ y_pred = np.array([y_pred])
371
+
372
+ if flatten and y_true.ndim > 1:
373
+ y_true = y_true.flatten()
374
+ y_pred = y_pred.flatten()
375
+
376
+ if library == 'sklearn':
377
+ from sklearn.metrics import r2_score
378
+ r2 = r2_score(y_true, y_pred)
379
+ elif library == 'statsmodels':
380
+ import statsmodels.api as sm
381
+ X = sm.add_constant(y_pred)
382
+ model = sm.OLS(y_true, X).fit()
383
+ r2 = model.rsquared
384
+ else:
385
+ numerator = np.sum((y_true - y_pred) ** 2)
386
+ denominator = np.sum((y_true - np.mean(y_true)) ** 2)
387
+ r2 = 1 - (numerator / denominator) if denominator != 0 else 0.0
388
+
389
+ n_samples = len(y_true)
390
+ if n_samples <= n_features + 1:
391
+ raise ValueError("Number of samples must be greater than number of features plus one for adjusted R-squared computation.")
392
+
393
+ adjusted_r2 = 1 - (1 - r2) * (n_samples - 1) / (n_samples - n_features - 1)
394
+ return adjusted_r2
395
+
329
396
  if __name__ == '__main__':
330
397
  # Example usage
331
398
  validator = DataValidator()
@@ -336,11 +403,16 @@ if __name__ == '__main__':
336
403
  print("1D array:", validator.is_1d_array(arr))
337
404
  print("Samples:", validator.check_samples(arr))
338
405
 
339
- # Test R2 score
406
+ # Test MAE, MSE, R2, MBD, EV, MAPE, RMSE
340
407
  y_true = [3, -0.5, 2, 7]
341
408
  y_pred = [2.5, 0.0, 2, 8]
342
- print("R2 Score:", metrics.r2_score(y_true, y_pred))
343
409
 
344
- # Test MAE and MSE
345
410
  print("Mean Absolute Error:", metrics.mean_absolute_error(y_true, y_pred))
346
411
  print("Mean Squared Error:", metrics.mean_squared_error(y_true, y_pred))
412
+ print("R2 Score:", metrics.r2_score(y_true, y_pred))
413
+ print("Mean Bias Deviation: ", metrics.mean_bias_deviation(y_true, y_pred))
414
+ print("Explained Variance Score: ", metrics.explained_variance_score(y_true, y_pred))
415
+ print("Mean Absolute Percentage Error: ", metrics.mean_absolute_percentage_error(y_true, y_pred))
416
+ print("Root Mean Squared Error: ", metrics.root_mean_squared_error(y_true, y_pred))
417
+ print("adjusted_r2_score: ", metrics.adjusted_r2_score(y_true, y_pred, 2))
418
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: Moral88
3
- Version: 0.5.0
3
+ Version: 0.6.0
4
4
  Summary: A library for regression evaluation metrics.
5
5
  Author: Morteza Alizadeh
6
6
  Author-email: alizadeh.c2m@gmail.com
@@ -0,0 +1,8 @@
1
+ Moral88/__init__.py,sha256=vb-aPc9ZbnYNSy9qq2fVESI63E10pYsCrDpnV8OHWkg,74
2
+ Moral88/regression.py,sha256=0aSRXLWur6tcC4xd806koyB2ktgPJodlOeXYCZZYDzE,17208
3
+ Moral88/segmentation.py,sha256=v0yqxdrKbM9LM7wVKLjJ4HrhrSrilNNeWS6-oK_27Ag,1363
4
+ Moral88-0.6.0.dist-info/LICENSE,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ Moral88-0.6.0.dist-info/METADATA,sha256=6Y1H8Qh9wnrZVUr2gnoBYMnF5EsXY6ijMoS9bFZ21bE,407
6
+ Moral88-0.6.0.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
7
+ Moral88-0.6.0.dist-info/top_level.txt,sha256=-dyn5iTprnSUHbtMpvRO-prJsIoaRxao7wlfCHLSsv4,8
8
+ Moral88-0.6.0.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- Moral88/__init__.py,sha256=vb-aPc9ZbnYNSy9qq2fVESI63E10pYsCrDpnV8OHWkg,74
2
- Moral88/regression.py,sha256=MjM3R1oqRWdlfo6Goc2NOT0UHeKGcQfdMyriqSvS5q4,14127
3
- Moral88/segmentation.py,sha256=v0yqxdrKbM9LM7wVKLjJ4HrhrSrilNNeWS6-oK_27Ag,1363
4
- Moral88-0.5.0.dist-info/LICENSE,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- Moral88-0.5.0.dist-info/METADATA,sha256=7_93ZrGO0rFBargNBBe7qvQzQnFi2BFN7RGss26ux3I,407
6
- Moral88-0.5.0.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
7
- Moral88-0.5.0.dist-info/top_level.txt,sha256=-dyn5iTprnSUHbtMpvRO-prJsIoaRxao7wlfCHLSsv4,8
8
- Moral88-0.5.0.dist-info/RECORD,,