fundamental-client 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,34 @@
1
+ """Fundamental Python SDK for tabular machine learning."""
2
+
3
+ from typing import Optional
4
+
5
+ from fundamental.clients import BaseClient, Fundamental, FundamentalEC2Client
6
+
7
+ # Deprecated aliases
8
+ from fundamental.deprecated import FTMClassifier, FTMRegressor
9
+ from fundamental.estimator import NEXUSClassifier, NEXUSRegressor
10
+
11
+ __version__ = "0.1.4"
12
+
13
+ _global_client: Optional[BaseClient] = None
14
+
15
+
16
+ def set_client(client: BaseClient) -> None:
17
+ global _global_client
18
+ _global_client = client
19
+
20
+
21
+ def get_client() -> BaseClient:
22
+ return _global_client or Fundamental()
23
+
24
+
25
+ __all__ = [
26
+ "FTMClassifier",
27
+ "FTMRegressor",
28
+ "Fundamental",
29
+ "FundamentalEC2Client",
30
+ "NEXUSClassifier",
31
+ "NEXUSRegressor",
32
+ "get_client",
33
+ "set_client",
34
+ ]
@@ -0,0 +1,7 @@
1
+ """Client implementations for Fundamental API."""
2
+
3
+ from fundamental.clients.base import BaseClient
4
+ from fundamental.clients.ec2 import FundamentalEC2Client
5
+ from fundamental.clients.fundamental import Fundamental
6
+
7
+ __all__ = ["BaseClient", "Fundamental", "FundamentalEC2Client"]
@@ -0,0 +1,37 @@
1
+ """Base client for API interactions."""
2
+
3
+ from abc import ABC
4
+ from typing import TYPE_CHECKING, Dict, Optional
5
+
6
+ from fundamental.config import Config
7
+
8
+ if TYPE_CHECKING:
9
+ from fundamental.services.models import ModelsService
10
+
11
+
12
+ class BaseClient(ABC):
13
+ """Abstract base client for API interactions.
14
+
15
+ All client implementations should inherit from this.
16
+ """
17
+
18
+ config: Config
19
+ models: "ModelsService"
20
+ _trace_dict: Optional[Dict[str, str]]
21
+
22
+ def __init__(self, *, config: Config) -> None:
23
+ """Initialize the base client with service properties."""
24
+ # Local import to avoid circular dependency at runtime
25
+ from fundamental.services.models import ModelsService
26
+
27
+ self.config = config
28
+ self.models = ModelsService(client=self)
29
+ self._trace_dict = None
30
+
31
+ def get_trace_dict(self) -> Optional[Dict[str, str]]:
32
+ """Get the current trace dictionary for workflow correlation."""
33
+ return self._trace_dict
34
+
35
+ def _set_trace_dict(self, trace_dict: Optional[Dict[str, str]]) -> None:
36
+ """Set the current trace dictionary (internal use only)."""
37
+ self._trace_dict = trace_dict
@@ -0,0 +1,37 @@
1
+ """Fundamental EC2 client for private deployment with AWS SigV4 authentication."""
2
+
3
+ from typing import Any, Optional
4
+
5
+ from fundamental.clients.base import BaseClient
6
+ from fundamental.config import Config
7
+
8
+
9
+ class FundamentalEC2Client(BaseClient):
10
+ """Client for Fundamental API on private EC2 deployments using AWS SigV4 authentication.
11
+
12
+ This client is pre-configured for private deployment mode where authentication
13
+ is done via AWS IAM roles instead of API keys. It uses SigV4 signing to
14
+ authenticate requests using the AWS credentials available in the environment.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ aws_region: str,
20
+ api_url: Optional[str] = None,
21
+ **kwargs: Any, # noqa: ANN401
22
+ ):
23
+ """Initialize the Fundamental EC2 client.
24
+
25
+ Args:
26
+ aws_region: AWS region for SigV4 signing (required)
27
+ api_url: Base URL for the private deployment API server (API Gateway URL).
28
+ Defaults to FUNDAMENTAL_API_URL environment variable if not provided.
29
+ **kwargs: Additional configuration options passed to Config
30
+ """
31
+ cfg = Config(
32
+ api_url=api_url or "",
33
+ use_sigv4_auth=True,
34
+ aws_region=aws_region,
35
+ **kwargs,
36
+ )
37
+ super().__init__(config=cfg)
@@ -0,0 +1,20 @@
1
+ """Fundamental client implementation."""
2
+
3
+ from typing import Any
4
+
5
+ from fundamental.clients.base import BaseClient
6
+ from fundamental.config import Config
7
+
8
+
9
+ class Fundamental(BaseClient):
10
+ """Default client for Fundamental API interactions."""
11
+
12
+ def __init__(
13
+ self,
14
+ **kwargs: Any, # noqa: ANN401
15
+ ):
16
+ """Initialize the Fundamental client."""
17
+ cfg = Config(
18
+ **kwargs,
19
+ )
20
+ super().__init__(config=cfg)
fundamental/config.py ADDED
@@ -0,0 +1,138 @@
1
+ """Configuration management for Fundamental SDK."""
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from importlib.metadata import version
6
+ from typing import Optional
7
+
8
+ from fundamental.constants import (
9
+ API_URL,
10
+ COMPLETE_MULTIPART_UPLOAD_PATH,
11
+ DEFAULT_DOWNLOAD_RESULT_TIMEOUT_SECONDS,
12
+ DEFAULT_FEATURE_IMPORTANCE_TIMEOUT_SECONDS,
13
+ DEFAULT_FIT_TIMEOUT_SECONDS,
14
+ DEFAULT_PREDICT_TIMEOUT_SECONDS,
15
+ DEFAULT_RETRIES_COUNT,
16
+ DEFAULT_TIMEOUT_SECONDS,
17
+ FEATURE_IMPORTANCE_MODEL_METADATA_GENERATE_PATH,
18
+ FEATURE_IMPORTANCE_PATH,
19
+ FEATURE_IMPORTANCE_STATUS_PATH,
20
+ FIT_MODEL_METADATA_GENERATE_PATH,
21
+ FIT_PATH,
22
+ FIT_STATUS_PATH,
23
+ MODEL_MANAGEMENT_PATH,
24
+ PREDICT_MODEL_METADATA_GENERATE_PATH,
25
+ PREDICT_PATH,
26
+ PREDICT_STATUS_PATH,
27
+ )
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class Config:
32
+ """Immutable configuration for Fundamental SDK.
33
+
34
+ Attributes:
35
+ api_key: API key for authentication (defaults to FUNDAMENTAL_API_KEY env var)
36
+ api_url: Base URL for the API (defaults to FUNDAMENTAL_API_URL env var or API_URL)
37
+ retries: Number of retries for failed requests
38
+ timeout: Default timeout in seconds
39
+ fit_timeout: Timeout for fit operations
40
+ predict_timeout: Timeout for predict operations
41
+ feature_importance_timeout: Timeout for feature importance operations
42
+ download_prediction_result_timeout: Timeout for downloading prediction results
43
+ download_feature_importance_result_timeout: Timeout for downloading feature
44
+ importance results
45
+ use_sigv4_auth: Use AWS SigV4 authentication instead of API key (for private deployment)
46
+ aws_region: AWS region for SigV4 signing (required when use_sigv4_auth=True)
47
+ """
48
+
49
+ api_key: str = ""
50
+ api_url: str = ""
51
+ retries: int = DEFAULT_RETRIES_COUNT
52
+ timeout: int = DEFAULT_TIMEOUT_SECONDS
53
+ fit_timeout: int = DEFAULT_FIT_TIMEOUT_SECONDS
54
+ predict_timeout: int = DEFAULT_PREDICT_TIMEOUT_SECONDS
55
+ feature_importance_timeout: int = DEFAULT_FEATURE_IMPORTANCE_TIMEOUT_SECONDS
56
+ download_prediction_result_timeout: int = DEFAULT_DOWNLOAD_RESULT_TIMEOUT_SECONDS
57
+ download_feature_importance_result_timeout: int = DEFAULT_DOWNLOAD_RESULT_TIMEOUT_SECONDS
58
+ use_sigv4_auth: bool = False
59
+ aws_region: str = ""
60
+
61
+ def __post_init__(self) -> None:
62
+ """Initialize config with environment variables if not provided."""
63
+ api_key: Optional[str] = self.api_key or os.getenv("FUNDAMENTAL_API_KEY")
64
+ object.__setattr__(self, "api_key", api_key)
65
+ url: str = self.api_url or os.getenv("FUNDAMENTAL_API_URL") or API_URL
66
+ object.__setattr__(self, "api_url", url)
67
+
68
+ if self.use_sigv4_auth:
69
+ # Validate AWS region is provided when using SigV4 auth
70
+ if not self.aws_region:
71
+ raise ValueError(
72
+ "AWS region is required when using SigV4 authentication. "
73
+ "Please provide it via aws_region parameter."
74
+ )
75
+ else:
76
+ # Validate API key when not using SigV4 auth
77
+ _ = self.get_api_key() # quick fail on missing api keys
78
+
79
+ def get_api_key(self) -> str:
80
+ """Get the API key, raising an error if not set."""
81
+ if not self.api_key:
82
+ raise ValueError(
83
+ "API key is required. Please provide it via:\n"
84
+ "1. Fundamental(api_key='your-key')\n"
85
+ "2. Set FUNDAMENTAL_API_KEY environment variable"
86
+ )
87
+ return self.api_key
88
+
89
+ def get_full_fit_url(self) -> str:
90
+ """Get the complete URL for the fit endpoint."""
91
+ return f"{self.api_url.rstrip('/')}{FIT_PATH}"
92
+
93
+ def get_full_predict_url(self) -> str:
94
+ """Get the complete URL for the predict endpoint."""
95
+ return f"{self.api_url.rstrip('/')}{PREDICT_PATH}"
96
+
97
+ def get_full_model_management_url(self) -> str:
98
+ """Get the complete URL for the model management endpoint."""
99
+ return f"{self.api_url.rstrip('/')}{MODEL_MANAGEMENT_PATH}"
100
+
101
+ def get_full_fit_model_metadata_generate_url(self) -> str:
102
+ """Get the complete URL for the model metadata generation endpoint."""
103
+ return f"{self.api_url.rstrip('/')}{FIT_MODEL_METADATA_GENERATE_PATH}"
104
+
105
+ def get_full_predict_model_metadata_generate_url(self) -> str:
106
+ """Get the complete URL for the predict metadata generation endpoint."""
107
+ return f"{self.api_url.rstrip('/')}{PREDICT_MODEL_METADATA_GENERATE_PATH}"
108
+
109
+ def get_full_complete_multipart_upload_url(self) -> str:
110
+ """Get the complete URL for the complete multipart upload endpoint."""
111
+ return f"{self.api_url.rstrip('/')}{COMPLETE_MULTIPART_UPLOAD_PATH}"
112
+
113
+ def get_full_fit_status_url(self) -> str:
114
+ """Get the complete URL for the fit status endpoint."""
115
+ return f"{self.api_url.rstrip('/')}{FIT_STATUS_PATH}"
116
+
117
+ def get_full_predict_status_url(self) -> str:
118
+ """Get the complete URL for the predict status endpoint."""
119
+ return f"{self.api_url.rstrip('/')}{PREDICT_STATUS_PATH}"
120
+
121
+ def get_full_feature_importance_url(self) -> str:
122
+ """Get the complete URL for the feature importance endpoint."""
123
+ return f"{self.api_url.rstrip('/')}{FEATURE_IMPORTANCE_PATH}"
124
+
125
+ def get_full_feature_importance_status_url(self) -> str:
126
+ """Get the complete URL for the feature importance status endpoint."""
127
+ return f"{self.api_url.rstrip('/')}{FEATURE_IMPORTANCE_STATUS_PATH}"
128
+
129
+ def get_full_feature_importance_model_metadata_generate_url(self) -> str:
130
+ """Get the complete URL for the feature importance metadata generation endpoint."""
131
+ return f"{self.api_url.rstrip('/')}{FEATURE_IMPORTANCE_MODEL_METADATA_GENERATE_PATH}"
132
+
133
+ def get_version(self) -> str:
134
+ """Get the SDK version."""
135
+ try:
136
+ return version("fundamental")
137
+ except Exception:
138
+ return "unknown"
@@ -0,0 +1,41 @@
1
+ API_URL = "https://api.fundamental.tech"
2
+ FIT_PATH = "/api/v1/model/fit"
3
+ PREDICT_PATH = "/api/v1/model/predict"
4
+ FIT_STATUS_PATH = "/api/v1/model/fit/status"
5
+ PREDICT_STATUS_PATH = "/api/v1/model/predict/status"
6
+ FEATURE_IMPORTANCE_PATH = "/api/v1/model/feature-importance"
7
+ FEATURE_IMPORTANCE_STATUS_PATH = "/api/v1/model/feature-importance/status"
8
+ FEATURE_IMPORTANCE_MODEL_METADATA_GENERATE_PATH = "/api/v1/model/feature-importance/create-metadata"
9
+ MODEL_MANAGEMENT_PATH = "/api/v1/model-management/trained-models"
10
+ FIT_MODEL_METADATA_GENERATE_PATH = "/api/v1/model/fit/create-metadata"
11
+ PREDICT_MODEL_METADATA_GENERATE_PATH = "/api/v1/model/predict/create-metadata"
12
+ COMPLETE_MULTIPART_UPLOAD_PATH = "/api/v1/model/complete-multipart-upload"
13
+
14
+ DEFAULT_TIMEOUT_SECONDS = 30
15
+ DEFAULT_RETRIES_COUNT = 1
16
+
17
+ # Shared default constants
18
+ DEFAULT_POLLING_INTERVAL_SECONDS = 2
19
+ DEFAULT_SUBMIT_REQUEST_TIMEOUT_SECONDS = 30
20
+ DEFAULT_DOWNLOAD_RESULT_TIMEOUT_SECONDS = 60 * 60 * 3 # 3 hours
21
+ DEFAULT_MODEL_METADATA_UPLOAD_TIMEOUT_SECONDS = 60 * 60 * 1 # 1 hour
22
+ DEFAULT_COMPLETE_MULTIPART_UPLOAD_TIMEOUT_SECONDS = 30
23
+ DEFAULT_MODEL_METADATA_GENERATE_TIMEOUT_SECONDS = 30
24
+
25
+ # Predict-specific
26
+ DEFAULT_PREDICT_POLLING_REQUESTS_WITHOUT_DELAY = 100
27
+ DEFAULT_PREDICT_TIMEOUT_SECONDS = 60 * 60 # 1 hour
28
+
29
+ # Feature importance-specific
30
+ DEFAULT_FEATURE_IMPORTANCE_TIMEOUT_SECONDS = 60 * 60 * 24 # 24 hours
31
+
32
+ # Fit-specific
33
+ DEFAULT_FIT_TIMEOUT_SECONDS = 60 * 60 * 3 # 3 hours
34
+
35
+
36
+ X_TRAIN_FILE_NAME = "x_train_file"
37
+ Y_TRAIN_FILE_NAME = "y_train_file"
38
+ X_TEST_FILE_NAME = "x_test_file"
39
+
40
+ # EC2/SigV4 authentication
41
+ SIGV4_SERVICE_NAME = "execute-api"
@@ -0,0 +1,43 @@
1
+ """
2
+ Deprecated aliases for backward compatibility.
3
+
4
+ This module contains deprecated estimator class aliases that will be removed in a future version.
5
+ Estimator classes emit FutureWarning when instantiated.
6
+
7
+ Note: FTMError alias is in fundamental.exceptions (not here) to avoid circular imports.
8
+
9
+ TO REMOVE:
10
+ 1. Delete this file
11
+ 2. Remove imports from fundamental/__init__.py
12
+ """
13
+
14
+ import warnings
15
+ from typing import Literal
16
+
17
+ from fundamental.estimator.classification import NEXUSClassifier
18
+ from fundamental.estimator.regression import NEXUSRegressor
19
+
20
+
21
+ def _deprecation_warning(old_name: str, new_name: str) -> None:
22
+ """Emit a FutureWarning for renamed classes."""
23
+ warnings.warn(
24
+ f"{old_name} is deprecated. Use {new_name} instead.",
25
+ FutureWarning,
26
+ stacklevel=3,
27
+ )
28
+
29
+
30
+ class FTMClassifier(NEXUSClassifier):
31
+ """Deprecated: Use NEXUSClassifier instead."""
32
+
33
+ def __init__(self, mode: Literal["quality", "speed"] = "quality"):
34
+ _deprecation_warning("FTMClassifier", "NEXUSClassifier")
35
+ super().__init__(mode=mode)
36
+
37
+
38
+ class FTMRegressor(NEXUSRegressor):
39
+ """Deprecated: Use NEXUSRegressor instead."""
40
+
41
+ def __init__(self, mode: Literal["quality", "speed"] = "quality"):
42
+ _deprecation_warning("FTMRegressor", "NEXUSRegressor")
43
+ super().__init__(mode=mode)
@@ -0,0 +1,16 @@
1
+ """
2
+ NEXUS Estimators Module
3
+
4
+ This module provides scikit-learn compatible estimators for the NEXUS foundation
5
+ model, organized by task type.
6
+
7
+ Note: Deprecated FTMClassifier/FTMRegressor aliases are available from `fundamental` package.
8
+ """
9
+
10
+ from fundamental.estimator.classification import NEXUSClassifier
11
+ from fundamental.estimator.regression import NEXUSRegressor
12
+
13
+ __all__ = [
14
+ "NEXUSClassifier",
15
+ "NEXUSRegressor",
16
+ ]
@@ -0,0 +1,263 @@
1
+ """
2
+ Base classes and model registry for NEXUS estimators.
3
+
4
+ This module contains the ModelRegistry for managing available models and
5
+ the abstract BaseNEXUSEstimator class that provides common functionality
6
+ for both classification and regression tasks.
7
+ """
8
+
9
+ import logging
10
+ from abc import ABC
11
+ from typing import Any, ClassVar, Literal, Optional
12
+
13
+ import numpy.typing as npt
14
+ from sklearn.base import BaseEstimator
15
+ from sklearn.utils.validation import check_is_fitted
16
+ from typing_extensions import Self
17
+
18
+ import fundamental
19
+ from fundamental.clients import BaseClient
20
+ from fundamental.services import poll_fit_result, remote_fit, remote_predict, submit_fit_task
21
+ from fundamental.services.models import ModelsService
22
+ from fundamental.utils.data import (
23
+ XType,
24
+ YType,
25
+ check_n_features_compat,
26
+ validate_data,
27
+ validate_inputs_type,
28
+ )
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class BaseNEXUSEstimator(BaseEstimator, ABC): # type: ignore[misc]
34
+ """
35
+ Abstract base class for NEXUS estimators.
36
+
37
+ Provides common functionality and interface for both classification
38
+ and regression estimators.
39
+ """
40
+
41
+ _task_type: ClassVar[Literal["classification", "regression"]]
42
+
43
+ def __init__(
44
+ self,
45
+ mode: Literal["quality", "speed"] = "quality",
46
+ ):
47
+ """Initialize the estimator.
48
+
49
+ Args:
50
+ mode: Training mode - "quality" for best performance, "speed" for faster training.
51
+ """
52
+ self.mode: Literal["quality", "speed"] = mode
53
+
54
+ def _get_client(self) -> BaseClient:
55
+ return fundamental.get_client()
56
+
57
+ def _set_n_features_in(self, X: XType) -> None:
58
+ """Set n_features_in_ attribute for sklearn compatibility.
59
+
60
+ Args:
61
+ X: Input features array or dataframe.
62
+ """
63
+ self.n_features_in_ = X.shape[1]
64
+
65
+ def load_model(self, trained_model_id: str) -> Self:
66
+ """
67
+ Load a trained model.
68
+ """
69
+ self._load_trained_model(trained_model_id)
70
+ self.fitted_ = True
71
+ return self
72
+
73
+ def fit(self, X: XType, y: YType) -> Self:
74
+ """
75
+ Fit the model to training data.
76
+
77
+ Parameters
78
+ ----------
79
+ X : XType
80
+ Training features as numpy array, pandas DataFrame.
81
+ y : YType
82
+ Training targets as numpy array, pandas Series.
83
+
84
+ Returns
85
+ -------
86
+ Self
87
+ Self for method chaining.
88
+ """
89
+ validate_inputs_type(X=X, y=y)
90
+ validate_data(X=X, y=y)
91
+
92
+ response = remote_fit(
93
+ X=X,
94
+ y=y,
95
+ task=type(self)._task_type,
96
+ mode=self.mode,
97
+ client=self._get_client(),
98
+ )
99
+ self._load_estimator_fields(response.estimator_fields)
100
+ self.trained_model_id_ = response.trained_model_id
101
+ self.fitted_ = True
102
+ self._set_n_features_in(X)
103
+ return self
104
+
105
+ def set_attributes(
106
+ self,
107
+ attributes: dict[str, str],
108
+ ) -> Self:
109
+ """
110
+ Set attributes on the fitted model.
111
+
112
+ Parameters
113
+ ----------
114
+ attributes : dict[str, str]
115
+ Dictionary of key-value pairs to set as model attributes.
116
+
117
+ Returns
118
+ -------
119
+ Self
120
+ Self for method chaining.
121
+
122
+ Raises
123
+ ------
124
+ NotFittedError
125
+ If the model has not been fitted yet.
126
+
127
+ Examples
128
+ --------
129
+ >>> clf.set_attributes({"stage": "prod", "owner": "lital"})
130
+ """
131
+ check_is_fitted(self)
132
+ assert self.trained_model_id_ is not None
133
+
134
+ models_service = ModelsService(client=self._get_client())
135
+ models_service.set_attributes(
136
+ model_id=self.trained_model_id_,
137
+ attributes=attributes,
138
+ )
139
+ return self
140
+
141
+ def submit_fit_task(self, X: XType, y: YType) -> str:
142
+ """
143
+ Submit a fit task without waiting for completion.
144
+
145
+ Parameters
146
+ ----------
147
+ X : XType
148
+ Training features as numpy array, pandas DataFrame.
149
+ y : YType
150
+ Training targets as numpy array, pandas Series.
151
+
152
+ Returns
153
+ -------
154
+ str
155
+ Task ID for polling with poll_fit_result.
156
+ """
157
+ validate_inputs_type(X=X, y=y)
158
+ validate_data(X=X, y=y)
159
+ result = submit_fit_task(
160
+ X=X,
161
+ y=y,
162
+ task=type(self)._task_type,
163
+ mode=self.mode,
164
+ client=self._get_client(),
165
+ )
166
+ # Store for poll_fit_result to use
167
+ self._pending_trained_model_id = result.trained_model_id
168
+ self._pending_n_features_in = X.shape[1]
169
+ return result.task_id
170
+
171
+ def poll_fit_result(self, task_id: str) -> Optional[Self]:
172
+ """
173
+ Check fit task status and load model if complete.
174
+
175
+ Parameters
176
+ ----------
177
+ task_id : str
178
+ The task ID returned by submit_fit_task.
179
+
180
+ Returns
181
+ -------
182
+ Optional[Self]
183
+ Self if fitting completed (for method chaining), None if still in progress.
184
+ """
185
+ result = poll_fit_result(
186
+ task_id=task_id,
187
+ trained_model_id=self._pending_trained_model_id,
188
+ client=self._get_client(),
189
+ )
190
+ if result is not None:
191
+ self._load_estimator_fields(result.estimator_fields)
192
+ self.trained_model_id_ = result.trained_model_id
193
+ self.fitted_ = True
194
+ self.n_features_in_ = self._pending_n_features_in
195
+ # Clean up pending attributes
196
+ del self._pending_n_features_in
197
+ del self._pending_trained_model_id
198
+ return self
199
+ return None
200
+
201
+ def predict(self, X: XType) -> npt.NDArray[Any]:
202
+ """
203
+ Predict preds function.
204
+
205
+ Parameters
206
+ ----------
207
+ X : XType
208
+ Input features as numpy array, pandas DataFrame.
209
+
210
+
211
+ Returns
212
+ -------
213
+ np.ndarray
214
+ Model predictions.
215
+
216
+ Raises
217
+ ------
218
+ NotFittedError
219
+ If the model has not been fitted yet.
220
+ """
221
+ return self._predict(X=X, output_type="preds")
222
+
223
+ def _predict(
224
+ self,
225
+ X: XType,
226
+ output_type: Literal["preds", "probas"],
227
+ ) -> npt.NDArray[Any]:
228
+ """Internal prediction method with configurable output type."""
229
+ validate_inputs_type(X=X)
230
+ validate_data(X=X)
231
+ check_is_fitted(self)
232
+
233
+ # Validate n_features_in_ for sklearn compatibility
234
+ check_n_features_compat(self, X, reset=False)
235
+
236
+ assert self.trained_model_id_ is not None
237
+ return remote_predict(
238
+ X=X,
239
+ output_type=output_type,
240
+ trained_model_id=self.trained_model_id_,
241
+ client=self._get_client(),
242
+ )
243
+
244
+ def _load_trained_model(self, model_id: str) -> None:
245
+ try:
246
+ models_service = ModelsService(client=self._get_client())
247
+ response = models_service.load(
248
+ trained_model_id=model_id,
249
+ )
250
+ self._load_estimator_fields(response.estimator_fields)
251
+ self.trained_model_id_ = model_id
252
+ except Exception as e:
253
+ logger.error(f"Failed to load model {model_id}: {e}")
254
+ raise e
255
+
256
+ def _load_estimator_fields(self, estimator_fields: dict[str, Any]) -> None:
257
+ for field, value in estimator_fields.items():
258
+ try:
259
+ # Keep private attrs as-is, add underscore to public attrs (sklearn convention)
260
+ field_name = field if field.endswith("_") or field.startswith("_") else f"{field}_"
261
+ setattr(self, field_name, value)
262
+ except AttributeError:
263
+ logger.warning(f"Field {field} is not a valid attribute in the current estimator")
@@ -0,0 +1,46 @@
1
+ """NEXUS Classification estimator."""
2
+
3
+ import logging
4
+ from typing import Any, Literal
5
+
6
+ import numpy.typing as npt
7
+ from sklearn.base import ClassifierMixin
8
+
9
+ from fundamental.estimator.nexus_estimator import NEXUSEstimator
10
+ from fundamental.utils.data import XType
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class NEXUSClassifier(ClassifierMixin, NEXUSEstimator):
16
+ """NEXUS Model for Classification Tasks."""
17
+
18
+ _task_type = "classification" # type: ignore[assignment]
19
+
20
+ def __init__(
21
+ self,
22
+ mode: Literal["quality", "speed"] = "quality",
23
+ ):
24
+ super().__init__(mode=mode)
25
+
26
+ def predict_proba(self, X: XType) -> npt.NDArray[Any]:
27
+ """
28
+ Predict probabilities function.
29
+
30
+ Parameters
31
+ ----------
32
+ X : XType
33
+ Input features as numpy array, pandas DataFrame.
34
+
35
+
36
+ Returns
37
+ -------
38
+ np.ndarray
39
+ Model probabilities.
40
+
41
+ Raises
42
+ ------
43
+ NotFittedError
44
+ If the model has not been fitted yet.
45
+ """
46
+ return self._predict(X=X, output_type="probas")