truefoundry 0.4.9rc1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (78) hide show
  1. truefoundry/deploy/__init__.py +5 -0
  2. truefoundry/deploy/builder/builders/tfy_python_buildpack/dockerfile_template.py +2 -2
  3. truefoundry/deploy/lib/dao/application.py +2 -1
  4. truefoundry/deploy/v2/lib/patched_models.py +8 -0
  5. truefoundry/ml/__init__.py +41 -1
  6. truefoundry/ml/autogen/client/__init__.py +44 -14
  7. truefoundry/ml/autogen/client/api/__init__.py +3 -3
  8. truefoundry/ml/autogen/client/api/deprecated_api.py +333 -0
  9. truefoundry/ml/autogen/client/api/generate_code_snippet_api.py +526 -0
  10. truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +0 -322
  11. truefoundry/ml/autogen/client/api_client.py +8 -1
  12. truefoundry/ml/autogen/client/models/__init__.py +41 -11
  13. truefoundry/ml/autogen/client/models/add_features_to_model_version_request_dto.py +3 -17
  14. truefoundry/ml/autogen/client/models/agent.py +1 -1
  15. truefoundry/ml/autogen/client/models/agent_app.py +1 -1
  16. truefoundry/ml/autogen/client/models/agent_open_api_tool.py +1 -1
  17. truefoundry/ml/autogen/client/models/agent_open_api_tool_with_fqn.py +1 -1
  18. truefoundry/ml/autogen/client/models/agent_with_fqn.py +1 -1
  19. truefoundry/ml/autogen/client/models/artifact_version_dto.py +3 -5
  20. truefoundry/ml/autogen/client/models/artifact_version_manifest.py +111 -0
  21. truefoundry/ml/autogen/client/models/assistant_message.py +1 -1
  22. truefoundry/ml/autogen/client/models/blob_storage_reference.py +1 -1
  23. truefoundry/ml/autogen/client/models/chat_prompt.py +1 -1
  24. truefoundry/ml/autogen/client/models/command.py +152 -0
  25. truefoundry/ml/autogen/client/models/{feature_dto.py → create_workflow_task_config_request_dto.py} +18 -14
  26. truefoundry/ml/autogen/client/models/{external_model_source.py → external_artifact_source.py} +12 -11
  27. truefoundry/ml/autogen/client/models/fast_ai_framework.py +75 -0
  28. truefoundry/ml/autogen/client/models/finalize_artifact_version_request_dto.py +3 -5
  29. truefoundry/ml/autogen/client/models/framework.py +250 -14
  30. truefoundry/ml/autogen/client/models/gluon_framework.py +74 -0
  31. truefoundry/ml/autogen/client/models/{upload_model_source.py → h2_o_framework.py} +11 -11
  32. truefoundry/ml/autogen/client/models/image_content_part.py +1 -1
  33. truefoundry/ml/autogen/client/models/keras_framework.py +74 -0
  34. truefoundry/ml/autogen/client/models/light_gbm_framework.py +75 -0
  35. truefoundry/ml/autogen/client/models/manifest.py +154 -0
  36. truefoundry/ml/autogen/client/models/model_version_dto.py +7 -8
  37. truefoundry/ml/autogen/client/models/model_version_environment.py +97 -0
  38. truefoundry/ml/autogen/client/models/model_version_manifest.py +30 -6
  39. truefoundry/ml/autogen/client/models/onnx_framework.py +74 -0
  40. truefoundry/ml/autogen/client/models/paddle_framework.py +75 -0
  41. truefoundry/ml/autogen/client/models/py_torch_framework.py +75 -0
  42. truefoundry/ml/autogen/client/models/{feature_value_type.py → serialization_format.py} +8 -8
  43. truefoundry/ml/autogen/client/models/sklearn_framework.py +92 -0
  44. truefoundry/ml/autogen/client/models/source.py +23 -46
  45. truefoundry/ml/autogen/client/models/source1.py +154 -0
  46. truefoundry/ml/autogen/client/models/spa_cy_framework.py +74 -0
  47. truefoundry/ml/autogen/client/models/stats_models_framework.py +75 -0
  48. truefoundry/ml/autogen/client/models/system_message.py +1 -1
  49. truefoundry/ml/autogen/client/models/{tensorflow_framework.py → tensor_flow_framework.py} +11 -10
  50. truefoundry/ml/autogen/client/models/text_content_part.py +1 -1
  51. truefoundry/ml/autogen/client/models/transformers_framework.py +10 -4
  52. truefoundry/ml/autogen/client/models/trigger_job_run_config_request_dto.py +90 -0
  53. truefoundry/ml/autogen/client/models/trigger_job_run_config_response_dto.py +71 -0
  54. truefoundry/ml/autogen/client/models/{truefoundry_model_source.py → true_foundry_artifact_source.py} +13 -11
  55. truefoundry/ml/autogen/client/models/update_artifact_version_request_dto.py +11 -1
  56. truefoundry/ml/autogen/client/models/update_model_version_request_dto.py +1 -13
  57. truefoundry/ml/autogen/client/models/user_message.py +1 -1
  58. truefoundry/ml/autogen/client/models/xg_boost_framework.py +92 -0
  59. truefoundry/ml/autogen/client_README.md +30 -12
  60. truefoundry/ml/autogen/entities/artifacts.py +87 -9
  61. truefoundry/ml/autogen/models/__init__.py +4 -0
  62. truefoundry/ml/autogen/models/exceptions.py +30 -0
  63. truefoundry/ml/autogen/models/schema.py +1547 -0
  64. truefoundry/ml/autogen/models/signature.py +139 -0
  65. truefoundry/ml/autogen/models/utils.py +699 -0
  66. truefoundry/ml/log_types/artifacts/artifact.py +131 -63
  67. truefoundry/ml/log_types/artifacts/general_artifact.py +7 -26
  68. truefoundry/ml/log_types/artifacts/model.py +195 -197
  69. truefoundry/ml/mlfoundry_api.py +47 -52
  70. truefoundry/ml/mlfoundry_run.py +35 -43
  71. truefoundry/ml/model_framework.py +169 -0
  72. {truefoundry-0.4.9rc1.dist-info → truefoundry-0.5.0.dist-info}/METADATA +1 -1
  73. {truefoundry-0.4.9rc1.dist-info → truefoundry-0.5.0.dist-info}/RECORD +75 -53
  74. truefoundry/ml/autogen/client/api/python_deployment_config_api.py +0 -201
  75. truefoundry/ml/autogen/client/models/model_schema_dto.py +0 -85
  76. truefoundry/ml/autogen/client/models/prediction_type.py +0 -34
  77. {truefoundry-0.4.9rc1.dist-info → truefoundry-0.5.0.dist-info}/WHEEL +0 -0
  78. {truefoundry-0.4.9rc1.dist-info → truefoundry-0.5.0.dist-info}/entry_points.txt +0 -0
