featrixsphere 0.2.6379__py3-none-any.whl → 0.2.6710__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
featrixsphere/__init__.py CHANGED
@@ -57,7 +57,7 @@ TWO API OPTIONS:
57
57
  >>> print(result['prediction'])
58
58
  """
59
59
 
60
- __version__ = "0.2.6379"
60
+ __version__ = "0.2.6710"
61
61
  __author__ = "Featrix"
62
62
  __email__ = "support@featrix.com"
63
63
  __license__ = "MIT"
@@ -342,6 +342,34 @@ class FeatrixSphere(HTTPClientMixin):
342
342
  ground_truth=ground_truth
343
343
  )
344
344
 
345
+ def list_sessions(
346
+ self,
347
+ name_prefix: str = "",
348
+ ) -> List[str]:
349
+ """
350
+ List sessions matching a name prefix/search term.
351
+
352
+ Searches session directory names on the compute cluster for
353
+ partial matches (not just prefix).
354
+
355
+ Args:
356
+ name_prefix: Term to match in session names
357
+
358
+ Returns:
359
+ List of matching session ID strings
360
+
361
+ Example:
362
+ sessions = featrix.list_sessions(name_prefix="customer")
363
+ for sid in sessions:
364
+ fm = featrix.foundational_model(sid)
365
+ print(f"{sid}: {fm.status}")
366
+ """
367
+ params = {}
368
+ if name_prefix:
369
+ params['name_prefix'] = name_prefix
370
+ response = self._get_json("/compute/sessions-for-org", params=params)
371
+ return response.get('sessions', [])
372
+
345
373
  def health_check(self) -> Dict[str, Any]:
346
374
  """
347
375
  Check if the API server is healthy.
@@ -22,6 +22,21 @@ from .reference_record import ReferenceRecord
22
22
  logger = logging.getLogger(__name__)
23
23
 
24
24
 
25
+ def _parse_datetime(value) -> Optional[datetime]:
26
+ """Parse a datetime from ISO string or return as-is if already datetime."""
27
+ if value is None:
28
+ return None
29
+ if isinstance(value, datetime):
30
+ return value
31
+ if isinstance(value, str):
32
+ try:
33
+ # Handle ISO format with or without timezone
34
+ return datetime.fromisoformat(value.replace('Z', '+00:00'))
35
+ except (ValueError, AttributeError):
36
+ return None
37
+ return None
38
+
39
+
25
40
  @dataclass
26
41
  class FoundationalModel:
