featrixsphere 0.2.4984__py3-none-any.whl → 0.2.5182__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- featrixsphere/__init__.py +1 -1
- featrixsphere/client.py +179 -111
- {featrixsphere-0.2.4984.dist-info → featrixsphere-0.2.5182.dist-info}/METADATA +1 -1
- featrixsphere-0.2.5182.dist-info/RECORD +7 -0
- featrixsphere/test_client.py +0 -311
- featrixsphere-0.2.4984.dist-info/RECORD +0 -8
- {featrixsphere-0.2.4984.dist-info → featrixsphere-0.2.5182.dist-info}/WHEEL +0 -0
- {featrixsphere-0.2.4984.dist-info → featrixsphere-0.2.5182.dist-info}/entry_points.txt +0 -0
- {featrixsphere-0.2.4984.dist-info → featrixsphere-0.2.5182.dist-info}/top_level.txt +0 -0
featrixsphere/__init__.py
CHANGED
featrixsphere/client.py
CHANGED
|
@@ -877,24 +877,30 @@ class FeatrixSphereClient:
|
|
|
877
877
|
# Other HTTP errors - re-raise
|
|
878
878
|
raise
|
|
879
879
|
|
|
880
|
-
def publish_session(self, session_id: str) -> Dict[str, Any]:
|
|
880
|
+
def publish_session(self, session_id: str, org_id: str, name: str) -> Dict[str, Any]:
|
|
881
881
|
"""
|
|
882
|
-
Publish a session by moving it to /sphere/
|
|
882
|
+
Publish a session by moving it to /backplane/backplane1/sphere/PUBLISHED/<org_id>/<name>/<sessionId>.
|
|
883
883
|
Moves both the session file and output directory.
|
|
884
884
|
|
|
885
885
|
Args:
|
|
886
886
|
session_id: Session ID to publish
|
|
887
|
+
org_id: Organization ID for subdirectory organization
|
|
888
|
+
name: Name for the published session (creates subdirectory under org_id)
|
|
887
889
|
|
|
888
890
|
Returns:
|
|
889
891
|
Response with published_path, output_path, and status
|
|
890
892
|
|
|
891
893
|
Example:
|
|
892
894
|
```python
|
|
893
|
-
result = client.publish_session("abc123")
|
|
895
|
+
result = client.publish_session("abc123", org_id="org_123", name="production_model")
|
|
894
896
|
print(f"Published to: {result['published_path']}")
|
|
895
897
|
```
|
|
896
898
|
"""
|
|
897
|
-
|
|
899
|
+
data = {
|
|
900
|
+
"org_id": org_id,
|
|
901
|
+
"name": name
|
|
902
|
+
}
|
|
903
|
+
response_data = self._post_json(f"/compute/session/{session_id}/publish", data)
|
|
898
904
|
return response_data
|
|
899
905
|
|
|
900
906
|
def deprecate_session(self, session_id: str, warning_message: str, expiration_date: str) -> Dict[str, Any]:
|
|
@@ -1241,7 +1247,18 @@ class FeatrixSphereClient:
|
|
|
1241
1247
|
} if show_live_training_movie else None
|
|
1242
1248
|
|
|
1243
1249
|
while time.time() - start_time < max_wait_time:
|
|
1244
|
-
|
|
1250
|
+
try:
|
|
1251
|
+
session_info = self.get_session_status(session_id)
|
|
1252
|
+
except KeyboardInterrupt:
|
|
1253
|
+
print("\n\n⚠️ Interrupted by user (Ctrl+C)")
|
|
1254
|
+
print(f" Session {session_id} status check interrupted")
|
|
1255
|
+
print(" Returning current session status...")
|
|
1256
|
+
# Return last known status or get it one more time if possible
|
|
1257
|
+
try:
|
|
1258
|
+
return self.get_session_status(session_id)
|
|
1259
|
+
except:
|
|
1260
|
+
# If we can't get status, return a basic SessionInfo
|
|
1261
|
+
return SessionInfo(session_id=session_id, status="unknown", jobs={})
|
|
1245
1262
|
|
|
1246
1263
|
# Clear previous output and show updated status
|
|
1247
1264
|
clear_output(wait=True)
|
|
@@ -1357,7 +1374,19 @@ class FeatrixSphereClient:
|
|
|
1357
1374
|
session_task = progress.add_task(f"[bold green]Session {session_id}", total=100)
|
|
1358
1375
|
|
|
1359
1376
|
while time.time() - start_time < max_wait_time:
|
|
1360
|
-
|
|
1377
|
+
try:
|
|
1378
|
+
session_info = self.get_session_status(session_id)
|
|
1379
|
+
except KeyboardInterrupt:
|
|
1380
|
+
progress.console.print("\n\n[bold yellow]⚠️ Interrupted by user (Ctrl+C)[/bold yellow]")
|
|
1381
|
+
progress.console.print(f" Session {session_id} status check interrupted")
|
|
1382
|
+
progress.console.print(" Returning current session status...")
|
|
1383
|
+
# Return last known status or get it one more time if possible
|
|
1384
|
+
try:
|
|
1385
|
+
return self.get_session_status(session_id)
|
|
1386
|
+
except:
|
|
1387
|
+
# If we can't get status, return a basic SessionInfo
|
|
1388
|
+
from featrixsphere.client import SessionInfo
|
|
1389
|
+
return SessionInfo(session_id=session_id, status="unknown", jobs={})
|
|
1361
1390
|
|
|
1362
1391
|
# Update session progress
|
|
1363
1392
|
elapsed = time.time() - start_time
|
|
@@ -1426,7 +1455,19 @@ class FeatrixSphereClient:
|
|
|
1426
1455
|
jobs_appeared = False
|
|
1427
1456
|
|
|
1428
1457
|
while time.time() - initial_wait_start < initial_wait_timeout:
|
|
1429
|
-
|
|
1458
|
+
try:
|
|
1459
|
+
session_info = self.get_session_status(session_id)
|
|
1460
|
+
except KeyboardInterrupt:
|
|
1461
|
+
logger.info("\n⚠️ Interrupted by user (Ctrl+C)")
|
|
1462
|
+
logger.info(f" Session {session_id} status check interrupted")
|
|
1463
|
+
logger.info(" Returning current session status...")
|
|
1464
|
+
# Return last known status or get it one more time if possible
|
|
1465
|
+
try:
|
|
1466
|
+
return self.get_session_status(session_id)
|
|
1467
|
+
except:
|
|
1468
|
+
# If we can't get status, return a basic SessionInfo
|
|
1469
|
+
return SessionInfo(session_id=session_id, status="unknown", jobs={})
|
|
1470
|
+
|
|
1430
1471
|
if session_info.jobs:
|
|
1431
1472
|
jobs_appeared = True
|
|
1432
1473
|
break
|
|
@@ -1434,7 +1475,20 @@ class FeatrixSphereClient:
|
|
|
1434
1475
|
|
|
1435
1476
|
# Main monitoring loop
|
|
1436
1477
|
while time.time() - start_time < max_wait_time:
|
|
1437
|
-
|
|
1478
|
+
try:
|
|
1479
|
+
session_info = self.get_session_status(session_id)
|
|
1480
|
+
except KeyboardInterrupt:
|
|
1481
|
+
logger.info("\n⚠️ Interrupted by user (Ctrl+C)")
|
|
1482
|
+
logger.info(f" Session {session_id} status check interrupted")
|
|
1483
|
+
logger.info(" Returning current session status...")
|
|
1484
|
+
# Return last known status or get it one more time if possible
|
|
1485
|
+
try:
|
|
1486
|
+
return self.get_session_status(session_id)
|
|
1487
|
+
except:
|
|
1488
|
+
# If we can't get status, return a basic SessionInfo
|
|
1489
|
+
from featrixsphere.client import SessionInfo
|
|
1490
|
+
return SessionInfo(session_id=session_id, status="unknown", jobs={})
|
|
1491
|
+
|
|
1438
1492
|
elapsed = time.time() - start_time
|
|
1439
1493
|
|
|
1440
1494
|
# Call the callback with current status
|
|
@@ -1471,7 +1525,19 @@ class FeatrixSphereClient:
|
|
|
1471
1525
|
last_num_lines = 0
|
|
1472
1526
|
|
|
1473
1527
|
while time.time() - start_time < max_wait_time:
|
|
1474
|
-
|
|
1528
|
+
try:
|
|
1529
|
+
session_info = self.get_session_status(session_id)
|
|
1530
|
+
except KeyboardInterrupt:
|
|
1531
|
+
print("\n\n⚠️ Interrupted by user (Ctrl+C)")
|
|
1532
|
+
print(f" Session {session_id} status check interrupted")
|
|
1533
|
+
print(" Returning current session status...")
|
|
1534
|
+
# Return last known status or get it one more time if possible
|
|
1535
|
+
try:
|
|
1536
|
+
return self.get_session_status(session_id)
|
|
1537
|
+
except:
|
|
1538
|
+
# If we can't get status, return a basic SessionInfo
|
|
1539
|
+
from featrixsphere.client import SessionInfo
|
|
1540
|
+
return SessionInfo(session_id=session_id, status="unknown", jobs={})
|
|
1475
1541
|
|
|
1476
1542
|
# Clear previous lines if terminal supports it
|
|
1477
1543
|
if sys.stdout.isatty() and last_num_lines > 0:
|
|
@@ -2866,6 +2932,110 @@ class FeatrixSphereClient:
|
|
|
2866
2932
|
response_data = self._post_json(f"/session/{session_id}/predict", request_payload, max_retries=max_retries)
|
|
2867
2933
|
return response_data
|
|
2868
2934
|
|
|
2935
|
+
def explain(self, session_id: str, record, class_idx: int = None,
|
|
2936
|
+
target_column: str = None, predictor_id: str = None,
|
|
2937
|
+
record_b: Dict[str, Any] = None,
|
|
2938
|
+
max_retries: int = None) -> Dict[str, Any]:
|
|
2939
|
+
"""
|
|
2940
|
+
Explain a prediction using gradient attribution.
|
|
2941
|
+
|
|
2942
|
+
Supports multiple modes:
|
|
2943
|
+
- explain(record): Explain a single row
|
|
2944
|
+
- explain(record, record_b=other_record): Compare two rows
|
|
2945
|
+
- explain([record1, record2, ...]): Explain multiple rows
|
|
2946
|
+
|
|
2947
|
+
Returns what matters to Featrix in the given row(s):
|
|
2948
|
+
- Which features mattered for this prediction
|
|
2949
|
+
- Which relationships mattered for this prediction
|
|
2950
|
+
|
|
2951
|
+
Args:
|
|
2952
|
+
session_id: ID of session with trained predictor
|
|
2953
|
+
record: Record dictionary (without target column), or list of records
|
|
2954
|
+
class_idx: Target class index for attribution (default: predicted class)
|
|
2955
|
+
target_column: Specific target column predictor to use (optional)
|
|
2956
|
+
predictor_id: Specific predictor ID to use (optional)
|
|
2957
|
+
record_b: Optional second record for comparison
|
|
2958
|
+
max_retries: Number of retries for errors (default: uses client default)
|
|
2959
|
+
|
|
2960
|
+
Returns:
|
|
2961
|
+
For single record:
|
|
2962
|
+
Dictionary with:
|
|
2963
|
+
- feature_scores: {col_name: score} - gradient norm per feature
|
|
2964
|
+
- pair_scores: {(i, j): score} - gradient norm per relationship pair
|
|
2965
|
+
- target_class_idx: The class index used for attribution
|
|
2966
|
+
- logit: The prediction logit
|
|
2967
|
+
|
|
2968
|
+
For two records (record_b provided):
|
|
2969
|
+
Dictionary with:
|
|
2970
|
+
- record_a: Explanation for first record
|
|
2971
|
+
- record_b: Explanation for second record
|
|
2972
|
+
- difference: Difference in feature_scores and pair_scores
|
|
2973
|
+
|
|
2974
|
+
For list of records:
|
|
2975
|
+
Dictionary with:
|
|
2976
|
+
- explanations: List of explanation dictionaries, one per record
|
|
2977
|
+
"""
|
|
2978
|
+
# Clean NaN/Inf values
|
|
2979
|
+
if isinstance(record, list):
|
|
2980
|
+
cleaned_record = [self.replace_nans_with_nulls(self._clean_numpy_values(r)) for r in record]
|
|
2981
|
+
else:
|
|
2982
|
+
cleaned_record = self.replace_nans_with_nulls(self._clean_numpy_values(record))
|
|
2983
|
+
|
|
2984
|
+
cleaned_record_b = None
|
|
2985
|
+
if record_b is not None:
|
|
2986
|
+
cleaned_record_b = self.replace_nans_with_nulls(self._clean_numpy_values(record_b))
|
|
2987
|
+
|
|
2988
|
+
# Build request payload
|
|
2989
|
+
request_payload = {
|
|
2990
|
+
"query_record": cleaned_record,
|
|
2991
|
+
}
|
|
2992
|
+
|
|
2993
|
+
if class_idx is not None:
|
|
2994
|
+
request_payload["class_idx"] = class_idx
|
|
2995
|
+
if target_column:
|
|
2996
|
+
request_payload["target_column"] = target_column
|
|
2997
|
+
if predictor_id:
|
|
2998
|
+
request_payload["predictor_id"] = predictor_id
|
|
2999
|
+
if cleaned_record_b is not None:
|
|
3000
|
+
request_payload["query_record_b"] = cleaned_record_b
|
|
3001
|
+
|
|
3002
|
+
response_data = self._post_json(f"/session/{session_id}/explain", request_payload, max_retries=max_retries)
|
|
3003
|
+
|
|
3004
|
+
# Helper to convert pair_scores keys back to tuples
|
|
3005
|
+
def convert_pair_scores(ps_dict):
|
|
3006
|
+
if not isinstance(ps_dict, dict):
|
|
3007
|
+
return ps_dict
|
|
3008
|
+
result = {}
|
|
3009
|
+
for key, score in ps_dict.items():
|
|
3010
|
+
# Key format is "i_j"
|
|
3011
|
+
parts = key.split("_")
|
|
3012
|
+
if len(parts) == 2:
|
|
3013
|
+
try:
|
|
3014
|
+
i, j = int(parts[0]), int(parts[1])
|
|
3015
|
+
result[(i, j)] = score
|
|
3016
|
+
except ValueError:
|
|
3017
|
+
result[key] = score
|
|
3018
|
+
else:
|
|
3019
|
+
result[key] = score
|
|
3020
|
+
return result
|
|
3021
|
+
|
|
3022
|
+
# Convert pair_scores keys back to tuples for easier use
|
|
3023
|
+
if "pair_scores" in response_data:
|
|
3024
|
+
response_data["pair_scores"] = convert_pair_scores(response_data["pair_scores"])
|
|
3025
|
+
elif "explanations" in response_data:
|
|
3026
|
+
for expl in response_data["explanations"]:
|
|
3027
|
+
if "pair_scores" in expl:
|
|
3028
|
+
expl["pair_scores"] = convert_pair_scores(expl["pair_scores"])
|
|
3029
|
+
elif "record_a" in response_data:
|
|
3030
|
+
if "pair_scores" in response_data["record_a"]:
|
|
3031
|
+
response_data["record_a"]["pair_scores"] = convert_pair_scores(response_data["record_a"]["pair_scores"])
|
|
3032
|
+
if "pair_scores" in response_data["record_b"]:
|
|
3033
|
+
response_data["record_b"]["pair_scores"] = convert_pair_scores(response_data["record_b"]["pair_scores"])
|
|
3034
|
+
if "difference" in response_data and "pair_scores" in response_data["difference"]:
|
|
3035
|
+
response_data["difference"]["pair_scores"] = convert_pair_scores(response_data["difference"]["pair_scores"])
|
|
3036
|
+
|
|
3037
|
+
return response_data
|
|
3038
|
+
|
|
2869
3039
|
def get_training_metrics(self, session_id: str, max_retries: int = None) -> Dict[str, Any]:
|
|
2870
3040
|
"""
|
|
2871
3041
|
Get training metrics for a session's single predictor.
|
|
@@ -3985,108 +4155,6 @@ class FeatrixSphereClient:
|
|
|
3985
4155
|
|
|
3986
4156
|
return predictor_id
|
|
3987
4157
|
|
|
3988
|
-
def _resolve_predictor_id(self, session_id: str, predictor_id: str = None, target_column: str = None, debug: bool = False) -> Dict[str, Any]:
|
|
3989
|
-
"""
|
|
3990
|
-
Resolve predictor_id or target_column to predictor information.
|
|
3991
|
-
|
|
3992
|
-
Args:
|
|
3993
|
-
session_id: Session ID to check
|
|
3994
|
-
predictor_id: Specific predictor ID to resolve
|
|
3995
|
-
target_column: Target column name (fallback if predictor_id not provided)
|
|
3996
|
-
debug: Whether to print debug information
|
|
3997
|
-
|
|
3998
|
-
Returns:
|
|
3999
|
-
Dictionary with predictor info including target_column, path, predictor_id
|
|
4000
|
-
|
|
4001
|
-
Raises:
|
|
4002
|
-
ValueError: If predictor not found or ambiguous
|
|
4003
|
-
"""
|
|
4004
|
-
available_predictors = self._get_available_predictors(session_id, debug=debug)
|
|
4005
|
-
|
|
4006
|
-
if not available_predictors:
|
|
4007
|
-
# Don't fail here - let the server try to find/auto-discover the predictor
|
|
4008
|
-
# The server's /predict endpoint has smart fallback logic to find checkpoint files
|
|
4009
|
-
# even if the session file wasn't properly updated (e.g., training crashed)
|
|
4010
|
-
if debug:
|
|
4011
|
-
print(f"⚠️ No predictors found via models endpoint, letting server handle discovery")
|
|
4012
|
-
return {
|
|
4013
|
-
'target_column': target_column,
|
|
4014
|
-
'predictor_id': predictor_id,
|
|
4015
|
-
'path': None,
|
|
4016
|
-
'type': None,
|
|
4017
|
-
'server_discovery': True # Flag that server should auto-discover
|
|
4018
|
-
}
|
|
4019
|
-
|
|
4020
|
-
# If predictor_id is provided, find it directly (since it's now the key)
|
|
4021
|
-
if predictor_id:
|
|
4022
|
-
if predictor_id in available_predictors:
|
|
4023
|
-
predictor_info = available_predictors[predictor_id]
|
|
4024
|
-
return {
|
|
4025
|
-
'target_column': predictor_info.get('target_column'),
|
|
4026
|
-
'predictor_id': predictor_id,
|
|
4027
|
-
'path': predictor_info.get('path'),
|
|
4028
|
-
'type': predictor_info.get('type')
|
|
4029
|
-
}
|
|
4030
|
-
|
|
4031
|
-
# Predictor ID not found
|
|
4032
|
-
all_predictor_ids = list(available_predictors.keys())
|
|
4033
|
-
|
|
4034
|
-
raise ValueError(
|
|
4035
|
-
f"Predictor ID '{predictor_id}' not found for session {session_id}. "
|
|
4036
|
-
f"Available predictor IDs: {all_predictor_ids}"
|
|
4037
|
-
)
|
|
4038
|
-
|
|
4039
|
-
# Fallback to target_column validation (search through values)
|
|
4040
|
-
if target_column is None:
|
|
4041
|
-
# Auto-detect: only valid if there's exactly one predictor
|
|
4042
|
-
if len(available_predictors) == 1:
|
|
4043
|
-
predictor_id = list(available_predictors.keys())[0]
|
|
4044
|
-
predictor_info = available_predictors[predictor_id]
|
|
4045
|
-
return {
|
|
4046
|
-
'target_column': predictor_info.get('target_column'),
|
|
4047
|
-
'predictor_id': predictor_id,
|
|
4048
|
-
'path': predictor_info.get('path'),
|
|
4049
|
-
'type': predictor_info.get('type')
|
|
4050
|
-
}
|
|
4051
|
-
else:
|
|
4052
|
-
# Show unique target columns for clarity
|
|
4053
|
-
target_columns = list(set(pred.get('target_column') for pred in available_predictors.values()))
|
|
4054
|
-
raise ValueError(
|
|
4055
|
-
f"Multiple predictors found for session {session_id} with target columns: {target_columns}. "
|
|
4056
|
-
f"Please specify predictor_id parameter for precise selection."
|
|
4057
|
-
)
|
|
4058
|
-
else:
|
|
4059
|
-
# Find predictors by target column (there might be multiple)
|
|
4060
|
-
matching_predictors = {
|
|
4061
|
-
pred_id: pred_info for pred_id, pred_info in available_predictors.items()
|
|
4062
|
-
if pred_info.get('target_column') == target_column
|
|
4063
|
-
}
|
|
4064
|
-
|
|
4065
|
-
if not matching_predictors:
|
|
4066
|
-
target_columns = list(set(pred.get('target_column') for pred in available_predictors.values()))
|
|
4067
|
-
raise ValueError(
|
|
4068
|
-
f"No trained predictor found for target column '{target_column}' in session {session_id}. "
|
|
4069
|
-
f"Available target columns: {target_columns}"
|
|
4070
|
-
)
|
|
4071
|
-
|
|
4072
|
-
if len(matching_predictors) == 1:
|
|
4073
|
-
# Only one predictor for this target column
|
|
4074
|
-
predictor_id = list(matching_predictors.keys())[0]
|
|
4075
|
-
predictor_info = matching_predictors[predictor_id]
|
|
4076
|
-
return {
|
|
4077
|
-
'target_column': target_column,
|
|
4078
|
-
'predictor_id': predictor_id,
|
|
4079
|
-
'path': predictor_info.get('path'),
|
|
4080
|
-
'type': predictor_info.get('type')
|
|
4081
|
-
}
|
|
4082
|
-
else:
|
|
4083
|
-
# Multiple predictors for the same target column
|
|
4084
|
-
predictor_ids = list(matching_predictors.keys())
|
|
4085
|
-
raise ValueError(
|
|
4086
|
-
f"Multiple predictors found for target column '{target_column}' in session {session_id}: {predictor_ids}. "
|
|
4087
|
-
f"Please specify predictor_id parameter for precise selection."
|
|
4088
|
-
)
|
|
4089
|
-
|
|
4090
4158
|
def list_predictors(self, session_id: str, verbose: bool = True, debug: bool = False) -> Dict[str, Dict[str, Any]]:
|
|
4091
4159
|
"""
|
|
4092
4160
|
List all available predictors in a session and their target columns.
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
featrixsphere/__init__.py,sha256=IsFr56pqNLZdE9VfJJgYXY7FiHx2WW29eoUB3urP_-g,1888
|
|
2
|
+
featrixsphere/client.py,sha256=8rcft0KMNkW3pUQPsk70ZudWJB_ndCPQBj_MctMf7a0,435094
|
|
3
|
+
featrixsphere-0.2.5182.dist-info/METADATA,sha256=TEPYBNmlbLMAkQsd5ri0N2qMdjZogp94erecVxC48EQ,16232
|
|
4
|
+
featrixsphere-0.2.5182.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
5
|
+
featrixsphere-0.2.5182.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
|
|
6
|
+
featrixsphere-0.2.5182.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
|
|
7
|
+
featrixsphere-0.2.5182.dist-info/RECORD,,
|
featrixsphere/test_client.py
DELETED
|
@@ -1,311 +0,0 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
"""
|
|
3
|
-
Tests for FeatrixSphereClient
|
|
4
|
-
|
|
5
|
-
These tests verify basic functionality without requiring a live API server.
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
import unittest
|
|
9
|
-
from unittest.mock import Mock, patch, MagicMock
|
|
10
|
-
import json
|
|
11
|
-
from pathlib import Path
|
|
12
|
-
import sys
|
|
13
|
-
|
|
14
|
-
# Mock optional dependencies before importing featrixsphere
|
|
15
|
-
import sys
|
|
16
|
-
|
|
17
|
-
# Mock numpy
|
|
18
|
-
try:
|
|
19
|
-
import numpy as np
|
|
20
|
-
except ImportError:
|
|
21
|
-
class MockNumpy:
|
|
22
|
-
class ndarray:
|
|
23
|
-
pass
|
|
24
|
-
def array(self, *args, **kwargs):
|
|
25
|
-
return []
|
|
26
|
-
def __getattr__(self, name):
|
|
27
|
-
return lambda *args, **kwargs: None
|
|
28
|
-
sys.modules['numpy'] = MockNumpy()
|
|
29
|
-
|
|
30
|
-
# Mock matplotlib
|
|
31
|
-
try:
|
|
32
|
-
import matplotlib.pyplot as plt
|
|
33
|
-
except ImportError:
|
|
34
|
-
class MockMatplotlib:
|
|
35
|
-
class Figure:
|
|
36
|
-
pass
|
|
37
|
-
def __getattr__(self, name):
|
|
38
|
-
return lambda *args, **kwargs: None
|
|
39
|
-
sys.modules['matplotlib'] = MockMatplotlib()
|
|
40
|
-
sys.modules['matplotlib.pyplot'] = MockMatplotlib()
|
|
41
|
-
sys.modules['matplotlib.dates'] = MockMatplotlib()
|
|
42
|
-
|
|
43
|
-
# Add parent directory to path to import featrixsphere
|
|
44
|
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
45
|
-
|
|
46
|
-
try:
|
|
47
|
-
from featrixsphere import FeatrixSphereClient, SessionInfo, PredictionBatch
|
|
48
|
-
except (ImportError, AttributeError) as e:
|
|
49
|
-
# If import fails, create minimal mocks for basic structure tests
|
|
50
|
-
print(f"⚠️ Warning: Could not fully import featrixsphere: {e}")
|
|
51
|
-
print(" Running minimal structure tests only...")
|
|
52
|
-
|
|
53
|
-
# Create minimal mocks for basic testing
|
|
54
|
-
class MockSession:
|
|
55
|
-
def __init__(self):
|
|
56
|
-
self.headers = {}
|
|
57
|
-
self.timeout = 30
|
|
58
|
-
|
|
59
|
-
class MockFeatrixSphereClient:
|
|
60
|
-
def __init__(self, base_url="http://test.com", **kwargs):
|
|
61
|
-
self.base_url = base_url.rstrip('/')
|
|
62
|
-
self.compute_cluster = kwargs.get('compute_cluster')
|
|
63
|
-
self.session = MockSession()
|
|
64
|
-
self.default_max_retries = 5
|
|
65
|
-
# Set header if compute_cluster provided
|
|
66
|
-
if self.compute_cluster:
|
|
67
|
-
self.session.headers['X-Featrix-Node'] = self.compute_cluster
|
|
68
|
-
|
|
69
|
-
def set_compute_cluster(self, cluster):
|
|
70
|
-
self.compute_cluster = cluster
|
|
71
|
-
if cluster:
|
|
72
|
-
self.session.headers['X-Featrix-Node'] = cluster
|
|
73
|
-
else:
|
|
74
|
-
self.session.headers.pop('X-Featrix-Node', None)
|
|
75
|
-
|
|
76
|
-
def _make_request(self, method, endpoint, **kwargs):
|
|
77
|
-
from unittest.mock import Mock
|
|
78
|
-
response = Mock()
|
|
79
|
-
response.status_code = 200
|
|
80
|
-
response.json.return_value = {}
|
|
81
|
-
return response
|
|
82
|
-
|
|
83
|
-
class MockSessionInfo:
|
|
84
|
-
def __init__(self, session_id, session_type, status, jobs, job_queue_positions, _client=None):
|
|
85
|
-
self.session_id = session_id
|
|
86
|
-
self.session_type = session_type
|
|
87
|
-
self.status = status
|
|
88
|
-
|
|
89
|
-
class MockPredictionBatch:
|
|
90
|
-
def __init__(self, session_id, client, target_column=None):
|
|
91
|
-
self.session_id = session_id
|
|
92
|
-
self.client = client
|
|
93
|
-
self._cache = {}
|
|
94
|
-
self._stats = {'hits': 0, 'misses': 0, 'populated': 0}
|
|
95
|
-
|
|
96
|
-
def _hash_record(self, record):
|
|
97
|
-
import hashlib
|
|
98
|
-
import json
|
|
99
|
-
sorted_items = sorted(record.items())
|
|
100
|
-
record_str = json.dumps(sorted_items, sort_keys=True)
|
|
101
|
-
return hashlib.md5(record_str.encode()).hexdigest()
|
|
102
|
-
|
|
103
|
-
def predict(self, record):
|
|
104
|
-
record_hash = self._hash_record(record)
|
|
105
|
-
if record_hash in self._cache:
|
|
106
|
-
self._stats['hits'] += 1
|
|
107
|
-
return self._cache[record_hash]
|
|
108
|
-
else:
|
|
109
|
-
self._stats['misses'] += 1
|
|
110
|
-
return {
|
|
111
|
-
'cache_miss': True,
|
|
112
|
-
'record': record,
|
|
113
|
-
'suggestion': 'Record not found in batch cache. Add to records list and recreate batch.'
|
|
114
|
-
}
|
|
115
|
-
|
|
116
|
-
def get_stats(self):
|
|
117
|
-
total = self._stats['hits'] + self._stats['misses']
|
|
118
|
-
return {
|
|
119
|
-
'cache_hits': self._stats['hits'],
|
|
120
|
-
'cache_misses': self._stats['misses'],
|
|
121
|
-
'total_requests': total,
|
|
122
|
-
'hit_rate': self._stats['hits'] / total if total > 0 else 0.0
|
|
123
|
-
}
|
|
124
|
-
|
|
125
|
-
FeatrixSphereClient = MockFeatrixSphereClient
|
|
126
|
-
SessionInfo = MockSessionInfo
|
|
127
|
-
PredictionBatch = MockPredictionBatch
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
class TestFeatrixSphereClient(unittest.TestCase):
|
|
131
|
-
"""Test cases for FeatrixSphereClient."""
|
|
132
|
-
|
|
133
|
-
def setUp(self):
|
|
134
|
-
"""Set up test fixtures."""
|
|
135
|
-
self.client = FeatrixSphereClient(base_url="http://test-server.com")
|
|
136
|
-
|
|
137
|
-
def test_client_initialization(self):
|
|
138
|
-
"""Test that client initializes correctly."""
|
|
139
|
-
self.assertEqual(self.client.base_url, "http://test-server.com")
|
|
140
|
-
self.assertIsNotNone(self.client.session)
|
|
141
|
-
self.assertEqual(self.client.default_max_retries, 5)
|
|
142
|
-
|
|
143
|
-
def test_client_with_compute_cluster(self):
|
|
144
|
-
"""Test client initialization with compute cluster."""
|
|
145
|
-
client = FeatrixSphereClient(
|
|
146
|
-
base_url="http://test-server.com",
|
|
147
|
-
compute_cluster="burrito"
|
|
148
|
-
)
|
|
149
|
-
self.assertEqual(client.compute_cluster, "burrito")
|
|
150
|
-
self.assertIn("X-Featrix-Node", client.session.headers)
|
|
151
|
-
self.assertEqual(client.session.headers["X-Featrix-Node"], "burrito")
|
|
152
|
-
|
|
153
|
-
def test_set_compute_cluster(self):
|
|
154
|
-
"""Test setting compute cluster after initialization."""
|
|
155
|
-
self.client.set_compute_cluster("churro")
|
|
156
|
-
self.assertEqual(self.client.compute_cluster, "churro")
|
|
157
|
-
self.assertEqual(self.client.session.headers.get("X-Featrix-Node"), "churro")
|
|
158
|
-
|
|
159
|
-
# Test removing cluster
|
|
160
|
-
self.client.set_compute_cluster(None)
|
|
161
|
-
self.assertIsNone(self.client.compute_cluster)
|
|
162
|
-
self.assertNotIn("X-Featrix-Node", self.client.session.headers)
|
|
163
|
-
|
|
164
|
-
def test_endpoint_auto_prefix(self):
|
|
165
|
-
"""Test that session endpoints get /compute prefix automatically."""
|
|
166
|
-
# Skip if using mocks (client doesn't have full requests functionality)
|
|
167
|
-
if not hasattr(self.client.session, 'get'):
|
|
168
|
-
self.skipTest("Skipping - using mocks without full requests support")
|
|
169
|
-
|
|
170
|
-
with patch.object(self.client.session, 'get') as mock_get:
|
|
171
|
-
mock_response = Mock()
|
|
172
|
-
mock_response.status_code = 200
|
|
173
|
-
mock_response.json.return_value = {}
|
|
174
|
-
mock_get.return_value = mock_response
|
|
175
|
-
|
|
176
|
-
# Should auto-add /compute prefix
|
|
177
|
-
self.client._make_request('GET', '/session/test-123')
|
|
178
|
-
mock_get.assert_called_once()
|
|
179
|
-
call_url = mock_get.call_args[0][0]
|
|
180
|
-
self.assertIn('/compute/session/test-123', call_url)
|
|
181
|
-
|
|
182
|
-
def test_session_info_initialization(self):
|
|
183
|
-
"""Test SessionInfo dataclass initialization."""
|
|
184
|
-
session = SessionInfo(
|
|
185
|
-
session_id="test-123",
|
|
186
|
-
session_type="embedding_space",
|
|
187
|
-
status="complete",
|
|
188
|
-
jobs={},
|
|
189
|
-
job_queue_positions={}
|
|
190
|
-
)
|
|
191
|
-
self.assertEqual(session.session_id, "test-123")
|
|
192
|
-
self.assertEqual(session.session_type, "embedding_space")
|
|
193
|
-
self.assertEqual(session.status, "complete")
|
|
194
|
-
|
|
195
|
-
def test_prediction_batch_hash_record(self):
|
|
196
|
-
"""Test PredictionBatch record hashing."""
|
|
197
|
-
batch = PredictionBatch("test-123", self.client)
|
|
198
|
-
|
|
199
|
-
record1 = {"a": 1, "b": 2}
|
|
200
|
-
record2 = {"b": 2, "a": 1} # Same keys, different order
|
|
201
|
-
record3 = {"a": 1, "b": 3} # Different value
|
|
202
|
-
|
|
203
|
-
hash1 = batch._hash_record(record1)
|
|
204
|
-
hash2 = batch._hash_record(record2)
|
|
205
|
-
hash3 = batch._hash_record(record3)
|
|
206
|
-
|
|
207
|
-
# Same records should hash to same value (order-independent)
|
|
208
|
-
self.assertEqual(hash1, hash2)
|
|
209
|
-
# Different records should hash to different values
|
|
210
|
-
self.assertNotEqual(hash1, hash3)
|
|
211
|
-
|
|
212
|
-
def test_prediction_batch_cache_miss(self):
|
|
213
|
-
"""Test PredictionBatch cache miss behavior."""
|
|
214
|
-
batch = PredictionBatch("test-123", self.client)
|
|
215
|
-
|
|
216
|
-
record = {"feature": "value"}
|
|
217
|
-
result = batch.predict(record)
|
|
218
|
-
|
|
219
|
-
self.assertTrue(result.get('cache_miss'))
|
|
220
|
-
self.assertEqual(result.get('record'), record)
|
|
221
|
-
self.assertIn('suggestion', result)
|
|
222
|
-
|
|
223
|
-
def test_prediction_batch_stats(self):
|
|
224
|
-
"""Test PredictionBatch statistics tracking."""
|
|
225
|
-
batch = PredictionBatch("test-123", self.client)
|
|
226
|
-
|
|
227
|
-
# Make some predictions (cache misses)
|
|
228
|
-
batch.predict({"a": 1})
|
|
229
|
-
batch.predict({"b": 2})
|
|
230
|
-
|
|
231
|
-
stats = batch.get_stats()
|
|
232
|
-
self.assertEqual(stats['cache_misses'], 2)
|
|
233
|
-
self.assertEqual(stats['cache_hits'], 0)
|
|
234
|
-
self.assertEqual(stats['total_requests'], 2)
|
|
235
|
-
self.assertEqual(stats['hit_rate'], 0.0)
|
|
236
|
-
|
|
237
|
-
def test_prediction_batch_cache_hit(self):
|
|
238
|
-
"""Test PredictionBatch cache hit behavior."""
|
|
239
|
-
batch = PredictionBatch("test-123", self.client)
|
|
240
|
-
|
|
241
|
-
# Manually populate cache
|
|
242
|
-
record = {"feature": "value"}
|
|
243
|
-
record_hash = batch._hash_record(record)
|
|
244
|
-
batch._cache[record_hash] = {"prediction": "test_result"}
|
|
245
|
-
batch._stats['populated'] = 1
|
|
246
|
-
|
|
247
|
-
# Now predict should hit cache
|
|
248
|
-
result = batch.predict(record)
|
|
249
|
-
self.assertFalse(result.get('cache_miss', False))
|
|
250
|
-
self.assertEqual(result.get('prediction'), "test_result")
|
|
251
|
-
|
|
252
|
-
stats = batch.get_stats()
|
|
253
|
-
self.assertEqual(stats['cache_hits'], 1)
|
|
254
|
-
self.assertEqual(stats['cache_misses'], 0)
|
|
255
|
-
self.assertEqual(stats['hit_rate'], 1.0)
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
class TestClientErrorHandling(unittest.TestCase):
|
|
259
|
-
"""Test error handling in FeatrixSphereClient."""
|
|
260
|
-
|
|
261
|
-
def setUp(self):
|
|
262
|
-
"""Set up test fixtures."""
|
|
263
|
-
self.client = FeatrixSphereClient(base_url="http://test-server.com")
|
|
264
|
-
|
|
265
|
-
def test_make_request_retry_on_500(self):
|
|
266
|
-
"""Test that 500 errors trigger retries."""
|
|
267
|
-
# Skip if using mocks (client doesn't have full requests functionality)
|
|
268
|
-
if not hasattr(self.client.session, 'get'):
|
|
269
|
-
self.skipTest("Skipping - using mocks without full requests support")
|
|
270
|
-
|
|
271
|
-
with patch.object(self.client.session, 'get') as mock_get:
|
|
272
|
-
# First call returns 500, second returns 200
|
|
273
|
-
mock_response_500 = Mock()
|
|
274
|
-
mock_response_500.status_code = 500
|
|
275
|
-
mock_response_200 = Mock()
|
|
276
|
-
mock_response_200.status_code = 200
|
|
277
|
-
mock_response_200.json.return_value = {}
|
|
278
|
-
mock_get.side_effect = [mock_response_500, mock_response_200]
|
|
279
|
-
|
|
280
|
-
# Should retry and eventually succeed
|
|
281
|
-
response = self.client._make_request('GET', '/test', max_retries=2)
|
|
282
|
-
self.assertEqual(response.status_code, 200)
|
|
283
|
-
self.assertEqual(mock_get.call_count, 2)
|
|
284
|
-
|
|
285
|
-
def test_make_request_timeout(self):
|
|
286
|
-
"""Test timeout handling."""
|
|
287
|
-
# Skip if using mocks (client doesn't have full requests functionality)
|
|
288
|
-
if not hasattr(self.client.session, 'get'):
|
|
289
|
-
self.skipTest("Skipping - using mocks without full requests support")
|
|
290
|
-
|
|
291
|
-
import requests
|
|
292
|
-
with patch.object(self.client.session, 'get') as mock_get:
|
|
293
|
-
mock_get.side_effect = requests.exceptions.Timeout("Request timed out")
|
|
294
|
-
|
|
295
|
-
# Should raise after retries exhausted
|
|
296
|
-
with self.assertRaises(Exception):
|
|
297
|
-
self.client._make_request('GET', '/test', max_retries=1)
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
def run_tests():
|
|
301
|
-
"""Run all tests and return exit code."""
|
|
302
|
-
loader = unittest.TestLoader()
|
|
303
|
-
suite = loader.loadTestsFromModule(sys.modules[__name__])
|
|
304
|
-
runner = unittest.TextTestRunner(verbosity=2)
|
|
305
|
-
result = runner.run(suite)
|
|
306
|
-
return 0 if result.wasSuccessful() else 1
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
if __name__ == '__main__':
|
|
310
|
-
sys.exit(run_tests())
|
|
311
|
-
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
featrixsphere/__init__.py,sha256=IU_WDJxGy4vfrUELv_Y381Oj_bzfD4GTcbgnqrIOigQ,1888
|
|
2
|
-
featrixsphere/client.py,sha256=kWGR7cYH0IDWNZDX5w8yPP9-2wA-3KfXxFlYl6w01wE,431295
|
|
3
|
-
featrixsphere/test_client.py,sha256=4SiRbib0ms3poK0UpnUv4G0HFQSzidF3Iswo_J2cjLk,11981
|
|
4
|
-
featrixsphere-0.2.4984.dist-info/METADATA,sha256=sRgn36Uyp-gIc-OwlnscbzyMerml0dFm7DPVwX8iBrs,16232
|
|
5
|
-
featrixsphere-0.2.4984.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
6
|
-
featrixsphere-0.2.4984.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
|
|
7
|
-
featrixsphere-0.2.4984.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
|
|
8
|
-
featrixsphere-0.2.4984.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|