featrixsphere 0.2.3737__py3-none-any.whl → 0.2.4982__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
featrixsphere/__init__.py CHANGED
@@ -38,7 +38,7 @@ Example:
38
38
  ... labels=['Experiment A', 'Experiment B'])
39
39
  """
40
40
 
41
- __version__ = "0.2.3737"
41
+ __version__ = "0.2.4982"
42
42
  __author__ = "Featrix"
43
43
  __email__ = "support@featrix.com"
44
44
  __license__ = "MIT"
featrixsphere/client.py CHANGED
@@ -11,7 +11,7 @@ import time
11
11
  import requests
12
12
  from pathlib import Path
13
13
  from typing import Dict, Any, Optional, List, Tuple, Union
14
- from dataclasses import dataclass
14
+ from dataclasses import dataclass, field
15
15
  import gzip
16
16
  import os
17
17
  import random
@@ -65,6 +65,7 @@ class SessionInfo:
65
65
  status: str
66
66
  jobs: Dict[str, Any]
67
67
  job_queue_positions: Dict[str, Any]
68
+ job_plan: List[Dict[str, Any]] = field(default_factory=list)
68
69
  _client: Optional['FeatrixSphereClient'] = None
69
70
 
70
71
  def predictors(self) -> List[Dict[str, Any]]:
@@ -82,6 +83,51 @@ class SessionInfo:
82
83
  return list(predictors_dict.values())
83
84
  except Exception:
84
85
  return []
86
+
87
+ def embedding_space_info(self) -> Optional[Dict[str, Any]]:
88
+ """
89
+ Get embedding space information for this session.
90
+
91
+ Returns:
92
+ Dictionary with ES info (dimensions, epochs, etc.) or None if not available
93
+ """
94
+ if not self._client:
95
+ return None
96
+
97
+ try:
98
+ # Get session details from the client
99
+ session_data = self._client._get_json(f"/compute/session/{self.session_id}")
100
+
101
+ es_info = {}
102
+
103
+ # Extract embedding space path
104
+ embedding_space_path = session_data.get('embedding_space')
105
+ if embedding_space_path:
106
+ es_info['embedding_space_path'] = embedding_space_path
107
+
108
+ # Extract model architecture info
109
+ model_info = session_data.get('model_info', {}) or session_data.get('embedding_space', {})
110
+ if isinstance(model_info, dict):
111
+ es_info['d_model'] = model_info.get('d_model') or model_info.get('embedding_dim')
112
+ es_info['dimensions'] = es_info.get('d_model') # Alias for compatibility
113
+ es_info['parameter_count'] = model_info.get('parameter_count') or model_info.get('num_parameters')
114
+ es_info['layer_count'] = model_info.get('layer_count') or model_info.get('num_layers')
115
+
116
+ # Extract training statistics
117
+ training_stats = session_data.get('training_stats', {}) or session_data.get('stats', {})
118
+ if isinstance(training_stats, dict):
119
+ es_info['epochs'] = training_stats.get('final_epoch') or training_stats.get('epochs_trained') or training_stats.get('epochs')
120
+ es_info['final_loss'] = training_stats.get('final_loss') or training_stats.get('loss')
121
+ es_info['final_val_loss'] = training_stats.get('final_val_loss') or training_stats.get('validation_loss')
122
+ es_info['training_time_seconds'] = training_stats.get('training_time') or training_stats.get('elapsed_seconds')
123
+
124
+ # If we have any info, return it
125
+ if es_info:
126
+ return es_info
127
+
128
+ return None
129
+ except Exception:
130
+ return None
85
131
 
86
132
 
87
133
  class PredictionBatch:
@@ -617,6 +663,7 @@ class FeatrixSphereClient:
617
663
  status=response_data.get('status', 'unknown'),
618
664
  jobs={},
619
665
  job_queue_positions={},
666
+ job_plan=[],
620
667
  _client=self
621
668
  )
622
669
 
@@ -640,6 +687,7 @@ class FeatrixSphereClient:
640
687
  session = response_data.get('session', {})
641
688
  jobs = response_data.get('jobs', {})
642
689
  positions = response_data.get('job_queue_positions', {})
690
+ job_plan = session.get('job_plan', [])
643
691
 
644
692
  return SessionInfo(
645
693
  session_id=session.get('session_id', session_id),
@@ -647,6 +695,7 @@ class FeatrixSphereClient:
647
695
  status=session.get('status', 'unknown'),
648
696
  jobs=jobs,
649
697
  job_queue_positions=positions,
698
+ job_plan=job_plan,
650
699
  _client=self
651
700
  )
652
701
 
@@ -1370,6 +1419,20 @@ class FeatrixSphereClient:
1370
1419
 
1371
1420
  start_time = time.time()
1372
1421
 
1422
+ # Initial wait for job dispatch (jobs are dispatched asynchronously after session creation)
1423
+ # Wait up to 10 seconds for jobs to appear before starting main monitoring loop
1424
+ initial_wait_timeout = 10
1425
+ initial_wait_start = time.time()
1426
+ jobs_appeared = False
1427
+
1428
+ while time.time() - initial_wait_start < initial_wait_timeout:
1429
+ session_info = self.get_session_status(session_id)
1430
+ if session_info.jobs:
1431
+ jobs_appeared = True
1432
+ break
1433
+ time.sleep(0.5) # Check every 500ms during initial wait
1434
+
1435
+ # Main monitoring loop
1373
1436
  while time.time() - start_time < max_wait_time:
1374
1437
  session_info = self.get_session_status(session_id)
1375
1438
  elapsed = time.time() - start_time
@@ -1842,7 +1905,8 @@ class FeatrixSphereClient:
1842
1905
  session_type=response_data.get('session_type', 'embedding_space'),
1843
1906
  status=response_data.get('status', 'ready'),
1844
1907
  jobs={},
1845
- job_queue_positions={}
1908
+ job_queue_positions={},
1909
+ job_plan=[]
1846
1910
  )
1847
1911
 
1848
1912
  def fine_tune_embedding_space(
@@ -1972,7 +2036,144 @@ class FeatrixSphereClient:
1972
2036
  session_type=response_data.get('session_type', 'embedding_space_finetune'),
1973
2037
  status=response_data.get('status', 'ready'),
1974
2038
  jobs={},
1975
- job_queue_positions={}
2039
+ job_queue_positions={},
2040
+ job_plan=[]
2041
+ )
2042
+
2043
+ def extend_embedding_space(
2044
+ self,
2045
+ name: str,
2046
+ parent_session_id: str = None,
2047
+ parent_embedding_space_path: str = None,
2048
+ s3_training_dataset: str = None,
2049
+ s3_validation_dataset: str = None,
2050
+ n_epochs: int = None,
2051
+ webhooks: Dict[str, str] = None,
2052
+ user_metadata: Dict[str, Any] = None
2053
+ ) -> SessionInfo:
2054
+ """
2055
+ Extend an existing embedding space with new feature columns.
2056
+
2057
+ This method takes a pre-trained embedding space and extends it with new feature columns
2058
+ from enriched training/validation data. The extended ES preserves existing encoder weights
2059
+ and creates new codecs for the new columns.
2060
+
2061
+ **When to Use Extend vs Fine-Tune:**
2062
+ - **Extend**: When you've added NEW COLUMNS (features) to your dataset
2063
+ - **Fine-Tune**: When you have new rows with the SAME COLUMNS
2064
+
2065
+ **How It Works:**
2066
+ 1. Loads the parent embedding space
2067
+ 2. Identifies new columns in the enriched dataset
2068
+ 3. Creates codecs for the new columns
2069
+ 4. Copies existing encoder weights (preserves learned representations)
2070
+ 5. Fine-tunes for shorter duration (default: original_epochs / 4)
2071
+ 6. Returns extended embedding space with all columns
2072
+
2073
+ Args:
2074
+ name: Name for the extended embedding space
2075
+ parent_session_id: Session ID of the parent embedding space (optional)
2076
+ parent_embedding_space_path: Direct path to parent embedding space pickle file (optional)
2077
+ s3_training_dataset: S3 URL for enriched training dataset with new columns (must start with 's3://')
2078
+ s3_validation_dataset: S3 URL for enriched validation dataset with new columns (must start with 's3://')
2079
+ n_epochs: Number of epochs for extension training (default: original_epochs / 4)
2080
+ webhooks: Optional dict with webhook configuration keys
2081
+ user_metadata: Optional metadata dict to attach to the session
2082
+
2083
+ Returns:
2084
+ SessionInfo for the newly created extension session
2085
+
2086
+ Raises:
2087
+ ValueError: If S3 URLs are invalid or neither parent identifier is provided
2088
+
2089
+ Example:
2090
+ ```python
2091
+ # Extend an existing embedding space with new feature columns
2092
+ client = FeatrixSphereClient("https://sphere-api.featrix.com")
2093
+
2094
+ # Original ES was trained on: age, income, credit_score
2095
+ # New data includes engineered features: debt_to_income_ratio, age_bin
2096
+ extended = client.extend_embedding_space(
2097
+ name="customer_model_with_features",
2098
+ parent_session_id="abc123-20240101-120000",
2099
+ s3_training_dataset="s3://my-bucket/enriched_training.csv",
2100
+ s3_validation_dataset="s3://my-bucket/enriched_validation.csv",
2101
+ n_epochs=25 # Optional: specify epochs (defaults to original/4)
2102
+ )
2103
+
2104
+ # Wait for extension to complete
2105
+ client.wait_for_session_completion(extended.session_id)
2106
+
2107
+ # The extended ES now includes the new feature columns
2108
+ # Use it for predictions with enriched data
2109
+ result = client.predict(extended.session_id, {
2110
+ "age": 35,
2111
+ "income": 75000,
2112
+ "credit_score": 720,
2113
+ "debt_to_income_ratio": 0.25, # New feature!
2114
+ "age_bin": "30-40" # New feature!
2115
+ })
2116
+ ```
2117
+ """
2118
+ # Validate S3 URLs
2119
+ if s3_training_dataset and not s3_training_dataset.startswith('s3://'):
2120
+ raise ValueError("s3_training_dataset must be a valid S3 URL (s3://...)")
2121
+ if s3_validation_dataset and not s3_validation_dataset.startswith('s3://'):
2122
+ raise ValueError("s3_validation_dataset must be a valid S3 URL (s3://...)")
2123
+
2124
+ # Validate that we have either parent_session_id or parent_embedding_space_path
2125
+ if not parent_session_id and not parent_embedding_space_path:
2126
+ raise ValueError("Either parent_session_id or parent_embedding_space_path must be provided")
2127
+
2128
+ print(f"Extending embedding space '{name}' with new features...")
2129
+ if parent_session_id:
2130
+ print(f" Parent session: {parent_session_id}")
2131
+ if parent_embedding_space_path:
2132
+ print(f" Parent embedding space: {parent_embedding_space_path}")
2133
+ print(f" Enriched training data: {s3_training_dataset}")
2134
+ print(f" Enriched validation data: {s3_validation_dataset}")
2135
+ if n_epochs:
2136
+ print(f" Extension epochs: {n_epochs}")
2137
+ else:
2138
+ print(f" Extension epochs: auto (original/4)")
2139
+
2140
+ data = {
2141
+ "name": name,
2142
+ "s3_file_data_set_training": s3_training_dataset,
2143
+ "s3_file_data_set_validation": s3_validation_dataset
2144
+ }
2145
+
2146
+ if parent_session_id:
2147
+ data["parent_session_id"] = parent_session_id
2148
+ if parent_embedding_space_path:
2149
+ data["parent_embedding_space_path"] = parent_embedding_space_path
2150
+ if n_epochs is not None:
2151
+ data["n_epochs"] = n_epochs
2152
+
2153
+ if webhooks:
2154
+ data['webhooks'] = webhooks
2155
+ if user_metadata:
2156
+ import json
2157
+ data['user_metadata'] = json.dumps(user_metadata)
2158
+ print(f"User metadata: {user_metadata}")
2159
+
2160
+ response_data = self._post_json("/compute/extend-embedding-space", data)
2161
+
2162
+ session_id = response_data.get('session_id')
2163
+ extend_info = response_data.get('extend_es_info', {})
2164
+
2165
+ print(f"Extension session created: {session_id}")
2166
+ if extend_info:
2167
+ print(f" Original epochs: {extend_info.get('original_epochs', 'N/A')}")
2168
+ print(f" Extension epochs: {extend_info.get('extension_epochs', 'N/A')}")
2169
+
2170
+ return SessionInfo(
2171
+ session_id=session_id,
2172
+ session_type=response_data.get('session_type', 'embedding_space_extend'),
2173
+ status=response_data.get('status', 'ready'),
2174
+ jobs={},
2175
+ job_queue_positions={},
2176
+ job_plan=[]
1976
2177
  )
1977
2178
 
1978
2179
  # =========================================================================
@@ -2033,7 +2234,8 @@ class FeatrixSphereClient:
2033
2234
  session_type=response_data.get('session_type', 'sphere'),
2034
2235
  status=response_data.get('status', 'ready'),
2035
2236
  jobs={},
2036
- job_queue_positions={}
2237
+ job_queue_positions={},
2238
+ job_plan=[]
2037
2239
  )
2038
2240
 
2039
2241
  def upload_df_and_create_session(self, df=None, filename: str = "data.csv", file_path: str = None,
@@ -2309,23 +2511,57 @@ class FeatrixSphereClient:
2309
2511
  compression_ratio = (1 - compressed_size / original_size) * 100
2310
2512
  print(f"Converted Parquet to CSV and compressed from {original_size:,} to {compressed_size:,} bytes ({compression_ratio:.1f}% reduction)")
2311
2513
  else:
2312
- # Regular CSV file - read and compress it
2514
+ # Regular CSV file - check size and suggest Parquet for large files
2313
2515
  with open(file_path, 'rb') as f:
2314
2516
  csv_content = f.read()
2315
2517
 
2316
- # Compress the content
2317
- print("Compressing CSV file...")
2318
- compressed_buffer = io.BytesIO()
2319
- with gzip.GzipFile(fileobj=compressed_buffer, mode='wb') as gz:
2320
- gz.write(csv_content)
2321
- file_content = compressed_buffer.getvalue()
2322
- upload_filename = os.path.basename(file_path) + '.gz'
2323
- content_type = 'application/gzip'
2518
+ csv_size_mb = len(csv_content) / (1024 * 1024)
2519
+ CSV_WARNING_THRESHOLD_MB = 1.0 # Warn if CSV > 1MB
2324
2520
 
2325
- original_size = len(csv_content)
2326
- compressed_size = len(file_content)
2327
- compression_ratio = (1 - compressed_size / original_size) * 100
2328
- print(f"Compressed from {original_size:,} to {compressed_size:,} bytes ({compression_ratio:.1f}% reduction)")
2521
+ if csv_size_mb > CSV_WARNING_THRESHOLD_MB:
2522
+ print(f"\n⚠️ Warning: CSV file is {csv_size_mb:.1f} MB")
2523
+ print(f" Parquet format is more efficient for large files (smaller size, faster upload).")
2524
+ print(f" Converting to Parquet format for better performance...")
2525
+
2526
+ # Read CSV as DataFrame
2527
+ csv_df = pd.read_csv(file_path)
2528
+
2529
+ # Convert to Parquet in memory
2530
+ parquet_buffer = io.BytesIO()
2531
+ try:
2532
+ # Try pyarrow first (faster), fallback to fastparquet
2533
+ csv_df.to_parquet(parquet_buffer, index=False, engine='pyarrow')
2534
+ except (ImportError, ValueError):
2535
+ # Fallback to fastparquet or default engine
2536
+ try:
2537
+ csv_df.to_parquet(parquet_buffer, index=False, engine='fastparquet')
2538
+ except (ImportError, ValueError):
2539
+ # Last resort: use default engine
2540
+ csv_df.to_parquet(parquet_buffer, index=False)
2541
+ parquet_content = parquet_buffer.getvalue()
2542
+ parquet_size_mb = len(parquet_content) / (1024 * 1024)
2543
+
2544
+ # Use Parquet instead of compressed CSV
2545
+ file_content = parquet_content
2546
+ upload_filename = os.path.basename(file_path).replace('.csv', '.parquet')
2547
+ content_type = 'application/octet-stream'
2548
+
2549
+ size_reduction = (1 - len(parquet_content) / len(csv_content)) * 100
2550
+ print(f" ✅ Converted to Parquet: {csv_size_mb:.1f} MB → {parquet_size_mb:.1f} MB ({size_reduction:.1f}% reduction)")
2551
+ else:
2552
+ # Small CSV - compress as before
2553
+ print("Compressing CSV file...")
2554
+ compressed_buffer = io.BytesIO()
2555
+ with gzip.GzipFile(fileobj=compressed_buffer, mode='wb') as gz:
2556
+ gz.write(csv_content)
2557
+ file_content = compressed_buffer.getvalue()
2558
+ upload_filename = os.path.basename(file_path) + '.gz'
2559
+ content_type = 'application/gzip'
2560
+
2561
+ original_size = len(csv_content)
2562
+ compressed_size = len(file_content)
2563
+ compression_ratio = (1 - compressed_size / original_size) * 100
2564
+ print(f"Compressed from {original_size:,} to {compressed_size:,} bytes ({compression_ratio:.1f}% reduction)")
2329
2565
 
2330
2566
  # Handle DataFrame input
2331
2567
  else:
@@ -2334,29 +2570,31 @@ class FeatrixSphereClient:
2334
2570
 
2335
2571
  print(f"Uploading DataFrame ({len(df)} rows, {len(df.columns)} columns)")
2336
2572
 
2337
- # Clean NaN values in DataFrame before CSV conversion
2573
+ # Clean NaN values in DataFrame before conversion
2338
2574
  # This prevents JSON encoding issues when the server processes the data
2339
2575
  # Use pandas.notna() with where() for compatibility with all pandas versions
2340
2576
  cleaned_df = df.where(pd.notna(df), None) # Replace NaN with None for JSON compatibility
2341
2577
 
2342
- # Convert DataFrame to CSV and compress
2343
- csv_buffer = io.StringIO()
2344
- cleaned_df.to_csv(csv_buffer, index=False)
2345
- csv_data = csv_buffer.getvalue().encode('utf-8')
2346
-
2347
- # Compress the CSV data
2348
- print("Compressing DataFrame...")
2349
- compressed_buffer = io.BytesIO()
2350
- with gzip.GzipFile(fileobj=compressed_buffer, mode='wb') as gz:
2351
- gz.write(csv_data)
2352
- file_content = compressed_buffer.getvalue()
2353
- upload_filename = filename if filename.endswith('.gz') else filename + '.gz'
2354
- content_type = 'application/gzip'
2355
-
2356
- original_size = len(csv_data)
2357
- compressed_size = len(file_content)
2358
- compression_ratio = (1 - compressed_size / original_size) * 100
2359
- print(f"Compressed from {original_size:,} to {compressed_size:,} bytes ({compression_ratio:.1f}% reduction)")
2578
+ # Always use Parquet format for DataFrames (smaller, faster than CSV.gz)
2579
+ print("Converting DataFrame to Parquet format...")
2580
+ parquet_buffer = io.BytesIO()
2581
+ try:
2582
+ # Try pyarrow first (faster), fallback to fastparquet
2583
+ cleaned_df.to_parquet(parquet_buffer, index=False, engine='pyarrow')
2584
+ except (ImportError, ValueError):
2585
+ # Fallback to fastparquet or default engine
2586
+ try:
2587
+ cleaned_df.to_parquet(parquet_buffer, index=False, engine='fastparquet')
2588
+ except (ImportError, ValueError):
2589
+ # Last resort: use default engine
2590
+ cleaned_df.to_parquet(parquet_buffer, index=False)
2591
+
2592
+ file_content = parquet_buffer.getvalue()
2593
+ parquet_size_mb = len(file_content) / (1024 * 1024)
2594
+ upload_filename = filename.replace('.csv', '.parquet') if filename.endswith('.csv') else filename + '.parquet'
2595
+ content_type = 'application/octet-stream'
2596
+
2597
+ print(f"✅ Saved as Parquet: {parquet_size_mb:.2f} MB")
2360
2598
 
2361
2599
  # Upload the compressed file with optional column overrides
2362
2600
  files = {'file': (upload_filename, file_content, content_type)}
@@ -2403,9 +2641,18 @@ class FeatrixSphereClient:
2403
2641
  file_size_mb = len(file_content) / (1024 * 1024)
2404
2642
  CHUNK_SIZE_MB = 512 # 512 MB chunk size
2405
2643
  CHUNK_SIZE_BYTES = CHUNK_SIZE_MB * 1024 * 1024
2644
+ LARGE_FILE_WARNING_MB = 10 # Warn if file > 10 MB
2645
+
2646
+ if file_size_mb > LARGE_FILE_WARNING_MB:
2647
+ print(f"\n⚠️ Warning: File size ({file_size_mb:.1f} MB) is quite large")
2648
+ print(f" For very large files (>10 MB), consider using S3 uploads:")
2649
+ print(f" 1. Upload your file to S3 (or your cloud storage)")
2650
+ print(f" 2. Generate a signed/private URL with read access")
2651
+ print(f" 3. Contact Featrix support to configure S3-based uploads")
2652
+ print(f" This can be more reliable than direct uploads for large datasets.")
2406
2653
 
2407
2654
  if file_size_mb > CHUNK_SIZE_MB:
2408
- print(f"⚠️ Warning: File size ({file_size_mb:.1f} MB) exceeds {CHUNK_SIZE_MB} MB threshold")
2655
+ print(f"\n⚠️ Warning: File size ({file_size_mb:.1f} MB) exceeds {CHUNK_SIZE_MB} MB threshold")
2409
2656
  print(f" Large uploads may timeout. Consider splitting the data or using smaller batches.")
2410
2657
 
2411
2658
  # Try upload with retry on 504
@@ -2451,7 +2698,8 @@ class FeatrixSphereClient:
2451
2698
  session_type=response_data.get('session_type', 'sphere'),
2452
2699
  status=response_data.get('status', 'ready'),
2453
2700
  jobs={},
2454
- job_queue_positions={}
2701
+ job_queue_positions={},
2702
+ job_plan=[]
2455
2703
  )
2456
2704
 
2457
2705
 
@@ -2564,7 +2812,8 @@ class FeatrixSphereClient:
2564
2812
  # =========================================================================
2565
2813
 
2566
2814
  def predict(self, session_id: str, record: Dict[str, Any], target_column: str = None,
2567
- predictor_id: str = None, max_retries: int = None, queue_batches: bool = False) -> Dict[str, Any]:
2815
+ predictor_id: str = None, best_metric_preference: str = None,
2816
+ max_retries: int = None, queue_batches: bool = False) -> Dict[str, Any]:
2568
2817
  """
