featrixsphere 0.2.3613__py3-none-any.whl → 0.2.4982__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
featrixsphere/__init__.py CHANGED
@@ -38,7 +38,7 @@ Example:
38
38
  ... labels=['Experiment A', 'Experiment B'])
39
39
  """
40
40
 
41
- __version__ = "0.2.3613"
41
+ __version__ = "0.2.4982"
42
42
  __author__ = "Featrix"
43
43
  __email__ = "support@featrix.com"
44
44
  __license__ = "MIT"
featrixsphere/client.py CHANGED
@@ -11,7 +11,7 @@ import time
11
11
  import requests
12
12
  from pathlib import Path
13
13
  from typing import Dict, Any, Optional, List, Tuple, Union
14
- from dataclasses import dataclass
14
+ from dataclasses import dataclass, field
15
15
  import gzip
16
16
  import os
17
17
  import random
@@ -65,6 +65,7 @@ class SessionInfo:
65
65
  status: str
66
66
  jobs: Dict[str, Any]
67
67
  job_queue_positions: Dict[str, Any]
68
+ job_plan: List[Dict[str, Any]] = field(default_factory=list)
68
69
  _client: Optional['FeatrixSphereClient'] = None
69
70
 
70
71
  def predictors(self) -> List[Dict[str, Any]]:
@@ -82,6 +83,51 @@ class SessionInfo:
82
83
  return list(predictors_dict.values())
83
84
  except Exception:
84
85
  return []
86
+
87
+ def embedding_space_info(self) -> Optional[Dict[str, Any]]:
88
+ """
89
+ Get embedding space information for this session.
90
+
91
+ Returns:
92
+ Dictionary with ES info (dimensions, epochs, etc.) or None if not available
93
+ """
94
+ if not self._client:
95
+ return None
96
+
97
+ try:
98
+ # Get session details from the client
99
+ session_data = self._client._get_json(f"/compute/session/{self.session_id}")
100
+
101
+ es_info = {}
102
+
103
+ # Extract embedding space path
104
+ embedding_space_path = session_data.get('embedding_space')
105
+ if embedding_space_path:
106
+ es_info['embedding_space_path'] = embedding_space_path
107
+
108
+ # Extract model architecture info
109
+ model_info = session_data.get('model_info', {}) or session_data.get('embedding_space', {})
110
+ if isinstance(model_info, dict):
111
+ es_info['d_model'] = model_info.get('d_model') or model_info.get('embedding_dim')
112
+ es_info['dimensions'] = es_info.get('d_model') # Alias for compatibility
113
+ es_info['parameter_count'] = model_info.get('parameter_count') or model_info.get('num_parameters')
114
+ es_info['layer_count'] = model_info.get('layer_count') or model_info.get('num_layers')
115
+
116
+ # Extract training statistics
117
+ training_stats = session_data.get('training_stats', {}) or session_data.get('stats', {})
118
+ if isinstance(training_stats, dict):
119
+ es_info['epochs'] = training_stats.get('final_epoch') or training_stats.get('epochs_trained') or training_stats.get('epochs')
120
+ es_info['final_loss'] = training_stats.get('final_loss') or training_stats.get('loss')
121
+ es_info['final_val_loss'] = training_stats.get('final_val_loss') or training_stats.get('validation_loss')
122
+ es_info['training_time_seconds'] = training_stats.get('training_time') or training_stats.get('elapsed_seconds')
123
+
124
+ # If we have any info, return it
125
+ if es_info:
126
+ return es_info
127
+
128
+ return None
129
+ except Exception:
130
+ return None
85
131
 
86
132
 
87
133
  class PredictionBatch:
@@ -566,6 +612,11 @@ class FeatrixSphereClient:
566
612
  response = self._make_request("DELETE", endpoint, max_retries=max_retries, **kwargs)
567
613
  return self._unwrap_response(response.json())
568
614
 
615
+ def _post_multipart(self, endpoint: str, data: Dict[str, Any] = None, files: Dict[str, Any] = None, max_retries: int = None, **kwargs) -> Dict[str, Any]:
616
+ """Make a POST request with multipart/form-data (for file uploads) and return JSON response."""
617
+ response = self._make_request("POST", endpoint, data=data, files=files, max_retries=max_retries, **kwargs)
618
+ return self._unwrap_response(response.json())
619
+
569
620
  # =========================================================================
570
621
  # Session Management
571
622
  # =========================================================================
@@ -612,6 +663,7 @@ class FeatrixSphereClient:
612
663
  status=response_data.get('status', 'unknown'),
613
664
  jobs={},
614
665
  job_queue_positions={},
666
+ job_plan=[],
615
667
  _client=self
616
668
  )
617
669
 
@@ -635,6 +687,7 @@ class FeatrixSphereClient:
635
687
  session = response_data.get('session', {})
636
688
  jobs = response_data.get('jobs', {})
637
689
  positions = response_data.get('job_queue_positions', {})
690
+ job_plan = session.get('job_plan', [])
638
691
 
639
692
  return SessionInfo(
640
693
  session_id=session.get('session_id', session_id),
@@ -642,6 +695,7 @@ class FeatrixSphereClient:
642
695
  status=session.get('status', 'unknown'),
643
696
  jobs=jobs,
644
697
  job_queue_positions=positions,
698
+ job_plan=job_plan,
645
699
  _client=self
646
700
  )
647
701
 
@@ -1365,6 +1419,20 @@ class FeatrixSphereClient:
1365
1419
 
1366
1420
  start_time = time.time()
1367
1421
 
1422
+ # Initial wait for job dispatch (jobs are dispatched asynchronously after session creation)
1423
+ # Wait up to 10 seconds for jobs to appear before starting main monitoring loop
1424
+ initial_wait_timeout = 10
1425
+ initial_wait_start = time.time()
1426
+ jobs_appeared = False
1427
+
1428
+ while time.time() - initial_wait_start < initial_wait_timeout:
1429
+ session_info = self.get_session_status(session_id)
1430
+ if session_info.jobs:
1431
+ jobs_appeared = True
1432
+ break
1433
+ time.sleep(0.5) # Check every 500ms during initial wait
1434
+
1435
+ # Main monitoring loop
1368
1436
  while time.time() - start_time < max_wait_time:
1369
1437
  session_info = self.get_session_status(session_id)
1370
1438
  elapsed = time.time() - start_time
@@ -1837,7 +1905,8 @@ class FeatrixSphereClient:
1837
1905
  session_type=response_data.get('session_type', 'embedding_space'),
1838
1906
  status=response_data.get('status', 'ready'),
1839
1907
  jobs={},
1840
- job_queue_positions={}
1908
+ job_queue_positions={},
1909
+ job_plan=[]
1841
1910
  )
1842
1911
 
1843
1912
  def fine_tune_embedding_space(
@@ -1967,7 +2036,144 @@ class FeatrixSphereClient:
1967
2036
  session_type=response_data.get('session_type', 'embedding_space_finetune'),
1968
2037
  status=response_data.get('status', 'ready'),
1969
2038
  jobs={},
1970
- job_queue_positions={}
2039
+ job_queue_positions={},
2040
+ job_plan=[]
2041
+ )
2042
+
2043
+ def extend_embedding_space(
2044
+ self,
2045
+ name: str,
2046
+ parent_session_id: str = None,
2047
+ parent_embedding_space_path: str = None,
2048
+ s3_training_dataset: str = None,
2049
+ s3_validation_dataset: str = None,
2050
+ n_epochs: int = None,
2051
+ webhooks: Dict[str, str] = None,
2052
+ user_metadata: Dict[str, Any] = None
2053
+ ) -> SessionInfo:
2054
+ """
2055
+ Extend an existing embedding space with new feature columns.
2056
+
2057
+ This method takes a pre-trained embedding space and extends it with new feature columns
2058
+ from enriched training/validation data. The extended ES preserves existing encoder weights
2059
+ and creates new codecs for the new columns.
2060
+
2061
+ **When to Use Extend vs Fine-Tune:**
2062
+ - **Extend**: When you've added NEW COLUMNS (features) to your dataset
2063
+ - **Fine-Tune**: When you have new rows with the SAME COLUMNS
2064
+
2065
+ **How It Works:**
2066
+ 1. Loads the parent embedding space
2067
+ 2. Identifies new columns in the enriched dataset
2068
+ 3. Creates codecs for the new columns
2069
+ 4. Copies existing encoder weights (preserves learned representations)
2070
+ 5. Fine-tunes for shorter duration (default: original_epochs / 4)
2071
+ 6. Returns extended embedding space with all columns
2072
+
2073
+ Args:
2074
+ name: Name for the extended embedding space
2075
+ parent_session_id: Session ID of the parent embedding space (optional)
2076
+ parent_embedding_space_path: Direct path to parent embedding space pickle file (optional)
2077
+ s3_training_dataset: S3 URL for enriched training dataset with new columns (must start with 's3://')
2078
+ s3_validation_dataset: S3 URL for enriched validation dataset with new columns (must start with 's3://')
2079
+ n_epochs: Number of epochs for extension training (default: original_epochs / 4)
2080
+ webhooks: Optional dict with webhook configuration keys
2081
+ user_metadata: Optional metadata dict to attach to the session
2082
+
2083
+ Returns:
2084
+ SessionInfo for the newly created extension session
2085
+
2086
+ Raises:
2087
+ ValueError: If S3 URLs are invalid or neither parent identifier is provided
2088
+
2089
+ Example:
2090
+ ```python
2091
+ # Extend an existing embedding space with new feature columns
2092
+ client = FeatrixSphereClient("https://sphere-api.featrix.com")
2093
+
2094
+ # Original ES was trained on: age, income, credit_score
2095
+ # New data includes engineered features: debt_to_income_ratio, age_bin
2096
+ extended = client.extend_embedding_space(
2097
+ name="customer_model_with_features",
2098
+ parent_session_id="abc123-20240101-120000",
2099
+ s3_training_dataset="s3://my-bucket/enriched_training.csv",
2100
+ s3_validation_dataset="s3://my-bucket/enriched_validation.csv",
2101
+ n_epochs=25 # Optional: specify epochs (defaults to original/4)
2102
+ )
2103
+
2104
+ # Wait for extension to complete
2105
+ client.wait_for_session_completion(extended.session_id)
2106
+
2107
+ # The extended ES now includes the new feature columns
2108
+ # Use it for predictions with enriched data
2109
+ result = client.predict(extended.session_id, {
2110
+ "age": 35,
2111
+ "income": 75000,
2112
+ "credit_score": 720,
2113
+ "debt_to_income_ratio": 0.25, # New feature!
2114
+ "age_bin": "30-40" # New feature!
2115
+ })
2116
+ ```
2117
+ """
2118
+ # Validate S3 URLs
2119
+ if s3_training_dataset and not s3_training_dataset.startswith('s3://'):
2120
+ raise ValueError("s3_training_dataset must be a valid S3 URL (s3://...)")
2121
+ if s3_validation_dataset and not s3_validation_dataset.startswith('s3://'):
2122
+ raise ValueError("s3_validation_dataset must be a valid S3 URL (s3://...)")
2123
+
2124
+ # Validate that we have either parent_session_id or parent_embedding_space_path
2125
+ if not parent_session_id and not parent_embedding_space_path:
2126
+ raise ValueError("Either parent_session_id or parent_embedding_space_path must be provided")
2127
+
2128
+ print(f"Extending embedding space '{name}' with new features...")
2129
+ if parent_session_id:
2130
+ print(f" Parent session: {parent_session_id}")
2131
+ if parent_embedding_space_path:
2132
+ print(f" Parent embedding space: {parent_embedding_space_path}")
2133
+ print(f" Enriched training data: {s3_training_dataset}")
2134
+ print(f" Enriched validation data: {s3_validation_dataset}")
2135
+ if n_epochs:
2136
+ print(f" Extension epochs: {n_epochs}")
2137
+ else:
2138
+ print(f" Extension epochs: auto (original/4)")
2139
+
2140
+ data = {
2141
+ "name": name,
2142
+ "s3_file_data_set_training": s3_training_dataset,
2143
+ "s3_file_data_set_validation": s3_validation_dataset
2144
+ }
2145
+
2146
+ if parent_session_id:
2147
+ data["parent_session_id"] = parent_session_id
2148
+ if parent_embedding_space_path:
2149
+ data["parent_embedding_space_path"] = parent_embedding_space_path
2150
+ if n_epochs is not None:
2151
+ data["n_epochs"] = n_epochs
2152
+
2153
+ if webhooks:
2154
+ data['webhooks'] = webhooks
2155
+ if user_metadata:
2156
+ import json
2157
+ data['user_metadata'] = json.dumps(user_metadata)
2158
+ print(f"User metadata: {user_metadata}")
2159
+
2160
+ response_data = self._post_json("/compute/extend-embedding-space", data)
2161
+
2162
+ session_id = response_data.get('session_id')
2163
+ extend_info = response_data.get('extend_es_info', {})
2164
+
2165
+ print(f"Extension session created: {session_id}")
2166
+ if extend_info:
2167
+ print(f" Original epochs: {extend_info.get('original_epochs', 'N/A')}")
2168
+ print(f" Extension epochs: {extend_info.get('extension_epochs', 'N/A')}")
2169
+
2170
+ return SessionInfo(
2171
+ session_id=session_id,
2172
+ session_type=response_data.get('session_type', 'embedding_space_extend'),
2173
+ status=response_data.get('status', 'ready'),
2174
+ jobs={},
2175
+ job_queue_positions={},
2176
+ job_plan=[]
1971
2177
  )
1972
2178
 
1973
2179
  # =========================================================================
@@ -2028,7 +2234,8 @@ class FeatrixSphereClient:
2028
2234
  session_type=response_data.get('session_type', 'sphere'),
2029
2235
  status=response_data.get('status', 'ready'),
2030
2236
  jobs={},
2031
- job_queue_positions={}
2237
+ job_queue_positions={},
2238
+ job_plan=[]
2032
2239
  )
2033
2240
 
2034
2241
  def upload_df_and_create_session(self, df=None, filename: str = "data.csv", file_path: str = None,
@@ -2304,23 +2511,57 @@ class FeatrixSphereClient:
2304
2511
  compression_ratio = (1 - compressed_size / original_size) * 100
2305
2512
  print(f"Converted Parquet to CSV and compressed from {original_size:,} to {compressed_size:,} bytes ({compression_ratio:.1f}% reduction)")
2306
2513
  else:
2307
- # Regular CSV file - read and compress it
2514
+ # Regular CSV file - check size and suggest Parquet for large files
2308
2515
  with open(file_path, 'rb') as f:
2309
2516
  csv_content = f.read()
2310
2517
 
2311
- # Compress the content
2312
- print("Compressing CSV file...")
2313
- compressed_buffer = io.BytesIO()
2314
- with gzip.GzipFile(fileobj=compressed_buffer, mode='wb') as gz:
2315
- gz.write(csv_content)
2316
- file_content = compressed_buffer.getvalue()
2317
- upload_filename = os.path.basename(file_path) + '.gz'
2318
- content_type = 'application/gzip'
2518
+ csv_size_mb = len(csv_content) / (1024 * 1024)
2519
+ CSV_WARNING_THRESHOLD_MB = 1.0 # Warn if CSV > 1MB
2319
2520
 
2320
- original_size = len(csv_content)
2321
- compressed_size = len(file_content)
2322
- compression_ratio = (1 - compressed_size / original_size) * 100
2323
- print(f"Compressed from {original_size:,} to {compressed_size:,} bytes ({compression_ratio:.1f}% reduction)")
2521
+ if csv_size_mb > CSV_WARNING_THRESHOLD_MB:
2522
+ print(f"\n⚠️ Warning: CSV file is {csv_size_mb:.1f} MB")
2523
+ print(f" Parquet format is more efficient for large files (smaller size, faster upload).")
2524
+ print(f" Converting to Parquet format for better performance...")
2525
+
2526
+ # Read CSV as DataFrame
2527
+ csv_df = pd.read_csv(file_path)
2528
+
2529
+ # Convert to Parquet in memory
2530
+ parquet_buffer = io.BytesIO()
2531
+ try:
2532
+ # Try pyarrow first (faster), fallback to fastparquet
2533
+ csv_df.to_parquet(parquet_buffer, index=False, engine='pyarrow')
2534
+ except (ImportError, ValueError):
2535
+ # Fallback to fastparquet or default engine
2536
+ try:
2537
+ csv_df.to_parquet(parquet_buffer, index=False, engine='fastparquet')
2538
+ except (ImportError, ValueError):
2539
+ # Last resort: use default engine
2540
+ csv_df.to_parquet(parquet_buffer, index=False)
2541
+ parquet_content = parquet_buffer.getvalue()
2542
+ parquet_size_mb = len(parquet_content) / (1024 * 1024)
2543
+
2544
+ # Use Parquet instead of compressed CSV
2545
+ file_content = parquet_content
2546
+ upload_filename = os.path.basename(file_path).replace('.csv', '.parquet')
2547
+ content_type = 'application/octet-stream'
2548
+
2549
+ size_reduction = (1 - len(parquet_content) / len(csv_content)) * 100
2550
+ print(f" ✅ Converted to Parquet: {csv_size_mb:.1f} MB → {parquet_size_mb:.1f} MB ({size_reduction:.1f}% reduction)")
2551
+ else:
2552
+ # Small CSV - compress as before
2553
+ print("Compressing CSV file...")
2554
+ compressed_buffer = io.BytesIO()
2555
+ with gzip.GzipFile(fileobj=compressed_buffer, mode='wb') as gz:
2556
+ gz.write(csv_content)
2557
+ file_content = compressed_buffer.getvalue()
2558
+ upload_filename = os.path.basename(file_path) + '.gz'
2559
+ content_type = 'application/gzip'
2560
+
2561
+ original_size = len(csv_content)
2562
+ compressed_size = len(file_content)
2563
+ compression_ratio = (1 - compressed_size / original_size) * 100
2564
+ print(f"Compressed from {original_size:,} to {compressed_size:,} bytes ({compression_ratio:.1f}% reduction)")
2324
2565
 
2325
2566
  # Handle DataFrame input
2326
2567
  else:
@@ -2329,29 +2570,31 @@ class FeatrixSphereClient:
2329
2570
 
2330
2571
  print(f"Uploading DataFrame ({len(df)} rows, {len(df.columns)} columns)")
2331
2572
 
2332
- # Clean NaN values in DataFrame before CSV conversion
2573
+ # Clean NaN values in DataFrame before conversion
2333
2574
  # This prevents JSON encoding issues when the server processes the data
2334
2575
  # Use pandas.notna() with where() for compatibility with all pandas versions
2335
2576
  cleaned_df = df.where(pd.notna(df), None) # Replace NaN with None for JSON compatibility
2336
2577
 
2337
- # Convert DataFrame to CSV and compress
2338
- csv_buffer = io.StringIO()
2339
- cleaned_df.to_csv(csv_buffer, index=False)
2340
- csv_data = csv_buffer.getvalue().encode('utf-8')
2341
-
2342
- # Compress the CSV data
2343
- print("Compressing DataFrame...")
2344
- compressed_buffer = io.BytesIO()
2345
- with gzip.GzipFile(fileobj=compressed_buffer, mode='wb') as gz:
2346
- gz.write(csv_data)
2347
- file_content = compressed_buffer.getvalue()
2348
- upload_filename = filename if filename.endswith('.gz') else filename + '.gz'
2349
- content_type = 'application/gzip'
2350
-
2351
- original_size = len(csv_data)
2352
- compressed_size = len(file_content)
2353
- compression_ratio = (1 - compressed_size / original_size) * 100
2354
- print(f"Compressed from {original_size:,} to {compressed_size:,} bytes ({compression_ratio:.1f}% reduction)")
2578
+ # Always use Parquet format for DataFrames (smaller, faster than CSV.gz)
2579
+ print("Converting DataFrame to Parquet format...")
2580
+ parquet_buffer = io.BytesIO()
2581
+ try:
2582
+ # Try pyarrow first (faster), fallback to fastparquet
2583
+ cleaned_df.to_parquet(parquet_buffer, index=False, engine='pyarrow')
2584
+ except (ImportError, ValueError):
2585
+ # Fallback to fastparquet or default engine
2586
+ try:
2587
+ cleaned_df.to_parquet(parquet_buffer, index=False, engine='fastparquet')
2588
+ except (ImportError, ValueError):
2589
+ # Last resort: use default engine
2590
+ cleaned_df.to_parquet(parquet_buffer, index=False)
2591
+
2592
+ file_content = parquet_buffer.getvalue()
2593
+ parquet_size_mb = len(file_content) / (1024 * 1024)
2594
+ upload_filename = filename.replace('.csv', '.parquet') if filename.endswith('.csv') else filename + '.parquet'
2595
+ content_type = 'application/octet-stream'
2596
+
2597
+ print(f"✅ Saved as Parquet: {parquet_size_mb:.2f} MB")
2355
2598
 
2356
2599
  # Upload the compressed file with optional column overrides
2357
2600
  files = {'file': (upload_filename, file_content, content_type)}
@@ -2398,9 +2641,18 @@ class FeatrixSphereClient:
2398
2641
  file_size_mb = len(file_content) / (1024 * 1024)
2399
2642
  CHUNK_SIZE_MB = 512 # 512 MB chunk size
2400
2643
  CHUNK_SIZE_BYTES = CHUNK_SIZE_MB * 1024 * 1024
2644
+ LARGE_FILE_WARNING_MB = 10 # Warn if file > 10 MB
2645
+
2646
+ if file_size_mb > LARGE_FILE_WARNING_MB:
2647
+ print(f"\n⚠️ Warning: File size ({file_size_mb:.1f} MB) is quite large")
2648
+ print(f" For very large files (>10 MB), consider using S3 uploads:")
2649
+ print(f" 1. Upload your file to S3 (or your cloud storage)")
2650
+ print(f" 2. Generate a signed/private URL with read access")
2651
+ print(f" 3. Contact Featrix support to configure S3-based uploads")
2652
+ print(f" This can be more reliable than direct uploads for large datasets.")
2401
2653
 
2402
2654
  if file_size_mb > CHUNK_SIZE_MB:
2403
- print(f"⚠️ Warning: File size ({file_size_mb:.1f} MB) exceeds {CHUNK_SIZE_MB} MB threshold")
2655
+ print(f"\n⚠️ Warning: File size ({file_size_mb:.1f} MB) exceeds {CHUNK_SIZE_MB} MB threshold")
2404
2656
  print(f" Large uploads may timeout. Consider splitting the data or using smaller batches.")
2405
2657
 
2406
2658
  # Try upload with retry on 504
@@ -2446,7 +2698,8 @@ class FeatrixSphereClient:
2446
2698
  session_type=response_data.get('session_type', 'sphere'),
2447
2699
  status=response_data.get('status', 'ready'),
2448
2700
  jobs={},
2449
- job_queue_positions={}
2701
+ job_queue_positions={},
2702
+ job_plan=[]
2450
2703
  )
2451
2704
 
2452
2705
 
@@ -2559,7 +2812,8 @@ class FeatrixSphereClient:
2559
2812
  # =========================================================================
2560
2813
 
2561
2814
  def predict(self, session_id: str, record: Dict[str, Any], target_column: str = None,
2562
- predictor_id: str = None, max_retries: int = None, queue_batches: bool = False) -> Dict[str, Any]:
2815
+ predictor_id: str = None, best_metric_preference: str = None,
2816
+ max_retries: int = None, queue_batches: bool = False) -> Dict[str, Any]:
2563
2817
  """
