photo-stack-finder 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. orchestrator/__init__.py +2 -2
  2. orchestrator/app.py +6 -11
  3. orchestrator/build_pipeline.py +19 -21
  4. orchestrator/orchestrator_runner.py +11 -8
  5. orchestrator/pipeline_builder.py +126 -126
  6. orchestrator/pipeline_orchestrator.py +604 -604
  7. orchestrator/review_persistence.py +162 -162
  8. orchestrator/static/orchestrator.css +76 -76
  9. orchestrator/static/orchestrator.html +11 -5
  10. orchestrator/static/orchestrator.js +3 -1
  11. overlap_metrics/__init__.py +1 -1
  12. overlap_metrics/config.py +135 -135
  13. overlap_metrics/core.py +284 -284
  14. overlap_metrics/estimators.py +292 -292
  15. overlap_metrics/metrics.py +307 -307
  16. overlap_metrics/registry.py +99 -99
  17. overlap_metrics/utils.py +104 -104
  18. photo_compare/__init__.py +1 -1
  19. photo_compare/base.py +285 -285
  20. photo_compare/config.py +225 -225
  21. photo_compare/distance.py +15 -15
  22. photo_compare/feature_methods.py +173 -173
  23. photo_compare/file_hash.py +29 -29
  24. photo_compare/hash_methods.py +99 -99
  25. photo_compare/histogram_methods.py +118 -118
  26. photo_compare/pixel_methods.py +58 -58
  27. photo_compare/structural_methods.py +104 -104
  28. photo_compare/types.py +28 -28
  29. {photo_stack_finder-0.1.7.dist-info → photo_stack_finder-0.1.8.dist-info}/METADATA +21 -22
  30. photo_stack_finder-0.1.8.dist-info/RECORD +75 -0
  31. scripts/orchestrate.py +12 -10
  32. utils/__init__.py +4 -3
  33. utils/base_pipeline_stage.py +171 -171
  34. utils/base_ports.py +176 -176
  35. utils/benchmark_utils.py +823 -823
  36. utils/channel.py +74 -74
  37. utils/comparison_gates.py +40 -21
  38. utils/compute_benchmarks.py +355 -355
  39. utils/compute_identical.py +94 -24
  40. utils/compute_indices.py +235 -235
  41. utils/compute_perceptual_hash.py +127 -127
  42. utils/compute_perceptual_match.py +240 -240
  43. utils/compute_sha_bins.py +64 -20
  44. utils/compute_template_similarity.py +1 -1
  45. utils/compute_versions.py +483 -483
  46. utils/config.py +8 -5
  47. utils/data_io.py +83 -83
  48. utils/graph_context.py +44 -44
  49. utils/logger.py +2 -2
  50. utils/models.py +2 -2
  51. utils/photo_file.py +90 -91
  52. utils/pipeline_graph.py +334 -334
  53. utils/pipeline_stage.py +408 -408
  54. utils/plot_helpers.py +123 -123
  55. utils/ports.py +136 -136
  56. utils/progress.py +415 -415
  57. utils/report_builder.py +139 -139
  58. utils/review_types.py +55 -55
  59. utils/review_utils.py +10 -19
  60. utils/sequence.py +10 -8
  61. utils/sequence_clustering.py +1 -1
  62. utils/template.py +57 -57
  63. utils/template_parsing.py +71 -0
  64. photo_stack_finder-0.1.7.dist-info/RECORD +0 -74
  65. {photo_stack_finder-0.1.7.dist-info → photo_stack_finder-0.1.8.dist-info}/WHEEL +0 -0
  66. {photo_stack_finder-0.1.7.dist-info → photo_stack_finder-0.1.8.dist-info}/entry_points.txt +0 -0
  67. {photo_stack_finder-0.1.7.dist-info → photo_stack_finder-0.1.8.dist-info}/licenses/LICENSE +0 -0
  68. {photo_stack_finder-0.1.7.dist-info → photo_stack_finder-0.1.8.dist-info}/top_level.txt +0 -0