2569
2818
  Make a single prediction for a record.
2570
2819
 
@@ -2573,6 +2822,7 @@ class FeatrixSphereClient:
2573
2822
  record: Record dictionary (without target column)
2574
2823
  target_column: Specific target column predictor to use (required if multiple predictors exist and predictor_id not specified)
2575
2824
  predictor_id: Specific predictor ID to use (recommended - more precise than target_column)
2825
+ best_metric_preference: Which metric checkpoint to use: "roc_auc", "pr_auc", or None (use default checkpoint) (default: None)
2576
2826
  max_retries: Number of retries for errors (default: uses client default)
2577
2827
  queue_batches: If True, queue this prediction for batch processing instead of immediate API call
2578
2828
 
@@ -2630,6 +2880,10 @@ class FeatrixSphereClient:
2630
2880
  if resolved_predictor_id:
2631
2881
  request_payload["predictor_id"] = resolved_predictor_id
2632
2882
 
2883
+ # Include best_metric_preference if specified
2884
+ if best_metric_preference:
2885
+ request_payload["best_metric_preference"] = best_metric_preference
2886
+
2633
2887
  response_data = self._post_json(f"/session/{session_id}/predict", request_payload, max_retries=max_retries)
2634
2888
  return response_data
2635
2889
 
@@ -2712,7 +2966,7 @@ class FeatrixSphereClient:
2712
2966
  def plot_training_loss(self, session_id: str, figsize: Tuple[int, int] = (12, 8),
2713
2967
  style: str = 'notebook', save_path: Optional[str] = None,
2714
2968
  show_learning_rate: bool = True, smooth: bool = True,
2715
- title: Optional[str] = None) -> 'plt.Figure':
2969
+ title: Optional[str] = None):
2716
2970
  """
2717
2971
  Plot comprehensive training loss curves for a session (both embedding space and single predictor).
2718
2972
 
@@ -2800,7 +3054,7 @@ class FeatrixSphereClient:
2800
3054
 
2801
3055
  def plot_embedding_space_training(self, session_id: str, figsize: Tuple[int, int] = (10, 6),
2802
3056
  style: str = 'notebook', save_path: Optional[str] = None,
2803
- show_mutual_info: bool = False) -> 'plt.Figure':
3057
+ show_mutual_info: bool = False):
2804
3058
  """
