truefoundry 0.5.0rc7__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (67) hide show
  1. truefoundry/common/auth_service_client.py +2 -2
  2. truefoundry/common/constants.py +9 -0
  3. truefoundry/common/utils.py +81 -1
  4. truefoundry/deploy/__init__.py +5 -0
  5. truefoundry/deploy/builder/builders/tfy_notebook_buildpack/__init__.py +4 -2
  6. truefoundry/deploy/builder/builders/tfy_python_buildpack/__init__.py +7 -5
  7. truefoundry/deploy/builder/builders/tfy_python_buildpack/dockerfile_template.py +87 -28
  8. truefoundry/deploy/builder/constants.py +8 -0
  9. truefoundry/deploy/builder/utils.py +9 -4
  10. truefoundry/deploy/cli/cli.py +2 -0
  11. truefoundry/deploy/cli/commands/__init__.py +1 -0
  12. truefoundry/deploy/cli/commands/deploy_init_command.py +22 -0
  13. truefoundry/deploy/lib/dao/application.py +2 -1
  14. truefoundry/deploy/v2/lib/patched_models.py +8 -0
  15. truefoundry/ml/__init__.py +25 -16
  16. truefoundry/ml/autogen/client/__init__.py +21 -3
  17. truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +325 -0
  18. truefoundry/ml/autogen/client/models/__init__.py +21 -3
  19. truefoundry/ml/autogen/client/models/artifact_version_manifest.py +2 -2
  20. truefoundry/ml/autogen/client/models/export_deployment_files_request_dto.py +82 -0
  21. truefoundry/ml/autogen/client/models/infer_method_name.py +34 -0
  22. truefoundry/ml/autogen/client/models/model_server.py +34 -0
  23. truefoundry/ml/autogen/client/models/model_version_environment.py +1 -1
  24. truefoundry/ml/autogen/client/models/model_version_manifest.py +2 -8
  25. truefoundry/ml/autogen/client/models/sklearn_framework.py +15 -5
  26. truefoundry/ml/autogen/client/models/sklearn_model_schema.py +82 -0
  27. truefoundry/ml/autogen/client/models/{serialization_format.py → sklearn_serialization_format.py} +5 -5
  28. truefoundry/ml/autogen/client/models/transformers_framework.py +2 -2
  29. truefoundry/ml/autogen/client/models/validate_external_storage_root_request_dto.py +71 -0
  30. truefoundry/ml/autogen/client/models/validate_external_storage_root_response_dto.py +69 -0
  31. truefoundry/ml/autogen/client/models/xg_boost_framework.py +17 -5
  32. truefoundry/ml/autogen/client/models/xg_boost_model_schema.py +88 -0
  33. truefoundry/ml/autogen/client/models/xg_boost_serialization_format.py +36 -0
  34. truefoundry/ml/autogen/client_README.md +11 -1
  35. truefoundry/ml/autogen/entities/artifacts.py +95 -39
  36. truefoundry/ml/autogen/models/signature.py +6 -3
  37. truefoundry/ml/autogen/models/utils.py +12 -7
  38. truefoundry/ml/cli/commands/model_init.py +97 -0
  39. truefoundry/ml/cli/utils.py +34 -0
  40. truefoundry/ml/log_types/artifacts/model.py +50 -38
  41. truefoundry/ml/log_types/artifacts/utils.py +38 -2
  42. truefoundry/ml/mlfoundry_api.py +74 -80
  43. truefoundry/ml/mlfoundry_run.py +0 -32
  44. truefoundry/ml/model_framework.py +372 -3
  45. truefoundry/ml/validation_utils.py +2 -0
  46. {truefoundry-0.5.0rc7.dist-info → truefoundry-0.5.1.dist-info}/METADATA +1 -5
  47. {truefoundry-0.5.0rc7.dist-info → truefoundry-0.5.1.dist-info}/RECORD +49 -56
  48. truefoundry/deploy/function_service/__init__.py +0 -3
  49. truefoundry/deploy/function_service/__main__.py +0 -27
  50. truefoundry/deploy/function_service/app.py +0 -92
  51. truefoundry/deploy/function_service/build.py +0 -45
  52. truefoundry/deploy/function_service/remote/__init__.py +0 -6
  53. truefoundry/deploy/function_service/remote/context.py +0 -3
  54. truefoundry/deploy/function_service/remote/method.py +0 -67
  55. truefoundry/deploy/function_service/remote/remote.py +0 -144
  56. truefoundry/deploy/function_service/route.py +0 -137
  57. truefoundry/deploy/function_service/service.py +0 -113
  58. truefoundry/deploy/function_service/utils.py +0 -53
  59. truefoundry/langchain/__init__.py +0 -12
  60. truefoundry/langchain/deprecated.py +0 -302
  61. truefoundry/langchain/truefoundry_chat.py +0 -130
  62. truefoundry/langchain/truefoundry_embeddings.py +0 -171
  63. truefoundry/langchain/truefoundry_llm.py +0 -106
  64. truefoundry/langchain/utils.py +0 -44
  65. truefoundry/ml/log_types/artifacts/model_extras.py +0 -48
  66. {truefoundry-0.5.0rc7.dist-info → truefoundry-0.5.1.dist-info}/WHEEL +0 -0
  67. {truefoundry-0.5.0rc7.dist-info → truefoundry-0.5.1.dist-info}/entry_points.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  # generated by datamodel-codegen:
2
2
  # filename: artifacts.json
3
- # timestamp: 2024-11-28T08:27:16+00:00
3
+ # timestamp: 2024-12-09T09:04:12+00:00
4
4
 
5
5
  from __future__ import annotations
6
6
 
@@ -52,11 +52,11 @@ class AgentWithFQN(Agent):
52
52
  class BaseArtifactVersion(BaseModel):
53
53
  description: Optional[constr(max_length=512)] = Field(
54
54
  None,
55
- description="+label=Description\n+docs=Description of the artifact version",
55
+ description="+label=Description\n+usage=Description of the artifact or model version\n+docs=Description of the artifact or model version",
56
56
  )
57
57
  metadata: Dict[str, Any] = Field(
58
58
  ...,
59
- description="+label=Metadata\n+docs=Metadata for the model version\n+usage=Metadata for the model version\n+uiType=JsonInput",
59
+ description="+label=Metadata\n+docs=Metadata for the artifact or model version\n+usage=Metadata for the artifact or model version\n+uiType=JsonInput",
60
60
  )
61
61
 
62
62
 
@@ -228,7 +228,7 @@ class ModelConfiguration(BaseModel):
228
228
 
229
229
  class ModelVersionEnvironment(BaseModel):
230
230
  """
231
- +label=ModelVersionEnvironment
231
+ +label=Environment
232
232
  """
233
233
 
234
234
  python_version: Optional[constr(regex=r"^\d+(\.\d+){1,2}([\-\.a-z0-9]+)?$")] = (
@@ -276,32 +276,44 @@ class PyTorchFramework(BaseModel):
276
276
  )
277
277
 
278
278
 
279
- class SerializationFormat(str, Enum):
279
+ class InferMethodName(str, Enum):
280
280
  """
281
- +label=Serialization format
282
- +usage=Serialization format used for the model
281
+ +label=Inference Method Name
282
+ +usage=Name of the method used for inference
283
283
  """
284
284
 
285
- cloudpickle = "cloudpickle"
286
- joblib = "joblib"
287
- pickle = "pickle"
285
+ predict = "predict"
286
+ predict_proba = "predict_proba"
288
287
 
289
288
 
290
- class SklearnFramework(BaseModel):
289
+ class SklearnModelSchema(BaseModel):
291
290
  """
292
- +docs=Scikit-learn framework for the model version
293
- +label=Sklearn
291
+ +label=Sklearn Model Schema
294
292
  """
295
293
 
