oracle-ads 2.12.11__py3-none-any.whl → 2.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. ads/aqua/__init__.py +7 -1
  2. ads/aqua/app.py +41 -27
  3. ads/aqua/client/client.py +48 -11
  4. ads/aqua/common/entities.py +28 -1
  5. ads/aqua/common/enums.py +32 -21
  6. ads/aqua/common/errors.py +3 -4
  7. ads/aqua/common/utils.py +10 -15
  8. ads/aqua/config/container_config.py +203 -0
  9. ads/aqua/config/evaluation/evaluation_service_config.py +5 -181
  10. ads/aqua/constants.py +1 -1
  11. ads/aqua/evaluation/constants.py +7 -7
  12. ads/aqua/evaluation/errors.py +3 -4
  13. ads/aqua/evaluation/evaluation.py +4 -4
  14. ads/aqua/extension/base_handler.py +4 -0
  15. ads/aqua/extension/model_handler.py +41 -27
  16. ads/aqua/extension/models/ws_models.py +5 -6
  17. ads/aqua/finetuning/constants.py +3 -3
  18. ads/aqua/finetuning/finetuning.py +2 -3
  19. ads/aqua/model/constants.py +7 -7
  20. ads/aqua/model/entities.py +2 -3
  21. ads/aqua/model/enums.py +4 -5
  22. ads/aqua/model/model.py +46 -29
  23. ads/aqua/modeldeployment/deployment.py +6 -14
  24. ads/aqua/modeldeployment/entities.py +5 -3
  25. ads/aqua/server/__init__.py +4 -0
  26. ads/aqua/server/__main__.py +24 -0
  27. ads/aqua/server/app.py +47 -0
  28. ads/aqua/server/aqua_spec.yml +1291 -0
  29. ads/aqua/ui.py +5 -199
  30. ads/common/auth.py +50 -28
  31. ads/common/extended_enum.py +52 -44
  32. ads/common/utils.py +91 -11
  33. ads/config.py +3 -0
  34. ads/llm/__init__.py +12 -8
  35. ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
  36. ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
  37. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +32 -23
  38. ads/model/artifact_downloader.py +6 -4
  39. ads/model/common/utils.py +15 -3
  40. ads/model/datascience_model.py +422 -71
  41. ads/model/generic_model.py +3 -3
  42. ads/model/model_metadata.py +70 -24
  43. ads/model/model_version_set.py +5 -3
  44. ads/model/service/oci_datascience_model.py +487 -17
  45. ads/opctl/anomaly_detection.py +11 -0
  46. ads/opctl/backend/marketplace/helm_helper.py +13 -14
  47. ads/opctl/cli.py +4 -5
  48. ads/opctl/cmds.py +28 -32
  49. ads/opctl/config/merger.py +8 -11
  50. ads/opctl/config/resolver.py +25 -30
  51. ads/opctl/forecast.py +11 -0
  52. ads/opctl/operator/cli.py +9 -9
  53. ads/opctl/operator/common/backend_factory.py +56 -60
  54. ads/opctl/operator/common/const.py +5 -5
  55. ads/opctl/operator/common/utils.py +16 -0
  56. ads/opctl/operator/lowcode/anomaly/const.py +8 -9
  57. ads/opctl/operator/lowcode/common/data.py +5 -2
  58. ads/opctl/operator/lowcode/common/transformations.py +2 -12
  59. ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +43 -48
  60. ads/opctl/operator/lowcode/forecast/__main__.py +5 -5
  61. ads/opctl/operator/lowcode/forecast/const.py +6 -6
  62. ads/opctl/operator/lowcode/forecast/model/arima.py +6 -3
  63. ads/opctl/operator/lowcode/forecast/model/automlx.py +61 -31
  64. ads/opctl/operator/lowcode/forecast/model/base_model.py +66 -40
  65. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +79 -13
  66. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +5 -2
  67. ads/opctl/operator/lowcode/forecast/model/prophet.py +28 -15
  68. ads/opctl/operator/lowcode/forecast/model_evaluator.py +13 -15
  69. ads/opctl/operator/lowcode/forecast/schema.yaml +1 -1
  70. ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +7 -0
  71. ads/opctl/operator/lowcode/forecast/whatifserve/score.py +19 -11
  72. ads/opctl/operator/lowcode/pii/constant.py +6 -7
  73. ads/opctl/operator/lowcode/recommender/constant.py +12 -7
  74. ads/opctl/operator/runtime/marketplace_runtime.py +4 -10
  75. ads/opctl/operator/runtime/runtime.py +4 -6
  76. ads/pipeline/ads_pipeline_run.py +13 -25
  77. ads/pipeline/visualizer/graph_renderer.py +3 -4
  78. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/METADATA +18 -15
  79. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/RECORD +82 -74
  80. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/WHEEL +1 -1
  81. ads/aqua/config/evaluation/evaluation_service_model_config.py +0 -8
  82. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/entry_points.txt +0 -0
  83. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info/licenses}/LICENSE.txt +0 -0
