eqcctpro 0.4.6__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eqcctpro/__init__.py +14 -2
- eqcctpro/eqcct_tf_models.py +407 -0
- eqcctpro/functionality.py +1424 -0
- eqcctpro/parallelization.py +1312 -0
- eqcctpro/seisbench_models.py +279 -0
- eqcctpro/tools.py +968 -0
- eqcctpro-0.7.0.dist-info/METADATA +312 -0
- eqcctpro-0.7.0.dist-info/RECORD +10 -0
- {eqcctpro-0.4.6.dist-info → eqcctpro-0.7.0.dist-info}/WHEEL +1 -1
- eqcctpro-0.4.6.dist-info/METADATA +0 -373
- eqcctpro-0.4.6.dist-info/RECORD +0 -5
- {eqcctpro-0.4.6.dist-info → eqcctpro-0.7.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1312 @@
|
|
|
1
|
+
"""
|
|
2
|
+
parallelization.py has access to all Ray functions: mseed_predictor(), parallel_predict(), ModelActor(), and their dependencies.
|
|
3
|
+
It is a level of abstraction so we can make the code more concise and cleaner
|
|
4
|
+
"""
|
|
5
|
+
import os
|
|
6
|
+
import ray
|
|
7
|
+
import csv
|
|
8
|
+
import sys
|
|
9
|
+
import ast
|
|
10
|
+
import math
|
|
11
|
+
import time
|
|
12
|
+
import json
|
|
13
|
+
import queue
|
|
14
|
+
import obspy
|
|
15
|
+
import psutil
|
|
16
|
+
import random
|
|
17
|
+
import numbers
|
|
18
|
+
import logging
|
|
19
|
+
import platform
|
|
20
|
+
import traceback
|
|
21
|
+
import numpy as np
|
|
22
|
+
from .tools import *
|
|
23
|
+
from os import listdir
|
|
24
|
+
from obspy import UTCDateTime
|
|
25
|
+
from datetime import datetime, timedelta
|
|
26
|
+
from logging.handlers import QueueHandler
|
|
27
|
+
|
|
28
|
+
# Dictionary of VRAM requirements (MB) for SeisBench models
|
|
29
|
+
# Format: (parent_model_name, child_model_name): vram_mb
|
|
30
|
+
# This will be populated with values provided by the user later
|
|
31
|
+
SEISBENCH_MODEL_VRAM_MB = {
|
|
32
|
+
# Example entries:
|
|
33
|
+
('PhaseNet', 'original'): 2000.0,
|
|
34
|
+
('EQTransformer', 'stead'): 2500.0,
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
def get_seisbench_model_vram_mb(parent_model_name, child_model_name, default_mb=2000.0):
|
|
38
|
+
"""
|
|
39
|
+
Get VRAM requirement for a SeisBench model.
|
|
40
|
+
"""
|
|
41
|
+
key = (parent_model_name, child_model_name)
|
|
42
|
+
return SEISBENCH_MODEL_VRAM_MB.get(key, default_mb)
|
|
43
|
+
|
|
44
|
+
def parse_time_range(time_string):
|
|
45
|
+
"""
|
|
46
|
+
Parses a time range string and returns start time, end time, and time delta.
|
|
47
|
+
"""
|
|
48
|
+
try:
|
|
49
|
+
start_str, end_str = time_string.split('_')
|
|
50
|
+
start_time = datetime.strptime(start_str, "%Y%m%dT%H%M%SZ")
|
|
51
|
+
end_time = datetime.strptime(end_str, "%Y%m%dT%H%M%SZ")
|
|
52
|
+
time_delta = end_time - start_time
|
|
53
|
+
|
|
54
|
+
return start_time, end_time, time_delta
|
|
55
|
+
|
|
56
|
+
except ValueError as e:
|
|
57
|
+
return None, None, None #Error handling.
|
|
58
|
+
|
|
59
|
+
def _mseed2nparray(args, files_list, station):
|
|
60
|
+
' read miniseed files and from a list of string names and returns 3 dictionaries of numpy arrays, meta data, and time slice info'
|
|
61
|
+
|
|
62
|
+
st = obspy.Stream()
|
|
63
|
+
# Read and process files
|
|
64
|
+
for file in files_list:
|
|
65
|
+
temp_st = obspy.read(file)
|
|
66
|
+
try:
|
|
67
|
+
temp_st.merge(fill_value=0)
|
|
68
|
+
except Exception:
|
|
69
|
+
temp_st.merge(fill_value=0)
|
|
70
|
+
temp_st.detrend('demean')
|
|
71
|
+
if temp_st:
|
|
72
|
+
st += temp_st
|
|
73
|
+
else:
|
|
74
|
+
return None # No data to process, return early
|
|
75
|
+
|
|
76
|
+
# Apply taper and bandpass filter
|
|
77
|
+
max_percentage = 5 / (st[0].stats.delta * st[0].stats.npts) # 5s of data will be tapered
|
|
78
|
+
st.taper(max_percentage=max_percentage, type='cosine')
|
|
79
|
+
freqmin = 1.0
|
|
80
|
+
freqmax = 45.0
|
|
81
|
+
if args["stations_filters"] is not None:
|
|
82
|
+
try:
|
|
83
|
+
df_filters = args["stations_filters"]
|
|
84
|
+
freqmin = df_filters[df_filters.sta == station].iloc[0]["hp"]
|
|
85
|
+
freqmax = df_filters[df_filters.sta == station].iloc[0]["lp"]
|
|
86
|
+
except:
|
|
87
|
+
pass
|
|
88
|
+
st.filter(type='bandpass', freqmin=freqmin, freqmax=freqmax, corners=2, zerophase=True)
|
|
89
|
+
|
|
90
|
+
# Interpolate if necessary
|
|
91
|
+
if any(tr.stats.sampling_rate != 100.0 for tr in st):
|
|
92
|
+
try:
|
|
93
|
+
st.interpolate(100, method="linear")
|
|
94
|
+
except:
|
|
95
|
+
st = _resampling(st)
|
|
96
|
+
|
|
97
|
+
# Trim stream to the common start and end times
|
|
98
|
+
st.trim(min(tr.stats.starttime for tr in st), max(tr.stats.endtime for tr in st), pad=True, fill_value=0)
|
|
99
|
+
start_time = st[0].stats.starttime
|
|
100
|
+
end_time = st[0].stats.endtime
|
|
101
|
+
|
|
102
|
+
# Prepare metadata
|
|
103
|
+
meta = {
|
|
104
|
+
"start_time": start_time,
|
|
105
|
+
"end_time": end_time,
|
|
106
|
+
"trace_name": f"{files_list[0].split('/')[-2]}/{files_list[0].split('/')[-1]}"
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
# Prepare component mapping and types
|
|
110
|
+
data_set = {}
|
|
111
|
+
st_times = []
|
|
112
|
+
components = {tr.stats.channel[-1]: tr for tr in st}
|
|
113
|
+
time_shift = int(60 - (args['overlap'] * 60))
|
|
114
|
+
|
|
115
|
+
# Define preferred components for each column
|
|
116
|
+
components_list = [
|
|
117
|
+
['E', '1'], # Column 0
|
|
118
|
+
['N', '2'], # Column 1
|
|
119
|
+
['Z'] # Column 2
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
current_time = start_time
|
|
123
|
+
while current_time < end_time:
|
|
124
|
+
window_end = current_time + 60
|
|
125
|
+
st_times.append(str(current_time).replace('T', ' ').replace('Z', ''))
|
|
126
|
+
npz_data = np.zeros((6000, 3))
|
|
127
|
+
|
|
128
|
+
for col_idx, comp_options in enumerate(components_list):
|
|
129
|
+
for comp in comp_options:
|
|
130
|
+
if comp in components:
|
|
131
|
+
tr = components[comp].copy().slice(current_time, window_end)
|
|
132
|
+
data = tr.data[:6000]
|
|
133
|
+
# Pad with zeros if data is shorter than 6000 samples
|
|
134
|
+
if len(data) < 6000:
|
|
135
|
+
data = np.pad(data, (0, 6000 - len(data)), 'constant')
|
|
136
|
+
npz_data[:, col_idx] = data
|
|
137
|
+
break # Stop after finding the first available component
|
|
138
|
+
|
|
139
|
+
key = str(current_time).replace('T', ' ').replace('Z', '')
|
|
140
|
+
data_set[key] = npz_data
|
|
141
|
+
current_time += time_shift
|
|
142
|
+
|
|
143
|
+
meta["trace_start_time"] = st_times
|
|
144
|
+
|
|
145
|
+
# Metadata population with default placeholders for now
|
|
146
|
+
try:
|
|
147
|
+
meta.update({
|
|
148
|
+
"receiver_code": st[0].stats.station,
|
|
149
|
+
"instrument_type": 0,
|
|
150
|
+
"network_code": 0,
|
|
151
|
+
"receiver_latitude": 0,
|
|
152
|
+
"receiver_longitude": 0,
|
|
153
|
+
"receiver_elevation_m": 0
|
|
154
|
+
})
|
|
155
|
+
except Exception:
|
|
156
|
+
meta.update({
|
|
157
|
+
"receiver_code": station,
|
|
158
|
+
"instrument_type": 0,
|
|
159
|
+
"network_code": 0,
|
|
160
|
+
"receiver_latitude": 0,
|
|
161
|
+
"receiver_longitude": 0,
|
|
162
|
+
"receiver_elevation_m": 0
|
|
163
|
+
})
|
|
164
|
+
|
|
165
|
+
return meta, data_set, freqmin, freqmax
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _output_writter_prediction(meta, csvPr, Ppicks, Pprob, Spicks, Sprob, detection_memory,prob_memory,predict_writer, idx, cq, cqq):
|
|
169
|
+
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
Writes the detection & picking results into a CSV file.
|
|
173
|
+
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
dataset: hdf5 obj
|
|
177
|
+
Dataset object of the trace.
|
|
178
|
+
|
|
179
|
+
predict_writer: obj
|
|
180
|
+
For writing out the detection/picking results in the CSV file.
|
|
181
|
+
|
|
182
|
+
csvPr: obj
|
|
183
|
+
For writing out the detection/picking results in the CSV file.
|
|
184
|
+
|
|
185
|
+
matches: dic
|
|
186
|
+
It contains the information for the detected and picked event.
|
|
187
|
+
|
|
188
|
+
snr: list of two floats
|
|
189
|
+
Estimated signal to noise ratios for picked P and S phases.
|
|
190
|
+
|
|
191
|
+
detection_memory : list
|
|
192
|
+
Keep the track of detected events.
|
|
193
|
+
|
|
194
|
+
Returns
|
|
195
|
+
-------
|
|
196
|
+
detection_memory : list
|
|
197
|
+
Keep the track of detected events.
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
station_name = meta["receiver_code"]
|
|
203
|
+
station_lat = meta["receiver_latitude"]
|
|
204
|
+
station_lon = meta["receiver_longitude"]
|
|
205
|
+
station_elv = meta["receiver_elevation_m"]
|
|
206
|
+
start_time = meta["trace_start_time"][idx]
|
|
207
|
+
station_name = "{:<4}".format(station_name)
|
|
208
|
+
network_name = meta["network_code"]
|
|
209
|
+
network_name = "{:<2}".format(network_name)
|
|
210
|
+
instrument_type = meta["instrument_type"]
|
|
211
|
+
instrument_type = "{:<2}".format(instrument_type)
|
|
212
|
+
|
|
213
|
+
try:
|
|
214
|
+
start_time = datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S.%f')
|
|
215
|
+
except Exception:
|
|
216
|
+
start_time = datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S')
|
|
217
|
+
|
|
218
|
+
def _date_convertor(r):
|
|
219
|
+
if isinstance(r, str):
|
|
220
|
+
mls = r.split('.')
|
|
221
|
+
if len(mls) == 1:
|
|
222
|
+
new_t = datetime.strptime(r, '%Y-%m-%d %H:%M:%S')
|
|
223
|
+
else:
|
|
224
|
+
new_t = datetime.strptime(r, '%Y-%m-%d %H:%M:%S.%f')
|
|
225
|
+
else:
|
|
226
|
+
new_t = r
|
|
227
|
+
|
|
228
|
+
return new_t
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
p_time = []
|
|
232
|
+
p_prob = []
|
|
233
|
+
PdateTime = []
|
|
234
|
+
if Ppicks[0]!=None:
|
|
235
|
+
#for iP in range(len(Ppicks)):
|
|
236
|
+
#if Ppicks[iP]!=None:
|
|
237
|
+
p_time.append(start_time+timedelta(seconds= Ppicks[0]/100))
|
|
238
|
+
p_prob.append(Pprob[0])
|
|
239
|
+
PdateTime.append(_date_convertor(start_time+timedelta(seconds= Ppicks[0]/100)))
|
|
240
|
+
detection_memory.append(p_time)
|
|
241
|
+
prob_memory.append(p_prob)
|
|
242
|
+
else:
|
|
243
|
+
p_time.append(None)
|
|
244
|
+
p_prob.append(None)
|
|
245
|
+
PdateTime.append(None)
|
|
246
|
+
|
|
247
|
+
s_time = []
|
|
248
|
+
s_prob = []
|
|
249
|
+
SdateTime=[]
|
|
250
|
+
if Spicks[0]!=None:
|
|
251
|
+
#for iS in range(len(Spicks)):
|
|
252
|
+
#if Spicks[iS]!=None:
|
|
253
|
+
s_time.append(start_time+timedelta(seconds= Spicks[0]/100))
|
|
254
|
+
s_prob.append(Sprob[0])
|
|
255
|
+
SdateTime.append(_date_convertor(start_time+timedelta(seconds= Spicks[0]/100)))
|
|
256
|
+
else:
|
|
257
|
+
s_time.append(None)
|
|
258
|
+
s_prob.append(None)
|
|
259
|
+
SdateTime.append(None)
|
|
260
|
+
|
|
261
|
+
SdateTime = np.array(SdateTime)
|
|
262
|
+
s_prob = np.array(s_prob)
|
|
263
|
+
|
|
264
|
+
p_prob = np.array(p_prob)
|
|
265
|
+
PdateTime = np.array(PdateTime)
|
|
266
|
+
|
|
267
|
+
predict_writer.writerow([meta["trace_name"],
|
|
268
|
+
network_name,
|
|
269
|
+
station_name,
|
|
270
|
+
instrument_type,
|
|
271
|
+
station_lat,
|
|
272
|
+
station_lon,
|
|
273
|
+
station_elv,
|
|
274
|
+
PdateTime[0],
|
|
275
|
+
p_prob[0],
|
|
276
|
+
SdateTime[0],
|
|
277
|
+
s_prob[0]
|
|
278
|
+
])
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
csvPr.flush()
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
return detection_memory,prob_memory
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def _get_snr(data, pat, window=200):
|
|
289
|
+
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
Estimates SNR.
|
|
293
|
+
|
|
294
|
+
Parameters
|
|
295
|
+
----------
|
|
296
|
+
data : numpy array
|
|
297
|
+
3 component data.
|
|
298
|
+
|
|
299
|
+
pat: positive integer
|
|
300
|
+
Sample point where a specific phase arrives.
|
|
301
|
+
|
|
302
|
+
window: positive integer, default=200
|
|
303
|
+
The length of the window for calculating the SNR (in the sample).
|
|
304
|
+
|
|
305
|
+
Returns
|
|
306
|
+
--------
|
|
307
|
+
snr : {float, None}
|
|
308
|
+
Estimated SNR in db.
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
"""
|
|
312
|
+
import math
|
|
313
|
+
snr = None
|
|
314
|
+
if pat:
|
|
315
|
+
try:
|
|
316
|
+
if int(pat) >= window and (int(pat)+window) < len(data):
|
|
317
|
+
nw1 = data[int(pat)-window : int(pat)];
|
|
318
|
+
sw1 = data[int(pat) : int(pat)+window];
|
|
319
|
+
snr = round(10*math.log10((np.percentile(sw1,95)/np.percentile(nw1,95))**2), 1)
|
|
320
|
+
elif int(pat) < window and (int(pat)+window) < len(data):
|
|
321
|
+
window = int(pat)
|
|
322
|
+
nw1 = data[int(pat)-window : int(pat)];
|
|
323
|
+
sw1 = data[int(pat) : int(pat)+window];
|
|
324
|
+
snr = round(10*math.log10((np.percentile(sw1,95)/np.percentile(nw1,95))**2), 1)
|
|
325
|
+
elif (int(pat)+window) > len(data):
|
|
326
|
+
window = len(data)-int(pat)
|
|
327
|
+
nw1 = data[int(pat)-window : int(pat)];
|
|
328
|
+
sw1 = data[int(pat) : int(pat)+window];
|
|
329
|
+
snr = round(10*math.log10((np.percentile(sw1,95)/np.percentile(nw1,95))**2), 1)
|
|
330
|
+
except Exception:
|
|
331
|
+
pass
|
|
332
|
+
return snr
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def _detect_peaks(x, mph=None, mpd=1, threshold=0, edge='rising', kpsh=False, valley=False):
|
|
336
|
+
|
|
337
|
+
"""
|
|
338
|
+
|
|
339
|
+
Detect peaks in data based on their amplitude and other features.
|
|
340
|
+
|
|
341
|
+
Parameters
|
|
342
|
+
----------
|
|
343
|
+
x : 1D array_like
|
|
344
|
+
data.
|
|
345
|
+
|
|
346
|
+
mph : {None, number}, default=None
|
|
347
|
+
detect peaks that are greater than minimum peak height.
|
|
348
|
+
|
|
349
|
+
mpd : int, default=1
|
|
350
|
+
detect peaks that are at least separated by minimum peak distance (in number of data).
|
|
351
|
+
|
|
352
|
+
threshold : int, default=0
|
|
353
|
+
detect peaks (valleys) that are greater (smaller) than `threshold in relation to their immediate neighbors.
|
|
354
|
+
|
|
355
|
+
edge : str, default=rising
|
|
356
|
+
for a flat peak, keep only the rising edge ('rising'), only the falling edge ('falling'), both edges ('both'), or don't detect a flat peak (None).
|
|
357
|
+
|
|
358
|
+
kpsh : bool, default=False
|
|
359
|
+
keep peaks with same height even if they are closer than `mpd`.
|
|
360
|
+
|
|
361
|
+
valley : bool, default=False
|
|
362
|
+
if True (1), detect valleys (local minima) instead of peaks.
|
|
363
|
+
|
|
364
|
+
Returns
|
|
365
|
+
-------
|
|
366
|
+
ind : 1D array_like
|
|
367
|
+
indeces of the peaks in `x`.
|
|
368
|
+
|
|
369
|
+
Modified from
|
|
370
|
+
----------
|
|
371
|
+
.. [1] http://nbviewer.ipython.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
"""
|
|
375
|
+
|
|
376
|
+
x = np.atleast_1d(x).astype('float64')
|
|
377
|
+
if x.size < 3:
|
|
378
|
+
return np.array([], dtype=int)
|
|
379
|
+
if valley:
|
|
380
|
+
x = -x
|
|
381
|
+
# find indices of all peaks
|
|
382
|
+
dx = x[1:] - x[:-1]
|
|
383
|
+
# handle NaN's
|
|
384
|
+
indnan = np.where(np.isnan(x))[0]
|
|
385
|
+
if indnan.size:
|
|
386
|
+
x[indnan] = np.inf
|
|
387
|
+
dx[np.where(np.isnan(dx))[0]] = np.inf
|
|
388
|
+
ine, ire, ife = np.array([[], [], []], dtype=int)
|
|
389
|
+
if not edge:
|
|
390
|
+
ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0]
|
|
391
|
+
else:
|
|
392
|
+
if edge.lower() in ['rising', 'both']:
|
|
393
|
+
ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0]
|
|
394
|
+
if edge.lower() in ['falling', 'both']:
|
|
395
|
+
ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0]
|
|
396
|
+
ind = np.unique(np.hstack((ine, ire, ife)))
|
|
397
|
+
# handle NaN's
|
|
398
|
+
if ind.size and indnan.size:
|
|
399
|
+
# NaN's and values close to NaN's cannot be peaks
|
|
400
|
+
ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan-1, indnan+1))), invert=True)]
|
|
401
|
+
# first and last values of x cannot be peaks
|
|
402
|
+
if ind.size and ind[0] == 0:
|
|
403
|
+
ind = ind[1:]
|
|
404
|
+
if ind.size and ind[-1] == x.size-1:
|
|
405
|
+
ind = ind[:-1]
|
|
406
|
+
# remove peaks < minimum peak height
|
|
407
|
+
if ind.size and mph is not None:
|
|
408
|
+
ind = ind[x[ind] >= mph]
|
|
409
|
+
# remove peaks - neighbors < threshold
|
|
410
|
+
if ind.size and threshold > 0:
|
|
411
|
+
dx = np.min(np.vstack([x[ind]-x[ind-1], x[ind]-x[ind+1]]), axis=0)
|
|
412
|
+
ind = np.delete(ind, np.where(dx < threshold)[0])
|
|
413
|
+
# detect small peaks closer than minimum peak distance
|
|
414
|
+
if ind.size and mpd > 1:
|
|
415
|
+
ind = ind[np.argsort(x[ind])][::-1] # sort ind by peak height
|
|
416
|
+
idel = np.zeros(ind.size, dtype=bool)
|
|
417
|
+
for i in range(ind.size):
|
|
418
|
+
if not idel[i]:
|
|
419
|
+
# keep peaks with the same height if kpsh is True
|
|
420
|
+
idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) \
|
|
421
|
+
& (x[ind[i]] > x[ind] if kpsh else True)
|
|
422
|
+
idel[i] = 0 # Keep current peak
|
|
423
|
+
# remove the small peaks and sort back the indices by their occurrence
|
|
424
|
+
ind = np.sort(ind[~idel])
|
|
425
|
+
|
|
426
|
+
return ind
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def _picker(args, yh3, thr_type='P_threshold'):
|
|
430
|
+
"""
|
|
431
|
+
Performs detection and picking.
|
|
432
|
+
|
|
433
|
+
Parameters
|
|
434
|
+
----------
|
|
435
|
+
args : dic
|
|
436
|
+
A dictionary containing all of the input parameters.
|
|
437
|
+
|
|
438
|
+
yh1 : 1D array
|
|
439
|
+
probability.
|
|
440
|
+
|
|
441
|
+
Returns
|
|
442
|
+
--------
|
|
443
|
+
Ppickall: Pick.
|
|
444
|
+
Pproball: Pick Probability.
|
|
445
|
+
|
|
446
|
+
"""
|
|
447
|
+
P_PICKall=[]
|
|
448
|
+
Ppickall=[]
|
|
449
|
+
Pproball = []
|
|
450
|
+
perrorall=[]
|
|
451
|
+
|
|
452
|
+
sP_arr = _detect_peaks(yh3, mph=args[thr_type], mpd=1)
|
|
453
|
+
|
|
454
|
+
P_PICKS = []
|
|
455
|
+
pick_errors = []
|
|
456
|
+
if len(sP_arr) > 0:
|
|
457
|
+
P_uncertainty = None
|
|
458
|
+
|
|
459
|
+
for pick in range(len(sP_arr)):
|
|
460
|
+
sauto = sP_arr[pick]
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
if sauto:
|
|
464
|
+
P_prob = np.round(yh3[int(sauto)], 3)
|
|
465
|
+
P_PICKS.append([sauto,P_prob, P_uncertainty])
|
|
466
|
+
|
|
467
|
+
so=[]
|
|
468
|
+
si=[]
|
|
469
|
+
P_PICKS = np.array(P_PICKS)
|
|
470
|
+
P_PICKall.append(P_PICKS)
|
|
471
|
+
for ij in P_PICKS:
|
|
472
|
+
so.append(ij[1])
|
|
473
|
+
si.append(ij[0])
|
|
474
|
+
try:
|
|
475
|
+
so = np.array(so)
|
|
476
|
+
inds = np.argmax(so)
|
|
477
|
+
swave = si[inds]
|
|
478
|
+
Ppickall.append((swave))
|
|
479
|
+
Pproball.append((np.max(so)))
|
|
480
|
+
except:
|
|
481
|
+
Ppickall.append(None)
|
|
482
|
+
Pproball.append(None)
|
|
483
|
+
|
|
484
|
+
#print(np.shape(Ppickall))
|
|
485
|
+
#Ppickall = np.array(Ppickall)
|
|
486
|
+
#Pproball = np.array(Pproball)
|
|
487
|
+
|
|
488
|
+
return Ppickall, Pproball
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def _resampling(st):
|
|
492
|
+
'perform resampling on Obspy stream objects'
|
|
493
|
+
|
|
494
|
+
need_resampling = [tr for tr in st if tr.stats.sampling_rate != 100.0]
|
|
495
|
+
if len(need_resampling) > 0:
|
|
496
|
+
# print('resampling ...', flush=True)
|
|
497
|
+
for indx, tr in enumerate(need_resampling):
|
|
498
|
+
if tr.stats.delta < 0.01:
|
|
499
|
+
tr.filter('lowpass',freq=45,zerophase=True)
|
|
500
|
+
tr.resample(100)
|
|
501
|
+
tr.stats.sampling_rate = 100
|
|
502
|
+
tr.stats.delta = 0.01
|
|
503
|
+
tr.data.dtype = 'int32'
|
|
504
|
+
st.remove(tr)
|
|
505
|
+
st.append(tr)
|
|
506
|
+
return st
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def _normalize(data, mode = 'max'):
|
|
510
|
+
"""
|
|
511
|
+
|
|
512
|
+
Normalize 3D arrays.
|
|
513
|
+
|
|
514
|
+
Parameters
|
|
515
|
+
----------
|
|
516
|
+
data : 3D numpy array
|
|
517
|
+
3 component traces.
|
|
518
|
+
|
|
519
|
+
mode : str, default='std'
|
|
520
|
+
Mode of normalization. 'max' or 'std'
|
|
521
|
+
|
|
522
|
+
Returns
|
|
523
|
+
-------
|
|
524
|
+
data : 3D numpy array
|
|
525
|
+
normalized data.
|
|
526
|
+
|
|
527
|
+
"""
|
|
528
|
+
|
|
529
|
+
data -= np.mean(data, axis=0, keepdims=True)
|
|
530
|
+
if mode == 'max':
|
|
531
|
+
max_data = np.max(data, axis=0, keepdims=True)
|
|
532
|
+
assert(max_data.shape[-1] == data.shape[-1])
|
|
533
|
+
max_data[max_data == 0] = 1
|
|
534
|
+
data /= max_data
|
|
535
|
+
|
|
536
|
+
elif mode == 'std':
|
|
537
|
+
std_data = np.std(data, axis=0, keepdims=True)
|
|
538
|
+
assert(std_data.shape[-1] == data.shape[-1])
|
|
539
|
+
std_data[std_data == 0] = 1
|
|
540
|
+
data /= std_data
|
|
541
|
+
return data
|
|
542
|
+
|
|
543
|
+
@ray.remote
|
|
544
|
+
def mseed_predictor(input_dir='downloads_mseeds',
|
|
545
|
+
output_dir="detections",
|
|
546
|
+
P_threshold=0.1,
|
|
547
|
+
S_threshold=0.1,
|
|
548
|
+
normalization_mode='std',
|
|
549
|
+
dt=1,
|
|
550
|
+
batch_size=500,
|
|
551
|
+
overlap=0.3,
|
|
552
|
+
gpu_id=None,
|
|
553
|
+
gpu_limit=None,
|
|
554
|
+
overwrite=False,
|
|
555
|
+
log_queue=None,
|
|
556
|
+
stations2use=None,
|
|
557
|
+
stations_filters=None,
|
|
558
|
+
p_model=None,
|
|
559
|
+
s_model=None,
|
|
560
|
+
number_of_concurrent_station_predictions=None,
|
|
561
|
+
ray_cpus=None,
|
|
562
|
+
use_gpu=False,
|
|
563
|
+
gpu_memory_limit_mb=None,
|
|
564
|
+
testing_gpu=None,
|
|
565
|
+
test_csv_filepath=None,
|
|
566
|
+
specific_stations=None,
|
|
567
|
+
timechunk_id=None,
|
|
568
|
+
waveform_overlap=None,
|
|
569
|
+
total_timechunks=None,
|
|
570
|
+
number_of_concurrent_timechunk_predictions=None,
|
|
571
|
+
total_analysis_time=None,
|
|
572
|
+
intra_threads=None,
|
|
573
|
+
inter_threads=None,
|
|
574
|
+
timechunk_dt=None,
|
|
575
|
+
# SeisBench model parameters
|
|
576
|
+
model_type='eqcct',
|
|
577
|
+
seisbench_parent_model=None,
|
|
578
|
+
seisbench_child_model=None,
|
|
579
|
+
Detection_threshold=0.3):
|
|
580
|
+
|
|
581
|
+
"""
|
|
582
|
+
|
|
583
|
+
To perform fast detection directly on mseed data.
|
|
584
|
+
|
|
585
|
+
Parameters
|
|
586
|
+
----------
|
|
587
|
+
input_dir: str
|
|
588
|
+
Directory name containing hdf5 and csv files-preprocessed data.
|
|
589
|
+
|
|
590
|
+
input_model: str
|
|
591
|
+
Path to a trained model.
|
|
592
|
+
|
|
593
|
+
stations_json: str
|
|
594
|
+
Path to a JSON file containing station information.
|
|
595
|
+
|
|
596
|
+
output_dir: str
|
|
597
|
+
Output directory that will be generated.
|
|
598
|
+
|
|
599
|
+
P_threshold: float, default=0.1
|
|
600
|
+
A value which the P probabilities above it will be considered as P arrival.
|
|
601
|
+
|
|
602
|
+
S_threshold: float, default=0.1
|
|
603
|
+
A value which the S probabilities above it will be considered as S arrival.
|
|
604
|
+
|
|
605
|
+
normalization_mode: str, default=std
|
|
606
|
+
Mode of normalization for data preprocessing max maximum amplitude among three components std standard deviation.
|
|
607
|
+
|
|
608
|
+
batch_size: int, default=500
|
|
609
|
+
Batch size. This wont affect the speed much but can affect the performance. A value beteen 200 to 1000 is recommended.
|
|
610
|
+
|
|
611
|
+
overlap: float, default=0.3
|
|
612
|
+
If set the detection and picking are performed in overlapping windows.
|
|
613
|
+
|
|
614
|
+
gpu_id: int
|
|
615
|
+
Id of GPU used for the prediction. If using CPU set to None.
|
|
616
|
+
|
|
617
|
+
gpu_limit: float
|
|
618
|
+
Set the maximum percentage of memory usage for the GPU.
|
|
619
|
+
|
|
620
|
+
overwrite: Bolean, default=False
|
|
621
|
+
Overwrite your results automatically.
|
|
622
|
+
|
|
623
|
+
Returns
|
|
624
|
+
--------
|
|
625
|
+
|
|
626
|
+
"""
|
|
627
|
+
|
|
628
|
+
# Set up logger that will write logs to this native process and add them to the log.queue to be added back to the main logger outside of this Raylet
|
|
629
|
+
# worker logger ships records to driver
|
|
630
|
+
logger = logging.getLogger("eqcctpro.worker")
|
|
631
|
+
logger.setLevel(logging.INFO)
|
|
632
|
+
logger.handlers[:] = []
|
|
633
|
+
logger.propagate = False
|
|
634
|
+
log_handler = QueueHandler(log_queue)
|
|
635
|
+
if log_queue is not None:
|
|
636
|
+
logger.addHandler(log_handler) # Ray queue supports put()
|
|
637
|
+
|
|
638
|
+
# We set up the tf_environ again for the Raylets, who adopt their own import state and TF runtime when created.
|
|
639
|
+
# We want to ensure that they are configured properly so that they won't die (bad)
|
|
640
|
+
skip_tf = (model_type.lower() != 'eqcct')
|
|
641
|
+
if not use_gpu:
|
|
642
|
+
tf_environ(gpu_id=-1, intra_threads=intra_threads, inter_threads=inter_threads, logger=logger, skip_tf=skip_tf)
|
|
643
|
+
# tf_environ(gpu_id=1, gpu_memory_limit_mb=gpu_memory_limit_mb, gpus_to_use=gpu_id, intra_threads=intra_threads, inter_threads=inter_threads)
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
args = {
|
|
647
|
+
"input_dir": input_dir,
|
|
648
|
+
"output_dir": output_dir,
|
|
649
|
+
"P_threshold": P_threshold,
|
|
650
|
+
"S_threshold": S_threshold,
|
|
651
|
+
"normalization_mode": normalization_mode,
|
|
652
|
+
"dt": dt,
|
|
653
|
+
"overlap": overlap,
|
|
654
|
+
"batch_size": batch_size,
|
|
655
|
+
"overwrite": overwrite,
|
|
656
|
+
"gpu_id": gpu_id,
|
|
657
|
+
"gpu_limit": gpu_limit,
|
|
658
|
+
"p_model": p_model,
|
|
659
|
+
"s_model": s_model,
|
|
660
|
+
"stations_filters": stations_filters,
|
|
661
|
+
"model_type": model_type,
|
|
662
|
+
"seisbench_parent_model": seisbench_parent_model,
|
|
663
|
+
"seisbench_child_model": seisbench_child_model,
|
|
664
|
+
"Detection_threshold": Detection_threshold
|
|
665
|
+
}
|
|
666
|
+
|
|
667
|
+
logger.info(f"------- Hardware Configuration -------")
|
|
668
|
+
try:
|
|
669
|
+
process = psutil.Process(os.getpid())
|
|
670
|
+
process.cpu_affinity(ray_cpus) # ray_cpus should be a list of core IDs like [0, 1, 2]
|
|
671
|
+
logger.info(f"CPU affinity set to cores: {list(ray_cpus)}")
|
|
672
|
+
logger.info("")
|
|
673
|
+
except Exception as e:
|
|
674
|
+
logger.error(f"Failed to set CPU affinity. Reason: {e}")
|
|
675
|
+
logger.error("")
|
|
676
|
+
sys.exit(1)
|
|
677
|
+
|
|
678
|
+
out_dir = os.path.join(os.getcwd(), str(args['output_dir']))
|
|
679
|
+
try:
|
|
680
|
+
if platform.system() == 'Windows': station_list = [ev.split(".")[0] for ev in listdir(args['input_dir']) if ev.split("\\")[-1] != ".DS_Store"]
|
|
681
|
+
else: station_list = [ev.split(".")[0] for ev in listdir(args['input_dir']) if ev.split("/")[-1] != ".DS_Store"]
|
|
682
|
+
station_list = sorted(set(station_list))
|
|
683
|
+
except Exception as e:
|
|
684
|
+
logger.info(f"{e}")
|
|
685
|
+
return # To-Do: Fix so that it has a valid return?
|
|
686
|
+
# log.write(f"GPU ID: {args['gpu_id']}; Batch size: {args['batch_size']}")
|
|
687
|
+
logger.info(f"------- Data Preprocessing for EQCCTPro -------")
|
|
688
|
+
logger.info(f"{len(station_list)} station(s) in {args['input_dir']}")
|
|
689
|
+
|
|
690
|
+
if stations2use and stations2use <= len(station_list): # For System Evaluation Execution
|
|
691
|
+
station_list = random.sample(station_list, stations2use) # Randomly choose stations from the sample size
|
|
692
|
+
# log.write(f"Using {len(station_list)} station(s) after selection.")
|
|
693
|
+
|
|
694
|
+
if specific_stations is not None: station_list = [x for x in station_list if x in specific_stations] # For "One Use Run" Over a Given Set of Stations (Just Run EQCCTPro on specific_stations)
|
|
695
|
+
else: station_list = station_list # someone put None thinking that they would be able to run the whole directory in one go
|
|
696
|
+
logger.info(f"Using {len(station_list)} selected station(s): {station_list}.")
|
|
697
|
+
|
|
698
|
+
if not station_list or any(looks_like_timechunk_id(x) for x in station_list):
|
|
699
|
+
# Rebuild from the actual contents of the timechunk dir
|
|
700
|
+
station_list = build_station_list_from_dir(args['input_dir'])
|
|
701
|
+
logger.info(f"Station list rebuilt from directory because it contained a timechunk id or was empty.")
|
|
702
|
+
|
|
703
|
+
tasks_predictor = [[f"({i+1}/{len(station_list)})", station_list[i], out_dir, args] for i in range(len(station_list))]
|
|
704
|
+
|
|
705
|
+
if not tasks_predictor: return
|
|
706
|
+
|
|
707
|
+
# CREATE MODEL ACTOR(S) - Add this before the task loop
|
|
708
|
+
logger.info(f"Creating model actor(s)...")
|
|
709
|
+
|
|
710
|
+
model_type_lower = model_type.lower() if model_type else 'eqcct'
|
|
711
|
+
|
|
712
|
+
if model_type_lower == 'seisbench':
|
|
713
|
+
# Create SeisBench model actors
|
|
714
|
+
if use_gpu:
|
|
715
|
+
# Get VRAM requirement for this SeisBench model
|
|
716
|
+
model_vram_mb = get_seisbench_model_vram_mb(
|
|
717
|
+
seisbench_parent_model,
|
|
718
|
+
seisbench_child_model,
|
|
719
|
+
default_mb=2000.0
|
|
720
|
+
)
|
|
721
|
+
# Use max of requested VRAM or model requirement (similar to EQCCT logic)
|
|
722
|
+
model_vram_mb = max(gpu_memory_limit_mb, model_vram_mb) if gpu_memory_limit_mb else model_vram_mb
|
|
723
|
+
|
|
724
|
+
model_actors = []
|
|
725
|
+
logger.info(f"Using GPUs: {gpu_id}")
|
|
726
|
+
for gpu_idx in gpu_id:
|
|
727
|
+
logger.info(f"Creating SeisBenchModelActor on GPU {gpu_idx} with {model_vram_mb/1024:.2f}GB VRAM requirement...")
|
|
728
|
+
actor = SeisBenchModelActor.options(num_gpus=1, num_cpus=0).remote(
|
|
729
|
+
parent_model_name=seisbench_parent_model,
|
|
730
|
+
child_model_name=seisbench_child_model,
|
|
731
|
+
gpus_to_use=[gpu_idx],
|
|
732
|
+
use_gpu=True
|
|
733
|
+
)
|
|
734
|
+
try:
|
|
735
|
+
ray.get(actor.ready.remote())
|
|
736
|
+
except Exception as e:
|
|
737
|
+
logger.error(f"Failed to create SeisBenchModelActor on GPU {gpu_idx}: {e}")
|
|
738
|
+
raise
|
|
739
|
+
logger.info(f"SeisBenchModelActor created on GPU {gpu_idx}.")
|
|
740
|
+
model_actors.append(actor)
|
|
741
|
+
logger.info(f"Created {len(model_actors)} GPU-sized SeisBenchModelActor(s).")
|
|
742
|
+
else:
|
|
743
|
+
model_actors = [SeisBenchModelActor.options(num_cpus=1).remote(
|
|
744
|
+
parent_model_name=seisbench_parent_model,
|
|
745
|
+
child_model_name=seisbench_child_model,
|
|
746
|
+
gpus_to_use=False,
|
|
747
|
+
use_gpu=False
|
|
748
|
+
)]
|
|
749
|
+
ray.get(model_actors[0].ready.remote())
|
|
750
|
+
logger.info(f"Created a 1 CPU-sized SeisBenchModelActor")
|
|
751
|
+
else:
|
|
752
|
+
# Create EQCCT model actors (original logic)
|
|
753
|
+
if use_gpu:
|
|
754
|
+
# Allocate more VRAM to model actors (they need to hold the full model)
|
|
755
|
+
# Reserve ~2-3GB per model actor, adjust based on your model size
|
|
756
|
+
model_vram_mb = max(gpu_memory_limit_mb, 3000) # At least VRAM or 3GB for EQCCT (subject to change)
|
|
757
|
+
|
|
758
|
+
# Create one model actor per GPU
|
|
759
|
+
model_actors = []
|
|
760
|
+
logger.info(f"Using GPUs: {gpu_id}")
|
|
761
|
+
for gpu_idx in gpu_id: # gpu_id is a list of GPU IDs and gpu_idx is the current GPU ID in the loop
|
|
762
|
+
logger.info(f"Creating ModelActor on GPU {gpu_idx} with {model_vram_mb/1024:.2f}GB VRAM limit...")
|
|
763
|
+
actor = ModelActor.options(num_gpus=1, num_cpus=0).remote(gpus_to_use=[gpu_idx], p_model_path=p_model, s_model_path=s_model, gpu_memory_limit_mb=model_vram_mb, use_gpu=True)
|
|
764
|
+
# Wait for __init__ to complete and raise if error
|
|
765
|
+
try:
|
|
766
|
+
ray.get(actor.ready.remote())
|
|
767
|
+
except Exception as e:
|
|
768
|
+
logger.error(f"Failed to create ModelActor on GPU {gpu_idx}: {e}")
|
|
769
|
+
raise
|
|
770
|
+
logger.info(f"ModelActor created on GPU {gpu_idx}.")
|
|
771
|
+
model_actors.append(actor)
|
|
772
|
+
|
|
773
|
+
logger.info(f"Created {len(model_actors)} GPU-sized ModelActor(s).")
|
|
774
|
+
# Using CUDA_VISIBLE_DEVICES is not a reliable way to report which physical GPU is being used bc Ray can overwrite, clear, or remap the assigned GPU so that each worker sees them as local indices (often starting from 0)
|
|
775
|
+
logger.info(f"[ModelActor] Model successfully loaded onto {'GPU' if use_gpu else 'CPU'}.") # Better way to log is to use ray.get_gpu_ids()
|
|
776
|
+
else:
|
|
777
|
+
# Create CPU model actor
|
|
778
|
+
model_actors = [ModelActor.options(num_cpus=1).remote(p_model_path=p_model, s_model_path=s_model, gpu_memory_limit_mb=None, use_gpu=False)]
|
|
779
|
+
logger.info(f"Created a 1 CPU-sized ModelActor")
|
|
780
|
+
|
|
781
|
+
# Submit tasks to ray in a queue
|
|
782
|
+
tasks_queue = []
|
|
783
|
+
max_pending_tasks = number_of_concurrent_station_predictions
|
|
784
|
+
logger.info(f"Starting EQCCTPro parallelized waveform processing...")
|
|
785
|
+
logger.info("")
|
|
786
|
+
start_time = time.time()
|
|
787
|
+
model_type_lower = model_type.lower() if model_type else 'eqcct'
|
|
788
|
+
if model_type_lower == 'seisbench':
|
|
789
|
+
logger.info(f"------- Analyzing Seismic Waveforms for P and S Picks via SeisBench ({seisbench_parent_model} - {seisbench_child_model}) -------")
|
|
790
|
+
else:
|
|
791
|
+
logger.info(f"------- Analyzing Seismic Waveforms for P and S Picks via EQCCT -------")
|
|
792
|
+
|
|
793
|
+
if timechunk_id is None:
|
|
794
|
+
# derive from the path if caller forgot to pass it
|
|
795
|
+
cand = os.path.basename(input_dir)
|
|
796
|
+
if "_" in cand and len(cand) >= 10:
|
|
797
|
+
timechunk_id = cand
|
|
798
|
+
else:
|
|
799
|
+
raise ValueError("timechunk_id is None and could not be inferred from input_dir; "
|
|
800
|
+
"expected a dir named like YYYYMMDDThhmmssZ_YYYYMMDDThhmmssZ")
|
|
801
|
+
starttime, endtime, time_delta = parse_time_range(timechunk_id)
|
|
802
|
+
|
|
803
|
+
logger.info(f"Analyzing {time_delta} minute timechunk from {starttime} to {endtime} ({waveform_overlap} min overlap)")
|
|
804
|
+
logger.info(f"Processing a total of {len(tasks_predictor)} stations, {max_pending_tasks} at a time.")
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
# Concurrent Prediction(s) Parallel Processing
|
|
808
|
+
try:
|
|
809
|
+
for i in range(len(tasks_predictor)):
|
|
810
|
+
while True:
|
|
811
|
+
# Add new task to queue while max is not reached
|
|
812
|
+
if len(tasks_queue) < max_pending_tasks:
|
|
813
|
+
# SELECT WHICH MODEL ACTOR TO USE (round-robin across GPUs)
|
|
814
|
+
model_actor = model_actors[i % len(model_actors)]
|
|
815
|
+
|
|
816
|
+
# Route to appropriate prediction function based on model type
|
|
817
|
+
if model_type_lower == 'seisbench':
|
|
818
|
+
# SeisBench models use parallel_predict_seisbench
|
|
819
|
+
if use_gpu is False:
|
|
820
|
+
tasks_queue.append(parallel_predict_seisbench.options(num_cpus=0).remote(tasks_predictor[i], model_actor, False))
|
|
821
|
+
elif use_gpu is True:
|
|
822
|
+
# Don't allocate GPUs to workers, only to model actors
|
|
823
|
+
# Use num_cpus=0 to avoid deadlocks when Ray has limited CPUs
|
|
824
|
+
tasks_queue.append(parallel_predict_seisbench.options(num_cpus=0, num_gpus=0).remote(tasks_predictor[i], model_actor, True))
|
|
825
|
+
else:
|
|
826
|
+
# EQCCT models use parallel_predict (original)
|
|
827
|
+
if use_gpu is False:
|
|
828
|
+
tasks_queue.append(parallel_predict.options(num_cpus=0).remote(tasks_predictor[i], model_actor, False))
|
|
829
|
+
elif use_gpu is True:
|
|
830
|
+
# Don't allocate GPUs to workers, only to model actors
|
|
831
|
+
# Use num_cpus=0 to avoid deadlocks when Ray has limited CPUs
|
|
832
|
+
tasks_queue.append(parallel_predict.options(num_cpus=0, num_gpus=0).remote(tasks_predictor[i], model_actor, True))
|
|
833
|
+
break
|
|
834
|
+
# If there are more tasks than maximum, just process them
|
|
835
|
+
else:
|
|
836
|
+
tasks_finished, tasks_queue = ray.wait(tasks_queue, num_returns=1, timeout=None)
|
|
837
|
+
for finished_task in tasks_finished:
|
|
838
|
+
log_entry = ray.get(finished_task)
|
|
839
|
+
logger.info(f'{log_entry}')
|
|
840
|
+
|
|
841
|
+
# After adding all the tasks to queue, process what's left
|
|
842
|
+
while tasks_queue:
|
|
843
|
+
tasks_finished, tasks_queue = ray.wait(tasks_queue, num_returns=1, timeout=None)
|
|
844
|
+
for finished_task in tasks_finished:
|
|
845
|
+
log_entry = ray.get(finished_task)
|
|
846
|
+
logger.info(f'{log_entry}')
|
|
847
|
+
logger.info("")
|
|
848
|
+
|
|
849
|
+
except Exception as e:
|
|
850
|
+
# Catch any error in the parallel processing
|
|
851
|
+
logger.error(f"ERROR in parallel processing at {datetime.now()}")
|
|
852
|
+
logger.error(f"Error: {str(e)}")
|
|
853
|
+
logger.error(traceback.format_exc())
|
|
854
|
+
raise # Re-raise to see the error
|
|
855
|
+
|
|
856
|
+
logger.info(f"------- Parallel Station Waveform Processing Complete For {starttime} to {endtime} Timechunk-------")
|
|
857
|
+
end_time = time.time()
|
|
858
|
+
logger.info(f"Picks saved at {output_dir}Process Runtime: {end_time - start_time:.2f} s")
|
|
859
|
+
|
|
860
|
+
if testing_gpu is not None:
|
|
861
|
+
# Guard: make sure CPUs is an int, not a list
|
|
862
|
+
num_ray_cpus = len(ray_cpus) if isinstance(ray_cpus, (list, tuple)) else int(len(list(ray_cpus)))
|
|
863
|
+
|
|
864
|
+
# Parse the timechunk_id to get start/end times
|
|
865
|
+
if timechunk_id:
|
|
866
|
+
starttime, endtime, time_delta = parse_time_range(timechunk_id)
|
|
867
|
+
timechunk_length_min = time_delta.total_seconds() / 60.0 if time_delta else None
|
|
868
|
+
else:
|
|
869
|
+
timechunk_length_min = None
|
|
870
|
+
|
|
871
|
+
# Determine model name for logging
|
|
872
|
+
if model_type_lower == 'seisbench':
|
|
873
|
+
model_used = f"{seisbench_parent_model}/{seisbench_child_model}"
|
|
874
|
+
else:
|
|
875
|
+
model_used = "eqcct"
|
|
876
|
+
|
|
877
|
+
# To-Do: Add column for CPU IDs
|
|
878
|
+
trial_data = {
|
|
879
|
+
"Trial Number": None, # Will be auto-filled by append_trial_row
|
|
880
|
+
"Stations Used": str(station_list),
|
|
881
|
+
"Number of Stations Used": len(station_list),
|
|
882
|
+
"Number of CPUs Allocated for Ray to Use": num_ray_cpus,
|
|
883
|
+
"Intra-parallelism Threads": intra_threads if intra_threads is not None else "",
|
|
884
|
+
"Inter-parallelism Threads": inter_threads if inter_threads is not None else "",
|
|
885
|
+
"GPUs Used": json.dumps(list(gpu_id)) if (use_gpu and gpu_id is not None) else "[]",
|
|
886
|
+
"Inference Actor Memory Limit (MB)": float(model_vram_mb) if (use_gpu and gpu_memory_limit_mb is not None) else "",
|
|
887
|
+
"Total Waveform Analysis Timespace (min)": float(total_analysis_time.total_seconds() / 60.0) if hasattr(total_analysis_time, "total_seconds") else (float(total_analysis_time) if total_analysis_time else ""),
|
|
888
|
+
"Total Number of Timechunks": int(total_timechunks) if total_timechunks is not None else "",
|
|
889
|
+
"Concurrent Timechunks Used": int(number_of_concurrent_timechunk_predictions) if number_of_concurrent_timechunk_predictions is not None else "",
|
|
890
|
+
"Length of Timechunk (min)": timechunk_length_min if timechunk_length_min is not None else "",
|
|
891
|
+
"Number of Concurrent Station Tasks": int(number_of_concurrent_station_predictions) if number_of_concurrent_station_predictions is not None else "",
|
|
892
|
+
"Total Run time for Picker (s)": round(end_time - start_time, 6),
|
|
893
|
+
"Model Used": model_used,
|
|
894
|
+
"Trial Success": "",
|
|
895
|
+
"Error Message": str(""),
|
|
896
|
+
}
|
|
897
|
+
|
|
898
|
+
append_trial_row(csv_path=test_csv_filepath, trial_data=trial_data)
|
|
899
|
+
logger.info(f"Successfully saved trial data to CSV at {test_csv_filepath}")
|
|
900
|
+
|
|
901
|
+
return "Successfully ran EQCCTPro, exiting..."
|
|
902
|
+
|
|
903
|
+
|
|
904
|
+
@ray.remote
|
|
905
|
+
class ModelActor:
|
|
906
|
+
def __init__(self, p_model_path, s_model_path, gpus_to_use=False, intra_threads=1, inter_threads=1, gpu_memory_limit_mb=None, use_gpu=True):
|
|
907
|
+
self.logger = logging.getLogger("eqcctpro.model_actor")
|
|
908
|
+
self.logger.setLevel(logging.INFO)
|
|
909
|
+
self.logger.handlers[:] = []
|
|
910
|
+
self.logger.propagate = False
|
|
911
|
+
self.logger.addHandler(logging.StreamHandler())
|
|
912
|
+
|
|
913
|
+
self.logger.info("=== ModelActor __init__ STARTED ===")
|
|
914
|
+
self.logger.info(f"p_model_path = {p_model_path}")
|
|
915
|
+
self.logger.info(f"s_model_path = {s_model_path}")
|
|
916
|
+
self.logger.info(f"Exists? P: {os.path.exists(p_model_path)}, S: {os.path.exists(s_model_path)}")
|
|
917
|
+
|
|
918
|
+
if use_gpu:
|
|
919
|
+
# Configure GPU memory for this actor
|
|
920
|
+
# We want one GPU per actor
|
|
921
|
+
try:
|
|
922
|
+
self.logger.info("Calling tf_environ...")
|
|
923
|
+
tf_environ(
|
|
924
|
+
gpu_id=gpus_to_use[0] if gpus_to_use else 0,
|
|
925
|
+
gpus_to_use=None, # First visible GPU only
|
|
926
|
+
vram_limit_mb=gpu_memory_limit_mb,
|
|
927
|
+
intra_threads=intra_threads,
|
|
928
|
+
inter_threads=inter_threads,
|
|
929
|
+
log_device=True,
|
|
930
|
+
logger=self.logger)
|
|
931
|
+
self.logger.info("tf_environ finished.")
|
|
932
|
+
except RuntimeError as e:
|
|
933
|
+
self.logger.error(f"[ModelActor] Error setting memory limit: {e}")
|
|
934
|
+
|
|
935
|
+
# Load the model once
|
|
936
|
+
self.logger.info("Importing/load_eqcct_model...")
|
|
937
|
+
from .eqcct_tf_models import load_eqcct_model
|
|
938
|
+
self.model = load_eqcct_model(p_model_path, s_model_path)
|
|
939
|
+
self.logger.info("Model loaded.")
|
|
940
|
+
|
|
941
|
+
def ready(self):
|
|
942
|
+
"""Simple method to check if the actor is ready"""
|
|
943
|
+
return True
|
|
944
|
+
|
|
945
|
+
def predict(self, data_generator):
|
|
946
|
+
"""Perform prediction using the loaded model"""
|
|
947
|
+
return self.model.predict(data_generator, verbose=0)
|
|
948
|
+
|
|
949
|
+
def predict_from_arrays(self, trace_start_time, data_set, batch_size, norm_mode):
|
|
950
|
+
from .eqcct_tf_models import PreLoadGeneratorTest
|
|
951
|
+
pred_generator = PreLoadGeneratorTest(trace_start_time, data_set,
|
|
952
|
+
batch_size=batch_size, norm_mode=norm_mode)
|
|
953
|
+
return self.model.predict(pred_generator, verbose=0)
|
|
954
|
+
|
|
955
|
+
|
|
956
|
+
@ray.remote
|
|
957
|
+
class SeisBenchModelActor:
|
|
958
|
+
"""
|
|
959
|
+
Ray actor for SeisBench models that loads the model once and shares it across predictions.
|
|
960
|
+
Similar to ModelActor but for SeisBench models (PyTorch-based).
|
|
961
|
+
"""
|
|
962
|
+
def __init__(self, parent_model_name, child_model_name, gpus_to_use=False, use_gpu=True):
|
|
963
|
+
self.logger = logging.getLogger("eqcctpro.seisbench_model_actor")
|
|
964
|
+
self.logger.setLevel(logging.INFO)
|
|
965
|
+
self.logger.handlers[:] = []
|
|
966
|
+
self.logger.propagate = False
|
|
967
|
+
self.logger.addHandler(logging.StreamHandler())
|
|
968
|
+
|
|
969
|
+
self.logger.info("=== SeisBenchModelActor __init__ STARTED ===")
|
|
970
|
+
self.logger.info(f"parent_model_name = {parent_model_name}")
|
|
971
|
+
self.logger.info(f"child_model_name = {child_model_name}")
|
|
972
|
+
self.use_gpu = use_gpu
|
|
973
|
+
self.gpus_to_use = gpus_to_use
|
|
974
|
+
|
|
975
|
+
# Set device for PyTorch (SeisBench uses PyTorch)
|
|
976
|
+
try:
|
|
977
|
+
import torch
|
|
978
|
+
except ImportError:
|
|
979
|
+
self.logger.error("PyTorch (torch) is not installed. SeisBench models require PyTorch.")
|
|
980
|
+
raise ImportError("PyTorch (torch) is not installed. Please install it to use SeisBench models.")
|
|
981
|
+
|
|
982
|
+
if use_gpu:
|
|
983
|
+
# When using Ray with num_gpus=1, the assigned GPU is always visible as cuda:0
|
|
984
|
+
# regardless of its physical ID (0, 1, etc.) because Ray sets CUDA_VISIBLE_DEVICES.
|
|
985
|
+
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
|
986
|
+
self.logger.info(f"Using device: {self.device} (mapped by Ray from physical {gpus_to_use})")
|
|
987
|
+
else:
|
|
988
|
+
self.device = torch.device('cpu')
|
|
989
|
+
self.logger.info("Using CPU device")
|
|
990
|
+
|
|
991
|
+
# Load the SeisBench model
|
|
992
|
+
self.logger.info("Loading SeisBench model...")
|
|
993
|
+
from .seisbench_models import SeisBenchModels
|
|
994
|
+
self.model_wrapper = SeisBenchModels(parent_model_name, child_model_name)
|
|
995
|
+
self.model_wrapper.load_model()
|
|
996
|
+
|
|
997
|
+
# Move model to device if using GPU
|
|
998
|
+
if use_gpu:
|
|
999
|
+
try:
|
|
1000
|
+
if hasattr(self.model_wrapper.model, 'to'):
|
|
1001
|
+
self.model_wrapper.model.to(self.device)
|
|
1002
|
+
self.logger.info(f"Model moved to {self.device}")
|
|
1003
|
+
except Exception as e:
|
|
1004
|
+
self.logger.warning(f"Could not move model to GPU: {e}")
|
|
1005
|
+
|
|
1006
|
+
self.logger.info("SeisBench model loaded successfully.")
|
|
1007
|
+
|
|
1008
|
+
def ready(self):
|
|
1009
|
+
"""Simple method to check if the actor is ready"""
|
|
1010
|
+
return True
|
|
1011
|
+
|
|
1012
|
+
def classify(self, stream, P_threshold=0.3, S_threshold=0.3, Detection_threshold=0.3, **kwargs):
|
|
1013
|
+
"""
|
|
1014
|
+
Classify a stream and return picks.
|
|
1015
|
+
|
|
1016
|
+
Parameters:
|
|
1017
|
+
-----------
|
|
1018
|
+
stream : obspy.Stream
|
|
1019
|
+
3-component ObsPy Stream
|
|
1020
|
+
P_threshold : float
|
|
1021
|
+
P phase detection threshold
|
|
1022
|
+
S_threshold : float
|
|
1023
|
+
S phase detection threshold
|
|
1024
|
+
Detection_threshold : float
|
|
1025
|
+
Detection threshold
|
|
1026
|
+
**kwargs : dict
|
|
1027
|
+
Additional arguments for model.classify()
|
|
1028
|
+
|
|
1029
|
+
Returns:
|
|
1030
|
+
--------
|
|
1031
|
+
ClassifyOutput
|
|
1032
|
+
Object containing picks
|
|
1033
|
+
"""
|
|
1034
|
+
return self.model_wrapper.classify(
|
|
1035
|
+
stream,
|
|
1036
|
+
P_threshold=P_threshold,
|
|
1037
|
+
S_threshold=S_threshold,
|
|
1038
|
+
Detection_threshold=Detection_threshold,
|
|
1039
|
+
**kwargs
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
@ray.remote
|
|
1044
|
+
def parallel_predict_seisbench(predict_args, model_actor, gpu=False):
|
|
1045
|
+
"""
|
|
1046
|
+
Prediction function for SeisBench models.
|
|
1047
|
+
Uses mseed2stream_3c for preprocessing and SeisBenchModelActor for predictions.
|
|
1048
|
+
"""
|
|
1049
|
+
import glob
|
|
1050
|
+
import shutil
|
|
1051
|
+
import csv
|
|
1052
|
+
import logging
|
|
1053
|
+
from logging.handlers import QueueHandler
|
|
1054
|
+
from pathlib import Path
|
|
1055
|
+
from .seisbench_models import mseed2stream_3c
|
|
1056
|
+
|
|
1057
|
+
pos, station, out_dir, args = predict_args
|
|
1058
|
+
|
|
1059
|
+
# Set up logger to forward to the main listener
|
|
1060
|
+
logger = logging.getLogger(f"eqcctpro.worker.{station}")
|
|
1061
|
+
logger.setLevel(logging.INFO)
|
|
1062
|
+
if args.get('log_queue') is not None:
|
|
1063
|
+
logger.addHandler(QueueHandler(args['log_queue']))
|
|
1064
|
+
|
|
1065
|
+
save_dir = os.path.join(out_dir, str(station)+'_outputs')
|
|
1066
|
+
csv_filename = os.path.join(save_dir,'X_prediction_results.csv')
|
|
1067
|
+
|
|
1068
|
+
if os.path.isfile(csv_filename):
|
|
1069
|
+
if args['overwrite']:
|
|
1070
|
+
shutil.rmtree(save_dir)
|
|
1071
|
+
else:
|
|
1072
|
+
return f"{pos} {station}: Skipped (already exists - overwrite=False)."
|
|
1073
|
+
|
|
1074
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
1075
|
+
csvPr_gen = open(csv_filename, 'w')
|
|
1076
|
+
predict_writer = csv.writer(csvPr_gen, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
|
1077
|
+
predict_writer.writerow(['file_name',
|
|
1078
|
+
'network',
|
|
1079
|
+
'station',
|
|
1080
|
+
'instrument_type',
|
|
1081
|
+
'station_lat',
|
|
1082
|
+
'station_lon',
|
|
1083
|
+
'station_elv',
|
|
1084
|
+
'p_arrival_time',
|
|
1085
|
+
'p_probability',
|
|
1086
|
+
's_arrival_time',
|
|
1087
|
+
's_probability'])
|
|
1088
|
+
csvPr_gen.flush()
|
|
1089
|
+
|
|
1090
|
+
start_Predicting = time.time()
|
|
1091
|
+
files_list = glob.glob(f"{args['input_dir']}/{station}/*mseed")
|
|
1092
|
+
|
|
1093
|
+
if not files_list:
|
|
1094
|
+
csvPr_gen.close()
|
|
1095
|
+
return f"{pos} {station}: FAILED - No mSEED files found."
|
|
1096
|
+
|
|
1097
|
+
try:
|
|
1098
|
+
# Use SeisBench preprocessing
|
|
1099
|
+
stream3c, freqmin, freqmax = mseed2stream_3c(args, files_list, station)
|
|
1100
|
+
except Exception as e:
|
|
1101
|
+
csvPr_gen.close()
|
|
1102
|
+
return f"{pos} {station}: FAILED reading mSEED: {str(e)}"
|
|
1103
|
+
|
|
1104
|
+
try:
|
|
1105
|
+
# Get picks from SeisBench model
|
|
1106
|
+
# Use ray.get with a timeout or just normally if we fixed the CPU deadlock
|
|
1107
|
+
classify_output = ray.get(model_actor.classify.remote(
|
|
1108
|
+
stream3c,
|
|
1109
|
+
P_threshold=args.get('P_threshold', 0.3),
|
|
1110
|
+
S_threshold=args.get('S_threshold', 0.3),
|
|
1111
|
+
Detection_threshold=args.get('Detection_threshold', 0.3),
|
|
1112
|
+
strict=False,
|
|
1113
|
+
flexible_horizontal_components=True
|
|
1114
|
+
))
|
|
1115
|
+
|
|
1116
|
+
# Extract metadata from stream
|
|
1117
|
+
station_code = stream3c[0].stats.station if len(stream3c) > 0 else station
|
|
1118
|
+
network_code = stream3c[0].stats.network if len(stream3c) > 0 else ""
|
|
1119
|
+
# Try to get coordinates from stream metadata if available
|
|
1120
|
+
station_lat = getattr(stream3c[0].stats, 'coordinates', {}).get('latitude', 0.0) if len(stream3c) > 0 else 0.0
|
|
1121
|
+
station_lon = getattr(stream3c[0].stats, 'coordinates', {}).get('longitude', 0.0) if len(stream3c) > 0 else 0.0
|
|
1122
|
+
station_elv = getattr(stream3c[0].stats, 'coordinates', {}).get('elevation', 0.0) if len(stream3c) > 0 else 0.0
|
|
1123
|
+
|
|
1124
|
+
# Extract picks from ClassifyOutput
|
|
1125
|
+
picks = classify_output.picks if hasattr(classify_output, 'picks') else []
|
|
1126
|
+
|
|
1127
|
+
# Group picks by time to write to CSV
|
|
1128
|
+
# SeisBench picks are individual. We'll group them if they are very close or just write them.
|
|
1129
|
+
# To match EQCCT style, we'll try to find P and S pairs within a 10s window?
|
|
1130
|
+
# Actually, let's just write them as they come for now, or use a simple grouping.
|
|
1131
|
+
|
|
1132
|
+
p_picks = [p for p in picks if getattr(p, 'phase', 'P').upper() == 'P']
|
|
1133
|
+
s_picks = [p for p in picks if getattr(p, 'phase', 'P').upper() == 'S']
|
|
1134
|
+
|
|
1135
|
+
# Simple pairing: for each P, find the first S that comes after it within 30s
|
|
1136
|
+
used_s = set()
|
|
1137
|
+
for p in p_picks:
|
|
1138
|
+
# Robust attribute extraction for SeisBench Pick objects
|
|
1139
|
+
p_time = getattr(p, 'peak_time', getattr(p, 'start_time', getattr(p, 'time', None)))
|
|
1140
|
+
p_prob = getattr(p, 'peak_value', getattr(p, 'score', getattr(p, 'value', 0.0)))
|
|
1141
|
+
|
|
1142
|
+
if p_time is None:
|
|
1143
|
+
continue
|
|
1144
|
+
|
|
1145
|
+
match_s = None
|
|
1146
|
+
for s in s_picks:
|
|
1147
|
+
s_time = getattr(s, 'peak_time', getattr(s, 'start_time', getattr(s, 'time', None)))
|
|
1148
|
+
if s not in used_s and s_time and 0 < (s_time - p_time) < 30:
|
|
1149
|
+
match_s = s
|
|
1150
|
+
used_s.add(s)
|
|
1151
|
+
break
|
|
1152
|
+
|
|
1153
|
+
if match_s:
|
|
1154
|
+
ms_time = getattr(match_s, 'peak_time', getattr(match_s, 'start_time', getattr(match_s, 'time', None)))
|
|
1155
|
+
ms_prob = getattr(match_s, 'peak_value', getattr(match_s, 'score', getattr(match_s, 'value', 0.0)))
|
|
1156
|
+
s_time_str = ms_time.strftime('%Y-%m-%d %H:%M:%S.%f') if ms_time else ''
|
|
1157
|
+
s_prob_str = f"{ms_prob:.6f}"
|
|
1158
|
+
else:
|
|
1159
|
+
s_time_str = ''
|
|
1160
|
+
s_prob_str = ''
|
|
1161
|
+
|
|
1162
|
+
predict_writer.writerow([
|
|
1163
|
+
station_code,
|
|
1164
|
+
network_code,
|
|
1165
|
+
station_code,
|
|
1166
|
+
0, # instrument_type
|
|
1167
|
+
station_lat,
|
|
1168
|
+
station_lon,
|
|
1169
|
+
station_elv,
|
|
1170
|
+
p_time.strftime('%Y-%m-%d %H:%M:%S.%f'),
|
|
1171
|
+
f"{p_prob:.6f}",
|
|
1172
|
+
s_time_str,
|
|
1173
|
+
s_prob_str
|
|
1174
|
+
])
|
|
1175
|
+
|
|
1176
|
+
# Write remaining S picks
|
|
1177
|
+
for s in s_picks:
|
|
1178
|
+
if s not in used_s:
|
|
1179
|
+
s_time = getattr(s, 'peak_time', getattr(s, 'start_time', getattr(s, 'time', None)))
|
|
1180
|
+
s_prob = getattr(s, 'peak_value', getattr(s, 'score', getattr(s, 'value', 0.0)))
|
|
1181
|
+
if s_time:
|
|
1182
|
+
predict_writer.writerow([
|
|
1183
|
+
station_code,
|
|
1184
|
+
network_code,
|
|
1185
|
+
station_code,
|
|
1186
|
+
0, # instrument_type
|
|
1187
|
+
station_lat,
|
|
1188
|
+
station_lon,
|
|
1189
|
+
station_elv,
|
|
1190
|
+
'',
|
|
1191
|
+
'',
|
|
1192
|
+
s_time.strftime('%Y-%m-%d %H:%M:%S.%f'),
|
|
1193
|
+
f"{s_prob:.6f}"
|
|
1194
|
+
])
|
|
1195
|
+
|
|
1196
|
+
# If no picks found at all, write one row with station info
|
|
1197
|
+
if not picks:
|
|
1198
|
+
predict_writer.writerow([
|
|
1199
|
+
station_code,
|
|
1200
|
+
network_code,
|
|
1201
|
+
station_code,
|
|
1202
|
+
0, # instrument_type
|
|
1203
|
+
station_lat,
|
|
1204
|
+
station_lon,
|
|
1205
|
+
station_elv,
|
|
1206
|
+
'', '', '', ''
|
|
1207
|
+
])
|
|
1208
|
+
|
|
1209
|
+
csvPr_gen.flush()
|
|
1210
|
+
csvPr_gen.close()
|
|
1211
|
+
|
|
1212
|
+
end_Predicting = time.time()
|
|
1213
|
+
delta = (end_Predicting - start_Predicting)
|
|
1214
|
+
return f"{pos} {station}: Finished the prediction in {round(delta,2)}s. (HP={freqmin}, LP={freqmax}, picks={len(picks)})"
|
|
1215
|
+
|
|
1216
|
+
except Exception as exp:
|
|
1217
|
+
if 'csvPr_gen' in locals():
|
|
1218
|
+
csvPr_gen.close()
|
|
1219
|
+
return f"{pos} {station}: FAILED the prediction. {exp}"
|
|
1220
|
+
|
|
1221
|
+
|
|
1222
|
+
@ray.remote
|
|
1223
|
+
def parallel_predict(predict_args, model_actor, gpu=False):
|
|
1224
|
+
"""
|
|
1225
|
+
Modified to use shared ModelActor instead of loading model per task
|
|
1226
|
+
"""
|
|
1227
|
+
# --- QUIET TF C++/Python LOGS BEFORE ANY TF IMPORT ---
|
|
1228
|
+
# We were getting info messages from TF because we were importing it natively from eqcct_tf_models
|
|
1229
|
+
# We need to supress TF first before we import it fully
|
|
1230
|
+
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3") # 3=ERROR
|
|
1231
|
+
os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0") # hide oneDNN banner
|
|
1232
|
+
if not gpu:
|
|
1233
|
+
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1") # don't probe CUDA on CPU tasks
|
|
1234
|
+
|
|
1235
|
+
# Python-side TF/absl logging
|
|
1236
|
+
try:
|
|
1237
|
+
import tensorflow as tf
|
|
1238
|
+
tf.get_logger().setLevel(logging.ERROR)
|
|
1239
|
+
try:
|
|
1240
|
+
from absl import logging as absl_logging
|
|
1241
|
+
absl_logging.set_verbosity(absl_logging.ERROR)
|
|
1242
|
+
except Exception:
|
|
1243
|
+
pass
|
|
1244
|
+
except Exception:
|
|
1245
|
+
# If eqcct_tf_models imports TF later, env vars above will still suppress C++ logs.
|
|
1246
|
+
pass
|
|
1247
|
+
|
|
1248
|
+
from .eqcct_tf_models import Patches, PatchEncoder, StochasticDepth, PreLoadGeneratorTest, load_eqcct_model
|
|
1249
|
+
pos, station, out_dir, args = predict_args
|
|
1250
|
+
|
|
1251
|
+
# NOTE: We removed the model loading code that was causing OOM errors
|
|
1252
|
+
# The model is now shared via the model_actor
|
|
1253
|
+
|
|
1254
|
+
save_dir = os.path.join(out_dir, str(station)+'_outputs')
|
|
1255
|
+
csv_filename = os.path.join(save_dir,'X_prediction_results.csv')
|
|
1256
|
+
|
|
1257
|
+
if os.path.isfile(csv_filename):
|
|
1258
|
+
if args['overwrite']:
|
|
1259
|
+
shutil.rmtree(save_dir)
|
|
1260
|
+
else:
|
|
1261
|
+
return f"{pos} {station}: Skipped (already exists - overwrite=False)."
|
|
1262
|
+
|
|
1263
|
+
os.makedirs(save_dir)
|
|
1264
|
+
csvPr_gen = open(csv_filename, 'w')
|
|
1265
|
+
predict_writer = csv.writer(csvPr_gen, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
|
1266
|
+
predict_writer.writerow(['file_name',
|
|
1267
|
+
'network',
|
|
1268
|
+
'station',
|
|
1269
|
+
'instrument_type',
|
|
1270
|
+
'station_lat',
|
|
1271
|
+
'station_lon',
|
|
1272
|
+
'station_elv',
|
|
1273
|
+
'p_arrival_time',
|
|
1274
|
+
'p_probability',
|
|
1275
|
+
's_arrival_time',
|
|
1276
|
+
's_probability'])
|
|
1277
|
+
csvPr_gen.flush()
|
|
1278
|
+
|
|
1279
|
+
start_Predicting = time.time()
|
|
1280
|
+
files_list = glob.glob(f"{args['input_dir']}/{station}/*mseed")
|
|
1281
|
+
|
|
1282
|
+
try:
|
|
1283
|
+
meta, data_set, hp, lp = _mseed2nparray(args, files_list, station)
|
|
1284
|
+
except Exception:
|
|
1285
|
+
return f"{pos} {station}: FAILED reading mSEED."
|
|
1286
|
+
|
|
1287
|
+
try:
|
|
1288
|
+
params_pred = {'batch_size': args["batch_size"], 'norm_mode': args["normalization_mode"]}
|
|
1289
|
+
pred_generator = PreLoadGeneratorTest(meta["trace_start_time"], data_set, **params_pred)
|
|
1290
|
+
|
|
1291
|
+
# USE THE SHARED MODEL ACTOR INSTEAD OF LOADING MODEL
|
|
1292
|
+
# predP, predS = ray.get(model_actor.predict.remote(pred_generator))\
|
|
1293
|
+
predP, predS = ray.get(model_actor.predict_from_arrays.remote(
|
|
1294
|
+
meta["trace_start_time"], data_set, args["batch_size"], args["normalization_mode"]))
|
|
1295
|
+
|
|
1296
|
+
detection_memory = []
|
|
1297
|
+
prob_memory = []
|
|
1298
|
+
for ix in range(len(predP)):
|
|
1299
|
+
Ppicks, Pprob = _picker(args, predP[ix,:, 0])
|
|
1300
|
+
Spicks, Sprob = _picker(args, predS[ix,:, 0], 'S_threshold')
|
|
1301
|
+
|
|
1302
|
+
detection_memory, prob_memory = _output_writter_prediction(
|
|
1303
|
+
meta, csvPr_gen, Ppicks, Pprob, Spicks, Sprob,
|
|
1304
|
+
detection_memory, prob_memory, predict_writer, ix, len(predP), len(predS)
|
|
1305
|
+
)
|
|
1306
|
+
|
|
1307
|
+
end_Predicting = time.time()
|
|
1308
|
+
delta = (end_Predicting - start_Predicting)
|
|
1309
|
+
return f"{pos} {station}: Finished the prediction in {round(delta,2)}s. (HP={hp}, LP={lp})"
|
|
1310
|
+
|
|
1311
|
+
except Exception as exp:
|
|
1312
|
+
return f"{pos} {station}: FAILED the prediction. {exp}"
|