eplacer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eplacer/__init__.py +0 -0
- eplacer/__main__.py +176 -0
- eplacer/data_prep.py +97 -0
- eplacer/external.py +60 -0
- eplacer/geographicRep.py +133 -0
- eplacer/models.py +234 -0
- eplacer/run_command.py +62 -0
- eplacer/run_model.py +610 -0
- eplacer/train_command.py +53 -0
- eplacer/train_evaluate.py +478 -0
- eplacer-0.1.0.dist-info/METADATA +143 -0
- eplacer-0.1.0.dist-info/RECORD +16 -0
- eplacer-0.1.0.dist-info/WHEEL +5 -0
- eplacer-0.1.0.dist-info/entry_points.txt +2 -0
- eplacer-0.1.0.dist-info/licenses/LICENSE.txt +5 -0
- eplacer-0.1.0.dist-info/top_level.txt +1 -0
eplacer/__init__.py
ADDED
|
File without changes
|
eplacer/__main__.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
#! /usr/bin/env python
|
|
2
|
+
'''
|
|
3
|
+
Usage:
|
|
4
|
+
eplacer [--version] [--help] <command> [<args>...]
|
|
5
|
+
|
|
6
|
+
Options:
|
|
7
|
+
-h, --help Generate Help Screen
|
|
8
|
+
-v, --version Get Version Number
|
|
9
|
+
|
|
10
|
+
General Commands:
|
|
11
|
+
train-model Trains a convolutional neural network to perform
|
|
12
|
+
a classification task to a specific taxonomic
|
|
13
|
+
group.
|
|
14
|
+
Options for classification are the following
|
|
15
|
+
- sequence-only
|
|
16
|
+
- sequence-geo
|
|
17
|
+
run-model Runs the convolutional neural network on new
|
|
18
|
+
data to assign taxonomy
|
|
19
|
+
|
|
20
|
+
See 'eplacer <command> --help' for more information on a command
|
|
21
|
+
'''
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
import sys
|
|
25
|
+
import os
|
|
26
|
+
from docopt import docopt
|
|
27
|
+
import multiprocessing
|
|
28
|
+
|
|
29
|
+
def main():
|
|
30
|
+
args = docopt(__doc__,
|
|
31
|
+
version='',
|
|
32
|
+
options_first=True)
|
|
33
|
+
argv = [args['<command>']] + args['<args>']
|
|
34
|
+
if args['<command>'] == 'train-model':
|
|
35
|
+
import eplacer.train_command
|
|
36
|
+
args = docopt(eplacer.train_command.__doc__, argv=argv)
|
|
37
|
+
# Check if provided directory exists, if provided
|
|
38
|
+
if args['--out']:
|
|
39
|
+
if os.path.exists(args['--out']):
|
|
40
|
+
if args['--force'] == False:
|
|
41
|
+
raise Exception('The path already exists! Exiting...\n')
|
|
42
|
+
# Set default out directory if it doesn't exist
|
|
43
|
+
else:
|
|
44
|
+
args['--out']="data/models/"
|
|
45
|
+
if os.path.exists(args['--out']):
|
|
46
|
+
if args['--force'] == False:
|
|
47
|
+
raise Exception('The path already exists! Exiting...\n')
|
|
48
|
+
# Check for existence of required input
|
|
49
|
+
if args['--fasta']:
|
|
50
|
+
if not os.path.exists(args['--fasta']):
|
|
51
|
+
raise Exception('No fasta file exists at this location! Exiting...\n')
|
|
52
|
+
else:
|
|
53
|
+
raise Exception('No fasta file specified! Exiting...\n')
|
|
54
|
+
if args['--taxa']:
|
|
55
|
+
if not os.path.exists(args['--taxa']):
|
|
56
|
+
raise Exception('No taxa file exists at this location! Exiting...\n')
|
|
57
|
+
else:
|
|
58
|
+
raise Exception('No taxa file specified! Exiting...\n')
|
|
59
|
+
# Default to the species level. Which may or may not work
|
|
60
|
+
if not args['--taxlevel']:
|
|
61
|
+
args['--taxlevel']="SPECIES"
|
|
62
|
+
# Check for the geo data. Set mode of running based on this.
|
|
63
|
+
if args['--geoData']:
|
|
64
|
+
if not os.path.exists(args['--geoData']):
|
|
65
|
+
raise Exception("GeoData path does not exist! Exiting\n")
|
|
66
|
+
else:
|
|
67
|
+
sys.stdout.write("Setting mode to train on "
|
|
68
|
+
"sequence and geographic data\n")
|
|
69
|
+
mode='sequence_geo'
|
|
70
|
+
if not args['--kernel']:
|
|
71
|
+
kernel = 3
|
|
72
|
+
else:
|
|
73
|
+
kernel = int(args['--kernel'])
|
|
74
|
+
if not args['--sigma']:
|
|
75
|
+
sigma = 1
|
|
76
|
+
else:
|
|
77
|
+
sigma = int(args['--sigma'])
|
|
78
|
+
if not args['--precision']:
|
|
79
|
+
precision = 3
|
|
80
|
+
else:
|
|
81
|
+
precision = float(args['--precision'])
|
|
82
|
+
else:
|
|
83
|
+
sys.stdout.write("Setting mode to train on "
|
|
84
|
+
"sequence data only\n")
|
|
85
|
+
mode='sequence'
|
|
86
|
+
if not args['--taxlevel']:
|
|
87
|
+
args['--taxlevel'] = "SPECIES"
|
|
88
|
+
if not args['--maskrate']:
|
|
89
|
+
maskrate = 0
|
|
90
|
+
else:
|
|
91
|
+
maskrate = float(args['--maskrate'])
|
|
92
|
+
if not args['--augments']:
|
|
93
|
+
augments = 5
|
|
94
|
+
else:
|
|
95
|
+
augments = int(args['--augments'])
|
|
96
|
+
print("augments: ", augments)
|
|
97
|
+
if mode == 'sequence':
|
|
98
|
+
eplacer.train_evaluate.train_sequence(args['--fasta'], args['--taxa'], args['--taxlevel'],args['--out'],args['--augments'],args['--maskrate'])
|
|
99
|
+
elif mode == 'sequence_geo':
|
|
100
|
+
eplacer.train_evaluate.train_sequenceOBIS(args['--fasta'], args['--taxa'], args['--taxlevel'],args['--out'],args['--geoData'],
|
|
101
|
+
augments,maskrate,sigma,kernel,precision)
|
|
102
|
+
exit()
|
|
103
|
+
elif args['<command>'] == 'run-model':
|
|
104
|
+
import eplacer.run_command
|
|
105
|
+
import eplacer.run_model
|
|
106
|
+
args = docopt(eplacer.run_command.__doc__, argv=argv)
|
|
107
|
+
# check that the provided directory exists, if provided
|
|
108
|
+
if args['--out']:
|
|
109
|
+
if os.path.exists(args['--out']):
|
|
110
|
+
if args['--force'] == False:
|
|
111
|
+
raise Exception('The path already exists! Exiting...\n')
|
|
112
|
+
# Set default out directory if it doesn't exist
|
|
113
|
+
else:
|
|
114
|
+
args['--out']="result/models/"
|
|
115
|
+
if os.path.exists(args['--out']):
|
|
116
|
+
if args['--force'] == False:
|
|
117
|
+
raise Exception('The path already exists! Exiting...\n')
|
|
118
|
+
if args['--fasta']:
|
|
119
|
+
if not os.path.exists(args['--fasta']):
|
|
120
|
+
raise Exception('No fasta file exists at this location! Exiting...\n')
|
|
121
|
+
else:
|
|
122
|
+
raise Exception('No fasta file specified! Exiting...\n')
|
|
123
|
+
if args['--blast']:
|
|
124
|
+
if not os.path.exists(args['--blast']):
|
|
125
|
+
raise Exception('No blast result file exists at this location! Exiting...\n')
|
|
126
|
+
else:
|
|
127
|
+
raise Exception('No blast result file specified! Exiting...\n')
|
|
128
|
+
if args['--model']:
|
|
129
|
+
if not os.path.exists(args['--model']):
|
|
130
|
+
raise Exception('No model file exists at this location! Exiting...\n')
|
|
131
|
+
args['--taxfile'] = str(args['--model']) + "/taxa_key_SPECIES.tsv"
|
|
132
|
+
if not os.path.exists(args['--taxfile']):
|
|
133
|
+
raise Exception('No taxfile file exists at this location! Your model directory may be corrupted. Exiting...\n')
|
|
134
|
+
else:
|
|
135
|
+
raise Exception('No model file specified! Exiting...\n')
|
|
136
|
+
if not args['--taxlevel']:
|
|
137
|
+
args['--taxlevel'] = "SPECIES"
|
|
138
|
+
if args['--threads']:
|
|
139
|
+
cpu_count = multiprocessing.cpu_count()
|
|
140
|
+
if int(args['--threads']) > cpu_count:
|
|
141
|
+
args['--threads'] = cpu_count
|
|
142
|
+
print(f"Too many threads requested. setting to {cpu_count}")
|
|
143
|
+
else:
|
|
144
|
+
cpu_count = multiprocessing.cpu_count()
|
|
145
|
+
args['--threads'] = cpu_count
|
|
146
|
+
print(f"No threads requested. setting to {cpu_count}")
|
|
147
|
+
if not args['--maskrate']:
|
|
148
|
+
maskrate = 0
|
|
149
|
+
else:
|
|
150
|
+
maskrate = float(args['--maskrate'])
|
|
151
|
+
# Check for the geo data. Set mode of running based on this.
|
|
152
|
+
if args['--geoData']:
|
|
153
|
+
if not os.path.exists(args['--geoData']):
|
|
154
|
+
raise Exception("GeoData path does not exist! Exiting\n")
|
|
155
|
+
else:
|
|
156
|
+
sys.stdout.write("Setting mode to train on "
|
|
157
|
+
"sequence and geographic data\n")
|
|
158
|
+
mode='sequence_geo'
|
|
159
|
+
if not args['--counts']:
|
|
160
|
+
raise Exception('Abundance matrix not specified! Exiting\n')
|
|
161
|
+
elif not os.path.exists(args['--counts']):
|
|
162
|
+
raise Exception('The path to the count matrix does not exist! Exiting\n')
|
|
163
|
+
if not args['--kernel']:
|
|
164
|
+
kernel = 3
|
|
165
|
+
else:
|
|
166
|
+
kernel = int(args['--kernel'])
|
|
167
|
+
if not args['--sigma']:
|
|
168
|
+
sigma = 1
|
|
169
|
+
else:
|
|
170
|
+
sigma = int(args['--sigma'])
|
|
171
|
+
if not args['--confidence']:
|
|
172
|
+
args['--confidence'] = 0.9
|
|
173
|
+
else:
|
|
174
|
+
raise Exception('No geoData available. Exiting...\n')
|
|
175
|
+
eplacer.run_model.gen_model_output_OBIS(args['--confidence'], args['--blast'], args['--out'],args['--fasta'], args['--model'], args['--taxlevel'],args['--taxfile'],args['--geoData'],args['--counts'],maskrate,sigma,kernel, args['--threads'])
|
|
176
|
+
|
eplacer/data_prep.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This script defines some useful code for prepping datasets
|
|
3
|
+
for ePlacer
|
|
4
|
+
|
|
5
|
+
Author: Christopher Powers
|
|
6
|
+
Institution: NOAA NEFSC PEMAD PBB
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from torch.utils.data import Dataset
|
|
12
|
+
import torch
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
def get_degenerate_bases():
|
|
16
|
+
"""
|
|
17
|
+
Returns a dictionary mapping IUPAC degenerate bases to their possible canonical bases
|
|
18
|
+
"""
|
|
19
|
+
return {
|
|
20
|
+
'A': ['A'],
|
|
21
|
+
'C': ['C'],
|
|
22
|
+
'G': ['G'],
|
|
23
|
+
'T': ['T'],
|
|
24
|
+
'R': ['A', 'G'], # Purine
|
|
25
|
+
'Y': ['C', 'T'], # Pyrimidine
|
|
26
|
+
'M': ['A', 'C'], # Amino
|
|
27
|
+
'K': ['G', 'T'], # Keto
|
|
28
|
+
'S': ['C', 'G'], # Strong
|
|
29
|
+
'W': ['A', 'T'], # Weak
|
|
30
|
+
'H': ['A', 'C', 'T'], # not G
|
|
31
|
+
'B': ['C', 'G', 'T'], # not A
|
|
32
|
+
'V': ['A', 'C', 'G'], # not T
|
|
33
|
+
'D': ['A', 'G', 'T'], # not C
|
|
34
|
+
'N': ['-'], # any base is not informative. Encode as gap
|
|
35
|
+
'-': ['-'] # gap
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
def encode_onehot(seq, mask_token = "N"):
|
|
39
|
+
"""
|
|
40
|
+
Function to encode an individual sequence with one hot encoding,
|
|
41
|
+
handling degenerate bases by averaging their possible canonical forms
|
|
42
|
+
"""
|
|
43
|
+
mapping = {
|
|
44
|
+
"A": [1., 0., 0., 0.],
|
|
45
|
+
"C": [0., 1., 0., 0.],
|
|
46
|
+
"G": [0., 0., 1., 0.],
|
|
47
|
+
"T": [0., 0., 0., 1.],
|
|
48
|
+
"-": [0., 0., 0., 0.],
|
|
49
|
+
mask_token: [0.25, 0.25, 0.25, 0.25]
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
degenerate_bases = get_degenerate_bases()
|
|
53
|
+
# Pre-calculate average encodings for degenerate bases
|
|
54
|
+
for base, possibilities in degenerate_bases.items():
|
|
55
|
+
if base not in mapping and len(possibilities) > 0:
|
|
56
|
+
avg_encoding = np.zeros(4)
|
|
57
|
+
for canonical_base in possibilities:
|
|
58
|
+
if canonical_base in mapping:
|
|
59
|
+
avg_encoding += np.array(mapping[canonical_base])
|
|
60
|
+
mapping[base] = (avg_encoding / len(possibilities)).tolist()
|
|
61
|
+
|
|
62
|
+
# Vectorized operation for the whole sequence
|
|
63
|
+
return np.array([mapping.get(base, [0., 0., 0., 0.]) for base in seq])
|
|
64
|
+
|
|
65
|
+
class SeqGeoDataset(Dataset):
|
|
66
|
+
"""
|
|
67
|
+
Dataset that stores the one-hot encoded data alongside
|
|
68
|
+
the geographic data
|
|
69
|
+
"""
|
|
70
|
+
def __init__(self, sequences, labels, geo_data):
|
|
71
|
+
self.seqs = []
|
|
72
|
+
self.geo = []
|
|
73
|
+
self.taxa_labels = []
|
|
74
|
+
for i in range(0,len(sequences)):
|
|
75
|
+
self.seqs.append(sequences[i])
|
|
76
|
+
for i in range(0,len(labels)):
|
|
77
|
+
self.taxa_labels.append(labels[i])
|
|
78
|
+
self.ohe_seqs = torch.stack([torch.from_numpy(encode_onehot(x)).float() for x in self.seqs])
|
|
79
|
+
self.labels = torch.Tensor(self.taxa_labels).long()
|
|
80
|
+
|
|
81
|
+
for i in range(0,len(geo_data)):
|
|
82
|
+
self.geo.append(geo_data[i])
|
|
83
|
+
self.ohe_seqs = torch.stack([torch.from_numpy(encode_onehot(x)).float() for x in self.seqs])
|
|
84
|
+
self.ohe_geo = torch.stack([torch.from_numpy(x).float() for x in self.geo])
|
|
85
|
+
self.labels = torch.Tensor(self.taxa_labels).long()
|
|
86
|
+
|
|
87
|
+
def __len__(self): return len(self.seqs)
|
|
88
|
+
|
|
89
|
+
def __getitem__(self,idx):
|
|
90
|
+
seq = self.ohe_seqs[idx]
|
|
91
|
+
label = self.labels[idx]
|
|
92
|
+
geo = self.ohe_geo[idx]
|
|
93
|
+
|
|
94
|
+
return seq, geo, label
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
|
eplacer/external.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This script runs mafft as a subprocess and generates
|
|
3
|
+
an alignment.
|
|
4
|
+
|
|
5
|
+
Author: Christopher Powers
|
|
6
|
+
Institution: NOAA NEFSC PEMAD PBB
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
import subprocess
|
|
11
|
+
|
|
12
|
+
def run_mafft(input, reference, moutput, subset_output, threads):
|
|
13
|
+
"""
|
|
14
|
+
Run the MAFFT alignment to add your sequences to a new fasta
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
# Get initial IDs
|
|
18
|
+
names = []
|
|
19
|
+
with open(input, "r") as infile:
|
|
20
|
+
for line in infile:
|
|
21
|
+
if line.startswith(">"):
|
|
22
|
+
names.append(line[1:].rstrip())
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
print("Beginning subprocess...")
|
|
26
|
+
print("Aligning with mafft...")
|
|
27
|
+
command = ["mafft --add", input, "--adjustdirection --thread", str(threads), "--keeplength --reorder", reference, ">", moutput]
|
|
28
|
+
subprocess.run(" ".join(command), shell=True, check=True)
|
|
29
|
+
except subprocess.CalledProcessError as e:
|
|
30
|
+
print(f"MAFFT exection failed with error: {e}")
|
|
31
|
+
except FileNotFoundError:
|
|
32
|
+
print("MAFFT not found. Is it installed/in the path?")
|
|
33
|
+
|
|
34
|
+
# subset the new file
|
|
35
|
+
key = ''
|
|
36
|
+
seq = ''
|
|
37
|
+
seqdict = {}
|
|
38
|
+
|
|
39
|
+
with open(moutput, "r") as infile:
|
|
40
|
+
for line in infile:
|
|
41
|
+
line = line.rstrip()
|
|
42
|
+
if line.startswith(">_R_"):
|
|
43
|
+
if key != '':
|
|
44
|
+
seqdict[key] = seq.upper()
|
|
45
|
+
seq = ''
|
|
46
|
+
key = line[4:]
|
|
47
|
+
elif line.startswith(">"):
|
|
48
|
+
if key != '':
|
|
49
|
+
seqdict[key] = seq.upper()
|
|
50
|
+
seq = ''
|
|
51
|
+
key = line[1:]
|
|
52
|
+
else:
|
|
53
|
+
seq += line
|
|
54
|
+
with open(subset_output, "w") as outfile:
|
|
55
|
+
for s in seqdict:
|
|
56
|
+
if s in names:
|
|
57
|
+
outfile.write(f">{s}\n{seqdict[s]}\n")
|
|
58
|
+
|
|
59
|
+
return seqdict
|
|
60
|
+
|
eplacer/geographicRep.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This script contains several scripts in order
|
|
3
|
+
to represent the distribution of a species geographically.
|
|
4
|
+
|
|
5
|
+
Author: Christopher Powers
|
|
6
|
+
Institution: NOAA NEFSC PEMAD PBB
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pygeohash
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from scipy.spatial import cKDTree
|
|
13
|
+
|
|
14
|
+
class SpeciesGeoEncoder:
|
|
15
|
+
def __init__(self, precision=3, min_lat=-90, max_lat=90, min_lon=-180, max_lon=180):
|
|
16
|
+
"""
|
|
17
|
+
Initialize the encoder with a fixed geographic grid
|
|
18
|
+
"""
|
|
19
|
+
self.precision = precision
|
|
20
|
+
self.min_lat = min_lat
|
|
21
|
+
self.max_lat = max_lat
|
|
22
|
+
self.min_lon = min_lon
|
|
23
|
+
self.max_lon = max_lon
|
|
24
|
+
self.total_precision = 32*precision
|
|
25
|
+
self.lat_divisions = int(self.total_precision/2)
|
|
26
|
+
self.lon_divisions = int(self.total_precision/2)
|
|
27
|
+
|
|
28
|
+
# Create fixed grid of geohashes and corresponding lat/lon points
|
|
29
|
+
self.grid_geohashes, self.grid_points = self._create_geohash_grid()
|
|
30
|
+
self.feature_dimension = len(self.grid_geohashes)
|
|
31
|
+
|
|
32
|
+
# Create KD-tree for efficient nearest neighbor search
|
|
33
|
+
self.kdtree = cKDTree(self.grid_points)
|
|
34
|
+
|
|
35
|
+
# Create mapping from geohash to index for faster lookup
|
|
36
|
+
self.geohash_to_index = {ghash: idx for idx, ghash in enumerate(self.grid_geohashes)}
|
|
37
|
+
|
|
38
|
+
def _create_geohash_grid(self):
|
|
39
|
+
"""Create a fixed grid of geohashes covering the area of interest"""
|
|
40
|
+
geohashes = []
|
|
41
|
+
points = []
|
|
42
|
+
|
|
43
|
+
# Create evenly spaced grid points
|
|
44
|
+
lats = np.linspace(self.min_lat, self.max_lat, self.lat_divisions)
|
|
45
|
+
lons = np.linspace(self.min_lon, self.max_lon, self.lon_divisions)
|
|
46
|
+
|
|
47
|
+
# Generate geohash for each grid point
|
|
48
|
+
for lat in lats:
|
|
49
|
+
for lon in lons:
|
|
50
|
+
ghash = pygeohash.encode(lat, lon, precision=int(self.precision))
|
|
51
|
+
geohashes.append(ghash)
|
|
52
|
+
points.append([lat, lon])
|
|
53
|
+
return geohashes, np.array(points)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def encode_species(self, species_locations):
|
|
58
|
+
"""
|
|
59
|
+
Encode species location data using the fixed geohash grid
|
|
60
|
+
"""
|
|
61
|
+
encoded_data = defaultdict(lambda:np.zeros(self.feature_dimension))
|
|
62
|
+
|
|
63
|
+
for species, locations in species_locations.items():
|
|
64
|
+
# Initialize array for counts
|
|
65
|
+
counts = np.zeros(self.feature_dimension)
|
|
66
|
+
|
|
67
|
+
if not locations:
|
|
68
|
+
encoded_data[species] = counts
|
|
69
|
+
continue
|
|
70
|
+
|
|
71
|
+
# Convert locations to numpy array for batch processing
|
|
72
|
+
locations_array = np.array(locations)
|
|
73
|
+
# Find nearest grid points for all locations at once
|
|
74
|
+
_, indices = self.kdtree.query(locations_array)
|
|
75
|
+
|
|
76
|
+
# Count occurrences using numpy
|
|
77
|
+
unique_indices, occurrence_counts = np.unique(indices, return_counts=True)
|
|
78
|
+
counts[unique_indices] = occurrence_counts
|
|
79
|
+
|
|
80
|
+
# Normalize counts
|
|
81
|
+
total = np.sum(counts)
|
|
82
|
+
if total > 0:
|
|
83
|
+
counts = counts / total
|
|
84
|
+
print(species)
|
|
85
|
+
non_zero_indices = np.nonzero(counts)
|
|
86
|
+
print(counts[non_zero_indices])
|
|
87
|
+
# null distribution. This is 1/100 of the the null distribution
|
|
88
|
+
mask = counts < 0.00165
|
|
89
|
+
counts[mask] = 0
|
|
90
|
+
non_zero_indices = np.nonzero(counts)
|
|
91
|
+
print(counts[non_zero_indices])
|
|
92
|
+
encoded_data[species] = counts
|
|
93
|
+
|
|
94
|
+
return encoded_data
|
|
95
|
+
|
|
96
|
+
def get_feature_dimension(self):
|
|
97
|
+
"""Return the fixed dimension of the feature space"""
|
|
98
|
+
return self.feature_dimension
|
|
99
|
+
|
|
100
|
+
def save_grid(self, filepath):
|
|
101
|
+
"""Save the grid information to a file"""
|
|
102
|
+
grid_info = {
|
|
103
|
+
'precision': self.precision,
|
|
104
|
+
'min_lat': self.min_lat,
|
|
105
|
+
'max_lat': self.max_lat,
|
|
106
|
+
'min_lon': self.min_lon,
|
|
107
|
+
'max_lon': self.max_lon,
|
|
108
|
+
'lat_divisions': self.lat_divisions,
|
|
109
|
+
'lon_divisions': self.lon_divisions,
|
|
110
|
+
'grid_geohashes': self.grid_geohashes,
|
|
111
|
+
'grid_points': self.grid_points,
|
|
112
|
+
'feature_dimension': self.feature_dimension
|
|
113
|
+
}
|
|
114
|
+
print("Saving Grid Info: ")
|
|
115
|
+
for i in grid_info:
|
|
116
|
+
if i != "grid_geohashes":
|
|
117
|
+
print("{}:{}".format(i, grid_info[i]))
|
|
118
|
+
np.save(filepath, grid_info)
|
|
119
|
+
|
|
120
|
+
@classmethod
|
|
121
|
+
def load_grid(cls, filepath):
|
|
122
|
+
"""Load a saved grid"""
|
|
123
|
+
grid_info = np.load(filepath, allow_pickle=True).item()
|
|
124
|
+
encoder = cls(
|
|
125
|
+
precision=grid_info['precision'],
|
|
126
|
+
min_lat=grid_info['min_lat'],
|
|
127
|
+
max_lat=grid_info['max_lat'],
|
|
128
|
+
min_lon=grid_info['min_lon'],
|
|
129
|
+
max_lon=grid_info['max_lon']
|
|
130
|
+
)
|
|
131
|
+
# Verify grid matches
|
|
132
|
+
assert np.array_equal(encoder.grid_points, grid_info['grid_points'])
|
|
133
|
+
return encoder
|