@@ -1,307 +1,307 @@
1
- """Concrete separation/overlap metrics for overlap_metrics library."""
2
-
3
- from __future__ import annotations
4
-
5
- import time
6
-
7
- import numpy as np
8
- import numpy.typing as npt
9
- from scipy import stats
10
-
11
- from .config import NUMERICS
12
- from .core import (
13
- PDF,
14
- DensityEstimatorBase,
15
- MetricBase,
16
- MetricResult,
17
- SampleBasedMetric,
18
- ScoreSamples,
19
- )
20
- from .utils import kl_divergence, make_grid
21
-
22
-
23
- class SeparationOVL(MetricBase):
24
- """1 - Overlap coefficient: higher values = better separation."""
25
-
26
- def __init__(self) -> None:
27
- super().__init__(name="separation_ovl", lower_is_better=False, bounds=(0.0, 1.0))
28
-
29
- def from_pdfs(
30
- self,
31
- p: PDF,
32
- q: PDF,
33
- n_grid: int,
34
- grid: str,
35
- ) -> MetricResult:
36
- """Compute 1 - OVL where OVL = ∫ min(p,q) dx."""
37
- xs: npt.NDArray[np.float64] = make_grid(n_grid=n_grid, mode=grid)
38
- p_vals: npt.NDArray[np.float64] = np.maximum(p(xs), 0.0)
39
- q_vals: npt.NDArray[np.float64] = np.maximum(q(xs), 0.0)
40
-
41
- # Compute overlap coefficient
42
- min_vals: npt.NDArray[np.float64] = np.minimum(p_vals, q_vals)
43
- ovl: float = float(np.trapezoid(min_vals, xs))
44
-
45
- # Separation is 1 - overlap
46
- separation: float = 1.0 - ovl
47
-
48
- return MetricResult(
49
- name=self.name,
50
- value=separation,
51
- lower_is_better=self.lower_is_better,
52
- bounds=self.bounds,
53
- estimator_name="pdf_based",
54
- details={"ovl": ovl},
55
- meta={},
56
- )
57
-
58
-
59
- class BhattacharyyaDistance(MetricBase):
60
- """Bhattacharyya distance: -ln(BC) where BC = ∫ sqrt(p*q) dx."""
61
-
62
- def __init__(self) -> None:
63
- super().__init__(
64
- name="bhattacharyya_distance",
65
- lower_is_better=False,
66
- bounds=(0.0, float("inf")),
67
- )
68
-
69
- def from_pdfs(
70
- self,
71
- p: PDF,
72
- q: PDF,
73
- n_grid: int,
74
- grid: str,
75
- ) -> MetricResult:
76
- """Compute Bhattacharyya distance."""
77
- xs: npt.NDArray[np.float64] = make_grid(n_grid=n_grid, mode=grid)
78
- p_vals: npt.NDArray[np.float64] = np.maximum(p(xs), 0.0)
79
- q_vals: npt.NDArray[np.float64] = np.maximum(q(xs), 0.0)
80
-
81
- # Compute Bhattacharyya coefficient
82
- sqrt_product: npt.NDArray[np.float64] = np.sqrt(p_vals * q_vals)
83
- bc: float = float(np.trapezoid(sqrt_product, xs))
84
-
85
- # Distance is -ln(BC) with safe floor
86
- bc_safe: float = max(bc, NUMERICS.LOG_FLOOR)
87
- db: float = -np.log(bc_safe)
88
-
89
- return MetricResult(
90
- name=self.name,
91
- value=db,
92
- lower_is_better=self.lower_is_better,
93
- bounds=self.bounds,
94
- estimator_name="pdf_based",
95
- details={"bc": bc},
96
- meta={},
97
- )
98
-
99
-
100
- class JensenShannon(MetricBase):
101
- """Jensen-Shannon divergence normalized by ln(2)."""
102
-
103
- def __init__(self) -> None:
104
- super().__init__(name="js_divergence", lower_is_better=False, bounds=(0.0, 1.0))
105
-
106
- def from_pdfs(
107
- self,
108
- p: PDF,
109
- q: PDF,
110
- n_grid: int,
111
- grid: str,
112
- ) -> MetricResult:
113
- """Compute Jensen-Shannon divergence."""
114
- xs: npt.NDArray[np.float64] = make_grid(n_grid=n_grid, mode=grid)
115
- p_vals: npt.NDArray[np.float64] = np.maximum(p(xs), NUMERICS.LOG_FLOOR)
116
- q_vals: npt.NDArray[np.float64] = np.maximum(q(xs), NUMERICS.LOG_FLOOR)
117
-
118
- # Midpoint distribution
119
- m_vals: npt.NDArray[np.float64] = 0.5 * (p_vals + q_vals)
120
-
121
- # Compute uniform dx for KL calculations
122
- dx: float = xs[1] - xs[0] if len(xs) > 1 else 1.0
123
- dx_array: npt.NDArray[np.float64] = np.full_like(xs, dx)
124
-
125
- # JS = 0.5 * KL(P||M) + 0.5 * KL(Q||M)
126
- kl_pm: float = kl_divergence(p_vals, m_vals, dx_array)
127
- kl_qm: float = kl_divergence(q_vals, m_vals, dx_array)
128
- js_raw: float = 0.5 * kl_pm + 0.5 * kl_qm
129
-
130
- # Normalize by ln(2) to get value in [0,1]
131
- js_normalized: float = js_raw / np.log(2.0)
132
-
133
- return MetricResult(
134
- name=self.name,
135
- value=js_normalized,
136
- lower_is_better=self.lower_is_better,
137
- bounds=self.bounds,
138
- estimator_name="pdf_based",
139
- details={"js_raw": js_raw},
140
- meta={},
141
- )
142
-
143
-
144
- class HellingerDistance(MetricBase):
145
- """Hellinger distance: sqrt(1 - BC) where BC is Bhattacharyya coefficient."""
146
-
147
- def __init__(self) -> None:
148
- super().__init__(name="hellinger_distance", lower_is_better=False, bounds=(0.0, 1.0))
149
-
150
- def from_pdfs(
151
- self,
152
- p: PDF,
153
- q: PDF,
154
- n_grid: int,
155
- grid: str,
156
- ) -> MetricResult:
157
- """Compute Hellinger distance."""
158
- xs: npt.NDArray[np.float64] = make_grid(n_grid=n_grid, mode=grid)
159
- p_vals: npt.NDArray[np.float64] = np.maximum(p(xs), 0.0)
160
- q_vals: npt.NDArray[np.float64] = np.maximum(q(xs), 0.0)
161
-
162
- # Compute Bhattacharyya coefficient
163
- sqrt_product: npt.NDArray[np.float64] = np.sqrt(p_vals * q_vals)
164
- bc: float = float(np.trapezoid(sqrt_product, xs))
165
-
166
- # Hellinger distance
167
- hellinger: float = np.sqrt(max(0.0, 1.0 - bc))
168
-
169
- return MetricResult(
170
- name=self.name,
171
- value=hellinger,
172
- lower_is_better=self.lower_is_better,
173
- bounds=self.bounds,
174
- estimator_name="pdf_based",
175
- details={"bc": bc},
176
- meta={},
177
- )
178
-
179
-
180
- class TotalVariation(MetricBase):
181
- """Total variation distance: 0.5 * ∫ |p - q| dx."""
182
-
183
- def __init__(self) -> None:
184
- super().__init__(name="total_variation", lower_is_better=False, bounds=(0.0, 1.0))
185
-
186
- def from_pdfs(
187
- self,
188
- p: PDF,
189
- q: PDF,
190
- n_grid: int,
191
- grid: str,
192
- ) -> MetricResult:
193
- """Compute total variation distance."""
194
- xs: npt.NDArray[np.float64] = make_grid(n_grid=n_grid, mode=grid)
195
- p_vals: npt.NDArray[np.float64] = np.maximum(p(xs), 0.0)
196
- q_vals: npt.NDArray[np.float64] = np.maximum(q(xs), 0.0)
197
-
198
- # TV = 0.5 * ∫ |p - q| dx
199
- abs_diff: npt.NDArray[np.float64] = np.abs(p_vals - q_vals)
200
- tv: float = 0.5 * float(np.trapezoid(abs_diff, xs))
201
-
202
- return MetricResult(
203
- name=self.name,
204
- value=tv,
205
- lower_is_better=self.lower_is_better,
206
- bounds=self.bounds,
207
- estimator_name="pdf_based",
208
- details={},
209
- meta={},
210
- )
211
-
212
-
213
- class Wasserstein1D(SampleBasedMetric):
214
- """1D Wasserstein distance computed directly from samples."""
215
-
216
- def __init__(self) -> None:
217
- super().__init__(name="wasserstein_1d", lower_is_better=False, bounds=(0.0, float("inf")))
218
-
219
- def from_samples(
220
- self,
221
- samples: ScoreSamples,
222
- estimator: DensityEstimatorBase,
223
- n_grid: int,
224
- grid: str,
225
- weights_pos: npt.NDArray[np.float64] | None = None,
226
- weights_neg: npt.NDArray[np.float64] | None = None,
227
- random_state: int | None = None,
228
- ) -> MetricResult:
229
- """Compute Wasserstein distance directly from samples."""
230
- start_time: float = time.perf_counter()
231
-
232
- # Use scipy's implementation
233
- if weights_pos is not None or weights_neg is not None:
234
- # scipy 1.9+ supports weights
235
- try:
236
- wasserstein_dist: float = stats.wasserstein_distance(
237
- samples.pos,
238
- samples.neg,
239
- u_weights=weights_pos,
240
- v_weights=weights_neg,
241
- )
242
- except TypeError:
243
- # Fallback for older scipy versions
244
- wasserstein_dist = stats.wasserstein_distance(samples.pos, samples.neg)
245
- else:
246
- wasserstein_dist = stats.wasserstein_distance(samples.pos, samples.neg)
247
-
248
- runtime_ms: float = (time.perf_counter() - start_time) * 1000.0
249
-
250
- return MetricResult(
251
- name=self.name,
252
- value=wasserstein_dist,
253
- lower_is_better=self.lower_is_better,
254
- bounds=self.bounds,
255
- estimator_name=estimator.name,
256
- details={},
257
- meta={
258
- "n_pos": float(len(samples.pos)),
259
- "n_neg": float(len(samples.neg)),
260
- "runtime_ms": runtime_ms,
261
- },
262
- )
263
-
264
-
265
- class KSStatistic(SampleBasedMetric):
266
- """Kolmogorov-Smirnov test statistic (two-sample)."""
267
-
268
- def __init__(self) -> None:
269
- super().__init__(name="ks_stat", lower_is_better=False, bounds=(0.0, 1.0))
270
-
271
- def from_samples(
272
- self,
273
- samples: ScoreSamples,
274
- estimator: DensityEstimatorBase,
275
- n_grid: int,
276
- grid: str,
277
- weights_pos: npt.NDArray[np.float64] | None = None,
278
- weights_neg: npt.NDArray[np.float64] | None = None,
279
- random_state: int | None = None,
280
- ) -> MetricResult:
281
- """Compute KS statistic directly from samples."""
282
- start_time: float = time.perf_counter()
283
-
284
- # Note: scipy's ks_2samp doesn't support weights
285
- if weights_pos is not None or weights_neg is not None:
286
- # Could implement weighted KS in future, for now ignore weights
287
- pass
288
-
289
- # Compute KS statistic
290
- ks_result = stats.ks_2samp(samples.pos, samples.neg)
291
- ks_stat: float = float(ks_result.statistic)
292
-
293
- runtime_ms: float = (time.perf_counter() - start_time) * 1000.0
294
-
295
- return MetricResult(
296
- name=self.name,
297
- value=ks_stat,
298
- lower_is_better=self.lower_is_better,
299
- bounds=self.bounds,
300
- estimator_name=estimator.name,
301
- details={"p_value": float(ks_result.pvalue)},
302
- meta={
303
- "n_pos": float(len(samples.pos)),
304
- "n_neg": float(len(samples.neg)),
305
- "runtime_ms": runtime_ms,
306
- },
307
- )
1
+ """Concrete separation/overlap metrics for overlap_metrics library."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ from scipy import stats
10
+
11
+ from .config import NUMERICS
12
+ from .core import (
13
+ PDF,
14
+ DensityEstimatorBase,
15
+ MetricBase,
16
+ MetricResult,
17
+ SampleBasedMetric,
18
+ ScoreSamples,
19
+ )
20
+ from .utils import kl_divergence, make_grid
21
+
22
+
23
+ class SeparationOVL(MetricBase):
24
+ """1 - Overlap coefficient: higher values = better separation."""
25
+
26
+ def __init__(self) -> None:
27
+ super().__init__(name="separation_ovl", lower_is_better=False, bounds=(0.0, 1.0))
28
+
29
+ def from_pdfs(
30
+ self,
31
+ p: PDF,
32
+ q: PDF,
33
+ n_grid: int,
34
+ grid: str,
35
+ ) -> MetricResult:
36
+ """Compute 1 - OVL where OVL = ∫ min(p,q) dx."""
37
+ xs: npt.NDArray[np.float64] = make_grid(n_grid=n_grid, mode=grid)
38
+ p_vals: npt.NDArray[np.float64] = np.maximum(p(xs), 0.0)
39
+ q_vals: npt.NDArray[np.float64] = np.maximum(q(xs), 0.0)
40
+
41
+ # Compute overlap coefficient
42
+ min_vals: npt.NDArray[np.float64] = np.minimum(p_vals, q_vals)
43
+ ovl: float = float(np.trapezoid(min_vals, xs))
44
+
45
+ # Separation is 1 - overlap
46
+ separation: float = 1.0 - ovl
47
+
48
+ return MetricResult(
49
+ name=self.name,
50
+ value=separation,
51
+ lower_is_better=self.lower_is_better,
52
+ bounds=self.bounds,
53
+ estimator_name="pdf_based",
54
+ details={"ovl": ovl},
55
+ meta={},
56
+ )
57
+
58
+
59
+ class BhattacharyyaDistance(MetricBase):
60
+ """Bhattacharyya distance: -ln(BC) where BC = ∫ sqrt(p*q) dx."""
61
+
62
+ def __init__(self) -> None:
63
+ super().__init__(
64
+ name="bhattacharyya_distance",
65
+ lower_is_better=False,
66
+ bounds=(0.0, float("inf")),
67
+ )
68
+
69
+ def from_pdfs(
70
+ self,
71
+ p: PDF,
72
+ q: PDF,
73
+ n_grid: int,
74
+ grid: str,
75
+ ) -> MetricResult:
76
+ """Compute Bhattacharyya distance."""
77
+ xs: npt.NDArray[np.float64] = make_grid(n_grid=n_grid, mode=grid)
78
+ p_vals: npt.NDArray[np.float64] = np.maximum(p(xs), 0.0)
79
+ q_vals: npt.NDArray[np.float64] = np.maximum(q(xs), 0.0)
80
+
81
+ # Compute Bhattacharyya coefficient
82
+ sqrt_product: npt.NDArray[np.float64] = np.sqrt(p_vals * q_vals)
83
+ bc: float = float(np.trapezoid(sqrt_product, xs))
84
+
85
+ # Distance is -ln(BC) with safe floor
86
+ bc_safe: float = max(bc, NUMERICS.LOG_FLOOR)
87
+ db: float = -np.log(bc_safe)
88
+
89
+ return MetricResult(
90
+ name=self.name,
91
+ value=db,
92
+ lower_is_better=self.lower_is_better,
93
+ bounds=self.bounds,
94
+ estimator_name="pdf_based",
95
+ details={"bc": bc},
96
+ meta={},
97
+ )
98
+
99
+
100
+ class JensenShannon(MetricBase):
101
+ """Jensen-Shannon divergence normalized by ln(2)."""
102
+
103
+ def __init__(self) -> None:
104
+ super().__init__(name="js_divergence", lower_is_better=False, bounds=(0.0, 1.0))
105
+
106
+ def from_pdfs(
107
+ self,
108
+ p: PDF,
109
+ q: PDF,
110
+ n_grid: int,
111
+ grid: str,
112
+ ) -> MetricResult:
113
+ """Compute Jensen-Shannon divergence."""
114
+ xs: npt.NDArray[np.float64] = make_grid(n_grid=n_grid, mode=grid)
115
+ p_vals: npt.NDArray[np.float64] = np.maximum(p(xs), NUMERICS.LOG_FLOOR)
116
+ q_vals: npt.NDArray[np.float64] = np.maximum(q(xs), NUMERICS.LOG_FLOOR)
117
+
118
+ # Midpoint distribution
119
+ m_vals: npt.NDArray[np.float64] = 0.5 * (p_vals + q_vals)
120
+
121
+ # Compute uniform dx for KL calculations
122
+ dx: float = xs[1] - xs[0] if len(xs) > 1 else 1.0
123
+ dx_array: npt.NDArray[np.float64] = np.full_like(xs, dx)
124
+
125
+ # JS = 0.5 * KL(P||M) + 0.5 * KL(Q||M)
126
+ kl_pm: float = kl_divergence(p_vals, m_vals, dx_array)
127
+ kl_qm: float = kl_divergence(q_vals, m_vals, dx_array)
128
+ js_raw: float = 0.5 * kl_pm + 0.5 * kl_qm
129
+
130
+ # Normalize by ln(2) to get value in [0,1]
131
+ js_normalized: float = js_raw / np.log(2.0)
132
+
133
+ return MetricResult(
134
+ name=self.name,
135
+ value=js_normalized,
136
+ lower_is_better=self.lower_is_better,
137
+ bounds=self.bounds,
138
+ estimator_name="pdf_based",
139
+ details={"js_raw": js_raw},
140
+ meta={},
141
+ )
142
+
143
+
144
+ class HellingerDistance(MetricBase):
145
+ """Hellinger distance: sqrt(1 - BC) where BC is Bhattacharyya coefficient."""
146
+
147
+ def __init__(self) -> None:
148
+ super().__init__(name="hellinger_distance", lower_is_better=False, bounds=(0.0, 1.0))
149
+
150
+ def from_pdfs(
151
+ self,
152
+ p: PDF,
153
+ q: PDF,
154
+ n_grid: int,
155
+ grid: str,
156
+ ) -> MetricResult:
157
+ """Compute Hellinger distance."""
158
+ xs: npt.NDArray[np.float64] = make_grid(n_grid=n_grid, mode=grid)
159
+ p_vals: npt.NDArray[np.float64] = np.maximum(p(xs), 0.0)
160
+ q_vals: npt.NDArray[np.float64] = np.maximum(q(xs), 0.0)
161
+
162
+ # Compute Bhattacharyya coefficient
163
+ sqrt_product: npt.NDArray[np.float64] = np.sqrt(p_vals * q_vals)
164
+ bc: float = float(np.trapezoid(sqrt_product, xs))
165
+
166
+ # Hellinger distance
167
+ hellinger: float = np.sqrt(max(0.0, 1.0 - bc))
168
+
169
+ return MetricResult(
170
+ name=self.name,
171
+ value=hellinger,
172
+ lower_is_better=self.lower_is_better,
173
+ bounds=self.bounds,
174
+ estimator_name="pdf_based",
175
+ details={"bc": bc},
176
+ meta={},
177
+ )
178
+
179
+
180
+ class TotalVariation(MetricBase):
181
+ """Total variation distance: 0.5 * ∫ |p - q| dx."""
182
+
183
+ def __init__(self) -> None:
184
+ super().__init__(name="total_variation", lower_is_better=False, bounds=(0.0, 1.0))
185
+
186
+ def from_pdfs(
187
+ self,
188
+ p: PDF,
189
+ q: PDF,
190
+ n_grid: int,
191
+ grid: str,
192
+ ) -> MetricResult:
193
+ """Compute total variation distance."""
194
+ xs: npt.NDArray[np.float64] = make_grid(n_grid=n_grid, mode=grid)
195
+ p_vals: npt.NDArray[np.float64] = np.maximum(p(xs), 0.0)
196
+ q_vals: npt.NDArray[np.float64] = np.maximum(q(xs), 0.0)
197
+
198
+ # TV = 0.5 * ∫ |p - q| dx
199
+ abs_diff: npt.NDArray[np.float64] = np.abs(p_vals - q_vals)
200
+ tv: float = 0.5 * float(np.trapezoid(abs_diff, xs))
201
+
202
+ return MetricResult(
203
+ name=self.name,
204
+ value=tv,
205
+ lower_is_better=self.lower_is_better,
206
+ bounds=self.bounds,
207
+ estimator_name="pdf_based",
208
+ details={},
209
+ meta={},
210
+ )
211
+
212
+
213
+ class Wasserstein1D(SampleBasedMetric):
214
+ """1D Wasserstein distance computed directly from samples."""
215
+
216
+ def __init__(self) -> None:
217
+ super().__init__(name="wasserstein_1d", lower_is_better=False, bounds=(0.0, float("inf")))
218
+
219
+ def from_samples(
220
+ self,
221
+ samples: ScoreSamples,
222
+ estimator: DensityEstimatorBase,
223
+ n_grid: int,
224
+ grid: str,
225
+ weights_pos: npt.NDArray[np.float64] | None = None,
226
+ weights_neg: npt.NDArray[np.float64] | None = None,
227
+ random_state: int | None = None,
228
+ ) -> MetricResult:
229
+ """Compute Wasserstein distance directly from samples."""
230
+ start_time: float = time.perf_counter()
231
+
232
+ # Use scipy's implementation
233
+ if weights_pos is not None or weights_neg is not None:
234
+ # scipy 1.9+ supports weights
235
+ try:
236
+ wasserstein_dist: float = stats.wasserstein_distance(
237
+ samples.pos,
238
+ samples.neg,
239
+ u_weights=weights_pos,
240
+ v_weights=weights_neg,
241
+ )
242
+ except TypeError:
243
+ # Fallback for older scipy versions
244
+ wasserstein_dist = stats.wasserstein_distance(samples.pos, samples.neg)
245
+ else:
246
+ wasserstein_dist = stats.wasserstein_distance(samples.pos, samples.neg)
247
+
248
+ runtime_ms: float = (time.perf_counter() - start_time) * 1000.0
249
+
250
+ return MetricResult(
251
+ name=self.name,
252
+ value=wasserstein_dist,
253
+ lower_is_better=self.lower_is_better,
254
+ bounds=self.bounds,
255
+ estimator_name=estimator.name,
256
+ details={},
257
+ meta={
258
+ "n_pos": float(len(samples.pos)),
259
+ "n_neg": float(len(samples.neg)),
260
+ "runtime_ms": runtime_ms,
261
+ },
262
+ )
263
+
264
+
265
+ class KSStatistic(SampleBasedMetric):
266
+ """Kolmogorov-Smirnov test statistic (two-sample)."""
267
+
268
+ def __init__(self) -> None:
269
+ super().__init__(name="ks_stat", lower_is_better=False, bounds=(0.0, 1.0))
270
+
271
+ def from_samples(
272
+ self,
273
+ samples: ScoreSamples,
274
+ estimator: DensityEstimatorBase,
275
+ n_grid: int,
276
+ grid: str,
277
+ weights_pos: npt.NDArray[np.float64] | None = None,
278
+ weights_neg: npt.NDArray[np.float64] | None = None,
279
+ random_state: int | None = None,
280
+ ) -> MetricResult:
281
+ """Compute KS statistic directly from samples."""
282
+ start_time: float = time.perf_counter()
283
+
284
+ # Note: scipy's ks_2samp doesn't support weights
285
+ if weights_pos is not None or weights_neg is not None:
286
+ # Could implement weighted KS in future, for now ignore weights
287
+ pass
288
+
289
+ # Compute KS statistic
290
+ ks_result = stats.ks_2samp(samples.pos, samples.neg)
291
+ ks_stat: float = float(ks_result.statistic)
292
+
293
+ runtime_ms: float = (time.perf_counter() - start_time) * 1000.0
294
+
295
+ return MetricResult(
296
+ name=self.name,
297
+ value=ks_stat,
298
+ lower_is_better=self.lower_is_better,
299
+ bounds=self.bounds,
300
+ estimator_name=estimator.name,
301
+ details={"p_value": float(ks_result.pvalue)},
302
+ meta={
303
+ "n_pos": float(len(samples.pos)),
304
+ "n_neg": float(len(samples.neg)),
305
+ "runtime_ms": runtime_ms,
306
+ },
307
+ )