eplacer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eplacer/__init__.py ADDED
File without changes
eplacer/__main__.py ADDED
@@ -0,0 +1,176 @@
1
+ #! /usr/bin/env python
2
+ '''
3
+ Usage:
4
+ eplacer [--version] [--help] <command> [<args>...]
5
+
6
+ Options:
7
+ -h, --help Generate Help Screen
8
+ -v, --version Get Version Number
9
+
10
+ General Commands:
11
+ train-model Trains a convolutional neural network to perform
12
+ a classification task to a specific taxonomic
13
+ group.
14
+ Options for classification are the following
15
+ - sequence-only
16
+ - sequence-geo
17
+ run-model Runs the convolutional neural network on new
18
+ data to assign taxonomy
19
+
20
+ See 'eplacer <command> --help' for more information on a command
21
+ '''
22
+
23
+
24
+ import sys
25
+ import os
26
+ from docopt import docopt
27
+ import multiprocessing
28
+
29
+ def main():
30
+ args = docopt(__doc__,
31
+ version='',
32
+ options_first=True)
33
+ argv = [args['<command>']] + args['<args>']
34
+ if args['<command>'] == 'train-model':
35
+ import eplacer.train_command
36
+ args = docopt(eplacer.train_command.__doc__, argv=argv)
37
+ # Check if provided directory exists, if provided
38
+ if args['--out']:
39
+ if os.path.exists(args['--out']):
40
+ if args['--force'] == False:
41
+ raise Exception('The path already exists! Exiting...\n')
42
+ # Set default out directory if it doesn't exist
43
+ else:
44
+ args['--out']="data/models/"
45
+ if os.path.exists(args['--out']):
46
+ if args['--force'] == False:
47
+ raise Exception('The path already exists! Exiting...\n')
48
+ # Check for existence of required input
49
+ if args['--fasta']:
50
+ if not os.path.exists(args['--fasta']):
51
+ raise Exception('No fasta file exists at this location! Exiting...\n')
52
+ else:
53
+ raise Exception('No fasta file specified! Exiting...\n')
54
+ if args['--taxa']:
55
+ if not os.path.exists(args['--taxa']):
56
+ raise Exception('No taxa file exists at this location! Exiting...\n')
57
+ else:
58
+ raise Exception('No taxa file specified! Exiting...\n')
59
+ # Default to the species level. Which may or may not work
60
+ if not args['--taxlevel']:
61
+ args['--taxlevel']="SPECIES"
62
+ # Check for the geo data. Set mode of running based on this.
63
+ if args['--geoData']:
64
+ if not os.path.exists(args['--geoData']):
65
+ raise Exception("GeoData path does not exist! Exiting\n")
66
+ else:
67
+ sys.stdout.write("Setting mode to train on "
68
+ "sequence and geographic data\n")
69
+ mode='sequence_geo'
70
+ if not args['--kernel']:
71
+ kernel = 3
72
+ else:
73
+ kernel = int(args['--kernel'])
74
+ if not args['--sigma']:
75
+ sigma = 1
76
+ else:
77
+ sigma = int(args['--sigma'])
78
+ if not args['--precision']:
79
+ precision = 3
80
+ else:
81
+ precision = float(args['--precision'])
82
+ else:
83
+ sys.stdout.write("Setting mode to train on "
84
+ "sequence data only\n")
85
+ mode='sequence'
86
+ if not args['--taxlevel']:
87
+ args['--taxlevel'] = "SPECIES"
88
+ if not args['--maskrate']:
89
+ maskrate = 0
90
+ else:
91
+ maskrate = float(args['--maskrate'])
92
+ if not args['--augments']:
93
+ augments = 5
94
+ else:
95
+ augments = int(args['--augments'])
96
+ print("augments: ", augments)
97
+ if mode == 'sequence':
98
+ eplacer.train_evaluate.train_sequence(args['--fasta'], args['--taxa'], args['--taxlevel'],args['--out'],args['--augments'],args['--maskrate'])
99
+ elif mode == 'sequence_geo':
100
+ eplacer.train_evaluate.train_sequenceOBIS(args['--fasta'], args['--taxa'], args['--taxlevel'],args['--out'],args['--geoData'],
101
+ augments,maskrate,sigma,kernel,precision)
102
+ exit()
103
+ elif args['<command>'] == 'run-model':
104
+ import eplacer.run_command
105
+ import eplacer.run_model
106
+ args = docopt(eplacer.run_command.__doc__, argv=argv)
107
+ # check that the provided directory exists, if provided
108
+ if args['--out']:
109
+ if os.path.exists(args['--out']):
110
+ if args['--force'] == False:
111
+ raise Exception('The path already exists! Exiting...\n')
112
+ # Set default out directory if it doesn't exist
113
+ else:
114
+ args['--out']="result/models/"
115
+ if os.path.exists(args['--out']):
116
+ if args['--force'] == False:
117
+ raise Exception('The path already exists! Exiting...\n')
118
+ if args['--fasta']:
119
+ if not os.path.exists(args['--fasta']):
120
+ raise Exception('No fasta file exists at this location! Exiting...\n')
121
+ else:
122
+ raise Exception('No fasta file specified! Exiting...\n')
123
+ if args['--blast']:
124
+ if not os.path.exists(args['--blast']):
125
+ raise Exception('No blast result file exists at this location! Exiting...\n')
126
+ else:
127
+ raise Exception('No blast result file specified! Exiting...\n')
128
+ if args['--model']:
129
+ if not os.path.exists(args['--model']):
130
+ raise Exception('No model file exists at this location! Exiting...\n')
131
+ args['--taxfile'] = str(args['--model']) + "/taxa_key_SPECIES.tsv"
132
+ if not os.path.exists(args['--taxfile']):
133
+ raise Exception('No taxfile file exists at this location! Your model directory may be corrupted. Exiting...\n')
134
+ else:
135
+ raise Exception('No model file specified! Exiting...\n')
136
+ if not args['--taxlevel']:
137
+ args['--taxlevel'] = "SPECIES"
138
+ if args['--threads']:
139
+ cpu_count = multiprocessing.cpu_count()
140
+ if int(args['--threads']) > cpu_count:
141
+ args['--threads'] = cpu_count
142
+ print(f"Too many threads requested. setting to {cpu_count}")
143
+ else:
144
+ cpu_count = multiprocessing.cpu_count()
145
+ args['--threads'] = cpu_count
146
+ print(f"No threads requested. setting to {cpu_count}")
147
+ if not args['--maskrate']:
148
+ maskrate = 0
149
+ else:
150
+ maskrate = float(args['--maskrate'])
151
+ # Check for the geo data. Set mode of running based on this.
152
+ if args['--geoData']:
153
+ if not os.path.exists(args['--geoData']):
154
+ raise Exception("GeoData path does not exist! Exiting\n")
155
+ else:
156
+ sys.stdout.write("Setting mode to train on "
157
+ "sequence and geographic data\n")
158
+ mode='sequence_geo'
159
+ if not args['--counts']:
160
+ raise Exception('Abundance matrix not specified! Exiting\n')
161
+ elif not os.path.exists(args['--counts']):
162
+ raise Exception('The path to the count matrix does not exist! Exiting\n')
163
+ if not args['--kernel']:
164
+ kernel = 3
165
+ else:
166
+ kernel = int(args['--kernel'])
167
+ if not args['--sigma']:
168
+ sigma = 1
169
+ else:
170
+ sigma = int(args['--sigma'])
171
+ if not args['--confidence']:
172
+ args['--confidence'] = 0.9
173
+ else:
174
+ raise Exception('No geoData available. Exiting...\n')
175
+ eplacer.run_model.gen_model_output_OBIS(args['--confidence'], args['--blast'], args['--out'],args['--fasta'], args['--model'], args['--taxlevel'],args['--taxfile'],args['--geoData'],args['--counts'],maskrate,sigma,kernel, args['--threads'])
176
+
eplacer/data_prep.py ADDED
@@ -0,0 +1,97 @@
1
+ """
2
+ This script defines some useful code for prepping datasets
3
+ for ePlacer
4
+
5
+ Author: Christopher Powers
6
+ Institution: NOAA NEFSC PEMAD PBB
7
+ """
8
+
9
+
10
+ import numpy as np
11
+ from torch.utils.data import Dataset
12
+ import torch
13
+ import numpy as np
14
+
15
+ def get_degenerate_bases():
16
+ """
17
+ Returns a dictionary mapping IUPAC degenerate bases to their possible canonical bases
18
+ """
19
+ return {
20
+ 'A': ['A'],
21
+ 'C': ['C'],
22
+ 'G': ['G'],
23
+ 'T': ['T'],
24
+ 'R': ['A', 'G'], # Purine
25
+ 'Y': ['C', 'T'], # Pyrimidine
26
+ 'M': ['A', 'C'], # Amino
27
+ 'K': ['G', 'T'], # Keto
28
+ 'S': ['C', 'G'], # Strong
29
+ 'W': ['A', 'T'], # Weak
30
+ 'H': ['A', 'C', 'T'], # not G
31
+ 'B': ['C', 'G', 'T'], # not A
32
+ 'V': ['A', 'C', 'G'], # not T
33
+ 'D': ['A', 'G', 'T'], # not C
34
+ 'N': ['-'], # any base is not informative. Encode as gap
35
+ '-': ['-'] # gap
36
+ }
37
+
38
+ def encode_onehot(seq, mask_token = "N"):
39
+ """
40
+ Function to encode an individual sequence with one hot encoding,
41
+ handling degenerate bases by averaging their possible canonical forms
42
+ """
43
+ mapping = {
44
+ "A": [1., 0., 0., 0.],
45
+ "C": [0., 1., 0., 0.],
46
+ "G": [0., 0., 1., 0.],
47
+ "T": [0., 0., 0., 1.],
48
+ "-": [0., 0., 0., 0.],
49
+ mask_token: [0.25, 0.25, 0.25, 0.25]
50
+ }
51
+
52
+ degenerate_bases = get_degenerate_bases()
53
+ # Pre-calculate average encodings for degenerate bases
54
+ for base, possibilities in degenerate_bases.items():
55
+ if base not in mapping and len(possibilities) > 0:
56
+ avg_encoding = np.zeros(4)
57
+ for canonical_base in possibilities:
58
+ if canonical_base in mapping:
59
+ avg_encoding += np.array(mapping[canonical_base])
60
+ mapping[base] = (avg_encoding / len(possibilities)).tolist()
61
+
62
+ # Vectorized operation for the whole sequence
63
+ return np.array([mapping.get(base, [0., 0., 0., 0.]) for base in seq])
64
+
65
+ class SeqGeoDataset(Dataset):
66
+ """
67
+ Dataset that stores the one-hot encoded data alongside
68
+ the geographic data
69
+ """
70
+ def __init__(self, sequences, labels, geo_data):
71
+ self.seqs = []
72
+ self.geo = []
73
+ self.taxa_labels = []
74
+ for i in range(0,len(sequences)):
75
+ self.seqs.append(sequences[i])
76
+ for i in range(0,len(labels)):
77
+ self.taxa_labels.append(labels[i])
78
+ self.ohe_seqs = torch.stack([torch.from_numpy(encode_onehot(x)).float() for x in self.seqs])
79
+ self.labels = torch.Tensor(self.taxa_labels).long()
80
+
81
+ for i in range(0,len(geo_data)):
82
+ self.geo.append(geo_data[i])
83
+ self.ohe_seqs = torch.stack([torch.from_numpy(encode_onehot(x)).float() for x in self.seqs])
84
+ self.ohe_geo = torch.stack([torch.from_numpy(x).float() for x in self.geo])
85
+ self.labels = torch.Tensor(self.taxa_labels).long()
86
+
87
+ def __len__(self): return len(self.seqs)
88
+
89
+ def __getitem__(self,idx):
90
+ seq = self.ohe_seqs[idx]
91
+ label = self.labels[idx]
92
+ geo = self.ohe_geo[idx]
93
+
94
+ return seq, geo, label
95
+
96
+
97
+
eplacer/external.py ADDED
@@ -0,0 +1,60 @@
1
+ """
2
+ This script runs mafft as a subprocess and generates
3
+ an alignment.
4
+
5
+ Author: Christopher Powers
6
+ Institution: NOAA NEFSC PEMAD PBB
7
+ """
8
+
9
+
10
+ import subprocess
11
+
12
+ def run_mafft(input, reference, moutput, subset_output, threads):
13
+ """
14
+ Run the MAFFT alignment to add your sequences to a new fasta
15
+ """
16
+
17
+ # Get initial IDs
18
+ names = []
19
+ with open(input, "r") as infile:
20
+ for line in infile:
21
+ if line.startswith(">"):
22
+ names.append(line[1:].rstrip())
23
+
24
+ try:
25
+ print("Beginning subprocess...")
26
+ print("Aligning with mafft...")
27
+ command = ["mafft --add", input, "--adjustdirection --thread", str(threads), "--keeplength --reorder", reference, ">", moutput]
28
+ subprocess.run(" ".join(command), shell=True, check=True)
29
+ except subprocess.CalledProcessError as e:
30
+ print(f"MAFFT exection failed with error: {e}")
31
+ except FileNotFoundError:
32
+ print("MAFFT not found. Is it installed/in the path?")
33
+
34
+ # subset the new file
35
+ key = ''
36
+ seq = ''
37
+ seqdict = {}
38
+
39
+ with open(moutput, "r") as infile:
40
+ for line in infile:
41
+ line = line.rstrip()
42
+ if line.startswith(">_R_"):
43
+ if key != '':
44
+ seqdict[key] = seq.upper()
45
+ seq = ''
46
+ key = line[4:]
47
+ elif line.startswith(">"):
48
+ if key != '':
49
+ seqdict[key] = seq.upper()
50
+ seq = ''
51
+ key = line[1:]
52
+ else:
53
+ seq += line
54
+ with open(subset_output, "w") as outfile:
55
+ for s in seqdict:
56
+ if s in names:
57
+ outfile.write(f">{s}\n{seqdict[s]}\n")
58
+
59
+ return seqdict
60
+
@@ -0,0 +1,133 @@
1
+ """
2
+ This script contains several scripts in order
3
+ to represent the distribution of a species geographically.
4
+
5
+ Author: Christopher Powers
6
+ Institution: NOAA NEFSC PEMAD PBB
7
+ """
8
+
9
+ import numpy as np
10
+ import pygeohash
11
+ from collections import defaultdict
12
+ from scipy.spatial import cKDTree
13
+
14
+ class SpeciesGeoEncoder:
15
+ def __init__(self, precision=3, min_lat=-90, max_lat=90, min_lon=-180, max_lon=180):
16
+ """
17
+ Initialize the encoder with a fixed geographic grid
18
+ """
19
+ self.precision = precision
20
+ self.min_lat = min_lat
21
+ self.max_lat = max_lat
22
+ self.min_lon = min_lon
23
+ self.max_lon = max_lon
24
+ self.total_precision = 32*precision
25
+ self.lat_divisions = int(self.total_precision/2)
26
+ self.lon_divisions = int(self.total_precision/2)
27
+
28
+ # Create fixed grid of geohashes and corresponding lat/lon points
29
+ self.grid_geohashes, self.grid_points = self._create_geohash_grid()
30
+ self.feature_dimension = len(self.grid_geohashes)
31
+
32
+ # Create KD-tree for efficient nearest neighbor search
33
+ self.kdtree = cKDTree(self.grid_points)
34
+
35
+ # Create mapping from geohash to index for faster lookup
36
+ self.geohash_to_index = {ghash: idx for idx, ghash in enumerate(self.grid_geohashes)}
37
+
38
+ def _create_geohash_grid(self):
39
+ """Create a fixed grid of geohashes covering the area of interest"""
40
+ geohashes = []
41
+ points = []
42
+
43
+ # Create evenly spaced grid points
44
+ lats = np.linspace(self.min_lat, self.max_lat, self.lat_divisions)
45
+ lons = np.linspace(self.min_lon, self.max_lon, self.lon_divisions)
46
+
47
+ # Generate geohash for each grid point
48
+ for lat in lats:
49
+ for lon in lons:
50
+ ghash = pygeohash.encode(lat, lon, precision=int(self.precision))
51
+ geohashes.append(ghash)
52
+ points.append([lat, lon])
53
+ return geohashes, np.array(points)
54
+
55
+
56
+
57
+ def encode_species(self, species_locations):
58
+ """
59
+ Encode species location data using the fixed geohash grid
60
+ """
61
+ encoded_data = defaultdict(lambda:np.zeros(self.feature_dimension))
62
+
63
+ for species, locations in species_locations.items():
64
+ # Initialize array for counts
65
+ counts = np.zeros(self.feature_dimension)
66
+
67
+ if not locations:
68
+ encoded_data[species] = counts
69
+ continue
70
+
71
+ # Convert locations to numpy array for batch processing
72
+ locations_array = np.array(locations)
73
+ # Find nearest grid points for all locations at once
74
+ _, indices = self.kdtree.query(locations_array)
75
+
76
+ # Count occurrences using numpy
77
+ unique_indices, occurrence_counts = np.unique(indices, return_counts=True)
78
+ counts[unique_indices] = occurrence_counts
79
+
80
+ # Normalize counts
81
+ total = np.sum(counts)
82
+ if total > 0:
83
+ counts = counts / total
84
+ print(species)
85
+ non_zero_indices = np.nonzero(counts)
86
+ print(counts[non_zero_indices])
87
+ # null distribution. This is 1/100 of the the null distribution
88
+ mask = counts < 0.00165
89
+ counts[mask] = 0
90
+ non_zero_indices = np.nonzero(counts)
91
+ print(counts[non_zero_indices])
92
+ encoded_data[species] = counts
93
+
94
+ return encoded_data
95
+
96
+ def get_feature_dimension(self):
97
+ """Return the fixed dimension of the feature space"""
98
+ return self.feature_dimension
99
+
100
+ def save_grid(self, filepath):
101
+ """Save the grid information to a file"""
102
+ grid_info = {
103
+ 'precision': self.precision,
104
+ 'min_lat': self.min_lat,
105
+ 'max_lat': self.max_lat,
106
+ 'min_lon': self.min_lon,
107
+ 'max_lon': self.max_lon,
108
+ 'lat_divisions': self.lat_divisions,
109
+ 'lon_divisions': self.lon_divisions,
110
+ 'grid_geohashes': self.grid_geohashes,
111
+ 'grid_points': self.grid_points,
112
+ 'feature_dimension': self.feature_dimension
113
+ }
114
+ print("Saving Grid Info: ")
115
+ for i in grid_info:
116
+ if i != "grid_geohashes":
117
+ print("{}:{}".format(i, grid_info[i]))
118
+ np.save(filepath, grid_info)
119
+
120
+ @classmethod
121
+ def load_grid(cls, filepath):
122
+ """Load a saved grid"""
123
+ grid_info = np.load(filepath, allow_pickle=True).item()
124
+ encoder = cls(
125
+ precision=grid_info['precision'],
126
+ min_lat=grid_info['min_lat'],
127
+ max_lat=grid_info['max_lat'],
128
+ min_lon=grid_info['min_lon'],
129
+ max_lon=grid_info['max_lon']
130
+ )
131
+ # Verify grid matches
132
+ assert np.array_equal(encoder.grid_points, grid_info['grid_points'])
133
+ return encoder