ads/aqua/ui.py CHANGED
@@ -1,12 +1,9 @@
1
1
  #!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
  import concurrent.futures
5
- from dataclasses import dataclass, field, fields
6
5
  from datetime import datetime, timedelta
7
- from enum import Enum
8
6
  from threading import Lock
9
- from typing import Dict, List, Optional
10
7
 
11
8
  from cachetools import TTLCache
12
9
  from oci.exceptions import ServiceError
@@ -14,210 +11,18 @@ from oci.identity.models import Compartment
14
11
 
15
12
  from ads.aqua import logger
16
13
  from ads.aqua.app import AquaApp
17
- from ads.aqua.common.entities import ContainerSpec
18
14
  from ads.aqua.common.enums import Tags
19
15
  from ads.aqua.common.errors import AquaResourceAccessError, AquaValueError
20
16
  from ads.aqua.common.utils import get_container_config, sanitize_response
17
+ from ads.aqua.config.container_config import AquaContainerConfig
21
18
  from ads.aqua.constants import PRIVATE_ENDPOINT_TYPE
22
19
  from ads.common import oci_client as oc
23
20
  from ads.common.auth import default_signer
24
21
  from ads.common.object_storage_details import ObjectStorageDetails
25
- from ads.common.serializer import DataClassSerializable
26
- from ads.config import (
27
- COMPARTMENT_OCID,
28
- DATA_SCIENCE_SERVICE_NAME,
29
- TENANCY_OCID,
30
- )
22
+ from ads.config import COMPARTMENT_OCID, DATA_SCIENCE_SERVICE_NAME, TENANCY_OCID
31
23
  from ads.telemetry import telemetry
32
24
 
33
25
 
