eqcctpro 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eqcctpro might be problematic. Click here for more details.

@@ -0,0 +1,915 @@
1
+ """
2
+ parallelization.py has access to all Ray functions: mseed_predictor(), parallel_predict(), ModelActor(), and their dependencies.
3
+ It is a level of abstraction so we can make the code more concise and cleaner
4
+ """
5
+ import os
6
+ import ray
7
+ import csv
8
+ import sys
9
+ import ast
10
+ import math
11
+ import time
12
+ import json
13
+ import queue
14
+ import obspy
15
+ import psutil
16
+ import random
17
+ import numbers
18
+ import logging
19
+ import platform
20
+ import traceback
21
+ import numpy as np
22
+ from .tools import *
23
+ from os import listdir
24
+ from obspy import UTCDateTime
25
+ from datetime import datetime, timedelta
26
+ from logging.handlers import QueueHandler
27
+
28
+ def parse_time_range(time_string):
29
+ """
30
+ Parses a time range string and returns start time, end time, and time delta.
31
+ """
32
+ try:
33
+ start_str, end_str = time_string.split('_')
34
+ start_time = datetime.strptime(start_str, "%Y%m%dT%H%M%SZ")
35
+ end_time = datetime.strptime(end_str, "%Y%m%dT%H%M%SZ")
36
+ time_delta = end_time - start_time
37
+
38
+ return start_time, end_time, time_delta
39
+
40
+ except ValueError as e:
41
+ return None, None, None #Error handling.
42
+
43
+ def _mseed2nparray(args, files_list, station):
44
+ ' read miniseed files and from a list of string names and returns 3 dictionaries of numpy arrays, meta data, and time slice info'
45
+
46
+ st = obspy.Stream()
47
+ # Read and process files
48
+ for file in files_list:
49
+ temp_st = obspy.read(file)
50
+ try:
51
+ temp_st.merge(fill_value=0)
52
+ except Exception:
53
+ temp_st.merge(fill_value=0)
54
+ temp_st.detrend('demean')
55
+ if temp_st:
56
+ st += temp_st
57
+ else:
58
+ return None # No data to process, return early
59
+
60
+ # Apply taper and bandpass filter
61
+ max_percentage = 5 / (st[0].stats.delta * st[0].stats.npts) # 5s of data will be tapered
62
+ st.taper(max_percentage=max_percentage, type='cosine')
63
+ freqmin = 1.0
64
+ freqmax = 45.0
65
+ if args["stations_filters"] is not None:
66
+ try:
67
+ df_filters = args["stations_filters"]
68
+ freqmin = df_filters[df_filters.sta == station].iloc[0]["hp"]
69
+ freqmax = df_filters[df_filters.sta == station].iloc[0]["lp"]
70
+ except:
71
+ pass
72
+ st.filter(type='bandpass', freqmin=freqmin, freqmax=freqmax, corners=2, zerophase=True)
73
+
74
+ # Interpolate if necessary
75
+ if any(tr.stats.sampling_rate != 100.0 for tr in st):
76
+ try:
77
+ st.interpolate(100, method="linear")
78
+ except:
79
+ st = _resampling(st)
80
+
81
+ # Trim stream to the common start and end times
82
+ st.trim(min(tr.stats.starttime for tr in st), max(tr.stats.endtime for tr in st), pad=True, fill_value=0)
83
+ start_time = st[0].stats.starttime
84
+ end_time = st[0].stats.endtime
85
+
86
+ # Prepare metadata
87
+ meta = {
88
+ "start_time": start_time,
89
+ "end_time": end_time,
90
+ "trace_name": f"{files_list[0].split('/')[-2]}/{files_list[0].split('/')[-1]}"
91
+ }
92
+
93
+ # Prepare component mapping and types
94
+ data_set = {}
95
+ st_times = []
96
+ components = {tr.stats.channel[-1]: tr for tr in st}
97
+ time_shift = int(60 - (args['overlap'] * 60))
98
+
99
+ # Define preferred components for each column
100
+ components_list = [
101
+ ['E', '1'], # Column 0
102
+ ['N', '2'], # Column 1
103
+ ['Z'] # Column 2
104
+ ]
105
+
106
+ current_time = start_time
107
+ while current_time < end_time:
108
+ window_end = current_time + 60
109
+ st_times.append(str(current_time).replace('T', ' ').replace('Z', ''))
110
+ npz_data = np.zeros((6000, 3))
111
+
112
+ for col_idx, comp_options in enumerate(components_list):
113
+ for comp in comp_options:
114
+ if comp in components:
115
+ tr = components[comp].copy().slice(current_time, window_end)
116
+ data = tr.data[:6000]
117
+ # Pad with zeros if data is shorter than 6000 samples
118
+ if len(data) < 6000:
119
+ data = np.pad(data, (0, 6000 - len(data)), 'constant')
120
+ npz_data[:, col_idx] = data
121
+ break # Stop after finding the first available component
122
+
123
+ key = str(current_time).replace('T', ' ').replace('Z', '')
124
+ data_set[key] = npz_data
125
+ current_time += time_shift
126
+
127
+ meta["trace_start_time"] = st_times
128
+
129
+ # Metadata population with default placeholders for now
130
+ try:
131
+ meta.update({
132
+ "receiver_code": st[0].stats.station,
133
+ "instrument_type": 0,
134
+ "network_code": 0,
135
+ "receiver_latitude": 0,
136
+ "receiver_longitude": 0,
137
+ "receiver_elevation_m": 0
138
+ })
139
+ except Exception:
140
+ meta.update({
141
+ "receiver_code": station,
142
+ "instrument_type": 0,
143
+ "network_code": 0,
144
+ "receiver_latitude": 0,
145
+ "receiver_longitude": 0,
146
+ "receiver_elevation_m": 0
147
+ })
148
+
149
+ return meta, data_set, freqmin, freqmax
150
+
151
+
152
+ def _output_writter_prediction(meta, csvPr, Ppicks, Pprob, Spicks, Sprob, detection_memory,prob_memory,predict_writer, idx, cq, cqq):
153
+
154
+ """
155
+
156
+ Writes the detection & picking results into a CSV file.
157
+
158
+ Parameters
159
+ ----------
160
+ dataset: hdf5 obj
161
+ Dataset object of the trace.
162
+
163
+ predict_writer: obj
164
+ For writing out the detection/picking results in the CSV file.
165
+
166
+ csvPr: obj
167
+ For writing out the detection/picking results in the CSV file.
168
+
169
+ matches: dic
170
+ It contains the information for the detected and picked event.
171
+
172
+ snr: list of two floats
173
+ Estimated signal to noise ratios for picked P and S phases.
174
+
175
+ detection_memory : list
176
+ Keep the track of detected events.
177
+
178
+ Returns
179
+ -------
180
+ detection_memory : list
181
+ Keep the track of detected events.
182
+
183
+
184
+ """
185
+
186
+ station_name = meta["receiver_code"]
187
+ station_lat = meta["receiver_latitude"]
188
+ station_lon = meta["receiver_longitude"]
189
+ station_elv = meta["receiver_elevation_m"]
190
+ start_time = meta["trace_start_time"][idx]
191
+ station_name = "{:<4}".format(station_name)
192
+ network_name = meta["network_code"]
193
+ network_name = "{:<2}".format(network_name)
194
+ instrument_type = meta["instrument_type"]
195
+ instrument_type = "{:<2}".format(instrument_type)
196
+
197
+ try:
198
+ start_time = datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S.%f')
199
+ except Exception:
200
+ start_time = datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S')
201
+
202
+ def _date_convertor(r):
203
+ if isinstance(r, str):
204
+ mls = r.split('.')
205
+ if len(mls) == 1:
206
+ new_t = datetime.strptime(r, '%Y-%m-%d %H:%M:%S')
207
+ else:
208
+ new_t = datetime.strptime(r, '%Y-%m-%d %H:%M:%S.%f')
209
+ else:
210
+ new_t = r
211
+
212
+ return new_t
213
+
214
+
215
+ p_time = []
216
+ p_prob = []
217
+ PdateTime = []
218
+ if Ppicks[0]!=None:
219
+ #for iP in range(len(Ppicks)):
220
+ #if Ppicks[iP]!=None:
221
+ p_time.append(start_time+timedelta(seconds= Ppicks[0]/100))
222
+ p_prob.append(Pprob[0])
223
+ PdateTime.append(_date_convertor(start_time+timedelta(seconds= Ppicks[0]/100)))
224
+ detection_memory.append(p_time)
225
+ prob_memory.append(p_prob)
226
+ else:
227
+ p_time.append(None)
228
+ p_prob.append(None)
229
+ PdateTime.append(None)
230
+
231
+ s_time = []
232
+ s_prob = []
233
+ SdateTime=[]
234
+ if Spicks[0]!=None:
235
+ #for iS in range(len(Spicks)):
236
+ #if Spicks[iS]!=None:
237
+ s_time.append(start_time+timedelta(seconds= Spicks[0]/100))
238
+ s_prob.append(Sprob[0])
239
+ SdateTime.append(_date_convertor(start_time+timedelta(seconds= Spicks[0]/100)))
240
+ else:
241
+ s_time.append(None)
242
+ s_prob.append(None)
243
+ SdateTime.append(None)
244
+
245
+ SdateTime = np.array(SdateTime)
246
+ s_prob = np.array(s_prob)
247
+
248
+ p_prob = np.array(p_prob)
249
+ PdateTime = np.array(PdateTime)
250
+
251
+ predict_writer.writerow([meta["trace_name"],
252
+ network_name,
253
+ station_name,
254
+ instrument_type,
255
+ station_lat,
256
+ station_lon,
257
+ station_elv,
258
+ PdateTime[0],
259
+ p_prob[0],
260
+ SdateTime[0],
261
+ s_prob[0]
262
+ ])
263
+
264
+
265
+
266
+ csvPr.flush()
267
+
268
+
269
+ return detection_memory,prob_memory
270
+
271
+
272
+ def _get_snr(data, pat, window=200):
273
+
274
+ """
275
+
276
+ Estimates SNR.
277
+
278
+ Parameters
279
+ ----------
280
+ data : numpy array
281
+ 3 component data.
282
+
283
+ pat: positive integer
284
+ Sample point where a specific phase arrives.
285
+
286
+ window: positive integer, default=200
287
+ The length of the window for calculating the SNR (in the sample).
288
+
289
+ Returns
290
+ --------
291
+ snr : {float, None}
292
+ Estimated SNR in db.
293
+
294
+
295
+ """
296
+ import math
297
+ snr = None
298
+ if pat:
299
+ try:
300
+ if int(pat) >= window and (int(pat)+window) < len(data):
301
+ nw1 = data[int(pat)-window : int(pat)];
302
+ sw1 = data[int(pat) : int(pat)+window];
303
+ snr = round(10*math.log10((np.percentile(sw1,95)/np.percentile(nw1,95))**2), 1)
304
+ elif int(pat) < window and (int(pat)+window) < len(data):
305
+ window = int(pat)
306
+ nw1 = data[int(pat)-window : int(pat)];
307
+ sw1 = data[int(pat) : int(pat)+window];
308
+ snr = round(10*math.log10((np.percentile(sw1,95)/np.percentile(nw1,95))**2), 1)
309
+ elif (int(pat)+window) > len(data):
310
+ window = len(data)-int(pat)
311
+ nw1 = data[int(pat)-window : int(pat)];
312
+ sw1 = data[int(pat) : int(pat)+window];
313
+ snr = round(10*math.log10((np.percentile(sw1,95)/np.percentile(nw1,95))**2), 1)
314
+ except Exception:
315
+ pass
316
+ return snr
317
+
318
+
319
+ def _detect_peaks(x, mph=None, mpd=1, threshold=0, edge='rising', kpsh=False, valley=False):
320
+
321
+ """
322
+
323
+ Detect peaks in data based on their amplitude and other features.
324
+
325
+ Parameters
326
+ ----------
327
+ x : 1D array_like
328
+ data.
329
+
330
+ mph : {None, number}, default=None
331
+ detect peaks that are greater than minimum peak height.
332
+
333
+ mpd : int, default=1
334
+ detect peaks that are at least separated by minimum peak distance (in number of data).
335
+
336
+ threshold : int, default=0
337
+ detect peaks (valleys) that are greater (smaller) than `threshold in relation to their immediate neighbors.
338
+
339
+ edge : str, default=rising
340
+ for a flat peak, keep only the rising edge ('rising'), only the falling edge ('falling'), both edges ('both'), or don't detect a flat peak (None).
341
+
342
+ kpsh : bool, default=False
343
+ keep peaks with same height even if they are closer than `mpd`.
344
+
345
+ valley : bool, default=False
346
+ if True (1), detect valleys (local minima) instead of peaks.
347
+
348
+ Returns
349
+ -------
350
+ ind : 1D array_like
351
+ indeces of the peaks in `x`.
352
+
353
+ Modified from
354
+ ----------
355
+ .. [1] http://nbviewer.ipython.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb
356
+
357
+
358
+ """
359
+
360
+ x = np.atleast_1d(x).astype('float64')
361
+ if x.size < 3:
362
+ return np.array([], dtype=int)
363
+ if valley:
364
+ x = -x
365
+ # find indices of all peaks
366
+ dx = x[1:] - x[:-1]
367
+ # handle NaN's
368
+ indnan = np.where(np.isnan(x))[0]
369
+ if indnan.size:
370
+ x[indnan] = np.inf
371
+ dx[np.where(np.isnan(dx))[0]] = np.inf
372
+ ine, ire, ife = np.array([[], [], []], dtype=int)
373
+ if not edge:
374
+ ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0]
375
+ else:
376
+ if edge.lower() in ['rising', 'both']:
377
+ ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0]
378
+ if edge.lower() in ['falling', 'both']:
379
+ ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0]
380
+ ind = np.unique(np.hstack((ine, ire, ife)))
381
+ # handle NaN's
382
+ if ind.size and indnan.size:
383
+ # NaN's and values close to NaN's cannot be peaks
384
+ ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan-1, indnan+1))), invert=True)]
385
+ # first and last values of x cannot be peaks
386
+ if ind.size and ind[0] == 0:
387
+ ind = ind[1:]
388
+ if ind.size and ind[-1] == x.size-1:
389
+ ind = ind[:-1]
390
+ # remove peaks < minimum peak height
391
+ if ind.size and mph is not None:
392
+ ind = ind[x[ind] >= mph]
393
+ # remove peaks - neighbors < threshold
394
+ if ind.size and threshold > 0:
395
+ dx = np.min(np.vstack([x[ind]-x[ind-1], x[ind]-x[ind+1]]), axis=0)
396
+ ind = np.delete(ind, np.where(dx < threshold)[0])
397
+ # detect small peaks closer than minimum peak distance
398
+ if ind.size and mpd > 1:
399
+ ind = ind[np.argsort(x[ind])][::-1] # sort ind by peak height
400
+ idel = np.zeros(ind.size, dtype=bool)
401
+ for i in range(ind.size):
402
+ if not idel[i]:
403
+ # keep peaks with the same height if kpsh is True
404
+ idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) \
405
+ & (x[ind[i]] > x[ind] if kpsh else True)
406
+ idel[i] = 0 # Keep current peak
407
+ # remove the small peaks and sort back the indices by their occurrence
408
+ ind = np.sort(ind[~idel])
409
+
410
+ return ind
411
+
412
+
413
+ def _picker(args, yh3, thr_type='P_threshold'):
414
+ """
415
+ Performs detection and picking.
416
+
417
+ Parameters
418
+ ----------
419
+ args : dic
420
+ A dictionary containing all of the input parameters.
421
+
422
+ yh1 : 1D array
423
+ probability.
424
+
425
+ Returns
426
+ --------
427
+ Ppickall: Pick.
428
+ Pproball: Pick Probability.
429
+
430
+ """
431
+ P_PICKall=[]
432
+ Ppickall=[]
433
+ Pproball = []
434
+ perrorall=[]
435
+
436
+ sP_arr = _detect_peaks(yh3, mph=args[thr_type], mpd=1)
437
+
438
+ P_PICKS = []
439
+ pick_errors = []
440
+ if len(sP_arr) > 0:
441
+ P_uncertainty = None
442
+
443
+ for pick in range(len(sP_arr)):
444
+ sauto = sP_arr[pick]
445
+
446
+
447
+ if sauto:
448
+ P_prob = np.round(yh3[int(sauto)], 3)
449
+ P_PICKS.append([sauto,P_prob, P_uncertainty])
450
+
451
+ so=[]
452
+ si=[]
453
+ P_PICKS = np.array(P_PICKS)
454
+ P_PICKall.append(P_PICKS)
455
+ for ij in P_PICKS:
456
+ so.append(ij[1])
457
+ si.append(ij[0])
458
+ try:
459
+ so = np.array(so)
460
+ inds = np.argmax(so)
461
+ swave = si[inds]
462
+ Ppickall.append((swave))
463
+ Pproball.append((np.max(so)))
464
+ except:
465
+ Ppickall.append(None)
466
+ Pproball.append(None)
467
+
468
+ #print(np.shape(Ppickall))
469
+ #Ppickall = np.array(Ppickall)
470
+ #Pproball = np.array(Pproball)
471
+
472
+ return Ppickall, Pproball
473
+
474
+
475
+ def _resampling(st):
476
+ 'perform resampling on Obspy stream objects'
477
+
478
+ need_resampling = [tr for tr in st if tr.stats.sampling_rate != 100.0]
479
+ if len(need_resampling) > 0:
480
+ # print('resampling ...', flush=True)
481
+ for indx, tr in enumerate(need_resampling):
482
+ if tr.stats.delta < 0.01:
483
+ tr.filter('lowpass',freq=45,zerophase=True)
484
+ tr.resample(100)
485
+ tr.stats.sampling_rate = 100
486
+ tr.stats.delta = 0.01
487
+ tr.data.dtype = 'int32'
488
+ st.remove(tr)
489
+ st.append(tr)
490
+ return st
491
+
492
+
493
+ def _normalize(data, mode = 'max'):
494
+ """
495
+
496
+ Normalize 3D arrays.
497
+
498
+ Parameters
499
+ ----------
500
+ data : 3D numpy array
501
+ 3 component traces.
502
+
503
+ mode : str, default='std'
504
+ Mode of normalization. 'max' or 'std'
505
+
506
+ Returns
507
+ -------
508
+ data : 3D numpy array
509
+ normalized data.
510
+
511
+ """
512
+
513
+ data -= np.mean(data, axis=0, keepdims=True)
514
+ if mode == 'max':
515
+ max_data = np.max(data, axis=0, keepdims=True)
516
+ assert(max_data.shape[-1] == data.shape[-1])
517
+ max_data[max_data == 0] = 1
518
+ data /= max_data
519
+
520
+ elif mode == 'std':
521
+ std_data = np.std(data, axis=0, keepdims=True)
522
+ assert(std_data.shape[-1] == data.shape[-1])
523
+ std_data[std_data == 0] = 1
524
+ data /= std_data
525
+ return data
526
+
527
+ @ray.remote
528
+ def mseed_predictor(input_dir='downloads_mseeds',
529
+ output_dir="detections",
530
+ P_threshold=0.1,
531
+ S_threshold=0.1,
532
+ normalization_mode='std',
533
+ dt=1,
534
+ batch_size=500,
535
+ overlap=0.3,
536
+ gpu_id=None,
537
+ gpu_limit=None,
538
+ overwrite=False,
539
+ log_queue=None,
540
+ stations2use=None,
541
+ stations_filters=None,
542
+ p_model=None,
543
+ s_model=None,
544
+ number_of_concurrent_station_predictions=None,
545
+ ray_cpus=None,
546
+ use_gpu=False,
547
+ gpu_memory_limit_mb=None,
548
+ testing_gpu=None,
549
+ test_csv_filepath=None,
550
+ specific_stations=None,
551
+ timechunk_id=None,
552
+ waveform_overlap=None,
553
+ total_timechunks=None,
554
+ number_of_concurrent_timechunk_predictions=None,
555
+ total_analysis_time=None,
556
+ intra_threads=None,
557
+ inter_threads=None,
558
+ timechunk_dt=None):
559
+
560
+ """
561
+
562
+ To perform fast detection directly on mseed data.
563
+
564
+ Parameters
565
+ ----------
566
+ input_dir: str
567
+ Directory name containing hdf5 and csv files-preprocessed data.
568
+
569
+ input_model: str
570
+ Path to a trained model.
571
+
572
+ stations_json: str
573
+ Path to a JSON file containing station information.
574
+
575
+ output_dir: str
576
+ Output directory that will be generated.
577
+
578
+ P_threshold: float, default=0.1
579
+ A value which the P probabilities above it will be considered as P arrival.
580
+
581
+ S_threshold: float, default=0.1
582
+ A value which the S probabilities above it will be considered as S arrival.
583
+
584
+ normalization_mode: str, default=std
585
+ Mode of normalization for data preprocessing max maximum amplitude among three components std standard deviation.
586
+
587
+ batch_size: int, default=500
588
+ Batch size. This wont affect the speed much but can affect the performance. A value beteen 200 to 1000 is recommended.
589
+
590
+ overlap: float, default=0.3
591
+ If set the detection and picking are performed in overlapping windows.
592
+
593
+ gpu_id: int
594
+ Id of GPU used for the prediction. If using CPU set to None.
595
+
596
+ gpu_limit: float
597
+ Set the maximum percentage of memory usage for the GPU.
598
+
599
+ overwrite: Bolean, default=False
600
+ Overwrite your results automatically.
601
+
602
+ Returns
603
+ --------
604
+
605
+ """
606
+
607
+ # Set up logger that will write logs to this native process and add them to the log.queue to be added back to the main logger outside of this Raylet
608
+ # worker logger ships records to driver
609
+ logger = logging.getLogger("eqcctpro.worker")
610
+ logger.setLevel(logging.INFO)
611
+ logger.handlers[:] = []
612
+ logger.propagate = False
613
+ if log_queue is not None:
614
+ logger.addHandler(QueueHandler(log_queue)) # Ray queue supports put()
615
+
616
+ # We set up the tf_environ again for the Raylets, who adopt their own import state and TF runtime when created.
617
+ # We want to ensure that they are configured properly so that they won't die (bad)
618
+ if not use_gpu:
619
+ tf_environ(gpu_id=-1, intra_threads=intra_threads, inter_threads=inter_threads, logger=logger)
620
+ # tf_environ(gpu_id=1, gpu_memory_limit_mb=gpu_memory_limit_mb, gpus_to_use=gpu_id, intra_threads=intra_threads, inter_threads=inter_threads)
621
+
622
+
623
+ args = {
624
+ "input_dir": input_dir,
625
+ "output_dir": output_dir,
626
+ "P_threshold": P_threshold,
627
+ "S_threshold": S_threshold,
628
+ "normalization_mode": normalization_mode,
629
+ "dt": dt,
630
+ "overlap": overlap,
631
+ "batch_size": batch_size,
632
+ "overwrite": overwrite,
633
+ "gpu_id": gpu_id,
634
+ "gpu_limit": gpu_limit,
635
+ "p_model": p_model,
636
+ "s_model": s_model,
637
+ "stations_filters": stations_filters
638
+ }
639
+
640
+ logger.info(f"------- Hardware Configuration -------")
641
+ try:
642
+ process = psutil.Process(os.getpid())
643
+ process.cpu_affinity(ray_cpus) # ray_cpus should be a list of core IDs like [0, 1, 2]
644
+ logger.info(f"CPU affinity set to cores: {list(ray_cpus)}")
645
+ logger.info("")
646
+ except Exception as e:
647
+ logger.error(f"Failed to set CPU affinity. Reason: {e}")
648
+ logger.error("")
649
+ sys.exit(1)
650
+
651
+ out_dir = os.path.join(os.getcwd(), str(args['output_dir']))
652
+ try:
653
+ if platform.system() == 'Windows': station_list = [ev.split(".")[0] for ev in listdir(args['input_dir']) if ev.split("\\")[-1] != ".DS_Store"]
654
+ else: station_list = [ev.split(".")[0] for ev in listdir(args['input_dir']) if ev.split("/")[-1] != ".DS_Store"]
655
+ station_list = sorted(set(station_list))
656
+ except Exception as e:
657
+ logger.info(f"{e}")
658
+ return # To-Do: Fix so that it has a valid return?
659
+ # log.write(f"GPU ID: {args['gpu_id']}; Batch size: {args['batch_size']}")
660
+ logger.info(f"------- Data Preprocessing for EQCCTPro -------")
661
+ logger.info(f"{len(station_list)} station(s) in {args['input_dir']}")
662
+
663
+ if stations2use and stations2use <= len(station_list): # For System Evaluation Execution
664
+ station_list = random.sample(station_list, stations2use) # Randomly choose stations from the sample size
665
+ # log.write(f"Using {len(station_list)} station(s) after selection.")
666
+
667
+ if specific_stations is not None: station_list = [x for x in station_list if x in specific_stations] # For "One Use Run" Over a Given Set of Stations (Just Run EQCCTPro on specific_stations)
668
+ else: station_list = station_list # someone put None thinking that they would be able to run the whole directory in one go
669
+ logger.info(f"Using {len(station_list)} selected station(s): {station_list}.")
670
+
671
+ if not station_list or any(looks_like_timechunk_id(x) for x in station_list):
672
+ # Rebuild from the actual contents of the timechunk dir
673
+ station_list = build_station_list_from_dir(args['input_dir'])
674
+ logger.info(f"Station list rebuilt from directory because it contained a timechunk id or was empty.")
675
+
676
+ tasks_predictor = [[f"({i+1}/{len(station_list)})", station_list[i], out_dir, args] for i in range(len(station_list))]
677
+
678
+ if not tasks_predictor: return
679
+
680
+ # CREATE MODEL ACTOR(S) - Add this before the task loop
681
+ logger.info(f"Creating model actor(s)...")
682
+
683
+ if use_gpu:
684
+ # Allocate more VRAM to model actors (they need to hold the full model)
685
+ # Reserve ~2-3GB per model actor, adjust based on your model size
686
+ model_vram_mb = min(gpu_memory_limit_mb * 2, 3000) # At least 2x task VRAM or 3GB
687
+
688
+ # Create one model actor per GPU
689
+ model_actors = []
690
+ for gpu_idx in gpu_id:
691
+ actor = ModelActor.options(num_gpus=1, num_cpus=0).remote(gpus_to_use=gpu_id, p_model_path=p_model, s_model_path=s_model, gpu_memory_limit_mb=model_vram_mb, use_gpu=True, logger=logger)
692
+ model_actors.append(actor)
693
+
694
+ logger.info(f"Created {len(model_actors)} GPU model actor(s) with {model_vram_mb/1024:.2f}GB VRAM each")
695
+ else:
696
+ # Create CPU model actor
697
+ model_actors = [ModelActor.options(num_cpus=1).remote(p_model_path=p_model, s_model_path=s_model, gpu_memory_limit_mb=None, use_gpu=False, logger=logger)]
698
+ logger.info(f"Created a 1 CPU-sized ModelActor")
699
+
700
+ # Submit tasks to ray in a queue
701
+ tasks_queue = []
702
+ max_pending_tasks = number_of_concurrent_station_predictions
703
+ logger.info(f"Starting EQCCTPro parallelized waveform processing...")
704
+ logger.info("")
705
+ start_time = time.time()
706
+ logger.info(f"------- Analyzing Seismic Waveforms for P and S Picks via EQCCT -------")
707
+
708
+ if timechunk_id is None:
709
+ # derive from the path if caller forgot to pass it
710
+ cand = os.path.basename(input_dir)
711
+ if "_" in cand and len(cand) >= 10:
712
+ timechunk_id = cand
713
+ else:
714
+ raise ValueError("timechunk_id is None and could not be inferred from input_dir; "
715
+ "expected a dir named like YYYYMMDDThhmmssZ_YYYYMMDDThhmmssZ")
716
+ starttime, endtime, time_delta = parse_time_range(timechunk_id)
717
+
718
+ logger.info(f"Analyzing {time_delta} minute timechunk from {starttime} to {endtime} ({waveform_overlap} min overlap)")
719
+ logger.info(f"Processing a total of {len(tasks_predictor)} stations, {max_pending_tasks} at a time.")
720
+
721
+
722
+ # Concurrent Prediction(s) Parallel Processing
723
+ try:
724
+ for i in range(len(tasks_predictor)):
725
+ while True:
726
+ # Add new task to queue while max is not reached
727
+ if len(tasks_queue) < max_pending_tasks:
728
+ # SELECT WHICH MODEL ACTOR TO USE (round-robin across GPUs)
729
+ model_actor = model_actors[i % len(model_actors)]
730
+
731
+ if use_gpu is False:
732
+ tasks_queue.append(parallel_predict.options(num_cpus=0).remote(tasks_predictor[i], model_actor, False, None))
733
+ elif use_gpu is True:
734
+ # Don't allocate GPUs to workers, only to model actors
735
+ tasks_queue.append(parallel_predict.options(num_cpus=0, num_gpus=0).remote(tasks_predictor[i], model_actor, True, gpu_memory_limit_mb))
736
+ break
737
+ # If there are more tasks than maximum, just process them
738
+ else:
739
+ tasks_finished, tasks_queue = ray.wait(tasks_queue, num_returns=1, timeout=None)
740
+ for finished_task in tasks_finished:
741
+ log_entry = ray.get(finished_task)
742
+ logger.info(f'{log_entry}')
743
+
744
+ # After adding all the tasks to queue, process what's left
745
+ while tasks_queue:
746
+ tasks_finished, tasks_queue = ray.wait(tasks_queue, num_returns=1, timeout=None)
747
+ for finished_task in tasks_finished:
748
+ log_entry = ray.get(finished_task)
749
+ logger.info(f'{log_entry}')
750
+ logger.info("")
751
+
752
+ except Exception as e:
753
+ # Catch any error in the parallel processing
754
+ logger.error(f"ERROR in parallel processing at {datetime.now()}")
755
+ logger.error(f"Error: {str(e)}")
756
+ logger.error(traceback.format_exc())
757
+ raise # Re-raise to see the error
758
+
759
+ logger.info(f"------- Parallel Station Waveform Processing Complete For {starttime} to {endtime} Timechunk-------")
760
+ end_time = time.time()
761
+ logger.info(f"Picks saved at {output_dir}Process Runtime: {end_time - start_time:.2f} s")
762
+
763
+ if testing_gpu is not None:
764
+ # Guard: make sure CPUs is an int, not a list
765
+ num_ray_cpus = len(ray_cpus) if isinstance(ray_cpus, (list, tuple)) else int(len(list(ray_cpus)))
766
+
767
+ # Parse the timechunk_id to get start/end times
768
+ if timechunk_id:
769
+ starttime, endtime, time_delta = parse_time_range(timechunk_id)
770
+ timechunk_length_min = time_delta.total_seconds() / 60.0 if time_delta else None
771
+ else:
772
+ timechunk_length_min = None
773
+
774
+ # To-Do: Add column for CPU IDs
775
+ trial_data = {
776
+ "Trial Number": None, # Will be auto-filled by append_trial_row
777
+ "Stations Used": str(station_list),
778
+ "Number of Stations Used": len(station_list),
779
+ "Number of CPUs Allocated for Ray to Use": num_ray_cpus,
780
+ "Intra-parallelism Threads": intra_threads if intra_threads is not None else "",
781
+ "Inter-parallelism Threads": inter_threads if inter_threads is not None else "",
782
+ "GPUs Used": json.dumps(list(gpu_id)) if (use_gpu and gpu_id is not None) else "[]",
783
+ "VRAM Used Per Task": float(gpu_memory_limit_mb) if (use_gpu and gpu_memory_limit_mb is not None) else "",
784
+ "Total Waveform Analysis Timespace (min)": float(total_analysis_time.total_seconds() / 60.0) if hasattr(total_analysis_time, "total_seconds") else (float(total_analysis_time) if total_analysis_time else ""),
785
+ "Total Number of Timechunks": int(total_timechunks) if total_timechunks is not None else "",
786
+ "Concurrent Timechunks Used": int(number_of_concurrent_timechunk_predictions) if number_of_concurrent_timechunk_predictions is not None else "",
787
+ "Length of Timechunk (min)": timechunk_length_min if timechunk_length_min is not None else "",
788
+ "Number of Concurrent Station Tasks": int(number_of_concurrent_station_predictions) if number_of_concurrent_station_predictions is not None else "",
789
+ "Total Run time for Picker (s)": round(end_time - start_time, 6),
790
+ "Trial Success": "",
791
+ "Error Message": str(""),
792
+ }
793
+
794
+ append_trial_row(csv_path=test_csv_filepath, trial_data=trial_data)
795
+ logger.info(f"Successfully saved trial data to CSV at {test_csv_filepath}")
796
+
797
+ return "Successfully ran EQCCTPro, exiting..."
798
+
799
+
800
+ @ray.remote
801
+ class ModelActor:
802
+ def __init__(self, p_model_path, s_model_path, gpus_to_use=False, intra_threads=1, inter_threads=1, gpu_memory_limit_mb=None, use_gpu=True, logger=None):
803
+ if use_gpu and gpu_memory_limit_mb:
804
+ # Configure GPU memory for this actor
805
+ try:
806
+ tf_environ(
807
+ gpu_id=-1,
808
+ gpus_to_use=gpus_to_use,
809
+ vram_limit_mb=gpu_memory_limit_mb,
810
+ intra_threads=intra_threads,
811
+ inter_threads=inter_threads,
812
+ log_device=False,
813
+ logger=None)
814
+ except RuntimeError as e:
815
+ logger.error(f"[ModelActor] Error setting memory limit: {e}")
816
+
817
+ # Load the model once
818
+ from .eqcct_tf_models import load_eqcct_model
819
+ self.model = load_eqcct_model(p_model_path, s_model_path)
820
+ logger.info(f"[ModelActor] Model loaded successfully")
821
+
822
+ def predict(self, data_generator):
823
+ """Perform prediction using the loaded model"""
824
+ return self.model.predict(data_generator, verbose=0)
825
+
826
+
827
+ @ray.remote
828
+ def parallel_predict(predict_args, model_actor, gpu=False, gpu_memory_limit_mb=None):
829
+ """
830
+ Modified to use shared ModelActor instead of loading model per task
831
+ """
832
+ # --- QUIET TF C++/Python LOGS BEFORE ANY TF IMPORT ---
833
+ # We were getting info messages from TF because we were importing it natively from eqcct_tf_models
834
+ # We need to supress TF first before we import it fully
835
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3") # 3=ERROR
836
+ os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0") # hide oneDNN banner
837
+ if not gpu:
838
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1") # don't probe CUDA on CPU tasks
839
+
840
+ # Python-side TF/absl logging
841
+ try:
842
+ import tensorflow as tf
843
+ tf.get_logger().setLevel(logging.ERROR)
844
+ try:
845
+ from absl import logging as absl_logging
846
+ absl_logging.set_verbosity(absl_logging.ERROR)
847
+ except Exception:
848
+ pass
849
+ except Exception:
850
+ # If eqcct_tf_models imports TF later, env vars above will still suppress C++ logs.
851
+ pass
852
+
853
+ from .eqcct_tf_models import Patches, PatchEncoder, StochasticDepth, PreLoadGeneratorTest, load_eqcct_model
854
+ pos, station, out_dir, args = predict_args
855
+
856
+ # NOTE: We removed the model loading code that was causing OOM errors
857
+ # The model is now shared via the model_actor
858
+
859
+ save_dir = os.path.join(out_dir, str(station)+'_outputs')
860
+ csv_filename = os.path.join(save_dir,'X_prediction_results.csv')
861
+
862
+ if os.path.isfile(csv_filename):
863
+ if args['overwrite']:
864
+ shutil.rmtree(save_dir)
865
+ else:
866
+ return f"{pos} {station}: Skipped (already exists - overwrite=False)."
867
+
868
+ os.makedirs(save_dir)
869
+ csvPr_gen = open(csv_filename, 'w')
870
+ predict_writer = csv.writer(csvPr_gen, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
871
+ predict_writer.writerow(['file_name',
872
+ 'network',
873
+ 'station',
874
+ 'instrument_type',
875
+ 'station_lat',
876
+ 'station_lon',
877
+ 'station_elv',
878
+ 'p_arrival_time',
879
+ 'p_probability',
880
+ 's_arrival_time',
881
+ 's_probability'])
882
+ csvPr_gen.flush()
883
+
884
+ start_Predicting = time.time()
885
+ files_list = glob.glob(f"{args['input_dir']}/{station}/*mseed")
886
+
887
+ try:
888
+ meta, data_set, hp, lp = _mseed2nparray(args, files_list, station)
889
+ except Exception:
890
+ return f"{pos} {station}: FAILED reading mSEED."
891
+
892
+ try:
893
+ params_pred = {'batch_size': args["batch_size"], 'norm_mode': args["normalization_mode"]}
894
+ pred_generator = PreLoadGeneratorTest(meta["trace_start_time"], data_set, **params_pred)
895
+
896
+ # USE THE SHARED MODEL ACTOR INSTEAD OF LOADING MODEL
897
+ predP, predS = ray.get(model_actor.predict.remote(pred_generator))
898
+
899
+ detection_memory = []
900
+ prob_memory = []
901
+ for ix in range(len(predP)):
902
+ Ppicks, Pprob = _picker(args, predP[ix,:, 0])
903
+ Spicks, Sprob = _picker(args, predS[ix,:, 0], 'S_threshold')
904
+
905
+ detection_memory, prob_memory = _output_writter_prediction(
906
+ meta, csvPr_gen, Ppicks, Pprob, Spicks, Sprob,
907
+ detection_memory, prob_memory, predict_writer, ix, len(predP), len(predS)
908
+ )
909
+
910
+ end_Predicting = time.time()
911
+ delta = (end_Predicting - start_Predicting)
912
+ return f"{pos} {station}: Finished the prediction in {round(delta,2)}s. (HP={hp}, LP={lp})"
913
+
914
+ except Exception as exp:
915
+ return f"{pos} {station}: FAILED the prediction. {exp}"