goad-py 0.8.5__pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,532 @@
1
+ """
2
+ Rich-based convergence display for GOAD.
3
+
4
+ This module provides a unified, reusable display system for convergence tracking
5
+ that works with both standard convergence and PHIPS convergence modes.
6
+ """
7
+
8
+ from typing import Callable, Dict, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ from rich.columns import Columns
12
+ from rich.console import Console, Group
13
+ from rich.live import Live
14
+ from rich.progress import BarColumn, Progress, TextColumn
15
+ from rich.spinner import Spinner
16
+ from rich.text import Text
17
+
18
+
19
+ class ConvergenceVariable:
20
+ """
21
+ Base class for convergence variables.
22
+
23
+ Encapsulates the logic for calculating progress, formatting display values,
24
+ and checking convergence for a single variable.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ name: str,
30
+ tolerance: float,
31
+ tolerance_type: str = "relative",
32
+ ):
33
+ """
34
+ Initialize a convergence variable.
35
+
36
+ Args:
37
+ name: Variable name (e.g., "asymmetry", "S11", "phips_dscs")
38
+ tolerance: Convergence tolerance threshold
39
+ tolerance_type: "relative" or "absolute"
40
+ """
41
+ self.name = name
42
+ self.tolerance = tolerance
43
+ self.tolerance_type = tolerance_type
44
+
45
+ if tolerance_type not in {"relative", "absolute"}:
46
+ raise ValueError(f"Invalid tolerance_type '{tolerance_type}'")
47
+
48
+ if tolerance <= 0:
49
+ raise ValueError(f"Tolerance must be positive, got {tolerance}")
50
+
51
+ def calculate_progress(self, mean: float, sem: float) -> float:
52
+ """
53
+ Calculate progress percentage (0-100) based on current SEM.
54
+
55
+ Uses sqrt formula for smoother progression: sqrt(target/current) * 100
56
+
57
+ Args:
58
+ mean: Current mean value
59
+ sem: Current SEM value
60
+
61
+ Returns:
62
+ Progress percentage (0-100)
63
+ """
64
+ if self.tolerance_type == "relative":
65
+ if mean != 0:
66
+ relative_sem = sem / abs(mean)
67
+ if relative_sem > 0 and not np.isinf(relative_sem):
68
+ return min(100.0, np.sqrt(self.tolerance / relative_sem) * 100.0)
69
+ else:
70
+ # Absolute
71
+ if sem > 0 and not np.isinf(sem):
72
+ return min(100.0, np.sqrt(self.tolerance / sem) * 100.0)
73
+
74
+ return 0.0
75
+
76
+ def format_sem_info(self, mean: float, sem: float) -> str:
77
+ """
78
+ Format SEM info string for display.
79
+
80
+ Args:
81
+ mean: Current mean value
82
+ sem: Current SEM value
83
+
84
+ Returns:
85
+ Formatted string like "[SEM: 19.7% / 20.0%]" or "[SEM: 0.0042 / 0.0050]"
86
+ """
87
+ if self.tolerance_type == "relative":
88
+ if mean != 0:
89
+ relative_sem = sem / abs(mean)
90
+ current_str = f"{relative_sem * 100:.1f}%"
91
+ else:
92
+ current_str = f"{sem:.4g}"
93
+ target_str = f"{self.tolerance * 100:.1f}%"
94
+ else:
95
+ current_str = f"{sem:.4g}"
96
+ target_str = f"{self.tolerance:.4g}"
97
+
98
+ return f"[SEM: {current_str} / {target_str}]"
99
+
100
+ def is_converged(self, mean: float, sem: float) -> bool:
101
+ """
102
+ Check if this variable has converged.
103
+
104
+ Args:
105
+ mean: Current mean value
106
+ sem: Current SEM value
107
+
108
+ Returns:
109
+ True if converged, False otherwise
110
+ """
111
+ if self.tolerance_type == "relative":
112
+ if mean != 0:
113
+ relative_sem = sem / abs(mean)
114
+ return relative_sem < self.tolerance
115
+ return False
116
+ else:
117
+ return sem < self.tolerance
118
+
119
+
120
+ class ArrayConvergenceVariable(ConvergenceVariable):
121
+ """
122
+ Convergence variable for array data (Mueller elements, PHIPS detectors).
123
+
124
+ Tracks convergence across multiple bins/detectors, reporting worst-case progress.
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ name: str,
130
+ tolerance: float,
131
+ tolerance_type: str = "relative",
132
+ indices: Optional[List[int]] = None,
133
+ ):
134
+ """
135
+ Initialize an array convergence variable.
136
+
137
+ Args:
138
+ name: Variable name
139
+ tolerance: Convergence tolerance threshold
140
+ tolerance_type: "relative" or "absolute"
141
+ indices: Specific indices to check (None = all)
142
+ """
143
+ super().__init__(name, tolerance, tolerance_type)
144
+ self.indices = indices
145
+
146
+ def calculate_progress_array(
147
+ self, mean_array: np.ndarray, sem_array: np.ndarray
148
+ ) -> Tuple[float, int, float]:
149
+ """
150
+ Calculate progress for array data based on worst-case bin/detector.
151
+
152
+ Args:
153
+ mean_array: Array of mean values
154
+ sem_array: Array of SEM values
155
+
156
+ Returns:
157
+ Tuple of (progress_percentage, worst_index, worst_sem)
158
+ """
159
+ if len(mean_array) == 0:
160
+ return 0.0, -1, np.inf
161
+
162
+ # Filter to specific indices if requested
163
+ if self.indices is not None:
164
+ mean_array = mean_array[self.indices]
165
+ sem_array = sem_array[self.indices]
166
+
167
+ # Find worst-case bin
168
+ if self.tolerance_type == "relative":
169
+ relative_sem_array = np.where(
170
+ mean_array != 0, sem_array / np.abs(mean_array), float("inf")
171
+ )
172
+ worst_idx = np.argmax(relative_sem_array)
173
+ worst_sem = relative_sem_array[worst_idx]
174
+ else:
175
+ worst_idx = np.argmax(sem_array)
176
+ worst_sem = sem_array[worst_idx]
177
+
178
+ # Calculate progress using worst SEM
179
+ if worst_sem > 0 and not np.isinf(worst_sem):
180
+ progress = min(100.0, np.sqrt(self.tolerance / worst_sem) * 100.0)
181
+ else:
182
+ progress = 0.0
183
+
184
+ return progress, worst_idx, worst_sem
185
+
186
+ def format_sem_info_array(self, worst_sem: float) -> str:
187
+ """
188
+ Format SEM info for array data.
189
+
190
+ Args:
191
+ worst_sem: Worst-case SEM (already relative if tolerance_type is relative)
192
+
193
+ Returns:
194
+ Formatted SEM string
195
+ """
196
+ if self.tolerance_type == "relative":
197
+ current_str = f"{worst_sem * 100:.2f}%"
198
+ target_str = f"{self.tolerance * 100:.1f}%"
199
+ else:
200
+ current_str = f"{worst_sem:.4g}"
201
+ target_str = f"{self.tolerance:.4g}"
202
+
203
+ return f"[SEM: {current_str} / {target_str}]"
204
+
205
+ def count_converged(
206
+ self, mean_array: np.ndarray, sem_array: np.ndarray
207
+ ) -> Tuple[int, int]:
208
+ """
209
+ Count how many bins/detectors have converged.
210
+
211
+ Args:
212
+ mean_array: Array of mean values
213
+ sem_array: Array of SEM values
214
+
215
+ Returns:
216
+ Tuple of (converged_count, total_count)
217
+ """
218
+ if len(mean_array) == 0:
219
+ return 0, 0
220
+
221
+ # Filter to specific indices if requested
222
+ if self.indices is not None:
223
+ mean_array = mean_array[self.indices]
224
+ sem_array = sem_array[self.indices]
225
+
226
+ # Check convergence for each bin
227
+ if self.tolerance_type == "relative":
228
+ relative_sem_array = np.where(
229
+ mean_array != 0, sem_array / np.abs(mean_array), float("inf")
230
+ )
231
+ converged_mask = relative_sem_array < self.tolerance
232
+ else:
233
+ converged_mask = sem_array < self.tolerance
234
+
235
+ converged_count = np.sum(converged_mask)
236
+ total_count = len(mean_array)
237
+
238
+ return int(converged_count), total_count
239
+
240
+ def is_converged_array(self, mean_array: np.ndarray, sem_array: np.ndarray) -> bool:
241
+ """
242
+ Check if all bins/detectors have converged.
243
+
244
+ Args:
245
+ mean_array: Array of mean values
246
+ sem_array: Array of SEM values
247
+
248
+ Returns:
249
+ True if all converged, False otherwise
250
+ """
251
+ converged_count, total_count = self.count_converged(mean_array, sem_array)
252
+ return converged_count == total_count
253
+
254
+
255
+ class ConvergenceDisplay:
256
+ """
257
+ Rich-based display for convergence progress.
258
+
259
+ Provides a unified, reusable display system that works with different
260
+ convergence modes (standard, PHIPS, etc.).
261
+ """
262
+
263
+ def __init__(
264
+ self,
265
+ variables: List[ConvergenceVariable],
266
+ batch_size: int,
267
+ min_batches: int,
268
+ convergence_type: str = "standard",
269
+ console: Optional[Console] = None,
270
+ log_file: Optional[str] = None,
271
+ ):
272
+ """
273
+ Initialize convergence display.
274
+
275
+ Args:
276
+ variables: List of convergence variables to track
277
+ batch_size: Number of orientations per batch
278
+ min_batches: Minimum batches before convergence check
279
+ convergence_type: Display string for convergence mode
280
+ console: Optional Rich console (creates one if None)
281
+ log_file: Optional path to log file for convergence progress
282
+ """
283
+ self.variables = variables
284
+ self.batch_size = batch_size
285
+ self.min_batches = min_batches
286
+ self.convergence_type = convergence_type
287
+ self._console = console or Console()
288
+
289
+ # File logging
290
+ self.log_file = log_file
291
+ self._file_console = None
292
+ self._file_handle = None
293
+ if self.log_file:
294
+ # Create a separate console for file output
295
+ self._file_handle = open(self.log_file, "w")
296
+ self._file_console = Console(file=self._file_handle, width=120)
297
+ # Write header
298
+ self._file_console.print(f"GOAD Convergence Log - {convergence_type}")
299
+ self._file_console.print("=" * 120)
300
+
301
+ # Progress tracking
302
+ self._progress: Optional[Progress] = None
303
+ self._progress_tasks: Dict[str, int] = {}
304
+
305
+ def _initialize_progress(self):
306
+ """Initialize Rich Progress instance and task IDs."""
307
+ if self._progress is not None:
308
+ return
309
+
310
+ self._progress = Progress(
311
+ TextColumn("[bold]{task.fields[variable]:<12}"),
312
+ BarColumn(bar_width=25),
313
+ TextColumn("[bold]{task.percentage:>3.0f}%"),
314
+ TextColumn("[cyan]{task.fields[sem_info]}"),
315
+ console=self._console,
316
+ transient=False,
317
+ )
318
+
319
+ # Add tasks for each variable
320
+ for var in self.variables:
321
+ task_id = self._progress.add_task(
322
+ "",
323
+ total=100,
324
+ variable=var.name,
325
+ sem_info="[SEM: -- / --]",
326
+ )
327
+ self._progress_tasks[var.name] = task_id
328
+
329
+ def build_display(
330
+ self,
331
+ iteration: int,
332
+ n_orientations: int,
333
+ get_stats: Callable[[str], Tuple[float, float]],
334
+ get_array_stats: Optional[
335
+ Callable[[str], Tuple[np.ndarray, np.ndarray]]
336
+ ] = None,
337
+ get_bin_labels: Optional[Callable[[str], np.ndarray]] = None,
338
+ power_ratio: Optional[float] = None,
339
+ geom_info: Optional[str] = None,
340
+ ) -> Group:
341
+ """
342
+ Build rich display for current convergence state.
343
+
344
+ Args:
345
+ iteration: Current batch iteration number
346
+ n_orientations: Total number of orientations processed
347
+ get_stats: Callback to get (mean, sem) for a variable name
348
+ get_array_stats: Optional callback to get (mean_array, sem_array) for array variables
349
+ get_bin_labels: Optional callback to get bin labels (e.g., theta angles) for array variables
350
+ power_ratio: Optional power ratio from solver
351
+ geom_info: Optional geometry info (for ensemble mode)
352
+
353
+ Returns:
354
+ Rich Group containing the full display
355
+ """
356
+ # Initialize progress on first call
357
+ self._initialize_progress()
358
+
359
+ # Calculate minimum required orientations
360
+ min_required = self.min_batches * self.batch_size
361
+
362
+ # Build title with inline spinner
363
+ spinner = Spinner("aesthetic", style="cyan")
364
+ title_text = Text.assemble(
365
+ ("GOAD: ", "bold cyan"),
366
+ (f"[Convergence: {self.convergence_type}] ", "bold white"),
367
+ )
368
+ title = Columns([title_text, spinner], expand=False, padding=(0, 1))
369
+
370
+ # Build batch header
371
+ batch_str = f"[Batch: {iteration}/{self.min_batches}]"
372
+
373
+ # Color orient based on whether min is reached
374
+ orient_color = "green" if n_orientations >= min_required else "red"
375
+ orient_text = Text(
376
+ f"[Orient: {n_orientations}/{min_required} ({self.batch_size})]",
377
+ style=orient_color,
378
+ )
379
+
380
+ # Power ratio (if available)
381
+ if power_ratio is not None:
382
+ power_color = self._get_power_ratio_color(power_ratio)
383
+ power_text = Text(f"Power: {power_ratio:.3f}", style=power_color)
384
+ else:
385
+ power_text = Text("Power: N/A")
386
+
387
+ # Build header
388
+ header_parts = [batch_str, " ", orient_text, " ", power_text]
389
+
390
+ if geom_info:
391
+ header_parts.extend([" ", (f"[{geom_info}]", "dim")])
392
+
393
+ header = Text.assemble(*header_parts)
394
+ separator = Text("━" * 70)
395
+
396
+ # Update progress for each variable
397
+ progress_lines = []
398
+ for var in self.variables:
399
+ task_id = self._progress_tasks[var.name]
400
+
401
+ if isinstance(var, ArrayConvergenceVariable):
402
+ # Array variable (Mueller element, PHIPS detector)
403
+ if get_array_stats is None:
404
+ raise ValueError(
405
+ f"get_array_stats callback required for array variable '{var.name}'"
406
+ )
407
+
408
+ mean_array, sem_array = get_array_stats(var.name)
409
+
410
+ if len(mean_array) == 0:
411
+ self._progress.update(
412
+ task_id, completed=0, sem_info="[SEM: -- / --]"
413
+ )
414
+ mean_str = "--"
415
+ progress_pct = 0
416
+ sem_info = "[SEM: -- / --]"
417
+ else:
418
+ # Calculate progress and find worst bin
419
+ progress, worst_idx, worst_sem = var.calculate_progress_array(
420
+ mean_array, sem_array
421
+ )
422
+ sem_info = var.format_sem_info_array(worst_sem)
423
+
424
+ # Count converged bins
425
+ converged_count, total_count = var.count_converged(
426
+ mean_array, sem_array
427
+ )
428
+
429
+ # Get bin label (e.g., theta angle)
430
+ if get_bin_labels is not None:
431
+ bin_labels = get_bin_labels(var.name)
432
+ worst_label = bin_labels[worst_idx]
433
+ mean_str = f"({converged_count}/{total_count}) Worst: θ={worst_label:.0f}°"
434
+ else:
435
+ mean_str = (
436
+ f"({converged_count}/{total_count}) Worst: #{worst_idx}"
437
+ )
438
+
439
+ progress_pct = int(progress)
440
+
441
+ # Update progress
442
+ self._progress.update(
443
+ task_id, completed=progress, sem_info=sem_info
444
+ )
445
+ else:
446
+ # Scalar variable
447
+ mean, sem = get_stats(var.name)
448
+
449
+ progress = var.calculate_progress(mean, sem)
450
+ sem_info = var.format_sem_info(mean, sem)
451
+ progress_pct = int(progress)
452
+
453
+ # Format mean value
454
+ if not np.isinf(mean):
455
+ mean_str = f"{mean:.4f}"
456
+ else:
457
+ mean_str = "--"
458
+
459
+ # Update progress
460
+ self._progress.update(task_id, completed=progress, sem_info=sem_info)
461
+
462
+ # Render progress bar line
463
+ task = self._progress.tasks[task_id]
464
+ bar_width = 25
465
+ filled = int((task.percentage / 100) * bar_width)
466
+ bar = "█" * filled + "░" * (bar_width - filled)
467
+
468
+ variable_text = f"{task.fields['variable']:<12}"
469
+ mean_display = f"{mean_str:>35}"
470
+ progress_text = f"{progress_pct}%"
471
+
472
+ line = Text.assemble(
473
+ f"{variable_text} ",
474
+ (mean_display, "bold green"),
475
+ f" [{bar}] {progress_text:>3} ",
476
+ (sem_info, "cyan"),
477
+ )
478
+ progress_lines.append(line)
479
+
480
+ # Build full display
481
+ display = Group(
482
+ separator,
483
+ title,
484
+ Text(""), # Blank line
485
+ header,
486
+ separator,
487
+ *progress_lines,
488
+ )
489
+
490
+ # Log to file if enabled
491
+ if self._file_console is not None:
492
+ self._file_console.print(display)
493
+
494
+ return display
495
+
496
+ def _get_power_ratio_color(self, power_ratio: float) -> str:
497
+ """Get color for power ratio based on threshold."""
498
+ if power_ratio >= 0.99:
499
+ return "green"
500
+ elif power_ratio >= 0.95:
501
+ return "yellow"
502
+ else:
503
+ return "red"
504
+
505
+ def create_live_context(self, refresh_per_second: float = 1.3) -> Live:
506
+ """
507
+ Create a Rich Live context for auto-updating display.
508
+
509
+ Args:
510
+ refresh_per_second: Refresh rate (default: 1.3 fps for smooth animation)
511
+
512
+ Returns:
513
+ Rich Live context manager
514
+ """
515
+ return Live(
516
+ console=self._console,
517
+ refresh_per_second=refresh_per_second,
518
+ transient=False,
519
+ )
520
+
521
+ def close(self):
522
+ """Close the log file if open."""
523
+ if self._file_handle is not None:
524
+ self._file_console.print("\n" + "=" * 120)
525
+ self._file_console.print("End of convergence log")
526
+ self._file_handle.close()
527
+ self._file_handle = None
528
+ self._file_console = None
529
+
530
+ def __del__(self):
531
+ """Cleanup when object is destroyed."""
532
+ self.close()