@@ -1,14 +1,23 @@
1
1
  import os
2
2
  import time
3
3
  import uuid
4
- from pathlib import Path
5
- from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Any,
7
+ Dict,
8
+ Iterator,
9
+ List,
10
+ Optional,
11
+ Sequence,
12
+ Tuple,
13
+ Union,
14
+ )
6
15
 
7
16
  import coolname
8
17
  import pandas as pd
9
18
 
10
19
  from truefoundry.common.utils import relogin_error_message
11
- from truefoundry.ml import constants
20
+ from truefoundry.ml import ModelVersionEnvironment, constants
12
21
  from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
13
22
  ArtifactDto,
14
23
  ArtifactType,
@@ -38,11 +47,15 @@ from truefoundry.ml.internal_namespace import NAMESPACE
38
47
  from truefoundry.ml.log_types.artifacts.artifact import (
39
48
  ArtifactPath,
40
49
  ArtifactVersion,
50
+ BlobStorageDirectory,
41
51
  ChatPromptVersion,
42
52
  )
43
53
  from truefoundry.ml.log_types.artifacts.dataset import DataDirectory
44
54
  from truefoundry.ml.log_types.artifacts.general_artifact import _log_artifact_version
45
- from truefoundry.ml.log_types.artifacts.model import ModelVersion, _log_model_version
55
+ from truefoundry.ml.log_types.artifacts.model import (
56
+ ModelVersion,
57
+ _log_model_version,
58
+ )
46
59
  from truefoundry.ml.logger import logger
47
60
  from truefoundry.ml.mlfoundry_run import MlFoundryRun