27
42
  """
@@ -68,6 +83,11 @@ class FoundationalModel:
68
83
  epochs: Optional[int] = None
69
84
  final_loss: Optional[float] = None
70
85
  created_at: Optional[datetime] = None
86
+ updated_at: Optional[datetime] = None
87
+ session_type: Optional[str] = None
88
+ compute_cluster: Optional[str] = None
89
+ error_message: Optional[str] = None
90
+ training_progress: Optional[Dict[str, Any]] = None
71
91
 
72
92
  # Internal
73
93
  _ctx: Optional['ClientContext'] = field(default=None, repr=False)
@@ -88,7 +108,9 @@ class FoundationalModel:
88
108
  dimensions=response.get('d_model') or response.get('dimensions'),
89
109
  epochs=response.get('epochs') or response.get('final_epoch'),
90
110
  final_loss=response.get('final_loss'),
91
- created_at=datetime.now(),
111
+ created_at=_parse_datetime(response.get('created_at')),
112
+ session_type=response.get('session_type'),
113
+ compute_cluster=response.get('compute_cluster'),
92
114
  _ctx=ctx,
93
115
  )
94
116
 
@@ -99,19 +121,22 @@ class FoundationalModel:
99
121
  ctx: 'ClientContext'
100
122
  ) -> 'FoundationalModel':
101
123
  """Load FoundationalModel from session ID."""
102
- # Get session info
103
- session_data = ctx.get_json(f"/compute/session/{session_id}")
124
+ # Get session info - response has {"session": {...}, "jobs": {...}}
125
+ response_data = ctx.get_json(f"/compute/session/{session_id}")
126
+ session = response_data.get('session', response_data)
104
127
 
105
128
  fm = cls(
106
129
  id=session_id,
107
- name=session_data.get('name'),
108
- status=session_data.get('status'),
109
- created_at=datetime.now(),
130
+ name=session.get('name'),
131
+ status=session.get('status'),
132
+ created_at=_parse_datetime(session.get('created_at')),
133
+ session_type=session.get('session_type'),
134
+ compute_cluster=session.get('compute_cluster'),
110
135
  _ctx=ctx,
111
136
  )
112
137
 
113
- # Try to get model info
114
- fm._update_from_session(session_data)
138
+ # Extract model info, training stats, jobs, error_message
139
+ fm._update_from_session(response_data)
115
140
 
116
141
  return fm
117
142
 
@@ -439,10 +464,11 @@ class FoundationalModel:
439
464
  last_status = None
440
465
 
441
466
  while time.time() - start_time < max_wait_time:
442
- # Get session status
443
- session_data = self._ctx.get_json(f"/compute/session/{self.id}")
467
+ # Get session status - response has {"session": {...}, "jobs": {...}}
468
+ response_data = self._ctx.get_json(f"/compute/session/{self.id}")
469
+ session_data = response_data.get('session', response_data)
444
470
  status = session_data.get('status', 'unknown')
445
- jobs = session_data.get('jobs', {})
471
+ jobs = response_data.get('jobs', {})
446
472
 
447
473
  # Look for ES training job
448
474
  es_job = None
@@ -475,7 +501,7 @@ class FoundationalModel:
475
501
  # Check completion
476
502
  if job_status == 'done' or status == 'done':
477
503
  self.status = 'done'
478
- self._update_from_session(session_data)
504
+ self._update_from_session(response_data)
479
505
  if show_progress:
480
506
  print(f"Training complete!")
481
507
  if self.dimensions:
@@ -566,6 +592,28 @@ class FoundationalModel:
566
592
 
567
593
  return self._ctx.get_json(f"/session/{self.id}/projections")
568
594
 
595
+ def get_sphere_preview(self, save_path: str = None) -> bytes:
596
+ """
597
+ Get the 2D sphere projection preview image (PNG).
598
+
599
+ Args:
600
+ save_path: Optional path to save the PNG file. If provided, the image
601
+ will be written to this path.
602
+
603
+ Returns:
604
+ Raw PNG image bytes.
605
+ """
606
+ if not self._ctx:
607
+ raise ValueError("FoundationalModel not connected to client")
608
+
609
+ png_bytes = self._ctx.get_bytes(f"/session/{self.id}/preview")
610
+
611
+ if save_path:
612
+ with open(save_path, 'wb') as f:
613
+ f.write(png_bytes)
614
+
615
+ return png_bytes
616
+
569
617
  def get_training_metrics(self) -> Dict[str, Any]:
570
618
  """Get training metrics and history."""
571
619
  if not self._ctx:
@@ -605,27 +653,73 @@ class FoundationalModel:
605
653
 
606
654
  return predictors
607
655
 
608
- def _update_from_session(self, session_data: Dict[str, Any]) -> None:
609
- """Update fields from session data."""
610
- # Try to get model info from various places
611
- model_info = session_data.get('model_info', {})
612
- training_stats = session_data.get('training_stats', {})
656
+ def _update_from_session(self, response_data: Dict[str, Any]) -> None:
657
+ """Update fields from session API response.
658
+
659
+ The response from GET /session/{id} has structure:
660
+ {"session": {...}, "jobs": {...}, ...}
661
+ """
662
+ # Handle both nested and flat response formats
663
+ session = response_data.get('session', response_data)
664
+ jobs = response_data.get('jobs', {})
665
+
666
+ # Core session fields
667
+ if session.get('name') and not self.name:
668
+ self.name = session['name']
669
+ if session.get('status'):
670
+ self.status = session['status']
671
+ if session.get('session_type'):
672
+ self.session_type = session['session_type']
673
+ if session.get('compute_cluster'):
674
+ self.compute_cluster = session['compute_cluster']
675
+ if session.get('created_at') and not self.created_at:
676
+ self.created_at = _parse_datetime(session['created_at'])
677
+ if session.get('finished_at'):
678
+ self.updated_at = _parse_datetime(session['finished_at'])
679
+ elif session.get('started_at'):
680
+ self.updated_at = _parse_datetime(session['started_at'])
681
+
682
+ # Model info from session
683
+ model_info = session.get('model_info', {})
684
+ training_stats = session.get('training_stats', {})
613
685
 
614
686
  self.dimensions = (
615
687
  model_info.get('d_model') or
616
688
  model_info.get('embedding_dim') or
617
- session_data.get('d_model')
689
+ session.get('d_model')
618
690
  )
619
691
  self.epochs = (
620
692
  training_stats.get('final_epoch') or
621
693
  training_stats.get('epochs_trained') or
622
- session_data.get('epochs')
694
+ session.get('epochs')
623
695
  )
624
696
  self.final_loss = (
625
697
  training_stats.get('final_loss') or
626
- session_data.get('final_loss')
698
+ session.get('final_loss')
627
699
  )
628
700
 
701
+ # Extract error_message and training_progress from jobs
702
+ for job_id, job in jobs.items():
703
+ job_type = job.get('job_type', '')
704
+ job_status = job.get('status', '')
705
+
706
+ # Training progress from ES training job
707
+ if job_type in ('train_embedding_space', 'train_es', 'training'):
708
+ current_epoch = job.get('current_epoch') or job.get('epoch')
709
+ total_epochs = job.get('total_epochs') or job.get('epochs')
710
+ if current_epoch or total_epochs:
711
+ self.training_progress = {
712
+ 'current_epoch': current_epoch,
713
+ 'total_epochs': total_epochs,
714
+ 'job_status': job_status,
715
+ }
716
+
717
+ # Error message from any failed job
718
+ if job_status in ('failed', 'error'):
719
+ err = job.get('error') or job.get('error_message')
720
+ if err:
721
+ self.error_message = err
722
+
629
723
  def _clean_record(self, record: Dict[str, Any]) -> Dict[str, Any]:
630
724
  """Clean a record for API submission."""
631
725
  import math
@@ -640,6 +734,281 @@ class FoundationalModel:
640
734
  cleaned[key] = value
641
735
  return cleaned
642
736
 
737
+ def get_columns(self) -> List[str]:
738
+ """
739
+ Get the column names in this foundational model's embedding space.
740
+
741
+ Returns:
742
+ List of column name strings
743
+
744
+ Example:
745
+ columns = fm.get_columns()
746
+ print(columns) # ['age', 'income', 'city', ...]
747
+ """
748
+ if not self._ctx:
749
+ raise ValueError("FoundationalModel not connected to client")
750
+
751
+ response = self._ctx.get_json(f"/compute/session/{self.id}/columns")
752
+ return response.get('columns', [])
753
+
754
+ @property
755
+ def columns(self) -> List[str]:
756
+ """Column names in this foundational model's embedding space."""
757
+ return self.get_columns()
758
+
759
+ @property
760
+ def schema_metadata(self) -> Dict[str, Any]:
761
+ """Get schema metadata including column names and types.
762
+
763
+ Returns:
764
+ Dict with 'column_names', 'column_types', and 'num_columns'
765
+ """
766
+ if not self._ctx:
767
+ raise ValueError("FoundationalModel not connected to client")
768
+ return self._ctx.get_json(f"/compute/session/{self.id}/columns")
769
+
770
+ def clone(
771
+ self,
772
+ target_compute_cluster: Optional[str] = None,
773
+ new_name: Optional[str] = None,
774
+ source_compute_cluster: Optional[str] = None,
775
+ ) -> 'FoundationalModel':
776
+ """
777
+ Clone this embedding space, optionally to a different compute node.
778
+
779
+ Args:
780
+ target_compute_cluster: Target compute cluster (None = same node)
781
+ new_name: Name for the cloned session
782
+ source_compute_cluster: Source compute cluster (if routing needed)
783
+
784
+ Returns:
785
+ New FoundationalModel instance for the cloned embedding space
786
+
787
+ Example:
788
+ cloned = fm.clone(
789
+ target_compute_cluster="churro",
790
+ new_name="my-model-clone"
791
+ )
792
+ """
793
+ if not self._ctx:
794
+ raise ValueError("FoundationalModel not connected to client")
795
+
796
+ data = {
797
+ "to_compute": target_compute_cluster,
798
+ "new_session_name": new_name,
799
+ }
800
+
801
+ response = self._ctx.post_json(
802
+ f"/compute/session/{self.id}/clone_embedding_space",
803
+ data=data
804
+ )
805
+
806
+ new_session_id = response.get('new_session_id', '')
807
+ return FoundationalModel(
808
+ id=new_session_id,
809
+ name=new_name,
810
+ status="done",
811
+ created_at=datetime.now(),
812
+ _ctx=self._ctx,
813
+ )
814
+
815
+ def refresh(self) -> Dict[str, Any]:
816
+ """
817
+ Refresh this foundational model's state from the server.
818
+
819
+ Returns the full server-side info for this model, and updates
820
+ local attributes (status, epochs, dimensions, etc.).
821
+
822
+ Returns:
823
+ Full model info dictionary from the server
824
+
825
+ Example:
826
+ info = fm.refresh()
827
+ print(fm.status) # Updated from server
828
+ print(fm.epochs) # Updated from server
829
+ """
830
+ if not self._ctx:
831
+ raise ValueError("FoundationalModel not connected to client")
832
+
833
+ data = self._ctx.get_json(f"/compute/session/{self.id}")
834
+ self._update_from_session(data)
835
+ return data
836
+
837
+ def is_ready(self) -> bool:
838
+ """
839
+ Check if this foundational model has finished training and is ready for use.
840
+
841
+ Returns:
842
+ True if training is complete, False otherwise
843
+
844
+ Example:
845
+ if fm.is_ready():
846
+ predictor = fm.create_classifier(target_column="target")
847
+ """
848
+ if not self._ctx:
849
+ raise ValueError("FoundationalModel not connected to client")
850
+
851
+ data = self._ctx.get_json(f"/compute/session/{self.id}")
852
+ self._update_from_session(data)
853
+ return self.status == 'done'
854
+
855
+ def publish(
856
+ self,
857
+ org_id: str,
858
+ name: Optional[str] = None,
859
+ ) -> Dict[str, Any]:
860
+ """
861
+ Publish this foundational model to the production directory.
862
+
863
+ Published models are protected from garbage collection and available
864
+ across all compute nodes via the shared backplane.
865
+
866
+ Args:
867
+ org_id: Organization ID for directory organization
868
+ name: Name for the published model (defaults to self.name)
869
+
870
+ Returns:
871
+ dict with published_path, output_path, and status
872
+
873
+ Example:
874
+ fm = featrix.create_foundational_model(name="my_model", csv_file="data.csv")
875
+ fm.wait_for_training()
876
+ fm.publish(org_id="my_org", name="my_model_v1")
877
+ """
878
+ if not self._ctx:
879
+ raise ValueError("FoundationalModel not connected to client")
880
+
881
+ publish_name = name or self.name
882
+ if not publish_name:
883
+ raise ValueError("name is required (either pass it or set it on the model)")
884
+
885
+ data = {
886
+ "org_id": org_id,
887
+ "name": publish_name,
888
+ }
889
+ return self._ctx.post_json(f"/compute/session/{self.id}/publish", data=data)
890
+
891
+ def deprecate(
892
+ self,
893
+ warning_message: str,
894
+ expiration_date: str,
895
+ ) -> Dict[str, Any]:
896
+ """
897
+ Deprecate this published model with a warning and expiration date.
898
+
899
+ The model remains available until the expiration date. Prediction
900
+ responses will include a model_expiration field warning consumers.
901
+
902
+ Args:
903
+ warning_message: Warning message to display
904
+ expiration_date: ISO format date string (e.g., "2026-06-01T00:00:00Z")
905
+
906
+ Returns:
907
+ dict with deprecation status
908
+
909
+ Example:
910
+ from datetime import datetime, timedelta
911
+ expiration = (datetime.now() + timedelta(days=90)).isoformat() + "Z"
912
+ fm.deprecate(
913
+ warning_message="Replaced by v2. Migrate by expiration.",
914
+ expiration_date=expiration
915
+ )
916
+ """
917
+ if not self._ctx:
918
+ raise ValueError("FoundationalModel not connected to client")
919
+
920
+ data = {
921
+ "warning_message": warning_message,
922
+ "expiration_date": expiration_date,
923
+ }
924
+ return self._ctx.post_json(f"/compute/session/{self.id}/deprecate", data=data)
925
+
926
+ def unpublish(self) -> Dict[str, Any]:
927
+ """
928
+ Unpublish this model, moving it back from the published directory.
929
+
930
+ WARNING: After unpublishing, the model is subject to garbage
931
+ collection and may be deleted when disk space is low.
932
+
933
+ Returns:
934
+ dict with unpublish status
935
+
936
+ Example:
937
+ fm.unpublish()
938
+ """
939
+ if not self._ctx:
940
+ raise ValueError("FoundationalModel not connected to client")
941
+
942
+ return self._ctx.post_json(f"/compute/session/{self.id}/unpublish", data={})
943
+
944
+ def publish_checkpoint(
945
+ self,
946
+ name: str,
947
+ org_id: Optional[str] = None,
948
+ checkpoint_epoch: Optional[int] = None,
949
+ session_name_prefix: Optional[str] = None,
950
+ publish: bool = True,
951
+ ) -> 'FoundationalModel':
952
+ """
953
+ Publish a checkpoint from this model's training as a new foundation model.
954
+
955
+ Creates a NEW FoundationalModel from a training checkpoint with full
956
+ provenance tracking. Useful for snapshotting good intermediate models
957
+ while training continues.
958
+
959
+ Args:
960
+ name: Name for the new foundation model (required)
961
+ org_id: Organization ID (required if publish=True)
962
+ checkpoint_epoch: Which epoch checkpoint to use (None = best/latest)
963
+ session_name_prefix: Optional prefix for the new session ID
964
+ publish: Move to published directory (default: True)
965
+
966
+ Returns:
967
+ New FoundationalModel instance for the published checkpoint
968
+
969
+ Example:
970
+ # Snapshot epoch 50 while training continues
971
+ checkpoint_fm = fm.publish_checkpoint(
972
+ name="My Model v0.5",
973
+ org_id="my_org",
974
+ checkpoint_epoch=50
975
+ )
976
+ # Use immediately
977
+ predictor = checkpoint_fm.create_classifier(target_column="target")
978
+ """
979
+ if not self._ctx:
980
+ raise ValueError("FoundationalModel not connected to client")
981
+
982
+ if publish and not org_id:
983
+ raise ValueError("org_id is required when publish=True")
984
+
985
+ data = {
986
+ "name": name,
987
+ "publish": publish,
988
+ }
989
+ if checkpoint_epoch is not None:
990
+ data["checkpoint_epoch"] = checkpoint_epoch
991
+ if session_name_prefix:
992
+ data["session_name_prefix"] = session_name_prefix
993
+ if org_id:
994
+ data["org_id"] = org_id
995
+
996
+ response = self._ctx.post_json(
997
+ f"/compute/session/{self.id}/publish_partial_foundation",
998
+ data=data
999
+ )
1000
+
1001
+ new_fm = FoundationalModel(
1002
+ id=response.get("foundation_session_id", ""),
1003
+ name=name,
1004
+ status="done",
1005
+ epochs=response.get("checkpoint_epoch"),
1006
+ created_at=datetime.now(),
1007
+ _ctx=self._ctx,
1008
+ )
1009
+
1010
+ return new_fm
1011
+
643
1012
  def to_dict(self) -> Dict[str, Any]:
644
1013
  """Convert to dictionary representation."""
645
1014
  return {
@@ -650,6 +1019,11 @@ class FoundationalModel:
650
1019
  'epochs': self.epochs,
651
1020
  'final_loss': self.final_loss,
652
1021
  'created_at': self.created_at.isoformat() if self.created_at else None,
1022
+ 'updated_at': self.updated_at.isoformat() if self.updated_at else None,
1023
+ 'session_type': self.session_type,
1024
+ 'compute_cluster': self.compute_cluster,
1025
+ 'error_message': self.error_message,
1026
+ 'training_progress': self.training_progress,
653
1027
  }
654
1028
 
655
1029
  def __repr__(self) -> str:
@@ -121,14 +121,34 @@ class HTTPClientMixin:
121
121
 
122
122
  def _unwrap_response(self, response_json: Dict[str, Any]) -> Dict[str, Any]:
123
123
  """
124
- Unwrap server response, handling the 'response' wrapper if present.
124
+ Unwrap server response, handling wrapper formats.
125
125
 
126
- The server sometimes wraps responses in {"response": {...}}.
126
+ The server may wrap responses as:
127
+ - {"response": {...}}
128
+ - {"_meta": {...}, "data": {...}}
129
+
130
+ Captures server metadata when present.
127
131
  """
