workbench 0.8.212__py3-none-any.whl → 0.8.217__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +257 -80
  3. workbench/algorithms/dataframe/projection_2d.py +38 -21
  4. workbench/algorithms/dataframe/proximity.py +75 -150
  5. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  6. workbench/algorithms/models/cleanlab_model.py +382 -0
  7. workbench/algorithms/models/noise_model.py +2 -2
  8. workbench/api/__init__.py +3 -0
  9. workbench/api/endpoint.py +10 -5
  10. workbench/api/feature_set.py +76 -6
  11. workbench/api/meta_model.py +289 -0
  12. workbench/api/model.py +43 -4
  13. workbench/core/artifacts/endpoint_core.py +75 -129
  14. workbench/core/artifacts/feature_set_core.py +1 -1
  15. workbench/core/artifacts/model_core.py +6 -4
  16. workbench/core/pipelines/pipeline_executor.py +1 -1
  17. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +30 -10
  18. workbench/model_script_utils/pytorch_utils.py +11 -1
  19. workbench/model_scripts/chemprop/chemprop.template +145 -69
  20. workbench/model_scripts/chemprop/generated_model_script.py +147 -71
  21. workbench/model_scripts/custom_models/chem_info/fingerprints.py +7 -3
  22. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  23. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  24. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  25. workbench/model_scripts/custom_models/uq_models/meta_uq.template +6 -6
  26. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  27. workbench/model_scripts/meta_model/meta_model.template +209 -0
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +42 -24
  29. workbench/model_scripts/pytorch_model/pytorch.template +42 -24
  30. workbench/model_scripts/pytorch_model/pytorch_utils.py +11 -1
  31. workbench/model_scripts/script_generation.py +4 -0
  32. workbench/model_scripts/xgb_model/generated_model_script.py +169 -158
  33. workbench/model_scripts/xgb_model/xgb_model.template +163 -152
  34. workbench/repl/workbench_shell.py +0 -5
  35. workbench/scripts/endpoint_test.py +2 -2
  36. workbench/utils/chem_utils/fingerprints.py +7 -3
  37. workbench/utils/chemprop_utils.py +23 -5
  38. workbench/utils/meta_model_simulator.py +471 -0
  39. workbench/utils/metrics_utils.py +94 -10
  40. workbench/utils/model_utils.py +91 -9
  41. workbench/utils/pytorch_utils.py +1 -1
  42. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  43. {workbench-0.8.212.dist-info → workbench-0.8.217.dist-info}/METADATA +2 -1
  44. {workbench-0.8.212.dist-info → workbench-0.8.217.dist-info}/RECORD +48 -43
  45. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  46. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  47. {workbench-0.8.212.dist-info → workbench-0.8.217.dist-info}/WHEEL +0 -0
  48. {workbench-0.8.212.dist-info → workbench-0.8.217.dist-info}/entry_points.txt +0 -0
  49. {workbench-0.8.212.dist-info → workbench-0.8.217.dist-info}/licenses/LICENSE +0 -0
  50. {workbench-0.8.212.dist-info → workbench-0.8.217.dist-info}/top_level.txt +0 -0
@@ -1,132 +1,243 @@
1
1
  import pandas as pd
2
2
  import numpy as np
3
3
  from sklearn.neighbors import NearestNeighbors
4
- from typing import Union, List
4
+ from typing import Union, List, Optional
5
5
  import logging
6
6
 
7
7
  # Workbench Imports
8
8
  from workbench.algorithms.dataframe.proximity import Proximity
9
+ from workbench.algorithms.dataframe.projection_2d import Projection2D
10
+ from workbench.utils.chem_utils.fingerprints import compute_morgan_fingerprints
9
11
 
10
12
  # Set up logging
11
13
  log = logging.getLogger("workbench")
12
14
 
13
15
 
14
16
  class FingerprintProximity(Proximity):
