phasenet 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. phasenet-0.2.0/.gitignore +12 -0
  2. phasenet-0.2.0/PKG-INFO +162 -0
  3. phasenet-0.2.0/README.md +133 -0
  4. phasenet-0.2.0/examples/ceed_training_plots.py +181 -0
  5. phasenet-0.2.0/examples/italy/prepare_inventory.py +233 -0
  6. phasenet-0.2.0/examples/italy/prepare_mseed_list.py +118 -0
  7. phasenet-0.2.0/examples/italy/run_all_years.sh +64 -0
  8. phasenet-0.2.0/examples/italy/upload_loop.sh +12 -0
  9. phasenet-0.2.0/examples/italy/upload_phasenet_origin.sh +47 -0
  10. phasenet-0.2.0/phasenet/__init__.py +3 -0
  11. phasenet-0.2.0/phasenet/data/__init__.py +7 -0
  12. phasenet-0.2.0/phasenet/data/ceed.py +2294 -0
  13. phasenet-0.2.0/phasenet/data/das.py +2555 -0
  14. phasenet-0.2.0/phasenet/data/transforms.py +776 -0
  15. phasenet-0.2.0/phasenet/models/__init__.py +8 -0
  16. phasenet-0.2.0/phasenet/models/phasenet.py +480 -0
  17. phasenet-0.2.0/phasenet/models/phasenet_das.py +75 -0
  18. phasenet-0.2.0/phasenet/models/phasenet_das_plus.py +46 -0
  19. phasenet-0.2.0/phasenet/models/phasenet_plus.py +27 -0
  20. phasenet-0.2.0/phasenet/models/phasenet_tf.py +21 -0
  21. phasenet-0.2.0/phasenet/models/phasenet_tf_plus.py +27 -0
  22. phasenet-0.2.0/phasenet/models/prompt/__init__.py +11 -0
  23. phasenet-0.2.0/phasenet/models/prompt/common.py +43 -0
  24. phasenet-0.2.0/phasenet/models/prompt/mask_decoder.py +185 -0
  25. phasenet-0.2.0/phasenet/models/prompt/prompt_encoder.py +242 -0
  26. phasenet-0.2.0/phasenet/models/prompt/transformer.py +232 -0
  27. phasenet-0.2.0/phasenet/models/unet.py +1316 -0
  28. phasenet-0.2.0/phasenet/models/unet2018.py +253 -0
  29. phasenet-0.2.0/phasenet/utils/__init__.py +3 -0
  30. phasenet-0.2.0/phasenet/utils/detect_peaks_cpu.py +207 -0
  31. phasenet-0.2.0/phasenet/utils/inference.py +405 -0
  32. phasenet-0.2.0/phasenet/utils/postprocess.py +711 -0
  33. phasenet-0.2.0/phasenet/utils/visualization.py +1465 -0
  34. phasenet-0.2.0/phasenet.egg-info/PKG-INFO +162 -0
  35. phasenet-0.2.0/phasenet.egg-info/SOURCES.txt +54 -0
  36. phasenet-0.2.0/phasenet.egg-info/dependency_links.txt +1 -0
  37. phasenet-0.2.0/phasenet.egg-info/entry_points.txt +3 -0
  38. phasenet-0.2.0/phasenet.egg-info/requires.txt +15 -0
  39. phasenet-0.2.0/phasenet.egg-info/top_level.txt +1 -0
  40. phasenet-0.2.0/predict.py +630 -0
  41. phasenet-0.2.0/pyproject.toml +47 -0
  42. phasenet-0.2.0/requirements.txt +15 -0
  43. phasenet-0.2.0/scripts/predict_ceed.py +489 -0
  44. phasenet-0.2.0/scripts/predict_ceed.sh +41 -0
  45. phasenet-0.2.0/scripts/predict_das.py +607 -0
  46. phasenet-0.2.0/scripts/predict_das.sh +55 -0
  47. phasenet-0.2.0/scripts/predict_mseed.py +487 -0
  48. phasenet-0.2.0/scripts/semisupervised_das.sh +221 -0
  49. phasenet-0.2.0/scripts/train_ceed.sh +21 -0
  50. phasenet-0.2.0/scripts/train_das.sh +57 -0
  51. phasenet-0.2.0/setup.cfg +4 -0
  52. phasenet-0.2.0/tests/benchmark_dataloader.py +373 -0
  53. phasenet-0.2.0/tests/test_augmentations.py +234 -0
  54. phasenet-0.2.0/tests/test_labels.py +288 -0
  55. phasenet-0.2.0/train.py +807 -0
  56. phasenet-0.2.0/utils.py +488 -0