34
- class ModelFormat(Enum):
35
- GGUF = "GGUF"
36
- SAFETENSORS = "SAFETENSORS"
37
- UNKNOWN = "UNKNOWN"
38
-
39
- def to_dict(self):
40
- return self.value
41
-
42
-
43
- # todo: the container config spec information is shared across ui and deployment modules, move them
44
- # within ads.aqua.common.entities. In that case, check for circular imports due to usage of get_container_config.
45
-
46
-
47
- @dataclass(repr=False)
48
- class AquaContainerEvaluationConfig(DataClassSerializable):
49
- """
50
- Represents the evaluation configuration for the container.
51
- """
52
-
53
- inference_max_threads: Optional[int] = None
54
- inference_rps: Optional[int] = None
55
- inference_timeout: Optional[int] = None
56
- inference_retries: Optional[int] = None
57
- inference_backoff_factor: Optional[int] = None
58
- inference_delay: Optional[int] = None
59
-
60
- @classmethod
61
- def from_config(cls, config: dict) -> "AquaContainerEvaluationConfig":
62
- return cls(
63
- inference_max_threads=config.get("inference_max_threads"),
64
- inference_rps=config.get("inference_rps"),
65
- inference_timeout=config.get("inference_timeout"),
66
- inference_retries=config.get("inference_retries"),
67
- inference_backoff_factor=config.get("inference_backoff_factor"),
68
- inference_delay=config.get("inference_delay"),
69
- )
70
-
71
- def to_filtered_dict(self):
72
- return {
73
- field.name: getattr(self, field.name)
74
- for field in fields(self)
75
- if getattr(self, field.name) is not None
76
- }
77
-
78
-
79
- @dataclass(repr=False)
80
- class AquaContainerConfigSpec(DataClassSerializable):
81
- cli_param: str = None
82
- server_port: str = None
83
- health_check_port: str = None
84
- env_vars: List[dict] = None
85
- restricted_params: List[str] = None
86
-
87
-
88
- @dataclass(repr=False)
89
- class AquaContainerConfigItem(DataClassSerializable):
90
- """Represents an item of the AQUA container configuration."""
91
-
92
- class Platform(Enum):
93
- ARM_CPU = "ARM_CPU"
94
- NVIDIA_GPU = "NVIDIA_GPU"
95
-
96
- def to_dict(self):
97
- return self.value
98
-
99
- def __repr__(self):
100
- return repr(self.value)
101
-
102
- name: str = None
103
- version: str = None
104
- display_name: str = None
105
- family: str = None
106
- platforms: List[Platform] = None
107
- model_formats: List[ModelFormat] = None
108
- spec: AquaContainerConfigSpec = field(default_factory=AquaContainerConfigSpec)
109
-
110
-
111
- @dataclass(repr=False)
112
- class AquaContainerConfig(DataClassSerializable):
113
- """
114
- Represents a configuration with AQUA containers to be returned to the client.
115
- """
116
-
117
- inference: Dict[str, AquaContainerConfigItem] = field(default_factory=dict)
118
- finetune: Dict[str, AquaContainerConfigItem] = field(default_factory=dict)
119
- evaluate: Dict[str, AquaContainerConfigItem] = field(default_factory=dict)
120
-
121
- def to_dict(self):
122
- return {
123
- "inference": list(self.inference.values()),
124
- "finetune": list(self.finetune.values()),
125
- "evaluate": list(self.evaluate.values()),
126
- }
127
-
128
- @classmethod
129
- def from_container_index_json(
130
- cls,
131
- config: Optional[Dict] = None,
132
- enable_spec: Optional[bool] = False,
133
- ) -> "AquaContainerConfig":
134
- """
135
- Create an AquaContainerConfig instance from a container index JSON.
136
-
137
- Parameters
138
- ----------
139
- config : Dict
140
- The container index JSON.
141
- enable_spec: bool
142
- flag to check if container specification details should be fetched.
143
-
144
- Returns
145
- -------
146
- AquaContainerConfig
147
- The container configuration instance.
148
- """
149
- if not config:
150
- config = get_container_config()
151
- inference_items = {}
152
- finetune_items = {}
153
- evaluate_items = {}
154
-
155
- # extract inference containers
156
- for container_type, containers in config.items():
157
- if isinstance(containers, list):
158
- for container in containers:
159
- platforms = [
160
- AquaContainerConfigItem.Platform[platform]
161
- for platform in container.get("platforms", [])
162
- ]
163
- model_formats = [
164
- ModelFormat[model_format]
165
- for model_format in container.get("modelFormats", [])
166
- ]
167
- container_spec = (
168
- config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
169
- container_type, {}
170
- )
171
- if enable_spec
172
- else None
173
- )
174
- container_item = AquaContainerConfigItem(
175
- name=container.get("name", ""),
176
- version=container.get("version", ""),
177
- display_name=container.get(
178
- "displayName", container.get("version", "")
179
- ),
180
- family=container_type,
181
- platforms=platforms,
182
- model_formats=model_formats,
183
- spec=(
184
- AquaContainerConfigSpec(
185
- cli_param=container_spec.get(
186
- ContainerSpec.CLI_PARM, ""
187
- ),
188
- server_port=container_spec.get(
189
- ContainerSpec.SERVER_PORT, ""
190
- ),
191
- health_check_port=container_spec.get(
192
- ContainerSpec.HEALTH_CHECK_PORT, ""
193
- ),
194
- env_vars=container_spec.get(ContainerSpec.ENV_VARS, []),
195
- restricted_params=container_spec.get(
196
- ContainerSpec.RESTRICTED_PARAMS, []
197
- ),
198
- )
199
- if container_spec
200
- else None
201
- ),
202
- )
203
- if container.get("type") == "inference":
204
- inference_items[container_type] = container_item
205
- elif (
206
- container.get("type") == "fine-tune"
207
- or container_type == "odsc-llm-fine-tuning"
208
- ):
209
- finetune_items[container_type] = container_item
210
- elif (
211
- container.get("type") == "evaluate"
212
- or container_type == "odsc-llm-evaluate"
213
- ):
214
- evaluate_items[container_type] = container_item
215
-
216
- return AquaContainerConfig(
217
- inference=inference_items, finetune=finetune_items, evaluate=evaluate_items
218
- )
219
-
220
-
221
26
  class AquaUIApp(AquaApp):
