masster 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of masster might be problematic. Click here for more details.

masster/study/helpers.py CHANGED
@@ -1,3 +1,18 @@
1
+ """
2
+ helpers.py
3
+
4
+ This module contains helper functions for the Study class that handle various operations
5
+ like data retrieval, filtering, compression, and utility functions.
6
+
7
+ The functions are organized into the following sections:
8
+ 1. Chromatogram extraction functions (BPC, TIC, EIC, chrom matrix)
9
+ 2. Data retrieval helper functions (get_sample, get_consensus, etc.)
10
+ 3. UID helper functions (_get_*_uids)
11
+ 4. Data filtering and selection functions
12
+ 5. Data compression and restoration functions
13
+ 6. Utility functions (reset, naming, colors, schema ordering)
14
+ """
15
+
1
16
  from __future__ import annotations
2
17
 
3
18
  import os
@@ -10,6 +25,11 @@ from tqdm import tqdm
10
25
  from masster.chromatogram import Chromatogram
11
26
 
12
27
 
28
+ # =====================================================================================
29
+ # CHROMATOGRAM EXTRACTION FUNCTIONS
30
+ # =====================================================================================
31
+
32
+
13
33
  def get_bpc(owner, sample=None, rt_unit="s", label=None, original=False):
14
34
  """
15
35
  Return a Chromatogram object containing the Base Peak Chromatogram (BPC).
@@ -96,7 +116,6 @@ def get_bpc(owner, sample=None, rt_unit="s", label=None, original=False):
96
116
  if (mapping_rows is None or mapping_rows.is_empty()) and hasattr(s, "sample_path"):
97
117
  # attempt to match by sample_path or file name
98
118
  try:
99
- sample_paths = feats.select(["sample_uid", "sample_name", "sample_path"]) # type: ignore[arg-type]
100
119
  # find row where sample_path matches
101
120
  mapping_rows = feats.filter(pl.col("sample_path") == getattr(s, "file", None))
102
121
  except Exception:
@@ -192,7 +211,7 @@ def get_tic(owner, sample=None, label=None):
192
211
  return chrom
193
212
 
194
213
 
195
- def get_eic(owner, sample=None, mz=None, mz_tol=0.01, rt_unit="s", label=None):
214
+ def get_eic(owner, sample=None, mz=None, mz_tol=None, rt_unit="s", label=None):
196
215
  """
197
216
  Return a Chromatogram object containing the Extracted Ion Chromatogram (EIC) for a target m/z.
198
217
 
@@ -206,13 +225,20 @@ def get_eic(owner, sample=None, mz=None, mz_tol=0.01, rt_unit="s", label=None):
206
225
  owner: Study or Sample instance
207
226
  sample: Sample identifier (required if owner is Study)
208
227
  mz (float): Target m/z value
209
- mz_tol (float): m/z tolerance (default 0.01)
228
+ mz_tol (float): m/z tolerance. If None, uses owner.parameters.eic_mz_tol (for Study) or defaults to 0.01
210
229
  rt_unit (str): Retention time unit for the chromatogram
211
230
  label (str): Optional label for the chromatogram
212
231
 
213
232
  Returns:
214
233
  Chromatogram
215
234
  """
235
+ # Use default mz_tol from study parameters if not provided
236
+ if mz_tol is None:
237
+ if hasattr(owner, 'parameters') and hasattr(owner.parameters, 'eic_mz_tol'):
238
+ mz_tol = owner.parameters.eic_mz_tol
239
+ else:
240
+ mz_tol = 0.01 # fallback default
241
+
216
242
  if mz is None:
217
243
  raise ValueError("mz must be provided for EIC computation")
218
244
 
@@ -290,6 +316,9 @@ def get_eic(owner, sample=None, mz=None, mz_tol=0.01, rt_unit="s", label=None):
290
316
  return chrom
291
317
 
292
318
 
319
+ # =====================================================================================
320
+ # DATA RETRIEVAL AND MATRIX FUNCTIONS
321
+ # =====================================================================================
293
322
 
294
323
 
295
324
  def get_chrom(self, uids=None, samples=None):
@@ -393,10 +422,14 @@ def get_chrom(self, uids=None, samples=None):
393
422
  # Create Polars DataFrame with complex objects
394
423
  df2_pivoted = pl.DataFrame(pivot_data)
395
424
 
396
- # Return as Polars DataFrame (can handle complex objects like Chromatogram)
397
425
  return df2_pivoted
398
426
 
399
427
 
428
+ # =====================================================================================
429
+ # UTILITY AND CONFIGURATION FUNCTIONS
430
+ # =====================================================================================
431
+
432
+
400
433
  def set_folder(self, folder):
401
434
  """
402
435
  Set the folder for saving and loading files.
@@ -424,6 +457,12 @@ def align_reset(self):
424
457
  # Ensure column order is maintained after with_columns operation
425
458
  self._ensure_features_df_schema_order()
426
459
 
460
+
461
+ # =====================================================================================
462
+ # DATA RETRIEVAL HELPER FUNCTIONS
463
+ # =====================================================================================
464
+
465
+
427
466
  # TODO I don't get this param
428
467
  def get_consensus(self, quant="chrom_area"):
429
468
  if self.consensus_df is None:
@@ -555,6 +594,11 @@ def get_consensus_matches(self, uids=None):
555
594
  return matches
556
595
 
557
596
 
597
+ # =====================================================================================
598
+ # UID HELPER FUNCTIONS
599
+ # =====================================================================================
600
+
601
+
558
602
  def fill_reset(self):
559
603
  # remove all features with filled=True
560
604
  if self.features_df is None:
@@ -757,6 +801,11 @@ def get_orphans(self):
757
801
  return not_in_consensus
758
802
 
759
803
 
804
+ # =====================================================================================
805
+ # DATA COMPRESSION AND RESTORATION FUNCTIONS
806
+ # =====================================================================================
807
+
808
+
760
809
  def compress(self, features=True, ms2=True, chrom=False, ms2_max=5):
761
810
  """
762
811
  Perform compress_features, compress_ms2, and compress_chrom operations.
@@ -1251,7 +1300,12 @@ def compress_chrom(self):
1251
1300
  self.logger.info(f"Compressed chromatograms: cleared {non_null_count} chromatogram objects from features_df")
1252
1301
 
1253
1302
 
1254
- def name_replace(self, replace_dict):
1303
+ # =====================================================================================
1304
+ # SAMPLE MANAGEMENT AND NAMING FUNCTIONS
1305
+ # =====================================================================================
1306
+
1307
+
1308
+ def sample_name_replace(self, replace_dict):
1255
1309
  """
1256
1310
  Replace sample names in samples_df based on a dictionary mapping.
