distclassipy 0.2.1__py3-none-any.whl → 0.2.2a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
distclassipy/__init__.py CHANGED
@@ -26,9 +26,9 @@ from .classifier import (
26
26
  DistanceMetricClassifier,
27
27
  EnsembleDistanceClassifier,
28
28
  )
29
- from .distances import Distance, _ALL_METRICS
29
+ from .distances import _ALL_METRICS
30
30
 
31
- __version__ = "0.2.1"
31
+ __version__ = "0.2.2a1"
32
32
 
33
33
  __all__ = [
34
34
  "DistanceMetricClassifier",
@@ -46,14 +46,15 @@ from sklearn.base import BaseEstimator, ClassifierMixin
46
46
  from sklearn.metrics import accuracy_score
47
47
  from sklearn.model_selection import train_test_split
48
48
  from sklearn.utils.multiclass import unique_labels
49
- from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
49
+ from sklearn.utils.validation import check_is_fitted, check_array
50
50
 
51
- from .distances import Distance, _ALL_METRICS
51
+ from . import distances
52
+ from .distances import _ALL_METRICS
52
53
 
53
54
  # Hardcoded source packages to check for distance metrics.
54
55
  METRIC_SOURCES_ = {
55
56
  "scipy.spatial.distance": scipy.spatial.distance,
56
- "distances.Distance": Distance(),
57
+ "distclassipy.distances": distances,
57
58
  }
58
59
 
59
60
 
@@ -61,7 +62,7 @@ def initialize_metric_function(metric):
61
62
  """Set the metric function based on the provided metric.
62
63
 
63
64
  If the metric is a string, the function will look for a corresponding
64
- function in scipy.spatial.distance or distances.Distance. If the metric
65
+ function in scipy.spatial.distance or distclassipy.distances. If the metric
65
66
  is a function, it will be used directly.
66
67
  """
67
68
  if callable(metric):
@@ -95,7 +96,7 @@ def initialize_metric_function(metric):
95
96
  raise ValueError(
96
97
  f"{metric} metric not found. Please pass a string of the "
97
98
  "name of a metric in scipy.spatial.distance or "
98
- "distances.Distance, or pass a metric function directly. For a "
99
+ "distclassipy.distances, or pass a metric function directly. For a "
99
100
  "list of available metrics, see: "
100
101
  "https://sidchaini.github.io/DistClassiPy/distances.html or "
101
102
  "https://docs.scipy.org/doc/scipy/reference/spatial.distance.html"
@@ -103,7 +104,7 @@ def initialize_metric_function(metric):
103
104
  return metric_fn_, metric_arg_
104
105
 
105
106
 
106
- class DistanceMetricClassifier(BaseEstimator, ClassifierMixin):
107
+ class DistanceMetricClassifier(ClassifierMixin, BaseEstimator):
107
108
  """A distance-based classifier that supports different distance metrics.
108
109
 
109
110
  The distance metric classifier determines the similarity between features in a
@@ -151,11 +152,13 @@ class DistanceMetricClassifier(BaseEstimator, ClassifierMixin):
151
152
 
152
153
  def __init__(
153
154
  self,
155
+ metric: str | Callable = None,
154
156
  scale: bool = True,
155
157
  central_stat: str = "median",
156
158
  dispersion_stat: str = "std",
157
159
  ) -> None:
158
160
  """Initialize the classifier with specified parameters."""
161
+ self.metric = metric
159
162
  self.scale = scale
160
163
  self.central_stat = central_stat
161
164
  self.dispersion_stat = dispersion_stat
@@ -186,11 +189,8 @@ class DistanceMetricClassifier(BaseEstimator, ClassifierMixin):
186
189
  self : object
187
190
  Fitted estimator.
188
191
  """
189
- X, y = check_X_y(X, y)
192
+ X, y = self._validate_data(X, y)
190
193
  self.classes_ = unique_labels(y)
191
- self.n_features_in_ = X.shape[
192
- 1
193
- ] # Number of features seen during fit - required for sklearn compatibility.
194
194
 
195
195
  if feat_labels is None:
196
196
  feat_labels = [f"Feature_{x}" for x in range(X.shape[1])]
@@ -242,7 +242,7 @@ class DistanceMetricClassifier(BaseEstimator, ClassifierMixin):
242
242
  def predict(
243
243
  self,
244
244
  X: np.array,
245
- metric: str | Callable = "euclidean",
245
+ metric: str | Callable = None,
246
246
  ) -> np.ndarray:
247
247
  """Predict the class labels for the provided X.
248
248
 
@@ -269,7 +269,7 @@ class DistanceMetricClassifier(BaseEstimator, ClassifierMixin):
269
269
  See Also
270
270
  --------
271
271
  scipy.spatial.dist : Other distance metrics provided in SciPy
272
- distclassipy.Distance : Distance metrics included with DistClassiPy
272
+ distclassipy.distances : Distance metrics included with DistClassiPy
273
273
 
274
274
  Notes
275
275
  -----
@@ -277,9 +277,15 @@ class DistanceMetricClassifier(BaseEstimator, ClassifierMixin):
277
277
  which allows SciPy to use an optimized C version of the code instead of the
278
278
  slower Python version.
279
279
  """
280
- check_is_fitted(self, "is_fitted_")
281
- X = check_array(X)
282
- metric_fn_, metric_arg_ = initialize_metric_function(metric)
280
+ check_is_fitted(self)
281
+ X = self._validate_data(X, reset=False)
282
+
283
+ metric_to_use = metric if metric is not None else self.metric
284
+ if metric_to_use is None:
285
+ # defaults to euclidean
286
+ metric_to_use = "euclidean"
287
+ metric_fn_, metric_arg_ = initialize_metric_function(metric_to_use)
288
+
283
289
  if not self.scale:
284
290
  dist_arr = scipy.spatial.distance.cdist(
285
291
  XA=X, XB=self.df_centroid_.to_numpy(), metric=metric_arg_
@@ -309,7 +315,7 @@ class DistanceMetricClassifier(BaseEstimator, ClassifierMixin):
309
315
  def predict_and_analyse(
310
316
  self,
311
317
  X: np.array,
312
- metric: str | Callable = "euclidean",
318
+ metric: str | Callable = None,
313
319
  ) -> np.ndarray:
314
320
  """Predict the class labels for the provided X and perform analysis.
315
321
 
@@ -336,7 +342,7 @@ class DistanceMetricClassifier(BaseEstimator, ClassifierMixin):
336
342
  See Also
337
343
  --------
338
344
  scipy.spatial.dist : Other distance metrics provided in SciPy
339
- distclassipy.Distance : Distance metrics included with DistClassiPy
345
+ distclassipy.distances : Distance metrics included with DistClassiPy
340
346
 
341
347
  Notes
342
348
  -----
@@ -345,10 +351,14 @@ class DistanceMetricClassifier(BaseEstimator, ClassifierMixin):
345
351
  of the slower Python version.
346
352
 
347
353
  """
348
- check_is_fitted(self, "is_fitted_")
349
- X = check_array(X)
354
+ check_is_fitted(self)
355
+ X = self._validate_data(X, reset=False)
350
356
 
351
- metric_fn_, metric_arg_ = initialize_metric_function(metric)
357
+ metric_to_use = metric if metric is not None else self.metric
358
+ if metric_to_use is None:
359
+ # defaults to euclidean
360
+ metric_to_use = "euclidean"
361
+ metric_fn_, metric_arg_ = initialize_metric_function(metric_to_use)
352
362
 
353
363
  if not self.scale:
354
364
  dist_arr = scipy.spatial.distance.cdist(
@@ -409,7 +419,7 @@ class DistanceMetricClassifier(BaseEstimator, ClassifierMixin):
409
419
 
410
420
  return self.confidence_df_.to_numpy()
411
421
 
412
- def score(self, X, y, metric: str | Callable = "euclidean") -> float:
422
+ def score(self, X, y, metric: str | Callable = None) -> float:
413
423
  """Return the mean accuracy on the given test data and labels.
414
424
 
415
425
  Parameters
@@ -426,11 +436,12 @@ class DistanceMetricClassifier(BaseEstimator, ClassifierMixin):
426
436
  score : float
427
437
  Mean accuracy of self.predict(X) wrt. y.
428
438
  """
429
- y_pred = self.predict(X, metric=metric)
439
+ metric_to_use = metric if metric is not None else self.metric
440
+ y_pred = self.predict(X, metric=metric_to_use)
430
441
  return accuracy_score(y, y_pred)
431
442
 
432
443
 
433
- class EnsembleDistanceClassifier(BaseEstimator, ClassifierMixin):
444
+ class EnsembleDistanceClassifier(ClassifierMixin, BaseEstimator):
434
445
  """An ensemble classifier that uses different metrics for each quantile.
435
446
 
436
447
  This classifier splits the data into quantiles based on a specified
@@ -532,8 +543,8 @@ class EnsembleDistanceClassifier(BaseEstimator, ClassifierMixin):
532
543
  predictions : np.ndarray
533
544
  The predicted class labels.
534
545
  """
535
- check_is_fitted(self, "is_fitted_")
536
- X = check_array(X)
546
+ check_is_fitted(self)
547
+ X = self._validate_data(X, reset=False)
537
548
 
538
549
  # notes for pred during best:
539
550
  # option 1: