featrixsphere 0.2.4991__py3-none-any.whl → 0.2.5183__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
featrixsphere/__init__.py CHANGED
@@ -38,7 +38,7 @@ Example:
38
38
  ... labels=['Experiment A', 'Experiment B'])
39
39
  """
40
40
 
41
- __version__ = "0.2.4991"
41
+ __version__ = "0.2.5183"
42
42
  __author__ = "Featrix"
43
43
  __email__ = "support@featrix.com"
44
44
  __license__ = "MIT"
featrixsphere/client.py CHANGED
@@ -877,24 +877,30 @@ class FeatrixSphereClient:
877
877
  # Other HTTP errors - re-raise
878
878
  raise
879
879
 
880
- def publish_session(self, session_id: str) -> Dict[str, Any]:
880
+ def publish_session(self, session_id: str, org_id: str, name: str) -> Dict[str, Any]:
881
881
  """
882
- Publish a session by moving it to /sphere/published/<sessionId>.
882
+ Publish a session by moving it to /backplane/backplane1/sphere/PUBLISHED/<org_id>/<name>/<sessionId>.
883
883
  Moves both the session file and output directory.
884
884
 
885
885
  Args:
886
886
  session_id: Session ID to publish
887
+ org_id: Organization ID for subdirectory organization
888
+ name: Name for the published session (creates subdirectory under org_id)
887
889
 
888
890
  Returns:
889
891
  Response with published_path, output_path, and status
890
892
 
891
893
  Example:
892
894
  ```python
893
- result = client.publish_session("abc123")
895
+ result = client.publish_session("abc123", org_id="org_123", name="production_model")
894
896
  print(f"Published to: {result['published_path']}")
895
897
  ```
896
898
  """
897
- response_data = self._post_json(f"/compute/session/{session_id}/publish", {})
899
+ data = {
900
+ "org_id": org_id,
901
+ "name": name
902
+ }
903
+ response_data = self._post_json(f"/compute/session/{session_id}/publish", data)
898
904
  return response_data
899
905
 
900
906
  def deprecate_session(self, session_id: str, warning_message: str, expiration_date: str) -> Dict[str, Any]:
@@ -1241,7 +1247,18 @@ class FeatrixSphereClient:
1241
1247
  } if show_live_training_movie else None
1242
1248
 
1243
1249
  while time.time() - start_time < max_wait_time:
1244
- session_info = self.get_session_status(session_id)
1250
+ try:
1251
+ session_info = self.get_session_status(session_id)
1252
+ except KeyboardInterrupt:
1253
+ print("\n\n⚠️ Interrupted by user (Ctrl+C)")
1254
+ print(f" Session {session_id} status check interrupted")
1255
+ print(" Returning current session status...")
1256
+ # Return last known status or get it one more time if possible
1257
+ try:
1258
+ return self.get_session_status(session_id)
1259
+ except:
1260
+ # If we can't get status, return a basic SessionInfo
1261
+ return SessionInfo(session_id=session_id, status="unknown", jobs={})
1245
1262
 
1246
1263
  # Clear previous output and show updated status
1247
1264
  clear_output(wait=True)
@@ -1357,7 +1374,19 @@ class FeatrixSphereClient:
1357
1374
  session_task = progress.add_task(f"[bold green]Session {session_id}", total=100)
1358
1375
 
1359
1376
  while time.time() - start_time < max_wait_time:
1360
- session_info = self.get_session_status(session_id)
1377
+ try:
1378
+ session_info = self.get_session_status(session_id)
1379
+ except KeyboardInterrupt:
1380
+ progress.console.print("\n\n[bold yellow]⚠️ Interrupted by user (Ctrl+C)[/bold yellow]")
1381
+ progress.console.print(f" Session {session_id} status check interrupted")
1382
+ progress.console.print(" Returning current session status...")
1383
+ # Return last known status or get it one more time if possible
1384
+ try:
1385
+ return self.get_session_status(session_id)
1386
+ except:
1387
+ # If we can't get status, return a basic SessionInfo
1388
+ from featrixsphere.client import SessionInfo
1389
+ return SessionInfo(session_id=session_id, status="unknown", jobs={})
1361
1390
 
1362
1391
  # Update session progress
1363
1392
  elapsed = time.time() - start_time
@@ -1426,7 +1455,19 @@ class FeatrixSphereClient:
1426
1455
  jobs_appeared = False
1427
1456
 
1428
1457
  while time.time() - initial_wait_start < initial_wait_timeout:
1429
- session_info = self.get_session_status(session_id)
1458
+ try:
1459
+ session_info = self.get_session_status(session_id)
1460
+ except KeyboardInterrupt:
1461
+ logger.info("\n⚠️ Interrupted by user (Ctrl+C)")
1462
+ logger.info(f" Session {session_id} status check interrupted")
1463
+ logger.info(" Returning current session status...")
1464
+ # Return last known status or get it one more time if possible
1465
+ try:
1466
+ return self.get_session_status(session_id)
1467
+ except:
1468
+ # If we can't get status, return a basic SessionInfo
1469
+ return SessionInfo(session_id=session_id, status="unknown", jobs={})
1470
+
1430
1471
  if session_info.jobs:
1431
1472
  jobs_appeared = True
1432
1473
  break
@@ -1434,7 +1475,20 @@ class FeatrixSphereClient:
1434
1475
 
1435
1476
  # Main monitoring loop
1436
1477
  while time.time() - start_time < max_wait_time:
1437
- session_info = self.get_session_status(session_id)
1478
+ try:
1479
+ session_info = self.get_session_status(session_id)
1480
+ except KeyboardInterrupt:
1481
+ logger.info("\n⚠️ Interrupted by user (Ctrl+C)")
1482
+ logger.info(f" Session {session_id} status check interrupted")
1483
+ logger.info(" Returning current session status...")
1484
+ # Return last known status or get it one more time if possible
1485
+ try:
1486
+ return self.get_session_status(session_id)
1487
+ except:
1488
+ # If we can't get status, return a basic SessionInfo
1489
+ from featrixsphere.client import SessionInfo
1490
+ return SessionInfo(session_id=session_id, status="unknown", jobs={})
1491
+
1438
1492
  elapsed = time.time() - start_time
1439
1493
 
1440
1494
  # Call the callback with current status
@@ -1471,7 +1525,19 @@ class FeatrixSphereClient:
1471
1525
  last_num_lines = 0
1472
1526
 
1473
1527
  while time.time() - start_time < max_wait_time:
1474
- session_info = self.get_session_status(session_id)
1528
+ try:
1529
+ session_info = self.get_session_status(session_id)
1530
+ except KeyboardInterrupt:
1531
+ print("\n\n⚠️ Interrupted by user (Ctrl+C)")
1532
+ print(f" Session {session_id} status check interrupted")
1533
+ print(" Returning current session status...")
1534
+ # Return last known status or get it one more time if possible
1535
+ try:
1536
+ return self.get_session_status(session_id)
1537
+ except:
1538
+ # If we can't get status, return a basic SessionInfo
1539
+ from featrixsphere.client import SessionInfo
1540
+ return SessionInfo(session_id=session_id, status="unknown", jobs={})
1475
1541
 
1476
1542
  # Clear previous lines if terminal supports it
1477
1543
  if sys.stdout.isatty() and last_num_lines > 0:
@@ -2866,6 +2932,110 @@ class FeatrixSphereClient:
2866
2932
  response_data = self._post_json(f"/session/{session_id}/predict", request_payload, max_retries=max_retries)
2867
2933
  return response_data
2868
2934
 
2935
+ def explain(self, session_id: str, record, class_idx: int = None,
2936
+ target_column: str = None, predictor_id: str = None,
2937
+ record_b: Dict[str, Any] = None,
2938
+ max_retries: int = None) -> Dict[str, Any]:
2939
+ """
2940
+ Explain a prediction using gradient attribution.
2941
+
2942
+ Supports multiple modes:
2943
+ - explain(record): Explain a single row
2944
+ - explain(record, record_b=other_record): Compare two rows
2945
+ - explain([record1, record2, ...]): Explain multiple rows
2946
+
2947
+ Returns what matters to Featrix in the given row(s):
2948
+ - Which features mattered for this prediction
2949
+ - Which relationships mattered for this prediction
2950
+
2951
+ Args:
2952
+ session_id: ID of session with trained predictor
2953
+ record: Record dictionary (without target column), or list of records
2954
+ class_idx: Target class index for attribution (default: predicted class)
2955
+ target_column: Specific target column predictor to use (optional)
2956
+ predictor_id: Specific predictor ID to use (optional)
2957
+ record_b: Optional second record for comparison
2958
+ max_retries: Number of retries for errors (default: uses client default)
2959
+
2960
+ Returns:
2961
+ For single record:
2962
+ Dictionary with:
2963
+ - feature_scores: {col_name: score} - gradient norm per feature
2964
+ - pair_scores: {(i, j): score} - gradient norm per relationship pair
2965
+ - target_class_idx: The class index used for attribution
2966
+ - logit: The prediction logit
2967
+
2968
+ For two records (record_b provided):
2969
+ Dictionary with:
2970
+ - record_a: Explanation for first record
2971
+ - record_b: Explanation for second record
2972
+ - difference: Difference in feature_scores and pair_scores
2973
+
2974
+ For list of records:
2975
+ Dictionary with:
2976
+ - explanations: List of explanation dictionaries, one per record
2977
+ """
2978
+ # Clean NaN/Inf values
2979
+ if isinstance(record, list):
2980
+ cleaned_record = [self.replace_nans_with_nulls(self._clean_numpy_values(r)) for r in record]
2981
+ else:
2982
+ cleaned_record = self.replace_nans_with_nulls(self._clean_numpy_values(record))
2983
+
2984
+ cleaned_record_b = None
2985
+ if record_b is not None:
2986
+ cleaned_record_b = self.replace_nans_with_nulls(self._clean_numpy_values(record_b))
2987
+
2988
+ # Build request payload
2989
+ request_payload = {
2990
+ "query_record": cleaned_record,
2991
+ }
2992
+
2993
+ if class_idx is not None:
2994
+ request_payload["class_idx"] = class_idx
2995
+ if target_column:
2996
+ request_payload["target_column"] = target_column
2997
+ if predictor_id:
2998
+ request_payload["predictor_id"] = predictor_id
2999
+ if cleaned_record_b is not None:
3000
+ request_payload["query_record_b"] = cleaned_record_b
3001
+
3002
+ response_data = self._post_json(f"/session/{session_id}/explain", request_payload, max_retries=max_retries)
3003
+
3004
+ # Helper to convert pair_scores keys back to tuples
3005
+ def convert_pair_scores(ps_dict):
3006
+ if not isinstance(ps_dict, dict):
3007
+ return ps_dict
3008
+ result = {}
3009
+ for key, score in ps_dict.items():
3010
+ # Key format is "i_j"
3011
+ parts = key.split("_")
3012
+ if len(parts) == 2:
3013
+ try:
3014
+ i, j = int(parts[0]), int(parts[1])
3015
+ result[(i, j)] = score
3016
+ except ValueError:
3017
+ result[key] = score
3018
+ else:
3019
+ result[key] = score
3020
+ return result
3021
+
3022
+ # Convert pair_scores keys back to tuples for easier use
3023
+ if "pair_scores" in response_data:
3024
+ response_data["pair_scores"] = convert_pair_scores(response_data["pair_scores"])
3025
+ elif "explanations" in response_data:
3026
+ for expl in response_data["explanations"]:
3027
+ if "pair_scores" in expl:
3028
+ expl["pair_scores"] = convert_pair_scores(expl["pair_scores"])
3029
+ elif "record_a" in response_data:
3030
+ if "pair_scores" in response_data["record_a"]:
3031
+ response_data["record_a"]["pair_scores"] = convert_pair_scores(response_data["record_a"]["pair_scores"])
3032
+ if "pair_scores" in response_data["record_b"]:
3033
+ response_data["record_b"]["pair_scores"] = convert_pair_scores(response_data["record_b"]["pair_scores"])
3034
+ if "difference" in response_data and "pair_scores" in response_data["difference"]:
3035
+ response_data["difference"]["pair_scores"] = convert_pair_scores(response_data["difference"]["pair_scores"])
3036
+
3037
+ return response_data
3038
+
2869
3039
  def get_training_metrics(self, session_id: str, max_retries: int = None) -> Dict[str, Any]:
2870
3040
  """
2871
3041
  Get training metrics for a session's single predictor.
@@ -8117,6 +8287,55 @@ class FeatrixSphereClient:
8117
8287
  """
8118
8288
  return PredictionGrid(session_id, self, degrees_of_freedom, grid_shape, target_column)
8119
8289
 
8290
+ def get_embedding_space_columns(self, session_id: str) -> Dict[str, Any]:
8291
+ """
8292
+ Get column names and types from the embedding space.
8293
+
8294
+ Tries to get from model_card.json first (if training completed),
8295
+ otherwise falls back to loading the embedding space directly.
8296
+
8297
+ Args:
8298
+ session_id: Session ID with trained embedding space
8299
+
8300
+ Returns:
8301
+ Dictionary with:
8302
+ - column_names: List of column names
8303
+ - column_types: Dict mapping column names to types (scalar, set, free_string, etc.)
8304
+ - num_columns: Total number of columns
8305
+
8306
+ Example:
8307
+ >>> columns = client.get_embedding_space_columns(session_id)
8308
+ >>> print(f"Columns: {columns['column_names']}")
8309
+ >>> print(f"Types: {columns['column_types']}")
8310
+ """
8311
+ # Try model_card first (if training completed)
8312
+ try:
8313
+ model_card = self.get_model_card(session_id)
8314
+
8315
+ # Extract column names from training_dataset.feature_names
8316
+ training_dataset = model_card.get('training_dataset', {})
8317
+ column_names = training_dataset.get('feature_names', [])
8318
+
8319
+ # Extract column types from feature_inventory
8320
+ feature_inventory = model_card.get('feature_inventory', {})
8321
+ column_types = {}
8322
+ for feature_name, feature_info in feature_inventory.items():
8323
+ if isinstance(feature_info, dict):
8324
+ column_types[feature_name] = feature_info.get('type', 'unknown')
8325
+
8326
+ if column_names:
8327
+ return {
8328
+ "column_names": column_names,
8329
+ "column_types": column_types,
8330
+ "num_columns": len(column_names)
8331
+ }
8332
+ except Exception:
8333
+ # Model card doesn't exist yet, fall back to direct endpoint
8334
+ pass
8335
+
8336
+ # Fallback: load embedding space directly
8337
+ return self._get_json(f"/compute/session/{session_id}/columns")
8338
+
8120
8339
 
8121
8340
  class PredictionGrid:
8122
8341
  """
