featrixsphere 0.2.5563__py3-none-any.whl → 0.2.5978__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,584 @@
1
+ """
2
+ FeatrixNotebookHelper for Jupyter notebook visualization.
3
+
4
+ Provides visualization methods for training metrics, embedding spaces,
5
+ and model analysis in Jupyter notebooks.
6
+ """
7
+
8
+ import logging
9
+ from typing import Dict, Any, Optional, List, Union, Tuple, TYPE_CHECKING
10
+
11
+ if TYPE_CHECKING:
12
+ from .foundational_model import FoundationalModel
13
+ from .predictor import Predictor
14
+ from .http_client import ClientContext
15
+ import matplotlib
16
+ import plotly
17
+ import numpy as np
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class FeatrixNotebookHelper:
23
+ """
24
+ Helper class for Jupyter notebook visualization.
25
+
26
+ Access via featrix.get_notebook() - no separate import needed!
27
+
28
+ Usage:
29
+ notebook = featrix.get_notebook()
30
+
31
+ # Visualize training
32
+ fig = notebook.training_loss(fm, style='notebook')
33
+ fig.show()
34
+
35
+ # 3D embedding space
36
+ fig = notebook.embedding_space_3d(fm, interactive=True)
37
+ fig.show()
38
+
39
+ # Training movie
40
+ movie = notebook.training_movie(fm, notebook_mode=True)
41
+ """
42
+
43
+ def __init__(self, ctx: Optional['ClientContext'] = None):
44
+ """Initialize notebook helper with client context."""
45
+ self._ctx = ctx
46
+
47
+ def training_loss(
48
+ self,
49
+ model: Union['FoundationalModel', 'Predictor'],
50
+ style: str = 'notebook',
51
+ show_learning_rate: bool = True,
52
+ smooth: bool = True,
53
+ figsize: Tuple[int, int] = (12, 6)
54
+ ) -> Any:
55
+ """
56
+ Plot training loss curves.
57
+
58
+ Args:
59
+ model: FoundationalModel or Predictor to visualize
60
+ style: Plot style ('notebook', 'paper', 'presentation')
61
+ show_learning_rate: Show learning rate on secondary axis
62
+ smooth: Apply smoothing to loss curves
63
+ figsize: Figure size (width, height)
64
+
65
+ Returns:
66
+ matplotlib Figure
67
+
68
+ Example:
69
+ fig = notebook.training_loss(fm, style='notebook')
70
+ fig.show()
71
+ """
72
+ try:
73
+ import matplotlib.pyplot as plt
74
+ import numpy as np
75
+ except ImportError:
76
+ raise ImportError("matplotlib is required for training_loss visualization")
77
+
78
+ # Get training metrics
79
+ metrics = model.get_training_metrics() if hasattr(model, 'get_training_metrics') else model.get_metrics()
80
+
81
+ # Apply style
82
+ if style == 'notebook':
83
+ plt.style.use('seaborn-v0_8-whitegrid')
84
+ elif style == 'paper':
85
+ plt.style.use('seaborn-v0_8-paper')
86
+ elif style == 'presentation':
87
+ plt.style.use('seaborn-v0_8-talk')
88
+
89
+ fig, ax1 = plt.subplots(figsize=figsize)
90
+
91
+ # Get loss history
92
+ loss_history = metrics.get('loss_history', metrics.get('training_loss', []))
93
+ epochs = list(range(1, len(loss_history) + 1))
94
+
95
+ # Apply smoothing if requested
96
+ if smooth and len(loss_history) > 10:
97
+ window = min(10, len(loss_history) // 5)
98
+ loss_smoothed = np.convolve(loss_history, np.ones(window)/window, mode='valid')
99
+ epochs_smoothed = epochs[window-1:]
100
+ ax1.plot(epochs, loss_history, alpha=0.3, color='blue', label='Loss (raw)')
101
+ ax1.plot(epochs_smoothed, loss_smoothed, color='blue', linewidth=2, label='Loss (smoothed)')
102
+ else:
103
+ ax1.plot(epochs, loss_history, color='blue', linewidth=2, label='Loss')
104
+
105
+ ax1.set_xlabel('Epoch')
106
+ ax1.set_ylabel('Loss', color='blue')
107
+ ax1.tick_params(axis='y', labelcolor='blue')
108
+
109
+ # Show learning rate if available and requested
110
+ if show_learning_rate:
111
+ lr_history = metrics.get('lr_history', metrics.get('learning_rate', []))
112
+ if lr_history:
113
+ ax2 = ax1.twinx()
114
+ ax2.plot(epochs[:len(lr_history)], lr_history, color='red', linestyle='--', label='Learning Rate')
115
+ ax2.set_ylabel('Learning Rate', color='red')
116
+ ax2.tick_params(axis='y', labelcolor='red')
117
+ ax2.set_yscale('log')
118
+
119
+ # Title
120
+ model_type = 'FM' if hasattr(model, 'dimensions') else 'Predictor'
121
+ model_id = getattr(model, 'id', 'unknown')[:8]
122
+ ax1.set_title(f'{model_type} Training Loss ({model_id}...)')
123
+
124
+ fig.tight_layout()
125
+ return fig
126
+
127
+ def embedding_space_3d(
128
+ self,
129
+ fm: 'FoundationalModel',
130
+ sample_size: int = 2000,
131
+ interactive: bool = True,
132
+ color_by: Optional[str] = None,
133
+ figsize: Tuple[int, int] = (800, 600)
134
+ ) -> Any:
135
+ """
136
+ Create 3D visualization of embedding space.
137
+
138
+ Args:
139
+ fm: FoundationalModel to visualize
140
+ sample_size: Number of points to sample
141
+ interactive: Use interactive plotly (True) or static matplotlib (False)
142
+ color_by: Column name to color points by
143
+ figsize: Figure size (width, height) for plotly
144
+
145
+ Returns:
146
+ plotly Figure (if interactive=True) or matplotlib Figure
147
+
148
+ Example:
149
+ fig = notebook.embedding_space_3d(fm, interactive=True)
150
+ fig.show()
151
+ """
152
+ # Get projections from server
153
+ projections = fm.get_projections()
154
+
155
+ points_3d = projections.get('3d', projections.get('pca_3d', []))
156
+ labels = projections.get('labels', [])
157
+ colors = projections.get('colors', projections.get(color_by, []))
158
+
159
+ if not points_3d:
160
+ raise ValueError("No 3D projection data available")
161
+
162
+ if interactive:
163
+ try:
164
+ import plotly.graph_objects as go
165
+ except ImportError:
166
+ raise ImportError("plotly is required for interactive 3D visualization")
167
+
168
+ x = [p[0] for p in points_3d[:sample_size]]
169
+ y = [p[1] for p in points_3d[:sample_size]]
170
+ z = [p[2] for p in points_3d[:sample_size]]
171
+
172
+ trace = go.Scatter3d(
173
+ x=x, y=y, z=z,
174
+ mode='markers',
175
+ marker=dict(
176
+ size=3,
177
+ color=colors[:sample_size] if colors else None,
178
+ colorscale='Viridis',
179
+ opacity=0.8
180
+ ),
181
+ text=labels[:sample_size] if labels else None,
182
+ hovertemplate='%{text}<br>(%{x:.2f}, %{y:.2f}, %{z:.2f})<extra></extra>'
183
+ )
184
+
185
+ fig = go.Figure(data=[trace])
186
+ fig.update_layout(
187
+ title=f"Embedding Space 3D ({fm.id[:8]}...)",
188
+ width=figsize[0],
189
+ height=figsize[1],
190
+ scene=dict(
191
+ xaxis_title='PC1',
192
+ yaxis_title='PC2',
193
+ zaxis_title='PC3',
194
+ )
195
+ )
196
+ return fig
197
+
198
+ else:
199
+ try:
200
+ import matplotlib.pyplot as plt
201
+ from mpl_toolkits.mplot3d import Axes3D
202
+ except ImportError:
203
+ raise ImportError("matplotlib is required for static 3D visualization")
204
+
205
+ fig = plt.figure(figsize=(figsize[0]/100, figsize[1]/100))
206
+ ax = fig.add_subplot(111, projection='3d')
207
+
208
+ x = [p[0] for p in points_3d[:sample_size]]
209
+ y = [p[1] for p in points_3d[:sample_size]]
210
+ z = [p[2] for p in points_3d[:sample_size]]
211
+
212
+ scatter = ax.scatter(x, y, z, c=colors[:sample_size] if colors else None,
213
+ cmap='viridis', alpha=0.6, s=10)
214
+
215
+ ax.set_xlabel('PC1')
216
+ ax.set_ylabel('PC2')
217
+ ax.set_zlabel('PC3')
218
+ ax.set_title(f"Embedding Space 3D ({fm.id[:8]}...)")
219
+
220
+ if colors:
221
+ plt.colorbar(scatter)
222
+
223
+ return fig
224
+
225
+ def training_movie(
226
+ self,
227
+ model: Union['FoundationalModel', 'Predictor'],
228
+ notebook_mode: bool = True,
229
+ fps: int = 2,
230
+ figsize: Tuple[int, int] = (10, 6)
231
+ ) -> Any:
232
+ """
233
+ Create animated training movie.
234
+
235
+ Shows how the model evolves during training, visualizing
236
+ loss curves and optionally embedding space evolution.
237
+
238
+ Args:
239
+ model: FoundationalModel or Predictor to visualize
240
+ notebook_mode: True for Jupyter widget, False for animation
241
+ fps: Frames per second
242
+ figsize: Figure size
243
+
244
+ Returns:
245
+ ipywidgets Widget (if notebook_mode) or matplotlib animation
246
+
247
+ Example:
248
+ movie = notebook.training_movie(fm, notebook_mode=True)
249
+ # Widget displays automatically in Jupyter
250
+ """
251
+ try:
252
+ import matplotlib.pyplot as plt
253
+ import matplotlib.animation as animation
254
+ import numpy as np
255
+ except ImportError:
256
+ raise ImportError("matplotlib is required for training_movie")
257
+
258
+ # Get training metrics
259
+ metrics = model.get_training_metrics() if hasattr(model, 'get_training_metrics') else model.get_metrics()
260
+ loss_history = metrics.get('loss_history', metrics.get('training_loss', []))
261
+
262
+ if not loss_history:
263
+ raise ValueError("No training history available")
264
+
265
+ fig, ax = plt.subplots(figsize=figsize)
266
+ ax.set_xlim(0, len(loss_history))
267
+ ax.set_ylim(0, max(loss_history) * 1.1)
268
+ ax.set_xlabel('Epoch')
269
+ ax.set_ylabel('Loss')
270
+ ax.set_title('Training Progress')
271
+
272
+ line, = ax.plot([], [], 'b-', linewidth=2)
273
+ epoch_text = ax.text(0.02, 0.95, '', transform=ax.transAxes,
274
+ fontsize=12, verticalalignment='top')
275
+
276
+ def init():
277
+ line.set_data([], [])
278
+ epoch_text.set_text('')
279
+ return line, epoch_text
280
+
281
+ def animate(frame):
282
+ x = list(range(frame + 1))
283
+ y = loss_history[:frame + 1]
284
+ line.set_data(x, y)
285
+ epoch_text.set_text(f'Epoch: {frame + 1}/{len(loss_history)}')
286
+ return line, epoch_text
287
+
288
+ anim = animation.FuncAnimation(
289
+ fig, animate, init_func=init,
290
+ frames=len(loss_history), interval=1000//fps,
291
+ blit=True, repeat=False
292
+ )
293
+
294
+ if notebook_mode:
295
+ try:
296
+ from IPython.display import HTML
297
+ plt.close(fig) # Prevent static display
298
+ return HTML(anim.to_jshtml())
299
+ except ImportError:
300
+ return anim
301
+ else:
302
+ return anim
303
+
304
+ def embedding_evolution(
305
+ self,
306
+ fm: 'FoundationalModel',
307
+ epoch_range: Optional[Tuple[int, int]] = None,
308
+ interactive: bool = True,
309
+ sample_size: int = 500
310
+ ) -> Any:
311
+ """
312
+ Visualize embedding evolution over epochs.
313
+
314
+ Shows how the embedding space changes during training.
315
+
316
+ Args:
317
+ fm: FoundationalModel to visualize
318
+ epoch_range: Tuple of (start_epoch, end_epoch) or None for all
319
+ interactive: Use interactive plotly
320
+ sample_size: Number of points to sample
321
+
322
+ Returns:
323
+ plotly Figure (if interactive) or matplotlib Figure
324
+
325
+ Example:
326
+ fig = notebook.embedding_evolution(fm, epoch_range=(1, 50))
327
+ fig.show()
328
+ """
329
+ # This requires epoch-by-epoch projection data which may not be available
330
+ # Return a placeholder for now
331
+ try:
332
+ import plotly.graph_objects as go
333
+ except ImportError:
334
+ raise ImportError("plotly is required for embedding_evolution")
335
+
336
+ # Try to get evolution data
337
+ try:
338
+ metrics = fm.get_training_metrics()
339
+ evolution = metrics.get('embedding_evolution', [])
340
+ except Exception:
341
+ evolution = []
342
+
343
+ if not evolution:
344
+ # Create placeholder figure
345
+ fig = go.Figure()
346
+ fig.add_annotation(
347
+ text="Embedding evolution data not available",
348
+ xref="paper", yref="paper",
349
+ x=0.5, y=0.5, showarrow=False,
350
+ font=dict(size=16)
351
+ )
352
+ fig.update_layout(
353
+ title="Embedding Evolution (data not available)",
354
+ width=800, height=600
355
+ )
356
+ return fig
357
+
358
+ # If data is available, create animation
359
+ frames = []
360
+ for epoch_data in evolution:
361
+ epoch = epoch_data.get('epoch', 0)
362
+ points = epoch_data.get('points', [])[:sample_size]
363
+
364
+ frame = go.Frame(
365
+ data=[go.Scatter(
366
+ x=[p[0] for p in points],
367
+ y=[p[1] for p in points],
368
+ mode='markers',
369
+ marker=dict(size=5, opacity=0.6)
370
+ )],
371
+ name=str(epoch)
372
+ )
373
+ frames.append(frame)
374
+
375
+ fig = go.Figure(
376
+ data=frames[0].data if frames else [],
377
+ frames=frames,
378
+ layout=go.Layout(
379
+ title="Embedding Evolution",
380
+ updatemenus=[{
381
+ "type": "buttons",
382
+ "buttons": [
383
+ {"label": "Play", "method": "animate", "args": [None, {"frame": {"duration": 500}}]},
384
+ {"label": "Pause", "method": "animate", "args": [[None], {"mode": "immediate"}]}
385
+ ]
386
+ }],
387
+ sliders=[{
388
+ "steps": [{"args": [[str(e['epoch'])], {"frame": {"duration": 0}}], "label": str(e['epoch'])}
389
+ for e in evolution],
390
+ "currentvalue": {"prefix": "Epoch: "}
391
+ }]
392
+ )
393
+ )
394
+
395
+ return fig
396
+
397
+ def training_comparison(
398
+ self,
399
+ models: List[Union['FoundationalModel', 'Predictor']],
400
+ labels: Optional[List[str]] = None,
401
+ figsize: Tuple[int, int] = (12, 6)
402
+ ) -> Any:
403
+ """
404
+ Compare training across multiple models.
405
+
406
+ Args:
407
+ models: List of models to compare
408
+ labels: Optional labels for each model
409
+ figsize: Figure size
410
+
411
+ Returns:
412
+ matplotlib Figure
413
+
414
+ Example:
415
+ fig = notebook.training_comparison([fm1, fm2], labels=['v1', 'v2'])
416
+ fig.show()
417
+ """
418
+ try:
419
+ import matplotlib.pyplot as plt
420
+ import numpy as np
421
+ except ImportError:
422
+ raise ImportError("matplotlib is required for training_comparison")
423
+
424
+ fig, ax = plt.subplots(figsize=figsize)
425
+
426
+ colors = plt.cm.tab10(np.linspace(0, 1, len(models)))
427
+
428
+ for i, model in enumerate(models):
429
+ metrics = model.get_training_metrics() if hasattr(model, 'get_training_metrics') else model.get_metrics()
430
+ loss_history = metrics.get('loss_history', metrics.get('training_loss', []))
431
+
432
+ label = labels[i] if labels and i < len(labels) else f"Model {i+1}"
433
+ ax.plot(loss_history, color=colors[i], linewidth=2, label=label)
434
+
435
+ ax.set_xlabel('Epoch')
436
+ ax.set_ylabel('Loss')
437
+ ax.set_title('Training Comparison')
438
+ ax.legend()
439
+ ax.grid(True, alpha=0.3)
440
+
441
+ fig.tight_layout()
442
+ return fig
443
+
444
+ def embedding_space_training(
445
+ self,
446
+ fm: 'FoundationalModel',
447
+ style: str = 'notebook',
448
+ figsize: Tuple[int, int] = (14, 6)
449
+ ) -> Any:
450
+ """
451
+ Plot embedding space training metrics.
452
+
453
+ Shows detailed metrics specific to foundational model training.
454
+
455
+ Args:
456
+ fm: FoundationalModel to visualize
457
+ style: Plot style
458
+ figsize: Figure size
459
+
460
+ Returns:
461
+ matplotlib Figure
462
+ """
463
+ try:
464
+ import matplotlib.pyplot as plt
465
+ import numpy as np
466
+ except ImportError:
467
+ raise ImportError("matplotlib is required")
468
+
469
+ metrics = fm.get_training_metrics()
470
+
471
+ fig, axes = plt.subplots(1, 3, figsize=figsize)
472
+
473
+ # Loss curve
474
+ loss = metrics.get('loss_history', [])
475
+ if loss:
476
+ axes[0].plot(loss, 'b-', linewidth=2)
477
+ axes[0].set_title('Training Loss')
478
+ axes[0].set_xlabel('Epoch')
479
+ axes[0].set_ylabel('Loss')
480
+ axes[0].grid(True, alpha=0.3)
481
+
482
+ # Learning rate
483
+ lr = metrics.get('lr_history', [])
484
+ if lr:
485
+ axes[1].semilogy(lr, 'r-', linewidth=2)
486
+ axes[1].set_title('Learning Rate')
487
+ axes[1].set_xlabel('Epoch')
488
+ axes[1].set_ylabel('LR')
489
+ axes[1].grid(True, alpha=0.3)
490
+
491
+ # Gradient norm if available
492
+ grad_norm = metrics.get('grad_norm_history', [])
493
+ if grad_norm:
494
+ axes[2].plot(grad_norm, 'g-', linewidth=2)
495
+ axes[2].set_title('Gradient Norm')
496
+ axes[2].set_xlabel('Epoch')
497
+ axes[2].set_ylabel('Norm')
498
+ axes[2].grid(True, alpha=0.3)
499
+ else:
500
+ axes[2].text(0.5, 0.5, 'No gradient data', ha='center', va='center')
501
+ axes[2].set_title('Gradient Norm')
502
+
503
+ fig.suptitle(f'Embedding Space Training ({fm.id[:8]}...)', fontsize=14)
504
+ fig.tight_layout()
505
+ return fig
506
+
507
+ def single_predictor_training(
508
+ self,
509
+ predictor: 'Predictor',
510
+ style: str = 'notebook',
511
+ figsize: Tuple[int, int] = (14, 6)
512
+ ) -> Any:
513
+ """
514
+ Plot single predictor training metrics.
515
+
516
+ Shows detailed metrics specific to predictor training including
517
+ accuracy, AUC, and confusion matrix evolution.
518
+
519
+ Args:
520
+ predictor: Predictor to visualize
521
+ style: Plot style
522
+ figsize: Figure size
523
+
524
+ Returns:
525
+ matplotlib Figure
526
+ """
527
+ try:
528
+ import matplotlib.pyplot as plt
529
+ import numpy as np
530
+ except ImportError:
531
+ raise ImportError("matplotlib is required")
532
+
533
+ metrics = predictor.get_metrics()
534
+
535
+ fig, axes = plt.subplots(1, 3, figsize=figsize)
536
+
537
+ # Loss curve
538
+ loss = metrics.get('loss_history', metrics.get('single_predictor', {}).get('loss_history', []))
539
+ if loss:
540
+ axes[0].plot(loss, 'b-', linewidth=2)
541
+ axes[0].set_title('Training Loss')
542
+ axes[0].set_xlabel('Epoch')
543
+ axes[0].set_ylabel('Loss')
544
+ axes[0].grid(True, alpha=0.3)
545
+
546
+ # Accuracy
547
+ accuracy = metrics.get('accuracy_history', metrics.get('single_predictor', {}).get('accuracy_history', []))
548
+ if accuracy:
549
+ axes[1].plot(accuracy, 'g-', linewidth=2)
550
+ axes[1].set_title('Accuracy')
551
+ axes[1].set_xlabel('Epoch')
552
+ axes[1].set_ylabel('Accuracy')
553
+ axes[1].set_ylim(0, 1)
554
+ axes[1].grid(True, alpha=0.3)
555
+ else:
556
+ final_acc = metrics.get('single_predictor', {}).get('accuracy', predictor.accuracy)
557
+ if final_acc:
558
+ axes[1].axhline(y=final_acc, color='g', linewidth=2)
559
+ axes[1].set_title(f'Final Accuracy: {final_acc:.4f}')
560
+ else:
561
+ axes[1].text(0.5, 0.5, 'No accuracy data', ha='center', va='center')
562
+ axes[1].set_title('Accuracy')
563
+
564
+ # AUC
565
+ auc_history = metrics.get('auc_history', metrics.get('single_predictor', {}).get('auc_history', []))
566
+ if auc_history:
567
+ axes[2].plot(auc_history, 'r-', linewidth=2)
568
+ axes[2].set_title('AUC')
569
+ axes[2].set_xlabel('Epoch')
570
+ axes[2].set_ylabel('AUC')
571
+ axes[2].set_ylim(0, 1)
572
+ axes[2].grid(True, alpha=0.3)
573
+ else:
574
+ final_auc = metrics.get('single_predictor', {}).get('roc_auc', predictor.auc)
575
+ if final_auc:
576
+ axes[2].axhline(y=final_auc, color='r', linewidth=2)
577
+ axes[2].set_title(f'Final AUC: {final_auc:.4f}')
578
+ else:
579
+ axes[2].text(0.5, 0.5, 'No AUC data', ha='center', va='center')
580
+ axes[2].set_title('AUC')
581
+
582
+ fig.suptitle(f'Predictor Training ({predictor.target_column})', fontsize=14)
583
+ fig.tight_layout()
584
+ return fig