oracle-ads 2.12.9__py3-none-any.whl → 2.12.10rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +4 -4
- ads/aqua/common/enums.py +3 -0
- ads/aqua/common/utils.py +62 -2
- ads/aqua/data.py +2 -19
- ads/aqua/extension/finetune_handler.py +8 -14
- ads/aqua/extension/model_handler.py +19 -2
- ads/aqua/finetuning/constants.py +5 -2
- ads/aqua/finetuning/entities.py +64 -17
- ads/aqua/finetuning/finetuning.py +38 -54
- ads/aqua/model/entities.py +2 -1
- ads/aqua/model/model.py +61 -23
- ads/common/auth.py +9 -9
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +12 -11
- ads/model/__init__.py +11 -13
- ads/model/artifact.py +47 -8
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/generic_model.py +26 -24
- ads/model/model_metadata.py +8 -7
- ads/opctl/config/merger.py +13 -14
- ads/opctl/operator/common/operator_config.py +4 -4
- ads/opctl/operator/lowcode/common/transformations.py +12 -5
- ads/opctl/operator/lowcode/common/utils.py +11 -5
- ads/opctl/operator/lowcode/forecast/const.py +2 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
- ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
- ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +61 -14
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
- ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
- ads/opctl/operator/lowcode/forecast/schema.yaml +13 -0
- ads/opctl/operator/lowcode/forecast/utils.py +4 -3
- ads/telemetry/base.py +18 -11
- ads/telemetry/client.py +33 -13
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/METADATA +7 -8
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/RECORD +60 -39
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,163 @@
|
|
1
|
+
# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
|
2
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
3
|
+
import logging
|
4
|
+
import traceback
|
5
|
+
from sqlite3 import Connection
|
6
|
+
from typing import Any, Dict, List, Optional
|
7
|
+
|
8
|
+
import autogen.runtime_logging
|
9
|
+
from autogen.logger.base_logger import BaseLogger
|
10
|
+
from autogen.logger.logger_factory import LoggerFactory
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class LoggerManager(BaseLogger):
|
16
|
+
"""Manages multiple AutoGen loggers."""
|
17
|
+
|
18
|
+
def __init__(self) -> None:
|
19
|
+
self.loggers: List[BaseLogger] = []
|
20
|
+
super().__init__()
|
21
|
+
|
22
|
+
def add_logger(self, logger: BaseLogger) -> None:
|
23
|
+
"""Adds a new AutoGen logger."""
|
24
|
+
self.loggers.append(logger)
|
25
|
+
|
26
|
+
def _call_loggers(self, method: str, *args, **kwargs) -> None:
|
27
|
+
"""Calls the specific method on each AutoGen logger in self.loggers."""
|
28
|
+
for autogen_logger in self.loggers:
|
29
|
+
try:
|
30
|
+
getattr(autogen_logger, method)(*args, **kwargs)
|
31
|
+
except Exception as e:
|
32
|
+
# Catch the logging exception so that the program will not be interrupted.
|
33
|
+
logger.error(
|
34
|
+
"Failed to %s with %s: %s",
|
35
|
+
method,
|
36
|
+
autogen_logger.__class__.__name__,
|
37
|
+
str(e),
|
38
|
+
)
|
39
|
+
logger.debug(traceback.format_exc())
|
40
|
+
|
41
|
+
def start(self) -> str:
|
42
|
+
"""Starts all loggers."""
|
43
|
+
return self._call_loggers("start")
|
44
|
+
|
45
|
+
def stop(self) -> None:
|
46
|
+
self._call_loggers("stop")
|
47
|
+
# Remove the loggers once they are stopped.
|
48
|
+
self.loggers = []
|
49
|
+
|
50
|
+
def get_connection(self) -> None | Connection:
|
51
|
+
return self._call_loggers("get_connection")
|
52
|
+
|
53
|
+
def log_chat_completion(self, *args, **kwargs) -> None:
|
54
|
+
return self._call_loggers("log_chat_completion", *args, **kwargs)
|
55
|
+
|
56
|
+
def log_new_agent(self, *args, **kwargs) -> None:
|
57
|
+
return self._call_loggers("log_new_agent", *args, **kwargs)
|
58
|
+
|
59
|
+
def log_event(self, *args, **kwargs) -> None:
|
60
|
+
return self._call_loggers("log_event", *args, **kwargs)
|
61
|
+
|
62
|
+
def log_new_wrapper(self, *args, **kwargs) -> None:
|
63
|
+
return self._call_loggers("log_new_wrapper", *args, **kwargs)
|
64
|
+
|
65
|
+
def log_new_client(self, *args, **kwargs) -> None:
|
66
|
+
return self._call_loggers("log_new_client", *args, **kwargs)
|
67
|
+
|
68
|
+
def log_function_use(self, *args, **kwargs) -> None:
|
69
|
+
return self._call_loggers("log_function_use", *args, **kwargs)
|
70
|
+
|
71
|
+
def __repr__(self) -> str:
|
72
|
+
return "\n\n".join(
|
73
|
+
[
|
74
|
+
f"{str(logger.__class__)}:\n{logger.__repr__()}"
|
75
|
+
for logger in self.loggers
|
76
|
+
]
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
def start(
|
81
|
+
autogen_logger: Optional[BaseLogger] = None,
|
82
|
+
logger_type: str = None,
|
83
|
+
config: Optional[Dict[str, Any]] = None,
|
84
|
+
) -> str:
|
85
|
+
"""Starts logging with AutoGen logger.
|
86
|
+
Specify your custom autogen_logger, or the logger_type and config to use a built-in logger.
|
87
|
+
|
88
|
+
Parameters
|
89
|
+
----------
|
90
|
+
autogen_logger : BaseLogger, optional
|
91
|
+
An AutoGen logger, which should be a subclass of autogen.logger.base_logger.BaseLogger.
|
92
|
+
logger_type : str, optional
|
93
|
+
Logger type, which can be a built-in AutoGen logger type ("file", or "sqlite"), by default None.
|
94
|
+
config : dict, optional
|
95
|
+
Configurations for the built-in AutoGen logger, by default None
|
96
|
+
|
97
|
+
Returns
|
98
|
+
-------
|
99
|
+
str
|
100
|
+
A unique session ID returned from starting the logger.
|
101
|
+
|
102
|
+
"""
|
103
|
+
if autogen_logger and logger_type:
|
104
|
+
raise ValueError(
|
105
|
+
"Please specify only autogen_logger(%s) or logger_type(%s).",
|
106
|
+
autogen_logger,
|
107
|
+
logger_type,
|
108
|
+
)
|
109
|
+
|
110
|
+
# Check if a logger is already configured
|
111
|
+
existing_logger = autogen.runtime_logging.autogen_logger
|
112
|
+
if not existing_logger:
|
113
|
+
# No logger is configured
|
114
|
+
logger_manager = LoggerManager()
|
115
|
+
elif isinstance(existing_logger, LoggerManager):
|
116
|
+
# Logger is already configured with ADS
|
117
|
+
logger_manager = existing_logger
|
118
|
+
else:
|
119
|
+
# Logger is configured but it is not via ADS
|
120
|
+
logger.warning("AutoGen is already configured with %s", str(existing_logger))
|
121
|
+
logger_manager = LoggerManager()
|
122
|
+
logger_manager.add_logger(existing_logger)
|
123
|
+
|
124
|
+
# Add AutoGen logger
|
125
|
+
if not autogen_logger:
|
126
|
+
autogen_logger = LoggerFactory.get_logger(
|
127
|
+
logger_type=logger_type, config=config
|
128
|
+
)
|
129
|
+
logger_manager.add_logger(autogen_logger)
|
130
|
+
|
131
|
+
try:
|
132
|
+
session_id = autogen_logger.start()
|
133
|
+
autogen.runtime_logging.is_logging = True
|
134
|
+
autogen.runtime_logging.autogen_logger = logger_manager
|
135
|
+
except Exception as e:
|
136
|
+
logger.error(f"Failed to start logging: {e}")
|
137
|
+
return session_id
|
138
|
+
|
139
|
+
|
140
|
+
def stop(*loggers) -> BaseLogger:
|
141
|
+
"""Stops AutoGen logger.
|
142
|
+
If loggers are managed by LoggerManager,
|
143
|
+
you may specify one or more loggers to be stopped.
|
144
|
+
If no logger is specified, all loggers will be stopped.
|
145
|
+
Stopped loggers will be removed from the LoggerManager.
|
146
|
+
"""
|
147
|
+
autogen_logger = autogen.runtime_logging.autogen_logger
|
148
|
+
if isinstance(autogen_logger, LoggerManager) and loggers:
|
149
|
+
for logger in loggers:
|
150
|
+
logger.stop()
|
151
|
+
if logger in autogen_logger.loggers:
|
152
|
+
autogen_logger.loggers.remove(logger)
|
153
|
+
else:
|
154
|
+
autogen.runtime_logging.stop()
|
155
|
+
return autogen_logger
|
156
|
+
|
157
|
+
|
158
|
+
def get_loggers() -> List[BaseLogger]:
|
159
|
+
"""Gets a list of existing AutoGen loggers."""
|
160
|
+
autogen_logger = autogen.runtime_logging.autogen_logger
|
161
|
+
if isinstance(autogen_logger, LoggerManager):
|
162
|
+
return autogen_logger.loggers
|
163
|
+
return [autogen_logger]
|
@@ -769,6 +769,8 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
769
769
|
Science Model Deployment endpoint. See:
|
770
770
|
https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
|
771
771
|
|
772
|
+
See https://docs.vllm.ai/en/latest/api/inference_params.html for the defaults of the parameters.
|
773
|
+
|
772
774
|
Example:
|
773
775
|
|
774
776
|
.. code-block:: python
|
@@ -786,7 +788,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
786
788
|
|
787
789
|
""" # noqa: E501
|
788
790
|
|
789
|
-
frequency_penalty: float =
|
791
|
+
frequency_penalty: Optional[float] = None
|
790
792
|
"""Penalizes repeated tokens according to frequency. Between 0 and 1."""
|
791
793
|
|
792
794
|
logit_bias: Optional[Dict[str, float]] = None
|
@@ -798,7 +800,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
798
800
|
n: int = 1
|
799
801
|
"""Number of output sequences to return for the given prompt."""
|
800
802
|
|
801
|
-
presence_penalty: float =
|
803
|
+
presence_penalty: Optional[float] = None
|
802
804
|
"""Penalizes repeated tokens. Between 0 and 1."""
|
803
805
|
|
804
806
|
temperature: float = 0.2
|
@@ -812,7 +814,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
812
814
|
(the one with the highest log probability per token).
|
813
815
|
"""
|
814
816
|
|
815
|
-
use_beam_search: Optional[bool] =
|
817
|
+
use_beam_search: Optional[bool] = None
|
816
818
|
"""Whether to use beam search instead of sampling."""
|
817
819
|
|
818
820
|
top_k: Optional[int] = -1
|
@@ -822,15 +824,15 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
822
824
|
"""Float that represents the minimum probability for a token to be considered.
|
823
825
|
Must be in [0,1]. 0 to disable this."""
|
824
826
|
|
825
|
-
repetition_penalty: Optional[float] =
|
827
|
+
repetition_penalty: Optional[float] = None
|
826
828
|
"""Float that penalizes new tokens based on their frequency in the
|
827
829
|
generated text. Values > 1 encourage the model to use new tokens."""
|
828
830
|
|
829
|
-
length_penalty: Optional[float] =
|
831
|
+
length_penalty: Optional[float] = None
|
830
832
|
"""Float that penalizes sequences based on their length. Used only
|
831
833
|
when `use_beam_search` is True."""
|
832
834
|
|
833
|
-
early_stopping: Optional[bool] =
|
835
|
+
early_stopping: Optional[bool] = None
|
834
836
|
"""Controls the stopping condition for beam search. It accepts the
|
835
837
|
following values: `True`, where the generation stops as soon as there
|
836
838
|
are `best_of` complete candidates; `False`, where a heuristic is applied
|
@@ -842,7 +844,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
842
844
|
"""Whether to ignore the EOS token and continue generating tokens after
|
843
845
|
the EOS token is generated."""
|
844
846
|
|
845
|
-
min_tokens: Optional[int] =
|
847
|
+
min_tokens: Optional[int] = None
|
846
848
|
"""Minimum number of tokens to generate per output sequence before
|
847
849
|
EOS or stop_token_ids can be generated"""
|
848
850
|
|
@@ -851,12 +853,11 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
851
853
|
The returned output will contain the stop tokens unless the stop tokens
|
852
854
|
are special tokens."""
|
853
855
|
|
854
|
-
skip_special_tokens: Optional[bool] =
|
856
|
+
skip_special_tokens: Optional[bool] = None
|
855
857
|
"""Whether to skip special tokens in the output. Defaults to True."""
|
856
858
|
|
857
|
-
spaces_between_special_tokens: Optional[bool] =
|
858
|
-
"""Whether to add spaces between special tokens in the output.
|
859
|
-
Defaults to True."""
|
859
|
+
spaces_between_special_tokens: Optional[bool] = None
|
860
|
+
"""Whether to add spaces between special tokens in the output."""
|
860
861
|
|
861
862
|
tool_choice: Optional[str] = None
|
862
863
|
"""Whether to use tool calling.
|
ads/model/__init__.py
CHANGED
@@ -1,29 +1,26 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c) 2021,
|
3
|
+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
|
-
from ads.model.generic_model import GenericModel, ModelState
|
8
6
|
from ads.model.datascience_model import DataScienceModel
|
9
|
-
from ads.model.
|
7
|
+
from ads.model.deployment.model_deployer import ModelDeployer
|
8
|
+
from ads.model.deployment.model_deployment import ModelDeployment
|
9
|
+
from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
|
10
10
|
from ads.model.framework.automl_model import AutoMLModel
|
11
|
+
from ads.model.framework.embedding_onnx_model import EmbeddingONNXModel
|
12
|
+
from ads.model.framework.huggingface_model import HuggingFacePipelineModel
|
11
13
|
from ads.model.framework.lightgbm_model import LightGBMModel
|
12
14
|
from ads.model.framework.pytorch_model import PyTorchModel
|
13
15
|
from ads.model.framework.sklearn_model import SklearnModel
|
16
|
+
from ads.model.framework.spark_model import SparkPipelineModel
|
14
17
|
from ads.model.framework.tensorflow_model import TensorFlowModel
|
15
18
|
from ads.model.framework.xgboost_model import XGBoostModel
|
16
|
-
from ads.model.
|
17
|
-
from ads.model.
|
18
|
-
|
19
|
-
from ads.model.deployment.model_deployer import ModelDeployer
|
20
|
-
from ads.model.deployment.model_deployment import ModelDeployment
|
21
|
-
from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
|
22
|
-
|
19
|
+
from ads.model.generic_model import GenericModel, ModelState
|
20
|
+
from ads.model.model_properties import ModelProperties
|
21
|
+
from ads.model.model_version_set import ModelVersionSet, experiment
|
23
22
|
from ads.model.serde.common import SERDE
|
24
23
|
from ads.model.serde.model_input import ModelInputSerializer
|
25
|
-
|
26
|
-
from ads.model.model_version_set import ModelVersionSet, experiment
|
27
24
|
from ads.model.service.oci_datascience_model_version_set import (
|
28
25
|
ModelVersionSetNotExists,
|
29
26
|
ModelVersionSetNotSaved,
|
@@ -42,6 +39,7 @@ __all__ = [
|
|
42
39
|
"XGBoostModel",
|
43
40
|
"SparkPipelineModel",
|
44
41
|
"HuggingFacePipelineModel",
|
42
|
+
"EmbeddingONNXModel",
|
45
43
|
"ModelDeployer",
|
46
44
|
"ModelDeployment",
|
47
45
|
"ModelDeploymentProperties",
|
ads/model/artifact.py
CHANGED
@@ -1,28 +1,28 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c) 2022,
|
3
|
+
# Copyright (c) 2022, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
import fnmatch
|
8
7
|
import importlib
|
9
8
|
import os
|
10
|
-
import sys
|
11
9
|
import shutil
|
10
|
+
import sys
|
12
11
|
import tempfile
|
13
12
|
import uuid
|
14
|
-
import
|
13
|
+
from datetime import datetime
|
15
14
|
from typing import Dict, Optional, Tuple
|
15
|
+
|
16
|
+
import fsspec
|
17
|
+
from jinja2 import Environment, PackageLoader
|
18
|
+
|
19
|
+
from ads import __version__
|
16
20
|
from ads.common import auth as authutil
|
17
21
|
from ads.common import logger, utils
|
18
22
|
from ads.common.object_storage_details import ObjectStorageDetails
|
19
23
|
from ads.config import CONDA_BUCKET_NAME, CONDA_BUCKET_NS
|
20
24
|
from ads.model.runtime.env_info import EnvInfo, InferenceEnvInfo, TrainingEnvInfo
|
21
25
|
from ads.model.runtime.runtime_info import RuntimeInfo
|
22
|
-
from jinja2 import Environment, PackageLoader
|
23
|
-
import warnings
|
24
|
-
from ads import __version__
|
25
|
-
from datetime import datetime
|
26
26
|
|
27
27
|
MODEL_ARTIFACT_VERSION = "3.0"
|
28
28
|
REQUIRED_ARTIFACT_FILES = ("runtime.yaml", "score.py")
|
@@ -378,6 +378,45 @@ class ModelArtifact:
|
|
378
378
|
) as f:
|
379
379
|
f.write(scorefn_template.render(context))
|
380
380
|
|
381
|
+
def prepare_schema(self, schema_name: str):
|
382
|
+
"""Copies schema to artifact directory.
|
383
|
+
|
384
|
+
Parameters
|
385
|
+
----------
|
386
|
+
schema_name: str
|
387
|
+
The schema name
|
388
|
+
|
389
|
+
Returns
|
390
|
+
-------
|
391
|
+
None
|
392
|
+
|
393
|
+
Raises
|
394
|
+
------
|
395
|
+
FileExistsError
|
396
|
+
If `schema_name` doesn't exist.
|
397
|
+
"""
|
398
|
+
uri_src = os.path.join(
|
399
|
+
os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
|
400
|
+
"templates",
|
401
|
+
"schemas",
|
402
|
+
f"{schema_name}",
|
403
|
+
)
|
404
|
+
|
405
|
+
if not os.path.exists(uri_src):
|
406
|
+
raise FileExistsError(
|
407
|
+
f"{schema_name} does not exists. "
|
408
|
+
"Ensure the schema name is valid or specify a different one."
|
409
|
+
)
|
410
|
+
|
411
|
+
uri_dst = os.path.join(self.artifact_dir, os.path.basename(uri_src))
|
412
|
+
|
413
|
+
utils.copy_file(
|
414
|
+
uri_src=uri_src,
|
415
|
+
uri_dst=uri_dst,
|
416
|
+
force_overwrite=True,
|
417
|
+
auth=self.auth,
|
418
|
+
)
|
419
|
+
|
381
420
|
def reload(self):
|
382
421
|
"""Syncs the `score.py` to reload the model and predict function.
|
383
422
|
|
@@ -0,0 +1,80 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
|
3
|
+
# Copyright (c) 2025 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
from ads.common.decorator.runtime_dependency import (
|
7
|
+
OptionalDependency,
|
8
|
+
runtime_dependency,
|
9
|
+
)
|
10
|
+
from ads.model.extractor.model_info_extractor import ModelInfoExtractor
|
11
|
+
from ads.model.model_metadata import Framework
|
12
|
+
|
13
|
+
|
14
|
+
class EmbeddingONNXExtractor(ModelInfoExtractor):
|
15
|
+
"""Class that extract model metadata from EmbeddingONNXModel models.
|
16
|
+
|
17
|
+
Attributes
|
18
|
+
----------
|
19
|
+
model: object
|
20
|
+
The model to extract metadata from.
|
21
|
+
|
22
|
+
Methods
|
23
|
+
-------
|
24
|
+
framework(self) -> str
|
25
|
+
Returns the framework of the model.
|
26
|
+
algorithm(self) -> object
|
27
|
+
Returns the algorithm of the model.
|
28
|
+
version(self) -> str
|
29
|
+
Returns the version of framework of the model.
|
30
|
+
hyperparameter(self) -> dict
|
31
|
+
Returns the hyperparameter of the model.
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(self, model=None):
|
35
|
+
self.model = model
|
36
|
+
|
37
|
+
@property
|
38
|
+
def framework(self):
|
39
|
+
"""Extracts the framework of the model.
|
40
|
+
|
41
|
+
Returns
|
42
|
+
----------
|
43
|
+
str:
|
44
|
+
The framework of the model.
|
45
|
+
"""
|
46
|
+
return Framework.EMBEDDING_ONNX
|
47
|
+
|
48
|
+
@property
|
49
|
+
def algorithm(self):
|
50
|
+
"""Extracts the algorithm of the model.
|
51
|
+
|
52
|
+
Returns
|
53
|
+
----------
|
54
|
+
object:
|
55
|
+
The algorithm of the model.
|
56
|
+
"""
|
57
|
+
return "Embedding_ONNX"
|
58
|
+
|
59
|
+
@property
|
60
|
+
@runtime_dependency(module="onnxruntime", install_from=OptionalDependency.ONNX)
|
61
|
+
def version(self):
|
62
|
+
"""Extracts the framework version of the model.
|
63
|
+
|
64
|
+
Returns
|
65
|
+
----------
|
66
|
+
str:
|
67
|
+
The framework version of the model.
|
68
|
+
"""
|
69
|
+
return onnxruntime.__version__
|
70
|
+
|
71
|
+
@property
|
72
|
+
def hyperparameter(self):
|
73
|
+
"""Extracts the hyperparameters of the model.
|
74
|
+
|
75
|
+
Returns
|
76
|
+
----------
|
77
|
+
dict:
|
78
|
+
The hyperparameters of the model.
|
79
|
+
"""
|
80
|
+
return None
|