2805
3059
  Plot detailed embedding space training metrics.
2806
3060
 
@@ -2876,7 +3130,7 @@ class FeatrixSphereClient:
2876
3130
 
2877
3131
  def plot_single_predictor_training(self, session_id: str, figsize: Tuple[int, int] = (10, 6),
2878
3132
  style: str = 'notebook', save_path: Optional[str] = None,
2879
- show_metrics: bool = True) -> 'plt.Figure':
3133
+ show_metrics: bool = True):
2880
3134
  """
2881
3135
  Plot detailed single predictor training metrics.
2882
3136
 
@@ -2952,7 +3206,7 @@ class FeatrixSphereClient:
2952
3206
 
2953
3207
  def plot_training_comparison(self, session_ids: List[str], labels: Optional[List[str]] = None,
2954
3208
  figsize: Tuple[int, int] = (12, 8), style: str = 'notebook',
2955
- save_path: Optional[str] = None) -> plt.Figure:
3209
+ save_path: Optional[str] = None):
2956
3210
  """
2957
3211
  Compare training curves across multiple sessions.
2958
3212
 
@@ -3205,7 +3459,7 @@ class FeatrixSphereClient:
3205
3459
  def plot_embedding_space_3d(self, session_id: str, sample_size: int = 2000,
3206
3460
  color_by: Optional[str] = None, size_by: Optional[str] = None,
3207
3461
  interactive: bool = True, style: str = 'notebook',
3208
- title: Optional[str] = None, save_path: Optional[str] = None) -> Union[plt.Figure, 'go.Figure']:
3462
+ title: Optional[str] = None, save_path: Optional[str] = None):
3209
3463
  """
3210
3464
  Create interactive 3D visualization of the embedding space.
3211
3465
 
@@ -3281,7 +3535,7 @@ class FeatrixSphereClient:
3281
3535
  style: str = 'notebook', save_path: Optional[str] = None,
3282
3536
  show_embedding_evolution: bool = True,
3283
3537
  show_loss_evolution: bool = True,
3284
- fps: int = 2, notebook_mode: bool = True) -> Union[plt.Figure, 'HTML']:
3538
+ fps: int = 2, notebook_mode: bool = True):
3285
3539
  """
