truefoundry 0.5.0rc7__py3-none-any.whl → 0.5.1rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (52) hide show
  1. truefoundry/common/utils.py +73 -1
  2. truefoundry/deploy/__init__.py +5 -0
  3. truefoundry/deploy/cli/cli.py +2 -0
  4. truefoundry/deploy/cli/commands/__init__.py +1 -0
  5. truefoundry/deploy/cli/commands/deploy_init_command.py +22 -0
  6. truefoundry/deploy/lib/dao/application.py +2 -1
  7. truefoundry/deploy/v2/lib/patched_models.py +8 -0
  8. truefoundry/ml/__init__.py +14 -12
  9. truefoundry/ml/autogen/client/__init__.py +5 -0
  10. truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +161 -0
  11. truefoundry/ml/autogen/client/models/__init__.py +5 -0
  12. truefoundry/ml/autogen/client/models/artifact_version_manifest.py +2 -2
  13. truefoundry/ml/autogen/client/models/export_deployment_files_request_dto.py +82 -0
  14. truefoundry/ml/autogen/client/models/infer_method_name.py +34 -0
  15. truefoundry/ml/autogen/client/models/model_server.py +34 -0
  16. truefoundry/ml/autogen/client/models/model_version_environment.py +1 -1
  17. truefoundry/ml/autogen/client/models/model_version_manifest.py +3 -3
  18. truefoundry/ml/autogen/client/models/sklearn_framework.py +17 -1
  19. truefoundry/ml/autogen/client/models/transformers_framework.py +2 -2
  20. truefoundry/ml/autogen/client/models/xg_boost_framework.py +6 -1
  21. truefoundry/ml/autogen/client_README.md +4 -0
  22. truefoundry/ml/autogen/entities/artifacts.py +29 -7
  23. truefoundry/ml/cli/commands/model_init.py +97 -0
  24. truefoundry/ml/cli/utils.py +34 -0
  25. truefoundry/ml/log_types/artifacts/model.py +63 -24
  26. truefoundry/ml/log_types/artifacts/utils.py +37 -1
  27. truefoundry/ml/mlfoundry_api.py +74 -78
  28. truefoundry/ml/mlfoundry_run.py +0 -30
  29. truefoundry/ml/model_framework.py +257 -3
  30. truefoundry/ml/validation_utils.py +2 -0
  31. {truefoundry-0.5.0rc7.dist-info → truefoundry-0.5.1rc2.dist-info}/METADATA +1 -5
  32. {truefoundry-0.5.0rc7.dist-info → truefoundry-0.5.1rc2.dist-info}/RECORD +34 -46
  33. truefoundry/deploy/function_service/__init__.py +0 -3
  34. truefoundry/deploy/function_service/__main__.py +0 -27
  35. truefoundry/deploy/function_service/app.py +0 -92
  36. truefoundry/deploy/function_service/build.py +0 -45
  37. truefoundry/deploy/function_service/remote/__init__.py +0 -6
  38. truefoundry/deploy/function_service/remote/context.py +0 -3
  39. truefoundry/deploy/function_service/remote/method.py +0 -67
  40. truefoundry/deploy/function_service/remote/remote.py +0 -144
  41. truefoundry/deploy/function_service/route.py +0 -137
  42. truefoundry/deploy/function_service/service.py +0 -113
  43. truefoundry/deploy/function_service/utils.py +0 -53
  44. truefoundry/langchain/__init__.py +0 -12
  45. truefoundry/langchain/deprecated.py +0 -302
  46. truefoundry/langchain/truefoundry_chat.py +0 -130
  47. truefoundry/langchain/truefoundry_embeddings.py +0 -171
  48. truefoundry/langchain/truefoundry_llm.py +0 -106
  49. truefoundry/langchain/utils.py +0 -44
  50. truefoundry/ml/log_types/artifacts/model_extras.py +0 -48
  51. {truefoundry-0.5.0rc7.dist-info → truefoundry-0.5.1rc2.dist-info}/WHEEL +0 -0
  52. {truefoundry-0.5.0rc7.dist-info → truefoundry-0.5.1rc2.dist-info}/entry_points.txt +0 -0
@@ -40,11 +40,11 @@ class ModelVersionManifest(BaseModel):
40
40
 
