reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,495 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Jupyter Widget Model Selection Component for Reflectorch
|
|
3
|
+
|
|
4
|
+
This module contains the model browsing and selection component for
|
|
5
|
+
reflectometry analysis models from Hugging Face repositories.
|
|
6
|
+
|
|
7
|
+
Components:
|
|
8
|
+
- ModelSelection: Model browsing and selection from Hugging Face
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import Optional, Dict, Any, Callable
|
|
12
|
+
import ipywidgets as widgets
|
|
13
|
+
from huggingface_hub import HfApi
|
|
14
|
+
|
|
15
|
+
from reflectorch.extensions.jupyter.custom_select import CustomSelect
|
|
16
|
+
|
|
17
|
+
class ModelSelection:
|
|
18
|
+
"""Model selection component for Hugging Face models.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, organization_list: tuple[str, ...] = ('reflectorch-ILL', )):
|
|
22
|
+
"""
|
|
23
|
+
Initialize model selection component
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
organization_list: Tuple of Hugging Face organizations to browse models from
|
|
27
|
+
"""
|
|
28
|
+
assert len(organization_list) > 0, "At least one organization must be provided"
|
|
29
|
+
self.organization = organization_list[0]
|
|
30
|
+
self.organization_list = organization_list
|
|
31
|
+
self.hf_api = HfApi()
|
|
32
|
+
self.models_data = []
|
|
33
|
+
self.selected_config = None
|
|
34
|
+
self.selected_model = None
|
|
35
|
+
self._model_cache = {} # Cache model info by organization
|
|
36
|
+
self._download_callback = None # External download handler
|
|
37
|
+
|
|
38
|
+
self.widget = self._create_model_browser()
|
|
39
|
+
|
|
40
|
+
def _create_model_details_template(self) -> str:
|
|
41
|
+
"""Create a comprehensive HTML template for model details"""
|
|
42
|
+
return """
|
|
43
|
+
<div style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; line-height: 1.5;">
|
|
44
|
+
<div style="margin-bottom: 12px;">
|
|
45
|
+
<h4 style="margin: 0 0 8px 0; color: #0d6efd; font-size: 16px;">📋 {modelId}</h4>
|
|
46
|
+
<div style="font-size: 13px; color: #6c757d;">
|
|
47
|
+
<b>Layers:</b> {num_layers} | <b>Parameterization:</b> {parameterization}
|
|
48
|
+
</div>
|
|
49
|
+
</div>
|
|
50
|
+
|
|
51
|
+
<div style="margin-bottom: 12px;">
|
|
52
|
+
<h5 style="margin: 0 0 6px 0; color: #495057; font-size: 14px;">Parameter Ranges</h5>
|
|
53
|
+
<div style="background: #f8f9fa; border: 1px solid #dee2e6; border-radius: 4px; padding: 8px; font-size: 12px;">
|
|
54
|
+
{param_ranges_table}
|
|
55
|
+
</div>
|
|
56
|
+
</div>
|
|
57
|
+
|
|
58
|
+
<div style="margin-bottom: 12px;">
|
|
59
|
+
<h5 style="margin: 0 0 6px 0; color: #495057; font-size: 14px;">Prior Bound Widths</h5>
|
|
60
|
+
<div style="background: #f8f9fa; border: 1px solid #dee2e6; border-radius: 4px; padding: 8px; font-size: 12px;">
|
|
61
|
+
{bound_width_ranges_table}
|
|
62
|
+
</div>
|
|
63
|
+
</div>
|
|
64
|
+
</div>
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def _create_model_browser(self) -> widgets.VBox:
|
|
68
|
+
"""Create the model browser widget"""
|
|
69
|
+
# Title and description
|
|
70
|
+
title = widgets.HTML(
|
|
71
|
+
"<h3>🤗 Model Selection</h3>",
|
|
72
|
+
layout=widgets.Layout(margin='0px 0px 10px 0px')
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
description = widgets.HTML(
|
|
76
|
+
"<p>Browse and select reflectometry models from Hugging Face repositories.</p>",
|
|
77
|
+
layout=widgets.Layout(margin='0px 0px 15px 0px')
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Download button - initially disabled
|
|
81
|
+
download_model_button = widgets.Button(
|
|
82
|
+
description="📥 Download Model",
|
|
83
|
+
button_style='success',
|
|
84
|
+
disabled=True,
|
|
85
|
+
tooltip='Select a model to enable download',
|
|
86
|
+
layout=widgets.Layout(width='150px')
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
refresh_button = widgets.Button(
|
|
90
|
+
description="🔄 Refresh",
|
|
91
|
+
button_style='info',
|
|
92
|
+
tooltip='Refresh model list (uses cache after first load)',
|
|
93
|
+
layout=widgets.Layout(width='100px')
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Organization selection
|
|
97
|
+
org_dropdown = widgets.Dropdown(
|
|
98
|
+
options=[(org, org) for org in self.organization_list],
|
|
99
|
+
value=self.organization,
|
|
100
|
+
description='Organization:',
|
|
101
|
+
layout=widgets.Layout(width='300px')
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Controls row
|
|
105
|
+
controls_row = widgets.HBox([
|
|
106
|
+
org_dropdown,
|
|
107
|
+
refresh_button,
|
|
108
|
+
download_model_button,
|
|
109
|
+
], layout=widgets.Layout(justify_content='flex-start', margin='0px 0px 15px 0px'))
|
|
110
|
+
|
|
111
|
+
# Status
|
|
112
|
+
status_label = widgets.HTML(
|
|
113
|
+
"Click 'Refresh' to load models",
|
|
114
|
+
layout=widgets.Layout(margin='5px 0px')
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# Add selection status widget
|
|
119
|
+
selection_status = widgets.HTML(
|
|
120
|
+
value="<i>Click a model row to select it</i>",
|
|
121
|
+
layout=widgets.Layout(margin='5px 0px')
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Model selection dropdown
|
|
125
|
+
model_selector = CustomSelect(
|
|
126
|
+
data=[],
|
|
127
|
+
columns=[
|
|
128
|
+
('Model ID', 'modelId'),
|
|
129
|
+
('Layers', 'num_layers'),
|
|
130
|
+
('Parameterization', 'parameterization')
|
|
131
|
+
],
|
|
132
|
+
layout=widgets.Layout(width='600px'),
|
|
133
|
+
show_details=True,
|
|
134
|
+
details_template=self._create_model_details_template()
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
# Store widget references for event handlers
|
|
139
|
+
self._widgets = {
|
|
140
|
+
'refresh_button': refresh_button,
|
|
141
|
+
'download_button': download_model_button,
|
|
142
|
+
'org_dropdown': org_dropdown,
|
|
143
|
+
'status_label': status_label,
|
|
144
|
+
'model_selector': model_selector,
|
|
145
|
+
'selection_status': selection_status
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
# Setup event handlers
|
|
149
|
+
self._setup_model_browser_events()
|
|
150
|
+
|
|
151
|
+
# Main content area
|
|
152
|
+
return widgets.VBox([
|
|
153
|
+
title,
|
|
154
|
+
description,
|
|
155
|
+
controls_row,
|
|
156
|
+
status_label,
|
|
157
|
+
model_selector.container,
|
|
158
|
+
selection_status,
|
|
159
|
+
])
|
|
160
|
+
|
|
161
|
+
def _setup_model_browser_events(self):
|
|
162
|
+
"""Setup event handlers for model browser"""
|
|
163
|
+
widgets_dict = self._widgets
|
|
164
|
+
|
|
165
|
+
def on_refresh_click(b):
|
|
166
|
+
"""Handle refresh button click"""
|
|
167
|
+
self._load_models()
|
|
168
|
+
|
|
169
|
+
def on_org_change(change):
|
|
170
|
+
"""Handle organization dropdown change"""
|
|
171
|
+
self.organization = change['new']
|
|
172
|
+
self._load_models()
|
|
173
|
+
# Clear selection when organization changes
|
|
174
|
+
self._clear_selection()
|
|
175
|
+
|
|
176
|
+
def on_model_select(row):
|
|
177
|
+
"""Handle model selection from CustomSelect"""
|
|
178
|
+
if row is None:
|
|
179
|
+
self._clear_selection()
|
|
180
|
+
return
|
|
181
|
+
|
|
182
|
+
# Extract model data and index from the row
|
|
183
|
+
model_index = row.get('index')
|
|
184
|
+
model_data = row.get('_model_data')
|
|
185
|
+
|
|
186
|
+
if model_index is not None and model_data is not None:
|
|
187
|
+
self._select_model(model_data, model_index)
|
|
188
|
+
widgets_dict['selection_status'].value = f"<b>Selected:</b> {row.get('modelId')}"
|
|
189
|
+
# Enable download button
|
|
190
|
+
widgets_dict['download_button'].disabled = False
|
|
191
|
+
widgets_dict['download_button'].tooltip = f"Download {row.get('modelId')}"
|
|
192
|
+
|
|
193
|
+
def on_download_click(b):
|
|
194
|
+
"""Handle download button click"""
|
|
195
|
+
if self.selected_model and self._download_callback:
|
|
196
|
+
self._download_callback(self.get_selected_model_info())
|
|
197
|
+
|
|
198
|
+
# Connect event handlers
|
|
199
|
+
widgets_dict['refresh_button'].on_click(on_refresh_click)
|
|
200
|
+
widgets_dict['download_button'].on_click(on_download_click)
|
|
201
|
+
widgets_dict['org_dropdown'].observe(on_org_change, names='value')
|
|
202
|
+
widgets_dict['model_selector'].on_select(on_model_select)
|
|
203
|
+
|
|
204
|
+
def select_model_by_index(self, index):
|
|
205
|
+
"""Public method to select a model by index"""
|
|
206
|
+
if 0 <= index < len(self.models_data):
|
|
207
|
+
model = self.models_data[index]
|
|
208
|
+
self._select_model(model, index)
|
|
209
|
+
model_id = model.get('modelId', 'Unknown')
|
|
210
|
+
self._widgets['selection_status'].value = f"<b>Selected:</b> {model_id}"
|
|
211
|
+
# Enable download button
|
|
212
|
+
self._widgets['download_button'].disabled = False
|
|
213
|
+
self._widgets['download_button'].tooltip = f"Download {model_id}"
|
|
214
|
+
# Update CustomSelect to match
|
|
215
|
+
self._widgets['model_selector'].set_selected_index(index)
|
|
216
|
+
else:
|
|
217
|
+
self._widgets['selection_status'].value = f"<i style='color: red;'>Invalid index: {index}</i>"
|
|
218
|
+
self._clear_selection()
|
|
219
|
+
|
|
220
|
+
def _load_models(self):
|
|
221
|
+
"""Load models synchronously"""
|
|
222
|
+
try:
|
|
223
|
+
self._widgets['status_label'].value = "🔄 Loading models..."
|
|
224
|
+
|
|
225
|
+
# Get models from specific organization using HF API
|
|
226
|
+
if self.organization in self._model_cache:
|
|
227
|
+
self._widgets['status_label'].value = "🔄 Loading cached models..."
|
|
228
|
+
models = self._model_cache[self.organization]
|
|
229
|
+
else:
|
|
230
|
+
self._widgets['status_label'].value = f"🔄 Loading models from {self.organization}..."
|
|
231
|
+
|
|
232
|
+
# List all models from the organization
|
|
233
|
+
hf_models = list(self.hf_api.list_models(author=self.organization, cardData=True))
|
|
234
|
+
|
|
235
|
+
# Convert HF model objects to our format and try to get config info
|
|
236
|
+
models = []
|
|
237
|
+
for i, hf_model in enumerate(hf_models):
|
|
238
|
+
try:
|
|
239
|
+
metadata = hf_model.card_data.get("metadata", {})
|
|
240
|
+
parameterization = metadata.get("parameterization", "slab")
|
|
241
|
+
num_layers = metadata.get("number_of_layers", 0)
|
|
242
|
+
param_ranges = metadata.get("param_ranges", {})
|
|
243
|
+
bound_width_ranges = metadata.get("bound_width_ranges", {})
|
|
244
|
+
misalignment_included = metadata.get("shift_param_config", {})
|
|
245
|
+
|
|
246
|
+
# Create model info
|
|
247
|
+
model_info = {
|
|
248
|
+
'index': i,
|
|
249
|
+
'modelId': hf_model.id,
|
|
250
|
+
'author': self.organization,
|
|
251
|
+
'hf_model': hf_model,
|
|
252
|
+
'metadata': hf_model.card_data.get("metadata", {}),
|
|
253
|
+
'num_layers': num_layers,
|
|
254
|
+
'param_ranges': param_ranges,
|
|
255
|
+
'bound_width_ranges': bound_width_ranges,
|
|
256
|
+
'misalignment_included': misalignment_included,
|
|
257
|
+
'parameterization': parameterization,
|
|
258
|
+
}
|
|
259
|
+
models.append(model_info)
|
|
260
|
+
|
|
261
|
+
except Exception as e:
|
|
262
|
+
print(f"Warning: Could not process model {hf_model.id}: {e}")
|
|
263
|
+
continue
|
|
264
|
+
|
|
265
|
+
# Cache the results
|
|
266
|
+
self._model_cache[self.organization] = models
|
|
267
|
+
|
|
268
|
+
self.models_data = models
|
|
269
|
+
|
|
270
|
+
# Update UI
|
|
271
|
+
self._display_models(models)
|
|
272
|
+
|
|
273
|
+
self._widgets['status_label'].value = f"✅ Loaded {len(models)} models"
|
|
274
|
+
|
|
275
|
+
except Exception as e:
|
|
276
|
+
self._widgets['status_label'].value = f"❌ Error loading models: {str(e)}"
|
|
277
|
+
print(f"Error loading models: {e}")
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _display_models(self, models):
|
|
281
|
+
"""Display models in the CustomSelect widget"""
|
|
282
|
+
# Prepare data for CustomSelect
|
|
283
|
+
model_data = []
|
|
284
|
+
|
|
285
|
+
for i, model in enumerate(models):
|
|
286
|
+
model_id = model.get('modelId') if isinstance(model, dict) else model.modelId
|
|
287
|
+
num_layers = model.get('num_layers', 'Unknown')
|
|
288
|
+
parameterization = model.get('parameterization', 'Unknown')
|
|
289
|
+
|
|
290
|
+
# Format detailed information for the template
|
|
291
|
+
param_ranges_table = self._format_parameter_ranges(model.get('param_ranges', {}))
|
|
292
|
+
bound_width_ranges_table = self._format_bound_width_ranges(model.get('bound_width_ranges', {}))
|
|
293
|
+
misalignment_info = self._format_misalignment_info(model.get('misalignment_included', {}))
|
|
294
|
+
additional_metadata = self._format_additional_metadata(model.get('metadata', {}))
|
|
295
|
+
|
|
296
|
+
# Create row data for CustomSelect
|
|
297
|
+
row_data = {
|
|
298
|
+
'index': i,
|
|
299
|
+
'modelId': model_id,
|
|
300
|
+
'num_layers': str(num_layers),
|
|
301
|
+
'parameterization': parameterization,
|
|
302
|
+
'param_ranges_table': param_ranges_table,
|
|
303
|
+
'bound_width_ranges_table': bound_width_ranges_table,
|
|
304
|
+
'misalignment_info': misalignment_info,
|
|
305
|
+
'additional_metadata': additional_metadata,
|
|
306
|
+
'_model_data': model # Store full model data for selection
|
|
307
|
+
}
|
|
308
|
+
model_data.append(row_data)
|
|
309
|
+
|
|
310
|
+
# Update CustomSelect widget with new data
|
|
311
|
+
self._widgets['model_selector'].set_data(model_data)
|
|
312
|
+
|
|
313
|
+
def _format_parameter_ranges(self, param_ranges: dict) -> str:
|
|
314
|
+
"""Format parameter ranges as an HTML table"""
|
|
315
|
+
if not param_ranges:
|
|
316
|
+
return "<i>No parameter range information available</i>"
|
|
317
|
+
|
|
318
|
+
table_rows = []
|
|
319
|
+
for param_type, ranges in param_ranges.items():
|
|
320
|
+
if isinstance(ranges, list) and len(ranges) == 2:
|
|
321
|
+
min_val, max_val = ranges
|
|
322
|
+
table_rows.append(f"<tr><td style='padding: 2px 8px 2px 0;'><b>{param_type}:</b></td><td style='padding: 2px 0;'>[{min_val}, {max_val}]</td></tr>")
|
|
323
|
+
else:
|
|
324
|
+
table_rows.append(f"<tr><td style='padding: 2px 8px 2px 0;'><b>{param_type}:</b></td><td style='padding: 2px 0;'>{ranges}</td></tr>")
|
|
325
|
+
|
|
326
|
+
if not table_rows:
|
|
327
|
+
return "<i>No parameter ranges specified</i>"
|
|
328
|
+
|
|
329
|
+
return f"""
|
|
330
|
+
<table style="width: 100%; border-collapse: collapse; font-size: 11px;">
|
|
331
|
+
{''.join(table_rows)}
|
|
332
|
+
</table>
|
|
333
|
+
"""
|
|
334
|
+
|
|
335
|
+
def _format_bound_width_ranges(self, bound_ranges: dict) -> str:
|
|
336
|
+
"""Format bound width ranges as an HTML table"""
|
|
337
|
+
if not bound_ranges:
|
|
338
|
+
return "<i>No bound width information available</i>"
|
|
339
|
+
|
|
340
|
+
table_rows = []
|
|
341
|
+
for param_type, ranges in bound_ranges.items():
|
|
342
|
+
if isinstance(ranges, list) and len(ranges) == 2:
|
|
343
|
+
min_val, max_val = ranges
|
|
344
|
+
table_rows.append(f"<tr><td style='padding: 2px 8px 2px 0;'><b>{param_type}:</b></td><td style='padding: 2px 0;'>[{min_val}, {max_val}]</td></tr>")
|
|
345
|
+
else:
|
|
346
|
+
table_rows.append(f"<tr><td style='padding: 2px 8px 2px 0;'><b>{param_type}:</b></td><td style='padding: 2px 0;'>{ranges}</td></tr>")
|
|
347
|
+
|
|
348
|
+
if not table_rows:
|
|
349
|
+
return "<i>No bound width ranges specified</i>"
|
|
350
|
+
|
|
351
|
+
return f"""
|
|
352
|
+
<table style="width: 100%; border-collapse: collapse; font-size: 11px;">
|
|
353
|
+
{''.join(table_rows)}
|
|
354
|
+
</table>
|
|
355
|
+
"""
|
|
356
|
+
|
|
357
|
+
def _format_misalignment_info(self, misalignment_config: dict) -> str:
|
|
358
|
+
"""Format misalignment/shift parameter information"""
|
|
359
|
+
if not misalignment_config:
|
|
360
|
+
return "<span style='color: #dc3545;'>❌ No misalignment support</span>"
|
|
361
|
+
|
|
362
|
+
info_parts = []
|
|
363
|
+
|
|
364
|
+
# Check for different types of misalignment
|
|
365
|
+
if misalignment_config.get('enabled', False):
|
|
366
|
+
info_parts.append("<span style='color: #198754;'>✅ Misalignment supported</span>")
|
|
367
|
+
|
|
368
|
+
# Add specific shift types if available
|
|
369
|
+
shift_types = []
|
|
370
|
+
if misalignment_config.get('q_shift', False):
|
|
371
|
+
shift_types.append("Q-shift")
|
|
372
|
+
if misalignment_config.get('intensity_shift', False):
|
|
373
|
+
shift_types.append("Intensity shift")
|
|
374
|
+
if misalignment_config.get('background_shift', False):
|
|
375
|
+
shift_types.append("Background shift")
|
|
376
|
+
|
|
377
|
+
if shift_types:
|
|
378
|
+
info_parts.append(f"<br><b>Types:</b> {', '.join(shift_types)}")
|
|
379
|
+
|
|
380
|
+
# Add parameter ranges if available
|
|
381
|
+
if 'param_ranges' in misalignment_config:
|
|
382
|
+
ranges = misalignment_config['param_ranges']
|
|
383
|
+
range_info = []
|
|
384
|
+
for param, range_val in ranges.items():
|
|
385
|
+
if isinstance(range_val, list) and len(range_val) == 2:
|
|
386
|
+
range_info.append(f"{param}: [{range_val[0]}, {range_val[1]}]")
|
|
387
|
+
if range_info:
|
|
388
|
+
info_parts.append(f"<br><b>Ranges:</b> {', '.join(range_info)}")
|
|
389
|
+
else:
|
|
390
|
+
info_parts.append("<span style='color: #dc3545;'>❌ No misalignment support</span>")
|
|
391
|
+
|
|
392
|
+
return ''.join(info_parts) if info_parts else "<i>Misalignment information not specified</i>"
|
|
393
|
+
|
|
394
|
+
def _format_additional_metadata(self, metadata: dict) -> str:
|
|
395
|
+
"""Format additional model metadata"""
|
|
396
|
+
if not metadata:
|
|
397
|
+
return "<i>No additional metadata available</i>"
|
|
398
|
+
|
|
399
|
+
info_parts = []
|
|
400
|
+
|
|
401
|
+
# Q-range information
|
|
402
|
+
q_min = metadata.get('q_min')
|
|
403
|
+
q_max = metadata.get('q_max')
|
|
404
|
+
if q_min is not None and q_max is not None:
|
|
405
|
+
info_parts.append(f"<b>Q-range:</b> [{q_min}, {q_max}]")
|
|
406
|
+
|
|
407
|
+
# Discretization info
|
|
408
|
+
num_points = metadata.get('num_q_points')
|
|
409
|
+
if num_points:
|
|
410
|
+
info_parts.append(f"<b>Q-points:</b> {num_points}")
|
|
411
|
+
|
|
412
|
+
# Training info
|
|
413
|
+
training_data = metadata.get('training_data_size')
|
|
414
|
+
if training_data:
|
|
415
|
+
info_parts.append(f"<b>Training size:</b> {training_data:,} samples")
|
|
416
|
+
|
|
417
|
+
# Model architecture info
|
|
418
|
+
model_type = metadata.get('model_type')
|
|
419
|
+
if model_type:
|
|
420
|
+
info_parts.append(f"<b>Architecture:</b> {model_type}")
|
|
421
|
+
|
|
422
|
+
# Radiation type
|
|
423
|
+
radiation = metadata.get('radiation_type', metadata.get('type'))
|
|
424
|
+
if radiation:
|
|
425
|
+
info_parts.append(f"<b>Radiation:</b> {radiation}")
|
|
426
|
+
|
|
427
|
+
# Version info
|
|
428
|
+
version = metadata.get('version')
|
|
429
|
+
if version:
|
|
430
|
+
info_parts.append(f"<b>Version:</b> {version}")
|
|
431
|
+
|
|
432
|
+
# License info
|
|
433
|
+
license_info = metadata.get('license')
|
|
434
|
+
if license_info:
|
|
435
|
+
info_parts.append(f"<b>License:</b> {license_info}")
|
|
436
|
+
|
|
437
|
+
return '<br>'.join(info_parts) if info_parts else "<i>No additional information available</i>"
|
|
438
|
+
|
|
439
|
+
@property
|
|
440
|
+
def download_button(self) -> widgets.Button:
|
|
441
|
+
"""Get the download button widget for external access"""
|
|
442
|
+
return self._widgets['download_button']
|
|
443
|
+
|
|
444
|
+
def set_download_callback(self, callback: Callable[[Dict[str, Any]], None]):
|
|
445
|
+
"""Set the callback function for download button clicks.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
callback: Function that will be called with model info when download is clicked
|
|
449
|
+
"""
|
|
450
|
+
self._download_callback = callback
|
|
451
|
+
|
|
452
|
+
def _clear_selection(self):
|
|
453
|
+
"""Clear the current model selection"""
|
|
454
|
+
self.selected_config = None
|
|
455
|
+
self.selected_model = None
|
|
456
|
+
self._widgets['download_button'].disabled = True
|
|
457
|
+
self._widgets['download_button'].tooltip = 'Select a model to enable download'
|
|
458
|
+
self._widgets['selection_status'].value = "<i>Click a model row to select it</i>"
|
|
459
|
+
|
|
460
|
+
def _select_model(self, model, index: int):
|
|
461
|
+
"""Handle model selection"""
|
|
462
|
+
model_id = model.get('modelId') if isinstance(model, dict) else model.modelId
|
|
463
|
+
|
|
464
|
+
# Store selection - the model ID is the full repo ID for HF models
|
|
465
|
+
self.selected_config = model_id
|
|
466
|
+
self.selected_model = model_id
|
|
467
|
+
|
|
468
|
+
@property
|
|
469
|
+
def selected_model_config_name(self) -> str:
|
|
470
|
+
"""Get the configuration name of the currently selected model"""
|
|
471
|
+
if self.selected_model is None:
|
|
472
|
+
return
|
|
473
|
+
return self.selected_model.split('/')[-1]
|
|
474
|
+
|
|
475
|
+
@property
|
|
476
|
+
def selected_model_data(self) -> Optional[Dict[str, Any]]:
|
|
477
|
+
"""Get the data of the currently selected model"""
|
|
478
|
+
if self.selected_model is None:
|
|
479
|
+
return
|
|
480
|
+
for model in self.models_data:
|
|
481
|
+
if model.get('modelId') == self.selected_model:
|
|
482
|
+
return model
|
|
483
|
+
return None
|
|
484
|
+
|
|
485
|
+
def get_selected_model_info(self) -> Optional[Dict[str, Any]]:
|
|
486
|
+
"""Get information about the currently selected model"""
|
|
487
|
+
if self.selected_config is None:
|
|
488
|
+
return None
|
|
489
|
+
|
|
490
|
+
return {
|
|
491
|
+
'repo_id': self.organization ,
|
|
492
|
+
'config_name': self.selected_model_config_name,
|
|
493
|
+
'model_name': self.selected_model,
|
|
494
|
+
'model_data': self.selected_model_data
|
|
495
|
+
}
|