workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (113) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +11 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/core/artifacts/__init__.py +11 -2
  13. workbench/core/artifacts/artifact.py +11 -3
  14. workbench/core/artifacts/data_capture_core.py +355 -0
  15. workbench/core/artifacts/endpoint_core.py +256 -118
  16. workbench/core/artifacts/feature_set_core.py +265 -16
  17. workbench/core/artifacts/model_core.py +107 -60
  18. workbench/core/artifacts/monitor_core.py +33 -248
  19. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  20. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  21. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  22. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  23. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  24. workbench/core/transforms/features_to_model/features_to_model.py +42 -32
  25. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  26. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  27. workbench/core/views/training_view.py +113 -42
  28. workbench/core/views/view.py +53 -3
  29. workbench/core/views/view_utils.py +4 -4
  30. workbench/model_scripts/chemprop/chemprop.template +852 -0
  31. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  32. workbench/model_scripts/chemprop/requirements.txt +11 -0
  33. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  34. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  36. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  37. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  38. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  39. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  40. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  41. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  42. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  43. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  44. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  45. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  46. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  47. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  48. workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
  49. workbench/model_scripts/pytorch_model/pytorch.template +370 -187
  50. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  51. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  52. workbench/model_scripts/script_generation.py +17 -9
  53. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  54. workbench/model_scripts/uq_models/mapie.template +605 -0
  55. workbench/model_scripts/uq_models/requirements.txt +1 -0
  56. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  57. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  58. workbench/repl/workbench_shell.py +28 -14
  59. workbench/scripts/endpoint_test.py +162 -0
  60. workbench/scripts/lambda_test.py +73 -0
  61. workbench/scripts/ml_pipeline_batch.py +137 -0
  62. workbench/scripts/ml_pipeline_sqs.py +186 -0
  63. workbench/scripts/monitor_cloud_watch.py +20 -100
  64. workbench/utils/aws_utils.py +4 -3
  65. workbench/utils/chem_utils/__init__.py +0 -0
  66. workbench/utils/chem_utils/fingerprints.py +134 -0
  67. workbench/utils/chem_utils/misc.py +194 -0
  68. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  69. workbench/utils/chem_utils/mol_standardize.py +450 -0
  70. workbench/utils/chem_utils/mol_tagging.py +348 -0
  71. workbench/utils/chem_utils/projections.py +209 -0
  72. workbench/utils/chem_utils/salts.py +256 -0
  73. workbench/utils/chem_utils/sdf.py +292 -0
  74. workbench/utils/chem_utils/toxicity.py +250 -0
  75. workbench/utils/chem_utils/vis.py +253 -0
  76. workbench/utils/chemprop_utils.py +760 -0
  77. workbench/utils/cloudwatch_handler.py +1 -1
  78. workbench/utils/cloudwatch_utils.py +137 -0
  79. workbench/utils/config_manager.py +3 -7
  80. workbench/utils/endpoint_utils.py +5 -7
  81. workbench/utils/license_manager.py +2 -6
  82. workbench/utils/model_utils.py +95 -34
  83. workbench/utils/monitor_utils.py +44 -62
  84. workbench/utils/pandas_utils.py +3 -3
  85. workbench/utils/pytorch_utils.py +526 -0
  86. workbench/utils/shap_utils.py +10 -2
  87. workbench/utils/workbench_logging.py +0 -3
  88. workbench/utils/workbench_sqs.py +1 -1
  89. workbench/utils/xgboost_model_utils.py +371 -156
  90. workbench/web_interface/components/model_plot.py +7 -1
  91. workbench/web_interface/components/plugin_unit_test.py +5 -2
  92. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  93. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  94. workbench/web_interface/components/plugins/model_details.py +9 -7
  95. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  96. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  97. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
  98. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  99. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  100. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  101. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  102. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  103. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  104. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  105. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  106. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  107. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  108. workbench/utils/chem_utils.py +0 -1556
  109. workbench/utils/execution_environment.py +0 -211
  110. workbench/utils/fast_inference.py +0 -167
  111. workbench/utils/resource_utils.py +0 -39
  112. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  113. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -4,8 +4,10 @@ import sys
