reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,268 @@
1
+ import ipywidgets as W
2
+ from IPython.display import display, HTML
3
+ from typing import List, Dict, Tuple, Callable, Optional, Any
4
+
5
+
6
+ class CustomSelect:
7
+ """
8
+ A styled select widget for displaying tabular data with monospace formatting.
9
+
10
+ Features:
11
+ - Clean, modern styling with hover effects
12
+ - Automatic column width calculation
13
+ - Header display
14
+ - Click callback support
15
+ - Easy data updates
16
+ """
17
+
18
+ # Inject CSS once when class is first used
19
+ _css_injected = False
20
+
21
+ @classmethod
22
+ def _inject_css(cls):
23
+ if not cls._css_injected:
24
+ display(HTML("""
25
+ <style>
26
+ .custom-select-container {
27
+ font-family: 'SF Mono', 'Monaco', 'Consolas', monospace;
28
+ }
29
+
30
+ .custom-select-header {
31
+ margin: 0;
32
+ padding: 8px 12px;
33
+ font-family: 'SF Mono', 'Monaco', 'Consolas', monospace;
34
+ font-weight: 600;
35
+ font-size: 13px;
36
+ background: linear-gradient(to bottom, #f8f9fa 0%, #e9ecef 100%);
37
+ border: 1px solid #dee2e6;
38
+ border-bottom: 2px solid #adb5bd;
39
+ border-radius: 6px 6px 0 0;
40
+ color: #495057;
41
+ letter-spacing: 0.3px;
42
+ }
43
+
44
+ .custom-select select {
45
+ font-family: 'SF Mono', 'Monaco', 'Consolas', monospace;
46
+ font-size: 13px;
47
+ background-color: #ffffff;
48
+ border: 1px solid #dee2e6;
49
+ border-top: none;
50
+ border-radius: 0 0 6px 6px;
51
+ padding: 6px 12px;
52
+ color: #212529;
53
+ line-height: 1.6;
54
+ transition: all 0.2s ease;
55
+ }
56
+
57
+ .custom-select select:hover {
58
+ background-color: #f8f9fa;
59
+ }
60
+
61
+ .custom-select select:focus {
62
+ outline: none;
63
+ border-color: #0d6efd;
64
+ box-shadow: 0 0 0 3px rgba(13, 110, 253, 0.1);
65
+ background-color: #ffffff;
66
+ }
67
+
68
+ .custom-select select option {
69
+ padding: 4px 8px;
70
+ }
71
+
72
+ .custom-select select option:hover {
73
+ background-color: #e7f1ff;
74
+ }
75
+
76
+ .custom-select-details {
77
+ margin-top: 12px;
78
+ padding: 12px 16px;
79
+ background-color: #f8f9fa;
80
+ border: 1px solid #dee2e6;
81
+ border-radius: 6px;
82
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
83
+ font-size: 14px;
84
+ color: #495057;
85
+ box-shadow: 0 1px 3px rgba(0,0,0,0.05);
86
+ }
87
+
88
+ .custom-select-details i {
89
+ color: #6c757d;
90
+ }
91
+
92
+ .custom-select-details b {
93
+ color: #0d6efd;
94
+ font-weight: 600;
95
+ }
96
+ </style>
97
+ """))
98
+ cls._css_injected = True
99
+
100
+ def __init__(
101
+ self,
102
+ data: List[Dict[str, Any]],
103
+ columns: List[Tuple[str, str]],
104
+ width: str = "auto",
105
+ max_rows: int = 10,
106
+ show_details: bool = False,
107
+ details_template: Optional[str] = None,
108
+ layout: Optional[W.Layout] = None,
109
+ ):
110
+ """
111
+ Initialize the CustomSelect widget.
112
+
113
+ Args:
114
+ data: List of dictionaries containing row data
115
+ columns: List of tuples (header_label, data_key)
116
+ width: CSS width for the select widget (e.g., '600px', 'auto')
117
+ max_rows: Maximum number of visible rows
118
+ show_details: Whether to show the details panel below
119
+ details_template: Custom HTML template for details (receives row dict)
120
+ layout: Layout for the VBox containing the widget
121
+ """
122
+ self._inject_css()
123
+
124
+ self.data = data
125
+ self.columns = columns
126
+ self.width = width
127
+ self.max_rows = max_rows
128
+ self.show_details = show_details
129
+ self.details_template = details_template
130
+ self._callback: Optional[Callable] = None
131
+
132
+ # Create widgets
133
+ self._header_widget = W.HTML()
134
+ self._select_widget = W.Select(layout=W.Layout(width=width))
135
+ self._select_widget.add_class("custom-select")
136
+
137
+ self._details_widget = W.HTML("<i>Select a row to view details…</i>")
138
+ self._details_widget.add_class("custom-select-details")
139
+
140
+ # Setup observer
141
+ self._select_widget.observe(self._on_selection_change, names="value")
142
+
143
+ # Build the widget
144
+ self._update_display()
145
+
146
+ # Container
147
+ widgets = [self._header_widget, self._select_widget]
148
+ if self.show_details:
149
+ widgets.append(self._details_widget)
150
+
151
+ if layout is not None:
152
+ self.container = W.VBox(widgets, layout=layout)
153
+ else:
154
+ self.container = W.VBox(widgets)
155
+
156
+ self.container.add_class("custom-select-container")
157
+
158
+ def _calculate_column_widths(self) -> Dict[str, int]:
159
+ """Calculate the width needed for each column."""
160
+ widths = {}
161
+ for header, key in self.columns:
162
+ header_len = len(str(header))
163
+ if self.data:
164
+ max_data_len = max(len(str(row.get(key, ""))) for row in self.data)
165
+ widths[key] = max(header_len, max_data_len)
166
+ else:
167
+ widths[key] = header_len
168
+ return widths
169
+
170
+ def _format_row(self, row: Dict[str, Any], widths: Dict[str, int]) -> str:
171
+ """Format a row with proper column alignment."""
172
+ parts = []
173
+ for _, key in self.columns:
174
+ value = str(row.get(key, ""))
175
+ parts.append(value.ljust(widths[key]))
176
+ return " ".join(parts)
177
+
178
+ def _update_display(self):
179
+ """Update the header and select options based on current data."""
180
+ if not self.data:
181
+ self._header_widget.value = "<pre class='custom-select-header'><i>No data</i></pre>"
182
+ self._select_widget.options = []
183
+ return
184
+
185
+ # Calculate widths
186
+ widths = self._calculate_column_widths()
187
+
188
+ # Build header
189
+ header_parts = [h.ljust(widths[k]) for h, k in self.columns]
190
+ header_text = " ".join(header_parts)
191
+ self._header_widget.value = f"<pre class='custom-select-header'>{header_text}</pre>"
192
+
193
+ # Build options
194
+ options = [(self._format_row(row, widths), row) for row in self.data]
195
+ self._select_widget.options = options
196
+ self._select_widget.rows = min(self.max_rows, len(self.data))
197
+
198
+ def _on_selection_change(self, change):
199
+ """Handle selection changes."""
200
+ if change["name"] == "value" and change["new"] is not None:
201
+ row = change["new"]
202
+
203
+ # Update details
204
+ if self.show_details:
205
+ if self.details_template:
206
+ self._details_widget.value = self.details_template.format(**row)
207
+ else:
208
+ # Default details display
209
+ details_parts = [f"<b>{k}:</b> {v}" for k, v in row.items()]
210
+ self._details_widget.value = " | ".join(details_parts)
211
+
212
+ # Call user callback
213
+ if self._callback:
214
+ self._callback(row)
215
+
216
+ def set_data(self, data: List[Dict[str, Any]]):
217
+ """
218
+ Update the data displayed in the select widget.
219
+
220
+ Args:
221
+ data: New list of dictionaries containing row data
222
+ """
223
+ self.data = data + [{}]
224
+ self._update_display()
225
+ if self.show_details:
226
+ self._details_widget.value = "<i>Select a row to view details…</i>"
227
+
228
+ def set_columns(self, columns: List[Tuple[str, str]]):
229
+ """
230
+ Update the columns definition.
231
+
232
+ Args:
233
+ columns: New list of tuples (header_label, data_key)
234
+ """
235
+ self.columns = columns
236
+ self._update_display()
237
+
238
+ def on_select(self, callback: Callable[[Dict[str, Any]], None]):
239
+ """
240
+ Register a callback function to be called when a row is selected.
241
+
242
+ Args:
243
+ callback: Function that receives the selected row dictionary
244
+ """
245
+ self._callback = callback
246
+ return self
247
+
248
+ def get_selected(self) -> Optional[Dict[str, Any]]:
249
+ """Get the currently selected row, or None if nothing is selected."""
250
+ return self._select_widget.value
251
+
252
+ def set_selected_index(self, index: int):
253
+ """Select a row by index."""
254
+ if 0 <= index < len(self.data):
255
+ self._select_widget.value = self.data[index]
256
+
257
+ def clear_selection(self):
258
+ """Clear the current selection."""
259
+ self._select_widget.value = None
260
+ if self.show_details:
261
+ self._details_widget.value = "<i>Select a row to view details…</i>"
262
+
263
+ def display(self):
264
+ """Display the widget."""
265
+ display(self.container)
266
+
267
+ def __repr__(self):
268
+ return f"CustomSelect(rows={len(self.data)}, columns={len(self.columns)})"
@@ -0,0 +1,241 @@
1
+ """
2
+ Log Widget for Jupyter Interfaces
3
+
4
+ A clean, reusable log widget with print redirection capabilities.
5
+ """
6
+
7
+ import sys
8
+ import contextlib
9
+ from io import StringIO
10
+ from typing import Optional
11
+ import ipywidgets as widgets
12
+
13
+
14
+ class LogWidget:
15
+ """
16
+ A clean log widget with print redirection capabilities.
17
+
18
+ Features:
19
+ - Hidden by default, auto-shows when messages arrive
20
+ - Clear and toggle visibility controls
21
+ - Context manager for print redirection
22
+ - Clean, professional styling
23
+
24
+ Example:
25
+ ```python
26
+ log_widget = LogWidget()
27
+
28
+ # Use in layout
29
+ layout = widgets.VBox([
30
+ main_content,
31
+ log_widget.widget
32
+ ])
33
+
34
+ # Redirect prints
35
+ with log_widget.capture_prints():
36
+ print("This goes to the log widget")
37
+
38
+ # Or use the convenience method
39
+ log_widget.log("Direct message to log")
40
+ ```
41
+ """
42
+
43
+ def __init__(self,
44
+ height: str = '150px',
45
+ hidden_by_default: bool = True,
46
+ auto_show_on_message: bool = True):
47
+ """
48
+ Initialize the log widget.
49
+
50
+ Args:
51
+ height: Height of the log output area
52
+ hidden_by_default: Whether to start with log hidden
53
+ auto_show_on_message: Whether to auto-show log when messages arrive
54
+ """
55
+ self.auto_show_on_message = auto_show_on_message
56
+ self._create_widgets(height, hidden_by_default)
57
+ self._setup_event_handlers()
58
+
59
+ def _create_widgets(self, height: str, hidden_by_default: bool):
60
+ """Create the log widget components"""
61
+ # Create log output area
62
+ self.output = widgets.Output(
63
+ layout=widgets.Layout(
64
+ height=height,
65
+ width='100%',
66
+ border='1px solid #ccc',
67
+ overflow='auto',
68
+ display='none' if hidden_by_default else ''
69
+ )
70
+ )
71
+
72
+ # Create control buttons
73
+ self.clear_button = widgets.Button(
74
+ description="Clear Log",
75
+ button_style='warning',
76
+ tooltip='Clear all log messages',
77
+ layout=widgets.Layout(width='100px')
78
+ )
79
+
80
+ self.toggle_button = widgets.Button(
81
+ description="Show Log" if hidden_by_default else "Hide Log",
82
+ button_style='info',
83
+ tooltip='Toggle log visibility',
84
+ layout=widgets.Layout(width='100px')
85
+ )
86
+
87
+ # Create label
88
+ label = widgets.HTML("<b>Log Messages:</b>")
89
+
90
+ # Create controls layout
91
+ controls = widgets.HBox([
92
+ label,
93
+ widgets.HTML("&nbsp;" * 10), # Spacer
94
+ self.clear_button,
95
+ self.toggle_button
96
+ ])
97
+
98
+ # Complete widget
99
+ self.widget = widgets.VBox([
100
+ controls,
101
+ self.output
102
+ ])
103
+
104
+ def _setup_event_handlers(self):
105
+ """Setup button event handlers"""
106
+ def clear_log(_):
107
+ self.output.clear_output()
108
+
109
+ def toggle_log(_):
110
+ if self.output.layout.display == 'none':
111
+ self.show()
112
+ else:
113
+ self.hide()
114
+
115
+ self.clear_button.on_click(clear_log)
116
+ self.toggle_button.on_click(toggle_log)
117
+
118
+ def show(self):
119
+ """Show the log output area"""
120
+ self.output.layout.display = ''
121
+ self.toggle_button.description = "Hide Log"
122
+
123
+ def hide(self):
124
+ """Hide the log output area"""
125
+ self.output.layout.display = 'none'
126
+ self.toggle_button.description = "Show Log"
127
+
128
+ def is_visible(self) -> bool:
129
+ """Check if log is currently visible"""
130
+ return self.output.layout.display != 'none'
131
+
132
+ def clear(self):
133
+ """Clear all log messages"""
134
+ self.output.clear_output()
135
+
136
+ def log(self, message: str):
137
+ """
138
+ Add a message directly to the log.
139
+
140
+ Args:
141
+ message: Message to add to the log
142
+ """
143
+ if self.auto_show_on_message and not self.is_visible():
144
+ self.show()
145
+
146
+ with self.output:
147
+ print(message)
148
+
149
+ @contextlib.contextmanager
150
+ def capture_prints(self):
151
+ """
152
+ Context manager to redirect print statements to the log widget.
153
+
154
+ Example:
155
+ with log_widget.capture_prints():
156
+ print("This goes to the log")
157
+ some_function_that_prints()
158
+ """
159
+ if self.output is None:
160
+ # Fallback to normal printing if widget not available
161
+ yield
162
+ return
163
+
164
+ # Capture stdout
165
+ old_stdout = sys.stdout
166
+ captured_output = StringIO()
167
+
168
+ try:
169
+ sys.stdout = captured_output
170
+ yield
171
+ finally:
172
+ sys.stdout = old_stdout
173
+
174
+ # Get captured content and add to log
175
+ content = captured_output.getvalue()
176
+ if content.strip(): # Only add if there's actual content
177
+ if self.auto_show_on_message and not self.is_visible():
178
+ self.show()
179
+
180
+ with self.output:
181
+ print(content.rstrip()) # Remove trailing newlines
182
+
183
+
184
+ class LoggedOperation:
185
+ """
186
+ Decorator/context manager for operations that should log to a specific log widget.
187
+
188
+ Example:
189
+ ```python
190
+ log_widget = LogWidget()
191
+
192
+ # As context manager
193
+ with LoggedOperation(log_widget):
194
+ print("This operation is logged")
195
+ do_something()
196
+
197
+ # As decorator
198
+ @LoggedOperation(log_widget)
199
+ def my_function():
200
+ print("Function output goes to log")
201
+ ```
202
+ """
203
+
204
+ def __init__(self, log_widget: LogWidget):
205
+ self.log_widget = log_widget
206
+
207
+ def __enter__(self):
208
+ self.context = self.log_widget.capture_prints()
209
+ return self.context.__enter__()
210
+
211
+ def __exit__(self, exc_type, exc_val, exc_tb):
212
+ return self.context.__exit__(exc_type, exc_val, exc_tb)
213
+
214
+ def __call__(self, func):
215
+ """Decorator functionality"""
216
+ def wrapper(*args, **kwargs):
217
+ with self:
218
+ return func(*args, **kwargs)
219
+ return wrapper
220
+
221
+
222
+ # Convenience function for quick log widget creation
223
+ def create_log_widget(height: str = '150px',
224
+ hidden_by_default: bool = True,
225
+ auto_show_on_message: bool = True) -> LogWidget:
226
+ """
227
+ Convenience function to create a log widget with common settings.
228
+
229
+ Args:
230
+ height: Height of the log output area
231
+ hidden_by_default: Whether to start with log hidden
232
+ auto_show_on_message: Whether to auto-show log when messages arrive
233
+
234
+ Returns:
235
+ Configured LogWidget instance
236
+ """
237
+ return LogWidget(
238
+ height=height,
239
+ hidden_by_default=hidden_by_default,
240
+ auto_show_on_message=auto_show_on_message
241
+ )