featrixsphere 0.1.583__py3-none-any.whl → 0.2.1314__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,489 +0,0 @@
1
- """
2
- Improved Notebook Training Movie Functions V2
3
-
4
- This module contains enhanced versions of the training movie functions used in FeatrixSphereClient.
5
- Key improvements:
6
- - Fixed scaling issues: No more jarring scale changes between frames
7
- - Precomputed axis limits and data structures
8
- - Better performance with reduced redundant calculations
9
- - Consistent 1:1:1 aspect ratio for 3D plots
10
- - Optimized rendering for smoother animation
11
-
12
- These functions can be used to replace the existing movie methods in client.py
13
- when you want better performance and fixed scaling behavior.
14
- """
15
-
16
- import numpy as np
17
- import matplotlib.pyplot as plt
18
- from typing import Dict, Any, Optional
19
- import pandas as pd
20
-
21
-
22
- class NotebookMovieV2:
23
- """Improved notebook movie functions with performance optimizations."""
24
-
25
- def __init__(self, client):
26
- """Initialize with reference to the FeatrixSphereClient."""
27
- self.client = client
28
- self._cached_limits = {} # Cache computed axis limits
29
- self._cached_data = {} # Cache processed epoch data
30
-
31
- def plot_embedding_evolution_frame_v2(self, ax, epoch_projections: Dict[str, Any], current_epoch: int):
32
- """
33
- IMPROVED: Plot 3D embedding space for current epoch with fixed scaling.
34
-
35
- Key improvements:
36
- - Computes and caches fixed axis limits to prevent scale changes
37
- - Maintains 1:1:1 aspect ratio for 3D plots
38
- - Better error handling and fallbacks
39
- - Optimized data processing
40
- """
41
- session_key = id(epoch_projections) # Use object ID as cache key
42
-
43
- # Check cache for precomputed limits
44
- if session_key not in self._cached_limits:
45
- self._precompute_limits_and_data(epoch_projections, session_key)
46
-
47
- if not epoch_projections:
48
- ax.text(0.5, 0.5, 'No embedding evolution data',
49
- transform=ax.transAxes, ha='center', va='center')
50
- return
51
-
52
- # Find projection data for current epoch
53
- current_projection = None
54
- for proj_data in epoch_projections.values():
55
- if proj_data.get('epoch', 0) == current_epoch:
56
- current_projection = proj_data
57
- break
58
-
59
- if not current_projection:
60
- ax.text(0.5, 0.5, f'No projection data for epoch {current_epoch}',
61
- transform=ax.transAxes, ha='center', va='center')
62
- return
63
-
64
- # Extract coordinates
65
- coords = current_projection.get('coords', [])
66
- if not coords:
67
- ax.text(0.5, 0.5, 'No coordinate data',
68
- transform=ax.transAxes, ha='center', va='center')
69
- return
70
-
71
- df = pd.DataFrame(coords)
72
-
73
- # PERFORMANCE FIX: Use cached fixed limits instead of recalculating
74
- cached_data = self._cached_limits[session_key]
75
-
76
- if all(col in df.columns for col in ['x', 'y', 'z']):
77
- # TRUE 3D VISUALIZATION - the star of the show!
78
- from mpl_toolkits.mplot3d import Axes3D
79
- import numpy as np
80
-
81
- # Clear the axis and recreate as 3D if needed
82
- if not hasattr(ax, 'zaxis'):
83
- ax.remove()
84
- ax = ax.figure.add_subplot(ax.get_geometry()[0], ax.get_geometry()[1],
85
- ax.get_geometry()[2], projection='3d')
86
-
87
- scatter = ax.scatter(df['x'], df['y'], df['z'], alpha=0.7, s=25, c='steelblue')
88
- ax.set_xlabel('Dimension 1', fontweight='bold')
89
- ax.set_ylabel('Dimension 2', fontweight='bold')
90
- ax.set_zlabel('Dimension 3', fontweight='bold')
91
-
92
- # Set equal aspect ratio for better 3D visualization
93
- if len(df) > 0:
94
- max_range = np.max([
95
- np.max(df['x']) - np.min(df['x']),
96
- np.max(df['y']) - np.min(df['y']),
97
- np.max(df['z']) - np.min(df['z'])
98
- ])
99
-
100
- mid_x = (np.max(df['x']) + np.min(df['x'])) * 0.5
101
- mid_y = (np.max(df['y']) + np.min(df['y'])) * 0.5
102
- mid_z = (np.max(df['z']) + np.min(df['z'])) * 0.5
103
-
104
- ax.set_xlim(mid_x - max_range/2, mid_x + max_range/2)
105
- ax.set_ylim(mid_y - max_range/2, mid_y + max_range/2)
106
- ax.set_zlim(mid_z - max_range/2, mid_z + max_range/2)
107
-
108
- elif 'x' in df.columns and 'y' in df.columns:
109
- # 2D projection fallback
110
- ax.scatter(df['x'], df['y'], alpha=0.6, s=20)
111
- ax.set_xlabel('Dimension 1', fontweight='bold')
112
- ax.set_ylabel('Dimension 2', fontweight='bold')
113
-
114
- # SCALING FIX: Use fixed axis limits
115
- if cached_data['xlim'] and cached_data['ylim']:
116
- ax.set_xlim(cached_data['xlim'])
117
- ax.set_ylim(cached_data['ylim'])
118
-
119
- ax.set_title(f'Embedding Space - Epoch {current_epoch}', fontweight='bold')
120
- ax.grid(True, alpha=0.3)
121
-
122
- # IMPROVEMENT: Ensure equal aspect ratio for better visualization
123
- ax.set_aspect('equal', adjustable='box')
124
-
125
- def _precompute_limits_and_data(self, epoch_projections: Dict[str, Any], session_key: str):
126
- """
127
- PERFORMANCE OPTIMIZATION: Precompute fixed axis limits and process data once.
128
-
129
- This prevents the jarring scale changes that occur when axis limits are
130
- recalculated for every frame.
131
- """
132
- all_x, all_y, all_z = [], [], []
133
-
134
- # Collect all coordinates from all epochs
135
- for proj_data in epoch_projections.values():
136
- coords = proj_data.get('coords', [])
137
- if coords:
138
- df = pd.DataFrame(coords)
139
-
140
- if 'x' in df.columns:
141
- all_x.extend(df['x'].tolist())
142
- if 'y' in df.columns:
143
- all_y.extend(df['y'].tolist())
144
- if 'z' in df.columns:
145
- all_z.extend(df['z'].tolist())
146
-
147
- # Calculate fixed axis limits with margin
148
- margin = 0.1 # 10% margin
149
-
150
- xlim = None
151
- ylim = None
152
- zlim = None
153
-
154
- if all_x:
155
- x_range = max(all_x) - min(all_x)
156
- x_margin = x_range * margin
157
- xlim = [min(all_x) - x_margin, max(all_x) + x_margin]
158
-
159
- if all_y:
160
- y_range = max(all_y) - min(all_y)
161
- y_margin = y_range * margin
162
- ylim = [min(all_y) - y_margin, max(all_y) + y_margin]
163
-
164
- if all_z:
165
- z_range = max(all_z) - min(all_z)
166
- z_margin = z_range * margin
167
- zlim = [min(all_z) - z_margin, max(all_z) + z_margin]
168
-
169
- # Cache the computed limits
170
- self._cached_limits[session_key] = {
171
- 'xlim': xlim,
172
- 'ylim': ylim,
173
- 'zlim': zlim,
174
- 'total_points': len(all_x),
175
- 'epochs': len(epoch_projections)
176
- }
177
-
178
- print(f"📏 V2: Precomputed fixed axis limits for {len(epoch_projections)} epochs")
179
- print(f" 🎯 X: {xlim}, Y: {ylim}, Z: {zlim}")
180
- print(f" 📊 Total points: {len(all_x)}")
181
-
182
- def create_static_evolution_plot_v2(self, epoch_projections: Dict[str, Any],
183
- sample_size: int, color_by: Optional[str],
184
- session_id: str):
185
- """
186
- IMPROVED: Create static evolution plot with fixed scaling and better performance.
187
-
188
- Key improvements:
189
- - Fixed axis limits across all subplots for consistent scaling
190
- - Better subplot layout and sizing
191
- - Optimized data sampling
192
- - Enhanced visual styling
193
- """
194
- epochs = sorted([v.get('epoch', 0) for v in epoch_projections.values()])
195
-
196
- if not epochs:
197
- fig, ax = plt.subplots(1, 1, figsize=(10, 8))
198
- ax.text(0.5, 0.5, 'No epoch projection data', transform=ax.transAxes, ha='center', va='center')
199
- return fig
200
-
201
- # Precompute limits for consistent scaling
202
- session_key = id(epoch_projections)
203
- if session_key not in self._cached_limits:
204
- self._precompute_limits_and_data(epoch_projections, session_key)
205
-
206
- cached_data = self._cached_limits[session_key]
207
-
208
- # Create optimized subplot grid
209
- n_epochs = len(epochs)
210
- cols = min(4, n_epochs)
211
- rows = (n_epochs + cols - 1) // cols
212
-
213
- # Better figure sizing
214
- fig_width = max(12, 3 * cols)
215
- fig_height = max(8, 3 * rows)
216
-
217
- fig, axes = plt.subplots(rows, cols, figsize=(fig_width, fig_height))
218
- if rows == 1 and cols == 1:
219
- axes = [axes]
220
- elif rows == 1:
221
- axes = axes
222
- else:
223
- axes = axes.flatten()
224
-
225
- for i, epoch in enumerate(epochs):
226
- if i >= len(axes):
227
- break
228
-
229
- ax = axes[i]
230
-
231
- # Find data for this epoch
232
- epoch_data = None
233
- for proj_data in epoch_projections.values():
234
- if proj_data.get('epoch', 0) == epoch:
235
- epoch_data = proj_data
236
- break
237
-
238
- if epoch_data:
239
- coords = epoch_data.get('coords', [])
240
- if coords:
241
- df = pd.DataFrame(coords)
242
-
243
- # Optimized sampling
244
- if len(df) > sample_size:
245
- df = df.sample(sample_size, random_state=42)
246
-
247
- if all(col in df.columns for col in ['x', 'y', 'z']):
248
- # TRUE 3D VISUALIZATION - the star of the show!
249
- from mpl_toolkits.mplot3d import Axes3D
250
- import numpy as np
251
-
252
- # Clear the axis and recreate as 3D if needed
253
- if not hasattr(ax, 'zaxis'):
254
- ax.remove()
255
- ax = ax.figure.add_subplot(ax.get_geometry()[0], ax.get_geometry()[1],
256
- ax.get_geometry()[2], projection='3d')
257
-
258
- if color_by and color_by in df.columns:
259
- scatter = ax.scatter(df['x'], df['y'], df['z'], alpha=0.7, s=25,
260
- c=df[color_by], cmap='viridis')
261
- else:
262
- scatter = ax.scatter(df['x'], df['y'], df['z'], alpha=0.7, s=25,
263
- c='steelblue')
264
-
265
- ax.set_xlabel('Dimension 1', fontsize=10)
266
- ax.set_ylabel('Dimension 2', fontsize=10)
267
- ax.set_zlabel('Dimension 3', fontsize=10)
268
-
269
- # Set equal aspect ratio for better 3D visualization
270
- if len(df) > 0:
271
- max_range = np.max([
272
- np.max(df['x']) - np.min(df['x']),
273
- np.max(df['y']) - np.min(df['y']),
274
- np.max(df['z']) - np.min(df['z'])
275
- ])
276
-
277
- mid_x = (np.max(df['x']) + np.min(df['x'])) * 0.5
278
- mid_y = (np.max(df['y']) + np.min(df['y'])) * 0.5
279
- mid_z = (np.max(df['z']) + np.min(df['z'])) * 0.5
280
-
281
- ax.set_xlim(mid_x - max_range/2, mid_x + max_range/2)
282
- ax.set_ylim(mid_y - max_range/2, mid_y + max_range/2)
283
- ax.set_zlim(mid_z - max_range/2, mid_z + max_range/2)
284
-
285
- elif 'x' in df.columns and 'y' in df.columns:
286
- # 2D projection fallback
287
- if color_by and color_by in df.columns:
288
- scatter = ax.scatter(df['x'], df['y'], alpha=0.7, s=25,
289
- c=df[color_by], cmap='viridis')
290
- else:
291
- scatter = ax.scatter(df['x'], df['y'], alpha=0.7, s=25,
292
- c='steelblue')
293
-
294
- ax.set_xlabel('Dimension 1', fontsize=10)
295
- ax.set_ylabel('Dimension 2', fontsize=10)
296
-
297
- # SCALING FIX: Apply fixed limits to all subplots
298
- if cached_data['xlim'] and cached_data['ylim']:
299
- ax.set_xlim(cached_data['xlim'])
300
- ax.set_ylim(cached_data['ylim'])
301
-
302
- # IMPROVEMENT: Equal aspect ratio for better visualization
303
- ax.set_aspect('equal', adjustable='box')
304
-
305
- ax.set_title(f'Epoch {epoch}', fontweight='bold', fontsize=11)
306
- ax.grid(True, alpha=0.3)
307
-
308
- # Better tick formatting
309
- ax.tick_params(axis='both', which='major', labelsize=8)
310
-
311
- # Hide empty subplots
312
- for i in range(n_epochs, len(axes)):
313
- axes[i].set_visible(False)
314
-
315
- plt.suptitle(f'Embedding Evolution V2 - Session {session_id[:12]}...',
316
- fontsize=16, fontweight='bold')
317
- plt.tight_layout()
318
-
319
- return fig
320
-
321
- def create_interactive_training_movie_v2(self, training_metrics, epoch_projections, session_id,
322
- show_embedding_evolution, show_loss_evolution):
323
- """
324
- IMPROVED: Create interactive training movie with better performance and fixed scaling.
325
-
326
- Key improvements:
327
- - Uses improved embedding plot functions with fixed scaling
328
- - Better widget layout and controls
329
- - Enhanced error handling
330
- - Performance optimizations
331
- """
332
- try:
333
- from ipywidgets import widgets, interact, Layout
334
- from IPython.display import display, HTML
335
- except ImportError:
336
- print("⚠️ ipywidgets not available - falling back to static movie")
337
- return self.client._create_static_training_movie(
338
- training_metrics, epoch_projections, (15, 10), 'notebook',
339
- None, show_embedding_evolution, show_loss_evolution, 2
340
- )
341
-
342
- # Extract training data
343
- progress_info = training_metrics.get('progress_info', {})
344
- loss_history = progress_info.get('loss_history', [])
345
- training_info = training_metrics.get('training_info', [])
346
-
347
- if not loss_history and not training_info:
348
- return HTML("<div style='color: red;'>No training data available for movie</div>")
349
-
350
- # Combine all epochs
351
- all_epochs = []
352
- if loss_history:
353
- all_epochs.extend([entry.get('epoch', 0) for entry in loss_history])
354
- if training_info:
355
- all_epochs.extend([entry.get('epoch', 0) for entry in training_info])
356
-
357
- if not all_epochs:
358
- return HTML("<div style='color: red;'>No epoch data found</div>")
359
-
360
- max_epoch = max(all_epochs)
361
-
362
- # Precompute data for better performance
363
- if show_embedding_evolution and epoch_projections:
364
- session_key = id(epoch_projections)
365
- if session_key not in self._cached_limits:
366
- self._precompute_limits_and_data(epoch_projections, session_key)
367
-
368
- # Create interactive widget
369
- def update_movie(epoch=1):
370
- """Update movie display for given epoch."""
371
- try:
372
- # Create subplot layout - EMBEDDING SPACE IS THE STAR!
373
- if show_embedding_evolution and show_loss_evolution:
374
- fig = plt.figure(figsize=(16, 8))
375
- ax2 = fig.add_subplot(1, 2, 1, projection='3d') # Large 3D embedding plot
376
- ax1 = fig.add_subplot(1, 2, 2) # Smaller loss plot
377
- elif show_loss_evolution:
378
- fig, ax1 = plt.subplots(1, 1, figsize=(10, 6))
379
- ax2 = None
380
- else:
381
- fig = plt.figure(figsize=(12, 8))
382
- ax2 = fig.add_subplot(1, 1, 1, projection='3d') # Full 3D embedding plot
383
- ax1 = None
384
-
385
- # Plot embedding evolution for current epoch (USING IMPROVED V2 METHOD) - THE STAR!
386
- if show_embedding_evolution and ax2 is not None:
387
- self.plot_embedding_evolution_frame_v2(ax2, epoch_projections, epoch)
388
- ax2.set_title('🌌 Featrix Sphere - 3D Embedding Space', fontweight='bold', fontsize=14)
389
-
390
- # Plot loss evolution up to current epoch - as sparkline
391
- if show_loss_evolution and ax1 is not None:
392
- self.client._plot_loss_evolution_frame(ax1, loss_history, training_info, epoch)
393
- ax1.set_title('📊 Training Loss', fontweight='bold', fontsize=12)
394
- ax1.tick_params(axis='both', which='major', labelsize=10)
395
- ax1.grid(True, alpha=0.3)
396
-
397
- plt.tight_layout()
398
- plt.show()
399
-
400
- except Exception as e:
401
- print(f"Error in movie frame {epoch}: {e}")
402
-
403
- # Enhanced controls with better styling
404
- epoch_slider = widgets.IntSlider(
405
- value=1,
406
- min=1,
407
- max=max_epoch,
408
- step=1,
409
- description='Epoch:',
410
- style={'description_width': '70px'},
411
- layout=Layout(width='600px')
412
- )
413
-
414
- # Improved play controls
415
- play_button = widgets.Play(
416
- value=1,
417
- min=1,
418
- max=max_epoch,
419
- step=1,
420
- description="Press play",
421
- disabled=False,
422
- interval=800 # Slightly slower for better visibility
423
- )
424
-
425
- speed_slider = widgets.IntSlider(
426
- value=800,
427
- min=200,
428
- max=2000,
429
- step=100,
430
- description='Speed (ms):',
431
- style={'description_width': '90px'},
432
- layout=Layout(width='350px')
433
- )
434
-
435
- # Link controls
436
- widgets.jslink((play_button, 'value'), (epoch_slider, 'value'))
437
-
438
- def update_speed(change):
439
- play_button.interval = change['new']
440
- speed_slider.observe(update_speed, names='value')
441
-
442
- # Enhanced layout
443
- control_box = widgets.VBox([
444
- widgets.HTML("<b>🎬 Animation Controls</b>"),
445
- widgets.HBox([play_button, speed_slider])
446
- ])
447
-
448
- main_controls = widgets.HBox([control_box, epoch_slider])
449
-
450
- # Display controls and interactive output
451
- display(main_controls)
452
- interact(update_movie, epoch=epoch_slider)
453
-
454
- return HTML(f"""
455
- <div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
456
- color: white; padding: 20px; border-radius: 15px; margin: 15px 0;
457
- box-shadow: 0 8px 32px rgba(0,0,0,0.3);'>
458
- <h3>🎬 Interactive Training Movie V2 - Session {session_id[:12]}...</h3>
459
- <div style='background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; margin-top: 10px;'>
460
- <p><strong>✨ Enhanced Features:</strong></p>
461
- <ul style='margin: 10px 0; padding-left: 20px;'>
462
- <li><strong>🔒 Fixed Scaling:</strong> No more jarring scale changes between frames</li>
463
- <li><strong>⚡ Performance:</strong> Optimized rendering for smoother playback</li>
464
- <li><strong>🎮 Enhanced Controls:</strong> Better play/pause and speed controls</li>
465
- <li><strong>📐 Consistent Ratios:</strong> 1:1:1 aspect ratio maintained throughout</li>
466
- </ul>
467
- <p><strong>🎮 How to Use:</strong></p>
468
- <ul style='margin: 10px 0; padding-left: 20px;'>
469
- <li>Use the <strong>Play button</strong> to automatically advance through epochs</li>
470
- <li>Adjust <strong>Speed</strong> to control playback rate (200ms = fast, 2000ms = slow)</li>
471
- <li>Drag the <strong>Epoch slider</strong> to jump to specific epochs</li>
472
- <li>Watch how training progresses with <em>consistent scaling</em>!</li>
473
- </ul>
474
- </div>
475
- </div>
476
- """)
477
-
478
-
479
- # Usage example:
480
- #
481
- # from featrixsphere.client_movie_v2 import NotebookMovieV2
482
- #
483
- # # In your FeatrixSphereClient instance:
484
- # movie_v2 = NotebookMovieV2(client)
485
- #
486
- # # Use improved functions:
487
- # movie_v2.plot_embedding_evolution_frame_v2(ax, epoch_projections, current_epoch)
488
- # movie_v2.create_interactive_training_movie_v2(training_metrics, epoch_projections,
489
- # session_id, True, True)