2564
2818
  Make a single prediction for a record.
2565
2819
 
@@ -2568,6 +2822,7 @@ class FeatrixSphereClient:
2568
2822
  record: Record dictionary (without target column)
2569
2823
  target_column: Specific target column predictor to use (required if multiple predictors exist and predictor_id not specified)
2570
2824
  predictor_id: Specific predictor ID to use (recommended - more precise than target_column)
2825
+ best_metric_preference: Which metric checkpoint to use: "roc_auc", "pr_auc", or None (use default checkpoint) (default: None)
2571
2826
  max_retries: Number of retries for errors (default: uses client default)
2572
2827
  queue_batches: If True, queue this prediction for batch processing instead of immediate API call
2573
2828
 
@@ -2625,6 +2880,10 @@ class FeatrixSphereClient:
2625
2880
  if resolved_predictor_id:
2626
2881
  request_payload["predictor_id"] = resolved_predictor_id
2627
2882
 
2883
+ # Include best_metric_preference if specified
2884
+ if best_metric_preference:
2885
+ request_payload["best_metric_preference"] = best_metric_preference
2886
+
2628
2887
  response_data = self._post_json(f"/session/{session_id}/predict", request_payload, max_retries=max_retries)
2629
2888
  return response_data
2630
2889
 
@@ -2707,7 +2966,7 @@ class FeatrixSphereClient:
2707
2966
  def plot_training_loss(self, session_id: str, figsize: Tuple[int, int] = (12, 8),
2708
2967
  style: str = 'notebook', save_path: Optional[str] = None,
2709
2968
  show_learning_rate: bool = True, smooth: bool = True,
2710
- title: Optional[str] = None) -> 'plt.Figure':
2969
+ title: Optional[str] = None):
2711
2970
  """
2712
2971
  Plot comprehensive training loss curves for a session (both embedding space and single predictor).
2713
2972
 
@@ -2795,7 +3054,7 @@ class FeatrixSphereClient:
2795
3054
 
2796
3055
  def plot_embedding_space_training(self, session_id: str, figsize: Tuple[int, int] = (10, 6),
2797
3056
  style: str = 'notebook', save_path: Optional[str] = None,
2798
- show_mutual_info: bool = False) -> 'plt.Figure':
3057
+ show_mutual_info: bool = False):
2799
3058
  """