17
+ """Proximity computations for binary fingerprints using Tanimoto similarity.
18
+
19
+ Note: Tanimoto similarity is equivalent to Jaccard similarity for binary vectors.
20
+ Tanimoto(A, B) = |A ∩ B| / |A ∪ B|
21
+ """
22
+
15
23
  def __init__(
16
- self, df: pd.DataFrame, id_column: Union[int, str], fingerprint_column: str, n_neighbors: int = 5
24
+ self,
25
+ df: pd.DataFrame,
26
+ id_column: str,
27
+ fingerprint_column: Optional[str] = None,
28
+ target: Optional[str] = None,
29
+ include_all_columns: bool = False,
30
+ radius: int = 2,
31
+ n_bits: int = 1024,
32
+ counts: bool = False,
17
33
  ) -> None:
18
34
  """
19
35
  Initialize the FingerprintProximity class for binary fingerprint similarity.
20
36
 
21
37
  Args:
22
- df (pd.DataFrame): DataFrame containing fingerprints.
23
- id_column (Union[int, str]): Name of the column used as an identifier.
24
- fingerprint_column (str): Name of the column containing fingerprints.
25
- n_neighbors (int): Default number of neighbors to compute.
38
+ df: DataFrame containing fingerprints or SMILES.
39
+ id_column: Name of the column used as an identifier.
40
+ fingerprint_column: Name of the column containing fingerprints (bit strings).
41
+ If None, looks for existing "fingerprint" column or computes from SMILES.
42
+ target: Name of the target column. Defaults to None.
43
+ include_all_columns: Include all DataFrame columns in neighbor results. Defaults to False.
44
+ radius: Radius for Morgan fingerprint computation (default: 2).
45
+ n_bits: Number of bits for fingerprint (default: 1024).
46
+ counts: Whether to use count simulation (default: False).
47
+ """
48
+ # Store fingerprint computation parameters
49
+ self._fp_radius = radius
50
+ self._fp_n_bits = n_bits
51
+ self._fp_counts = counts
52
+
53
+ # Store the requested fingerprint column (may be None)
54
+ self._fingerprint_column_arg = fingerprint_column
55
+
56
+ # Determine fingerprint column name (but don't compute yet - that happens in _prepare_data)
57
+ self.fingerprint_column = self._resolve_fingerprint_column_name(df, fingerprint_column)
58
+
59
+ # Call parent constructor with fingerprint_column as the only "feature"
60
+ super().__init__(
61
+ df,
62
+ id_column=id_column,
63
+ features=[self.fingerprint_column],
64
+ target=target,
65
+ include_all_columns=include_all_columns,
66
+ )
67
+
68
+ @staticmethod
69
+ def _resolve_fingerprint_column_name(df: pd.DataFrame, fingerprint_column: Optional[str]) -> str:
26
70
  """
27
- self.fingerprint_column = fingerprint_column
71
+ Determine the fingerprint column name, validating it exists or can be computed.
28
72
 
29
- # Call the parent class constructor
30
- super().__init__(df, id_column=id_column, features=[fingerprint_column], n_neighbors=n_neighbors)
73
+ Args:
74
+ df: Input DataFrame.
75
+ fingerprint_column: Explicitly specified fingerprint column, or None.
31
76
 
32
- # Override the build_proximity_model method
33
- def build_proximity_model(self) -> None:
77
+ Returns:
78
+ Name of the fingerprint column to use.
79
+
80
+ Raises:
81
+ ValueError: If no fingerprint column exists and no SMILES column found.
34
82
  """