48
61
  from truefoundry.ml.session import (
@@ -57,6 +70,9 @@ from truefoundry.ml.validation_utils import (
57
70
  _validate_run_name,
58
71
  )
59
72
 
73
+ if TYPE_CHECKING:
74
+ from truefoundry.ml import ModelFrameworkType
75
+
60
76
  _SEARCH_MAX_RESULTS_DEFAULT = 1000
61
77
 
62
78
  _INTERNAL_ENV_VARS = [
@@ -1209,12 +1225,13 @@ class MlFoundry:
1209
1225
  *,
1210
1226
  ml_repo: str,
1211
1227
  name: str,
1212
- model_file_or_folder: str,
1213
- framework: Optional[Union[ModelFramework, str]],
1214
- additional_files: Sequence[Tuple[Union[str, Path], Optional[str]]] = (),
1228
+ model_file_or_folder: Union[str, BlobStorageDirectory],
1215
1229
  description: Optional[str] = None,
1216
1230
  metadata: Optional[Dict[str, Any]] = None,
1217
1231
  progress: Optional[bool] = None,
1232
+ framework: Optional[Union[str, ModelFramework, "ModelFrameworkType"]] = None,
1233
+ environment: Optional[ModelVersionEnvironment] = None,
1234
+ model_schema: Optional[Dict[str, Any]] = None,
1218
1235
  ) -> ModelVersion:
1219
1236
  """
1220
1237
  Serialize and log a versioned model under the current ml_repo. Each logged model generates a new version
@@ -1226,45 +1243,24 @@ class MlFoundry:
1226
1243
  name (str): Name of the model. If a model with this name already exists under the current ML Repo,
1227
1244
  the logged model will be added as a new version under that `name`. If no models exist with the given
1228
1245
  `name`, the given model will be logged as version 1.
1229
- model_file_or_folder (str): Path to either a single file or a folder containing model files. This folder
1230
- is usually created using serialization methods of libraries or frameworks e.g. `joblib.dump`,
1231
- `model.save_pretrained(...)`, `torch.save(...)`, `model.save(...)`
1232
- framework (Union[enums.ModelFramework, str]): Model Framework. Ex:- pytorch, sklearn, tensorflow etc.
1233
- The full list of supported frameworks can be found in `truefoundry.ml.enums.ModelFramework`.
1234
- Can also be `None` when `model` is `None`.
1235
- additional_files (Sequence[Tuple[Union[str, Path], Optional[str]]], optional): A list of pairs
1236
- of (source path, destination path) to add additional files and folders
1237
- to the model version contents. The first member of the pair should be a file or directory path
1238
- and the second member should be the path inside the model versions contents to upload to.
1239
- The model version contents are arranged like follows
1240
- .
1241
- └── model/
1242
- └── # model files are serialized here
1243
- └── # any additional files and folders can be added here.
1244
-
1245
- You can also add additional files to model/ subdirectory by specifying the destination path as model/
1246
-
1247
- ```python
1248
- run.log_model(
1249
- name="xyz",
1250
- model_file_or_folder="clf.joblib",
1251
- framework="sklearn",
1252
- additional_files=[("foo.txt", "foo/bar/foo.txt"), ("tokenizer/", "foo/tokenizer/")]
1253
- )
1254
- ```
1255
1246
 
1256
- would result in
1247
+ model_file_or_folder (Union[str, BlobStorageDirectory]):
1248
+ str:
1249
+ Path to either a single file or a folder containing model files.
1250
+ This folder is typically created using serialization methods from libraries or frameworks,
1251
+ e.g., `joblib.dump`, `model.save_pretrained(...)`, `torch.save(...)`, or `model.save(...)`.
1252
+ BlobStorageDirectory:
1253
+ uri (str): URI to the model file or folder in a storage integration associated with the specified ML Repo.
1254
+ The model files or folder must reside within the same storage integration as the specified ML Repo.
1255
+ Accepted URI formats include `s3://integration-bucket-name/prefix/path/to/model` or `gs://integration-bucket-name/prefix/path/to/model`.
1256
+ If the URI points to a model in a different storage integration, an error will be raised indicating "Invalid source URI."
1257
+
1258
+ framework (Optional[Union[ModelFramework, ModelFrameworkType]]): Framework used for model serialization.
1259
+ Supported frameworks values (ModelFrameworkType) can be imported from `from truefoundry.ml import *`.
1260
+ Supported frameworks can be found in `truefoundry.ml.enums.ModelFramework`.
1261
+ Can also be `None` if the framework is not known or not supported.
1262
+ **Deprecated**: Prefer `ModelFrameworkType` over `enums.ModelFramework`.
1257
1263
 
1258
- ```
1259
- .
1260
- ├── model/
1261
- │ └── clf.joblib # if `model_file_or_folder` is a folder, contents will be added here
1262
- └── foo/
1263
- ├── bar/
1264
- │ └── foo.txt
1265
- └── tokenizer/
1266
- └── # contents of tokenizer/ directory will be uploaded here
1267
- ```
1268
1264
  description (Optional[str], optional): arbitrary text upto 1024 characters to store as description.
1269
1265
  This field can be updated at any time after logging. Defaults to `None`
1270
1266
  metadata (Optional[Dict[str, Any]], optional): arbitrary json serializable dictionary to store metadata.
@@ -1281,8 +1277,7 @@ class MlFoundry:
1281
1277
  ### Sklearn
1282
1278
 
1283
1279
  ```python
1284
- from truefoundry.ml import get_client
1285
- from truefoundry.ml.enums import ModelFramework
1280
+ from truefoundry.ml import get_client, SklearnFramework
1286
1281
 
1287
1282
  import joblib
1288
1283
  import numpy as np
@@ -1307,7 +1302,7 @@ class MlFoundry:
1307
1302
  ml_repo="my-classification-project",
1308
1303
  name="my-sklearn-model",
1309
1304
  model_file_or_folder="sklearn-pipeline.joblib",
1310
- framework=ModelFramework.SKLEARN,
1305
+ framework=SklearnFramework(),
1311
1306
  metadata={"accuracy": 0.99, "f1": 0.80},
1312
1307
  step=1, # step number, useful when using iterative algorithms like SGD
1313
1308
  )
@@ -1317,8 +1312,7 @@ class MlFoundry:
1317
1312
  ### Huggingface Transformers
1318
1313
 
1319
1314
  ```python
1320
- from truefoundry.ml import get_client
1321
- from truefoundry.ml.enums import ModelFramework
1315
+ from truefoundry.ml import get_client, TransformersFramework, LibraryName
1322
1316
 
1323
1317
  import torch
1324
1318
  from transformers import AutoTokenizer, AutoConfig, pipeline, AutoModelForCausalLM
@@ -1342,7 +1336,7 @@ class MlFoundry:
1342
1336
  ml_repo="my-llm-project",
1343
1337
  name="my-transformers-model",
1344
1338
  model_file_or_folder="my-transformers-model/",
1345
- framework=ModelFramework.TRANSFORMERS
1339
+ framework=TransformersFramework(library_name=LibraryName.TRANSFORMERS, pipeline_tag='text-generation')
1346
1340
  )
1347
1341
  print(model_version.fqn)
1348
1342
  ```
@@ -1356,12 +1350,13 @@ class MlFoundry:
1356
1350
  ml_repo_id=ml_repo_id,
1357
1351
  name=name,
1358
1352
  model_file_or_folder=model_file_or_folder,
1359
- framework=framework,
1360
- additional_files=additional_files,
1361
1353
  description=description,
1362
1354
  metadata=metadata,
1363
1355
  step=None,
1364
1356
  progress=progress,
1357
+ framework=framework,
1358
+ environment=environment,
1359
+ model_schema=model_schema,
1365
1360
  )
1366
1361
  logger.info(f"Logged model successfully with fqn {model_version.fqn!r}")
1367
1362
  return model_version
@@ -3,7 +3,6 @@ import os
3
3
  import platform
4
4
  import re
5
5
  import time
6
- from pathlib import Path
7
6
  from typing import (
8
7
  TYPE_CHECKING,
9
8
  Any,
@@ -12,7 +11,6 @@ from typing import (
12
11
  Iterator,
13
12
  List,
14
13
  Optional,
15
- Sequence,
16
14
  Tuple,
17
15
  Union,
18
16
  )
@@ -20,7 +18,7 @@ from urllib.parse import urljoin, urlsplit
20
18
 
21
19
  from truefoundry import version
22
20
  from truefoundry.common.utils import relogin_error_message
23
- from truefoundry.ml import constants, enums
21
+ from truefoundry.ml import ModelVersionEnvironment, constants
24
22
  from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
25
23
  ArtifactType,
26
24
  DeleteRunRequest,
@@ -40,13 +38,20 @@ from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
40
38
  UpdateRunRequestDto,
41
39
  )
42
40
  from truefoundry.ml.entities import Metric
43
- from truefoundry.ml.enums import RunStatus
41
+ from truefoundry.ml.enums import ModelFramework, RunStatus
44
42
  from truefoundry.ml.exceptions import MlFoundryException
45
43
  from truefoundry.ml.internal_namespace import NAMESPACE
46
44
  from truefoundry.ml.log_types import Image, Plot
47
- from truefoundry.ml.log_types.artifacts.artifact import ArtifactPath, ArtifactVersion
45
+ from truefoundry.ml.log_types.artifacts.artifact import (
46
+ ArtifactPath,
47
+ ArtifactVersion,
48
+ BlobStorageDirectory,
49
+ )
48
50
  from truefoundry.ml.log_types.artifacts.general_artifact import _log_artifact_version
49
- from truefoundry.ml.log_types.artifacts.model import ModelVersion, _log_model_version
51
+ from truefoundry.ml.log_types.artifacts.model import (
52
+ ModelVersion,
53
+ _log_model_version,
54
+ )
50
55
  from truefoundry.ml.logger import logger
51
56
  from truefoundry.ml.run_utils import ParamsType, flatten_dict, process_params
52
57
  from truefoundry.ml.session import ACTIVE_RUNS, _get_api_client, get_active_session
@@ -61,6 +66,8 @@ if TYPE_CHECKING:
61
66
  import matplotlib
62
67
  import plotly
63
68
 
69
+ from truefoundry.ml import ModelFrameworkType
70
+
64
71
 
65
72
  def _ensure_not_deleted(method):
66
73
  @functools.wraps(method)
@@ -920,13 +927,14 @@ class MlFoundryRun:
920
927
  self,
921
928
  *,
922
929
  name: str,
923
- model_file_or_folder: str,
924
- framework: Optional[Union[enums.ModelFramework, str]],
925
- additional_files: Sequence[Tuple[Union[str, Path], Optional[str]]] = (),
930
+ model_file_or_folder: Union[str, BlobStorageDirectory],
926
931
  description: Optional[str] = None,
927
932
  metadata: Optional[Dict[str, Any]] = None,
928
933
  step: int = 0,
929
934
  progress: Optional[bool] = None,
935
+ framework: Optional[Union[str, ModelFramework, "ModelFrameworkType"]] = None,
936
+ environment: Optional[ModelVersionEnvironment] = None,
937
+ model_schema: Optional[Dict[str, Any]] = None,
930
938
  ) -> ModelVersion:
931
939
  # TODO (chiragjn): Document mapping of framework to list of valid model save kwargs
932
940
  # TODO (chiragjn): Add more examples
@@ -939,39 +947,22 @@ class MlFoundryRun:
939
947
  name (str): Name of the model. If a model with this name already exists under the current ML Repo,
940
948
  the logged model will be added as a new version under that `name`. If no models exist with the given
941
949
  `name`, the given model will be logged as version 1.
942
- model_file_or_folder (str): Path to either a single file or a folder containing model files. This folder
943
- is usually created using serialization methods of libraries or frameworks e.g. `joblib.dump`,
944
- `model.save_pretrained(...)`, `torch.save(...)`, `model.save(...)`
945
- framework (Union[enums.ModelFramework, str]): Model Framework. Ex:- pytorch, sklearn, tensorflow etc.
946
- The full list of supported frameworks can be found in `truefoundry.ml.enums.ModelFramework`.
947
- Can also be `None` when `model` is `None`.
948
- additional_files (Sequence[Tuple[Union[str, Path], Optional[str]]], optional): A list of pairs
949
- of (source path, destination path) to add additional files and folders
950
- to the model version contents. The first member of the pair should be a file or directory path
951
- and the second member should be the path inside the model versions contents to upload to.
952
- The model version contents are arranged like follows
953
- .
954
- └── model/
955
- └── # model files are serialized here
956
- └── # any additional files and folders can be added here.
950
+ model_file_or_folder (Union[str, BlobStorageDirectory]):
951
+ str:
952
+ Path to either a single file or a folder containing model files.
953
+ This folder is typically created using serialization methods from libraries or frameworks,
954
+ e.g., `joblib.dump`, `model.save_pretrained(...)`, `torch.save(...)`, or `model.save(...)`.
955
+ BlobStorageDirectory:
956
+ uri (str): URI to the model file or folder in a storage integration associated with the specified ML Repo.
957
+ The model files or folder must reside within the same storage integration as the specified ML Repo.
958
+ Accepted URI formats include `s3://integration-bucket-name/prefix/path/to/model` or `gs://integration-bucket-name/prefix/path/to/model`.
959
+ If the URI points to a model in a different storage integration, an error will be raised indicating "Invalid source URI."
960
+ framework (Optional[Union[ModelFramework, ModelFrameworkType]]): Framework used for model serialization.
961
+ Supported frameworks values (ModelFrameworkType) can be imported from `from truefoundry.ml import *`.
962
+ Supported frameworks can be found in `truefoundry.ml.enums.ModelFramework`.
963
+ Can also be `None` if the framework is not known or not supported.
964
+ **Deprecated**: Prefer `ModelFrameworkType` over `enums.ModelFramework`.
957
965
 
958
- You can also add additional files to model/ subdirectory by specifying the destination path as model/
959
-
960
- ```
961
- E.g. >>> run.log_model(
962
- ... name="xyz", model_file_or_folder="clf.joblib", framework="sklearn",
963
- ... additional_files=[("foo.txt", "foo/bar/foo.txt"), ("tokenizer/", "foo/tokenizer/")]
964
- ... )
965
- would result in
966
- .
967
- ├── model/
968
- │ └── clf.joblib # if `model_file_or_folder` is a folder, contents will be added here
969
- └── foo/
970
- ├── bar/
971
- │ └── foo.txt
972
- └── tokenizer/
973
- └── # contents of tokenizer/ directory will be uploaded here
974
- ```
975
966
  description (Optional[str], optional): arbitrary text upto 1024 characters to store as description.
976
967
  This field can be updated at any time after logging. Defaults to `None`
977
968
  metadata (Optional[Dict[str, Any]], optional): arbitrary json serializable dictionary to store metadata.
@@ -1064,12 +1055,13 @@ class MlFoundryRun:
1064
1055
  run=self,
1065
1056
  name=name,
1066
1057
  model_file_or_folder=model_file_or_folder,
1067
- framework=framework,
1068
- additional_files=additional_files,
1069
1058
  description=description,
1070
1059
  metadata=metadata,
1071
1060
  step=step,
1072
1061
  progress=progress,
1062
+ framework=framework,
1063
+ environment=environment,
1064
+ model_schema=model_schema,
1073
1065
  )
1074
1066
  logger.info(f"Logged model successfully with fqn {model_version.fqn!r}")
1075
1067
  return model_version
@@ -0,0 +1,169 @@
1
+ import warnings
2
+ from typing import Any, Dict, Literal, Optional, Union, get_args
3
+
4
+ from truefoundry.ml import ModelFramework
5
+ from truefoundry.ml.autogen.entities import artifacts as autogen_artifacts
6
+ from truefoundry.pydantic_v1 import BaseModel, Field
7
+
8
+
9
+ class FastAIFramework(autogen_artifacts.FastAIFramework):
10
+ """FastAI model Framework"""
11
+
12
+ type: Literal["fastai"] = "fastai"
13
+
14
+
15
+ class GluonFramework(autogen_artifacts.GluonFramework):
16
+ """Gluon model Framework"""
17
+
18
+ type: Literal["gluon"] = "gluon"
19
+
20
+
21
+ class H2OFramework(autogen_artifacts.H2OFramework):
22
+ """H2O model Framework"""
23
+
24
+ type: Literal["h2o"] = "h2o"
25
+
26
+
27
+ class KerasFramework(autogen_artifacts.KerasFramework):
28
+ """Keras model Framework"""
29
+
30
+ type: Literal["keras"] = "keras"
31
+
32
+
33
+ class LightGBMFramework(autogen_artifacts.LightGBMFramework):
34
+ """LightGBM model Framework"""
35
+
36
+ type: Literal["lightgbm"] = "lightgbm"
37
+
38
+
39
+ class ONNXFramework(autogen_artifacts.ONNXFramework):
40
+ """ONNX model Framework"""
41
+
42
+ type: Literal["onnx"] = "onnx"
43
+
44
+
45
+ class PaddleFramework(autogen_artifacts.PaddleFramework):
46
+ """Paddle model Framework"""
47
+
48
+ type: Literal["paddle"] = "paddle"
49
+
50
+
51
+ class PyTorchFramework(autogen_artifacts.PyTorchFramework):
52
+ """PyTorch model Framework"""
53
+
54
+ type: Literal["pytorch"] = "pytorch"
55
+
56
+
57
+ class SklearnFramework(autogen_artifacts.SklearnFramework):
58
+ """Sklearn model Framework"""
59
+
60
+ type: Literal["sklearn"] = "sklearn"
61
+
62
+
63
+ class SpaCyFramework(autogen_artifacts.SpaCyFramework):
64
+ """SpaCy model Framework"""
65
+
66
+ type: Literal["spacy"] = "spacy"
67
+
68
+
69
+ class StatsModelsFramework(autogen_artifacts.StatsModelsFramework):
70
+ """StatsModels model Framework"""
71
+
72
+ type: Literal["statsmodels"] = "statsmodels"
73
+
74
+
75
+ class TensorFlowFramework(autogen_artifacts.TensorFlowFramework):
76
+ """TensorFlow model Framework"""
77
+
78
+ type: Literal["tensorflow"] = "tensorflow"
79
+
80
+
81
+ class TransformersFramework(autogen_artifacts.TransformersFramework):
82
+ """Transformers model Framework"""
83
+
84
+ type: Literal["transformers"] = "transformers"
85
+
86
+
87
+ class XGBoostFramework(autogen_artifacts.XGBoostFramework):
88
+ """XGBoost model Framework"""
89
+
90
+ type: Literal["xgboost"] = "xgboost"
91
+
92
+
93
+ # Union of all the model frameworks
94
+
95
+
96
+ ModelFrameworkType = Union[
97
+ FastAIFramework,
98
+ GluonFramework,
99
+ H2OFramework,
100
+ KerasFramework,
101
+ LightGBMFramework,
102
+ ONNXFramework,
103
+ PaddleFramework,
104
+ PyTorchFramework,
105
+ SklearnFramework,
106
+ SpaCyFramework,
107
+ StatsModelsFramework,
108
+ TensorFlowFramework,
109
+ TransformersFramework,
110
+ XGBoostFramework,
111
+ ]
112
+
113
+
114
+ class _ModelFramework(BaseModel):
115
+ __root__: ModelFrameworkType = Field(discriminator="type")
116
+
117
+ @classmethod
118
+ def to_model_framework_type(
119
+ cls,
120
+ framework: Optional[Union[str, ModelFramework, "ModelFrameworkType"]] = None,
121
+ ) -> Optional["ModelFrameworkType"]:
122
+ """
123
+ Converts a ModelFramework or string representation to a ModelFrameworkType object.
124
+
125
+ Args:
126
+ framework (Optional[Union[str, ModelFramework, ModelFrameworkType]]): ModelFrameworkType or equivalent input.
127
+ Supported frameworks can be found in `truefoundry.ml.enums.ModelFramework`.
128
+ May be `None` if the framework is unknown or unsupported.
129
+ **Deprecated**: Prefer passing a `ModelFrameworkType` instance.
130
+
131
+ Returns:
132
+ ModelFrameworkType corresponding to the input, or None if the input is None.
133
+ """
134
+ if framework is None:
135
+ return None
136
+
137
+ # Issue a deprecation warning for str and ModelFramework types
138
+ if isinstance(framework, (str, ModelFramework)):
139
+ warnings.warn(
140
+ "Passing a string or ModelFramework Enum is deprecated. Please use a ModelFrameworkType object.",
141
+ DeprecationWarning,
142
+ stacklevel=2,
143
+ )
144
+
145
+ # Convert string to ModelFramework
146
+ if isinstance(framework, str):
147
+ framework = ModelFramework(framework)
148
+
149
+ # Convert ModelFramework to ModelFrameworkType
150
+ if isinstance(framework, ModelFramework):
151
+ if framework == ModelFramework.UNKNOWN:
152
+ return None
153
+ return cls.parse_obj({"type": framework.value}).__root__
154
+
155
+ # Directly return if already a ModelFrameworkType
156
+ if isinstance(framework, get_args(ModelFrameworkType)):
157
+ return framework
158
+
159
+ raise ValueError(
160
+ "framework must be a string, ModelFramework enum, or ModelFrameworkType object"
161
+ )
162
+
163
+ @classmethod
164
+ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[ModelFrameworkType]:
165
+ """Create an instance of ModelFramework from a dict"""
166
+ if obj is None:
167
+ return None
168
+
169
+ return cls.parse_obj(obj).__root__
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: truefoundry
3
- Version: 0.4.9rc1
3
+ Version: 0.5.0
4
4
  Summary: Truefoundry CLI
5
5
  Author: Abhishek Choudhary
6
6
  Author-email: abhishek@truefoundry.com