2800
3059
  Plot detailed embedding space training metrics.
2801
3060
 
@@ -2871,7 +3130,7 @@ class FeatrixSphereClient:
2871
3130
 
2872
3131
  def plot_single_predictor_training(self, session_id: str, figsize: Tuple[int, int] = (10, 6),
2873
3132
  style: str = 'notebook', save_path: Optional[str] = None,
2874
- show_metrics: bool = True) -> 'plt.Figure':
3133
+ show_metrics: bool = True):
2875
3134
  """
2876
3135
  Plot detailed single predictor training metrics.
2877
3136
 
@@ -2947,7 +3206,7 @@ class FeatrixSphereClient:
2947
3206
 
2948
3207
  def plot_training_comparison(self, session_ids: List[str], labels: Optional[List[str]] = None,
2949
3208
  figsize: Tuple[int, int] = (12, 8), style: str = 'notebook',
2950
- save_path: Optional[str] = None) -> plt.Figure:
3209
+ save_path: Optional[str] = None):
2951
3210
  """
2952
3211
  Compare training curves across multiple sessions.
2953
3212
 
@@ -3200,7 +3459,7 @@ class FeatrixSphereClient:
3200
3459
  def plot_embedding_space_3d(self, session_id: str, sample_size: int = 2000,
3201
3460
  color_by: Optional[str] = None, size_by: Optional[str] = None,
3202
3461
  interactive: bool = True, style: str = 'notebook',
3203
- title: Optional[str] = None, save_path: Optional[str] = None) -> Union[plt.Figure, 'go.Figure']:
3462
+ title: Optional[str] = None, save_path: Optional[str] = None):
3204
3463
  """
3205
3464
  Create interactive 3D visualization of the embedding space.
3206
3465
 
@@ -3276,7 +3535,7 @@ class FeatrixSphereClient:
3276
3535
  style: str = 'notebook', save_path: Optional[str] = None,
3277
3536
  show_embedding_evolution: bool = True,
3278
3537
  show_loss_evolution: bool = True,
3279
- fps: int = 2, notebook_mode: bool = True) -> Union[plt.Figure, 'HTML']:
3538
+ fps: int = 2, notebook_mode: bool = True):
3280
3539
  """
