featrixsphere 0.2.6127__py3-none-any.whl → 0.2.6708__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- featrixsphere/__init__.py +1 -1
- featrixsphere/api/foundational_model.py +288 -0
- featrixsphere/api/http_client.py +37 -4
- featrixsphere/api/prediction_result.py +98 -9
- featrixsphere/api/predictor.py +77 -3
- featrixsphere/client.py +81 -58
- {featrixsphere-0.2.6127.dist-info → featrixsphere-0.2.6708.dist-info}/METADATA +1 -1
- featrixsphere-0.2.6708.dist-info/RECORD +17 -0
- {featrixsphere-0.2.6127.dist-info → featrixsphere-0.2.6708.dist-info}/WHEEL +1 -1
- featrixsphere-0.2.6127.dist-info/RECORD +0 -17
- {featrixsphere-0.2.6127.dist-info → featrixsphere-0.2.6708.dist-info}/entry_points.txt +0 -0
- {featrixsphere-0.2.6127.dist-info → featrixsphere-0.2.6708.dist-info}/top_level.txt +0 -0
featrixsphere/__init__.py
CHANGED
|
@@ -566,6 +566,28 @@ class FoundationalModel:
|
|
|
566
566
|
|
|
567
567
|
return self._ctx.get_json(f"/session/{self.id}/projections")
|
|
568
568
|
|
|
569
|
+
def get_sphere_preview(self, save_path: str = None) -> bytes:
|
|
570
|
+
"""
|
|
571
|
+
Get the 2D sphere projection preview image (PNG).
|
|
572
|
+
|
|
573
|
+
Args:
|
|
574
|
+
save_path: Optional path to save the PNG file. If provided, the image
|
|
575
|
+
will be written to this path.
|
|
576
|
+
|
|
577
|
+
Returns:
|
|
578
|
+
Raw PNG image bytes.
|
|
579
|
+
"""
|
|
580
|
+
if not self._ctx:
|
|
581
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
582
|
+
|
|
583
|
+
png_bytes = self._ctx.get_bytes(f"/session/{self.id}/preview")
|
|
584
|
+
|
|
585
|
+
if save_path:
|
|
586
|
+
with open(save_path, 'wb') as f:
|
|
587
|
+
f.write(png_bytes)
|
|
588
|
+
|
|
589
|
+
return png_bytes
|
|
590
|
+
|
|
569
591
|
def get_training_metrics(self) -> Dict[str, Any]:
|
|
570
592
|
"""Get training metrics and history."""
|
|
571
593
|
if not self._ctx:
|
|
@@ -640,6 +662,272 @@ class FoundationalModel:
|
|
|
640
662
|
cleaned[key] = value
|
|
641
663
|
return cleaned
|
|
642
664
|
|
|
665
|
+
def get_columns(self) -> List[str]:
|
|
666
|
+
"""
|
|
667
|
+
Get the column names in this foundational model's embedding space.
|
|
668
|
+
|
|
669
|
+
Returns:
|
|
670
|
+
List of column name strings
|
|
671
|
+
|
|
672
|
+
Example:
|
|
673
|
+
columns = fm.get_columns()
|
|
674
|
+
print(columns) # ['age', 'income', 'city', ...]
|
|
675
|
+
"""
|
|
676
|
+
if not self._ctx:
|
|
677
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
678
|
+
|
|
679
|
+
response = self._ctx.get_json(f"/compute/session/{self.id}/columns")
|
|
680
|
+
return response.get('columns', [])
|
|
681
|
+
|
|
682
|
+
@property
|
|
683
|
+
def columns(self) -> List[str]:
|
|
684
|
+
"""Column names in this foundational model's embedding space."""
|
|
685
|
+
return self.get_columns()
|
|
686
|
+
|
|
687
|
+
def clone(
|
|
688
|
+
self,
|
|
689
|
+
target_compute_cluster: Optional[str] = None,
|
|
690
|
+
new_name: Optional[str] = None,
|
|
691
|
+
source_compute_cluster: Optional[str] = None,
|
|
692
|
+
) -> 'FoundationalModel':
|
|
693
|
+
"""
|
|
694
|
+
Clone this embedding space, optionally to a different compute node.
|
|
695
|
+
|
|
696
|
+
Args:
|
|
697
|
+
target_compute_cluster: Target compute cluster (None = same node)
|
|
698
|
+
new_name: Name for the cloned session
|
|
699
|
+
source_compute_cluster: Source compute cluster (if routing needed)
|
|
700
|
+
|
|
701
|
+
Returns:
|
|
702
|
+
New FoundationalModel instance for the cloned embedding space
|
|
703
|
+
|
|
704
|
+
Example:
|
|
705
|
+
cloned = fm.clone(
|
|
706
|
+
target_compute_cluster="churro",
|
|
707
|
+
new_name="my-model-clone"
|
|
708
|
+
)
|
|
709
|
+
"""
|
|
710
|
+
if not self._ctx:
|
|
711
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
712
|
+
|
|
713
|
+
data = {
|
|
714
|
+
"to_compute": target_compute_cluster,
|
|
715
|
+
"new_session_name": new_name,
|
|
716
|
+
}
|
|
717
|
+
|
|
718
|
+
response = self._ctx.post_json(
|
|
719
|
+
f"/compute/session/{self.id}/clone_embedding_space",
|
|
720
|
+
data=data
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
new_session_id = response.get('new_session_id', '')
|
|
724
|
+
return FoundationalModel(
|
|
725
|
+
id=new_session_id,
|
|
726
|
+
name=new_name,
|
|
727
|
+
status="done",
|
|
728
|
+
created_at=datetime.now(),
|
|
729
|
+
_ctx=self._ctx,
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
def refresh(self) -> Dict[str, Any]:
|
|
733
|
+
"""
|
|
734
|
+
Refresh this foundational model's state from the server.
|
|
735
|
+
|
|
736
|
+
Returns the full server-side info for this model, and updates
|
|
737
|
+
local attributes (status, epochs, dimensions, etc.).
|
|
738
|
+
|
|
739
|
+
Returns:
|
|
740
|
+
Full model info dictionary from the server
|
|
741
|
+
|
|
742
|
+
Example:
|
|
743
|
+
info = fm.refresh()
|
|
744
|
+
print(fm.status) # Updated from server
|
|
745
|
+
print(fm.epochs) # Updated from server
|
|
746
|
+
"""
|
|
747
|
+
if not self._ctx:
|
|
748
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
749
|
+
|
|
750
|
+
data = self._ctx.get_json(f"/compute/session/{self.id}")
|
|
751
|
+
self._update_from_session(data)
|
|
752
|
+
self.status = data.get('status', self.status)
|
|
753
|
+
return data
|
|
754
|
+
|
|
755
|
+
def is_ready(self) -> bool:
|
|
756
|
+
"""
|
|
757
|
+
Check if this foundational model has finished training and is ready for use.
|
|
758
|
+
|
|
759
|
+
Returns:
|
|
760
|
+
True if training is complete, False otherwise
|
|
761
|
+
|
|
762
|
+
Example:
|
|
763
|
+
if fm.is_ready():
|
|
764
|
+
predictor = fm.create_classifier(target_column="target")
|
|
765
|
+
"""
|
|
766
|
+
if not self._ctx:
|
|
767
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
768
|
+
|
|
769
|
+
data = self._ctx.get_json(f"/compute/session/{self.id}")
|
|
770
|
+
status = data.get('status', 'unknown')
|
|
771
|
+
self.status = status
|
|
772
|
+
return status == 'done'
|
|
773
|
+
|
|
774
|
+
def publish(
|
|
775
|
+
self,
|
|
776
|
+
org_id: str,
|
|
777
|
+
name: Optional[str] = None,
|
|
778
|
+
) -> Dict[str, Any]:
|
|
779
|
+
"""
|
|
780
|
+
Publish this foundational model to the production directory.
|
|
781
|
+
|
|
782
|
+
Published models are protected from garbage collection and available
|
|
783
|
+
across all compute nodes via the shared backplane.
|
|
784
|
+
|
|
785
|
+
Args:
|
|
786
|
+
org_id: Organization ID for directory organization
|
|
787
|
+
name: Name for the published model (defaults to self.name)
|
|
788
|
+
|
|
789
|
+
Returns:
|
|
790
|
+
dict with published_path, output_path, and status
|
|
791
|
+
|
|
792
|
+
Example:
|
|
793
|
+
fm = featrix.create_foundational_model(name="my_model", csv_file="data.csv")
|
|
794
|
+
fm.wait_for_training()
|
|
795
|
+
fm.publish(org_id="my_org", name="my_model_v1")
|
|
796
|
+
"""
|
|
797
|
+
if not self._ctx:
|
|
798
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
799
|
+
|
|
800
|
+
publish_name = name or self.name
|
|
801
|
+
if not publish_name:
|
|
802
|
+
raise ValueError("name is required (either pass it or set it on the model)")
|
|
803
|
+
|
|
804
|
+
data = {
|
|
805
|
+
"org_id": org_id,
|
|
806
|
+
"name": publish_name,
|
|
807
|
+
}
|
|
808
|
+
return self._ctx.post_json(f"/compute/session/{self.id}/publish", data=data)
|
|
809
|
+
|
|
810
|
+
def deprecate(
|
|
811
|
+
self,
|
|
812
|
+
warning_message: str,
|
|
813
|
+
expiration_date: str,
|
|
814
|
+
) -> Dict[str, Any]:
|
|
815
|
+
"""
|
|
816
|
+
Deprecate this published model with a warning and expiration date.
|
|
817
|
+
|
|
818
|
+
The model remains available until the expiration date. Prediction
|
|
819
|
+
responses will include a model_expiration field warning consumers.
|
|
820
|
+
|
|
821
|
+
Args:
|
|
822
|
+
warning_message: Warning message to display
|
|
823
|
+
expiration_date: ISO format date string (e.g., "2026-06-01T00:00:00Z")
|
|
824
|
+
|
|
825
|
+
Returns:
|
|
826
|
+
dict with deprecation status
|
|
827
|
+
|
|
828
|
+
Example:
|
|
829
|
+
from datetime import datetime, timedelta
|
|
830
|
+
expiration = (datetime.now() + timedelta(days=90)).isoformat() + "Z"
|
|
831
|
+
fm.deprecate(
|
|
832
|
+
warning_message="Replaced by v2. Migrate by expiration.",
|
|
833
|
+
expiration_date=expiration
|
|
834
|
+
)
|
|
835
|
+
"""
|
|
836
|
+
if not self._ctx:
|
|
837
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
838
|
+
|
|
839
|
+
data = {
|
|
840
|
+
"warning_message": warning_message,
|
|
841
|
+
"expiration_date": expiration_date,
|
|
842
|
+
}
|
|
843
|
+
return self._ctx.post_json(f"/compute/session/{self.id}/deprecate", data=data)
|
|
844
|
+
|
|
845
|
+
def unpublish(self) -> Dict[str, Any]:
|
|
846
|
+
"""
|
|
847
|
+
Unpublish this model, moving it back from the published directory.
|
|
848
|
+
|
|
849
|
+
WARNING: After unpublishing, the model is subject to garbage
|
|
850
|
+
collection and may be deleted when disk space is low.
|
|
851
|
+
|
|
852
|
+
Returns:
|
|
853
|
+
dict with unpublish status
|
|
854
|
+
|
|
855
|
+
Example:
|
|
856
|
+
fm.unpublish()
|
|
857
|
+
"""
|
|
858
|
+
if not self._ctx:
|
|
859
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
860
|
+
|
|
861
|
+
return self._ctx.post_json(f"/compute/session/{self.id}/unpublish", data={})
|
|
862
|
+
|
|
863
|
+
def publish_checkpoint(
|
|
864
|
+
self,
|
|
865
|
+
name: str,
|
|
866
|
+
org_id: Optional[str] = None,
|
|
867
|
+
checkpoint_epoch: Optional[int] = None,
|
|
868
|
+
session_name_prefix: Optional[str] = None,
|
|
869
|
+
publish: bool = True,
|
|
870
|
+
) -> 'FoundationalModel':
|
|
871
|
+
"""
|
|
872
|
+
Publish a checkpoint from this model's training as a new foundation model.
|
|
873
|
+
|
|
874
|
+
Creates a NEW FoundationalModel from a training checkpoint with full
|
|
875
|
+
provenance tracking. Useful for snapshotting good intermediate models
|
|
876
|
+
while training continues.
|
|
877
|
+
|
|
878
|
+
Args:
|
|
879
|
+
name: Name for the new foundation model (required)
|
|
880
|
+
org_id: Organization ID (required if publish=True)
|
|
881
|
+
checkpoint_epoch: Which epoch checkpoint to use (None = best/latest)
|
|
882
|
+
session_name_prefix: Optional prefix for the new session ID
|
|
883
|
+
publish: Move to published directory (default: True)
|
|
884
|
+
|
|
885
|
+
Returns:
|
|
886
|
+
New FoundationalModel instance for the published checkpoint
|
|
887
|
+
|
|
888
|
+
Example:
|
|
889
|
+
# Snapshot epoch 50 while training continues
|
|
890
|
+
checkpoint_fm = fm.publish_checkpoint(
|
|
891
|
+
name="My Model v0.5",
|
|
892
|
+
org_id="my_org",
|
|
893
|
+
checkpoint_epoch=50
|
|
894
|
+
)
|
|
895
|
+
# Use immediately
|
|
896
|
+
predictor = checkpoint_fm.create_classifier(target_column="target")
|
|
897
|
+
"""
|
|
898
|
+
if not self._ctx:
|
|
899
|
+
raise ValueError("FoundationalModel not connected to client")
|
|
900
|
+
|
|
901
|
+
if publish and not org_id:
|
|
902
|
+
raise ValueError("org_id is required when publish=True")
|
|
903
|
+
|
|
904
|
+
data = {
|
|
905
|
+
"name": name,
|
|
906
|
+
"publish": publish,
|
|
907
|
+
}
|
|
908
|
+
if checkpoint_epoch is not None:
|
|
909
|
+
data["checkpoint_epoch"] = checkpoint_epoch
|
|
910
|
+
if session_name_prefix:
|
|
911
|
+
data["session_name_prefix"] = session_name_prefix
|
|
912
|
+
if org_id:
|
|
913
|
+
data["org_id"] = org_id
|
|
914
|
+
|
|
915
|
+
response = self._ctx.post_json(
|
|
916
|
+
f"/compute/session/{self.id}/publish_partial_foundation",
|
|
917
|
+
data=data
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
new_fm = FoundationalModel(
|
|
921
|
+
id=response.get("foundation_session_id", ""),
|
|
922
|
+
name=name,
|
|
923
|
+
status="done",
|
|
924
|
+
epochs=response.get("checkpoint_epoch"),
|
|
925
|
+
created_at=datetime.now(),
|
|
926
|
+
_ctx=self._ctx,
|
|
927
|
+
)
|
|
928
|
+
|
|
929
|
+
return new_fm
|
|
930
|
+
|
|
643
931
|
def to_dict(self) -> Dict[str, Any]:
|
|
644
932
|
"""Convert to dictionary representation."""
|
|
645
933
|
return {
|
featrixsphere/api/http_client.py
CHANGED
|
@@ -121,14 +121,34 @@ class HTTPClientMixin:
|
|
|
121
121
|
|
|
122
122
|
def _unwrap_response(self, response_json: Dict[str, Any]) -> Dict[str, Any]:
|
|
123
123
|
"""
|
|
124
|
-
Unwrap server response, handling
|
|
124
|
+
Unwrap server response, handling wrapper formats.
|
|
125
125
|
|
|
126
|
-
The server
|
|
126
|
+
The server may wrap responses as:
|
|
127
|
+
- {"response": {...}}
|
|
128
|
+
- {"_meta": {...}, "data": {...}}
|
|
129
|
+
|
|
130
|
+
Captures server metadata when present.
|
|
127
131
|
"""
|
|
128
|
-
if isinstance(response_json, dict)
|
|
129
|
-
|
|
132
|
+
if isinstance(response_json, dict):
|
|
133
|
+
# Handle _meta/data wrapper (captures server metadata)
|
|
134
|
+
if '_meta' in response_json and 'data' in response_json:
|
|
135
|
+
self._last_server_metadata = response_json['_meta']
|
|
136
|
+
return response_json['data']
|
|
137
|
+
# Handle response wrapper
|
|
138
|
+
if 'response' in response_json and len(response_json) == 1:
|
|
139
|
+
return response_json['response']
|
|
130
140
|
return response_json
|
|
131
141
|
|
|
142
|
+
@property
|
|
143
|
+
def last_server_metadata(self) -> Optional[Dict[str, Any]]:
|
|
144
|
+
"""
|
|
145
|
+
Metadata from the most recent server response.
|
|
146
|
+
|
|
147
|
+
Contains server info like compute_cluster_time, compute_cluster,
|
|
148
|
+
compute_cluster_version, etc.
|
|
149
|
+
"""
|
|
150
|
+
return getattr(self, '_last_server_metadata', None)
|
|
151
|
+
|
|
132
152
|
def _get_json(
|
|
133
153
|
self,
|
|
134
154
|
endpoint: str,
|
|
@@ -139,6 +159,16 @@ class HTTPClientMixin:
|
|
|
139
159
|
response = self._make_request("GET", endpoint, max_retries=max_retries, **kwargs)
|
|
140
160
|
return self._unwrap_response(response.json())
|
|
141
161
|
|
|
162
|
+
def _get_bytes(
|
|
163
|
+
self,
|
|
164
|
+
endpoint: str,
|
|
165
|
+
max_retries: Optional[int] = None,
|
|
166
|
+
**kwargs
|
|
167
|
+
) -> bytes:
|
|
168
|
+
"""Make a GET request and return raw bytes (for binary content like images)."""
|
|
169
|
+
response = self._make_request("GET", endpoint, max_retries=max_retries, **kwargs)
|
|
170
|
+
return response.content
|
|
171
|
+
|
|
142
172
|
def _post_json(
|
|
143
173
|
self,
|
|
144
174
|
endpoint: str,
|
|
@@ -207,3 +237,6 @@ class ClientContext:
|
|
|
207
237
|
def post_multipart(self, endpoint: str, data: Dict[str, Any] = None,
|
|
208
238
|
files: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
|
|
209
239
|
return self._client._post_multipart(endpoint, data, files, **kwargs)
|
|
240
|
+
|
|
241
|
+
def get_bytes(self, endpoint: str, **kwargs) -> bytes:
|
|
242
|
+
return self._client._get_bytes(endpoint, **kwargs)
|
|
@@ -23,11 +23,16 @@ class PredictionResult:
|
|
|
23
23
|
prediction: Raw prediction result (class probabilities or numeric value)
|
|
24
24
|
predicted_class: Predicted class name (for classification)
|
|
25
25
|
confidence: Confidence score (for classification)
|
|
26
|
+
probabilities: Full probability distribution (for classification)
|
|
27
|
+
threshold: Classification threshold (for binary classification)
|
|
26
28
|
query_record: Original input record
|
|
27
29
|
predictor_id: ID of predictor that made this prediction
|
|
28
30
|
session_id: Session ID (internal)
|
|
29
31
|
timestamp: When prediction was made
|
|
30
32
|
target_column: Target column name
|
|
33
|
+
guardrails: Per-column guardrail warnings (if any)
|
|
34
|
+
ignored_query_columns: Columns in input that were not used (not in training data)
|
|
35
|
+
available_query_columns: All columns the model knows about
|
|
31
36
|
|
|
32
37
|
Usage:
|
|
33
38
|
result = predictor.predict({"age": 35, "income": 50000})
|
|
@@ -35,6 +40,14 @@ class PredictionResult:
|
|
|
35
40
|
print(result.confidence) # 0.87
|
|
36
41
|
print(result.prediction_uuid) # UUID for feedback
|
|
37
42
|
|
|
43
|
+
# Check for guardrail warnings
|
|
44
|
+
if result.guardrails:
|
|
45
|
+
print(f"Warnings: {len(result.guardrails)} columns with issues")
|
|
46
|
+
|
|
47
|
+
# Check for ignored columns
|
|
48
|
+
if result.ignored_query_columns:
|
|
49
|
+
print(f"Ignored: {result.ignored_query_columns}")
|
|
50
|
+
|
|
38
51
|
# Send feedback if prediction was wrong
|
|
39
52
|
if result.predicted_class != actual_label:
|
|
40
53
|
feedback = result.send_feedback(ground_truth=actual_label)
|
|
@@ -44,14 +57,52 @@ class PredictionResult:
|
|
|
44
57
|
prediction_uuid: Optional[str] = None
|
|
45
58
|
prediction: Optional[Union[Dict[str, float], float]] = None
|
|
46
59
|
predicted_class: Optional[str] = None
|
|
60
|
+
probability: Optional[float] = None
|
|
47
61
|
confidence: Optional[float] = None
|
|
62
|
+
probabilities: Optional[Dict[str, float]] = None
|
|
63
|
+
threshold: Optional[float] = None
|
|
48
64
|
query_record: Optional[Dict[str, Any]] = None
|
|
65
|
+
|
|
66
|
+
# Documentation fields - explain what prediction/probability/confidence mean
|
|
67
|
+
readme_prediction: str = field(default="The predicted class label (for classification) or value (for regression).")
|
|
68
|
+
readme_probability: str = field(default="Raw probability of the predicted class from the model's softmax output.")
|
|
69
|
+
readme_confidence: str = field(default=(
|
|
70
|
+
"For binary classification: normalized margin from threshold. "
|
|
71
|
+
"confidence = (prob - threshold) / (1 - threshold) if predicting positive, "
|
|
72
|
+
"or (threshold - prob) / threshold if predicting negative. "
|
|
73
|
+
"Ranges from 0 (at decision boundary) to 1 (maximally certain). "
|
|
74
|
+
"For multi-class: same as probability."
|
|
75
|
+
))
|
|
76
|
+
readme_threshold: str = field(default=(
|
|
77
|
+
"Decision boundary for binary classification. "
|
|
78
|
+
"If P(positive_class) >= threshold, predict positive; otherwise predict negative. "
|
|
79
|
+
"Calibrated to optimize F1 score on validation data."
|
|
80
|
+
))
|
|
81
|
+
readme_probabilities: str = field(default=(
|
|
82
|
+
"Full probability distribution across all classes from the model's softmax output. "
|
|
83
|
+
"Dictionary mapping class labels to their probabilities (sum to 1.0)."
|
|
84
|
+
))
|
|
85
|
+
readme_pos_label: str = field(default=(
|
|
86
|
+
"The class label considered 'positive' for binary classification metrics. "
|
|
87
|
+
"Threshold and confidence calculations are relative to this class."
|
|
88
|
+
))
|
|
49
89
|
predictor_id: Optional[str] = None
|
|
50
90
|
session_id: Optional[str] = None
|
|
51
91
|
target_column: Optional[str] = None
|
|
52
92
|
timestamp: Optional[datetime] = None
|
|
53
93
|
model_version: Optional[str] = None
|
|
54
94
|
|
|
95
|
+
# Checkpoint info from the model (epoch, metric_type, metric_value)
|
|
96
|
+
checkpoint_info: Optional[Dict[str, Any]] = None
|
|
97
|
+
|
|
98
|
+
# Guardrails and warnings
|
|
99
|
+
guardrails: Optional[Dict[str, Any]] = None
|
|
100
|
+
ignored_query_columns: Optional[list] = None
|
|
101
|
+
available_query_columns: Optional[list] = None
|
|
102
|
+
|
|
103
|
+
# Feature importance (from leave-one-out ablation)
|
|
104
|
+
feature_importance: Optional[Dict[str, float]] = None
|
|
105
|
+
|
|
55
106
|
# Internal: client context for sending feedback
|
|
56
107
|
_ctx: Optional['ClientContext'] = field(default=None, repr=False)
|
|
57
108
|
|
|
@@ -73,29 +124,44 @@ class PredictionResult:
|
|
|
73
124
|
Returns:
|
|
74
125
|
PredictionResult instance
|
|
75
126
|
"""
|
|
76
|
-
# Extract prediction data
|
|
127
|
+
# Extract prediction data - handle both formats
|
|
128
|
+
# New format: prediction is the class label, probabilities is separate
|
|
129
|
+
# Old format: prediction is the probabilities dict
|
|
77
130
|
prediction = response.get('prediction')
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
131
|
+
probabilities = response.get('probabilities')
|
|
132
|
+
predicted_class = response.get('predicted_class')
|
|
133
|
+
probability = response.get('probability')
|
|
134
|
+
confidence = response.get('confidence')
|
|
135
|
+
|
|
136
|
+
# For old format where prediction is the probabilities dict
|
|
137
|
+
if isinstance(prediction, dict) and not probabilities:
|
|
138
|
+
probabilities = prediction
|
|
84
139
|
if prediction:
|
|
85
140
|
predicted_class = max(prediction.keys(), key=lambda k: prediction[k])
|
|
86
|
-
|
|
141
|
+
probability = prediction[predicted_class]
|
|
142
|
+
confidence = probability # Old format: confidence = probability
|
|
143
|
+
elif isinstance(prediction, str) and not predicted_class:
|
|
144
|
+
# New format: prediction is already the class label
|
|
145
|
+
predicted_class = prediction
|
|
87
146
|
|
|
88
147
|
return cls(
|
|
89
148
|
prediction_uuid=response.get('prediction_uuid') or response.get('prediction_id'),
|
|
90
149
|
prediction=prediction,
|
|
91
150
|
predicted_class=predicted_class,
|
|
151
|
+
probability=probability,
|
|
92
152
|
confidence=confidence,
|
|
153
|
+
probabilities=probabilities,
|
|
154
|
+
threshold=response.get('threshold'),
|
|
93
155
|
query_record=query_record,
|
|
94
156
|
predictor_id=response.get('predictor_id'),
|
|
95
157
|
session_id=response.get('session_id'),
|
|
96
158
|
target_column=response.get('target_column'),
|
|
97
159
|
timestamp=datetime.now(),
|
|
98
160
|
model_version=response.get('model_version'),
|
|
161
|
+
checkpoint_info=response.get('checkpoint_info'),
|
|
162
|
+
guardrails=response.get('guardrails'),
|
|
163
|
+
ignored_query_columns=response.get('ignored_query_columns'),
|
|
164
|
+
available_query_columns=response.get('available_query_columns'),
|
|
99
165
|
_ctx=ctx,
|
|
100
166
|
)
|
|
101
167
|
|
|
@@ -126,18 +192,41 @@ class PredictionResult:
|
|
|
126
192
|
|
|
127
193
|
def to_dict(self) -> Dict[str, Any]:
|
|
128
194
|
"""Convert to dictionary representation."""
|
|
129
|
-
|
|
195
|
+
result = {
|
|
130
196
|
'prediction_uuid': self.prediction_uuid,
|
|
131
197
|
'prediction': self.prediction,
|
|
132
198
|
'predicted_class': self.predicted_class,
|
|
199
|
+
'probability': self.probability,
|
|
133
200
|
'confidence': self.confidence,
|
|
201
|
+
'probabilities': self.probabilities,
|
|
202
|
+
'threshold': self.threshold,
|
|
134
203
|
'query_record': self.query_record,
|
|
135
204
|
'predictor_id': self.predictor_id,
|
|
136
205
|
'session_id': self.session_id,
|
|
137
206
|
'target_column': self.target_column,
|
|
138
207
|
'timestamp': self.timestamp.isoformat() if self.timestamp else None,
|
|
139
208
|
'model_version': self.model_version,
|
|
209
|
+
# Documentation
|
|
210
|
+
'readme_prediction': self.readme_prediction,
|
|
211
|
+
'readme_probability': self.readme_probability,
|
|
212
|
+
'readme_confidence': self.readme_confidence,
|
|
213
|
+
'readme_threshold': self.readme_threshold,
|
|
214
|
+
'readme_probabilities': self.readme_probabilities,
|
|
215
|
+
'readme_pos_label': self.readme_pos_label,
|
|
140
216
|
}
|
|
217
|
+
# Include checkpoint_info if present
|
|
218
|
+
if self.checkpoint_info:
|
|
219
|
+
result['checkpoint_info'] = self.checkpoint_info
|
|
220
|
+
# Include guardrails if present
|
|
221
|
+
if self.guardrails:
|
|
222
|
+
result['guardrails'] = self.guardrails
|
|
223
|
+
if self.ignored_query_columns:
|
|
224
|
+
result['ignored_query_columns'] = self.ignored_query_columns
|
|
225
|
+
if self.available_query_columns:
|
|
226
|
+
result['available_query_columns'] = self.available_query_columns
|
|
227
|
+
if self.feature_importance:
|
|
228
|
+
result['feature_importance'] = self.feature_importance
|
|
229
|
+
return result
|
|
141
230
|
|
|
142
231
|
|
|
143
232
|
@dataclass
|
featrixsphere/api/predictor.py
CHANGED
|
@@ -105,7 +105,8 @@ class Predictor:
|
|
|
105
105
|
def predict(
|
|
106
106
|
self,
|
|
107
107
|
record: Dict[str, Any],
|
|
108
|
-
best_metric_preference: Optional[str] = None
|
|
108
|
+
best_metric_preference: Optional[str] = None,
|
|
109
|
+
feature_importance: bool = False
|
|
109
110
|
) -> PredictionResult:
|
|
110
111
|
"""
|
|
111
112
|
Make a single prediction.
|
|
@@ -113,15 +114,21 @@ class Predictor:
|
|
|
113
114
|
Args:
|
|
114
115
|
record: Input record dictionary
|
|
115
116
|
best_metric_preference: Metric checkpoint to use ("roc_auc", "pr_auc", or None)
|
|
117
|
+
feature_importance: If True, compute feature importance via leave-one-out ablation
|
|
116
118
|
|
|
117
119
|
Returns:
|
|
118
|
-
PredictionResult with prediction, confidence, and prediction_uuid
|
|
120
|
+
PredictionResult with prediction, confidence, and prediction_uuid.
|
|
121
|
+
If feature_importance=True, also includes feature_importance dict.
|
|
119
122
|
|
|
120
123
|
Example:
|
|
121
124
|
result = predictor.predict({"age": 35, "income": 50000})
|
|
122
125
|
print(result.predicted_class) # "churned"
|
|
123
126
|
print(result.confidence) # 0.87
|
|
124
127
|
print(result.prediction_uuid) # UUID for feedback
|
|
128
|
+
|
|
129
|
+
# With feature importance
|
|
130
|
+
result = predictor.predict(record, feature_importance=True)
|
|
131
|
+
print(result.feature_importance) # {"income": 0.15, "age": 0.08, ...}
|
|
125
132
|
"""
|
|
126
133
|
if not self._ctx:
|
|
127
134
|
raise ValueError("Predictor not connected to client")
|
|
@@ -129,7 +136,44 @@ class Predictor:
|
|
|
129
136
|
# Clean the record
|
|
130
137
|
cleaned_record = self._clean_record(record)
|
|
131
138
|
|
|
132
|
-
|
|
139
|
+
if feature_importance:
|
|
140
|
+
# Build N+1 records: original + each feature nulled out
|
|
141
|
+
columns = list(cleaned_record.keys())
|
|
142
|
+
batch = [cleaned_record] # Original first
|
|
143
|
+
|
|
144
|
+
for col in columns:
|
|
145
|
+
ablated = cleaned_record.copy()
|
|
146
|
+
ablated[col] = None
|
|
147
|
+
batch.append(ablated)
|
|
148
|
+
|
|
149
|
+
# Single batch call
|
|
150
|
+
results = self.batch_predict(
|
|
151
|
+
batch,
|
|
152
|
+
show_progress=False,
|
|
153
|
+
best_metric_preference=best_metric_preference
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Compare: importance = |original_confidence - ablated_confidence|
|
|
157
|
+
original = results[0]
|
|
158
|
+
importance = {}
|
|
159
|
+
original_conf = original.confidence or 0.0
|
|
160
|
+
|
|
161
|
+
for i, col in enumerate(columns):
|
|
162
|
+
ablated_result = results[i + 1]
|
|
163
|
+
ablated_conf = ablated_result.confidence or 0.0
|
|
164
|
+
# Higher delta = more important
|
|
165
|
+
delta = abs(original_conf - ablated_conf)
|
|
166
|
+
importance[col] = round(delta, 4)
|
|
167
|
+
|
|
168
|
+
# Sort by importance (highest first)
|
|
169
|
+
original.feature_importance = dict(sorted(
|
|
170
|
+
importance.items(),
|
|
171
|
+
key=lambda x: x[1],
|
|
172
|
+
reverse=True
|
|
173
|
+
))
|
|
174
|
+
return original
|
|
175
|
+
|
|
176
|
+
# Standard single prediction
|
|
133
177
|
request_payload = {
|
|
134
178
|
"query_record": cleaned_record,
|
|
135
179
|
"predictor_id": self.id,
|
|
@@ -196,6 +240,36 @@ class Predictor:
|
|
|
196
240
|
|
|
197
241
|
return results
|
|
198
242
|
|
|
243
|
+
def predict_csv_file(
|
|
244
|
+
self,
|
|
245
|
+
csv_path: str,
|
|
246
|
+
show_progress: bool = True,
|
|
247
|
+
best_metric_preference: Optional[str] = None
|
|
248
|
+
) -> List[PredictionResult]:
|
|
249
|
+
"""
|
|
250
|
+
Load a CSV file and run batch predictions on all rows.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
csv_path: Path to the CSV file
|
|
254
|
+
show_progress: Show progress bar
|
|
255
|
+
best_metric_preference: Metric checkpoint to use
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
List of PredictionResult objects
|
|
259
|
+
|
|
260
|
+
Example:
|
|
261
|
+
results = predictor.predict_csv_file("test_data.csv")
|
|
262
|
+
for r in results:
|
|
263
|
+
print(r.predicted_class, r.confidence)
|
|
264
|
+
"""
|
|
265
|
+
import pandas as pd
|
|
266
|
+
df = pd.read_csv(csv_path)
|
|
267
|
+
return self.batch_predict(
|
|
268
|
+
df,
|
|
269
|
+
show_progress=show_progress,
|
|
270
|
+
best_metric_preference=best_metric_preference
|
|
271
|
+
)
|
|
272
|
+
|
|
199
273
|
def explain(
|
|
200
274
|
self,
|
|
201
275
|
record: Dict[str, Any],
|
featrixsphere/client.py
CHANGED
|
@@ -137,22 +137,24 @@ class SessionInfo:
|
|
|
137
137
|
class PredictionBatch:
|
|
138
138
|
"""
|
|
139
139
|
Cached prediction batch that allows instant lookups after initial batch processing.
|
|
140
|
-
|
|
140
|
+
|
|
141
141
|
Usage:
|
|
142
142
|
# First run - populate cache
|
|
143
143
|
batch = client.predict_batch(session_id, records)
|
|
144
|
-
|
|
144
|
+
|
|
145
145
|
# Second run - instant cache lookups
|
|
146
146
|
for i in values1:
|
|
147
147
|
for j in values2:
|
|
148
148
|
record = {"param1": i, "param2": j}
|
|
149
149
|
result = batch.predict(record) # Instant!
|
|
150
150
|
"""
|
|
151
|
-
|
|
152
|
-
def __init__(self, session_id: str, client: 'FeatrixSphereClient', target_column: str = None
|
|
151
|
+
|
|
152
|
+
def __init__(self, session_id: str, client: 'FeatrixSphereClient', target_column: str = None,
|
|
153
|
+
best_metric_preference: str = None):
|
|
153
154
|
self.session_id = session_id
|
|
154
155
|
self.client = client
|
|
155
156
|
self.target_column = target_column
|
|
157
|
+
self.best_metric_preference = best_metric_preference
|
|
156
158
|
self._cache = {} # record_hash -> prediction_result
|
|
157
159
|
self._stats = {'hits': 0, 'misses': 0, 'populated': 0}
|
|
158
160
|
|
|
@@ -203,14 +205,15 @@ class PredictionBatch:
|
|
|
203
205
|
"""Populate the cache with batch predictions."""
|
|
204
206
|
if not records:
|
|
205
207
|
return {'summary': {'total_records': 0, 'successful': 0, 'failed': 0}}
|
|
206
|
-
|
|
208
|
+
|
|
207
209
|
print(f"🚀 Creating prediction batch for {len(records)} records...")
|
|
208
|
-
|
|
210
|
+
|
|
209
211
|
# Use existing batch prediction system
|
|
210
212
|
batch_results = self.client.predict_records(
|
|
211
213
|
session_id=self.session_id,
|
|
212
214
|
records=records,
|
|
213
215
|
target_column=self.target_column,
|
|
216
|
+
best_metric_preference=self.best_metric_preference,
|
|
214
217
|
show_progress_bar=True
|
|
215
218
|
)
|
|
216
219
|
|
|
@@ -2927,24 +2930,26 @@ class FeatrixSphereClient:
|
|
|
2927
2930
|
# Single Predictor Functionality
|
|
2928
2931
|
# =========================================================================
|
|
2929
2932
|
|
|
2930
|
-
def predict(self, session_id: str, record: Dict[str, Any], target_column: str = None,
|
|
2933
|
+
def predict(self, session_id: str, record: Dict[str, Any], target_column: str = None,
|
|
2931
2934
|
predictor_id: str = None, best_metric_preference: str = None,
|
|
2932
|
-
|
|
2935
|
+
checkpoint_epoch: int = None, max_retries: int = None,
|
|
2936
|
+
queue_batches: bool = False) -> Dict[str, Any]:
|
|
2933
2937
|
"""
|
|
2934
2938
|
Make a single prediction for a record.
|
|
2935
|
-
|
|
2939
|
+
|
|
2936
2940
|
Args:
|
|
2937
2941
|
session_id: ID of session with trained predictor
|
|
2938
2942
|
record: Record dictionary (without target column)
|
|
2939
2943
|
target_column: Specific target column predictor to use (required if multiple predictors exist and predictor_id not specified)
|
|
2940
2944
|
predictor_id: Specific predictor ID to use (recommended - more precise than target_column)
|
|
2941
2945
|
best_metric_preference: Which metric checkpoint to use: "roc_auc", "pr_auc", or None (use default checkpoint) (default: None)
|
|
2946
|
+
checkpoint_epoch: Specific epoch checkpoint to use (e.g., 65 for epoch 65). Overrides best_metric_preference.
|
|
2942
2947
|
max_retries: Number of retries for errors (default: uses client default)
|
|
2943
2948
|
queue_batches: If True, queue this prediction for batch processing instead of immediate API call
|
|
2944
|
-
|
|
2949
|
+
|
|
2945
2950
|
Returns:
|
|
2946
2951
|
Prediction result dictionary if queue_batches=False, or queue ID if queue_batches=True
|
|
2947
|
-
|
|
2952
|
+
|
|
2948
2953
|
Note:
|
|
2949
2954
|
predictor_id is recommended over target_column for precision. If both are provided, predictor_id takes precedence.
|
|
2950
2955
|
Use client.list_predictors(session_id) to see available predictor IDs.
|
|
@@ -2955,29 +2960,31 @@ class FeatrixSphereClient:
|
|
|
2955
2960
|
if should_warn:
|
|
2956
2961
|
call_count = len(self._prediction_call_times.get(session_id, []))
|
|
2957
2962
|
self._show_batching_warning(session_id, call_count)
|
|
2958
|
-
|
|
2963
|
+
|
|
2959
2964
|
# If queueing is enabled, add to queue and return queue ID
|
|
2960
2965
|
if queue_batches:
|
|
2961
2966
|
queue_id = self._add_to_prediction_queue(session_id, record, target_column, predictor_id)
|
|
2962
2967
|
return {"queued": True, "queue_id": queue_id}
|
|
2963
|
-
|
|
2964
|
-
# Clean NaN/Inf values
|
|
2968
|
+
|
|
2969
|
+
# Clean NaN/Inf values
|
|
2965
2970
|
cleaned_record = self._clean_numpy_values(record)
|
|
2966
2971
|
cleaned_record = self.replace_nans_with_nulls(cleaned_record)
|
|
2967
|
-
|
|
2972
|
+
|
|
2968
2973
|
# Build request payload - let the server handle predictor resolution
|
|
2969
2974
|
request_payload = {
|
|
2970
2975
|
"query_record": cleaned_record,
|
|
2971
2976
|
}
|
|
2972
|
-
|
|
2977
|
+
|
|
2973
2978
|
# Include whatever the caller provided - server will figure it out
|
|
2974
2979
|
if target_column:
|
|
2975
2980
|
request_payload["target_column"] = target_column
|
|
2976
2981
|
if predictor_id:
|
|
2977
2982
|
request_payload["predictor_id"] = predictor_id
|
|
2978
|
-
if
|
|
2983
|
+
if checkpoint_epoch is not None:
|
|
2984
|
+
request_payload["checkpoint_epoch"] = checkpoint_epoch
|
|
2985
|
+
elif best_metric_preference:
|
|
2979
2986
|
request_payload["best_metric_preference"] = best_metric_preference
|
|
2980
|
-
|
|
2987
|
+
|
|
2981
2988
|
# Just send it to the server - it has all the smart fallback logic
|
|
2982
2989
|
response_data = self._post_json(f"/session/{session_id}/predict", request_payload, max_retries=max_retries)
|
|
2983
2990
|
return response_data
|
|
@@ -6648,8 +6655,8 @@ class FeatrixSphereClient:
|
|
|
6648
6655
|
|
|
6649
6656
|
def predict_table(self, session_id: str, table_data: Dict[str, Any],
|
|
6650
6657
|
target_column: str = None, predictor_id: str = None,
|
|
6651
|
-
best_metric_preference: str = None,
|
|
6652
|
-
trace: bool = False) -> Dict[str, Any]:
|
|
6658
|
+
best_metric_preference: str = None, checkpoint_epoch: int = None,
|
|
6659
|
+
max_retries: int = None, trace: bool = False) -> Dict[str, Any]:
|
|
6653
6660
|
"""
|
|
6654
6661
|
Make batch predictions using JSON Tables format.
|
|
6655
6662
|
|
|
@@ -6658,6 +6665,8 @@ class FeatrixSphereClient:
|
|
|
6658
6665
|
table_data: Data in JSON Tables format, or list of records, or dict with 'table'/'records'
|
|
6659
6666
|
target_column: Specific target column predictor to use (required if multiple predictors exist)
|
|
6660
6667
|
predictor_id: Specific predictor ID to use (recommended - more precise than target_column)
|
|
6668
|
+
best_metric_preference: Which metric checkpoint to use: "roc_auc", "pr_auc", or None (default)
|
|
6669
|
+
checkpoint_epoch: Specific epoch checkpoint to use (e.g., 65). Overrides best_metric_preference.
|
|
6661
6670
|
max_retries: Number of retries for errors (default: uses client default, recommend higher for batch)
|
|
6662
6671
|
trace: Enable detailed debug logging (default: False)
|
|
6663
6672
|
|
|
@@ -6710,7 +6719,9 @@ class FeatrixSphereClient:
|
|
|
6710
6719
|
table_data['target_column'] = target_column
|
|
6711
6720
|
if predictor_id:
|
|
6712
6721
|
table_data['predictor_id'] = predictor_id
|
|
6713
|
-
if
|
|
6722
|
+
if checkpoint_epoch is not None:
|
|
6723
|
+
table_data['checkpoint_epoch'] = checkpoint_epoch
|
|
6724
|
+
elif best_metric_preference:
|
|
6714
6725
|
table_data['best_metric_preference'] = best_metric_preference
|
|
6715
6726
|
|
|
6716
6727
|
if trace:
|
|
@@ -6744,7 +6755,8 @@ class FeatrixSphereClient:
|
|
|
6744
6755
|
raise
|
|
6745
6756
|
|
|
6746
6757
|
def predict_records(self, session_id: str, records: List[Dict[str, Any]],
|
|
6747
|
-
target_column: str = None, predictor_id: str = None,
|
|
6758
|
+
target_column: str = None, predictor_id: str = None,
|
|
6759
|
+
best_metric_preference: str = None, checkpoint_epoch: int = None,
|
|
6748
6760
|
batch_size: int = 2500, use_async: bool = False,
|
|
6749
6761
|
show_progress_bar: bool = True, print_target_column_warning: bool = True,
|
|
6750
6762
|
trace: bool = False) -> Dict[str, Any]:
|
|
@@ -6756,6 +6768,8 @@ class FeatrixSphereClient:
|
|
|
6756
6768
|
records: List of record dictionaries
|
|
6757
6769
|
target_column: Specific target column predictor to use (required if multiple predictors exist and predictor_id not specified)
|
|
6758
6770
|
predictor_id: Specific predictor ID to use (recommended - more precise than target_column)
|
|
6771
|
+
best_metric_preference: Which metric checkpoint to use: "roc_auc", "pr_auc", or None (default)
|
|
6772
|
+
checkpoint_epoch: Specific epoch checkpoint to use (e.g., 65). Overrides best_metric_preference.
|
|
6759
6773
|
batch_size: Number of records to send per API call (default: 2500)
|
|
6760
6774
|
use_async: Force async processing for large datasets (default: False - async disabled due to pickle issues)
|
|
6761
6775
|
show_progress_bar: Whether to show progress bar for async jobs (default: True)
|
|
@@ -7744,23 +7758,25 @@ class FeatrixSphereClient:
|
|
|
7744
7758
|
else:
|
|
7745
7759
|
return data
|
|
7746
7760
|
|
|
7747
|
-
def predict_csv_file(self, session_id: str, file_path: Path
|
|
7761
|
+
def predict_csv_file(self, session_id: str, file_path: Path,
|
|
7762
|
+
best_metric_preference: str = None) -> Dict[str, Any]:
|
|
7748
7763
|
"""
|
|
7749
7764
|
Make batch predictions on a CSV file.
|
|
7750
|
-
|
|
7765
|
+
|
|
7751
7766
|
Args:
|
|
7752
7767
|
session_id: ID of session with trained predictor
|
|
7753
7768
|
file_path: Path to CSV file
|
|
7754
|
-
|
|
7769
|
+
best_metric_preference: Which metric checkpoint to use: "roc_auc", "pr_auc", or None (default)
|
|
7770
|
+
|
|
7755
7771
|
Returns:
|
|
7756
7772
|
Batch prediction results
|
|
7757
7773
|
"""
|
|
7758
7774
|
import pandas as pd
|
|
7759
7775
|
from jsontables import JSONTablesEncoder
|
|
7760
|
-
|
|
7776
|
+
|
|
7761
7777
|
if not file_path.exists():
|
|
7762
7778
|
raise FileNotFoundError(f"File not found: {file_path}")
|
|
7763
|
-
|
|
7779
|
+
|
|
7764
7780
|
# Support CSV, Parquet, JSON, and JSONL files
|
|
7765
7781
|
file_path_str = str(file_path).lower()
|
|
7766
7782
|
if file_path_str.endswith('.parquet'):
|
|
@@ -7779,29 +7795,31 @@ class FeatrixSphereClient:
|
|
|
7779
7795
|
df = pd.read_json(file_path)
|
|
7780
7796
|
else:
|
|
7781
7797
|
df = pd.read_csv(file_path)
|
|
7782
|
-
|
|
7798
|
+
|
|
7783
7799
|
# Convert to JSON Tables format and clean NaNs
|
|
7784
7800
|
table_data = JSONTablesEncoder.from_dataframe(df)
|
|
7785
7801
|
cleaned_table_data = self.replace_nans_with_nulls(table_data)
|
|
7786
|
-
|
|
7787
|
-
return self.predict_table(session_id, cleaned_table_data)
|
|
7788
7802
|
|
|
7789
|
-
|
|
7803
|
+
return self.predict_table(session_id, cleaned_table_data, best_metric_preference=best_metric_preference)
|
|
7804
|
+
|
|
7805
|
+
def run_predictions(self, session_id: str, records: List[Dict[str, Any]],
|
|
7806
|
+
best_metric_preference: str = None) -> Dict[str, Any]:
|
|
7790
7807
|
"""
|
|
7791
7808
|
Run predictions on provided records. Clean and fast for production use.
|
|
7792
|
-
|
|
7809
|
+
|
|
7793
7810
|
Args:
|
|
7794
7811
|
session_id: ID of session with trained predictor
|
|
7795
7812
|
records: List of record dictionaries
|
|
7796
|
-
|
|
7813
|
+
best_metric_preference: Which metric checkpoint to use: "roc_auc", "pr_auc", or None (default)
|
|
7814
|
+
|
|
7797
7815
|
Returns:
|
|
7798
7816
|
Dictionary with prediction results
|
|
7799
7817
|
"""
|
|
7800
7818
|
# Clean NaNs for JSON encoding
|
|
7801
7819
|
cleaned_records = self.replace_nans_with_nulls(records)
|
|
7802
|
-
|
|
7820
|
+
|
|
7803
7821
|
# Make batch predictions
|
|
7804
|
-
batch_results = self.predict_records(session_id, cleaned_records)
|
|
7822
|
+
batch_results = self.predict_records(session_id, cleaned_records, best_metric_preference=best_metric_preference)
|
|
7805
7823
|
predictions = batch_results['predictions']
|
|
7806
7824
|
|
|
7807
7825
|
# Process predictions into clean format
|
|
@@ -8517,32 +8535,33 @@ class FeatrixSphereClient:
|
|
|
8517
8535
|
|
|
8518
8536
|
return cleared_counts
|
|
8519
8537
|
|
|
8520
|
-
def predict_batch(self, session_id: str, records: List[Dict[str, Any]],
|
|
8521
|
-
|
|
8538
|
+
def predict_batch(self, session_id: str, records: List[Dict[str, Any]],
|
|
8539
|
+
target_column: str = None, best_metric_preference: str = None) -> PredictionBatch:
|
|
8522
8540
|
"""
|
|
8523
8541
|
Create a prediction batch for instant cached lookups.
|
|
8524
|
-
|
|
8542
|
+
|
|
8525
8543
|
Perfect for parameter sweeps, grid searches, and exploring prediction surfaces.
|
|
8526
8544
|
Run your loops twice with identical code - first populates cache, second gets instant results.
|
|
8527
|
-
|
|
8545
|
+
|
|
8528
8546
|
Args:
|
|
8529
8547
|
session_id: ID of session with trained predictor
|
|
8530
8548
|
records: List of all records you'll want to predict on
|
|
8531
8549
|
target_column: Specific target column predictor to use
|
|
8532
|
-
|
|
8550
|
+
best_metric_preference: Which metric checkpoint to use: "roc_auc", "pr_auc", or None (default)
|
|
8551
|
+
|
|
8533
8552
|
Returns:
|
|
8534
8553
|
PredictionBatch object with instant predict() method
|
|
8535
|
-
|
|
8554
|
+
|
|
8536
8555
|
Example:
|
|
8537
8556
|
# Generate all combinations you'll need
|
|
8538
8557
|
records = []
|
|
8539
8558
|
for i in range(10):
|
|
8540
8559
|
for j in range(10):
|
|
8541
8560
|
records.append({"param1": i, "param2": j})
|
|
8542
|
-
|
|
8561
|
+
|
|
8543
8562
|
# First run - populate cache with batch processing
|
|
8544
8563
|
batch = client.predict_batch(session_id, records)
|
|
8545
|
-
|
|
8564
|
+
|
|
8546
8565
|
# Second run - same loops but instant cache lookups
|
|
8547
8566
|
results = []
|
|
8548
8567
|
for i in range(10):
|
|
@@ -8552,50 +8571,52 @@ class FeatrixSphereClient:
|
|
|
8552
8571
|
results.append(result)
|
|
8553
8572
|
"""
|
|
8554
8573
|
# Create batch object
|
|
8555
|
-
batch = PredictionBatch(session_id, self, target_column)
|
|
8556
|
-
|
|
8574
|
+
batch = PredictionBatch(session_id, self, target_column, best_metric_preference)
|
|
8575
|
+
|
|
8557
8576
|
# Populate cache with batch predictions
|
|
8558
8577
|
batch._populate_cache(records)
|
|
8559
|
-
|
|
8578
|
+
|
|
8560
8579
|
return batch
|
|
8561
8580
|
|
|
8562
|
-
def predict_grid(self, session_id: str, degrees_of_freedom: int,
|
|
8563
|
-
|
|
8581
|
+
def predict_grid(self, session_id: str, degrees_of_freedom: int,
|
|
8582
|
+
grid_shape: tuple = None, target_column: str = None,
|
|
8583
|
+
best_metric_preference: str = None) -> 'PredictionGrid':
|
|
8564
8584
|
"""
|
|
8565
8585
|
Create a prediction grid for exploring parameter surfaces with automatic visualization.
|
|
8566
|
-
|
|
8586
|
+
|
|
8567
8587
|
Perfect for 1D curves, 2D heatmaps, and 3D surfaces with built-in plotting functions.
|
|
8568
|
-
|
|
8588
|
+
|
|
8569
8589
|
Args:
|
|
8570
8590
|
session_id: ID of session with trained predictor
|
|
8571
8591
|
degrees_of_freedom: Number of dimensions (1, 2, or 3)
|
|
8572
8592
|
grid_shape: Custom grid shape tuple (default: auto-sized)
|
|
8573
8593
|
target_column: Specific target column predictor to use
|
|
8574
|
-
|
|
8594
|
+
best_metric_preference: Which metric checkpoint to use: "roc_auc", "pr_auc", or None (default)
|
|
8595
|
+
|
|
8575
8596
|
Returns:
|
|
8576
8597
|
PredictionGrid object with predict() and plotting methods
|
|
8577
|
-
|
|
8598
|
+
|
|
8578
8599
|
Example:
|
|
8579
8600
|
# 2D parameter sweep with automatic plotting
|
|
8580
8601
|
grid = client.predict_grid(session_id, degrees_of_freedom=2)
|
|
8581
8602
|
grid.set_axis_labels(["Spend", "Campaign Type"])
|
|
8582
8603
|
grid.set_axis_values(0, [100, 250, 500])
|
|
8583
8604
|
grid.set_axis_values(1, ["search", "display", "social"])
|
|
8584
|
-
|
|
8605
|
+
|
|
8585
8606
|
for i, spend in enumerate([100, 250, 500]):
|
|
8586
8607
|
for j, campaign in enumerate(["search", "display", "social"]):
|
|
8587
8608
|
record = {"spend": spend, "campaign_type": campaign}
|
|
8588
8609
|
grid.predict(record, grid_position=(i, j))
|
|
8589
|
-
|
|
8610
|
+
|
|
8590
8611
|
# Automatic visualization
|
|
8591
8612
|
grid.plot_heatmap() # 2D heatmap
|
|
8592
8613
|
grid.plot_3d() # 3D surface
|
|
8593
|
-
|
|
8614
|
+
|
|
8594
8615
|
# Find optimal parameters
|
|
8595
8616
|
optimal_pos = grid.get_optimal_position()
|
|
8596
8617
|
print(f"Optimal parameters at grid position: {optimal_pos}")
|
|
8597
8618
|
"""
|
|
8598
|
-
return PredictionGrid(session_id, self, degrees_of_freedom, grid_shape, target_column)
|
|
8619
|
+
return PredictionGrid(session_id, self, degrees_of_freedom, grid_shape, target_column, best_metric_preference)
|
|
8599
8620
|
|
|
8600
8621
|
def get_embedding_space_columns(self, session_id: str) -> Dict[str, Any]:
|
|
8601
8622
|
"""
|
|
@@ -8672,12 +8693,13 @@ class PredictionGrid:
|
|
|
8672
8693
|
grid.plot_3d() # 3D surface plot
|
|
8673
8694
|
"""
|
|
8674
8695
|
|
|
8675
|
-
def __init__(self, session_id: str, client: 'FeatrixSphereClient', degrees_of_freedom: int,
|
|
8676
|
-
grid_shape: tuple = None, target_column: str = None):
|
|
8696
|
+
def __init__(self, session_id: str, client: 'FeatrixSphereClient', degrees_of_freedom: int,
|
|
8697
|
+
grid_shape: tuple = None, target_column: str = None, best_metric_preference: str = None):
|
|
8677
8698
|
self.session_id = session_id
|
|
8678
8699
|
self.client = client
|
|
8679
8700
|
self.degrees_of_freedom = degrees_of_freedom
|
|
8680
8701
|
self.target_column = target_column
|
|
8702
|
+
self.best_metric_preference = best_metric_preference
|
|
8681
8703
|
|
|
8682
8704
|
# Initialize grid matrix based on degrees of freedom
|
|
8683
8705
|
if grid_shape:
|
|
@@ -8769,6 +8791,7 @@ class PredictionGrid:
|
|
|
8769
8791
|
session_id=self.session_id,
|
|
8770
8792
|
records=records_list,
|
|
8771
8793
|
target_column=self.target_column,
|
|
8794
|
+
best_metric_preference=self.best_metric_preference,
|
|
8772
8795
|
show_progress_bar=show_progress
|
|
8773
8796
|
)
|
|
8774
8797
|
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
featrixsphere/__init__.py,sha256=NPVY3_tG74OqiLz1MYbrb9R4Wff68xFCI6zr3gVRTyQ,2190
|
|
2
|
+
featrixsphere/client.py,sha256=JNYAHDFtxmhQVuclO3dWEphJuxNFLi-PfvWylZWKF_4,452929
|
|
3
|
+
featrixsphere/api/__init__.py,sha256=quyvuPphVj9wb6v8Dio0SMG9iHgJAmY3asHk3f_zF10,1269
|
|
4
|
+
featrixsphere/api/api_endpoint.py,sha256=i3eCWuaUXftnH1Ai6MFZ7md7pC2FcRAIRO87CBZhyEQ,9000
|
|
5
|
+
featrixsphere/api/client.py,sha256=TdpujNsJxO4GfPMI_KoemQWV89go3KuK6OPAo9jX6Bs,12574
|
|
6
|
+
featrixsphere/api/foundational_model.py,sha256=wf5-VvVUXYoiyKr3y4Ok8OnwMhtaUuYLXwdgCwFuM-k,31065
|
|
7
|
+
featrixsphere/api/http_client.py,sha256=q59-41fHua_7AwtPFCvshlSUKJ-fS0X337L9Ooyn0DI,8440
|
|
8
|
+
featrixsphere/api/notebook_helper.py,sha256=xY9jsao26eaNiFh2s0_TlRZnR8xZ4P_e0EOKr2PtoVs,20060
|
|
9
|
+
featrixsphere/api/prediction_result.py,sha256=HQsJdr89zWxdRx395nevN3aP7ZXZuZxB4UGX5Ykhkfk,12235
|
|
10
|
+
featrixsphere/api/predictor.py,sha256=1v0ffkEjmrO3BP0PNWAXtiAU-AlOQJSiDICmW1bQbGU,20300
|
|
11
|
+
featrixsphere/api/reference_record.py,sha256=-XOTF6ynznB3ouz06w3AF8X9SVId0g_dO20VvGNesUQ,7095
|
|
12
|
+
featrixsphere/api/vector_database.py,sha256=BplxKkPnAbcBX1A4KxFBJVb3qkQ-FH9zi9v2dWG5CgY,7976
|
|
13
|
+
featrixsphere-0.2.6708.dist-info/METADATA,sha256=IpNh40w32ZlavhKaRkBFRLvpXyXmoNddYMjbc2NpBGI,16232
|
|
14
|
+
featrixsphere-0.2.6708.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
15
|
+
featrixsphere-0.2.6708.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
|
|
16
|
+
featrixsphere-0.2.6708.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
|
|
17
|
+
featrixsphere-0.2.6708.dist-info/RECORD,,
|
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
featrixsphere/__init__.py,sha256=tnkl4O62quDkp8n4aXzKQu1ELKyF5yfqpgmcjA9-rnY,2190
|
|
2
|
-
featrixsphere/client.py,sha256=mxUcqkhMkAsYpwJjdamQdXB_wRlnV-TUCU71xA289tA,451235
|
|
3
|
-
featrixsphere/api/__init__.py,sha256=quyvuPphVj9wb6v8Dio0SMG9iHgJAmY3asHk3f_zF10,1269
|
|
4
|
-
featrixsphere/api/api_endpoint.py,sha256=i3eCWuaUXftnH1Ai6MFZ7md7pC2FcRAIRO87CBZhyEQ,9000
|
|
5
|
-
featrixsphere/api/client.py,sha256=TdpujNsJxO4GfPMI_KoemQWV89go3KuK6OPAo9jX6Bs,12574
|
|
6
|
-
featrixsphere/api/foundational_model.py,sha256=ZF5wKMs6SfsNC3XYYXgbRMhnrtmLe6NeckjCCiH0fK0,21628
|
|
7
|
-
featrixsphere/api/http_client.py,sha256=TsOQHHNTDFGAR3mdHevj-0wy1-hPtgHXKe8Egiz5FVo,7269
|
|
8
|
-
featrixsphere/api/notebook_helper.py,sha256=xY9jsao26eaNiFh2s0_TlRZnR8xZ4P_e0EOKr2PtoVs,20060
|
|
9
|
-
featrixsphere/api/prediction_result.py,sha256=Tx7LXzF4XT-U3VqAN_IFc5DvxPnygc78M2usrD-yMu4,7521
|
|
10
|
-
featrixsphere/api/predictor.py,sha256=-vwCKpCfTgZKqzpDnzy1iYZQ-1-MGW8aErvxM9trktw,17652
|
|
11
|
-
featrixsphere/api/reference_record.py,sha256=-XOTF6ynznB3ouz06w3AF8X9SVId0g_dO20VvGNesUQ,7095
|
|
12
|
-
featrixsphere/api/vector_database.py,sha256=BplxKkPnAbcBX1A4KxFBJVb3qkQ-FH9zi9v2dWG5CgY,7976
|
|
13
|
-
featrixsphere-0.2.6127.dist-info/METADATA,sha256=Ew2yMa6rOSrsv1bSq38HXJogU0O5bliojWkbyQYGWV0,16232
|
|
14
|
-
featrixsphere-0.2.6127.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
15
|
-
featrixsphere-0.2.6127.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
|
|
16
|
-
featrixsphere-0.2.6127.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
|
|
17
|
-
featrixsphere-0.2.6127.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|