222
27
  """Contains APIs for supporting Aqua UI.
223
28
 
@@ -512,7 +317,8 @@ class AquaUIApp(AquaApp):
512
317
 
513
318
  Returns
514
319
  -------
515
- str has json representation of `oci.data_science.models.ModelDeploymentShapeSummary`."""
320
+ str has json representation of `oci.data_science.models.ModelDeploymentShapeSummary`.
321
+ """
516
322
  compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
517
323
  logger.info(
518
324
  f"Loading model deployment shape summary from compartment: {compartment_id}"
ads/common/auth.py CHANGED
@@ -1,23 +1,25 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
- # Copyright (c) 2021, 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2021, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  import copy
8
- from datetime import datetime
9
7
  import os
10
- from dataclasses import dataclass
11
8
  import time
9
+ from dataclasses import dataclass
10
+ from datetime import datetime
12
11
  from typing import Any, Callable, Dict, Optional, Union
13
12
 
14
- import ads.telemetry
15
13
  import oci
14
+ from oci.config import (
15
+ DEFAULT_LOCATION, # "~/.oci/config"
16
+ DEFAULT_PROFILE, # "DEFAULT"
17
+ )
18
+
19
+ import ads.telemetry
16
20
  from ads.common import logger
17
21
  from ads.common.decorator.deprecate import deprecated
18
- from ads.common.extended_enum import ExtendedEnumMeta
19
- from oci.config import DEFAULT_LOCATION # "~/.oci/config"
20
- from oci.config import DEFAULT_PROFILE # "DEFAULT"
22
+ from ads.common.extended_enum import ExtendedEnum
21
23
 
22
24
  SECURITY_TOKEN_LEFT_TIME = 600
23
25
 
@@ -26,7 +28,7 @@ class SecurityTokenError(Exception): # pragma: no cover
26
28
  pass
27
29
 
28
30
 
29
- class AuthType(str, metaclass=ExtendedEnumMeta):
31
+ class AuthType(ExtendedEnum):
30
32
  API_KEY = "api_key"
31
33
  RESOURCE_PRINCIPAL = "resource_principal"
32
34
  INSTANCE_PRINCIPAL = "instance_principal"
@@ -73,7 +75,11 @@ class AuthState(metaclass=SingletonMeta):
73
75
  self.oci_key_profile = self.oci_key_profile or os.environ.get(
74
76
  "OCI_CONFIG_PROFILE", DEFAULT_PROFILE
75
77
  )
76
- self.oci_config = self.oci_config or {"region": os.environ["OCI_RESOURCE_REGION"]} if os.environ.get("OCI_RESOURCE_REGION") else {}
78
+ self.oci_config = (
79
+ self.oci_config or {"region": os.environ["OCI_RESOURCE_REGION"]}
80
+ if os.environ.get("OCI_RESOURCE_REGION")
81
+ else {}
82
+ )
77
83
  self.oci_signer_kwargs = self.oci_signer_kwargs or {}
78
84
  self.oci_client_kwargs = self.oci_client_kwargs or {}
79
85
 
@@ -82,11 +88,15 @@ def set_auth(
82
88
  auth: Optional[str] = AuthType.API_KEY,
83
89
  oci_config_location: Optional[str] = DEFAULT_LOCATION,
84
90
  profile: Optional[str] = DEFAULT_PROFILE,
85
- config: Optional[Dict] = {"region": os.environ["OCI_RESOURCE_REGION"]} if os.environ.get("OCI_RESOURCE_REGION") else {},
91
+ config: Optional[Dict] = (
92
+ {"region": os.environ["OCI_RESOURCE_REGION"]}
93
+ if os.environ.get("OCI_RESOURCE_REGION")
94
+ else {}
95
+ ),
86
96
  signer: Optional[Any] = None,
87
97
  signer_callable: Optional[Callable] = None,
88
- signer_kwargs: Optional[Dict] = {},
89
- client_kwargs: Optional[Dict] = {},
98
+ signer_kwargs: Optional[Dict] = None,
99
+ client_kwargs: Optional[Dict] = None,
90
100
  ) -> None:
91
101
  """