128
- if isinstance(response_json, dict) and 'response' in response_json and len(response_json) == 1:
129
- return response_json['response']
132
+ if isinstance(response_json, dict):
133
+ # Handle _meta/data wrapper (captures server metadata)
134
+ if '_meta' in response_json and 'data' in response_json:
135
+ self._last_server_metadata = response_json['_meta']
136
+ return response_json['data']
137
+ # Handle response wrapper
138
+ if 'response' in response_json and len(response_json) == 1:
139
+ return response_json['response']
130
140
  return response_json
131
141
 
142
+ @property
143
+ def last_server_metadata(self) -> Optional[Dict[str, Any]]:
144
+ """
145
+ Metadata from the most recent server response.
146
+
147
+ Contains server info like compute_cluster_time, compute_cluster,
148
+ compute_cluster_version, etc.
149
+ """
150
+ return getattr(self, '_last_server_metadata', None)
151
+
132
152
  def _get_json(
133
153
  self,
134
154
  endpoint: str,
@@ -139,6 +159,16 @@ class HTTPClientMixin:
139
159
  response = self._make_request("GET", endpoint, max_retries=max_retries, **kwargs)
140
160
  return self._unwrap_response(response.json())
141
161
 
162
+ def _get_bytes(
163
+ self,
164
+ endpoint: str,
165
+ max_retries: Optional[int] = None,
166
+ **kwargs
167
+ ) -> bytes:
168
+ """Make a GET request and return raw bytes (for binary content like images)."""
169
+ response = self._make_request("GET", endpoint, max_retries=max_retries, **kwargs)
170
+ return response.content
171
+
142
172
  def _post_json(
143
173
  self,
144
174
  endpoint: str,
@@ -207,3 +237,6 @@ class ClientContext:
207
237
  def post_multipart(self, endpoint: str, data: Dict[str, Any] = None,
208
238
  files: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
209
239
  return self._client._post_multipart(endpoint, data, files, **kwargs)
240
+
241
+ def get_bytes(self, endpoint: str, **kwargs) -> bytes:
242
+ return self._client._get_bytes(endpoint, **kwargs)
@@ -23,11 +23,16 @@ class PredictionResult:
23
23
  prediction: Raw prediction result (class probabilities or numeric value)
24
24
  predicted_class: Predicted class name (for classification)
25
25
  confidence: Confidence score (for classification)
26
+ probabilities: Full probability distribution (for classification)
27
+ threshold: Classification threshold (for binary classification)
26
28
  query_record: Original input record
27
29
  predictor_id: ID of predictor that made this prediction
28
30
  session_id: Session ID (internal)
29
31
  timestamp: When prediction was made
30
32
  target_column: Target column name
33
+ guardrails: Per-column guardrail warnings (if any)
34
+ ignored_query_columns: Columns in input that were not used (not in training data)
35
+ available_query_columns: All columns the model knows about
31
36
 
32
37
  Usage:
33
38
  result = predictor.predict({"age": 35, "income": 50000})
@@ -35,6 +40,14 @@ class PredictionResult:
35
40
  print(result.confidence) # 0.87
36
41
  print(result.prediction_uuid) # UUID for feedback
37
42
 
43
+ # Check for guardrail warnings
44
+ if result.guardrails:
45
+ print(f"Warnings: {len(result.guardrails)} columns with issues")
46
+
47
+ # Check for ignored columns
48
+ if result.ignored_query_columns:
49
+ print(f"Ignored: {result.ignored_query_columns}")
50
+
38
51
  # Send feedback if prediction was wrong
39
52
  if result.predicted_class != actual_label:
40
53
  feedback = result.send_feedback(ground_truth=actual_label)
@@ -44,14 +57,52 @@ class PredictionResult:
44
57
  prediction_uuid: Optional[str] = None
45
58
  prediction: Optional[Union[Dict[str, float], float]] = None
46
59
  predicted_class: Optional[str] = None
60
+ probability: Optional[float] = None
47
61
  confidence: Optional[float] = None
62
+ probabilities: Optional[Dict[str, float]] = None
63
+ threshold: Optional[float] = None
48
64
  query_record: Optional[Dict[str, Any]] = None
65
+
66
+ # Documentation fields - explain what prediction/probability/confidence mean
67
+ readme_prediction: str = field(default="The predicted class label (for classification) or value (for regression).")
68
+ readme_probability: str = field(default="Raw probability of the predicted class from the model's softmax output.")
69
+ readme_confidence: str = field(default=(
70
+ "For binary classification: normalized margin from threshold. "
71
+ "confidence = (prob - threshold) / (1 - threshold) if predicting positive, "
72
+ "or (threshold - prob) / threshold if predicting negative. "
73
+ "Ranges from 0 (at decision boundary) to 1 (maximally certain). "
74
+ "For multi-class: same as probability."
75
+ ))
76
+ readme_threshold: str = field(default=(
77
+ "Decision boundary for binary classification. "
78
+ "If P(positive_class) >= threshold, predict positive; otherwise predict negative. "
79
+ "Calibrated to optimize F1 score on validation data."
80
+ ))
81
+ readme_probabilities: str = field(default=(
82
+ "Full probability distribution across all classes from the model's softmax output. "
83
+ "Dictionary mapping class labels to their probabilities (sum to 1.0)."
84
+ ))
85
+ readme_pos_label: str = field(default=(
86
+ "The class label considered 'positive' for binary classification metrics. "
87
+ "Threshold and confidence calculations are relative to this class."
88
+ ))
49
89
  predictor_id: Optional[str] = None
50
90
  session_id: Optional[str] = None
51
91
  target_column: Optional[str] = None
52
92
  timestamp: Optional[datetime] = None
53
93
  model_version: Optional[str] = None
54
94
 
95
+ # Checkpoint info from the model (epoch, metric_type, metric_value)
96
+ checkpoint_info: Optional[Dict[str, Any]] = None
97
+
98
+ # Guardrails and warnings
99
+ guardrails: Optional[Dict[str, Any]] = None
100
+ ignored_query_columns: Optional[list] = None
101
+ available_query_columns: Optional[list] = None
102
+
103
+ # Feature importance (from leave-one-out ablation)
104
+ feature_importance: Optional[Dict[str, float]] = None
105
+
55
106
  # Internal: client context for sending feedback
56
107
  _ctx: Optional['ClientContext'] = field(default=None, repr=False)
57
108
 
@@ -73,29 +124,44 @@ class PredictionResult:
73
124
  Returns:
74
125
  PredictionResult instance
75
126
  """
76
- # Extract prediction data
127
+ # Extract prediction data - handle both formats
128
+ # New format: prediction is the class label, probabilities is separate
129
+ # Old format: prediction is the probabilities dict
77
130
  prediction = response.get('prediction')
78
- predicted_class = None
79
- confidence = None
80
-
81
- # For classification, extract class and confidence
82
- if isinstance(prediction, dict):
83
- # Find the class with highest probability
131
+ probabilities = response.get('probabilities')
132
+ predicted_class = response.get('predicted_class')
133
+ probability = response.get('probability')
134
+ confidence = response.get('confidence')
135
+
136
+ # For old format where prediction is the probabilities dict
137
+ if isinstance(prediction, dict) and not probabilities:
138
+ probabilities = prediction
84
139
  if prediction:
85
140
  predicted_class = max(prediction.keys(), key=lambda k: prediction[k])
86
- confidence = prediction[predicted_class]
141
+ probability = prediction[predicted_class]
142
+ confidence = probability # Old format: confidence = probability
143
+ elif isinstance(prediction, str) and not predicted_class:
144
+ # New format: prediction is already the class label
145
+ predicted_class = prediction
87
146
 
88
147
  return cls(
89
148
  prediction_uuid=response.get('prediction_uuid') or response.get('prediction_id'),
90
149
  prediction=prediction,
91
150
  predicted_class=predicted_class,
151
+ probability=probability,
92
152
  confidence=confidence,
153
+ probabilities=probabilities,
154
+ threshold=response.get('threshold'),
93
155
  query_record=query_record,
94
156
  predictor_id=response.get('predictor_id'),
95
157
  session_id=response.get('session_id'),
96
158
  target_column=response.get('target_column'),
97
159
  timestamp=datetime.now(),
98
160
  model_version=response.get('model_version'),
161
+ checkpoint_info=response.get('checkpoint_info'),
162
+ guardrails=response.get('guardrails'),
163
+ ignored_query_columns=response.get('ignored_query_columns'),
164
+ available_query_columns=response.get('available_query_columns'),
99
165
  _ctx=ctx,
100
166
  )
101
167
 
@@ -126,18 +192,41 @@ class PredictionResult:
126
192
 
127
193
  def to_dict(self) -> Dict[str, Any]:
128
194
  """Convert to dictionary representation."""
129
- return {
195
+ result = {
130
196
  'prediction_uuid': self.prediction_uuid,
131
197
  'prediction': self.prediction,
132
198
  'predicted_class': self.predicted_class,
199
+ 'probability': self.probability,
133
200
  'confidence': self.confidence,
201
+ 'probabilities': self.probabilities,
202
+ 'threshold': self.threshold,
134
203
  'query_record': self.query_record,
135
204
  'predictor_id': self.predictor_id,
136
205
  'session_id': self.session_id,
137
206
  'target_column': self.target_column,
138
207
  'timestamp': self.timestamp.isoformat() if self.timestamp else None,
139
208
  'model_version': self.model_version,
209
+ # Documentation
210
+ 'readme_prediction': self.readme_prediction,
211
+ 'readme_probability': self.readme_probability,
212
+ 'readme_confidence': self.readme_confidence,
213
+ 'readme_threshold': self.readme_threshold,
214
+ 'readme_probabilities': self.readme_probabilities,
215
+ 'readme_pos_label': self.readme_pos_label,
140
216
  }
217
+ # Include checkpoint_info if present
218
+ if self.checkpoint_info:
219
+ result['checkpoint_info'] = self.checkpoint_info
220
+ # Include guardrails if present
221
+ if self.guardrails:
222
+ result['guardrails'] = self.guardrails
223
+ if self.ignored_query_columns:
224
+ result['ignored_query_columns'] = self.ignored_query_columns
225
+ if self.available_query_columns:
226
+ result['available_query_columns'] = self.available_query_columns
227
+ if self.feature_importance:
228
+ result['feature_importance'] = self.feature_importance
229
+ return result
141
230
 
142
231
 
143
232
  @dataclass
@@ -105,7 +105,8 @@ class Predictor:
105
105
  def predict(
106
106
  self,
107
107
  record: Dict[str, Any],
108
- best_metric_preference: Optional[str] = None
108
+ best_metric_preference: Optional[str] = None,
109
+ feature_importance: bool = False
109
110
  ) -> PredictionResult:
110
111
  """
111
112
  Make a single prediction.
@@ -113,15 +114,21 @@ class Predictor:
113
114
  Args:
114
115
  record: Input record dictionary
115
116
  best_metric_preference: Metric checkpoint to use ("roc_auc", "pr_auc", or None)
117
+ feature_importance: If True, compute feature importance via leave-one-out ablation
116
118
 
117
119
  Returns:
118
- PredictionResult with prediction, confidence, and prediction_uuid
120
+ PredictionResult with prediction, confidence, and prediction_uuid.
121
+ If feature_importance=True, also includes feature_importance dict.
119
122
 
120
123
  Example:
121
124
  result = predictor.predict({"age": 35, "income": 50000})
122
125
  print(result.predicted_class) # "churned"
123
126
  print(result.confidence) # 0.87
124
127
  print(result.prediction_uuid) # UUID for feedback
128
+
129
+ # With feature importance
130
+ result = predictor.predict(record, feature_importance=True)
131
+ print(result.feature_importance) # {"income": 0.15, "age": 0.08, ...}
125
132
  """
126
133
  if not self._ctx:
127
134
  raise ValueError("Predictor not connected to client")
@@ -129,7 +136,44 @@ class Predictor:
129
136
  # Clean the record
130
137
  cleaned_record = self._clean_record(record)
131
138
 
132
- # Build request
139
+ if feature_importance:
140
+ # Build N+1 records: original + each feature nulled out
141
+ columns = list(cleaned_record.keys())
142
+ batch = [cleaned_record] # Original first
143
+
144
+ for col in columns:
145
+ ablated = cleaned_record.copy()
146
+ ablated[col] = None
147
+ batch.append(ablated)
148
+
149
+ # Single batch call
150
+ results = self.batch_predict(
151
+ batch,
152
+ show_progress=False,
153
+ best_metric_preference=best_metric_preference
154
+ )
155
+
156
+ # Compare: importance = |original_confidence - ablated_confidence|
157
+ original = results[0]
158
+ importance = {}
159
+ original_conf = original.confidence or 0.0
160
+
161
+ for i, col in enumerate(columns):
162
+ ablated_result = results[i + 1]
163
+ ablated_conf = ablated_result.confidence or 0.0
164
+ # Higher delta = more important
165
+ delta = abs(original_conf - ablated_conf)
166
+ importance[col] = round(delta, 4)
167
+
168
+ # Sort by importance (highest first)
169
+ original.feature_importance = dict(sorted(
170
+ importance.items(),
171
+ key=lambda x: x[1],
172
+ reverse=True
173
+ ))
174
+ return original
175
+
176
+ # Standard single prediction
133
177
  request_payload = {
134
178
  "query_record": cleaned_record,
135
179
  "predictor_id": self.id,
@@ -196,6 +240,36 @@ class Predictor:
196
240
 
197
241
  return results
198
242
 
243
+ def predict_csv_file(
244
+ self,
245
+ csv_path: str,
246
+ show_progress: bool = True,
247
+ best_metric_preference: Optional[str] = None
248
+ ) -> List[PredictionResult]:
249
+ """
250
+ Load a CSV file and run batch predictions on all rows.
251
+
252
+ Args:
253
+ csv_path: Path to the CSV file
254
+ show_progress: Show progress bar
255
+ best_metric_preference: Metric checkpoint to use
256
+
257
+ Returns:
258
+ List of PredictionResult objects
259
+
260
+ Example:
261
+ results = predictor.predict_csv_file("test_data.csv")
262
+ for r in results:
263
+ print(r.predicted_class, r.confidence)
264
+ """
265
+ import pandas as pd
266
+ df = pd.read_csv(csv_path)
267
+ return self.batch_predict(
268
+ df,
269
+ show_progress=show_progress,
270
+ best_metric_preference=best_metric_preference
271
+ )
272
+
199
273
  def explain(
200
274
  self,
201
275
  record: Dict[str, Any],
featrixsphere/client.py CHANGED
@@ -2930,24 +2930,26 @@ class FeatrixSphereClient:
2930
2930
  # Single Predictor Functionality
2931
2931
  # =========================================================================
2932
2932
 
2933
- def predict(self, session_id: str, record: Dict[str, Any], target_column: str = None,
2933
+ def predict(self, session_id: str, record: Dict[str, Any], target_column: str = None,
2934
2934
  predictor_id: str = None, best_metric_preference: str = None,
2935
- max_retries: int = None, queue_batches: bool = False) -> Dict[str, Any]:
2935
+ checkpoint_epoch: int = None, max_retries: int = None,
2936
+ queue_batches: bool = False) -> Dict[str, Any]:
2936
2937
  """
2937
2938
  Make a single prediction for a record.
2938
-
2939
+
2939
2940
  Args:
2940
2941
  session_id: ID of session with trained predictor
2941
2942
  record: Record dictionary (without target column)
2942
2943
  target_column: Specific target column predictor to use (required if multiple predictors exist and predictor_id not specified)
2943
2944
  predictor_id: Specific predictor ID to use (recommended - more precise than target_column)
2944
2945
  best_metric_preference: Which metric checkpoint to use: "roc_auc", "pr_auc", or None (use default checkpoint) (default: None)
2946
+ checkpoint_epoch: Specific epoch checkpoint to use (e.g., 65 for epoch 65). Overrides best_metric_preference.
2945
2947
  max_retries: Number of retries for errors (default: uses client default)
2946
2948
  queue_batches: If True, queue this prediction for batch processing instead of immediate API call
2947
-
2949
+
2948
2950
  Returns:
2949
2951
  Prediction result dictionary if queue_batches=False, or queue ID if queue_batches=True
2950
-
2952
+
2951
2953
  Note:
2952
2954
  predictor_id is recommended over target_column for precision. If both are provided, predictor_id takes precedence.
2953
2955
  Use client.list_predictors(session_id) to see available predictor IDs.
@@ -2958,29 +2960,31 @@ class FeatrixSphereClient:
2958
2960
  if should_warn:
2959
2961
  call_count = len(self._prediction_call_times.get(session_id, []))
2960
2962
  self._show_batching_warning(session_id, call_count)
2961
-
2963
+
2962
2964
  # If queueing is enabled, add to queue and return queue ID
2963
2965
  if queue_batches:
2964
2966
  queue_id = self._add_to_prediction_queue(session_id, record, target_column, predictor_id)
2965
2967
  return {"queued": True, "queue_id": queue_id}
2966
-
2967
- # Clean NaN/Inf values
2968
+
2969
+ # Clean NaN/Inf values
2968
2970
  cleaned_record = self._clean_numpy_values(record)
2969
2971
  cleaned_record = self.replace_nans_with_nulls(cleaned_record)
2970
-
2972
+
2971
2973
  # Build request payload - let the server handle predictor resolution
2972
2974
  request_payload = {
2973
2975
  "query_record": cleaned_record,
2974
2976
  }
2975
-
2977
+
2976
2978
  # Include whatever the caller provided - server will figure it out
2977
2979
  if target_column:
2978
2980
  request_payload["target_column"] = target_column
2979
2981
  if predictor_id:
2980
2982
  request_payload["predictor_id"] = predictor_id
2981
- if best_metric_preference:
2983
+ if checkpoint_epoch is not None:
2984
+ request_payload["checkpoint_epoch"] = checkpoint_epoch
2985
+ elif best_metric_preference:
2982
2986
  request_payload["best_metric_preference"] = best_metric_preference
2983
-
2987
+
2984
2988
  # Just send it to the server - it has all the smart fallback logic
2985
2989
  response_data = self._post_json(f"/session/{session_id}/predict", request_payload, max_retries=max_retries)
2986
2990
  return response_data
@@ -6651,8 +6655,8 @@ class FeatrixSphereClient:
6651
6655
 
6652
6656
  def predict_table(self, session_id: str, table_data: Dict[str, Any],
6653
6657
  target_column: str = None, predictor_id: str = None,
6654
- best_metric_preference: str = None, max_retries: int = None,
6655
- trace: bool = False) -> Dict[str, Any]:
6658
+ best_metric_preference: str = None, checkpoint_epoch: int = None,
6659
+ max_retries: int = None, trace: bool = False) -> Dict[str, Any]:
6656
6660
  """
6657
6661
  Make batch predictions using JSON Tables format.
6658
6662
 
@@ -6661,6 +6665,8 @@ class FeatrixSphereClient:
6661
6665
  table_data: Data in JSON Tables format, or list of records, or dict with 'table'/'records'
6662
6666
  target_column: Specific target column predictor to use (required if multiple predictors exist)
6663
6667
  predictor_id: Specific predictor ID to use (recommended - more precise than target_column)
6668
+ best_metric_preference: Which metric checkpoint to use: "roc_auc", "pr_auc", or None (default)
6669
+ checkpoint_epoch: Specific epoch checkpoint to use (e.g., 65). Overrides best_metric_preference.
6664
6670
  max_retries: Number of retries for errors (default: uses client default, recommend higher for batch)
6665
6671
  trace: Enable detailed debug logging (default: False)
6666
6672
 
@@ -6713,7 +6719,9 @@ class FeatrixSphereClient:
6713
6719
  table_data['target_column'] = target_column
6714
6720
  if predictor_id:
6715
6721
  table_data['predictor_id'] = predictor_id
6716
- if best_metric_preference:
6722
+ if checkpoint_epoch is not None:
6723
+ table_data['checkpoint_epoch'] = checkpoint_epoch
6724
+ elif best_metric_preference:
6717
6725
  table_data['best_metric_preference'] = best_metric_preference
6718
6726
 
6719
6727
  if trace:
@@ -6747,7 +6755,8 @@ class FeatrixSphereClient:
6747
6755
  raise
6748
6756
 
6749
6757
  def predict_records(self, session_id: str, records: List[Dict[str, Any]],
6750
- target_column: str = None, predictor_id: str = None, best_metric_preference: str = None,
6758
+ target_column: str = None, predictor_id: str = None,
6759
+ best_metric_preference: str = None, checkpoint_epoch: int = None,
6751
6760
  batch_size: int = 2500, use_async: bool = False,
6752
6761
  show_progress_bar: bool = True, print_target_column_warning: bool = True,
6753
6762
  trace: bool = False) -> Dict[str, Any]:
@@ -6759,6 +6768,8 @@ class FeatrixSphereClient:
6759
6768
  records: List of record dictionaries
6760
6769
  target_column: Specific target column predictor to use (required if multiple predictors exist and predictor_id not specified)
6761
6770
  predictor_id: Specific predictor ID to use (recommended - more precise than target_column)
6771
+ best_metric_preference: Which metric checkpoint to use: "roc_auc", "pr_auc", or None (default)
6772
+ checkpoint_epoch: Specific epoch checkpoint to use (e.g., 65). Overrides best_metric_preference.
6762
6773
  batch_size: Number of records to send per API call (default: 2500)
6763
6774
  use_async: Force async processing for large datasets (default: False - async disabled due to pickle issues)
6764
6775
  show_progress_bar: Whether to show progress bar for async jobs (default: True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: featrixsphere
3
- Version: 0.2.6379
3
+ Version: 0.2.6710
4
4
  Summary: Transform any CSV into a production-ready ML model in minutes, not months.
5
5
  Home-page: https://github.com/Featrix/sphere
6
6
  Author: Featrix
@@ -0,0 +1,17 @@
1
+ featrixsphere/__init__.py,sha256=QQglKOYv0bjonuO0-wOkeyyXMWhv3yK0_s5Uaap-GVk,2190
2
+ featrixsphere/client.py,sha256=JNYAHDFtxmhQVuclO3dWEphJuxNFLi-PfvWylZWKF_4,452929
3
+ featrixsphere/api/__init__.py,sha256=quyvuPphVj9wb6v8Dio0SMG9iHgJAmY3asHk3f_zF10,1269
4
+ featrixsphere/api/api_endpoint.py,sha256=i3eCWuaUXftnH1Ai6MFZ7md7pC2FcRAIRO87CBZhyEQ,9000
5
+ featrixsphere/api/client.py,sha256=TvNqrzSPQdw0A4kW48M0S3SDrBRmkc6kTY8UkzO4eRs,13426
6
+ featrixsphere/api/foundational_model.py,sha256=0ZFO-mJs66nVRXQbM0o1fB4HmhzLBXUqbTCF46LVH1k,34925
7
+ featrixsphere/api/http_client.py,sha256=q59-41fHua_7AwtPFCvshlSUKJ-fS0X337L9Ooyn0DI,8440
8
+ featrixsphere/api/notebook_helper.py,sha256=xY9jsao26eaNiFh2s0_TlRZnR8xZ4P_e0EOKr2PtoVs,20060
9
+ featrixsphere/api/prediction_result.py,sha256=HQsJdr89zWxdRx395nevN3aP7ZXZuZxB4UGX5Ykhkfk,12235
10
+ featrixsphere/api/predictor.py,sha256=1v0ffkEjmrO3BP0PNWAXtiAU-AlOQJSiDICmW1bQbGU,20300
11
+ featrixsphere/api/reference_record.py,sha256=-XOTF6ynznB3ouz06w3AF8X9SVId0g_dO20VvGNesUQ,7095
12
+ featrixsphere/api/vector_database.py,sha256=BplxKkPnAbcBX1A4KxFBJVb3qkQ-FH9zi9v2dWG5CgY,7976
13
+ featrixsphere-0.2.6710.dist-info/METADATA,sha256=_gDwyRsSfEa0thWd6IjhMeCEZF5dMc8BfLBL4J2b5HQ,16232
14
+ featrixsphere-0.2.6710.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
15
+ featrixsphere-0.2.6710.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
16
+ featrixsphere-0.2.6710.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
17
+ featrixsphere-0.2.6710.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.10.1)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,17 +0,0 @@
1
- featrixsphere/__init__.py,sha256=m4FTeSot2GaITV5l_kD5WrSPZKdKmVbcmRwXZE_nYJk,2190
2
- featrixsphere/client.py,sha256=Nj6C_Th4jyK7JQIXUJ_URok9AA0OND6DOAjoFbKhs2Q,452098
3
- featrixsphere/api/__init__.py,sha256=quyvuPphVj9wb6v8Dio0SMG9iHgJAmY3asHk3f_zF10,1269
4
- featrixsphere/api/api_endpoint.py,sha256=i3eCWuaUXftnH1Ai6MFZ7md7pC2FcRAIRO87CBZhyEQ,9000
5
- featrixsphere/api/client.py,sha256=TdpujNsJxO4GfPMI_KoemQWV89go3KuK6OPAo9jX6Bs,12574
6
- featrixsphere/api/foundational_model.py,sha256=ZF5wKMs6SfsNC3XYYXgbRMhnrtmLe6NeckjCCiH0fK0,21628
7
- featrixsphere/api/http_client.py,sha256=TsOQHHNTDFGAR3mdHevj-0wy1-hPtgHXKe8Egiz5FVo,7269
8
- featrixsphere/api/notebook_helper.py,sha256=xY9jsao26eaNiFh2s0_TlRZnR8xZ4P_e0EOKr2PtoVs,20060
9
- featrixsphere/api/prediction_result.py,sha256=Tx7LXzF4XT-U3VqAN_IFc5DvxPnygc78M2usrD-yMu4,7521
10
- featrixsphere/api/predictor.py,sha256=-vwCKpCfTgZKqzpDnzy1iYZQ-1-MGW8aErvxM9trktw,17652
11
- featrixsphere/api/reference_record.py,sha256=-XOTF6ynznB3ouz06w3AF8X9SVId0g_dO20VvGNesUQ,7095
12
- featrixsphere/api/vector_database.py,sha256=BplxKkPnAbcBX1A4KxFBJVb3qkQ-FH9zi9v2dWG5CgY,7976
13
- featrixsphere-0.2.6379.dist-info/METADATA,sha256=EdpmIuyoX1hr1eelFuZbN-zOwrsIsN9TupOeehDJxys,16232
14
- featrixsphere-0.2.6379.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
15
- featrixsphere-0.2.6379.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
16
- featrixsphere-0.2.6379.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
17
- featrixsphere-0.2.6379.dist-info/RECORD,,