featrixsphere 0.2.5563__py3-none-any.whl → 0.2.5978__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,658 @@
1
+ """
2
+ FoundationalModel class for FeatrixSphere API.
3
+
4
+ Represents a trained embedding space (foundational model).
5
+ """
6
+
7
+ import time
8
+ import logging
9
+ from dataclasses import dataclass, field
10
+ from datetime import datetime
11
+ from typing import Dict, Any, Optional, List, Union, TYPE_CHECKING
12
+
13
+ if TYPE_CHECKING:
14
+ from .http_client import ClientContext
15
+ import pandas as pd
16
+ import numpy as np
17
+
18
+ from .predictor import Predictor
19
+ from .vector_database import VectorDatabase
20
+ from .reference_record import ReferenceRecord
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class FoundationalModel:
27
+ """
28
+ Represents a foundational model (embedding space).
29
+
30
+ Attributes:
31
+ id: Session ID / FM ID
32
+ name: Model name
33
+ status: Training status ("training", "done", "error")
34
+ dimensions: Embedding dimensions (d_model)
35
+ epochs: Training epochs completed
36
+ created_at: Creation timestamp
37
+
38
+ Usage:
39
+ # Create from client
40
+ fm = featrix.create_foundational_model(
41
+ name="customer_embeddings",
42
+ csv_file="customers.csv"
43
+ )
44
+
45
+ # Wait for training
46
+ fm.wait_for_training()
47
+
48
+ # Create classifier
49
+ predictor = fm.create_classifier(
50
+ name="churn_predictor",
51
+ target_column="churned"
52
+ )
53
+
54
+ # Create vector database
55
+ vdb = fm.create_vector_database(
56
+ name="customer_search",
57
+ records=customer_records
58
+ )
59
+
60
+ # Encode records
61
+ vectors = fm.encode([{"age": 35}, {"age": 42}])
62
+ """
63
+
64
+ id: str
65
+ name: Optional[str] = None
66
+ status: Optional[str] = None
67
+ dimensions: Optional[int] = None
68
+ epochs: Optional[int] = None
69
+ final_loss: Optional[float] = None
70
+ created_at: Optional[datetime] = None
71
+
72
+ # Internal
73
+ _ctx: Optional['ClientContext'] = field(default=None, repr=False)
74
+
75
+ @classmethod
76
+ def from_response(
77
+ cls,
78
+ response: Dict[str, Any],
79
+ ctx: Optional['ClientContext'] = None
80
+ ) -> 'FoundationalModel':
81
+ """Create FoundationalModel from API response."""
82
+ session_id = response.get('session_id', '')
83
+
84
+ return cls(
85
+ id=session_id,
86
+ name=response.get('name'),
87
+ status=response.get('status'),
88
+ dimensions=response.get('d_model') or response.get('dimensions'),
89
+ epochs=response.get('epochs') or response.get('final_epoch'),
90
+ final_loss=response.get('final_loss'),
91
+ created_at=datetime.now(),
92
+ _ctx=ctx,
93
+ )
94
+
95
+ @classmethod
96
+ def from_session_id(
97
+ cls,
98
+ session_id: str,
99
+ ctx: 'ClientContext'
100
+ ) -> 'FoundationalModel':
101
+ """Load FoundationalModel from session ID."""
102
+ # Get session info
103
+ session_data = ctx.get_json(f"/compute/session/{session_id}")
104
+
105
+ fm = cls(
106
+ id=session_id,
107
+ name=session_data.get('name'),
108
+ status=session_data.get('status'),
109
+ created_at=datetime.now(),
110
+ _ctx=ctx,
111
+ )
112
+
113
+ # Try to get model info
114
+ fm._update_from_session(session_data)
115
+
116
+ return fm
117
+
118
+ def create_classifier(
119
+ self,
120
+ target_column: str,
121
+ name: Optional[str] = None,
122
+ labels_file: Optional[str] = None,
123
+ labels_df: Optional['pd.DataFrame'] = None,
124
+ epochs: int = 0,
125
+ rare_label_value: Optional[str] = None,
126
+ class_imbalance: Optional[Dict[str, float]] = None,
127
+ webhooks: Optional[Dict[str, str]] = None,
128
+ **kwargs
129
+ ) -> Predictor:
130
+ """
131
+ Create a classifier predictor from this foundational model.
132
+
133
+ Args:
134
+ target_column: Column name to predict
135
+ name: Predictor name (optional)
136
+ labels_file: Path to labels file (optional - use existing data if not provided)
137
+ labels_df: DataFrame with labels (optional)
138
+ epochs: Training epochs (0 = auto)
139
+ rare_label_value: Rare class for metrics
140
+ class_imbalance: Expected class distribution
141
+ webhooks: Webhook URLs for events
142
+ **kwargs: Additional training parameters
143
+
144
+ Returns:
145
+ Predictor object (training started)
146
+
147
+ Example:
148
+ predictor = fm.create_classifier(
149
+ target_column="churned",
150
+ name="churn_predictor"
151
+ )
152
+ predictor.wait_for_training()
153
+ """
154
+ return self._create_predictor(
155
+ target_column=target_column,
156
+ target_type="set",
157
+ name=name,
158
+ labels_file=labels_file,
159
+ labels_df=labels_df,
160
+ epochs=epochs,
161
+ rare_label_value=rare_label_value,
162
+ class_imbalance=class_imbalance,
163
+ webhooks=webhooks,
164
+ **kwargs
165
+ )
166
+
167
+ def create_regressor(
168
+ self,
169
+ target_column: str,
170
+ name: Optional[str] = None,
171
+ labels_file: Optional[str] = None,
172
+ labels_df: Optional['pd.DataFrame'] = None,
173
+ epochs: int = 0,
174
+ webhooks: Optional[Dict[str, str]] = None,
175
+ **kwargs
176
+ ) -> Predictor:
177
+ """
178
+ Create a regressor predictor from this foundational model.
179
+
180
+ Args:
181
+ target_column: Column name to predict
182
+ name: Predictor name (optional)
183
+ labels_file: Path to labels file (optional)
184
+ labels_df: DataFrame with labels (optional)
185
+ epochs: Training epochs (0 = auto)
186
+ webhooks: Webhook URLs for events
187
+ **kwargs: Additional training parameters
188
+
189
+ Returns:
190
+ Predictor object (training started)
191
+ """
192
+ return self._create_predictor(
193
+ target_column=target_column,
194
+ target_type="numeric",
195
+ name=name,
196
+ labels_file=labels_file,
197
+ labels_df=labels_df,
198
+ epochs=epochs,
199
+ webhooks=webhooks,
200
+ **kwargs
201
+ )
202
+
203
+ def _create_predictor(
204
+ self,
205
+ target_column: str,
206
+ target_type: str,
207
+ name: Optional[str] = None,
208
+ labels_file: Optional[str] = None,
209
+ labels_df: Optional['pd.DataFrame'] = None,
210
+ epochs: int = 0,
211
+ rare_label_value: Optional[str] = None,
212
+ class_imbalance: Optional[Dict[str, float]] = None,
213
+ webhooks: Optional[Dict[str, str]] = None,
214
+ **kwargs
215
+ ) -> Predictor:
216
+ """Internal method to create predictor."""
217
+ if not self._ctx:
218
+ raise ValueError("FoundationalModel not connected to client")
219
+
220
+ # If labels file provided, use the file upload endpoint
221
+ if labels_file or labels_df is not None:
222
+ return self._create_predictor_with_labels(
223
+ target_column=target_column,
224
+ target_type=target_type,
225
+ labels_file=labels_file,
226
+ labels_df=labels_df,
227
+ epochs=epochs,
228
+ rare_label_value=rare_label_value,
229
+ class_imbalance=class_imbalance,
230
+ webhooks=webhooks,
231
+ **kwargs
232
+ )
233
+
234
+ # Use simple endpoint (train on existing session data)
235
+ data = {
236
+ "target_column": target_column,
237
+ "target_column_type": target_type,
238
+ "epochs": epochs,
239
+ }
240
+ if rare_label_value:
241
+ data["rare_label_value"] = rare_label_value
242
+ if class_imbalance:
243
+ data["class_imbalance"] = class_imbalance
244
+ if webhooks:
245
+ data["webhooks"] = webhooks
246
+ data.update(kwargs)
247
+
248
+ response = self._ctx.post_json(
249
+ f"/compute/session/{self.id}/train_predictor",
250
+ data=data
251
+ )
252
+
253
+ predictor = Predictor(
254
+ id=response.get('predictor_id', ''),
255
+ session_id=self.id,
256
+ target_column=target_column,
257
+ target_type=target_type,
258
+ name=name,
259
+ status="training",
260
+ _ctx=self._ctx,
261
+ _foundational_model=self,
262
+ )
263
+
264
+ return predictor
265
+
266
+ def _create_predictor_with_labels(
267
+ self,
268
+ target_column: str,
269
+ target_type: str,
270
+ labels_file: Optional[str] = None,
271
+ labels_df: Optional['pd.DataFrame'] = None,
272
+ epochs: int = 0,
273
+ rare_label_value: Optional[str] = None,
274
+ class_imbalance: Optional[Dict[str, float]] = None,
275
+ webhooks: Optional[Dict[str, str]] = None,
276
+ **kwargs
277
+ ) -> Predictor:
278
+ """Create predictor with separate labels file."""
279
+ import io
280
+ import gzip
281
+
282
+ # Prepare the labels data
283
+ if labels_df is not None:
284
+ # Convert DataFrame to CSV bytes
285
+ csv_buffer = io.StringIO()
286
+ labels_df.to_csv(csv_buffer, index=False)
287
+ file_content = csv_buffer.getvalue().encode('utf-8')
288
+ filename = "labels.csv"
289
+ elif labels_file:
290
+ with open(labels_file, 'rb') as f:
291
+ file_content = f.read()
292
+ filename = labels_file.split('/')[-1]
293
+ else:
294
+ raise ValueError("Either labels_file or labels_df must be provided")
295
+
296
+ # Compress if large
297
+ if len(file_content) > 100_000:
298
+ compressed = gzip.compress(file_content)
299
+ if len(compressed) < len(file_content):
300
+ file_content = compressed
301
+ filename = filename + '.gz'
302
+
303
+ # Build form data
304
+ form_data = {
305
+ "target_column": target_column,
306
+ "target_column_type": target_type,
307
+ "epochs": str(epochs),
308
+ }
309
+ if rare_label_value:
310
+ form_data["rare_label_value"] = rare_label_value
311
+ if class_imbalance:
312
+ import json
313
+ form_data["class_imbalance"] = json.dumps(class_imbalance)
314
+ if webhooks:
315
+ import json
316
+ form_data["webhooks"] = json.dumps(webhooks)
317
+
318
+ files = {
319
+ "file": (filename, file_content)
320
+ }
321
+
322
+ response = self._ctx.post_multipart(
323
+ f"/compute/session/{self.id}/train_predictor",
324
+ data=form_data,
325
+ files=files
326
+ )
327
+
328
+ return Predictor(
329
+ id=response.get('predictor_id', ''),
330
+ session_id=self.id,
331
+ target_column=target_column,
332
+ target_type=target_type,
333
+ status="training",
334
+ _ctx=self._ctx,
335
+ _foundational_model=self,
336
+ )
337
+
338
+ def create_vector_database(
339
+ self,
340
+ name: Optional[str] = None,
341
+ records: Optional[Union[List[Dict[str, Any]], 'pd.DataFrame']] = None
342
+ ) -> VectorDatabase:
343
+ """
344
+ Create a vector database from this foundational model.
345
+
346
+ Args:
347
+ name: Database name
348
+ records: Initial records to add (optional)
349
+
350
+ Returns:
351
+ VectorDatabase object
352
+
353
+ Example:
354
+ vdb = fm.create_vector_database(
355
+ name="customer_search",
356
+ records=customer_records
357
+ )
358
+ similar = vdb.similarity_search({"age": 35}, k=5)
359
+ """
360
+ vdb = VectorDatabase.from_session(
361
+ session_id=self.id,
362
+ name=name,
363
+ ctx=self._ctx,
364
+ foundational_model=self,
365
+ )
366
+
367
+ # Add initial records if provided
368
+ if records is not None:
369
+ vdb.add_records(records)
370
+
371
+ return vdb
372
+
373
+ def create_reference_record(
374
+ self,
375
+ record: Dict[str, Any],
376
+ name: Optional[str] = None
377
+ ) -> ReferenceRecord:
378
+ """
379
+ Create a reference record from a specific record for similarity search.
380
+
381
+ A reference record is a reference point in the embedding space that you can use
382
+ to find similar records. Particularly useful when you only have a positive
383
+ class but no negative class - just find more records like the positive example.
384
+
385
+ Args:
386
+ record: The record to create a reference from
387
+ name: Optional name for the reference record
388
+
389
+ Returns:
390
+ ReferenceRecord object that can be used for similarity search
391
+
392
+ Example:
393
+ # Create a reference record from a high-value customer
394
+ ref = fm.create_reference_record(
395
+ record={"age": 35, "income": 100000, "plan": "premium"},
396
+ name="high_value_customer"
397
+ )
398
+
399
+ # Find similar customers
400
+ similar = ref.find_similar(k=10, vector_database=vdb)
401
+ """
402
+ if not self._ctx:
403
+ raise ValueError("FoundationalModel not connected to client")
404
+
405
+ return ReferenceRecord.from_record(
406
+ record=record,
407
+ session_id=self.id,
408
+ name=name,
409
+ ctx=self._ctx,
410
+ foundational_model=self,
411
+ )
412
+
413
+ def wait_for_training(
414
+ self,
415
+ max_wait_time: int = 3600,
416
+ poll_interval: int = 10,
417
+ show_progress: bool = True
418
+ ) -> 'FoundationalModel':
419
+ """
420
+ Wait for foundational model training to complete.
421
+
422
+ Args:
423
+ max_wait_time: Maximum wait time in seconds
424
+ poll_interval: Polling interval in seconds
425
+ show_progress: Print progress updates
426
+
427
+ Returns:
428
+ Self (updated with final status)
429
+
430
+ Raises:
431
+ TimeoutError: If training doesn't complete in time
432
+ RuntimeError: If training fails
433
+ """
434
+ if not self._ctx:
435
+ raise ValueError("FoundationalModel not connected to client")
436
+
437
+ start_time = time.time()
438
+ last_epoch = None
439
+ last_status = None
440
+
441
+ while time.time() - start_time < max_wait_time:
442
+ # Get session status
443
+ session_data = self._ctx.get_json(f"/compute/session/{self.id}")
444
+ status = session_data.get('status', 'unknown')
445
+ jobs = session_data.get('jobs', {})
446
+
447
+ # Look for ES training job
448
+ es_job = None
449
+ for job_id, job in jobs.items():
450
+ if job.get('job_type') in ('train_embedding_space', 'train_es', 'training'):
451
+ es_job = job
452
+ break
453
+
454
+ # Get progress info
455
+ current_epoch = None
456
+ total_epochs = None
457
+ if es_job:
458
+ current_epoch = es_job.get('current_epoch') or es_job.get('epoch')
459
+ total_epochs = es_job.get('total_epochs') or es_job.get('epochs')
460
+ job_status = es_job.get('status', status)
461
+ else:
462
+ job_status = status
463
+
464
+ # Progress update
465
+ if show_progress:
466
+ elapsed = int(time.time() - start_time)
467
+ if current_epoch != last_epoch or job_status != last_status:
468
+ if current_epoch and total_epochs:
469
+ print(f"[{elapsed}s] Training: epoch {current_epoch}/{total_epochs} ({job_status})")
470
+ else:
471
+ print(f"[{elapsed}s] Training: {job_status}")
472
+ last_epoch = current_epoch
473
+ last_status = job_status
474
+
475
+ # Check completion
476
+ if job_status == 'done' or status == 'done':
477
+ self.status = 'done'
478
+ self._update_from_session(session_data)
479
+ if show_progress:
480
+ print(f"Training complete!")
481
+ if self.dimensions:
482
+ print(f" Dimensions: {self.dimensions}")
483
+ if self.epochs:
484
+ print(f" Epochs: {self.epochs}")
485
+ return self
486
+
487
+ elif job_status == 'failed' or status == 'failed':
488
+ error_msg = 'Unknown error'
489
+ if es_job:
490
+ error_msg = es_job.get('error', error_msg)
491
+ self.status = 'error'
492
+ raise RuntimeError(f"Training failed: {error_msg}")
493
+
494
+ time.sleep(poll_interval)
495
+
496
+ raise TimeoutError(f"Training did not complete within {max_wait_time}s")
497
+
498
+ def encode(
499
+ self,
500
+ records: Union[Dict[str, Any], List[Dict[str, Any]], 'pd.DataFrame']
501
+ ) -> List[List[float]]:
502
+ """
503
+ Encode records to embedding vectors.
504
+
505
+ Args:
506
+ records: Single record, list of records, or DataFrame
507
+
508
+ Returns:
509
+ List of embedding vectors (as lists of floats)
510
+
511
+ Example:
512
+ vectors = fm.encode([
513
+ {"age": 35, "income": 50000},
514
+ {"age": 42, "income": 75000}
515
+ ])
516
+ """
517
+ if not self._ctx:
518
+ raise ValueError("FoundationalModel not connected to client")
519
+
520
+ # Normalize input
521
+ if isinstance(records, dict):
522
+ records = [records]
523
+ elif hasattr(records, 'to_dict'):
524
+ records = records.to_dict('records')
525
+
526
+ # Clean records
527
+ cleaned = [self._clean_record(r) for r in records]
528
+
529
+ response = self._ctx.post_json(
530
+ f"/session/{self.id}/encode_records",
531
+ data={"records": cleaned}
532
+ )
533
+
534
+ return response.get('embeddings', [])
535
+
536
+ def extend(
537
+ self,
538
+ new_data_file: Optional[str] = None,
539
+ new_data_df: Optional['pd.DataFrame'] = None,
540
+ epochs: Optional[int] = None,
541
+ **kwargs
542
+ ) -> 'FoundationalModel':
543
+ """
544
+ Extend this foundational model with new data.
545
+
546
+ Args:
547
+ new_data_file: Path to new data file
548
+ new_data_df: DataFrame with new data
549
+ epochs: Additional training epochs
550
+ **kwargs: Additional parameters
551
+
552
+ Returns:
553
+ New FoundationalModel instance (training started)
554
+ """
555
+ if not self._ctx:
556
+ raise ValueError("FoundationalModel not connected to client")
557
+
558
+ # This creates a new session with extended data
559
+ # Implementation depends on server API
560
+ raise NotImplementedError("extend() not yet implemented - use create_foundational_model with new data")
561
+
562
+ def get_projections(self) -> Dict[str, Any]:
563
+ """Get 2D/3D projections for visualization."""
564
+ if not self._ctx:
565
+ raise ValueError("FoundationalModel not connected to client")
566
+
567
+ return self._ctx.get_json(f"/session/{self.id}/projections")
568
+
569
+ def get_training_metrics(self) -> Dict[str, Any]:
570
+ """Get training metrics and history."""
571
+ if not self._ctx:
572
+ raise ValueError("FoundationalModel not connected to client")
573
+
574
+ return self._ctx.get_json(f"/session/{self.id}/training_metrics")
575
+
576
+ def get_model_card(self) -> Dict[str, Any]:
577
+ """Get the model card for this foundational model."""
578
+ if not self._ctx:
579
+ raise ValueError("FoundationalModel not connected to client")
580
+
581
+ return self._ctx.get_json(f"/session/{self.id}/model_card")
582
+
583
+ def list_predictors(self) -> List[Predictor]:
584
+ """List all predictors for this foundational model."""
585
+ if not self._ctx:
586
+ raise ValueError("FoundationalModel not connected to client")
587
+
588
+ response = self._ctx.get_json(f"/session/{self.id}/predictor")
589
+ predictors_data = response.get('predictors', {})
590
+
591
+ predictors = []
592
+ for pred_id, pred_info in predictors_data.items():
593
+ pred = Predictor(
594
+ id=pred_id,
595
+ session_id=self.id,
596
+ target_column=pred_info.get('target_column', ''),
597
+ target_type=pred_info.get('target_type', 'set'),
598
+ name=pred_info.get('name'),
599
+ status=pred_info.get('status'),
600
+ accuracy=pred_info.get('accuracy'),
601
+ _ctx=self._ctx,
602
+ _foundational_model=self,
603
+ )
604
+ predictors.append(pred)
605
+
606
+ return predictors
607
+
608
+ def _update_from_session(self, session_data: Dict[str, Any]) -> None:
609
+ """Update fields from session data."""
610
+ # Try to get model info from various places
611
+ model_info = session_data.get('model_info', {})
612
+ training_stats = session_data.get('training_stats', {})
613
+
614
+ self.dimensions = (
615
+ model_info.get('d_model') or
616
+ model_info.get('embedding_dim') or
617
+ session_data.get('d_model')
618
+ )
619
+ self.epochs = (
620
+ training_stats.get('final_epoch') or
621
+ training_stats.get('epochs_trained') or
622
+ session_data.get('epochs')
623
+ )
624
+ self.final_loss = (
625
+ training_stats.get('final_loss') or
626
+ session_data.get('final_loss')
627
+ )
628
+
629
+ def _clean_record(self, record: Dict[str, Any]) -> Dict[str, Any]:
630
+ """Clean a record for API submission."""
631
+ import math
632
+
633
+ cleaned = {}
634
+ for key, value in record.items():
635
+ if isinstance(value, float):
636
+ if math.isnan(value) or math.isinf(value):
637
+ value = None
638
+ if hasattr(value, 'item'):
639
+ value = value.item()
640
+ cleaned[key] = value
641
+ return cleaned
642
+
643
+ def to_dict(self) -> Dict[str, Any]:
644
+ """Convert to dictionary representation."""
645
+ return {
646
+ 'id': self.id,
647
+ 'name': self.name,
648
+ 'status': self.status,
649
+ 'dimensions': self.dimensions,
650
+ 'epochs': self.epochs,
651
+ 'final_loss': self.final_loss,
652
+ 'created_at': self.created_at.isoformat() if self.created_at else None,
653
+ }
654
+
655
+ def __repr__(self) -> str:
656
+ status_str = f", status='{self.status}'" if self.status else ""
657
+ dims_str = f", dims={self.dimensions}" if self.dimensions else ""
658
+ return f"FoundationalModel(id='{self.id}'{status_str}{dims_str})"