296
- type: Literal["sklearn"] = Field(
297
- ..., description="+label=Type\n+usage=Type of the framework\n+value=sklearn"
294
+ infer_method_name: InferMethodName = Field(
295
+ ...,
296
+ description="+label=Inference Method Name\n+usage=Name of the method used for inference",
298
297
  )
299
- serialization_format: Optional[SerializationFormat] = Field(
300
- None,
301
- description="+label=Serialization format\n+usage=Serialization format used for the model",
298
+ inputs: List[Dict[str, Any]] = Field(
299
+ ..., description="+label= Input Schema\n+usage=Schema of the input"
300
+ )
301
+ outputs: List[Dict[str, Any]] = Field(
302
+ ..., description="+label= Output Schema\n+usage=Schema of the output"
302
303
  )
303
304
 
304
305
 
306
+ class SklearnSerializationFormat(str, Enum):
307
+ """
308
+ +label=Serialization format
309
+ +usage=Serialization format used for sklearn models
310
+ """
311
+
312
+ cloudpickle = "cloudpickle"
313
+ joblib = "joblib"
314
+ pickle = "pickle"
315
+
316
+
305
317
  class SpaCyFramework(BaseModel):
306
318
  """
307
319
  +docs=spaCy framework for the model version
@@ -389,11 +401,11 @@ class TransformersFramework(BaseModel):
389
401
  )
390
402
  pipeline_tag: Optional[str] = Field(
391
403
  None,
392
- description="+label=Pipeline Tag\n+usage=Pipeline tag\n+docs=Pipeline tag for the framework",
404
+ description="+label=Pipeline Tag\n+usage=The `pipeline()` task this model can be used with e.g. `text-generation`. See [huggingface docs](https://huggingface.co/docs/transformers/main/en/main_classes/pipelines#transformers.pipeline.task) for all possible values\n+docs=Pipeline tag for the framework",
393
405
  )
394
406
  base_model: Optional[str] = Field(
395
407
  None,
396
- description="+label=Base Model\n+usage=Base model\n+docs=Base model Id. If this is a finetuned model, this points to the base model used for finetuning",
408
+ description="+label=Base Model\n+usage=Base model Id. If this is a finetuned model, this points to the base model used for finetuning\n+docs=Base model Id. If this is a finetuned model, this points to the base model used for finetuning",
397
409
  )
398
410
 
399
411
 
@@ -433,19 +445,33 @@ class UserMessage(BaseModel):
433
445
  )
434
446
 
435
447
 
436
- class XGBoostFramework(BaseModel):
448
+ class XGBoostModelSchema(BaseModel):
437
449
  """
438
- +docs=XGBoost framework for the model version
439
- +label=XGBoost
450
+ +label=XGBoost Model Schema
440
451
  """
441
452
 
442
- type: Literal["xgboost"] = Field(
443
- ..., description="+label=Type\n+usage=Type of the framework\n+value=xgboost"
453
+ infer_method_name: Literal["predict"] = Field(
454
+ ...,
455
+ description="+label=Inference Method Name\n+usage=Name of the method used for inference",
444
456
  )
445
- serialization_format: Optional[SerializationFormat] = Field(
446
- None,
447
- description="+label=Serialization format\n+usage=Serialization format used for the model",
457
+ inputs: List[Dict[str, Any]] = Field(
458
+ ..., description="+label= Input Schema\n+usage=Schema of the input"
448
459
  )
460
+ outputs: List[Dict[str, Any]] = Field(
461
+ ..., description="+label= Output Schema\n+usage=Schema of the output"
462
+ )
463
+
464
+
465
+ class XGBoostSerializationFormat(str, Enum):
466
+ """
467
+ +label=Serialization format
468
+ +usage=Serialization format used for XGBoost models
469
+ """
470
+
471
+ cloudpickle = "cloudpickle"
472
+ joblib = "joblib"
473
+ pickle = "pickle"
474
+ json = "json"
449
475
 
450
476
 
451
477
  class AgentOpenAPITool(BaseModel):
@@ -541,6 +567,47 @@ class ChatPrompt(BasePrompt):
541
567
  )
542
568
 
543
569
 
570
+ class SklearnFramework(BaseModel):
571
+ """
572
+ +docs=Scikit-learn framework for the model version
573
+ +label=Sklearn
574
+ """
575
+
576
+ type: Literal["sklearn"] = Field(
577
+ ..., description="+label=Type\n+usage=Type of the framework\n+value=sklearn"
578
+ )
579
+ model_filepath: Optional[str] = Field(
580
+ None,
581
+ description="+label=Model file path\n+usage=Relative path to the model file",
582
+ )
583
+ serialization_format: Optional[SklearnSerializationFormat] = None
584
+ model_schema: Optional[SklearnModelSchema] = None
585
+
586
+
587
+ class XGBoostFramework(BaseModel):
588
+ """
589
+ +docs=XGBoost framework for the model version
590
+ +label=XGBoost
591
+ """
592
+
593
+ type: Literal["xgboost"] = Field(
594
+ ..., description="+label=Type\n+usage=Type of the framework\n+value=xgboost"
595
+ )
596
+ serialization_format: Optional[XGBoostSerializationFormat] = None
597
+ model_filepath: Optional[str] = Field(
598
+ None,
599
+ description="+label=Model file path\n+usage=Relative path to the model file",
600
+ )
601
+ model_schema: Optional[XGBoostModelSchema] = None
602
+
603
+
604
+ class AgentApp(BaseModel):
605
+ type: Literal["agent-app"] = Field(..., description="+value=agent-app")
606
+ tools: List[AgentOpenAPIToolWithFQN]
607
+ agents: List[AgentWithFQN]
608
+ root_agent: constr(min_length=1)
609
+
610
+
544
611
  class ModelVersion(BaseArtifactVersion):
545
612
  type: Literal["model-version"] = Field(
546
613
  ..., description='+label=Type\n+usage=Model Version\n+value="model-version"'
@@ -571,17 +638,6 @@ class ModelVersion(BaseArtifactVersion):
571
638
  )
572
639
  environment: Optional[ModelVersionEnvironment] = None
573
640
  step: conint(ge=0) = Field(0, description="+label=Step")
574
- model_schema: Optional[Dict[str, Any]] = Field(
575
- None,
576
- description="+label=Model Schema\n+usage=Schema of the model\n+uiType=Hidden",
577
- )
578
-
579
-
580
- class AgentApp(BaseModel):
581
- type: Literal["agent-app"] = Field(..., description="+value=agent-app")
582
- tools: List[AgentOpenAPIToolWithFQN]
583
- agents: List[AgentWithFQN]
584
- root_agent: constr(min_length=1)
585
641
 
586
642
 
587
643
  class VersionedArtifactType(BaseModel):
@@ -1,13 +1,16 @@
1
1
  from dataclasses import dataclass, is_dataclass
2
- from typing import Any, Dict, Optional, Union
2
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
3
3
 
4
4
  import numpy as np
5
- import pandas as pd
5
+
6
+ if TYPE_CHECKING:
7
+ import pandas as pd
8
+
6
9
 
7
10
  from .schema import ParamSchema, Schema, convert_dataclass_to_schema
8
11
  from .utils import infer_param_schema, infer_schema
9
12
 
10
- MlflowInferableDataset = Union[pd.DataFrame, np.ndarray, Dict[str, np.ndarray]]
13
+ MlflowInferableDataset = Union["pd.DataFrame", "np.ndarray", Dict[str, "np.ndarray"]]
11
14
 
12
15
 
13
16
  class ModelSignature:
@@ -4,7 +4,12 @@ from collections import defaultdict
4
4
  from typing import Any, Dict, List, Optional, Union
5
5
 
6
6
  import numpy as np
7
- import pandas as pd
7
+
8
+ try:
9
+ import pandas as pd
10
+ except ImportError:
11
+ pd = None
12
+
8
13
 
9
14
  from .exceptions import MlflowException
10
15
  from .schema import (
@@ -330,7 +335,7 @@ def infer_schema(data: Any) -> Schema: # noqa: C901
330
335
  ]
331
336
  )
332
337
  # pandas.Series
333
- elif isinstance(data, pd.Series):
338
+ elif pd and isinstance(data, pd.Series):
334
339
  name = getattr(data, "name", None)
335
340
  schema = Schema(
336
341
  [
@@ -342,7 +347,7 @@ def infer_schema(data: Any) -> Schema: # noqa: C901
342
347
  ]
343
348
  )
344
349
  # pandas.DataFrame
345
- elif isinstance(data, pd.DataFrame):
350
+ elif pd and isinstance(data, pd.DataFrame):
346
351
  schema = Schema(
347
352
  [
348
353
  ColSpec(
@@ -473,13 +478,13 @@ def _is_none_or_nan(x):
473
478
 
474
479
 
475
480
  def _infer_required(col) -> bool:
476
- if isinstance(col, (list, pd.Series)):
481
+ if pd and isinstance(col, (list, pd.Series)):
477
482
  return not any(_is_none_or_nan(x) for x in col)
478
483
  return not _is_none_or_nan(col)
479
484
 
480
485
 
481
- def _infer_pandas_column(col: pd.Series) -> DataType:
482
- if not isinstance(col, pd.Series):
486
+ def _infer_pandas_column(col: "pd.Series") -> DataType:
487
+ if pd and not isinstance(col, pd.Series):
483
488
  raise TypeError(f"Expected pandas.Series, got '{type(col)}'.")
484
489
  if len(col.values.shape) > 1:
485
490
  raise MlflowException(f"Expected 1d array, got array with shape {col.shape}")
@@ -496,7 +501,7 @@ def _infer_pandas_column(col: pd.Series) -> DataType:
496
501
  # For backwards compatibility, we fall back to string
497
502
  # if the provided array is of string type
498
503
  # This is for diviner test where df field is ('key2', 'key1', 'key0')
499
- if pd.api.types.is_string_dtype(col):
504
+ if pd and pd.api.types.is_string_dtype(col):
500
505
  return DataType.string
501
506
  raise MlflowException(
502
507
  f"Failed to infer schema for pandas.Series {col}. Error: {e}"
@@ -0,0 +1,97 @@
1
+ import os
2
+ from typing import Optional
3
+
4
+ import rich_click as click
5
+
6
+ from truefoundry.deploy.cli.console import console
7
+ from truefoundry.deploy.cli.const import COMMAND_CLS
8
+ from truefoundry.deploy.cli.util import handle_exception_wrapper
9
+ from truefoundry.ml.autogen.client.models import ModelServer
10
+ from truefoundry.ml.cli.utils import (
11
+ AppName,
12
+ NonEmptyString,
13
+ )
14
+ from truefoundry.ml.mlfoundry_api import get_client
15
+
16
+
17
+ @click.command(
18
+ name="model",
19
+ cls=COMMAND_CLS,
20
+ help="Generating application code for the specified model version.",
21
+ )
22
+ @click.option(
23
+ "--name",
24
+ required=True,
25
+ type=AppName(),
26
+ help="Name for the model server deployment",
27
+ show_default=True,
28
+ )
29
+ @click.option(
30
+ "--model-version-fqn",
31
+ "--model_version_fqn",
32
+ type=NonEmptyString(),
33
+ required=True,
34
+ show_default=True,
35
+ help="Fully Qualified Name (FQN) of the model version to deploy, e.g., 'model:tenant_name/my-model/linear-regression:2'",
36
+ )
37
+ @click.option(
38
+ "-w",
39
+ "--workspace-fqn",
40
+ "--workspace_fqn",
41
+ type=NonEmptyString(),
42
+ required=True,
43
+ show_default=True,
44
+ help="Fully Qualified Name (FQN) of the workspace to deploy",
45
+ )
46
+ @click.option(
47
+ "--model-server",
48
+ "--model_server",
49
+ type=click.Choice(ModelServer, case_sensitive=False),
50
+ default=ModelServer.FASTAPI.value,
51
+ show_default=True,
52
+ help="Specify the model server (Case Insensitive).",
53
+ )
54
+ @click.option(
55
+ "--output-dir",
56
+ "--output_dir",
57
+ type=click.Path(exists=True, file_okay=False, writable=True),
58
+ help="Output directory for the model server code",
59
+ required=False,
60
+ show_default=True,
61
+ default=os.getcwd(),
62
+ )
63
+ @handle_exception_wrapper
64
+ def model_init_command(
65
+ name: str,
66
+ model_version_fqn: str,
67
+ workspace_fqn: str,
68
+ model_server: ModelServer,
69
+ output_dir: Optional[str],
70
+ ):
71
+ """
72
+ Generates application code for the specified model version.
73
+ """
74
+ ml_client = get_client()
75
+ console.print(f"Generating application code for {model_version_fqn!r}")
76
+ output_dir = ml_client._initialize_model_server(
77
+ name=name,
78
+ model_version_fqn=model_version_fqn,
79
+ workspace_fqn=workspace_fqn,
80
+ model_server=ModelServer[model_server.upper()],
81
+ output_dir=output_dir,
82
+ )
83
+ message = f"""
84
+ [bold green]Model Server code initialized successfully![/bold green]
85
+
86
+ [bold]Code Location:[/bold] {output_dir}
87
+
88
+ [bold]Next Steps:[/bold]
89
+ - Navigate to the model server directory:
90
+ [green]cd {output_dir}[/green]
91
+ - Refer to the README file in the directory for further instructions.
92
+ """
93
+ console.print(message)
94
+
95
+
96
+ def get_model_init_command():
97
+ return model_init_command
@@ -0,0 +1,34 @@
1
+ import rich_click as click
2
+
3
+ from truefoundry.ml import MlFoundryException
4
+ from truefoundry.ml.validation_utils import (
5
+ _APP_NAME_REGEX,
6
+ )
7
+
8
+
9
+ class AppName(click.ParamType):
10
+ """
11
+ Custom ParamType to validate application names.
12
+ """
13
+
14
+ name = "application-name"
15
+
16
+ def convert(self, value, param, ctx):
17
+ try:
18
+ if not value or not _APP_NAME_REGEX.match(value):
19
+ raise MlFoundryException(
20
+ f"{value!r} must be lowercase and cannot contain spaces. It can only contain alphanumeric characters and hyphens. "
21
+ f"Length must be between 1 and 30 characters."
22
+ )
23
+ except MlFoundryException as e:
24
+ self.fail(str(e), param, ctx)
25
+ return value
26
+
27
+
28
+ class NonEmptyString(click.ParamType):
29
+ name = "non-empty-string"
30
+
31
+ def convert(self, value, param, ctx):
32
+ if isinstance(value, str) and not value.strip():
33
+ self.fail("Value cannot be empty or contain only spaces.", param, ctx)
34
+ return value
@@ -8,7 +8,7 @@ import typing
8
8
  import uuid
9
9
  import warnings
10
10
  from pathlib import Path
11
- from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union
11
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
12
12
 
13
13
  from truefoundry.ml.artifact.truefoundry_artifact_repo import (
14
14
  ArtifactIdentifier,
@@ -31,6 +31,7 @@ from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
31
31
  TrueFoundryArtifactSource,
32
32
  UpdateModelVersionRequestDto,
33
33
  )
34
+ from truefoundry.ml.autogen.models import infer_signature as _infer_signature
34
35
  from truefoundry.ml.enums import ModelFramework
35
36
  from truefoundry.ml.exceptions import MlFoundryException
36
37
  from truefoundry.ml.log_types.artifacts.artifact import BlobStorageDirectory
@@ -44,11 +45,19 @@ from truefoundry.ml.log_types.artifacts.utils import (
44
45
  _validate_description,
45
46
  calculate_total_size,
46
47
  )
47
- from truefoundry.ml.model_framework import ModelFrameworkType, _ModelFramework
48
+ from truefoundry.ml.model_framework import (
49
+ ModelFrameworkType,
50
+ _ModelFramework,
51
+ auto_update_environment_details,
52
+ auto_update_model_framework_details,
53
+ )
48
54
  from truefoundry.ml.session import _get_api_client
49
55
  from truefoundry.pydantic_v1 import BaseModel, Extra
50
56
 
51
57
  if TYPE_CHECKING:
58
+ import numpy as np
59
+ import pandas as pd
60
+
52
61
  from truefoundry.ml.mlfoundry_run import MlFoundryRun
53
62
 
54
63
 
@@ -97,7 +106,7 @@ class ModelVersion:
97
106
  self._deleted = False
98
107
  self._description: str = ""
99
108
  self._metadata: Dict[str, Any] = {}
100
- self._model_schema: Optional[Dict[str, Any]] = None
109
+ self._environment: Optional[ModelVersionEnvironment] = None
101
110
  self._framework: Optional[ModelFrameworkType] = None
102
111
  self._set_mutable_attrs()
103
112
 
@@ -140,22 +149,19 @@ class ModelVersion:
140
149
  if self._model_version.manifest:
141
150
  self._description = self._model_version.manifest.description or ""
142
151
  self._metadata = copy.deepcopy(self._model_version.manifest.metadata)
143
- self._model_schema = copy.deepcopy(
144
- self._model_version.manifest.model_schema
152
+ self._environment = copy.deepcopy(self._model_version.manifest.environment)
153
+ self._framework = (
154
+ copy.deepcopy(self._model_version.manifest.framework.actual_instance)
155
+ if self._model_version.manifest.framework
156
+ else None
145
157
  )
146
- if self._model_version.manifest.framework:
147
- self._framework = copy.deepcopy(
148
- self._model_version.manifest.framework.actual_instance
149
- )
150
- else:
151
- self._framework = None
152
158
  else:
153
159
  self._description = self._model_version.description or ""
154
160
  self._metadata = copy.deepcopy(self._model_version.artifact_metadata)
161
+ self._environment = None
155
162
  self._framework = _ModelFramework.to_model_framework_type(
156
163
  self._model_version.model_framework
157
164
  )
158
- self._model_schema = None
159
165
 
160
166
  def _refetch_model_version(self, reset_mutable_attrs: bool = True):
161
167
  _model_version = self._mlfoundry_artifacts_api.get_model_version_get(
@@ -226,21 +232,21 @@ class ModelVersion:
226
232
  self._metadata = copy.deepcopy(value)
227
233
 
228
234
  @property
229
- def model_schema(self) -> Optional[Dict[str, Any]]:
230
- """Get model_schema for the current model"""
231
- return self._model_schema
235
+ def environment(self) -> Optional[ModelVersionEnvironment]:
236
+ """Get the environment details for the model"""
237
+ return self._environment
232
238
 
233
- @model_schema.setter
234
- def model_schema(self, value: Optional[Dict[str, Any]]):
235
- """set the model_schema for current model"""
239
+ @environment.setter
240
+ def environment(self, value: Optional[Dict[str, Any]]):
241
+ """set the environment details for the model"""
236
242
  if not self._model_version.manifest:
237
243
  warnings.warn(
238
- message="This model version was created using an older serialization format. model_schema will not be updated",
244
+ message="This model version was created using an older serialization format. Environment will not be updated",
239
245
  category=DeprecationWarning,
240
246
  stacklevel=2,
241
247
  )
242
248
  return
243
- self._model_schema = copy.deepcopy(value)
249
+ self._environment = copy.deepcopy(value)
244
250
 
245
251
  @property
246
252
  def framework(self) -> Optional["ModelFrameworkType"]:
@@ -439,7 +445,7 @@ class ModelVersion:
439
445
  if self._model_version.manifest:
440
446
  self._model_version.manifest.description = self.description
441
447
  self._model_version.manifest.metadata = self.metadata
442
- self._model_version.manifest.model_schema = self.model_schema
448
+ self._model_version.manifest.environment = self.environment
443
449
  self._model_version.manifest.framework = (
444
450
  Framework.from_dict(self.framework.dict()) if self.framework else None
445
451
  )
@@ -467,14 +473,12 @@ def _log_model_version( # noqa: C901
467
473
  model_file_or_folder: Union[str, BlobStorageDirectory],
468
474
  mlfoundry_artifacts_api: Optional[MlfoundryArtifactsApi] = None,
469
475
  ml_repo_id: Optional[str] = None,
470
- additional_files: Sequence[Tuple[Union[str, Path], Optional[str]]] = (),
471
476
  description: Optional[str] = None,
472
477
  metadata: Optional[Dict[str, Any]] = None,
473
478
  step: Optional[int] = 0,
474
479
  progress: Optional[bool] = None,
475
480
  framework: Optional[Union[str, ModelFramework, "ModelFrameworkType"]] = None,
476
481
  environment: Optional[ModelVersionEnvironment] = None,
477
- model_schema: Optional[Dict[str, Any]] = None,
478
482
  ) -> ModelVersion:
479
483
  if (run and mlfoundry_artifacts_api) or (not run and not mlfoundry_artifacts_api):
480
484
  raise MlFoundryException(
@@ -498,7 +502,6 @@ def _log_model_version( # noqa: C901
498
502
  step = step or 0
499
503
  total_size = None
500
504
  metadata = metadata or {}
501
- additional_files = additional_files or {}
502
505
 
503
506
  _validate_description(description)
504
507
  _validate_artifact_metadata(metadata)
@@ -516,17 +519,6 @@ def _log_model_version( # noqa: C901
516
519
  ignore_model_dir_dest_conflict=True,
517
520
  )
518
521
 
519
- # verify additional files and paths, copy additional files
520
- if additional_files:
521
- logger.info("Adding `additional_files` to model version contents")
522
- temp_dest_to_src_map = _copy_additional_files(
523
- root_dir=temp_dir.name,
524
- files_dir="",
525
- model_dir=None,
526
- additional_files=additional_files,
527
- ignore_model_dir_dest_conflict=False,
528
- existing_dest_to_src_map=temp_dest_to_src_map,
529
- )
530
522
  except Exception as e:
531
523
  temp_dir.cleanup()
532
524
  raise MlFoundryException("Failed to log model") from e
@@ -580,16 +572,24 @@ def _log_model_version( # noqa: C901
580
572
  else:
581
573
  raise MlFoundryException("Invalid model_file_or_folder provided")
582
574
 
583
- _framework = _ModelFramework.to_model_framework_type(framework)
584
575
  _source_cls = typing.get_type_hints(ModelVersionManifest)["source"]
576
+
577
+ # Auto fetch the framework & environment details if not provided
578
+ framework = _ModelFramework.to_model_framework_type(framework)
579
+ if framework and isinstance(model_file_or_folder, str):
580
+ auto_update_model_framework_details(
581
+ framework=framework, model_file_or_folder=model_file_or_folder
582
+ )
583
+ environment = environment or ModelVersionEnvironment()
584
+ auto_update_environment_details(environment=environment, framework=framework)
585
+
585
586
  model_manifest = ModelVersionManifest(
586
587
  description=description,
587
588
  metadata=metadata,
588
589
  source=_source_cls.from_dict(source.dict()),
589
- framework=Framework.from_dict(_framework.dict()) if _framework else None,
590
+ framework=Framework.from_dict(framework.dict()) if framework else None,
590
591
  environment=environment,
591
592
  step=step,
592
- model_schema=model_schema,
593
593
  )
594
594
  artifact_version_response = mlfoundry_artifacts_api.finalize_artifact_version_post(
595
595
  finalize_artifact_version_request_dto=FinalizeArtifactVersionRequestDto(
@@ -603,3 +603,15 @@ def _log_model_version( # noqa: C901
603
603
  )
604
604
  )
605
605
  return ModelVersion.from_fqn(fqn=artifact_version_response.artifact_version.fqn)
606
+
607
+
608
+ def infer_signature(
609
+ model_input: Any = None,
610
+ model_output: Optional[
611
+ Union["pd.DataFrame", "np.ndarray", Dict[str, "np.ndarray"]]
612
+ ] = None,
613
+ params: Optional[Dict[str, Any]] = None,
614
+ ):
615
+ return _infer_signature(
616
+ model_input=model_input, model_output=model_output, params=params
617
+ )
@@ -2,7 +2,7 @@ import json
2
2
  import logging
3
3
  import os
4
4
  import posixpath
5
- from pathlib import Path
5
+ from pathlib import Path, PureWindowsPath
6
6
  from typing import Any, Dict, Optional, Sequence, Tuple, Union
7
7
 
8
8
  from truefoundry.ml.exceptions import MlFoundryException
@@ -11,6 +11,13 @@ from truefoundry.ml.log_types.artifacts.constants import DESCRIPTION_MAX_LENGTH
11
11
  logger = logging.getLogger(__name__)
12
12
 
13
13
 
14
+ def to_unix_path(path):
15
+ path = os.path.normpath(path)
16
+ if os.path.sep == "\\":
17
+ path = PureWindowsPath(path).as_posix()
18
+ return path
19
+
20
+
14
21
  def _copy_tree(
15
22
  root_dir: str, src_path: str, dest_path: str, dest_to_src: Dict[str, str]
16
23
  ):
@@ -42,6 +49,35 @@ def is_destination_path_dirlike(dest_path) -> bool:
42
49
  return False
43
50
 
44
51
 
52
+ def get_single_file_path_if_only_one_in_directory(path: str) -> Optional[str]:
53
+ """
54
+ Get the filename from a path, or return a filename from a directory if a single file exists.
55
+ Args:
56
+ path: The file or folder path.
57
+ Returns:
58
+ Optional[str]: The filename or None if no files are found or multiple files are found.
59
+ """
60
+ # If it's already a file, return it as-is
61
+ if os.path.isfile(path):
62
+ return path
63
+
64
+ # If it's a directory, check if it contains a single file
65
+ if is_destination_path_dirlike(path):
66
+ all_files = []
67
+ for root, _, files in os.walk(path):
68
+ # Collect all files found in any subdirectory
69
+ all_files.extend(os.path.join(root, f) for f in files)
70
+ # If more than one file is found, stop early
71
+ if len(all_files) > 1:
72
+ return None
73
+
74
+ # If only one file is found, return it
75
+ if len(all_files) == 1:
76
+ return all_files[0]
77
+
78
+ return None # No file found or Multiple files found
79
+
80
+
45
81
  def _copy_additional_files(
46
82
  root_dir: str,
47
83
  files_dir: str, # relative to root dir e.g. "files/"
@@ -136,7 +172,7 @@ def _get_src_dest_pairs(
136
172
  dest_to_src_map: Dict[str, str],
137
173
  ) -> Sequence[Tuple[str, str]]:
138
174
  src_dest_pairs = [
139
- (src_path, os.path.relpath(dest_abs_path, root_dir))
175
+ (src_path, to_unix_path(os.path.relpath(dest_abs_path, root_dir)))
140
176
  for dest_abs_path, src_path in dest_to_src_map.items()
141
177
  ]
142
178
  return src_dest_pairs