4
4
  import time
5
5
  import argparse
6
6
  from datetime import datetime, timedelta, timezone
7
- from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
7
+
8
+ # Workbench Imports
8
9
  from workbench.utils.repl_utils import cprint, Spinner
10
+ from workbench.utils.cloudwatch_utils import get_cloudwatch_client, get_active_log_streams, stream_log_events
9
11
 
10
12
  # Define the log levels to include all log levels above the specified level
11
13
  log_level_map = {
@@ -33,64 +35,6 @@ def date_display(dt):
33
35
  return dt.strftime("%Y-%m-%d %I:%M%p") + "(UTC)"
34
36
 
35
37
 
36
- def get_cloudwatch_client():
37
- """Get the CloudWatch Logs client using the Workbench assumed role session."""
38
- session = AWSAccountClamp().boto3_session
39
- return session.client("logs")
40
-
41
-
42
- def get_active_log_streams(client, log_group_name, start_time_ms, stream_filter=None):
43
- """Retrieve log streams that have events after the specified start time."""
44
-
45
- # Get all the streams in the log group
46
- active_streams = []
47
- stream_params = {
48
- "logGroupName": log_group_name,
49
- "orderBy": "LastEventTime",
50
- "descending": True,
51
- }
52
-
53
- # Loop to retrieve all log streams (maximum 50 per call)
54
- while True:
55
- response = client.describe_log_streams(**stream_params)
56
- log_streams = response.get("logStreams", [])
57
-
58
- for log_stream in log_streams:
59
- log_stream_name = log_stream["logStreamName"]
60
- last_event_timestamp = log_stream.get("lastEventTimestamp")
61
-
62
- # Include streams with events since the specified start time
63
- # Note: There's some issue where the last event timestamp is 'off'
64
- # so we're going to add 60 minutes from the last event timestamp
65
- last_event_timestamp += 60 * 60 * 1000
66
- if last_event_timestamp >= start_time_ms:
67
- active_streams.append(log_stream_name)
68
- else:
69
- break # Stop if we reach streams older than the start time
70
-
71
- # Check if there are more streams to retrieve
72
- if "nextToken" in response:
73
- stream_params["nextToken"] = response["nextToken"]
74
- else:
75
- break
76
-
77
- # Sort and report the active log streams
78
- active_streams.sort()
79
- if active_streams:
80
- print("Active log streams:", len(active_streams))
81
-
82
- # Filter the active streams by a substring if provided
83
- if stream_filter and active_streams:
84
- print(f"Filtering active log streams by '{stream_filter}'...")
85
- active_streams = [stream for stream in active_streams if stream_filter in stream]
86
-
87
- for stream in active_streams:
88
- print(f"\t - {stream}")
89
-
90
- # Return the active log streams
91
- return active_streams
92
-
93
-
94
38
  def get_latest_log_events(client, log_group_name, start_time, end_time=None, stream_filter=None):
95
39
  """Retrieve the latest log events from the active/filtered log streams in a CloudWatch Logs group."""
96
40
 
@@ -99,11 +43,15 @@ def get_latest_log_events(client, log_group_name, start_time, end_time=None, str
99
43
  get_latest_log_events.first_run = True
100
44
 
101
45
  log_events = []
102
- start_time_ms = int(start_time.timestamp() * 1000) # Convert start_time to milliseconds
46
+ start_time_ms = int(start_time.timestamp() * 1000)
47
+
48
+ # Use the util function to get active streams
49
+ active_streams = get_active_log_streams(log_group_name, start_time_ms, stream_filter, client)
103
50
 
104
- # Get the active log streams with events since start_time
105
- active_streams = get_active_log_streams(client, log_group_name, start_time_ms, stream_filter)
106
51
  if active_streams:
52
+ print(f"Active log streams: {len(active_streams)}")
53
+ for stream in active_streams:
54
+ print(f"\t - {stream}")
107
55
  print(f"Processing log events from {date_display(start_time)} on {len(active_streams)} active log streams...")
108
56
  get_latest_log_events.first_run = False
109
57
  else:
@@ -114,50 +62,22 @@ def get_latest_log_events(client, log_group_name, start_time, end_time=None, str
114
62
  print("Monitoring for new events...")
115
63
  return log_events
116
64
 
117
- # Iterate over the active streams and fetch log events
65
+ # Use the util function to stream events from each log stream
118
66
  for log_stream_name in active_streams:
119
- params = {
120
- "logGroupName": log_group_name,
121
- "logStreamName": log_stream_name,
122
- "startTime": start_time_ms, # Use start_time in milliseconds
123
- "startFromHead": True, # Start from the nearest event to start_time
124
- }
125
- next_event_token = None
126
- if end_time is not None:
127
- params["endTime"] = int(end_time.timestamp() * 1000)
128
-
129
- # Process the log events from this log stream
130
67
  spinner = Spinner("lightpurple", f"Pulling events from {log_stream_name}:")
131
68
  spinner.start()
132
69
  log_stream_events = 0
133
70
 
134
- # Get the log events for the active log stream
135
- while True:
136
- if next_event_token:
137
- params["nextToken"] = next_event_token
138
- params.pop("startTime", None) # Remove startTime when using nextToken
71
+ # Stream events using the util function
72
+ for event in stream_log_events(
73
+ log_group_name, log_stream_name, start_time, end_time, follow=False, client=client
74
+ ):
75
+ log_stream_events += 1
76
+ log_events.append(event)
139
77
 
140
- # Fetch the log events (this call takes a while: optimize if we can)
141
- events_response = client.get_log_events(**params)
78
+ spinner.stop()
79
+ print(f"Processed {log_stream_events} events from {log_stream_name} (Total: {len(log_events)})")
142
80
 
143
- events = events_response.get("events", [])
144
- for event in events:
145
- event["logStreamName"] = log_stream_name
146
-
147
- # Add the log stream events to our list of all log events
148
- log_stream_events += len(events)
149
- log_events.extend(events)
150
-
151
- # Handle pagination for log events
152
- next_event_token = events_response.get("nextForwardToken")
153
-
154
- # Break the loop if there are no more events to fetch
155
- if not next_event_token or next_event_token == params.get("nextToken"):
156
- spinner.stop()
157
- print(f"Processed {log_stream_events} events from {log_stream_name} (Total: {len(log_events)})")
158
- break
159
-
160
- # Return the log events
161
81
  return log_events
162
82
 
163
83
 
@@ -206,6 +126,7 @@ def monitor_log_group(
206
126
  print(f"Monitoring log group: {log_group_name} from {date_display(start_time)}")
207
127
  print(f"Log levels: {log_levels}")
208
128
  print(f"Search terms: {search_terms}")
129
+
209
130
  while True:
210
131
  # Get the latest log events with stream filtering if provided
211
132
  all_log_events = get_latest_log_events(client, log_group_name, start_time, end_time, stream_filter)
@@ -218,7 +139,6 @@ def monitor_log_group(
218
139
 
219
140
  # Check the search terms
220
141
  if not search_terms or any(term in event["message"].lower() for term in search_terms):
221
-
222
142
  # Calculate the start and end index for this match
223
143
  start_index = max(i - before, 0)
224
144
  end_index = min(i + after, len(all_log_events) - 1)
@@ -55,7 +55,8 @@ def aws_throttle(func=None, retry_intervals=None):
55
55
  if func is None:
56
56
  return lambda f: aws_throttle(f, retry_intervals=retry_intervals)
57
57
 
58
- service_hold_time = 2 # Seconds to wait before calling AWS function
58
+ # This is currently commented out (we might want to use it later)
59
+ # service_hold_time = 2 # Seconds to wait before calling AWS function
59
60
  default_intervals = [2**i for i in range(1, 9)] # Default exponential backoff: 2, 4, 8... 256 seconds
60
61
  intervals = retry_intervals or default_intervals
61
62
 
@@ -64,8 +65,8 @@ def aws_throttle(func=None, retry_intervals=None):
64
65
  for attempt, delay in enumerate(intervals, start=1):
65
66
  try:
66
67
  # Add sleep before calling AWS func if running as a service
67
- if cm.running_as_service:
68
- time.sleep(service_hold_time)
68
+ # if cm.running_as_service:
69
+ # time.sleep(service_hold_time)
69
70
  return func(*args, **kwargs)
70
71
  except ClientError as e:
71
72
  if e.response["Error"]["Code"] == "ThrottlingException":
File without changes
@@ -0,0 +1,134 @@
1
+ """Molecular fingerprint computation utilities"""
2
+
3
+ import logging
4
+ import pandas as pd
5
+
6
+ # Molecular Descriptor Imports
7
+ from rdkit import Chem
8
+ from rdkit.Chem import rdFingerprintGenerator
9
+ from rdkit.Chem.MolStandardize import rdMolStandardize
10
+
11
+ # Set up the logger
12
+ log = logging.getLogger("workbench")
13
+
14
+
15
+ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=True) -> pd.DataFrame:
16
+ """Compute and add Morgan fingerprints to the DataFrame.
17
+
18
+ Args:
19
+ df (pd.DataFrame): Input DataFrame containing SMILES strings.
20
+ radius (int): Radius for the Morgan fingerprint.
21
+ n_bits (int): Number of bits for the fingerprint.
22
+ counts (bool): Count simulation for the fingerprint.
23
+
24
+ Returns:
25
+ pd.DataFrame: The input DataFrame with the Morgan fingerprints added as bit strings.
26
+
27
+ Note:
28
+ See: https://greglandrum.github.io/rdkit-blog/posts/2021-07-06-simulating-counts.html
29
+ """
30
+ delete_mol_column = False
31
+
32
+ # Check for the SMILES column (case-insensitive)
33
+ smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
34
+ if smiles_column is None:
35
+ raise ValueError("Input DataFrame must have a 'smiles' column")
36
+
37
+ # Sanity check the molecule column (sometimes it gets serialized, which doesn't work)
38
+ if "molecule" in df.columns and df["molecule"].dtype == "string":
39
+ log.warning("Detected serialized molecules in 'molecule' column. Removing...")
40
+ del df["molecule"]
41
+
42
+ # Convert SMILES to RDKit molecule objects (vectorized)
43
+ if "molecule" not in df.columns:
44
+ log.info("Converting SMILES to RDKit Molecules...")
45
+ delete_mol_column = True
46
+ df["molecule"] = df[smiles_column].apply(Chem.MolFromSmiles)
47
+ # Make sure our molecules are not None
48
+ failed_smiles = df[df["molecule"].isnull()][smiles_column].tolist()
49
+ if failed_smiles:
50
+ log.error(f"Failed to convert the following SMILES to molecules: {failed_smiles}")
51
+ df = df.dropna(subset=["molecule"])
52
+
53
+ # If we have fragments in our compounds, get the largest fragment before computing fingerprints
54
+ largest_frags = df["molecule"].apply(
55
+ lambda mol: rdMolStandardize.LargestFragmentChooser().choose(mol) if mol else None
56
+ )
57
+
58
+ # Create a Morgan fingerprint generator
59
+ if counts:
60
+ n_bits *= 4 # Multiply by 4 to simulate counts
61
+ morgan_generator = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits, countSimulation=counts)
62
+
63
+ # Compute Morgan fingerprints (vectorized)
64
+ fingerprints = largest_frags.apply(
65
+ lambda mol: (morgan_generator.GetFingerprint(mol).ToBitString() if mol else pd.NA)
66
+ )
67
+
68
+ # Add the fingerprints to the DataFrame
69
+ df["fingerprint"] = fingerprints
70
+
71
+ # Drop the intermediate 'molecule' column if it was added
72
+ if delete_mol_column:
73
+ del df["molecule"]
74
+ return df
75
+
76
+
77
+ if __name__ == "__main__":
78
+ print("Running molecular fingerprint tests...")
79
+ print("Note: This requires molecular_screening module to be available")
80
+
81
+ # Test molecules
82
+ test_molecules = {
83
+ "aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
84
+ "caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
85
+ "glucose": "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # With stereochemistry
86
+ "sodium_acetate": "CC(=O)[O-].[Na+]", # Salt
87
+ "benzene": "c1ccccc1",
88
+ "butene_e": "C/C=C/C", # E-butene
89
+ "butene_z": "C/C=C\\C", # Z-butene
90
+ }
91
+
92
+ # Test 1: Morgan Fingerprints
93
+ print("\n1. Testing Morgan fingerprint generation...")
94
+
95
+ test_df = pd.DataFrame({"SMILES": list(test_molecules.values()), "name": list(test_molecules.keys())})
96
+
97
+ fp_df = compute_morgan_fingerprints(test_df.copy(), radius=2, n_bits=512, counts=False)
98
+
99
+ print(" Fingerprint generation results:")
100
+ for _, row in fp_df.iterrows():
101
+ fp = row.get("fingerprint", "N/A")
102
+ fp_len = len(fp) if fp != "N/A" else 0
103
+ print(f" {row['name']:15} → {fp_len} bits")
104
+
105
+ # Test 2: Different fingerprint parameters
106
+ print("\n2. Testing different fingerprint parameters...")
107
+
108
+ # Test with counts enabled
109
+ fp_counts_df = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=256, counts=True)
110
+
111
+ print(" With count simulation (256 bits * 4):")
112
+ for _, row in fp_counts_df.iterrows():
113
+ fp = row.get("fingerprint", "N/A")
114
+ fp_len = len(fp) if fp != "N/A" else 0
115
+ print(f" {row['name']:15} → {fp_len} bits")
116
+
117
+ # Test 3: Edge cases
118
+ print("\n3. Testing edge cases...")
119
+
120
+ # Invalid SMILES
121
+ invalid_df = pd.DataFrame({"SMILES": ["INVALID", ""]})
122
+ try:
123
+ fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
124
+ print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} valid molecules")
125
+ except Exception as e:
126
+ print(f" ✓ Invalid SMILES properly raised error: {type(e).__name__}")
127
+
128
+ # Test with pre-existing molecule column
129
+ mol_df = test_df.copy()
130
+ mol_df["molecule"] = mol_df["SMILES"].apply(Chem.MolFromSmiles)
131
+ fp_with_mol = compute_morgan_fingerprints(mol_df)
132
+ print(f" ✓ Pre-existing molecule column handled: {len(fp_with_mol)} fingerprints generated")
133
+
134
+ print("\n✅ All fingerprint tests completed!")
@@ -0,0 +1,194 @@
1
+ """Miscellaneous processing functions for molecular data."""
2
+
3
+ import logging
4
+ import numpy as np
5
+ import pandas as pd
6
+ from typing import List, Optional
7
+
8
+ # Set up the logger
9
+ log = logging.getLogger("workbench")
10
+
11
+
12
+ def geometric_mean(series: pd.Series) -> float:
13
+ """Computes the geometric mean manually to avoid using scipy."""
14
+ return np.exp(np.log(series).mean())
15
+
16
+
17
+ def rollup_experimental_data(
18
+ df: pd.DataFrame, id: str, time: str, target: str, use_gmean: bool = False
19
+ ) -> pd.DataFrame:
20
+ """
21
+ Rolls up a dataset by selecting the largest time per unique ID and averaging the target value
22
+ if multiple records exist at that time. Supports both arithmetic and geometric mean.
23
+
24
+ Parameters:
25
+ df (pd.DataFrame): Input dataframe.
26
+ id (str): Column representing the unique molecule ID.
27
+ time (str): Column representing the time.
28
+ target (str): Column representing the target value.
29
+ use_gmean (bool): Whether to use the geometric mean instead of the arithmetic mean.
30
+
31
+ Returns:
32
+ pd.DataFrame: Rolled-up dataframe with all original columns retained.
33
+ """
34
+ # Find the max time per unique ID
35
+ max_time_df = df.groupby(id)[time].transform("max")
36
+ filtered_df = df[df[time] == max_time_df]
37
+
38
+ # Define aggregation function
39
+ agg_func = geometric_mean if use_gmean else np.mean
40
+
41
+ # Perform aggregation on all columns
42
+ agg_dict = {col: "first" for col in df.columns if col not in [target, id, time]}
43
+ agg_dict[target] = lambda x: agg_func(x) if len(x) > 1 else x.iloc[0] # Apply mean or gmean
44
+
45
+ rolled_up_df = filtered_df.groupby([id, time]).agg(agg_dict).reset_index()
46
+ return rolled_up_df
47
+
48
+
49
+ def micromolar_to_log(series_µM: pd.Series) -> pd.Series:
50
+ """
51
+ Convert a pandas Series of concentrations in µM (micromolar) to their logarithmic values (log10).
52
+
53
+ Parameters:
54
+ series_uM (pd.Series): Series of concentrations in micromolar.
55
+
56
+ Returns:
57
+ pd.Series: Series of logarithmic values (log10).
58
+ """
59
+ # Replace 0 or negative values with a small number to avoid log errors
60
+ adjusted_series = series_µM.clip(lower=1e-9) # Alignment with another project
61
+
62
+ series_mol_per_l = adjusted_series * 1e-6 # Convert µM/L to mol/L
63
+ log_series = np.log10(series_mol_per_l)
64
+ return log_series
65
+
66
+
67
+ def log_to_micromolar(log_series: pd.Series) -> pd.Series:
68
+ """
69
+ Convert a pandas Series of logarithmic values (log10) back to concentrations in µM (micromolar).
70
+
71
+ Parameters:
72
+ log_series (pd.Series): Series of logarithmic values (log10).
73
+
74
+ Returns:
75
+ pd.Series: Series of concentrations in micromolar.
76
+ """
77
+ series_mol_per_l = 10**log_series # Convert log10 back to mol/L
78
+ series_µM = series_mol_per_l * 1e6 # Convert mol/L to µM
79
+ return series_µM
80
+
81
+
82
+ def feature_resolution_issues(df: pd.DataFrame, features: List[str], show_cols: Optional[List[str]] = None) -> None:
83
+ """
84
+ Identify and print groups in a DataFrame where the given features have more than one unique SMILES,
85
+ sorted by group size (largest number of unique SMILES first).
86
+
87
+ Args:
88
+ df (pd.DataFrame): Input DataFrame containing SMILES strings.
89
+ features (List[str]): List of features to check.
90
+ show_cols (Optional[List[str]]): Columns to display; defaults to all columns.
91
+ """
92
+ # Check for the 'smiles' column (case-insensitive)
93
+ smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
94
+ if smiles_column is None:
95
+ raise ValueError("Input DataFrame must have a 'smiles' column")
96
+
97
+ show_cols = show_cols if show_cols is not None else df.columns.tolist()
98
+
99
+ # Drop duplicates to keep only unique SMILES for each feature combination
100
+ unique_df = df.drop_duplicates(subset=[smiles_column] + features)
101
+
102
+ # Find groups with more than one unique SMILES
103
+ group_counts = unique_df.groupby(features).size()
104
+ collision_groups = group_counts[group_counts > 1].sort_values(ascending=False)
105
+
106
+ # Print each group in order of size (largest first)
107
+ for group, count in collision_groups.items():
108
+ # Get the rows for this group
109
+ if isinstance(group, tuple):
110
+ group_mask = (unique_df[features] == group).all(axis=1)
111
+ else:
112
+ group_mask = unique_df[features[0]] == group
113
+
114
+ group_df = unique_df[group_mask]
115
+
116
+ print(f"Feature Group (unique SMILES: {count}):")
117
+ print(group_df[show_cols])
118
+ print("\n")
119
+
120
+
121
+ if __name__ == "__main__":
122
+ print("Running molecular processing and transformation tests...")
123
+ print("Note: This requires the molecular_filters module to be available")
124
+
125
+ # Test 1: Concentration conversions
126
+ print("\n1. Testing concentration conversions...")
127
+
128
+ # Test micromolar to log
129
+ test_conc = pd.Series([1.0, 10.0, 100.0, 1000.0, 0.001])
130
+ log_values = micromolar_to_log(test_conc)
131
+ back_to_uM = log_to_micromolar(log_values)
132
+
133
+ print(" µM → log10 → µM:")
134
+ for orig, log_val, back in zip(test_conc, log_values, back_to_uM):
135
+ print(f" {orig:8.3f} µM → {log_val:6.2f} → {back:8.3f} µM")
136
+
137
+ # Test 2: Geometric mean
138
+ print("\n2. Testing geometric mean...")
139
+ test_series = pd.Series([2, 4, 8, 16])
140
+ geo_mean = geometric_mean(test_series)
141
+ arith_mean = np.mean(test_series)
142
+ print(f" Series: {list(test_series)}")
143
+ print(f" Arithmetic mean: {arith_mean:.2f}")
144
+ print(f" Geometric mean: {geo_mean:.2f}")
145
+
146
+ # Test 3: Experimental data rollup
147
+ print("\n3. Testing experimental data rollup...")
148
+
149
+ # Create test data with multiple timepoints and replicates
150
+ test_data = pd.DataFrame(
151
+ {
152
+ "compound_id": ["A", "A", "A", "B", "B", "C", "C", "C"],
153
+ "time": [1, 2, 2, 1, 2, 1, 1, 2],
154
+ "activity": [10, 20, 22, 5, 8, 100, 110, 200],
155
+ "assay": ["kinase", "kinase", "kinase", "kinase", "kinase", "cell", "cell", "cell"],
156
+ }
157
+ )
158
+
159
+ # Rollup with arithmetic mean
160
+ rolled_arith = rollup_experimental_data(test_data, "compound_id", "time", "activity", use_gmean=False)
161
+ print(" Arithmetic mean rollup:")
162
+ print(rolled_arith[["compound_id", "time", "activity"]])
163
+
164
+ # Rollup with geometric mean
165
+ rolled_geo = rollup_experimental_data(test_data, "compound_id", "time", "activity", use_gmean=True)
166
+ print("\n Geometric mean rollup:")
167
+ print(rolled_geo[["compound_id", "time", "activity"]])
168
+
169
+ # Test 4: Feature resolution issues
170
+ print("\n4. Testing feature resolution identification...")
171
+
172
+ # Create data with some duplicate features but different SMILES
173
+ resolution_df = pd.DataFrame(
174
+ {
175
+ "smiles": ["CCO", "C(C)O", "CC(C)O", "CCC(C)O", "CCCO"],
176
+ "assay_id": ["A1", "A1", "A2", "A2", "A3"],
177
+ "value": [1.0, 1.5, 2.0, 2.2, 3.0],
178
+ }
179
+ )
180
+
181
+ print(" Checking for feature collisions in 'assay_id':")
182
+ feature_resolution_issues(resolution_df, ["assay_id"], show_cols=["smiles", "assay_id", "value"])
183
+
184
+ # Test 7: Edge cases
185
+ print("\n7. Testing edge cases...")
186
+
187
+ # Zero and negative concentrations
188
+ edge_conc = pd.Series([0, -1, 1e-10])
189
+ edge_log = micromolar_to_log(edge_conc)
190
+ print(" Edge concentration handling:")
191
+ for c, l in zip(edge_conc, edge_log):
192
+ print(f" {c:6.2e} µM → {l:6.2f}")
193
+
194
+ print("\n✅ All molecular processing tests completed!")