35
- Prepare the fingerprint data for nearest neighbor calculations.
83
+ # If explicitly provided, validate it exists
84
+ if fingerprint_column is not None:
85
+ if fingerprint_column not in df.columns:
86
+ raise ValueError(f"Fingerprint column '{fingerprint_column}' not found in DataFrame")
87
+ return fingerprint_column
88
+
89
+ # Check for existing "fingerprint" column
90
+ if "fingerprint" in df.columns:
91
+ log.info("Using existing 'fingerprint' column")
92
+ return "fingerprint"
93
+
94
+ # Will need to compute from SMILES - validate SMILES column exists
95
+ smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
96
+ if smiles_column is None:
97
+ raise ValueError(
98
+ "No fingerprint column provided and no SMILES column found. "
99
+ "Either provide a fingerprint_column or include a 'smiles' column in the DataFrame."
100
+ )
101
+
102
+ # Fingerprints will be computed in _prepare_data
103
+ return "fingerprint"
104
+
105
+ def _prepare_data(self) -> None:
106
+ """Compute fingerprints from SMILES if needed."""
107
+ # If fingerprint column doesn't exist yet, compute it
108
+ if self.fingerprint_column not in self.df.columns:
109
+ log.info(f"Computing Morgan fingerprints (radius={self._fp_radius}, n_bits={self._fp_n_bits})...")
110
+ self.df = compute_morgan_fingerprints(
111
+ self.df, radius=self._fp_radius, n_bits=self._fp_n_bits, counts=self._fp_counts
112
+ )
113
+
114
+ def _build_model(self) -> None:
115
+ """
116
+ Build the fingerprint proximity model for Tanimoto similarity.
36
117
  Converts fingerprint strings to binary arrays and initializes NearestNeighbors.
118
+
119
+ Note: sklearn uses Jaccard distance internally (1 - Tanimoto similarity).
120
+ We convert back to Tanimoto similarity in the output methods.
37
121
  """
38
122
  log.info("Converting fingerprints to binary feature matrix...")
39
- # self.proximity_type = ProximityType.SIMILARITY
40
123
 
41
- # Convert fingerprint strings to binary arrays
124
+ # Convert fingerprint strings to binary arrays and store for later use
125
+ self.X = self._fingerprints_to_matrix(self.df)
42
126
 
43
- fingerprint_bits = self.df[self.fingerprint_column].apply(
44
- lambda fp: np.array([int(bit) for bit in fp], dtype=np.bool_)
45
- )
46
- self.X = np.vstack(fingerprint_bits)
127
+ # sklearn uses Jaccard distance = 1 - Tanimoto similarity
128
+ # We convert to Tanimoto similarity in neighbors() and _precompute_metrics()
129
+ log.info("Building NearestNeighbors model (Jaccard/Tanimoto metric, BallTree)...")
130
+ self.nn = NearestNeighbors(metric="jaccard", algorithm="ball_tree").fit(self.X)
47
131
 
48
- # Use Jaccard similarity for binary fingerprints
49
- log.info("Computing NearestNeighbors with Jaccard metric...")
50
- self.nn = NearestNeighbors(metric="jaccard", n_neighbors=self.n_neighbors + 1).fit(self.X)
132
+ def _transform_features(self, df: pd.DataFrame) -> np.ndarray:
133
+ """
134
+ Transform fingerprints to binary matrix for querying.
135
+
136
+ Args:
137
+ df: DataFrame containing fingerprints to transform.
138
+
139
+ Returns:
140
+ Binary feature matrix for the fingerprints.
141
+ """
142
+ return self._fingerprints_to_matrix(df)
51
143
 
52
- # Override the prep_features_for_query method
53
- def prep_features_for_query(self, query_df: pd.DataFrame) -> np.ndarray:
144
+ def _fingerprints_to_matrix(self, df: pd.DataFrame) -> np.ndarray:
54
145
  """
55
- Prepare the query DataFrame by converting fingerprints to binary arrays.
146
+ Convert fingerprint strings to a binary numpy matrix.
56
147
 
57
148
  Args:
58
- query_df (pd.DataFrame): DataFrame containing query fingerprints.
149
+ df: DataFrame containing fingerprint column.
59
150
 
60
151
  Returns:
61
- np.ndarray: Binary feature matrix for the query fingerprints.
152
+ 2D numpy array of binary fingerprint bits.
62
153
  """
