featrixsphere 0.2.6379__py3-none-any.whl → 0.2.6710__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- featrixsphere/__init__.py +1 -1
- featrixsphere/api/client.py +28 -0
- featrixsphere/api/foundational_model.py +394 -20
- featrixsphere/api/http_client.py +37 -4
- featrixsphere/api/prediction_result.py +98 -9
- featrixsphere/api/predictor.py +77 -3
- featrixsphere/client.py +27 -16
- {featrixsphere-0.2.6379.dist-info → featrixsphere-0.2.6710.dist-info}/METADATA +1 -1
- featrixsphere-0.2.6710.dist-info/RECORD +17 -0
- {featrixsphere-0.2.6379.dist-info → featrixsphere-0.2.6710.dist-info}/WHEEL +1 -1
- featrixsphere-0.2.6379.dist-info/RECORD +0 -17
- {featrixsphere-0.2.6379.dist-info → featrixsphere-0.2.6710.dist-info}/entry_points.txt +0 -0
- {featrixsphere-0.2.6379.dist-info → featrixsphere-0.2.6710.dist-info}/top_level.txt +0 -0
featrixsphere/__init__.py
CHANGED
featrixsphere/api/client.py
CHANGED
|
@@ -342,6 +342,34 @@ class FeatrixSphere(HTTPClientMixin):
|
|
|
342
342
|
ground_truth=ground_truth
|
|
343
343
|
)
|
|
344
344
|
|
|
345
|
+
def list_sessions(
|
|
346
|
+
self,
|
|
347
|
+
name_prefix: str = "",
|
|
348
|
+
) -> List[str]:
|
|
349
|
+
"""
|
|
350
|
+
List sessions matching a name prefix/search term.
|
|
351
|
+
|
|
352
|
+
Searches session directory names on the compute cluster for
|
|
353
|
+
partial matches (not just prefix).
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
name_prefix: Term to match in session names
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
List of matching session ID strings
|
|
360
|
+
|
|
361
|
+
Example:
|
|
362
|
+
sessions = featrix.list_sessions(name_prefix="customer")
|
|
363
|
+
for sid in sessions:
|
|
364
|
+
fm = featrix.foundational_model(sid)
|
|
365
|
+
print(f"{sid}: {fm.status}")
|
|
366
|
+
"""
|
|
367
|
+
params = {}
|
|
368
|
+
if name_prefix:
|
|
369
|
+
params['name_prefix'] = name_prefix
|
|
370
|
+
response = self._get_json("/compute/sessions-for-org", params=params)
|
|
371
|
+
return response.get('sessions', [])
|
|
372
|
+
|
|
345
373
|
def health_check(self) -> Dict[str, Any]:
|
|
346
374
|
"""
|
|
347
375
|
Check if the API server is healthy.
|
|
@@ -22,6 +22,21 @@ from .reference_record import ReferenceRecord
|
|
|
22
22
|
logger = logging.getLogger(__name__)
|
|
23
23
|
|
|
24
24
|
|
|
25
|
+
def _parse_datetime(value) -> Optional[datetime]:
|
|
26
|
+
"""Parse a datetime from ISO string or return as-is if already datetime."""
|
|
27
|
+
if value is None:
|
|
28
|
+
return None
|
|
29
|
+
if isinstance(value, datetime):
|
|
30
|
+
return value
|
|
31
|
+
if isinstance(value, str):
|
|
32
|
+
try:
|
|
33
|
+
# Handle ISO format with or without timezone
|
|
34
|
+
return datetime.fromisoformat(value.replace('Z', '+00:00'))
|
|
35
|
+
except (ValueError, AttributeError):
|
|
36
|
+
return None
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
|
|
25
40
|
@dataclass
|
|
26
41
|
class FoundationalModel:
|
|
27
42
|
"""
|
|
@@ -68,6 +83,11 @@ class FoundationalModel:
|
|
|
68
83
|
epochs: Optional[int] = None
|
|
69
84
|
final_loss: Optional[float] = None
|
|
70
85
|
created_at: Optional[datetime] = None
|
|
86
|
+
updated_at: Optional[datetime] = None
|
|
87
|
+
session_type: Optional[str] = None
|
|
88
|
+
compute_cluster: Optional[str] = None
|
|
89
|
+
error_message: Optional[str] = None
|
|
90
|
+
training_progress: Optional[Dict[str, Any]] = None
|
|
71
91
|
|
|
72
92
|
# Internal
|
|
73
93
|
_ctx: Optional['ClientContext'] = field(default=None, repr=False)
|
|
@@ -88,7 +108,9 @@ class FoundationalModel:
|
|
|
88
108
|
dimensions=response.get('d_model') or response.get('dimensions'),
|
|
89
109
|
epochs=response.get('epochs') or response.get('final_epoch'),
|
|
90
110
|
final_loss=response.get('final_loss'),
|
|
91
|
-
created_at=
|
|
111
|
+
created_at=_parse_datetime(response.get('created_at')),
|
|
112
|
+
session_type=response.get('session_type'),
|
|
113
|
+
compute_cluster=response.get('compute_cluster'),
|
|
92
114
|
_ctx=ctx,
|
|
93
115
|
)
|
|
94
116
|
|
|
@@ -99,19 +121,22 @@ class FoundationalModel:
|
|
|
99
121
|
ctx: 'ClientContext'
|
|
100
122
|
) -> 'FoundationalModel':
|
|
101
123
|
"""Load FoundationalModel from session ID."""
|
|
102
|
-
# Get session info
|
|
103
|
-
|
|
124
|
+
# Get session info - response has {"session": {...}, "jobs": {...}}
|
|
125
|
+
response_data = ctx.get_json(f"/compute/session/{session_id}")
|
|
126
|
+
session = response_data.get('session', response_data)
|
|
104
127
|
|
|
105
128
|
fm = cls(
|
|
106
129
|
id=session_id,
|
|
107
|
-
name=
|
|
108
|
-
status=
|
|
109
|
-
created_at=
|
|
130
|
+
name=session.get('name'),
|
|
131
|
+
status=session.get('status'),
|
|
132
|
+
created_at=_parse_datetime(session.get('created_at')),
|
|
133
|
+
session_type=session.get('session_type'),
|
|
134
|
+
compute_cluster=session.get('compute_cluster'),
|
|
110
135
|
_ctx=ctx,
|
|
111
136
|
)
|
|
112
137
|
|
|
113
|
-
#
|
|
114
|
-
fm._update_from_session(
|
|
138
|
+
# Extract model info, training stats, jobs, error_message
|
|
139
|
+
fm._update_from_session(response_data)
|
|
115
140
|
|
|
116
141
|
return fm
|
|
117
142
|
|
|
@@ -439,10 +464,11 @@ class FoundationalModel:
|
|
|
439
464
|
last_status = None
|
|
440
465
|
|
|
441
466
|
while time.time() - start_time < max_wait_time:
|
|
442
|
-
# Get session status
|
|
443
|
-
|
|
467
|
+
# Get session status - response has {"session": {...}, "jobs": {...}}
|
|
468
|
+
response_data = self._ctx.get_json(f"/compute/session/{self.id}")
|
|
469
|
+
session_data = response_data.get('session', response_data)
|
|
444
470
|
status = session_data.get('status', 'unknown')
|
|
445
|
-
jobs =
|
|
471
|
+
jobs = response_data.get('jobs', {})
|
|
446
472
|
|
|
447
473
|
# Look for ES training job
|
|
448
474
|
es_job = None
|
|
@@ -475,7 +501,7 @@ class FoundationalModel:
|
|
|
475
501
|
# Check completion
|
|
476
502
|
if job_status == 'done' or status == 'done':
|
|
477
503
|
self.status = 'done'
|
|
478
|
-
self._update_from_session(
|
|
504
|
+
self._update_from_session(response_data)
|
|
479
505
|
if show_progress:
|
|
480
506
|
print(f"Training complete!")
|
|
481
507
|
if self.dimensions:
|
|
@@ -566,6 +592,28 @@ class FoundationalModel:
|
|
|
566
592
|
|
|
567
593
|
return self._ctx.get_json(f"/session/{self.id}/projections")
|
|
568
594
|
|
|
595
|
+
def get_sphere_preview(self, save_path: str = None) -> bytes:
|
|
596
|
+
"""
|
|
597
|
+
Get the 2D sphere projection preview image (PNG).
|
|
598
|
+
|
|
599
|
+
Args:
|
|
600
|
+
save_path: Optional path to save the PNG file. If provided, the image
|
|
601
|
+
will be written to this path.
|
|
602
|
+
|
|
603
|
+
Returns:
|
|
604
|
+
Raw PNG image bytes.
|
|
605
|
+
"""
|
|
606
|
+
if not self._ctx:
|
|
607
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
608
|
+
|
|
609
|
+
png_bytes = self._ctx.get_bytes(f"/session/{self.id}/preview")
|
|
610
|
+
|
|
611
|
+
if save_path:
|
|
612
|
+
with open(save_path, 'wb') as f:
|
|
613
|
+
f.write(png_bytes)
|
|
614
|
+
|
|
615
|
+
return png_bytes
|
|
616
|
+
|
|
569
617
|
def get_training_metrics(self) -> Dict[str, Any]:
|
|
570
618
|
"""Get training metrics and history."""
|
|
571
619
|
if not self._ctx:
|
|
@@ -605,27 +653,73 @@ class FoundationalModel:
|
|
|
605
653
|
|
|
606
654
|
return predictors
|
|
607
655
|
|
|
608
|
-
def _update_from_session(self,
|
|
609
|
-
"""Update fields from session
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
656
|
+
def _update_from_session(self, response_data: Dict[str, Any]) -> None:
|
|
657
|
+
"""Update fields from session API response.
|
|
658
|
+
|
|
659
|
+
The response from GET /session/{id} has structure:
|
|
660
|
+
{"session": {...}, "jobs": {...}, ...}
|
|
661
|
+
"""
|
|
662
|
+
# Handle both nested and flat response formats
|
|
663
|
+
session = response_data.get('session', response_data)
|
|
664
|
+
jobs = response_data.get('jobs', {})
|
|
665
|
+
|
|
666
|
+
# Core session fields
|
|
667
|
+
if session.get('name') and not self.name:
|
|
668
|
+
self.name = session['name']
|
|
669
|
+
if session.get('status'):
|
|
670
|
+
self.status = session['status']
|
|
671
|
+
if session.get('session_type'):
|
|
672
|
+
self.session_type = session['session_type']
|
|
673
|
+
if session.get('compute_cluster'):
|
|
674
|
+
self.compute_cluster = session['compute_cluster']
|
|
675
|
+
if session.get('created_at') and not self.created_at:
|
|
676
|
+
self.created_at = _parse_datetime(session['created_at'])
|
|
677
|
+
if session.get('finished_at'):
|
|
678
|
+
self.updated_at = _parse_datetime(session['finished_at'])
|
|
679
|
+
elif session.get('started_at'):
|
|
680
|
+
self.updated_at = _parse_datetime(session['started_at'])
|
|
681
|
+
|
|
682
|
+
# Model info from session
|
|
683
|
+
model_info = session.get('model_info', {})
|
|
684
|
+
training_stats = session.get('training_stats', {})
|
|
613
685
|
|
|
614
686
|
self.dimensions = (
|
|
615
687
|
model_info.get('d_model') or
|
|
616
688
|
model_info.get('embedding_dim') or
|
|
617
|
-
|
|
689
|
+
session.get('d_model')
|
|
618
690
|
)
|
|
619
691
|
self.epochs = (
|
|
620
692
|
training_stats.get('final_epoch') or
|
|
621
693
|
training_stats.get('epochs_trained') or
|
|
622
|
-
|
|
694
|
+
session.get('epochs')
|
|
623
695
|
)
|
|
624
696
|
self.final_loss = (
|
|
625
697
|
training_stats.get('final_loss') or
|
|
626
|
-
|
|
698
|
+
session.get('final_loss')
|
|
627
699
|
)
|
|
628
700
|
|
|
701
|
+
# Extract error_message and training_progress from jobs
|
|
702
|
+
for job_id, job in jobs.items():
|
|
703
|
+
job_type = job.get('job_type', '')
|
|
704
|
+
job_status = job.get('status', '')
|
|
705
|
+
|
|
706
|
+
# Training progress from ES training job
|
|
707
|
+
if job_type in ('train_embedding_space', 'train_es', 'training'):
|
|
708
|
+
current_epoch = job.get('current_epoch') or job.get('epoch')
|
|
709
|
+
total_epochs = job.get('total_epochs') or job.get('epochs')
|
|
710
|
+
if current_epoch or total_epochs:
|
|
711
|
+
self.training_progress = {
|
|
712
|
+
'current_epoch': current_epoch,
|
|
713
|
+
'total_epochs': total_epochs,
|
|
714
|
+
'job_status': job_status,
|
|
715
|
+
}
|
|
716
|
+
|
|
717
|
+
# Error message from any failed job
|
|
718
|
+
if job_status in ('failed', 'error'):
|
|
719
|
+
err = job.get('error') or job.get('error_message')
|
|
720
|
+
if err:
|
|
721
|
+
self.error_message = err
|
|
722
|
+
|
|
629
723
|
def _clean_record(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
|
630
724
|
"""Clean a record for API submission."""
|
|
631
725
|
import math
|
|
@@ -640,6 +734,281 @@ class FoundationalModel:
|
|
|
640
734
|
cleaned[key] = value
|
|
641
735
|
return cleaned
|
|
642
736
|
|
|
737
|
+
def get_columns(self) -> List[str]:
|
|
738
|
+
"""
|
|
739
|
+
Get the column names in this foundational model's embedding space.
|
|
740
|
+
|
|
741
|
+
Returns:
|
|
742
|
+
List of column name strings
|
|
743
|
+
|
|
744
|
+
Example:
|
|
745
|
+
columns = fm.get_columns()
|
|
746
|
+
print(columns) # ['age', 'income', 'city', ...]
|
|
747
|
+
"""
|
|
748
|
+
if not self._ctx:
|
|
749
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
750
|
+
|
|
751
|
+
response = self._ctx.get_json(f"/compute/session/{self.id}/columns")
|
|
752
|
+
return response.get('columns', [])
|
|
753
|
+
|
|
754
|
+
@property
|
|
755
|
+
def columns(self) -> List[str]:
|
|
756
|
+
"""Column names in this foundational model's embedding space."""
|
|
757
|
+
return self.get_columns()
|
|
758
|
+
|
|
759
|
+
@property
|
|
760
|
+
def schema_metadata(self) -> Dict[str, Any]:
|
|
761
|
+
"""Get schema metadata including column names and types.
|
|
762
|
+
|
|
763
|
+
Returns:
|
|
764
|
+
Dict with 'column_names', 'column_types', and 'num_columns'
|
|
765
|
+
"""
|
|
766
|
+
if not self._ctx:
|
|
767
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
768
|
+
return self._ctx.get_json(f"/compute/session/{self.id}/columns")
|
|
769
|
+
|
|
770
|
+
def clone(
|
|
771
|
+
self,
|
|
772
|
+
target_compute_cluster: Optional[str] = None,
|
|
773
|
+
new_name: Optional[str] = None,
|
|
774
|
+
source_compute_cluster: Optional[str] = None,
|
|
775
|
+
) -> 'FoundationalModel':
|
|
776
|
+
"""
|
|
777
|
+
Clone this embedding space, optionally to a different compute node.
|
|
778
|
+
|
|
779
|
+
Args:
|
|
780
|
+
target_compute_cluster: Target compute cluster (None = same node)
|
|
781
|
+
new_name: Name for the cloned session
|
|
782
|
+
source_compute_cluster: Source compute cluster (if routing needed)
|
|
783
|
+
|
|
784
|
+
Returns:
|
|
785
|
+
New FoundationalModel instance for the cloned embedding space
|
|
786
|
+
|
|
787
|
+
Example:
|
|
788
|
+
cloned = fm.clone(
|
|
789
|
+
target_compute_cluster="churro",
|
|
790
|
+
new_name="my-model-clone"
|
|
791
|
+
)
|
|
792
|
+
"""
|
|
793
|
+
if not self._ctx:
|
|
794
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
795
|
+
|
|
796
|
+
data = {
|
|
797
|
+
"to_compute": target_compute_cluster,
|
|
798
|
+
"new_session_name": new_name,
|
|
799
|
+
}
|
|
800
|
+
|
|
801
|
+
response = self._ctx.post_json(
|
|
802
|
+
f"/compute/session/{self.id}/clone_embedding_space",
|
|
803
|
+
data=data
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
new_session_id = response.get('new_session_id', '')
|
|
807
|
+
return FoundationalModel(
|
|
808
|
+
id=new_session_id,
|
|
809
|
+
name=new_name,
|
|
810
|
+
status="done",
|
|
811
|
+
created_at=datetime.now(),
|
|
812
|
+
_ctx=self._ctx,
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
def refresh(self) -> Dict[str, Any]:
|
|
816
|
+
"""
|
|
817
|
+
Refresh this foundational model's state from the server.
|
|
818
|
+
|
|
819
|
+
Returns the full server-side info for this model, and updates
|
|
820
|
+
local attributes (status, epochs, dimensions, etc.).
|
|
821
|
+
|
|
822
|
+
Returns:
|
|
823
|
+
Full model info dictionary from the server
|
|
824
|
+
|
|
825
|
+
Example:
|
|
826
|
+
info = fm.refresh()
|
|
827
|
+
print(fm.status) # Updated from server
|
|
828
|
+
print(fm.epochs) # Updated from server
|
|
829
|
+
"""
|
|
830
|
+
if not self._ctx:
|
|
831
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
832
|
+
|
|
833
|
+
data = self._ctx.get_json(f"/compute/session/{self.id}")
|
|
834
|
+
self._update_from_session(data)
|
|
835
|
+
return data
|
|
836
|
+
|
|
837
|
+
def is_ready(self) -> bool:
|
|
838
|
+
"""
|
|
839
|
+
Check if this foundational model has finished training and is ready for use.
|
|
840
|
+
|
|
841
|
+
Returns:
|
|
842
|
+
True if training is complete, False otherwise
|
|
843
|
+
|
|
844
|
+
Example:
|
|
845
|
+
if fm.is_ready():
|
|
846
|
+
predictor = fm.create_classifier(target_column="target")
|
|
847
|
+
"""
|
|
848
|
+
if not self._ctx:
|
|
849
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
850
|
+
|
|
851
|
+
data = self._ctx.get_json(f"/compute/session/{self.id}")
|
|
852
|
+
self._update_from_session(data)
|
|
853
|
+
return self.status == 'done'
|
|
854
|
+
|
|
855
|
+
def publish(
|
|
856
|
+
self,
|
|
857
|
+
org_id: str,
|
|
858
|
+
name: Optional[str] = None,
|
|
859
|
+
) -> Dict[str, Any]:
|
|
860
|
+
"""
|
|
861
|
+
Publish this foundational model to the production directory.
|
|
862
|
+
|
|
863
|
+
Published models are protected from garbage collection and available
|
|
864
|
+
across all compute nodes via the shared backplane.
|
|
865
|
+
|
|
866
|
+
Args:
|
|
867
|
+
org_id: Organization ID for directory organization
|
|
868
|
+
name: Name for the published model (defaults to self.name)
|
|
869
|
+
|
|
870
|
+
Returns:
|
|
871
|
+
dict with published_path, output_path, and status
|
|
872
|
+
|
|
873
|
+
Example:
|
|
874
|
+
fm = featrix.create_foundational_model(name="my_model", csv_file="data.csv")
|
|
875
|
+
fm.wait_for_training()
|
|
876
|
+
fm.publish(org_id="my_org", name="my_model_v1")
|
|
877
|
+
"""
|
|
878
|
+
if not self._ctx:
|
|
879
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
880
|
+
|
|
881
|
+
publish_name = name or self.name
|
|
882
|
+
if not publish_name:
|
|
883
|
+
raise ValueError("name is required (either pass it or set it on the model)")
|
|
884
|
+
|
|
885
|
+
data = {
|
|
886
|
+
"org_id": org_id,
|
|
887
|
+
"name": publish_name,
|
|
888
|
+
}
|
|
889
|
+
return self._ctx.post_json(f"/compute/session/{self.id}/publish", data=data)
|
|
890
|
+
|
|
891
|
+
def deprecate(
|
|
892
|
+
self,
|
|
893
|
+
warning_message: str,
|
|
894
|
+
expiration_date: str,
|
|
895
|
+
) -> Dict[str, Any]:
|
|
896
|
+
"""
|
|
897
|
+
Deprecate this published model with a warning and expiration date.
|
|
898
|
+
|
|
899
|
+
The model remains available until the expiration date. Prediction
|
|
900
|
+
responses will include a model_expiration field warning consumers.
|
|
901
|
+
|
|
902
|
+
Args:
|
|
903
|
+
warning_message: Warning message to display
|
|
904
|
+
expiration_date: ISO format date string (e.g., "2026-06-01T00:00:00Z")
|
|
905
|
+
|
|
906
|
+
Returns:
|
|
907
|
+
dict with deprecation status
|
|
908
|
+
|
|
909
|
+
Example:
|
|
910
|
+
from datetime import datetime, timedelta
|
|
911
|
+
expiration = (datetime.now() + timedelta(days=90)).isoformat() + "Z"
|
|
912
|
+
fm.deprecate(
|
|
913
|
+
warning_message="Replaced by v2. Migrate by expiration.",
|
|
914
|
+
expiration_date=expiration
|
|
915
|
+
)
|
|
916
|
+
"""
|
|
917
|
+
if not self._ctx:
|
|
918
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
919
|
+
|
|
920
|
+
data = {
|
|
921
|
+
"warning_message": warning_message,
|
|
922
|
+
"expiration_date": expiration_date,
|
|
923
|
+
}
|
|
924
|
+
return self._ctx.post_json(f"/compute/session/{self.id}/deprecate", data=data)
|
|
925
|
+
|
|
926
|
+
def unpublish(self) -> Dict[str, Any]:
|
|
927
|
+
"""
|
|
928
|
+
Unpublish this model, moving it back from the published directory.
|
|
929
|
+
|
|
930
|
+
WARNING: After unpublishing, the model is subject to garbage
|
|
931
|
+
collection and may be deleted when disk space is low.
|
|
932
|
+
|
|
933
|
+
Returns:
|
|
934
|
+
dict with unpublish status
|
|
935
|
+
|
|
936
|
+
Example:
|
|
937
|
+
fm.unpublish()
|
|
938
|
+
"""
|
|
939
|
+
if not self._ctx:
|
|
940
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
941
|
+
|
|
942
|
+
return self._ctx.post_json(f"/compute/session/{self.id}/unpublish", data={})
|
|
943
|
+
|
|
944
|
+
def publish_checkpoint(
|
|
945
|
+
self,
|
|
946
|
+
name: str,
|
|
947
|
+
org_id: Optional[str] = None,
|
|
948
|
+
checkpoint_epoch: Optional[int] = None,
|
|
949
|
+
session_name_prefix: Optional[str] = None,
|
|
950
|
+
publish: bool = True,
|
|
951
|
+
) -> 'FoundationalModel':
|
|
952
|
+
"""
|
|
953
|
+
Publish a checkpoint from this model's training as a new foundation model.
|
|
954
|
+
|
|
955
|
+
Creates a NEW FoundationalModel from a training checkpoint with full
|
|
956
|
+
provenance tracking. Useful for snapshotting good intermediate models
|
|
957
|
+
while training continues.
|
|
958
|
+
|
|
959
|
+
Args:
|
|
960
|
+
name: Name for the new foundation model (required)
|
|
961
|
+
org_id: Organization ID (required if publish=True)
|
|
962
|
+
checkpoint_epoch: Which epoch checkpoint to use (None = best/latest)
|
|
963
|
+
session_name_prefix: Optional prefix for the new session ID
|
|
964
|
+
publish: Move to published directory (default: True)
|
|
965
|
+
|
|
966
|
+
Returns:
|
|
967
|
+
New FoundationalModel instance for the published checkpoint
|
|
968
|
+
|
|
969
|
+
Example:
|
|
970
|
+
# Snapshot epoch 50 while training continues
|
|
971
|
+
checkpoint_fm = fm.publish_checkpoint(
|
|
972
|
+
name="My Model v0.5",
|
|
973
|
+
org_id="my_org",
|
|
974
|
+
checkpoint_epoch=50
|
|
975
|
+
)
|
|
976
|
+
# Use immediately
|
|
977
|
+
predictor = checkpoint_fm.create_classifier(target_column="target")
|
|
978
|
+
"""
|
|
979
|
+
if not self._ctx:
|
|
980
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
981
|
+
|
|
982
|
+
if publish and not org_id:
|
|
983
|
+
raise ValueError("org_id is required when publish=True")
|
|
984
|
+
|
|
985
|
+
data = {
|
|
986
|
+
"name": name,
|
|
987
|
+
"publish": publish,
|
|
988
|
+
}
|
|
989
|
+
if checkpoint_epoch is not None:
|
|
990
|
+
data["checkpoint_epoch"] = checkpoint_epoch
|
|
991
|
+
if session_name_prefix:
|
|
992
|
+
data["session_name_prefix"] = session_name_prefix
|
|
993
|
+
if org_id:
|
|
994
|
+
data["org_id"] = org_id
|
|
995
|
+
|
|
996
|
+
response = self._ctx.post_json(
|
|
997
|
+
f"/compute/session/{self.id}/publish_partial_foundation",
|
|
998
|
+
data=data
|
|
999
|
+
)
|
|
1000
|
+
|
|
1001
|
+
new_fm = FoundationalModel(
|
|
1002
|
+
id=response.get("foundation_session_id", ""),
|
|
1003
|
+
name=name,
|
|
1004
|
+
status="done",
|
|
1005
|
+
epochs=response.get("checkpoint_epoch"),
|
|
1006
|
+
created_at=datetime.now(),
|
|
1007
|
+
_ctx=self._ctx,
|
|
1008
|
+
)
|
|
1009
|
+
|
|
1010
|
+
return new_fm
|
|
1011
|
+
|
|
643
1012
|
def to_dict(self) -> Dict[str, Any]:
|
|
644
1013
|
"""Convert to dictionary representation."""
|
|
645
1014
|
return {
|
|
@@ -650,6 +1019,11 @@ class FoundationalModel:
|
|
|
650
1019
|
'epochs': self.epochs,
|
|
651
1020
|
'final_loss': self.final_loss,
|
|
652
1021
|
'created_at': self.created_at.isoformat() if self.created_at else None,
|
|
1022
|
+
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
|
|
1023
|
+
'session_type': self.session_type,
|
|
1024
|
+
'compute_cluster': self.compute_cluster,
|
|
1025
|
+
'error_message': self.error_message,
|
|
1026
|
+
'training_progress': self.training_progress,
|
|
653
1027
|
}
|
|
654
1028
|
|
|
655
1029
|
def __repr__(self) -> str:
|
featrixsphere/api/http_client.py
CHANGED
|
@@ -121,14 +121,34 @@ class HTTPClientMixin:
|
|
|
121
121
|
|
|
122
122
|
def _unwrap_response(self, response_json: Dict[str, Any]) -> Dict[str, Any]:
|
|
123
123
|
"""
|
|
124
|
-
Unwrap server response, handling
|
|
124
|
+
Unwrap server response, handling wrapper formats.
|
|
125
125
|
|
|
126
|
-
The server
|
|
126
|
+
The server may wrap responses as:
|
|
127
|
+
- {"response": {...}}
|
|
128
|
+
- {"_meta": {...}, "data": {...}}
|
|
129
|
+
|
|
130
|
+
Captures server metadata when present.
|
|
127
131
|
"""
|
|
128
|
-
if isinstance(response_json, dict)
|
|
129
|
-
|
|
132
|
+
if isinstance(response_json, dict):
|
|
133
|
+
# Handle _meta/data wrapper (captures server metadata)
|
|
134
|
+
if '_meta' in response_json and 'data' in response_json:
|
|
135
|
+
self._last_server_metadata = response_json['_meta']
|
|
136
|
+
return response_json['data']
|
|
137
|
+
# Handle response wrapper
|
|
138
|
+
if 'response' in response_json and len(response_json) == 1:
|
|
139
|
+
return response_json['response']
|
|
130
140
|
return response_json
|
|
131
141
|
|
|
142
|
+
@property
|
|
143
|
+
def last_server_metadata(self) -> Optional[Dict[str, Any]]:
|
|
144
|
+
"""
|
|
145
|
+
Metadata from the most recent server response.
|
|
146
|
+
|
|
147
|
+
Contains server info like compute_cluster_time, compute_cluster,
|
|
148
|
+
compute_cluster_version, etc.
|
|
149
|
+
"""
|
|
150
|
+
return getattr(self, '_last_server_metadata', None)
|
|
151
|
+
|
|
132
152
|
def _get_json(
|
|
133
153
|
self,
|
|
134
154
|
endpoint: str,
|
|
@@ -139,6 +159,16 @@ class HTTPClientMixin:
|
|
|
139
159
|
response = self._make_request("GET", endpoint, max_retries=max_retries, **kwargs)
|
|
140
160
|
return self._unwrap_response(response.json())
|
|
141
161
|
|
|
162
|
+
def _get_bytes(
|
|
163
|
+
self,
|
|
164
|
+
endpoint: str,
|
|
165
|
+
max_retries: Optional[int] = None,
|
|
166
|
+
**kwargs
|
|
167
|
+
) -> bytes:
|
|
168
|
+
"""Make a GET request and return raw bytes (for binary content like images)."""
|
|
169
|
+
response = self._make_request("GET", endpoint, max_retries=max_retries, **kwargs)
|
|
170
|
+
return response.content
|
|
171
|
+
|
|
142
172
|
def _post_json(
|
|
143
173
|
self,
|
|
144
174
|
endpoint: str,
|
|
@@ -207,3 +237,6 @@ class ClientContext:
|
|
|
207
237
|
def post_multipart(self, endpoint: str, data: Dict[str, Any] = None,
|
|
208
238
|
files: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
|
|
209
239
|
return self._client._post_multipart(endpoint, data, files, **kwargs)
|
|
240
|
+
|
|
241
|
+
def get_bytes(self, endpoint: str, **kwargs) -> bytes:
|
|
242
|
+
return self._client._get_bytes(endpoint, **kwargs)
|
|
@@ -23,11 +23,16 @@ class PredictionResult:
|
|
|
23
23
|
prediction: Raw prediction result (class probabilities or numeric value)
|
|
24
24
|
predicted_class: Predicted class name (for classification)
|
|
25
25
|
confidence: Confidence score (for classification)
|
|
26
|
+
probabilities: Full probability distribution (for classification)
|
|
27
|
+
threshold: Classification threshold (for binary classification)
|
|
26
28
|
query_record: Original input record
|
|
27
29
|
predictor_id: ID of predictor that made this prediction
|
|
28
30
|
session_id: Session ID (internal)
|
|
29
31
|
timestamp: When prediction was made
|
|
30
32
|
target_column: Target column name
|
|
33
|
+
guardrails: Per-column guardrail warnings (if any)
|
|
34
|
+
ignored_query_columns: Columns in input that were not used (not in training data)
|
|
35
|
+
available_query_columns: All columns the model knows about
|
|
31
36
|
|
|
32
37
|
Usage:
|
|
33
38
|
result = predictor.predict({"age": 35, "income": 50000})
|
|
@@ -35,6 +40,14 @@ class PredictionResult:
|
|
|
35
40
|
print(result.confidence) # 0.87
|
|
36
41
|
print(result.prediction_uuid) # UUID for feedback
|
|
37
42
|
|
|
43
|
+
# Check for guardrail warnings
|
|
44
|
+
if result.guardrails:
|
|
45
|
+
print(f"Warnings: {len(result.guardrails)} columns with issues")
|
|
46
|
+
|
|
47
|
+
# Check for ignored columns
|
|
48
|
+
if result.ignored_query_columns:
|
|
49
|
+
print(f"Ignored: {result.ignored_query_columns}")
|
|
50
|
+
|
|
38
51
|
# Send feedback if prediction was wrong
|
|
39
52
|
if result.predicted_class != actual_label:
|
|
40
53
|
feedback = result.send_feedback(ground_truth=actual_label)
|
|
@@ -44,14 +57,52 @@ class PredictionResult:
|
|
|
44
57
|
prediction_uuid: Optional[str] = None
|
|
45
58
|
prediction: Optional[Union[Dict[str, float], float]] = None
|
|
46
59
|
predicted_class: Optional[str] = None
|
|
60
|
+
probability: Optional[float] = None
|
|
47
61
|
confidence: Optional[float] = None
|
|
62
|
+
probabilities: Optional[Dict[str, float]] = None
|
|
63
|
+
threshold: Optional[float] = None
|
|
48
64
|
query_record: Optional[Dict[str, Any]] = None
|
|
65
|
+
|
|
66
|
+
# Documentation fields - explain what prediction/probability/confidence mean
|
|
67
|
+
readme_prediction: str = field(default="The predicted class label (for classification) or value (for regression).")
|
|
68
|
+
readme_probability: str = field(default="Raw probability of the predicted class from the model's softmax output.")
|
|
69
|
+
readme_confidence: str = field(default=(
|
|
70
|
+
"For binary classification: normalized margin from threshold. "
|
|
71
|
+
"confidence = (prob - threshold) / (1 - threshold) if predicting positive, "
|
|
72
|
+
"or (threshold - prob) / threshold if predicting negative. "
|
|
73
|
+
"Ranges from 0 (at decision boundary) to 1 (maximally certain). "
|
|
74
|
+
"For multi-class: same as probability."
|
|
75
|
+
))
|
|
76
|
+
readme_threshold: str = field(default=(
|
|
77
|
+
"Decision boundary for binary classification. "
|
|
78
|
+
"If P(positive_class) >= threshold, predict positive; otherwise predict negative. "
|
|
79
|
+
"Calibrated to optimize F1 score on validation data."
|
|
80
|
+
))
|
|
81
|
+
readme_probabilities: str = field(default=(
|
|
82
|
+
"Full probability distribution across all classes from the model's softmax output. "
|
|
83
|
+
"Dictionary mapping class labels to their probabilities (sum to 1.0)."
|
|
84
|
+
))
|
|
85
|
+
readme_pos_label: str = field(default=(
|
|
86
|
+
"The class label considered 'positive' for binary classification metrics. "
|
|
87
|
+
"Threshold and confidence calculations are relative to this class."
|
|
88
|
+
))
|
|
49
89
|
predictor_id: Optional[str] = None
|
|
50
90
|
session_id: Optional[str] = None
|
|
51
91
|
target_column: Optional[str] = None
|
|
52
92
|
timestamp: Optional[datetime] = None
|
|
53
93
|
model_version: Optional[str] = None
|
|
54
94
|
|
|
95
|
+
# Checkpoint info from the model (epoch, metric_type, metric_value)
|
|
96
|
+
checkpoint_info: Optional[Dict[str, Any]] = None
|
|
97
|
+
|
|
98
|
+
# Guardrails and warnings
|
|
99
|
+
guardrails: Optional[Dict[str, Any]] = None
|
|
100
|
+
ignored_query_columns: Optional[list] = None
|
|
101
|
+
available_query_columns: Optional[list] = None
|
|
102
|
+
|
|
103
|
+
# Feature importance (from leave-one-out ablation)
|
|
104
|
+
feature_importance: Optional[Dict[str, float]] = None
|
|
105
|
+
|
|
55
106
|
# Internal: client context for sending feedback
|
|
56
107
|
_ctx: Optional['ClientContext'] = field(default=None, repr=False)
|
|
57
108
|
|
|
@@ -73,29 +124,44 @@ class PredictionResult:
|
|
|
73
124
|
Returns:
|
|
74
125
|
PredictionResult instance
|
|
75
126
|
"""
|
|
76
|
-
# Extract prediction data
|
|
127
|
+
# Extract prediction data - handle both formats
|
|
128
|
+
# New format: prediction is the class label, probabilities is separate
|
|
129
|
+
# Old format: prediction is the probabilities dict
|
|
77
130
|
prediction = response.get('prediction')
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
131
|
+
probabilities = response.get('probabilities')
|
|
132
|
+
predicted_class = response.get('predicted_class')
|
|
133
|
+
probability = response.get('probability')
|
|
134
|
+
confidence = response.get('confidence')
|
|
135
|
+
|
|
136
|
+
# For old format where prediction is the probabilities dict
|
|
137
|
+
if isinstance(prediction, dict) and not probabilities:
|
|
138
|
+
probabilities = prediction
|
|
84
139
|
if prediction:
|
|
85
140
|
predicted_class = max(prediction.keys(), key=lambda k: prediction[k])
|
|
86
|
-
|
|
141
|
+
probability = prediction[predicted_class]
|
|
142
|
+
confidence = probability # Old format: confidence = probability
|
|
143
|
+
elif isinstance(prediction, str) and not predicted_class:
|
|
144
|
+
# New format: prediction is already the class label
|
|
145
|
+
predicted_class = prediction
|
|
87
146
|
|
|
88
147
|
return cls(
|
|
89
148
|
prediction_uuid=response.get('prediction_uuid') or response.get('prediction_id'),
|
|
90
149
|
prediction=prediction,
|
|
91
150
|
predicted_class=predicted_class,
|
|
151
|
+
probability=probability,
|
|
92
152
|
confidence=confidence,
|
|
153
|
+
probabilities=probabilities,
|
|
154
|
+
threshold=response.get('threshold'),
|
|
93
155
|
query_record=query_record,
|
|
94
156
|
predictor_id=response.get('predictor_id'),
|
|
95
157
|
session_id=response.get('session_id'),
|
|
96
158
|
target_column=response.get('target_column'),
|
|
97
159
|
timestamp=datetime.now(),
|
|
98
160
|
model_version=response.get('model_version'),
|
|
161
|
+
checkpoint_info=response.get('checkpoint_info'),
|
|
162
|
+
guardrails=response.get('guardrails'),
|
|
163
|
+
ignored_query_columns=response.get('ignored_query_columns'),
|
|
164
|
+
available_query_columns=response.get('available_query_columns'),
|
|
99
165
|
_ctx=ctx,
|
|
100
166
|
)
|
|
101
167
|
|
|
@@ -126,18 +192,41 @@ class PredictionResult:
|
|
|
126
192
|
|
|
127
193
|
def to_dict(self) -> Dict[str, Any]:
|
|
128
194
|
"""Convert to dictionary representation."""
|
|
129
|
-
|
|
195
|
+
result = {
|
|
130
196
|
'prediction_uuid': self.prediction_uuid,
|
|
131
197
|
'prediction': self.prediction,
|
|
132
198
|
'predicted_class': self.predicted_class,
|
|
199
|
+
'probability': self.probability,
|
|
133
200
|
'confidence': self.confidence,
|
|
201
|
+
'probabilities': self.probabilities,
|
|
202
|
+
'threshold': self.threshold,
|
|
134
203
|
'query_record': self.query_record,
|
|
135
204
|
'predictor_id': self.predictor_id,
|
|
136
205
|
'session_id': self.session_id,
|
|
137
206
|
'target_column': self.target_column,
|
|
138
207
|
'timestamp': self.timestamp.isoformat() if self.timestamp else None,
|
|
139
208
|
'model_version': self.model_version,
|
|
209
|
+
# Documentation
|
|
210
|
+
'readme_prediction': self.readme_prediction,
|
|
211
|
+
'readme_probability': self.readme_probability,
|
|
212
|
+
'readme_confidence': self.readme_confidence,
|
|
213
|
+
'readme_threshold': self.readme_threshold,
|
|
214
|
+
'readme_probabilities': self.readme_probabilities,
|
|
215
|
+
'readme_pos_label': self.readme_pos_label,
|
|
140
216
|
}
|
|
217
|
+
# Include checkpoint_info if present
|
|
218
|
+
if self.checkpoint_info:
|
|
219
|
+
result['checkpoint_info'] = self.checkpoint_info
|
|
220
|
+
# Include guardrails if present
|
|
221
|
+
if self.guardrails:
|
|
222
|
+
result['guardrails'] = self.guardrails
|
|
223
|
+
if self.ignored_query_columns:
|
|
224
|
+
result['ignored_query_columns'] = self.ignored_query_columns
|
|
225
|
+
if self.available_query_columns:
|
|
226
|
+
result['available_query_columns'] = self.available_query_columns
|
|
227
|
+
if self.feature_importance:
|
|
228
|
+
result['feature_importance'] = self.feature_importance
|
|
229
|
+
return result
|
|
141
230
|
|
|
142
231
|
|
|
143
232
|
@dataclass
|
featrixsphere/api/predictor.py
CHANGED
|
@@ -105,7 +105,8 @@ class Predictor:
|
|
|
105
105
|
def predict(
|
|
106
106
|
self,
|
|
107
107
|
record: Dict[str, Any],
|
|
108
|
-
best_metric_preference: Optional[str] = None
|
|
108
|
+
best_metric_preference: Optional[str] = None,
|
|
109
|
+
feature_importance: bool = False
|
|
109
110
|
) -> PredictionResult:
|
|
110
111
|
"""
|
|
111
112
|
Make a single prediction.
|
|
@@ -113,15 +114,21 @@ class Predictor:
|
|
|
113
114
|
Args:
|
|
114
115
|
record: Input record dictionary
|
|
115
116
|
best_metric_preference: Metric checkpoint to use ("roc_auc", "pr_auc", or None)
|
|
117
|
+
feature_importance: If True, compute feature importance via leave-one-out ablation
|
|
116
118
|
|
|
117
119
|
Returns:
|
|
118
|
-
PredictionResult with prediction, confidence, and prediction_uuid
|
|
120
|
+
PredictionResult with prediction, confidence, and prediction_uuid.
|
|
121
|
+
If feature_importance=True, also includes feature_importance dict.
|
|
119
122
|
|
|
120
123
|
Example:
|
|
121
124
|
result = predictor.predict({"age": 35, "income": 50000})
|
|
122
125
|
print(result.predicted_class) # "churned"
|
|
123
126
|
print(result.confidence) # 0.87
|
|
124
127
|
print(result.prediction_uuid) # UUID for feedback
|
|
128
|
+
|
|
129
|
+
# With feature importance
|
|
130
|
+
result = predictor.predict(record, feature_importance=True)
|
|
131
|
+
print(result.feature_importance) # {"income": 0.15, "age": 0.08, ...}
|
|
125
132
|
"""
|
|
126
133
|
if not self._ctx:
|
|
127
134
|
raise ValueError("Predictor not connected to client")
|
|
@@ -129,7 +136,44 @@ class Predictor:
|
|
|
129
136
|
# Clean the record
|
|
130
137
|
cleaned_record = self._clean_record(record)
|
|
131
138
|
|
|
132
|
-
|
|
139
|
+
if feature_importance:
|
|
140
|
+
# Build N+1 records: original + each feature nulled out
|
|
141
|
+
columns = list(cleaned_record.keys())
|
|
142
|
+
batch = [cleaned_record] # Original first
|
|
143
|
+
|
|
144
|
+
for col in columns:
|
|
145
|
+
ablated = cleaned_record.copy()
|
|
146
|
+
ablated[col] = None
|
|
147
|
+
batch.append(ablated)
|
|
148
|
+
|
|
149
|
+
# Single batch call
|
|
150
|
+
results = self.batch_predict(
|
|
151
|
+
batch,
|
|
152
|
+
show_progress=False,
|
|
153
|
+
best_metric_preference=best_metric_preference
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Compare: importance = |original_confidence - ablated_confidence|
|
|
157
|
+
original = results[0]
|
|
158
|
+
importance = {}
|
|
159
|
+
original_conf = original.confidence or 0.0
|
|
160
|
+
|
|
161
|
+
for i, col in enumerate(columns):
|
|
162
|
+
ablated_result = results[i + 1]
|
|
163
|
+
ablated_conf = ablated_result.confidence or 0.0
|
|
164
|
+
# Higher delta = more important
|
|
165
|
+
delta = abs(original_conf - ablated_conf)
|
|
166
|
+
importance[col] = round(delta, 4)
|
|
167
|
+
|
|
168
|
+
# Sort by importance (highest first)
|
|
169
|
+
original.feature_importance = dict(sorted(
|
|
170
|
+
importance.items(),
|
|
171
|
+
key=lambda x: x[1],
|
|
172
|
+
reverse=True
|
|
173
|
+
))
|
|
174
|
+
return original
|
|
175
|
+
|
|
176
|
+
# Standard single prediction
|
|
133
177
|
request_payload = {
|
|
134
178
|
"query_record": cleaned_record,
|
|
135
179
|
"predictor_id": self.id,
|
|
@@ -196,6 +240,36 @@ class Predictor:
|
|
|
196
240
|
|
|
197
241
|
return results
|
|
198
242
|
|
|
243
|
+
def predict_csv_file(
|
|
244
|
+
self,
|
|
245
|
+
csv_path: str,
|
|
246
|
+
show_progress: bool = True,
|
|
247
|
+
best_metric_preference: Optional[str] = None
|
|
248
|
+
) -> List[PredictionResult]:
|
|
249
|
+
"""
|
|
250
|
+
Load a CSV file and run batch predictions on all rows.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
csv_path: Path to the CSV file
|
|
254
|
+
show_progress: Show progress bar
|
|
255
|
+
best_metric_preference: Metric checkpoint to use
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
List of PredictionResult objects
|
|
259
|
+
|
|
260
|
+
Example:
|
|
261
|
+
results = predictor.predict_csv_file("test_data.csv")
|
|
262
|
+
for r in results:
|
|
263
|
+
print(r.predicted_class, r.confidence)
|
|
264
|
+
"""
|
|
265
|
+
import pandas as pd
|
|
266
|
+
df = pd.read_csv(csv_path)
|
|
267
|
+
return self.batch_predict(
|
|
268
|
+
df,
|
|
269
|
+
show_progress=show_progress,
|
|
270
|
+
best_metric_preference=best_metric_preference
|
|
271
|
+
)
|
|
272
|
+
|
|
199
273
|
def explain(
|
|
200
274
|
self,
|
|
201
275
|
record: Dict[str, Any],
|
featrixsphere/client.py
CHANGED
|
@@ -2930,24 +2930,26 @@ class FeatrixSphereClient:
|
|
|
2930
2930
|
# Single Predictor Functionality
|
|
2931
2931
|
# =========================================================================
|
|
2932
2932
|
|
|
2933
|
-
def predict(self, session_id: str, record: Dict[str, Any], target_column: str = None,
|
|
2933
|
+
def predict(self, session_id: str, record: Dict[str, Any], target_column: str = None,
|
|
2934
2934
|
predictor_id: str = None, best_metric_preference: str = None,
|
|
2935
|
-
|
|
2935
|
+
checkpoint_epoch: int = None, max_retries: int = None,
|
|
2936
|
+
queue_batches: bool = False) -> Dict[str, Any]:
|
|
2936
2937
|
"""
|
|
2937
2938
|
Make a single prediction for a record.
|
|
2938
|
-
|
|
2939
|
+
|
|
2939
2940
|
Args:
|
|
2940
2941
|
session_id: ID of session with trained predictor
|
|
2941
2942
|
record: Record dictionary (without target column)
|
|
2942
2943
|
target_column: Specific target column predictor to use (required if multiple predictors exist and predictor_id not specified)
|
|
2943
2944
|
predictor_id: Specific predictor ID to use (recommended - more precise than target_column)
|
|
2944
2945
|
best_metric_preference: Which metric checkpoint to use: "roc_auc", "pr_auc", or None (use default checkpoint) (default: None)
|
|
2946
|
+
checkpoint_epoch: Specific epoch checkpoint to use (e.g., 65 for epoch 65). Overrides best_metric_preference.
|
|
2945
2947
|
max_retries: Number of retries for errors (default: uses client default)
|
|
2946
2948
|
queue_batches: If True, queue this prediction for batch processing instead of immediate API call
|
|
2947
|
-
|
|
2949
|
+
|
|
2948
2950
|
Returns:
|
|
2949
2951
|
Prediction result dictionary if queue_batches=False, or queue ID if queue_batches=True
|
|
2950
|
-
|
|
2952
|
+
|
|
2951
2953
|
Note:
|
|
2952
2954
|
predictor_id is recommended over target_column for precision. If both are provided, predictor_id takes precedence.
|
|
2953
2955
|
Use client.list_predictors(session_id) to see available predictor IDs.
|
|
@@ -2958,29 +2960,31 @@ class FeatrixSphereClient:
|
|
|
2958
2960
|
if should_warn:
|
|
2959
2961
|
call_count = len(self._prediction_call_times.get(session_id, []))
|
|
2960
2962
|
self._show_batching_warning(session_id, call_count)
|
|
2961
|
-
|
|
2963
|
+
|
|
2962
2964
|
# If queueing is enabled, add to queue and return queue ID
|
|
2963
2965
|
if queue_batches:
|
|
2964
2966
|
queue_id = self._add_to_prediction_queue(session_id, record, target_column, predictor_id)
|
|
2965
2967
|
return {"queued": True, "queue_id": queue_id}
|
|
2966
|
-
|
|
2967
|
-
# Clean NaN/Inf values
|
|
2968
|
+
|
|
2969
|
+
# Clean NaN/Inf values
|
|
2968
2970
|
cleaned_record = self._clean_numpy_values(record)
|
|
2969
2971
|
cleaned_record = self.replace_nans_with_nulls(cleaned_record)
|
|
2970
|
-
|
|
2972
|
+
|
|
2971
2973
|
# Build request payload - let the server handle predictor resolution
|
|
2972
2974
|
request_payload = {
|
|
2973
2975
|
"query_record": cleaned_record,
|
|
2974
2976
|
}
|
|
2975
|
-
|
|
2977
|
+
|
|
2976
2978
|
# Include whatever the caller provided - server will figure it out
|
|
2977
2979
|
if target_column:
|
|
2978
2980
|
request_payload["target_column"] = target_column
|
|
2979
2981
|
if predictor_id:
|
|
2980
2982
|
request_payload["predictor_id"] = predictor_id
|
|
2981
|
-
if
|
|
2983
|
+
if checkpoint_epoch is not None:
|
|
2984
|
+
request_payload["checkpoint_epoch"] = checkpoint_epoch
|
|
2985
|
+
elif best_metric_preference:
|
|
2982
2986
|
request_payload["best_metric_preference"] = best_metric_preference
|
|
2983
|
-
|
|
2987
|
+
|
|
2984
2988
|
# Just send it to the server - it has all the smart fallback logic
|
|
2985
2989
|
response_data = self._post_json(f"/session/{session_id}/predict", request_payload, max_retries=max_retries)
|
|
2986
2990
|
return response_data
|
|
@@ -6651,8 +6655,8 @@ class FeatrixSphereClient:
|
|
|
6651
6655
|
|
|
6652
6656
|
def predict_table(self, session_id: str, table_data: Dict[str, Any],
|
|
6653
6657
|
target_column: str = None, predictor_id: str = None,
|
|
6654
|
-
best_metric_preference: str = None,
|
|
6655
|
-
trace: bool = False) -> Dict[str, Any]:
|
|
6658
|
+
best_metric_preference: str = None, checkpoint_epoch: int = None,
|
|
6659
|
+
max_retries: int = None, trace: bool = False) -> Dict[str, Any]:
|
|
6656
6660
|
"""
|
|
6657
6661
|
Make batch predictions using JSON Tables format.
|
|
6658
6662
|
|
|
@@ -6661,6 +6665,8 @@ class FeatrixSphereClient:
|
|
|
6661
6665
|
table_data: Data in JSON Tables format, or list of records, or dict with 'table'/'records'
|
|
6662
6666
|
target_column: Specific target column predictor to use (required if multiple predictors exist)
|
|
6663
6667
|
predictor_id: Specific predictor ID to use (recommended - more precise than target_column)
|
|
6668
|
+
best_metric_preference: Which metric checkpoint to use: "roc_auc", "pr_auc", or None (default)
|
|
6669
|
+
checkpoint_epoch: Specific epoch checkpoint to use (e.g., 65). Overrides best_metric_preference.
|
|
6664
6670
|
max_retries: Number of retries for errors (default: uses client default, recommend higher for batch)
|
|
6665
6671
|
trace: Enable detailed debug logging (default: False)
|
|
6666
6672
|
|
|
@@ -6713,7 +6719,9 @@ class FeatrixSphereClient:
|
|
|
6713
6719
|
table_data['target_column'] = target_column
|
|
6714
6720
|
if predictor_id:
|
|
6715
6721
|
table_data['predictor_id'] = predictor_id
|
|
6716
|
-
if
|
|
6722
|
+
if checkpoint_epoch is not None:
|
|
6723
|
+
table_data['checkpoint_epoch'] = checkpoint_epoch
|
|
6724
|
+
elif best_metric_preference:
|
|
6717
6725
|
table_data['best_metric_preference'] = best_metric_preference
|
|
6718
6726
|
|
|
6719
6727
|
if trace:
|
|
@@ -6747,7 +6755,8 @@ class FeatrixSphereClient:
|
|
|
6747
6755
|
raise
|
|
6748
6756
|
|
|
6749
6757
|
def predict_records(self, session_id: str, records: List[Dict[str, Any]],
|
|
6750
|
-
target_column: str = None, predictor_id: str = None,
|
|
6758
|
+
target_column: str = None, predictor_id: str = None,
|
|
6759
|
+
best_metric_preference: str = None, checkpoint_epoch: int = None,
|
|
6751
6760
|
batch_size: int = 2500, use_async: bool = False,
|
|
6752
6761
|
show_progress_bar: bool = True, print_target_column_warning: bool = True,
|
|
6753
6762
|
trace: bool = False) -> Dict[str, Any]:
|
|
@@ -6759,6 +6768,8 @@ class FeatrixSphereClient:
|
|
|
6759
6768
|
records: List of record dictionaries
|
|
6760
6769
|
target_column: Specific target column predictor to use (required if multiple predictors exist and predictor_id not specified)
|
|
6761
6770
|
predictor_id: Specific predictor ID to use (recommended - more precise than target_column)
|
|
6771
|
+
best_metric_preference: Which metric checkpoint to use: "roc_auc", "pr_auc", or None (default)
|
|
6772
|
+
checkpoint_epoch: Specific epoch checkpoint to use (e.g., 65). Overrides best_metric_preference.
|
|
6762
6773
|
batch_size: Number of records to send per API call (default: 2500)
|
|
6763
6774
|
use_async: Force async processing for large datasets (default: False - async disabled due to pickle issues)
|
|
6764
6775
|
show_progress_bar: Whether to show progress bar for async jobs (default: True)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
featrixsphere/__init__.py,sha256=QQglKOYv0bjonuO0-wOkeyyXMWhv3yK0_s5Uaap-GVk,2190
|
|
2
|
+
featrixsphere/client.py,sha256=JNYAHDFtxmhQVuclO3dWEphJuxNFLi-PfvWylZWKF_4,452929
|
|
3
|
+
featrixsphere/api/__init__.py,sha256=quyvuPphVj9wb6v8Dio0SMG9iHgJAmY3asHk3f_zF10,1269
|
|
4
|
+
featrixsphere/api/api_endpoint.py,sha256=i3eCWuaUXftnH1Ai6MFZ7md7pC2FcRAIRO87CBZhyEQ,9000
|
|
5
|
+
featrixsphere/api/client.py,sha256=TvNqrzSPQdw0A4kW48M0S3SDrBRmkc6kTY8UkzO4eRs,13426
|
|
6
|
+
featrixsphere/api/foundational_model.py,sha256=0ZFO-mJs66nVRXQbM0o1fB4HmhzLBXUqbTCF46LVH1k,34925
|
|
7
|
+
featrixsphere/api/http_client.py,sha256=q59-41fHua_7AwtPFCvshlSUKJ-fS0X337L9Ooyn0DI,8440
|
|
8
|
+
featrixsphere/api/notebook_helper.py,sha256=xY9jsao26eaNiFh2s0_TlRZnR8xZ4P_e0EOKr2PtoVs,20060
|
|
9
|
+
featrixsphere/api/prediction_result.py,sha256=HQsJdr89zWxdRx395nevN3aP7ZXZuZxB4UGX5Ykhkfk,12235
|
|
10
|
+
featrixsphere/api/predictor.py,sha256=1v0ffkEjmrO3BP0PNWAXtiAU-AlOQJSiDICmW1bQbGU,20300
|
|
11
|
+
featrixsphere/api/reference_record.py,sha256=-XOTF6ynznB3ouz06w3AF8X9SVId0g_dO20VvGNesUQ,7095
|
|
12
|
+
featrixsphere/api/vector_database.py,sha256=BplxKkPnAbcBX1A4KxFBJVb3qkQ-FH9zi9v2dWG5CgY,7976
|
|
13
|
+
featrixsphere-0.2.6710.dist-info/METADATA,sha256=_gDwyRsSfEa0thWd6IjhMeCEZF5dMc8BfLBL4J2b5HQ,16232
|
|
14
|
+
featrixsphere-0.2.6710.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
15
|
+
featrixsphere-0.2.6710.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
|
|
16
|
+
featrixsphere-0.2.6710.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
|
|
17
|
+
featrixsphere-0.2.6710.dist-info/RECORD,,
|
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
featrixsphere/__init__.py,sha256=m4FTeSot2GaITV5l_kD5WrSPZKdKmVbcmRwXZE_nYJk,2190
|
|
2
|
-
featrixsphere/client.py,sha256=Nj6C_Th4jyK7JQIXUJ_URok9AA0OND6DOAjoFbKhs2Q,452098
|
|
3
|
-
featrixsphere/api/__init__.py,sha256=quyvuPphVj9wb6v8Dio0SMG9iHgJAmY3asHk3f_zF10,1269
|
|
4
|
-
featrixsphere/api/api_endpoint.py,sha256=i3eCWuaUXftnH1Ai6MFZ7md7pC2FcRAIRO87CBZhyEQ,9000
|
|
5
|
-
featrixsphere/api/client.py,sha256=TdpujNsJxO4GfPMI_KoemQWV89go3KuK6OPAo9jX6Bs,12574
|
|
6
|
-
featrixsphere/api/foundational_model.py,sha256=ZF5wKMs6SfsNC3XYYXgbRMhnrtmLe6NeckjCCiH0fK0,21628
|
|
7
|
-
featrixsphere/api/http_client.py,sha256=TsOQHHNTDFGAR3mdHevj-0wy1-hPtgHXKe8Egiz5FVo,7269
|
|
8
|
-
featrixsphere/api/notebook_helper.py,sha256=xY9jsao26eaNiFh2s0_TlRZnR8xZ4P_e0EOKr2PtoVs,20060
|
|
9
|
-
featrixsphere/api/prediction_result.py,sha256=Tx7LXzF4XT-U3VqAN_IFc5DvxPnygc78M2usrD-yMu4,7521
|
|
10
|
-
featrixsphere/api/predictor.py,sha256=-vwCKpCfTgZKqzpDnzy1iYZQ-1-MGW8aErvxM9trktw,17652
|
|
11
|
-
featrixsphere/api/reference_record.py,sha256=-XOTF6ynznB3ouz06w3AF8X9SVId0g_dO20VvGNesUQ,7095
|
|
12
|
-
featrixsphere/api/vector_database.py,sha256=BplxKkPnAbcBX1A4KxFBJVb3qkQ-FH9zi9v2dWG5CgY,7976
|
|
13
|
-
featrixsphere-0.2.6379.dist-info/METADATA,sha256=EdpmIuyoX1hr1eelFuZbN-zOwrsIsN9TupOeehDJxys,16232
|
|
14
|
-
featrixsphere-0.2.6379.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
|
|
15
|
-
featrixsphere-0.2.6379.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
|
|
16
|
-
featrixsphere-0.2.6379.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
|
|
17
|
-
featrixsphere-0.2.6379.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|