featrixsphere 0.2.1141__py3-none-any.whl → 0.2.1221__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- featrixsphere/__init__.py +1 -1
- featrixsphere/client.py +431 -19
- {featrixsphere-0.2.1141.dist-info → featrixsphere-0.2.1221.dist-info}/METADATA +1 -1
- featrixsphere-0.2.1221.dist-info/RECORD +9 -0
- featrixsphere-0.2.1141.dist-info/RECORD +0 -9
- {featrixsphere-0.2.1141.dist-info → featrixsphere-0.2.1221.dist-info}/WHEEL +0 -0
- {featrixsphere-0.2.1141.dist-info → featrixsphere-0.2.1221.dist-info}/entry_points.txt +0 -0
- {featrixsphere-0.2.1141.dist-info → featrixsphere-0.2.1221.dist-info}/top_level.txt +0 -0
featrixsphere/__init__.py
CHANGED
featrixsphere/client.py
CHANGED
|
@@ -633,13 +633,97 @@ class FeatrixSphereClient:
|
|
|
633
633
|
_client=self
|
|
634
634
|
)
|
|
635
635
|
|
|
636
|
-
def
|
|
636
|
+
def update_user_metadata(self, session_id: str, metadata: Dict[str, Any], write_mode: str = "merge") -> Dict[str, Any]:
|
|
637
|
+
"""
|
|
638
|
+
Update user metadata for a session.
|
|
639
|
+
|
|
640
|
+
Args:
|
|
641
|
+
session_id: The session ID to update metadata for
|
|
642
|
+
metadata: Dictionary of metadata to update (max 32KB total)
|
|
643
|
+
write_mode: How to update metadata:
|
|
644
|
+
- "merge" (default): Merge new metadata with existing (existing keys are updated, new keys are added)
|
|
645
|
+
- "overwrite": Replace all user_metadata with the new dictionary
|
|
646
|
+
|
|
647
|
+
Returns:
|
|
648
|
+
Dictionary containing the updated session information
|
|
649
|
+
|
|
650
|
+
Raises:
|
|
651
|
+
requests.exceptions.HTTPError: If the request fails
|
|
652
|
+
ValueError: If write_mode is not "merge" or "overwrite"
|
|
653
|
+
|
|
654
|
+
Example:
|
|
655
|
+
>>> # Merge new metadata with existing
|
|
656
|
+
>>> client.update_user_metadata(
|
|
657
|
+
... session_id="abc123",
|
|
658
|
+
... metadata={"new_key": "value", "existing_key": "updated_value"},
|
|
659
|
+
... write_mode="merge"
|
|
660
|
+
... )
|
|
661
|
+
|
|
662
|
+
>>> # Replace all metadata
|
|
663
|
+
>>> client.update_user_metadata(
|
|
664
|
+
... session_id="abc123",
|
|
665
|
+
... metadata={"only_key": "only_value"},
|
|
666
|
+
... write_mode="overwrite"
|
|
667
|
+
... )
|
|
668
|
+
"""
|
|
669
|
+
if write_mode not in ["merge", "overwrite"]:
|
|
670
|
+
raise ValueError(f"write_mode must be 'merge' or 'overwrite', got '{write_mode}'")
|
|
671
|
+
|
|
672
|
+
request_data = {
|
|
673
|
+
"user_metadata": metadata,
|
|
674
|
+
"write_mode": write_mode
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
response_data = self._post_json(f"/session/{session_id}/update_user_metadata", request_data)
|
|
678
|
+
return response_data
|
|
679
|
+
|
|
680
|
+
def is_foundation_model_ready(self, session_id: str, max_retries: int = None) -> Tuple[bool, str]:
|
|
681
|
+
"""
|
|
682
|
+
Check if a foundation model session is ready to use (training completed).
|
|
683
|
+
|
|
684
|
+
Args:
|
|
685
|
+
session_id: The session ID to check
|
|
686
|
+
max_retries: Maximum number of retries (defaults to client default)
|
|
687
|
+
|
|
688
|
+
Returns:
|
|
689
|
+
Tuple of (is_ready: bool, status_message: str)
|
|
690
|
+
- is_ready: True if session is done and model card is available
|
|
691
|
+
- status_message: Human-readable status message
|
|
692
|
+
|
|
693
|
+
Example:
|
|
694
|
+
>>> is_ready, message = client.is_foundation_model_ready("session_123")
|
|
695
|
+
>>> if not is_ready:
|
|
696
|
+
... print(f"Foundation model not ready: {message}")
|
|
697
|
+
"""
|
|
698
|
+
try:
|
|
699
|
+
session_status = self.get_session_status(session_id, max_retries=max_retries)
|
|
700
|
+
|
|
701
|
+
if session_status.status in ["done", "DONE"]:
|
|
702
|
+
# Check if model card exists
|
|
703
|
+
try:
|
|
704
|
+
self.get_model_card(session_id, max_retries=max_retries, check_status_first=False)
|
|
705
|
+
return True, "Foundation model is ready"
|
|
706
|
+
except (requests.exceptions.HTTPError, FileNotFoundError):
|
|
707
|
+
return False, "Session is done but model card is not available yet"
|
|
708
|
+
else:
|
|
709
|
+
return False, f"Session is still {session_status.status}. Training may still be in progress."
|
|
710
|
+
|
|
711
|
+
except requests.exceptions.HTTPError as e:
|
|
712
|
+
if e.response.status_code == 404:
|
|
713
|
+
return False, f"Session {session_id} not found"
|
|
714
|
+
return False, f"Error checking session status: {e}"
|
|
715
|
+
except Exception as e:
|
|
716
|
+
return False, f"Error checking foundation model: {e}"
|
|
717
|
+
|
|
718
|
+
def get_model_card(self, session_id: str, max_retries: int = None, check_status_first: bool = True) -> Dict[str, Any]:
|
|
637
719
|
"""
|
|
638
720
|
Get the model card JSON for a given session.
|
|
639
721
|
|
|
640
722
|
Args:
|
|
641
723
|
session_id: The session ID to get the model card for
|
|
642
724
|
max_retries: Maximum number of retries (defaults to client default)
|
|
725
|
+
check_status_first: If True, check session status before fetching model card.
|
|
726
|
+
Provides better error messages if session is still training.
|
|
643
727
|
|
|
644
728
|
Returns:
|
|
645
729
|
Dictionary containing the model card JSON data
|
|
@@ -647,12 +731,31 @@ class FeatrixSphereClient:
|
|
|
647
731
|
Raises:
|
|
648
732
|
requests.exceptions.HTTPError: If the request fails
|
|
649
733
|
FileNotFoundError: If the model card doesn't exist (404)
|
|
734
|
+
ValueError: If session is not ready and check_status_first is True
|
|
650
735
|
|
|
651
736
|
Example:
|
|
652
737
|
>>> client = FeatrixSphereClient()
|
|
653
738
|
>>> model_card = client.get_model_card("session_123")
|
|
654
739
|
>>> print(model_card["model_details"]["name"])
|
|
655
740
|
"""
|
|
741
|
+
# Check session status first to provide better error messages
|
|
742
|
+
if check_status_first:
|
|
743
|
+
try:
|
|
744
|
+
session_status = self.get_session_status(session_id, max_retries=max_retries)
|
|
745
|
+
if session_status.status not in ["done", "DONE"]:
|
|
746
|
+
raise ValueError(
|
|
747
|
+
f"Session {session_id} is not ready (status: {session_status.status}). "
|
|
748
|
+
f"Model card is only available after training completes. "
|
|
749
|
+
f"Use wait_for_session_completion() to wait for training to finish."
|
|
750
|
+
)
|
|
751
|
+
except requests.exceptions.HTTPError as e:
|
|
752
|
+
# If we can't get status, continue and let the model_card request fail
|
|
753
|
+
# This handles cases where the session doesn't exist
|
|
754
|
+
if e.response.status_code == 404:
|
|
755
|
+
raise FileNotFoundError(f"Session {session_id} not found") from e
|
|
756
|
+
# For other HTTP errors, continue to try model_card request
|
|
757
|
+
pass
|
|
758
|
+
|
|
656
759
|
response = self._make_request(
|
|
657
760
|
"GET",
|
|
658
761
|
f"/session/{session_id}/model_card",
|
|
@@ -660,6 +763,77 @@ class FeatrixSphereClient:
|
|
|
660
763
|
)
|
|
661
764
|
return response.json()
|
|
662
765
|
|
|
766
|
+
def publish_session(self, session_id: str) -> Dict[str, Any]:
|
|
767
|
+
"""
|
|
768
|
+
Publish a session by moving it to /sphere/published/<sessionId>.
|
|
769
|
+
Moves both the session file and output directory.
|
|
770
|
+
|
|
771
|
+
Args:
|
|
772
|
+
session_id: Session ID to publish
|
|
773
|
+
|
|
774
|
+
Returns:
|
|
775
|
+
Response with published_path, output_path, and status
|
|
776
|
+
|
|
777
|
+
Example:
|
|
778
|
+
```python
|
|
779
|
+
result = client.publish_session("abc123")
|
|
780
|
+
print(f"Published to: {result['published_path']}")
|
|
781
|
+
```
|
|
782
|
+
"""
|
|
783
|
+
response_data = self._post_json(f"/compute/session/{session_id}/publish", {})
|
|
784
|
+
return response_data
|
|
785
|
+
|
|
786
|
+
def deprecate_session(self, session_id: str, warning_message: str, expiration_date: str) -> Dict[str, Any]:
|
|
787
|
+
"""
|
|
788
|
+
Deprecate a published session with a warning message and expiration date.
|
|
789
|
+
The session remains available until the expiration date.
|
|
790
|
+
|
|
791
|
+
Args:
|
|
792
|
+
session_id: Session ID to deprecate
|
|
793
|
+
warning_message: Warning message to display about deprecation
|
|
794
|
+
expiration_date: ISO format date string when session will be removed (e.g., "2025-12-31T23:59:59Z")
|
|
795
|
+
|
|
796
|
+
Returns:
|
|
797
|
+
Response with deprecation status
|
|
798
|
+
|
|
799
|
+
Example:
|
|
800
|
+
```python
|
|
801
|
+
from datetime import datetime, timedelta
|
|
802
|
+
|
|
803
|
+
expiration = (datetime.now() + timedelta(days=90)).isoformat() + "Z"
|
|
804
|
+
result = client.deprecate_session(
|
|
805
|
+
session_id="abc123",
|
|
806
|
+
warning_message="This session will be removed on 2025-12-31",
|
|
807
|
+
expiration_date=expiration
|
|
808
|
+
)
|
|
809
|
+
```
|
|
810
|
+
"""
|
|
811
|
+
data = {
|
|
812
|
+
"warning_message": warning_message,
|
|
813
|
+
"expiration_date": expiration_date
|
|
814
|
+
}
|
|
815
|
+
response_data = self._post_json(f"/compute/session/{session_id}/deprecate", data)
|
|
816
|
+
return response_data
|
|
817
|
+
|
|
818
|
+
def unpublish_session(self, session_id: str) -> Dict[str, Any]:
|
|
819
|
+
"""
|
|
820
|
+
Unpublish a session by moving it back from /sphere/published/<sessionId>.
|
|
821
|
+
|
|
822
|
+
Args:
|
|
823
|
+
session_id: Session ID to unpublish
|
|
824
|
+
|
|
825
|
+
Returns:
|
|
826
|
+
Response with unpublish status
|
|
827
|
+
|
|
828
|
+
Example:
|
|
829
|
+
```python
|
|
830
|
+
result = client.unpublish_session("abc123")
|
|
831
|
+
print(f"Status: {result['status']}")
|
|
832
|
+
```
|
|
833
|
+
"""
|
|
834
|
+
response_data = self._post_json(f"/compute/session/{session_id}/unpublish", {})
|
|
835
|
+
return response_data
|
|
836
|
+
|
|
663
837
|
def get_sessions_for_org(self, name_prefix: str, max_retries: int = None) -> Dict[str, Any]:
|
|
664
838
|
"""
|
|
665
839
|
Get all sessions matching a name prefix across all compute nodes.
|
|
@@ -703,8 +877,8 @@ class FeatrixSphereClient:
|
|
|
703
877
|
>>> print(f"Model card recreated: {model_card['model_info']['name']}")
|
|
704
878
|
"""
|
|
705
879
|
response = self._make_request(
|
|
706
|
-
"
|
|
707
|
-
f"/session/{session_id}/model_card",
|
|
880
|
+
"GET",
|
|
881
|
+
f"/compute/session/{session_id}/model_card",
|
|
708
882
|
max_retries=max_retries
|
|
709
883
|
)
|
|
710
884
|
return response.json()
|
|
@@ -1555,10 +1729,10 @@ class FeatrixSphereClient:
|
|
|
1555
1729
|
|
|
1556
1730
|
def upload_file_and_create_session(self, file_path: Path, session_name_prefix: str = None, name: str = None, webhooks: Dict[str, str] = None) -> SessionInfo:
|
|
1557
1731
|
"""
|
|
1558
|
-
Upload a CSV file and create a new session.
|
|
1732
|
+
Upload a CSV, Parquet, JSON, or JSONL file and create a new session.
|
|
1559
1733
|
|
|
1560
1734
|
Args:
|
|
1561
|
-
file_path: Path to the CSV file to upload
|
|
1735
|
+
file_path: Path to the CSV, Parquet, JSON, or JSONL file to upload
|
|
1562
1736
|
session_name_prefix: Optional prefix for the session ID. Session will be named <prefix>-<full-uuid>
|
|
1563
1737
|
name: Optional name for the embedding space/model (for identification and metadata)
|
|
1564
1738
|
webhooks: Optional dict with webhook configuration keys (webhook_callback_secret, s3_backup_url, model_id_update_url)
|
|
@@ -1622,7 +1796,7 @@ class FeatrixSphereClient:
|
|
|
1622
1796
|
webhooks: Dict[str, str] = None,
|
|
1623
1797
|
epochs: int = None) -> SessionInfo:
|
|
1624
1798
|
"""
|
|
1625
|
-
Upload a pandas DataFrame or
|
|
1799
|
+
Upload a pandas DataFrame, CSV file, Parquet file, JSON file, or JSONL file and create a new session.
|
|
1626
1800
|
|
|
1627
1801
|
Special Column: __featrix_train_predictor
|
|
1628
1802
|
------------------------------------------
|
|
@@ -1630,7 +1804,7 @@ class FeatrixSphereClient:
|
|
|
1630
1804
|
which rows are used for single predictor training.
|
|
1631
1805
|
|
|
1632
1806
|
How it works:
|
|
1633
|
-
- Add a boolean column "__featrix_train_predictor" to your DataFrame/CSV before upload
|
|
1807
|
+
- Add a boolean column "__featrix_train_predictor" to your DataFrame/CSV/Parquet/JSON/JSONL before upload
|
|
1634
1808
|
- Set it to True for rows you want to use for predictor training
|
|
1635
1809
|
- Set it to False (or any other value) for rows to exclude from predictor training
|
|
1636
1810
|
- Embedding space training uses ALL rows (ignores this column)
|
|
@@ -1664,7 +1838,7 @@ class FeatrixSphereClient:
|
|
|
1664
1838
|
Args:
|
|
1665
1839
|
df: pandas DataFrame to upload (optional if file_path is provided)
|
|
1666
1840
|
filename: Name to give the uploaded file (default: "data.csv")
|
|
1667
|
-
file_path: Path to CSV file to upload (optional if df is provided)
|
|
1841
|
+
file_path: Path to CSV, Parquet, JSON, or JSONL file to upload (optional if df is provided)
|
|
1668
1842
|
column_overrides: Dict mapping column names to types ("scalar", "set", "free_string", "free_string_list")
|
|
1669
1843
|
column_types: Alias for column_overrides (for backward compatibility)
|
|
1670
1844
|
string_list_delimiter: Delimiter for free_string_list columns (default: "|")
|
|
@@ -1705,21 +1879,90 @@ class FeatrixSphereClient:
|
|
|
1705
1879
|
if not os.path.exists(file_path):
|
|
1706
1880
|
raise FileNotFoundError(f"File not found: {file_path}")
|
|
1707
1881
|
|
|
1708
|
-
# Check if it's a
|
|
1709
|
-
|
|
1710
|
-
|
|
1882
|
+
# Check if it's a supported file type
|
|
1883
|
+
file_ext = file_path.lower()
|
|
1884
|
+
if not file_ext.endswith(('.csv', '.csv.gz', '.parquet', '.json', '.jsonl')):
|
|
1885
|
+
raise ValueError("File must be a CSV, Parquet, JSON, or JSONL file (with .csv, .csv.gz, .parquet, .json, or .jsonl extension)")
|
|
1711
1886
|
|
|
1712
1887
|
print(f"Uploading file: {file_path}")
|
|
1713
1888
|
|
|
1714
1889
|
# Read the file content
|
|
1715
1890
|
if file_path.endswith('.gz'):
|
|
1716
|
-
# Already gzipped
|
|
1891
|
+
# Already gzipped CSV
|
|
1717
1892
|
with gzip.open(file_path, 'rb') as f:
|
|
1718
1893
|
file_content = f.read()
|
|
1719
1894
|
upload_filename = os.path.basename(file_path)
|
|
1720
1895
|
content_type = 'application/gzip'
|
|
1896
|
+
elif file_path.lower().endswith(('.json', '.jsonl')):
|
|
1897
|
+
# JSON/JSONL file - read as DataFrame, convert to CSV, then compress
|
|
1898
|
+
print(f"Reading {'JSONL' if file_path.lower().endswith('.jsonl') else 'JSON'} file...")
|
|
1899
|
+
try:
|
|
1900
|
+
from featrix.neural.input_data_file import featrix_wrap_read_json_file
|
|
1901
|
+
json_df = featrix_wrap_read_json_file(file_path)
|
|
1902
|
+
if json_df is None:
|
|
1903
|
+
raise ValueError(f"Failed to parse {'JSONL' if file_path.lower().endswith('.jsonl') else 'JSON'} file")
|
|
1904
|
+
except ImportError:
|
|
1905
|
+
# Fallback to pandas if featrix wrapper not available
|
|
1906
|
+
if file_path.lower().endswith('.jsonl'):
|
|
1907
|
+
# JSONL: one JSON object per line
|
|
1908
|
+
import json
|
|
1909
|
+
records = []
|
|
1910
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
|
1911
|
+
for line in f:
|
|
1912
|
+
if line.strip():
|
|
1913
|
+
records.append(json.loads(line))
|
|
1914
|
+
json_df = pd.DataFrame(records)
|
|
1915
|
+
else:
|
|
1916
|
+
# Regular JSON
|
|
1917
|
+
json_df = pd.read_json(file_path)
|
|
1918
|
+
|
|
1919
|
+
# Clean NaN values before CSV conversion
|
|
1920
|
+
cleaned_df = json_df.where(pd.notna(json_df), None)
|
|
1921
|
+
|
|
1922
|
+
# Convert to CSV and compress
|
|
1923
|
+
csv_buffer = io.StringIO()
|
|
1924
|
+
cleaned_df.to_csv(csv_buffer, index=False)
|
|
1925
|
+
csv_data = csv_buffer.getvalue().encode('utf-8')
|
|
1926
|
+
|
|
1927
|
+
print(f"Compressing {'JSONL' if file_path.lower().endswith('.jsonl') else 'JSON'} (converted to CSV)...")
|
|
1928
|
+
compressed_buffer = io.BytesIO()
|
|
1929
|
+
with gzip.GzipFile(fileobj=compressed_buffer, mode='wb') as gz:
|
|
1930
|
+
gz.write(csv_data)
|
|
1931
|
+
file_content = compressed_buffer.getvalue()
|
|
1932
|
+
upload_filename = os.path.basename(file_path).replace('.jsonl', '.csv.gz').replace('.json', '.csv.gz')
|
|
1933
|
+
content_type = 'application/gzip'
|
|
1934
|
+
|
|
1935
|
+
original_size = len(csv_data)
|
|
1936
|
+
compressed_size = len(file_content)
|
|
1937
|
+
compression_ratio = (1 - compressed_size / original_size) * 100
|
|
1938
|
+
print(f"Converted {'JSONL' if file_path.lower().endswith('.jsonl') else 'JSON'} to CSV and compressed from {original_size:,} to {compressed_size:,} bytes ({compression_ratio:.1f}% reduction)")
|
|
1939
|
+
elif file_path.lower().endswith('.parquet'):
|
|
1940
|
+
# Parquet file - read as DataFrame, convert to CSV, then compress
|
|
1941
|
+
print("Reading Parquet file...")
|
|
1942
|
+
parquet_df = pd.read_parquet(file_path)
|
|
1943
|
+
|
|
1944
|
+
# Clean NaN values before CSV conversion
|
|
1945
|
+
cleaned_df = parquet_df.where(pd.notna(parquet_df), None)
|
|
1946
|
+
|
|
1947
|
+
# Convert to CSV and compress
|
|
1948
|
+
csv_buffer = io.StringIO()
|
|
1949
|
+
cleaned_df.to_csv(csv_buffer, index=False)
|
|
1950
|
+
csv_data = csv_buffer.getvalue().encode('utf-8')
|
|
1951
|
+
|
|
1952
|
+
print("Compressing Parquet (converted to CSV)...")
|
|
1953
|
+
compressed_buffer = io.BytesIO()
|
|
1954
|
+
with gzip.GzipFile(fileobj=compressed_buffer, mode='wb') as gz:
|
|
1955
|
+
gz.write(csv_data)
|
|
1956
|
+
file_content = compressed_buffer.getvalue()
|
|
1957
|
+
upload_filename = os.path.basename(file_path).replace('.parquet', '.csv.gz')
|
|
1958
|
+
content_type = 'application/gzip'
|
|
1959
|
+
|
|
1960
|
+
original_size = len(csv_data)
|
|
1961
|
+
compressed_size = len(file_content)
|
|
1962
|
+
compression_ratio = (1 - compressed_size / original_size) * 100
|
|
1963
|
+
print(f"Converted Parquet to CSV and compressed from {original_size:,} to {compressed_size:,} bytes ({compression_ratio:.1f}% reduction)")
|
|
1721
1964
|
else:
|
|
1722
|
-
#
|
|
1965
|
+
# Regular CSV file - read and compress it
|
|
1723
1966
|
with open(file_path, 'rb') as f:
|
|
1724
1967
|
csv_content = f.read()
|
|
1725
1968
|
|
|
@@ -4753,7 +4996,7 @@ class FeatrixSphereClient:
|
|
|
4753
4996
|
predictor_id: str = None, target_column: str = None,
|
|
4754
4997
|
batch_size: int = 0, learning_rate: float = None,
|
|
4755
4998
|
poll_interval: int = 30, max_poll_time: int = 3600,
|
|
4756
|
-
verbose: bool = True) -> Dict[str, Any]:
|
|
4999
|
+
verbose: bool = True, webhooks: Dict[str, str] = None) -> Dict[str, Any]:
|
|
4757
5000
|
"""
|
|
4758
5001
|
Continue training an existing single predictor for more epochs.
|
|
4759
5002
|
Loads the existing predictor and resumes training from where it left off.
|
|
@@ -4768,6 +5011,7 @@ class FeatrixSphereClient:
|
|
|
4768
5011
|
poll_interval: Seconds between status checks (default: 30)
|
|
4769
5012
|
max_poll_time: Maximum time to poll in seconds (default: 3600 = 1 hour)
|
|
4770
5013
|
verbose: Whether to print status updates (default: True)
|
|
5014
|
+
webhooks: Optional dict with webhook configuration keys (webhook_callback_secret, s3_backup_url, model_id_update_url)
|
|
4771
5015
|
|
|
4772
5016
|
Returns:
|
|
4773
5017
|
Response with continuation start confirmation or completion status
|
|
@@ -4799,6 +5043,8 @@ class FeatrixSphereClient:
|
|
|
4799
5043
|
data["target_column"] = target_column
|
|
4800
5044
|
if learning_rate is not None:
|
|
4801
5045
|
data["learning_rate"] = learning_rate
|
|
5046
|
+
if webhooks:
|
|
5047
|
+
data["webhooks"] = webhooks
|
|
4802
5048
|
|
|
4803
5049
|
if verbose:
|
|
4804
5050
|
print(f"🔄 Continuing training for predictor on session {session_id}")
|
|
@@ -4888,6 +5134,139 @@ class FeatrixSphereClient:
|
|
|
4888
5134
|
print(f"❌ Error starting predictor continuation: {e}")
|
|
4889
5135
|
raise
|
|
4890
5136
|
|
|
5137
|
+
def foundation_model_train_more(self, session_id: str, es_id: str = None, data_passes: int = None,
|
|
5138
|
+
epochs: int = None, poll_interval: int = 30, max_poll_time: int = 3600,
|
|
5139
|
+
verbose: bool = True, webhooks: Dict[str, str] = None) -> Dict[str, Any]:
|
|
5140
|
+
"""
|
|
5141
|
+
Continue training an existing foundation model (embedding space) for more epochs.
|
|
5142
|
+
Loads the existing embedding space and resumes training from where it left off.
|
|
5143
|
+
|
|
5144
|
+
Args:
|
|
5145
|
+
session_id: Session ID containing the trained foundation model
|
|
5146
|
+
es_id: Embedding space ID (optional, uses session's ES if not provided)
|
|
5147
|
+
data_passes: Additional epochs to train (preferred, default: 50)
|
|
5148
|
+
epochs: Additional epochs to train (deprecated, use data_passes instead, for compatibility)
|
|
5149
|
+
poll_interval: Seconds between status checks (default: 30)
|
|
5150
|
+
max_poll_time: Maximum time to poll in seconds (default: 3600 = 1 hour)
|
|
5151
|
+
verbose: Whether to print status updates (default: True)
|
|
5152
|
+
webhooks: Optional dict with webhook configuration keys (webhook_callback_secret, s3_backup_url, model_id_update_url)
|
|
5153
|
+
|
|
5154
|
+
Returns:
|
|
5155
|
+
Response with continuation start confirmation or completion status
|
|
5156
|
+
|
|
5157
|
+
Example:
|
|
5158
|
+
```python
|
|
5159
|
+
# Continue training for 50 more epochs
|
|
5160
|
+
result = client.foundation_model_train_more(
|
|
5161
|
+
session_id="abc123",
|
|
5162
|
+
data_passes=50
|
|
5163
|
+
)
|
|
5164
|
+
```
|
|
5165
|
+
"""
|
|
5166
|
+
# Support both data_passes and epochs for compatibility
|
|
5167
|
+
if data_passes is None and epochs is None:
|
|
5168
|
+
data_passes = 50 # Default
|
|
5169
|
+
elif data_passes is None:
|
|
5170
|
+
data_passes = epochs # Use epochs if data_passes not provided
|
|
5171
|
+
# If both provided, data_passes takes precedence
|
|
5172
|
+
|
|
5173
|
+
if data_passes <= 0:
|
|
5174
|
+
raise ValueError("data_passes (or epochs) must be > 0 (specify additional epochs to train)")
|
|
5175
|
+
|
|
5176
|
+
data = {
|
|
5177
|
+
"data_passes": data_passes,
|
|
5178
|
+
}
|
|
5179
|
+
|
|
5180
|
+
if es_id:
|
|
5181
|
+
data["es_id"] = es_id
|
|
5182
|
+
if webhooks:
|
|
5183
|
+
data["webhooks"] = webhooks
|
|
5184
|
+
|
|
5185
|
+
if verbose:
|
|
5186
|
+
print(f"🔄 Continuing training for foundation model on session {session_id}")
|
|
5187
|
+
print(f" Additional epochs: {data_passes}")
|
|
5188
|
+
if es_id:
|
|
5189
|
+
print(f" ES ID: {es_id}")
|
|
5190
|
+
|
|
5191
|
+
try:
|
|
5192
|
+
response_data = self._post_json(f"/compute/session/{session_id}/train_foundation_model_more", data)
|
|
5193
|
+
|
|
5194
|
+
if verbose:
|
|
5195
|
+
print(f"✅ Foundation model continuation started: {response_data.get('message')}")
|
|
5196
|
+
|
|
5197
|
+
# Poll for completion if requested
|
|
5198
|
+
if poll_interval > 0 and max_poll_time > 0:
|
|
5199
|
+
import time
|
|
5200
|
+
start_time = time.time()
|
|
5201
|
+
last_status = ""
|
|
5202
|
+
|
|
5203
|
+
while time.time() - start_time < max_poll_time:
|
|
5204
|
+
try:
|
|
5205
|
+
session_info = self.get_session_status(session_id)
|
|
5206
|
+
jobs = session_info.jobs if hasattr(session_info, 'jobs') else {}
|
|
5207
|
+
|
|
5208
|
+
# Find continuation jobs
|
|
5209
|
+
es_jobs = {j_id: j for j_id, j in jobs.items()
|
|
5210
|
+
if j.get('type') == 'train_es'}
|
|
5211
|
+
|
|
5212
|
+
if not es_jobs:
|
|
5213
|
+
if verbose:
|
|
5214
|
+
print("✅ No continuation jobs found - training may have completed")
|
|
5215
|
+
break
|
|
5216
|
+
|
|
5217
|
+
# Check job statuses
|
|
5218
|
+
running_jobs = [j_id for j_id, j in es_jobs.items() if j.get('status') == 'running']
|
|
5219
|
+
completed_jobs = [j_id for j_id, j in es_jobs.items() if j.get('status') == 'done']
|
|
5220
|
+
failed_jobs = [j_id for j_id, j in es_jobs.items() if j.get('status') == 'failed']
|
|
5221
|
+
|
|
5222
|
+
current_status = f"Running: {len(running_jobs)}, Done: {len(completed_jobs)}, Failed: {len(failed_jobs)}"
|
|
5223
|
+
if current_status != last_status and verbose:
|
|
5224
|
+
print(f"📊 Status: {current_status}")
|
|
5225
|
+
last_status = current_status
|
|
5226
|
+
|
|
5227
|
+
if not running_jobs and (completed_jobs or failed_jobs):
|
|
5228
|
+
if completed_jobs:
|
|
5229
|
+
if verbose:
|
|
5230
|
+
print(f"✅ Foundation model continuation completed successfully!")
|
|
5231
|
+
return {
|
|
5232
|
+
"message": "Foundation model continuation completed successfully",
|
|
5233
|
+
"session_id": session_id,
|
|
5234
|
+
"status": "completed",
|
|
5235
|
+
"additional_epochs": data_passes
|
|
5236
|
+
}
|
|
5237
|
+
else:
|
|
5238
|
+
if verbose:
|
|
5239
|
+
print(f"❌ Foundation model continuation failed")
|
|
5240
|
+
return {
|
|
5241
|
+
"message": "Foundation model continuation failed",
|
|
5242
|
+
"session_id": session_id,
|
|
5243
|
+
"status": "failed",
|
|
5244
|
+
"failed_jobs": failed_jobs
|
|
5245
|
+
}
|
|
5246
|
+
|
|
5247
|
+
time.sleep(poll_interval)
|
|
5248
|
+
except Exception as poll_error:
|
|
5249
|
+
if verbose:
|
|
5250
|
+
print(f"⚠️ Error during polling: {poll_error}")
|
|
5251
|
+
time.sleep(poll_interval)
|
|
5252
|
+
|
|
5253
|
+
# Timeout
|
|
5254
|
+
if verbose:
|
|
5255
|
+
print(f"⏱️ Polling timeout reached ({max_poll_time}s)")
|
|
5256
|
+
return {
|
|
5257
|
+
"message": "Polling timeout",
|
|
5258
|
+
"session_id": session_id,
|
|
5259
|
+
"status": "timeout",
|
|
5260
|
+
"additional_epochs": data_passes
|
|
5261
|
+
}
|
|
5262
|
+
|
|
5263
|
+
return response_data
|
|
5264
|
+
|
|
5265
|
+
except Exception as e:
|
|
5266
|
+
if verbose:
|
|
5267
|
+
print(f"❌ Error starting foundation model continuation: {e}")
|
|
5268
|
+
raise
|
|
5269
|
+
|
|
4891
5270
|
def _train_single_predictor_with_file(
|
|
4892
5271
|
self,
|
|
4893
5272
|
session_id: str,
|
|
@@ -5965,7 +6344,24 @@ class FeatrixSphereClient:
|
|
|
5965
6344
|
if not file_path.exists():
|
|
5966
6345
|
raise FileNotFoundError(f"File not found: {file_path}")
|
|
5967
6346
|
|
|
5968
|
-
|
|
6347
|
+
# Support CSV, Parquet, JSON, and JSONL files
|
|
6348
|
+
file_path_str = str(file_path).lower()
|
|
6349
|
+
if file_path_str.endswith('.parquet'):
|
|
6350
|
+
df = pd.read_parquet(file_path)
|
|
6351
|
+
elif file_path_str.endswith('.jsonl'):
|
|
6352
|
+
# JSONL: one JSON object per line
|
|
6353
|
+
import json
|
|
6354
|
+
records = []
|
|
6355
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
|
6356
|
+
for line in f:
|
|
6357
|
+
if line.strip():
|
|
6358
|
+
records.append(json.loads(line))
|
|
6359
|
+
df = pd.DataFrame(records)
|
|
6360
|
+
elif file_path_str.endswith('.json'):
|
|
6361
|
+
# Regular JSON
|
|
6362
|
+
df = pd.read_json(file_path)
|
|
6363
|
+
else:
|
|
6364
|
+
df = pd.read_csv(file_path)
|
|
5969
6365
|
|
|
5970
6366
|
# Convert to JSON Tables format and clean NaNs
|
|
5971
6367
|
table_data = JSONTablesEncoder.from_dataframe(df)
|
|
@@ -6119,11 +6515,11 @@ class FeatrixSphereClient:
|
|
|
6119
6515
|
def run_csv_predictions(self, session_id: str, csv_file: str, target_column: str = None,
|
|
6120
6516
|
sample_size: int = None, remove_target: bool = True) -> Dict[str, Any]:
|
|
6121
6517
|
"""
|
|
6122
|
-
Run predictions on a CSV file with automatic accuracy calculation.
|
|
6518
|
+
Run predictions on a CSV, Parquet, JSON, or JSONL file with automatic accuracy calculation.
|
|
6123
6519
|
|
|
6124
6520
|
Args:
|
|
6125
6521
|
session_id: ID of session with trained predictor
|
|
6126
|
-
csv_file: Path to CSV file
|
|
6522
|
+
csv_file: Path to CSV, Parquet, JSON, or JSONL file
|
|
6127
6523
|
target_column: Name of target column (for accuracy calculation)
|
|
6128
6524
|
sample_size: Number of records to test (None = all records)
|
|
6129
6525
|
remove_target: Whether to remove target column from prediction input
|
|
@@ -6133,8 +6529,24 @@ class FeatrixSphereClient:
|
|
|
6133
6529
|
"""
|
|
6134
6530
|
import pandas as pd
|
|
6135
6531
|
|
|
6136
|
-
# Load CSV
|
|
6137
|
-
|
|
6532
|
+
# Load CSV, Parquet, JSON, or JSONL
|
|
6533
|
+
csv_file_lower = csv_file.lower()
|
|
6534
|
+
if csv_file_lower.endswith('.parquet'):
|
|
6535
|
+
df = pd.read_parquet(csv_file)
|
|
6536
|
+
elif csv_file_lower.endswith('.jsonl'):
|
|
6537
|
+
# JSONL: one JSON object per line
|
|
6538
|
+
import json
|
|
6539
|
+
records = []
|
|
6540
|
+
with open(csv_file, 'r', encoding='utf-8') as f:
|
|
6541
|
+
for line in f:
|
|
6542
|
+
if line.strip():
|
|
6543
|
+
records.append(json.loads(line))
|
|
6544
|
+
df = pd.DataFrame(records)
|
|
6545
|
+
elif csv_file_lower.endswith('.json'):
|
|
6546
|
+
# Regular JSON
|
|
6547
|
+
df = pd.read_json(csv_file)
|
|
6548
|
+
else:
|
|
6549
|
+
df = pd.read_csv(csv_file)
|
|
6138
6550
|
|
|
6139
6551
|
# Handle target column
|
|
6140
6552
|
actual_values = None
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
featrixsphere/__init__.py,sha256=P6lNvrcNRFDLvHiBlcFeM-UTp0SymQm5msQQHX6W8XU,1888
|
|
2
|
+
featrixsphere/cli.py,sha256=AW9O3vCvCNJ2UxVGN66eRmeN7XLSiHJlvK6JLZ9UJXc,13358
|
|
3
|
+
featrixsphere/client.py,sha256=lwPu6RQVA5zps9DTYrVlPhRv8_8PEdFGI11hezLSEyA,379068
|
|
4
|
+
featrixsphere/test_client.py,sha256=4SiRbib0ms3poK0UpnUv4G0HFQSzidF3Iswo_J2cjLk,11981
|
|
5
|
+
featrixsphere-0.2.1221.dist-info/METADATA,sha256=k8SrUaI9oa3JaTObnLVQGMA-iyoZM7wLdWnzluqh-Ms,16232
|
|
6
|
+
featrixsphere-0.2.1221.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
7
|
+
featrixsphere-0.2.1221.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
|
|
8
|
+
featrixsphere-0.2.1221.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
|
|
9
|
+
featrixsphere-0.2.1221.dist-info/RECORD,,
|
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
featrixsphere/__init__.py,sha256=FMxe64cn4iu9Ce5UDkOAtWZQMeWSijwX-tsiTDvblkM,1888
|
|
2
|
-
featrixsphere/cli.py,sha256=AW9O3vCvCNJ2UxVGN66eRmeN7XLSiHJlvK6JLZ9UJXc,13358
|
|
3
|
-
featrixsphere/client.py,sha256=TsiV-nr0VbBS1jJfidk5zrhOx6StolKsSn_txH0wmmg,358958
|
|
4
|
-
featrixsphere/test_client.py,sha256=4SiRbib0ms3poK0UpnUv4G0HFQSzidF3Iswo_J2cjLk,11981
|
|
5
|
-
featrixsphere-0.2.1141.dist-info/METADATA,sha256=27KEfgeXQqUNAlO3HIFhYkJU43YN3RdCjTJ_-viNJow,16232
|
|
6
|
-
featrixsphere-0.2.1141.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
7
|
-
featrixsphere-0.2.1141.dist-info/entry_points.txt,sha256=QreJeYfD_VWvbEqPmMXZ3pqqlFlJ1qZb-NtqnyhEldc,51
|
|
8
|
-
featrixsphere-0.2.1141.dist-info/top_level.txt,sha256=AyN4wjfzlD0hWnDieuEHX0KckphIk_aC73XCG4df5uU,14
|
|
9
|
-
featrixsphere-0.2.1141.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|