fundamental-client 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fundamental/__init__.py +34 -0
- fundamental/clients/__init__.py +7 -0
- fundamental/clients/base.py +37 -0
- fundamental/clients/ec2.py +37 -0
- fundamental/clients/fundamental.py +20 -0
- fundamental/config.py +138 -0
- fundamental/constants.py +41 -0
- fundamental/deprecated.py +43 -0
- fundamental/estimator/__init__.py +16 -0
- fundamental/estimator/base.py +263 -0
- fundamental/estimator/classification.py +46 -0
- fundamental/estimator/nexus_estimator.py +120 -0
- fundamental/estimator/regression.py +22 -0
- fundamental/exceptions.py +78 -0
- fundamental/models/__init__.py +4 -0
- fundamental/models/generated.py +431 -0
- fundamental/services/__init__.py +25 -0
- fundamental/services/feature_importance.py +172 -0
- fundamental/services/inference.py +283 -0
- fundamental/services/models.py +186 -0
- fundamental/utils/__init__.py +0 -0
- fundamental/utils/data.py +437 -0
- fundamental/utils/http.py +294 -0
- fundamental/utils/polling.py +97 -0
- fundamental/utils/safetensors_deserialize.py +98 -0
- fundamental_client-0.2.3.dist-info/METADATA +241 -0
- fundamental_client-0.2.3.dist-info/RECORD +29 -0
- fundamental_client-0.2.3.dist-info/WHEEL +4 -0
- fundamental_client-0.2.3.dist-info/licenses/LICENSE +201 -0
fundamental/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Fundamental Python SDK for tabular machine learning."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from fundamental.clients import BaseClient, Fundamental, FundamentalEC2Client
|
|
6
|
+
|
|
7
|
+
# Deprecated aliases
|
|
8
|
+
from fundamental.deprecated import FTMClassifier, FTMRegressor
|
|
9
|
+
from fundamental.estimator import NEXUSClassifier, NEXUSRegressor
|
|
10
|
+
|
|
11
|
+
__version__ = "0.1.4"
|
|
12
|
+
|
|
13
|
+
_global_client: Optional[BaseClient] = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def set_client(client: BaseClient) -> None:
|
|
17
|
+
global _global_client
|
|
18
|
+
_global_client = client
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_client() -> BaseClient:
|
|
22
|
+
return _global_client or Fundamental()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"FTMClassifier",
|
|
27
|
+
"FTMRegressor",
|
|
28
|
+
"Fundamental",
|
|
29
|
+
"FundamentalEC2Client",
|
|
30
|
+
"NEXUSClassifier",
|
|
31
|
+
"NEXUSRegressor",
|
|
32
|
+
"get_client",
|
|
33
|
+
"set_client",
|
|
34
|
+
]
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""Client implementations for Fundamental API."""
|
|
2
|
+
|
|
3
|
+
from fundamental.clients.base import BaseClient
|
|
4
|
+
from fundamental.clients.ec2 import FundamentalEC2Client
|
|
5
|
+
from fundamental.clients.fundamental import Fundamental
|
|
6
|
+
|
|
7
|
+
__all__ = ["BaseClient", "Fundamental", "FundamentalEC2Client"]
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Base client for API interactions."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from typing import TYPE_CHECKING, Dict, Optional
|
|
5
|
+
|
|
6
|
+
from fundamental.config import Config
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from fundamental.services.models import ModelsService
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseClient(ABC):
|
|
13
|
+
"""Abstract base client for API interactions.
|
|
14
|
+
|
|
15
|
+
All client implementations should inherit from this.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
config: Config
|
|
19
|
+
models: "ModelsService"
|
|
20
|
+
_trace_dict: Optional[Dict[str, str]]
|
|
21
|
+
|
|
22
|
+
def __init__(self, *, config: Config) -> None:
|
|
23
|
+
"""Initialize the base client with service properties."""
|
|
24
|
+
# Local import to avoid circular dependency at runtime
|
|
25
|
+
from fundamental.services.models import ModelsService
|
|
26
|
+
|
|
27
|
+
self.config = config
|
|
28
|
+
self.models = ModelsService(client=self)
|
|
29
|
+
self._trace_dict = None
|
|
30
|
+
|
|
31
|
+
def get_trace_dict(self) -> Optional[Dict[str, str]]:
|
|
32
|
+
"""Get the current trace dictionary for workflow correlation."""
|
|
33
|
+
return self._trace_dict
|
|
34
|
+
|
|
35
|
+
def _set_trace_dict(self, trace_dict: Optional[Dict[str, str]]) -> None:
|
|
36
|
+
"""Set the current trace dictionary (internal use only)."""
|
|
37
|
+
self._trace_dict = trace_dict
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Fundamental EC2 client for private deployment with AWS SigV4 authentication."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
from fundamental.clients.base import BaseClient
|
|
6
|
+
from fundamental.config import Config
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FundamentalEC2Client(BaseClient):
|
|
10
|
+
"""Client for Fundamental API on private EC2 deployments using AWS SigV4 authentication.
|
|
11
|
+
|
|
12
|
+
This client is pre-configured for private deployment mode where authentication
|
|
13
|
+
is done via AWS IAM roles instead of API keys. It uses SigV4 signing to
|
|
14
|
+
authenticate requests using the AWS credentials available in the environment.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
aws_region: str,
|
|
20
|
+
api_url: Optional[str] = None,
|
|
21
|
+
**kwargs: Any, # noqa: ANN401
|
|
22
|
+
):
|
|
23
|
+
"""Initialize the Fundamental EC2 client.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
aws_region: AWS region for SigV4 signing (required)
|
|
27
|
+
api_url: Base URL for the private deployment API server (API Gateway URL).
|
|
28
|
+
Defaults to FUNDAMENTAL_API_URL environment variable if not provided.
|
|
29
|
+
**kwargs: Additional configuration options passed to Config
|
|
30
|
+
"""
|
|
31
|
+
cfg = Config(
|
|
32
|
+
api_url=api_url or "",
|
|
33
|
+
use_sigv4_auth=True,
|
|
34
|
+
aws_region=aws_region,
|
|
35
|
+
**kwargs,
|
|
36
|
+
)
|
|
37
|
+
super().__init__(config=cfg)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Fundamental client implementation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from fundamental.clients.base import BaseClient
|
|
6
|
+
from fundamental.config import Config
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Fundamental(BaseClient):
|
|
10
|
+
"""Default client for Fundamental API interactions."""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
**kwargs: Any, # noqa: ANN401
|
|
15
|
+
):
|
|
16
|
+
"""Initialize the Fundamental client."""
|
|
17
|
+
cfg = Config(
|
|
18
|
+
**kwargs,
|
|
19
|
+
)
|
|
20
|
+
super().__init__(config=cfg)
|
fundamental/config.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""Configuration management for Fundamental SDK."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from importlib.metadata import version
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from fundamental.constants import (
|
|
9
|
+
API_URL,
|
|
10
|
+
COMPLETE_MULTIPART_UPLOAD_PATH,
|
|
11
|
+
DEFAULT_DOWNLOAD_RESULT_TIMEOUT_SECONDS,
|
|
12
|
+
DEFAULT_FEATURE_IMPORTANCE_TIMEOUT_SECONDS,
|
|
13
|
+
DEFAULT_FIT_TIMEOUT_SECONDS,
|
|
14
|
+
DEFAULT_PREDICT_TIMEOUT_SECONDS,
|
|
15
|
+
DEFAULT_RETRIES_COUNT,
|
|
16
|
+
DEFAULT_TIMEOUT_SECONDS,
|
|
17
|
+
FEATURE_IMPORTANCE_MODEL_METADATA_GENERATE_PATH,
|
|
18
|
+
FEATURE_IMPORTANCE_PATH,
|
|
19
|
+
FEATURE_IMPORTANCE_STATUS_PATH,
|
|
20
|
+
FIT_MODEL_METADATA_GENERATE_PATH,
|
|
21
|
+
FIT_PATH,
|
|
22
|
+
FIT_STATUS_PATH,
|
|
23
|
+
MODEL_MANAGEMENT_PATH,
|
|
24
|
+
PREDICT_MODEL_METADATA_GENERATE_PATH,
|
|
25
|
+
PREDICT_PATH,
|
|
26
|
+
PREDICT_STATUS_PATH,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class Config:
|
|
32
|
+
"""Immutable configuration for Fundamental SDK.
|
|
33
|
+
|
|
34
|
+
Attributes:
|
|
35
|
+
api_key: API key for authentication (defaults to FUNDAMENTAL_API_KEY env var)
|
|
36
|
+
api_url: Base URL for the API (defaults to FUNDAMENTAL_API_URL env var or API_URL)
|
|
37
|
+
retries: Number of retries for failed requests
|
|
38
|
+
timeout: Default timeout in seconds
|
|
39
|
+
fit_timeout: Timeout for fit operations
|
|
40
|
+
predict_timeout: Timeout for predict operations
|
|
41
|
+
feature_importance_timeout: Timeout for feature importance operations
|
|
42
|
+
download_prediction_result_timeout: Timeout for downloading prediction results
|
|
43
|
+
download_feature_importance_result_timeout: Timeout for downloading feature
|
|
44
|
+
importance results
|
|
45
|
+
use_sigv4_auth: Use AWS SigV4 authentication instead of API key (for private deployment)
|
|
46
|
+
aws_region: AWS region for SigV4 signing (required when use_sigv4_auth=True)
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
api_key: str = ""
|
|
50
|
+
api_url: str = ""
|
|
51
|
+
retries: int = DEFAULT_RETRIES_COUNT
|
|
52
|
+
timeout: int = DEFAULT_TIMEOUT_SECONDS
|
|
53
|
+
fit_timeout: int = DEFAULT_FIT_TIMEOUT_SECONDS
|
|
54
|
+
predict_timeout: int = DEFAULT_PREDICT_TIMEOUT_SECONDS
|
|
55
|
+
feature_importance_timeout: int = DEFAULT_FEATURE_IMPORTANCE_TIMEOUT_SECONDS
|
|
56
|
+
download_prediction_result_timeout: int = DEFAULT_DOWNLOAD_RESULT_TIMEOUT_SECONDS
|
|
57
|
+
download_feature_importance_result_timeout: int = DEFAULT_DOWNLOAD_RESULT_TIMEOUT_SECONDS
|
|
58
|
+
use_sigv4_auth: bool = False
|
|
59
|
+
aws_region: str = ""
|
|
60
|
+
|
|
61
|
+
def __post_init__(self) -> None:
|
|
62
|
+
"""Initialize config with environment variables if not provided."""
|
|
63
|
+
api_key: Optional[str] = self.api_key or os.getenv("FUNDAMENTAL_API_KEY")
|
|
64
|
+
object.__setattr__(self, "api_key", api_key)
|
|
65
|
+
url: str = self.api_url or os.getenv("FUNDAMENTAL_API_URL") or API_URL
|
|
66
|
+
object.__setattr__(self, "api_url", url)
|
|
67
|
+
|
|
68
|
+
if self.use_sigv4_auth:
|
|
69
|
+
# Validate AWS region is provided when using SigV4 auth
|
|
70
|
+
if not self.aws_region:
|
|
71
|
+
raise ValueError(
|
|
72
|
+
"AWS region is required when using SigV4 authentication. "
|
|
73
|
+
"Please provide it via aws_region parameter."
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
# Validate API key when not using SigV4 auth
|
|
77
|
+
_ = self.get_api_key() # quick fail on missing api keys
|
|
78
|
+
|
|
79
|
+
def get_api_key(self) -> str:
|
|
80
|
+
"""Get the API key, raising an error if not set."""
|
|
81
|
+
if not self.api_key:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
"API key is required. Please provide it via:\n"
|
|
84
|
+
"1. Fundamental(api_key='your-key')\n"
|
|
85
|
+
"2. Set FUNDAMENTAL_API_KEY environment variable"
|
|
86
|
+
)
|
|
87
|
+
return self.api_key
|
|
88
|
+
|
|
89
|
+
def get_full_fit_url(self) -> str:
|
|
90
|
+
"""Get the complete URL for the fit endpoint."""
|
|
91
|
+
return f"{self.api_url.rstrip('/')}{FIT_PATH}"
|
|
92
|
+
|
|
93
|
+
def get_full_predict_url(self) -> str:
|
|
94
|
+
"""Get the complete URL for the predict endpoint."""
|
|
95
|
+
return f"{self.api_url.rstrip('/')}{PREDICT_PATH}"
|
|
96
|
+
|
|
97
|
+
def get_full_model_management_url(self) -> str:
|
|
98
|
+
"""Get the complete URL for the model management endpoint."""
|
|
99
|
+
return f"{self.api_url.rstrip('/')}{MODEL_MANAGEMENT_PATH}"
|
|
100
|
+
|
|
101
|
+
def get_full_fit_model_metadata_generate_url(self) -> str:
|
|
102
|
+
"""Get the complete URL for the model metadata generation endpoint."""
|
|
103
|
+
return f"{self.api_url.rstrip('/')}{FIT_MODEL_METADATA_GENERATE_PATH}"
|
|
104
|
+
|
|
105
|
+
def get_full_predict_model_metadata_generate_url(self) -> str:
|
|
106
|
+
"""Get the complete URL for the predict metadata generation endpoint."""
|
|
107
|
+
return f"{self.api_url.rstrip('/')}{PREDICT_MODEL_METADATA_GENERATE_PATH}"
|
|
108
|
+
|
|
109
|
+
def get_full_complete_multipart_upload_url(self) -> str:
|
|
110
|
+
"""Get the complete URL for the complete multipart upload endpoint."""
|
|
111
|
+
return f"{self.api_url.rstrip('/')}{COMPLETE_MULTIPART_UPLOAD_PATH}"
|
|
112
|
+
|
|
113
|
+
def get_full_fit_status_url(self) -> str:
|
|
114
|
+
"""Get the complete URL for the fit status endpoint."""
|
|
115
|
+
return f"{self.api_url.rstrip('/')}{FIT_STATUS_PATH}"
|
|
116
|
+
|
|
117
|
+
def get_full_predict_status_url(self) -> str:
|
|
118
|
+
"""Get the complete URL for the predict status endpoint."""
|
|
119
|
+
return f"{self.api_url.rstrip('/')}{PREDICT_STATUS_PATH}"
|
|
120
|
+
|
|
121
|
+
def get_full_feature_importance_url(self) -> str:
|
|
122
|
+
"""Get the complete URL for the feature importance endpoint."""
|
|
123
|
+
return f"{self.api_url.rstrip('/')}{FEATURE_IMPORTANCE_PATH}"
|
|
124
|
+
|
|
125
|
+
def get_full_feature_importance_status_url(self) -> str:
|
|
126
|
+
"""Get the complete URL for the feature importance status endpoint."""
|
|
127
|
+
return f"{self.api_url.rstrip('/')}{FEATURE_IMPORTANCE_STATUS_PATH}"
|
|
128
|
+
|
|
129
|
+
def get_full_feature_importance_model_metadata_generate_url(self) -> str:
|
|
130
|
+
"""Get the complete URL for the feature importance metadata generation endpoint."""
|
|
131
|
+
return f"{self.api_url.rstrip('/')}{FEATURE_IMPORTANCE_MODEL_METADATA_GENERATE_PATH}"
|
|
132
|
+
|
|
133
|
+
def get_version(self) -> str:
|
|
134
|
+
"""Get the SDK version."""
|
|
135
|
+
try:
|
|
136
|
+
return version("fundamental")
|
|
137
|
+
except Exception:
|
|
138
|
+
return "unknown"
|
fundamental/constants.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
API_URL = "https://api.fundamental.tech"
|
|
2
|
+
FIT_PATH = "/api/v1/model/fit"
|
|
3
|
+
PREDICT_PATH = "/api/v1/model/predict"
|
|
4
|
+
FIT_STATUS_PATH = "/api/v1/model/fit/status"
|
|
5
|
+
PREDICT_STATUS_PATH = "/api/v1/model/predict/status"
|
|
6
|
+
FEATURE_IMPORTANCE_PATH = "/api/v1/model/feature-importance"
|
|
7
|
+
FEATURE_IMPORTANCE_STATUS_PATH = "/api/v1/model/feature-importance/status"
|
|
8
|
+
FEATURE_IMPORTANCE_MODEL_METADATA_GENERATE_PATH = "/api/v1/model/feature-importance/create-metadata"
|
|
9
|
+
MODEL_MANAGEMENT_PATH = "/api/v1/model-management/trained-models"
|
|
10
|
+
FIT_MODEL_METADATA_GENERATE_PATH = "/api/v1/model/fit/create-metadata"
|
|
11
|
+
PREDICT_MODEL_METADATA_GENERATE_PATH = "/api/v1/model/predict/create-metadata"
|
|
12
|
+
COMPLETE_MULTIPART_UPLOAD_PATH = "/api/v1/model/complete-multipart-upload"
|
|
13
|
+
|
|
14
|
+
DEFAULT_TIMEOUT_SECONDS = 30
|
|
15
|
+
DEFAULT_RETRIES_COUNT = 1
|
|
16
|
+
|
|
17
|
+
# Shared default constants
|
|
18
|
+
DEFAULT_POLLING_INTERVAL_SECONDS = 2
|
|
19
|
+
DEFAULT_SUBMIT_REQUEST_TIMEOUT_SECONDS = 30
|
|
20
|
+
DEFAULT_DOWNLOAD_RESULT_TIMEOUT_SECONDS = 60 * 60 * 3 # 3 hours
|
|
21
|
+
DEFAULT_MODEL_METADATA_UPLOAD_TIMEOUT_SECONDS = 60 * 60 * 1 # 1 hour
|
|
22
|
+
DEFAULT_COMPLETE_MULTIPART_UPLOAD_TIMEOUT_SECONDS = 30
|
|
23
|
+
DEFAULT_MODEL_METADATA_GENERATE_TIMEOUT_SECONDS = 30
|
|
24
|
+
|
|
25
|
+
# Predict-specific
|
|
26
|
+
DEFAULT_PREDICT_POLLING_REQUESTS_WITHOUT_DELAY = 100
|
|
27
|
+
DEFAULT_PREDICT_TIMEOUT_SECONDS = 60 * 60 # 1 hour
|
|
28
|
+
|
|
29
|
+
# Feature importance-specific
|
|
30
|
+
DEFAULT_FEATURE_IMPORTANCE_TIMEOUT_SECONDS = 60 * 60 * 24 # 24 hours
|
|
31
|
+
|
|
32
|
+
# Fit-specific
|
|
33
|
+
DEFAULT_FIT_TIMEOUT_SECONDS = 60 * 60 * 3 # 3 hours
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
X_TRAIN_FILE_NAME = "x_train_file"
|
|
37
|
+
Y_TRAIN_FILE_NAME = "y_train_file"
|
|
38
|
+
X_TEST_FILE_NAME = "x_test_file"
|
|
39
|
+
|
|
40
|
+
# EC2/SigV4 authentication
|
|
41
|
+
SIGV4_SERVICE_NAME = "execute-api"
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Deprecated aliases for backward compatibility.
|
|
3
|
+
|
|
4
|
+
This module contains deprecated estimator class aliases that will be removed in a future version.
|
|
5
|
+
Estimator classes emit FutureWarning when instantiated.
|
|
6
|
+
|
|
7
|
+
Note: FTMError alias is in fundamental.exceptions (not here) to avoid circular imports.
|
|
8
|
+
|
|
9
|
+
TO REMOVE:
|
|
10
|
+
1. Delete this file
|
|
11
|
+
2. Remove imports from fundamental/__init__.py
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import warnings
|
|
15
|
+
from typing import Literal
|
|
16
|
+
|
|
17
|
+
from fundamental.estimator.classification import NEXUSClassifier
|
|
18
|
+
from fundamental.estimator.regression import NEXUSRegressor
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _deprecation_warning(old_name: str, new_name: str) -> None:
|
|
22
|
+
"""Emit a FutureWarning for renamed classes."""
|
|
23
|
+
warnings.warn(
|
|
24
|
+
f"{old_name} is deprecated. Use {new_name} instead.",
|
|
25
|
+
FutureWarning,
|
|
26
|
+
stacklevel=3,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class FTMClassifier(NEXUSClassifier):
|
|
31
|
+
"""Deprecated: Use NEXUSClassifier instead."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, mode: Literal["quality", "speed"] = "quality"):
|
|
34
|
+
_deprecation_warning("FTMClassifier", "NEXUSClassifier")
|
|
35
|
+
super().__init__(mode=mode)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class FTMRegressor(NEXUSRegressor):
|
|
39
|
+
"""Deprecated: Use NEXUSRegressor instead."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, mode: Literal["quality", "speed"] = "quality"):
|
|
42
|
+
_deprecation_warning("FTMRegressor", "NEXUSRegressor")
|
|
43
|
+
super().__init__(mode=mode)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
NEXUS Estimators Module
|
|
3
|
+
|
|
4
|
+
This module provides scikit-learn compatible estimators for the NEXUS foundation
|
|
5
|
+
model, organized by task type.
|
|
6
|
+
|
|
7
|
+
Note: Deprecated FTMClassifier/FTMRegressor aliases are available from `fundamental` package.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from fundamental.estimator.classification import NEXUSClassifier
|
|
11
|
+
from fundamental.estimator.regression import NEXUSRegressor
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"NEXUSClassifier",
|
|
15
|
+
"NEXUSRegressor",
|
|
16
|
+
]
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base classes and model registry for NEXUS estimators.
|
|
3
|
+
|
|
4
|
+
This module contains the ModelRegistry for managing available models and
|
|
5
|
+
the abstract BaseNEXUSEstimator class that provides common functionality
|
|
6
|
+
for both classification and regression tasks.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from abc import ABC
|
|
11
|
+
from typing import Any, ClassVar, Literal, Optional
|
|
12
|
+
|
|
13
|
+
import numpy.typing as npt
|
|
14
|
+
from sklearn.base import BaseEstimator
|
|
15
|
+
from sklearn.utils.validation import check_is_fitted
|
|
16
|
+
from typing_extensions import Self
|
|
17
|
+
|
|
18
|
+
import fundamental
|
|
19
|
+
from fundamental.clients import BaseClient
|
|
20
|
+
from fundamental.services import poll_fit_result, remote_fit, remote_predict, submit_fit_task
|
|
21
|
+
from fundamental.services.models import ModelsService
|
|
22
|
+
from fundamental.utils.data import (
|
|
23
|
+
XType,
|
|
24
|
+
YType,
|
|
25
|
+
check_n_features_compat,
|
|
26
|
+
validate_data,
|
|
27
|
+
validate_inputs_type,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BaseNEXUSEstimator(BaseEstimator, ABC): # type: ignore[misc]
|
|
34
|
+
"""
|
|
35
|
+
Abstract base class for NEXUS estimators.
|
|
36
|
+
|
|
37
|
+
Provides common functionality and interface for both classification
|
|
38
|
+
and regression estimators.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
_task_type: ClassVar[Literal["classification", "regression"]]
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
mode: Literal["quality", "speed"] = "quality",
|
|
46
|
+
):
|
|
47
|
+
"""Initialize the estimator.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
mode: Training mode - "quality" for best performance, "speed" for faster training.
|
|
51
|
+
"""
|
|
52
|
+
self.mode: Literal["quality", "speed"] = mode
|
|
53
|
+
|
|
54
|
+
def _get_client(self) -> BaseClient:
|
|
55
|
+
return fundamental.get_client()
|
|
56
|
+
|
|
57
|
+
def _set_n_features_in(self, X: XType) -> None:
|
|
58
|
+
"""Set n_features_in_ attribute for sklearn compatibility.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
X: Input features array or dataframe.
|
|
62
|
+
"""
|
|
63
|
+
self.n_features_in_ = X.shape[1]
|
|
64
|
+
|
|
65
|
+
def load_model(self, trained_model_id: str) -> Self:
|
|
66
|
+
"""
|
|
67
|
+
Load a trained model.
|
|
68
|
+
"""
|
|
69
|
+
self._load_trained_model(trained_model_id)
|
|
70
|
+
self.fitted_ = True
|
|
71
|
+
return self
|
|
72
|
+
|
|
73
|
+
def fit(self, X: XType, y: YType) -> Self:
|
|
74
|
+
"""
|
|
75
|
+
Fit the model to training data.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
X : XType
|
|
80
|
+
Training features as numpy array, pandas DataFrame.
|
|
81
|
+
y : YType
|
|
82
|
+
Training targets as numpy array, pandas Series.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
Self
|
|
87
|
+
Self for method chaining.
|
|
88
|
+
"""
|
|
89
|
+
validate_inputs_type(X=X, y=y)
|
|
90
|
+
validate_data(X=X, y=y)
|
|
91
|
+
|
|
92
|
+
response = remote_fit(
|
|
93
|
+
X=X,
|
|
94
|
+
y=y,
|
|
95
|
+
task=type(self)._task_type,
|
|
96
|
+
mode=self.mode,
|
|
97
|
+
client=self._get_client(),
|
|
98
|
+
)
|
|
99
|
+
self._load_estimator_fields(response.estimator_fields)
|
|
100
|
+
self.trained_model_id_ = response.trained_model_id
|
|
101
|
+
self.fitted_ = True
|
|
102
|
+
self._set_n_features_in(X)
|
|
103
|
+
return self
|
|
104
|
+
|
|
105
|
+
def set_attributes(
|
|
106
|
+
self,
|
|
107
|
+
attributes: dict[str, str],
|
|
108
|
+
) -> Self:
|
|
109
|
+
"""
|
|
110
|
+
Set attributes on the fitted model.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
attributes : dict[str, str]
|
|
115
|
+
Dictionary of key-value pairs to set as model attributes.
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
Self
|
|
120
|
+
Self for method chaining.
|
|
121
|
+
|
|
122
|
+
Raises
|
|
123
|
+
------
|
|
124
|
+
NotFittedError
|
|
125
|
+
If the model has not been fitted yet.
|
|
126
|
+
|
|
127
|
+
Examples
|
|
128
|
+
--------
|
|
129
|
+
>>> clf.set_attributes({"stage": "prod", "owner": "lital"})
|
|
130
|
+
"""
|
|
131
|
+
check_is_fitted(self)
|
|
132
|
+
assert self.trained_model_id_ is not None
|
|
133
|
+
|
|
134
|
+
models_service = ModelsService(client=self._get_client())
|
|
135
|
+
models_service.set_attributes(
|
|
136
|
+
model_id=self.trained_model_id_,
|
|
137
|
+
attributes=attributes,
|
|
138
|
+
)
|
|
139
|
+
return self
|
|
140
|
+
|
|
141
|
+
def submit_fit_task(self, X: XType, y: YType) -> str:
|
|
142
|
+
"""
|
|
143
|
+
Submit a fit task without waiting for completion.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
X : XType
|
|
148
|
+
Training features as numpy array, pandas DataFrame.
|
|
149
|
+
y : YType
|
|
150
|
+
Training targets as numpy array, pandas Series.
|
|
151
|
+
|
|
152
|
+
Returns
|
|
153
|
+
-------
|
|
154
|
+
str
|
|
155
|
+
Task ID for polling with poll_fit_result.
|
|
156
|
+
"""
|
|
157
|
+
validate_inputs_type(X=X, y=y)
|
|
158
|
+
validate_data(X=X, y=y)
|
|
159
|
+
result = submit_fit_task(
|
|
160
|
+
X=X,
|
|
161
|
+
y=y,
|
|
162
|
+
task=type(self)._task_type,
|
|
163
|
+
mode=self.mode,
|
|
164
|
+
client=self._get_client(),
|
|
165
|
+
)
|
|
166
|
+
# Store for poll_fit_result to use
|
|
167
|
+
self._pending_trained_model_id = result.trained_model_id
|
|
168
|
+
self._pending_n_features_in = X.shape[1]
|
|
169
|
+
return result.task_id
|
|
170
|
+
|
|
171
|
+
def poll_fit_result(self, task_id: str) -> Optional[Self]:
|
|
172
|
+
"""
|
|
173
|
+
Check fit task status and load model if complete.
|
|
174
|
+
|
|
175
|
+
Parameters
|
|
176
|
+
----------
|
|
177
|
+
task_id : str
|
|
178
|
+
The task ID returned by submit_fit_task.
|
|
179
|
+
|
|
180
|
+
Returns
|
|
181
|
+
-------
|
|
182
|
+
Optional[Self]
|
|
183
|
+
Self if fitting completed (for method chaining), None if still in progress.
|
|
184
|
+
"""
|
|
185
|
+
result = poll_fit_result(
|
|
186
|
+
task_id=task_id,
|
|
187
|
+
trained_model_id=self._pending_trained_model_id,
|
|
188
|
+
client=self._get_client(),
|
|
189
|
+
)
|
|
190
|
+
if result is not None:
|
|
191
|
+
self._load_estimator_fields(result.estimator_fields)
|
|
192
|
+
self.trained_model_id_ = result.trained_model_id
|
|
193
|
+
self.fitted_ = True
|
|
194
|
+
self.n_features_in_ = self._pending_n_features_in
|
|
195
|
+
# Clean up pending attributes
|
|
196
|
+
del self._pending_n_features_in
|
|
197
|
+
del self._pending_trained_model_id
|
|
198
|
+
return self
|
|
199
|
+
return None
|
|
200
|
+
|
|
201
|
+
def predict(self, X: XType) -> npt.NDArray[Any]:
|
|
202
|
+
"""
|
|
203
|
+
Predict preds function.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
X : XType
|
|
208
|
+
Input features as numpy array, pandas DataFrame.
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
Returns
|
|
212
|
+
-------
|
|
213
|
+
np.ndarray
|
|
214
|
+
Model predictions.
|
|
215
|
+
|
|
216
|
+
Raises
|
|
217
|
+
------
|
|
218
|
+
NotFittedError
|
|
219
|
+
If the model has not been fitted yet.
|
|
220
|
+
"""
|
|
221
|
+
return self._predict(X=X, output_type="preds")
|
|
222
|
+
|
|
223
|
+
def _predict(
|
|
224
|
+
self,
|
|
225
|
+
X: XType,
|
|
226
|
+
output_type: Literal["preds", "probas"],
|
|
227
|
+
) -> npt.NDArray[Any]:
|
|
228
|
+
"""Internal prediction method with configurable output type."""
|
|
229
|
+
validate_inputs_type(X=X)
|
|
230
|
+
validate_data(X=X)
|
|
231
|
+
check_is_fitted(self)
|
|
232
|
+
|
|
233
|
+
# Validate n_features_in_ for sklearn compatibility
|
|
234
|
+
check_n_features_compat(self, X, reset=False)
|
|
235
|
+
|
|
236
|
+
assert self.trained_model_id_ is not None
|
|
237
|
+
return remote_predict(
|
|
238
|
+
X=X,
|
|
239
|
+
output_type=output_type,
|
|
240
|
+
trained_model_id=self.trained_model_id_,
|
|
241
|
+
client=self._get_client(),
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def _load_trained_model(self, model_id: str) -> None:
|
|
245
|
+
try:
|
|
246
|
+
models_service = ModelsService(client=self._get_client())
|
|
247
|
+
response = models_service.load(
|
|
248
|
+
trained_model_id=model_id,
|
|
249
|
+
)
|
|
250
|
+
self._load_estimator_fields(response.estimator_fields)
|
|
251
|
+
self.trained_model_id_ = model_id
|
|
252
|
+
except Exception as e:
|
|
253
|
+
logger.error(f"Failed to load model {model_id}: {e}")
|
|
254
|
+
raise e
|
|
255
|
+
|
|
256
|
+
def _load_estimator_fields(self, estimator_fields: dict[str, Any]) -> None:
|
|
257
|
+
for field, value in estimator_fields.items():
|
|
258
|
+
try:
|
|
259
|
+
# Keep private attrs as-is, add underscore to public attrs (sklearn convention)
|
|
260
|
+
field_name = field if field.endswith("_") or field.startswith("_") else f"{field}_"
|
|
261
|
+
setattr(self, field_name, value)
|
|
262
|
+
except AttributeError:
|
|
263
|
+
logger.warning(f"Field {field} is not a valid attribute in the current estimator")
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""NEXUS Classification estimator."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
import numpy.typing as npt
|
|
7
|
+
from sklearn.base import ClassifierMixin
|
|
8
|
+
|
|
9
|
+
from fundamental.estimator.nexus_estimator import NEXUSEstimator
|
|
10
|
+
from fundamental.utils.data import XType
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class NEXUSClassifier(ClassifierMixin, NEXUSEstimator):
|
|
16
|
+
"""NEXUS Model for Classification Tasks."""
|
|
17
|
+
|
|
18
|
+
_task_type = "classification" # type: ignore[assignment]
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
mode: Literal["quality", "speed"] = "quality",
|
|
23
|
+
):
|
|
24
|
+
super().__init__(mode=mode)
|
|
25
|
+
|
|
26
|
+
def predict_proba(self, X: XType) -> npt.NDArray[Any]:
|
|
27
|
+
"""
|
|
28
|
+
Predict probabilities function.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
X : XType
|
|
33
|
+
Input features as numpy array, pandas DataFrame.
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
np.ndarray
|
|
39
|
+
Model probabilities.
|
|
40
|
+
|
|
41
|
+
Raises
|
|
42
|
+
------
|
|
43
|
+
NotFittedError
|
|
44
|
+
If the model has not been fitted yet.
|
|
45
|
+
"""
|
|
46
|
+
return self._predict(X=X, output_type="probas")
|