3281
3540
  Create an animated training movie showing loss curves and embedding evolution.
3282
3541
 
@@ -3330,7 +3589,7 @@ class FeatrixSphereClient:
3330
3589
 
3331
3590
  def plot_embedding_evolution(self, session_id: str, epoch_range: Optional[Tuple[int, int]] = None,
3332
3591
  interactive: bool = True, sample_size: int = 1000,
3333
- color_by: Optional[str] = None) -> Union[plt.Figure, 'go.Figure']:
3592
+ color_by: Optional[str] = None):
3334
3593
  """
3335
3594
  Show how embedding space evolves during training across epochs.
3336
3595
 
@@ -3766,7 +4025,18 @@ class FeatrixSphereClient:
3766
4025
  available_predictors = self._get_available_predictors(session_id, debug=debug)
3767
4026
 
3768
4027
  if not available_predictors:
3769
- raise ValueError(f"No trained predictors found for session {session_id}")
4028
+ # Don't fail here - let the server try to find/auto-discover the predictor
4029
+ # The server's /predict endpoint has smart fallback logic to find checkpoint files
4030
+ # even if the session file wasn't properly updated (e.g., training crashed)
4031
+ if debug:
4032
+ print(f"⚠️ No predictors found via models endpoint, letting server handle discovery")
4033
+ return {
4034
+ 'target_column': target_column,
4035
+ 'predictor_id': predictor_id,
4036
+ 'path': None,
4037
+ 'type': None,
4038
+ 'server_discovery': True # Flag that server should auto-discover
4039
+ }
3770
4040
 
3771
4041
  # If predictor_id is provided, find it directly (since it's now the key)
3772
4042
  if predictor_id:
@@ -4644,6 +4914,7 @@ class FeatrixSphereClient:
4644
4914
  status="running",
4645
4915
  jobs={},
4646
4916
  job_queue_positions={},
4917
+ job_plan=[],
4647
4918
  _client=self
4648
4919
  )
4649
4920
 
@@ -5495,6 +5766,7 @@ class FeatrixSphereClient:
5495
5766
  status="running",
5496
5767
  jobs={},
5497
5768
  job_queue_positions={},
5769
+ job_plan=[],
5498
5770
  _client=self
5499
5771
  )
5500
5772
 
@@ -6078,7 +6350,8 @@ class FeatrixSphereClient:
6078
6350
  # =========================================================================
6079
6351
 
6080
6352
  def predict_table(self, session_id: str, table_data: Dict[str, Any],
6081
- target_column: str = None, predictor_id: str = None, max_retries: int = None) -> Dict[str, Any]:
6353
+ target_column: str = None, predictor_id: str = None,
6354
+ best_metric_preference: str = None, max_retries: int = None) -> Dict[str, Any]:
6082
6355
  """
