truefoundry 0.5.0rc6__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truefoundry might be problematic. Click here for more details.
- truefoundry/common/auth_service_client.py +2 -2
- truefoundry/common/constants.py +9 -0
- truefoundry/common/utils.py +81 -1
- truefoundry/deploy/__init__.py +5 -0
- truefoundry/deploy/builder/builders/tfy_notebook_buildpack/__init__.py +4 -2
- truefoundry/deploy/builder/builders/tfy_python_buildpack/__init__.py +7 -5
- truefoundry/deploy/builder/builders/tfy_python_buildpack/dockerfile_template.py +87 -28
- truefoundry/deploy/builder/constants.py +8 -0
- truefoundry/deploy/builder/utils.py +9 -4
- truefoundry/deploy/cli/cli.py +2 -0
- truefoundry/deploy/cli/commands/__init__.py +1 -0
- truefoundry/deploy/cli/commands/deploy_init_command.py +22 -0
- truefoundry/deploy/lib/dao/application.py +2 -1
- truefoundry/deploy/v2/lib/patched_models.py +8 -0
- truefoundry/ml/__init__.py +25 -15
- truefoundry/ml/artifact/truefoundry_artifact_repo.py +8 -3
- truefoundry/ml/autogen/client/__init__.py +24 -0
- truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +325 -0
- truefoundry/ml/autogen/client/models/__init__.py +24 -0
- truefoundry/ml/autogen/client/models/artifact_version_manifest.py +2 -2
- truefoundry/ml/autogen/client/models/export_deployment_files_request_dto.py +82 -0
- truefoundry/ml/autogen/client/models/infer_method_name.py +34 -0
- truefoundry/ml/autogen/client/models/model_server.py +34 -0
- truefoundry/ml/autogen/client/models/model_version_environment.py +97 -0
- truefoundry/ml/autogen/client/models/model_version_manifest.py +13 -8
- truefoundry/ml/autogen/client/models/sklearn_framework.py +25 -2
- truefoundry/ml/autogen/client/models/sklearn_model_schema.py +82 -0
- truefoundry/ml/autogen/client/models/sklearn_serialization_format.py +35 -0
- truefoundry/ml/autogen/client/models/transformers_framework.py +2 -2
- truefoundry/ml/autogen/client/models/validate_external_storage_root_request_dto.py +71 -0
- truefoundry/ml/autogen/client/models/validate_external_storage_root_response_dto.py +69 -0
- truefoundry/ml/autogen/client/models/xg_boost_framework.py +28 -3
- truefoundry/ml/autogen/client/models/xg_boost_model_schema.py +88 -0
- truefoundry/ml/autogen/client/models/xg_boost_serialization_format.py +36 -0
- truefoundry/ml/autogen/client_README.md +12 -0
- truefoundry/ml/autogen/entities/artifacts.py +119 -26
- truefoundry/ml/autogen/models/signature.py +6 -3
- truefoundry/ml/autogen/models/utils.py +12 -7
- truefoundry/ml/cli/commands/model_init.py +97 -0
- truefoundry/ml/cli/utils.py +34 -0
- truefoundry/ml/log_types/artifacts/model.py +53 -38
- truefoundry/ml/log_types/artifacts/utils.py +38 -2
- truefoundry/ml/mlfoundry_api.py +77 -81
- truefoundry/ml/mlfoundry_run.py +3 -33
- truefoundry/ml/model_framework.py +372 -3
- truefoundry/ml/validation_utils.py +2 -0
- {truefoundry-0.5.0rc6.dist-info → truefoundry-0.5.1.dist-info}/METADATA +2 -6
- {truefoundry-0.5.0rc6.dist-info → truefoundry-0.5.1.dist-info}/RECORD +50 -55
- truefoundry/deploy/function_service/__init__.py +0 -3
- truefoundry/deploy/function_service/__main__.py +0 -27
- truefoundry/deploy/function_service/app.py +0 -92
- truefoundry/deploy/function_service/build.py +0 -45
- truefoundry/deploy/function_service/remote/__init__.py +0 -6
- truefoundry/deploy/function_service/remote/context.py +0 -3
- truefoundry/deploy/function_service/remote/method.py +0 -67
- truefoundry/deploy/function_service/remote/remote.py +0 -144
- truefoundry/deploy/function_service/route.py +0 -137
- truefoundry/deploy/function_service/service.py +0 -113
- truefoundry/deploy/function_service/utils.py +0 -53
- truefoundry/langchain/__init__.py +0 -12
- truefoundry/langchain/deprecated.py +0 -302
- truefoundry/langchain/truefoundry_chat.py +0 -130
- truefoundry/langchain/truefoundry_embeddings.py +0 -171
- truefoundry/langchain/truefoundry_llm.py +0 -106
- truefoundry/langchain/utils.py +0 -44
- truefoundry/ml/log_types/artifacts/model_extras.py +0 -48
- {truefoundry-0.5.0rc6.dist-info → truefoundry-0.5.1.dist-info}/WHEEL +0 -0
- {truefoundry-0.5.0rc6.dist-info → truefoundry-0.5.1.dist-info}/entry_points.txt +0 -0
|
@@ -1,10 +1,50 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
1
3
|
import warnings
|
|
2
|
-
from
|
|
3
|
-
|
|
4
|
-
from
|
|
4
|
+
from collections import OrderedDict
|
|
5
|
+
from pickle import load as pickle_load
|
|
6
|
+
from typing import (
|
|
7
|
+
TYPE_CHECKING,
|
|
8
|
+
Any,
|
|
9
|
+
Callable,
|
|
10
|
+
Dict,
|
|
11
|
+
List,
|
|
12
|
+
Literal,
|
|
13
|
+
Optional,
|
|
14
|
+
Type,
|
|
15
|
+
Union,
|
|
16
|
+
get_args,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from truefoundry.common.utils import (
|
|
20
|
+
get_python_version_major_minor,
|
|
21
|
+
list_pip_packages_installed,
|
|
22
|
+
)
|
|
23
|
+
from truefoundry.ml.autogen.client import (
|
|
24
|
+
SklearnSerializationFormat,
|
|
25
|
+
XGBoostSerializationFormat,
|
|
26
|
+
)
|
|
5
27
|
from truefoundry.ml.autogen.entities import artifacts as autogen_artifacts
|
|
28
|
+
from truefoundry.ml.autogen.models import infer_signature
|
|
29
|
+
from truefoundry.ml.enums import ModelFramework
|
|
30
|
+
from truefoundry.ml.log_types.artifacts.utils import (
|
|
31
|
+
get_single_file_path_if_only_one_in_directory,
|
|
32
|
+
to_unix_path,
|
|
33
|
+
)
|
|
6
34
|
from truefoundry.pydantic_v1 import BaseModel, Field
|
|
7
35
|
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from sklearn.base import BaseEstimator
|
|
38
|
+
from xgboost import Booster, XGBModel
|
|
39
|
+
|
|
40
|
+
# Map serialization format to corresponding pip packages
|
|
41
|
+
SERIALIZATION_FORMAT_TO_PACKAGES_NAME_MAP = {
|
|
42
|
+
SklearnSerializationFormat.JOBLIB: ["joblib"],
|
|
43
|
+
SklearnSerializationFormat.CLOUDPICKLE: ["cloudpickle"],
|
|
44
|
+
XGBoostSerializationFormat.JOBLIB: ["joblib"],
|
|
45
|
+
XGBoostSerializationFormat.CLOUDPICKLE: ["cloudpickle"],
|
|
46
|
+
}
|
|
47
|
+
|
|
8
48
|
|
|
9
49
|
class FastAIFramework(autogen_artifacts.FastAIFramework):
|
|
10
50
|
"""FastAI model Framework"""
|
|
@@ -111,6 +151,87 @@ ModelFrameworkType = Union[
|
|
|
111
151
|
]
|
|
112
152
|
|
|
113
153
|
|
|
154
|
+
class _SerializationFormatLoaderRegistry:
|
|
155
|
+
def __init__(self, framework: Type[Union[SklearnFramework, XGBoostFramework]]):
|
|
156
|
+
# An OrderedDict is used to maintain the order of loaders based on priority
|
|
157
|
+
# The loaders are added in the following order:
|
|
158
|
+
# 1. joblib (if available)
|
|
159
|
+
# 2. cloudpickle (if available)
|
|
160
|
+
# 3. pickle (default fallback)
|
|
161
|
+
# This ensures that when looking up a loader, it follows the correct loading priority.
|
|
162
|
+
self._loader_map: Dict[
|
|
163
|
+
Union[SklearnSerializationFormat, XGBoostSerializationFormat],
|
|
164
|
+
Callable[[bytes], object],
|
|
165
|
+
] = OrderedDict()
|
|
166
|
+
format_class: Union[SklearnSerializationFormat, XGBoostSerializationFormat] = (
|
|
167
|
+
SklearnSerializationFormat
|
|
168
|
+
if framework == SklearnFramework
|
|
169
|
+
else XGBoostSerializationFormat
|
|
170
|
+
)
|
|
171
|
+
is_xgboost = framework == XGBoostFramework
|
|
172
|
+
|
|
173
|
+
try:
|
|
174
|
+
from joblib import load as joblib_load
|
|
175
|
+
|
|
176
|
+
self._loader_map[format_class.JOBLIB] = joblib_load
|
|
177
|
+
except ImportError:
|
|
178
|
+
pass
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
from cloudpickle import load as cloudpickle_load
|
|
182
|
+
|
|
183
|
+
self._loader_map[format_class.CLOUDPICKLE] = cloudpickle_load
|
|
184
|
+
|
|
185
|
+
except ImportError:
|
|
186
|
+
pass
|
|
187
|
+
|
|
188
|
+
if is_xgboost:
|
|
189
|
+
try:
|
|
190
|
+
from xgboost import Booster
|
|
191
|
+
|
|
192
|
+
booster = Booster()
|
|
193
|
+
self._loader_map[format_class.JSON] = booster.load_model
|
|
194
|
+
except ImportError:
|
|
195
|
+
pass
|
|
196
|
+
|
|
197
|
+
# Add pickle loader as a fallback
|
|
198
|
+
self._loader_map[format_class.PICKLE] = pickle_load
|
|
199
|
+
|
|
200
|
+
def get_loader_map(
|
|
201
|
+
self,
|
|
202
|
+
) -> Dict[
|
|
203
|
+
Union[SklearnSerializationFormat, XGBoostSerializationFormat],
|
|
204
|
+
Callable[[bytes], object],
|
|
205
|
+
]:
|
|
206
|
+
return self._loader_map
|
|
207
|
+
|
|
208
|
+
def _detect_model_serialization_format(
|
|
209
|
+
self,
|
|
210
|
+
model_file_path: str,
|
|
211
|
+
) -> Optional[Union[SklearnSerializationFormat, XGBoostSerializationFormat]]:
|
|
212
|
+
"""
|
|
213
|
+
The function will attempt to load the model using each different serialization format's loader and return the first successful one.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
model_file_path (str): The path to the file to be loaded.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Optional[Union[SklearnSerializationFormat, XGBoostSerializationFormat]]: The serialization format if successfully loaded, None otherwise.
|
|
220
|
+
"""
|
|
221
|
+
# Attempt to load the model using each framework
|
|
222
|
+
for (
|
|
223
|
+
serialization_format,
|
|
224
|
+
loader,
|
|
225
|
+
) in self._loader_map.items():
|
|
226
|
+
try:
|
|
227
|
+
with open(model_file_path, "rb") as f:
|
|
228
|
+
loader(f)
|
|
229
|
+
return serialization_format
|
|
230
|
+
except Exception:
|
|
231
|
+
continue
|
|
232
|
+
return None
|
|
233
|
+
|
|
234
|
+
|
|
114
235
|
class _ModelFramework(BaseModel):
|
|
115
236
|
__root__: ModelFrameworkType = Field(discriminator="type")
|
|
116
237
|
|
|
@@ -167,3 +288,251 @@ class _ModelFramework(BaseModel):
|
|
|
167
288
|
return None
|
|
168
289
|
|
|
169
290
|
return cls.parse_obj(obj).__root__
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
# Mapping of model frameworks to pip packages
|
|
294
|
+
_MODEL_FRAMEWORK_TO_PIP_PACKAGES: Dict[Type[ModelFrameworkType], List[str]] = {
|
|
295
|
+
SklearnFramework: ["scikit-learn", "numpy", "pandas"],
|
|
296
|
+
XGBoostFramework: ["xgboost", "numpy", "pandas"],
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _get_required_framework_pip_packages(framework: "ModelFrameworkType") -> List[str]:
|
|
301
|
+
"""
|
|
302
|
+
Fetches the pip packages required for a given model framework.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
framework ("ModelFrameworkType"): The model framework for which to fetch the pip packages.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
List[str]: The list of pip packages required for the given model framework.
|
|
309
|
+
If no packages are found for the framework type, returns an empty list.
|
|
310
|
+
"""
|
|
311
|
+
return _MODEL_FRAMEWORK_TO_PIP_PACKAGES.get(framework.__class__, [])
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _fetch_framework_specific_pip_packages(
|
|
315
|
+
framework: "ModelFrameworkType",
|
|
316
|
+
) -> List[str]:
|
|
317
|
+
"""
|
|
318
|
+
Fetch the pip packages required for the given framework, including any dependencies
|
|
319
|
+
related to the framework's serialization format.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
framework: The framework object (e.g., SklearnFramework, XGBoostFramework).
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
List[str]: A list of pip packages for the given framework and environment,
|
|
326
|
+
including any dependencies based on the serialization format
|
|
327
|
+
(e.g., ['numpy==1.19.5', ...]).
|
|
328
|
+
"""
|
|
329
|
+
framework_package_names = _get_required_framework_pip_packages(framework=framework)
|
|
330
|
+
|
|
331
|
+
# Add serialization format dependencies if applicable
|
|
332
|
+
if isinstance(framework, (SklearnFramework, XGBoostFramework)):
|
|
333
|
+
framework_package_names.extend(
|
|
334
|
+
SERIALIZATION_FORMAT_TO_PACKAGES_NAME_MAP.get(
|
|
335
|
+
framework.serialization_format, []
|
|
336
|
+
)
|
|
337
|
+
)
|
|
338
|
+
return [
|
|
339
|
+
f"{package.name}=={package.version}"
|
|
340
|
+
for package in list_pip_packages_installed(
|
|
341
|
+
filter_package_names=framework_package_names
|
|
342
|
+
)
|
|
343
|
+
]
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def auto_update_environment_details(
|
|
347
|
+
environment: autogen_artifacts.ModelVersionEnvironment,
|
|
348
|
+
framework: Optional[ModelFrameworkType],
|
|
349
|
+
):
|
|
350
|
+
"""
|
|
351
|
+
Auto fetch the environment details if not provided, based on the provided environment and framework.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
environment: The environment object that holds environment details like python_version and pip_packages.
|
|
355
|
+
framework: The framework object (e.g., SklearnFramework, XGBoostFramework) that may affect pip_package fetching.
|
|
356
|
+
"""
|
|
357
|
+
# Auto fetch python_version if not provided
|
|
358
|
+
if not environment.python_version:
|
|
359
|
+
environment.python_version = get_python_version_major_minor()
|
|
360
|
+
|
|
361
|
+
# Framework-specific pip_package handling
|
|
362
|
+
if framework and not environment.pip_packages:
|
|
363
|
+
environment.pip_packages = _fetch_framework_specific_pip_packages(framework)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def _validate_and_get_absolute_model_filepath(
|
|
367
|
+
model_file_or_folder: str,
|
|
368
|
+
model_filepath: Optional[str] = None,
|
|
369
|
+
) -> Optional[str]:
|
|
370
|
+
# If no model_filepath is set, resolve it from the directory
|
|
371
|
+
if not model_filepath:
|
|
372
|
+
# If model_filepath is not set, resolve it based on these cases:
|
|
373
|
+
# - Case 1: model_file_or_folder/model.joblib -> model.joblib
|
|
374
|
+
# - Case 2: model_file_or_folder/folder/model.joblib -> folder/model.joblib
|
|
375
|
+
# - Case 3: model_file_or_folder/folder/model.joblib, model_file_or_folder/config.json -> None
|
|
376
|
+
return get_single_file_path_if_only_one_in_directory(model_file_or_folder)
|
|
377
|
+
|
|
378
|
+
# If model_filepath is already set, validate and resolve it:
|
|
379
|
+
# - Case 1: Resolve the absolute file path of the model file relative to the provided directory.
|
|
380
|
+
# Example: If model_file_or_folder is '/root/models' and model_filepath is 'model.joblib',
|
|
381
|
+
# the resolved model file path would be '/root/models/model.joblib'. Validate it.
|
|
382
|
+
#
|
|
383
|
+
# - Case 2: If model_filepath is a relative path, resolve it to an absolute path based on the provided directory.
|
|
384
|
+
# Example: If model_file_or_folder is '/root/models' and model_filepath is 'subfolder/model.joblib',
|
|
385
|
+
# the resolved path would be '/root/models/subfolder/model.joblib'. Validate it.
|
|
386
|
+
#
|
|
387
|
+
# - Case 3: Verify that the resolved model file exists and is a valid file.
|
|
388
|
+
# Example: If the resolved path is '/root/models/model.joblib', check if the file exists.
|
|
389
|
+
# If it does not exist, raise a FileNotFoundError.
|
|
390
|
+
#
|
|
391
|
+
# - Case 4: Ensure the resolved model file is located within the specified directory or is the directory itself.
|
|
392
|
+
# Example: If the resolved path is '/root/models/model.joblib' and model_file_or_folder is '/root/models',
|
|
393
|
+
# the resolved path is valid. If the file lies outside '/root/models', raise a ValueError.
|
|
394
|
+
#
|
|
395
|
+
|
|
396
|
+
# If model_filepath is set, Resolve the absolute path of the model file (It can be a relative path or absolute path)
|
|
397
|
+
model_dir = (
|
|
398
|
+
os.path.dirname(model_file_or_folder)
|
|
399
|
+
if os.path.isfile(model_file_or_folder)
|
|
400
|
+
else model_file_or_folder
|
|
401
|
+
)
|
|
402
|
+
absolute_model_filepath = os.path.abspath(os.path.join(model_dir, model_filepath))
|
|
403
|
+
|
|
404
|
+
# Validate if resolve valid is within the provided directory or is the same as it
|
|
405
|
+
if not (
|
|
406
|
+
absolute_model_filepath == model_file_or_folder
|
|
407
|
+
or absolute_model_filepath.startswith(model_file_or_folder + os.sep)
|
|
408
|
+
):
|
|
409
|
+
raise ValueError(
|
|
410
|
+
f"model_filepath '{model_filepath}' must be relative to "
|
|
411
|
+
f"{model_file_or_folder}. Resolved path '{absolute_model_filepath}' is invalid."
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
if not os.path.isfile(absolute_model_filepath):
|
|
415
|
+
raise FileNotFoundError(f"Model file not found: {absolute_model_filepath}")
|
|
416
|
+
|
|
417
|
+
return absolute_model_filepath
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def _validate_and_resolve_model_filepath(
|
|
421
|
+
model_file_or_folder: str,
|
|
422
|
+
model_filepath: Optional[str] = None,
|
|
423
|
+
) -> Optional[str]:
|
|
424
|
+
absolute_model_filepath = _validate_and_get_absolute_model_filepath(
|
|
425
|
+
model_file_or_folder=model_file_or_folder, model_filepath=model_filepath
|
|
426
|
+
)
|
|
427
|
+
if absolute_model_filepath:
|
|
428
|
+
return to_unix_path(
|
|
429
|
+
os.path.relpath(absolute_model_filepath, model_file_or_folder)
|
|
430
|
+
if os.path.isdir(model_file_or_folder)
|
|
431
|
+
else os.path.basename(absolute_model_filepath)
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def auto_update_model_framework_details(
|
|
436
|
+
framework: "ModelFrameworkType", model_file_or_folder: str
|
|
437
|
+
):
|
|
438
|
+
"""
|
|
439
|
+
Auto update the model framework details based on the provided model file or folder path.
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
framework: The framework object (e.g., SklearnFramework, XGBoostFramework) to update.
|
|
443
|
+
model_file_or_folder: The path to the model file or folder.
|
|
444
|
+
"""
|
|
445
|
+
|
|
446
|
+
# Ensure the model file or folder path is an absolute path
|
|
447
|
+
model_file_or_folder = os.path.abspath(model_file_or_folder)
|
|
448
|
+
|
|
449
|
+
if isinstance(framework, (SklearnFramework, XGBoostFramework)):
|
|
450
|
+
framework.model_filepath = _validate_and_resolve_model_filepath(
|
|
451
|
+
model_file_or_folder=model_file_or_folder,
|
|
452
|
+
model_filepath=framework.model_filepath,
|
|
453
|
+
)
|
|
454
|
+
if framework.model_filepath:
|
|
455
|
+
absolute_model_filepath = (
|
|
456
|
+
model_file_or_folder
|
|
457
|
+
if os.path.isfile(model_file_or_folder)
|
|
458
|
+
else os.path.join(model_file_or_folder, framework.model_filepath)
|
|
459
|
+
)
|
|
460
|
+
if not framework.serialization_format:
|
|
461
|
+
loader_registry = _SerializationFormatLoaderRegistry(
|
|
462
|
+
framework=framework
|
|
463
|
+
)
|
|
464
|
+
framework.serialization_format = (
|
|
465
|
+
loader_registry._detect_model_serialization_format(
|
|
466
|
+
model_file_path=absolute_model_filepath
|
|
467
|
+
)
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def _infer_schema(
|
|
472
|
+
model_input: Any,
|
|
473
|
+
model: Union["BaseEstimator", "Booster", "XGBModel"],
|
|
474
|
+
infer_method_name: str = "predict",
|
|
475
|
+
) -> Dict[str, Any]:
|
|
476
|
+
if not hasattr(model, infer_method_name):
|
|
477
|
+
raise ValueError(
|
|
478
|
+
f"Model does not have the method '{infer_method_name}' to infer the schema."
|
|
479
|
+
)
|
|
480
|
+
model_infer_method = getattr(model, infer_method_name)
|
|
481
|
+
model_output = model_infer_method(model_input)
|
|
482
|
+
|
|
483
|
+
model_signature = infer_signature(
|
|
484
|
+
model_input=model_input, model_output=model_output
|
|
485
|
+
)
|
|
486
|
+
return model_signature.to_dict()
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def sklearn_infer_schema(
|
|
490
|
+
model_input: Any,
|
|
491
|
+
model: "BaseEstimator",
|
|
492
|
+
infer_method_name: str = "predict",
|
|
493
|
+
) -> autogen_artifacts.SklearnModelSchema:
|
|
494
|
+
"""
|
|
495
|
+
Infer the schema of a Sklearn model.
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
model_input (Any): The input data to be used for schema inference.
|
|
499
|
+
model (Any): The Sklearn model instance.
|
|
500
|
+
infer_method_name (str): The name of the method to be used for schema inference.
|
|
501
|
+
Eg: predict (default), predict_proba
|
|
502
|
+
Returns:
|
|
503
|
+
SklearnModelSchema: The inferred schema of the Sklearn model.
|
|
504
|
+
"""
|
|
505
|
+
model_signature_json = _infer_schema(
|
|
506
|
+
model_input=model_input, model=model, infer_method_name=infer_method_name
|
|
507
|
+
)
|
|
508
|
+
return autogen_artifacts.SklearnModelSchema(
|
|
509
|
+
infer_method_name=infer_method_name,
|
|
510
|
+
inputs=json.loads(model_signature_json["inputs"]),
|
|
511
|
+
outputs=json.loads(model_signature_json["outputs"]),
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def xgboost_infer_schema(
|
|
516
|
+
model_input: Any,
|
|
517
|
+
model: Union["Booster", "XGBModel"],
|
|
518
|
+
infer_method_name: str = "predict",
|
|
519
|
+
) -> autogen_artifacts.XGBoostModelSchema:
|
|
520
|
+
"""
|
|
521
|
+
Infer the schema of an XGBoost model.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
model_input (Any): The input data to be used for schema inference.
|
|
525
|
+
model (Any): The XGBoost model instance.
|
|
526
|
+
infer_method_name (str): The name of the method to be used for schema inference.
|
|
527
|
+
Eg: predict (default), predict_proba
|
|
528
|
+
Returns:
|
|
529
|
+
XGBoostModelSchema: The inferred schema of the XGBoost model.
|
|
530
|
+
"""
|
|
531
|
+
model_signature_json = _infer_schema(
|
|
532
|
+
model_input=model_input, model=model, infer_method_name=infer_method_name
|
|
533
|
+
)
|
|
534
|
+
return autogen_artifacts.XGBoostModelSchema(
|
|
535
|
+
infer_method_name=infer_method_name,
|
|
536
|
+
inputs=json.loads(model_signature_json["inputs"]),
|
|
537
|
+
outputs=json.loads(model_signature_json["outputs"]),
|
|
538
|
+
)
|
|
@@ -29,6 +29,8 @@ _ML_REPO_NAME_REGEX = re.compile(r"^[a-zA-Z][a-zA-Z0-9\-]{1,98}[a-zA-Z0-9]$")
|
|
|
29
29
|
_RUN_NAME_REGEX = re.compile(r"^[a-zA-Z0-9-]*$")
|
|
30
30
|
_RUN_LOG_LOG_TYPE_REGEX = re.compile(r"^[a-zA-Z0-9-/]*$")
|
|
31
31
|
_RUN_LOG_KEY_REGEX = re.compile(r"^[a-zA-Z0-9-_]*$")
|
|
32
|
+
_APP_NAME_REGEX = re.compile(r"^[a-z][a-z0-9\\-]{1,30}[a-z0-9]$")
|
|
33
|
+
|
|
32
34
|
|
|
33
35
|
MAX_PARAMS_TAGS_PER_BATCH = 100
|
|
34
36
|
MAX_METRICS_PER_BATCH = 1000
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: truefoundry
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.1
|
|
4
4
|
Summary: Truefoundry CLI
|
|
5
5
|
Author: Abhishek Choudhary
|
|
6
6
|
Author-email: abhishek@truefoundry.com
|
|
@@ -19,7 +19,6 @@ Requires-Dist: aenum (>=3.0.0,<4.0.0)
|
|
|
19
19
|
Requires-Dist: click (>=7.0.0,<9.0.0)
|
|
20
20
|
Requires-Dist: coolname (>=1.1.0,<2.0.0)
|
|
21
21
|
Requires-Dist: docker (>=6.1.2,<8.0.0)
|
|
22
|
-
Requires-Dist: fastapi (>=0.56.0,<0.200.0)
|
|
23
22
|
Requires-Dist: filelock (>=3.8.0,<4.0.0)
|
|
24
23
|
Requires-Dist: flytekit (==1.13.13) ; extra == "workflow"
|
|
25
24
|
Requires-Dist: gitignorefile (>=1.1.2,<2.0.0)
|
|
@@ -29,15 +28,13 @@ Requires-Dist: numpy (>=1.23.0,<2.0.0) ; python_version < "3.12"
|
|
|
29
28
|
Requires-Dist: numpy (>=1.26.0,<2.0.0) ; python_version >= "3.12"
|
|
30
29
|
Requires-Dist: openai (>=1.16.2,<2.0.0)
|
|
31
30
|
Requires-Dist: packaging (>=20.0,<25.0)
|
|
32
|
-
Requires-Dist: pandas (>=1.0.0,<3.0.0) ; python_version < "3.10"
|
|
33
|
-
Requires-Dist: pandas (>=1.4.0,<3.0.0) ; python_version >= "3.10"
|
|
34
31
|
Requires-Dist: pydantic (>=1.8.2,<3.0.0)
|
|
35
32
|
Requires-Dist: pygments (>=2.12.0,<3.0.0)
|
|
36
33
|
Requires-Dist: python-dateutil (>=2.8.2,<3.0.0)
|
|
37
34
|
Requires-Dist: python-dotenv (>=1.0.1,<2.0.0)
|
|
38
35
|
Requires-Dist: python-socketio[client] (>=5.5.2,<6.0.0)
|
|
39
36
|
Requires-Dist: questionary (>=1.10.0,<2.0.0)
|
|
40
|
-
Requires-Dist: requests (>=2.
|
|
37
|
+
Requires-Dist: requests (>=2.18.0,<3.0.0)
|
|
41
38
|
Requires-Dist: requirements-parser (>=0.11.0,<0.12.0)
|
|
42
39
|
Requires-Dist: rich (>=13.7.1,<14.0.0)
|
|
43
40
|
Requires-Dist: rich-click (>=1.2.1,<2.0.0)
|
|
@@ -46,7 +43,6 @@ Requires-Dist: scipy (>=1.5.0,<2.0.0) ; python_version < "3.12"
|
|
|
46
43
|
Requires-Dist: tqdm (>=4.0.0,<5.0.0)
|
|
47
44
|
Requires-Dist: typing-extensions (>=4.0)
|
|
48
45
|
Requires-Dist: urllib3 (>=1.26.18,<3)
|
|
49
|
-
Requires-Dist: uvicorn (>=0.13.0,<1.0.0)
|
|
50
46
|
Requires-Dist: yq (>=3.1.0,<4.0.0)
|
|
51
47
|
Description-Content-Type: text/markdown
|
|
52
48
|
|