eqcctpro 0.6.2__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1312 @@
1
+ """
2
+ parallelization.py has access to all Ray functions: mseed_predictor(), parallel_predict(), ModelActor(), and their dependencies.
3
+ It is a level of abstraction so we can make the code more concise and cleaner
4
+ """
5
+ import os
6
+ import ray
7
+ import csv
8
+ import sys
9
+ import ast
10
+ import math
11
+ import time
12
+ import json
13
+ import queue
14
+ import obspy
15
+ import psutil
16
+ import random
17
+ import numbers
18
+ import logging
19
+ import platform
20
+ import traceback
21
+ import numpy as np
22
+ from .tools import *
23
+ from os import listdir
24
+ from obspy import UTCDateTime
25
+ from datetime import datetime, timedelta
26
+ from logging.handlers import QueueHandler
27
+
28
+ # Dictionary of VRAM requirements (MB) for SeisBench models
29
+ # Format: (parent_model_name, child_model_name): vram_mb
30
+ # This will be populated with values provided by the user later
31
+ SEISBENCH_MODEL_VRAM_MB = {
32
+ # Example entries:
33
+ ('PhaseNet', 'original'): 2000.0,
34
+ ('EQTransformer', 'stead'): 2500.0,
35
+ }
36
+
37
+ def get_seisbench_model_vram_mb(parent_model_name, child_model_name, default_mb=2000.0):
38
+ """
39
+ Get VRAM requirement for a SeisBench model.
40
+ """
41
+ key = (parent_model_name, child_model_name)
42
+ return SEISBENCH_MODEL_VRAM_MB.get(key, default_mb)
43
+
44
+ def parse_time_range(time_string):
45
+ """
46
+ Parses a time range string and returns start time, end time, and time delta.
47
+ """
48
+ try:
49
+ start_str, end_str = time_string.split('_')
50
+ start_time = datetime.strptime(start_str, "%Y%m%dT%H%M%SZ")
51
+ end_time = datetime.strptime(end_str, "%Y%m%dT%H%M%SZ")
52
+ time_delta = end_time - start_time
53
+
54
+ return start_time, end_time, time_delta
55
+
56
+ except ValueError as e:
57
+ return None, None, None #Error handling.
58
+
59
+ def _mseed2nparray(args, files_list, station):
60
+ ' read miniseed files and from a list of string names and returns 3 dictionaries of numpy arrays, meta data, and time slice info'
61
+
62
+ st = obspy.Stream()
63
+ # Read and process files
64
+ for file in files_list:
65
+ temp_st = obspy.read(file)
66
+ try:
67
+ temp_st.merge(fill_value=0)
68
+ except Exception:
69
+ temp_st.merge(fill_value=0)
70
+ temp_st.detrend('demean')
71
+ if temp_st:
72
+ st += temp_st
73
+ else:
74
+ return None # No data to process, return early
75
+
76
+ # Apply taper and bandpass filter
77
+ max_percentage = 5 / (st[0].stats.delta * st[0].stats.npts) # 5s of data will be tapered
78
+ st.taper(max_percentage=max_percentage, type='cosine')
79
+ freqmin = 1.0
80
+ freqmax = 45.0
81
+ if args["stations_filters"] is not None:
82
+ try:
83
+ df_filters = args["stations_filters"]
84
+ freqmin = df_filters[df_filters.sta == station].iloc[0]["hp"]
85
+ freqmax = df_filters[df_filters.sta == station].iloc[0]["lp"]
86
+ except:
87
+ pass
88
+ st.filter(type='bandpass', freqmin=freqmin, freqmax=freqmax, corners=2, zerophase=True)
89
+
90
+ # Interpolate if necessary
91
+ if any(tr.stats.sampling_rate != 100.0 for tr in st):
92
+ try:
93
+ st.interpolate(100, method="linear")
94
+ except:
95
+ st = _resampling(st)
96
+
97
+ # Trim stream to the common start and end times
98
+ st.trim(min(tr.stats.starttime for tr in st), max(tr.stats.endtime for tr in st), pad=True, fill_value=0)
99
+ start_time = st[0].stats.starttime
100
+ end_time = st[0].stats.endtime
101
+
102
+ # Prepare metadata
103
+ meta = {
104
+ "start_time": start_time,
105
+ "end_time": end_time,
106
+ "trace_name": f"{files_list[0].split('/')[-2]}/{files_list[0].split('/')[-1]}"
107
+ }
108
+
109
+ # Prepare component mapping and types
110
+ data_set = {}
111
+ st_times = []
112
+ components = {tr.stats.channel[-1]: tr for tr in st}
113
+ time_shift = int(60 - (args['overlap'] * 60))
114
+
115
+ # Define preferred components for each column
116
+ components_list = [
117
+ ['E', '1'], # Column 0
118
+ ['N', '2'], # Column 1
119
+ ['Z'] # Column 2
120
+ ]
121
+
122
+ current_time = start_time
123
+ while current_time < end_time:
124
+ window_end = current_time + 60
125
+ st_times.append(str(current_time).replace('T', ' ').replace('Z', ''))
126
+ npz_data = np.zeros((6000, 3))
127
+
128
+ for col_idx, comp_options in enumerate(components_list):
129
+ for comp in comp_options:
130
+ if comp in components:
131
+ tr = components[comp].copy().slice(current_time, window_end)
132
+ data = tr.data[:6000]
133
+ # Pad with zeros if data is shorter than 6000 samples
134
+ if len(data) < 6000:
135
+ data = np.pad(data, (0, 6000 - len(data)), 'constant')
136
+ npz_data[:, col_idx] = data
137
+ break # Stop after finding the first available component
138
+
139
+ key = str(current_time).replace('T', ' ').replace('Z', '')
140
+ data_set[key] = npz_data
141
+ current_time += time_shift
142
+
143
+ meta["trace_start_time"] = st_times
144
+
145
+ # Metadata population with default placeholders for now
146
+ try:
147
+ meta.update({
148
+ "receiver_code": st[0].stats.station,
149
+ "instrument_type": 0,
150
+ "network_code": 0,
151
+ "receiver_latitude": 0,
152
+ "receiver_longitude": 0,
153
+ "receiver_elevation_m": 0
154
+ })
155
+ except Exception:
156
+ meta.update({
157
+ "receiver_code": station,
158
+ "instrument_type": 0,
159
+ "network_code": 0,
160
+ "receiver_latitude": 0,
161
+ "receiver_longitude": 0,
162
+ "receiver_elevation_m": 0
163
+ })
164
+
165
+ return meta, data_set, freqmin, freqmax
166
+
167
+
168
+ def _output_writter_prediction(meta, csvPr, Ppicks, Pprob, Spicks, Sprob, detection_memory,prob_memory,predict_writer, idx, cq, cqq):
169
+
170
+ """
171
+
172
+ Writes the detection & picking results into a CSV file.
173
+
174
+ Parameters
175
+ ----------
176
+ dataset: hdf5 obj
177
+ Dataset object of the trace.
178
+
179
+ predict_writer: obj
180
+ For writing out the detection/picking results in the CSV file.
181
+
182
+ csvPr: obj
183
+ For writing out the detection/picking results in the CSV file.
184
+
185
+ matches: dic
186
+ It contains the information for the detected and picked event.
187
+
188
+ snr: list of two floats
189
+ Estimated signal to noise ratios for picked P and S phases.
190
+
191
+ detection_memory : list
192
+ Keep the track of detected events.
193
+
194
+ Returns
195
+ -------
196
+ detection_memory : list
197
+ Keep the track of detected events.
198
+
199
+
200
+ """
201
+
202
+ station_name = meta["receiver_code"]
203
+ station_lat = meta["receiver_latitude"]
204
+ station_lon = meta["receiver_longitude"]
205
+ station_elv = meta["receiver_elevation_m"]
206
+ start_time = meta["trace_start_time"][idx]
207
+ station_name = "{:<4}".format(station_name)
208
+ network_name = meta["network_code"]
209
+ network_name = "{:<2}".format(network_name)
210
+ instrument_type = meta["instrument_type"]
211
+ instrument_type = "{:<2}".format(instrument_type)
212
+
213
+ try:
214
+ start_time = datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S.%f')
215
+ except Exception:
216
+ start_time = datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S')
217
+
218
+ def _date_convertor(r):
219
+ if isinstance(r, str):
220
+ mls = r.split('.')
221
+ if len(mls) == 1:
222
+ new_t = datetime.strptime(r, '%Y-%m-%d %H:%M:%S')
223
+ else:
224
+ new_t = datetime.strptime(r, '%Y-%m-%d %H:%M:%S.%f')
225
+ else:
226
+ new_t = r
227
+
228
+ return new_t
229
+
230
+
231
+ p_time = []
232
+ p_prob = []
233
+ PdateTime = []
234
+ if Ppicks[0]!=None:
235
+ #for iP in range(len(Ppicks)):
236
+ #if Ppicks[iP]!=None:
237
+ p_time.append(start_time+timedelta(seconds= Ppicks[0]/100))
238
+ p_prob.append(Pprob[0])
239
+ PdateTime.append(_date_convertor(start_time+timedelta(seconds= Ppicks[0]/100)))
240
+ detection_memory.append(p_time)
241
+ prob_memory.append(p_prob)
242
+ else:
243
+ p_time.append(None)
244
+ p_prob.append(None)
245
+ PdateTime.append(None)
246
+
247
+ s_time = []
248
+ s_prob = []
249
+ SdateTime=[]
250
+ if Spicks[0]!=None:
251
+ #for iS in range(len(Spicks)):
252
+ #if Spicks[iS]!=None:
253
+ s_time.append(start_time+timedelta(seconds= Spicks[0]/100))
254
+ s_prob.append(Sprob[0])
255
+ SdateTime.append(_date_convertor(start_time+timedelta(seconds= Spicks[0]/100)))
256
+ else:
257
+ s_time.append(None)
258
+ s_prob.append(None)
259
+ SdateTime.append(None)
260
+
261
+ SdateTime = np.array(SdateTime)
262
+ s_prob = np.array(s_prob)
263
+
264
+ p_prob = np.array(p_prob)
265
+ PdateTime = np.array(PdateTime)
266
+
267
+ predict_writer.writerow([meta["trace_name"],
268
+ network_name,
269
+ station_name,
270
+ instrument_type,
271
+ station_lat,
272
+ station_lon,
273
+ station_elv,
274
+ PdateTime[0],
275
+ p_prob[0],
276
+ SdateTime[0],
277
+ s_prob[0]
278
+ ])
279
+
280
+
281
+
282
+ csvPr.flush()
283
+
284
+
285
+ return detection_memory,prob_memory
286
+
287
+
288
+ def _get_snr(data, pat, window=200):
289
+
290
+ """
291
+
292
+ Estimates SNR.
293
+
294
+ Parameters
295
+ ----------
296
+ data : numpy array
297
+ 3 component data.
298
+
299
+ pat: positive integer
300
+ Sample point where a specific phase arrives.
301
+
302
+ window: positive integer, default=200
303
+ The length of the window for calculating the SNR (in the sample).
304
+
305
+ Returns
306
+ --------
307
+ snr : {float, None}
308
+ Estimated SNR in db.
309
+
310
+
311
+ """
312
+ import math
313
+ snr = None
314
+ if pat:
315
+ try:
316
+ if int(pat) >= window and (int(pat)+window) < len(data):
317
+ nw1 = data[int(pat)-window : int(pat)];
318
+ sw1 = data[int(pat) : int(pat)+window];
319
+ snr = round(10*math.log10((np.percentile(sw1,95)/np.percentile(nw1,95))**2), 1)
320
+ elif int(pat) < window and (int(pat)+window) < len(data):
321
+ window = int(pat)
322
+ nw1 = data[int(pat)-window : int(pat)];
323
+ sw1 = data[int(pat) : int(pat)+window];
324
+ snr = round(10*math.log10((np.percentile(sw1,95)/np.percentile(nw1,95))**2), 1)
325
+ elif (int(pat)+window) > len(data):
326
+ window = len(data)-int(pat)
327
+ nw1 = data[int(pat)-window : int(pat)];
328
+ sw1 = data[int(pat) : int(pat)+window];
329
+ snr = round(10*math.log10((np.percentile(sw1,95)/np.percentile(nw1,95))**2), 1)
330
+ except Exception:
331
+ pass
332
+ return snr
333
+
334
+
335
+ def _detect_peaks(x, mph=None, mpd=1, threshold=0, edge='rising', kpsh=False, valley=False):
336
+
337
+ """
338
+
339
+ Detect peaks in data based on their amplitude and other features.
340
+
341
+ Parameters
342
+ ----------
343
+ x : 1D array_like
344
+ data.
345
+
346
+ mph : {None, number}, default=None
347
+ detect peaks that are greater than minimum peak height.
348
+
349
+ mpd : int, default=1
350
+ detect peaks that are at least separated by minimum peak distance (in number of data).
351
+
352
+ threshold : int, default=0
353
+ detect peaks (valleys) that are greater (smaller) than `threshold in relation to their immediate neighbors.
354
+
355
+ edge : str, default=rising
356
+ for a flat peak, keep only the rising edge ('rising'), only the falling edge ('falling'), both edges ('both'), or don't detect a flat peak (None).
357
+
358
+ kpsh : bool, default=False
359
+ keep peaks with same height even if they are closer than `mpd`.
360
+
361
+ valley : bool, default=False
362
+ if True (1), detect valleys (local minima) instead of peaks.
363
+
364
+ Returns
365
+ -------
366
+ ind : 1D array_like
367
+ indeces of the peaks in `x`.
368
+
369
+ Modified from
370
+ ----------
371
+ .. [1] http://nbviewer.ipython.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb
372
+
373
+
374
+ """
375
+
376
+ x = np.atleast_1d(x).astype('float64')
377
+ if x.size < 3:
378
+ return np.array([], dtype=int)
379
+ if valley:
380
+ x = -x
381
+ # find indices of all peaks
382
+ dx = x[1:] - x[:-1]
383
+ # handle NaN's
384
+ indnan = np.where(np.isnan(x))[0]
385
+ if indnan.size:
386
+ x[indnan] = np.inf
387
+ dx[np.where(np.isnan(dx))[0]] = np.inf
388
+ ine, ire, ife = np.array([[], [], []], dtype=int)
389
+ if not edge:
390
+ ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0]
391
+ else:
392
+ if edge.lower() in ['rising', 'both']:
393
+ ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0]
394
+ if edge.lower() in ['falling', 'both']:
395
+ ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0]
396
+ ind = np.unique(np.hstack((ine, ire, ife)))
397
+ # handle NaN's
398
+ if ind.size and indnan.size:
399
+ # NaN's and values close to NaN's cannot be peaks
400
+ ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan-1, indnan+1))), invert=True)]
401
+ # first and last values of x cannot be peaks
402
+ if ind.size and ind[0] == 0:
403
+ ind = ind[1:]
404
+ if ind.size and ind[-1] == x.size-1:
405
+ ind = ind[:-1]
406
+ # remove peaks < minimum peak height
407
+ if ind.size and mph is not None:
408
+ ind = ind[x[ind] >= mph]
409
+ # remove peaks - neighbors < threshold
410
+ if ind.size and threshold > 0:
411
+ dx = np.min(np.vstack([x[ind]-x[ind-1], x[ind]-x[ind+1]]), axis=0)
412
+ ind = np.delete(ind, np.where(dx < threshold)[0])
413
+ # detect small peaks closer than minimum peak distance
414
+ if ind.size and mpd > 1:
415
+ ind = ind[np.argsort(x[ind])][::-1] # sort ind by peak height
416
+ idel = np.zeros(ind.size, dtype=bool)
417
+ for i in range(ind.size):
418
+ if not idel[i]:
419
+ # keep peaks with the same height if kpsh is True
420
+ idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) \
421
+ & (x[ind[i]] > x[ind] if kpsh else True)
422
+ idel[i] = 0 # Keep current peak
423
+ # remove the small peaks and sort back the indices by their occurrence
424
+ ind = np.sort(ind[~idel])
425
+
426
+ return ind
427
+
428
+
429
+ def _picker(args, yh3, thr_type='P_threshold'):
430
+ """
431
+ Performs detection and picking.
432
+
433
+ Parameters
434
+ ----------
435
+ args : dic
436
+ A dictionary containing all of the input parameters.
437
+
438
+ yh1 : 1D array
439
+ probability.
440
+
441
+ Returns
442
+ --------
443
+ Ppickall: Pick.
444
+ Pproball: Pick Probability.
445
+
446
+ """
447
+ P_PICKall=[]
448
+ Ppickall=[]
449
+ Pproball = []
450
+ perrorall=[]
451
+
452
+ sP_arr = _detect_peaks(yh3, mph=args[thr_type], mpd=1)
453
+
454
+ P_PICKS = []
455
+ pick_errors = []
456
+ if len(sP_arr) > 0:
457
+ P_uncertainty = None
458
+
459
+ for pick in range(len(sP_arr)):
460
+ sauto = sP_arr[pick]
461
+
462
+
463
+ if sauto:
464
+ P_prob = np.round(yh3[int(sauto)], 3)
465
+ P_PICKS.append([sauto,P_prob, P_uncertainty])
466
+
467
+ so=[]
468
+ si=[]
469
+ P_PICKS = np.array(P_PICKS)
470
+ P_PICKall.append(P_PICKS)
471
+ for ij in P_PICKS:
472
+ so.append(ij[1])
473
+ si.append(ij[0])
474
+ try:
475
+ so = np.array(so)
476
+ inds = np.argmax(so)
477
+ swave = si[inds]
478
+ Ppickall.append((swave))
479
+ Pproball.append((np.max(so)))
480
+ except:
481
+ Ppickall.append(None)
482
+ Pproball.append(None)
483
+
484
+ #print(np.shape(Ppickall))
485
+ #Ppickall = np.array(Ppickall)
486
+ #Pproball = np.array(Pproball)
487
+
488
+ return Ppickall, Pproball
489
+
490
+
491
+ def _resampling(st):
492
+ 'perform resampling on Obspy stream objects'
493
+
494
+ need_resampling = [tr for tr in st if tr.stats.sampling_rate != 100.0]
495
+ if len(need_resampling) > 0:
496
+ # print('resampling ...', flush=True)
497
+ for indx, tr in enumerate(need_resampling):
498
+ if tr.stats.delta < 0.01:
499
+ tr.filter('lowpass',freq=45,zerophase=True)
500
+ tr.resample(100)
501
+ tr.stats.sampling_rate = 100
502
+ tr.stats.delta = 0.01
503
+ tr.data.dtype = 'int32'
504
+ st.remove(tr)
505
+ st.append(tr)
506
+ return st
507
+
508
+
509
+ def _normalize(data, mode = 'max'):
510
+ """
511
+
512
+ Normalize 3D arrays.
513
+
514
+ Parameters
515
+ ----------
516
+ data : 3D numpy array
517
+ 3 component traces.
518
+
519
+ mode : str, default='std'
520
+ Mode of normalization. 'max' or 'std'
521
+
522
+ Returns
523
+ -------
524
+ data : 3D numpy array
525
+ normalized data.
526
+
527
+ """
528
+
529
+ data -= np.mean(data, axis=0, keepdims=True)
530
+ if mode == 'max':
531
+ max_data = np.max(data, axis=0, keepdims=True)
532
+ assert(max_data.shape[-1] == data.shape[-1])
533
+ max_data[max_data == 0] = 1
534
+ data /= max_data
535
+
536
+ elif mode == 'std':
537
+ std_data = np.std(data, axis=0, keepdims=True)
538
+ assert(std_data.shape[-1] == data.shape[-1])
539
+ std_data[std_data == 0] = 1
540
+ data /= std_data
541
+ return data
542
+
543
+ @ray.remote
544
+ def mseed_predictor(input_dir='downloads_mseeds',
545
+ output_dir="detections",
546
+ P_threshold=0.1,
547
+ S_threshold=0.1,
548
+ normalization_mode='std',
549
+ dt=1,
550
+ batch_size=500,
551
+ overlap=0.3,
552
+ gpu_id=None,
553
+ gpu_limit=None,
554
+ overwrite=False,
555
+ log_queue=None,
556
+ stations2use=None,
557
+ stations_filters=None,
558
+ p_model=None,
559
+ s_model=None,
560
+ number_of_concurrent_station_predictions=None,
561
+ ray_cpus=None,
562
+ use_gpu=False,
563
+ gpu_memory_limit_mb=None,
564
+ testing_gpu=None,
565
+ test_csv_filepath=None,
566
+ specific_stations=None,
567
+ timechunk_id=None,
568
+ waveform_overlap=None,
569
+ total_timechunks=None,
570
+ number_of_concurrent_timechunk_predictions=None,
571
+ total_analysis_time=None,
572
+ intra_threads=None,
573
+ inter_threads=None,
574
+ timechunk_dt=None,
575
+ # SeisBench model parameters
576
+ model_type='eqcct',
577
+ seisbench_parent_model=None,
578
+ seisbench_child_model=None,
579
+ Detection_threshold=0.3):
580
+
581
+ """
582
+
583
+ To perform fast detection directly on mseed data.
584
+
585
+ Parameters
586
+ ----------
587
+ input_dir: str
588
+ Directory name containing hdf5 and csv files-preprocessed data.
589
+
590
+ input_model: str
591
+ Path to a trained model.
592
+
593
+ stations_json: str
594
+ Path to a JSON file containing station information.
595
+
596
+ output_dir: str
597
+ Output directory that will be generated.
598
+
599
+ P_threshold: float, default=0.1
600
+ A value which the P probabilities above it will be considered as P arrival.
601
+
602
+ S_threshold: float, default=0.1
603
+ A value which the S probabilities above it will be considered as S arrival.
604
+
605
+ normalization_mode: str, default=std
606
+ Mode of normalization for data preprocessing max maximum amplitude among three components std standard deviation.
607
+
608
+ batch_size: int, default=500
609
+ Batch size. This wont affect the speed much but can affect the performance. A value beteen 200 to 1000 is recommended.
610
+
611
+ overlap: float, default=0.3
612
+ If set the detection and picking are performed in overlapping windows.
613
+
614
+ gpu_id: int
615
+ Id of GPU used for the prediction. If using CPU set to None.
616
+
617
+ gpu_limit: float
618
+ Set the maximum percentage of memory usage for the GPU.
619
+
620
+ overwrite: Bolean, default=False
621
+ Overwrite your results automatically.
622
+
623
+ Returns
624
+ --------
625
+
626
+ """
627
+
628
+ # Set up logger that will write logs to this native process and add them to the log.queue to be added back to the main logger outside of this Raylet
629
+ # worker logger ships records to driver
630
+ logger = logging.getLogger("eqcctpro.worker")
631
+ logger.setLevel(logging.INFO)
632
+ logger.handlers[:] = []
633
+ logger.propagate = False
634
+ log_handler = QueueHandler(log_queue)
635
+ if log_queue is not None:
636
+ logger.addHandler(log_handler) # Ray queue supports put()
637
+
638
+ # We set up the tf_environ again for the Raylets, who adopt their own import state and TF runtime when created.
639
+ # We want to ensure that they are configured properly so that they won't die (bad)
640
+ skip_tf = (model_type.lower() != 'eqcct')
641
+ if not use_gpu:
642
+ tf_environ(gpu_id=-1, intra_threads=intra_threads, inter_threads=inter_threads, logger=logger, skip_tf=skip_tf)
643
+ # tf_environ(gpu_id=1, gpu_memory_limit_mb=gpu_memory_limit_mb, gpus_to_use=gpu_id, intra_threads=intra_threads, inter_threads=inter_threads)
644
+
645
+
646
+ args = {
647
+ "input_dir": input_dir,
648
+ "output_dir": output_dir,
649
+ "P_threshold": P_threshold,
650
+ "S_threshold": S_threshold,
651
+ "normalization_mode": normalization_mode,
652
+ "dt": dt,
653
+ "overlap": overlap,
654
+ "batch_size": batch_size,
655
+ "overwrite": overwrite,
656
+ "gpu_id": gpu_id,
657
+ "gpu_limit": gpu_limit,
658
+ "p_model": p_model,
659
+ "s_model": s_model,
660
+ "stations_filters": stations_filters,
661
+ "model_type": model_type,
662
+ "seisbench_parent_model": seisbench_parent_model,
663
+ "seisbench_child_model": seisbench_child_model,
664
+ "Detection_threshold": Detection_threshold
665
+ }
666
+
667
+ logger.info(f"------- Hardware Configuration -------")
668
+ try:
669
+ process = psutil.Process(os.getpid())
670
+ process.cpu_affinity(ray_cpus) # ray_cpus should be a list of core IDs like [0, 1, 2]
671
+ logger.info(f"CPU affinity set to cores: {list(ray_cpus)}")
672
+ logger.info("")
673
+ except Exception as e:
674
+ logger.error(f"Failed to set CPU affinity. Reason: {e}")
675
+ logger.error("")
676
+ sys.exit(1)
677
+
678
+ out_dir = os.path.join(os.getcwd(), str(args['output_dir']))
679
+ try:
680
+ if platform.system() == 'Windows': station_list = [ev.split(".")[0] for ev in listdir(args['input_dir']) if ev.split("\\")[-1] != ".DS_Store"]
681
+ else: station_list = [ev.split(".")[0] for ev in listdir(args['input_dir']) if ev.split("/")[-1] != ".DS_Store"]
682
+ station_list = sorted(set(station_list))
683
+ except Exception as e:
684
+ logger.info(f"{e}")
685
+ return # To-Do: Fix so that it has a valid return?
686
+ # log.write(f"GPU ID: {args['gpu_id']}; Batch size: {args['batch_size']}")
687
+ logger.info(f"------- Data Preprocessing for EQCCTPro -------")
688
+ logger.info(f"{len(station_list)} station(s) in {args['input_dir']}")
689
+
690
+ if stations2use and stations2use <= len(station_list): # For System Evaluation Execution
691
+ station_list = random.sample(station_list, stations2use) # Randomly choose stations from the sample size
692
+ # log.write(f"Using {len(station_list)} station(s) after selection.")
693
+
694
+ if specific_stations is not None: station_list = [x for x in station_list if x in specific_stations] # For "One Use Run" Over a Given Set of Stations (Just Run EQCCTPro on specific_stations)
695
+ else: station_list = station_list # someone put None thinking that they would be able to run the whole directory in one go
696
+ logger.info(f"Using {len(station_list)} selected station(s): {station_list}.")
697
+
698
+ if not station_list or any(looks_like_timechunk_id(x) for x in station_list):
699
+ # Rebuild from the actual contents of the timechunk dir
700
+ station_list = build_station_list_from_dir(args['input_dir'])
701
+ logger.info(f"Station list rebuilt from directory because it contained a timechunk id or was empty.")
702
+
703
+ tasks_predictor = [[f"({i+1}/{len(station_list)})", station_list[i], out_dir, args] for i in range(len(station_list))]
704
+
705
+ if not tasks_predictor: return
706
+
707
+ # CREATE MODEL ACTOR(S) - Add this before the task loop
708
+ logger.info(f"Creating model actor(s)...")
709
+
710
+ model_type_lower = model_type.lower() if model_type else 'eqcct'
711
+
712
+ if model_type_lower == 'seisbench':
713
+ # Create SeisBench model actors
714
+ if use_gpu:
715
+ # Get VRAM requirement for this SeisBench model
716
+ model_vram_mb = get_seisbench_model_vram_mb(
717
+ seisbench_parent_model,
718
+ seisbench_child_model,
719
+ default_mb=2000.0
720
+ )
721
+ # Use max of requested VRAM or model requirement (similar to EQCCT logic)
722
+ model_vram_mb = max(gpu_memory_limit_mb, model_vram_mb) if gpu_memory_limit_mb else model_vram_mb
723
+
724
+ model_actors = []
725
+ logger.info(f"Using GPUs: {gpu_id}")
726
+ for gpu_idx in gpu_id:
727
+ logger.info(f"Creating SeisBenchModelActor on GPU {gpu_idx} with {model_vram_mb/1024:.2f}GB VRAM requirement...")
728
+ actor = SeisBenchModelActor.options(num_gpus=1, num_cpus=0).remote(
729
+ parent_model_name=seisbench_parent_model,
730
+ child_model_name=seisbench_child_model,
731
+ gpus_to_use=[gpu_idx],
732
+ use_gpu=True
733
+ )
734
+ try:
735
+ ray.get(actor.ready.remote())
736
+ except Exception as e:
737
+ logger.error(f"Failed to create SeisBenchModelActor on GPU {gpu_idx}: {e}")
738
+ raise
739
+ logger.info(f"SeisBenchModelActor created on GPU {gpu_idx}.")
740
+ model_actors.append(actor)
741
+ logger.info(f"Created {len(model_actors)} GPU-sized SeisBenchModelActor(s).")
742
+ else:
743
+ model_actors = [SeisBenchModelActor.options(num_cpus=1).remote(
744
+ parent_model_name=seisbench_parent_model,
745
+ child_model_name=seisbench_child_model,
746
+ gpus_to_use=False,
747
+ use_gpu=False
748
+ )]
749
+ ray.get(model_actors[0].ready.remote())
750
+ logger.info(f"Created a 1 CPU-sized SeisBenchModelActor")
751
+ else:
752
+ # Create EQCCT model actors (original logic)
753
+ if use_gpu:
754
+ # Allocate more VRAM to model actors (they need to hold the full model)
755
+ # Reserve ~2-3GB per model actor, adjust based on your model size
756
+ model_vram_mb = max(gpu_memory_limit_mb, 3000) # At least VRAM or 3GB for EQCCT (subject to change)
757
+
758
+ # Create one model actor per GPU
759
+ model_actors = []
760
+ logger.info(f"Using GPUs: {gpu_id}")
761
+ for gpu_idx in gpu_id: # gpu_id is a list of GPU IDs and gpu_idx is the current GPU ID in the loop
762
+ logger.info(f"Creating ModelActor on GPU {gpu_idx} with {model_vram_mb/1024:.2f}GB VRAM limit...")
763
+ actor = ModelActor.options(num_gpus=1, num_cpus=0).remote(gpus_to_use=[gpu_idx], p_model_path=p_model, s_model_path=s_model, gpu_memory_limit_mb=model_vram_mb, use_gpu=True)
764
+ # Wait for __init__ to complete and raise if error
765
+ try:
766
+ ray.get(actor.ready.remote())
767
+ except Exception as e:
768
+ logger.error(f"Failed to create ModelActor on GPU {gpu_idx}: {e}")
769
+ raise
770
+ logger.info(f"ModelActor created on GPU {gpu_idx}.")
771
+ model_actors.append(actor)
772
+
773
+ logger.info(f"Created {len(model_actors)} GPU-sized ModelActor(s).")
774
+ # Using CUDA_VISIBLE_DEVICES is not a reliable way to report which physical GPU is being used bc Ray can overwrite, clear, or remap the assigned GPU so that each worker sees them as local indices (often starting from 0)
775
+ logger.info(f"[ModelActor] Model successfully loaded onto {'GPU' if use_gpu else 'CPU'}.") # Better way to log is to use ray.get_gpu_ids()
776
+ else:
777
+ # Create CPU model actor
778
+ model_actors = [ModelActor.options(num_cpus=1).remote(p_model_path=p_model, s_model_path=s_model, gpu_memory_limit_mb=None, use_gpu=False)]
779
+ logger.info(f"Created a 1 CPU-sized ModelActor")
780
+
781
+ # Submit tasks to ray in a queue
782
+ tasks_queue = []
783
+ max_pending_tasks = number_of_concurrent_station_predictions
784
+ logger.info(f"Starting EQCCTPro parallelized waveform processing...")
785
+ logger.info("")
786
+ start_time = time.time()
787
+ model_type_lower = model_type.lower() if model_type else 'eqcct'
788
+ if model_type_lower == 'seisbench':
789
+ logger.info(f"------- Analyzing Seismic Waveforms for P and S Picks via SeisBench ({seisbench_parent_model} - {seisbench_child_model}) -------")
790
+ else:
791
+ logger.info(f"------- Analyzing Seismic Waveforms for P and S Picks via EQCCT -------")
792
+
793
+ if timechunk_id is None:
794
+ # derive from the path if caller forgot to pass it
795
+ cand = os.path.basename(input_dir)
796
+ if "_" in cand and len(cand) >= 10:
797
+ timechunk_id = cand
798
+ else:
799
+ raise ValueError("timechunk_id is None and could not be inferred from input_dir; "
800
+ "expected a dir named like YYYYMMDDThhmmssZ_YYYYMMDDThhmmssZ")
801
+ starttime, endtime, time_delta = parse_time_range(timechunk_id)
802
+
803
+ logger.info(f"Analyzing {time_delta} minute timechunk from {starttime} to {endtime} ({waveform_overlap} min overlap)")
804
+ logger.info(f"Processing a total of {len(tasks_predictor)} stations, {max_pending_tasks} at a time.")
805
+
806
+
807
+ # Concurrent Prediction(s) Parallel Processing
808
+ try:
809
+ for i in range(len(tasks_predictor)):
810
+ while True:
811
+ # Add new task to queue while max is not reached
812
+ if len(tasks_queue) < max_pending_tasks:
813
+ # SELECT WHICH MODEL ACTOR TO USE (round-robin across GPUs)
814
+ model_actor = model_actors[i % len(model_actors)]
815
+
816
+ # Route to appropriate prediction function based on model type
817
+ if model_type_lower == 'seisbench':
818
+ # SeisBench models use parallel_predict_seisbench
819
+ if use_gpu is False:
820
+ tasks_queue.append(parallel_predict_seisbench.options(num_cpus=0).remote(tasks_predictor[i], model_actor, False))
821
+ elif use_gpu is True:
822
+ # Don't allocate GPUs to workers, only to model actors
823
+ # Use num_cpus=0 to avoid deadlocks when Ray has limited CPUs
824
+ tasks_queue.append(parallel_predict_seisbench.options(num_cpus=0, num_gpus=0).remote(tasks_predictor[i], model_actor, True))
825
+ else:
826
+ # EQCCT models use parallel_predict (original)
827
+ if use_gpu is False:
828
+ tasks_queue.append(parallel_predict.options(num_cpus=0).remote(tasks_predictor[i], model_actor, False))
829
+ elif use_gpu is True:
830
+ # Don't allocate GPUs to workers, only to model actors
831
+ # Use num_cpus=0 to avoid deadlocks when Ray has limited CPUs
832
+ tasks_queue.append(parallel_predict.options(num_cpus=0, num_gpus=0).remote(tasks_predictor[i], model_actor, True))
833
+ break
834
+ # If there are more tasks than maximum, just process them
835
+ else:
836
+ tasks_finished, tasks_queue = ray.wait(tasks_queue, num_returns=1, timeout=None)
837
+ for finished_task in tasks_finished:
838
+ log_entry = ray.get(finished_task)
839
+ logger.info(f'{log_entry}')
840
+
841
+ # After adding all the tasks to queue, process what's left
842
+ while tasks_queue:
843
+ tasks_finished, tasks_queue = ray.wait(tasks_queue, num_returns=1, timeout=None)
844
+ for finished_task in tasks_finished:
845
+ log_entry = ray.get(finished_task)
846
+ logger.info(f'{log_entry}')
847
+ logger.info("")
848
+
849
+ except Exception as e:
850
+ # Catch any error in the parallel processing
851
+ logger.error(f"ERROR in parallel processing at {datetime.now()}")
852
+ logger.error(f"Error: {str(e)}")
853
+ logger.error(traceback.format_exc())
854
+ raise # Re-raise to see the error
855
+
856
+ logger.info(f"------- Parallel Station Waveform Processing Complete For {starttime} to {endtime} Timechunk-------")
857
+ end_time = time.time()
858
+ logger.info(f"Picks saved at {output_dir}Process Runtime: {end_time - start_time:.2f} s")
859
+
860
+ if testing_gpu is not None:
861
+ # Guard: make sure CPUs is an int, not a list
862
+ num_ray_cpus = len(ray_cpus) if isinstance(ray_cpus, (list, tuple)) else int(len(list(ray_cpus)))
863
+
864
+ # Parse the timechunk_id to get start/end times
865
+ if timechunk_id:
866
+ starttime, endtime, time_delta = parse_time_range(timechunk_id)
867
+ timechunk_length_min = time_delta.total_seconds() / 60.0 if time_delta else None
868
+ else:
869
+ timechunk_length_min = None
870
+
871
+ # Determine model name for logging
872
+ if model_type_lower == 'seisbench':
873
+ model_used = f"{seisbench_parent_model}/{seisbench_child_model}"
874
+ else:
875
+ model_used = "eqcct"
876
+
877
+ # To-Do: Add column for CPU IDs
878
+ trial_data = {
879
+ "Trial Number": None, # Will be auto-filled by append_trial_row
880
+ "Stations Used": str(station_list),
881
+ "Number of Stations Used": len(station_list),
882
+ "Number of CPUs Allocated for Ray to Use": num_ray_cpus,
883
+ "Intra-parallelism Threads": intra_threads if intra_threads is not None else "",
884
+ "Inter-parallelism Threads": inter_threads if inter_threads is not None else "",
885
+ "GPUs Used": json.dumps(list(gpu_id)) if (use_gpu and gpu_id is not None) else "[]",
886
+ "Inference Actor Memory Limit (MB)": float(model_vram_mb) if (use_gpu and gpu_memory_limit_mb is not None) else "",
887
+ "Total Waveform Analysis Timespace (min)": float(total_analysis_time.total_seconds() / 60.0) if hasattr(total_analysis_time, "total_seconds") else (float(total_analysis_time) if total_analysis_time else ""),
888
+ "Total Number of Timechunks": int(total_timechunks) if total_timechunks is not None else "",
889
+ "Concurrent Timechunks Used": int(number_of_concurrent_timechunk_predictions) if number_of_concurrent_timechunk_predictions is not None else "",
890
+ "Length of Timechunk (min)": timechunk_length_min if timechunk_length_min is not None else "",
891
+ "Number of Concurrent Station Tasks": int(number_of_concurrent_station_predictions) if number_of_concurrent_station_predictions is not None else "",
892
+ "Total Run time for Picker (s)": round(end_time - start_time, 6),
893
+ "Model Used": model_used,
894
+ "Trial Success": "",
895
+ "Error Message": str(""),
896
+ }
897
+
898
+ append_trial_row(csv_path=test_csv_filepath, trial_data=trial_data)
899
+ logger.info(f"Successfully saved trial data to CSV at {test_csv_filepath}")
900
+
901
+ return "Successfully ran EQCCTPro, exiting..."
902
+
903
+
904
+ @ray.remote
905
+ class ModelActor:
906
+ def __init__(self, p_model_path, s_model_path, gpus_to_use=False, intra_threads=1, inter_threads=1, gpu_memory_limit_mb=None, use_gpu=True):
907
+ self.logger = logging.getLogger("eqcctpro.model_actor")
908
+ self.logger.setLevel(logging.INFO)
909
+ self.logger.handlers[:] = []
910
+ self.logger.propagate = False
911
+ self.logger.addHandler(logging.StreamHandler())
912
+
913
+ self.logger.info("=== ModelActor __init__ STARTED ===")
914
+ self.logger.info(f"p_model_path = {p_model_path}")
915
+ self.logger.info(f"s_model_path = {s_model_path}")
916
+ self.logger.info(f"Exists? P: {os.path.exists(p_model_path)}, S: {os.path.exists(s_model_path)}")
917
+
918
+ if use_gpu:
919
+ # Configure GPU memory for this actor
920
+ # We want one GPU per actor
921
+ try:
922
+ self.logger.info("Calling tf_environ...")
923
+ tf_environ(
924
+ gpu_id=gpus_to_use[0] if gpus_to_use else 0,
925
+ gpus_to_use=None, # First visible GPU only
926
+ vram_limit_mb=gpu_memory_limit_mb,
927
+ intra_threads=intra_threads,
928
+ inter_threads=inter_threads,
929
+ log_device=True,
930
+ logger=self.logger)
931
+ self.logger.info("tf_environ finished.")
932
+ except RuntimeError as e:
933
+ self.logger.error(f"[ModelActor] Error setting memory limit: {e}")
934
+
935
+ # Load the model once
936
+ self.logger.info("Importing/load_eqcct_model...")
937
+ from .eqcct_tf_models import load_eqcct_model
938
+ self.model = load_eqcct_model(p_model_path, s_model_path)
939
+ self.logger.info("Model loaded.")
940
+
941
+ def ready(self):
942
+ """Simple method to check if the actor is ready"""
943
+ return True
944
+
945
+ def predict(self, data_generator):
946
+ """Perform prediction using the loaded model"""
947
+ return self.model.predict(data_generator, verbose=0)
948
+
949
+ def predict_from_arrays(self, trace_start_time, data_set, batch_size, norm_mode):
950
+ from .eqcct_tf_models import PreLoadGeneratorTest
951
+ pred_generator = PreLoadGeneratorTest(trace_start_time, data_set,
952
+ batch_size=batch_size, norm_mode=norm_mode)
953
+ return self.model.predict(pred_generator, verbose=0)
954
+
955
+
956
+ @ray.remote
957
+ class SeisBenchModelActor:
958
+ """
959
+ Ray actor for SeisBench models that loads the model once and shares it across predictions.
960
+ Similar to ModelActor but for SeisBench models (PyTorch-based).
961
+ """
962
+ def __init__(self, parent_model_name, child_model_name, gpus_to_use=False, use_gpu=True):
963
+ self.logger = logging.getLogger("eqcctpro.seisbench_model_actor")
964
+ self.logger.setLevel(logging.INFO)
965
+ self.logger.handlers[:] = []
966
+ self.logger.propagate = False
967
+ self.logger.addHandler(logging.StreamHandler())
968
+
969
+ self.logger.info("=== SeisBenchModelActor __init__ STARTED ===")
970
+ self.logger.info(f"parent_model_name = {parent_model_name}")
971
+ self.logger.info(f"child_model_name = {child_model_name}")
972
+ self.use_gpu = use_gpu
973
+ self.gpus_to_use = gpus_to_use
974
+
975
+ # Set device for PyTorch (SeisBench uses PyTorch)
976
+ try:
977
+ import torch
978
+ except ImportError:
979
+ self.logger.error("PyTorch (torch) is not installed. SeisBench models require PyTorch.")
980
+ raise ImportError("PyTorch (torch) is not installed. Please install it to use SeisBench models.")
981
+
982
+ if use_gpu:
983
+ # When using Ray with num_gpus=1, the assigned GPU is always visible as cuda:0
984
+ # regardless of its physical ID (0, 1, etc.) because Ray sets CUDA_VISIBLE_DEVICES.
985
+ self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
986
+ self.logger.info(f"Using device: {self.device} (mapped by Ray from physical {gpus_to_use})")
987
+ else:
988
+ self.device = torch.device('cpu')
989
+ self.logger.info("Using CPU device")
990
+
991
+ # Load the SeisBench model
992
+ self.logger.info("Loading SeisBench model...")
993
+ from .seisbench_models import SeisBenchModels
994
+ self.model_wrapper = SeisBenchModels(parent_model_name, child_model_name)
995
+ self.model_wrapper.load_model()
996
+
997
+ # Move model to device if using GPU
998
+ if use_gpu:
999
+ try:
1000
+ if hasattr(self.model_wrapper.model, 'to'):
1001
+ self.model_wrapper.model.to(self.device)
1002
+ self.logger.info(f"Model moved to {self.device}")
1003
+ except Exception as e:
1004
+ self.logger.warning(f"Could not move model to GPU: {e}")
1005
+
1006
+ self.logger.info("SeisBench model loaded successfully.")
1007
+
1008
+ def ready(self):
1009
+ """Simple method to check if the actor is ready"""
1010
+ return True
1011
+
1012
+ def classify(self, stream, P_threshold=0.3, S_threshold=0.3, Detection_threshold=0.3, **kwargs):
1013
+ """
1014
+ Classify a stream and return picks.
1015
+
1016
+ Parameters:
1017
+ -----------
1018
+ stream : obspy.Stream
1019
+ 3-component ObsPy Stream
1020
+ P_threshold : float
1021
+ P phase detection threshold
1022
+ S_threshold : float
1023
+ S phase detection threshold
1024
+ Detection_threshold : float
1025
+ Detection threshold
1026
+ **kwargs : dict
1027
+ Additional arguments for model.classify()
1028
+
1029
+ Returns:
1030
+ --------
1031
+ ClassifyOutput
1032
+ Object containing picks
1033
+ """
1034
+ return self.model_wrapper.classify(
1035
+ stream,
1036
+ P_threshold=P_threshold,
1037
+ S_threshold=S_threshold,
1038
+ Detection_threshold=Detection_threshold,
1039
+ **kwargs
1040
+ )
1041
+
1042
+
1043
+ @ray.remote
1044
+ def parallel_predict_seisbench(predict_args, model_actor, gpu=False):
1045
+ """
1046
+ Prediction function for SeisBench models.
1047
+ Uses mseed2stream_3c for preprocessing and SeisBenchModelActor for predictions.
1048
+ """
1049
+ import glob
1050
+ import shutil
1051
+ import csv
1052
+ import logging
1053
+ from logging.handlers import QueueHandler
1054
+ from pathlib import Path
1055
+ from .seisbench_models import mseed2stream_3c
1056
+
1057
+ pos, station, out_dir, args = predict_args
1058
+
1059
+ # Set up logger to forward to the main listener
1060
+ logger = logging.getLogger(f"eqcctpro.worker.{station}")
1061
+ logger.setLevel(logging.INFO)
1062
+ if args.get('log_queue') is not None:
1063
+ logger.addHandler(QueueHandler(args['log_queue']))
1064
+
1065
+ save_dir = os.path.join(out_dir, str(station)+'_outputs')
1066
+ csv_filename = os.path.join(save_dir,'X_prediction_results.csv')
1067
+
1068
+ if os.path.isfile(csv_filename):
1069
+ if args['overwrite']:
1070
+ shutil.rmtree(save_dir)
1071
+ else:
1072
+ return f"{pos} {station}: Skipped (already exists - overwrite=False)."
1073
+
1074
+ os.makedirs(save_dir, exist_ok=True)
1075
+ csvPr_gen = open(csv_filename, 'w')
1076
+ predict_writer = csv.writer(csvPr_gen, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
1077
+ predict_writer.writerow(['file_name',
1078
+ 'network',
1079
+ 'station',
1080
+ 'instrument_type',
1081
+ 'station_lat',
1082
+ 'station_lon',
1083
+ 'station_elv',
1084
+ 'p_arrival_time',
1085
+ 'p_probability',
1086
+ 's_arrival_time',
1087
+ 's_probability'])
1088
+ csvPr_gen.flush()
1089
+
1090
+ start_Predicting = time.time()
1091
+ files_list = glob.glob(f"{args['input_dir']}/{station}/*mseed")
1092
+
1093
+ if not files_list:
1094
+ csvPr_gen.close()
1095
+ return f"{pos} {station}: FAILED - No mSEED files found."
1096
+
1097
+ try:
1098
+ # Use SeisBench preprocessing
1099
+ stream3c, freqmin, freqmax = mseed2stream_3c(args, files_list, station)
1100
+ except Exception as e:
1101
+ csvPr_gen.close()
1102
+ return f"{pos} {station}: FAILED reading mSEED: {str(e)}"
1103
+
1104
+ try:
1105
+ # Get picks from SeisBench model
1106
+ # Use ray.get with a timeout or just normally if we fixed the CPU deadlock
1107
+ classify_output = ray.get(model_actor.classify.remote(
1108
+ stream3c,
1109
+ P_threshold=args.get('P_threshold', 0.3),
1110
+ S_threshold=args.get('S_threshold', 0.3),
1111
+ Detection_threshold=args.get('Detection_threshold', 0.3),
1112
+ strict=False,
1113
+ flexible_horizontal_components=True
1114
+ ))
1115
+
1116
+ # Extract metadata from stream
1117
+ station_code = stream3c[0].stats.station if len(stream3c) > 0 else station
1118
+ network_code = stream3c[0].stats.network if len(stream3c) > 0 else ""
1119
+ # Try to get coordinates from stream metadata if available
1120
+ station_lat = getattr(stream3c[0].stats, 'coordinates', {}).get('latitude', 0.0) if len(stream3c) > 0 else 0.0
1121
+ station_lon = getattr(stream3c[0].stats, 'coordinates', {}).get('longitude', 0.0) if len(stream3c) > 0 else 0.0
1122
+ station_elv = getattr(stream3c[0].stats, 'coordinates', {}).get('elevation', 0.0) if len(stream3c) > 0 else 0.0
1123
+
1124
+ # Extract picks from ClassifyOutput
1125
+ picks = classify_output.picks if hasattr(classify_output, 'picks') else []
1126
+
1127
+ # Group picks by time to write to CSV
1128
+ # SeisBench picks are individual. We'll group them if they are very close or just write them.
1129
+ # To match EQCCT style, we'll try to find P and S pairs within a 10s window?
1130
+ # Actually, let's just write them as they come for now, or use a simple grouping.
1131
+
1132
+ p_picks = [p for p in picks if getattr(p, 'phase', 'P').upper() == 'P']
1133
+ s_picks = [p for p in picks if getattr(p, 'phase', 'P').upper() == 'S']
1134
+
1135
+ # Simple pairing: for each P, find the first S that comes after it within 30s
1136
+ used_s = set()
1137
+ for p in p_picks:
1138
+ # Robust attribute extraction for SeisBench Pick objects
1139
+ p_time = getattr(p, 'peak_time', getattr(p, 'start_time', getattr(p, 'time', None)))
1140
+ p_prob = getattr(p, 'peak_value', getattr(p, 'score', getattr(p, 'value', 0.0)))
1141
+
1142
+ if p_time is None:
1143
+ continue
1144
+
1145
+ match_s = None
1146
+ for s in s_picks:
1147
+ s_time = getattr(s, 'peak_time', getattr(s, 'start_time', getattr(s, 'time', None)))
1148
+ if s not in used_s and s_time and 0 < (s_time - p_time) < 30:
1149
+ match_s = s
1150
+ used_s.add(s)
1151
+ break
1152
+
1153
+ if match_s:
1154
+ ms_time = getattr(match_s, 'peak_time', getattr(match_s, 'start_time', getattr(match_s, 'time', None)))
1155
+ ms_prob = getattr(match_s, 'peak_value', getattr(match_s, 'score', getattr(match_s, 'value', 0.0)))
1156
+ s_time_str = ms_time.strftime('%Y-%m-%d %H:%M:%S.%f') if ms_time else ''
1157
+ s_prob_str = f"{ms_prob:.6f}"
1158
+ else:
1159
+ s_time_str = ''
1160
+ s_prob_str = ''
1161
+
1162
+ predict_writer.writerow([
1163
+ station_code,
1164
+ network_code,
1165
+ station_code,
1166
+ 0, # instrument_type
1167
+ station_lat,
1168
+ station_lon,
1169
+ station_elv,
1170
+ p_time.strftime('%Y-%m-%d %H:%M:%S.%f'),
1171
+ f"{p_prob:.6f}",
1172
+ s_time_str,
1173
+ s_prob_str
1174
+ ])
1175
+
1176
+ # Write remaining S picks
1177
+ for s in s_picks:
1178
+ if s not in used_s:
1179
+ s_time = getattr(s, 'peak_time', getattr(s, 'start_time', getattr(s, 'time', None)))
1180
+ s_prob = getattr(s, 'peak_value', getattr(s, 'score', getattr(s, 'value', 0.0)))
1181
+ if s_time:
1182
+ predict_writer.writerow([
1183
+ station_code,
1184
+ network_code,
1185
+ station_code,
1186
+ 0, # instrument_type
1187
+ station_lat,
1188
+ station_lon,
1189
+ station_elv,
1190
+ '',
1191
+ '',
1192
+ s_time.strftime('%Y-%m-%d %H:%M:%S.%f'),
1193
+ f"{s_prob:.6f}"
1194
+ ])
1195
+
1196
+ # If no picks found at all, write one row with station info
1197
+ if not picks:
1198
+ predict_writer.writerow([
1199
+ station_code,
1200
+ network_code,
1201
+ station_code,
1202
+ 0, # instrument_type
1203
+ station_lat,
1204
+ station_lon,
1205
+ station_elv,
1206
+ '', '', '', ''
1207
+ ])
1208
+
1209
+ csvPr_gen.flush()
1210
+ csvPr_gen.close()
1211
+
1212
+ end_Predicting = time.time()
1213
+ delta = (end_Predicting - start_Predicting)
1214
+ return f"{pos} {station}: Finished the prediction in {round(delta,2)}s. (HP={freqmin}, LP={freqmax}, picks={len(picks)})"
1215
+
1216
+ except Exception as exp:
1217
+ if 'csvPr_gen' in locals():
1218
+ csvPr_gen.close()
1219
+ return f"{pos} {station}: FAILED the prediction. {exp}"
1220
+
1221
+
1222
+ @ray.remote
1223
+ def parallel_predict(predict_args, model_actor, gpu=False):
1224
+ """
1225
+ Modified to use shared ModelActor instead of loading model per task
1226
+ """
1227
+ # --- QUIET TF C++/Python LOGS BEFORE ANY TF IMPORT ---
1228
+ # We were getting info messages from TF because we were importing it natively from eqcct_tf_models
1229
+ # We need to supress TF first before we import it fully
1230
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3") # 3=ERROR
1231
+ os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0") # hide oneDNN banner
1232
+ if not gpu:
1233
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1") # don't probe CUDA on CPU tasks
1234
+
1235
+ # Python-side TF/absl logging
1236
+ try:
1237
+ import tensorflow as tf
1238
+ tf.get_logger().setLevel(logging.ERROR)
1239
+ try:
1240
+ from absl import logging as absl_logging
1241
+ absl_logging.set_verbosity(absl_logging.ERROR)
1242
+ except Exception:
1243
+ pass
1244
+ except Exception:
1245
+ # If eqcct_tf_models imports TF later, env vars above will still suppress C++ logs.
1246
+ pass
1247
+
1248
+ from .eqcct_tf_models import Patches, PatchEncoder, StochasticDepth, PreLoadGeneratorTest, load_eqcct_model
1249
+ pos, station, out_dir, args = predict_args
1250
+
1251
+ # NOTE: We removed the model loading code that was causing OOM errors
1252
+ # The model is now shared via the model_actor
1253
+
1254
+ save_dir = os.path.join(out_dir, str(station)+'_outputs')
1255
+ csv_filename = os.path.join(save_dir,'X_prediction_results.csv')
1256
+
1257
+ if os.path.isfile(csv_filename):
1258
+ if args['overwrite']:
1259
+ shutil.rmtree(save_dir)
1260
+ else:
1261
+ return f"{pos} {station}: Skipped (already exists - overwrite=False)."
1262
+
1263
+ os.makedirs(save_dir)
1264
+ csvPr_gen = open(csv_filename, 'w')
1265
+ predict_writer = csv.writer(csvPr_gen, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
1266
+ predict_writer.writerow(['file_name',
1267
+ 'network',
1268
+ 'station',
1269
+ 'instrument_type',
1270
+ 'station_lat',
1271
+ 'station_lon',
1272
+ 'station_elv',
1273
+ 'p_arrival_time',
1274
+ 'p_probability',
1275
+ 's_arrival_time',
1276
+ 's_probability'])
1277
+ csvPr_gen.flush()
1278
+
1279
+ start_Predicting = time.time()
1280
+ files_list = glob.glob(f"{args['input_dir']}/{station}/*mseed")
1281
+
1282
+ try:
1283
+ meta, data_set, hp, lp = _mseed2nparray(args, files_list, station)
1284
+ except Exception:
1285
+ return f"{pos} {station}: FAILED reading mSEED."
1286
+
1287
+ try:
1288
+ params_pred = {'batch_size': args["batch_size"], 'norm_mode': args["normalization_mode"]}
1289
+ pred_generator = PreLoadGeneratorTest(meta["trace_start_time"], data_set, **params_pred)
1290
+
1291
+ # USE THE SHARED MODEL ACTOR INSTEAD OF LOADING MODEL
1292
+ # predP, predS = ray.get(model_actor.predict.remote(pred_generator))\
1293
+ predP, predS = ray.get(model_actor.predict_from_arrays.remote(
1294
+ meta["trace_start_time"], data_set, args["batch_size"], args["normalization_mode"]))
1295
+
1296
+ detection_memory = []
1297
+ prob_memory = []
1298
+ for ix in range(len(predP)):
1299
+ Ppicks, Pprob = _picker(args, predP[ix,:, 0])
1300
+ Spicks, Sprob = _picker(args, predS[ix,:, 0], 'S_threshold')
1301
+
1302
+ detection_memory, prob_memory = _output_writter_prediction(
1303
+ meta, csvPr_gen, Ppicks, Pprob, Spicks, Sprob,
1304
+ detection_memory, prob_memory, predict_writer, ix, len(predP), len(predS)
1305
+ )
1306
+
1307
+ end_Predicting = time.time()
1308
+ delta = (end_Predicting - start_Predicting)
1309
+ return f"{pos} {station}: Finished the prediction in {round(delta,2)}s. (HP={hp}, LP={lp})"
1310
+
1311
+ except Exception as exp:
1312
+ return f"{pos} {station}: FAILED the prediction. {exp}"