reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,495 @@
1
+ """
2
+ Jupyter Widget Model Selection Component for Reflectorch
3
+
4
+ This module contains the model browsing and selection component for
5
+ reflectometry analysis models from Hugging Face repositories.
6
+
7
+ Components:
8
+ - ModelSelection: Model browsing and selection from Hugging Face
9
+ """
10
+
11
+ from typing import Optional, Dict, Any, Callable
12
+ import ipywidgets as widgets
13
+ from huggingface_hub import HfApi
14
+
15
+ from reflectorch.extensions.jupyter.custom_select import CustomSelect
16
+
17
+ class ModelSelection:
18
+ """Model selection component for Hugging Face models.
19
+ """
20
+
21
+ def __init__(self, organization_list: tuple[str, ...] = ('reflectorch-ILL', )):
22
+ """
23
+ Initialize model selection component
24
+
25
+ Args:
26
+ organization_list: Tuple of Hugging Face organizations to browse models from
27
+ """
28
+ assert len(organization_list) > 0, "At least one organization must be provided"
29
+ self.organization = organization_list[0]
30
+ self.organization_list = organization_list
31
+ self.hf_api = HfApi()
32
+ self.models_data = []
33
+ self.selected_config = None
34
+ self.selected_model = None
35
+ self._model_cache = {} # Cache model info by organization
36
+ self._download_callback = None # External download handler
37
+
38
+ self.widget = self._create_model_browser()
39
+
40
+ def _create_model_details_template(self) -> str:
41
+ """Create a comprehensive HTML template for model details"""
42
+ return """
43
+ <div style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; line-height: 1.5;">
44
+ <div style="margin-bottom: 12px;">
45
+ <h4 style="margin: 0 0 8px 0; color: #0d6efd; font-size: 16px;">📋 {modelId}</h4>
46
+ <div style="font-size: 13px; color: #6c757d;">
47
+ <b>Layers:</b> {num_layers} | <b>Parameterization:</b> {parameterization}
48
+ </div>
49
+ </div>
50
+
51
+ <div style="margin-bottom: 12px;">
52
+ <h5 style="margin: 0 0 6px 0; color: #495057; font-size: 14px;">Parameter Ranges</h5>
53
+ <div style="background: #f8f9fa; border: 1px solid #dee2e6; border-radius: 4px; padding: 8px; font-size: 12px;">
54
+ {param_ranges_table}
55
+ </div>
56
+ </div>
57
+
58
+ <div style="margin-bottom: 12px;">
59
+ <h5 style="margin: 0 0 6px 0; color: #495057; font-size: 14px;">Prior Bound Widths</h5>
60
+ <div style="background: #f8f9fa; border: 1px solid #dee2e6; border-radius: 4px; padding: 8px; font-size: 12px;">
61
+ {bound_width_ranges_table}
62
+ </div>
63
+ </div>
64
+ </div>
65
+ """
66
+
67
+ def _create_model_browser(self) -> widgets.VBox:
68
+ """Create the model browser widget"""
69
+ # Title and description
70
+ title = widgets.HTML(
71
+ "<h3>🤗 Model Selection</h3>",
72
+ layout=widgets.Layout(margin='0px 0px 10px 0px')
73
+ )
74
+
75
+ description = widgets.HTML(
76
+ "<p>Browse and select reflectometry models from Hugging Face repositories.</p>",
77
+ layout=widgets.Layout(margin='0px 0px 15px 0px')
78
+ )
79
+
80
+ # Download button - initially disabled
81
+ download_model_button = widgets.Button(
82
+ description="📥 Download Model",
83
+ button_style='success',
84
+ disabled=True,
85
+ tooltip='Select a model to enable download',
86
+ layout=widgets.Layout(width='150px')
87
+ )
88
+
89
+ refresh_button = widgets.Button(
90
+ description="🔄 Refresh",
91
+ button_style='info',
92
+ tooltip='Refresh model list (uses cache after first load)',
93
+ layout=widgets.Layout(width='100px')
94
+ )
95
+
96
+ # Organization selection
97
+ org_dropdown = widgets.Dropdown(
98
+ options=[(org, org) for org in self.organization_list],
99
+ value=self.organization,
100
+ description='Organization:',
101
+ layout=widgets.Layout(width='300px')
102
+ )
103
+
104
+ # Controls row
105
+ controls_row = widgets.HBox([
106
+ org_dropdown,
107
+ refresh_button,
108
+ download_model_button,
109
+ ], layout=widgets.Layout(justify_content='flex-start', margin='0px 0px 15px 0px'))
110
+
111
+ # Status
112
+ status_label = widgets.HTML(
113
+ "Click 'Refresh' to load models",
114
+ layout=widgets.Layout(margin='5px 0px')
115
+ )
116
+
117
+
118
+ # Add selection status widget
119
+ selection_status = widgets.HTML(
120
+ value="<i>Click a model row to select it</i>",
121
+ layout=widgets.Layout(margin='5px 0px')
122
+ )
123
+
124
+ # Model selection dropdown
125
+ model_selector = CustomSelect(
126
+ data=[],
127
+ columns=[
128
+ ('Model ID', 'modelId'),
129
+ ('Layers', 'num_layers'),
130
+ ('Parameterization', 'parameterization')
131
+ ],
132
+ layout=widgets.Layout(width='600px'),
133
+ show_details=True,
134
+ details_template=self._create_model_details_template()
135
+ )
136
+
137
+
138
+ # Store widget references for event handlers
139
+ self._widgets = {
140
+ 'refresh_button': refresh_button,
141
+ 'download_button': download_model_button,
142
+ 'org_dropdown': org_dropdown,
143
+ 'status_label': status_label,
144
+ 'model_selector': model_selector,
145
+ 'selection_status': selection_status
146
+ }
147
+
148
+ # Setup event handlers
149
+ self._setup_model_browser_events()
150
+
151
+ # Main content area
152
+ return widgets.VBox([
153
+ title,
154
+ description,
155
+ controls_row,
156
+ status_label,
157
+ model_selector.container,
158
+ selection_status,
159
+ ])
160
+
161
+ def _setup_model_browser_events(self):
162
+ """Setup event handlers for model browser"""
163
+ widgets_dict = self._widgets
164
+
165
+ def on_refresh_click(b):
166
+ """Handle refresh button click"""
167
+ self._load_models()
168
+
169
+ def on_org_change(change):
170
+ """Handle organization dropdown change"""
171
+ self.organization = change['new']
172
+ self._load_models()
173
+ # Clear selection when organization changes
174
+ self._clear_selection()
175
+
176
+ def on_model_select(row):
177
+ """Handle model selection from CustomSelect"""
178
+ if row is None:
179
+ self._clear_selection()
180
+ return
181
+
182
+ # Extract model data and index from the row
183
+ model_index = row.get('index')
184
+ model_data = row.get('_model_data')
185
+
186
+ if model_index is not None and model_data is not None:
187
+ self._select_model(model_data, model_index)
188
+ widgets_dict['selection_status'].value = f"<b>Selected:</b> {row.get('modelId')}"
189
+ # Enable download button
190
+ widgets_dict['download_button'].disabled = False
191
+ widgets_dict['download_button'].tooltip = f"Download {row.get('modelId')}"
192
+
193
+ def on_download_click(b):
194
+ """Handle download button click"""
195
+ if self.selected_model and self._download_callback:
196
+ self._download_callback(self.get_selected_model_info())
197
+
198
+ # Connect event handlers
199
+ widgets_dict['refresh_button'].on_click(on_refresh_click)
200
+ widgets_dict['download_button'].on_click(on_download_click)
201
+ widgets_dict['org_dropdown'].observe(on_org_change, names='value')
202
+ widgets_dict['model_selector'].on_select(on_model_select)
203
+
204
+ def select_model_by_index(self, index):
205
+ """Public method to select a model by index"""
206
+ if 0 <= index < len(self.models_data):
207
+ model = self.models_data[index]
208
+ self._select_model(model, index)
209
+ model_id = model.get('modelId', 'Unknown')
210
+ self._widgets['selection_status'].value = f"<b>Selected:</b> {model_id}"
211
+ # Enable download button
212
+ self._widgets['download_button'].disabled = False
213
+ self._widgets['download_button'].tooltip = f"Download {model_id}"
214
+ # Update CustomSelect to match
215
+ self._widgets['model_selector'].set_selected_index(index)
216
+ else:
217
+ self._widgets['selection_status'].value = f"<i style='color: red;'>Invalid index: {index}</i>"
218
+ self._clear_selection()
219
+
220
+ def _load_models(self):
221
+ """Load models synchronously"""
222
+ try:
223
+ self._widgets['status_label'].value = "🔄 Loading models..."
224
+
225
+ # Get models from specific organization using HF API
226
+ if self.organization in self._model_cache:
227
+ self._widgets['status_label'].value = "🔄 Loading cached models..."
228
+ models = self._model_cache[self.organization]
229
+ else:
230
+ self._widgets['status_label'].value = f"🔄 Loading models from {self.organization}..."
231
+
232
+ # List all models from the organization
233
+ hf_models = list(self.hf_api.list_models(author=self.organization, cardData=True))
234
+
235
+ # Convert HF model objects to our format and try to get config info
236
+ models = []
237
+ for i, hf_model in enumerate(hf_models):
238
+ try:
239
+ metadata = hf_model.card_data.get("metadata", {})
240
+ parameterization = metadata.get("parameterization", "slab")
241
+ num_layers = metadata.get("number_of_layers", 0)
242
+ param_ranges = metadata.get("param_ranges", {})
243
+ bound_width_ranges = metadata.get("bound_width_ranges", {})
244
+ misalignment_included = metadata.get("shift_param_config", {})
245
+
246
+ # Create model info
247
+ model_info = {
248
+ 'index': i,
249
+ 'modelId': hf_model.id,
250
+ 'author': self.organization,
251
+ 'hf_model': hf_model,
252
+ 'metadata': hf_model.card_data.get("metadata", {}),
253
+ 'num_layers': num_layers,
254
+ 'param_ranges': param_ranges,
255
+ 'bound_width_ranges': bound_width_ranges,
256
+ 'misalignment_included': misalignment_included,
257
+ 'parameterization': parameterization,
258
+ }
259
+ models.append(model_info)
260
+
261
+ except Exception as e:
262
+ print(f"Warning: Could not process model {hf_model.id}: {e}")
263
+ continue
264
+
265
+ # Cache the results
266
+ self._model_cache[self.organization] = models
267
+
268
+ self.models_data = models
269
+
270
+ # Update UI
271
+ self._display_models(models)
272
+
273
+ self._widgets['status_label'].value = f"✅ Loaded {len(models)} models"
274
+
275
+ except Exception as e:
276
+ self._widgets['status_label'].value = f"❌ Error loading models: {str(e)}"
277
+ print(f"Error loading models: {e}")
278
+
279
+
280
+ def _display_models(self, models):
281
+ """Display models in the CustomSelect widget"""
282
+ # Prepare data for CustomSelect
283
+ model_data = []
284
+
285
+ for i, model in enumerate(models):
286
+ model_id = model.get('modelId') if isinstance(model, dict) else model.modelId
287
+ num_layers = model.get('num_layers', 'Unknown')
288
+ parameterization = model.get('parameterization', 'Unknown')
289
+
290
+ # Format detailed information for the template
291
+ param_ranges_table = self._format_parameter_ranges(model.get('param_ranges', {}))
292
+ bound_width_ranges_table = self._format_bound_width_ranges(model.get('bound_width_ranges', {}))
293
+ misalignment_info = self._format_misalignment_info(model.get('misalignment_included', {}))
294
+ additional_metadata = self._format_additional_metadata(model.get('metadata', {}))
295
+
296
+ # Create row data for CustomSelect
297
+ row_data = {
298
+ 'index': i,
299
+ 'modelId': model_id,
300
+ 'num_layers': str(num_layers),
301
+ 'parameterization': parameterization,
302
+ 'param_ranges_table': param_ranges_table,
303
+ 'bound_width_ranges_table': bound_width_ranges_table,
304
+ 'misalignment_info': misalignment_info,
305
+ 'additional_metadata': additional_metadata,
306
+ '_model_data': model # Store full model data for selection
307
+ }
308
+ model_data.append(row_data)
309
+
310
+ # Update CustomSelect widget with new data
311
+ self._widgets['model_selector'].set_data(model_data)
312
+
313
+ def _format_parameter_ranges(self, param_ranges: dict) -> str:
314
+ """Format parameter ranges as an HTML table"""
315
+ if not param_ranges:
316
+ return "<i>No parameter range information available</i>"
317
+
318
+ table_rows = []
319
+ for param_type, ranges in param_ranges.items():
320
+ if isinstance(ranges, list) and len(ranges) == 2:
321
+ min_val, max_val = ranges
322
+ table_rows.append(f"<tr><td style='padding: 2px 8px 2px 0;'><b>{param_type}:</b></td><td style='padding: 2px 0;'>[{min_val}, {max_val}]</td></tr>")
323
+ else:
324
+ table_rows.append(f"<tr><td style='padding: 2px 8px 2px 0;'><b>{param_type}:</b></td><td style='padding: 2px 0;'>{ranges}</td></tr>")
325
+
326
+ if not table_rows:
327
+ return "<i>No parameter ranges specified</i>"
328
+
329
+ return f"""
330
+ <table style="width: 100%; border-collapse: collapse; font-size: 11px;">
331
+ {''.join(table_rows)}
332
+ </table>
333
+ """
334
+
335
+ def _format_bound_width_ranges(self, bound_ranges: dict) -> str:
336
+ """Format bound width ranges as an HTML table"""
337
+ if not bound_ranges:
338
+ return "<i>No bound width information available</i>"
339
+
340
+ table_rows = []
341
+ for param_type, ranges in bound_ranges.items():
342
+ if isinstance(ranges, list) and len(ranges) == 2:
343
+ min_val, max_val = ranges
344
+ table_rows.append(f"<tr><td style='padding: 2px 8px 2px 0;'><b>{param_type}:</b></td><td style='padding: 2px 0;'>[{min_val}, {max_val}]</td></tr>")
345
+ else:
346
+ table_rows.append(f"<tr><td style='padding: 2px 8px 2px 0;'><b>{param_type}:</b></td><td style='padding: 2px 0;'>{ranges}</td></tr>")
347
+
348
+ if not table_rows:
349
+ return "<i>No bound width ranges specified</i>"
350
+
351
+ return f"""
352
+ <table style="width: 100%; border-collapse: collapse; font-size: 11px;">
353
+ {''.join(table_rows)}
354
+ </table>
355
+ """
356
+
357
+ def _format_misalignment_info(self, misalignment_config: dict) -> str:
358
+ """Format misalignment/shift parameter information"""
359
+ if not misalignment_config:
360
+ return "<span style='color: #dc3545;'>❌ No misalignment support</span>"
361
+
362
+ info_parts = []
363
+
364
+ # Check for different types of misalignment
365
+ if misalignment_config.get('enabled', False):
366
+ info_parts.append("<span style='color: #198754;'>✅ Misalignment supported</span>")
367
+
368
+ # Add specific shift types if available
369
+ shift_types = []
370
+ if misalignment_config.get('q_shift', False):
371
+ shift_types.append("Q-shift")
372
+ if misalignment_config.get('intensity_shift', False):
373
+ shift_types.append("Intensity shift")
374
+ if misalignment_config.get('background_shift', False):
375
+ shift_types.append("Background shift")
376
+
377
+ if shift_types:
378
+ info_parts.append(f"<br><b>Types:</b> {', '.join(shift_types)}")
379
+
380
+ # Add parameter ranges if available
381
+ if 'param_ranges' in misalignment_config:
382
+ ranges = misalignment_config['param_ranges']
383
+ range_info = []
384
+ for param, range_val in ranges.items():
385
+ if isinstance(range_val, list) and len(range_val) == 2:
386
+ range_info.append(f"{param}: [{range_val[0]}, {range_val[1]}]")
387
+ if range_info:
388
+ info_parts.append(f"<br><b>Ranges:</b> {', '.join(range_info)}")
389
+ else:
390
+ info_parts.append("<span style='color: #dc3545;'>❌ No misalignment support</span>")
391
+
392
+ return ''.join(info_parts) if info_parts else "<i>Misalignment information not specified</i>"
393
+
394
+ def _format_additional_metadata(self, metadata: dict) -> str:
395
+ """Format additional model metadata"""
396
+ if not metadata:
397
+ return "<i>No additional metadata available</i>"
398
+
399
+ info_parts = []
400
+
401
+ # Q-range information
402
+ q_min = metadata.get('q_min')
403
+ q_max = metadata.get('q_max')
404
+ if q_min is not None and q_max is not None:
405
+ info_parts.append(f"<b>Q-range:</b> [{q_min}, {q_max}]")
406
+
407
+ # Discretization info
408
+ num_points = metadata.get('num_q_points')
409
+ if num_points:
410
+ info_parts.append(f"<b>Q-points:</b> {num_points}")
411
+
412
+ # Training info
413
+ training_data = metadata.get('training_data_size')
414
+ if training_data:
415
+ info_parts.append(f"<b>Training size:</b> {training_data:,} samples")
416
+
417
+ # Model architecture info
418
+ model_type = metadata.get('model_type')
419
+ if model_type:
420
+ info_parts.append(f"<b>Architecture:</b> {model_type}")
421
+
422
+ # Radiation type
423
+ radiation = metadata.get('radiation_type', metadata.get('type'))
424
+ if radiation:
425
+ info_parts.append(f"<b>Radiation:</b> {radiation}")
426
+
427
+ # Version info
428
+ version = metadata.get('version')
429
+ if version:
430
+ info_parts.append(f"<b>Version:</b> {version}")
431
+
432
+ # License info
433
+ license_info = metadata.get('license')
434
+ if license_info:
435
+ info_parts.append(f"<b>License:</b> {license_info}")
436
+
437
+ return '<br>'.join(info_parts) if info_parts else "<i>No additional information available</i>"
438
+
439
+ @property
440
+ def download_button(self) -> widgets.Button:
441
+ """Get the download button widget for external access"""
442
+ return self._widgets['download_button']
443
+
444
+ def set_download_callback(self, callback: Callable[[Dict[str, Any]], None]):
445
+ """Set the callback function for download button clicks.
446
+
447
+ Args:
448
+ callback: Function that will be called with model info when download is clicked
449
+ """
450
+ self._download_callback = callback
451
+
452
+ def _clear_selection(self):
453
+ """Clear the current model selection"""
454
+ self.selected_config = None
455
+ self.selected_model = None
456
+ self._widgets['download_button'].disabled = True
457
+ self._widgets['download_button'].tooltip = 'Select a model to enable download'
458
+ self._widgets['selection_status'].value = "<i>Click a model row to select it</i>"
459
+
460
+ def _select_model(self, model, index: int):
461
+ """Handle model selection"""
462
+ model_id = model.get('modelId') if isinstance(model, dict) else model.modelId
463
+
464
+ # Store selection - the model ID is the full repo ID for HF models
465
+ self.selected_config = model_id
466
+ self.selected_model = model_id
467
+
468
+ @property
469
+ def selected_model_config_name(self) -> str:
470
+ """Get the configuration name of the currently selected model"""
471
+ if self.selected_model is None:
472
+ return
473
+ return self.selected_model.split('/')[-1]
474
+
475
+ @property
476
+ def selected_model_data(self) -> Optional[Dict[str, Any]]:
477
+ """Get the data of the currently selected model"""
478
+ if self.selected_model is None:
479
+ return
480
+ for model in self.models_data:
481
+ if model.get('modelId') == self.selected_model:
482
+ return model
483
+ return None
484
+
485
+ def get_selected_model_info(self) -> Optional[Dict[str, Any]]:
486
+ """Get information about the currently selected model"""
487
+ if self.selected_config is None:
488
+ return None
489
+
490
+ return {
491
+ 'repo_id': self.organization ,
492
+ 'config_name': self.selected_model_config_name,
493
+ 'model_name': self.selected_model,
494
+ 'model_data': self.selected_model_data
495
+ }