featrixsphere 0.2.5563__py3-none-any.whl → 0.2.5978__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- featrixsphere/__init__.py +37 -18
- featrixsphere/api/__init__.py +50 -0
- featrixsphere/api/api_endpoint.py +280 -0
- featrixsphere/api/client.py +396 -0
- featrixsphere/api/foundational_model.py +658 -0
- featrixsphere/api/http_client.py +209 -0
- featrixsphere/api/notebook_helper.py +584 -0
- featrixsphere/api/prediction_result.py +231 -0
- featrixsphere/api/predictor.py +537 -0
- featrixsphere/api/reference_record.py +227 -0
- featrixsphere/api/vector_database.py +269 -0
- featrixsphere/client.py +215 -12
- {featrixsphere-0.2.5563.dist-info → featrixsphere-0.2.5978.dist-info}/METADATA +1 -1
- featrixsphere-0.2.5978.dist-info/RECORD +17 -0
- featrixsphere-0.2.5563.dist-info/RECORD +0 -7
- {featrixsphere-0.2.5563.dist-info → featrixsphere-0.2.5978.dist-info}/WHEEL +0 -0
- {featrixsphere-0.2.5563.dist-info → featrixsphere-0.2.5978.dist-info}/entry_points.txt +0 -0
- {featrixsphere-0.2.5563.dist-info → featrixsphere-0.2.5978.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,537 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Predictor class for FeatrixSphere API.
|
|
3
|
+
|
|
4
|
+
Represents a trained predictor (classifier or regressor).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import time
|
|
8
|
+
import logging
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from typing import Dict, Any, Optional, List, Union, TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from .http_client import ClientContext
|
|
15
|
+
from .foundational_model import FoundationalModel
|
|
16
|
+
import pandas as pd
|
|
17
|
+
|
|
18
|
+
from .prediction_result import PredictionResult
|
|
19
|
+
from .api_endpoint import APIEndpoint
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class Predictor:
|
|
26
|
+
"""
|
|
27
|
+
Represents a trained predictor (classifier or regressor).
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
id: Predictor ID
|
|
31
|
+
name: Predictor name
|
|
32
|
+
target_column: Target column name
|
|
33
|
+
target_type: Target type ("set", "numeric", "binary")
|
|
34
|
+
status: Training status ("training", "done", "error")
|
|
35
|
+
session_id: Parent session ID
|
|
36
|
+
accuracy: Training accuracy (if available)
|
|
37
|
+
created_at: Creation timestamp
|
|
38
|
+
|
|
39
|
+
Usage:
|
|
40
|
+
# Create from foundational model
|
|
41
|
+
predictor = fm.create_classifier(
|
|
42
|
+
name="churn_predictor",
|
|
43
|
+
target_column="churned"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Wait for training
|
|
47
|
+
predictor.wait_for_training()
|
|
48
|
+
|
|
49
|
+
# Make prediction
|
|
50
|
+
result = predictor.predict({"age": 35, "income": 50000})
|
|
51
|
+
print(result.predicted_class)
|
|
52
|
+
print(result.confidence)
|
|
53
|
+
|
|
54
|
+
# Batch predictions
|
|
55
|
+
results = predictor.batch_predict([
|
|
56
|
+
{"age": 35, "income": 50000},
|
|
57
|
+
{"age": 42, "income": 75000}
|
|
58
|
+
])
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
id: str
|
|
62
|
+
session_id: str
|
|
63
|
+
target_column: str
|
|
64
|
+
target_type: str = "set"
|
|
65
|
+
name: Optional[str] = None
|
|
66
|
+
status: Optional[str] = None
|
|
67
|
+
accuracy: Optional[float] = None
|
|
68
|
+
auc: Optional[float] = None
|
|
69
|
+
f1: Optional[float] = None
|
|
70
|
+
created_at: Optional[datetime] = None
|
|
71
|
+
|
|
72
|
+
# Internal
|
|
73
|
+
_ctx: Optional['ClientContext'] = field(default=None, repr=False)
|
|
74
|
+
_foundational_model: Optional['FoundationalModel'] = field(default=None, repr=False)
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def from_response(
|
|
78
|
+
cls,
|
|
79
|
+
response: Dict[str, Any],
|
|
80
|
+
session_id: str,
|
|
81
|
+
ctx: Optional['ClientContext'] = None,
|
|
82
|
+
foundational_model: Optional['FoundationalModel'] = None
|
|
83
|
+
) -> 'Predictor':
|
|
84
|
+
"""Create Predictor from API response."""
|
|
85
|
+
return cls(
|
|
86
|
+
id=response.get('predictor_id') or response.get('id', ''),
|
|
87
|
+
session_id=session_id,
|
|
88
|
+
target_column=response.get('target_column', ''),
|
|
89
|
+
target_type=response.get('target_type') or response.get('target_column_type', 'set'),
|
|
90
|
+
name=response.get('name'),
|
|
91
|
+
status=response.get('status'),
|
|
92
|
+
accuracy=response.get('accuracy'),
|
|
93
|
+
auc=response.get('auc') or response.get('roc_auc'),
|
|
94
|
+
f1=response.get('f1') or response.get('f1_score'),
|
|
95
|
+
created_at=datetime.now(),
|
|
96
|
+
_ctx=ctx,
|
|
97
|
+
_foundational_model=foundational_model,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def foundational_model(self) -> Optional['FoundationalModel']:
|
|
102
|
+
"""Get the parent foundational model."""
|
|
103
|
+
return self._foundational_model
|
|
104
|
+
|
|
105
|
+
def predict(
|
|
106
|
+
self,
|
|
107
|
+
record: Dict[str, Any],
|
|
108
|
+
best_metric_preference: Optional[str] = None
|
|
109
|
+
) -> PredictionResult:
|
|
110
|
+
"""
|
|
111
|
+
Make a single prediction.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
record: Input record dictionary
|
|
115
|
+
best_metric_preference: Metric checkpoint to use ("roc_auc", "pr_auc", or None)
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
PredictionResult with prediction, confidence, and prediction_uuid
|
|
119
|
+
|
|
120
|
+
Example:
|
|
121
|
+
result = predictor.predict({"age": 35, "income": 50000})
|
|
122
|
+
print(result.predicted_class) # "churned"
|
|
123
|
+
print(result.confidence) # 0.87
|
|
124
|
+
print(result.prediction_uuid) # UUID for feedback
|
|
125
|
+
"""
|
|
126
|
+
if not self._ctx:
|
|
127
|
+
raise ValueError("Predictor not connected to client")
|
|
128
|
+
|
|
129
|
+
# Clean the record
|
|
130
|
+
cleaned_record = self._clean_record(record)
|
|
131
|
+
|
|
132
|
+
# Build request
|
|
133
|
+
request_payload = {
|
|
134
|
+
"query_record": cleaned_record,
|
|
135
|
+
"predictor_id": self.id,
|
|
136
|
+
}
|
|
137
|
+
if best_metric_preference:
|
|
138
|
+
request_payload["best_metric_preference"] = best_metric_preference
|
|
139
|
+
|
|
140
|
+
# Make request
|
|
141
|
+
response = self._ctx.post_json(
|
|
142
|
+
f"/session/{self.session_id}/predict",
|
|
143
|
+
data=request_payload
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return PredictionResult.from_response(response, cleaned_record, self._ctx)
|
|
147
|
+
|
|
148
|
+
def batch_predict(
|
|
149
|
+
self,
|
|
150
|
+
records: Union[List[Dict[str, Any]], 'pd.DataFrame'],
|
|
151
|
+
show_progress: bool = True,
|
|
152
|
+
best_metric_preference: Optional[str] = None
|
|
153
|
+
) -> List[PredictionResult]:
|
|
154
|
+
"""
|
|
155
|
+
Make batch predictions.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
records: List of record dictionaries or DataFrame
|
|
159
|
+
show_progress: Show progress bar
|
|
160
|
+
best_metric_preference: Metric checkpoint to use
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
List of PredictionResult objects
|
|
164
|
+
"""
|
|
165
|
+
if not self._ctx:
|
|
166
|
+
raise ValueError("Predictor not connected to client")
|
|
167
|
+
|
|
168
|
+
# Convert DataFrame to list of dicts if needed
|
|
169
|
+
if hasattr(records, 'to_dict'):
|
|
170
|
+
records = records.to_dict('records')
|
|
171
|
+
|
|
172
|
+
# Clean records
|
|
173
|
+
cleaned_records = [self._clean_record(r) for r in records]
|
|
174
|
+
|
|
175
|
+
# Build request
|
|
176
|
+
request_payload = {
|
|
177
|
+
"records": cleaned_records,
|
|
178
|
+
"predictor_id": self.id,
|
|
179
|
+
}
|
|
180
|
+
if best_metric_preference:
|
|
181
|
+
request_payload["best_metric_preference"] = best_metric_preference
|
|
182
|
+
|
|
183
|
+
# Make request using predict_table endpoint
|
|
184
|
+
response = self._ctx.post_json(
|
|
185
|
+
f"/session/{self.session_id}/predict_table",
|
|
186
|
+
data={"table": {"rows": cleaned_records}, "predictor_id": self.id}
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Parse results
|
|
190
|
+
results = []
|
|
191
|
+
predictions = response.get('predictions', [])
|
|
192
|
+
|
|
193
|
+
for i, pred in enumerate(predictions):
|
|
194
|
+
record = cleaned_records[i] if i < len(cleaned_records) else {}
|
|
195
|
+
results.append(PredictionResult.from_response(pred, record, self._ctx))
|
|
196
|
+
|
|
197
|
+
return results
|
|
198
|
+
|
|
199
|
+
def explain(
|
|
200
|
+
self,
|
|
201
|
+
record: Dict[str, Any],
|
|
202
|
+
class_idx: Optional[int] = None
|
|
203
|
+
) -> Dict[str, Any]:
|
|
204
|
+
"""
|
|
205
|
+
Explain a prediction using gradient attribution.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
record: Input record to explain
|
|
209
|
+
class_idx: Class index to explain (for multi-class)
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
Explanation dictionary with feature attributions
|
|
213
|
+
"""
|
|
214
|
+
if not self._ctx:
|
|
215
|
+
raise ValueError("Predictor not connected to client")
|
|
216
|
+
|
|
217
|
+
cleaned_record = self._clean_record(record)
|
|
218
|
+
|
|
219
|
+
request_payload = {
|
|
220
|
+
"query_record": cleaned_record,
|
|
221
|
+
"predictor_id": self.id,
|
|
222
|
+
}
|
|
223
|
+
if class_idx is not None:
|
|
224
|
+
request_payload["class_idx"] = class_idx
|
|
225
|
+
|
|
226
|
+
return self._ctx.post_json(
|
|
227
|
+
f"/session/{self.session_id}/explain",
|
|
228
|
+
data=request_payload
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def wait_for_training(
|
|
232
|
+
self,
|
|
233
|
+
max_wait_time: int = 3600,
|
|
234
|
+
poll_interval: int = 10,
|
|
235
|
+
show_progress: bool = True
|
|
236
|
+
) -> 'Predictor':
|
|
237
|
+
"""
|
|
238
|
+
Wait for predictor training to complete.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
max_wait_time: Maximum wait time in seconds
|
|
242
|
+
poll_interval: Polling interval in seconds
|
|
243
|
+
show_progress: Print progress updates
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
Self (updated with final status)
|
|
247
|
+
|
|
248
|
+
Raises:
|
|
249
|
+
TimeoutError: If training doesn't complete in time
|
|
250
|
+
RuntimeError: If training fails
|
|
251
|
+
"""
|
|
252
|
+
if not self._ctx:
|
|
253
|
+
raise ValueError("Predictor not connected to client")
|
|
254
|
+
|
|
255
|
+
start_time = time.time()
|
|
256
|
+
last_status = None
|
|
257
|
+
|
|
258
|
+
while time.time() - start_time < max_wait_time:
|
|
259
|
+
# Get session status
|
|
260
|
+
session_data = self._ctx.get_json(f"/compute/session/{self.session_id}")
|
|
261
|
+
jobs = session_data.get('jobs', {})
|
|
262
|
+
|
|
263
|
+
# Find our predictor's job
|
|
264
|
+
predictor_job = None
|
|
265
|
+
for job_id, job in jobs.items():
|
|
266
|
+
if job.get('job_type') == 'train_single_predictor':
|
|
267
|
+
job_target = job.get('target_column')
|
|
268
|
+
if job_target == self.target_column:
|
|
269
|
+
predictor_job = job
|
|
270
|
+
break
|
|
271
|
+
|
|
272
|
+
if predictor_job:
|
|
273
|
+
status = predictor_job.get('status', 'unknown')
|
|
274
|
+
|
|
275
|
+
if status != last_status and show_progress:
|
|
276
|
+
elapsed = int(time.time() - start_time)
|
|
277
|
+
print(f"[{elapsed}s] Predictor training: {status}")
|
|
278
|
+
last_status = status
|
|
279
|
+
|
|
280
|
+
if status == 'done':
|
|
281
|
+
self.status = 'done'
|
|
282
|
+
# Try to get metrics
|
|
283
|
+
self._update_metrics()
|
|
284
|
+
if show_progress:
|
|
285
|
+
print(f"Predictor training complete!")
|
|
286
|
+
if self.accuracy:
|
|
287
|
+
print(f" Accuracy: {self.accuracy:.4f}")
|
|
288
|
+
if self.auc:
|
|
289
|
+
print(f" AUC: {self.auc:.4f}")
|
|
290
|
+
return self
|
|
291
|
+
|
|
292
|
+
elif status == 'failed':
|
|
293
|
+
error_msg = predictor_job.get('error', 'Unknown error')
|
|
294
|
+
self.status = 'error'
|
|
295
|
+
raise RuntimeError(f"Predictor training failed: {error_msg}")
|
|
296
|
+
|
|
297
|
+
time.sleep(poll_interval)
|
|
298
|
+
|
|
299
|
+
raise TimeoutError(f"Predictor training did not complete within {max_wait_time}s")
|
|
300
|
+
|
|
301
|
+
def train_more(
|
|
302
|
+
self,
|
|
303
|
+
epochs: int = 50,
|
|
304
|
+
**kwargs
|
|
305
|
+
) -> 'Predictor':
|
|
306
|
+
"""
|
|
307
|
+
Continue training the predictor.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
epochs: Additional epochs to train
|
|
311
|
+
**kwargs: Additional training parameters
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
Self (training started)
|
|
315
|
+
"""
|
|
316
|
+
if not self._ctx:
|
|
317
|
+
raise ValueError("Predictor not connected to client")
|
|
318
|
+
|
|
319
|
+
data = {
|
|
320
|
+
"epochs": epochs,
|
|
321
|
+
"target_column": self.target_column,
|
|
322
|
+
**kwargs
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
self._ctx.post_json(
|
|
326
|
+
f"/compute/session/{self.session_id}/train_predictor_more",
|
|
327
|
+
data=data
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
self.status = "training"
|
|
331
|
+
return self
|
|
332
|
+
|
|
333
|
+
def create_api_endpoint(
|
|
334
|
+
self,
|
|
335
|
+
name: str,
|
|
336
|
+
api_key: Optional[str] = None,
|
|
337
|
+
description: Optional[str] = None
|
|
338
|
+
) -> APIEndpoint:
|
|
339
|
+
"""
|
|
340
|
+
Create a named API endpoint for this predictor.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
name: Endpoint name
|
|
344
|
+
api_key: API key (if None, auto-generate)
|
|
345
|
+
description: Endpoint description
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
APIEndpoint object
|
|
349
|
+
|
|
350
|
+
Example:
|
|
351
|
+
endpoint = predictor.create_api_endpoint(
|
|
352
|
+
name="production_api",
|
|
353
|
+
description="Production endpoint"
|
|
354
|
+
)
|
|
355
|
+
print(f"API Key: {endpoint.api_key}")
|
|
356
|
+
"""
|
|
357
|
+
if not self._ctx:
|
|
358
|
+
raise ValueError("Predictor not connected to client")
|
|
359
|
+
|
|
360
|
+
data = {
|
|
361
|
+
"name": name,
|
|
362
|
+
"predictor_id": self.id,
|
|
363
|
+
}
|
|
364
|
+
if api_key:
|
|
365
|
+
data["api_key"] = api_key
|
|
366
|
+
if description:
|
|
367
|
+
data["description"] = description
|
|
368
|
+
|
|
369
|
+
response = self._ctx.post_json(
|
|
370
|
+
f"/session/{self.session_id}/create_endpoint",
|
|
371
|
+
data=data
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
return APIEndpoint.from_response(
|
|
375
|
+
response=response,
|
|
376
|
+
predictor_id=self.id,
|
|
377
|
+
session_id=self.session_id,
|
|
378
|
+
ctx=self._ctx,
|
|
379
|
+
predictor=self,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
def configure_webhooks(
|
|
383
|
+
self,
|
|
384
|
+
training_finished: Optional[str] = None,
|
|
385
|
+
training_started: Optional[str] = None,
|
|
386
|
+
alert_drift: Optional[str] = None,
|
|
387
|
+
alert_performance_degradation: Optional[str] = None,
|
|
388
|
+
alert_error_rate: Optional[str] = None,
|
|
389
|
+
alert_quota_threshold: Optional[str] = None,
|
|
390
|
+
prediction_error: Optional[str] = None,
|
|
391
|
+
usage: Optional[str] = None,
|
|
392
|
+
batch_job_completed: Optional[str] = None,
|
|
393
|
+
webhook_secret: Optional[str] = None
|
|
394
|
+
) -> Dict[str, Any]:
|
|
395
|
+
"""
|
|
396
|
+
Configure webhooks for various predictor events.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
training_finished: URL for training completion webhook
|
|
400
|
+
training_started: URL for training start webhook
|
|
401
|
+
alert_drift: URL for data drift alerts
|
|
402
|
+
alert_performance_degradation: URL for performance alerts
|
|
403
|
+
alert_error_rate: URL for error rate alerts
|
|
404
|
+
alert_quota_threshold: URL for quota threshold alerts
|
|
405
|
+
prediction_error: URL for prediction error webhook
|
|
406
|
+
usage: URL for usage statistics (can be spammy!)
|
|
407
|
+
batch_job_completed: URL for batch job completion
|
|
408
|
+
webhook_secret: Secret for webhook verification
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
Current webhook configuration
|
|
412
|
+
|
|
413
|
+
Example:
|
|
414
|
+
predictor.configure_webhooks(
|
|
415
|
+
training_finished="https://api.example.com/webhooks/training",
|
|
416
|
+
alert_drift="https://api.example.com/webhooks/drift",
|
|
417
|
+
webhook_secret="my_secret_key"
|
|
418
|
+
)
|
|
419
|
+
"""
|
|
420
|
+
if not self._ctx:
|
|
421
|
+
raise ValueError("Predictor not connected to client")
|
|
422
|
+
|
|
423
|
+
webhooks = {}
|
|
424
|
+
if training_finished:
|
|
425
|
+
webhooks["training_finished"] = training_finished
|
|
426
|
+
if training_started:
|
|
427
|
+
webhooks["training_started"] = training_started
|
|
428
|
+
if alert_drift:
|
|
429
|
+
webhooks["alert_drift"] = alert_drift
|
|
430
|
+
if alert_performance_degradation:
|
|
431
|
+
webhooks["alert_performance_degradation"] = alert_performance_degradation
|
|
432
|
+
if alert_error_rate:
|
|
433
|
+
webhooks["alert_error_rate"] = alert_error_rate
|
|
434
|
+
if alert_quota_threshold:
|
|
435
|
+
webhooks["alert_quota_threshold"] = alert_quota_threshold
|
|
436
|
+
if prediction_error:
|
|
437
|
+
webhooks["prediction_error"] = prediction_error
|
|
438
|
+
if usage:
|
|
439
|
+
webhooks["usage"] = usage
|
|
440
|
+
if batch_job_completed:
|
|
441
|
+
webhooks["batch_job_completed"] = batch_job_completed
|
|
442
|
+
if webhook_secret:
|
|
443
|
+
webhooks["webhook_secret"] = webhook_secret
|
|
444
|
+
|
|
445
|
+
response = self._ctx.post_json(
|
|
446
|
+
f"/session/{self.session_id}/configure_webhooks",
|
|
447
|
+
data={"webhooks": webhooks, "predictor_id": self.id}
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
return response.get('webhooks', webhooks)
|
|
451
|
+
|
|
452
|
+
def get_webhooks(self) -> Dict[str, str]:
|
|
453
|
+
"""
|
|
454
|
+
Get current webhook configuration.
|
|
455
|
+
|
|
456
|
+
Returns:
|
|
457
|
+
Dictionary of webhook event types to URLs
|
|
458
|
+
"""
|
|
459
|
+
if not self._ctx:
|
|
460
|
+
raise ValueError("Predictor not connected to client")
|
|
461
|
+
|
|
462
|
+
response = self._ctx.get_json(
|
|
463
|
+
f"/session/{self.session_id}/webhooks"
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
return response.get('webhooks', {})
|
|
467
|
+
|
|
468
|
+
def disable_webhook(self, event_type: str) -> None:
|
|
469
|
+
"""
|
|
470
|
+
Disable a specific webhook event.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
event_type: Webhook event type to disable
|
|
474
|
+
"""
|
|
475
|
+
if not self._ctx:
|
|
476
|
+
raise ValueError("Predictor not connected to client")
|
|
477
|
+
|
|
478
|
+
self._ctx.post_json(
|
|
479
|
+
f"/session/{self.session_id}/disable_webhook",
|
|
480
|
+
data={"event_type": event_type, "predictor_id": self.id}
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
def get_metrics(self) -> Dict[str, Any]:
|
|
484
|
+
"""Get training metrics for this predictor."""
|
|
485
|
+
if not self._ctx:
|
|
486
|
+
raise ValueError("Predictor not connected to client")
|
|
487
|
+
|
|
488
|
+
return self._ctx.get_json(f"/session/{self.session_id}/training_metrics")
|
|
489
|
+
|
|
490
|
+
def _update_metrics(self) -> None:
|
|
491
|
+
"""Update metrics from server."""
|
|
492
|
+
try:
|
|
493
|
+
metrics = self.get_metrics()
|
|
494
|
+
# Look for our predictor's metrics
|
|
495
|
+
sp_metrics = metrics.get('single_predictor', {})
|
|
496
|
+
if sp_metrics:
|
|
497
|
+
self.accuracy = sp_metrics.get('accuracy')
|
|
498
|
+
self.auc = sp_metrics.get('roc_auc') or sp_metrics.get('auc')
|
|
499
|
+
self.f1 = sp_metrics.get('f1') or sp_metrics.get('f1_score')
|
|
500
|
+
except Exception:
|
|
501
|
+
pass # Metrics may not be available yet
|
|
502
|
+
|
|
503
|
+
def _clean_record(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
|
504
|
+
"""Clean a record for API submission."""
|
|
505
|
+
import math
|
|
506
|
+
|
|
507
|
+
cleaned = {}
|
|
508
|
+
for key, value in record.items():
|
|
509
|
+
# Handle NaN/Inf
|
|
510
|
+
if isinstance(value, float):
|
|
511
|
+
if math.isnan(value) or math.isinf(value):
|
|
512
|
+
value = None
|
|
513
|
+
# Handle numpy types
|
|
514
|
+
if hasattr(value, 'item'):
|
|
515
|
+
value = value.item()
|
|
516
|
+
cleaned[key] = value
|
|
517
|
+
return cleaned
|
|
518
|
+
|
|
519
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
520
|
+
"""Convert to dictionary representation."""
|
|
521
|
+
return {
|
|
522
|
+
'id': self.id,
|
|
523
|
+
'session_id': self.session_id,
|
|
524
|
+
'target_column': self.target_column,
|
|
525
|
+
'target_type': self.target_type,
|
|
526
|
+
'name': self.name,
|
|
527
|
+
'status': self.status,
|
|
528
|
+
'accuracy': self.accuracy,
|
|
529
|
+
'auc': self.auc,
|
|
530
|
+
'f1': self.f1,
|
|
531
|
+
'created_at': self.created_at.isoformat() if self.created_at else None,
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
def __repr__(self) -> str:
|
|
535
|
+
status_str = f", status='{self.status}'" if self.status else ""
|
|
536
|
+
acc_str = f", accuracy={self.accuracy:.4f}" if self.accuracy else ""
|
|
537
|
+
return f"Predictor(id='{self.id}', target='{self.target_column}'{status_str}{acc_str})"
|