featrixsphere 0.2.4991__py3-none-any.whl → 0.2.5183__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- featrixsphere/__init__.py +1 -1
- featrixsphere/client.py +228 -58
- {featrixsphere-0.2.4991.dist-info → featrixsphere-0.2.5183.dist-info}/METADATA +1 -1
- featrixsphere-0.2.5183.dist-info/RECORD +7 -0
- featrixsphere/test_client.py +0 -311
- featrixsphere-0.2.4991.dist-info/RECORD +0 -8
- {featrixsphere-0.2.4991.dist-info → featrixsphere-0.2.5183.dist-info}/WHEEL +0 -0
- {featrixsphere-0.2.4991.dist-info → featrixsphere-0.2.5183.dist-info}/entry_points.txt +0 -0
- {featrixsphere-0.2.4991.dist-info → featrixsphere-0.2.5183.dist-info}/top_level.txt +0 -0
featrixsphere/__init__.py
CHANGED
featrixsphere/client.py
CHANGED
|
@@ -877,24 +877,30 @@ class FeatrixSphereClient:
|
|
|
877
877
|
# Other HTTP errors - re-raise
|
|
878
878
|
raise
|
|
879
879
|
|
|
880
|
-
def publish_session(self, session_id: str) -> Dict[str, Any]:
|
|
880
|
+
def publish_session(self, session_id: str, org_id: str, name: str) -> Dict[str, Any]:
|
|
881
881
|
"""
|
|
882
|
-
Publish a session by moving it to /sphere/
|
|
882
|
+
Publish a session by moving it to /backplane/backplane1/sphere/PUBLISHED/<org_id>/<name>/<sessionId>.
|
|
883
883
|
Moves both the session file and output directory.
|
|
884
884
|
|
|
885
885
|
Args:
|
|
886
886
|
session_id: Session ID to publish
|
|
887
|
+
org_id: Organization ID for subdirectory organization
|
|
888
|
+
name: Name for the published session (creates subdirectory under org_id)
|
|
887
889
|
|
|
888
890
|
Returns:
|
|
889
891
|
Response with published_path, output_path, and status
|
|
890
892
|
|
|
891
893
|
Example:
|
|
892
894
|
```python
|
|
893
|
-
result = client.publish_session("abc123")
|
|
895
|
+
result = client.publish_session("abc123", org_id="org_123", name="production_model")
|
|
894
896
|
print(f"Published to: {result['published_path']}")
|
|
895
897
|
```
|
|
896
898
|
"""
|
|
897
|
-
|
|
899
|
+
data = {
|
|
900
|
+
"org_id": org_id,
|
|
901
|
+
"name": name
|
|
902
|
+
}
|
|
903
|
+
response_data = self._post_json(f"/compute/session/{session_id}/publish", data)
|
|
898
904
|
return response_data
|
|
899
905
|
|
|
900
906
|
def deprecate_session(self, session_id: str, warning_message: str, expiration_date: str) -> Dict[str, Any]:
|
|
@@ -1241,7 +1247,18 @@ class FeatrixSphereClient:
|
|
|
1241
1247
|
} if show_live_training_movie else None
|
|
1242
1248
|
|
|
1243
1249
|
while time.time() - start_time < max_wait_time:
|
|
1244
|
-
|
|
1250
|
+
try:
|
|
1251
|
+
session_info = self.get_session_status(session_id)
|
|
1252
|
+
except KeyboardInterrupt:
|
|
1253
|
+
print("\n\n⚠️ Interrupted by user (Ctrl+C)")
|
|
1254
|
+
print(f" Session {session_id} status check interrupted")
|
|
1255
|
+
print(" Returning current session status...")
|
|
1256
|
+
# Return last known status or get it one more time if possible
|
|
1257
|
+
try:
|
|
1258
|
+
return self.get_session_status(session_id)
|
|
1259
|
+
except:
|
|
1260
|
+
# If we can't get status, return a basic SessionInfo
|
|
1261
|
+
return SessionInfo(session_id=session_id, status="unknown", jobs={})
|
|
1245
1262
|
|
|
1246
1263
|
# Clear previous output and show updated status
|
|
1247
1264
|
clear_output(wait=True)
|
|
@@ -1357,7 +1374,19 @@ class FeatrixSphereClient:
|
|
|
1357
1374
|
session_task = progress.add_task(f"[bold green]Session {session_id}", total=100)
|
|
1358
1375
|
|
|
1359
1376
|
while time.time() - start_time < max_wait_time:
|
|
1360
|
-
|
|
1377
|
+
try:
|
|
1378
|
+
session_info = self.get_session_status(session_id)
|
|
1379
|
+
except KeyboardInterrupt:
|
|
1380
|
+
progress.console.print("\n\n[bold yellow]⚠️ Interrupted by user (Ctrl+C)[/bold yellow]")
|
|
1381
|
+
progress.console.print(f" Session {session_id} status check interrupted")
|
|
1382
|
+
progress.console.print(" Returning current session status...")
|
|
1383
|
+
# Return last known status or get it one more time if possible
|
|
1384
|
+
try:
|
|
1385
|
+
return self.get_session_status(session_id)
|
|
1386
|
+
except:
|
|
1387
|
+
# If we can't get status, return a basic SessionInfo
|
|
1388
|
+
from featrixsphere.client import SessionInfo
|
|
1389
|
+
return SessionInfo(session_id=session_id, status="unknown", jobs={})
|
|
1361
1390
|
|
|
1362
1391
|
# Update session progress
|
|
1363
1392
|
elapsed = time.time() - start_time
|
|
@@ -1426,7 +1455,19 @@ class FeatrixSphereClient:
|
|
|
1426
1455
|
jobs_appeared = False
|
|
1427
1456
|
|
|
1428
1457
|
while time.time() - initial_wait_start < initial_wait_timeout:
|
|
1429
|
-
|
|
1458
|
+
try:
|
|
1459
|
+
session_info = self.get_session_status(session_id)
|
|
1460
|
+
except KeyboardInterrupt:
|
|
1461
|
+
logger.info("\n⚠️ Interrupted by user (Ctrl+C)")
|
|
1462
|
+
logger.info(f" Session {session_id} status check interrupted")
|
|
1463
|
+
logger.info(" Returning current session status...")
|
|
1464
|
+
# Return last known status or get it one more time if possible
|
|
1465
|
+
try:
|
|
1466
|
+
return self.get_session_status(session_id)
|
|
1467
|
+
except:
|
|
1468
|
+
# If we can't get status, return a basic SessionInfo
|
|
1469
|
+
return SessionInfo(session_id=session_id, status="unknown", jobs={})
|
|
1470
|
+
|
|
1430
1471
|
if session_info.jobs:
|
|
1431
1472
|
jobs_appeared = True
|
|
1432
1473
|
break
|
|
@@ -1434,7 +1475,20 @@ class FeatrixSphereClient:
|
|
|
1434
1475
|
|
|
1435
1476
|
# Main monitoring loop
|
|
1436
1477
|
while time.time() - start_time < max_wait_time:
|
|
1437
|
-
|
|
1478
|
+
try:
|
|
1479
|
+
session_info = self.get_session_status(session_id)
|
|
1480
|
+
except KeyboardInterrupt:
|
|
1481
|
+
logger.info("\n⚠️ Interrupted by user (Ctrl+C)")
|
|
1482
|
+
logger.info(f" Session {session_id} status check interrupted")
|
|
1483
|
+
logger.info(" Returning current session status...")
|
|
1484
|
+
# Return last known status or get it one more time if possible
|
|
1485
|
+
try:
|
|
1486
|
+
return self.get_session_status(session_id)
|
|
1487
|
+
except:
|
|
1488
|
+
# If we can't get status, return a basic SessionInfo
|
|
1489
|
+
from featrixsphere.client import SessionInfo
|
|
1490
|
+
return SessionInfo(session_id=session_id, status="unknown", jobs={})
|
|
1491
|
+
|
|
1438
1492
|
elapsed = time.time() - start_time
|
|
1439
1493
|
|
|
1440
1494
|
# Call the callback with current status
|
|
@@ -1471,7 +1525,19 @@ class FeatrixSphereClient:
|
|
|
1471
1525
|
last_num_lines = 0
|
|
1472
1526
|
|
|
1473
1527
|
while time.time() - start_time < max_wait_time:
|
|
1474
|
-
|
|
1528
|
+
try:
|
|
1529
|
+
session_info = self.get_session_status(session_id)
|
|
1530
|
+
except KeyboardInterrupt:
|
|
1531
|
+
print("\n\n⚠️ Interrupted by user (Ctrl+C)")
|
|
1532
|
+
print(f" Session {session_id} status check interrupted")
|
|
1533
|
+
print(" Returning current session status...")
|
|
1534
|
+
# Return last known status or get it one more time if possible
|
|
1535
|
+
try:
|
|
1536
|
+
return self.get_session_status(session_id)
|
|
1537
|
+
except:
|
|
1538
|
+
# If we can't get status, return a basic SessionInfo
|
|
1539
|
+
from featrixsphere.client import SessionInfo
|
|
1540
|
+
return SessionInfo(session_id=session_id, status="unknown", jobs={})
|
|
1475
1541
|
|
|
1476
1542
|
# Clear previous lines if terminal supports it
|
|
1477
1543
|
if sys.stdout.isatty() and last_num_lines > 0:
|
|
@@ -2866,6 +2932,110 @@ class FeatrixSphereClient:
|
|
|
2866
2932
|
response_data = self._post_json(f"/session/{session_id}/predict", request_payload, max_retries=max_retries)
|
|
2867
2933
|
return response_data
|
|
2868
2934
|
|
|
2935
|
+
def explain(self, session_id: str, record, class_idx: int = None,
|
|
2936
|
+
target_column: str = None, predictor_id: str = None,
|
|
2937
|
+
record_b: Dict[str, Any] = None,
|
|
2938
|
+
max_retries: int = None) -> Dict[str, Any]:
|
|
2939
|
+
"""
|
|
2940
|
+
Explain a prediction using gradient attribution.
|
|
2941
|
+
|
|
2942
|
+
Supports multiple modes:
|
|
2943
|
+
- explain(record): Explain a single row
|
|
2944
|
+
- explain(record, record_b=other_record): Compare two rows
|
|
2945
|
+
- explain([record1, record2, ...]): Explain multiple rows
|
|
2946
|
+
|
|
2947
|
+
Returns what matters to Featrix in the given row(s):
|
|
2948
|
+
- Which features mattered for this prediction
|
|
2949
|
+
- Which relationships mattered for this prediction
|
|
2950
|
+
|
|
2951
|
+
Args:
|
|
2952
|
+
session_id: ID of session with trained predictor
|
|
2953
|
+
record: Record dictionary (without target column), or list of records
|
|
2954
|
+
class_idx: Target class index for attribution (default: predicted class)
|
|
2955
|
+
target_column: Specific target column predictor to use (optional)
|
|
2956
|
+
predictor_id: Specific predictor ID to use (optional)
|
|
2957
|
+
record_b: Optional second record for comparison
|
|
2958
|
+
max_retries: Number of retries for errors (default: uses client default)
|
|
2959
|
+
|
|
2960
|
+
Returns:
|
|
2961
|
+
For single record:
|
|
2962
|
+
Dictionary with:
|
|
2963
|
+
- feature_scores: {col_name: score} - gradient norm per feature
|
|
2964
|
+
- pair_scores: {(i, j): score} - gradient norm per relationship pair
|
|
2965
|
+
- target_class_idx: The class index used for attribution
|
|
2966
|
+
- logit: The prediction logit
|
|
2967
|
+
|
|
2968
|
+
For two records (record_b provided):
|
|
2969
|
+
Dictionary with:
|
|
2970
|
+
- record_a: Explanation for first record
|
|
2971
|
+
- record_b: Explanation for second record
|
|
2972
|
+
- difference: Difference in feature_scores and pair_scores
|
|
2973
|
+
|
|
2974
|
+
For list of records:
|
|
2975
|
+
Dictionary with:
|
|
2976
|
+
- explanations: List of explanation dictionaries, one per record
|
|
2977
|
+
"""
|
|
2978
|
+
# Clean NaN/Inf values
|
|
2979
|
+
if isinstance(record, list):
|
|
2980
|
+
cleaned_record = [self.replace_nans_with_nulls(self._clean_numpy_values(r)) for r in record]
|
|
2981
|
+
else:
|
|
2982
|
+
cleaned_record = self.replace_nans_with_nulls(self._clean_numpy_values(record))
|
|
2983
|
+
|
|
2984
|
+
cleaned_record_b = None
|
|
2985
|
+
if record_b is not None:
|
|
2986
|
+
cleaned_record_b = self.replace_nans_with_nulls(self._clean_numpy_values(record_b))
|
|
2987
|
+
|
|
2988
|
+
# Build request payload
|
|
2989
|
+
request_payload = {
|
|
2990
|
+
"query_record": cleaned_record,
|
|
2991
|
+
}
|
|
2992
|
+
|
|
2993
|
+
if class_idx is not None:
|
|
2994
|
+
request_payload["class_idx"] = class_idx
|
|
2995
|
+
if target_column:
|
|
2996
|
+
request_payload["target_column"] = target_column
|
|
2997
|
+
if predictor_id:
|
|
2998
|
+
request_payload["predictor_id"] = predictor_id
|
|
2999
|
+
if cleaned_record_b is not None:
|
|
3000
|
+
request_payload["query_record_b"] = cleaned_record_b
|
|
3001
|
+
|
|
3002
|
+
response_data = self._post_json(f"/session/{session_id}/explain", request_payload, max_retries=max_retries)
|
|
3003
|
+
|
|
3004
|
+
# Helper to convert pair_scores keys back to tuples
|
|
3005
|
+
def convert_pair_scores(ps_dict):
|
|
3006
|
+
if not isinstance(ps_dict, dict):
|
|
3007
|
+
return ps_dict
|
|
3008
|
+
result = {}
|
|
3009
|
+
for key, score in ps_dict.items():
|
|
3010
|
+
# Key format is "i_j"
|
|
3011
|
+
parts = key.split("_")
|
|
3012
|
+
if len(parts) == 2:
|
|
3013
|
+
try:
|
|
3014
|
+
i, j = int(parts[0]), int(parts[1])
|
|
3015
|
+
result[(i, j)] = score
|
|
3016
|
+
except ValueError:
|
|
3017
|
+
result[key] = score
|
|
3018
|
+
else:
|
|
3019
|
+
result[key] = score
|
|
3020
|
+
return result
|
|
3021
|
+
|
|
3022
|
+
# Convert pair_scores keys back to tuples for easier use
|
|
3023
|
+
if "pair_scores" in response_data:
|
|
3024
|
+
response_data["pair_scores"] = convert_pair_scores(response_data["pair_scores"])
|
|
3025
|
+
elif "explanations" in response_data:
|
|
3026
|
+
for expl in response_data["explanations"]:
|
|
3027
|
+
if "pair_scores" in expl:
|
|
3028
|
+
expl["pair_scores"] = convert_pair_scores(expl["pair_scores"])
|
|
3029
|
+
elif "record_a" in response_data:
|
|
3030
|
+
if "pair_scores" in response_data["record_a"]:
|
|
3031
|
+
response_data["record_a"]["pair_scores"] = convert_pair_scores(response_data["record_a"]["pair_scores"])
|
|
3032
|
+
if "pair_scores" in response_data["record_b"]:
|
|
3033
|
+
response_data["record_b"]["pair_scores"] = convert_pair_scores(response_data["record_b"]["pair_scores"])
|
|
3034
|
+
if "difference" in response_data and "pair_scores" in response_data["difference"]:
|
|
3035
|
+
response_data["difference"]["pair_scores"] = convert_pair_scores(response_data["difference"]["pair_scores"])
|
|
3036
|
+
|
|
3037
|
+
return response_data
|
|
3038
|
+
|
|
2869
3039
|
def get_training_metrics(self, session_id: str, max_retries: int = None) -> Dict[str, Any]:
|
|
2870
3040
|
"""
|
|
2871
3041
|
Get training metrics for a session's single predictor.
|
|
@@ -8117,6 +8287,55 @@ class FeatrixSphereClient:
|
|
|
8117
8287
|
"""
|
|
8118
8288
|
return PredictionGrid(session_id, self, degrees_of_freedom, grid_shape, target_column)
|
|
8119
8289
|
|
|
8290
|
+
def get_embedding_space_columns(self, session_id: str) -> Dict[str, Any]:
|
|
8291
|
+
"""
|
|
8292
|
+
Get column names and types from the embedding space.
|
|
8293
|
+
|
|
8294
|
+
Tries to get from model_card.json first (if training completed),
|
|
8295
|
+
otherwise falls back to loading the embedding space directly.
|
|
8296
|
+
|
|
8297
|
+
Args:
|
|
8298
|
+
session_id: Session ID with trained embedding space
|
|
8299
|
+
|
|
8300
|
+
Returns:
|
|
8301
|
+
Dictionary with:
|
|
8302
|
+
- column_names: List of column names
|
|
8303
|
+
- column_types: Dict mapping column names to types (scalar, set, free_string, etc.)
|
|
8304
|
+
- num_columns: Total number of columns
|
|
8305
|
+
|
|
8306
|
+
Example:
|
|
8307
|
+
>>> columns = client.get_embedding_space_columns(session_id)
|
|
8308
|
+
>>> print(f"Columns: {columns['column_names']}")
|
|
8309
|
+
>>> print(f"Types: {columns['column_types']}")
|
|
8310
|
+
"""
|
|
8311
|
+
# Try model_card first (if training completed)
|
|
8312
|
+
try:
|
|
8313
|
+
model_card = self.get_model_card(session_id)
|
|
8314
|
+
|
|
8315
|
+
# Extract column names from training_dataset.feature_names
|
|
8316
|
+
training_dataset = model_card.get('training_dataset', {})
|
|
8317
|
+
column_names = training_dataset.get('feature_names', [])
|
|
8318
|
+
|
|
8319
|
+
# Extract column types from feature_inventory
|
|
8320
|
+
feature_inventory = model_card.get('feature_inventory', {})
|
|
8321
|
+
column_types = {}
|
|
8322
|
+
for feature_name, feature_info in feature_inventory.items():
|
|
8323
|
+
if isinstance(feature_info, dict):
|
|
8324
|
+
column_types[feature_name] = feature_info.get('type', 'unknown')
|
|
8325
|
+
|
|
8326
|
+
if column_names:
|
|
8327
|
+
return {
|
|
8328
|
+
"column_names": column_names,
|
|
8329
|
+
"column_types": column_types,
|
|
8330
|
+
"num_columns": len(column_names)
|
|
8331
|
+
}
|
|
8332
|
+
except Exception:
|
|
8333
|
+
# Model card doesn't exist yet, fall back to direct endpoint
|
|
8334
|
+
pass
|
|
8335
|
+
|
|
8336
|
+
# Fallback: load embedding space directly
|
|
8337
|
+
return self._get_json(f"/compute/session/{session_id}/columns")
|
|
8338
|
+
|
|
8120
8339
|
|
|
8121
8340
|
class PredictionGrid:
|
|
8122
8341
|
"""
|
|
@@ -8862,55 +9081,6 @@ class PredictionGrid:
|
|
|
8862
9081
|
else:
|
|
8863
9082
|
return fig
|
|
8864
9083
|
|
|
8865
|
-
def get_embedding_space_columns(self, session_id: str) -> Dict[str, Any]:
|
|
8866
|
-
"""
|
|
8867
|
-
Get column names and types from the embedding space.
|
|
8868
|
-
|
|
8869
|
-
Tries to get from model_card.json first (if training completed),
|
|
8870
|
-
otherwise falls back to loading the embedding space directly.
|
|
8871
|
-
|
|
8872
|
-
Args:
|
|
8873
|
-
session_id: Session ID with trained embedding space
|
|
8874
|
-
|
|
8875
|
-
Returns:
|
|
8876
|
-
Dictionary with:
|
|
8877
|
-
- column_names: List of column names
|
|
8878
|
-
- column_types: Dict mapping column names to types (scalar, set, free_string, etc.)
|
|
8879
|
-
- num_columns: Total number of columns
|
|
8880
|
-
|
|
8881
|
-
Example:
|
|
8882
|
-
>>> columns = client.get_embedding_space_columns(session_id)
|
|
8883
|
-
>>> print(f"Columns: {columns['column_names']}")
|
|
8884
|
-
>>> print(f"Types: {columns['column_types']}")
|
|
8885
|
-
"""
|
|
8886
|
-
# Try model_card first (if training completed)
|
|
8887
|
-
try:
|
|
8888
|
-
model_card = self.get_model_card(session_id)
|
|
8889
|
-
|
|
8890
|
-
# Extract column names from training_dataset.feature_names
|
|
8891
|
-
training_dataset = model_card.get('training_dataset', {})
|
|
8892
|
-
column_names = training_dataset.get('feature_names', [])
|
|
8893
|
-
|
|
8894
|
-
# Extract column types from feature_inventory
|
|
8895
|
-
feature_inventory = model_card.get('feature_inventory', {})
|
|
8896
|
-
column_types = {}
|
|
8897
|
-
for feature_name, feature_info in feature_inventory.items():
|
|
8898
|
-
if isinstance(feature_info, dict):
|
|
8899
|
-
column_types[feature_name] = feature_info.get('type', 'unknown')
|
|
8900
|
-
|
|
8901
|
-
if column_names:
|
|
8902
|
-
return {
|
|
8903
|
-
"column_names": column_names,
|
|
8904
|
-
"column_types": column_types,
|
|
8905
|
-
"num_columns": len(column_names)
|
|
8906
|
-
}
|
|
8907
|
-
except Exception:
|
|
8908
|
-
# Model card doesn't exist yet, fall back to direct endpoint
|
|
8909
|
-
pass
|
|
8910
|
-
|
|
8911
|
-
# Fallback: load embedding space directly
|
|
8912
|
-
return self._get_json(f"/compute/session/{session_id}/columns")
|
|
8913
|
-
|
|
8914
9084
|
def get_predictor_schema(self, session_id: str, predictor_index: int = 0) -> Dict[str, Any]:
|
|
8915
9085
|
"""
|
|
8916
9086
|
Get predictor schema/metadata for validating input data locally.
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
featrixsphere/__init__.py,sha256=tPVhFLHujHMC0bnsnjggdMIb2chNkYbbm0wUX-lwWYY,1888
|
|
2
|
+
featrixsphere/client.py,sha256=YvOB2y8zh4iCMccXQ-4ZsQ8dgmUSQlkLh2zsxIiIoYM,435090
|
|
3
|
+
featrixsphere-0.2.5183.dist-info/METADATA,sha256=BEVBhDxyQvjFfDWyzSSKTZLwybnQFyPYU6dk_cxB5CM,16232
|
|
4
|
+
featrixsphere-0.2.5183.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
5
|
+
featrixsphere-0.2.5183.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
|
|
6
|
+
featrixsphere-0.2.5183.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
|
|
7
|
+
featrixsphere-0.2.5183.dist-info/RECORD,,
|
featrixsphere/test_client.py
DELETED
|
@@ -1,311 +0,0 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
"""
|
|
3
|
-
Tests for FeatrixSphereClient
|
|
4
|
-
|
|
5
|
-
These tests verify basic functionality without requiring a live API server.
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
import unittest
|
|
9
|
-
from unittest.mock import Mock, patch, MagicMock
|
|
10
|
-
import json
|
|
11
|
-
from pathlib import Path
|
|
12
|
-
import sys
|
|
13
|
-
|
|
14
|
-
# Mock optional dependencies before importing featrixsphere
|
|
15
|
-
import sys
|
|
16
|
-
|
|
17
|
-
# Mock numpy
|
|
18
|
-
try:
|
|
19
|
-
import numpy as np
|
|
20
|
-
except ImportError:
|
|
21
|
-
class MockNumpy:
|
|
22
|
-
class ndarray:
|
|
23
|
-
pass
|
|
24
|
-
def array(self, *args, **kwargs):
|
|
25
|
-
return []
|
|
26
|
-
def __getattr__(self, name):
|
|
27
|
-
return lambda *args, **kwargs: None
|
|
28
|
-
sys.modules['numpy'] = MockNumpy()
|
|
29
|
-
|
|
30
|
-
# Mock matplotlib
|
|
31
|
-
try:
|
|
32
|
-
import matplotlib.pyplot as plt
|
|
33
|
-
except ImportError:
|
|
34
|
-
class MockMatplotlib:
|
|
35
|
-
class Figure:
|
|
36
|
-
pass
|
|
37
|
-
def __getattr__(self, name):
|
|
38
|
-
return lambda *args, **kwargs: None
|
|
39
|
-
sys.modules['matplotlib'] = MockMatplotlib()
|
|
40
|
-
sys.modules['matplotlib.pyplot'] = MockMatplotlib()
|
|
41
|
-
sys.modules['matplotlib.dates'] = MockMatplotlib()
|
|
42
|
-
|
|
43
|
-
# Add parent directory to path to import featrixsphere
|
|
44
|
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
45
|
-
|
|
46
|
-
try:
|
|
47
|
-
from featrixsphere import FeatrixSphereClient, SessionInfo, PredictionBatch
|
|
48
|
-
except (ImportError, AttributeError) as e:
|
|
49
|
-
# If import fails, create minimal mocks for basic structure tests
|
|
50
|
-
print(f"⚠️ Warning: Could not fully import featrixsphere: {e}")
|
|
51
|
-
print(" Running minimal structure tests only...")
|
|
52
|
-
|
|
53
|
-
# Create minimal mocks for basic testing
|
|
54
|
-
class MockSession:
|
|
55
|
-
def __init__(self):
|
|
56
|
-
self.headers = {}
|
|
57
|
-
self.timeout = 30
|
|
58
|
-
|
|
59
|
-
class MockFeatrixSphereClient:
|
|
60
|
-
def __init__(self, base_url="http://test.com", **kwargs):
|
|
61
|
-
self.base_url = base_url.rstrip('/')
|
|
62
|
-
self.compute_cluster = kwargs.get('compute_cluster')
|
|
63
|
-
self.session = MockSession()
|
|
64
|
-
self.default_max_retries = 5
|
|
65
|
-
# Set header if compute_cluster provided
|
|
66
|
-
if self.compute_cluster:
|
|
67
|
-
self.session.headers['X-Featrix-Node'] = self.compute_cluster
|
|
68
|
-
|
|
69
|
-
def set_compute_cluster(self, cluster):
|
|
70
|
-
self.compute_cluster = cluster
|
|
71
|
-
if cluster:
|
|
72
|
-
self.session.headers['X-Featrix-Node'] = cluster
|
|
73
|
-
else:
|
|
74
|
-
self.session.headers.pop('X-Featrix-Node', None)
|
|
75
|
-
|
|
76
|
-
def _make_request(self, method, endpoint, **kwargs):
|
|
77
|
-
from unittest.mock import Mock
|
|
78
|
-
response = Mock()
|
|
79
|
-
response.status_code = 200
|
|
80
|
-
response.json.return_value = {}
|
|
81
|
-
return response
|
|
82
|
-
|
|
83
|
-
class MockSessionInfo:
|
|
84
|
-
def __init__(self, session_id, session_type, status, jobs, job_queue_positions, _client=None):
|
|
85
|
-
self.session_id = session_id
|
|
86
|
-
self.session_type = session_type
|
|
87
|
-
self.status = status
|
|
88
|
-
|
|
89
|
-
class MockPredictionBatch:
|
|
90
|
-
def __init__(self, session_id, client, target_column=None):
|
|
91
|
-
self.session_id = session_id
|
|
92
|
-
self.client = client
|
|
93
|
-
self._cache = {}
|
|
94
|
-
self._stats = {'hits': 0, 'misses': 0, 'populated': 0}
|
|
95
|
-
|
|
96
|
-
def _hash_record(self, record):
|
|
97
|
-
import hashlib
|
|
98
|
-
import json
|
|
99
|
-
sorted_items = sorted(record.items())
|
|
100
|
-
record_str = json.dumps(sorted_items, sort_keys=True)
|
|
101
|
-
return hashlib.md5(record_str.encode()).hexdigest()
|
|
102
|
-
|
|
103
|
-
def predict(self, record):
|
|
104
|
-
record_hash = self._hash_record(record)
|
|
105
|
-
if record_hash in self._cache:
|
|
106
|
-
self._stats['hits'] += 1
|
|
107
|
-
return self._cache[record_hash]
|
|
108
|
-
else:
|
|
109
|
-
self._stats['misses'] += 1
|
|
110
|
-
return {
|
|
111
|
-
'cache_miss': True,
|
|
112
|
-
'record': record,
|
|
113
|
-
'suggestion': 'Record not found in batch cache. Add to records list and recreate batch.'
|
|
114
|
-
}
|
|
115
|
-
|
|
116
|
-
def get_stats(self):
|
|
117
|
-
total = self._stats['hits'] + self._stats['misses']
|
|
118
|
-
return {
|
|
119
|
-
'cache_hits': self._stats['hits'],
|
|
120
|
-
'cache_misses': self._stats['misses'],
|
|
121
|
-
'total_requests': total,
|
|
122
|
-
'hit_rate': self._stats['hits'] / total if total > 0 else 0.0
|
|
123
|
-
}
|
|
124
|
-
|
|
125
|
-
FeatrixSphereClient = MockFeatrixSphereClient
|
|
126
|
-
SessionInfo = MockSessionInfo
|
|
127
|
-
PredictionBatch = MockPredictionBatch
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
class TestFeatrixSphereClient(unittest.TestCase):
|
|
131
|
-
"""Test cases for FeatrixSphereClient."""
|
|
132
|
-
|
|
133
|
-
def setUp(self):
|
|
134
|
-
"""Set up test fixtures."""
|
|
135
|
-
self.client = FeatrixSphereClient(base_url="http://test-server.com")
|
|
136
|
-
|
|
137
|
-
def test_client_initialization(self):
|
|
138
|
-
"""Test that client initializes correctly."""
|
|
139
|
-
self.assertEqual(self.client.base_url, "http://test-server.com")
|
|
140
|
-
self.assertIsNotNone(self.client.session)
|
|
141
|
-
self.assertEqual(self.client.default_max_retries, 5)
|
|
142
|
-
|
|
143
|
-
def test_client_with_compute_cluster(self):
|
|
144
|
-
"""Test client initialization with compute cluster."""
|
|
145
|
-
client = FeatrixSphereClient(
|
|
146
|
-
base_url="http://test-server.com",
|
|
147
|
-
compute_cluster="burrito"
|
|
148
|
-
)
|
|
149
|
-
self.assertEqual(client.compute_cluster, "burrito")
|
|
150
|
-
self.assertIn("X-Featrix-Node", client.session.headers)
|
|
151
|
-
self.assertEqual(client.session.headers["X-Featrix-Node"], "burrito")
|
|
152
|
-
|
|
153
|
-
def test_set_compute_cluster(self):
|
|
154
|
-
"""Test setting compute cluster after initialization."""
|
|
155
|
-
self.client.set_compute_cluster("churro")
|
|
156
|
-
self.assertEqual(self.client.compute_cluster, "churro")
|
|
157
|
-
self.assertEqual(self.client.session.headers.get("X-Featrix-Node"), "churro")
|
|
158
|
-
|
|
159
|
-
# Test removing cluster
|
|
160
|
-
self.client.set_compute_cluster(None)
|
|
161
|
-
self.assertIsNone(self.client.compute_cluster)
|
|
162
|
-
self.assertNotIn("X-Featrix-Node", self.client.session.headers)
|
|
163
|
-
|
|
164
|
-
def test_endpoint_auto_prefix(self):
|
|
165
|
-
"""Test that session endpoints get /compute prefix automatically."""
|
|
166
|
-
# Skip if using mocks (client doesn't have full requests functionality)
|
|
167
|
-
if not hasattr(self.client.session, 'get'):
|
|
168
|
-
self.skipTest("Skipping - using mocks without full requests support")
|
|
169
|
-
|
|
170
|
-
with patch.object(self.client.session, 'get') as mock_get:
|
|
171
|
-
mock_response = Mock()
|
|
172
|
-
mock_response.status_code = 200
|
|
173
|
-
mock_response.json.return_value = {}
|
|
174
|
-
mock_get.return_value = mock_response
|
|
175
|
-
|
|
176
|
-
# Should auto-add /compute prefix
|
|
177
|
-
self.client._make_request('GET', '/session/test-123')
|
|
178
|
-
mock_get.assert_called_once()
|
|
179
|
-
call_url = mock_get.call_args[0][0]
|
|
180
|
-
self.assertIn('/compute/session/test-123', call_url)
|
|
181
|
-
|
|
182
|
-
def test_session_info_initialization(self):
|
|
183
|
-
"""Test SessionInfo dataclass initialization."""
|
|
184
|
-
session = SessionInfo(
|
|
185
|
-
session_id="test-123",
|
|
186
|
-
session_type="embedding_space",
|
|
187
|
-
status="complete",
|
|
188
|
-
jobs={},
|
|
189
|
-
job_queue_positions={}
|
|
190
|
-
)
|
|
191
|
-
self.assertEqual(session.session_id, "test-123")
|
|
192
|
-
self.assertEqual(session.session_type, "embedding_space")
|
|
193
|
-
self.assertEqual(session.status, "complete")
|
|
194
|
-
|
|
195
|
-
def test_prediction_batch_hash_record(self):
|
|
196
|
-
"""Test PredictionBatch record hashing."""
|
|
197
|
-
batch = PredictionBatch("test-123", self.client)
|
|
198
|
-
|
|
199
|
-
record1 = {"a": 1, "b": 2}
|
|
200
|
-
record2 = {"b": 2, "a": 1} # Same keys, different order
|
|
201
|
-
record3 = {"a": 1, "b": 3} # Different value
|
|
202
|
-
|
|
203
|
-
hash1 = batch._hash_record(record1)
|
|
204
|
-
hash2 = batch._hash_record(record2)
|
|
205
|
-
hash3 = batch._hash_record(record3)
|
|
206
|
-
|
|
207
|
-
# Same records should hash to same value (order-independent)
|
|
208
|
-
self.assertEqual(hash1, hash2)
|
|
209
|
-
# Different records should hash to different values
|
|
210
|
-
self.assertNotEqual(hash1, hash3)
|
|
211
|
-
|
|
212
|
-
def test_prediction_batch_cache_miss(self):
|
|
213
|
-
"""Test PredictionBatch cache miss behavior."""
|
|
214
|
-
batch = PredictionBatch("test-123", self.client)
|
|
215
|
-
|
|
216
|
-
record = {"feature": "value"}
|
|
217
|
-
result = batch.predict(record)
|
|
218
|
-
|
|
219
|
-
self.assertTrue(result.get('cache_miss'))
|
|
220
|
-
self.assertEqual(result.get('record'), record)
|
|
221
|
-
self.assertIn('suggestion', result)
|
|
222
|
-
|
|
223
|
-
def test_prediction_batch_stats(self):
|
|
224
|
-
"""Test PredictionBatch statistics tracking."""
|
|
225
|
-
batch = PredictionBatch("test-123", self.client)
|
|
226
|
-
|
|
227
|
-
# Make some predictions (cache misses)
|
|
228
|
-
batch.predict({"a": 1})
|
|
229
|
-
batch.predict({"b": 2})
|
|
230
|
-
|
|
231
|
-
stats = batch.get_stats()
|
|
232
|
-
self.assertEqual(stats['cache_misses'], 2)
|
|
233
|
-
self.assertEqual(stats['cache_hits'], 0)
|
|
234
|
-
self.assertEqual(stats['total_requests'], 2)
|
|
235
|
-
self.assertEqual(stats['hit_rate'], 0.0)
|
|
236
|
-
|
|
237
|
-
def test_prediction_batch_cache_hit(self):
|
|
238
|
-
"""Test PredictionBatch cache hit behavior."""
|
|
239
|
-
batch = PredictionBatch("test-123", self.client)
|
|
240
|
-
|
|
241
|
-
# Manually populate cache
|
|
242
|
-
record = {"feature": "value"}
|
|
243
|
-
record_hash = batch._hash_record(record)
|
|
244
|
-
batch._cache[record_hash] = {"prediction": "test_result"}
|
|
245
|
-
batch._stats['populated'] = 1
|
|
246
|
-
|
|
247
|
-
# Now predict should hit cache
|
|
248
|
-
result = batch.predict(record)
|
|
249
|
-
self.assertFalse(result.get('cache_miss', False))
|
|
250
|
-
self.assertEqual(result.get('prediction'), "test_result")
|
|
251
|
-
|
|
252
|
-
stats = batch.get_stats()
|
|
253
|
-
self.assertEqual(stats['cache_hits'], 1)
|
|
254
|
-
self.assertEqual(stats['cache_misses'], 0)
|
|
255
|
-
self.assertEqual(stats['hit_rate'], 1.0)
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
class TestClientErrorHandling(unittest.TestCase):
|
|
259
|
-
"""Test error handling in FeatrixSphereClient."""
|
|
260
|
-
|
|
261
|
-
def setUp(self):
|
|
262
|
-
"""Set up test fixtures."""
|
|
263
|
-
self.client = FeatrixSphereClient(base_url="http://test-server.com")
|
|
264
|
-
|
|
265
|
-
def test_make_request_retry_on_500(self):
|
|
266
|
-
"""Test that 500 errors trigger retries."""
|
|
267
|
-
# Skip if using mocks (client doesn't have full requests functionality)
|
|
268
|
-
if not hasattr(self.client.session, 'get'):
|
|
269
|
-
self.skipTest("Skipping - using mocks without full requests support")
|
|
270
|
-
|
|
271
|
-
with patch.object(self.client.session, 'get') as mock_get:
|
|
272
|
-
# First call returns 500, second returns 200
|
|
273
|
-
mock_response_500 = Mock()
|
|
274
|
-
mock_response_500.status_code = 500
|
|
275
|
-
mock_response_200 = Mock()
|
|
276
|
-
mock_response_200.status_code = 200
|
|
277
|
-
mock_response_200.json.return_value = {}
|
|
278
|
-
mock_get.side_effect = [mock_response_500, mock_response_200]
|
|
279
|
-
|
|
280
|
-
# Should retry and eventually succeed
|
|
281
|
-
response = self.client._make_request('GET', '/test', max_retries=2)
|
|
282
|
-
self.assertEqual(response.status_code, 200)
|
|
283
|
-
self.assertEqual(mock_get.call_count, 2)
|
|
284
|
-
|
|
285
|
-
def test_make_request_timeout(self):
|
|
286
|
-
"""Test timeout handling."""
|
|
287
|
-
# Skip if using mocks (client doesn't have full requests functionality)
|
|
288
|
-
if not hasattr(self.client.session, 'get'):
|
|
289
|
-
self.skipTest("Skipping - using mocks without full requests support")
|
|
290
|
-
|
|
291
|
-
import requests
|
|
292
|
-
with patch.object(self.client.session, 'get') as mock_get:
|
|
293
|
-
mock_get.side_effect = requests.exceptions.Timeout("Request timed out")
|
|
294
|
-
|
|
295
|
-
# Should raise after retries exhausted
|
|
296
|
-
with self.assertRaises(Exception):
|
|
297
|
-
self.client._make_request('GET', '/test', max_retries=1)
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
def run_tests():
|
|
301
|
-
"""Run all tests and return exit code."""
|
|
302
|
-
loader = unittest.TestLoader()
|
|
303
|
-
suite = loader.loadTestsFromModule(sys.modules[__name__])
|
|
304
|
-
runner = unittest.TextTestRunner(verbosity=2)
|
|
305
|
-
result = runner.run(suite)
|
|
306
|
-
return 0 if result.wasSuccessful() else 1
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
if __name__ == '__main__':
|
|
310
|
-
sys.exit(run_tests())
|
|
311
|
-
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
featrixsphere/__init__.py,sha256=UaVSYqUbFAR92BWYcQkfV7lVfnbogtZFbAJTgDdmUQo,1888
|
|
2
|
-
featrixsphere/client.py,sha256=bqRRyumplQmeRDrxyyhYcj_VIOnl_svh-fJ8VUcxBv4,426243
|
|
3
|
-
featrixsphere/test_client.py,sha256=4SiRbib0ms3poK0UpnUv4G0HFQSzidF3Iswo_J2cjLk,11981
|
|
4
|
-
featrixsphere-0.2.4991.dist-info/METADATA,sha256=DOj--UmpmV2TsuBAtsn72NYFRwn6Nipqq71oGj_bAnc,16232
|
|
5
|
-
featrixsphere-0.2.4991.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
6
|
-
featrixsphere-0.2.4991.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
|
|
7
|
-
featrixsphere-0.2.4991.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
|
|
8
|
-
featrixsphere-0.2.4991.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|