41
41
  description: Optional[constr(strict=True, max_length=512)] = Field(
42
42
  default=None,
43
- description="+label=Description +docs=Description of the artifact version",
43
+ description="+label=Description +usage=Description of the artifact or model version +docs=Description of the artifact or model version",
44
44
  )
45
45
  metadata: Dict[str, Any] = Field(
46
46
  default=...,
47
- description="+label=Metadata +docs=Metadata for the model version +usage=Metadata for the model version +uiType=JsonInput",
47
+ description="+label=Metadata +docs=Metadata for the artifact or model version +usage=Metadata for the artifact or model version +uiType=JsonInput",
48
48
  )
49
49
  type: Optional[StrictStr] = "model-version"
50
50
  source: Source1 = Field(...)
@@ -55,7 +55,7 @@ class ModelVersionManifest(BaseModel):
55
55
  )
56
56
  model_schema: Optional[Dict[str, Any]] = Field(
57
57
  default=None,
58
- description="+label=Model Schema +usage=Schema of the model +uiType=Hidden",
58
+ description="+label=Model Schema +usage=Schema of the model +uiType=JsonInput",
59
59
  )
60
60
  __properties = [
61
61
  "description",
@@ -18,6 +18,7 @@ import pprint
18
18
  import re # noqa: F401
19
19
  from typing import Optional
20
20
 
21
+ from truefoundry.ml.autogen.client.models.infer_method_name import InferMethodName
21
22
  from truefoundry.ml.autogen.client.models.serialization_format import (
22
23
  SerializationFormat,
23
24
  )
@@ -37,7 +38,20 @@ class SklearnFramework(BaseModel):
37
38
  default=None,
38
39
  description="+label=Serialization format +usage=Serialization format used for the model",
39
40
  )
40
- __properties = ["type", "serialization_format"]
41
+ model_filepath: Optional[StrictStr] = Field(
42
+ default=None,
43
+ description="+label=Model file path +usage=Relative path to the model file",
44
+ )
45
+ infer_method_name: Optional[InferMethodName] = Field(
46
+ default=None,
47
+ description="+label=Inference method name +usage=Name of the method used for inference",
48
+ )
49
+ __properties = [
50
+ "type",
51
+ "serialization_format",
52
+ "model_filepath",
53
+ "infer_method_name",
54
+ ]
41
55
 
42
56
  @validator("type")
43
57
  def type_validate_enum(cls, value):
@@ -83,6 +97,8 @@ class SklearnFramework(BaseModel):
83
97
  {
84
98
  "type": obj.get("type"),
85
99
  "serialization_format": obj.get("serialization_format"),
100
+ "model_filepath": obj.get("model_filepath"),
101
+ "infer_method_name": obj.get("infer_method_name"),
86
102
  }
87
103
  )
88
104
  return _obj
@@ -37,11 +37,11 @@ class TransformersFramework(BaseModel):
37
37
  )
38
38
  pipeline_tag: Optional[StrictStr] = Field(
39
39
  default=None,
40
- description="+label=Pipeline Tag +usage=Pipeline tag +docs=Pipeline tag for the framework",
40
+ description="+label=Pipeline Tag +usage=The `pipeline()` task this model can be used with e.g. `text-generation`. See [huggingface docs](https://huggingface.co/docs/transformers/main/en/main_classes/pipelines#transformers.pipeline.task) for all possible values +docs=Pipeline tag for the framework",
41
41
  )
42
42
  base_model: Optional[StrictStr] = Field(
43
43
  default=None,
44
- description="+label=Base Model +usage=Base model +docs=Base model Id. If this is a finetuned model, this points to the base model used for finetuning",
44
+ description="+label=Base Model +usage=Base model Id. If this is a finetuned model, this points to the base model used for finetuning +docs=Base model Id. If this is a finetuned model, this points to the base model used for finetuning",
45
45
  )
46
46
  __properties = ["type", "library_name", "pipeline_tag", "base_model"]
47
47
 
@@ -37,7 +37,11 @@ class XGBoostFramework(BaseModel):
37
37
  default=None,
38
38
  description="+label=Serialization format +usage=Serialization format used for the model",
39
39
  )