@@ -0,0 +1,12 @@
1
+ figures_backup/
2
+ figures/
3
+ results/
4
+ output/
5
+ __pycache__/
6
+ CLAUDE.md
7
+ *.pth
8
+ *.png
9
+ *.pdf
10
+ wandb/
11
+ ToDelete/
12
+
@@ -0,0 +1,162 @@
1
+ Metadata-Version: 2.4
2
+ Name: phasenet
3
+ Version: 0.2.0
4
+ Summary: A PyTorch implementation of PhaseNet for seismic and DAS phase picking
5
+ Author: Weiqiang Zhu
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/AI4EPS/phasenet-pytorch
8
+ Project-URL: Repository, https://github.com/AI4EPS/phasenet-pytorch
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Topic :: Scientific/Engineering
12
+ Requires-Python: >=3.8
13
+ Description-Content-Type: text/markdown
14
+ Requires-Dist: torch
15
+ Requires-Dist: torchvision
16
+ Requires-Dist: einops
17
+ Requires-Dist: numpy
18
+ Requires-Dist: scipy
19
+ Requires-Dist: h5py
20
+ Requires-Dist: matplotlib
21
+ Requires-Dist: pandas
22
+ Requires-Dist: tqdm
23
+ Requires-Dist: fsspec
24
+ Requires-Dist: obspy
25
+ Requires-Dist: gcsfs
26
+ Requires-Dist: datasets
27
+ Requires-Dist: pyarrow
28
+ Requires-Dist: wandb
29
+
30
+ # PhaseNet-PyTorch
31
+
32
+ PyTorch implementation of PhaseNet for seismic and DAS phase picking, event detection, and polarity classification.
33
+
34
+ ## Models
35
+
36
+ | Model | Features | Data Type |
37
+ |-------|----------|-----------|
38
+ | `phasenet` | Phase (P/S) picking | Seismic 3-component |
39
+ | `phasenet_tf` | Phase picking + STFT spectrogram | Seismic 3-component |
40
+ | `phasenet_plus` | Phase + polarity + event detection | Seismic 3-component |
41
+ | `phasenet_tf_plus` | Phase + polarity + event detection + STFT | Seismic 3-component |
42
+ | `phasenet_das` | Phase picking | DAS single-channel |
43
+ | `phasenet_das_plus` | Phase + event detection | DAS single-channel |
44
+
45
+ The `_tf` variants add a Short-Time Fourier Transform (STFT) branch that extracts frequency features alongside the temporal waveform, improving performance on noisy data.
46
+
47
+ ## Prediction
48
+
49
+ ### CEED (Seismic) Prediction
50
+
51
+ ```bash
52
+ # Demo: process a few events with plots
53
+ python scripts/predict_ceed.py --n-events 5
54
+
55
+ # Process all days for a year (saves parquets to results/ceed/)
56
+ python scripts/predict_ceed.py --all --year 2025 --output-dir results/ceed
57
+ ```
58
+
59
+ Output: one parquet per day file in `results/ceed/{region}/`, with columns:
60
+ `event_id, station_id, waveform_index, origin_id, origin_index, origin_time, phase_index, phase_time, phase_score, phase_type, phase_polarity`
61
+
62
+ ### DAS Prediction
63
+
64
+ ```bash
65
+ # Predict with base PhaseNet on a HuggingFace subset
66
+ python scripts/predict_das.py --subset arcata --plot
67
+
68
+ # Predict with a trained DAS model on local data
69
+ python scripts/predict_das.py \
70
+ --data-dir data/quakeflow_das/arcata/data \
71
+ --model-type phasenet_das_plus \
72
+ --checkpoint output/train_das_arcata/checkpoint.pth \
73
+ --output-dir results/das/train_das_arcata/arcata \
74
+ --no-ema --plot
75
+
76
+ # Predict from a file list
77
+ python scripts/predict_das.py \
78
+ --file-list file_list.txt \
79
+ --model-type phasenet_das_plus \
80
+ --checkpoint output/model.pth \
81
+ --plot
82
+
83
+ # Multi-GPU prediction
84
+ bash scripts/predict_das.sh phasenet_das_plus arcata 8 output/model.pth
85
+ ```
86
+
87
+ Output: one parquet per event in `results/das/{model_name}/{subset}/`, with columns:
88
+ `event_id, channel_index, origin_id, origin_index, origin_time, phase_index, phase_time, phase_score, phase_type, dt_s, ps_center, ps_interval`
89
+
90
+ By default, picks are associated using P-S pairing. Use `--use-event-head` to associate via the model's event detection head instead.
91
+
92
+ ## Training
93
+
94
+ ### CEED (Seismic) Training
95
+
96
+ ```bash
97
+ python train.py \
98
+ --model phasenet_plus \
99
+ --dataset-type ceed \
100
+ --label-path results/ceed \
101
+ --nx 16 \
102
+ --max-iters 100000 \
103
+ --batch-size 8 \
104
+ --workers 4 \
105
+ --lr 3e-4 \
106
+ --eval-interval 5000 \
107
+ --output-dir output/train_ceed
108
+ ```
109
+
110
+ ### DAS Training
111
+
112
+ ```bash
113
+ # Using the training script
114
+ bash scripts/train_das.sh 0 arcata v26
115
+
116
+ # Or directly
117
+ python train.py \
118
+ --model phasenet_das_plus \
119
+ --dataset-type das \
120
+ --data-path data/quakeflow_das/arcata/data \
121
+ --label-path results/das/phasenet/arcata/picks \
122
+ --label-list results/das/phasenet/arcata/labels.txt \
123
+ --nx 2048 --nt 4096 \
124
+ --num-patch 16 \
125
+ --max-iters 50000 \
126
+ --batch-size 2 --workers 8 \
127
+ --lr 1e-4 --weight-decay 0.01 \
128
+ --model-ema --model-ema-decay 0.999 \
129
+ --eval-interval 1000 --save-interval 1000 \
130
+ --output-dir output/train_das_arcata_v26
131
+ ```
132
+
133
+ ### Key Training Options
134
+
135
+ | Option | Description | Default |
136
+ |--------|-------------|---------|
137
+ | `--num-patch N` | Random crops per DAS sample (amortizes IO) | 2 |
138
+ | `--model-ema` | Enable exponential moving average | off |
139
+ | `--gradient-accumulation-steps N` | Accumulate gradients for larger effective batch | 1 |
140
+ | `--clip-grad-norm V` | Gradient clipping | 1.0 |
141
+ | `--compile` | Enable torch.compile | off |
142
+ | `--resume --checkpoint PATH` | Resume from checkpoint | - |
143
+ | `--reset-lr` | Reset LR schedule when resuming | off |
144
+
145
+ ## Semi-supervised Training
146
+
147
+ Iterative self-training pipeline for DAS: predict → train on predictions → predict with new model → repeat.
148
+
149
+ ```bash
150
+ # Start from PhaseNet (train DAS model from scratch)
151
+ bash scripts/semisupervised_das.sh arcata 5 0
152
+
153
+ # Start from a pretrained DAS model
154
+ bash scripts/semisupervised_das.sh arcata 5 0 phasenet_das output/train_das_v26/checkpoint.pth
155
+ ```
156
+
157
+ Arguments: `[subset] [num_iterations] [gpu] [start_from] [checkpoint]`
158
+
159
+ - **From phasenet**: iteration 0 predicts with PhaseNet, iteration 1 trains from scratch (10k steps, warmup), iterations 2+ continue (1k steps, no warmup)
160
+ - **From phasenet_das**: iteration 0 predicts with pretrained DAS model, iterations 1+ continue (1k steps, no warmup)
161
+
162
+ Results are saved to `results/semisupervised_das/` and checkpoints to `output/semisupervised_das_v{N}/`.
@@ -0,0 +1,133 @@
1
+ # PhaseNet-PyTorch
2
+
3
+ PyTorch implementation of PhaseNet for seismic and DAS phase picking, event detection, and polarity classification.
4
+
5
+ ## Models
6
+
7
+ | Model | Features | Data Type |
8
+ |-------|----------|-----------|
9
+ | `phasenet` | Phase (P/S) picking | Seismic 3-component |
10
+ | `phasenet_tf` | Phase picking + STFT spectrogram | Seismic 3-component |
11
+ | `phasenet_plus` | Phase + polarity + event detection | Seismic 3-component |
12
+ | `phasenet_tf_plus` | Phase + polarity + event detection + STFT | Seismic 3-component |
13
+ | `phasenet_das` | Phase picking | DAS single-channel |
14
+ | `phasenet_das_plus` | Phase + event detection | DAS single-channel |
15
+
16
+ The `_tf` variants add a Short-Time Fourier Transform (STFT) branch that extracts frequency features alongside the temporal waveform, improving performance on noisy data.
17
+
18
+ ## Prediction
19
+
20
+ ### CEED (Seismic) Prediction
21
+
22
+ ```bash
23
+ # Demo: process a few events with plots
24
+ python scripts/predict_ceed.py --n-events 5
25
+
26
+ # Process all days for a year (saves parquets to results/ceed/)
27
+ python scripts/predict_ceed.py --all --year 2025 --output-dir results/ceed
28
+ ```
29
+
30
+ Output: one parquet per day file in `results/ceed/{region}/`, with columns:
31
+ `event_id, station_id, waveform_index, origin_id, origin_index, origin_time, phase_index, phase_time, phase_score, phase_type, phase_polarity`
32
+
33
+ ### DAS Prediction
34
+
35
+ ```bash
36
+ # Predict with base PhaseNet on a HuggingFace subset
37
+ python scripts/predict_das.py --subset arcata --plot
38
+
39
+ # Predict with a trained DAS model on local data
40
+ python scripts/predict_das.py \
41
+ --data-dir data/quakeflow_das/arcata/data \
42
+ --model-type phasenet_das_plus \
43
+ --checkpoint output/train_das_arcata/checkpoint.pth \
44
+ --output-dir results/das/train_das_arcata/arcata \
45
+ --no-ema --plot
46
+
47
+ # Predict from a file list
48
+ python scripts/predict_das.py \
49
+ --file-list file_list.txt \
50
+ --model-type phasenet_das_plus \
51
+ --checkpoint output/model.pth \
52
+ --plot
53
+
54
+ # Multi-GPU prediction
55
+ bash scripts/predict_das.sh phasenet_das_plus arcata 8 output/model.pth
56
+ ```
57
+
58
+ Output: one parquet per event in `results/das/{model_name}/{subset}/`, with columns:
59
+ `event_id, channel_index, origin_id, origin_index, origin_time, phase_index, phase_time, phase_score, phase_type, dt_s, ps_center, ps_interval`
60
+
61
+ By default, picks are associated using P-S pairing. Use `--use-event-head` to associate via the model's event detection head instead.
62
+
63
+ ## Training
64
+
65
+ ### CEED (Seismic) Training
66
+
67
+ ```bash
68
+ python train.py \
69
+ --model phasenet_plus \
70
+ --dataset-type ceed \
71
+ --label-path results/ceed \
72
+ --nx 16 \
73
+ --max-iters 100000 \
74
+ --batch-size 8 \
75
+ --workers 4 \
76
+ --lr 3e-4 \
77
+ --eval-interval 5000 \
78
+ --output-dir output/train_ceed
79
+ ```
80
+
81
+ ### DAS Training
82
+
83
+ ```bash
84
+ # Using the training script
85
+ bash scripts/train_das.sh 0 arcata v26
86
+
87
+ # Or directly
88
+ python train.py \
89
+ --model phasenet_das_plus \
90
+ --dataset-type das \
91
+ --data-path data/quakeflow_das/arcata/data \
92
+ --label-path results/das/phasenet/arcata/picks \
93
+ --label-list results/das/phasenet/arcata/labels.txt \
94
+ --nx 2048 --nt 4096 \
95
+ --num-patch 16 \
96
+ --max-iters 50000 \
97
+ --batch-size 2 --workers 8 \
98
+ --lr 1e-4 --weight-decay 0.01 \
99
+ --model-ema --model-ema-decay 0.999 \
100
+ --eval-interval 1000 --save-interval 1000 \
101
+ --output-dir output/train_das_arcata_v26
102
+ ```
103
+
104
+ ### Key Training Options
105
+
106
+ | Option | Description | Default |
107
+ |--------|-------------|---------|
108
+ | `--num-patch N` | Random crops per DAS sample (amortizes IO) | 2 |
109
+ | `--model-ema` | Enable exponential moving average | off |
110
+ | `--gradient-accumulation-steps N` | Accumulate gradients for larger effective batch | 1 |
111
+ | `--clip-grad-norm V` | Gradient clipping | 1.0 |
112
+ | `--compile` | Enable torch.compile | off |
113
+ | `--resume --checkpoint PATH` | Resume from checkpoint | - |
114
+ | `--reset-lr` | Reset LR schedule when resuming | off |
115
+
116
+ ## Semi-supervised Training
117
+
118
+ Iterative self-training pipeline for DAS: predict → train on predictions → predict with new model → repeat.
119
+
120
+ ```bash
121
+ # Start from PhaseNet (train DAS model from scratch)
122
+ bash scripts/semisupervised_das.sh arcata 5 0
123
+
124
+ # Start from a pretrained DAS model
125
+ bash scripts/semisupervised_das.sh arcata 5 0 phasenet_das output/train_das_v26/checkpoint.pth
126
+ ```
127
+
128
+ Arguments: `[subset] [num_iterations] [gpu] [start_from] [checkpoint]`
129
+
130
+ - **From phasenet**: iteration 0 predicts with PhaseNet, iteration 1 trains from scratch (10k steps, warmup), iterations 2+ continue (1k steps, no warmup)
131
+ - **From phasenet_das**: iteration 0 predicts with pretrained DAS model, iterations 1+ continue (1k steps, no warmup)
132
+
133
+ Results are saved to `results/semisupervised_das/` and checkpoints to `output/semisupervised_das_v{N}/`.
@@ -0,0 +1,181 @@
1
+ #!/usr/bin/env python
2
+ """Generate training label plots from CEED prediction results.
3
+
4
+ Loads a prediction parquet, joins with original waveforms from GCS,
5
+ and plots top-N events with training labels.
6
+
7
+ Usage:
8
+ python examples/ceed_training_plots.py \
9
+ --parquet results/ceed/phasenet2018/NC/2025_001.parquet \
10
+ --region NC --year 2025 --day 001
11
+ """
12
+ import argparse
13
+ import json
14
+ import os
15
+ import sys
16
+
17
+ sys.path.insert(0, ".")
18
+
19
+ import gcsfs
20
+ import numpy as np
21
+ import pandas as pd
22
+ import pyarrow.parquet as pq
23
+
24
+ from phasenet.data.ceed import (
25
+ records_to_sample, default_train_transforms, plot_demo, Target,
26
+ )
27
+
28
+ BUCKETS = {
29
+ "NC": "quakeflow_dataset/NCEDC",
30
+ "SC": "quakeflow_dataset/SCEDC",
31
+ }
32
+ GCS_CRED_PATH = os.path.expanduser("~/.config/gcloud/application_default_credentials.json")
33
+
34
+
35
+ def get_storage_options():
36
+ if os.path.exists(GCS_CRED_PATH):
37
+ with open(GCS_CRED_PATH) as f:
38
+ return {"token": json.load(f)}
39
+ return {}
40
+
41
+
42
+ def load_day_waveforms(fs, bucket, year, day):
43
+ """Return {waveform_index: ndarray (3, nt)} for a day parquet."""
44
+ path = f"{bucket}/waveform_parquet/{year}/{day}.parquet"
45
+ pf = pq.ParquetFile(fs.open(path))
46
+ non_wave_cols = [c for c in pf.schema_arrow.names if c != "waveform"]
47
+ result = {}
48
+ row_offset = 0
49
+ for batch in pf.iter_batches(batch_size=256):
50
+ n = len(batch)
51
+ col = batch.column("waveform")
52
+ flat = col.flatten().flatten()
53
+ buf = flat.buffers()[1]
54
+ waveforms = np.frombuffer(
55
+ buf, dtype=np.float32, offset=flat.offset * 4
56
+ ).reshape(n, 3, -1).copy()
57
+ meta = batch.select(non_wave_cols).to_pydict()
58
+ for i in range(n):
59
+ eid = meta["event_id"][i] or ""
60
+ if eid:
61
+ result[row_offset + i] = waveforms[i]
62
+ row_offset += n
63
+ return result
64
+
65
+
66
+ def generate_training_plots(parquet_path, region, fs, year, day,
67
+ n_events=3, figure_dir="figures/ceed"):
68
+ """Load label parquet, join with original waveforms, plot training labels."""
69
+ print(f"\n=== Training label plots for {region} ===")
70
+
71
+ df = pd.read_parquet(parquet_path)
72
+ print(f" {len(df)} pick rows, "
73
+ f"{df.groupby(['event_id','origin_index']).ngroups} detected origins")
74
+
75
+ bucket = BUCKETS[region]
76
+ print(" Loading waveforms from original parquet...")
77
+ wave_by_idx = load_day_waveforms(fs, bucket, year, day)
78
+ print(f" {len(wave_by_idx)} waveforms loaded")
79
+
80
+ raw_pq = pq.read_table(
81
+ fs.open(f"{bucket}/waveform_parquet/{year}/{day}.parquet"),
82
+ columns=["event_id", "network", "station", "distance_km", "begin_time"],
83
+ ).to_pydict()
84
+ n_raw = len(raw_pq["event_id"])
85
+ meta_by_idx = {i: {k: raw_pq[k][i] for k in raw_pq} for i in range(n_raw)}
86
+
87
+ top_eids = (
88
+ df.groupby("event_id")["phase_index"].count()
89
+ .nlargest(n_events).index
90
+ )
91
+
92
+ transforms = default_train_transforms(crop_length=4096, enable_stacking=False)
93
+ fig_dir = os.path.join(figure_dir, region)
94
+
95
+ for src_eid in top_eids:
96
+ ev_df = df[df["event_id"] == src_eid]
97
+ linked = ev_df[ev_df["origin_index"].notna()]
98
+ valid_widxs = set(linked["waveform_index"].unique())
99
+
100
+ if not valid_widxs:
101
+ print(f" [WARN] {src_eid}: no stations with model picks")
102
+ continue
103
+
104
+ recs_tagged = []
105
+ for widx in valid_widxs:
106
+ wave = wave_by_idx.get(widx)
107
+ meta = meta_by_idx.get(widx, {})
108
+ if wave is None:
109
+ continue
110
+ recs_tagged.append({
111
+ "event_id": src_eid,
112
+ "network": meta.get("network", ""),
113
+ "station": meta.get("station", ""),
114
+ "distance_km": meta.get("distance_km"),
115
+ "begin_time": meta.get("begin_time"),
116
+ "waveform": wave,
117
+ "_widx": widx,
118
+ })
119
+
120
+ if not recs_tagged:
121
+ print(f" [WARN] {src_eid}: no waveforms found")
122
+ continue
123
+
124
+ recs_tagged.sort(key=lambda r: r.get("distance_km") or 0.0)
125
+ widx_to_sta = {r["_widx"]: i for i, r in enumerate(recs_tagged)}
126
+ recs = [{k: v for k, v in r.items() if k != "_widx"} for r in recs_tagged]
127
+
128
+ p_picks_all, s_picks_all = [], []
129
+ ps_centers_all, ps_intervals_all = [], []
130
+
131
+ for widx, sta_grp in linked.groupby("waveform_index"):
132
+ sta_idx = widx_to_sta.get(widx)
133
+ if sta_idx is None:
134
+ continue
135
+ for pidx in sta_grp.loc[sta_grp["phase_type"] == "P", "phase_index"].dropna():
136
+ p_picks_all.append((sta_idx, float(pidx)))
137
+ for sidx in sta_grp.loc[sta_grp["phase_type"] == "S", "phase_index"].dropna():
138
+ s_picks_all.append((sta_idx, float(sidx)))
139
+ for _, o_grp in sta_grp.groupby("origin_index"):
140
+ p_o = o_grp.loc[o_grp["phase_type"] == "P", "phase_index"].dropna()
141
+ s_o = o_grp.loc[o_grp["phase_type"] == "S", "phase_index"].dropna()
142
+ if p_o.empty or s_o.empty:
143
+ continue
144
+ p_i, s_i = float(p_o.iloc[0]), float(s_o.iloc[0])
145
+ ps_centers_all.append((sta_idx, (p_i + s_i) / 2.0))
146
+ ps_intervals_all.append((sta_idx, s_i - p_i))
147
+
148
+ try:
149
+ sample = records_to_sample(recs)
150
+ if p_picks_all or s_picks_all:
151
+ sample.targets = [Target(
152
+ p_picks=p_picks_all, s_picks=s_picks_all,
153
+ ps_centers=ps_centers_all, ps_intervals=ps_intervals_all,
154
+ event_id=src_eid,
155
+ )]
156
+ n_origins = int(linked["origin_index"].nunique())
157
+ print(f" {src_eid}: {len(recs)} stations origins={n_origins} "
158
+ f"waveform {sample.waveform.shape}")
159
+ plot_demo(sample, transforms, event_id=src_eid,
160
+ output_dir=fig_dir, n_augmented=2)
161
+ except Exception as e:
162
+ print(f" [WARN] {src_eid}: {e}")
163
+
164
+
165
+ def main():
166
+ parser = argparse.ArgumentParser(description="CEED training label plots")
167
+ parser.add_argument("--parquet", required=True, help="Path to prediction parquet")
168
+ parser.add_argument("--region", default="NC", help="Region (NC or SC)")
169
+ parser.add_argument("--year", type=int, required=True)
170
+ parser.add_argument("--day", required=True, help="Day string (e.g. 001)")
171
+ parser.add_argument("--n-events", type=int, default=3)
172
+ parser.add_argument("--figure-dir", default="figures/ceed")
173
+ args = parser.parse_args()
174
+
175
+ fs = gcsfs.GCSFileSystem(**get_storage_options())
176
+ generate_training_plots(args.parquet, args.region, fs,
177
+ args.year, args.day, args.n_events, args.figure_dir)
178
+
179
+
180
+ if __name__ == "__main__":
181
+ main()