1257
1311
 
@@ -1317,7 +1371,7 @@ def name_replace(self, replace_dict):
1317
1371
  self.logger.info(f"Successfully replaced {replaced_count} sample names")
1318
1372
 
1319
1373
 
1320
- def name_reset(self):
1374
+ def sample_name_reset(self):
1321
1375
  """
1322
1376
  Reset sample names to the basename of sample_path without extensions.
1323
1377
 
@@ -1399,7 +1453,7 @@ def set_source(self, filename):
1399
1453
  failed_count = 0
1400
1454
 
1401
1455
  # Get all current file_source values
1402
- current_sources = self.samples_df.get_column("file_source").to_list()
1456
+ current_sources = self.samples_df.get_column("sample_source").to_list()
1403
1457
  sample_names = self.samples_df.get_column("sample_name").to_list()
1404
1458
 
1405
1459
  new_sources = []
@@ -1447,6 +1501,11 @@ def set_source(self, filename):
1447
1501
  self.logger.warning(f"Failed to update file_source for {failed_count} samples")
1448
1502
 
1449
1503
 
1504
+ # =====================================================================================
1505
+ # DATA FILTERING AND SELECTION FUNCTIONS
1506
+ # =====================================================================================
1507
+
1508
+
1450
1509
  def features_select(
1451
1510
  self,
1452
1511
  mz=None,
@@ -1872,13 +1931,21 @@ def consensus_select(
1872
1931
  chrom_prominence_scaled_mean=None,
1873
1932
  chrom_height_scaled_mean=None,
1874
1933
  rt_delta_mean=None,
1934
+ sortby=None,
1935
+ descending=True,
1875
1936
  ):
1876
1937
  """
1877
1938
  Select consensus features from consensus_df based on specified criteria and return the filtered DataFrame.
1878
1939
 
1879
1940
  Parameters:
1880
- mz: m/z range filter (tuple for range, single value for minimum)
1881
- rt: retention time range filter (tuple for range, single value for minimum)
1941
+ mz: m/z filter with flexible formats:
1942
+ - float: m/z value ± default tolerance (uses study.parameters.eic_mz_tol)
1943
+ - tuple (mz_min, mz_max): range where mz_max > mz_min
1944
+ - tuple (mz_center, mz_tol): range where mz_tol < mz_center (interpreted as mz_center ± mz_tol)
1945
+ rt: retention time filter with flexible formats:
1946
+ - float: RT value ± default tolerance (uses study.parameters.eic_rt_tol)
1947
+ - tuple (rt_min, rt_max): range where rt_max > rt_min
1948
+ - tuple (rt_center, rt_tol): range where rt_tol < rt_center (interpreted as rt_center ± rt_tol)
1882
1949
  inty_mean: mean intensity filter (tuple for range, single value for minimum)
1883
1950
  consensus_uid: consensus UID filter (list, single value, or tuple for range)
1884
1951
  consensus_id: consensus ID filter (list or single value)