3286
3540
  Create an animated training movie showing loss curves and embedding evolution.
3287
3541
 
@@ -3335,7 +3589,7 @@ class FeatrixSphereClient:
3335
3589
 
3336
3590
  def plot_embedding_evolution(self, session_id: str, epoch_range: Optional[Tuple[int, int]] = None,
3337
3591
  interactive: bool = True, sample_size: int = 1000,
3338
- color_by: Optional[str] = None) -> Union[plt.Figure, 'go.Figure']:
3592
+ color_by: Optional[str] = None):
3339
3593
  """
3340
3594
  Show how embedding space evolves during training across epochs.
3341
3595
 
@@ -3771,7 +4025,18 @@ class FeatrixSphereClient:
3771
4025
  available_predictors = self._get_available_predictors(session_id, debug=debug)
3772
4026
 
3773
4027
  if not available_predictors:
3774
- raise ValueError(f"No trained predictors found for session {session_id}")
4028
+ # Don't fail here - let the server try to find/auto-discover the predictor
4029
+ # The server's /predict endpoint has smart fallback logic to find checkpoint files
4030
+ # even if the session file wasn't properly updated (e.g., training crashed)
4031
+ if debug:
4032
+ print(f"⚠️ No predictors found via models endpoint, letting server handle discovery")
4033
+ return {
4034
+ 'target_column': target_column,
4035
+ 'predictor_id': predictor_id,
4036
+ 'path': None,
4037
+ 'type': None,
4038
+ 'server_discovery': True # Flag that server should auto-discover
4039
+ }
3775
4040
 
3776
4041
  # If predictor_id is provided, find it directly (since it's now the key)
3777
4042
  if predictor_id:
@@ -4649,6 +4914,7 @@ class FeatrixSphereClient:
4649
4914
  status="running",
4650
4915
  jobs={},
4651
4916
  job_queue_positions={},
4917
+ job_plan=[],
4652
4918
  _client=self
4653
4919
  )
4654
4920
 
@@ -5500,6 +5766,7 @@ class FeatrixSphereClient:
5500
5766
  status="running",
5501
5767
  jobs={},
5502
5768
  job_queue_positions={},
5769
+ job_plan=[],
5503
5770
  _client=self
5504
5771
  )
5505
5772
 
@@ -6083,7 +6350,8 @@ class FeatrixSphereClient:
6083
6350
  # =========================================================================
6084
6351
 
6085
6352
  def predict_table(self, session_id: str, table_data: Dict[str, Any],
6086
- target_column: str = None, predictor_id: str = None, max_retries: int = None) -> Dict[str, Any]:
6353
+ target_column: str = None, predictor_id: str = None,
6354
+ best_metric_preference: str = None, max_retries: int = None) -> Dict[str, Any]:
6087
6355
  """
