featrixsphere 0.2.5566__py3-none-any.whl → 0.2.5978__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- featrixsphere/__init__.py +37 -18
- featrixsphere/api/__init__.py +50 -0
- featrixsphere/api/api_endpoint.py +280 -0
- featrixsphere/api/client.py +396 -0
- featrixsphere/api/foundational_model.py +658 -0
- featrixsphere/api/http_client.py +209 -0
- featrixsphere/api/notebook_helper.py +584 -0
- featrixsphere/api/prediction_result.py +231 -0
- featrixsphere/api/predictor.py +537 -0
- featrixsphere/api/reference_record.py +227 -0
- featrixsphere/api/vector_database.py +269 -0
- featrixsphere/client.py +211 -8
- {featrixsphere-0.2.5566.dist-info → featrixsphere-0.2.5978.dist-info}/METADATA +1 -1
- featrixsphere-0.2.5978.dist-info/RECORD +17 -0
- featrixsphere-0.2.5566.dist-info/RECORD +0 -7
- {featrixsphere-0.2.5566.dist-info → featrixsphere-0.2.5978.dist-info}/WHEEL +0 -0
- {featrixsphere-0.2.5566.dist-info → featrixsphere-0.2.5978.dist-info}/entry_points.txt +0 -0
- {featrixsphere-0.2.5566.dist-info → featrixsphere-0.2.5978.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,584 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FeatrixNotebookHelper for Jupyter notebook visualization.
|
|
3
|
+
|
|
4
|
+
Provides visualization methods for training metrics, embedding spaces,
|
|
5
|
+
and model analysis in Jupyter notebooks.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Dict, Any, Optional, List, Union, Tuple, TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from .foundational_model import FoundationalModel
|
|
13
|
+
from .predictor import Predictor
|
|
14
|
+
from .http_client import ClientContext
|
|
15
|
+
import matplotlib
|
|
16
|
+
import plotly
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FeatrixNotebookHelper:
|
|
23
|
+
"""
|
|
24
|
+
Helper class for Jupyter notebook visualization.
|
|
25
|
+
|
|
26
|
+
Access via featrix.get_notebook() - no separate import needed!
|
|
27
|
+
|
|
28
|
+
Usage:
|
|
29
|
+
notebook = featrix.get_notebook()
|
|
30
|
+
|
|
31
|
+
# Visualize training
|
|
32
|
+
fig = notebook.training_loss(fm, style='notebook')
|
|
33
|
+
fig.show()
|
|
34
|
+
|
|
35
|
+
# 3D embedding space
|
|
36
|
+
fig = notebook.embedding_space_3d(fm, interactive=True)
|
|
37
|
+
fig.show()
|
|
38
|
+
|
|
39
|
+
# Training movie
|
|
40
|
+
movie = notebook.training_movie(fm, notebook_mode=True)
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, ctx: Optional['ClientContext'] = None):
|
|
44
|
+
"""Initialize notebook helper with client context."""
|
|
45
|
+
self._ctx = ctx
|
|
46
|
+
|
|
47
|
+
def training_loss(
|
|
48
|
+
self,
|
|
49
|
+
model: Union['FoundationalModel', 'Predictor'],
|
|
50
|
+
style: str = 'notebook',
|
|
51
|
+
show_learning_rate: bool = True,
|
|
52
|
+
smooth: bool = True,
|
|
53
|
+
figsize: Tuple[int, int] = (12, 6)
|
|
54
|
+
) -> Any:
|
|
55
|
+
"""
|
|
56
|
+
Plot training loss curves.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
model: FoundationalModel or Predictor to visualize
|
|
60
|
+
style: Plot style ('notebook', 'paper', 'presentation')
|
|
61
|
+
show_learning_rate: Show learning rate on secondary axis
|
|
62
|
+
smooth: Apply smoothing to loss curves
|
|
63
|
+
figsize: Figure size (width, height)
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
matplotlib Figure
|
|
67
|
+
|
|
68
|
+
Example:
|
|
69
|
+
fig = notebook.training_loss(fm, style='notebook')
|
|
70
|
+
fig.show()
|
|
71
|
+
"""
|
|
72
|
+
try:
|
|
73
|
+
import matplotlib.pyplot as plt
|
|
74
|
+
import numpy as np
|
|
75
|
+
except ImportError:
|
|
76
|
+
raise ImportError("matplotlib is required for training_loss visualization")
|
|
77
|
+
|
|
78
|
+
# Get training metrics
|
|
79
|
+
metrics = model.get_training_metrics() if hasattr(model, 'get_training_metrics') else model.get_metrics()
|
|
80
|
+
|
|
81
|
+
# Apply style
|
|
82
|
+
if style == 'notebook':
|
|
83
|
+
plt.style.use('seaborn-v0_8-whitegrid')
|
|
84
|
+
elif style == 'paper':
|
|
85
|
+
plt.style.use('seaborn-v0_8-paper')
|
|
86
|
+
elif style == 'presentation':
|
|
87
|
+
plt.style.use('seaborn-v0_8-talk')
|
|
88
|
+
|
|
89
|
+
fig, ax1 = plt.subplots(figsize=figsize)
|
|
90
|
+
|
|
91
|
+
# Get loss history
|
|
92
|
+
loss_history = metrics.get('loss_history', metrics.get('training_loss', []))
|
|
93
|
+
epochs = list(range(1, len(loss_history) + 1))
|
|
94
|
+
|
|
95
|
+
# Apply smoothing if requested
|
|
96
|
+
if smooth and len(loss_history) > 10:
|
|
97
|
+
window = min(10, len(loss_history) // 5)
|
|
98
|
+
loss_smoothed = np.convolve(loss_history, np.ones(window)/window, mode='valid')
|
|
99
|
+
epochs_smoothed = epochs[window-1:]
|
|
100
|
+
ax1.plot(epochs, loss_history, alpha=0.3, color='blue', label='Loss (raw)')
|
|
101
|
+
ax1.plot(epochs_smoothed, loss_smoothed, color='blue', linewidth=2, label='Loss (smoothed)')
|
|
102
|
+
else:
|
|
103
|
+
ax1.plot(epochs, loss_history, color='blue', linewidth=2, label='Loss')
|
|
104
|
+
|
|
105
|
+
ax1.set_xlabel('Epoch')
|
|
106
|
+
ax1.set_ylabel('Loss', color='blue')
|
|
107
|
+
ax1.tick_params(axis='y', labelcolor='blue')
|
|
108
|
+
|
|
109
|
+
# Show learning rate if available and requested
|
|
110
|
+
if show_learning_rate:
|
|
111
|
+
lr_history = metrics.get('lr_history', metrics.get('learning_rate', []))
|
|
112
|
+
if lr_history:
|
|
113
|
+
ax2 = ax1.twinx()
|
|
114
|
+
ax2.plot(epochs[:len(lr_history)], lr_history, color='red', linestyle='--', label='Learning Rate')
|
|
115
|
+
ax2.set_ylabel('Learning Rate', color='red')
|
|
116
|
+
ax2.tick_params(axis='y', labelcolor='red')
|
|
117
|
+
ax2.set_yscale('log')
|
|
118
|
+
|
|
119
|
+
# Title
|
|
120
|
+
model_type = 'FM' if hasattr(model, 'dimensions') else 'Predictor'
|
|
121
|
+
model_id = getattr(model, 'id', 'unknown')[:8]
|
|
122
|
+
ax1.set_title(f'{model_type} Training Loss ({model_id}...)')
|
|
123
|
+
|
|
124
|
+
fig.tight_layout()
|
|
125
|
+
return fig
|
|
126
|
+
|
|
127
|
+
def embedding_space_3d(
|
|
128
|
+
self,
|
|
129
|
+
fm: 'FoundationalModel',
|
|
130
|
+
sample_size: int = 2000,
|
|
131
|
+
interactive: bool = True,
|
|
132
|
+
color_by: Optional[str] = None,
|
|
133
|
+
figsize: Tuple[int, int] = (800, 600)
|
|
134
|
+
) -> Any:
|
|
135
|
+
"""
|
|
136
|
+
Create 3D visualization of embedding space.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
fm: FoundationalModel to visualize
|
|
140
|
+
sample_size: Number of points to sample
|
|
141
|
+
interactive: Use interactive plotly (True) or static matplotlib (False)
|
|
142
|
+
color_by: Column name to color points by
|
|
143
|
+
figsize: Figure size (width, height) for plotly
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
plotly Figure (if interactive=True) or matplotlib Figure
|
|
147
|
+
|
|
148
|
+
Example:
|
|
149
|
+
fig = notebook.embedding_space_3d(fm, interactive=True)
|
|
150
|
+
fig.show()
|
|
151
|
+
"""
|
|
152
|
+
# Get projections from server
|
|
153
|
+
projections = fm.get_projections()
|
|
154
|
+
|
|
155
|
+
points_3d = projections.get('3d', projections.get('pca_3d', []))
|
|
156
|
+
labels = projections.get('labels', [])
|
|
157
|
+
colors = projections.get('colors', projections.get(color_by, []))
|
|
158
|
+
|
|
159
|
+
if not points_3d:
|
|
160
|
+
raise ValueError("No 3D projection data available")
|
|
161
|
+
|
|
162
|
+
if interactive:
|
|
163
|
+
try:
|
|
164
|
+
import plotly.graph_objects as go
|
|
165
|
+
except ImportError:
|
|
166
|
+
raise ImportError("plotly is required for interactive 3D visualization")
|
|
167
|
+
|
|
168
|
+
x = [p[0] for p in points_3d[:sample_size]]
|
|
169
|
+
y = [p[1] for p in points_3d[:sample_size]]
|
|
170
|
+
z = [p[2] for p in points_3d[:sample_size]]
|
|
171
|
+
|
|
172
|
+
trace = go.Scatter3d(
|
|
173
|
+
x=x, y=y, z=z,
|
|
174
|
+
mode='markers',
|
|
175
|
+
marker=dict(
|
|
176
|
+
size=3,
|
|
177
|
+
color=colors[:sample_size] if colors else None,
|
|
178
|
+
colorscale='Viridis',
|
|
179
|
+
opacity=0.8
|
|
180
|
+
),
|
|
181
|
+
text=labels[:sample_size] if labels else None,
|
|
182
|
+
hovertemplate='%{text}<br>(%{x:.2f}, %{y:.2f}, %{z:.2f})<extra></extra>'
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
fig = go.Figure(data=[trace])
|
|
186
|
+
fig.update_layout(
|
|
187
|
+
title=f"Embedding Space 3D ({fm.id[:8]}...)",
|
|
188
|
+
width=figsize[0],
|
|
189
|
+
height=figsize[1],
|
|
190
|
+
scene=dict(
|
|
191
|
+
xaxis_title='PC1',
|
|
192
|
+
yaxis_title='PC2',
|
|
193
|
+
zaxis_title='PC3',
|
|
194
|
+
)
|
|
195
|
+
)
|
|
196
|
+
return fig
|
|
197
|
+
|
|
198
|
+
else:
|
|
199
|
+
try:
|
|
200
|
+
import matplotlib.pyplot as plt
|
|
201
|
+
from mpl_toolkits.mplot3d import Axes3D
|
|
202
|
+
except ImportError:
|
|
203
|
+
raise ImportError("matplotlib is required for static 3D visualization")
|
|
204
|
+
|
|
205
|
+
fig = plt.figure(figsize=(figsize[0]/100, figsize[1]/100))
|
|
206
|
+
ax = fig.add_subplot(111, projection='3d')
|
|
207
|
+
|
|
208
|
+
x = [p[0] for p in points_3d[:sample_size]]
|
|
209
|
+
y = [p[1] for p in points_3d[:sample_size]]
|
|
210
|
+
z = [p[2] for p in points_3d[:sample_size]]
|
|
211
|
+
|
|
212
|
+
scatter = ax.scatter(x, y, z, c=colors[:sample_size] if colors else None,
|
|
213
|
+
cmap='viridis', alpha=0.6, s=10)
|
|
214
|
+
|
|
215
|
+
ax.set_xlabel('PC1')
|
|
216
|
+
ax.set_ylabel('PC2')
|
|
217
|
+
ax.set_zlabel('PC3')
|
|
218
|
+
ax.set_title(f"Embedding Space 3D ({fm.id[:8]}...)")
|
|
219
|
+
|
|
220
|
+
if colors:
|
|
221
|
+
plt.colorbar(scatter)
|
|
222
|
+
|
|
223
|
+
return fig
|
|
224
|
+
|
|
225
|
+
def training_movie(
|
|
226
|
+
self,
|
|
227
|
+
model: Union['FoundationalModel', 'Predictor'],
|
|
228
|
+
notebook_mode: bool = True,
|
|
229
|
+
fps: int = 2,
|
|
230
|
+
figsize: Tuple[int, int] = (10, 6)
|
|
231
|
+
) -> Any:
|
|
232
|
+
"""
|
|
233
|
+
Create animated training movie.
|
|
234
|
+
|
|
235
|
+
Shows how the model evolves during training, visualizing
|
|
236
|
+
loss curves and optionally embedding space evolution.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
model: FoundationalModel or Predictor to visualize
|
|
240
|
+
notebook_mode: True for Jupyter widget, False for animation
|
|
241
|
+
fps: Frames per second
|
|
242
|
+
figsize: Figure size
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
ipywidgets Widget (if notebook_mode) or matplotlib animation
|
|
246
|
+
|
|
247
|
+
Example:
|
|
248
|
+
movie = notebook.training_movie(fm, notebook_mode=True)
|
|
249
|
+
# Widget displays automatically in Jupyter
|
|
250
|
+
"""
|
|
251
|
+
try:
|
|
252
|
+
import matplotlib.pyplot as plt
|
|
253
|
+
import matplotlib.animation as animation
|
|
254
|
+
import numpy as np
|
|
255
|
+
except ImportError:
|
|
256
|
+
raise ImportError("matplotlib is required for training_movie")
|
|
257
|
+
|
|
258
|
+
# Get training metrics
|
|
259
|
+
metrics = model.get_training_metrics() if hasattr(model, 'get_training_metrics') else model.get_metrics()
|
|
260
|
+
loss_history = metrics.get('loss_history', metrics.get('training_loss', []))
|
|
261
|
+
|
|
262
|
+
if not loss_history:
|
|
263
|
+
raise ValueError("No training history available")
|
|
264
|
+
|
|
265
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
266
|
+
ax.set_xlim(0, len(loss_history))
|
|
267
|
+
ax.set_ylim(0, max(loss_history) * 1.1)
|
|
268
|
+
ax.set_xlabel('Epoch')
|
|
269
|
+
ax.set_ylabel('Loss')
|
|
270
|
+
ax.set_title('Training Progress')
|
|
271
|
+
|
|
272
|
+
line, = ax.plot([], [], 'b-', linewidth=2)
|
|
273
|
+
epoch_text = ax.text(0.02, 0.95, '', transform=ax.transAxes,
|
|
274
|
+
fontsize=12, verticalalignment='top')
|
|
275
|
+
|
|
276
|
+
def init():
|
|
277
|
+
line.set_data([], [])
|
|
278
|
+
epoch_text.set_text('')
|
|
279
|
+
return line, epoch_text
|
|
280
|
+
|
|
281
|
+
def animate(frame):
|
|
282
|
+
x = list(range(frame + 1))
|
|
283
|
+
y = loss_history[:frame + 1]
|
|
284
|
+
line.set_data(x, y)
|
|
285
|
+
epoch_text.set_text(f'Epoch: {frame + 1}/{len(loss_history)}')
|
|
286
|
+
return line, epoch_text
|
|
287
|
+
|
|
288
|
+
anim = animation.FuncAnimation(
|
|
289
|
+
fig, animate, init_func=init,
|
|
290
|
+
frames=len(loss_history), interval=1000//fps,
|
|
291
|
+
blit=True, repeat=False
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
if notebook_mode:
|
|
295
|
+
try:
|
|
296
|
+
from IPython.display import HTML
|
|
297
|
+
plt.close(fig) # Prevent static display
|
|
298
|
+
return HTML(anim.to_jshtml())
|
|
299
|
+
except ImportError:
|
|
300
|
+
return anim
|
|
301
|
+
else:
|
|
302
|
+
return anim
|
|
303
|
+
|
|
304
|
+
def embedding_evolution(
|
|
305
|
+
self,
|
|
306
|
+
fm: 'FoundationalModel',
|
|
307
|
+
epoch_range: Optional[Tuple[int, int]] = None,
|
|
308
|
+
interactive: bool = True,
|
|
309
|
+
sample_size: int = 500
|
|
310
|
+
) -> Any:
|
|
311
|
+
"""
|
|
312
|
+
Visualize embedding evolution over epochs.
|
|
313
|
+
|
|
314
|
+
Shows how the embedding space changes during training.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
fm: FoundationalModel to visualize
|
|
318
|
+
epoch_range: Tuple of (start_epoch, end_epoch) or None for all
|
|
319
|
+
interactive: Use interactive plotly
|
|
320
|
+
sample_size: Number of points to sample
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
plotly Figure (if interactive) or matplotlib Figure
|
|
324
|
+
|
|
325
|
+
Example:
|
|
326
|
+
fig = notebook.embedding_evolution(fm, epoch_range=(1, 50))
|
|
327
|
+
fig.show()
|
|
328
|
+
"""
|
|
329
|
+
# This requires epoch-by-epoch projection data which may not be available
|
|
330
|
+
# Return a placeholder for now
|
|
331
|
+
try:
|
|
332
|
+
import plotly.graph_objects as go
|
|
333
|
+
except ImportError:
|
|
334
|
+
raise ImportError("plotly is required for embedding_evolution")
|
|
335
|
+
|
|
336
|
+
# Try to get evolution data
|
|
337
|
+
try:
|
|
338
|
+
metrics = fm.get_training_metrics()
|
|
339
|
+
evolution = metrics.get('embedding_evolution', [])
|
|
340
|
+
except Exception:
|
|
341
|
+
evolution = []
|
|
342
|
+
|
|
343
|
+
if not evolution:
|
|
344
|
+
# Create placeholder figure
|
|
345
|
+
fig = go.Figure()
|
|
346
|
+
fig.add_annotation(
|
|
347
|
+
text="Embedding evolution data not available",
|
|
348
|
+
xref="paper", yref="paper",
|
|
349
|
+
x=0.5, y=0.5, showarrow=False,
|
|
350
|
+
font=dict(size=16)
|
|
351
|
+
)
|
|
352
|
+
fig.update_layout(
|
|
353
|
+
title="Embedding Evolution (data not available)",
|
|
354
|
+
width=800, height=600
|
|
355
|
+
)
|
|
356
|
+
return fig
|
|
357
|
+
|
|
358
|
+
# If data is available, create animation
|
|
359
|
+
frames = []
|
|
360
|
+
for epoch_data in evolution:
|
|
361
|
+
epoch = epoch_data.get('epoch', 0)
|
|
362
|
+
points = epoch_data.get('points', [])[:sample_size]
|
|
363
|
+
|
|
364
|
+
frame = go.Frame(
|
|
365
|
+
data=[go.Scatter(
|
|
366
|
+
x=[p[0] for p in points],
|
|
367
|
+
y=[p[1] for p in points],
|
|
368
|
+
mode='markers',
|
|
369
|
+
marker=dict(size=5, opacity=0.6)
|
|
370
|
+
)],
|
|
371
|
+
name=str(epoch)
|
|
372
|
+
)
|
|
373
|
+
frames.append(frame)
|
|
374
|
+
|
|
375
|
+
fig = go.Figure(
|
|
376
|
+
data=frames[0].data if frames else [],
|
|
377
|
+
frames=frames,
|
|
378
|
+
layout=go.Layout(
|
|
379
|
+
title="Embedding Evolution",
|
|
380
|
+
updatemenus=[{
|
|
381
|
+
"type": "buttons",
|
|
382
|
+
"buttons": [
|
|
383
|
+
{"label": "Play", "method": "animate", "args": [None, {"frame": {"duration": 500}}]},
|
|
384
|
+
{"label": "Pause", "method": "animate", "args": [[None], {"mode": "immediate"}]}
|
|
385
|
+
]
|
|
386
|
+
}],
|
|
387
|
+
sliders=[{
|
|
388
|
+
"steps": [{"args": [[str(e['epoch'])], {"frame": {"duration": 0}}], "label": str(e['epoch'])}
|
|
389
|
+
for e in evolution],
|
|
390
|
+
"currentvalue": {"prefix": "Epoch: "}
|
|
391
|
+
}]
|
|
392
|
+
)
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
return fig
|
|
396
|
+
|
|
397
|
+
def training_comparison(
|
|
398
|
+
self,
|
|
399
|
+
models: List[Union['FoundationalModel', 'Predictor']],
|
|
400
|
+
labels: Optional[List[str]] = None,
|
|
401
|
+
figsize: Tuple[int, int] = (12, 6)
|
|
402
|
+
) -> Any:
|
|
403
|
+
"""
|
|
404
|
+
Compare training across multiple models.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
models: List of models to compare
|
|
408
|
+
labels: Optional labels for each model
|
|
409
|
+
figsize: Figure size
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
matplotlib Figure
|
|
413
|
+
|
|
414
|
+
Example:
|
|
415
|
+
fig = notebook.training_comparison([fm1, fm2], labels=['v1', 'v2'])
|
|
416
|
+
fig.show()
|
|
417
|
+
"""
|
|
418
|
+
try:
|
|
419
|
+
import matplotlib.pyplot as plt
|
|
420
|
+
import numpy as np
|
|
421
|
+
except ImportError:
|
|
422
|
+
raise ImportError("matplotlib is required for training_comparison")
|
|
423
|
+
|
|
424
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
425
|
+
|
|
426
|
+
colors = plt.cm.tab10(np.linspace(0, 1, len(models)))
|
|
427
|
+
|
|
428
|
+
for i, model in enumerate(models):
|
|
429
|
+
metrics = model.get_training_metrics() if hasattr(model, 'get_training_metrics') else model.get_metrics()
|
|
430
|
+
loss_history = metrics.get('loss_history', metrics.get('training_loss', []))
|
|
431
|
+
|
|
432
|
+
label = labels[i] if labels and i < len(labels) else f"Model {i+1}"
|
|
433
|
+
ax.plot(loss_history, color=colors[i], linewidth=2, label=label)
|
|
434
|
+
|
|
435
|
+
ax.set_xlabel('Epoch')
|
|
436
|
+
ax.set_ylabel('Loss')
|
|
437
|
+
ax.set_title('Training Comparison')
|
|
438
|
+
ax.legend()
|
|
439
|
+
ax.grid(True, alpha=0.3)
|
|
440
|
+
|
|
441
|
+
fig.tight_layout()
|
|
442
|
+
return fig
|
|
443
|
+
|
|
444
|
+
def embedding_space_training(
|
|
445
|
+
self,
|
|
446
|
+
fm: 'FoundationalModel',
|
|
447
|
+
style: str = 'notebook',
|
|
448
|
+
figsize: Tuple[int, int] = (14, 6)
|
|
449
|
+
) -> Any:
|
|
450
|
+
"""
|
|
451
|
+
Plot embedding space training metrics.
|
|
452
|
+
|
|
453
|
+
Shows detailed metrics specific to foundational model training.
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
fm: FoundationalModel to visualize
|
|
457
|
+
style: Plot style
|
|
458
|
+
figsize: Figure size
|
|
459
|
+
|
|
460
|
+
Returns:
|
|
461
|
+
matplotlib Figure
|
|
462
|
+
"""
|
|
463
|
+
try:
|
|
464
|
+
import matplotlib.pyplot as plt
|
|
465
|
+
import numpy as np
|
|
466
|
+
except ImportError:
|
|
467
|
+
raise ImportError("matplotlib is required")
|
|
468
|
+
|
|
469
|
+
metrics = fm.get_training_metrics()
|
|
470
|
+
|
|
471
|
+
fig, axes = plt.subplots(1, 3, figsize=figsize)
|
|
472
|
+
|
|
473
|
+
# Loss curve
|
|
474
|
+
loss = metrics.get('loss_history', [])
|
|
475
|
+
if loss:
|
|
476
|
+
axes[0].plot(loss, 'b-', linewidth=2)
|
|
477
|
+
axes[0].set_title('Training Loss')
|
|
478
|
+
axes[0].set_xlabel('Epoch')
|
|
479
|
+
axes[0].set_ylabel('Loss')
|
|
480
|
+
axes[0].grid(True, alpha=0.3)
|
|
481
|
+
|
|
482
|
+
# Learning rate
|
|
483
|
+
lr = metrics.get('lr_history', [])
|
|
484
|
+
if lr:
|
|
485
|
+
axes[1].semilogy(lr, 'r-', linewidth=2)
|
|
486
|
+
axes[1].set_title('Learning Rate')
|
|
487
|
+
axes[1].set_xlabel('Epoch')
|
|
488
|
+
axes[1].set_ylabel('LR')
|
|
489
|
+
axes[1].grid(True, alpha=0.3)
|
|
490
|
+
|
|
491
|
+
# Gradient norm if available
|
|
492
|
+
grad_norm = metrics.get('grad_norm_history', [])
|
|
493
|
+
if grad_norm:
|
|
494
|
+
axes[2].plot(grad_norm, 'g-', linewidth=2)
|
|
495
|
+
axes[2].set_title('Gradient Norm')
|
|
496
|
+
axes[2].set_xlabel('Epoch')
|
|
497
|
+
axes[2].set_ylabel('Norm')
|
|
498
|
+
axes[2].grid(True, alpha=0.3)
|
|
499
|
+
else:
|
|
500
|
+
axes[2].text(0.5, 0.5, 'No gradient data', ha='center', va='center')
|
|
501
|
+
axes[2].set_title('Gradient Norm')
|
|
502
|
+
|
|
503
|
+
fig.suptitle(f'Embedding Space Training ({fm.id[:8]}...)', fontsize=14)
|
|
504
|
+
fig.tight_layout()
|
|
505
|
+
return fig
|
|
506
|
+
|
|
507
|
+
def single_predictor_training(
|
|
508
|
+
self,
|
|
509
|
+
predictor: 'Predictor',
|
|
510
|
+
style: str = 'notebook',
|
|
511
|
+
figsize: Tuple[int, int] = (14, 6)
|
|
512
|
+
) -> Any:
|
|
513
|
+
"""
|
|
514
|
+
Plot single predictor training metrics.
|
|
515
|
+
|
|
516
|
+
Shows detailed metrics specific to predictor training including
|
|
517
|
+
accuracy, AUC, and confusion matrix evolution.
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
predictor: Predictor to visualize
|
|
521
|
+
style: Plot style
|
|
522
|
+
figsize: Figure size
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
matplotlib Figure
|
|
526
|
+
"""
|
|
527
|
+
try:
|
|
528
|
+
import matplotlib.pyplot as plt
|
|
529
|
+
import numpy as np
|
|
530
|
+
except ImportError:
|
|
531
|
+
raise ImportError("matplotlib is required")
|
|
532
|
+
|
|
533
|
+
metrics = predictor.get_metrics()
|
|
534
|
+
|
|
535
|
+
fig, axes = plt.subplots(1, 3, figsize=figsize)
|
|
536
|
+
|
|
537
|
+
# Loss curve
|
|
538
|
+
loss = metrics.get('loss_history', metrics.get('single_predictor', {}).get('loss_history', []))
|
|
539
|
+
if loss:
|
|
540
|
+
axes[0].plot(loss, 'b-', linewidth=2)
|
|
541
|
+
axes[0].set_title('Training Loss')
|
|
542
|
+
axes[0].set_xlabel('Epoch')
|
|
543
|
+
axes[0].set_ylabel('Loss')
|
|
544
|
+
axes[0].grid(True, alpha=0.3)
|
|
545
|
+
|
|
546
|
+
# Accuracy
|
|
547
|
+
accuracy = metrics.get('accuracy_history', metrics.get('single_predictor', {}).get('accuracy_history', []))
|
|
548
|
+
if accuracy:
|
|
549
|
+
axes[1].plot(accuracy, 'g-', linewidth=2)
|
|
550
|
+
axes[1].set_title('Accuracy')
|
|
551
|
+
axes[1].set_xlabel('Epoch')
|
|
552
|
+
axes[1].set_ylabel('Accuracy')
|
|
553
|
+
axes[1].set_ylim(0, 1)
|
|
554
|
+
axes[1].grid(True, alpha=0.3)
|
|
555
|
+
else:
|
|
556
|
+
final_acc = metrics.get('single_predictor', {}).get('accuracy', predictor.accuracy)
|
|
557
|
+
if final_acc:
|
|
558
|
+
axes[1].axhline(y=final_acc, color='g', linewidth=2)
|
|
559
|
+
axes[1].set_title(f'Final Accuracy: {final_acc:.4f}')
|
|
560
|
+
else:
|
|
561
|
+
axes[1].text(0.5, 0.5, 'No accuracy data', ha='center', va='center')
|
|
562
|
+
axes[1].set_title('Accuracy')
|
|
563
|
+
|
|
564
|
+
# AUC
|
|
565
|
+
auc_history = metrics.get('auc_history', metrics.get('single_predictor', {}).get('auc_history', []))
|
|
566
|
+
if auc_history:
|
|
567
|
+
axes[2].plot(auc_history, 'r-', linewidth=2)
|
|
568
|
+
axes[2].set_title('AUC')
|
|
569
|
+
axes[2].set_xlabel('Epoch')
|
|
570
|
+
axes[2].set_ylabel('AUC')
|
|
571
|
+
axes[2].set_ylim(0, 1)
|
|
572
|
+
axes[2].grid(True, alpha=0.3)
|
|
573
|
+
else:
|
|
574
|
+
final_auc = metrics.get('single_predictor', {}).get('roc_auc', predictor.auc)
|
|
575
|
+
if final_auc:
|
|
576
|
+
axes[2].axhline(y=final_auc, color='r', linewidth=2)
|
|
577
|
+
axes[2].set_title(f'Final AUC: {final_auc:.4f}')
|
|
578
|
+
else:
|
|
579
|
+
axes[2].text(0.5, 0.5, 'No AUC data', ha='center', va='center')
|
|
580
|
+
axes[2].set_title('AUC')
|
|
581
|
+
|
|
582
|
+
fig.suptitle(f'Predictor Training ({predictor.target_column})', fontsize=14)
|
|
583
|
+
fig.tight_layout()
|
|
584
|
+
return fig
|