63
- fingerprint_bits = query_df[self.fingerprint_column].apply(
154
+ fingerprint_bits = df[self.fingerprint_column].apply(
64
155
  lambda fp: np.array([int(bit) for bit in fp], dtype=np.bool_)
65
156
  )
66
157
  return np.vstack(fingerprint_bits)
67
158
 
68
- def all_neighbors(
69
- self,
70
- min_similarity: float = None,
71
- include_self: bool = False,
72
- add_columns: List[str] = None,
73
- ) -> pd.DataFrame:
159
+ def _precompute_metrics(self) -> None:
160
+ """Precompute metrics, adding Tanimoto similarity alongside distance."""
161
+ # Call parent to compute nn_distance (Jaccard), nn_id, nn_target, nn_target_diff
162
+ super()._precompute_metrics()
163
+
164
+ # Add Tanimoto similarity (keep nn_distance for internal use by target_gradients)
165
+ self.df["nn_similarity"] = 1 - self.df["nn_distance"]
166
+
167
+ def _set_core_columns(self) -> None:
168
+ """Set core columns using nn_similarity instead of nn_distance."""
169
+ self.core_columns = [self.id_column, "nn_similarity", "nn_id"]
170
+ if self.target:
171
+ self.core_columns.extend([self.target, "nn_target", "nn_target_diff"])
172
+
173
+ def _project_2d(self) -> None:
174
+ """Project the fingerprint matrix to 2D for visualization using UMAP with Jaccard metric."""
175
+ self.df = Projection2D().fit_transform(self.df, feature_matrix=self.X, metric="jaccard")
176
+
177
+ def isolated(self, top_percent: float = 1.0) -> pd.DataFrame:
74
178
  """
75
- Find neighbors for all fingerprints in the dataset.
179
+ Find isolated data points based on Tanimoto similarity to nearest neighbor.
76
180
 
77
181
  Args:
78
- min_similarity: Minimum similarity threshold (0-1)
79
- include_self: Whether to include self in results
80
- add_columns: Additional columns to include in results
182
+ top_percent: Percentage of most isolated data points to return (e.g., 1.0 returns top 1%)
81
183
 
82
184
  Returns:
83
- DataFrame containing neighbors and similarities
185
+ DataFrame of observations with lowest Tanimoto similarity, sorted ascending
84
186
  """
187
+ # For Tanimoto similarity, isolated means LOW similarity to nearest neighbor
188
+ percentile = top_percent
189
+ threshold = np.percentile(self.df["nn_similarity"], percentile)
190
+ isolated = self.df[self.df["nn_similarity"] <= threshold].copy()
191
+ isolated = isolated.sort_values("nn_similarity", ascending=True).reset_index(drop=True)
192
+ return isolated if self.include_all_columns else isolated[self.core_columns]
85
193
 
86
- # Call the parent class method to find neighbors
87
- return self.neighbors(
88
- query_df=self.df,
89
- min_similarity=min_similarity,
90
- include_self=include_self,
91
- add_columns=add_columns,
194
+ def proximity_stats(self) -> pd.DataFrame:
195
+ """
196
+ Return distribution statistics for nearest neighbor Tanimoto similarity.
197
+
198
+ Returns:
199
+ DataFrame with similarity distribution statistics (count, mean, std, percentiles)
200
+ """
201
+ return (
202
+ self.df["nn_similarity"]
203
+ .describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])
204
+ .to_frame()
92
205
  )
93
206
 
94
207
  def neighbors(
95
208
  self,
96
- query_df: pd.DataFrame,
97
- min_similarity: float = None,
98
- include_self: bool = False,
99
- add_columns: List[str] = None,
209
+ id_or_ids: Union[str, int, List[Union[str, int]]],
210
+ n_neighbors: Optional[int] = 5,
211
+ min_similarity: Optional[float] = None,
212
+ include_self: bool = True,
100
213
  ) -> pd.DataFrame:
101
214
  """
102
- Find neighbors for each row in the query DataFrame.
215
+ Return neighbors for ID(s) from the existing dataset.
103
216
 
104
217
  Args:
105
- query_df: DataFrame containing query fingerprints
106
- min_similarity: Minimum similarity threshold (0-1)
107
- include_self: Whether to include self in results (if present)
108
- add_columns: Additional columns to include in results
218
+ id_or_ids: Single ID or list of IDs to look up
219
+ n_neighbors: Number of neighbors to return (default: 5, ignored if min_similarity is set)
220
+ min_similarity: If provided, find all neighbors with Tanimoto similarity >= this value (0-1)
221
+ include_self: Whether to include self in results (default: True)
109
222
 
110
223
  Returns:
111
- DataFrame containing neighbors and similarities
112
-
113
- Note: The query DataFrame must include the feature columns. The id_column is optional.
224
+ DataFrame containing neighbors with Tanimoto similarity scores
114
225
  """
