smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +7 -6
- smftools/_version.py +1 -1
- smftools/cli/cli_flows.py +94 -0
- smftools/cli/hmm_adata.py +338 -0
- smftools/cli/load_adata.py +577 -0
- smftools/cli/preprocess_adata.py +363 -0
- smftools/cli/spatial_adata.py +564 -0
- smftools/cli_entry.py +435 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +38 -0
- smftools/config/deaminase.yaml +61 -0
- smftools/config/default.yaml +264 -0
- smftools/config/direct.yaml +41 -0
- smftools/config/discover_input_files.py +115 -0
- smftools/config/experiment_config.py +1288 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/{tools → hmm}/display_hmm.py +3 -3
- smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
- smftools/{tools → hmm}/train_hmm.py +1 -1
- smftools/informatics/__init__.py +13 -9
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/archived/fast5_to_pod5.py +43 -0
- smftools/informatics/archived/helpers/archived/__init__.py +71 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
- smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
- smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
- smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
- smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
- smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
- smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
- smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
- smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
- smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
- smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
- smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
- smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
- smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
- smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
- smftools/informatics/bam_functions.py +812 -0
- smftools/informatics/basecalling.py +67 -0
- smftools/informatics/bed_functions.py +366 -0
- smftools/informatics/binarize_converted_base_identities.py +172 -0
- smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
- smftools/informatics/fasta_functions.py +255 -0
- smftools/informatics/h5ad_functions.py +197 -0
- smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
- smftools/informatics/modkit_functions.py +129 -0
- smftools/informatics/ohe.py +160 -0
- smftools/informatics/pod5_functions.py +224 -0
- smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/{tools → machine_learning}/models/positional.py +3 -2
- smftools/{tools → machine_learning}/models/rnn.py +2 -1
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/plotting/__init__.py +4 -1
- smftools/plotting/autocorrelation_plotting.py +609 -0
- smftools/plotting/general_plotting.py +1292 -140
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +15 -8
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/binarize.py +17 -0
- smftools/preprocessing/binarize_on_Youden.py +2 -2
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_coverage.py +10 -1
- smftools/preprocessing/calculate_position_Youden.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +17 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1326 -124
- smftools/preprocessing/invert_adata.py +12 -5
- smftools/preprocessing/load_sample_sheet.py +19 -4
- smftools/readwrite.py +1021 -89
- smftools/tools/__init__.py +3 -32
- smftools/tools/calculate_umap.py +5 -5
- smftools/tools/general_tools.py +3 -3
- smftools/tools/position_stats.py +468 -106
- smftools/tools/read_stats.py +115 -1
- smftools/tools/spatial_autocorrelation.py +562 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
- smftools-0.2.3.dist-info/RECORD +173 -0
- smftools-0.2.3.dist-info/entry_points.txt +2 -0
- smftools/informatics/fast5_to_pod5.py +0 -21
- smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
- smftools/informatics/helpers/__init__.py +0 -74
- smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
- smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
- smftools/informatics/helpers/bam_qc.py +0 -66
- smftools/informatics/helpers/bed_to_bigwig.py +0 -39
- smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
- smftools/informatics/helpers/index_fasta.py +0 -12
- smftools/informatics/helpers/make_dirs.py +0 -21
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
- smftools/informatics/load_adata.py +0 -182
- smftools/informatics/readwrite.py +0 -106
- smftools/informatics/subsample_fasta_from_bed.py +0 -47
- smftools/preprocessing/append_C_context.py +0 -82
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
- smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
- smftools/preprocessing/filter_reads_on_length.py +0 -51
- smftools/tools/call_hmm_peaks.py +0 -105
- smftools/tools/data/__init__.py +0 -2
- smftools/tools/data/anndata_data_module.py +0 -90
- smftools/tools/inference/__init__.py +0 -1
- smftools/tools/inference/lightning_inference.py +0 -41
- smftools/tools/models/base.py +0 -14
- smftools/tools/models/cnn.py +0 -34
- smftools/tools/models/lightning_base.py +0 -41
- smftools/tools/models/mlp.py +0 -17
- smftools/tools/models/sklearn_models.py +0 -40
- smftools/tools/models/transformer.py +0 -133
- smftools/tools/training/__init__.py +0 -1
- smftools/tools/training/train_lightning_model.py +0 -47
- smftools-0.1.7.dist-info/RECORD +0 -136
- /smftools/{tools/evaluation → cli}/__init__.py +0 -0
- /smftools/{tools → hmm}/calculate_distances.py +0 -0
- /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
- /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
- /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
- /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
- /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
- /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
- /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
- /smftools/{tools → machine_learning}/models/__init__.py +0 -0
- /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
- /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
- /smftools/{tools → machine_learning}/utils/device.py +0 -0
- /smftools/{tools → machine_learning}/utils/grl.py +0 -0
- /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
- /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,82 +0,0 @@
|
|
|
1
|
-
## append_C_context
|
|
2
|
-
|
|
3
|
-
## Conversion SMF Specific
|
|
4
|
-
# Read methylation QC
|
|
5
|
-
def append_C_context(adata, obs_column='Reference', use_consensus=False, native=False):
|
|
6
|
-
"""
|
|
7
|
-
Adds Cytosine context to the position within the given category. When use_consensus is True, it uses the consensus sequence, otherwise it defaults to the FASTA sequence.
|
|
8
|
-
|
|
9
|
-
Parameters:
|
|
10
|
-
adata (AnnData): The input adata object.
|
|
11
|
-
obs_column (str): The observation column in which to stratify on. Default is 'Reference', which should not be changed for most purposes.
|
|
12
|
-
use_consensus (bool): A truth statement indicating whether to use the consensus sequence from the reads mapped to the reference. If False, the reference FASTA is used instead.
|
|
13
|
-
native (bool): If False, perform conversion SMF assumptions. If True, perform native SMF assumptions
|
|
14
|
-
|
|
15
|
-
Returns:
|
|
16
|
-
None
|
|
17
|
-
"""
|
|
18
|
-
import numpy as np
|
|
19
|
-
import anndata as ad
|
|
20
|
-
|
|
21
|
-
print('Adding Cytosine context based on reference FASTA sequence for sample')
|
|
22
|
-
|
|
23
|
-
site_types = ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C', 'any_C_site']
|
|
24
|
-
categories = adata.obs[obs_column].cat.categories
|
|
25
|
-
for cat in categories:
|
|
26
|
-
# Assess if the strand is the top or bottom strand converted
|
|
27
|
-
if 'top' in cat:
|
|
28
|
-
strand = 'top'
|
|
29
|
-
elif 'bottom' in cat:
|
|
30
|
-
strand = 'bottom'
|
|
31
|
-
|
|
32
|
-
if native:
|
|
33
|
-
basename = cat.split(f"_{strand}")[0]
|
|
34
|
-
if use_consensus:
|
|
35
|
-
sequence = adata.uns[f'{basename}_consensus_sequence']
|
|
36
|
-
else:
|
|
37
|
-
# This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
|
|
38
|
-
sequence = adata.uns[f'{basename}_FASTA_sequence']
|
|
39
|
-
else:
|
|
40
|
-
basename = cat.split(f"_{strand}")[0]
|
|
41
|
-
if use_consensus:
|
|
42
|
-
sequence = adata.uns[f'{basename}_consensus_sequence']
|
|
43
|
-
else:
|
|
44
|
-
# This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
|
|
45
|
-
sequence = adata.uns[f'{basename}_FASTA_sequence']
|
|
46
|
-
# Init a dict keyed by reference site type that points to a bool of whether the position is that site type.
|
|
47
|
-
boolean_dict = {}
|
|
48
|
-
for site_type in site_types:
|
|
49
|
-
boolean_dict[f'{cat}_{site_type}'] = np.full(len(sequence), False, dtype=bool)
|
|
50
|
-
|
|
51
|
-
if strand == 'top':
|
|
52
|
-
# Iterate through the sequence and apply the criteria
|
|
53
|
-
for i in range(1, len(sequence) - 1):
|
|
54
|
-
if sequence[i] == 'C':
|
|
55
|
-
boolean_dict[f'{cat}_any_C_site'][i] = True
|
|
56
|
-
if sequence[i - 1] == 'G' and sequence[i + 1] != 'G':
|
|
57
|
-
boolean_dict[f'{cat}_GpC_site'][i] = True
|
|
58
|
-
elif sequence[i - 1] == 'G' and sequence[i + 1] == 'G':
|
|
59
|
-
boolean_dict[f'{cat}_ambiguous_GpC_CpG_site'][i] = True
|
|
60
|
-
elif sequence[i - 1] != 'G' and sequence[i + 1] == 'G':
|
|
61
|
-
boolean_dict[f'{cat}_CpG_site'][i] = True
|
|
62
|
-
elif sequence[i - 1] != 'G' and sequence[i + 1] != 'G':
|
|
63
|
-
boolean_dict[f'{cat}_other_C'][i] = True
|
|
64
|
-
elif strand == 'bottom':
|
|
65
|
-
# Iterate through the sequence and apply the criteria
|
|
66
|
-
for i in range(1, len(sequence) - 1):
|
|
67
|
-
if sequence[i] == 'G':
|
|
68
|
-
boolean_dict[f'{cat}_any_C_site'][i] = True
|
|
69
|
-
if sequence[i + 1] == 'C' and sequence[i - 1] != 'C':
|
|
70
|
-
boolean_dict[f'{cat}_GpC_site'][i] = True
|
|
71
|
-
elif sequence[i - 1] == 'C' and sequence[i + 1] == 'C':
|
|
72
|
-
boolean_dict[f'{cat}_ambiguous_GpC_CpG_site'][i] = True
|
|
73
|
-
elif sequence[i - 1] == 'C' and sequence[i + 1] != 'C':
|
|
74
|
-
boolean_dict[f'{cat}_CpG_site'][i] = True
|
|
75
|
-
elif sequence[i - 1] != 'C' and sequence[i + 1] != 'C':
|
|
76
|
-
boolean_dict[f'{cat}_other_C'][i] = True
|
|
77
|
-
else:
|
|
78
|
-
print('Error: top or bottom strand of conversion could not be determined. Ensure this value is in the Reference name.')
|
|
79
|
-
|
|
80
|
-
for site_type in site_types:
|
|
81
|
-
adata.var[f'{cat}_{site_type}'] = boolean_dict[f'{cat}_{site_type}'].astype(bool)
|
|
82
|
-
adata.obsm[f'{cat}_{site_type}'] = adata[:, adata.var[f'{cat}_{site_type}'] == True].X
|
|
@@ -1,94 +0,0 @@
|
|
|
1
|
-
## calculate_converted_read_methylation_stats
|
|
2
|
-
|
|
3
|
-
## Conversion SMF Specific
|
|
4
|
-
# Read methylation QC
|
|
5
|
-
|
|
6
|
-
def calculate_converted_read_methylation_stats(adata, reference_column, sample_names_col):
|
|
7
|
-
"""
|
|
8
|
-
Adds methylation statistics for each read. Indicates whether the read GpC methylation exceeded other_C methylation (background false positives).
|
|
9
|
-
|
|
10
|
-
Parameters:
|
|
11
|
-
adata (AnnData): An adata object
|
|
12
|
-
reference_column (str): String representing the name of the Reference column to use
|
|
13
|
-
sample_names_col (str): String representing the name of the sample name column to use
|
|
14
|
-
|
|
15
|
-
Returns:
|
|
16
|
-
None
|
|
17
|
-
"""
|
|
18
|
-
import numpy as np
|
|
19
|
-
import anndata as ad
|
|
20
|
-
import pandas as pd
|
|
21
|
-
|
|
22
|
-
print('Calculating read level methylation statistics')
|
|
23
|
-
|
|
24
|
-
references = set(adata.obs[reference_column])
|
|
25
|
-
sample_names = set(adata.obs[sample_names_col])
|
|
26
|
-
|
|
27
|
-
site_types = ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C']
|
|
28
|
-
|
|
29
|
-
for site_type in site_types:
|
|
30
|
-
adata.obs[f'{site_type}_row_methylation_sums'] = pd.Series(0, index=adata.obs_names, dtype=int)
|
|
31
|
-
adata.obs[f'{site_type}_row_methylation_means'] = pd.Series(np.nan, index=adata.obs_names, dtype=float)
|
|
32
|
-
adata.obs[f'number_valid_{site_type}_in_read'] = pd.Series(0, index=adata.obs_names, dtype=int)
|
|
33
|
-
adata.obs[f'fraction_valid_{site_type}_in_range'] = pd.Series(np.nan, index=adata.obs_names, dtype=float)
|
|
34
|
-
for cat in references:
|
|
35
|
-
cat_subset = adata[adata.obs[reference_column] == cat].copy()
|
|
36
|
-
for site_type in site_types:
|
|
37
|
-
print(f'Iterating over {cat}_{site_type}')
|
|
38
|
-
observation_matrix = cat_subset.obsm[f'{cat}_{site_type}']
|
|
39
|
-
number_valid_positions_in_read = np.nansum(~np.isnan(observation_matrix), axis=1)
|
|
40
|
-
row_methylation_sums = np.nansum(observation_matrix, axis=1)
|
|
41
|
-
number_valid_positions_in_read[number_valid_positions_in_read == 0] = 1
|
|
42
|
-
fraction_valid_positions_in_range = number_valid_positions_in_read / np.max(number_valid_positions_in_read)
|
|
43
|
-
row_methylation_means = np.divide(row_methylation_sums, number_valid_positions_in_read)
|
|
44
|
-
temp_obs_data = pd.DataFrame({f'number_valid_{site_type}_in_read': number_valid_positions_in_read,
|
|
45
|
-
f'fraction_valid_{site_type}_in_range': fraction_valid_positions_in_range,
|
|
46
|
-
f'{site_type}_row_methylation_sums': row_methylation_sums,
|
|
47
|
-
f'{site_type}_row_methylation_means': row_methylation_means}, index=cat_subset.obs.index)
|
|
48
|
-
adata.obs.update(temp_obs_data)
|
|
49
|
-
# Indicate whether the read-level GpC methylation rate exceeds the false methylation rate of the read
|
|
50
|
-
pass_array = np.array(adata.obs[f'GpC_site_row_methylation_means'] > adata.obs[f'other_C_row_methylation_means'])
|
|
51
|
-
adata.obs['GpC_above_other_C'] = pd.Series(pass_array, index=adata.obs.index, dtype=bool)
|
|
52
|
-
|
|
53
|
-
# Below should be a plotting function
|
|
54
|
-
# adata.uns['methylation_dict'] = {}
|
|
55
|
-
# n_bins = 50
|
|
56
|
-
# site_types_to_analyze = ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C']
|
|
57
|
-
|
|
58
|
-
# for reference in references:
|
|
59
|
-
# reference_adata = adata[adata.obs[reference_column] == reference].copy()
|
|
60
|
-
# split_reference = reference.split('_')[0][1:]
|
|
61
|
-
# for sample in sample_names:
|
|
62
|
-
# sample_adata = reference_adata[reference_adata.obs[sample_names_col] == sample].copy()
|
|
63
|
-
# for site_type in site_types_to_analyze:
|
|
64
|
-
# methylation_data = sample_adata.obs[f'{site_type}_row_methylation_means']
|
|
65
|
-
# max_meth = np.max(sample_adata.obs[f'{site_type}_row_methylation_sums'])
|
|
66
|
-
# if not np.isnan(max_meth):
|
|
67
|
-
# n_bins = int(max_meth // 2)
|
|
68
|
-
# else:
|
|
69
|
-
# n_bins = 1
|
|
70
|
-
# mean = np.mean(methylation_data)
|
|
71
|
-
# median = np.median(methylation_data)
|
|
72
|
-
# stdev = np.std(methylation_data)
|
|
73
|
-
# adata.uns['methylation_dict'][f'{reference}_{sample}_{site_type}'] = [mean, median, stdev]
|
|
74
|
-
# if show_methylation_histogram or save_methylation_histogram:
|
|
75
|
-
# fig, ax = plt.subplots(figsize=(6, 4))
|
|
76
|
-
# count, bins, patches = plt.hist(methylation_data, bins=n_bins, weights=np.ones(len(methylation_data)) / len(methylation_data), alpha=0.7, color='blue', edgecolor='black')
|
|
77
|
-
# plt.axvline(median, color='red', linestyle='dashed', linewidth=1)
|
|
78
|
-
# plt.text(median + stdev, max(count)*0.8, f'Median: {median:.2f}', color='red')
|
|
79
|
-
# plt.axvline(median - stdev, color='green', linestyle='dashed', linewidth=1, label=f'Stdev: {stdev:.2f}')
|
|
80
|
-
# plt.axvline(median + stdev, color='green', linestyle='dashed', linewidth=1)
|
|
81
|
-
# plt.text(median + stdev + 0.05, max(count) / 3, f'+1 Stdev: {stdev:.2f}', color='green')
|
|
82
|
-
# plt.xlabel('Fraction methylated')
|
|
83
|
-
# plt.ylabel('Proportion')
|
|
84
|
-
# title = f'Distribution of {methylation_data.shape[0]} read {site_type} methylation means \nfor {sample} sample on {split_reference} after filtering'
|
|
85
|
-
# plt.title(title, pad=20)
|
|
86
|
-
# plt.xlim(-0.05, 1.05) # Set x-axis range from 0 to 1
|
|
87
|
-
# ax.spines['right'].set_visible(False)
|
|
88
|
-
# ax.spines['top'].set_visible(False)
|
|
89
|
-
# save_name = output_directory + f'/{readwrite.date_string()} {title}'
|
|
90
|
-
# if save_methylation_histogram:
|
|
91
|
-
# plt.savefig(save_name, bbox_inches='tight', pad_inches=0.1)
|
|
92
|
-
# plt.close()
|
|
93
|
-
# else:
|
|
94
|
-
# plt.show()
|
|
@@ -1,44 +0,0 @@
|
|
|
1
|
-
## filter_converted_reads_on_methylation
|
|
2
|
-
|
|
3
|
-
## Conversion SMF Specific
|
|
4
|
-
def filter_converted_reads_on_methylation(adata, valid_SMF_site_threshold=0.8, min_SMF_threshold=0.025, max_SMF_threshold=0.975):
|
|
5
|
-
"""
|
|
6
|
-
Filter adata object using minimum thresholds for valid SMF site fraction in read, as well as minimum methylation content in read.
|
|
7
|
-
|
|
8
|
-
Parameters:
|
|
9
|
-
adata (AnnData): An adata object.
|
|
10
|
-
valid_SMF_site_threshold (float): A minimum proportion of valid SMF sites that must be present in the read. Default is 0.8
|
|
11
|
-
min_SMF_threshold (float): A minimum read methylation level. Default is 0.025
|
|
12
|
-
Returns:
|
|
13
|
-
Anndata
|
|
14
|
-
"""
|
|
15
|
-
import numpy as np
|
|
16
|
-
import anndata as ad
|
|
17
|
-
import pandas as pd
|
|
18
|
-
|
|
19
|
-
if valid_SMF_site_threshold:
|
|
20
|
-
# Keep reads that have over a given valid GpC site content
|
|
21
|
-
adata = adata[adata.obs['fraction_valid_GpC_site_in_range'] > valid_SMF_site_threshold].copy()
|
|
22
|
-
|
|
23
|
-
if min_SMF_threshold:
|
|
24
|
-
# Keep reads with SMF methylation over background methylation.
|
|
25
|
-
below_background = (~adata.obs['GpC_above_other_C']).sum()
|
|
26
|
-
print(f'Removing {below_background} reads that have GpC conversion below background conversion rate')
|
|
27
|
-
adata = adata[adata.obs['GpC_above_other_C'] == True].copy()
|
|
28
|
-
# Keep reads over a defined methylation threshold
|
|
29
|
-
s0 = adata.shape[0]
|
|
30
|
-
adata = adata[adata.obs['GpC_site_row_methylation_means'] > min_SMF_threshold].copy()
|
|
31
|
-
s1 = adata.shape[0]
|
|
32
|
-
below_threshold = s0 - s1
|
|
33
|
-
print(f'Removing {below_threshold} reads that have GpC conversion below a minimum threshold conversion rate')
|
|
34
|
-
|
|
35
|
-
if max_SMF_threshold:
|
|
36
|
-
# Keep reads below a defined methylation threshold
|
|
37
|
-
s0 = adata.shape[0]
|
|
38
|
-
adata = adata[adata.obs['GpC_site_row_methylation_means'] < max_SMF_threshold].copy()
|
|
39
|
-
s1 = adata.shape[0]
|
|
40
|
-
above_threshold = s0 - s1
|
|
41
|
-
print(f'Removing {above_threshold} reads that have GpC conversion above a maximum threshold conversion rate')
|
|
42
|
-
|
|
43
|
-
return adata
|
|
44
|
-
|
|
@@ -1,51 +0,0 @@
|
|
|
1
|
-
## filter_reads_on_length
|
|
2
|
-
|
|
3
|
-
def filter_reads_on_length(adata, filter_on_coordinates=False, min_read_length=2700, max_read_length=3200):
|
|
4
|
-
"""
|
|
5
|
-
Filters the adata object to keep a defined coordinate window, as well as reads that are over a minimum threshold in length.
|
|
6
|
-
|
|
7
|
-
Parameters:
|
|
8
|
-
adata (AnnData): An adata object.
|
|
9
|
-
filter_on_coordinates (bool | list): If False, skips filtering. Otherwise, provide a list containing integers representing the lower and upper bound coordinates to filter on. Default is False.
|
|
10
|
-
min_read_length (int): The minimum read length to keep in the filtered dataset. Default is 2700.
|
|
11
|
-
max_read_length (int): The maximum query read length to keep in the filtered dataset. Default is 3200.
|
|
12
|
-
|
|
13
|
-
Returns:
|
|
14
|
-
adata
|
|
15
|
-
"""
|
|
16
|
-
import numpy as np
|
|
17
|
-
import anndata as ad
|
|
18
|
-
import pandas as pd
|
|
19
|
-
|
|
20
|
-
if filter_on_coordinates:
|
|
21
|
-
lower_bound, upper_bound = filter_on_coordinates
|
|
22
|
-
# Extract the position information from the adata object as an np array
|
|
23
|
-
var_names_arr = adata.var_names.astype(int).to_numpy()
|
|
24
|
-
# Find the upper bound coordinate that is closest to the specified value
|
|
25
|
-
closest_end_index = np.argmin(np.abs(var_names_arr - upper_bound))
|
|
26
|
-
upper_bound = int(adata.var_names[closest_end_index])
|
|
27
|
-
# Find the lower bound coordinate that is closest to the specified value
|
|
28
|
-
closest_start_index = np.argmin(np.abs(var_names_arr - lower_bound))
|
|
29
|
-
lower_bound = int(adata.var_names[closest_start_index])
|
|
30
|
-
# Get a list of positional indexes that encompass the lower and upper bounds of the dataset
|
|
31
|
-
position_list = list(range(lower_bound, upper_bound + 1))
|
|
32
|
-
position_list = [str(pos) for pos in position_list]
|
|
33
|
-
position_set = set(position_list)
|
|
34
|
-
print(f'Subsetting adata to keep data between coordinates {lower_bound} and {upper_bound}')
|
|
35
|
-
adata = adata[:, adata.var_names.isin(position_set)].copy()
|
|
36
|
-
|
|
37
|
-
if min_read_length:
|
|
38
|
-
print(f'Subsetting adata to keep reads longer than {min_read_length}')
|
|
39
|
-
s0 = adata.shape[0]
|
|
40
|
-
adata = adata[adata.obs['read_length'] > min_read_length].copy()
|
|
41
|
-
s1 = adata.shape[0]
|
|
42
|
-
print(f'Removed {s0-s1} reads less than {min_read_length} basepairs in length')
|
|
43
|
-
|
|
44
|
-
if max_read_length:
|
|
45
|
-
print(f'Subsetting adata to keep reads shorter than {max_read_length}')
|
|
46
|
-
s0 = adata.shape[0]
|
|
47
|
-
adata = adata[adata.obs['read_length'] < max_read_length].copy()
|
|
48
|
-
s1 = adata.shape[0]
|
|
49
|
-
print(f'Removed {s0-s1} reads greater than {max_read_length} basepairs in length')
|
|
50
|
-
|
|
51
|
-
return adata
|
smftools/tools/call_hmm_peaks.py
DELETED
|
@@ -1,105 +0,0 @@
|
|
|
1
|
-
def call_hmm_peaks(adata, feature_configs, obs_column='Reference_strand', site_types=['GpC_site', 'CpG_site'], save_plot=False, output_dir=None, date_tag=None):
|
|
2
|
-
"""
|
|
3
|
-
Calls peaks from HMM feature layers and annotates them into the AnnData object.
|
|
4
|
-
|
|
5
|
-
Parameters:
|
|
6
|
-
adata : AnnData object with HMM layers (from apply_hmm)
|
|
7
|
-
feature_configs : dict
|
|
8
|
-
min_distance : minimum distance between peaks
|
|
9
|
-
peak_width : window size around peak centers
|
|
10
|
-
peak_prominence : required peak prominence
|
|
11
|
-
peak_threshold : threshold for labeling a read as "present" at a peak
|
|
12
|
-
site_types : list of var site types to aggregate
|
|
13
|
-
save_plot : whether to save the plot
|
|
14
|
-
output_dir : path to save the figure if save_plot=True
|
|
15
|
-
date_tag : optional tag for filename
|
|
16
|
-
"""
|
|
17
|
-
import matplotlib.pyplot as plt
|
|
18
|
-
from scipy.signal import find_peaks
|
|
19
|
-
import os
|
|
20
|
-
import numpy as np
|
|
21
|
-
|
|
22
|
-
peak_columns = []
|
|
23
|
-
|
|
24
|
-
for feature_layer, config in feature_configs.items():
|
|
25
|
-
min_distance = config.get('min_distance', 200)
|
|
26
|
-
peak_width = config.get('peak_width', 200)
|
|
27
|
-
peak_prominence = config.get('peak_prominence', 0.2)
|
|
28
|
-
peak_threshold = config.get('peak_threshold', 0.8)
|
|
29
|
-
|
|
30
|
-
# 1️⃣ Calculate mean intensity profile
|
|
31
|
-
matrix = adata.layers[feature_layer]
|
|
32
|
-
means = np.mean(matrix, axis=0)
|
|
33
|
-
feature_peak_columns = []
|
|
34
|
-
|
|
35
|
-
# 2️⃣ Peak calling
|
|
36
|
-
peak_centers, _ = find_peaks(means, prominence=peak_prominence, distance=min_distance)
|
|
37
|
-
adata.uns[f'{feature_layer} peak_centers'] = peak_centers
|
|
38
|
-
|
|
39
|
-
# 3️⃣ Plot
|
|
40
|
-
plt.figure(figsize=(6, 3))
|
|
41
|
-
plt.plot(range(len(means)), means)
|
|
42
|
-
plt.title(f"{feature_layer} density with peak calls")
|
|
43
|
-
plt.xlabel("Genomic position")
|
|
44
|
-
plt.ylabel("Mean feature density")
|
|
45
|
-
y = max(means) / 2
|
|
46
|
-
for i, center in enumerate(peak_centers):
|
|
47
|
-
plus_minus_width = peak_width // 2
|
|
48
|
-
start = center - plus_minus_width
|
|
49
|
-
end = center + plus_minus_width
|
|
50
|
-
plt.axvspan(start, end, color='purple', alpha=0.2)
|
|
51
|
-
plt.axvline(center, color='red', linestyle='--')
|
|
52
|
-
if i%2:
|
|
53
|
-
aligned = [end, 'left']
|
|
54
|
-
else:
|
|
55
|
-
aligned = [start, 'right']
|
|
56
|
-
plt.text(aligned[0], 0, f"Peak {i}\n{center}", color='red', ha=aligned[1])
|
|
57
|
-
|
|
58
|
-
if save_plot and output_dir:
|
|
59
|
-
filename = f"{output_dir}/{date_tag or 'output'}_{feature_layer}_peaks.png"
|
|
60
|
-
plt.savefig(filename, bbox_inches='tight')
|
|
61
|
-
print(f"Saved plot to {filename}")
|
|
62
|
-
else:
|
|
63
|
-
plt.show()
|
|
64
|
-
|
|
65
|
-
# 4️⃣ Annotate peaks back into adata.obs
|
|
66
|
-
for center in peak_centers:
|
|
67
|
-
half_width = peak_width // 2
|
|
68
|
-
start, end = center - half_width, center + half_width
|
|
69
|
-
colname = f'{feature_layer}_peak_{center}'
|
|
70
|
-
peak_columns.append(colname)
|
|
71
|
-
feature_peak_columns.append(colname)
|
|
72
|
-
|
|
73
|
-
adata.var[colname] = (
|
|
74
|
-
(adata.var_names.astype(int) >= start) &
|
|
75
|
-
(adata.var_names.astype(int) <= end)
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
# Feature layer intensity around peak
|
|
79
|
-
mean_values = np.mean(matrix[:, start:end+1], axis=1)
|
|
80
|
-
sum_values = np.sum(matrix[:, start:end+1], axis=1)
|
|
81
|
-
adata.obs[f'mean_{feature_layer}_around_{center}'] = mean_values
|
|
82
|
-
adata.obs[f'sum_{feature_layer}_around_{center}'] = sum_values
|
|
83
|
-
adata.obs[f'{feature_layer}_present_at_{center}'] = mean_values > peak_threshold
|
|
84
|
-
|
|
85
|
-
# Site-type based aggregation
|
|
86
|
-
for site_type in site_types:
|
|
87
|
-
adata.obs[f'{site_type}_sum_around_{center}'] = 0
|
|
88
|
-
adata.obs[f'{site_type}_mean_around_{center}'] = np.nan
|
|
89
|
-
|
|
90
|
-
references = adata.obs[obs_column].cat.categories
|
|
91
|
-
for ref in adata.obs[obs_column].cat.categories:
|
|
92
|
-
subset = adata[adata.obs[obs_column] == ref]
|
|
93
|
-
for site_type in site_types:
|
|
94
|
-
mask = subset.var.get(f'{ref}_{site_type}', None)
|
|
95
|
-
if mask is not None:
|
|
96
|
-
region_mask = (subset.var_names[mask].astype(int) >= start) & (subset.var_names[mask].astype(int) <= end)
|
|
97
|
-
region = subset[:, mask].X[:, region_mask]
|
|
98
|
-
adata.obs.loc[subset.obs.index, f'{site_type}_sum_around_{center}'] = np.nansum(region, axis=1)
|
|
99
|
-
adata.obs.loc[subset.obs.index, f'{site_type}_mean_around_{center}'] = np.nanmean(region, axis=1)
|
|
100
|
-
|
|
101
|
-
adata.var[f'is_in_any_{feature_layer}_peak'] = adata.var[feature_peak_columns].any(axis=1)
|
|
102
|
-
print(f"✅ Peak annotation completed for {feature_layer} with {len(peak_centers)} peaks.")
|
|
103
|
-
|
|
104
|
-
# Combine all peaks into a single "is_in_any_peak" column
|
|
105
|
-
adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
|
smftools/tools/data/__init__.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch.utils.data import DataLoader, TensorDataset, random_split
|
|
3
|
-
import pytorch_lightning as pl
|
|
4
|
-
import numpy as np
|
|
5
|
-
import pandas as pd
|
|
6
|
-
|
|
7
|
-
class AnnDataModule(pl.LightningDataModule):
|
|
8
|
-
def __init__(self, adata, tensor_source="X", tensor_key=None, label_col="labels",
|
|
9
|
-
batch_size=64, train_frac=0.7, random_seed=42, split_col='train_val_split', split_save_path=None, load_existing_split=False,
|
|
10
|
-
inference_mode=False):
|
|
11
|
-
super().__init__()
|
|
12
|
-
self.adata = adata # The adata object
|
|
13
|
-
self.tensor_source = tensor_source # X, layers, obsm
|
|
14
|
-
self.tensor_key = tensor_key # name of the layer or obsm key
|
|
15
|
-
self.label_col = label_col # name of the label column in obs
|
|
16
|
-
self.batch_size = batch_size
|
|
17
|
-
self.train_frac = train_frac
|
|
18
|
-
self.random_seed = random_seed
|
|
19
|
-
self.split_col = split_col # Name of obs column to store "train"/"val"
|
|
20
|
-
self.split_save_path = split_save_path # Where to save the obs_names and train/test split logging
|
|
21
|
-
self.load_existing_split = load_existing_split # Whether to load from an existing split
|
|
22
|
-
self.inference_mode = inference_mode # Whether to load the AnnDataModule in inference mode.
|
|
23
|
-
|
|
24
|
-
def setup(self, stage=None):
|
|
25
|
-
# Load feature matrix
|
|
26
|
-
if self.tensor_source == "X":
|
|
27
|
-
X = self.adata.X
|
|
28
|
-
elif self.tensor_source == "layers":
|
|
29
|
-
assert self.tensor_key in self.adata.layers, f"Layer '{self.tensor_key}' not found."
|
|
30
|
-
X = self.adata.layers[self.tensor_key]
|
|
31
|
-
elif self.tensor_source == "obsm":
|
|
32
|
-
assert self.tensor_key in self.adata.obsm, f"obsm key '{self.tensor_key}' not found."
|
|
33
|
-
X = self.adata.obsm[self.tensor_key]
|
|
34
|
-
else:
|
|
35
|
-
raise ValueError(f"Invalid tensor_source: {self.tensor_source}")
|
|
36
|
-
|
|
37
|
-
# Convert to tensor
|
|
38
|
-
X_tensor = torch.tensor(X, dtype=torch.float32)
|
|
39
|
-
|
|
40
|
-
if self.inference_mode:
|
|
41
|
-
self.infer_dataset = TensorDataset(X_tensor)
|
|
42
|
-
|
|
43
|
-
else:
|
|
44
|
-
# Load and encode labels
|
|
45
|
-
y = self.adata.obs[self.label_col]
|
|
46
|
-
if y.dtype.name == 'category':
|
|
47
|
-
y = y.cat.codes
|
|
48
|
-
y_tensor = torch.tensor(y.values, dtype=torch.long)
|
|
49
|
-
|
|
50
|
-
# Use existing split
|
|
51
|
-
if self.load_existing_split:
|
|
52
|
-
split_df = pd.read_csv(self.split_save_path, index_col=0)
|
|
53
|
-
assert self.split_col in split_df.columns, f"'{self.split_col}' column missing in split file."
|
|
54
|
-
self.adata.obs[self.split_col] = split_df.loc[self.adata.obs_names][self.split_col].values
|
|
55
|
-
|
|
56
|
-
# If no split exists, create one
|
|
57
|
-
if self.split_col not in self.adata.obs:
|
|
58
|
-
full_dataset = TensorDataset(X_tensor, y_tensor)
|
|
59
|
-
n_train = int(self.train_frac * len(full_dataset))
|
|
60
|
-
n_val = len(full_dataset) - n_train
|
|
61
|
-
self.train_set, self.val_set = random_split(
|
|
62
|
-
full_dataset, [n_train, n_val],
|
|
63
|
-
generator=torch.Generator().manual_seed(self.random_seed)
|
|
64
|
-
)
|
|
65
|
-
# Assign split labels
|
|
66
|
-
split_array = np.full(len(self.adata), "val", dtype=object)
|
|
67
|
-
train_idx = self.train_set.indices if hasattr(self.train_set, "indices") else self.train_set._indices
|
|
68
|
-
split_array[train_idx] = "train"
|
|
69
|
-
self.adata.obs[self.split_col] = split_array
|
|
70
|
-
|
|
71
|
-
# Save to disk
|
|
72
|
-
if self.split_save_path:
|
|
73
|
-
self.adata.obs[[self.split_col]].to_csv(self.split_save_path)
|
|
74
|
-
else:
|
|
75
|
-
split_labels = self.adata.obs[self.split_col].values
|
|
76
|
-
train_mask = split_labels == "train"
|
|
77
|
-
val_mask = split_labels == "val"
|
|
78
|
-
self.train_set = TensorDataset(X_tensor[train_mask], y_tensor[train_mask])
|
|
79
|
-
self.val_set = TensorDataset(X_tensor[val_mask], y_tensor[val_mask])
|
|
80
|
-
|
|
81
|
-
def train_dataloader(self):
|
|
82
|
-
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
|
|
83
|
-
|
|
84
|
-
def val_dataloader(self):
|
|
85
|
-
return DataLoader(self.val_set, batch_size=self.batch_size)
|
|
86
|
-
|
|
87
|
-
def predict_dataloader(self):
|
|
88
|
-
if not self.inference_mode:
|
|
89
|
-
raise RuntimeError("predict_dataloader only available in inference mode.")
|
|
90
|
-
return DataLoader(self.infer_dataset, batch_size=self.batch_size)
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .lightning_inference import run_lightning_inference
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import pandas as pd
|
|
3
|
-
import numpy as np
|
|
4
|
-
from pytorch_lightning import Trainer
|
|
5
|
-
|
|
6
|
-
def run_lightning_inference(
|
|
7
|
-
adata,
|
|
8
|
-
model,
|
|
9
|
-
datamodule,
|
|
10
|
-
label_col="labels",
|
|
11
|
-
prefix="model"
|
|
12
|
-
):
|
|
13
|
-
|
|
14
|
-
# Get class labels
|
|
15
|
-
if label_col in adata.obs and pd.api.types.is_categorical_dtype(adata.obs[label_col]):
|
|
16
|
-
class_labels = adata.obs[label_col].cat.categories.tolist()
|
|
17
|
-
else:
|
|
18
|
-
raise ValueError("label_col must be a categorical column in adata.obs")
|
|
19
|
-
|
|
20
|
-
# Run predictions
|
|
21
|
-
trainer = Trainer(accelerator="auto", devices=1, logger=False, enable_checkpointing=False)
|
|
22
|
-
preds = trainer.predict(model, datamodule=datamodule)
|
|
23
|
-
probs = torch.cat(preds, dim=0).cpu().numpy() # (N, C)
|
|
24
|
-
pred_class_idx = probs.argmax(axis=1)
|
|
25
|
-
pred_class_labels = [class_labels[i] for i in pred_class_idx]
|
|
26
|
-
pred_class_probs = probs[np.arange(len(probs)), pred_class_idx]
|
|
27
|
-
|
|
28
|
-
# Construct full prefix with label_col
|
|
29
|
-
full_prefix = f"{prefix}_{label_col}"
|
|
30
|
-
|
|
31
|
-
# Store predictions in obs
|
|
32
|
-
adata.obs[f"{full_prefix}_pred"] = pred_class_idx
|
|
33
|
-
adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
|
|
34
|
-
adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
|
|
35
|
-
|
|
36
|
-
# Per-class probabilities
|
|
37
|
-
for i, class_name in enumerate(class_labels):
|
|
38
|
-
adata.obs[f"{full_prefix}_prob_{class_name}"] = probs[:, i]
|
|
39
|
-
|
|
40
|
-
# Full probability matrix in obsm
|
|
41
|
-
adata.obsm[f"{full_prefix}_pred_prob_all"] = probs
|
smftools/tools/models/base.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
import torch.nn as nn
|
|
2
|
-
from ..utils.device import detect_device
|
|
3
|
-
|
|
4
|
-
class BaseTorchModel(nn.Module):
|
|
5
|
-
"""
|
|
6
|
-
Minimal base class for torch models that:
|
|
7
|
-
- Stores device
|
|
8
|
-
- Moves model to detected device on init
|
|
9
|
-
"""
|
|
10
|
-
def __init__(self, dropout_rate=0.2):
|
|
11
|
-
super().__init__()
|
|
12
|
-
self.device = detect_device() # detects available devices
|
|
13
|
-
self.dropout_rate = dropout_rate # default dropout rate to be used in regularization.
|
|
14
|
-
self.to(self.device) # move model to device
|
smftools/tools/models/cnn.py
DELETED
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import torch.nn as nn
|
|
3
|
-
from .base import BaseTorchModel
|
|
4
|
-
|
|
5
|
-
class CNNClassifier(BaseTorchModel):
|
|
6
|
-
def __init__(self, input_size, num_classes, **kwargs):
|
|
7
|
-
super().__init__(**kwargs)
|
|
8
|
-
# Define convolutional layers
|
|
9
|
-
self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
|
|
10
|
-
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
|
|
11
|
-
# Define activation function
|
|
12
|
-
self.relu = nn.ReLU()
|
|
13
|
-
|
|
14
|
-
# Determine the flattened size dynamically
|
|
15
|
-
dummy_input = torch.zeros(1, 1, input_size).to(self.device)
|
|
16
|
-
with torch.no_grad():
|
|
17
|
-
dummy_output = self._forward_conv(dummy_input)
|
|
18
|
-
flattened_size = dummy_output.view(1, -1).shape[1]
|
|
19
|
-
|
|
20
|
-
# Define fully connected layers
|
|
21
|
-
self.fc1 = nn.Linear(flattened_size, 64)
|
|
22
|
-
self.fc2 = nn.Linear(64, num_classes)
|
|
23
|
-
|
|
24
|
-
def _forward_conv(self, x):
|
|
25
|
-
x = self.relu(self.conv1(x))
|
|
26
|
-
x = self.relu(self.conv2(x))
|
|
27
|
-
return x
|
|
28
|
-
|
|
29
|
-
def forward(self, x):
|
|
30
|
-
x = x.unsqueeze(1) # [B, 1, L]
|
|
31
|
-
x = self._forward_conv(x)
|
|
32
|
-
x = x.view(x.size(0), -1) # flatten
|
|
33
|
-
x = self.relu(self.fc1(x))
|
|
34
|
-
return self.fc2(x)
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import pytorch_lightning as pl
|
|
3
|
-
|
|
4
|
-
class TorchClassifierWrapper(pl.LightningModule):
|
|
5
|
-
def __init__(
|
|
6
|
-
self,
|
|
7
|
-
model: torch.nn.Module,
|
|
8
|
-
optimizer_cls=torch.optim.AdamW,
|
|
9
|
-
optimizer_kwargs=None,
|
|
10
|
-
criterion_cls=torch.nn.CrossEntropyLoss,
|
|
11
|
-
criterion_kwargs=None,
|
|
12
|
-
lr: float = 1e-3,
|
|
13
|
-
):
|
|
14
|
-
super().__init__()
|
|
15
|
-
self.model = model
|
|
16
|
-
self.save_hyperparameters(ignore=['model']) # logs all except actual model instance
|
|
17
|
-
self.optimizer_cls = optimizer_cls
|
|
18
|
-
self.optimizer_kwargs = optimizer_kwargs or {}
|
|
19
|
-
self.criterion = criterion_cls(**(criterion_kwargs or {}))
|
|
20
|
-
self.lr = lr
|
|
21
|
-
|
|
22
|
-
def forward(self, x):
|
|
23
|
-
return self.model(x)
|
|
24
|
-
|
|
25
|
-
def training_step(self, batch, batch_idx):
|
|
26
|
-
x, y = batch
|
|
27
|
-
logits = self(x)
|
|
28
|
-
loss = self.criterion(logits, y)
|
|
29
|
-
self.log("train_loss", loss, prog_bar=True)
|
|
30
|
-
return loss
|
|
31
|
-
|
|
32
|
-
def validation_step(self, batch, batch_idx):
|
|
33
|
-
x, y = batch
|
|
34
|
-
logits = self(x)
|
|
35
|
-
loss = self.criterion(logits, y)
|
|
36
|
-
acc = (logits.argmax(dim=1) == y).float().mean()
|
|
37
|
-
self.log_dict({"val_loss": loss, "val_acc": acc}, prog_bar=True)
|
|
38
|
-
return loss
|
|
39
|
-
|
|
40
|
-
def configure_optimizers(self):
|
|
41
|
-
return self.optimizer_cls(self.parameters(), lr=self.lr, **self.optimizer_kwargs)
|
smftools/tools/models/mlp.py
DELETED
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import torch.nn as nn
|
|
3
|
-
from .base import BaseTorchModel
|
|
4
|
-
|
|
5
|
-
class MLPClassifier(BaseTorchModel):
|
|
6
|
-
def __init__(self, input_dim, num_classes, hidden_sizes=(128, 64), **kwargs):
|
|
7
|
-
super().__init__(**kwargs)
|
|
8
|
-
layers = []
|
|
9
|
-
prev = input_dim
|
|
10
|
-
for h in hidden_sizes:
|
|
11
|
-
layers.extend([nn.Linear(prev, h), nn.ReLU(), nn.Dropout(self.dropout_rate)])
|
|
12
|
-
prev = h
|
|
13
|
-
layers.append(nn.Linear(prev, num_classes))
|
|
14
|
-
self.model = nn.Sequential(*layers)
|
|
15
|
-
|
|
16
|
-
def forward(self, x):
|
|
17
|
-
return self.model(x)
|