smftools 0.1.6__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. smftools/__init__.py +34 -0
  2. smftools/_settings.py +20 -0
  3. smftools/_version.py +1 -0
  4. smftools/cli.py +184 -0
  5. smftools/config/__init__.py +1 -0
  6. smftools/config/conversion.yaml +33 -0
  7. smftools/config/deaminase.yaml +56 -0
  8. smftools/config/default.yaml +253 -0
  9. smftools/config/direct.yaml +17 -0
  10. smftools/config/experiment_config.py +1191 -0
  11. smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
  12. smftools/datasets/F1_sample_sheet.csv +5 -0
  13. smftools/datasets/__init__.py +9 -0
  14. smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
  15. smftools/datasets/datasets.py +28 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/hmm/apply_hmm_batched.py +242 -0
  19. smftools/hmm/calculate_distances.py +18 -0
  20. smftools/hmm/call_hmm_peaks.py +106 -0
  21. smftools/hmm/display_hmm.py +18 -0
  22. smftools/hmm/hmm_readwrite.py +16 -0
  23. smftools/hmm/nucleosome_hmm_refinement.py +104 -0
  24. smftools/hmm/train_hmm.py +78 -0
  25. smftools/informatics/__init__.py +14 -0
  26. smftools/informatics/archived/bam_conversion.py +59 -0
  27. smftools/informatics/archived/bam_direct.py +63 -0
  28. smftools/informatics/archived/basecalls_to_adata.py +71 -0
  29. smftools/informatics/archived/conversion_smf.py +132 -0
  30. smftools/informatics/archived/deaminase_smf.py +132 -0
  31. smftools/informatics/archived/direct_smf.py +137 -0
  32. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  33. smftools/informatics/basecall_pod5s.py +80 -0
  34. smftools/informatics/fast5_to_pod5.py +24 -0
  35. smftools/informatics/helpers/__init__.py +73 -0
  36. smftools/informatics/helpers/align_and_sort_BAM.py +86 -0
  37. smftools/informatics/helpers/aligned_BAM_to_bed.py +85 -0
  38. smftools/informatics/helpers/archived/informatics.py +260 -0
  39. smftools/informatics/helpers/archived/load_adata.py +516 -0
  40. smftools/informatics/helpers/bam_qc.py +66 -0
  41. smftools/informatics/helpers/bed_to_bigwig.py +39 -0
  42. smftools/informatics/helpers/binarize_converted_base_identities.py +172 -0
  43. smftools/informatics/helpers/canoncall.py +34 -0
  44. smftools/informatics/helpers/complement_base_list.py +21 -0
  45. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +378 -0
  46. smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
  47. smftools/informatics/helpers/converted_BAM_to_adata_II.py +505 -0
  48. smftools/informatics/helpers/count_aligned_reads.py +43 -0
  49. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  50. smftools/informatics/helpers/discover_input_files.py +100 -0
  51. smftools/informatics/helpers/extract_base_identities.py +70 -0
  52. smftools/informatics/helpers/extract_mods.py +83 -0
  53. smftools/informatics/helpers/extract_read_features_from_bam.py +33 -0
  54. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  55. smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
  56. smftools/informatics/helpers/find_conversion_sites.py +51 -0
  57. smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
  58. smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
  59. smftools/informatics/helpers/get_native_references.py +28 -0
  60. smftools/informatics/helpers/index_fasta.py +12 -0
  61. smftools/informatics/helpers/make_dirs.py +21 -0
  62. smftools/informatics/helpers/make_modbed.py +27 -0
  63. smftools/informatics/helpers/modQC.py +27 -0
  64. smftools/informatics/helpers/modcall.py +36 -0
  65. smftools/informatics/helpers/modkit_extract_to_adata.py +887 -0
  66. smftools/informatics/helpers/ohe_batching.py +76 -0
  67. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  68. smftools/informatics/helpers/one_hot_decode.py +27 -0
  69. smftools/informatics/helpers/one_hot_encode.py +57 -0
  70. smftools/informatics/helpers/plot_bed_histograms.py +269 -0
  71. smftools/informatics/helpers/run_multiqc.py +28 -0
  72. smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
  73. smftools/informatics/helpers/split_and_index_BAM.py +32 -0
  74. smftools/informatics/readwrite.py +106 -0
  75. smftools/informatics/subsample_fasta_from_bed.py +47 -0
  76. smftools/informatics/subsample_pod5.py +104 -0
  77. smftools/load_adata.py +1346 -0
  78. smftools/machine_learning/__init__.py +12 -0
  79. smftools/machine_learning/data/__init__.py +2 -0
  80. smftools/machine_learning/data/anndata_data_module.py +234 -0
  81. smftools/machine_learning/data/preprocessing.py +6 -0
  82. smftools/machine_learning/evaluation/__init__.py +2 -0
  83. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  84. smftools/machine_learning/evaluation/evaluators.py +223 -0
  85. smftools/machine_learning/inference/__init__.py +3 -0
  86. smftools/machine_learning/inference/inference_utils.py +27 -0
  87. smftools/machine_learning/inference/lightning_inference.py +68 -0
  88. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  89. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  90. smftools/machine_learning/models/__init__.py +9 -0
  91. smftools/machine_learning/models/base.py +295 -0
  92. smftools/machine_learning/models/cnn.py +138 -0
  93. smftools/machine_learning/models/lightning_base.py +345 -0
  94. smftools/machine_learning/models/mlp.py +26 -0
  95. smftools/machine_learning/models/positional.py +18 -0
  96. smftools/machine_learning/models/rnn.py +17 -0
  97. smftools/machine_learning/models/sklearn_models.py +273 -0
  98. smftools/machine_learning/models/transformer.py +303 -0
  99. smftools/machine_learning/models/wrappers.py +20 -0
  100. smftools/machine_learning/training/__init__.py +2 -0
  101. smftools/machine_learning/training/train_lightning_model.py +135 -0
  102. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  103. smftools/machine_learning/utils/__init__.py +2 -0
  104. smftools/machine_learning/utils/device.py +10 -0
  105. smftools/machine_learning/utils/grl.py +14 -0
  106. smftools/plotting/__init__.py +18 -0
  107. smftools/plotting/autocorrelation_plotting.py +611 -0
  108. smftools/plotting/classifiers.py +355 -0
  109. smftools/plotting/general_plotting.py +682 -0
  110. smftools/plotting/hmm_plotting.py +260 -0
  111. smftools/plotting/position_stats.py +462 -0
  112. smftools/plotting/qc_plotting.py +270 -0
  113. smftools/preprocessing/__init__.py +38 -0
  114. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  115. smftools/preprocessing/append_base_context.py +122 -0
  116. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  117. smftools/preprocessing/archives/mark_duplicates.py +146 -0
  118. smftools/preprocessing/archives/preprocessing.py +614 -0
  119. smftools/preprocessing/archives/remove_duplicates.py +21 -0
  120. smftools/preprocessing/binarize_on_Youden.py +45 -0
  121. smftools/preprocessing/binary_layers_to_ohe.py +40 -0
  122. smftools/preprocessing/calculate_complexity.py +72 -0
  123. smftools/preprocessing/calculate_complexity_II.py +248 -0
  124. smftools/preprocessing/calculate_consensus.py +47 -0
  125. smftools/preprocessing/calculate_coverage.py +51 -0
  126. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  127. smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
  128. smftools/preprocessing/calculate_position_Youden.py +115 -0
  129. smftools/preprocessing/calculate_read_length_stats.py +79 -0
  130. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  131. smftools/preprocessing/clean_NaN.py +62 -0
  132. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  133. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  134. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  135. smftools/preprocessing/flag_duplicate_reads.py +1351 -0
  136. smftools/preprocessing/invert_adata.py +37 -0
  137. smftools/preprocessing/load_sample_sheet.py +53 -0
  138. smftools/preprocessing/make_dirs.py +21 -0
  139. smftools/preprocessing/min_non_diagonal.py +25 -0
  140. smftools/preprocessing/recipes.py +127 -0
  141. smftools/preprocessing/subsample_adata.py +58 -0
  142. smftools/readwrite.py +1004 -0
  143. smftools/tools/__init__.py +20 -0
  144. smftools/tools/archived/apply_hmm.py +202 -0
  145. smftools/tools/archived/classifiers.py +787 -0
  146. smftools/tools/archived/classify_methylated_features.py +66 -0
  147. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  148. smftools/tools/archived/subset_adata_v1.py +32 -0
  149. smftools/tools/archived/subset_adata_v2.py +46 -0
  150. smftools/tools/calculate_umap.py +62 -0
  151. smftools/tools/cluster_adata_on_methylation.py +105 -0
  152. smftools/tools/general_tools.py +69 -0
  153. smftools/tools/position_stats.py +601 -0
  154. smftools/tools/read_stats.py +184 -0
  155. smftools/tools/spatial_autocorrelation.py +562 -0
  156. smftools/tools/subset_adata.py +28 -0
  157. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/METADATA +9 -2
  158. smftools-0.2.1.dist-info/RECORD +161 -0
  159. smftools-0.2.1.dist-info/entry_points.txt +2 -0
  160. smftools-0.1.6.dist-info/RECORD +0 -4
  161. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
  162. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,37 @@
