oracle-ads 2.12.10rc0__py3-none-any.whl → 2.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +2 -1
- ads/aqua/app.py +46 -19
- ads/aqua/client/__init__.py +3 -0
- ads/aqua/client/client.py +799 -0
- ads/aqua/common/enums.py +19 -14
- ads/aqua/common/errors.py +3 -4
- ads/aqua/common/utils.py +2 -2
- ads/aqua/constants.py +1 -0
- ads/aqua/evaluation/constants.py +7 -7
- ads/aqua/evaluation/errors.py +3 -4
- ads/aqua/evaluation/evaluation.py +20 -12
- ads/aqua/extension/aqua_ws_msg_handler.py +14 -7
- ads/aqua/extension/base_handler.py +12 -9
- ads/aqua/extension/model_handler.py +29 -1
- ads/aqua/extension/models/ws_models.py +5 -6
- ads/aqua/finetuning/constants.py +3 -3
- ads/aqua/finetuning/entities.py +3 -0
- ads/aqua/finetuning/finetuning.py +32 -1
- ads/aqua/model/constants.py +7 -7
- ads/aqua/model/entities.py +2 -1
- ads/aqua/model/enums.py +4 -5
- ads/aqua/model/model.py +158 -76
- ads/aqua/modeldeployment/deployment.py +22 -10
- ads/aqua/modeldeployment/entities.py +3 -1
- ads/cli.py +16 -8
- ads/common/auth.py +33 -20
- ads/common/extended_enum.py +52 -44
- ads/llm/__init__.py +11 -8
- ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
- ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
- ads/model/artifact_downloader.py +3 -4
- ads/model/datascience_model.py +84 -64
- ads/model/generic_model.py +3 -3
- ads/model/model_metadata.py +17 -11
- ads/model/service/oci_datascience_model.py +12 -14
- ads/opctl/backend/marketplace/helm_helper.py +13 -14
- ads/opctl/cli.py +4 -5
- ads/opctl/cmds.py +28 -32
- ads/opctl/config/merger.py +8 -11
- ads/opctl/config/resolver.py +25 -30
- ads/opctl/operator/cli.py +9 -9
- ads/opctl/operator/common/backend_factory.py +56 -60
- ads/opctl/operator/common/const.py +5 -5
- ads/opctl/operator/lowcode/anomaly/const.py +8 -9
- ads/opctl/operator/lowcode/common/transformations.py +38 -3
- ads/opctl/operator/lowcode/common/utils.py +11 -1
- ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +43 -48
- ads/opctl/operator/lowcode/forecast/__main__.py +10 -0
- ads/opctl/operator/lowcode/forecast/const.py +6 -6
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +1 -1
- ads/opctl/operator/lowcode/forecast/operator_config.py +31 -0
- ads/opctl/operator/lowcode/forecast/schema.yaml +63 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +233 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +238 -0
- ads/opctl/operator/lowcode/pii/constant.py +6 -7
- ads/opctl/operator/lowcode/recommender/constant.py +12 -7
- ads/opctl/operator/runtime/marketplace_runtime.py +4 -10
- ads/opctl/operator/runtime/runtime.py +4 -6
- ads/pipeline/ads_pipeline_run.py +13 -25
- ads/pipeline/visualizer/graph_renderer.py +3 -4
- {oracle_ads-2.12.10rc0.dist-info → oracle_ads-2.13.0.dist-info}/METADATA +4 -2
- {oracle_ads-2.12.10rc0.dist-info → oracle_ads-2.13.0.dist-info}/RECORD +66 -59
- {oracle_ads-2.12.10rc0.dist-info → oracle_ads-2.13.0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.10rc0.dist-info → oracle_ads-2.13.0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.12.10rc0.dist-info → oracle_ads-2.13.0.dist-info}/entry_points.txt +0 -0
@@ -1,8 +1,7 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# Copyright (c) 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
4
|
|
5
|
-
import logging
|
6
5
|
import shlex
|
7
6
|
from typing import Dict, List, Optional, Union
|
8
7
|
|
@@ -271,7 +270,7 @@ class AquaDeploymentApp(AquaApp):
|
|
271
270
|
f"field. Either re-register the model with custom container URI, or set container_image_uri "
|
272
271
|
f"parameter when creating this deployment."
|
273
272
|
) from err
|
274
|
-
|
273
|
+
logger.info(
|
275
274
|
f"Aqua Image used for deploying {aqua_model.id} : {container_image_uri}"
|
276
275
|
)
|
277
276
|
|
@@ -282,14 +281,14 @@ class AquaDeploymentApp(AquaApp):
|
|
282
281
|
default_cmd_var = shlex.split(cmd_var_string)
|
283
282
|
if default_cmd_var:
|
284
283
|
cmd_var = validate_cmd_var(default_cmd_var, cmd_var)
|
285
|
-
|
284
|
+
logger.info(f"CMD used for deploying {aqua_model.id} :{cmd_var}")
|
286
285
|
except ValueError:
|
287
|
-
|
286
|
+
logger.debug(
|
288
287
|
f"CMD will be ignored for this deployment as {AQUA_DEPLOYMENT_CONTAINER_CMD_VAR_METADATA_NAME} "
|
289
288
|
f"key is not available in the custom metadata field for this model."
|
290
289
|
)
|
291
290
|
except Exception as e:
|
292
|
-
|
291
|
+
logger.error(
|
293
292
|
f"There was an issue processing CMD arguments. Error: {str(e)}"
|
294
293
|
)
|
295
294
|
|
@@ -385,7 +384,7 @@ class AquaDeploymentApp(AquaApp):
|
|
385
384
|
if key not in env_var:
|
386
385
|
env_var.update(env)
|
387
386
|
|
388
|
-
|
387
|
+
logger.info(f"Env vars used for deploying {aqua_model.id} :{env_var}")
|
389
388
|
|
390
389
|
# Start model deployment
|
391
390
|
# configure model deployment infrastructure
|
@@ -440,10 +439,14 @@ class AquaDeploymentApp(AquaApp):
|
|
440
439
|
.with_runtime(container_runtime)
|
441
440
|
).deploy(wait_for_completion=False)
|
442
441
|
|
442
|
+
deployment_id = deployment.dsc_model_deployment.id
|
443
|
+
logger.info(
|
444
|
+
f"Aqua model deployment {deployment_id} created for model {aqua_model.id}."
|
445
|
+
)
|
443
446
|
model_type = (
|
444
447
|
AQUA_MODEL_TYPE_CUSTOM if is_fine_tuned_model else AQUA_MODEL_TYPE_SERVICE
|
445
448
|
)
|
446
|
-
|
449
|
+
|
447
450
|
# we arbitrarily choose last 8 characters of OCID to identify MD in telemetry
|
448
451
|
telemetry_kwargs = {"ocid": get_ocid_substring(deployment_id, key_len=8)}
|
449
452
|
|
@@ -539,6 +542,9 @@ class AquaDeploymentApp(AquaApp):
|
|
539
542
|
value=state,
|
540
543
|
)
|
541
544
|
|
545
|
+
logger.info(
|
546
|
+
f"Fetched {len(results)} model deployments from compartment_id={compartment_id}."
|
547
|
+
)
|
542
548
|
# tracks number of times deployment listing was called
|
543
549
|
self.telemetry.record_event_async(category="aqua/deployment", action="list")
|
544
550
|
|
@@ -546,18 +552,21 @@ class AquaDeploymentApp(AquaApp):
|
|
546
552
|
|
547
553
|
@telemetry(entry_point="plugin=deployment&action=delete", name="aqua")
|
548
554
|
def delete(self, model_deployment_id: str):
|
555
|
+
logger.info(f"Deleting model deployment {model_deployment_id}.")
|
549
556
|
return self.ds_client.delete_model_deployment(
|
550
557
|
model_deployment_id=model_deployment_id
|
551
558
|
).data
|
552
559
|
|
553
560
|
@telemetry(entry_point="plugin=deployment&action=deactivate", name="aqua")
|
554
561
|
def deactivate(self, model_deployment_id: str):
|
562
|
+
logger.info(f"Deactivating model deployment {model_deployment_id}.")
|
555
563
|
return self.ds_client.deactivate_model_deployment(
|
556
564
|
model_deployment_id=model_deployment_id
|
557
565
|
).data
|
558
566
|
|
559
567
|
@telemetry(entry_point="plugin=deployment&action=activate", name="aqua")
|
560
568
|
def activate(self, model_deployment_id: str):
|
569
|
+
logger.info(f"Activating model deployment {model_deployment_id}.")
|
561
570
|
return self.ds_client.activate_model_deployment(
|
562
571
|
model_deployment_id=model_deployment_id
|
563
572
|
).data
|
@@ -579,6 +588,8 @@ class AquaDeploymentApp(AquaApp):
|
|
579
588
|
AquaDeploymentDetail:
|
580
589
|
The instance of the Aqua model deployment details.
|
581
590
|
"""
|
591
|
+
logger.info(f"Fetching model deployment details for {model_deployment_id}.")
|
592
|
+
|
582
593
|
model_deployment = self.ds_client.get_model_deployment(
|
583
594
|
model_deployment_id=model_deployment_id, **kwargs
|
584
595
|
).data
|
@@ -594,7 +605,8 @@ class AquaDeploymentApp(AquaApp):
|
|
594
605
|
|
595
606
|
if not oci_aqua:
|
596
607
|
raise AquaRuntimeError(
|
597
|
-
f"Target deployment {model_deployment_id} is not Aqua deployment
|
608
|
+
f"Target deployment {model_deployment_id} is not Aqua deployment as it does not contain "
|
609
|
+
f"{Tags.AQUA_TAG} tag."
|
598
610
|
)
|
599
611
|
|
600
612
|
log_id = ""
|
@@ -652,7 +664,7 @@ class AquaDeploymentApp(AquaApp):
|
|
652
664
|
config = self.get_config(model_id, AQUA_MODEL_DEPLOYMENT_CONFIG)
|
653
665
|
if not config:
|
654
666
|
logger.debug(
|
655
|
-
f"Deployment config for custom model: {model_id} is not available."
|
667
|
+
f"Deployment config for custom model: {model_id} is not available. Use defaults."
|
656
668
|
)
|
657
669
|
return config
|
658
670
|
|
@@ -41,6 +41,7 @@ class AquaDeployment(DataClassSerializable):
|
|
41
41
|
id: str = None
|
42
42
|
display_name: str = None
|
43
43
|
aqua_service_model: bool = None
|
44
|
+
model_id: str = None
|
44
45
|
aqua_model_name: str = None
|
45
46
|
state: str = None
|
46
47
|
description: str = None
|
@@ -97,7 +98,7 @@ class AquaDeployment(DataClassSerializable):
|
|
97
98
|
else None
|
98
99
|
),
|
99
100
|
)
|
100
|
-
|
101
|
+
model_id = oci_model_deployment._model_deployment_configuration_details.model_configuration_details.model_id
|
101
102
|
tags = {}
|
102
103
|
tags.update(oci_model_deployment.freeform_tags or UNKNOWN_DICT)
|
103
104
|
tags.update(oci_model_deployment.defined_tags or UNKNOWN_DICT)
|
@@ -110,6 +111,7 @@ class AquaDeployment(DataClassSerializable):
|
|
110
111
|
|
111
112
|
return AquaDeployment(
|
112
113
|
id=oci_model_deployment.id,
|
114
|
+
model_id=model_id,
|
113
115
|
display_name=oci_model_deployment.display_name,
|
114
116
|
aqua_service_model=aqua_service_model_tag is not None,
|
115
117
|
aqua_model_name=aqua_model_name,
|
ads/cli.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
#
|
3
|
-
|
4
|
-
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
|
2
|
+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
|
5
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
4
|
|
5
|
+
import json
|
6
|
+
import logging
|
7
7
|
import sys
|
8
8
|
import traceback
|
9
|
-
|
9
|
+
import uuid
|
10
10
|
|
11
11
|
import fire
|
12
|
+
from pydantic import BaseModel
|
12
13
|
|
13
14
|
from ads.common import logger
|
14
15
|
|
@@ -27,7 +28,7 @@ except Exception as ex:
|
|
27
28
|
)
|
28
29
|
logger.debug(ex)
|
29
30
|
logger.debug(traceback.format_exc())
|
30
|
-
exit()
|
31
|
+
sys.exit()
|
31
32
|
|
32
33
|
# https://packaging.python.org/en/latest/guides/single-sourcing-package-version/#single-sourcing-the-package-version
|
33
34
|
if sys.version_info >= (3, 8):
|
@@ -84,7 +85,13 @@ def serialize(data):
|
|
84
85
|
The string representation of each dataclass object.
|
85
86
|
"""
|
86
87
|
if isinstance(data, list):
|
87
|
-
|
88
|
+
for item in data:
|
89
|
+
if isinstance(item, BaseModel):
|
90
|
+
print(json.dumps(item.dict(), indent=4))
|
91
|
+
else:
|
92
|
+
print(str(item))
|
93
|
+
elif isinstance(data, BaseModel):
|
94
|
+
print(json.dumps(data.dict(), indent=4))
|
88
95
|
else:
|
89
96
|
print(str(data))
|
90
97
|
|
@@ -122,8 +129,9 @@ def exit_program(ex: Exception, logger: "logging.Logger") -> None:
|
|
122
129
|
... exit_program(e, logger)
|
123
130
|
"""
|
124
131
|
|
125
|
-
|
126
|
-
logger.
|
132
|
+
request_id = str(uuid.uuid4())
|
133
|
+
logger.debug(f"Error Request ID: {request_id}\nError: {traceback.format_exc()}")
|
134
|
+
logger.error(f"Error Request ID: {request_id}\n" f"Error: {str(ex)}")
|
127
135
|
|
128
136
|
exit_code = getattr(ex, "exit_code", 1)
|
129
137
|
logger.error(f"Exit code: {exit_code}")
|
ads/common/auth.py
CHANGED
@@ -1,23 +1,25 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8; -*-
|
3
2
|
|
4
|
-
# Copyright (c) 2021,
|
3
|
+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
import copy
|
8
|
-
from datetime import datetime
|
9
7
|
import os
|
10
|
-
from dataclasses import dataclass
|
11
8
|
import time
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from datetime import datetime
|
12
11
|
from typing import Any, Callable, Dict, Optional, Union
|
13
12
|
|
14
|
-
import ads.telemetry
|
15
13
|
import oci
|
14
|
+
from oci.config import (
|
15
|
+
DEFAULT_LOCATION, # "~/.oci/config"
|
16
|
+
DEFAULT_PROFILE, # "DEFAULT"
|
17
|
+
)
|
18
|
+
|
19
|
+
import ads.telemetry
|
16
20
|
from ads.common import logger
|
17
21
|
from ads.common.decorator.deprecate import deprecated
|
18
|
-
from ads.common.extended_enum import
|
19
|
-
from oci.config import DEFAULT_LOCATION # "~/.oci/config"
|
20
|
-
from oci.config import DEFAULT_PROFILE # "DEFAULT"
|
22
|
+
from ads.common.extended_enum import ExtendedEnum
|
21
23
|
|
22
24
|
SECURITY_TOKEN_LEFT_TIME = 600
|
23
25
|
|
@@ -26,7 +28,7 @@ class SecurityTokenError(Exception): # pragma: no cover
|
|
26
28
|
pass
|
27
29
|
|
28
30
|
|
29
|
-
class AuthType(
|
31
|
+
class AuthType(ExtendedEnum):
|
30
32
|
API_KEY = "api_key"
|
31
33
|
RESOURCE_PRINCIPAL = "resource_principal"
|
32
34
|
INSTANCE_PRINCIPAL = "instance_principal"
|
@@ -73,7 +75,11 @@ class AuthState(metaclass=SingletonMeta):
|
|
73
75
|
self.oci_key_profile = self.oci_key_profile or os.environ.get(
|
74
76
|
"OCI_CONFIG_PROFILE", DEFAULT_PROFILE
|
75
77
|
)
|
76
|
-
self.oci_config =
|
78
|
+
self.oci_config = (
|
79
|
+
self.oci_config or {"region": os.environ["OCI_RESOURCE_REGION"]}
|
80
|
+
if os.environ.get("OCI_RESOURCE_REGION")
|
81
|
+
else {}
|
82
|
+
)
|
77
83
|
self.oci_signer_kwargs = self.oci_signer_kwargs or {}
|
78
84
|
self.oci_client_kwargs = self.oci_client_kwargs or {}
|
79
85
|
|
@@ -82,7 +88,9 @@ def set_auth(
|
|
82
88
|
auth: Optional[str] = AuthType.API_KEY,
|
83
89
|
oci_config_location: Optional[str] = DEFAULT_LOCATION,
|
84
90
|
profile: Optional[str] = DEFAULT_PROFILE,
|
85
|
-
config: Optional[Dict] = {"region": os.environ["OCI_RESOURCE_REGION"]}
|
91
|
+
config: Optional[Dict] = {"region": os.environ["OCI_RESOURCE_REGION"]}
|
92
|
+
if os.environ.get("OCI_RESOURCE_REGION")
|
93
|
+
else {},
|
86
94
|
signer: Optional[Any] = None,
|
87
95
|
signer_callable: Optional[Callable] = None,
|
88
96
|
signer_kwargs: Optional[Dict] = {},
|
@@ -202,8 +210,8 @@ def set_auth(
|
|
202
210
|
oci_config_location != DEFAULT_LOCATION or profile != DEFAULT_PROFILE
|
203
211
|
):
|
204
212
|
raise ValueError(
|
205
|
-
|
206
|
-
|
213
|
+
"'config' and 'oci_config_location', 'profile' pair are mutually exclusive."
|
214
|
+
"Please specify 'config' OR 'oci_config_location', 'profile' pair."
|
207
215
|
)
|
208
216
|
|
209
217
|
auth_state.oci_config = config
|
@@ -621,7 +629,7 @@ class APIKey(AuthSignerGenerator):
|
|
621
629
|
)
|
622
630
|
|
623
631
|
oci.config.validate_config(configuration)
|
624
|
-
logger.debug(
|
632
|
+
logger.debug("Using 'api_key' authentication.")
|
625
633
|
return {
|
626
634
|
"config": configuration,
|
627
635
|
"signer": oci.signer.Signer(
|
@@ -684,14 +692,19 @@ class ResourcePrincipal(AuthSignerGenerator):
|
|
684
692
|
"signer": oci.auth.signers.get_resource_principals_signer(),
|
685
693
|
"client_kwargs": self.client_kwargs,
|
686
694
|
}
|
687
|
-
logger.debug(
|
695
|
+
logger.debug("Using 'resource_principal' authentication.")
|
688
696
|
return signer_dict
|
689
697
|
|
690
698
|
@staticmethod
|
691
699
|
def supported():
|
692
700
|
return any(
|
693
701
|
os.environ.get(var)
|
694
|
-
for var in [
|
702
|
+
for var in [
|
703
|
+
"JOB_RUN_OCID",
|
704
|
+
"NB_SESSION_OCID",
|
705
|
+
"DATAFLOW_RUN_ID",
|
706
|
+
"PIPELINE_RUN_OCID",
|
707
|
+
]
|
695
708
|
)
|
696
709
|
|
697
710
|
|
@@ -747,7 +760,7 @@ class InstancePrincipal(AuthSignerGenerator):
|
|
747
760
|
),
|
748
761
|
"client_kwargs": self.client_kwargs,
|
749
762
|
}
|
750
|
-
logger.debug(
|
763
|
+
logger.debug("Using 'instance_principal' authentication.")
|
751
764
|
return signer_dict
|
752
765
|
|
753
766
|
|
@@ -814,7 +827,7 @@ class SecurityToken(AuthSignerGenerator):
|
|
814
827
|
oci.config.from_file(self.oci_config_location, self.oci_key_profile)
|
815
828
|
)
|
816
829
|
|
817
|
-
logger.debug(
|
830
|
+
logger.debug("Using 'security_token' authentication.")
|
818
831
|
|
819
832
|
for parameter in self.SECURITY_TOKEN_REQUIRED:
|
820
833
|
if parameter not in configuration:
|
@@ -903,7 +916,7 @@ class SecurityToken(AuthSignerGenerator):
|
|
903
916
|
raise ValueError("Invalid `security_token_file`. Specify a valid path.")
|
904
917
|
try:
|
905
918
|
token = None
|
906
|
-
with open(expanded_path
|
919
|
+
with open(expanded_path) as f:
|
907
920
|
token = f.read()
|
908
921
|
return token
|
909
922
|
except:
|
@@ -1023,7 +1036,7 @@ class OCIAuthContext:
|
|
1023
1036
|
logger.debug(f"OCI profile set to {self.profile}")
|
1024
1037
|
else:
|
1025
1038
|
ads.set_auth(auth=AuthType.RESOURCE_PRINCIPAL)
|
1026
|
-
logger.debug(
|
1039
|
+
logger.debug("OCI auth set to resource principal")
|
1027
1040
|
return self
|
1028
1041
|
|
1029
1042
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
ads/common/extended_enum.py
CHANGED
@@ -1,73 +1,81 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8; -*-
|
3
2
|
|
4
|
-
# Copyright (c) 2022,
|
3
|
+
# Copyright (c) 2022, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
|
8
7
|
from abc import ABCMeta
|
9
|
-
from enum import Enum
|
10
8
|
|
11
9
|
|
12
10
|
class ExtendedEnumMeta(ABCMeta):
|
13
|
-
"""
|
11
|
+
"""
|
12
|
+
A helper metaclass to extend functionality of a generic "Enum-like" class.
|
14
13
|
|
15
14
|
Methods
|
16
15
|
-------
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
... KEY2 = "value2"
|
25
|
-
>>> print(TestEnum.KEY1) # "value1"
|
16
|
+
__contains__(cls, item) -> bool:
|
17
|
+
Checks if `item` is among the attribute values of the class.
|
18
|
+
Case-insensitive if `item` is a string.
|
19
|
+
values(cls) -> tuple:
|
20
|
+
Returns the tuple of class attribute values.
|
21
|
+
keys(cls) -> tuple:
|
22
|
+
Returns the tuple of class attribute names.
|
26
23
|
"""
|
27
24
|
|
28
|
-
def __contains__(cls,
|
29
|
-
|
30
|
-
|
31
|
-
def values(cls) -> list:
|
32
|
-
"""Gets the list of class attributes values.
|
25
|
+
def __contains__(cls, item: object) -> bool:
|
26
|
+
"""
|
27
|
+
Checks if `item` is a member of the class's values.
|
33
28
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
29
|
+
- If `item` is a string, does a case-insensitive match against any string
|
30
|
+
values stored in the class.
|
31
|
+
- Otherwise, does a direct membership test.
|
32
|
+
"""
|
33
|
+
# Gather the attribute values
|
34
|
+
attr_values = cls.values()
|
35
|
+
|
36
|
+
# If item is a string, compare case-insensitively to any str-type values
|
37
|
+
if isinstance(item, str):
|
38
|
+
return any(
|
39
|
+
isinstance(val, str) and val.lower() == item.lower()
|
40
|
+
for val in attr_values
|
41
|
+
)
|
42
|
+
else:
|
43
|
+
# For non-string items (e.g., int), do a direct membership check
|
44
|
+
return item in attr_values
|
45
|
+
|
46
|
+
def __iter__(cls):
|
47
|
+
# Make the class iterable by returning an iterator over its values
|
48
|
+
return iter(cls.values())
|
49
|
+
|
50
|
+
def values(cls) -> tuple:
|
51
|
+
"""
|
52
|
+
Gets the tuple of class attribute values, excluding private or special
|
53
|
+
attributes and any callables (methods, etc.).
|
38
54
|
"""
|
39
55
|
return tuple(
|
40
|
-
value
|
56
|
+
value
|
57
|
+
for key, value in cls.__dict__.items()
|
58
|
+
if not key.startswith("_") and not callable(value)
|
41
59
|
)
|
42
60
|
|
43
|
-
def keys(cls) ->
|
44
|
-
"""
|
45
|
-
|
46
|
-
|
47
|
-
-------
|
48
|
-
list
|
49
|
-
The list of class attributes names.
|
61
|
+
def keys(cls) -> tuple:
|
62
|
+
"""
|
63
|
+
Gets the tuple of class attribute names, excluding private or special
|
64
|
+
attributes and any callables (methods, etc.).
|
50
65
|
"""
|
51
66
|
return tuple(
|
52
|
-
key
|
67
|
+
key
|
68
|
+
for key, value in cls.__dict__.items()
|
69
|
+
if not key.startswith("_") and not callable(value)
|
53
70
|
)
|
54
71
|
|
55
72
|
|
56
|
-
class ExtendedEnum(
|
73
|
+
class ExtendedEnum(metaclass=ExtendedEnumMeta):
|
57
74
|
"""The base class to extend functionality of a generic Enum.
|
58
75
|
|
59
76
|
Examples
|
60
77
|
--------
|
61
|
-
>>> class TestEnum(
|
62
|
-
... KEY1 = "
|
63
|
-
... KEY2 = "
|
64
|
-
>>> print(TestEnum.KEY1.value) # "value1"
|
78
|
+
>>> class TestEnum(ExtendedEnum):
|
79
|
+
... KEY1 = "v1"
|
80
|
+
... KEY2 = "v2"
|
65
81
|
"""
|
66
|
-
|
67
|
-
@classmethod
|
68
|
-
def values(cls):
|
69
|
-
return sorted(map(lambda c: c.value, cls))
|
70
|
-
|
71
|
-
@classmethod
|
72
|
-
def keys(cls):
|
73
|
-
return sorted(map(lambda c: c.name, cls))
|
ads/llm/__init__.py
CHANGED
@@ -1,21 +1,24 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c)
|
3
|
+
# Copyright (c) 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
try:
|
8
7
|
import langchain
|
9
|
-
|
10
|
-
|
11
|
-
OCIModelDeploymentTGI,
|
12
|
-
)
|
8
|
+
|
9
|
+
from ads.llm.chat_template import ChatTemplates
|
13
10
|
from ads.llm.langchain.plugins.chat_models.oci_data_science import (
|
14
11
|
ChatOCIModelDeployment,
|
15
|
-
ChatOCIModelDeploymentVLLM,
|
16
12
|
ChatOCIModelDeploymentTGI,
|
13
|
+
ChatOCIModelDeploymentVLLM,
|
14
|
+
)
|
15
|
+
from ads.llm.langchain.plugins.embeddings.oci_data_science_model_deployment_endpoint import (
|
16
|
+
OCIDataScienceEmbedding,
|
17
|
+
)
|
18
|
+
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
|
19
|
+
OCIModelDeploymentTGI,
|
20
|
+
OCIModelDeploymentVLLM,
|
17
21
|
)
|
18
|
-
from ads.llm.chat_template import ChatTemplates
|
19
22
|
except ImportError as ex:
|
20
23
|
if ex.name == "langchain":
|
21
24
|
raise ImportError(
|