115
-
116
- # Calculate radius from similarity if provided
226
+ # Convert min_similarity to radius (Jaccard distance = 1 - Tanimoto similarity)
117
227
  radius = 1 - min_similarity if min_similarity is not None else None
118
228
 
119
- # Call the parent class method to find neighbors
229
+ # Call parent method (returns Jaccard distance)
120
230
  neighbors_df = super().neighbors(
121
- query_df=query_df,
231
+ id_or_ids=id_or_ids,
232
+ n_neighbors=n_neighbors,
122
233
  radius=radius,
123
234
  include_self=include_self,
124
- add_columns=add_columns,
125
235
  )
126
236
 
127
- # Convert distances to similarity
237
+ # Convert Jaccard distance to Tanimoto similarity
128
238
  neighbors_df["similarity"] = 1 - neighbors_df["distance"]
129
239
  neighbors_df.drop(columns=["distance"], inplace=True)
240
+
130
241
  return neighbors_df
131
242
 
132
243
 
@@ -135,28 +246,94 @@ if __name__ == "__main__":
135
246
  pd.set_option("display.max_columns", None)
136
247
  pd.set_option("display.width", 1000)
137
248
 
138
- # Example DataFrame
249
+ # Create an Example DataFrame with fingerprints
139
250
  data = {
140
- "id": ["a", "b", "c", "d"],
141
- "fingerprint": ["101010", "111010", "101110", "011100"],
251
+ "id": ["a", "b", "c", "d", "e"],
252
+ "fingerprint": ["101010", "111010", "101110", "011100", "000111"],
253
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
254
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
255
+ "target": [1, 0, 1, 0, 5],
142
256
  }
143
257
  df = pd.DataFrame(data)
144
258
 
145
- # Initialize the FingerprintProximity class
146
- proximity = FingerprintProximity(df, fingerprint_column="fingerprint", id_column="id", n_neighbors=3)
259
+ # Test basic FingerprintProximity with explicit fingerprint column
260
+ prox = FingerprintProximity(df, fingerprint_column="fingerprint", id_column="id", target="target")
261
+ print(prox.neighbors("a", n_neighbors=3))
262
+
263
+ # Test neighbors with similarity threshold
264
+ print(prox.neighbors("a", min_similarity=0.5))
265
+
266
+ # Test with include_all_columns=True
267
+ prox = FingerprintProximity(
268
+ df,
269
+ fingerprint_column="fingerprint",
270
+ id_column="id",
271
+ target="target",
272
+ include_all_columns=True,
273
+ )
274
+ print(prox.neighbors(["a", "b"]))
275
+
276
+ # Test on real data from Workbench
277
+ from workbench.api import FeatureSet, Model
278
+
279
+ fs = FeatureSet("aqsol_features")
280
+ model = Model("aqsol-regression")
281
+ df = fs.pull_dataframe()
282
+ prox = FingerprintProximity(df, id_column=fs.id_column, target=model.target())
283
+
284
+ print("\n" + "=" * 80)
285
+ print("Testing Neighbors...")
286
+ print("=" * 80)
287
+ test_id = df[fs.id_column].tolist()[0]
288
+ print(f"\nNeighbors for ID {test_id}:")
289
+ print(prox.neighbors(test_id))
290
+
291
+ print("\n" + "=" * 80)
292
+ print("Testing isolated compounds...")
293
+ print("=" * 80)
294
+
295
+ # Test isolated data in the top 1%
296
+ isolated_1pct = prox.isolated(top_percent=1.0)
297
+ print(f"\nTop 1% most isolated compounds (n={len(isolated_1pct)}):")
298
+ print(isolated_1pct)
299
+
300
+ # Test isolated data in the top 5%
301
+ isolated_5pct = prox.isolated(top_percent=5.0)
302
+ print(f"\nTop 5% most isolated compounds (n={len(isolated_5pct)}):")
303
+ print(isolated_5pct)
304
+
305
+ print("\n" + "=" * 80)
306
+ print("Testing target_gradients...")
307
+ print("=" * 80)
308
+
309
+ # Test with different parameters
310
+ gradients_1pct = prox.target_gradients(top_percent=1.0, min_delta=1.0)
311
+ print(f"\nTop 1% target gradients (min_delta=1.0) (n={len(gradients_1pct)}):")
312
+ print(gradients_1pct)
313
+
314
+ gradients_5pct = prox.target_gradients(top_percent=5.0, min_delta=5.0)
315
+ print(f"\nTop 5% target gradients (min_delta=5.0) (n={len(gradients_5pct)}):")
316
+ print(gradients_5pct)
317
+
318
+ # Test proximity_stats
319
+ print("\n" + "=" * 80)
320
+ print("Testing proximity_stats...")
321
+ print("=" * 80)
322
+ stats = prox.proximity_stats()
323
+ print(stats)
147
324
 