92
102
  Sets the default authentication type.
@@ -187,6 +197,9 @@ def set_auth(
187
197
  >>> # instance principals authentication dictionary created based on callable with kwargs parameters:
188
198
  >>> ads.set_auth(signer_callable=signer_callable, signer_kwargs=signer_kwargs)
189
199
  """
200
+ signer_kwargs = signer_kwargs or {}
201
+ client_kwargs = client_kwargs or {}
202
+
190
203
  auth_state = AuthState()
191
204
 
192
205
  valid_auth_keys = AuthFactory.classes.keys()
@@ -202,8 +215,8 @@ def set_auth(
202
215
  oci_config_location != DEFAULT_LOCATION or profile != DEFAULT_PROFILE
203
216
  ):
204
217
  raise ValueError(
205
- f"'config' and 'oci_config_location', 'profile' pair are mutually exclusive."
206
- f"Please specify 'config' OR 'oci_config_location', 'profile' pair."
218
+ "'config' and 'oci_config_location', 'profile' pair are mutually exclusive."
219
+ "Please specify 'config' OR 'oci_config_location', 'profile' pair."
207
220
  )
208
221
 
209
222
  auth_state.oci_config = config
@@ -250,9 +263,11 @@ def api_keys(
250
263
  """
251
264
  signer_args = dict(
252
265
  oci_config=oci_config if isinstance(oci_config, Dict) else {},
253
- oci_config_location=oci_config
254
- if isinstance(oci_config, str)
255
- else os.path.expanduser(DEFAULT_LOCATION),
266
+ oci_config_location=(
267
+ oci_config
268
+ if isinstance(oci_config, str)
269
+ else os.path.expanduser(DEFAULT_LOCATION)
270
+ ),
256
271
  oci_key_profile=profile,
257
272
  client_kwargs=client_kwargs,
258
273
  )
@@ -326,9 +341,11 @@ def security_token(
326
341
  """
327
342
  signer_args = dict(
328
343
  oci_config=oci_config if isinstance(oci_config, Dict) else {},
329
- oci_config_location=oci_config
330
- if isinstance(oci_config, str)
331
- else os.path.expanduser(DEFAULT_LOCATION),
344
+ oci_config_location=(
345
+ oci_config
346
+ if isinstance(oci_config, str)
347
+ else os.path.expanduser(DEFAULT_LOCATION)
348
+ ),
332
349
  oci_key_profile=profile,
333
350
  client_kwargs=client_kwargs,
334
351
  )
@@ -621,7 +638,7 @@ class APIKey(AuthSignerGenerator):
621
638
  )
622
639
 
623
640
  oci.config.validate_config(configuration)
624
- logger.debug(f"Using 'api_key' authentication.")
641
+ logger.debug("Using 'api_key' authentication.")
625
642
  return {
626
643
  "config": configuration,
627
644
  "signer": oci.signer.Signer(
@@ -684,14 +701,19 @@ class ResourcePrincipal(AuthSignerGenerator):
684
701
  "signer": oci.auth.signers.get_resource_principals_signer(),
685
702
  "client_kwargs": self.client_kwargs,
686
703
  }
687
- logger.debug(f"Using 'resource_principal' authentication.")
704
+ logger.debug("Using 'resource_principal' authentication.")
688
705
  return signer_dict
689
706
 
690
707
  @staticmethod
691
708
  def supported():
692
709
  return any(
693
710
  os.environ.get(var)
694
- for var in ['JOB_RUN_OCID', 'NB_SESSION_OCID', 'DATAFLOW_RUN_ID', 'PIPELINE_RUN_OCID']
711
+ for var in [
712
+ "JOB_RUN_OCID",
713
+ "NB_SESSION_OCID",
714
+ "DATAFLOW_RUN_ID",
715
+ "PIPELINE_RUN_OCID",
716
+ ]
695
717
  )
696
718
 
697
719
 
@@ -747,7 +769,7 @@ class InstancePrincipal(AuthSignerGenerator):
747
769
  ),
748
770
  "client_kwargs": self.client_kwargs,
749
771
  }
750
- logger.debug(f"Using 'instance_principal' authentication.")
772
+ logger.debug("Using 'instance_principal' authentication.")
751
773
  return signer_dict
752
774
 
753
775
 
@@ -814,7 +836,7 @@ class SecurityToken(AuthSignerGenerator):
814
836
  oci.config.from_file(self.oci_config_location, self.oci_key_profile)
815
837
  )
816
838
 
817
- logger.debug(f"Using 'security_token' authentication.")
839
+ logger.debug("Using 'security_token' authentication.")
818
840
 
819
841
  for parameter in self.SECURITY_TOKEN_REQUIRED:
820
842
  if parameter not in configuration:
@@ -903,7 +925,7 @@ class SecurityToken(AuthSignerGenerator):
903
925
  raise ValueError("Invalid `security_token_file`. Specify a valid path.")
904
926
  try:
905
927
  token = None
906
- with open(expanded_path, "r") as f:
928
+ with open(expanded_path) as f:
907
929
  token = f.read()
908
930
  return token
909
931
  except:
@@ -1023,7 +1045,7 @@ class OCIAuthContext:
1023
1045
  logger.debug(f"OCI profile set to {self.profile}")
1024
1046
  else:
1025
1047
  ads.set_auth(auth=AuthType.RESOURCE_PRINCIPAL)
1026
- logger.debug(f"OCI auth set to resource principal")
1048
+ logger.debug("OCI auth set to resource principal")
1027
1049
  return self
1028
1050
 
1029
1051
  def __exit__(self, exc_type, exc_val, exc_tb):
@@ -1,73 +1,81 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
- # Copyright (c) 2022, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2022, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
 
8
7
  from abc import ABCMeta
9
- from enum import Enum
10
8
 
11
9
 
12
10
  class ExtendedEnumMeta(ABCMeta):
13
- """The helper metaclass to extend functionality of a generic Enum.
11
+ """
12
+ A helper metaclass to extend functionality of a generic "Enum-like" class.
14
13
 
15
14
  Methods
16
15
  -------
17
- values(cls) -> list:
18
- Gets the list of class attributes.
19
-
20
- Examples
21
- --------
22
- >>> class TestEnum(str, metaclass=ExtendedEnumMeta):
23
- ... KEY1 = "value1"
24
- ... KEY2 = "value2"
25
- >>> print(TestEnum.KEY1) # "value1"
16
+ __contains__(cls, item) -> bool:
17
+ Checks if `item` is among the attribute values of the class.
18
+ Case-insensitive if `item` is a string.
19
+ values(cls) -> tuple:
20
+ Returns the tuple of class attribute values.
21
+ keys(cls) -> tuple:
22
+ Returns the tuple of class attribute names.
26
23
  """
27
24
 
28
- def __contains__(cls, value):
29
- return value and value.lower() in tuple(value.lower() for value in cls.values())
30
-
31
- def values(cls) -> list:
32
- """Gets the list of class attributes values.
25
+ def __contains__(cls, item: object) -> bool:
26
+ """
27
+ Checks if `item` is a member of the class's values.
33
28
 
34
- Returns
35
- -------
36
- list
37
- The list of class values.
29
+ - If `item` is a string, does a case-insensitive match against any string
30
+ values stored in the class.
31
+ - Otherwise, does a direct membership test.
32
+ """
33
+ # Gather the attribute values
34
+ attr_values = cls.values()
35
+
36
+ # If item is a string, compare case-insensitively to any str-type values
37
+ if isinstance(item, str):
38
+ return any(
39
+ isinstance(val, str) and val.lower() == item.lower()
40
+ for val in attr_values
41
+ )
42
+ else:
43
+ # For non-string items (e.g., int), do a direct membership check
44
+ return item in attr_values
45
+
46
+ def __iter__(cls):
47
+ # Make the class iterable by returning an iterator over its values
48
+ return iter(cls.values())
49
+
50
+ def values(cls) -> tuple:
51
+ """
52
+ Gets the tuple of class attribute values, excluding private or special
53
+ attributes and any callables (methods, etc.).
38
54
  """
39
55
  return tuple(
40
- value for key, value in cls.__dict__.items() if not key.startswith("_")
56
+ value
57
+ for key, value in cls.__dict__.items()
58
+ if not key.startswith("_") and not callable(value)
41
59
  )
42
60
 
43
- def keys(cls) -> list:
44
- """Gets the list of class attributes names.
45
-
46
- Returns
47
- -------
48
- list
49
- The list of class attributes names.
61
+ def keys(cls) -> tuple:
62
+ """
63
+ Gets the tuple of class attribute names, excluding private or special
64
+ attributes and any callables (methods, etc.).
50
65
  """
51
66
  return tuple(
52
- key for key, value in cls.__dict__.items() if not key.startswith("_")
67
+ key
68
+ for key, value in cls.__dict__.items()
69
+ if not key.startswith("_") and not callable(value)
53
70
  )
54
71
 
55
72
 
56
- class ExtendedEnum(Enum):
73
+ class ExtendedEnum(metaclass=ExtendedEnumMeta):
57
74
  """The base class to extend functionality of a generic Enum.
58
75
 
59
76
  Examples
60
77
  --------
61
- >>> class TestEnum(ExtendedEnumMeta):
62
- ... KEY1 = "value1"
63
- ... KEY2 = "value2"
64
- >>> print(TestEnum.KEY1.value) # "value1"
78
+ >>> class TestEnum(ExtendedEnum):
79
+ ... KEY1 = "v1"
80
+ ... KEY2 = "v2"
65
81
  """
66
-
67
- @classmethod
68
- def values(cls):
69
- return sorted(map(lambda c: c.value, cls))
70
-
71
- @classmethod
72
- def keys(cls):
73
- return sorted(map(lambda c: c.name, cls))