6083
6356
  Make batch predictions using JSON Tables format.
6084
6357
 
@@ -6119,6 +6392,8 @@ class FeatrixSphereClient:
6119
6392
  table_data['target_column'] = target_column
6120
6393
  if predictor_id:
6121
6394
  table_data['predictor_id'] = predictor_id
6395
+ if best_metric_preference:
6396
+ table_data['best_metric_preference'] = best_metric_preference
6122
6397
 
6123
6398
  try:
6124
6399
  response_data = self._post_json(f"/session/{session_id}/predict_table", table_data, max_retries=max_retries)
@@ -6131,7 +6406,8 @@ class FeatrixSphereClient:
6131
6406
  raise
6132
6407
 
6133
6408
  def predict_records(self, session_id: str, records: List[Dict[str, Any]],
6134
- target_column: str = None, predictor_id: str = None, batch_size: int = 2500, use_async: bool = False,
6409
+ target_column: str = None, predictor_id: str = None, best_metric_preference: str = None,
6410
+ batch_size: int = 2500, use_async: bool = False,
6135
6411
  show_progress_bar: bool = True, print_target_column_warning: bool = True) -> Dict[str, Any]:
6136
6412
  """
6137
6413
  Make batch predictions on a list of records with automatic client-side batching.
@@ -6178,7 +6454,8 @@ class FeatrixSphereClient:
6178
6454
  table_data = JSONTablesEncoder.from_records(cleaned_records)
6179
6455
 
6180
6456
  try:
6181
- result = self.predict_table(session_id, table_data)
6457
+ result = self.predict_table(session_id, table_data, target_column=target_column,
6458
+ predictor_id=predictor_id, best_metric_preference=best_metric_preference)
6182
6459
 
6183
6460
  # Check if server returned an async job
6184
6461
  if result.get('async') and result.get('job_id'):
@@ -6215,7 +6492,8 @@ class FeatrixSphereClient:
6215
6492
  table_data = JSONTablesEncoder.from_records(cleaned_records)
6216
6493
 
6217
6494
  try:
6218
- return self.predict_table(session_id, table_data)
6495
+ return self.predict_table(session_id, table_data, target_column=target_column,
6496
+ predictor_id=predictor_id, best_metric_preference=best_metric_preference)
6219
6497
  except Exception as e:
6220
6498
  if "404" in str(e) and "Single predictor not found" in str(e):
6221
6499
  self._raise_predictor_not_found_error(session_id, "predict_records")
@@ -6245,7 +6523,8 @@ class FeatrixSphereClient:
6245
6523
  table_data = JSONTablesEncoder.from_records(chunk_records)
6246
6524
 
6247
6525
  # Make prediction
6248
- chunk_result = self.predict_table(session_id, table_data)
6526
+ chunk_result = self.predict_table(session_id, table_data, target_column=target_column,
6527
+ predictor_id=predictor_id, best_metric_preference=best_metric_preference)
6249
6528
  chunk_predictions = chunk_result.get('predictions', [])
6250
6529
 
6251
6530
  # Adjust row indices to match original dataset
@@ -6606,7 +6885,8 @@ class FeatrixSphereClient:
6606
6885
  print(f"\n⏰ Timeout after {max_wait_time} seconds")
6607
6886
  return {'status': 'timeout', 'message': f'Job did not complete within {max_wait_time} seconds'}
6608
6887
 
6609
- def predict_df(self, session_id: str, df, target_column: str = None, predictor_id: str = None, show_progress_bar: bool = True, print_target_column_warning: bool = True) -> Dict[str, Any]:
6888
+ def predict_df(self, session_id: str, df, target_column: str = None, predictor_id: str = None,
6889
+ best_metric_preference: str = None, show_progress_bar: bool = True, print_target_column_warning: bool = True) -> Dict[str, Any]:
6610
6890
  """
