featrixsphere 0.2.5566__py3-none-any.whl → 0.2.6127__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- featrixsphere/__init__.py +37 -18
- featrixsphere/api/__init__.py +50 -0
- featrixsphere/api/api_endpoint.py +280 -0
- featrixsphere/api/client.py +396 -0
- featrixsphere/api/foundational_model.py +658 -0
- featrixsphere/api/http_client.py +209 -0
- featrixsphere/api/notebook_helper.py +584 -0
- featrixsphere/api/prediction_result.py +231 -0
- featrixsphere/api/predictor.py +537 -0
- featrixsphere/api/reference_record.py +227 -0
- featrixsphere/api/vector_database.py +269 -0
- featrixsphere/client.py +218 -8
- {featrixsphere-0.2.5566.dist-info → featrixsphere-0.2.6127.dist-info}/METADATA +1 -1
- featrixsphere-0.2.6127.dist-info/RECORD +17 -0
- featrixsphere-0.2.5566.dist-info/RECORD +0 -7
- {featrixsphere-0.2.5566.dist-info → featrixsphere-0.2.6127.dist-info}/WHEEL +0 -0
- {featrixsphere-0.2.5566.dist-info → featrixsphere-0.2.6127.dist-info}/entry_points.txt +0 -0
- {featrixsphere-0.2.5566.dist-info → featrixsphere-0.2.6127.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,658 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FoundationalModel class for FeatrixSphere API.
|
|
3
|
+
|
|
4
|
+
Represents a trained embedding space (foundational model).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import time
|
|
8
|
+
import logging
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from typing import Dict, Any, Optional, List, Union, TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from .http_client import ClientContext
|
|
15
|
+
import pandas as pd
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
from .predictor import Predictor
|
|
19
|
+
from .vector_database import VectorDatabase
|
|
20
|
+
from .reference_record import ReferenceRecord
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class FoundationalModel:
|
|
27
|
+
"""
|
|
28
|
+
Represents a foundational model (embedding space).
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
id: Session ID / FM ID
|
|
32
|
+
name: Model name
|
|
33
|
+
status: Training status ("training", "done", "error")
|
|
34
|
+
dimensions: Embedding dimensions (d_model)
|
|
35
|
+
epochs: Training epochs completed
|
|
36
|
+
created_at: Creation timestamp
|
|
37
|
+
|
|
38
|
+
Usage:
|
|
39
|
+
# Create from client
|
|
40
|
+
fm = featrix.create_foundational_model(
|
|
41
|
+
name="customer_embeddings",
|
|
42
|
+
csv_file="customers.csv"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Wait for training
|
|
46
|
+
fm.wait_for_training()
|
|
47
|
+
|
|
48
|
+
# Create classifier
|
|
49
|
+
predictor = fm.create_classifier(
|
|
50
|
+
name="churn_predictor",
|
|
51
|
+
target_column="churned"
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Create vector database
|
|
55
|
+
vdb = fm.create_vector_database(
|
|
56
|
+
name="customer_search",
|
|
57
|
+
records=customer_records
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Encode records
|
|
61
|
+
vectors = fm.encode([{"age": 35}, {"age": 42}])
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
id: str
|
|
65
|
+
name: Optional[str] = None
|
|
66
|
+
status: Optional[str] = None
|
|
67
|
+
dimensions: Optional[int] = None
|
|
68
|
+
epochs: Optional[int] = None
|
|
69
|
+
final_loss: Optional[float] = None
|
|
70
|
+
created_at: Optional[datetime] = None
|
|
71
|
+
|
|
72
|
+
# Internal
|
|
73
|
+
_ctx: Optional['ClientContext'] = field(default=None, repr=False)
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def from_response(
|
|
77
|
+
cls,
|
|
78
|
+
response: Dict[str, Any],
|
|
79
|
+
ctx: Optional['ClientContext'] = None
|
|
80
|
+
) -> 'FoundationalModel':
|
|
81
|
+
"""Create FoundationalModel from API response."""
|
|
82
|
+
session_id = response.get('session_id', '')
|
|
83
|
+
|
|
84
|
+
return cls(
|
|
85
|
+
id=session_id,
|
|
86
|
+
name=response.get('name'),
|
|
87
|
+
status=response.get('status'),
|
|
88
|
+
dimensions=response.get('d_model') or response.get('dimensions'),
|
|
89
|
+
epochs=response.get('epochs') or response.get('final_epoch'),
|
|
90
|
+
final_loss=response.get('final_loss'),
|
|
91
|
+
created_at=datetime.now(),
|
|
92
|
+
_ctx=ctx,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def from_session_id(
|
|
97
|
+
cls,
|
|
98
|
+
session_id: str,
|
|
99
|
+
ctx: 'ClientContext'
|
|
100
|
+
) -> 'FoundationalModel':
|
|
101
|
+
"""Load FoundationalModel from session ID."""
|
|
102
|
+
# Get session info
|
|
103
|
+
session_data = ctx.get_json(f"/compute/session/{session_id}")
|
|
104
|
+
|
|
105
|
+
fm = cls(
|
|
106
|
+
id=session_id,
|
|
107
|
+
name=session_data.get('name'),
|
|
108
|
+
status=session_data.get('status'),
|
|
109
|
+
created_at=datetime.now(),
|
|
110
|
+
_ctx=ctx,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Try to get model info
|
|
114
|
+
fm._update_from_session(session_data)
|
|
115
|
+
|
|
116
|
+
return fm
|
|
117
|
+
|
|
118
|
+
def create_classifier(
|
|
119
|
+
self,
|
|
120
|
+
target_column: str,
|
|
121
|
+
name: Optional[str] = None,
|
|
122
|
+
labels_file: Optional[str] = None,
|
|
123
|
+
labels_df: Optional['pd.DataFrame'] = None,
|
|
124
|
+
epochs: int = 0,
|
|
125
|
+
rare_label_value: Optional[str] = None,
|
|
126
|
+
class_imbalance: Optional[Dict[str, float]] = None,
|
|
127
|
+
webhooks: Optional[Dict[str, str]] = None,
|
|
128
|
+
**kwargs
|
|
129
|
+
) -> Predictor:
|
|
130
|
+
"""
|
|
131
|
+
Create a classifier predictor from this foundational model.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
target_column: Column name to predict
|
|
135
|
+
name: Predictor name (optional)
|
|
136
|
+
labels_file: Path to labels file (optional - use existing data if not provided)
|
|
137
|
+
labels_df: DataFrame with labels (optional)
|
|
138
|
+
epochs: Training epochs (0 = auto)
|
|
139
|
+
rare_label_value: Rare class for metrics
|
|
140
|
+
class_imbalance: Expected class distribution
|
|
141
|
+
webhooks: Webhook URLs for events
|
|
142
|
+
**kwargs: Additional training parameters
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Predictor object (training started)
|
|
146
|
+
|
|
147
|
+
Example:
|
|
148
|
+
predictor = fm.create_classifier(
|
|
149
|
+
target_column="churned",
|
|
150
|
+
name="churn_predictor"
|
|
151
|
+
)
|
|
152
|
+
predictor.wait_for_training()
|
|
153
|
+
"""
|
|
154
|
+
return self._create_predictor(
|
|
155
|
+
target_column=target_column,
|
|
156
|
+
target_type="set",
|
|
157
|
+
name=name,
|
|
158
|
+
labels_file=labels_file,
|
|
159
|
+
labels_df=labels_df,
|
|
160
|
+
epochs=epochs,
|
|
161
|
+
rare_label_value=rare_label_value,
|
|
162
|
+
class_imbalance=class_imbalance,
|
|
163
|
+
webhooks=webhooks,
|
|
164
|
+
**kwargs
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def create_regressor(
|
|
168
|
+
self,
|
|
169
|
+
target_column: str,
|
|
170
|
+
name: Optional[str] = None,
|
|
171
|
+
labels_file: Optional[str] = None,
|
|
172
|
+
labels_df: Optional['pd.DataFrame'] = None,
|
|
173
|
+
epochs: int = 0,
|
|
174
|
+
webhooks: Optional[Dict[str, str]] = None,
|
|
175
|
+
**kwargs
|
|
176
|
+
) -> Predictor:
|
|
177
|
+
"""
|
|
178
|
+
Create a regressor predictor from this foundational model.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
target_column: Column name to predict
|
|
182
|
+
name: Predictor name (optional)
|
|
183
|
+
labels_file: Path to labels file (optional)
|
|
184
|
+
labels_df: DataFrame with labels (optional)
|
|
185
|
+
epochs: Training epochs (0 = auto)
|
|
186
|
+
webhooks: Webhook URLs for events
|
|
187
|
+
**kwargs: Additional training parameters
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Predictor object (training started)
|
|
191
|
+
"""
|
|
192
|
+
return self._create_predictor(
|
|
193
|
+
target_column=target_column,
|
|
194
|
+
target_type="numeric",
|
|
195
|
+
name=name,
|
|
196
|
+
labels_file=labels_file,
|
|
197
|
+
labels_df=labels_df,
|
|
198
|
+
epochs=epochs,
|
|
199
|
+
webhooks=webhooks,
|
|
200
|
+
**kwargs
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def _create_predictor(
|
|
204
|
+
self,
|
|
205
|
+
target_column: str,
|
|
206
|
+
target_type: str,
|
|
207
|
+
name: Optional[str] = None,
|
|
208
|
+
labels_file: Optional[str] = None,
|
|
209
|
+
labels_df: Optional['pd.DataFrame'] = None,
|
|
210
|
+
epochs: int = 0,
|
|
211
|
+
rare_label_value: Optional[str] = None,
|
|
212
|
+
class_imbalance: Optional[Dict[str, float]] = None,
|
|
213
|
+
webhooks: Optional[Dict[str, str]] = None,
|
|
214
|
+
**kwargs
|
|
215
|
+
) -> Predictor:
|
|
216
|
+
"""Internal method to create predictor."""
|
|
217
|
+
if not self._ctx:
|
|
218
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
219
|
+
|
|
220
|
+
# If labels file provided, use the file upload endpoint
|
|
221
|
+
if labels_file or labels_df is not None:
|
|
222
|
+
return self._create_predictor_with_labels(
|
|
223
|
+
target_column=target_column,
|
|
224
|
+
target_type=target_type,
|
|
225
|
+
labels_file=labels_file,
|
|
226
|
+
labels_df=labels_df,
|
|
227
|
+
epochs=epochs,
|
|
228
|
+
rare_label_value=rare_label_value,
|
|
229
|
+
class_imbalance=class_imbalance,
|
|
230
|
+
webhooks=webhooks,
|
|
231
|
+
**kwargs
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Use simple endpoint (train on existing session data)
|
|
235
|
+
data = {
|
|
236
|
+
"target_column": target_column,
|
|
237
|
+
"target_column_type": target_type,
|
|
238
|
+
"epochs": epochs,
|
|
239
|
+
}
|
|
240
|
+
if rare_label_value:
|
|
241
|
+
data["rare_label_value"] = rare_label_value
|
|
242
|
+
if class_imbalance:
|
|
243
|
+
data["class_imbalance"] = class_imbalance
|
|
244
|
+
if webhooks:
|
|
245
|
+
data["webhooks"] = webhooks
|
|
246
|
+
data.update(kwargs)
|
|
247
|
+
|
|
248
|
+
response = self._ctx.post_json(
|
|
249
|
+
f"/compute/session/{self.id}/train_predictor",
|
|
250
|
+
data=data
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
predictor = Predictor(
|
|
254
|
+
id=response.get('predictor_id', ''),
|
|
255
|
+
session_id=self.id,
|
|
256
|
+
target_column=target_column,
|
|
257
|
+
target_type=target_type,
|
|
258
|
+
name=name,
|
|
259
|
+
status="training",
|
|
260
|
+
_ctx=self._ctx,
|
|
261
|
+
_foundational_model=self,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
return predictor
|
|
265
|
+
|
|
266
|
+
def _create_predictor_with_labels(
|
|
267
|
+
self,
|
|
268
|
+
target_column: str,
|
|
269
|
+
target_type: str,
|
|
270
|
+
labels_file: Optional[str] = None,
|
|
271
|
+
labels_df: Optional['pd.DataFrame'] = None,
|
|
272
|
+
epochs: int = 0,
|
|
273
|
+
rare_label_value: Optional[str] = None,
|
|
274
|
+
class_imbalance: Optional[Dict[str, float]] = None,
|
|
275
|
+
webhooks: Optional[Dict[str, str]] = None,
|
|
276
|
+
**kwargs
|
|
277
|
+
) -> Predictor:
|
|
278
|
+
"""Create predictor with separate labels file."""
|
|
279
|
+
import io
|
|
280
|
+
import gzip
|
|
281
|
+
|
|
282
|
+
# Prepare the labels data
|
|
283
|
+
if labels_df is not None:
|
|
284
|
+
# Convert DataFrame to CSV bytes
|
|
285
|
+
csv_buffer = io.StringIO()
|
|
286
|
+
labels_df.to_csv(csv_buffer, index=False)
|
|
287
|
+
file_content = csv_buffer.getvalue().encode('utf-8')
|
|
288
|
+
filename = "labels.csv"
|
|
289
|
+
elif labels_file:
|
|
290
|
+
with open(labels_file, 'rb') as f:
|
|
291
|
+
file_content = f.read()
|
|
292
|
+
filename = labels_file.split('/')[-1]
|
|
293
|
+
else:
|
|
294
|
+
raise ValueError("Either labels_file or labels_df must be provided")
|
|
295
|
+
|
|
296
|
+
# Compress if large
|
|
297
|
+
if len(file_content) > 100_000:
|
|
298
|
+
compressed = gzip.compress(file_content)
|
|
299
|
+
if len(compressed) < len(file_content):
|
|
300
|
+
file_content = compressed
|
|
301
|
+
filename = filename + '.gz'
|
|
302
|
+
|
|
303
|
+
# Build form data
|
|
304
|
+
form_data = {
|
|
305
|
+
"target_column": target_column,
|
|
306
|
+
"target_column_type": target_type,
|
|
307
|
+
"epochs": str(epochs),
|
|
308
|
+
}
|
|
309
|
+
if rare_label_value:
|
|
310
|
+
form_data["rare_label_value"] = rare_label_value
|
|
311
|
+
if class_imbalance:
|
|
312
|
+
import json
|
|
313
|
+
form_data["class_imbalance"] = json.dumps(class_imbalance)
|
|
314
|
+
if webhooks:
|
|
315
|
+
import json
|
|
316
|
+
form_data["webhooks"] = json.dumps(webhooks)
|
|
317
|
+
|
|
318
|
+
files = {
|
|
319
|
+
"file": (filename, file_content)
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
response = self._ctx.post_multipart(
|
|
323
|
+
f"/compute/session/{self.id}/train_predictor",
|
|
324
|
+
data=form_data,
|
|
325
|
+
files=files
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
return Predictor(
|
|
329
|
+
id=response.get('predictor_id', ''),
|
|
330
|
+
session_id=self.id,
|
|
331
|
+
target_column=target_column,
|
|
332
|
+
target_type=target_type,
|
|
333
|
+
status="training",
|
|
334
|
+
_ctx=self._ctx,
|
|
335
|
+
_foundational_model=self,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
def create_vector_database(
|
|
339
|
+
self,
|
|
340
|
+
name: Optional[str] = None,
|
|
341
|
+
records: Optional[Union[List[Dict[str, Any]], 'pd.DataFrame']] = None
|
|
342
|
+
) -> VectorDatabase:
|
|
343
|
+
"""
|
|
344
|
+
Create a vector database from this foundational model.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
name: Database name
|
|
348
|
+
records: Initial records to add (optional)
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
VectorDatabase object
|
|
352
|
+
|
|
353
|
+
Example:
|
|
354
|
+
vdb = fm.create_vector_database(
|
|
355
|
+
name="customer_search",
|
|
356
|
+
records=customer_records
|
|
357
|
+
)
|
|
358
|
+
similar = vdb.similarity_search({"age": 35}, k=5)
|
|
359
|
+
"""
|
|
360
|
+
vdb = VectorDatabase.from_session(
|
|
361
|
+
session_id=self.id,
|
|
362
|
+
name=name,
|
|
363
|
+
ctx=self._ctx,
|
|
364
|
+
foundational_model=self,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
# Add initial records if provided
|
|
368
|
+
if records is not None:
|
|
369
|
+
vdb.add_records(records)
|
|
370
|
+
|
|
371
|
+
return vdb
|
|
372
|
+
|
|
373
|
+
def create_reference_record(
|
|
374
|
+
self,
|
|
375
|
+
record: Dict[str, Any],
|
|
376
|
+
name: Optional[str] = None
|
|
377
|
+
) -> ReferenceRecord:
|
|
378
|
+
"""
|
|
379
|
+
Create a reference record from a specific record for similarity search.
|
|
380
|
+
|
|
381
|
+
A reference record is a reference point in the embedding space that you can use
|
|
382
|
+
to find similar records. Particularly useful when you only have a positive
|
|
383
|
+
class but no negative class - just find more records like the positive example.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
record: The record to create a reference from
|
|
387
|
+
name: Optional name for the reference record
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
ReferenceRecord object that can be used for similarity search
|
|
391
|
+
|
|
392
|
+
Example:
|
|
393
|
+
# Create a reference record from a high-value customer
|
|
394
|
+
ref = fm.create_reference_record(
|
|
395
|
+
record={"age": 35, "income": 100000, "plan": "premium"},
|
|
396
|
+
name="high_value_customer"
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# Find similar customers
|
|
400
|
+
similar = ref.find_similar(k=10, vector_database=vdb)
|
|
401
|
+
"""
|
|
402
|
+
if not self._ctx:
|
|
403
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
404
|
+
|
|
405
|
+
return ReferenceRecord.from_record(
|
|
406
|
+
record=record,
|
|
407
|
+
session_id=self.id,
|
|
408
|
+
name=name,
|
|
409
|
+
ctx=self._ctx,
|
|
410
|
+
foundational_model=self,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
def wait_for_training(
|
|
414
|
+
self,
|
|
415
|
+
max_wait_time: int = 3600,
|
|
416
|
+
poll_interval: int = 10,
|
|
417
|
+
show_progress: bool = True
|
|
418
|
+
) -> 'FoundationalModel':
|
|
419
|
+
"""
|
|
420
|
+
Wait for foundational model training to complete.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
max_wait_time: Maximum wait time in seconds
|
|
424
|
+
poll_interval: Polling interval in seconds
|
|
425
|
+
show_progress: Print progress updates
|
|
426
|
+
|
|
427
|
+
Returns:
|
|
428
|
+
Self (updated with final status)
|
|
429
|
+
|
|
430
|
+
Raises:
|
|
431
|
+
TimeoutError: If training doesn't complete in time
|
|
432
|
+
RuntimeError: If training fails
|
|
433
|
+
"""
|
|
434
|
+
if not self._ctx:
|
|
435
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
436
|
+
|
|
437
|
+
start_time = time.time()
|
|
438
|
+
last_epoch = None
|
|
439
|
+
last_status = None
|
|
440
|
+
|
|
441
|
+
while time.time() - start_time < max_wait_time:
|
|
442
|
+
# Get session status
|
|
443
|
+
session_data = self._ctx.get_json(f"/compute/session/{self.id}")
|
|
444
|
+
status = session_data.get('status', 'unknown')
|
|
445
|
+
jobs = session_data.get('jobs', {})
|
|
446
|
+
|
|
447
|
+
# Look for ES training job
|
|
448
|
+
es_job = None
|
|
449
|
+
for job_id, job in jobs.items():
|
|
450
|
+
if job.get('job_type') in ('train_embedding_space', 'train_es', 'training'):
|
|
451
|
+
es_job = job
|
|
452
|
+
break
|
|
453
|
+
|
|
454
|
+
# Get progress info
|
|
455
|
+
current_epoch = None
|
|
456
|
+
total_epochs = None
|
|
457
|
+
if es_job:
|
|
458
|
+
current_epoch = es_job.get('current_epoch') or es_job.get('epoch')
|
|
459
|
+
total_epochs = es_job.get('total_epochs') or es_job.get('epochs')
|
|
460
|
+
job_status = es_job.get('status', status)
|
|
461
|
+
else:
|
|
462
|
+
job_status = status
|
|
463
|
+
|
|
464
|
+
# Progress update
|
|
465
|
+
if show_progress:
|
|
466
|
+
elapsed = int(time.time() - start_time)
|
|
467
|
+
if current_epoch != last_epoch or job_status != last_status:
|
|
468
|
+
if current_epoch and total_epochs:
|
|
469
|
+
print(f"[{elapsed}s] Training: epoch {current_epoch}/{total_epochs} ({job_status})")
|
|
470
|
+
else:
|
|
471
|
+
print(f"[{elapsed}s] Training: {job_status}")
|
|
472
|
+
last_epoch = current_epoch
|
|
473
|
+
last_status = job_status
|
|
474
|
+
|
|
475
|
+
# Check completion
|
|
476
|
+
if job_status == 'done' or status == 'done':
|
|
477
|
+
self.status = 'done'
|
|
478
|
+
self._update_from_session(session_data)
|
|
479
|
+
if show_progress:
|
|
480
|
+
print(f"Training complete!")
|
|
481
|
+
if self.dimensions:
|
|
482
|
+
print(f" Dimensions: {self.dimensions}")
|
|
483
|
+
if self.epochs:
|
|
484
|
+
print(f" Epochs: {self.epochs}")
|
|
485
|
+
return self
|
|
486
|
+
|
|
487
|
+
elif job_status == 'failed' or status == 'failed':
|
|
488
|
+
error_msg = 'Unknown error'
|
|
489
|
+
if es_job:
|
|
490
|
+
error_msg = es_job.get('error', error_msg)
|
|
491
|
+
self.status = 'error'
|
|
492
|
+
raise RuntimeError(f"Training failed: {error_msg}")
|
|
493
|
+
|
|
494
|
+
time.sleep(poll_interval)
|
|
495
|
+
|
|
496
|
+
raise TimeoutError(f"Training did not complete within {max_wait_time}s")
|
|
497
|
+
|
|
498
|
+
def encode(
|
|
499
|
+
self,
|
|
500
|
+
records: Union[Dict[str, Any], List[Dict[str, Any]], 'pd.DataFrame']
|
|
501
|
+
) -> List[List[float]]:
|
|
502
|
+
"""
|
|
503
|
+
Encode records to embedding vectors.
|
|
504
|
+
|
|
505
|
+
Args:
|
|
506
|
+
records: Single record, list of records, or DataFrame
|
|
507
|
+
|
|
508
|
+
Returns:
|
|
509
|
+
List of embedding vectors (as lists of floats)
|
|
510
|
+
|
|
511
|
+
Example:
|
|
512
|
+
vectors = fm.encode([
|
|
513
|
+
{"age": 35, "income": 50000},
|
|
514
|
+
{"age": 42, "income": 75000}
|
|
515
|
+
])
|
|
516
|
+
"""
|
|
517
|
+
if not self._ctx:
|
|
518
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
519
|
+
|
|
520
|
+
# Normalize input
|
|
521
|
+
if isinstance(records, dict):
|
|
522
|
+
records = [records]
|
|
523
|
+
elif hasattr(records, 'to_dict'):
|
|
524
|
+
records = records.to_dict('records')
|
|
525
|
+
|
|
526
|
+
# Clean records
|
|
527
|
+
cleaned = [self._clean_record(r) for r in records]
|
|
528
|
+
|
|
529
|
+
response = self._ctx.post_json(
|
|
530
|
+
f"/session/{self.id}/encode_records",
|
|
531
|
+
data={"records": cleaned}
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
return response.get('embeddings', [])
|
|
535
|
+
|
|
536
|
+
def extend(
|
|
537
|
+
self,
|
|
538
|
+
new_data_file: Optional[str] = None,
|
|
539
|
+
new_data_df: Optional['pd.DataFrame'] = None,
|
|
540
|
+
epochs: Optional[int] = None,
|
|
541
|
+
**kwargs
|
|
542
|
+
) -> 'FoundationalModel':
|
|
543
|
+
"""
|
|
544
|
+
Extend this foundational model with new data.
|
|
545
|
+
|
|
546
|
+
Args:
|
|
547
|
+
new_data_file: Path to new data file
|
|
548
|
+
new_data_df: DataFrame with new data
|
|
549
|
+
epochs: Additional training epochs
|
|
550
|
+
**kwargs: Additional parameters
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
New FoundationalModel instance (training started)
|
|
554
|
+
"""
|
|
555
|
+
if not self._ctx:
|
|
556
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
557
|
+
|
|
558
|
+
# This creates a new session with extended data
|
|
559
|
+
# Implementation depends on server API
|
|
560
|
+
raise NotImplementedError("extend() not yet implemented - use create_foundational_model with new data")
|
|
561
|
+
|
|
562
|
+
def get_projections(self) -> Dict[str, Any]:
|
|
563
|
+
"""Get 2D/3D projections for visualization."""
|
|
564
|
+
if not self._ctx:
|
|
565
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
566
|
+
|
|
567
|
+
return self._ctx.get_json(f"/session/{self.id}/projections")
|
|
568
|
+
|
|
569
|
+
def get_training_metrics(self) -> Dict[str, Any]:
|
|
570
|
+
"""Get training metrics and history."""
|
|
571
|
+
if not self._ctx:
|
|
572
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
573
|
+
|
|
574
|
+
return self._ctx.get_json(f"/session/{self.id}/training_metrics")
|
|
575
|
+
|
|
576
|
+
def get_model_card(self) -> Dict[str, Any]:
|
|
577
|
+
"""Get the model card for this foundational model."""
|
|
578
|
+
if not self._ctx:
|
|
579
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
580
|
+
|
|
581
|
+
return self._ctx.get_json(f"/session/{self.id}/model_card")
|
|
582
|
+
|
|
583
|
+
def list_predictors(self) -> List[Predictor]:
|
|
584
|
+
"""List all predictors for this foundational model."""
|
|
585
|
+
if not self._ctx:
|
|
586
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
587
|
+
|
|
588
|
+
response = self._ctx.get_json(f"/session/{self.id}/predictor")
|
|
589
|
+
predictors_data = response.get('predictors', {})
|
|
590
|
+
|
|
591
|
+
predictors = []
|
|
592
|
+
for pred_id, pred_info in predictors_data.items():
|
|
593
|
+
pred = Predictor(
|
|
594
|
+
id=pred_id,
|
|
595
|
+
session_id=self.id,
|
|
596
|
+
target_column=pred_info.get('target_column', ''),
|
|
597
|
+
target_type=pred_info.get('target_type', 'set'),
|
|
598
|
+
name=pred_info.get('name'),
|
|
599
|
+
status=pred_info.get('status'),
|
|
600
|
+
accuracy=pred_info.get('accuracy'),
|
|
601
|
+
_ctx=self._ctx,
|
|
602
|
+
_foundational_model=self,
|
|
603
|
+
)
|
|
604
|
+
predictors.append(pred)
|
|
605
|
+
|
|
606
|
+
return predictors
|
|
607
|
+
|
|
608
|
+
def _update_from_session(self, session_data: Dict[str, Any]) -> None:
|
|
609
|
+
"""Update fields from session data."""
|
|
610
|
+
# Try to get model info from various places
|
|
611
|
+
model_info = session_data.get('model_info', {})
|
|
612
|
+
training_stats = session_data.get('training_stats', {})
|
|
613
|
+
|
|
614
|
+
self.dimensions = (
|
|
615
|
+
model_info.get('d_model') or
|
|
616
|
+
model_info.get('embedding_dim') or
|
|
617
|
+
session_data.get('d_model')
|
|
618
|
+
)
|
|
619
|
+
self.epochs = (
|
|
620
|
+
training_stats.get('final_epoch') or
|
|
621
|
+
training_stats.get('epochs_trained') or
|
|
622
|
+
session_data.get('epochs')
|
|
623
|
+
)
|
|
624
|
+
self.final_loss = (
|
|
625
|
+
training_stats.get('final_loss') or
|
|
626
|
+
session_data.get('final_loss')
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
def _clean_record(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
|
630
|
+
"""Clean a record for API submission."""
|
|
631
|
+
import math
|
|
632
|
+
|
|
633
|
+
cleaned = {}
|
|
634
|
+
for key, value in record.items():
|
|
635
|
+
if isinstance(value, float):
|
|
636
|
+
if math.isnan(value) or math.isinf(value):
|
|
637
|
+
value = None
|
|
638
|
+
if hasattr(value, 'item'):
|
|
639
|
+
value = value.item()
|
|
640
|
+
cleaned[key] = value
|
|
641
|
+
return cleaned
|
|
642
|
+
|
|
643
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
644
|
+
"""Convert to dictionary representation."""
|
|
645
|
+
return {
|
|
646
|
+
'id': self.id,
|
|
647
|
+
'name': self.name,
|
|
648
|
+
'status': self.status,
|
|
649
|
+
'dimensions': self.dimensions,
|
|
650
|
+
'epochs': self.epochs,
|
|
651
|
+
'final_loss': self.final_loss,
|
|
652
|
+
'created_at': self.created_at.isoformat() if self.created_at else None,
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
def __repr__(self) -> str:
|
|
656
|
+
status_str = f", status='{self.status}'" if self.status else ""
|
|
657
|
+
dims_str = f", dims={self.dimensions}" if self.dimensions else ""
|
|
658
|
+
return f"FoundationalModel(id='{self.id}'{status_str}{dims_str})"
|