40
- __properties = ["type", "serialization_format"]
40
+ model_filepath: Optional[StrictStr] = Field(
41
+ default=None,
42
+ description="+label=Model file path +usage=Relative path to the model file",
43
+ )
44
+ __properties = ["type", "serialization_format", "model_filepath"]
41
45
 
42
46
  @validator("type")
43
47
  def type_validate_enum(cls, value):
@@ -83,6 +87,7 @@ class XGBoostFramework(BaseModel):
83
87
  {
84
88
  "type": obj.get("type"),
85
89
  "serialization_format": obj.get("serialization_format"),
90
+ "model_filepath": obj.get("model_filepath"),
86
91
  }
87
92
  )
88
93
  return _obj
@@ -102,6 +102,7 @@ Class | Method | HTTP request | Description
102
102
  *MlfoundryArtifactsApi* | [**delete_dataset_post**](truefoundry/ml/autogen/client/docs/MlfoundryArtifactsApi.md#delete_dataset_post) | **POST** /api/2.0/mlflow/mlfoundry-artifacts/datasets/delete | Delete Dataset
103
103
  *MlfoundryArtifactsApi* | [**delete_files_for_dataset_delete**](truefoundry/ml/autogen/client/docs/MlfoundryArtifactsApi.md#delete_files_for_dataset_delete) | **DELETE** /api/2.0/mlflow/mlfoundry-artifacts/datasets/files/ | Delete Files For Dataset
104
104
  *MlfoundryArtifactsApi* | [**delete_model_version_post**](truefoundry/ml/autogen/client/docs/MlfoundryArtifactsApi.md#delete_model_version_post) | **POST** /api/2.0/mlflow/mlfoundry-artifacts/model-versions/delete | Delete Model Version
105
+ *MlfoundryArtifactsApi* | [**export_deployment_files_by_fqn_post**](truefoundry/ml/autogen/client/docs/MlfoundryArtifactsApi.md#export_deployment_files_by_fqn_post) | **POST** /api/2.0/mlflow/mlfoundry-artifacts/model-versions/export-deployment-files-by-fqn | Export Deployment Files By Fqn
105
106
  *MlfoundryArtifactsApi* | [**finalize_artifact_version_post**](truefoundry/ml/autogen/client/docs/MlfoundryArtifactsApi.md#finalize_artifact_version_post) | **POST** /api/2.0/mlflow/mlfoundry-artifacts/artifact-versions/finalize | Finalize Artifact Version
106
107
  *MlfoundryArtifactsApi* | [**get_artifact_by_fqn_get**](truefoundry/ml/autogen/client/docs/MlfoundryArtifactsApi.md#get_artifact_by_fqn_get) | **GET** /api/2.0/mlflow/mlfoundry-artifacts/artifacts/get-by-fqn | Get Artifact By Fqn
107
108
  *MlfoundryArtifactsApi* | [**get_artifact_by_id_get**](truefoundry/ml/autogen/client/docs/MlfoundryArtifactsApi.md#get_artifact_by_id_get) | **GET** /api/2.0/mlflow/mlfoundry-artifacts/artifacts/get | Get Artifact By Id
@@ -211,6 +212,7 @@ Class | Method | HTTP request | Description
211
212
  - [ExperimentIdRequestDto](truefoundry/ml/autogen/client/docs/ExperimentIdRequestDto.md)
212
213
  - [ExperimentResponseDto](truefoundry/ml/autogen/client/docs/ExperimentResponseDto.md)
213
214
  - [ExperimentTagDto](truefoundry/ml/autogen/client/docs/ExperimentTagDto.md)
215
+ - [ExportDeploymentFilesRequestDto](truefoundry/ml/autogen/client/docs/ExportDeploymentFilesRequestDto.md)
214
216
  - [ExternalArtifactSource](truefoundry/ml/autogen/client/docs/ExternalArtifactSource.md)
215
217
  - [FastAIFramework](truefoundry/ml/autogen/client/docs/FastAIFramework.md)
216
218
  - [FileInfoDto](truefoundry/ml/autogen/client/docs/FileInfoDto.md)
@@ -233,6 +235,7 @@ Class | Method | HTTP request | Description
233
235
  - [HTTPValidationError](truefoundry/ml/autogen/client/docs/HTTPValidationError.md)
234
236
  - [ImageContentPart](truefoundry/ml/autogen/client/docs/ImageContentPart.md)
235
237
  - [ImageUrl](truefoundry/ml/autogen/client/docs/ImageUrl.md)
238
+ - [InferMethodName](truefoundry/ml/autogen/client/docs/InferMethodName.md)
236
239
  - [InternalMetadata](truefoundry/ml/autogen/client/docs/InternalMetadata.md)
237
240
  - [KerasFramework](truefoundry/ml/autogen/client/docs/KerasFramework.md)
238
241
  - [LatestRunLogDto](truefoundry/ml/autogen/client/docs/LatestRunLogDto.md)
@@ -270,6 +273,7 @@ Class | Method | HTTP request | Description
270
273
  - [ModelConfiguration](truefoundry/ml/autogen/client/docs/ModelConfiguration.md)
271
274
  - [ModelDto](truefoundry/ml/autogen/client/docs/ModelDto.md)
272
275
  - [ModelResponseDto](truefoundry/ml/autogen/client/docs/ModelResponseDto.md)
276
+ - [ModelServer](truefoundry/ml/autogen/client/docs/ModelServer.md)
273
277
  - [ModelVersionDto](truefoundry/ml/autogen/client/docs/ModelVersionDto.md)
274
278
  - [ModelVersionEnvironment](truefoundry/ml/autogen/client/docs/ModelVersionEnvironment.md)
275
279
  - [ModelVersionManifest](truefoundry/ml/autogen/client/docs/ModelVersionManifest.md)
@@ -1,6 +1,6 @@
1
1
  # generated by datamodel-codegen:
2
2
  # filename: artifacts.json
3
- # timestamp: 2024-11-28T08:27:16+00:00
3
+ # timestamp: 2024-12-04T12:00:28+00:00
4
4
 
5
5
  from __future__ import annotations
6
6
 
@@ -52,11 +52,11 @@ class AgentWithFQN(Agent):
52
52
  class BaseArtifactVersion(BaseModel):
53
53
  description: Optional[constr(max_length=512)] = Field(
54
54
  None,
55
- description="+label=Description\n+docs=Description of the artifact version",
55
+ description="+label=Description\n+usage=Description of the artifact or model version\n+docs=Description of the artifact or model version",
56
56
  )
57
57
  metadata: Dict[str, Any] = Field(
58
58
  ...,
59
- description="+label=Metadata\n+docs=Metadata for the model version\n+usage=Metadata for the model version\n+uiType=JsonInput",
59
+ description="+label=Metadata\n+docs=Metadata for the artifact or model version\n+usage=Metadata for the artifact or model version\n+uiType=JsonInput",
60
60
  )
61
61
 
62
62
 
@@ -228,7 +228,7 @@ class ModelConfiguration(BaseModel):
228
228
 
229
229
  class ModelVersionEnvironment(BaseModel):
230
230
  """
231
- +label=ModelVersionEnvironment
231
+ +label=Environment
232
232
  """
233
233
 
234
234
  python_version: Optional[constr(regex=r"^\d+(\.\d+){1,2}([\-\.a-z0-9]+)?$")] = (
@@ -287,6 +287,16 @@ class SerializationFormat(str, Enum):
287
287
  pickle = "pickle"
288
288
 
289
289
 
290
+ class InferMethodName(str, Enum):
291
+ """
292
+ +label=Inference method name
293
+ +usage=Name of the method used for inference
294
+ """
295
+
296
+ predict = "predict"
297
+ predict_proba = "predict_proba"
298
+
299
+
290
300
  class SklearnFramework(BaseModel):
291
301
  """
292
302
  +docs=Scikit-learn framework for the model version
@@ -300,6 +310,14 @@ class SklearnFramework(BaseModel):
300
310
  None,
301
311
  description="+label=Serialization format\n+usage=Serialization format used for the model",
302
312
  )
313
+ model_filepath: Optional[str] = Field(
314
+ None,
315
+ description="+label=Model file path\n+usage=Relative path to the model file",
316
+ )
317
+ infer_method_name: Optional[InferMethodName] = Field(
318
+ None,
319
+ description="+label=Inference method name\n+usage=Name of the method used for inference",
320
+ )
303
321
 
304
322
 
305
323
  class SpaCyFramework(BaseModel):
@@ -389,11 +407,11 @@ class TransformersFramework(BaseModel):
389
407
  )
390
408
  pipeline_tag: Optional[str] = Field(
391
409
  None,
392
- description="+label=Pipeline Tag\n+usage=Pipeline tag\n+docs=Pipeline tag for the framework",
410
+ description="+label=Pipeline Tag\n+usage=The `pipeline()` task this model can be used with e.g. `text-generation`. See [huggingface docs](https://huggingface.co/docs/transformers/main/en/main_classes/pipelines#transformers.pipeline.task) for all possible values\n+docs=Pipeline tag for the framework",
393
411
  )
394
412
  base_model: Optional[str] = Field(
395
413
  None,
396
- description="+label=Base Model\n+usage=Base model\n+docs=Base model Id. If this is a finetuned model, this points to the base model used for finetuning",
414
+ description="+label=Base Model\n+usage=Base model Id. If this is a finetuned model, this points to the base model used for finetuning\n+docs=Base model Id. If this is a finetuned model, this points to the base model used for finetuning",
397
415
  )
398
416
 
399
417
 
@@ -446,6 +464,10 @@ class XGBoostFramework(BaseModel):
446
464
  None,
447
465
  description="+label=Serialization format\n+usage=Serialization format used for the model",
448
466
  )
467
+ model_filepath: Optional[str] = Field(
468
+ None,
469
+ description="+label=Model file path\n+usage=Relative path to the model file",
470
+ )
449
471
 
450
472
 
451
473
  class AgentOpenAPITool(BaseModel):
@@ -573,7 +595,7 @@ class ModelVersion(BaseArtifactVersion):
573
595
  step: conint(ge=0) = Field(0, description="+label=Step")
574
596
  model_schema: Optional[Dict[str, Any]] = Field(
575
597
  None,
576
- description="+label=Model Schema\n+usage=Schema of the model\n+uiType=Hidden",
598
+ description="+label=Model Schema\n+usage=Schema of the model\n+uiType=JsonInput",
577
599
  )
578
600
 
579
601
 
@@ -0,0 +1,97 @@
1
+ import os
2
+ from typing import Optional
3
+
4
+ import rich_click as click
5
+
6
+ from truefoundry.deploy.cli.console import console
7
+ from truefoundry.deploy.cli.const import COMMAND_CLS
8
+ from truefoundry.deploy.cli.util import handle_exception_wrapper
9
+ from truefoundry.ml.autogen.client.models import ModelServer
10
+ from truefoundry.ml.cli.utils import (
11
+ AppName,
12
+ NonEmptyString,
13
+ )
14
+ from truefoundry.ml.mlfoundry_api import get_client
15
+
16
+
17
+ @click.command(
18
+ name="model",
19
+ cls=COMMAND_CLS,
20
+ help="Generating application code for the specified model version.",
21
+ )
22
+ @click.option(
23
+ "--name",
24
+ required=True,
25
+ type=AppName(),
26
+ help="Name for the model server deployment",
27
+ show_default=True,
28
+ )
29
+ @click.option(
30
+ "--model-version-fqn",
31
+ "--model_version_fqn",
32
+ type=NonEmptyString(),
33
+ required=True,
34
+ show_default=True,
35
+ help="Fully Qualified Name (FQN) of the model version to deploy, e.g., 'model:tenant_name/my-model/linear-regression:2'",
36
+ )
37
+ @click.option(
38
+ "-w",
39
+ "--workspace-fqn",
40
+ "--workspace_fqn",
41
+ type=NonEmptyString(),
42
+ required=True,
43
+ show_default=True,
44
+ help="Fully Qualified Name (FQN) of the workspace to deploy",
45
+ )
46
+ @click.option(
47
+ "--model-server",
48
+ "--model_server",
49
+ type=click.Choice(ModelServer, case_sensitive=False),
50
+ default=ModelServer.FASTAPI.value,
51
+ show_default=True,
52
+ help="Specify the model server (Case Insensitive).",
53
+ )
54
+ @click.option(
55
+ "--output-dir",
56
+ "--output_dir",
57
+ type=click.Path(exists=True, file_okay=False, writable=True),
58
+ help="Output directory for the model server code",
59
+ required=False,
60
+ show_default=True,
61
+ default=os.getcwd(),
62
+ )
63
+ @handle_exception_wrapper
64
+ def model_init_command(
65
+ name: str,
66
+ model_version_fqn: str,
67
+ workspace_fqn: str,
68
+ model_server: ModelServer,
69
+ output_dir: Optional[str],
70
+ ):
71
+ """
72
+ Generates application code for the specified model version.
73
+ """
74
+ ml_client = get_client()
75
+ console.print(f"Generating application code for {model_version_fqn!r}")
76
+ output_dir = ml_client._initialize_model_server(
77
+ name=name,
78
+ model_version_fqn=model_version_fqn,
79
+ workspace_fqn=workspace_fqn,
80
+ model_server=ModelServer[model_server.upper()],
81
+ output_dir=output_dir,
82
+ )
83
+ message = f"""
84
+ [bold green]Model Server code initialized successfully![/bold green]
85
+
86
+ [bold]Code Location:[/bold] {output_dir}
87
+
88
+ [bold]Next Steps:[/bold]
89
+ - Navigate to the model server directory:
90
+ [green]cd {output_dir}[/green]
91
+ - Refer to the README file in the directory for further instructions.
92
+ """
93
+ console.print(message)
94
+
95
+
96
+ def get_model_init_command():
97
+ return model_init_command
@@ -0,0 +1,34 @@
1
+ import rich_click as click
2
+
3
+ from truefoundry.ml import MlFoundryException
4
+ from truefoundry.ml.validation_utils import (
5
+ _APP_NAME_REGEX,
6
+ )
7
+
8
+
9
+ class AppName(click.ParamType):
10
+ """
11
+ Custom ParamType to validate application names.
12
+ """
13
+
14
+ name = "application-name"
15
+
16
+ def convert(self, value, param, ctx):
17
+ try:
18
+ if not value or not _APP_NAME_REGEX.match(value):
19
+ raise MlFoundryException(
20
+ f"{value!r} must be lowercase and cannot contain spaces. It can only contain alphanumeric characters and hyphens. "
21
+ f"Length must be between 1 and 30 characters."
22
+ )
23
+ except MlFoundryException as e:
24
+ self.fail(str(e), param, ctx)
25
+ return value
26
+
27
+
28
+ class NonEmptyString(click.ParamType):
29
+ name = "non-empty-string"
30
+
31
+ def convert(self, value, param, ctx):
32
+ if isinstance(value, str) and not value.strip():
33
+ self.fail("Value cannot be empty or contain only spaces.", param, ctx)
34
+ return value
@@ -8,7 +8,7 @@ import typing
8
8
  import uuid
9
9
  import warnings
10
10
  from pathlib import Path
11
- from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union
11
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
12
12
 
13
13
  from truefoundry.ml.artifact.truefoundry_artifact_repo import (
14
14
  ArtifactIdentifier,
@@ -44,11 +44,19 @@ from truefoundry.ml.log_types.artifacts.utils import (
44
44
  _validate_description,
45
45
  calculate_total_size,
46
46
  )
47
- from truefoundry.ml.model_framework import ModelFrameworkType, _ModelFramework
47
+ from truefoundry.ml.model_framework import (
48
+ ModelFrameworkType,
49
+ _ModelFramework,
50
+ auto_update_environment_details,
51
+ auto_update_model_framework_details,
52
+ )
48
53
  from truefoundry.ml.session import _get_api_client
49
54
  from truefoundry.pydantic_v1 import BaseModel, Extra
50
55
 
51
56
  if TYPE_CHECKING:
57
+ import numpy as np
58
+ import pandas as pd
59
+
52
60
  from truefoundry.ml.mlfoundry_run import MlFoundryRun
53
61
 
54
62
 
@@ -98,6 +106,7 @@ class ModelVersion:
98
106
  self._description: str = ""
99
107
  self._metadata: Dict[str, Any] = {}
100
108
  self._model_schema: Optional[Dict[str, Any]] = None
109
+ self._environment: Optional[ModelVersionEnvironment] = None
101
110
  self._framework: Optional[ModelFrameworkType] = None
102
111
  self._set_mutable_attrs()
103
112
 
@@ -143,19 +152,20 @@ class ModelVersion:
143
152
  self._model_schema = copy.deepcopy(
144
153
  self._model_version.manifest.model_schema
145
154
  )
146
- if self._model_version.manifest.framework:
147
- self._framework = copy.deepcopy(
148
- self._model_version.manifest.framework.actual_instance
149
- )
150
- else:
151
- self._framework = None
155
+ self._environment = copy.deepcopy(self._model_version.manifest.environment)
156
+ self._framework = (
157
+ copy.deepcopy(self._model_version.manifest.framework.actual_instance)
158
+ if self._model_version.manifest.framework
159
+ else None
160
+ )
152
161
  else:
153
162
  self._description = self._model_version.description or ""
154
163
  self._metadata = copy.deepcopy(self._model_version.artifact_metadata)
164
+ self._model_schema = None
165
+ self._environment = None
155
166
  self._framework = _ModelFramework.to_model_framework_type(
156
167
  self._model_version.model_framework
157
168
  )
158
- self._model_schema = None
159
169
 
160
170
  def _refetch_model_version(self, reset_mutable_attrs: bool = True):
161
171
  _model_version = self._mlfoundry_artifacts_api.get_model_version_get(
@@ -242,6 +252,23 @@ class ModelVersion:
242
252
  return
243
253
  self._model_schema = copy.deepcopy(value)
244
254
 
255
+ @property
256
+ def environment(self) -> Optional[ModelVersionEnvironment]:
257
+ """Get the environment details for the model"""
258
+ return self._environment
259
+
260
+ @environment.setter
261
+ def environment(self, value: Optional[Dict[str, Any]]):
262
+ """set the environment details for the model"""
263
+ if not self._model_version.manifest:
264
+ warnings.warn(
265
+ message="This model version was created using an older serialization format. Environment will not be updated",
266
+ category=DeprecationWarning,
267
+ stacklevel=2,
268
+ )
269
+ return
270
+ self._environment = copy.deepcopy(value)
271
+
245
272
  @property
246
273
  def framework(self) -> Optional["ModelFrameworkType"]:
247
274
  """Get the framework of the model"""
@@ -440,6 +467,7 @@ class ModelVersion:
440
467
  self._model_version.manifest.description = self.description
441
468
  self._model_version.manifest.metadata = self.metadata
442
469
  self._model_version.manifest.model_schema = self.model_schema
470
+ self._model_version.manifest.environment = self.environment
443
471
  self._model_version.manifest.framework = (
444
472
  Framework.from_dict(self.framework.dict()) if self.framework else None
445
473
  )
@@ -467,7 +495,6 @@ def _log_model_version( # noqa: C901
467
495
  model_file_or_folder: Union[str, BlobStorageDirectory],
468
496
  mlfoundry_artifacts_api: Optional[MlfoundryArtifactsApi] = None,
469
497
  ml_repo_id: Optional[str] = None,
470
- additional_files: Sequence[Tuple[Union[str, Path], Optional[str]]] = (),
471
498
  description: Optional[str] = None,
472
499
  metadata: Optional[Dict[str, Any]] = None,
473
500
  step: Optional[int] = 0,
@@ -498,7 +525,6 @@ def _log_model_version( # noqa: C901
498
525
  step = step or 0
499
526
  total_size = None
500
527
  metadata = metadata or {}
501
- additional_files = additional_files or {}
502
528
 
503
529
  _validate_description(description)
504
530
  _validate_artifact_metadata(metadata)
@@ -516,17 +542,6 @@ def _log_model_version( # noqa: C901
516
542
  ignore_model_dir_dest_conflict=True,
517
543
  )
518
544
 
519
- # verify additional files and paths, copy additional files
520
- if additional_files:
521
- logger.info("Adding `additional_files` to model version contents")
522
- temp_dest_to_src_map = _copy_additional_files(
523
- root_dir=temp_dir.name,
524
- files_dir="",
525
- model_dir=None,
526
- additional_files=additional_files,
527
- ignore_model_dir_dest_conflict=False,
528
- existing_dest_to_src_map=temp_dest_to_src_map,
529
- )
530
545
  except Exception as e:
531
546
  temp_dir.cleanup()
532
547
  raise MlFoundryException("Failed to log model") from e
@@ -580,13 +595,22 @@ def _log_model_version( # noqa: C901
580
595
  else:
581
596
  raise MlFoundryException("Invalid model_file_or_folder provided")
582
597
 
583
- _framework = _ModelFramework.to_model_framework_type(framework)
584
598
  _source_cls = typing.get_type_hints(ModelVersionManifest)["source"]
599
+
600
+ # Auto fetch the framework & environment details if not provided
601
+ framework = _ModelFramework.to_model_framework_type(framework)
602
+ if framework and isinstance(model_file_or_folder, str):
603
+ auto_update_model_framework_details(
604
+ framework=framework, model_file_or_folder=model_file_or_folder
605
+ )
606
+ environment = environment or ModelVersionEnvironment()
607
+ auto_update_environment_details(environment=environment, framework=framework)
608
+
585
609
  model_manifest = ModelVersionManifest(
586
610
  description=description,
587
611
  metadata=metadata,
588
612
  source=_source_cls.from_dict(source.dict()),
589
- framework=Framework.from_dict(_framework.dict()) if _framework else None,
613
+ framework=Framework.from_dict(framework.dict()) if framework else None,
590
614
  environment=environment,
591
615
  step=step,
592
616
  model_schema=model_schema,
@@ -603,3 +627,18 @@ def _log_model_version( # noqa: C901
603
627
  )
604
628
  )
605
629
  return ModelVersion.from_fqn(fqn=artifact_version_response.artifact_version.fqn)
630
+
631
+
632
+ def infer_signature(
633
+ model_input: Any = None,
634
+ model_output: Optional[
635
+ Union["pd.DataFrame", "np.ndarray", Dict[str, "np.ndarray"]]
636
+ ] = None,
637
+ params: Optional[Dict[str, Any]] = None,
638
+ ):
639
+ # TODO: Importing this globally causes hard dependencies on some libraries like pandas
640
+ from truefoundry.ml.autogen.models import infer_signature as _infer_signature
641
+
642
+ return _infer_signature(
643
+ model_input=model_input, model_output=model_output, params=params
644
+ )
@@ -2,7 +2,7 @@ import json
2
2
  import logging
3
3
  import os
4
4
  import posixpath
5
- from pathlib import Path
5
+ from pathlib import Path, PureWindowsPath
6
6
  from typing import Any, Dict, Optional, Sequence, Tuple, Union
7
7
 
8
8
  from truefoundry.ml.exceptions import MlFoundryException
@@ -11,6 +11,13 @@ from truefoundry.ml.log_types.artifacts.constants import DESCRIPTION_MAX_LENGTH
11
11
  logger = logging.getLogger(__name__)
12
12
 
13
13
 
14
+ def to_unix_path(path):
15
+ path = os.path.normpath(path)
16
+ if os.path.sep == "\\":
17
+ path = PureWindowsPath(path).as_posix()
18
+ return path
19
+
20
+
14
21
  def _copy_tree(
15
22
  root_dir: str, src_path: str, dest_path: str, dest_to_src: Dict[str, str]
16
23
  ):
@@ -42,6 +49,35 @@ def is_destination_path_dirlike(dest_path) -> bool:
42
49
  return False
43
50
 
44
51
 
52
+ def get_single_file_path_if_only_one_in_directory(path: str) -> Optional[str]:
53
+ """
54
+ Get the filename from a path, or return a filename from a directory if a single file exists.
55
+ Args:
56
+ path: The file or folder path.
57
+ Returns:
58
+ Optional[str]: The filename or None if no files are found or multiple files are found.
59
+ """
60
+ # If it's already a file, return it as-is
61
+ if os.path.isfile(path):
62
+ return path
63
+
64
+ # If it's a directory, check if it contains a single file
65
+ if is_destination_path_dirlike(path):
66
+ all_files = []
67
+ for root, _, files in os.walk(path):
68
+ # Collect all files found in any subdirectory
69
+ all_files.extend(os.path.join(root, f) for f in files)
70
+ # If more than one file is found, stop early
71
+ if len(all_files) > 1:
72
+ return None
73
+
74
+ # If only one file is found, return it
75
+ if len(all_files) == 1:
76
+ return all_files[0]
77
+
78
+ return None # No file found or Multiple files found
79
+
80
+
45
81
  def _copy_additional_files(
46
82
  root_dir: str,
47
83
  files_dir: str, # relative to root dir e.g. "files/"