goad-py 0.6.0__cp38-abi3-musllinux_1_2_aarch64.whl → 0.7.0__cp38-abi3-musllinux_1_2_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of goad-py might be problematic. Click here for more details.

@@ -0,0 +1,499 @@
1
+ """
2
+ Rich-based convergence display for GOAD.
3
+
4
+ This module provides a unified, reusable display system for convergence tracking
5
+ that works with both standard convergence and PHIPS convergence modes.
6
+ """
7
+
8
+ from typing import Callable, Dict, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ from rich.columns import Columns
12
+ from rich.console import Console, Group
13
+ from rich.live import Live
14
+ from rich.progress import BarColumn, Progress, TextColumn
15
+ from rich.spinner import Spinner
16
+ from rich.text import Text
17
+
18
+
19
+ class ConvergenceVariable:
20
+ """
21
+ Base class for convergence variables.
22
+
23
+ Encapsulates the logic for calculating progress, formatting display values,
24
+ and checking convergence for a single variable.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ name: str,
30
+ tolerance: float,
31
+ tolerance_type: str = "relative",
32
+ ):
33
+ """
34
+ Initialize a convergence variable.
35
+
36
+ Args:
37
+ name: Variable name (e.g., "asymmetry", "S11", "phips_dscs")
38
+ tolerance: Convergence tolerance threshold
39
+ tolerance_type: "relative" or "absolute"
40
+ """
41
+ self.name = name
42
+ self.tolerance = tolerance
43
+ self.tolerance_type = tolerance_type
44
+
45
+ if tolerance_type not in {"relative", "absolute"}:
46
+ raise ValueError(f"Invalid tolerance_type '{tolerance_type}'")
47
+
48
+ if tolerance <= 0:
49
+ raise ValueError(f"Tolerance must be positive, got {tolerance}")
50
+
51
+ def calculate_progress(self, mean: float, sem: float) -> float:
52
+ """
53
+ Calculate progress percentage (0-100) based on current SEM.
54
+
55
+ Uses sqrt formula for smoother progression: sqrt(target/current) * 100
56
+
57
+ Args:
58
+ mean: Current mean value
59
+ sem: Current SEM value
60
+
61
+ Returns:
62
+ Progress percentage (0-100)
63
+ """
64
+ if self.tolerance_type == "relative":
65
+ if mean != 0:
66
+ relative_sem = sem / abs(mean)
67
+ if relative_sem > 0 and not np.isinf(relative_sem):
68
+ return min(100.0, np.sqrt(self.tolerance / relative_sem) * 100.0)
69
+ else:
70
+ # Absolute
71
+ if sem > 0 and not np.isinf(sem):
72
+ return min(100.0, np.sqrt(self.tolerance / sem) * 100.0)
73
+
74
+ return 0.0
75
+
76
+ def format_sem_info(self, mean: float, sem: float) -> str:
77
+ """
78
+ Format SEM info string for display.
79
+
80
+ Args:
81
+ mean: Current mean value
82
+ sem: Current SEM value
83
+
84
+ Returns:
85
+ Formatted string like "[SEM: 19.7% / 20.0%]" or "[SEM: 0.0042 / 0.0050]"
86
+ """
87
+ if self.tolerance_type == "relative":
88
+ if mean != 0:
89
+ relative_sem = sem / abs(mean)
90
+ current_str = f"{relative_sem * 100:.1f}%"
91
+ else:
92
+ current_str = f"{sem:.4g}"
93
+ target_str = f"{self.tolerance * 100:.1f}%"
94
+ else:
95
+ current_str = f"{sem:.4g}"
96
+ target_str = f"{self.tolerance:.4g}"
97
+
98
+ return f"[SEM: {current_str} / {target_str}]"
99
+
100
+ def is_converged(self, mean: float, sem: float) -> bool:
101
+ """
102
+ Check if this variable has converged.
103
+
104
+ Args:
105
+ mean: Current mean value
106
+ sem: Current SEM value
107
+
108
+ Returns:
109
+ True if converged, False otherwise
110
+ """
111
+ if self.tolerance_type == "relative":
112
+ if mean != 0:
113
+ relative_sem = sem / abs(mean)
114
+ return relative_sem < self.tolerance
115
+ return False
116
+ else:
117
+ return sem < self.tolerance
118
+
119
+
120
+ class ArrayConvergenceVariable(ConvergenceVariable):
121
+ """
122
+ Convergence variable for array data (Mueller elements, PHIPS detectors).
123
+
124
+ Tracks convergence across multiple bins/detectors, reporting worst-case progress.
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ name: str,
130
+ tolerance: float,
131
+ tolerance_type: str = "relative",
132
+ indices: Optional[List[int]] = None,
133
+ ):
134
+ """
135
+ Initialize an array convergence variable.
136
+
137
+ Args:
138
+ name: Variable name
139
+ tolerance: Convergence tolerance threshold
140
+ tolerance_type: "relative" or "absolute"
141
+ indices: Specific indices to check (None = all)
142
+ """
143
+ super().__init__(name, tolerance, tolerance_type)
144
+ self.indices = indices
145
+
146
+ def calculate_progress_array(
147
+ self, mean_array: np.ndarray, sem_array: np.ndarray
148
+ ) -> Tuple[float, int, float]:
149
+ """
150
+ Calculate progress for array data based on worst-case bin/detector.
151
+
152
+ Args:
153
+ mean_array: Array of mean values
154
+ sem_array: Array of SEM values
155
+
156
+ Returns:
157
+ Tuple of (progress_percentage, worst_index, worst_sem)
158
+ """
159
+ if len(mean_array) == 0:
160
+ return 0.0, -1, np.inf
161
+
162
+ # Filter to specific indices if requested
163
+ if self.indices is not None:
164
+ mean_array = mean_array[self.indices]
165
+ sem_array = sem_array[self.indices]
166
+
167
+ # Find worst-case bin
168
+ if self.tolerance_type == "relative":
169
+ relative_sem_array = np.where(
170
+ mean_array != 0, sem_array / np.abs(mean_array), float("inf")
171
+ )
172
+ worst_idx = np.argmax(relative_sem_array)
173
+ worst_sem = relative_sem_array[worst_idx]
174
+ else:
175
+ worst_idx = np.argmax(sem_array)
176
+ worst_sem = sem_array[worst_idx]
177
+
178
+ # Calculate progress using worst SEM
179
+ if worst_sem > 0 and not np.isinf(worst_sem):
180
+ progress = min(100.0, np.sqrt(self.tolerance / worst_sem) * 100.0)
181
+ else:
182
+ progress = 0.0
183
+
184
+ return progress, worst_idx, worst_sem
185
+
186
+ def format_sem_info_array(self, worst_sem: float) -> str:
187
+ """
188
+ Format SEM info for array data.
189
+
190
+ Args:
191
+ worst_sem: Worst-case SEM (already relative if tolerance_type is relative)
192
+
193
+ Returns:
194
+ Formatted SEM string
195
+ """
196
+ if self.tolerance_type == "relative":
197
+ current_str = f"{worst_sem * 100:.2f}%"
198
+ target_str = f"{self.tolerance * 100:.1f}%"
199
+ else:
200
+ current_str = f"{worst_sem:.4g}"
201
+ target_str = f"{self.tolerance:.4g}"
202
+
203
+ return f"[SEM: {current_str} / {target_str}]"
204
+
205
+ def count_converged(
206
+ self, mean_array: np.ndarray, sem_array: np.ndarray
207
+ ) -> Tuple[int, int]:
208
+ """
209
+ Count how many bins/detectors have converged.
210
+
211
+ Args:
212
+ mean_array: Array of mean values
213
+ sem_array: Array of SEM values
214
+
215
+ Returns:
216
+ Tuple of (converged_count, total_count)
217
+ """
218
+ if len(mean_array) == 0:
219
+ return 0, 0
220
+
221
+ # Filter to specific indices if requested
222
+ if self.indices is not None:
223
+ mean_array = mean_array[self.indices]
224
+ sem_array = sem_array[self.indices]
225
+
226
+ # Check convergence for each bin
227
+ if self.tolerance_type == "relative":
228
+ relative_sem_array = np.where(
229
+ mean_array != 0, sem_array / np.abs(mean_array), float("inf")
230
+ )
231
+ converged_mask = relative_sem_array < self.tolerance
232
+ else:
233
+ converged_mask = sem_array < self.tolerance
234
+
235
+ converged_count = np.sum(converged_mask)
236
+ total_count = len(mean_array)
237
+
238
+ return int(converged_count), total_count
239
+
240
+ def is_converged_array(self, mean_array: np.ndarray, sem_array: np.ndarray) -> bool:
241
+ """
242
+ Check if all bins/detectors have converged.
243
+
244
+ Args:
245
+ mean_array: Array of mean values
246
+ sem_array: Array of SEM values
247
+
248
+ Returns:
249
+ True if all converged, False otherwise
250
+ """
251
+ converged_count, total_count = self.count_converged(mean_array, sem_array)
252
+ return converged_count == total_count
253
+
254
+
255
+ class ConvergenceDisplay:
256
+ """
257
+ Rich-based display for convergence progress.
258
+
259
+ Provides a unified, reusable display system that works with different
260
+ convergence modes (standard, PHIPS, etc.).
261
+ """
262
+
263
+ def __init__(
264
+ self,
265
+ variables: List[ConvergenceVariable],
266
+ batch_size: int,
267
+ min_batches: int,
268
+ convergence_type: str = "standard",
269
+ console: Optional[Console] = None,
270
+ ):
271
+ """
272
+ Initialize convergence display.
273
+
274
+ Args:
275
+ variables: List of convergence variables to track
276
+ batch_size: Number of orientations per batch
277
+ min_batches: Minimum batches before convergence check
278
+ convergence_type: Display string for convergence mode
279
+ console: Optional Rich console (creates one if None)
280
+ """
281
+ self.variables = variables
282
+ self.batch_size = batch_size
283
+ self.min_batches = min_batches
284
+ self.convergence_type = convergence_type
285
+ self._console = console or Console()
286
+
287
+ # Progress tracking
288
+ self._progress: Optional[Progress] = None
289
+ self._progress_tasks: Dict[str, int] = {}
290
+
291
+ def _initialize_progress(self):
292
+ """Initialize Rich Progress instance and task IDs."""
293
+ if self._progress is not None:
294
+ return
295
+
296
+ self._progress = Progress(
297
+ TextColumn("[bold]{task.fields[variable]:<12}"),
298
+ BarColumn(bar_width=25),
299
+ TextColumn("[bold]{task.percentage:>3.0f}%"),
300
+ TextColumn("[cyan]{task.fields[sem_info]}"),
301
+ console=self._console,
302
+ transient=False,
303
+ )
304
+
305
+ # Add tasks for each variable
306
+ for var in self.variables:
307
+ task_id = self._progress.add_task(
308
+ "",
309
+ total=100,
310
+ variable=var.name,
311
+ sem_info="[SEM: -- / --]",
312
+ )
313
+ self._progress_tasks[var.name] = task_id
314
+
315
+ def build_display(
316
+ self,
317
+ iteration: int,
318
+ n_orientations: int,
319
+ get_stats: Callable[[str], Tuple[float, float]],
320
+ get_array_stats: Optional[
321
+ Callable[[str], Tuple[np.ndarray, np.ndarray]]
322
+ ] = None,
323
+ get_bin_labels: Optional[Callable[[str], np.ndarray]] = None,
324
+ power_ratio: Optional[float] = None,
325
+ geom_info: Optional[str] = None,
326
+ ) -> Group:
327
+ """
328
+ Build rich display for current convergence state.
329
+
330
+ Args:
331
+ iteration: Current batch iteration number
332
+ n_orientations: Total number of orientations processed
333
+ get_stats: Callback to get (mean, sem) for a variable name
334
+ get_array_stats: Optional callback to get (mean_array, sem_array) for array variables
335
+ get_bin_labels: Optional callback to get bin labels (e.g., theta angles) for array variables
336
+ power_ratio: Optional power ratio from solver
337
+ geom_info: Optional geometry info (for ensemble mode)
338
+
339
+ Returns:
340
+ Rich Group containing the full display
341
+ """
342
+ # Initialize progress on first call
343
+ self._initialize_progress()
344
+
345
+ # Calculate minimum required orientations
346
+ min_required = self.min_batches * self.batch_size
347
+
348
+ # Build title with inline spinner
349
+ spinner = Spinner("aesthetic", style="cyan")
350
+ title_text = Text.assemble(
351
+ ("GOAD: ", "bold cyan"),
352
+ (f"[Convergence: {self.convergence_type}] ", "bold white"),
353
+ )
354
+ title = Columns([title_text, spinner], expand=False, padding=(0, 1))
355
+
356
+ # Build batch header
357
+ batch_str = f"[Batch: {iteration}/{self.min_batches}]"
358
+
359
+ # Color orient based on whether min is reached
360
+ orient_color = "green" if n_orientations >= min_required else "red"
361
+ orient_text = Text(
362
+ f"[Orient: {n_orientations}/{min_required} ({self.batch_size})]",
363
+ style=orient_color,
364
+ )
365
+
366
+ # Power ratio (if available)
367
+ if power_ratio is not None:
368
+ power_color = self._get_power_ratio_color(power_ratio)
369
+ power_text = Text(f"Power: {power_ratio:.3f}", style=power_color)
370
+ else:
371
+ power_text = Text("Power: N/A")
372
+
373
+ # Build header
374
+ header_parts = [batch_str, " ", orient_text, " ", power_text]
375
+
376
+ if geom_info:
377
+ header_parts.extend([" ", (f"[{geom_info}]", "dim")])
378
+
379
+ header = Text.assemble(*header_parts)
380
+ separator = Text("━" * 70)
381
+
382
+ # Update progress for each variable
383
+ progress_lines = []
384
+ for var in self.variables:
385
+ task_id = self._progress_tasks[var.name]
386
+
387
+ if isinstance(var, ArrayConvergenceVariable):
388
+ # Array variable (Mueller element, PHIPS detector)
389
+ if get_array_stats is None:
390
+ raise ValueError(
391
+ f"get_array_stats callback required for array variable '{var.name}'"
392
+ )
393
+
394
+ mean_array, sem_array = get_array_stats(var.name)
395
+
396
+ if len(mean_array) == 0:
397
+ self._progress.update(
398
+ task_id, completed=0, sem_info="[SEM: -- / --]"
399
+ )
400
+ mean_str = "--"
401
+ progress_pct = 0
402
+ sem_info = "[SEM: -- / --]"
403
+ else:
404
+ # Calculate progress and find worst bin
405
+ progress, worst_idx, worst_sem = var.calculate_progress_array(
406
+ mean_array, sem_array
407
+ )
408
+ sem_info = var.format_sem_info_array(worst_sem)
409
+
410
+ # Count converged bins
411
+ converged_count, total_count = var.count_converged(
412
+ mean_array, sem_array
413
+ )
414
+
415
+ # Get bin label (e.g., theta angle)
416
+ if get_bin_labels is not None:
417
+ bin_labels = get_bin_labels(var.name)
418
+ worst_label = bin_labels[worst_idx]
419
+ mean_str = f"({converged_count}/{total_count}) Worst: θ={worst_label:.0f}°"
420
+ else:
421
+ mean_str = (
422
+ f"({converged_count}/{total_count}) Worst: #{worst_idx}"
423
+ )
424
+
425
+ progress_pct = int(progress)
426
+
427
+ # Update progress
428
+ self._progress.update(
429
+ task_id, completed=progress, sem_info=sem_info
430
+ )
431
+ else:
432
+ # Scalar variable
433
+ mean, sem = get_stats(var.name)
434
+
435
+ progress = var.calculate_progress(mean, sem)
436
+ sem_info = var.format_sem_info(mean, sem)
437
+ progress_pct = int(progress)
438
+
439
+ # Format mean value
440
+ if not np.isinf(mean):
441
+ mean_str = f"{mean:.4f}"
442
+ else:
443
+ mean_str = "--"
444
+
445
+ # Update progress
446
+ self._progress.update(task_id, completed=progress, sem_info=sem_info)
447
+
448
+ # Render progress bar line
449
+ task = self._progress.tasks[task_id]
450
+ bar_width = 25
451
+ filled = int((task.percentage / 100) * bar_width)
452
+ bar = "█" * filled + "░" * (bar_width - filled)
453
+
454
+ variable_text = f"{task.fields['variable']:<12}"
455
+ mean_display = f"{mean_str:>35}"
456
+ progress_text = f"{progress_pct}%"
457
+
458
+ line = Text.assemble(
459
+ f"{variable_text} ",
460
+ (mean_display, "bold green"),
461
+ f" [{bar}] {progress_text:>3} ",
462
+ (sem_info, "cyan"),
463
+ )
464
+ progress_lines.append(line)
465
+
466
+ # Return full display
467
+ return Group(
468
+ separator,
469
+ title,
470
+ Text(""), # Blank line
471
+ header,
472
+ separator,
473
+ *progress_lines,
474
+ )
475
+
476
+ def _get_power_ratio_color(self, power_ratio: float) -> str:
477
+ """Get color for power ratio based on threshold."""
478
+ if power_ratio >= 0.99:
479
+ return "green"
480
+ elif power_ratio >= 0.95:
481
+ return "yellow"
482
+ else:
483
+ return "red"
484
+
485
+ def create_live_context(self, refresh_per_second: float = 1.3) -> Live:
486
+ """
487
+ Create a Rich Live context for auto-updating display.
488
+
489
+ Args:
490
+ refresh_per_second: Refresh rate (default: 1.3 fps for smooth animation)
491
+
492
+ Returns:
493
+ Rich Live context manager
494
+ """
495
+ return Live(
496
+ console=self._console,
497
+ refresh_per_second=refresh_per_second,
498
+ transient=False,
499
+ )