masster 0.5.27__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of masster might be problematic. Click here for more details.

masster/sample/id.py ADDED
@@ -0,0 +1,1160 @@
1
+ """sample/id.py
2
+
3
+ Identification helpers for Sample: load a Lib and identify features
4
+ by matching m/z (and optionally RT).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import polars as pl
10
+
11
+
12
+ def lib_load(
13
+ sample,
14
+ lib_source,
15
+ polarity: str | None = None,
16
+ adducts: list | None = None,
17
+ iso: str | None = None,
18
+ ):
19
+ """Load a compound library into the sample.
20
+
21
+ Args:
22
+ sample: Sample instance
23
+ lib_source: either a CSV/JSON file path (str) or a Lib instance
24
+ polarity: ionization polarity ("positive" or "negative") - used when lib_source is a CSV/JSON path.
25
+ If None, uses sample.polarity automatically.
26
+ adducts: specific adducts to generate - used when lib_source is a CSV/JSON path
27
+ iso: isotope generation mode ("13C" to generate 13C isotopes, None for no isotopes)
28
+
29
+ Side effects:
30
+ sets sample.lib_df to a Polars DataFrame and stores the lib object on
31
+ sample._lib for later reference.
32
+ """
33
+ # Lazy import to avoid circular imports at module import time
34
+ try:
35
+ from masster.lib.lib import Lib
36
+ except Exception:
37
+ Lib = None
38
+
39
+ if lib_source is None:
40
+ raise ValueError("lib_source must be a CSV/JSON file path (str) or a Lib instance")
41
+
42
+ # Use sample polarity if not explicitly provided
43
+ if polarity is None:
44
+ sample_polarity = getattr(sample, "polarity", "positive")
45
+ # Normalize polarity names
46
+ if sample_polarity in ["pos", "positive"]:
47
+ polarity = "positive"
48
+ elif sample_polarity in ["neg", "negative"]:
49
+ polarity = "negative"
50
+ else:
51
+ polarity = "positive" # Default fallback
52
+ sample.logger.debug(f"Using sample polarity: {polarity}")
53
+
54
+ # Handle string input (CSV or JSON file path)
55
+ if isinstance(lib_source, str):
56
+ if Lib is None:
57
+ raise ImportError(
58
+ "Could not import masster.lib.lib.Lib - required for CSV/JSON loading",
59
+ )
60
+
61
+ lib_obj = Lib()
62
+
63
+ # Determine file type by extension
64
+ if lib_source.lower().endswith(".json"):
65
+ lib_obj.import_json(lib_source, polarity=polarity, adducts=adducts)
66
+ elif lib_source.lower().endswith(".csv"):
67
+ lib_obj.import_csv(lib_source, polarity=polarity, adducts=adducts)
68
+ else:
69
+ # Default to CSV behavior for backward compatibility
70
+ lib_obj.import_csv(lib_source, polarity=polarity, adducts=adducts)
71
+
72
+ # Handle Lib instance
73
+ elif Lib is not None and isinstance(lib_source, Lib):
74
+ lib_obj = lib_source
75
+
76
+ # Handle other objects with lib_df attribute
77
+ elif hasattr(lib_source, "lib_df"):
78
+ lib_obj = lib_source
79
+
80
+ else:
81
+ raise TypeError(
82
+ "lib_source must be a CSV/JSON file path (str), a masster.lib.Lib instance, or have a 'lib_df' attribute",
83
+ )
84
+
85
+ # Ensure lib_df is populated
86
+ lf = getattr(lib_obj, "lib_df", None)
87
+ if lf is None or (hasattr(lf, "is_empty") and lf.is_empty()):
88
+ raise ValueError("Library has no data populated in lib_df")
89
+
90
+ # Filter by polarity to match sample
91
+ # Map polarity to charge signs
92
+ if polarity == "positive":
93
+ target_charges = [1, 2] # positive charges
94
+ elif polarity == "negative":
95
+ target_charges = [-1, -2] # negative charges
96
+ else:
97
+ target_charges = [-2, -1, 1, 2] # all charges
98
+
99
+ # Filter library entries by charge sign (which corresponds to polarity)
100
+ filtered_lf = lf.filter(pl.col("z").is_in(target_charges))
101
+
102
+ if filtered_lf.is_empty():
103
+ print(
104
+ f"Warning: No library entries found for polarity '{polarity}'. Using all entries.",
105
+ )
106
+ filtered_lf = lf
107
+
108
+ # Store pointer and DataFrame on sample
109
+ sample._lib = lib_obj
110
+
111
+ # Add source_id column with filename (without path) if loading from CSV/JSON
112
+ if isinstance(lib_source, str):
113
+ import os
114
+
115
+ filename_only = os.path.basename(lib_source)
116
+ filtered_lf = filtered_lf.with_columns(pl.lit(filename_only).alias("source_id"))
117
+
118
+ # Ensure required columns exist and set correct values
119
+ required_columns = {"quant_group": pl.Int64, "iso": pl.Int64}
120
+
121
+ for col_name, col_dtype in required_columns.items():
122
+ if col_name == "quant_group":
123
+ # Set quant_group using cmpd_uid (same for isotopomers of same compound)
124
+ if "cmpd_uid" in filtered_lf.columns:
125
+ filtered_lf = filtered_lf.with_columns(pl.col("cmpd_uid").cast(col_dtype).alias("quant_group"))
126
+ else:
127
+ # Fallback to lib_uid if cmpd_uid doesn't exist
128
+ filtered_lf = filtered_lf.with_columns(pl.col("lib_uid").cast(col_dtype).alias("quant_group"))
129
+ elif col_name == "iso":
130
+ if col_name not in filtered_lf.columns:
131
+ # Default to zero for iso
132
+ filtered_lf = filtered_lf.with_columns(pl.lit(0).cast(col_dtype).alias(col_name))
133
+
134
+ # Generate 13C isotopes if requested
135
+ original_count = len(filtered_lf)
136
+ if iso == "13C":
137
+ filtered_lf = _generate_13c_isotopes(filtered_lf)
138
+ # Update the log message to show the correct count after isotope generation
139
+ if isinstance(lib_source, str):
140
+ import os
141
+
142
+ filename_only = os.path.basename(lib_source)
143
+ print(
144
+ f"Generated 13C isotopes: {len(filtered_lf)} total entries ({original_count} original + {len(filtered_lf) - original_count} isotopes) from {filename_only}"
145
+ )
146
+
147
+ # Store library as Polars DataFrame
148
+ sample.lib_df = filtered_lf
149
+
150
+ # Store this operation in history
151
+ if hasattr(sample, "store_history"):
152
+ sample.store_history(
153
+ ["lib_load"],
154
+ {"lib_source": str(lib_source), "polarity": polarity, "adducts": adducts, "iso": iso},
155
+ )
156
+
157
+
158
+ def identify(sample, features=None, params=None, **kwargs):
159
+ """Identify features against the loaded library.
160
+
161
+ Matches features_df.mz against lib_df.mz within mz_tolerance. If rt_tolerance
162
+ is provided and both feature and library entries have rt values, RT is
163
+ used as an additional filter.
164
+
165
+ Args:
166
+ sample: Sample instance
167
+ features: Optional DataFrame or list of feature_uids to identify.
168
+ If None, identifies all features.
169
+ params: Optional identify_defaults instance with matching tolerances and scoring parameters.
170
+ If None, uses default parameters.
171
+ **kwargs: Individual parameter overrides (mz_tol, rt_tol, heteroatom_penalty,
172
+ multiple_formulas_penalty, multiple_compounds_penalty, heteroatoms)
173
+
174
+ The resulting DataFrame is stored as sample.id_df. Columns:
175
+ - feature_uid
176
+ - lib_uid
177
+ - mz_delta
178
+ - rt_delta (nullable)
179
+ - score (adduct probability with penalties applied)
180
+ """
181
+ # Get logger from sample if available
182
+ logger = getattr(sample, "logger", None)
183
+
184
+ # Setup parameters
185
+ params = _setup_identify_parameters(params, kwargs)
186
+ effective_mz_tol = getattr(params, "mz_tol", 0.01)
187
+ effective_rt_tol = getattr(params, "rt_tol", 2.0)
188
+
189
+ if logger:
190
+ logger.debug(
191
+ f"Starting identification with mz_tolerance={effective_mz_tol}, rt_tolerance={effective_rt_tol}",
192
+ )
193
+
194
+ # Validate inputs
195
+ if not _validate_identify_inputs(sample, logger):
196
+ return
197
+
198
+ # Prepare features and determine target UIDs
199
+ features_to_process, target_uids = _prepare_features(sample, features, logger)
200
+ if features_to_process is None:
201
+ return
202
+
203
+ # Smart reset of id_df: only clear results for features being re-identified
204
+ _smart_reset_id_results(sample, target_uids, logger)
205
+
206
+ # Cache adduct probabilities (expensive operation)
207
+ adduct_prob_map = _get_cached_adduct_probabilities(sample, logger)
208
+
209
+ # Perform identification with optimized matching
210
+ results = _perform_identification_matching(
211
+ features_to_process, sample, effective_mz_tol, effective_rt_tol, adduct_prob_map, logger
212
+ )
213
+
214
+ # Update or append results to sample.id_df
215
+ _update_identification_results(sample, results, logger)
216
+
217
+ # Apply scoring adjustments
218
+ _finalize_identification_results(sample, params, logger)
219
+
220
+ # Store operation in history
221
+ _store_identification_history(sample, effective_mz_tol, effective_rt_tol, target_uids, params, kwargs)
222
+
223
+ # Log final statistics
224
+ features_count = len(features_to_process)
225
+ if logger:
226
+ features_with_matches = len([r for r in results if len(r["matches"]) > 0])
227
+ total_matches = sum(len(r["matches"]) for r in results)
228
+ logger.success(
229
+ f"Identification completed: {features_with_matches}/{features_count} features matched, {total_matches} total identifications",
230
+ )
231
+
232
+
233
+ def get_id(sample, features=None) -> pl.DataFrame:
234
+ """Get identification results with comprehensive annotation data.
235
+
236
+ Combines identification results (sample.id_df) with library information to provide
237
+ comprehensive identification data including names, adducts, formulas, etc.
238
+
239
+ Args:
240
+ sample: Sample instance with id_df and lib_df populated
241
+ features: Optional DataFrame or list of feature_uids to filter results.
242
+ If None, returns all identification results.
243
+
244
+ Returns:
245
+ Polars DataFrame with columns:
246
+ - feature_uid
247
+ - lib_uid
248
+ - mz (feature m/z)
249
+ - rt (feature RT)
250
+ - name (compound name from library)
251
+ - shortname (short name from library, if available)
252
+ - class (compound class from library, if available)
253
+ - formula (molecular formula from library)
254
+ - adduct (adduct type from library)
255
+ - smiles (SMILES notation from library)
256
+ - mz_delta (absolute m/z difference)
257
+ - rt_delta (absolute RT difference, nullable)
258
+ - Additional library columns if available (inchi, inchikey, etc.)
259
+
260
+ Raises:
261
+ ValueError: If sample.id_df or sample.lib_df are empty
262
+ """
263
+ # Validate inputs
264
+ if getattr(sample, "id_df", None) is None or sample.id_df.is_empty():
265
+ raise ValueError(
266
+ "Identification results (sample.id_df) are empty; call identify() first",
267
+ )
268
+
269
+ if getattr(sample, "lib_df", None) is None or sample.lib_df.is_empty():
270
+ raise ValueError("Library (sample.lib_df) is empty; call lib_load() first")
271
+
272
+ if getattr(sample, "features_df", None) is None or sample.features_df.is_empty():
273
+ raise ValueError("Features (sample.features_df) are empty")
274
+
275
+ # Start with identification results
276
+ result_df = sample.id_df.clone()
277
+
278
+ # Filter by features if provided
279
+ if features is not None:
280
+ if hasattr(features, "columns"): # DataFrame-like
281
+ if "feature_uid" in features.columns:
282
+ uids = features["feature_uid"].unique().to_list()
283
+ else:
284
+ raise ValueError(
285
+ "features DataFrame must contain 'feature_uid' column",
286
+ )
287
+ elif hasattr(features, "__iter__") and not isinstance(
288
+ features,
289
+ str,
290
+ ): # List-like
291
+ uids = list(features)
292
+ else:
293
+ raise ValueError(
294
+ "features must be a DataFrame with 'feature_uid' column or a list of UIDs",
295
+ )
296
+
297
+ result_df = result_df.filter(pl.col("feature_uid").is_in(uids))
298
+
299
+ if result_df.is_empty():
300
+ return pl.DataFrame()
301
+
302
+ # Join with features_df to get feature m/z and RT
303
+ features_cols = ["feature_uid", "mz", "rt"]
304
+ # Only select columns that exist in features_df
305
+ available_features_cols = [col for col in features_cols if col in sample.features_df.columns]
306
+
307
+ result_df = result_df.join(
308
+ sample.features_df.select(available_features_cols),
309
+ on="feature_uid",
310
+ how="left",
311
+ suffix="_feature",
312
+ )
313
+
314
+ # Join with lib_df to get library information
315
+ lib_cols = [
316
+ "lib_uid",
317
+ "name",
318
+ "shortname",
319
+ "class",
320
+ "formula",
321
+ "adduct",
322
+ "smiles",
323
+ "cmpd_uid",
324
+ "inchikey",
325
+ ]
326
+ # Add optional columns if they exist
327
+ optional_lib_cols = ["inchi", "db_id", "db"]
328
+ for col in optional_lib_cols:
329
+ if col in sample.lib_df.columns:
330
+ lib_cols.append(col)
331
+
332
+ # Only select columns that exist in lib_df
333
+ available_lib_cols = [col for col in lib_cols if col in sample.lib_df.columns]
334
+
335
+ result_df = result_df.join(
336
+ sample.lib_df.select(available_lib_cols),
337
+ on="lib_uid",
338
+ how="left",
339
+ suffix="_lib",
340
+ )
341
+
342
+ # Reorder columns for better readability
343
+ column_order = [
344
+ "feature_uid",
345
+ "cmpd_uid" if "cmpd_uid" in result_df.columns else None,
346
+ "lib_uid",
347
+ "name" if "name" in result_df.columns else None,
348
+ "shortname" if "shortname" in result_df.columns else None,
349
+ "class" if "class" in result_df.columns else None,
350
+ "formula" if "formula" in result_df.columns else None,
351
+ "adduct" if "adduct" in result_df.columns else None,
352
+ "mz" if "mz" in result_df.columns else None,
353
+ "mz_delta",
354
+ "rt" if "rt" in result_df.columns else None,
355
+ "rt_delta",
356
+ "matcher" if "matcher" in result_df.columns else None,
357
+ "score" if "score" in result_df.columns else None,
358
+ "id_source" if "id_source" in result_df.columns else None,
359
+ "smiles" if "smiles" in result_df.columns else None,
360
+ "inchikey" if "inchikey" in result_df.columns else None,
361
+ ]
362
+
363
+ # Add any remaining columns
364
+ remaining_cols = [col for col in result_df.columns if col not in column_order]
365
+ column_order.extend(remaining_cols)
366
+
367
+ # Filter out None values and select existing columns
368
+ final_column_order = [col for col in column_order if col is not None and col in result_df.columns]
369
+
370
+ result_df = result_df.select(final_column_order)
371
+
372
+ return result_df
373
+
374
+
375
+ def id_reset(sample):
376
+ """Reset identification data and remove from history.
377
+
378
+ Removes:
379
+ - sample.id_df (identification results DataFrame)
380
+ - Resets id_top_* columns in features_df to null
381
+ - 'identify' from sample.history
382
+
383
+ Args:
384
+ sample: Sample instance to reset
385
+ """
386
+ # Get logger from sample if available
387
+ logger = getattr(sample, "logger", None)
388
+
389
+ # Remove id_df
390
+ if hasattr(sample, "id_df"):
391
+ if logger:
392
+ logger.debug("Removing id_df")
393
+ delattr(sample, "id_df")
394
+
395
+ # Reset id_top_* columns in features_df
396
+ if hasattr(sample, "features_df") and sample.features_df is not None and not sample.features_df.is_empty():
397
+ if logger:
398
+ logger.debug("Resetting id_top_* columns in features_df")
399
+
400
+ # Check which columns exist before trying to update them
401
+ id_columns_to_reset = []
402
+ for col in ["id_top_name", "id_top_class", "id_top_adduct", "id_top_score"]:
403
+ if col in sample.features_df.columns:
404
+ if col == "id_top_score":
405
+ id_columns_to_reset.append(pl.lit(None, dtype=pl.Float64).alias(col))
406
+ else:
407
+ id_columns_to_reset.append(pl.lit(None, dtype=pl.String).alias(col))
408
+
409
+ if id_columns_to_reset:
410
+ sample.features_df = sample.features_df.with_columns(id_columns_to_reset)
411
+
412
+ # Remove identify from history
413
+ if hasattr(sample, "history") and "identify" in sample.history:
414
+ if logger:
415
+ logger.debug("Removing 'identify' from history")
416
+ del sample.history["identify"]
417
+
418
+ if logger:
419
+ logger.info("Identification data reset completed")
420
+
421
+
422
+ def lib_reset(sample):
423
+ """Reset library and identification data and remove from history.
424
+
425
+ Removes:
426
+ - sample.id_df (identification results DataFrame)
427
+ - sample.lib_df (library DataFrame)
428
+ - sample._lib (library object reference)
429
+ - Resets id_top_* columns in features_df to null
430
+ - 'identify' from sample.history
431
+ - 'lib_load' from sample.history (if exists)
432
+
433
+ Args:
434
+ sample: Sample instance to reset
435
+ """
436
+ # Get logger from sample if available
437
+ logger = getattr(sample, "logger", None)
438
+
439
+ # Remove id_df
440
+ if hasattr(sample, "id_df"):
441
+ if logger:
442
+ logger.debug("Removing id_df")
443
+ delattr(sample, "id_df")
444
+
445
+ # Remove lib_df
446
+ if hasattr(sample, "lib_df"):
447
+ if logger:
448
+ logger.debug("Removing lib_df")
449
+ delattr(sample, "lib_df")
450
+
451
+ # Remove lib object reference
452
+ if hasattr(sample, "_lib"):
453
+ if logger:
454
+ logger.debug("Removing _lib reference")
455
+ delattr(sample, "_lib")
456
+
457
+ # Reset id_top_* columns in features_df
458
+ if hasattr(sample, "features_df") and sample.features_df is not None and not sample.features_df.is_empty():
459
+ if logger:
460
+ logger.debug("Resetting id_top_* columns in features_df")
461
+
462
+ # Check which columns exist before trying to update them
463
+ id_columns_to_reset = []
464
+ for col in ["id_top_name", "id_top_class", "id_top_adduct", "id_top_score"]:
465
+ if col in sample.features_df.columns:
466
+ if col == "id_top_score":
467
+ id_columns_to_reset.append(pl.lit(None, dtype=pl.Float64).alias(col))
468
+ else:
469
+ id_columns_to_reset.append(pl.lit(None, dtype=pl.String).alias(col))
470
+
471
+ if id_columns_to_reset:
472
+ sample.features_df = sample.features_df.with_columns(id_columns_to_reset)
473
+
474
+ # Remove from history
475
+ if hasattr(sample, "history"):
476
+ if "identify" in sample.history:
477
+ if logger:
478
+ logger.debug("Removing 'identify' from history")
479
+ del sample.history["identify"]
480
+
481
+ if "lib_load" in sample.history:
482
+ if logger:
483
+ logger.debug("Removing 'lib_load' from history")
484
+ del sample.history["lib_load"]
485
+
486
+ if logger:
487
+ logger.info("Library and identification data reset completed")
488
+
489
+
490
+ # Helper functions (private)
491
+
492
+
493
+ def _setup_identify_parameters(params, kwargs):
494
+ """Setup identification parameters with fallbacks and overrides."""
495
+ # Import defaults class
496
+ try:
497
+ from masster.sample.defaults.identify_def import identify_defaults
498
+ except ImportError:
499
+ identify_defaults = None
500
+
501
+ # Use provided params or create defaults
502
+ if params is None:
503
+ if identify_defaults is not None:
504
+ params = identify_defaults()
505
+ else:
506
+ # Fallback if imports fail
507
+ class FallbackParams:
508
+ mz_tol = 0.01
509
+ rt_tol = 2.0
510
+ heteroatom_penalty = 0.7
511
+ multiple_formulas_penalty = 0.8
512
+ multiple_compounds_penalty = 0.8
513
+ heteroatoms = ["Cl", "Br", "F", "I"]
514
+
515
+ params = FallbackParams()
516
+
517
+ # Override parameters with any provided kwargs
518
+ if kwargs:
519
+ # Handle parameter name mapping for backwards compatibility
520
+ param_mapping = {"rt_tolerance": "rt_tol", "mz_tolerance": "mz_tol"}
521
+
522
+ for param_name, value in kwargs.items():
523
+ # Check if we need to map the parameter name
524
+ mapped_name = param_mapping.get(param_name, param_name)
525
+
526
+ if hasattr(params, mapped_name):
527
+ setattr(params, mapped_name, value)
528
+ elif hasattr(params, param_name):
529
+ setattr(params, param_name, value)
530
+
531
+ return params
532
+
533
+
534
+ def _smart_reset_id_results(sample, target_uids, logger):
535
+ """Smart reset of identification results - only clear what's being re-identified."""
536
+ if target_uids is not None:
537
+ # Selective reset: only clear results for features being re-identified
538
+ if hasattr(sample, "id_df") and sample.id_df is not None and not sample.id_df.is_empty():
539
+ sample.id_df = sample.id_df.filter(~pl.col("feature_uid").is_in(target_uids))
540
+ if logger:
541
+ logger.debug(f"Cleared previous results for {len(target_uids)} specific features")
542
+ elif not hasattr(sample, "id_df"):
543
+ sample.id_df = pl.DataFrame()
544
+ else:
545
+ # Full reset: clear all results
546
+ sample.id_df = pl.DataFrame()
547
+ if logger:
548
+ logger.debug("Cleared all previous identification results")
549
+
550
+
551
+ def _get_cached_adduct_probabilities(sample, logger):
552
+ """Get adduct probabilities with caching to avoid repeated expensive computation."""
553
+ # Check if we have cached results and cache key matches current parameters
554
+ current_cache_key = _get_adduct_cache_key(sample)
555
+
556
+ if (
557
+ hasattr(sample, "_cached_adduct_probs")
558
+ and hasattr(sample, "_cached_adduct_key")
559
+ and sample._cached_adduct_key == current_cache_key
560
+ ):
561
+ if logger:
562
+ logger.debug("Using cached adduct probabilities")
563
+ return sample._cached_adduct_probs
564
+
565
+ # Compute and cache
566
+ if logger:
567
+ logger.debug("Computing adduct probabilities...")
568
+ adduct_prob_map = _get_adduct_probabilities(sample)
569
+ sample._cached_adduct_probs = adduct_prob_map
570
+ sample._cached_adduct_key = current_cache_key
571
+
572
+ if logger:
573
+ logger.debug(f"Computed and cached probabilities for {len(adduct_prob_map)} adducts")
574
+ return adduct_prob_map
575
+
576
+
577
+ def _get_adduct_cache_key(sample):
578
+ """Generate a cache key based on adduct-related parameters."""
579
+ if hasattr(sample, "parameters") and hasattr(sample.parameters, "adducts"):
580
+ adducts_str = "|".join(sorted(sample.parameters.adducts)) if sample.parameters.adducts else ""
581
+ min_prob = getattr(sample.parameters, "adduct_min_probability", 0.04)
582
+ return f"adducts:{adducts_str}:min_prob:{min_prob}"
583
+ return "default"
584
+
585
+
586
+ def clear_identification_cache(sample):
587
+ """Clear cached identification data (useful when parameters change)."""
588
+ cache_attrs = ["_cached_adduct_probs", "_cached_adduct_key"]
589
+ for attr in cache_attrs:
590
+ if hasattr(sample, attr):
591
+ delattr(sample, attr)
592
+
593
+
594
+ def _perform_identification_matching(
595
+ features_to_process, sample, effective_mz_tol, effective_rt_tol, adduct_prob_map, logger
596
+ ):
597
+ """Perform optimized identification matching using vectorized operations where possible."""
598
+ results = []
599
+
600
+ # Get library data as arrays for faster access
601
+ lib_df = sample.lib_df
602
+
603
+ if logger:
604
+ features_count = len(features_to_process)
605
+ lib_count = len(lib_df)
606
+ logger.debug(
607
+ f"Identifying {features_count} features against {lib_count} library entries",
608
+ )
609
+
610
+ # Process each feature
611
+ for feat_row in features_to_process.iter_rows(named=True):
612
+ feat_uid = feat_row.get("feature_uid")
613
+ feat_mz = feat_row.get("mz")
614
+ feat_rt = feat_row.get("rt")
615
+
616
+ if feat_mz is None:
617
+ if logger:
618
+ logger.debug(f"Skipping feature {feat_uid} - no m/z value")
619
+ results.append({"feature_uid": feat_uid, "matches": []})
620
+ continue
621
+
622
+ # Find matches using vectorized filtering
623
+ matches = _find_matches_vectorized(
624
+ lib_df, feat_mz, feat_rt, effective_mz_tol, effective_rt_tol, logger, feat_uid
625
+ )
626
+
627
+ # Convert matches to result format
628
+ match_results = []
629
+ if not matches.is_empty():
630
+ for match_row in matches.iter_rows(named=True):
631
+ mz_delta = abs(feat_mz - match_row.get("mz")) if match_row.get("mz") is not None else None
632
+ lib_rt = match_row.get("rt")
633
+ rt_delta = abs(feat_rt - lib_rt) if (feat_rt is not None and lib_rt is not None) else None
634
+
635
+ # Get library probability as base score, then multiply by adduct probability
636
+ lib_probability = match_row.get("probability", 1.0) if match_row.get("probability") is not None else 1.0
637
+ adduct = match_row.get("adduct")
638
+ adduct_probability = adduct_prob_map.get(adduct, 1.0) if adduct else 1.0
639
+ score = lib_probability * adduct_probability
640
+ # Scale to 0-100 and round to 1 decimal place
641
+ score = round(score * 100.0, 1)
642
+
643
+ match_results.append({
644
+ "lib_uid": match_row.get("lib_uid"),
645
+ "mz_delta": mz_delta,
646
+ "rt_delta": rt_delta,
647
+ "matcher": "ms1",
648
+ "score": score,
649
+ })
650
+
651
+ results.append({"feature_uid": feat_uid, "matches": match_results})
652
+
653
+ return results
654
+
655
+
656
+ def _find_matches_vectorized(lib_df, feat_mz, feat_rt, mz_tol, rt_tol, logger, feat_uid):
657
+ """Find library matches using optimized vectorized operations."""
658
+ # Filter by m/z tolerance using vectorized operations
659
+ matches = lib_df.filter((pl.col("mz") >= feat_mz - mz_tol) & (pl.col("mz") <= feat_mz + mz_tol))
660
+
661
+ initial_match_count = len(matches)
662
+
663
+ # Apply RT filter if available
664
+ if rt_tol is not None and feat_rt is not None and not matches.is_empty():
665
+ # First, check if any m/z matches have RT data
666
+ rt_candidates = matches.filter(pl.col("rt").is_not_null())
667
+
668
+ if not rt_candidates.is_empty():
669
+ # Apply RT filtering to candidates with RT data
670
+ rt_matches = rt_candidates.filter((pl.col("rt") >= feat_rt - rt_tol) & (pl.col("rt") <= feat_rt + rt_tol))
671
+
672
+ if not rt_matches.is_empty():
673
+ matches = rt_matches
674
+ if logger:
675
+ logger.debug(
676
+ f"Feature {feat_uid}: {initial_match_count} m/z matches, {len(rt_candidates)} with RT, {len(matches)} after RT filter"
677
+ )
678
+ else:
679
+ # NO FALLBACK - if RT filtering finds no matches, return empty
680
+ matches = rt_matches # This is empty
681
+ if logger:
682
+ logger.debug(
683
+ f"Feature {feat_uid}: RT filtering eliminated all {len(rt_candidates)} candidates (rt_tol={rt_tol}s) - no matches returned"
684
+ )
685
+ else:
686
+ # No RT data in library matches - fall back to m/z-only matching
687
+ if logger:
688
+ logger.debug(
689
+ f"Feature {feat_uid}: {initial_match_count} m/z matches but none have library RT data - using m/z-only matching"
690
+ )
691
+ # Keep the m/z matches (don't return empty DataFrame)
692
+
693
+ # Add stricter m/z validation - prioritize more accurate matches
694
+ if not matches.is_empty():
695
+ strict_mz_tol = mz_tol * 0.5 # Use 50% of tolerance as strict threshold
696
+ strict_matches = matches.filter(
697
+ (pl.col("mz") >= feat_mz - strict_mz_tol) & (pl.col("mz") <= feat_mz + strict_mz_tol)
698
+ )
699
+
700
+ if not strict_matches.is_empty():
701
+ # Use strict matches if available
702
+ matches = strict_matches
703
+ if logger:
704
+ logger.debug(
705
+ f"Feature {feat_uid}: Using {len(matches)} strict m/z matches (within {strict_mz_tol:.6f} Da)"
706
+ )
707
+ else:
708
+ if logger:
709
+ logger.debug(f"Feature {feat_uid}: No strict matches, using {len(matches)} loose matches")
710
+
711
+ # Improved deduplication - prioritize by m/z accuracy
712
+ if not matches.is_empty() and len(matches) > 1:
713
+ if "formula" in matches.columns and "adduct" in matches.columns:
714
+ pre_dedup_count = len(matches)
715
+
716
+ # Calculate m/z error for sorting
717
+ matches = matches.with_columns([(pl.col("mz") - feat_mz).abs().alias("mz_error_abs")])
718
+
719
+ # Group by formula and adduct, but keep the most accurate m/z match
720
+ matches = (
721
+ matches.sort(["mz_error_abs", "lib_uid"]) # Sort by m/z accuracy first, then lib_uid for consistency
722
+ .group_by(["formula", "adduct"], maintain_order=True)
723
+ .first()
724
+ .drop("mz_error_abs") # Remove the temporary column
725
+ )
726
+
727
+ post_dedup_count = len(matches)
728
+ if logger and post_dedup_count < pre_dedup_count:
729
+ logger.debug(
730
+ f"Feature {feat_uid}: deduplicated {pre_dedup_count} to {post_dedup_count} matches (m/z accuracy prioritized)"
731
+ )
732
+
733
+ return matches
734
+
735
+
736
+ def _update_identification_results(sample, results, logger):
737
+ """Update sample.id_df with new identification results."""
738
+ # Flatten results into records
739
+ records = []
740
+ for result in results:
741
+ feature_uid = result["feature_uid"]
742
+ for match in result["matches"]:
743
+ records.append({
744
+ "feature_uid": feature_uid,
745
+ "lib_uid": match["lib_uid"],
746
+ "mz_delta": match["mz_delta"],
747
+ "rt_delta": match["rt_delta"],
748
+ "matcher": match["matcher"],
749
+ "score": match["score"],
750
+ "iso": 0, # Default to zero
751
+ })
752
+
753
+ # Convert to DataFrame and append to existing results
754
+ new_results_df = pl.DataFrame(records) if records else pl.DataFrame()
755
+
756
+ if not new_results_df.is_empty():
757
+ # Join with lib_df to get source_id and rename to id_source
758
+ if hasattr(sample, "lib_df") and sample.lib_df is not None and not sample.lib_df.is_empty():
759
+ if "source_id" in sample.lib_df.columns:
760
+ new_results_df = new_results_df.join(
761
+ sample.lib_df.select(["lib_uid", "source_id"]),
762
+ on="lib_uid",
763
+ how="left"
764
+ ).rename({"source_id": "id_source"})
765
+ if logger:
766
+ logger.debug("Added 'id_source' column from library source_id")
767
+ else:
768
+ # Add id_source column with None if source_id not available
769
+ new_results_df = new_results_df.with_columns(pl.lit(None, dtype=pl.String).alias("id_source"))
770
+ if logger:
771
+ logger.debug("Library has no source_id; id_source set to None")
772
+ else:
773
+ # Add id_source column with None if lib_df not available
774
+ new_results_df = new_results_df.with_columns(pl.lit(None, dtype=pl.String).alias("id_source"))
775
+
776
+ if hasattr(sample, "id_df") and sample.id_df is not None and not sample.id_df.is_empty():
777
+ # Check if existing id_df has the iso column
778
+ if "iso" not in sample.id_df.columns:
779
+ # Add iso column to existing id_df with default value 0
780
+ sample.id_df = sample.id_df.with_columns(pl.lit(0).alias("iso"))
781
+ if logger:
782
+ logger.debug("Added 'iso' column to existing id_df for schema compatibility")
783
+
784
+ # Check if existing id_df has the id_source column
785
+ if "id_source" not in sample.id_df.columns:
786
+ # Add id_source column to existing id_df with None
787
+ sample.id_df = sample.id_df.with_columns(pl.lit(None, dtype=pl.String).alias("id_source"))
788
+ if logger:
789
+ logger.debug("Added 'id_source' column to existing id_df for schema compatibility")
790
+
791
+ sample.id_df = pl.concat([sample.id_df, new_results_df])
792
+ else:
793
+ sample.id_df = new_results_df
794
+
795
+ if logger:
796
+ logger.debug(f"Added {len(records)} identification results to sample.id_df")
797
+ elif not hasattr(sample, "id_df"):
798
+ sample.id_df = pl.DataFrame()
799
+
800
+
801
+ def _finalize_identification_results(sample, params, logger):
802
+ """Apply final scoring adjustments and update features columns."""
803
+ # Apply scoring adjustments based on compound and formula counts
804
+ _apply_scoring_adjustments(sample, params)
805
+
806
+ # Update features_df with top-scoring identification results
807
+ _update_features_id_columns(sample, logger)
808
+
809
+
810
+ def _store_identification_history(sample, effective_mz_tol, effective_rt_tol, target_uids, params, kwargs):
811
+ """Store identification operation in sample history."""
812
+ if hasattr(sample, "store_history"):
813
+ history_params = {"mz_tol": effective_mz_tol, "rt_tol": effective_rt_tol}
814
+ if target_uids is not None:
815
+ history_params["features"] = target_uids
816
+ if params is not None and hasattr(params, "to_dict"):
817
+ history_params["params"] = params.to_dict()
818
+ if kwargs:
819
+ history_params["kwargs"] = kwargs
820
+ sample.store_history(["identify"], history_params)
821
+
822
+
823
+ def _validate_identify_inputs(sample, logger=None):
824
+ """Validate inputs for identification process."""
825
+ if getattr(sample, "features_df", None) is None or sample.features_df.is_empty():
826
+ if logger:
827
+ logger.warning("No features found for identification")
828
+ return False
829
+
830
+ if getattr(sample, "lib_df", None) is None or sample.lib_df.is_empty():
831
+ if logger:
832
+ logger.error("Library (sample.lib_df) is empty; call lib_load() first")
833
+ raise ValueError("Library (sample.lib_df) is empty; call lib_load() first")
834
+
835
+ return True
836
+
837
+
838
+ def _prepare_features(sample, features, logger=None):
839
+ """Prepare features for identification."""
840
+ target_uids = None
841
+ if features is not None:
842
+ if hasattr(features, "columns"): # DataFrame-like
843
+ if "feature_uid" in features.columns:
844
+ target_uids = features["feature_uid"].unique().to_list()
845
+ else:
846
+ raise ValueError(
847
+ "features DataFrame must contain 'feature_uid' column",
848
+ )
849
+ elif hasattr(features, "__iter__") and not isinstance(
850
+ features,
851
+ str,
852
+ ): # List-like
853
+ target_uids = list(features)
854
+ else:
855
+ raise ValueError(
856
+ "features must be a DataFrame with 'feature_uid' column or a list of UIDs",
857
+ )
858
+
859
+ if logger:
860
+ logger.debug(f"Identifying {len(target_uids)} specified features")
861
+
862
+ # Filter features if target_uids specified
863
+ features_to_process = sample.features_df
864
+ if target_uids is not None:
865
+ features_to_process = sample.features_df.filter(
866
+ pl.col("feature_uid").is_in(target_uids),
867
+ )
868
+ if features_to_process.is_empty():
869
+ if logger:
870
+ logger.warning(
871
+ "No features found matching specified features",
872
+ )
873
+ return None, target_uids
874
+
875
+ return features_to_process, target_uids
876
+
877
+
878
+ def _get_adduct_probabilities(sample):
879
+ """Get adduct probabilities from _get_adducts() results."""
880
+ from masster.sample.adducts import _get_adducts
881
+
882
+ adducts_df = _get_adducts(sample)
883
+ adduct_prob_map = {}
884
+ if not adducts_df.is_empty():
885
+ for row in adducts_df.iter_rows(named=True):
886
+ adduct_prob_map[row.get("name")] = row.get("probability", 1.0)
887
+ return adduct_prob_map
888
+
889
+
890
+ def _apply_scoring_adjustments(sample, params):
891
+ """Apply scoring adjustments based on compound and formula counts using optimized operations."""
892
+ if not sample.id_df.is_empty() and hasattr(sample, "lib_df") and not sample.lib_df.is_empty():
893
+ # Get penalty parameters
894
+ heteroatoms = getattr(params, "heteroatoms", ["Cl", "Br", "F", "I"])
895
+ heteroatom_penalty = getattr(params, "heteroatom_penalty", 0.7)
896
+ formulas_penalty = getattr(params, "multiple_formulas_penalty", 0.8)
897
+ compounds_penalty = getattr(params, "multiple_compounds_penalty", 0.8)
898
+
899
+ # Single join to get all needed library information
900
+ lib_columns = ["lib_uid", "cmpd_uid", "formula"]
901
+ id_with_lib = sample.id_df.join(
902
+ sample.lib_df.select(lib_columns),
903
+ on="lib_uid",
904
+ how="left",
905
+ )
906
+
907
+ # Calculate all statistics in one group_by operation
908
+ stats = id_with_lib.group_by("feature_uid").agg([
909
+ pl.col("cmpd_uid").n_unique().alias("num_cmpds"),
910
+ pl.col("formula").filter(pl.col("formula").is_not_null()).n_unique().alias("num_formulas"),
911
+ ])
912
+
913
+ # Join stats back and apply all penalties in one with_columns operation
914
+ heteroatom_conditions = [pl.col("formula").str.contains(atom) for atom in heteroatoms]
915
+ has_heteroatoms = (
916
+ pl.fold(acc=pl.lit(False), function=lambda acc, x: acc | x, exprs=heteroatom_conditions)
917
+ if heteroatom_conditions
918
+ else pl.lit(False)
919
+ )
920
+
921
+ sample.id_df = (
922
+ id_with_lib.join(stats, on="feature_uid", how="left")
923
+ .with_columns([
924
+ # Apply all penalties in sequence using case-when chains
925
+ pl.when(pl.col("formula").is_not_null() & has_heteroatoms)
926
+ .then(pl.col("score") * heteroatom_penalty)
927
+ .otherwise(pl.col("score"))
928
+ .alias("score_temp1")
929
+ ])
930
+ .with_columns([
931
+ pl.when(pl.col("num_formulas") > 1)
932
+ .then(pl.col("score_temp1") * formulas_penalty)
933
+ .otherwise(pl.col("score_temp1"))
934
+ .alias("score_temp2")
935
+ ])
936
+ .with_columns([
937
+ pl.when(pl.col("num_cmpds") > 1)
938
+ .then(pl.col("score_temp2") * compounds_penalty)
939
+ .otherwise(pl.col("score_temp2"))
940
+ .round(4)
941
+ .alias("score")
942
+ ])
943
+ .select([
944
+ "feature_uid",
945
+ "lib_uid",
946
+ "mz_delta",
947
+ "rt_delta",
948
+ "matcher",
949
+ "score",
950
+ ])
951
+ )
952
+
953
+
954
+ def _update_features_id_columns(sample, logger=None):
955
+ """
956
+ Update features_df with top-scoring identification results using safe in-place updates.
957
+ """
958
+ try:
959
+ if not hasattr(sample, "id_df") or sample.id_df is None or sample.id_df.is_empty():
960
+ if logger:
961
+ logger.debug("No identification results to process")
962
+ return
963
+
964
+ if not hasattr(sample, "lib_df") or sample.lib_df is None or sample.lib_df.is_empty():
965
+ if logger:
966
+ logger.debug("No library data available")
967
+ return
968
+
969
+ if not hasattr(sample, "features_df") or sample.features_df is None or sample.features_df.is_empty():
970
+ if logger:
971
+ logger.debug("No features data available")
972
+ return
973
+
974
+ # Get library columns we need
975
+ lib_columns = ["lib_uid", "name", "adduct"]
976
+ if "class" in sample.lib_df.columns:
977
+ lib_columns.append("class")
978
+
979
+ # Get top-scoring identification for each feature
980
+ top_ids = (
981
+ sample.id_df.sort(["feature_uid", "score"], descending=[False, True])
982
+ .group_by("feature_uid", maintain_order=True)
983
+ .first()
984
+ .join(sample.lib_df.select(lib_columns), on="lib_uid", how="left")
985
+ .select([
986
+ "feature_uid",
987
+ "name",
988
+ pl.col("class").alias("id_top_class")
989
+ if "class" in lib_columns
990
+ else pl.lit(None, dtype=pl.String).alias("id_top_class"),
991
+ pl.col("adduct").alias("id_top_adduct"),
992
+ pl.col("score").alias("id_top_score"),
993
+ ])
994
+ .rename({"name": "id_top_name"})
995
+ )
996
+
997
+ # Ensure we have the id_top columns in features_df
998
+ for col_name, dtype in [
999
+ ("id_top_name", pl.String),
1000
+ ("id_top_class", pl.String),
1001
+ ("id_top_adduct", pl.String),
1002
+ ("id_top_score", pl.Float64),
1003
+ ]:
1004
+ if col_name not in sample.features_df.columns:
1005
+ sample.features_df = sample.features_df.with_columns(pl.lit(None, dtype=dtype).alias(col_name))
1006
+
1007
+ # Create a mapping dictionary for efficient updates
1008
+ id_mapping = {}
1009
+ for row in top_ids.iter_rows(named=True):
1010
+ feature_uid = row["feature_uid"]
1011
+ id_mapping[feature_uid] = {
1012
+ "id_top_name": row["id_top_name"],
1013
+ "id_top_class": row["id_top_class"],
1014
+ "id_top_adduct": row["id_top_adduct"],
1015
+ "id_top_score": row["id_top_score"],
1016
+ }
1017
+
1018
+ # Update features_df using map_elements (safer than join for avoiding duplicates)
1019
+ if id_mapping:
1020
+ sample.features_df = sample.features_df.with_columns([
1021
+ pl.col("feature_uid")
1022
+ .map_elements(lambda uid: id_mapping.get(uid, {}).get("id_top_name"), return_dtype=pl.String)
1023
+ .alias("id_top_name"),
1024
+ pl.col("feature_uid")
1025
+ .map_elements(lambda uid: id_mapping.get(uid, {}).get("id_top_class"), return_dtype=pl.String)
1026
+ .alias("id_top_class"),
1027
+ pl.col("feature_uid")
1028
+ .map_elements(lambda uid: id_mapping.get(uid, {}).get("id_top_adduct"), return_dtype=pl.String)
1029
+ .alias("id_top_adduct"),
1030
+ pl.col("feature_uid")
1031
+ .map_elements(lambda uid: id_mapping.get(uid, {}).get("id_top_score"), return_dtype=pl.Float64)
1032
+ .alias("id_top_score"),
1033
+ ])
1034
+
1035
+ if logger:
1036
+ num_updated = len(id_mapping)
1037
+ logger.debug(f"Updated features_df with top identifications for {num_updated} features")
1038
+
1039
+ except Exception as e:
1040
+ if logger:
1041
+ logger.error(f"Error updating features_df with identification results: {e}")
1042
+ # Don't re-raise to avoid breaking the identification process
1043
+
1044
+
1045
+ def _generate_13c_isotopes(lib_df):
1046
+ """
1047
+ Generate 13C isotope variants for library entries.
1048
+
1049
+ For each compound with n carbon atoms, creates n+1 entries:
1050
+ - iso=0: original compound (no 13C)
1051
+ - iso=1: one 13C isotope (+1.00335 Da)
1052
+ - iso=2: two 13C isotopes (+2.00670 Da)
1053
+ - ...
1054
+ - iso=n: n 13C isotopes (+n*1.00335 Da)
1055
+
1056
+ All isotopomers share the same quant_group.
1057
+
1058
+ Args:
1059
+ lib_df: Polars DataFrame with library entries
1060
+
1061
+ Returns:
1062
+ Polars DataFrame with additional 13C isotope entries
1063
+ """
1064
+ if lib_df.is_empty():
1065
+ return lib_df
1066
+
1067
+ # First, ensure all original entries have iso=0
1068
+ original_df = lib_df.with_columns(pl.lit(0).alias("iso"))
1069
+
1070
+ isotope_entries = []
1071
+ next_lib_uid = lib_df["lib_uid"].max() + 1 if len(lib_df) > 0 else 1
1072
+
1073
+ # Mass difference for one 13C isotope
1074
+ c13_mass_shift = 1.00335 # Mass difference between 13C and 12C
1075
+
1076
+ for row in original_df.iter_rows(named=True):
1077
+ formula = row.get("formula", "")
1078
+ if not formula:
1079
+ continue
1080
+
1081
+ # Count carbon atoms in the formula
1082
+ carbon_count = _count_carbon_atoms(formula)
1083
+ if carbon_count == 0:
1084
+ continue
1085
+
1086
+ # Get the original quant_group to keep it consistent across isotopes
1087
+ quant_group = row.get("quant_group", row.get("cmpd_uid", row.get("lib_uid", 1)))
1088
+
1089
+ # Generate isotope variants (1 to n 13C atoms)
1090
+ for iso_num in range(1, carbon_count + 1):
1091
+ # Calculate mass shift for this number of 13C isotopes
1092
+ mass_shift = iso_num * c13_mass_shift
1093
+
1094
+ # Create new entry
1095
+ isotope_entry = dict(row) # Copy all fields
1096
+ isotope_entry["lib_uid"] = next_lib_uid
1097
+ isotope_entry["iso"] = iso_num
1098
+ isotope_entry["m"] = row["m"] + mass_shift
1099
+ isotope_entry["mz"] = (row["m"] + mass_shift) / abs(row["z"]) if row["z"] != 0 else row["m"] + mass_shift
1100
+ isotope_entry["quant_group"] = quant_group # Keep same quant_group
1101
+
1102
+ isotope_entries.append(isotope_entry)
1103
+ next_lib_uid += 1
1104
+
1105
+ # Combine original entries (now with iso=0) with isotope entries
1106
+ if isotope_entries:
1107
+ isotope_df = pl.DataFrame(isotope_entries)
1108
+ # Ensure schema compatibility by aligning data types
1109
+ try:
1110
+ return pl.concat([original_df, isotope_df])
1111
+ except Exception:
1112
+ # If concat fails due to schema mismatch, convert to compatible types
1113
+ # Get common schema
1114
+ original_schema = original_df.schema
1115
+ isotope_schema = isotope_df.schema
1116
+
1117
+ # Cast isotope_df columns to match original_df schema where possible
1118
+ cast_exprs = []
1119
+ for col_name in isotope_df.columns:
1120
+ if col_name in original_schema:
1121
+ target_dtype = original_schema[col_name]
1122
+ cast_exprs.append(pl.col(col_name).cast(target_dtype, strict=False))
1123
+ else:
1124
+ cast_exprs.append(pl.col(col_name))
1125
+
1126
+ isotope_df_cast = isotope_df.select(cast_exprs)
1127
+ return pl.concat([original_df, isotope_df_cast])
1128
+ else:
1129
+ return original_df
1130
+
1131
+
1132
+ def _count_carbon_atoms(formula: str) -> int:
1133
+ """
1134
+ Count the number of carbon atoms in a molecular formula.
1135
+
1136
+ Args:
1137
+ formula: Molecular formula string like "C6H12O6"
1138
+
1139
+ Returns:
1140
+ Number of carbon atoms
1141
+ """
1142
+ import re
1143
+
1144
+ if not formula or not isinstance(formula, str):
1145
+ return 0
1146
+
1147
+ # Look for carbon followed by optional number
1148
+ # C followed by digits, or just C (which means 1)
1149
+ carbon_matches = re.findall(r"C(\d*)", formula)
1150
+
1151
+ total_carbons = 0
1152
+ for match in carbon_matches:
1153
+ if match == "":
1154
+ # Just 'C' without number means 1 carbon
1155
+ total_carbons += 1
1156
+ else:
1157
+ # 'C' followed by number
1158
+ total_carbons += int(match)
1159
+
1160
+ return total_carbons