6088
6356
  Make batch predictions using JSON Tables format.
6089
6357
 
@@ -6124,6 +6392,8 @@ class FeatrixSphereClient:
6124
6392
  table_data['target_column'] = target_column
6125
6393
  if predictor_id:
6126
6394
  table_data['predictor_id'] = predictor_id
6395
+ if best_metric_preference:
6396
+ table_data['best_metric_preference'] = best_metric_preference
6127
6397
 
6128
6398
  try:
6129
6399
  response_data = self._post_json(f"/session/{session_id}/predict_table", table_data, max_retries=max_retries)
@@ -6136,7 +6406,8 @@ class FeatrixSphereClient:
6136
6406
  raise
6137
6407
 
6138
6408
  def predict_records(self, session_id: str, records: List[Dict[str, Any]],
6139
- target_column: str = None, predictor_id: str = None, batch_size: int = 2500, use_async: bool = False,
6409
+ target_column: str = None, predictor_id: str = None, best_metric_preference: str = None,
6410
+ batch_size: int = 2500, use_async: bool = False,
6140
6411
  show_progress_bar: bool = True, print_target_column_warning: bool = True) -> Dict[str, Any]:
6141
6412
  """
6142
6413
  Make batch predictions on a list of records with automatic client-side batching.
@@ -6183,7 +6454,8 @@ class FeatrixSphereClient:
6183
6454
  table_data = JSONTablesEncoder.from_records(cleaned_records)
6184
6455
 
6185
6456
  try:
6186
- result = self.predict_table(session_id, table_data)
6457
+ result = self.predict_table(session_id, table_data, target_column=target_column,
6458
+ predictor_id=predictor_id, best_metric_preference=best_metric_preference)
6187
6459
 
6188
6460
  # Check if server returned an async job
6189
6461
  if result.get('async') and result.get('job_id'):
@@ -6220,7 +6492,8 @@ class FeatrixSphereClient:
6220
6492
  table_data = JSONTablesEncoder.from_records(cleaned_records)
6221
6493
 
6222
6494
  try:
6223
- return self.predict_table(session_id, table_data)
6495
+ return self.predict_table(session_id, table_data, target_column=target_column,
6496
+ predictor_id=predictor_id, best_metric_preference=best_metric_preference)
6224
6497
  except Exception as e:
6225
6498
  if "404" in str(e) and "Single predictor not found" in str(e):
6226
6499
  self._raise_predictor_not_found_error(session_id, "predict_records")
@@ -6250,7 +6523,8 @@ class FeatrixSphereClient:
6250
6523
  table_data = JSONTablesEncoder.from_records(chunk_records)
6251
6524
 
6252
6525
  # Make prediction
6253
- chunk_result = self.predict_table(session_id, table_data)
6526
+ chunk_result = self.predict_table(session_id, table_data, target_column=target_column,
6527
+ predictor_id=predictor_id, best_metric_preference=best_metric_preference)
6254
6528
  chunk_predictions = chunk_result.get('predictions', [])
6255
6529
 
6256
6530
  # Adjust row indices to match original dataset
@@ -6611,7 +6885,8 @@ class FeatrixSphereClient:
6611
6885
  print(f"\n⏰ Timeout after {max_wait_time} seconds")
6612
6886
  return {'status': 'timeout', 'message': f'Job did not complete within {max_wait_time} seconds'}
6613
6887
 
6614
- def predict_df(self, session_id: str, df, target_column: str = None, predictor_id: str = None, show_progress_bar: bool = True, print_target_column_warning: bool = True) -> Dict[str, Any]:
6888
+ def predict_df(self, session_id: str, df, target_column: str = None, predictor_id: str = None,
6889
+ best_metric_preference: str = None, show_progress_bar: bool = True, print_target_column_warning: bool = True) -> Dict[str, Any]:
6615
6890
  """