1
+ ## invert_adata
2
+
3
+ def invert_adata(adata, uns_flag='adata_positions_inverted', force_redo=False):
4
+ """
5
+ Inverts the AnnData object along the column (variable) axis.
6
+
7
+ Parameters:
8
+ adata (AnnData): An AnnData object.
9
+
10
+ Returns:
11
+ AnnData: A new AnnData object with inverted column ordering.
12
+ """
13
+ import numpy as np
14
+ import anndata as ad
15
+
16
+ # Only run if not already performed
17
+ already = bool(adata.uns.get(uns_flag, False))
18
+ if (already and not force_redo):
19
+ # QC already performed; nothing to do
20
+ return adata
21
+
22
+ print("Inverting AnnData along the column axis...")
23
+
24
+ # Reverse the order of columns (variables)
25
+ inverted_adata = adata[:, ::-1].copy()
26
+
27
+ # Reassign var_names with new order
28
+ inverted_adata.var_names = adata.var_names
29
+
30
+ # Optional: Store original coordinates for reference
31
+ inverted_adata.var["Original_var_names"] = adata.var_names[::-1]
32
+
33
+ # mark as done
34
+ inverted_adata.uns[uns_flag] = True
35
+
36
+ print("Inversion complete!")
37
+ return inverted_adata
@@ -0,0 +1,53 @@
1
+ def load_sample_sheet(adata,
2
+ sample_sheet_path,
3
+ mapping_key_column='obs_names',
4
+ as_category=True,
5
+ uns_flag='sample_sheet_loaded',
6
+ force_reload=True
7
+ ):
8
+ """
9
+ Loads a sample sheet CSV and maps metadata into the AnnData object as categorical columns.
10
+
11
+ Parameters:
12
+ adata (AnnData): The AnnData object to append sample information to.
13
+ sample_sheet_path (str): Path to the CSV file.
14
+ mapping_key_column (str): Column name in the CSV to map against adata.obs_names or an existing obs column.
15
+ as_category (bool): If True, added columns will be cast as pandas Categorical.
16
+
17
+ Returns:
18
+ AnnData: Updated AnnData object.
19
+ """
20
+ import pandas as pd
21
+
22
+ # Only run if not already performed
23
+ already = bool(adata.uns.get(uns_flag, False))
24
+ if already and not force_reload:
25
+ # QC already performed; nothing to do
26
+ return
27
+
28
+ print('Loading sample sheet...')
29
+ df = pd.read_csv(sample_sheet_path)
30
+ df[mapping_key_column] = df[mapping_key_column].astype(str)
31
+
32
+ # If matching against obs_names directly
33
+ if mapping_key_column == 'obs_names':
34
+ key_series = adata.obs_names.astype(str)
35
+ else:
36
+ key_series = adata.obs[mapping_key_column].astype(str)
37
+
38
+ value_columns = [col for col in df.columns if col != mapping_key_column]
39
+
40
+ print(f'Appending metadata columns: {value_columns}')
41
+ df = df.set_index(mapping_key_column)
42
+
43
+ for col in value_columns:
44
+ mapped = key_series.map(df[col])
45
+ if as_category:
46
+ mapped = mapped.astype('category')
47
+ adata.obs[col] = mapped
48
+
49
+ # mark as done
50
+ adata.uns[uns_flag] = True
51
+
52
+ print('Sample sheet metadata successfully added as categories.' if as_category else 'Metadata added.')
53
+ return adata
@@ -0,0 +1,21 @@
1
+ ## make_dirs
2
+
3
+ # General
4
+ def make_dirs(directories):
5
+ """
6
+ Takes a list of file paths and makes new directories if the directory does not already exist.
7
+
8
+ Parameters:
9
+ directories (list): A list of directories to make
10
+
11
+ Returns:
12
+ None
13
+ """
14
+ import os
15
+
16
+ for directory in directories:
17
+ if not os.path.isdir(directory):
18
+ os.mkdir(directory)
19
+ print(f"Directory '{directory}' created successfully.")
20
+ else:
21
+ print(f"Directory '{directory}' already exists.")
@@ -0,0 +1,25 @@
1
+ ## min_non_diagonal
2
+
3
+ def min_non_diagonal(matrix):
4
+ """
5
+ Takes a matrix and returns the smallest value from each row with the diagonal masked.
6
+
7
+ Parameters:
8
+ matrix (ndarray): A 2D ndarray.
9
+
10
+ Returns:
11
+ min_values (list): A list of minimum values from each row of the matrix
12
+ """
13
+ import numpy as np
14
+
15
+ n = matrix.shape[0]
16
+ min_values = []
17
+ for i in range(n):
18
+ # Mask to exclude the diagonal element
19
+ row_mask = np.ones(n, dtype=bool)
20
+ row_mask[i] = False
21
+ # Extract the row excluding the diagonal element
22
+ row = matrix[i, row_mask]
23
+ # Find the minimum value in the row
24
+ min_values.append(np.min(row))
25
+ return min_values
@@ -0,0 +1,127 @@
1
+ # recipes
2
+
3
+ def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory, mapping_key_column='Sample', reference_column = 'Reference', sample_names_col='Sample_names', invert=True):
4
+ """
5
+ The first part of the preprocessing workflow applied to the smf.inform.pod_to_adata() output derived from Kissiov_and_McKenna_2025.
6
+
7
+ Performs the following tasks:
8
+ 1) Loads a sample CSV to append metadata mappings to the adata object.
9
+ 2) Appends a boolean indicating whether each position in var_names is within a given reference.
10
+ 3) Appends the cytosine context to each position from each reference.
11
+ 4) Calculate read level methylation statistics.
12
+ 5) Calculates read length statistics (start position, end position, read length).
13
+ 6) Optionally inverts the adata to flip the position coordinate orientation.
14
+ 7) Adds new layers containing NaN replaced variants of adata.X (fill_closest, nan0_0minus1, nan1_12).
15
+ 8) Returns a dictionary to pass the variable namespace to the parent scope.
16
+
17
+ Parameters:
18
+ adata (AnnData): The AnnData object to use as input.
19
+ sample_sheet_path (str): String representing the path to the sample sheet csv containing the sample metadata.
20
+ output_directory (str): String representing the path to the output directory for plots.
21
+ mapping_key_column (str): The column name to use as the mapping keys for applying the sample sheet metadata.
22
+ reference_column (str): The name of the reference column to use.
23
+ sample_names_col (str): The name of the sample name column to use.
24
+ invert (bool): Whether to invert the positional coordinates of the adata object.
25
+
26
+ Returns:
27
+ variables (dict): A dictionary of variables to append to the parent scope.
28
+ """
29
+ import anndata as ad
30
+ import pandas as pd
31
+ import numpy as np
32
+ from .load_sample_sheet import load_sample_sheet
33
+ from .calculate_coverage import calculate_coverage
34
+ from .append_C_context import append_C_context
35
+ from .calculate_converted_read_methylation_stats import calculate_converted_read_methylation_stats
36
+ from .invert_adata import invert_adata
37
+ from .calculate_read_length_stats import calculate_read_length_stats
38
+ from .clean_NaN import clean_NaN
39
+
40
+ # Clean up some of the Reference metadata and save variable names that point to sets of values in the column.
41
+ adata.obs[reference_column] = adata.obs[reference_column].astype('category')
42
+ references = adata.obs[reference_column].cat.categories
43
+ split_references = [(reference, reference.split('_')[0][1:]) for reference in references]
44
+ reference_mapping = {k: v for k, v in split_references}
45
+ adata.obs[f'{reference_column}_short'] = adata.obs[reference_column].map(reference_mapping)
46
+ short_references = set(adata.obs[f'{reference_column}_short'])
47
+ binary_layers = list(adata.layers.keys())
48
+
49
+ # load sample sheet metadata
50
+ load_sample_sheet(adata, sample_sheet_path, mapping_key_column)
51
+
52
+ # hold sample names set
53
+ adata.obs[sample_names_col] = adata.obs[sample_names_col].astype('category')
54
+ sample_names = adata.obs[sample_names_col].cat.categories
55
+
56
+ # Add position level metadata
57
+ calculate_coverage(adata, obs_column=reference_column)
58
+ adata.var['SNP_position'] = (adata.var[f'N_{reference_column}_with_position'] > 0) & (adata.var[f'N_{reference_column}_with_position'] < len(references)).astype(bool)
59
+
60
+ # Append cytosine context to the reference positions based on the conversion strand.
61
+ append_C_context(adata, obs_column=reference_column, use_consensus=False)
62
+
63
+ # Calculate read level methylation statistics. Assess if GpC methylation level is above other_C methylation level as a QC.
64
+ calculate_converted_read_methylation_stats(adata, reference_column, sample_names_col)
65
+
66
+ # Calculate read length statistics
67
+ upper_bound, lower_bound = calculate_read_length_stats(adata, reference_column, sample_names_col)
68
+
69
+ # Invert the adata object (ie flip the strand orientation for visualization)
70
+ if invert:
71
+ adata = invert_adata(adata)
72
+ else:
73
+ pass
74
+
75
+ # NaN replacement strategies stored in additional layers. Having layer=None uses adata.X
76
+ clean_NaN(adata, layer=None)
77
+
78
+ variables = {
79
+ "short_references": short_references,
80
+ "binary_layers": binary_layers,
81
+ "sample_names": sample_names,
82
+ "upper_bound": upper_bound,
83
+ "lower_bound": lower_bound,
84
+ "references": references
85
+ }
86
+ return variables
87
+
88
+ def recipe_2_Kissiov_and_McKenna_2025(adata, output_directory, binary_layers, distance_thresholds={}, reference_column = 'Reference', sample_names_col='Sample_names'):
89
+ """
90
+ The second part of the preprocessing workflow applied to the adata that has already been preprocessed by recipe_1_Kissiov_and_McKenna_2025.
91
+
92
+ Performs the following tasks:
93
+ 1) Marks putative PCR duplicates using pairwise hamming distance metrics.
94
+ 2) Performs a complexity analysis of the library based on the PCR duplicate detection rate.
95
+ 3) Removes PCR duplicates from the adata.
96
+ 4) Returns two adata object: one for the filtered adata and one for the duplicate adata.
97
+
98
+ Parameters:
99
+ adata (AnnData): The AnnData object to use as input.
100
+ output_directory (str): String representing the path to the output directory for plots.
101
+ binary_layers (list): A list of layers to used for the binary encoding of read sequences. Used for duplicate detection.
102
+ distance_thresholds (dict): A dictionary keyed by obs_column categories that points to a float corresponding to the distance threshold to apply. Default is an empty dict.
103
+ reference_column (str): The name of the reference column to use.
104
+ sample_names_col (str): The name of the sample name column to use.
105
+
106
+ Returns:
107
+ filtered_adata (AnnData): An AnnData object containing the filtered reads
108
+ duplicates (AnnData): An AnnData object containing the duplicate reads
109
+ """
110
+ import anndata as ad
111
+ import pandas as pd
112
+ import numpy as np
113
+ from .mark_duplicates import mark_duplicates
114
+ from .calculate_complexity import calculate_complexity
115
+ from .remove_duplicates import remove_duplicates
116
+
117
+ # Add here a way to remove reads below a given read quality (based on nan content). Need to also add a way to pull from BAM files the read quality from each read
118
+
119
+ # Duplicate detection using pairwise hamming distance across reads
120
+ mark_duplicates(adata, binary_layers, obs_column=reference_column, sample_col=sample_names_col, distance_thresholds=distance_thresholds, method='N_masked_distances')
121
+
122
+ # Complexity analysis using the marked duplicates and the lander-watermann algorithm
123
+ calculate_complexity(adata, output_directory, obs_column=reference_column, sample_col=sample_names_col, plot=True, save_plot=False)
124
+
125
+ # Remove duplicate reads and store the duplicate reads in a new AnnData object named duplicates.
126
+ filtered_adata, duplicates = remove_duplicates(adata)
127
+ return filtered_adata, duplicates
@@ -0,0 +1,58 @@
1
+ def subsample_adata(adata, obs_columns=None, max_samples=2000, random_seed=42):
2
+ """
3
+ Subsamples an AnnData object so that each unique combination of categories
4
+ in the given `obs_columns` has at most `max_samples` observations.
5
+ If `obs_columns` is None or empty, the function randomly subsamples the entire dataset.
6
+
7
+ Parameters:
8
+ adata (AnnData): The AnnData object to subsample.
9
+ obs_columns (list of str, optional): List of observation column names to group by.
10
+ max_samples (int): The maximum number of observations per category combination.
11
+ random_seed (int): Random seed for reproducibility.
12
+
13
+ Returns:
14
+ AnnData: A new AnnData object with subsampled observations.
15
+ """
16
+ import anndata as ad
17
+ import numpy as np
18
+
19
+ np.random.seed(random_seed) # Ensure reproducibility
20
+
21
+ if not obs_columns: # If no obs columns are given, sample globally
22
+ if adata.shape[0] > max_samples:
23
+ sampled_indices = np.random.choice(adata.obs.index, max_samples, replace=False)
24
+ else:
25
+ sampled_indices = adata.obs.index # Keep all if fewer than max_samples
26
+
27
+ return adata[sampled_indices].copy()
28
+
29
+ sampled_indices = []
30
+
31
+ # Get unique category combinations from all specified obs columns
32
+ unique_combinations = adata.obs[obs_columns].drop_duplicates()
33
+
34
+ for _, row in unique_combinations.iterrows():
35
+ # Build filter condition dynamically for multiple columns
36
+ condition = (adata.obs[obs_columns] == row.values).all(axis=1)
37
+
38
+ # Get indices for the current category combination
39
+ subset_indices = adata.obs[condition].index.to_numpy()
40
+
41
+ # Subsample or take all
42
+ if len(subset_indices) > max_samples:
43
+ sampled = np.random.choice(subset_indices, max_samples, replace=False)
44
+ else:
45
+ sampled = subset_indices # Keep all if fewer than max_samples
46
+
47
+ sampled_indices.extend(sampled)
48
+
49
+ # ⚠ Handle backed mode detection
50
+ if adata.isbacked:
51
+ print("⚠ Detected backed mode. Subset will be loaded fully into memory.")
52
+ subset = adata[sampled_indices]
53
+ subset = subset.to_memory()
54
+ else:
55
+ subset = adata[sampled_indices]
56
+
57
+ # Create a new AnnData object with only the selected indices
58
+ return subset[sampled_indices].copy()