featrixsphere 0.2.5566__py3-none-any.whl → 0.2.6127__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,537 @@
1
+ """
2
+ Predictor class for FeatrixSphere API.
3
+
4
+ Represents a trained predictor (classifier or regressor).
5
+ """
6
+
7
+ import time
8
+ import logging
9
+ from dataclasses import dataclass, field
10
+ from datetime import datetime
11
+ from typing import Dict, Any, Optional, List, Union, TYPE_CHECKING
12
+
13
+ if TYPE_CHECKING:
14
+ from .http_client import ClientContext
15
+ from .foundational_model import FoundationalModel
16
+ import pandas as pd
17
+
18
+ from .prediction_result import PredictionResult
19
+ from .api_endpoint import APIEndpoint
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class Predictor:
26
+ """
27
+ Represents a trained predictor (classifier or regressor).
28
+
29
+ Attributes:
30
+ id: Predictor ID
31
+ name: Predictor name
32
+ target_column: Target column name
33
+ target_type: Target type ("set", "numeric", "binary")
34
+ status: Training status ("training", "done", "error")
35
+ session_id: Parent session ID
36
+ accuracy: Training accuracy (if available)
37
+ created_at: Creation timestamp
38
+
39
+ Usage:
40
+ # Create from foundational model
41
+ predictor = fm.create_classifier(
42
+ name="churn_predictor",
43
+ target_column="churned"
44
+ )
45
+
46
+ # Wait for training
47
+ predictor.wait_for_training()
48
+
49
+ # Make prediction
50
+ result = predictor.predict({"age": 35, "income": 50000})
51
+ print(result.predicted_class)
52
+ print(result.confidence)
53
+
54
+ # Batch predictions
55
+ results = predictor.batch_predict([
56
+ {"age": 35, "income": 50000},
57
+ {"age": 42, "income": 75000}
58
+ ])
59
+ """
60
+
61
+ id: str
62
+ session_id: str
63
+ target_column: str
64
+ target_type: str = "set"
65
+ name: Optional[str] = None
66
+ status: Optional[str] = None
67
+ accuracy: Optional[float] = None
68
+ auc: Optional[float] = None
69
+ f1: Optional[float] = None
70
+ created_at: Optional[datetime] = None
71
+
72
+ # Internal
73
+ _ctx: Optional['ClientContext'] = field(default=None, repr=False)
74
+ _foundational_model: Optional['FoundationalModel'] = field(default=None, repr=False)
75
+
76
+ @classmethod
77
+ def from_response(
78
+ cls,
79
+ response: Dict[str, Any],
80
+ session_id: str,
81
+ ctx: Optional['ClientContext'] = None,
82
+ foundational_model: Optional['FoundationalModel'] = None
83
+ ) -> 'Predictor':
84
+ """Create Predictor from API response."""
85
+ return cls(
86
+ id=response.get('predictor_id') or response.get('id', ''),
87
+ session_id=session_id,
88
+ target_column=response.get('target_column', ''),
89
+ target_type=response.get('target_type') or response.get('target_column_type', 'set'),
90
+ name=response.get('name'),
91
+ status=response.get('status'),
92
+ accuracy=response.get('accuracy'),
93
+ auc=response.get('auc') or response.get('roc_auc'),
94
+ f1=response.get('f1') or response.get('f1_score'),
95
+ created_at=datetime.now(),
96
+ _ctx=ctx,
97
+ _foundational_model=foundational_model,
98
+ )
99
+
100
+ @property
101
+ def foundational_model(self) -> Optional['FoundationalModel']:
102
+ """Get the parent foundational model."""
103
+ return self._foundational_model
104
+
105
+ def predict(
106
+ self,
107
+ record: Dict[str, Any],
108
+ best_metric_preference: Optional[str] = None
109
+ ) -> PredictionResult:
110
+ """
111
+ Make a single prediction.
112
+
113
+ Args:
114
+ record: Input record dictionary
115
+ best_metric_preference: Metric checkpoint to use ("roc_auc", "pr_auc", or None)
116
+
117
+ Returns:
118
+ PredictionResult with prediction, confidence, and prediction_uuid
119
+
120
+ Example:
121
+ result = predictor.predict({"age": 35, "income": 50000})
122
+ print(result.predicted_class) # "churned"
123
+ print(result.confidence) # 0.87
124
+ print(result.prediction_uuid) # UUID for feedback
125
+ """
126
+ if not self._ctx:
127
+ raise ValueError("Predictor not connected to client")
128
+
129
+ # Clean the record
130
+ cleaned_record = self._clean_record(record)
131
+
132
+ # Build request
133
+ request_payload = {
134
+ "query_record": cleaned_record,
135
+ "predictor_id": self.id,
136
+ }
137
+ if best_metric_preference:
138
+ request_payload["best_metric_preference"] = best_metric_preference
139
+
140
+ # Make request
141
+ response = self._ctx.post_json(
142
+ f"/session/{self.session_id}/predict",
143
+ data=request_payload
144
+ )
145
+
146
+ return PredictionResult.from_response(response, cleaned_record, self._ctx)
147
+
148
+ def batch_predict(
149
+ self,
150
+ records: Union[List[Dict[str, Any]], 'pd.DataFrame'],
151
+ show_progress: bool = True,
152
+ best_metric_preference: Optional[str] = None
153
+ ) -> List[PredictionResult]:
154
+ """
155
+ Make batch predictions.
156
+
157
+ Args:
158
+ records: List of record dictionaries or DataFrame
159
+ show_progress: Show progress bar
160
+ best_metric_preference: Metric checkpoint to use
161
+
162
+ Returns:
163
+ List of PredictionResult objects
164
+ """
165
+ if not self._ctx:
166
+ raise ValueError("Predictor not connected to client")
167
+
168
+ # Convert DataFrame to list of dicts if needed
169
+ if hasattr(records, 'to_dict'):
170
+ records = records.to_dict('records')
171
+
172
+ # Clean records
173
+ cleaned_records = [self._clean_record(r) for r in records]
174
+
175
+ # Build request
176
+ request_payload = {
177
+ "records": cleaned_records,
178
+ "predictor_id": self.id,
179
+ }
180
+ if best_metric_preference:
181
+ request_payload["best_metric_preference"] = best_metric_preference
182
+
183
+ # Make request using predict_table endpoint
184
+ response = self._ctx.post_json(
185
+ f"/session/{self.session_id}/predict_table",
186
+ data={"table": {"rows": cleaned_records}, "predictor_id": self.id}
187
+ )
188
+
189
+ # Parse results
190
+ results = []
191
+ predictions = response.get('predictions', [])
192
+
193
+ for i, pred in enumerate(predictions):
194
+ record = cleaned_records[i] if i < len(cleaned_records) else {}
195
+ results.append(PredictionResult.from_response(pred, record, self._ctx))
196
+
197
+ return results
198
+
199
+ def explain(
200
+ self,
201
+ record: Dict[str, Any],
202
+ class_idx: Optional[int] = None
203
+ ) -> Dict[str, Any]:
204
+ """
205
+ Explain a prediction using gradient attribution.
206
+
207
+ Args:
208
+ record: Input record to explain
209
+ class_idx: Class index to explain (for multi-class)
210
+
211
+ Returns:
212
+ Explanation dictionary with feature attributions
213
+ """
214
+ if not self._ctx:
215
+ raise ValueError("Predictor not connected to client")
216
+
217
+ cleaned_record = self._clean_record(record)
218
+
219
+ request_payload = {
220
+ "query_record": cleaned_record,
221
+ "predictor_id": self.id,
222
+ }
223
+ if class_idx is not None:
224
+ request_payload["class_idx"] = class_idx
225
+
226
+ return self._ctx.post_json(
227
+ f"/session/{self.session_id}/explain",
228
+ data=request_payload
229
+ )
230
+
231
+ def wait_for_training(
232
+ self,
233
+ max_wait_time: int = 3600,
234
+ poll_interval: int = 10,
235
+ show_progress: bool = True
236
+ ) -> 'Predictor':
237
+ """
238
+ Wait for predictor training to complete.
239
+
240
+ Args:
241
+ max_wait_time: Maximum wait time in seconds
242
+ poll_interval: Polling interval in seconds
243
+ show_progress: Print progress updates
244
+
245
+ Returns:
246
+ Self (updated with final status)
247
+
248
+ Raises:
249
+ TimeoutError: If training doesn't complete in time
250
+ RuntimeError: If training fails
251
+ """
252
+ if not self._ctx:
253
+ raise ValueError("Predictor not connected to client")
254
+
255
+ start_time = time.time()
256
+ last_status = None
257
+
258
+ while time.time() - start_time < max_wait_time:
259
+ # Get session status
260
+ session_data = self._ctx.get_json(f"/compute/session/{self.session_id}")
261
+ jobs = session_data.get('jobs', {})
262
+
263
+ # Find our predictor's job
264
+ predictor_job = None
265
+ for job_id, job in jobs.items():
266
+ if job.get('job_type') == 'train_single_predictor':
267
+ job_target = job.get('target_column')
268
+ if job_target == self.target_column:
269
+ predictor_job = job
270
+ break
271
+
272
+ if predictor_job:
273
+ status = predictor_job.get('status', 'unknown')
274
+
275
+ if status != last_status and show_progress:
276
+ elapsed = int(time.time() - start_time)
277
+ print(f"[{elapsed}s] Predictor training: {status}")
278
+ last_status = status
279
+
280
+ if status == 'done':
281
+ self.status = 'done'
282
+ # Try to get metrics
283
+ self._update_metrics()
284
+ if show_progress:
285
+ print(f"Predictor training complete!")
286
+ if self.accuracy:
287
+ print(f" Accuracy: {self.accuracy:.4f}")
288
+ if self.auc:
289
+ print(f" AUC: {self.auc:.4f}")
290
+ return self
291
+
292
+ elif status == 'failed':
293
+ error_msg = predictor_job.get('error', 'Unknown error')
294
+ self.status = 'error'
295
+ raise RuntimeError(f"Predictor training failed: {error_msg}")
296
+
297
+ time.sleep(poll_interval)
298
+
299
+ raise TimeoutError(f"Predictor training did not complete within {max_wait_time}s")
300
+
301
+ def train_more(
302
+ self,
303
+ epochs: int = 50,
304
+ **kwargs
305
+ ) -> 'Predictor':
306
+ """
307
+ Continue training the predictor.
308
+
309
+ Args:
310
+ epochs: Additional epochs to train
311
+ **kwargs: Additional training parameters
312
+
313
+ Returns:
314
+ Self (training started)
315
+ """
316
+ if not self._ctx:
317
+ raise ValueError("Predictor not connected to client")
318
+
319
+ data = {
320
+ "epochs": epochs,
321
+ "target_column": self.target_column,
322
+ **kwargs
323
+ }
324
+
325
+ self._ctx.post_json(
326
+ f"/compute/session/{self.session_id}/train_predictor_more",
327
+ data=data
328
+ )
329
+
330
+ self.status = "training"
331
+ return self
332
+
333
+ def create_api_endpoint(
334
+ self,
335
+ name: str,
336
+ api_key: Optional[str] = None,
337
+ description: Optional[str] = None
338
+ ) -> APIEndpoint:
339
+ """
340
+ Create a named API endpoint for this predictor.
341
+
342
+ Args:
343
+ name: Endpoint name
344
+ api_key: API key (if None, auto-generate)
345
+ description: Endpoint description
346
+
347
+ Returns:
348
+ APIEndpoint object
349
+
350
+ Example:
351
+ endpoint = predictor.create_api_endpoint(
352
+ name="production_api",
353
+ description="Production endpoint"
354
+ )
355
+ print(f"API Key: {endpoint.api_key}")
356
+ """
357
+ if not self._ctx:
358
+ raise ValueError("Predictor not connected to client")
359
+
360
+ data = {
361
+ "name": name,
362
+ "predictor_id": self.id,
363
+ }
364
+ if api_key:
365
+ data["api_key"] = api_key
366
+ if description:
367
+ data["description"] = description
368
+
369
+ response = self._ctx.post_json(
370
+ f"/session/{self.session_id}/create_endpoint",
371
+ data=data
372
+ )
373
+
374
+ return APIEndpoint.from_response(
375
+ response=response,
376
+ predictor_id=self.id,
377
+ session_id=self.session_id,
378
+ ctx=self._ctx,
379
+ predictor=self,
380
+ )
381
+
382
+ def configure_webhooks(
383
+ self,
384
+ training_finished: Optional[str] = None,
385
+ training_started: Optional[str] = None,
386
+ alert_drift: Optional[str] = None,
387
+ alert_performance_degradation: Optional[str] = None,
388
+ alert_error_rate: Optional[str] = None,
389
+ alert_quota_threshold: Optional[str] = None,
390
+ prediction_error: Optional[str] = None,
391
+ usage: Optional[str] = None,
392
+ batch_job_completed: Optional[str] = None,
393
+ webhook_secret: Optional[str] = None
394
+ ) -> Dict[str, Any]:
395
+ """
396
+ Configure webhooks for various predictor events.
397
+
398
+ Args:
399
+ training_finished: URL for training completion webhook
400
+ training_started: URL for training start webhook
401
+ alert_drift: URL for data drift alerts
402
+ alert_performance_degradation: URL for performance alerts
403
+ alert_error_rate: URL for error rate alerts
404
+ alert_quota_threshold: URL for quota threshold alerts
405
+ prediction_error: URL for prediction error webhook
406
+ usage: URL for usage statistics (can be spammy!)
407
+ batch_job_completed: URL for batch job completion
408
+ webhook_secret: Secret for webhook verification
409
+
410
+ Returns:
411
+ Current webhook configuration
412
+
413
+ Example:
414
+ predictor.configure_webhooks(
415
+ training_finished="https://api.example.com/webhooks/training",
416
+ alert_drift="https://api.example.com/webhooks/drift",
417
+ webhook_secret="my_secret_key"
418
+ )
419
+ """
420
+ if not self._ctx:
421
+ raise ValueError("Predictor not connected to client")
422
+
423
+ webhooks = {}
424
+ if training_finished:
425
+ webhooks["training_finished"] = training_finished
426
+ if training_started:
427
+ webhooks["training_started"] = training_started
428
+ if alert_drift:
429
+ webhooks["alert_drift"] = alert_drift
430
+ if alert_performance_degradation:
431
+ webhooks["alert_performance_degradation"] = alert_performance_degradation
432
+ if alert_error_rate:
433
+ webhooks["alert_error_rate"] = alert_error_rate
434
+ if alert_quota_threshold:
435
+ webhooks["alert_quota_threshold"] = alert_quota_threshold
436
+ if prediction_error:
437
+ webhooks["prediction_error"] = prediction_error
438
+ if usage:
439
+ webhooks["usage"] = usage
440
+ if batch_job_completed:
441
+ webhooks["batch_job_completed"] = batch_job_completed
442
+ if webhook_secret:
443
+ webhooks["webhook_secret"] = webhook_secret
444
+
445
+ response = self._ctx.post_json(
446
+ f"/session/{self.session_id}/configure_webhooks",
447
+ data={"webhooks": webhooks, "predictor_id": self.id}
448
+ )
449
+
450
+ return response.get('webhooks', webhooks)
451
+
452
+ def get_webhooks(self) -> Dict[str, str]:
453
+ """
454
+ Get current webhook configuration.
455
+
456
+ Returns:
457
+ Dictionary of webhook event types to URLs
458
+ """
459
+ if not self._ctx:
460
+ raise ValueError("Predictor not connected to client")
461
+
462
+ response = self._ctx.get_json(
463
+ f"/session/{self.session_id}/webhooks"
464
+ )
465
+
466
+ return response.get('webhooks', {})
467
+
468
+ def disable_webhook(self, event_type: str) -> None:
469
+ """
470
+ Disable a specific webhook event.
471
+
472
+ Args:
473
+ event_type: Webhook event type to disable
474
+ """
475
+ if not self._ctx:
476
+ raise ValueError("Predictor not connected to client")
477
+
478
+ self._ctx.post_json(
479
+ f"/session/{self.session_id}/disable_webhook",
480
+ data={"event_type": event_type, "predictor_id": self.id}
481
+ )
482
+
483
+ def get_metrics(self) -> Dict[str, Any]:
484
+ """Get training metrics for this predictor."""
485
+ if not self._ctx:
486
+ raise ValueError("Predictor not connected to client")
487
+
488
+ return self._ctx.get_json(f"/session/{self.session_id}/training_metrics")
489
+
490
+ def _update_metrics(self) -> None:
491
+ """Update metrics from server."""
492
+ try:
493
+ metrics = self.get_metrics()
494
+ # Look for our predictor's metrics
495
+ sp_metrics = metrics.get('single_predictor', {})
496
+ if sp_metrics:
497
+ self.accuracy = sp_metrics.get('accuracy')
498
+ self.auc = sp_metrics.get('roc_auc') or sp_metrics.get('auc')
499
+ self.f1 = sp_metrics.get('f1') or sp_metrics.get('f1_score')
500
+ except Exception:
501
+ pass # Metrics may not be available yet
502
+
503
+ def _clean_record(self, record: Dict[str, Any]) -> Dict[str, Any]:
504
+ """Clean a record for API submission."""
505
+ import math
506
+
507
+ cleaned = {}
508
+ for key, value in record.items():
509
+ # Handle NaN/Inf
510
+ if isinstance(value, float):
511
+ if math.isnan(value) or math.isinf(value):
512
+ value = None
513
+ # Handle numpy types
514
+ if hasattr(value, 'item'):
515
+ value = value.item()
516
+ cleaned[key] = value
517
+ return cleaned
518
+
519
+ def to_dict(self) -> Dict[str, Any]:
520
+ """Convert to dictionary representation."""
521
+ return {
522
+ 'id': self.id,
523
+ 'session_id': self.session_id,
524
+ 'target_column': self.target_column,
525
+ 'target_type': self.target_type,
526
+ 'name': self.name,
527
+ 'status': self.status,
528
+ 'accuracy': self.accuracy,
529
+ 'auc': self.auc,
530
+ 'f1': self.f1,
531
+ 'created_at': self.created_at.isoformat() if self.created_at else None,
532
+ }
533
+
534
+ def __repr__(self) -> str:
535
+ status_str = f", status='{self.status}'" if self.status else ""
536
+ acc_str = f", accuracy={self.accuracy:.4f}" if self.accuracy else ""
537
+ return f"Predictor(id='{self.id}', target='{self.target_column}'{status_str}{acc_str})"