@@ -8862,55 +9081,6 @@ class PredictionGrid:
8862
9081
  else:
8863
9082
  return fig
8864
9083
 
8865
- def get_embedding_space_columns(self, session_id: str) -> Dict[str, Any]:
8866
- """
8867
- Get column names and types from the embedding space.
8868
-
8869
- Tries to get from model_card.json first (if training completed),
8870
- otherwise falls back to loading the embedding space directly.
8871
-
8872
- Args:
8873
- session_id: Session ID with trained embedding space
8874
-
8875
- Returns:
8876
- Dictionary with:
8877
- - column_names: List of column names
8878
- - column_types: Dict mapping column names to types (scalar, set, free_string, etc.)
8879
- - num_columns: Total number of columns
8880
-
8881
- Example:
8882
- >>> columns = client.get_embedding_space_columns(session_id)
8883
- >>> print(f"Columns: {columns['column_names']}")
8884
- >>> print(f"Types: {columns['column_types']}")
8885
- """
8886
- # Try model_card first (if training completed)
8887
- try:
8888
- model_card = self.get_model_card(session_id)
8889
-
8890
- # Extract column names from training_dataset.feature_names
8891
- training_dataset = model_card.get('training_dataset', {})
8892
- column_names = training_dataset.get('feature_names', [])
8893
-
8894
- # Extract column types from feature_inventory
8895
- feature_inventory = model_card.get('feature_inventory', {})
8896
- column_types = {}
8897
- for feature_name, feature_info in feature_inventory.items():
8898
- if isinstance(feature_info, dict):
8899
- column_types[feature_name] = feature_info.get('type', 'unknown')
8900
-
8901
- if column_names:
8902
- return {
8903
- "column_names": column_names,
8904
- "column_types": column_types,
8905
- "num_columns": len(column_names)
8906
- }
8907
- except Exception:
8908
- # Model card doesn't exist yet, fall back to direct endpoint
8909
- pass
8910
-
8911
- # Fallback: load embedding space directly
8912
- return self._get_json(f"/compute/session/{session_id}/columns")
8913
-
8914
9084
  def get_predictor_schema(self, session_id: str, predictor_index: int = 0) -> Dict[str, Any]:
8915
9085
  """
8916
9086
  Get predictor schema/metadata for validating input data locally.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: featrixsphere
3
- Version: 0.2.4991
3
+ Version: 0.2.5183
4
4
  Summary: Transform any CSV into a production-ready ML model in minutes, not months.
5
5
  Home-page: https://github.com/Featrix/sphere
6
6
  Author: Featrix
@@ -0,0 +1,7 @@
1
+ featrixsphere/__init__.py,sha256=tPVhFLHujHMC0bnsnjggdMIb2chNkYbbm0wUX-lwWYY,1888
2
+ featrixsphere/client.py,sha256=YvOB2y8zh4iCMccXQ-4ZsQ8dgmUSQlkLh2zsxIiIoYM,435090
3
+ featrixsphere-0.2.5183.dist-info/METADATA,sha256=BEVBhDxyQvjFfDWyzSSKTZLwybnQFyPYU6dk_cxB5CM,16232
4
+ featrixsphere-0.2.5183.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
5
+ featrixsphere-0.2.5183.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
6
+ featrixsphere-0.2.5183.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
7
+ featrixsphere-0.2.5183.dist-info/RECORD,,
@@ -1,311 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Tests for FeatrixSphereClient
4
-
5
- These tests verify basic functionality without requiring a live API server.
6
- """
7
-
8
- import unittest
9
- from unittest.mock import Mock, patch, MagicMock
10
- import json
11
- from pathlib import Path
12
- import sys
13
-
14
- # Mock optional dependencies before importing featrixsphere
15
- import sys
16
-
17
- # Mock numpy
18
- try:
19
- import numpy as np
20
- except ImportError:
21
- class MockNumpy:
22
- class ndarray:
23
- pass
24
- def array(self, *args, **kwargs):
25
- return []
26
- def __getattr__(self, name):
27
- return lambda *args, **kwargs: None
28
- sys.modules['numpy'] = MockNumpy()
29
-
30
- # Mock matplotlib
31
- try:
32
- import matplotlib.pyplot as plt
33
- except ImportError:
34
- class MockMatplotlib:
35
- class Figure:
36
- pass
37
- def __getattr__(self, name):
38
- return lambda *args, **kwargs: None
39
- sys.modules['matplotlib'] = MockMatplotlib()
40
- sys.modules['matplotlib.pyplot'] = MockMatplotlib()
41
- sys.modules['matplotlib.dates'] = MockMatplotlib()
42
-
43
- # Add parent directory to path to import featrixsphere
44
- sys.path.insert(0, str(Path(__file__).parent.parent))
45
-
46
- try:
47
- from featrixsphere import FeatrixSphereClient, SessionInfo, PredictionBatch
48
- except (ImportError, AttributeError) as e:
49
- # If import fails, create minimal mocks for basic structure tests
50
- print(f"⚠️ Warning: Could not fully import featrixsphere: {e}")
51
- print(" Running minimal structure tests only...")
52
-
53
- # Create minimal mocks for basic testing
54
- class MockSession:
55
- def __init__(self):
56
- self.headers = {}
57
- self.timeout = 30
58
-
59
- class MockFeatrixSphereClient:
60
- def __init__(self, base_url="http://test.com", **kwargs):
61
- self.base_url = base_url.rstrip('/')
62
- self.compute_cluster = kwargs.get('compute_cluster')
63
- self.session = MockSession()
64
- self.default_max_retries = 5
65
- # Set header if compute_cluster provided
66
- if self.compute_cluster:
67
- self.session.headers['X-Featrix-Node'] = self.compute_cluster
68
-
69
- def set_compute_cluster(self, cluster):
70
- self.compute_cluster = cluster
71
- if cluster:
72
- self.session.headers['X-Featrix-Node'] = cluster
73
- else:
74
- self.session.headers.pop('X-Featrix-Node', None)
75
-
76
- def _make_request(self, method, endpoint, **kwargs):
77
- from unittest.mock import Mock
78
- response = Mock()
79
- response.status_code = 200
80
- response.json.return_value = {}
81
- return response
82
-
83
- class MockSessionInfo:
84
- def __init__(self, session_id, session_type, status, jobs, job_queue_positions, _client=None):
85
- self.session_id = session_id
86
- self.session_type = session_type
87
- self.status = status
88
-
89
- class MockPredictionBatch:
90
- def __init__(self, session_id, client, target_column=None):
91
- self.session_id = session_id
92
- self.client = client
93
- self._cache = {}
94
- self._stats = {'hits': 0, 'misses': 0, 'populated': 0}
95
-
96
- def _hash_record(self, record):
97
- import hashlib
98
- import json
99
- sorted_items = sorted(record.items())
100
- record_str = json.dumps(sorted_items, sort_keys=True)
101
- return hashlib.md5(record_str.encode()).hexdigest()
102
-
103
- def predict(self, record):
104
- record_hash = self._hash_record(record)
105
- if record_hash in self._cache:
106
- self._stats['hits'] += 1
107
- return self._cache[record_hash]
108
- else:
109
- self._stats['misses'] += 1
110
- return {
111
- 'cache_miss': True,
112
- 'record': record,
113
- 'suggestion': 'Record not found in batch cache. Add to records list and recreate batch.'
114
- }
115
-
116
- def get_stats(self):
117
- total = self._stats['hits'] + self._stats['misses']
118
- return {
119
- 'cache_hits': self._stats['hits'],
120
- 'cache_misses': self._stats['misses'],
121
- 'total_requests': total,
122
- 'hit_rate': self._stats['hits'] / total if total > 0 else 0.0
123
- }
124
-
125
- FeatrixSphereClient = MockFeatrixSphereClient
126
- SessionInfo = MockSessionInfo
127
- PredictionBatch = MockPredictionBatch
128
-
129
-
130
- class TestFeatrixSphereClient(unittest.TestCase):
131
- """Test cases for FeatrixSphereClient."""
132
-
133
- def setUp(self):
134
- """Set up test fixtures."""
135
- self.client = FeatrixSphereClient(base_url="http://test-server.com")
136
-
137
- def test_client_initialization(self):
138
- """Test that client initializes correctly."""
139
- self.assertEqual(self.client.base_url, "http://test-server.com")
140
- self.assertIsNotNone(self.client.session)
141
- self.assertEqual(self.client.default_max_retries, 5)
142
-
143
- def test_client_with_compute_cluster(self):
144
- """Test client initialization with compute cluster."""
145
- client = FeatrixSphereClient(
146
- base_url="http://test-server.com",
147
- compute_cluster="burrito"
148
- )
149
- self.assertEqual(client.compute_cluster, "burrito")
150
- self.assertIn("X-Featrix-Node", client.session.headers)
151
- self.assertEqual(client.session.headers["X-Featrix-Node"], "burrito")
152
-
153
- def test_set_compute_cluster(self):
154
- """Test setting compute cluster after initialization."""
155
- self.client.set_compute_cluster("churro")
156
- self.assertEqual(self.client.compute_cluster, "churro")
157
- self.assertEqual(self.client.session.headers.get("X-Featrix-Node"), "churro")
158
-
159
- # Test removing cluster
160
- self.client.set_compute_cluster(None)
161
- self.assertIsNone(self.client.compute_cluster)
162
- self.assertNotIn("X-Featrix-Node", self.client.session.headers)
163
-
164
- def test_endpoint_auto_prefix(self):
165
- """Test that session endpoints get /compute prefix automatically."""
166
- # Skip if using mocks (client doesn't have full requests functionality)
167
- if not hasattr(self.client.session, 'get'):
168
- self.skipTest("Skipping - using mocks without full requests support")
169
-
170
- with patch.object(self.client.session, 'get') as mock_get:
171
- mock_response = Mock()
172
- mock_response.status_code = 200
173
- mock_response.json.return_value = {}
174
- mock_get.return_value = mock_response
175
-
176
- # Should auto-add /compute prefix
177
- self.client._make_request('GET', '/session/test-123')
178
- mock_get.assert_called_once()
179
- call_url = mock_get.call_args[0][0]
180
- self.assertIn('/compute/session/test-123', call_url)
181
-
182
- def test_session_info_initialization(self):
183
- """Test SessionInfo dataclass initialization."""
184
- session = SessionInfo(
185
- session_id="test-123",
186
- session_type="embedding_space",
187
- status="complete",
188
- jobs={},
189
- job_queue_positions={}
190
- )
191
- self.assertEqual(session.session_id, "test-123")
192
- self.assertEqual(session.session_type, "embedding_space")
193
- self.assertEqual(session.status, "complete")
194
-
195
- def test_prediction_batch_hash_record(self):
196
- """Test PredictionBatch record hashing."""
197
- batch = PredictionBatch("test-123", self.client)
198
-
199
- record1 = {"a": 1, "b": 2}
200
- record2 = {"b": 2, "a": 1} # Same keys, different order
201
- record3 = {"a": 1, "b": 3} # Different value
202
-
203
- hash1 = batch._hash_record(record1)
204
- hash2 = batch._hash_record(record2)
205
- hash3 = batch._hash_record(record3)
206
-
207
- # Same records should hash to same value (order-independent)
208
- self.assertEqual(hash1, hash2)
209
- # Different records should hash to different values
210
- self.assertNotEqual(hash1, hash3)
211
-
212
- def test_prediction_batch_cache_miss(self):
213
- """Test PredictionBatch cache miss behavior."""
214
- batch = PredictionBatch("test-123", self.client)
215
-
216
- record = {"feature": "value"}
217
- result = batch.predict(record)
218
-
219
- self.assertTrue(result.get('cache_miss'))
220
- self.assertEqual(result.get('record'), record)
221
- self.assertIn('suggestion', result)
222
-
223
- def test_prediction_batch_stats(self):
224
- """Test PredictionBatch statistics tracking."""
225
- batch = PredictionBatch("test-123", self.client)
226
-
227
- # Make some predictions (cache misses)
228
- batch.predict({"a": 1})
229
- batch.predict({"b": 2})
230
-
231
- stats = batch.get_stats()
232
- self.assertEqual(stats['cache_misses'], 2)
233
- self.assertEqual(stats['cache_hits'], 0)
234
- self.assertEqual(stats['total_requests'], 2)
235
- self.assertEqual(stats['hit_rate'], 0.0)
236
-
237
- def test_prediction_batch_cache_hit(self):
238
- """Test PredictionBatch cache hit behavior."""
239
- batch = PredictionBatch("test-123", self.client)
240
-
241
- # Manually populate cache
242
- record = {"feature": "value"}
243
- record_hash = batch._hash_record(record)
244
- batch._cache[record_hash] = {"prediction": "test_result"}
245
- batch._stats['populated'] = 1
246
-
247
- # Now predict should hit cache
248
- result = batch.predict(record)
249
- self.assertFalse(result.get('cache_miss', False))
250
- self.assertEqual(result.get('prediction'), "test_result")
251
-
252
- stats = batch.get_stats()
253
- self.assertEqual(stats['cache_hits'], 1)
254
- self.assertEqual(stats['cache_misses'], 0)
255
- self.assertEqual(stats['hit_rate'], 1.0)
256
-
257
-
258
- class TestClientErrorHandling(unittest.TestCase):
259
- """Test error handling in FeatrixSphereClient."""
260
-
261
- def setUp(self):
262
- """Set up test fixtures."""
263
- self.client = FeatrixSphereClient(base_url="http://test-server.com")
264
-
265
- def test_make_request_retry_on_500(self):
266
- """Test that 500 errors trigger retries."""
267
- # Skip if using mocks (client doesn't have full requests functionality)
268
- if not hasattr(self.client.session, 'get'):
269
- self.skipTest("Skipping - using mocks without full requests support")
270
-
271
- with patch.object(self.client.session, 'get') as mock_get:
272
- # First call returns 500, second returns 200
273
- mock_response_500 = Mock()
274
- mock_response_500.status_code = 500
275
- mock_response_200 = Mock()
276
- mock_response_200.status_code = 200
277
- mock_response_200.json.return_value = {}
278
- mock_get.side_effect = [mock_response_500, mock_response_200]
279
-
280
- # Should retry and eventually succeed
281
- response = self.client._make_request('GET', '/test', max_retries=2)
282
- self.assertEqual(response.status_code, 200)
283
- self.assertEqual(mock_get.call_count, 2)
284
-
285
- def test_make_request_timeout(self):
286
- """Test timeout handling."""
287
- # Skip if using mocks (client doesn't have full requests functionality)
288
- if not hasattr(self.client.session, 'get'):
289
- self.skipTest("Skipping - using mocks without full requests support")
290
-
291
- import requests
292
- with patch.object(self.client.session, 'get') as mock_get:
293
- mock_get.side_effect = requests.exceptions.Timeout("Request timed out")
294
-
295
- # Should raise after retries exhausted
296
- with self.assertRaises(Exception):
297
- self.client._make_request('GET', '/test', max_retries=1)
298
-
299
-
300
- def run_tests():
301
- """Run all tests and return exit code."""
302
- loader = unittest.TestLoader()
303
- suite = loader.loadTestsFromModule(sys.modules[__name__])
304
- runner = unittest.TextTestRunner(verbosity=2)
305
- result = runner.run(suite)
306
- return 0 if result.wasSuccessful() else 1
307
-
308
-
309
- if __name__ == '__main__':
310
- sys.exit(run_tests())
311
-
@@ -1,8 +0,0 @@
1
- featrixsphere/__init__.py,sha256=UaVSYqUbFAR92BWYcQkfV7lVfnbogtZFbAJTgDdmUQo,1888
2
- featrixsphere/client.py,sha256=bqRRyumplQmeRDrxyyhYcj_VIOnl_svh-fJ8VUcxBv4,426243
3
- featrixsphere/test_client.py,sha256=4SiRbib0ms3poK0UpnUv4G0HFQSzidF3Iswo_J2cjLk,11981
4
- featrixsphere-0.2.4991.dist-info/METADATA,sha256=DOj--UmpmV2TsuBAtsn72NYFRwn6Nipqq71oGj_bAnc,16232
5
- featrixsphere-0.2.4991.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
- featrixsphere-0.2.4991.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
7
- featrixsphere-0.2.4991.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
8
- featrixsphere-0.2.4991.dist-info/RECORD,,