phasenet 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- phasenet-0.2.0/.gitignore +12 -0
- phasenet-0.2.0/PKG-INFO +162 -0
- phasenet-0.2.0/README.md +133 -0
- phasenet-0.2.0/examples/ceed_training_plots.py +181 -0
- phasenet-0.2.0/examples/italy/prepare_inventory.py +233 -0
- phasenet-0.2.0/examples/italy/prepare_mseed_list.py +118 -0
- phasenet-0.2.0/examples/italy/run_all_years.sh +64 -0
- phasenet-0.2.0/examples/italy/upload_loop.sh +12 -0
- phasenet-0.2.0/examples/italy/upload_phasenet_origin.sh +47 -0
- phasenet-0.2.0/phasenet/__init__.py +3 -0
- phasenet-0.2.0/phasenet/data/__init__.py +7 -0
- phasenet-0.2.0/phasenet/data/ceed.py +2294 -0
- phasenet-0.2.0/phasenet/data/das.py +2555 -0
- phasenet-0.2.0/phasenet/data/transforms.py +776 -0
- phasenet-0.2.0/phasenet/models/__init__.py +8 -0
- phasenet-0.2.0/phasenet/models/phasenet.py +480 -0
- phasenet-0.2.0/phasenet/models/phasenet_das.py +75 -0
- phasenet-0.2.0/phasenet/models/phasenet_das_plus.py +46 -0
- phasenet-0.2.0/phasenet/models/phasenet_plus.py +27 -0
- phasenet-0.2.0/phasenet/models/phasenet_tf.py +21 -0
- phasenet-0.2.0/phasenet/models/phasenet_tf_plus.py +27 -0
- phasenet-0.2.0/phasenet/models/prompt/__init__.py +11 -0
- phasenet-0.2.0/phasenet/models/prompt/common.py +43 -0
- phasenet-0.2.0/phasenet/models/prompt/mask_decoder.py +185 -0
- phasenet-0.2.0/phasenet/models/prompt/prompt_encoder.py +242 -0
- phasenet-0.2.0/phasenet/models/prompt/transformer.py +232 -0
- phasenet-0.2.0/phasenet/models/unet.py +1316 -0
- phasenet-0.2.0/phasenet/models/unet2018.py +253 -0
- phasenet-0.2.0/phasenet/utils/__init__.py +3 -0
- phasenet-0.2.0/phasenet/utils/detect_peaks_cpu.py +207 -0
- phasenet-0.2.0/phasenet/utils/inference.py +405 -0
- phasenet-0.2.0/phasenet/utils/postprocess.py +711 -0
- phasenet-0.2.0/phasenet/utils/visualization.py +1465 -0
- phasenet-0.2.0/phasenet.egg-info/PKG-INFO +162 -0
- phasenet-0.2.0/phasenet.egg-info/SOURCES.txt +54 -0
- phasenet-0.2.0/phasenet.egg-info/dependency_links.txt +1 -0
- phasenet-0.2.0/phasenet.egg-info/entry_points.txt +3 -0
- phasenet-0.2.0/phasenet.egg-info/requires.txt +15 -0
- phasenet-0.2.0/phasenet.egg-info/top_level.txt +1 -0
- phasenet-0.2.0/predict.py +630 -0
- phasenet-0.2.0/pyproject.toml +47 -0
- phasenet-0.2.0/requirements.txt +15 -0
- phasenet-0.2.0/scripts/predict_ceed.py +489 -0
- phasenet-0.2.0/scripts/predict_ceed.sh +41 -0
- phasenet-0.2.0/scripts/predict_das.py +607 -0
- phasenet-0.2.0/scripts/predict_das.sh +55 -0
- phasenet-0.2.0/scripts/predict_mseed.py +487 -0
- phasenet-0.2.0/scripts/semisupervised_das.sh +221 -0
- phasenet-0.2.0/scripts/train_ceed.sh +21 -0
- phasenet-0.2.0/scripts/train_das.sh +57 -0
- phasenet-0.2.0/setup.cfg +4 -0
- phasenet-0.2.0/tests/benchmark_dataloader.py +373 -0
- phasenet-0.2.0/tests/test_augmentations.py +234 -0
- phasenet-0.2.0/tests/test_labels.py +288 -0
- phasenet-0.2.0/train.py +807 -0
- phasenet-0.2.0/utils.py +488 -0
phasenet-0.2.0/PKG-INFO
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: phasenet
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: A PyTorch implementation of PhaseNet for seismic and DAS phase picking
|
|
5
|
+
Author: Weiqiang Zhu
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/AI4EPS/phasenet-pytorch
|
|
8
|
+
Project-URL: Repository, https://github.com/AI4EPS/phasenet-pytorch
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Classifier: Topic :: Scientific/Engineering
|
|
12
|
+
Requires-Python: >=3.8
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
Requires-Dist: torch
|
|
15
|
+
Requires-Dist: torchvision
|
|
16
|
+
Requires-Dist: einops
|
|
17
|
+
Requires-Dist: numpy
|
|
18
|
+
Requires-Dist: scipy
|
|
19
|
+
Requires-Dist: h5py
|
|
20
|
+
Requires-Dist: matplotlib
|
|
21
|
+
Requires-Dist: pandas
|
|
22
|
+
Requires-Dist: tqdm
|
|
23
|
+
Requires-Dist: fsspec
|
|
24
|
+
Requires-Dist: obspy
|
|
25
|
+
Requires-Dist: gcsfs
|
|
26
|
+
Requires-Dist: datasets
|
|
27
|
+
Requires-Dist: pyarrow
|
|
28
|
+
Requires-Dist: wandb
|
|
29
|
+
|
|
30
|
+
# PhaseNet-PyTorch
|
|
31
|
+
|
|
32
|
+
PyTorch implementation of PhaseNet for seismic and DAS phase picking, event detection, and polarity classification.
|
|
33
|
+
|
|
34
|
+
## Models
|
|
35
|
+
|
|
36
|
+
| Model | Features | Data Type |
|
|
37
|
+
|-------|----------|-----------|
|
|
38
|
+
| `phasenet` | Phase (P/S) picking | Seismic 3-component |
|
|
39
|
+
| `phasenet_tf` | Phase picking + STFT spectrogram | Seismic 3-component |
|
|
40
|
+
| `phasenet_plus` | Phase + polarity + event detection | Seismic 3-component |
|
|
41
|
+
| `phasenet_tf_plus` | Phase + polarity + event detection + STFT | Seismic 3-component |
|
|
42
|
+
| `phasenet_das` | Phase picking | DAS single-channel |
|
|
43
|
+
| `phasenet_das_plus` | Phase + event detection | DAS single-channel |
|
|
44
|
+
|
|
45
|
+
The `_tf` variants add a Short-Time Fourier Transform (STFT) branch that extracts frequency features alongside the temporal waveform, improving performance on noisy data.
|
|
46
|
+
|
|
47
|
+
## Prediction
|
|
48
|
+
|
|
49
|
+
### CEED (Seismic) Prediction
|
|
50
|
+
|
|
51
|
+
```bash
|
|
52
|
+
# Demo: process a few events with plots
|
|
53
|
+
python scripts/predict_ceed.py --n-events 5
|
|
54
|
+
|
|
55
|
+
# Process all days for a year (saves parquets to results/ceed/)
|
|
56
|
+
python scripts/predict_ceed.py --all --year 2025 --output-dir results/ceed
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
Output: one parquet per day file in `results/ceed/{region}/`, with columns:
|
|
60
|
+
`event_id, station_id, waveform_index, origin_id, origin_index, origin_time, phase_index, phase_time, phase_score, phase_type, phase_polarity`
|
|
61
|
+
|
|
62
|
+
### DAS Prediction
|
|
63
|
+
|
|
64
|
+
```bash
|
|
65
|
+
# Predict with base PhaseNet on a HuggingFace subset
|
|
66
|
+
python scripts/predict_das.py --subset arcata --plot
|
|
67
|
+
|
|
68
|
+
# Predict with a trained DAS model on local data
|
|
69
|
+
python scripts/predict_das.py \
|
|
70
|
+
--data-dir data/quakeflow_das/arcata/data \
|
|
71
|
+
--model-type phasenet_das_plus \
|
|
72
|
+
--checkpoint output/train_das_arcata/checkpoint.pth \
|
|
73
|
+
--output-dir results/das/train_das_arcata/arcata \
|
|
74
|
+
--no-ema --plot
|
|
75
|
+
|
|
76
|
+
# Predict from a file list
|
|
77
|
+
python scripts/predict_das.py \
|
|
78
|
+
--file-list file_list.txt \
|
|
79
|
+
--model-type phasenet_das_plus \
|
|
80
|
+
--checkpoint output/model.pth \
|
|
81
|
+
--plot
|
|
82
|
+
|
|
83
|
+
# Multi-GPU prediction
|
|
84
|
+
bash scripts/predict_das.sh phasenet_das_plus arcata 8 output/model.pth
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
Output: one parquet per event in `results/das/{model_name}/{subset}/`, with columns:
|
|
88
|
+
`event_id, channel_index, origin_id, origin_index, origin_time, phase_index, phase_time, phase_score, phase_type, dt_s, ps_center, ps_interval`
|
|
89
|
+
|
|
90
|
+
By default, picks are associated using P-S pairing. Use `--use-event-head` to associate via the model's event detection head instead.
|
|
91
|
+
|
|
92
|
+
## Training
|
|
93
|
+
|
|
94
|
+
### CEED (Seismic) Training
|
|
95
|
+
|
|
96
|
+
```bash
|
|
97
|
+
python train.py \
|
|
98
|
+
--model phasenet_plus \
|
|
99
|
+
--dataset-type ceed \
|
|
100
|
+
--label-path results/ceed \
|
|
101
|
+
--nx 16 \
|
|
102
|
+
--max-iters 100000 \
|
|
103
|
+
--batch-size 8 \
|
|
104
|
+
--workers 4 \
|
|
105
|
+
--lr 3e-4 \
|
|
106
|
+
--eval-interval 5000 \
|
|
107
|
+
--output-dir output/train_ceed
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
### DAS Training
|
|
111
|
+
|
|
112
|
+
```bash
|
|
113
|
+
# Using the training script
|
|
114
|
+
bash scripts/train_das.sh 0 arcata v26
|
|
115
|
+
|
|
116
|
+
# Or directly
|
|
117
|
+
python train.py \
|
|
118
|
+
--model phasenet_das_plus \
|
|
119
|
+
--dataset-type das \
|
|
120
|
+
--data-path data/quakeflow_das/arcata/data \
|
|
121
|
+
--label-path results/das/phasenet/arcata/picks \
|
|
122
|
+
--label-list results/das/phasenet/arcata/labels.txt \
|
|
123
|
+
--nx 2048 --nt 4096 \
|
|
124
|
+
--num-patch 16 \
|
|
125
|
+
--max-iters 50000 \
|
|
126
|
+
--batch-size 2 --workers 8 \
|
|
127
|
+
--lr 1e-4 --weight-decay 0.01 \
|
|
128
|
+
--model-ema --model-ema-decay 0.999 \
|
|
129
|
+
--eval-interval 1000 --save-interval 1000 \
|
|
130
|
+
--output-dir output/train_das_arcata_v26
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
### Key Training Options
|
|
134
|
+
|
|
135
|
+
| Option | Description | Default |
|
|
136
|
+
|--------|-------------|---------|
|
|
137
|
+
| `--num-patch N` | Random crops per DAS sample (amortizes IO) | 2 |
|
|
138
|
+
| `--model-ema` | Enable exponential moving average | off |
|
|
139
|
+
| `--gradient-accumulation-steps N` | Accumulate gradients for larger effective batch | 1 |
|
|
140
|
+
| `--clip-grad-norm V` | Gradient clipping | 1.0 |
|
|
141
|
+
| `--compile` | Enable torch.compile | off |
|
|
142
|
+
| `--resume --checkpoint PATH` | Resume from checkpoint | - |
|
|
143
|
+
| `--reset-lr` | Reset LR schedule when resuming | off |
|
|
144
|
+
|
|
145
|
+
## Semi-supervised Training
|
|
146
|
+
|
|
147
|
+
Iterative self-training pipeline for DAS: predict → train on predictions → predict with new model → repeat.
|
|
148
|
+
|
|
149
|
+
```bash
|
|
150
|
+
# Start from PhaseNet (train DAS model from scratch)
|
|
151
|
+
bash scripts/semisupervised_das.sh arcata 5 0
|
|
152
|
+
|
|
153
|
+
# Start from a pretrained DAS model
|
|
154
|
+
bash scripts/semisupervised_das.sh arcata 5 0 phasenet_das output/train_das_v26/checkpoint.pth
|
|
155
|
+
```
|
|
156
|
+
|
|
157
|
+
Arguments: `[subset] [num_iterations] [gpu] [start_from] [checkpoint]`
|
|
158
|
+
|
|
159
|
+
- **From phasenet**: iteration 0 predicts with PhaseNet, iteration 1 trains from scratch (10k steps, warmup), iterations 2+ continue (1k steps, no warmup)
|
|
160
|
+
- **From phasenet_das**: iteration 0 predicts with pretrained DAS model, iterations 1+ continue (1k steps, no warmup)
|
|
161
|
+
|
|
162
|
+
Results are saved to `results/semisupervised_das/` and checkpoints to `output/semisupervised_das_v{N}/`.
|
phasenet-0.2.0/README.md
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# PhaseNet-PyTorch
|
|
2
|
+
|
|
3
|
+
PyTorch implementation of PhaseNet for seismic and DAS phase picking, event detection, and polarity classification.
|
|
4
|
+
|
|
5
|
+
## Models
|
|
6
|
+
|
|
7
|
+
| Model | Features | Data Type |
|
|
8
|
+
|-------|----------|-----------|
|
|
9
|
+
| `phasenet` | Phase (P/S) picking | Seismic 3-component |
|
|
10
|
+
| `phasenet_tf` | Phase picking + STFT spectrogram | Seismic 3-component |
|
|
11
|
+
| `phasenet_plus` | Phase + polarity + event detection | Seismic 3-component |
|
|
12
|
+
| `phasenet_tf_plus` | Phase + polarity + event detection + STFT | Seismic 3-component |
|
|
13
|
+
| `phasenet_das` | Phase picking | DAS single-channel |
|
|
14
|
+
| `phasenet_das_plus` | Phase + event detection | DAS single-channel |
|
|
15
|
+
|
|
16
|
+
The `_tf` variants add a Short-Time Fourier Transform (STFT) branch that extracts frequency features alongside the temporal waveform, improving performance on noisy data.
|
|
17
|
+
|
|
18
|
+
## Prediction
|
|
19
|
+
|
|
20
|
+
### CEED (Seismic) Prediction
|
|
21
|
+
|
|
22
|
+
```bash
|
|
23
|
+
# Demo: process a few events with plots
|
|
24
|
+
python scripts/predict_ceed.py --n-events 5
|
|
25
|
+
|
|
26
|
+
# Process all days for a year (saves parquets to results/ceed/)
|
|
27
|
+
python scripts/predict_ceed.py --all --year 2025 --output-dir results/ceed
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
Output: one parquet per day file in `results/ceed/{region}/`, with columns:
|
|
31
|
+
`event_id, station_id, waveform_index, origin_id, origin_index, origin_time, phase_index, phase_time, phase_score, phase_type, phase_polarity`
|
|
32
|
+
|
|
33
|
+
### DAS Prediction
|
|
34
|
+
|
|
35
|
+
```bash
|
|
36
|
+
# Predict with base PhaseNet on a HuggingFace subset
|
|
37
|
+
python scripts/predict_das.py --subset arcata --plot
|
|
38
|
+
|
|
39
|
+
# Predict with a trained DAS model on local data
|
|
40
|
+
python scripts/predict_das.py \
|
|
41
|
+
--data-dir data/quakeflow_das/arcata/data \
|
|
42
|
+
--model-type phasenet_das_plus \
|
|
43
|
+
--checkpoint output/train_das_arcata/checkpoint.pth \
|
|
44
|
+
--output-dir results/das/train_das_arcata/arcata \
|
|
45
|
+
--no-ema --plot
|
|
46
|
+
|
|
47
|
+
# Predict from a file list
|
|
48
|
+
python scripts/predict_das.py \
|
|
49
|
+
--file-list file_list.txt \
|
|
50
|
+
--model-type phasenet_das_plus \
|
|
51
|
+
--checkpoint output/model.pth \
|
|
52
|
+
--plot
|
|
53
|
+
|
|
54
|
+
# Multi-GPU prediction
|
|
55
|
+
bash scripts/predict_das.sh phasenet_das_plus arcata 8 output/model.pth
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
Output: one parquet per event in `results/das/{model_name}/{subset}/`, with columns:
|
|
59
|
+
`event_id, channel_index, origin_id, origin_index, origin_time, phase_index, phase_time, phase_score, phase_type, dt_s, ps_center, ps_interval`
|
|
60
|
+
|
|
61
|
+
By default, picks are associated using P-S pairing. Use `--use-event-head` to associate via the model's event detection head instead.
|
|
62
|
+
|
|
63
|
+
## Training
|
|
64
|
+
|
|
65
|
+
### CEED (Seismic) Training
|
|
66
|
+
|
|
67
|
+
```bash
|
|
68
|
+
python train.py \
|
|
69
|
+
--model phasenet_plus \
|
|
70
|
+
--dataset-type ceed \
|
|
71
|
+
--label-path results/ceed \
|
|
72
|
+
--nx 16 \
|
|
73
|
+
--max-iters 100000 \
|
|
74
|
+
--batch-size 8 \
|
|
75
|
+
--workers 4 \
|
|
76
|
+
--lr 3e-4 \
|
|
77
|
+
--eval-interval 5000 \
|
|
78
|
+
--output-dir output/train_ceed
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
### DAS Training
|
|
82
|
+
|
|
83
|
+
```bash
|
|
84
|
+
# Using the training script
|
|
85
|
+
bash scripts/train_das.sh 0 arcata v26
|
|
86
|
+
|
|
87
|
+
# Or directly
|
|
88
|
+
python train.py \
|
|
89
|
+
--model phasenet_das_plus \
|
|
90
|
+
--dataset-type das \
|
|
91
|
+
--data-path data/quakeflow_das/arcata/data \
|
|
92
|
+
--label-path results/das/phasenet/arcata/picks \
|
|
93
|
+
--label-list results/das/phasenet/arcata/labels.txt \
|
|
94
|
+
--nx 2048 --nt 4096 \
|
|
95
|
+
--num-patch 16 \
|
|
96
|
+
--max-iters 50000 \
|
|
97
|
+
--batch-size 2 --workers 8 \
|
|
98
|
+
--lr 1e-4 --weight-decay 0.01 \
|
|
99
|
+
--model-ema --model-ema-decay 0.999 \
|
|
100
|
+
--eval-interval 1000 --save-interval 1000 \
|
|
101
|
+
--output-dir output/train_das_arcata_v26
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
### Key Training Options
|
|
105
|
+
|
|
106
|
+
| Option | Description | Default |
|
|
107
|
+
|--------|-------------|---------|
|
|
108
|
+
| `--num-patch N` | Random crops per DAS sample (amortizes IO) | 2 |
|
|
109
|
+
| `--model-ema` | Enable exponential moving average | off |
|
|
110
|
+
| `--gradient-accumulation-steps N` | Accumulate gradients for larger effective batch | 1 |
|
|
111
|
+
| `--clip-grad-norm V` | Gradient clipping | 1.0 |
|
|
112
|
+
| `--compile` | Enable torch.compile | off |
|
|
113
|
+
| `--resume --checkpoint PATH` | Resume from checkpoint | - |
|
|
114
|
+
| `--reset-lr` | Reset LR schedule when resuming | off |
|
|
115
|
+
|
|
116
|
+
## Semi-supervised Training
|
|
117
|
+
|
|
118
|
+
Iterative self-training pipeline for DAS: predict → train on predictions → predict with new model → repeat.
|
|
119
|
+
|
|
120
|
+
```bash
|
|
121
|
+
# Start from PhaseNet (train DAS model from scratch)
|
|
122
|
+
bash scripts/semisupervised_das.sh arcata 5 0
|
|
123
|
+
|
|
124
|
+
# Start from a pretrained DAS model
|
|
125
|
+
bash scripts/semisupervised_das.sh arcata 5 0 phasenet_das output/train_das_v26/checkpoint.pth
|
|
126
|
+
```
|
|
127
|
+
|
|
128
|
+
Arguments: `[subset] [num_iterations] [gpu] [start_from] [checkpoint]`
|
|
129
|
+
|
|
130
|
+
- **From phasenet**: iteration 0 predicts with PhaseNet, iteration 1 trains from scratch (10k steps, warmup), iterations 2+ continue (1k steps, no warmup)
|
|
131
|
+
- **From phasenet_das**: iteration 0 predicts with pretrained DAS model, iterations 1+ continue (1k steps, no warmup)
|
|
132
|
+
|
|
133
|
+
Results are saved to `results/semisupervised_das/` and checkpoints to `output/semisupervised_das_v{N}/`.
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
"""Generate training label plots from CEED prediction results.
|
|
3
|
+
|
|
4
|
+
Loads a prediction parquet, joins with original waveforms from GCS,
|
|
5
|
+
and plots top-N events with training labels.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
python examples/ceed_training_plots.py \
|
|
9
|
+
--parquet results/ceed/phasenet2018/NC/2025_001.parquet \
|
|
10
|
+
--region NC --year 2025 --day 001
|
|
11
|
+
"""
|
|
12
|
+
import argparse
|
|
13
|
+
import json
|
|
14
|
+
import os
|
|
15
|
+
import sys
|
|
16
|
+
|
|
17
|
+
sys.path.insert(0, ".")
|
|
18
|
+
|
|
19
|
+
import gcsfs
|
|
20
|
+
import numpy as np
|
|
21
|
+
import pandas as pd
|
|
22
|
+
import pyarrow.parquet as pq
|
|
23
|
+
|
|
24
|
+
from phasenet.data.ceed import (
|
|
25
|
+
records_to_sample, default_train_transforms, plot_demo, Target,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
BUCKETS = {
|
|
29
|
+
"NC": "quakeflow_dataset/NCEDC",
|
|
30
|
+
"SC": "quakeflow_dataset/SCEDC",
|
|
31
|
+
}
|
|
32
|
+
GCS_CRED_PATH = os.path.expanduser("~/.config/gcloud/application_default_credentials.json")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_storage_options():
|
|
36
|
+
if os.path.exists(GCS_CRED_PATH):
|
|
37
|
+
with open(GCS_CRED_PATH) as f:
|
|
38
|
+
return {"token": json.load(f)}
|
|
39
|
+
return {}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def load_day_waveforms(fs, bucket, year, day):
|
|
43
|
+
"""Return {waveform_index: ndarray (3, nt)} for a day parquet."""
|
|
44
|
+
path = f"{bucket}/waveform_parquet/{year}/{day}.parquet"
|
|
45
|
+
pf = pq.ParquetFile(fs.open(path))
|
|
46
|
+
non_wave_cols = [c for c in pf.schema_arrow.names if c != "waveform"]
|
|
47
|
+
result = {}
|
|
48
|
+
row_offset = 0
|
|
49
|
+
for batch in pf.iter_batches(batch_size=256):
|
|
50
|
+
n = len(batch)
|
|
51
|
+
col = batch.column("waveform")
|
|
52
|
+
flat = col.flatten().flatten()
|
|
53
|
+
buf = flat.buffers()[1]
|
|
54
|
+
waveforms = np.frombuffer(
|
|
55
|
+
buf, dtype=np.float32, offset=flat.offset * 4
|
|
56
|
+
).reshape(n, 3, -1).copy()
|
|
57
|
+
meta = batch.select(non_wave_cols).to_pydict()
|
|
58
|
+
for i in range(n):
|
|
59
|
+
eid = meta["event_id"][i] or ""
|
|
60
|
+
if eid:
|
|
61
|
+
result[row_offset + i] = waveforms[i]
|
|
62
|
+
row_offset += n
|
|
63
|
+
return result
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def generate_training_plots(parquet_path, region, fs, year, day,
|
|
67
|
+
n_events=3, figure_dir="figures/ceed"):
|
|
68
|
+
"""Load label parquet, join with original waveforms, plot training labels."""
|
|
69
|
+
print(f"\n=== Training label plots for {region} ===")
|
|
70
|
+
|
|
71
|
+
df = pd.read_parquet(parquet_path)
|
|
72
|
+
print(f" {len(df)} pick rows, "
|
|
73
|
+
f"{df.groupby(['event_id','origin_index']).ngroups} detected origins")
|
|
74
|
+
|
|
75
|
+
bucket = BUCKETS[region]
|
|
76
|
+
print(" Loading waveforms from original parquet...")
|
|
77
|
+
wave_by_idx = load_day_waveforms(fs, bucket, year, day)
|
|
78
|
+
print(f" {len(wave_by_idx)} waveforms loaded")
|
|
79
|
+
|
|
80
|
+
raw_pq = pq.read_table(
|
|
81
|
+
fs.open(f"{bucket}/waveform_parquet/{year}/{day}.parquet"),
|
|
82
|
+
columns=["event_id", "network", "station", "distance_km", "begin_time"],
|
|
83
|
+
).to_pydict()
|
|
84
|
+
n_raw = len(raw_pq["event_id"])
|
|
85
|
+
meta_by_idx = {i: {k: raw_pq[k][i] for k in raw_pq} for i in range(n_raw)}
|
|
86
|
+
|
|
87
|
+
top_eids = (
|
|
88
|
+
df.groupby("event_id")["phase_index"].count()
|
|
89
|
+
.nlargest(n_events).index
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
transforms = default_train_transforms(crop_length=4096, enable_stacking=False)
|
|
93
|
+
fig_dir = os.path.join(figure_dir, region)
|
|
94
|
+
|
|
95
|
+
for src_eid in top_eids:
|
|
96
|
+
ev_df = df[df["event_id"] == src_eid]
|
|
97
|
+
linked = ev_df[ev_df["origin_index"].notna()]
|
|
98
|
+
valid_widxs = set(linked["waveform_index"].unique())
|
|
99
|
+
|
|
100
|
+
if not valid_widxs:
|
|
101
|
+
print(f" [WARN] {src_eid}: no stations with model picks")
|
|
102
|
+
continue
|
|
103
|
+
|
|
104
|
+
recs_tagged = []
|
|
105
|
+
for widx in valid_widxs:
|
|
106
|
+
wave = wave_by_idx.get(widx)
|
|
107
|
+
meta = meta_by_idx.get(widx, {})
|
|
108
|
+
if wave is None:
|
|
109
|
+
continue
|
|
110
|
+
recs_tagged.append({
|
|
111
|
+
"event_id": src_eid,
|
|
112
|
+
"network": meta.get("network", ""),
|
|
113
|
+
"station": meta.get("station", ""),
|
|
114
|
+
"distance_km": meta.get("distance_km"),
|
|
115
|
+
"begin_time": meta.get("begin_time"),
|
|
116
|
+
"waveform": wave,
|
|
117
|
+
"_widx": widx,
|
|
118
|
+
})
|
|
119
|
+
|
|
120
|
+
if not recs_tagged:
|
|
121
|
+
print(f" [WARN] {src_eid}: no waveforms found")
|
|
122
|
+
continue
|
|
123
|
+
|
|
124
|
+
recs_tagged.sort(key=lambda r: r.get("distance_km") or 0.0)
|
|
125
|
+
widx_to_sta = {r["_widx"]: i for i, r in enumerate(recs_tagged)}
|
|
126
|
+
recs = [{k: v for k, v in r.items() if k != "_widx"} for r in recs_tagged]
|
|
127
|
+
|
|
128
|
+
p_picks_all, s_picks_all = [], []
|
|
129
|
+
ps_centers_all, ps_intervals_all = [], []
|
|
130
|
+
|
|
131
|
+
for widx, sta_grp in linked.groupby("waveform_index"):
|
|
132
|
+
sta_idx = widx_to_sta.get(widx)
|
|
133
|
+
if sta_idx is None:
|
|
134
|
+
continue
|
|
135
|
+
for pidx in sta_grp.loc[sta_grp["phase_type"] == "P", "phase_index"].dropna():
|
|
136
|
+
p_picks_all.append((sta_idx, float(pidx)))
|
|
137
|
+
for sidx in sta_grp.loc[sta_grp["phase_type"] == "S", "phase_index"].dropna():
|
|
138
|
+
s_picks_all.append((sta_idx, float(sidx)))
|
|
139
|
+
for _, o_grp in sta_grp.groupby("origin_index"):
|
|
140
|
+
p_o = o_grp.loc[o_grp["phase_type"] == "P", "phase_index"].dropna()
|
|
141
|
+
s_o = o_grp.loc[o_grp["phase_type"] == "S", "phase_index"].dropna()
|
|
142
|
+
if p_o.empty or s_o.empty:
|
|
143
|
+
continue
|
|
144
|
+
p_i, s_i = float(p_o.iloc[0]), float(s_o.iloc[0])
|
|
145
|
+
ps_centers_all.append((sta_idx, (p_i + s_i) / 2.0))
|
|
146
|
+
ps_intervals_all.append((sta_idx, s_i - p_i))
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
sample = records_to_sample(recs)
|
|
150
|
+
if p_picks_all or s_picks_all:
|
|
151
|
+
sample.targets = [Target(
|
|
152
|
+
p_picks=p_picks_all, s_picks=s_picks_all,
|
|
153
|
+
ps_centers=ps_centers_all, ps_intervals=ps_intervals_all,
|
|
154
|
+
event_id=src_eid,
|
|
155
|
+
)]
|
|
156
|
+
n_origins = int(linked["origin_index"].nunique())
|
|
157
|
+
print(f" {src_eid}: {len(recs)} stations origins={n_origins} "
|
|
158
|
+
f"waveform {sample.waveform.shape}")
|
|
159
|
+
plot_demo(sample, transforms, event_id=src_eid,
|
|
160
|
+
output_dir=fig_dir, n_augmented=2)
|
|
161
|
+
except Exception as e:
|
|
162
|
+
print(f" [WARN] {src_eid}: {e}")
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def main():
|
|
166
|
+
parser = argparse.ArgumentParser(description="CEED training label plots")
|
|
167
|
+
parser.add_argument("--parquet", required=True, help="Path to prediction parquet")
|
|
168
|
+
parser.add_argument("--region", default="NC", help="Region (NC or SC)")
|
|
169
|
+
parser.add_argument("--year", type=int, required=True)
|
|
170
|
+
parser.add_argument("--day", required=True, help="Day string (e.g. 001)")
|
|
171
|
+
parser.add_argument("--n-events", type=int, default=3)
|
|
172
|
+
parser.add_argument("--figure-dir", default="figures/ceed")
|
|
173
|
+
args = parser.parse_args()
|
|
174
|
+
|
|
175
|
+
fs = gcsfs.GCSFileSystem(**get_storage_options())
|
|
176
|
+
generate_training_plots(args.parquet, args.region, fs,
|
|
177
|
+
args.year, args.day, args.n_events, args.figure_dir)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
if __name__ == "__main__":
|
|
181
|
+
main()
|