6616
6891
  Make batch predictions on a pandas DataFrame.
6617
6892
 
@@ -6636,7 +6911,8 @@ class FeatrixSphereClient:
6636
6911
  records = df.to_dict(orient='records')
6637
6912
  # Clean NaNs for JSON encoding
6638
6913
  cleaned_records = self.replace_nans_with_nulls(records)
6639
- return self.predict_records(session_id, cleaned_records, target_column=target_column, predictor_id=predictor_id, show_progress_bar=show_progress_bar, print_target_column_warning=print_target_column_warning)
6914
+ return self.predict_records(session_id, cleaned_records, target_column=target_column, predictor_id=predictor_id,
6915
+ best_metric_preference=best_metric_preference, show_progress_bar=show_progress_bar, print_target_column_warning=print_target_column_warning)
6640
6916
 
6641
6917
  def _raise_predictor_not_found_error(self, session_id: str, method_name: str):
6642
6918
  """
@@ -6767,28 +7043,54 @@ class FeatrixSphereClient:
6767
7043
  training_metrics = models.get('training_metrics', {})
6768
7044
  if debug:
6769
7045
  print(f"🔍 Debug: training_metrics available = {training_metrics.get('available')}")
7046
+ target_column = None
7047
+ metadata = {}
7048
+
6770
7049
  if training_metrics.get('available'):
6771
- metrics_data = self.get_training_metrics(session_id)
6772
- if debug:
6773
- print(f"🔍 Debug: metrics_data keys = {list(metrics_data.keys())}")
6774
- training_metrics_inner = metrics_data.get('training_metrics', {})
6775
- if debug:
6776
- print(f"🔍 Debug: training_metrics_inner keys = {list(training_metrics_inner.keys()) if training_metrics_inner else 'None'}")
6777
- target_column = training_metrics_inner.get('target_column')
6778
- if debug:
6779
- print(f"🔍 Debug: extracted target_column = {target_column}")
6780
- if target_column:
6781
- # Extract metadata from training metrics
6782
- metadata = self._extract_predictor_metadata(metrics_data, debug)
6783
-
6784
- # Generate unique predictor ID
6785
- predictor_path = single_predictor.get('path', '')
7050
+ try:
7051
+ metrics_data = self.get_training_metrics(session_id)
7052
+ if debug:
7053
+ print(f"🔍 Debug: metrics_data keys = {list(metrics_data.keys())}")
7054
+ training_metrics_inner = metrics_data.get('training_metrics', {})
7055
+ if debug:
7056
+ print(f"🔍 Debug: training_metrics_inner keys = {list(training_metrics_inner.keys()) if training_metrics_inner else 'None'}")
7057
+ target_column = training_metrics_inner.get('target_column')
7058
+ if debug:
7059
+ print(f"🔍 Debug: extracted target_column = {target_column}")
7060
+ if target_column:
7061
+ # Extract metadata from training metrics
7062
+ metadata = self._extract_predictor_metadata(metrics_data, debug)
7063
+ except Exception as e:
7064
+ if debug:
7065
+ print(f"⚠️ Could not get training metrics: {e}")
7066
+
7067
+ # Fallback: try to get target column from job_plan
7068
+ if not target_column:
7069
+ job_plan = session.get('job_plan', [])
7070
+ for job in job_plan:
7071
+ if job.get('job_type') == 'train_single_predictor':
7072
+ spec = job.get('spec', {})
7073
+ target_column = spec.get('target_column')
7074
+ if target_column:
7075
+ if debug:
7076
+ print(f"🔍 Debug: extracted target_column from job_plan: {target_column}")
7077
+ break
7078
+
7079
+ # If predictor is available, add it even without target_column (can be None)
7080
+ if single_predictor.get('available') or single_predictor.get('predictors'):
7081
+ # Generate unique predictor ID
7082
+ predictor_path = single_predictor.get('path', '')
7083
+ if not predictor_path and single_predictor.get('predictors'):
7084
+ # Use first predictor from new format
7085
+ predictor_path = single_predictor.get('predictors', [{}])[0].get('path', '')
7086
+
7087
+ if predictor_path:
6786
7088
  predictor_id = self._generate_predictor_id(predictor_path, 'single_predictor')
6787
7089
 
6788
7090
  predictors[predictor_id] = {
6789
7091
  'predictor_id': predictor_id,
6790
7092
  'path': predictor_path,
6791
- 'target_column': target_column,
7093
+ 'target_column': target_column, # Can be None
6792
7094
  'available': True,
6793
7095
  'type': 'single_predictor',
6794
7096
  **metadata # Include epochs, validation_loss, job_status, etc.
@@ -6811,6 +7113,9 @@ class FeatrixSphereClient:
6811
7113
  if debug:
6812
7114
  print(f"🔍 Debug: single_predictors array = {single_predictors_paths}")
6813
7115
  if single_predictors_paths:
7116
+ target_column = None
7117
+ metadata = {}
7118
+
6814
7119
  # Try to get target column info from training metrics
6815
7120
  training_metrics = models.get('training_metrics', {})
6816
7121
  if training_metrics.get('available'):
@@ -6820,30 +7125,44 @@ class FeatrixSphereClient:
6820
7125
  if target_column:
6821
7126
  # Extract metadata from training metrics
6822
7127
  metadata = self._extract_predictor_metadata(metrics_data, debug)
6823
-
6824
- # Add each predictor individually with its own predictor_id key
6825
- for i, path in enumerate(single_predictors_paths):
6826
- predictor_id = self._generate_predictor_id(path, f'multiple_predictor_{i}')
6827
-
6828
- predictors[predictor_id] = {
6829
- 'predictor_id': predictor_id,
6830
- 'path': path,
6831
- 'target_column': target_column,
6832
- 'available': True,
6833
- 'type': 'single_predictor', # Each is treated as individual predictor
6834
- 'predictor_index': i, # Track original index for compatibility
6835
- **metadata # Include epochs, validation_loss, job_status, etc.
6836
- }
6837
- if debug:
6838
- print(f"✅ Added predictor {i} for target_column: {target_column}")
6839
- print(f" Predictor ID: {predictor_id}")
6840
- print(f" Path: {path}")
6841
-
6842
- if debug:
6843
- print(f" Total predictors added: {len(single_predictors_paths)}")
6844
- print(f" Shared metadata: {metadata}")
6845
7128
  except Exception as e:
6846
- print(f"Warning: Could not extract target column from training metrics: {e}")
7129
+ if debug:
7130
+ print(f"⚠️ Could not get training metrics: {e}")
7131
+
7132
+ # Fallback: try to get target column from job_plan
7133
+ if not target_column:
7134
+ job_plan = session.get('job_plan', [])
7135
+ for job in job_plan:
7136
+ if job.get('job_type') == 'train_single_predictor':
7137
+ spec = job.get('spec', {})
7138
+ target_column = spec.get('target_column')
7139
+ if target_column:
7140
+ if debug:
7141
+ print(f"🔍 Debug: extracted target_column from job_plan: {target_column}")
7142
+ break
7143
+
7144
+ # Add each predictor even if target_column is None
7145
+ for i, path in enumerate(single_predictors_paths):
7146
+ predictor_id = self._generate_predictor_id(path, f'multiple_predictor_{i}')
7147
+
7148
+ predictors[predictor_id] = {
7149
+ 'predictor_id': predictor_id,
7150
+ 'path': path,
7151
+ 'target_column': target_column, # Can be None
7152
+ 'available': True,
7153
+ 'type': 'single_predictor', # Each is treated as individual predictor
7154
+ 'predictor_index': i, # Track original index for compatibility
7155
+ **metadata # Include epochs, validation_loss, job_status, etc.
7156
+ }
7157
+ if debug:
7158
+ print(f"✅ Added predictor {i} for target_column: {target_column}")
7159
+ print(f" Predictor ID: {predictor_id}")
7160
+ print(f" Path: {path}")
7161
+
7162
+ if debug:
7163
+ print(f" Total predictors added: {len(single_predictors_paths)}")
7164
+ if metadata:
7165
+ print(f" Shared metadata: {metadata}")
6847
7166
 
6848
7167
  # Fallback: check old format single_predictor field
6849
7168
  single_predictor_path = session.get('single_predictor')
@@ -6920,7 +7239,7 @@ class FeatrixSphereClient:
6920
7239
  target_column: Specific target column to validate, or None for auto-detect
6921
7240
 
6922
7241
  Returns:
6923
- Validated target column name
7242
+ Validated target column name (or None if can't determine, server will handle)
6924
7243
 
6925
7244
  Raises:
6926
7245
  ValueError: If target_column is invalid or multiple predictors exist without specification
@@ -6928,7 +7247,8 @@ class FeatrixSphereClient:
6928
7247
  available_predictors = self._get_available_predictors(session_id)
6929
7248
 
6930
7249
  if not available_predictors:
6931
- raise ValueError(f"No trained predictors found for session {session_id}")
7250
+ # Don't fail - let server handle discovery. Return provided target_column or None.
7251
+ return target_column
6932
7252
 
6933
7253
  if target_column is None:
6934
7254
  # Auto-detect: only valid if there's exactly one predictor
@@ -6982,6 +7302,10 @@ class FeatrixSphereClient:
6982
7302
  # Re-raise validation errors
6983
7303
  raise e
6984
7304
 
7305
+ # If we couldn't determine target column (server will handle), just return records as-is
7306
+ if validated_target_column is None:
7307
+ return records
7308
+
6985
7309
  if validated_target_column in records[0]:
6986
7310
  if print_warning:
6987
7311
  print(f"⚠️ Warning: Removing target column '{validated_target_column}' from prediction data")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: featrixsphere
3
- Version: 0.2.3737
3
+ Version: 0.2.4982
4
4
  Summary: Transform any CSV into a production-ready ML model in minutes, not months.
5
5
  Home-page: https://github.com/Featrix/sphere
6
6
  Author: Featrix
@@ -0,0 +1,8 @@
1
+ featrixsphere/__init__.py,sha256=6e5gN0j6g8dLuetEEDGVk9ZXasdgsuii0qNF83OQIYU,1888
2
+ featrixsphere/client.py,sha256=HouSLZcEoGjKZy9c0HK0TG-Xog7L_12kCik54TrwzS8,432653
3
+ featrixsphere/test_client.py,sha256=4SiRbib0ms3poK0UpnUv4G0HFQSzidF3Iswo_J2cjLk,11981
4
+ featrixsphere-0.2.4982.dist-info/METADATA,sha256=DI-xhIm1bAk2rz-5XxK-2jdQsySuQGRZPTLN7LcxEiM,16232
5
+ featrixsphere-0.2.4982.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
+ featrixsphere-0.2.4982.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
7
+ featrixsphere-0.2.4982.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
8
+ featrixsphere-0.2.4982.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- featrixsphere/__init__.py,sha256=rPPiD3URmePTVO31B5XHkaGVl6zd-r6OIbpBxjqi9Yg,1888
2
- featrixsphere/client.py,sha256=XaNFHfjogpj3exISZG1Q2SIMn-NewVsUELpzN7-5I-A,416085
3
- featrixsphere/test_client.py,sha256=4SiRbib0ms3poK0UpnUv4G0HFQSzidF3Iswo_J2cjLk,11981
4
- featrixsphere-0.2.3737.dist-info/METADATA,sha256=3gUOu7cpZPdzBkw5Z_8H6StVxotvIBCkOXmOWfW951c,16232
5
- featrixsphere-0.2.3737.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
- featrixsphere-0.2.3737.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
7
- featrixsphere-0.2.3737.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
8
- featrixsphere-0.2.3737.dist-info/RECORD,,