148
- # Test 1: All neighbors
149
- print("\n--- Test 1: All Neighbors ---")
150
- all_neighbors_df = proximity.all_neighbors()
151
- print(all_neighbors_df)
325
+ # Plot the similarity distribution using pandas
326
+ print("\n" + "=" * 80)
327
+ print("Plotting similarity distribution...")
328
+ print("=" * 80)
329
+ prox.df["nn_similarity"].hist(bins=50, figsize=(10, 6), edgecolor="black")
152
330
 
153
- # Test 2: Neighbors for a specific query
154
- print("\n--- Test 2: Neighbors for Query ---")
155
- query_df = pd.DataFrame({"id": ["a"], "fingerprint": ["101010"]})
156
- query_neighbors_df = proximity.neighbors(query_df=query_df)
157
- print(query_neighbors_df)
331
+ # Visualize the 2D projection
332
+ print("\n" + "=" * 80)
333
+ print("Visualizing 2D Projection...")
334
+ print("=" * 80)
335
+ from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
336
+ from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
158
337
 
159
- # Test 3: Neighbors with similarity threshold
160
- print("\n--- Test 3: Neighbors with Minimum Similarity 0.5 ---")
161
- query_neighbors_sim_df = proximity.neighbors(query_df=query_df, min_similarity=0.5)
162
- print(query_neighbors_sim_df)
338
+ unit_test = PluginUnitTest(ScatterPlot, input_data=prox.df[:1000], x="x", y="y", color=model.target())
339
+ unit_test.run()
@@ -22,7 +22,14 @@ class Projection2D:
22
22
  self.log = logging.getLogger("workbench")
23
23
  self.projection_model = None
24
24
 