6611
6891
  Make batch predictions on a pandas DataFrame.
6612
6892
 
@@ -6631,7 +6911,8 @@ class FeatrixSphereClient:
6631
6911
  records = df.to_dict(orient='records')
6632
6912
  # Clean NaNs for JSON encoding
6633
6913
  cleaned_records = self.replace_nans_with_nulls(records)
6634
- return self.predict_records(session_id, cleaned_records, target_column=target_column, predictor_id=predictor_id, show_progress_bar=show_progress_bar, print_target_column_warning=print_target_column_warning)
6914
+ return self.predict_records(session_id, cleaned_records, target_column=target_column, predictor_id=predictor_id,
6915
+ best_metric_preference=best_metric_preference, show_progress_bar=show_progress_bar, print_target_column_warning=print_target_column_warning)
6635
6916
 
6636
6917
  def _raise_predictor_not_found_error(self, session_id: str, method_name: str):
6637
6918
  """
@@ -6762,28 +7043,54 @@ class FeatrixSphereClient:
6762
7043
  training_metrics = models.get('training_metrics', {})
6763
7044
  if debug:
6764
7045
  print(f"🔍 Debug: training_metrics available = {training_metrics.get('available')}")
7046
+ target_column = None
7047
+ metadata = {}
7048
+
6765
7049
  if training_metrics.get('available'):
6766
- metrics_data = self.get_training_metrics(session_id)
6767
- if debug:
6768
- print(f"🔍 Debug: metrics_data keys = {list(metrics_data.keys())}")
6769
- training_metrics_inner = metrics_data.get('training_metrics', {})
6770
- if debug:
6771
- print(f"🔍 Debug: training_metrics_inner keys = {list(training_metrics_inner.keys()) if training_metrics_inner else 'None'}")
6772
- target_column = training_metrics_inner.get('target_column')
6773
- if debug:
6774
- print(f"🔍 Debug: extracted target_column = {target_column}")
6775
- if target_column:
6776
- # Extract metadata from training metrics
6777
- metadata = self._extract_predictor_metadata(metrics_data, debug)
6778
-
6779
- # Generate unique predictor ID
6780
- predictor_path = single_predictor.get('path', '')
7050
+ try:
7051
+ metrics_data = self.get_training_metrics(session_id)
7052
+ if debug:
7053
+ print(f"🔍 Debug: metrics_data keys = {list(metrics_data.keys())}")
7054
+ training_metrics_inner = metrics_data.get('training_metrics', {})
7055
+ if debug:
7056
+ print(f"🔍 Debug: training_metrics_inner keys = {list(training_metrics_inner.keys()) if training_metrics_inner else 'None'}")
7057
+ target_column = training_metrics_inner.get('target_column')
7058
+ if debug:
7059
+ print(f"🔍 Debug: extracted target_column = {target_column}")
7060
+ if target_column:
7061
+ # Extract metadata from training metrics
7062
+ metadata = self._extract_predictor_metadata(metrics_data, debug)
7063
+ except Exception as e:
7064
+ if debug:
7065
+ print(f"⚠️ Could not get training metrics: {e}")
7066
+
7067
+ # Fallback: try to get target column from job_plan
7068
+ if not target_column:
7069
+ job_plan = session.get('job_plan', [])
7070
+ for job in job_plan:
7071
+ if job.get('job_type') == 'train_single_predictor':
7072
+ spec = job.get('spec', {})
7073
+ target_column = spec.get('target_column')
7074
+ if target_column:
7075
+ if debug:
7076
+ print(f"🔍 Debug: extracted target_column from job_plan: {target_column}")
7077
+ break
7078
+
7079
+ # If predictor is available, add it even without target_column (can be None)
7080
+ if single_predictor.get('available') or single_predictor.get('predictors'):
7081
+ # Generate unique predictor ID
7082
+ predictor_path = single_predictor.get('path', '')
7083
+ if not predictor_path and single_predictor.get('predictors'):
7084
+ # Use first predictor from new format
7085
+ predictor_path = single_predictor.get('predictors', [{}])[0].get('path', '')
7086
+
7087
+ if predictor_path:
6781
7088
  predictor_id = self._generate_predictor_id(predictor_path, 'single_predictor')
6782
7089
 
6783
7090
  predictors[predictor_id] = {
6784
7091
  'predictor_id': predictor_id,
6785
7092
  'path': predictor_path,
6786
- 'target_column': target_column,
7093
+ 'target_column': target_column, # Can be None
6787
7094
  'available': True,
6788
7095
  'type': 'single_predictor',
6789
7096
  **metadata # Include epochs, validation_loss, job_status, etc.
@@ -6806,6 +7113,9 @@ class FeatrixSphereClient:
6806
7113
  if debug:
6807
7114
  print(f"🔍 Debug: single_predictors array = {single_predictors_paths}")
6808
7115
  if single_predictors_paths:
7116
+ target_column = None
7117
+ metadata = {}
7118
+
6809
7119
  # Try to get target column info from training metrics
6810
7120
  training_metrics = models.get('training_metrics', {})
6811
7121
  if training_metrics.get('available'):
@@ -6815,30 +7125,44 @@ class FeatrixSphereClient:
6815
7125
  if target_column:
6816
7126
  # Extract metadata from training metrics
6817
7127
  metadata = self._extract_predictor_metadata(metrics_data, debug)
6818
-
6819
- # Add each predictor individually with its own predictor_id key
6820
- for i, path in enumerate(single_predictors_paths):
6821
- predictor_id = self._generate_predictor_id(path, f'multiple_predictor_{i}')
6822
-
6823
- predictors[predictor_id] = {
6824
- 'predictor_id': predictor_id,
6825
- 'path': path,
6826
- 'target_column': target_column,
6827
- 'available': True,
6828
- 'type': 'single_predictor', # Each is treated as individual predictor
6829
- 'predictor_index': i, # Track original index for compatibility
6830
- **metadata # Include epochs, validation_loss, job_status, etc.
6831
- }
6832
- if debug:
6833
- print(f"✅ Added predictor {i} for target_column: {target_column}")
6834
- print(f" Predictor ID: {predictor_id}")
6835
- print(f" Path: {path}")
6836
-
6837
- if debug:
6838
- print(f" Total predictors added: {len(single_predictors_paths)}")
6839
- print(f" Shared metadata: {metadata}")
6840
7128
  except Exception as e:
6841
- print(f"Warning: Could not extract target column from training metrics: {e}")
7129
+ if debug:
7130
+ print(f"⚠️ Could not get training metrics: {e}")
7131
+
7132
+ # Fallback: try to get target column from job_plan
7133
+ if not target_column:
7134
+ job_plan = session.get('job_plan', [])
7135
+ for job in job_plan:
7136
+ if job.get('job_type') == 'train_single_predictor':
7137
+ spec = job.get('spec', {})
7138
+ target_column = spec.get('target_column')
7139
+ if target_column:
7140
+ if debug:
7141
+ print(f"🔍 Debug: extracted target_column from job_plan: {target_column}")
7142
+ break
7143
+
7144
+ # Add each predictor even if target_column is None
7145
+ for i, path in enumerate(single_predictors_paths):
7146
+ predictor_id = self._generate_predictor_id(path, f'multiple_predictor_{i}')
7147
+
7148
+ predictors[predictor_id] = {
7149
+ 'predictor_id': predictor_id,
7150
+ 'path': path,
7151
+ 'target_column': target_column, # Can be None
7152
+ 'available': True,
7153
+ 'type': 'single_predictor', # Each is treated as individual predictor
7154
+ 'predictor_index': i, # Track original index for compatibility
7155
+ **metadata # Include epochs, validation_loss, job_status, etc.
7156
+ }
7157
+ if debug:
7158
+ print(f"✅ Added predictor {i} for target_column: {target_column}")
7159
+ print(f" Predictor ID: {predictor_id}")
7160
+ print(f" Path: {path}")
7161
+
7162
+ if debug:
7163
+ print(f" Total predictors added: {len(single_predictors_paths)}")
7164
+ if metadata:
7165
+ print(f" Shared metadata: {metadata}")
6842
7166
 
6843
7167
  # Fallback: check old format single_predictor field
6844
7168
  single_predictor_path = session.get('single_predictor')
@@ -6915,7 +7239,7 @@ class FeatrixSphereClient:
6915
7239
  target_column: Specific target column to validate, or None for auto-detect
6916
7240
 
6917
7241
  Returns:
6918
- Validated target column name
7242
+ Validated target column name (or None if can't determine, server will handle)
6919
7243
 
6920
7244
  Raises:
6921
7245
  ValueError: If target_column is invalid or multiple predictors exist without specification
@@ -6923,7 +7247,8 @@ class FeatrixSphereClient:
6923
7247
  available_predictors = self._get_available_predictors(session_id)
6924
7248
 
6925
7249
  if not available_predictors:
6926
- raise ValueError(f"No trained predictors found for session {session_id}")
7250
+ # Don't fail - let server handle discovery. Return provided target_column or None.
7251
+ return target_column
6927
7252
 
6928
7253
  if target_column is None:
6929
7254
  # Auto-detect: only valid if there's exactly one predictor
@@ -6977,6 +7302,10 @@ class FeatrixSphereClient:
6977
7302
  # Re-raise validation errors
6978
7303
  raise e
6979
7304
 
7305
+ # If we couldn't determine target column (server will handle), just return records as-is
7306
+ if validated_target_column is None:
7307
+ return records
7308
+
6980
7309
  if validated_target_column in records[0]:
6981
7310
  if print_warning:
6982
7311
  print(f"⚠️ Warning: Removing target column '{validated_target_column}' from prediction data")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: featrixsphere
3
- Version: 0.2.3613
3
+ Version: 0.2.4982
4
4
  Summary: Transform any CSV into a production-ready ML model in minutes, not months.
5
5
  Home-page: https://github.com/Featrix/sphere
6
6
  Author: Featrix
@@ -0,0 +1,8 @@
1
+ featrixsphere/__init__.py,sha256=6e5gN0j6g8dLuetEEDGVk9ZXasdgsuii0qNF83OQIYU,1888
2
+ featrixsphere/client.py,sha256=HouSLZcEoGjKZy9c0HK0TG-Xog7L_12kCik54TrwzS8,432653
3
+ featrixsphere/test_client.py,sha256=4SiRbib0ms3poK0UpnUv4G0HFQSzidF3Iswo_J2cjLk,11981
4
+ featrixsphere-0.2.4982.dist-info/METADATA,sha256=DI-xhIm1bAk2rz-5XxK-2jdQsySuQGRZPTLN7LcxEiM,16232
5
+ featrixsphere-0.2.4982.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
+ featrixsphere-0.2.4982.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
7
+ featrixsphere-0.2.4982.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
8
+ featrixsphere-0.2.4982.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- featrixsphere/__init__.py,sha256=PGw5nesAoGHUrtvevAaFuhNq3Qfgzm7JYrbvMx6bCU4,1888
2
- featrixsphere/client.py,sha256=TiEZIT1Qmc983CbPljc8__e0jJRnpQ3Lf6SabwrvLlo,415649
3
- featrixsphere/test_client.py,sha256=4SiRbib0ms3poK0UpnUv4G0HFQSzidF3Iswo_J2cjLk,11981
4
- featrixsphere-0.2.3613.dist-info/METADATA,sha256=0AlxzTM9nXmXhPka3AUhMhIAgdAMxPhzU0Sy18FXc8E,16232
5
- featrixsphere-0.2.3613.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
- featrixsphere-0.2.3613.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
7
- featrixsphere-0.2.3613.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
8
- featrixsphere-0.2.3613.dist-info/RECORD,,