@@ -1891,6 +1958,8 @@ def consensus_select(
1891
1958
  chrom_prominence_scaled_mean: mean scaled chromatogram prominence filter (tuple for range, single value for minimum)
1892
1959
  chrom_height_scaled_mean: mean scaled chromatogram height filter (tuple for range, single value for minimum)
1893
1960
  rt_delta_mean: mean RT delta filter (tuple for range, single value for minimum)
1961
+ sortby: column name(s) to sort by (string, list of strings, or None for no sorting)
1962
+ descending: sort direction (True for descending, False for ascending, default is True)
1894
1963
 
1895
1964
  Returns:
1896
1965
  polars.DataFrame: Filtered consensus DataFrame
@@ -1905,11 +1974,32 @@ def consensus_select(
1905
1974
  # Filter by m/z
1906
1975
  if mz is not None:
1907
1976
  consensus_len_before_filter = len(consensus)
1977
+
1908
1978
  if isinstance(mz, tuple) and len(mz) == 2:
1909
- min_mz, max_mz = mz
1979
+ # Check if second value is smaller than first (indicating mz, mz_tol format)
1980
+ if mz[1] < mz[0]:
1981
+ # First is mz, second is mz_tol
1982
+ mz_center, mz_tol = mz
1983
+ min_mz = mz_center - mz_tol
1984
+ max_mz = mz_center + mz_tol
1985
+ else:
1986
+ # Standard (min_mz, max_mz) format
1987
+ min_mz, max_mz = mz
1910
1988
  consensus = consensus.filter((pl.col("mz") >= min_mz) & (pl.col("mz") <= max_mz))
1911
1989
  else:
1912
- consensus = consensus.filter(pl.col("mz") >= mz)
1990
+ # Single float value - use default mz tolerance from study parameters
1991
+ default_mz_tol = getattr(self, 'parameters', None)
1992
+ if default_mz_tol and hasattr(default_mz_tol, 'eic_mz_tol'):
1993
+ default_mz_tol = default_mz_tol.eic_mz_tol
1994
+ else:
1995
+ # Fallback to align_defaults if study parameters not available
1996
+ from masster.study.defaults.align_def import align_defaults
1997
+ default_mz_tol = align_defaults().mz_max_diff
1998
+
1999
+ min_mz = mz - default_mz_tol
2000
+ max_mz = mz + default_mz_tol
2001
+ consensus = consensus.filter((pl.col("mz") >= min_mz) & (pl.col("mz") <= max_mz))
2002
+
1913
2003
  self.logger.debug(
1914
2004
  f"Selected consensus by mz. Consensus removed: {consensus_len_before_filter - len(consensus)}",
1915
2005
  )
@@ -1917,11 +2007,32 @@ def consensus_select(
1917
2007
  # Filter by retention time
1918
2008
  if rt is not None:
1919
2009
  consensus_len_before_filter = len(consensus)
2010
+
1920
2011
  if isinstance(rt, tuple) and len(rt) == 2:
1921
- min_rt, max_rt = rt
2012
+ # Check if second value is smaller than first (indicating rt, rt_tol format)
2013
+ if rt[1] < rt[0]:
2014
+ # First is rt, second is rt_tol
2015
+ rt_center, rt_tol = rt
2016
+ min_rt = rt_center - rt_tol
2017
+ max_rt = rt_center + rt_tol
2018
+ else:
2019
+ # Standard (min_rt, max_rt) format
2020
+ min_rt, max_rt = rt
1922
2021
  consensus = consensus.filter((pl.col("rt") >= min_rt) & (pl.col("rt") <= max_rt))
1923
2022
  else:
1924
- consensus = consensus.filter(pl.col("rt") >= rt)
2023
+ # Single float value - use default rt tolerance from study parameters
2024
+ default_rt_tol = getattr(self, 'parameters', None)
2025
+ if default_rt_tol and hasattr(default_rt_tol, 'eic_rt_tol'):
2026
+ default_rt_tol = default_rt_tol.eic_rt_tol
2027
+ else:
2028
+ # Fallback to align_defaults if study parameters not available
2029
+ from masster.study.defaults.align_def import align_defaults
2030
+ default_rt_tol = align_defaults().rt_max_diff
2031
+
2032
+ min_rt = rt - default_rt_tol
2033
+ max_rt = rt + default_rt_tol
2034
+ consensus = consensus.filter((pl.col("rt") >= min_rt) & (pl.col("rt") <= max_rt))
2035
+
1925
2036
  self.logger.debug(
1926
2037
  f"Selected consensus by rt. Consensus removed: {consensus_len_before_filter - len(consensus)}",
1927
2038
  )
@@ -2118,6 +2229,27 @@ def consensus_select(
2118
2229
  else:
2119
2230
  self.logger.info(f"Selected consensus features. Features remaining: {len(consensus)} (from {initial_count})")
2120
2231
 
2232
+ # Sort the results if sortby is specified
2233
+ if sortby is not None:
2234
+ if isinstance(sortby, str):
2235
+ # Single column
2236
+ if sortby in consensus.columns:
2237
+ consensus = consensus.sort(sortby, descending=descending)
2238
+ else:
2239
+ self.logger.warning(f"Sort column '{sortby}' not found in consensus DataFrame")
2240
+ elif isinstance(sortby, (list, tuple)):
2241
+ # Multiple columns
2242
+ valid_columns = [col for col in sortby if col in consensus.columns]
2243
+ invalid_columns = [col for col in sortby if col not in consensus.columns]
2244
+
2245
+ if invalid_columns:
2246
+ self.logger.warning(f"Sort columns not found in consensus DataFrame: {invalid_columns}")
2247
+
2248
+ if valid_columns:
2249
+ consensus = consensus.sort(valid_columns, descending=descending)
2250
+ else:
2251
+ self.logger.warning(f"Invalid sortby parameter type: {type(sortby)}. Expected str, list, or tuple.")
2252
+
2121
2253
  return consensus
2122
2254
 
2123
2255
 
@@ -2222,3 +2354,832 @@ def consensus_delete(self, consensus):
2222
2354
  None (modifies self.consensus_df and related DataFrames in place)
2223
2355
  """
2224
2356
  self.consensus_filter(consensus)
2357
+
2358
+
2359
+ # =====================================================================================
2360
+ # SAMPLE MANAGEMENT AND DELETION FUNCTIONS
2361
+ # =====================================================================================
2362
+
2363
+
2364
+ def samples_select(
2365
+ self,
2366
+ sample_uid=None,
2367
+ sample_name=None,
2368
+ sample_type=None,
2369
+ sample_group=None,
2370
+ sample_batch=None,
2371
+ sample_sequence=None,
2372
+ num_features=None,
2373
+ num_ms1=None,
2374
+ num_ms2=None,
2375
+ ):
2376
+ """
2377
+ Select samples from samples_df based on specified criteria and return the filtered DataFrame.
2378
+
2379
+ Parameters:
2380
+ sample_uid: sample UID filter (list, single value, or tuple for range)
2381
+ sample_name: sample name filter (list or single value)
2382
+ sample_type: sample type filter (list or single value)
2383
+ sample_group: sample group filter (list or single value)
2384
+ sample_batch: sample batch filter (list, single value, or tuple for range)
2385
+ sample_sequence: sample sequence filter (list, single value, or tuple for range)
2386
+ num_features: number of features filter (tuple for range, single value for minimum)
2387
+ num_ms1: number of MS1 spectra filter (tuple for range, single value for minimum)
2388
+ num_ms2: number of MS2 spectra filter (tuple for range, single value for minimum)
2389
+
2390
+ Returns:
2391
+ polars.DataFrame: Filtered samples DataFrame
2392
+ """
2393
+ if self.samples_df is None or self.samples_df.is_empty():
2394
+ self.logger.warning("No samples found in study.")
2395
+ return pl.DataFrame()
2396
+
2397
+ # Early return if no filters provided
2398
+ filter_params = [
2399
+ sample_uid,
2400
+ sample_name,
2401
+ sample_type,
2402
+ sample_group,
2403
+ sample_batch,
2404
+ sample_sequence,
2405
+ num_features,
2406
+ num_ms1,
2407
+ num_ms2,
2408
+ ]
2409
+ if all(param is None for param in filter_params):
2410
+ return self.samples_df.clone()
2411
+
2412
+ initial_count = len(self.samples_df)
2413
+
2414
+ # Pre-check available columns once for efficiency
2415
+ available_columns = set(self.samples_df.columns)
2416
+
2417
+ # Build all filter conditions first, then apply them all at once
2418
+ filter_conditions = []
2419
+ warnings = []
2420
+
2421
+ # Filter by sample_uid
2422
+ if sample_uid is not None:
2423
+ if isinstance(sample_uid, (list, tuple)):
2424
+ if len(sample_uid) == 2 and not isinstance(sample_uid, list):
2425
+ # Treat as range
2426
+ min_uid, max_uid = sample_uid
2427
+ filter_conditions.append((pl.col("sample_uid") >= min_uid) & (pl.col("sample_uid") <= max_uid))
2428
+ else:
2429
+ # Treat as list
2430
+ filter_conditions.append(pl.col("sample_uid").is_in(sample_uid))
2431
+ else:
2432
+ filter_conditions.append(pl.col("sample_uid") == sample_uid)
2433
+
2434
+ # Filter by sample_name
2435
+ if sample_name is not None:
2436
+ if isinstance(sample_name, list):
2437
+ filter_conditions.append(pl.col("sample_name").is_in(sample_name))
2438
+ else:
2439
+ filter_conditions.append(pl.col("sample_name") == sample_name)
2440
+
2441
+ # Filter by sample_type
2442
+ if sample_type is not None:
2443
+ if "sample_type" in available_columns:
2444
+ if isinstance(sample_type, list):
2445
+ filter_conditions.append(pl.col("sample_type").is_in(sample_type))
2446
+ else:
2447
+ filter_conditions.append(pl.col("sample_type") == sample_type)
2448
+ else:
2449
+ warnings.append("'sample_type' column not found in samples_df")
2450
+
2451
+ # Filter by sample_group
2452
+ if sample_group is not None:
2453
+ if "sample_group" in available_columns:
2454
+ if isinstance(sample_group, list):
2455
+ filter_conditions.append(pl.col("sample_group").is_in(sample_group))
2456
+ else:
2457
+ filter_conditions.append(pl.col("sample_group") == sample_group)
2458
+ else:
2459
+ warnings.append("'sample_group' column not found in samples_df")
2460
+
2461
+ # Filter by sample_batch
2462
+ if sample_batch is not None:
2463
+ if "sample_batch" in available_columns:
2464
+ if isinstance(sample_batch, (list, tuple)):
2465
+ if len(sample_batch) == 2 and not isinstance(sample_batch, list):
2466
+ # Treat as range
2467
+ min_batch, max_batch = sample_batch
2468
+ filter_conditions.append((pl.col("sample_batch") >= min_batch) & (pl.col("sample_batch") <= max_batch))
2469
+ else:
2470
+ # Treat as list
2471
+ filter_conditions.append(pl.col("sample_batch").is_in(sample_batch))
2472
+ else:
2473
+ filter_conditions.append(pl.col("sample_batch") == sample_batch)
2474
+ else:
2475
+ warnings.append("'sample_batch' column not found in samples_df")
2476
+
2477
+ # Filter by sample_sequence
2478
+ if sample_sequence is not None:
2479
+ if "sample_sequence" in available_columns:
2480
+ if isinstance(sample_sequence, (list, tuple)):
2481
+ if len(sample_sequence) == 2 and not isinstance(sample_sequence, list):
2482
+ # Treat as range
2483
+ min_seq, max_seq = sample_sequence
2484
+ filter_conditions.append((pl.col("sample_sequence") >= min_seq) & (pl.col("sample_sequence") <= max_seq))
2485
+ else:
2486
+ # Treat as list
2487
+ filter_conditions.append(pl.col("sample_sequence").is_in(sample_sequence))
2488
+ else:
2489
+ filter_conditions.append(pl.col("sample_sequence") == sample_sequence)
2490
+ else:
2491
+ warnings.append("'sample_sequence' column not found in samples_df")
2492
+
2493
+ # Filter by num_features
2494
+ if num_features is not None:
2495
+ if "num_features" in available_columns:
2496
+ if isinstance(num_features, tuple) and len(num_features) == 2:
2497
+ min_features, max_features = num_features
2498
+ filter_conditions.append((pl.col("num_features") >= min_features) & (pl.col("num_features") <= max_features))
2499
+ else:
2500
+ filter_conditions.append(pl.col("num_features") >= num_features)
2501
+ else:
2502
+ warnings.append("'num_features' column not found in samples_df")
2503
+
2504
+ # Filter by num_ms1
2505
+ if num_ms1 is not None:
2506
+ if "num_ms1" in available_columns:
2507
+ if isinstance(num_ms1, tuple) and len(num_ms1) == 2:
2508
+ min_ms1, max_ms1 = num_ms1
2509
+ filter_conditions.append((pl.col("num_ms1") >= min_ms1) & (pl.col("num_ms1") <= max_ms1))
2510
+ else:
2511
+ filter_conditions.append(pl.col("num_ms1") >= num_ms1)
2512
+ else:
2513
+ warnings.append("'num_ms1' column not found in samples_df")
2514
+
2515
+ # Filter by num_ms2
2516
+ if num_ms2 is not None:
2517
+ if "num_ms2" in available_columns:
2518
+ if isinstance(num_ms2, tuple) and len(num_ms2) == 2:
2519
+ min_ms2, max_ms2 = num_ms2
2520
+ filter_conditions.append((pl.col("num_ms2") >= min_ms2) & (pl.col("num_ms2") <= max_ms2))
2521
+ else:
2522
+ filter_conditions.append(pl.col("num_ms2") >= num_ms2)
2523
+ else:
2524
+ warnings.append("'num_ms2' column not found in samples_df")
2525
+
2526
+ # Log all warnings once at the end for efficiency
2527
+ for warning in warnings:
2528
+ self.logger.warning(warning)
2529
+
2530
+ # Apply all filters at once using lazy evaluation for optimal performance
2531
+ if filter_conditions:
2532
+ # Combine all conditions with AND
2533
+ combined_filter = filter_conditions[0]
2534
+ for condition in filter_conditions[1:]:
2535
+ combined_filter = combined_filter & condition
2536
+
2537
+ # Apply the combined filter using lazy evaluation
2538
+ samples = self.samples_df.lazy().filter(combined_filter).collect()
2539
+ else:
2540
+ samples = self.samples_df.clone()
2541
+
2542
+ final_count = len(samples)
2543
+
2544
+ if final_count == 0:
2545
+ self.logger.warning("No samples remaining after applying selection criteria.")
2546
+ else:
2547
+ self.logger.info(f"Samples selected: {final_count} (out of {initial_count})")
2548
+
2549
+ return samples
2550
+
2551
+
2552
+ def samples_delete(self, samples):
2553
+ """
2554
+ Delete samples and all related data from the study based on sample identifiers.
2555
+
2556
+ This function eliminates all data related to the specified samples (and their sample_uids)
2557
+ from all dataframes including:
2558
+ - samples_df: Removes the sample rows
2559
+ - features_df: Removes all features belonging to these samples
2560
+ - consensus_mapping_df: Removes mappings for features from these samples
2561
+ - consensus_ms2: Removes MS2 spectra for features from these samples
2562
+ - feature_maps: Removes the corresponding feature maps
2563
+
2564
+ Also updates map_id values to maintain sequential indices after deletion.
2565
+
2566
+ Parameters:
2567
+ samples: Samples to delete. Can be:
2568
+ - list of int: List of sample_uids to delete
2569
+ - polars.DataFrame: DataFrame obtained from samples_select (will use sample_uid column)
2570
+ - int: Single sample_uid to delete
2571
+
2572
+ Returns:
2573
+ None (modifies study DataFrames and feature_maps in place)
2574
+ """
2575
+ if self.samples_df is None or self.samples_df.is_empty():
2576
+ self.logger.warning("No samples found in study.")
2577
+ return
2578
+
2579
+ # Early return if no samples provided
2580
+ if samples is None:
2581
+ self.logger.warning("No samples provided for deletion.")
2582
+ return
2583
+
2584
+ initial_sample_count = len(self.samples_df)
2585
+
2586
+ # Determine sample_uids to remove
2587
+ if isinstance(samples, pl.DataFrame):
2588
+ if "sample_uid" not in samples.columns:
2589
+ self.logger.error("samples DataFrame must contain 'sample_uid' column")
2590
+ return
2591
+ sample_uids_to_remove = samples["sample_uid"].to_list()
2592
+ elif isinstance(samples, (list, tuple)):
2593
+ sample_uids_to_remove = list(samples) # Convert tuple to list if needed
2594
+ elif isinstance(samples, int):
2595
+ sample_uids_to_remove = [samples]
2596
+ else:
2597
+ self.logger.error("samples parameter must be a DataFrame, list, tuple, or int")
2598
+ return
2599
+
2600
+ # Early return if no UIDs to remove
2601
+ if not sample_uids_to_remove:
2602
+ self.logger.warning("No sample UIDs provided for deletion.")
2603
+ return
2604
+
2605
+ # Convert to set for faster lookup if list is large
2606
+ if len(sample_uids_to_remove) > 100:
2607
+ sample_uids_set = set(sample_uids_to_remove)
2608
+ # Use the set for filtering if it's significantly smaller
2609
+ if len(sample_uids_set) < len(sample_uids_to_remove) * 0.8:
2610
+ sample_uids_to_remove = list(sample_uids_set)
2611
+
2612
+ self.logger.info(f"Deleting {len(sample_uids_to_remove)} samples and all related data...")
2613
+
2614
+ # Get feature_uids that need to be removed from features_df
2615
+ feature_uids_to_remove = []
2616
+ initial_features_count = 0
2617
+ if self.features_df is not None and not self.features_df.is_empty():
2618
+ initial_features_count = len(self.features_df)
2619
+ feature_uids_to_remove = self.features_df.filter(
2620
+ pl.col("sample_uid").is_in(sample_uids_to_remove),
2621
+ )["feature_uid"].to_list()
2622
+
2623
+ # Get map_ids to remove from feature_maps (needed before samples_df deletion)
2624
+ map_ids_to_remove = []
2625
+ if hasattr(self, 'feature_maps') and self.feature_maps is not None:
2626
+ # Get map_ids for samples to be deleted
2627
+ map_ids_df = self.samples_df.filter(
2628
+ pl.col("sample_uid").is_in(sample_uids_to_remove)
2629
+ ).select("map_id")
2630
+ if not map_ids_df.is_empty():
2631
+ map_ids_to_remove = map_ids_df["map_id"].to_list()
2632
+
2633
+ # 1. Remove samples from samples_df
2634
+ self.samples_df = self.samples_df.filter(
2635
+ ~pl.col("sample_uid").is_in(sample_uids_to_remove),
2636
+ )
2637
+
2638
+ # 2. Remove corresponding features from features_df
2639
+ removed_features_count = 0
2640
+ if feature_uids_to_remove and self.features_df is not None and not self.features_df.is_empty():
2641
+ self.features_df = self.features_df.filter(
2642
+ ~pl.col("sample_uid").is_in(sample_uids_to_remove),
2643
+ )
2644
+ removed_features_count = initial_features_count - len(self.features_df)
2645
+
2646
+ # 3. Remove from consensus_mapping_df
2647
+ removed_mapping_count = 0
2648
+ if feature_uids_to_remove and self.consensus_mapping_df is not None and not self.consensus_mapping_df.is_empty():
2649
+ initial_mapping_count = len(self.consensus_mapping_df)
2650
+ self.consensus_mapping_df = self.consensus_mapping_df.filter(
2651
+ ~pl.col("feature_uid").is_in(feature_uids_to_remove),
2652
+ )
2653
+ removed_mapping_count = initial_mapping_count - len(self.consensus_mapping_df)
2654
+
2655
+ # 4. Remove from consensus_ms2 if it exists
2656
+ removed_ms2_count = 0
2657
+ if hasattr(self, "consensus_ms2") and self.consensus_ms2 is not None and not self.consensus_ms2.is_empty():
2658
+ initial_ms2_count = len(self.consensus_ms2)
2659
+ self.consensus_ms2 = self.consensus_ms2.filter(
2660
+ ~pl.col("sample_uid").is_in(sample_uids_to_remove),
2661
+ )
2662
+ removed_ms2_count = initial_ms2_count - len(self.consensus_ms2)
2663
+
2664
+ # 5. Remove from feature_maps and update map_id
2665
+ removed_maps_count = 0
2666
+ if hasattr(self, 'feature_maps') and self.feature_maps is not None and map_ids_to_remove:
2667
+ # Remove feature maps in reverse order to maintain indices
2668
+ for map_id in sorted(map_ids_to_remove, reverse=True):
2669
+ if 0 <= map_id < len(self.feature_maps):
2670
+ self.feature_maps.pop(map_id)
2671
+ removed_maps_count += 1
2672
+
2673
+ # Update map_id values in samples_df to maintain sequential indices
2674
+ if len(self.samples_df) > 0:
2675
+ new_map_ids = list(range(len(self.samples_df)))
2676
+ self.samples_df = self.samples_df.with_columns(
2677
+ pl.lit(new_map_ids).alias("map_id")
2678
+ )
2679
+
2680
+ # Calculate and log results
2681
+ removed_sample_count = initial_sample_count - len(self.samples_df)
2682
+ final_sample_count = len(self.samples_df)
2683
+
2684
+ # Create comprehensive summary message
2685
+ summary_parts = [
2686
+ f"Deleted {removed_sample_count} samples",
2687
+ ]
2688
+
2689
+ if removed_features_count > 0:
2690
+ summary_parts.append(f"{removed_features_count} features")
2691
+
2692
+ if removed_mapping_count > 0:
2693
+ summary_parts.append(f"{removed_mapping_count} consensus mappings")
2694
+
2695
+ if removed_ms2_count > 0:
2696
+ summary_parts.append(f"{removed_ms2_count} MS2 spectra")
2697
+
2698
+ if removed_maps_count > 0:
2699
+ summary_parts.append(f"{removed_maps_count} feature maps")
2700
+
2701
+ summary_parts.append(f"Remaining samples: {final_sample_count}")
2702
+
2703
+ self.logger.info(". ".join(summary_parts))
2704
+
2705
+ # Update map_id indices if needed
2706
+ if removed_maps_count > 0 and final_sample_count > 0:
2707
+ self.logger.debug(f"Updated map_id values to range from 0 to {final_sample_count - 1}")
2708
+
2709
+
2710
+ # =====================================================================================
2711
+ # COLOR PALETTE AND VISUALIZATION FUNCTIONS
2712
+ # =====================================================================================
2713
+
2714
+
2715
+ def sample_color(self, by=None, palette="Turbo256"):
2716
+ """
2717
+ Set sample colors in the sample_color column of samples_df.
2718
+
2719
+ When a new sample is added, this function resets all colors picking from the specified palette.
2720
+ The default palette is Turbo256.
2721
+
2722
+ Parameters:
2723
+ by (str or list, optional): Property to base colors on. Options:
2724
+ - 'sample_uid': Use sample_uid values to assign colors
2725
+ - 'sample_index': Use sample index (position) to assign colors
2726
+ - 'sample_type': Use sample_type values to assign colors
2727
+ - 'sample_name': Use sample_name values to assign colors
2728
+ - list of colors: Use provided list of hex color codes
2729
+ - None: Use sequential colors from palette (default)
2730
+ palette (str): Color palette to use. Options:
2731
+ - 'Turbo256': Turbo colormap (256 colors, perceptually uniform)
2732
+ - 'Viridis256': Viridis colormap (256 colors, perceptually uniform)
2733
+ - 'Plasma256': Plasma colormap (256 colors, perceptually uniform)
2734
+ - 'Inferno256': Inferno colormap (256 colors, perceptually uniform)
2735
+ - 'Magma256': Magma colormap (256 colors, perceptually uniform)
2736
+ - 'Cividis256': Cividis colormap (256 colors, colorblind-friendly)
2737
+ - 'Set1': Qualitative palette (9 distinct colors)
2738
+ - 'Set2': Qualitative palette (8 distinct colors)
2739
+ - 'Set3': Qualitative palette (12 distinct colors)
2740
+ - 'Tab10': Tableau 10 palette (10 distinct colors)
2741
+ - 'Tab20': Tableau 20 palette (20 distinct colors)
2742
+ - 'Dark2': Dark qualitative palette (8 colors)
2743
+ - 'Paired': Paired qualitative palette (12 colors)
2744
+ - 'Spectral': Spectral diverging colormap
2745
+ - 'Rainbow': Rainbow colormap
2746
+ - 'Coolwarm': Cool-warm diverging colormap
2747
+ - 'Seismic': Seismic diverging colormap
2748
+ - Any other colormap name supported by the cmap library
2749
+
2750
+ For a complete catalog of available colormaps, see:
2751
+ https://cmap-docs.readthedocs.io/en/latest/catalog/
2752
+
2753
+ Returns:
2754
+ None (modifies self.samples_df in place)
2755
+
2756
+ Example:
2757
+ # Set colors based on sample type
2758
+ study.sample_color(by='sample_type', palette='Set1')
2759
+
2760
+ # Set colors using a custom color list
2761
+ study.sample_color(by=['#FF0000', '#00FF00', '#0000FF'])
2762
+
2763
+ # Reset to default Turbo256 sequential colors
2764
+ study.sample_color()
2765
+ """
2766
+ if self.samples_df is None or len(self.samples_df) == 0:
2767
+ self.logger.warning("No samples found in study.")
2768
+ return
2769
+
2770
+ sample_count = len(self.samples_df)
2771
+
2772
+ # Handle custom color list
2773
+ if isinstance(by, list):
2774
+ if len(by) < sample_count:
2775
+ self.logger.warning(f"Provided color list has {len(by)} colors but {sample_count} samples. Repeating colors.")
2776
+ # Cycle through the provided colors if there aren't enough
2777
+ colors = []
2778
+ for i in range(sample_count):
2779
+ colors.append(by[i % len(by)])
2780
+ else:
2781
+ colors = by[:sample_count]
2782
+ else:
2783
+ # Use the new approach: sample colors evenly from the whole colormap
2784
+ if by is None:
2785
+ # Sequential colors evenly sampled from the colormap
2786
+ try:
2787
+ colors = _sample_colors_from_colormap(palette, sample_count)
2788
+ except ValueError as e:
2789
+ self.logger.error(f"Error sampling colors from colormap: {e}")
2790
+ return
2791
+
2792
+ elif by == 'sample_uid':
2793
+ # Use sample_uid to determine position in evenly sampled colormap
2794
+ sample_uids = self.samples_df['sample_uid'].to_list()
2795
+ try:
2796
+ # Sample colors evenly for the number of samples
2797
+ palette_colors = _sample_colors_from_colormap(palette, sample_count)
2798
+ colors = []
2799
+ for uid in sample_uids:
2800
+ # Use modulo to cycle through evenly sampled colors
2801
+ color_index = uid % len(palette_colors)
2802
+ colors.append(palette_colors[color_index])
2803
+ except ValueError as e:
2804
+ self.logger.error(f"Error sampling colors from colormap: {e}")
2805
+ return
2806
+
2807
+ elif by == 'sample_index':
2808
+ # Use sample index (position in DataFrame) with evenly sampled colors
2809
+ try:
2810
+ colors = _sample_colors_from_colormap(palette, sample_count)
2811
+ except ValueError as e:
2812
+ self.logger.error(f"Error sampling colors from colormap: {e}")
2813
+ return
2814
+
2815
+ elif by == 'sample_type':
2816
+ # Use sample_type to assign colors - same type gets same color
2817
+ # Sample colors evenly across colormap for unique types
2818
+ sample_types = self.samples_df['sample_type'].to_list()
2819
+ unique_types = list(set([t for t in sample_types if t is not None]))
2820
+
2821
+ try:
2822
+ # Sample colors evenly for unique types
2823
+ type_colors = _sample_colors_from_colormap(palette, len(unique_types))
2824
+ type_to_color = {}
2825
+
2826
+ for i, sample_type in enumerate(unique_types):
2827
+ type_to_color[sample_type] = type_colors[i]
2828
+
2829
+ colors = []
2830
+ for sample_type in sample_types:
2831
+ if sample_type is None:
2832
+ # Default to first color for None
2833
+ colors.append(type_colors[0] if type_colors else "#000000")
2834
+ else:
2835
+ colors.append(type_to_color[sample_type])
2836
+ except ValueError as e:
2837
+ self.logger.error(f"Error sampling colors from colormap: {e}")
2838
+ return
2839
+
2840
+ elif by == 'sample_name':
2841
+ # Use sample_name to assign colors - same name gets same color (unlikely but possible)
2842
+ # Sample colors evenly across colormap for unique names
2843
+ sample_names = self.samples_df['sample_name'].to_list()
2844
+ unique_names = list(set([n for n in sample_names if n is not None]))
2845
+
2846
+ try:
2847
+ # Sample colors evenly for unique names
2848
+ name_colors = _sample_colors_from_colormap(palette, len(unique_names))
2849
+ name_to_color = {}
2850
+
2851
+ for i, sample_name in enumerate(unique_names):
2852
+ name_to_color[sample_name] = name_colors[i]
2853
+
2854
+ colors = []
2855
+ for sample_name in sample_names:
2856
+ if sample_name is None:
2857
+ # Default to first color for None
2858
+ colors.append(name_colors[0] if name_colors else "#000000")
2859
+ else:
2860
+ colors.append(name_to_color[sample_name])
2861
+ except ValueError as e:
2862
+ self.logger.error(f"Error sampling colors from colormap: {e}")
2863
+ return
2864
+ else:
2865
+ self.logger.error(f"Invalid by value: {by}. Must be 'sample_uid', 'sample_index', 'sample_type', 'sample_name', a list of colors, or None.")
2866
+ return
2867
+
2868
+ # Update the sample_color column
2869
+ self.samples_df = self.samples_df.with_columns(
2870
+ pl.Series("sample_color", colors).alias("sample_color")
2871
+ )
2872
+
2873
+ if isinstance(by, list):
2874
+ self.logger.debug(f"Set sample colors using provided color list ({len(by)} colors)")
2875
+ elif by is None:
2876
+ self.logger.debug(f"Set sequential sample colors using {palette} palette")
2877
+ else:
2878
+ self.logger.debug(f"Set sample colors based on {by} using {palette} palette")
2879
+
2880
+
2881
+ def sample_color_reset(self):
2882
+ """
2883
+ Reset sample colors to default coloring using the 'turbo' colormap.
2884
+
2885
+ This function assigns colors by distributing samples evenly across the full
2886
+ turbo colormap range, ensuring maximum color diversity and visual distinction
2887
+ between samples.
2888
+
2889
+ Returns:
2890
+ None (modifies self.samples_df in place)
2891
+ """
2892
+ if self.samples_df is None or len(self.samples_df) == 0:
2893
+ self.logger.warning("No samples found in study.")
2894
+ return
2895
+
2896
+ try:
2897
+ from cmap import Colormap
2898
+
2899
+ # Use turbo colormap
2900
+ cm = Colormap('turbo')
2901
+
2902
+ # Get sample count and assign colors evenly distributed across colormap
2903
+ n_samples = len(self.samples_df)
2904
+ colors = []
2905
+
2906
+ # Distribute samples evenly across the full colormap range
2907
+ for i in range(n_samples):
2908
+ # Evenly distribute samples across colormap (avoiding endpoints to prevent white/black)
2909
+ normalized_value = (i + 0.5) / n_samples # +0.5 to center samples in their bins
2910
+ # Optionally, map to a subset of colormap to avoid extreme colors
2911
+ # Use 10% to 90% of colormap range for better color diversity
2912
+ normalized_value = 0.1 + (normalized_value * 0.8)
2913
+
2914
+ color_rgba = cm(normalized_value)
2915
+
2916
+ # Convert RGBA to hex
2917
+ if len(color_rgba) >= 3:
2918
+ r, g, b = color_rgba[:3]
2919
+ # Convert to 0-255 range if needed
2920
+ if max(color_rgba[:3]) <= 1.0:
2921
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
2922
+ hex_color = f"#{r:02x}{g:02x}{b:02x}"
2923
+ colors.append(hex_color)
2924
+
2925
+ # Update the sample_color column
2926
+ self.samples_df = self.samples_df.with_columns(
2927
+ pl.Series("sample_color", colors).alias("sample_color")
2928
+ )
2929
+
2930
+ self.logger.debug(f"Reset sample colors using turbo colormap with even distribution ({n_samples} samples)")
2931
+
2932
+ except ImportError:
2933
+ self.logger.error("cmap library is required for sample color reset. Install with: pip install cmap")
2934
+ except Exception as e:
2935
+ self.logger.error(f"Failed to reset sample colors: {e}")
2936
+
2937
+
2938
+ def _get_color_palette(palette_name):
2939
+ """
2940
+ Get color palette as a list of hex color codes using the cmap library.
2941
+
2942
+ Parameters:
2943
+ palette_name (str): Name of the palette
2944
+
2945
+ Returns:
2946
+ list: List of hex color codes
2947
+
2948
+ Raises:
2949
+ ValueError: If palette_name is not supported
2950
+ """
2951
+ try:
2952
+ from cmap import Colormap
2953
+ except ImportError:
2954
+ raise ValueError("cmap library is required for color palettes. Install with: pip install cmap")
2955
+
2956
+ # Map common palette names to cmap names
2957
+ palette_mapping = {
2958
+ # Scientific colormaps
2959
+ "Turbo256": "turbo",
2960
+ "Viridis256": "viridis",
2961
+ "Plasma256": "plasma",
2962
+ "Inferno256": "inferno",
2963
+ "Magma256": "magma",
2964
+ "Cividis256": "cividis",
2965
+
2966
+ # Qualitative palettes
2967
+ "Set1": "Set1",
2968
+ "Set2": "Set2",
2969
+ "Set3": "Set3",
2970
+ "Tab10": "tab10",
2971
+ "Tab20": "tab20",
2972
+ "Dark2": "Dark2",
2973
+ "Paired": "Paired",
2974
+
2975
+ # Additional useful palettes
2976
+ "Spectral": "Spectral",
2977
+ "Rainbow": "rainbow",
2978
+ "Coolwarm": "coolwarm",
2979
+ "Seismic": "seismic",
2980
+ }
2981
+
2982
+ # Get the cmap name
2983
+ cmap_name = palette_mapping.get(palette_name, palette_name.lower())
2984
+
2985
+ try:
2986
+ # Create colormap
2987
+ cm = Colormap(cmap_name)
2988
+
2989
+ # Determine number of colors to generate
2990
+ if "256" in palette_name:
2991
+ n_colors = 256
2992
+ elif palette_name in ["Set1"]:
2993
+ n_colors = 9
2994
+ elif palette_name in ["Set2", "Dark2"]:
2995
+ n_colors = 8
2996
+ elif palette_name in ["Set3", "Paired"]:
2997
+ n_colors = 12
2998
+ elif palette_name in ["Tab10"]:
2999
+ n_colors = 10
3000
+ elif palette_name in ["Tab20"]:
3001
+ n_colors = 20
3002
+ else:
3003
+ n_colors = 256 # Default for continuous colormaps
3004
+
3005
+ # Generate colors
3006
+ if n_colors <= 20:
3007
+ # For discrete palettes, use evenly spaced indices
3008
+ indices = [i / (n_colors - 1) for i in range(n_colors)]
3009
+ else:
3010
+ # For continuous palettes, use full range
3011
+ indices = [i / (n_colors - 1) for i in range(n_colors)]
3012
+
3013
+ # Get colors as RGBA and convert to hex
3014
+ colors = cm(indices)
3015
+ hex_colors = []
3016
+
3017
+ for color in colors:
3018
+ if len(color) >= 3: # RGBA or RGB
3019
+ r, g, b = color[:3]
3020
+ # Convert to 0-255 range if needed
3021
+ if max(color[:3]) <= 1.0:
3022
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
3023
+ hex_color = f"#{r:02x}{g:02x}{b:02x}"
3024
+ hex_colors.append(hex_color)
3025
+
3026
+ return hex_colors
3027
+
3028
+ except Exception as e:
3029
+ raise ValueError(f"Failed to create colormap '{cmap_name}': {e}. "
3030
+ f"Available palettes: {list(palette_mapping.keys())}")
3031
+
3032
+
3033
+ def _sample_colors_from_colormap(palette_name, n_colors):
3034
+ """
3035
+ Sample colors evenly from the whole colormap range, similar to sample_color_reset.
3036
+
3037
+ Parameters:
3038
+ palette_name (str): Name of the palette/colormap
3039
+ n_colors (int): Number of colors to sample
3040
+
3041
+ Returns:
3042
+ list: List of hex color codes sampled evenly from the colormap
3043
+
3044
+ Raises:
3045
+ ValueError: If palette_name is not supported
3046
+ """
3047
+ try:
3048
+ from cmap import Colormap
3049
+ except ImportError:
3050
+ raise ValueError("cmap library is required for color palettes. Install with: pip install cmap")
3051
+
3052
+ # Map common palette names to cmap names (same as _get_color_palette)
3053
+ palette_mapping = {
3054
+ # Scientific colormaps
3055
+ "Turbo256": "turbo",
3056
+ "Viridis256": "viridis",
3057
+ "Plasma256": "plasma",
3058
+ "Inferno256": "inferno",
3059
+ "Magma256": "magma",
3060
+ "Cividis256": "cividis",
3061
+
3062
+ # Qualitative palettes
3063
+ "Set1": "Set1",
3064
+ "Set2": "Set2",
3065
+ "Set3": "Set3",
3066
+ "Tab10": "tab10",
3067
+ "Tab20": "tab20",
3068
+ "Dark2": "Dark2",
3069
+ "Paired": "Paired",
3070
+
3071
+ # Additional useful palettes
3072
+ "Spectral": "Spectral",
3073
+ "Rainbow": "rainbow",
3074
+ "Coolwarm": "coolwarm",
3075
+ "Seismic": "seismic",
3076
+ }
3077
+
3078
+ # Get the cmap name
3079
+ cmap_name = palette_mapping.get(palette_name, palette_name.lower())
3080
+
3081
+ try:
3082
+ # Create colormap
3083
+ cm = Colormap(cmap_name)
3084
+
3085
+ colors = []
3086
+
3087
+ # Distribute samples evenly across the full colormap range (same approach as sample_color_reset)
3088
+ for i in range(n_colors):
3089
+ # Evenly distribute samples across colormap (avoiding endpoints to prevent white/black)
3090
+ normalized_value = (i + 0.5) / n_colors # +0.5 to center samples in their bins
3091
+ # Map to a subset of colormap to avoid extreme colors (use 10% to 90% range)
3092
+ normalized_value = 0.1 + (normalized_value * 0.8)
3093
+
3094
+ color_rgba = cm(normalized_value)
3095
+
3096
+ # Convert RGBA to hex
3097
+ if len(color_rgba) >= 3:
3098
+ r, g, b = color_rgba[:3]
3099
+ # Convert to 0-255 range if needed
3100
+ if max(color_rgba[:3]) <= 1.0:
3101
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
3102
+ hex_color = f"#{r:02x}{g:02x}{b:02x}"
3103
+ colors.append(hex_color)
3104
+
3105
+ return colors
3106
+
3107
+ except Exception as e:
3108
+ raise ValueError(f"Failed to create colormap '{cmap_name}': {e}. "
3109
+ f"Available palettes: {list(palette_mapping.keys())}")
3110
+
3111
+
3112
+ def _matplotlib_to_hex(color_dict):
3113
+ """Convert matplotlib color dictionary to list of hex colors."""
3114
+ return list(color_dict.values())
3115
+
3116
+
3117
+ # =====================================================================================
3118
+ # SCHEMA AND DATA STRUCTURE FUNCTIONS
3119
+ # =====================================================================================
3120
+
3121
+
3122
+ def _ensure_features_df_schema_order(self):
3123
+ """
3124
+ Ensure features_df columns are ordered according to study5_schema.json.
3125
+
3126
+ This method should be called after operations that might scramble the column order.
3127
+ """
3128
+ if self.features_df is None or self.features_df.is_empty():
3129
+ return
3130
+
3131
+ try:
3132
+ import os
3133
+ import json
3134
+ from masster.study.h5 import _reorder_columns_by_schema
3135
+
3136
+ # Load schema
3137
+ schema_path = os.path.join(os.path.dirname(__file__), "study5_schema.json")
3138
+ with open(schema_path, 'r') as f:
3139
+ schema = json.load(f)
3140
+
3141
+ # Reorder columns to match schema
3142
+ self.features_df = _reorder_columns_by_schema(self.features_df, schema, 'features_df')
3143
+
3144
+ except Exception as e:
3145
+ self.logger.warning(f"Failed to reorder features_df columns: {e}")
3146
+
3147
+
3148
+ def migrate_map_id_to_index(self):
3149
+ """
3150
+ Migrate map_id from string-based OpenMS unique IDs to integer indices.
3151
+
3152
+ This function converts the map_id column from string type (with OpenMS unique IDs)
3153
+ to integer type where each map_id corresponds to the index of the feature map
3154
+ in self.features_maps.
3155
+
3156
+ This migration is needed for studies that were created before the map_id format
3157
+ change from OpenMS unique IDs to feature map indices.
3158
+ """
3159
+ if self.samples_df is None or self.samples_df.is_empty():
3160
+ self.logger.warning("No samples to migrate")
3161
+ return
3162
+
3163
+ # Check if migration is needed
3164
+ current_dtype = self.samples_df['map_id'].dtype
3165
+ if current_dtype == pl.Int64:
3166
+ self.logger.info("map_id column is already Int64 type - no migration needed")
3167
+ return
3168
+
3169
+ self.logger.info("Migrating map_id from string-based OpenMS IDs to integer indices")
3170
+
3171
+ # Create new map_id values based on sample order
3172
+ # Each sample gets a map_id that corresponds to its position in features_maps
3173
+ sample_count = len(self.samples_df)
3174
+ new_map_ids = list(range(sample_count))
3175
+
3176
+ # Update the map_id column
3177
+ self.samples_df = self.samples_df.with_columns(
3178
+ pl.lit(new_map_ids).alias("map_id")
3179
+ )
3180
+
3181
+ # Ensure the column is Int64 type
3182
+ self.samples_df = self.samples_df.cast({"map_id": pl.Int64})
3183
+
3184
+ self.logger.info(f"Successfully migrated {sample_count} samples to indexed map_id format")
3185
+ self.logger.info(f"map_id now ranges from 0 to {sample_count - 1}")