25
- def fit_transform(self, input_df: pd.DataFrame, features: list = None, projection: str = "UMAP") -> pd.DataFrame:
25
+ def fit_transform(
26
+ self,
27
+ input_df: pd.DataFrame,
28
+ features: list = None,
29
+ feature_matrix: np.ndarray = None,
30
+ metric: str = "euclidean",
31
+ projection: str = "UMAP",
32
+ ) -> pd.DataFrame:
26
33
  """Fit and transform a DataFrame using the selected dimensionality reduction method.
27
34
 
28
35
  This method creates a copy of the input DataFrame, processes the specified features
@@ -32,6 +39,9 @@ class Projection2D:
32
39
  Args:
33
40
  input_df (pd.DataFrame): The DataFrame containing features to project.
34
41
  features (list, optional): List of feature column names. If None, numeric columns are auto-selected.
42
+ feature_matrix (np.ndarray, optional): Pre-computed feature matrix. If provided, features is ignored
43
+ and no scaling is applied (caller is responsible for appropriate preprocessing).
44
+ metric (str, optional): Distance metric for UMAP (e.g., 'euclidean', 'jaccard'). Default 'euclidean'.
35
45
  projection (str, optional): The projection to use ('UMAP', 'TSNE', 'MDS' or 'PCA'). Default 'UMAP'.
36
46
 
37
47
  Returns:
@@ -40,36 +50,44 @@ class Projection2D:
40
50
  # Create a copy of the input DataFrame
41
51
  df = input_df.copy()
42
52
 
43
- # Auto-identify numeric features if none are provided
44
- if features is None:
45
- features = [col for col in df.select_dtypes(include="number").columns if not col.endswith("id")]
46
- self.log.info(f"Auto-identified numeric features: {features}")
47
-
48
- if len(features) < 2 or df.empty:
49
- self.log.critical("At least two numeric features are required, and DataFrame must not be empty.")
50
- return df
51
-
52
- # Process a copy of the feature data for projection
53
- X = df[features]
54
- X = X.apply(lambda col: col.fillna(col.mean()))
55
- X_scaled = StandardScaler().fit_transform(X)
53
+ # If a feature matrix is provided, use it directly (no scaling)
54
+ if feature_matrix is not None:
55
+ if len(feature_matrix) != len(df):
56
+ self.log.critical("feature_matrix length must match DataFrame length.")
57
+ return df
58
+ X_processed = feature_matrix
59
+ else:
60
+ # Auto-identify numeric features if none are provided
61
+ if features is None:
62
+ features = [col for col in df.select_dtypes(include="number").columns if not col.endswith("id")]
63
+ self.log.info(f"Auto-identified numeric features: {features}")
64
+
65
+ if len(features) < 2 or df.empty:
66
+ self.log.critical("At least two numeric features are required, and DataFrame must not be empty.")
67
+ return df
68
+
69
+ # Process a copy of the feature data for projection
70
+ X = df[features]
71
+ X = X.apply(lambda col: col.fillna(col.mean()))
72
+ X_processed = StandardScaler().fit_transform(X)
56
73
 
57
74
  # Select the projection method (using df for perplexity calculation)
58
- self.projection_model = self._get_projection_model(projection, df)
75
+ self.projection_model = self._get_projection_model(projection, df, metric=metric)
59
76
 
60
- # Apply the projection on the normalized data
61
- projection_result = self.projection_model.fit_transform(X_scaled)
77
+ # Apply the projection on the processed data
78
+ projection_result = self.projection_model.fit_transform(X_processed)
62
79
  df[["x", "y"]] = projection_result
63
80
 
64
81
  # Resolve coincident points and return the new DataFrame
65
82
  return self.resolve_coincident_points(df)
66
83
 
67
- def _get_projection_model(self, projection: str, df: pd.DataFrame):
84
+ def _get_projection_model(self, projection: str, df: pd.DataFrame, metric: str = "euclidean"):
68
85
  """Select and return the appropriate projection model.
69
86
 
70
87
  Args:
71
88
  projection (str): The projection method ('TSNE', 'MDS', 'PCA', or 'UMAP').
72
89
  df (pd.DataFrame): The DataFrame being transformed (used for computing perplexity).
90
+ metric (str): Distance metric for UMAP (default 'euclidean').
73
91
 
74
92
  Returns:
75
93
  A dimensionality reduction model instance.
@@ -88,8 +106,8 @@ class Projection2D:
88
106
  return PCA(n_components=2)
89
107
 
90
108
  if projection == "UMAP" and UMAP_AVAILABLE:
91
- self.log.info("Projection: UMAP")
92
- return umap.UMAP(n_components=2)
109
+ self.log.info(f"Projection: UMAP with metric={metric}")
110
+ return umap.UMAP(n_components=2, metric=metric)
93
111
 
94
112
  self.log.warning(
95
113
  f"Projection method '{projection}' not recognized or UMAP not available. Falling back to TSNE."
@@ -118,7 +136,6 @@ class Projection2D:
118
136
 
119
137
  # Find duplicates
120
138
  duplicated = rounded.duplicated(subset=["x_round", "y_round"], keep=False)
121
- print("Coincident Points found:", duplicated.sum())
122
139
  if not duplicated.any():
123
140
  return df
124
141