masster 0.4.11__py3-none-any.whl → 0.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of masster might be problematic. Click here for more details.

masster/study/id.py CHANGED
@@ -16,13 +16,7 @@ def lib_load(
16
16
  polarity: str | None = None,
17
17
  adducts: list | None = None,
18
18
  ):
19
- """Load a co # Add compound and formula count columns
20
- if "consensus_uid" in result_df.columns:
21
- # Calculate counts per consensus_uid
22
- count_stats = result_df.group_by("consensus_uid").agg([
23
- pl.col("cmpd_uid").filter(pl.col("cmpd_uid").is_not_null()).n_unique().alias("num_cmpds") if "cmpd_uid" in result_df.columns else pl.lit(None).alias("num_cmpds"),
24
- pl.col("formula").filter(pl.col("formula").is_not_null()).n_unique().alias("num_formulas") if "formula" in result_df.columns else pl.lit(None).alias("num_formulas")
25
- ])library into the study.
19
+ """Load a compound library into the study.
26
20
 
27
21
  Args:
28
22
  study: Study instance
@@ -117,14 +111,17 @@ def lib_load(
117
111
  study.lib_df = (
118
112
  filtered_lf.clone()
119
113
  if hasattr(filtered_lf, "clone")
120
- else pl.DataFrame(filtered_lf)
114
+ else pl.DataFrame(filtered_lf.to_dict() if hasattr(filtered_lf, 'to_dict') else filtered_lf)
121
115
  )
122
116
  except Exception:
123
- study.lib_df = (
124
- pl.from_pandas(filtered_lf)
125
- if hasattr(filtered_lf, "to_pandas")
126
- else pl.DataFrame(filtered_lf)
127
- )
117
+ try:
118
+ study.lib_df = (
119
+ pl.from_pandas(filtered_lf)
120
+ if hasattr(filtered_lf, "to_pandas")
121
+ else pl.DataFrame(filtered_lf.to_dict() if hasattr(filtered_lf, 'to_dict') else filtered_lf)
122
+ )
123
+ except Exception:
124
+ study.lib_df = pl.DataFrame()
128
125
 
129
126
  # Store this operation in history
130
127
  if hasattr(study, "store_history"):
@@ -134,29 +131,8 @@ def lib_load(
134
131
  )
135
132
 
136
133
 
137
- def identify(study, features=None, params=None, **kwargs):
138
- """Identify consensus features against the loaded library.
139
-
140
- Matches consensus_df.mz against lib_df.mz within mz_tolerance. If rt_tolerance
141
- is provided and both consensus and library entries have rt values, RT is
142
- used as an additional filter.
143
-
144
- Args:
145
- study: Study instance
146
- features: Optional DataFrame or list of consensus_uids to identify.
147
- If None, identifies all consensus features.
148
- params: Optional identify_defaults instance with matching tolerances and scoring parameters.
149
- If None, uses default parameters.
150
- **kwargs: Individual parameter overrides (mz_tol, rt_tol, heteroatom_penalty,
151
- multiple_formulas_penalty, multiple_compounds_penalty, heteroatoms)
152
-
153
- The resulting DataFrame is stored as study.id_df. Columns:
154
- - consensus_uid
155
- - lib_uid
156
- - mz_delta
157
- - rt_delta (nullable)
158
- - score (adduct probability from _get_adducts with penalties applied)
159
- """
134
+ def _setup_identify_parameters(params, kwargs):
135
+ """Setup identification parameters with fallbacks and overrides."""
160
136
  # Import defaults class
161
137
  try:
162
138
  from masster.study.defaults.identify_def import identify_defaults
@@ -184,19 +160,251 @@ def identify(study, features=None, params=None, **kwargs):
184
160
  for param_name, value in kwargs.items():
185
161
  if hasattr(params, param_name):
186
162
  setattr(params, param_name, value)
163
+
164
+ return params
187
165
 
188
- # Get effective tolerances from params (now possibly overridden)
189
- effective_mz_tol = getattr(params, "mz_tol", 0.01)
190
- effective_rt_tol = getattr(params, "rt_tol", 2.0)
191
- # Get logger from study if available
192
- logger = getattr(study, "logger", None)
193
166
 
167
+ def _smart_reset_id_results(study, target_uids, logger):
168
+ """Smart reset of identification results - only clear what's being re-identified."""
169
+ if target_uids is not None:
170
+ # Selective reset: only clear results for features being re-identified
171
+ if hasattr(study, "id_df") and study.id_df is not None and not study.id_df.is_empty():
172
+ study.id_df = study.id_df.filter(
173
+ ~pl.col("consensus_uid").is_in(target_uids)
174
+ )
175
+ if logger:
176
+ logger.debug(f"Cleared previous results for {len(target_uids)} specific features")
177
+ elif not hasattr(study, "id_df"):
178
+ study.id_df = pl.DataFrame()
179
+ else:
180
+ # Full reset: clear all results
181
+ study.id_df = pl.DataFrame()
182
+ if logger:
183
+ logger.debug("Cleared all previous identification results")
184
+
185
+
186
+ def _get_cached_adduct_probabilities(study, logger):
187
+ """Get adduct probabilities with caching to avoid repeated expensive computation."""
188
+ # Check if we have cached results and cache key matches current parameters
189
+ current_cache_key = _get_adduct_cache_key(study)
190
+
191
+ if (hasattr(study, '_cached_adduct_probs') and
192
+ hasattr(study, '_cached_adduct_key') and
193
+ study._cached_adduct_key == current_cache_key):
194
+ if logger:
195
+ logger.debug("Using cached adduct probabilities")
196
+ return study._cached_adduct_probs
197
+
198
+ # Compute and cache
199
+ if logger:
200
+ logger.debug("Computing adduct probabilities...")
201
+ adduct_prob_map = _get_adduct_probabilities(study)
202
+ study._cached_adduct_probs = adduct_prob_map
203
+ study._cached_adduct_key = current_cache_key
204
+
194
205
  if logger:
206
+ logger.debug(f"Computed and cached probabilities for {len(adduct_prob_map)} adducts")
207
+ return adduct_prob_map
208
+
209
+
210
+ def _get_adduct_cache_key(study):
211
+ """Generate a cache key based on adduct-related parameters."""
212
+ if hasattr(study, 'parameters') and hasattr(study.parameters, 'adducts'):
213
+ adducts_str = '|'.join(sorted(study.parameters.adducts)) if study.parameters.adducts else ""
214
+ min_prob = getattr(study.parameters, 'adduct_min_probability', 0.04)
215
+ return f"adducts:{adducts_str}:min_prob:{min_prob}"
216
+ return "default"
217
+
218
+
219
+ def clear_identification_cache(study):
220
+ """Clear cached identification data (useful when parameters change)."""
221
+ cache_attrs = ['_cached_adduct_probs', '_cached_adduct_key']
222
+ for attr in cache_attrs:
223
+ if hasattr(study, attr):
224
+ delattr(study, attr)
225
+
226
+
227
+ def _perform_identification_matching(consensus_to_process, study, effective_mz_tol, effective_rt_tol, adduct_prob_map, logger):
228
+ """Perform optimized identification matching using vectorized operations where possible."""
229
+ results = []
230
+
231
+ # Get library data as arrays for faster access
232
+ lib_df = study.lib_df
233
+
234
+ if logger:
235
+ consensus_count = len(consensus_to_process)
236
+ lib_count = len(lib_df)
195
237
  logger.debug(
196
- f"Starting identification with mz_tolerance={effective_mz_tol}, rt_tolerance={effective_rt_tol}",
238
+ f"Identifying {consensus_count} consensus features against {lib_count} library entries",
239
+ )
240
+
241
+ # Process each consensus feature
242
+ for cons_row in consensus_to_process.iter_rows(named=True):
243
+ cons_uid = cons_row.get("consensus_uid")
244
+ cons_mz = cons_row.get("mz")
245
+ cons_rt = cons_row.get("rt")
246
+
247
+ if cons_mz is None:
248
+ if logger:
249
+ logger.debug(f"Skipping consensus feature {cons_uid} - no m/z value")
250
+ results.append({"consensus_uid": cons_uid, "matches": []})
251
+ continue
252
+
253
+ # Find matches using vectorized filtering
254
+ matches = _find_matches_vectorized(
255
+ lib_df, cons_mz, cons_rt, effective_mz_tol, effective_rt_tol, logger, cons_uid
197
256
  )
257
+
258
+ # Convert matches to result format
259
+ match_results = []
260
+ if not matches.is_empty():
261
+ for match_row in matches.iter_rows(named=True):
262
+ mz_delta = abs(cons_mz - match_row.get("mz")) if match_row.get("mz") is not None else None
263
+ lib_rt = match_row.get("rt")
264
+ rt_delta = (
265
+ abs(cons_rt - lib_rt)
266
+ if (cons_rt is not None and lib_rt is not None)
267
+ else None
268
+ )
269
+
270
+ # Get adduct probability from cached map
271
+ adduct = match_row.get("adduct")
272
+ score = adduct_prob_map.get(adduct, 1.0) if adduct else 1.0
273
+
274
+ match_results.append({
275
+ "lib_uid": match_row.get("lib_uid"),
276
+ "mz_delta": mz_delta,
277
+ "rt_delta": rt_delta,
278
+ "matcher": "ms1",
279
+ "score": score,
280
+ })
281
+
282
+ results.append({"consensus_uid": cons_uid, "matches": match_results})
283
+
284
+ return results
285
+
286
+
287
+ def _find_matches_vectorized(lib_df, cons_mz, cons_rt, mz_tol, rt_tol, logger, cons_uid):
288
+ """Find library matches using optimized vectorized operations."""
289
+ # Filter by m/z tolerance using vectorized operations
290
+ matches = lib_df.filter(
291
+ (pl.col("mz") >= cons_mz - mz_tol) & (pl.col("mz") <= cons_mz + mz_tol)
292
+ )
293
+
294
+ initial_match_count = len(matches)
295
+
296
+ # Apply RT filter if available
297
+ if rt_tol is not None and cons_rt is not None and not matches.is_empty():
298
+ rt_matches = matches.filter(
299
+ pl.col("rt").is_not_null() &
300
+ (pl.col("rt") >= cons_rt - rt_tol) &
301
+ (pl.col("rt") <= cons_rt + rt_tol)
302
+ )
303
+
304
+ if not rt_matches.is_empty():
305
+ matches = rt_matches
306
+ if logger:
307
+ logger.debug(
308
+ f"Consensus {cons_uid}: {initial_match_count} m/z matches, {len(matches)} after RT filter"
309
+ )
310
+ else:
311
+ if logger:
312
+ logger.debug(
313
+ f"Consensus {cons_uid}: {initial_match_count} m/z matches, 0 after RT filter - using m/z matches only"
314
+ )
315
+
316
+ # Optimized deduplication using Polars operations
317
+ if not matches.is_empty() and len(matches) > 1:
318
+ if "formula" in matches.columns and "adduct" in matches.columns:
319
+ pre_dedup_count = len(matches)
320
+
321
+ # Use Polars group_by with maintain_order for consistent results
322
+ matches = (
323
+ matches
324
+ .sort("lib_uid") # Ensure consistent ordering
325
+ .group_by(["formula", "adduct"], maintain_order=True)
326
+ .first()
327
+ )
328
+
329
+ post_dedup_count = len(matches)
330
+ if logger and post_dedup_count < pre_dedup_count:
331
+ logger.debug(
332
+ f"Consensus {cons_uid}: deduplicated {pre_dedup_count} to {post_dedup_count} matches"
333
+ )
334
+
335
+ return matches
336
+
337
+
338
+ def _update_identification_results(study, results, logger):
339
+ """Update study.id_df with new identification results."""
340
+ # Flatten results into records
341
+ records = []
342
+ for result in results:
343
+ consensus_uid = result["consensus_uid"]
344
+ for match in result["matches"]:
345
+ records.append({
346
+ "consensus_uid": consensus_uid,
347
+ "lib_uid": match["lib_uid"],
348
+ "mz_delta": match["mz_delta"],
349
+ "rt_delta": match["rt_delta"],
350
+ "matcher": match["matcher"],
351
+ "score": match["score"],
352
+ })
353
+
354
+ # Convert to DataFrame and append to existing results
355
+ new_results_df = pl.DataFrame(records) if records else pl.DataFrame()
356
+
357
+ if not new_results_df.is_empty():
358
+ if hasattr(study, "id_df") and study.id_df is not None and not study.id_df.is_empty():
359
+ study.id_df = pl.concat([study.id_df, new_results_df])
360
+ else:
361
+ study.id_df = new_results_df
362
+
363
+ if logger:
364
+ logger.debug(f"Added {len(records)} identification results to study.id_df")
365
+ elif not hasattr(study, "id_df"):
366
+ study.id_df = pl.DataFrame()
367
+
368
+
369
+ def _finalize_identification_results(study, params, logger):
370
+ """Apply final scoring adjustments and update consensus columns."""
371
+ # Apply scoring adjustments based on compound and formula counts
372
+ _apply_scoring_adjustments(study, params)
373
+
374
+ # Update consensus_df with top-scoring identification results
375
+ _update_consensus_id_columns(study, logger)
376
+
377
+
378
+ def _store_identification_history(study, effective_mz_tol, effective_rt_tol, target_uids, params, kwargs):
379
+ """Store identification operation in study history."""
380
+ if hasattr(study, "store_history"):
381
+ history_params = {"mz_tol": effective_mz_tol, "rt_tol": effective_rt_tol}
382
+ if target_uids is not None:
383
+ history_params["features"] = target_uids
384
+ if params is not None and hasattr(params, "to_dict"):
385
+ history_params["params"] = params.to_dict()
386
+ if kwargs:
387
+ history_params["kwargs"] = kwargs
388
+ study.store_history(["identify"], history_params)
389
+
390
+
391
+ def _validate_identify_inputs(study, logger=None):
392
+ """Validate inputs for identification process."""
393
+ if getattr(study, "consensus_df", None) is None or study.consensus_df.is_empty():
394
+ if logger:
395
+ logger.warning("No consensus features found for identification")
396
+ return False
397
+
398
+ if getattr(study, "lib_df", None) is None or study.lib_df.is_empty():
399
+ if logger:
400
+ logger.error("Library (study.lib_df) is empty; call lib_load() first")
401
+ raise ValueError("Library (study.lib_df) is empty; call lib_load() first")
402
+
403
+ return True
404
+
198
405
 
199
- # Determine which features to process
406
+ def _prepare_consensus_features(study, features, logger=None):
407
+ """Prepare consensus features for identification."""
200
408
  target_uids = None
201
409
  if features is not None:
202
410
  if hasattr(features, "columns"): # DataFrame-like
@@ -219,38 +427,6 @@ def identify(study, features=None, params=None, **kwargs):
219
427
  if logger:
220
428
  logger.debug(f"Identifying {len(target_uids)} specified features")
221
429
 
222
- # Clear previous identification results for target features only
223
- if hasattr(study, "id_df") and not study.id_df.is_empty():
224
- if target_uids is not None:
225
- # Keep results for features NOT being re-identified
226
- study.id_df = study.id_df.filter(
227
- ~pl.col("consensus_uid").is_in(target_uids),
228
- )
229
- if logger:
230
- logger.debug(
231
- f"Cleared previous identification results for {len(target_uids)} features",
232
- )
233
- else:
234
- # Clear all results if no specific features specified
235
- study.id_df = pl.DataFrame()
236
- if logger:
237
- logger.debug("Cleared all previous identification results")
238
- elif not hasattr(study, "id_df"):
239
- study.id_df = pl.DataFrame()
240
- if logger:
241
- logger.debug("Initialized empty id_df")
242
-
243
- # Validate inputs
244
- if getattr(study, "consensus_df", None) is None or study.consensus_df.is_empty():
245
- if logger:
246
- logger.warning("No consensus features found for identification")
247
- return
248
-
249
- if getattr(study, "lib_df", None) is None or study.lib_df.is_empty():
250
- if logger:
251
- logger.error("Library (study.lib_df) is empty; call lib_load() first")
252
- raise ValueError("Library (study.lib_df) is empty; call lib_load() first")
253
-
254
430
  # Filter consensus features if target_uids specified
255
431
  consensus_to_process = study.consensus_df
256
432
  if target_uids is not None:
@@ -262,270 +438,294 @@ def identify(study, features=None, params=None, **kwargs):
262
438
  logger.warning(
263
439
  "No consensus features found matching specified features",
264
440
  )
265
- return
441
+ return None, target_uids
442
+
443
+ return consensus_to_process, target_uids
266
444
 
267
- consensus_count = len(consensus_to_process)
268
- lib_count = len(study.lib_df)
269
445
 
270
- if logger:
271
- if target_uids is not None:
272
- logger.debug(
273
- f"Identifying {consensus_count} specified consensus features against {lib_count} library entries",
274
- )
275
- else:
276
- logger.debug(
277
- f"Identifying {consensus_count} consensus features against {lib_count} library entries",
278
- )
279
446
 
280
- # Get adduct probabilities
281
- adducts_df = study._get_adducts()
447
+
448
+ def _get_adduct_probabilities(study):
449
+ """Get adduct probabilities from _get_adducts() results."""
450
+ adducts_df = _get_adducts(study)
282
451
  adduct_prob_map = {}
283
452
  if not adducts_df.is_empty():
284
453
  for row in adducts_df.iter_rows(named=True):
285
454
  adduct_prob_map[row.get("name")] = row.get("probability", 1.0)
455
+ return adduct_prob_map
286
456
 
287
- results = []
288
- features_with_matches = 0
289
- total_matches = 0
290
- rt_filtered_compounds = 0
291
- multiply_charged_filtered = 0
292
457
 
293
- # Iterate consensus rows and find matching lib rows by m/z +/- tolerance
294
- for cons in consensus_to_process.iter_rows(named=True):
295
- cons_mz = cons.get("mz")
296
- cons_rt = cons.get("rt")
297
- cons_uid = cons.get("consensus_uid")
298
458
 
299
- if cons_mz is None:
300
- if logger:
301
- logger.debug(f"Skipping consensus feature {cons_uid} - no m/z value")
302
- continue
459
+ def _create_identification_results(consensus_to_process, study, effective_mz_tol, effective_rt_tol, adduct_prob_map, logger=None):
460
+ """Create identification results by matching consensus features against library (DEPRECATED - use optimized version)."""
461
+ # This function is now deprecated in favor of _perform_identification_matching
462
+ # Keep for backward compatibility but redirect to optimized version
463
+ results = _perform_identification_matching(
464
+ consensus_to_process, study, effective_mz_tol, effective_rt_tol, adduct_prob_map, logger
465
+ )
466
+
467
+ # Convert to legacy format for compatibility
468
+ legacy_results = []
469
+ features_with_matches = 0
470
+ total_matches = 0
471
+
472
+ for result in results:
473
+ if result["matches"]:
474
+ features_with_matches += 1
475
+ total_matches += len(result["matches"])
476
+
477
+ for match in result["matches"]:
478
+ legacy_results.append({
479
+ "consensus_uid": result["consensus_uid"],
480
+ "lib_uid": match["lib_uid"],
481
+ "mz_delta": match["mz_delta"],
482
+ "rt_delta": match["rt_delta"],
483
+ "matcher": match["matcher"],
484
+ "score": match["score"],
485
+ })
486
+
487
+ return legacy_results, features_with_matches, total_matches
488
+
489
+
490
+ def _apply_scoring_adjustments(study, params):
491
+ """Apply scoring adjustments based on compound and formula counts using optimized operations."""
492
+ if (
493
+ not study.id_df.is_empty()
494
+ and hasattr(study, "lib_df")
495
+ and not study.lib_df.is_empty()
496
+ ):
497
+ # Get penalty parameters
498
+ heteroatoms = getattr(params, "heteroatoms", ["Cl", "Br", "F", "I"])
499
+ heteroatom_penalty = getattr(params, "heteroatom_penalty", 0.7)
500
+ formulas_penalty = getattr(params, "multiple_formulas_penalty", 0.8)
501
+ compounds_penalty = getattr(params, "multiple_compounds_penalty", 0.8)
303
502
 
304
- # Filter lib by mz window
305
- matches = study.lib_df.filter(
306
- (pl.col("mz") >= cons_mz - effective_mz_tol)
307
- & (pl.col("mz") <= cons_mz + effective_mz_tol),
503
+ # Single join to get all needed library information
504
+ lib_columns = ["lib_uid", "cmpd_uid", "formula"]
505
+ id_with_lib = study.id_df.join(
506
+ study.lib_df.select(lib_columns),
507
+ on="lib_uid",
508
+ how="left",
308
509
  )
309
510
 
310
- initial_matches = len(matches)
511
+ # Calculate all statistics in one group_by operation
512
+ stats = id_with_lib.group_by("consensus_uid").agg([
513
+ pl.col("cmpd_uid").n_unique().alias("num_cmpds"),
514
+ pl.col("formula").filter(pl.col("formula").is_not_null()).n_unique().alias("num_formulas"),
515
+ ])
311
516
 
312
- # If rt_tol provided and consensus RT present, prefer rt-filtered hits
313
- if effective_rt_tol is not None and cons_rt is not None:
314
- rt_matches = matches.filter(
315
- pl.col("rt").is_not_null()
316
- & (pl.col("rt") >= cons_rt - effective_rt_tol)
317
- & (pl.col("rt") <= cons_rt + effective_rt_tol),
318
- )
319
- if not rt_matches.is_empty():
320
- matches = rt_matches
321
- if logger:
322
- logger.debug(
323
- f"Consensus {cons_uid}: {initial_matches} m/z matches, {len(matches)} after RT filter",
324
- )
325
- else:
326
- if logger:
327
- logger.debug(
328
- f"Consensus {cons_uid}: {initial_matches} m/z matches, 0 after RT filter - using m/z matches only",
329
- )
517
+ # Join stats back and apply all penalties in one with_columns operation
518
+ heteroatom_conditions = [pl.col("formula").str.contains(atom) for atom in heteroatoms]
519
+ has_heteroatoms = pl.fold(
520
+ acc=pl.lit(False),
521
+ function=lambda acc, x: acc | x,
522
+ exprs=heteroatom_conditions
523
+ ) if heteroatom_conditions else pl.lit(False)
330
524
 
331
- # Apply scoring-based filtering system
332
- if not matches.is_empty():
333
- filtered_matches = matches.clone()
334
- else:
335
- filtered_matches = pl.DataFrame()
525
+ study.id_df = (
526
+ id_with_lib
527
+ .join(stats, on="consensus_uid", how="left")
528
+ .with_columns([
529
+ # Apply all penalties in sequence using case-when chains
530
+ pl.when(pl.col("formula").is_not_null() & has_heteroatoms)
531
+ .then(pl.col("score") * heteroatom_penalty)
532
+ .otherwise(pl.col("score"))
533
+ .alias("score_temp1")
534
+ ])
535
+ .with_columns([
536
+ pl.when(pl.col("num_formulas") > 1)
537
+ .then(pl.col("score_temp1") * formulas_penalty)
538
+ .otherwise(pl.col("score_temp1"))
539
+ .alias("score_temp2")
540
+ ])
541
+ .with_columns([
542
+ pl.when(pl.col("num_cmpds") > 1)
543
+ .then(pl.col("score_temp2") * compounds_penalty)
544
+ .otherwise(pl.col("score_temp2"))
545
+ .round(4)
546
+ .alias("score")
547
+ ])
548
+ .select([
549
+ "consensus_uid",
550
+ "lib_uid",
551
+ "mz_delta",
552
+ "rt_delta",
553
+ "matcher",
554
+ "score",
555
+ ])
556
+ )
336
557
 
337
- if not filtered_matches.is_empty():
338
- features_with_matches += 1
339
- feature_match_count = len(filtered_matches)
340
- total_matches += feature_match_count
341
558
 
559
+ def _update_consensus_id_columns(study, logger=None):
560
+ """Update consensus_df with top-scoring identification results using safe in-place updates."""
561
+ try:
562
+ if not hasattr(study, "id_df") or study.id_df is None or study.id_df.is_empty():
342
563
  if logger:
343
- logger.debug(
344
- f"Consensus {cons_uid} (mz={cons_mz:.5f}): {feature_match_count} library matches",
564
+ logger.debug("No identification results to process")
565
+ return
566
+
567
+ if not hasattr(study, "lib_df") or study.lib_df is None or study.lib_df.is_empty():
568
+ if logger:
569
+ logger.debug("No library data available")
570
+ return
571
+
572
+ if not hasattr(study, "consensus_df") or study.consensus_df is None or study.consensus_df.is_empty():
573
+ if logger:
574
+ logger.debug("No consensus data available")
575
+ return
576
+
577
+ # Get library columns we need
578
+ lib_columns = ["lib_uid", "name", "adduct"]
579
+ if "class" in study.lib_df.columns:
580
+ lib_columns.append("class")
581
+
582
+ # Get top-scoring identification for each consensus feature
583
+ top_ids = (
584
+ study.id_df
585
+ .sort(["consensus_uid", "score"], descending=[False, True])
586
+ .group_by("consensus_uid", maintain_order=True)
587
+ .first()
588
+ .join(study.lib_df.select(lib_columns), on="lib_uid", how="left")
589
+ .select([
590
+ "consensus_uid",
591
+ "name",
592
+ pl.col("class").alias("id_top_class") if "class" in lib_columns else pl.lit(None, dtype=pl.String).alias("id_top_class"),
593
+ pl.col("adduct").alias("id_top_adduct"),
594
+ pl.col("score").alias("id_top_score")
595
+ ])
596
+ .rename({"name": "id_top_name"})
597
+ )
598
+
599
+ # Ensure we have the id_top columns in consensus_df
600
+ for col_name, dtype in [
601
+ ("id_top_name", pl.String),
602
+ ("id_top_class", pl.String),
603
+ ("id_top_adduct", pl.String),
604
+ ("id_top_score", pl.Float64)
605
+ ]:
606
+ if col_name not in study.consensus_df.columns:
607
+ study.consensus_df = study.consensus_df.with_columns(
608
+ pl.lit(None, dtype=dtype).alias(col_name)
345
609
  )
346
610
 
347
- for m in filtered_matches.iter_rows(named=True):
348
- mz_delta = abs(cons_mz - m.get("mz")) if m.get("mz") is not None else None
349
- lib_rt = m.get("rt")
350
- rt_delta = (
351
- abs(cons_rt - lib_rt)
352
- if (cons_rt is not None and lib_rt is not None)
353
- else None
354
- )
611
+ # Create a mapping dictionary for efficient updates
612
+ id_mapping = {}
613
+ for row in top_ids.iter_rows(named=True):
614
+ consensus_uid = row["consensus_uid"]
615
+ id_mapping[consensus_uid] = {
616
+ "id_top_name": row["id_top_name"],
617
+ "id_top_class": row["id_top_class"],
618
+ "id_top_adduct": row["id_top_adduct"],
619
+ "id_top_score": row["id_top_score"]
620
+ }
621
+
622
+ # Update consensus_df using map_elements (safer than join for avoiding duplicates)
623
+ if id_mapping:
624
+ study.consensus_df = study.consensus_df.with_columns([
625
+ pl.col("consensus_uid").map_elements(
626
+ lambda uid: id_mapping.get(uid, {}).get("id_top_name"),
627
+ return_dtype=pl.String
628
+ ).alias("id_top_name"),
629
+ pl.col("consensus_uid").map_elements(
630
+ lambda uid: id_mapping.get(uid, {}).get("id_top_class"),
631
+ return_dtype=pl.String
632
+ ).alias("id_top_class"),
633
+ pl.col("consensus_uid").map_elements(
634
+ lambda uid: id_mapping.get(uid, {}).get("id_top_adduct"),
635
+ return_dtype=pl.String
636
+ ).alias("id_top_adduct"),
637
+ pl.col("consensus_uid").map_elements(
638
+ lambda uid: id_mapping.get(uid, {}).get("id_top_score"),
639
+ return_dtype=pl.Float64
640
+ ).alias("id_top_score")
641
+ ])
355
642
 
356
- # Get adduct probability from _get_adducts() results
357
- adduct = m.get("adduct")
358
- score = adduct_prob_map.get(adduct, 1.0) if adduct else 1.0
643
+ if logger:
644
+ num_updated = len(id_mapping)
645
+ logger.debug(f"Updated consensus_df with top identifications for {num_updated} features")
646
+
647
+ except Exception as e:
648
+ if logger:
649
+ logger.error(f"Error updating consensus_df with identification results: {e}")
650
+ # Don't re-raise to avoid breaking the identification process
359
651
 
360
- results.append(
361
- {
362
- "consensus_uid": cons.get("consensus_uid"),
363
- "lib_uid": m.get("lib_uid"),
364
- "mz_delta": mz_delta,
365
- "rt_delta": rt_delta,
366
- "matcher": "ms1",
367
- "score": score,
368
- },
369
- )
370
652
 
371
- # Merge new results with existing results
372
- new_results_df = pl.DataFrame(results) if results else pl.DataFrame()
373
653
 
374
- if not new_results_df.is_empty():
375
- if hasattr(study, "id_df") and not study.id_df.is_empty():
376
- # Concatenate new results with existing results
377
- study.id_df = pl.concat([study.id_df, new_results_df])
378
- else:
379
- # First results
380
- study.id_df = new_results_df
381
654
 
382
- # Apply scoring adjustments based on compound and formula counts
383
- if (
384
- not study.id_df.is_empty()
385
- and hasattr(study, "lib_df")
386
- and not study.lib_df.is_empty()
387
- ):
388
- # Join with lib_df to get compound and formula information
389
- id_with_lib = study.id_df.join(
390
- study.lib_df.select(["lib_uid", "cmpd_uid", "formula"]),
391
- on="lib_uid",
392
- how="left",
393
- )
655
+ def identify(study, features=None, params=None, **kwargs):
656
+ """Identify consensus features against the loaded library.
394
657
 
395
- # Calculate counts per consensus_uid
396
- count_stats = id_with_lib.group_by("consensus_uid").agg(
397
- [
398
- pl.col("cmpd_uid").n_unique().alias("num_cmpds"),
399
- pl.col("formula")
400
- .filter(pl.col("formula").is_not_null())
401
- .n_unique()
402
- .alias("num_formulas"),
403
- ],
404
- )
658
+ Matches consensus_df.mz against lib_df.mz within mz_tolerance. If rt_tolerance
659
+ is provided and both consensus and library entries have rt values, RT is
660
+ used as an additional filter.
405
661
 
406
- # Join counts back to id_df
407
- id_with_counts = study.id_df.join(count_stats, on="consensus_uid", how="left")
662
+ Args:
663
+ study: Study instance
664
+ features: Optional DataFrame or list of consensus_uids to identify.
665
+ If None, identifies all consensus features.
666
+ params: Optional identify_defaults instance with matching tolerances and scoring parameters.
667
+ If None, uses default parameters.
668
+ **kwargs: Individual parameter overrides (mz_tol, rt_tol, heteroatom_penalty,
669
+ multiple_formulas_penalty, multiple_compounds_penalty, heteroatoms)
408
670
 
409
- # Join with lib_df again to get formula information for heteroatom penalty
410
- id_with_formula = id_with_counts.join(
411
- study.lib_df.select(["lib_uid", "formula"]),
412
- on="lib_uid",
413
- how="left",
671
+ The resulting DataFrame is stored as study.id_df. Columns:
672
+ - consensus_uid
673
+ - lib_uid
674
+ - mz_delta
675
+ - rt_delta (nullable)
676
+ - score (adduct probability from _get_adducts with penalties applied)
677
+ """
678
+ # Get logger from study if available
679
+ logger = getattr(study, "logger", None)
680
+
681
+ # Setup parameters early
682
+ params = _setup_identify_parameters(params, kwargs)
683
+ effective_mz_tol = getattr(params, "mz_tol", 0.01)
684
+ effective_rt_tol = getattr(params, "rt_tol", 2.0)
685
+
686
+ if logger:
687
+ logger.debug(
688
+ f"Starting identification with mz_tolerance={effective_mz_tol}, rt_tolerance={effective_rt_tol}",
414
689
  )
415
690
 
416
- # Apply scoring penalties
417
- heteroatoms = getattr(params, "heteroatoms", ["Cl", "Br", "F", "I"])
418
- heteroatom_penalty = getattr(params, "heteroatom_penalty", 0.7)
419
- formulas_penalty = getattr(params, "multiple_formulas_penalty", 0.8)
420
- compounds_penalty = getattr(params, "multiple_compounds_penalty", 0.8)
691
+ # Validate inputs early
692
+ if not _validate_identify_inputs(study, logger):
693
+ return
421
694
 
422
- # Build heteroatom condition
423
- heteroatom_condition = None
424
- for atom in heteroatoms:
425
- atom_condition = pl.col("formula").str.contains(atom)
426
- if heteroatom_condition is None:
427
- heteroatom_condition = atom_condition
428
- else:
429
- heteroatom_condition = heteroatom_condition | atom_condition
695
+ # Prepare consensus features and determine target UIDs early
696
+ consensus_to_process, target_uids = _prepare_consensus_features(study, features, logger)
697
+ if consensus_to_process is None:
698
+ return
430
699
 
431
- # Apply penalties
432
- study.id_df = (
433
- id_with_formula.with_columns(
434
- [
435
- # Heteroatom penalty: if formula contains specified heteroatoms, apply penalty
436
- pl.when(
437
- pl.col("formula").is_not_null() & heteroatom_condition,
438
- )
439
- .then(pl.col("score") * heteroatom_penalty)
440
- .otherwise(pl.col("score"))
441
- .alias("score_temp0"),
442
- ],
443
- )
444
- .with_columns(
445
- [
446
- # If num_formulas > 1, apply multiple formulas penalty
447
- pl.when(pl.col("num_formulas") > 1)
448
- .then(pl.col("score_temp0") * formulas_penalty)
449
- .otherwise(pl.col("score_temp0"))
450
- .alias("score_temp1"),
451
- ],
452
- )
453
- .with_columns(
454
- [
455
- # If num_cmpds > 1, apply multiple compounds penalty
456
- pl.when(pl.col("num_cmpds") > 1)
457
- .then(pl.col("score_temp1") * compounds_penalty)
458
- .otherwise(pl.col("score_temp1"))
459
- .round(4) # Round to 4 decimal places
460
- .alias("score"),
461
- ],
462
- )
463
- .select(
464
- [
465
- "consensus_uid",
466
- "lib_uid",
467
- "mz_delta",
468
- "rt_delta",
469
- "matcher",
470
- "score",
471
- ],
472
- )
473
- )
700
+ # Smart reset of id_df: only clear results for features being re-identified
701
+ _smart_reset_id_results(study, target_uids, logger)
702
+
703
+ # Cache adduct probabilities (expensive operation)
704
+ adduct_prob_map = _get_cached_adduct_probabilities(study, logger)
474
705
 
475
- # Store this operation in history
476
- if hasattr(study, "store_history"):
477
- history_params = {"mz_tol": effective_mz_tol, "rt_tol": effective_rt_tol}
478
- if features is not None:
479
- history_params["features"] = target_uids
480
- if params is not None and hasattr(params, "to_dict"):
481
- history_params["params"] = params.to_dict()
482
- if kwargs:
483
- history_params["kwargs"] = kwargs
484
- study.store_history(["identify"], history_params)
706
+ # Perform identification with optimized matching
707
+ results = _perform_identification_matching(
708
+ consensus_to_process, study, effective_mz_tol, effective_rt_tol, adduct_prob_map, logger
709
+ )
485
710
 
486
- if logger:
487
- if rt_filtered_compounds > 0:
488
- logger.debug(
489
- f"RT consistency filtering applied to {rt_filtered_compounds} compound groups",
490
- )
711
+ # Update or append results to study.id_df
712
+ _update_identification_results(study, results, logger)
491
713
 
492
- if multiply_charged_filtered > 0:
493
- logger.debug(
494
- f"Excluded {multiply_charged_filtered} multiply charged adducts (no [M+H]+ or [M-H]- coeluting)",
495
- )
714
+ # Apply scoring adjustments and update consensus columns
715
+ _finalize_identification_results(study, params, logger)
496
716
 
717
+ # Store operation in history
718
+ _store_identification_history(study, effective_mz_tol, effective_rt_tol, target_uids, params, kwargs)
719
+
720
+ # Log final statistics
721
+ consensus_count = len(consensus_to_process)
722
+ if logger:
723
+ features_with_matches = len([r for r in results if len(r["matches"]) > 0])
724
+ total_matches = sum(len(r["matches"]) for r in results)
497
725
  logger.info(
498
726
  f"Identification completed: {features_with_matches}/{consensus_count} features matched, {total_matches} total identifications",
499
727
  )
500
728
 
501
- if total_matches > 0:
502
- # Calculate some statistics
503
- mz_deltas = [r["mz_delta"] for r in results if r["mz_delta"] is not None]
504
- rt_deltas = [r["rt_delta"] for r in results if r["rt_delta"] is not None]
505
- scores = [r["score"] for r in results if r["score"] is not None]
506
-
507
- if mz_deltas:
508
- avg_mz_delta = sum(mz_deltas) / len(mz_deltas)
509
- max_mz_delta = max(mz_deltas)
510
- logger.debug(
511
- f"m/z accuracy: average Δ={avg_mz_delta:.5f} Da, max Δ={max_mz_delta:.5f} Da",
512
- )
513
-
514
- if rt_deltas:
515
- avg_rt_delta = sum(rt_deltas) / len(rt_deltas)
516
- max_rt_delta = max(rt_deltas)
517
- logger.debug(
518
- f"RT accuracy: average Δ={avg_rt_delta:.2f} min, max Δ={max_rt_delta:.2f} min",
519
- )
520
-
521
- if scores:
522
- avg_score = sum(scores) / len(scores)
523
- min_score = min(scores)
524
- max_score = max(scores)
525
- logger.debug(
526
- f"Adduct probability scores: average={avg_score:.3f}, min={min_score:.3f}, max={max_score:.3f}",
527
- )
528
-
529
729
 
530
730
  def get_id(study, features=None) -> pl.DataFrame:
531
731
  """Get identification results with comprehensive annotation data.
@@ -795,6 +995,7 @@ def id_reset(study):
795
995
  Removes:
796
996
  - study.id_df (identification results DataFrame)
797
997
  - 'identify' from study.history
998
+ - Resets id_top_* columns in consensus_df to null
798
999
 
799
1000
  Args:
800
1001
  study: Study instance to reset
@@ -808,6 +1009,23 @@ def id_reset(study):
808
1009
  logger.debug("Removing id_df")
809
1010
  delattr(study, "id_df")
810
1011
 
1012
+ # Reset id_top_* columns in consensus_df
1013
+ if hasattr(study, "consensus_df") and not study.consensus_df.is_empty():
1014
+ if logger:
1015
+ logger.debug("Resetting id_top_* columns in consensus_df")
1016
+
1017
+ # Check which columns exist before trying to update them
1018
+ id_columns_to_reset = []
1019
+ for col in ["id_top_name", "id_top_class", "id_top_adduct", "id_top_score"]:
1020
+ if col in study.consensus_df.columns:
1021
+ if col == "id_top_score":
1022
+ id_columns_to_reset.append(pl.lit(None, dtype=pl.Float64).alias(col))
1023
+ else:
1024
+ id_columns_to_reset.append(pl.lit(None, dtype=pl.String).alias(col))
1025
+
1026
+ if id_columns_to_reset:
1027
+ study.consensus_df = study.consensus_df.with_columns(id_columns_to_reset)
1028
+
811
1029
  # Remove identify from history
812
1030
  if hasattr(study, "history") and "identify" in study.history:
813
1031
  if logger:
@@ -827,6 +1045,7 @@ def lib_reset(study):
827
1045
  - study._lib (library object reference)
828
1046
  - 'identify' from study.history
829
1047
  - 'lib_load' from study.history (if exists)
1048
+ - Resets id_top_* columns in consensus_df to null
830
1049
 
831
1050
  Args:
832
1051
  study: Study instance to reset
@@ -852,6 +1071,23 @@ def lib_reset(study):
852
1071
  logger.debug("Removing _lib reference")
853
1072
  delattr(study, "_lib")
854
1073
 
1074
+ # Reset id_top_* columns in consensus_df
1075
+ if hasattr(study, "consensus_df") and not study.consensus_df.is_empty():
1076
+ if logger:
1077
+ logger.debug("Resetting id_top_* columns in consensus_df")
1078
+
1079
+ # Check which columns exist before trying to update them
1080
+ id_columns_to_reset = []
1081
+ for col in ["id_top_name", "id_top_class", "id_top_adduct", "id_top_score"]:
1082
+ if col in study.consensus_df.columns:
1083
+ if col == "id_top_score":
1084
+ id_columns_to_reset.append(pl.lit(None, dtype=pl.Float64).alias(col))
1085
+ else:
1086
+ id_columns_to_reset.append(pl.lit(None, dtype=pl.String).alias(col))
1087
+
1088
+ if id_columns_to_reset:
1089
+ study.consensus_df = study.consensus_df.with_columns(id_columns_to_reset)
1090
+
855
1091
  # Remove from history
856
1092
  if hasattr(study, "history"):
857
1093
  if "identify" in study.history:
@@ -868,7 +1104,7 @@ def lib_reset(study):
868
1104
  logger.info("Library and identification data reset completed")
869
1105
 
870
1106
 
871
- def _get_adducts(self, adducts_list: list = None, **kwargs):
1107
+ def _get_adducts(study, adducts_list: list = None, **kwargs):
872
1108
  """
873
1109
  Generate comprehensive adduct specifications for study-level adduct filtering.
874
1110
 
@@ -901,10 +1137,11 @@ def _get_adducts(self, adducts_list: list = None, **kwargs):
901
1137
  # Import required modules
902
1138
 
903
1139
  # Use provided adducts list or get from study parameters
904
- if adducts_list is None:
905
- adducts_list = (
906
- self.parameters.adducts
907
- if hasattr(self.parameters, "adducts") and self.parameters.adducts
1140
+ adducts_list_to_use = adducts_list
1141
+ if adducts_list_to_use is None:
1142
+ adducts_list_to_use = (
1143
+ study.parameters.adducts
1144
+ if hasattr(study.parameters, "adducts") and study.parameters.adducts
908
1145
  else []
909
1146
  )
910
1147
 
@@ -914,13 +1151,13 @@ def _get_adducts(self, adducts_list: list = None, **kwargs):
914
1151
  max_combinations = kwargs.get("max_combinations", 3) # Up to 3 combinations
915
1152
  min_probability = kwargs.get(
916
1153
  "min_probability",
917
- getattr(self.parameters, "adduct_min_probability", 0.04),
1154
+ getattr(study.parameters, "adduct_min_probability", 0.04),
918
1155
  )
919
1156
 
920
1157
  # Parse base adduct specifications
921
1158
  base_specs = []
922
1159
 
923
- for adduct_str in adducts_list:
1160
+ for adduct_str in adducts_list_to_use:
924
1161
  if not isinstance(adduct_str, str) or ":" not in adduct_str:
925
1162
  continue
926
1163
 
@@ -934,7 +1171,7 @@ def _get_adducts(self, adducts_list: list = None, **kwargs):
934
1171
  probability = float(parts[2])
935
1172
 
936
1173
  # Calculate mass shift from formula
937
- mass_shift = self._calculate_formula_mass_shift(formula_part)
1174
+ mass_shift = _calculate_formula_mass_shift(study, formula_part)
938
1175
 
939
1176
  base_specs.append(
940
1177
  {
@@ -972,7 +1209,7 @@ def _get_adducts(self, adducts_list: list = None, **kwargs):
972
1209
  # 1. Single adducts (filter out neutral adducts with charge == 0)
973
1210
  for spec in base_specs:
974
1211
  if charge_min <= spec["charge"] <= charge_max and spec["charge"] != 0:
975
- formatted_name = self._format_adduct_name([spec])
1212
+ formatted_name = _format_adduct_name([spec])
976
1213
  combinations_list.append(
977
1214
  {
978
1215
  "components": [spec],
@@ -991,15 +1228,16 @@ def _get_adducts(self, adducts_list: list = None, **kwargs):
991
1228
  total_charge = base_charge * multiplier
992
1229
  if charge_min <= total_charge <= charge_max and total_charge != 0:
993
1230
  components = [spec] * multiplier
994
- formatted_name = self._format_adduct_name(components)
1231
+ formatted_name = _format_adduct_name(components)
1232
+ probability_multiplied = float(spec["probability"]) ** multiplier
995
1233
 
996
1234
  combinations_list.append(
997
1235
  {
998
1236
  "components": components,
999
1237
  "formatted_name": formatted_name,
1000
- "total_mass_shift": spec["mass_shift"] * multiplier,
1238
+ "total_mass_shift": float(spec["mass_shift"]) * multiplier,
1001
1239
  "total_charge": total_charge,
1002
- "combined_probability": spec["probability"] ** multiplier,
1240
+ "combined_probability": probability_multiplied,
1003
1241
  "complexity": multiplier,
1004
1242
  },
1005
1243
  )
@@ -1012,16 +1250,16 @@ def _get_adducts(self, adducts_list: list = None, **kwargs):
1012
1250
  total_charge = pos_spec["charge"] + neut_spec["charge"]
1013
1251
  if charge_min <= total_charge <= charge_max and total_charge != 0:
1014
1252
  components = [pos_spec, neut_spec]
1015
- formatted_name = self._format_adduct_name(components)
1253
+ formatted_name = _format_adduct_name(components)
1016
1254
  combinations_list.append(
1017
1255
  {
1018
1256
  "components": components,
1019
1257
  "formatted_name": formatted_name,
1020
- "total_mass_shift": pos_spec["mass_shift"]
1021
- + neut_spec["mass_shift"],
1258
+ "total_mass_shift": float(pos_spec["mass_shift"])
1259
+ + float(neut_spec["mass_shift"]),
1022
1260
  "total_charge": total_charge,
1023
- "combined_probability": pos_spec["probability"]
1024
- * neut_spec["probability"],
1261
+ "combined_probability": float(pos_spec["probability"])
1262
+ * float(neut_spec["probability"]),
1025
1263
  "complexity": 2,
1026
1264
  },
1027
1265
  )
@@ -1051,9 +1289,11 @@ def _get_adducts(self, adducts_list: list = None, **kwargs):
1051
1289
  adducts_df = adducts_df.filter(pl.col("probability") >= min_probability)
1052
1290
  adducts_after_filter = len(adducts_df)
1053
1291
 
1054
- self.logger.debug(
1055
- f"Study adducts: generated {adducts_before_filter}, filtered to {adducts_after_filter} (min_prob={min_probability})",
1056
- )
1292
+ logger = getattr(study, "logger", None)
1293
+ if logger:
1294
+ logger.debug(
1295
+ f"Study adducts: generated {adducts_before_filter}, filtered to {adducts_after_filter} (min_prob={min_probability})",
1296
+ )
1057
1297
 
1058
1298
  else:
1059
1299
  # Return empty DataFrame with correct schema
@@ -1070,7 +1310,7 @@ def _get_adducts(self, adducts_list: list = None, **kwargs):
1070
1310
  return adducts_df
1071
1311
 
1072
1312
 
1073
- def _calculate_formula_mass_shift(self, formula: str) -> float:
1313
+ def _calculate_formula_mass_shift(study, formula: str) -> float:
1074
1314
  """Calculate mass shift from formula string like "+H", "-H2O", "+Na-H", etc."""
1075
1315
  # Standard atomic masses
1076
1316
  atomic_masses = {
@@ -1121,7 +1361,7 @@ def _calculate_formula_mass_shift(self, formula: str) -> float:
1121
1361
  continue
1122
1362
 
1123
1363
  # Parse element and count (e.g., "H2O" -> H:2, O:1)
1124
- elements = self._parse_element_counts(part)
1364
+ elements = _parse_element_counts(part)
1125
1365
 
1126
1366
  for element, count in elements.items():
1127
1367
  if element in atomic_masses:
@@ -1130,9 +1370,9 @@ def _calculate_formula_mass_shift(self, formula: str) -> float:
1130
1370
  return total_mass
1131
1371
 
1132
1372
 
1133
- def _parse_element_counts(self, formula_part: str) -> dict[str, int]:
1373
+ def _parse_element_counts(formula_part: str) -> dict[str, int]:
1134
1374
  """Parse element counts from a formula part like 'H2O' -> {'H': 2, 'O': 1}"""
1135
- elements = {}
1375
+ elements: dict[str, int] = {}
1136
1376
  i = 0
1137
1377
 
1138
1378
  while i < len(formula_part):
@@ -1156,7 +1396,7 @@ def _parse_element_counts(self, formula_part: str) -> dict[str, int]:
1156
1396
  return elements
1157
1397
 
1158
1398
 
1159
- def _format_adduct_name(self, components: list[dict]) -> str:
1399
+ def _format_adduct_name(components: list[dict]) -> str:
1160
1400
  """Format adduct name from components like [M+H]1+ or [M+2H]2+"""
1161
1401
  if not components:
1162
1402
  return "[M]"