arize 8.0.0a23__py3-none-any.whl → 8.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +11 -10
- arize/_exporter/client.py +1 -1
- arize/_generated/api_client/__init__.py +0 -2
- arize/_generated/api_client/models/__init__.py +0 -1
- arize/_generated/api_client/models/datasets_create_request.py +2 -10
- arize/_generated/api_client/models/datasets_examples_insert_request.py +2 -10
- arize/_generated/api_client/test/test_datasets_create_request.py +2 -6
- arize/_generated/api_client/test/test_datasets_examples_insert_request.py +2 -6
- arize/_generated/api_client/test/test_datasets_examples_list200_response.py +2 -6
- arize/_generated/api_client/test/test_datasets_examples_update_request.py +2 -6
- arize/_generated/api_client/test/test_experiments_create_request.py +2 -6
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +2 -6
- arize/_generated/api_client_README.md +0 -1
- arize/client.py +47 -163
- arize/config.py +59 -100
- arize/datasets/client.py +11 -6
- arize/embeddings/nlp_generators.py +12 -6
- arize/embeddings/tabular_generators.py +14 -11
- arize/experiments/__init__.py +12 -0
- arize/experiments/client.py +13 -9
- arize/experiments/functions.py +6 -6
- arize/experiments/types.py +3 -3
- arize/{models → ml}/batch_validation/errors.py +2 -2
- arize/{models → ml}/batch_validation/validator.py +5 -3
- arize/{models → ml}/casting.py +42 -78
- arize/{models → ml}/client.py +19 -17
- arize/{models → ml}/proto.py +2 -2
- arize/{models → ml}/stream_validation.py +1 -1
- arize/{models → ml}/surrogate_explainer/mimic.py +6 -2
- arize/{types.py → ml/types.py} +99 -234
- arize/pre_releases.py +2 -1
- arize/projects/client.py +11 -6
- arize/spans/client.py +91 -86
- arize/spans/conversion.py +11 -4
- arize/spans/validation/common/value_validation.py +1 -1
- arize/spans/validation/spans/dataframe_form_validation.py +1 -1
- arize/spans/validation/spans/value_validation.py +2 -1
- arize/utils/dataframe.py +1 -1
- arize/utils/online_tasks/dataframe_preprocessor.py +5 -6
- arize/utils/types.py +105 -0
- arize/version.py +1 -1
- {arize-8.0.0a23.dist-info → arize-8.0.0b1.dist-info}/METADATA +56 -59
- {arize-8.0.0a23.dist-info → arize-8.0.0b1.dist-info}/RECORD +50 -51
- arize/_generated/api_client/models/primitive_value.py +0 -172
- arize/_generated/api_client/test/test_primitive_value.py +0 -50
- /arize/{models → ml}/__init__.py +0 -0
- /arize/{models → ml}/batch_validation/__init__.py +0 -0
- /arize/{models → ml}/bounded_executor.py +0 -0
- /arize/{models → ml}/surrogate_explainer/__init__.py +0 -0
- {arize-8.0.0a23.dist-info → arize-8.0.0b1.dist-info}/WHEEL +0 -0
- {arize-8.0.0a23.dist-info → arize-8.0.0b1.dist-info}/licenses/LICENSE +0 -0
- {arize-8.0.0a23.dist-info → arize-8.0.0b1.dist-info}/licenses/NOTICE +0 -0
arize/spans/client.py
CHANGED
|
@@ -26,7 +26,7 @@ from arize.exceptions.base import (
|
|
|
26
26
|
from arize.exceptions.models import MissingProjectNameError
|
|
27
27
|
from arize.exceptions.spaces import MissingSpaceIDError
|
|
28
28
|
from arize.logging import CtxAdapter
|
|
29
|
-
from arize.types import Environments
|
|
29
|
+
from arize.ml.types import Environments
|
|
30
30
|
from arize.utils.arrow import post_arrow_table
|
|
31
31
|
from arize.utils.dataframe import (
|
|
32
32
|
remove_extraneous_columns,
|
|
@@ -44,14 +44,18 @@ logger = logging.getLogger(__name__)
|
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
class SpansClient:
|
|
47
|
-
"""Client for logging LLM tracing spans and evaluations to Arize.
|
|
47
|
+
"""Client for logging LLM tracing spans and evaluations to Arize.
|
|
48
48
|
|
|
49
|
-
|
|
50
|
-
|
|
49
|
+
This class is primarily intended for internal use within the SDK. Users are
|
|
50
|
+
highly encouraged to access resource-specific functionality via
|
|
51
|
+
:class:`arize.ArizeClient`.
|
|
52
|
+
"""
|
|
51
53
|
|
|
52
|
-
|
|
53
|
-
sdk_config: SDK configuration containing API endpoints and credentials.
|
|
54
|
+
def __init__(self, *, sdk_config: SDKConfiguration) -> None:
|
|
54
55
|
"""
|
|
56
|
+
Args:
|
|
57
|
+
sdk_config: Resolved SDK configuration.
|
|
58
|
+
""" # noqa: D205, D212
|
|
55
59
|
self._sdk_config = sdk_config
|
|
56
60
|
|
|
57
61
|
def log(
|
|
@@ -72,23 +76,23 @@ class SpansClient:
|
|
|
72
76
|
successful delivery of records.
|
|
73
77
|
|
|
74
78
|
Args:
|
|
75
|
-
space_id
|
|
76
|
-
project_name
|
|
77
|
-
dataframe
|
|
78
|
-
evals_dataframe
|
|
79
|
+
space_id: The space ID where the project resides.
|
|
80
|
+
project_name: A unique name to identify your project in the Arize platform.
|
|
81
|
+
dataframe: The dataframe containing the LLM traces.
|
|
82
|
+
evals_dataframe: A dataframe containing LLM evaluations data.
|
|
79
83
|
The evaluations are joined to their corresponding spans via a left outer join, i.e.,
|
|
80
84
|
using only `context.span_id` from the spans dataframe. Defaults to None.
|
|
81
|
-
datetime_format
|
|
85
|
+
datetime_format: format for the timestamp captured in the LLM traces.
|
|
82
86
|
Defaults to "%Y-%m-%dT%H:%M:%S.%f+00:00".
|
|
83
|
-
validate
|
|
87
|
+
validate: When set to True, validation is run before sending data.
|
|
84
88
|
Defaults to True.
|
|
85
|
-
timeout
|
|
89
|
+
timeout: You can stop waiting for a response after a given number
|
|
86
90
|
of seconds with the timeout parameter. Defaults to None.
|
|
87
|
-
tmp_dir
|
|
91
|
+
tmp_dir: Temporary directory/file to store the serialized data in binary
|
|
88
92
|
before sending to Arize.
|
|
89
93
|
|
|
90
94
|
Returns:
|
|
91
|
-
|
|
95
|
+
Response object from the HTTP request.
|
|
92
96
|
|
|
93
97
|
"""
|
|
94
98
|
from arize.spans.columns import (
|
|
@@ -274,15 +278,15 @@ class SpansClient:
|
|
|
274
278
|
each evaluation to its respective span.
|
|
275
279
|
|
|
276
280
|
Args:
|
|
277
|
-
space_id
|
|
278
|
-
project_name
|
|
279
|
-
dataframe
|
|
280
|
-
validate
|
|
281
|
+
space_id: The space ID where the project resides.
|
|
282
|
+
project_name: A unique name to identify your project in the Arize platform.
|
|
283
|
+
dataframe: A dataframe containing LLM evaluations data.
|
|
284
|
+
validate: When set to True, validation is run before sending data.
|
|
281
285
|
Defaults to True.
|
|
282
|
-
force_http
|
|
283
|
-
timeout
|
|
286
|
+
force_http: Force the use of HTTP for data upload. Defaults to False.
|
|
287
|
+
timeout: You can stop waiting for a response after a given number
|
|
284
288
|
of seconds with the timeout parameter. Defaults to None.
|
|
285
|
-
tmp_dir
|
|
289
|
+
tmp_dir: Temporary directory/file to store the serialized data in binary
|
|
286
290
|
before sending to Arize.
|
|
287
291
|
"""
|
|
288
292
|
from arize.spans.columns import EVAL_COLUMN_PATTERN, SPAN_SPAN_ID_COL
|
|
@@ -447,10 +451,10 @@ class SpansClient:
|
|
|
447
451
|
`annotation.notes` column can be included for free-form text notes.
|
|
448
452
|
|
|
449
453
|
Args:
|
|
450
|
-
space_id
|
|
451
|
-
project_name
|
|
452
|
-
dataframe
|
|
453
|
-
validate
|
|
454
|
+
space_id: The space ID where the project resides.
|
|
455
|
+
project_name: A unique name to identify your project in the Arize platform.
|
|
456
|
+
dataframe: A dataframe containing LLM annotation data.
|
|
457
|
+
validate: When set to True, validation is run before sending data.
|
|
454
458
|
Defaults to True.
|
|
455
459
|
"""
|
|
456
460
|
from arize.spans.columns import (
|
|
@@ -661,6 +665,7 @@ class SpansClient:
|
|
|
661
665
|
This method is only supported for LLM model types.
|
|
662
666
|
|
|
663
667
|
The dataframe must contain a column `context.span_id` to identify spans and either:
|
|
668
|
+
|
|
664
669
|
1. A column with JSON patch documents (specified by patch_document_column_name), or
|
|
665
670
|
2. One or more columns with prefix `attributes.metadata.` that will be automatically
|
|
666
671
|
converted to a patch document (e.g., `attributes.metadata.tag` → `{"tag": value}`).
|
|
@@ -668,7 +673,8 @@ class SpansClient:
|
|
|
668
673
|
If both methods are used, the explicit patch document is applied after the individual field updates.
|
|
669
674
|
The patches will be applied to the `attributes.metadata` field of each span.
|
|
670
675
|
|
|
671
|
-
|
|
676
|
+
Type Handling:
|
|
677
|
+
|
|
672
678
|
- The client primarily supports string, integer, and float data types.
|
|
673
679
|
- Boolean values are converted to string representations.
|
|
674
680
|
- Nested JSON objects and arrays are serialized to JSON strings during transmission.
|
|
@@ -685,12 +691,14 @@ class SpansClient:
|
|
|
685
691
|
|
|
686
692
|
Returns:
|
|
687
693
|
Dictionary containing update results with the following keys:
|
|
694
|
+
|
|
688
695
|
- spans_processed: Total number of spans in the input dataframe
|
|
689
696
|
- spans_updated: Count of successfully updated span metadata records
|
|
690
697
|
- spans_failed: Count of spans that failed to update
|
|
691
698
|
- errors: List of dictionaries with 'span_id' and 'error_message' keys for each failed span
|
|
692
699
|
|
|
693
|
-
|
|
700
|
+
Error types from the server include:
|
|
701
|
+
|
|
694
702
|
- parse_failure: Failed to parse JSON metadata
|
|
695
703
|
- patch_failure: Failed to apply JSON patch
|
|
696
704
|
- type_conflict: Type conflict in metadata
|
|
@@ -699,58 +707,60 @@ class SpansClient:
|
|
|
699
707
|
- druid_rejection: Backend rejected the update
|
|
700
708
|
|
|
701
709
|
Raises:
|
|
702
|
-
AuthError: When API key or space ID is missing
|
|
703
|
-
ValidationFailure: When validation of the dataframe or values fails
|
|
704
|
-
ImportError: When required tracing dependencies are missing
|
|
705
|
-
ArrowInvalid: When the dataframe cannot be converted to Arrow format
|
|
706
|
-
RuntimeError: If the request fails or no response is received
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
df = pd.DataFrame(
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
)
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
710
|
+
AuthError: When API key or space ID is missing.
|
|
711
|
+
ValidationFailure: When validation of the dataframe or values fails.
|
|
712
|
+
ImportError: When required tracing dependencies are missing.
|
|
713
|
+
ArrowInvalid: When the dataframe cannot be converted to Arrow format.
|
|
714
|
+
RuntimeError: If the request fails or no response is received.
|
|
715
|
+
|
|
716
|
+
Examples:
|
|
717
|
+
Method 1: Using a patch document
|
|
718
|
+
|
|
719
|
+
>>> df = pd.DataFrame(
|
|
720
|
+
... {
|
|
721
|
+
... "context.span_id": ["span1", "span2"],
|
|
722
|
+
... "patch_document": [
|
|
723
|
+
... {"tag": "important"},
|
|
724
|
+
... {"priority": "high"},
|
|
725
|
+
... ],
|
|
726
|
+
... }
|
|
727
|
+
... )
|
|
728
|
+
|
|
729
|
+
Method 2: Using direct field columns
|
|
730
|
+
|
|
731
|
+
>>> df = pd.DataFrame(
|
|
732
|
+
... {
|
|
733
|
+
... "context.span_id": ["span1", "span2"],
|
|
734
|
+
... "attributes.metadata.tag": ["important", "standard"],
|
|
735
|
+
... "attributes.metadata.priority": ["high", "medium"],
|
|
736
|
+
... }
|
|
737
|
+
... )
|
|
738
|
+
|
|
739
|
+
Method 3: Combining both approaches
|
|
740
|
+
|
|
741
|
+
>>> df = pd.DataFrame(
|
|
742
|
+
... {
|
|
743
|
+
... "context.span_id": ["span1"],
|
|
744
|
+
... "attributes.metadata.tag": ["important"],
|
|
745
|
+
... "patch_document": [
|
|
746
|
+
... {"priority": "high"}
|
|
747
|
+
... ], # Overrides conflicting fields
|
|
748
|
+
... }
|
|
749
|
+
... )
|
|
750
|
+
|
|
751
|
+
Method 4: Setting fields to null
|
|
752
|
+
|
|
753
|
+
>>> df = pd.DataFrame(
|
|
754
|
+
... {
|
|
755
|
+
... "context.span_id": ["span1"],
|
|
756
|
+
... "attributes.metadata.old_field": [
|
|
757
|
+
... None
|
|
758
|
+
... ], # Sets field to JSON null
|
|
759
|
+
... "patch_document": [
|
|
760
|
+
... {"other_field": None}
|
|
761
|
+
... ], # Also sets field to JSON null
|
|
762
|
+
... }
|
|
763
|
+
... )
|
|
754
764
|
"""
|
|
755
765
|
# Import validation modules
|
|
756
766
|
from arize.spans.columns import SPAN_SPAN_ID_COL
|
|
@@ -992,7 +1002,6 @@ class SpansClient:
|
|
|
992
1002
|
end_time: datetime,
|
|
993
1003
|
where: str = "",
|
|
994
1004
|
columns: list | None = None,
|
|
995
|
-
similarity_search_params: SimilaritySearchParams | None = None,
|
|
996
1005
|
stream_chunk_size: int | None = None,
|
|
997
1006
|
) -> pd.DataFrame:
|
|
998
1007
|
"""Export span data from Arize to a pandas DataFrame.
|
|
@@ -1002,8 +1011,7 @@ class SpansClient:
|
|
|
1002
1011
|
WHERE clauses and similarity search for semantic retrieval.
|
|
1003
1012
|
|
|
1004
1013
|
Returns:
|
|
1005
|
-
|
|
1006
|
-
pd.DataFrame: DataFrame containing the requested span data with columns
|
|
1014
|
+
DataFrame containing the requested span data with columns
|
|
1007
1015
|
for span metadata, attributes, events, and any custom fields.
|
|
1008
1016
|
"""
|
|
1009
1017
|
with ArizeFlightClient(
|
|
@@ -1025,7 +1033,6 @@ class SpansClient:
|
|
|
1025
1033
|
end_time=end_time,
|
|
1026
1034
|
where=where,
|
|
1027
1035
|
columns=columns,
|
|
1028
|
-
similarity_search_params=similarity_search_params,
|
|
1029
1036
|
stream_chunk_size=stream_chunk_size,
|
|
1030
1037
|
)
|
|
1031
1038
|
|
|
@@ -1039,7 +1046,6 @@ class SpansClient:
|
|
|
1039
1046
|
end_time: datetime,
|
|
1040
1047
|
where: str = "",
|
|
1041
1048
|
columns: list | None = None,
|
|
1042
|
-
similarity_search_params: SimilaritySearchParams | None = None,
|
|
1043
1049
|
stream_chunk_size: int | None = None,
|
|
1044
1050
|
) -> None:
|
|
1045
1051
|
"""Export span data from Arize to a Parquet file.
|
|
@@ -1069,7 +1075,6 @@ class SpansClient:
|
|
|
1069
1075
|
end_time=end_time,
|
|
1070
1076
|
where=where,
|
|
1071
1077
|
columns=columns,
|
|
1072
|
-
similarity_search_params=similarity_search_params,
|
|
1073
1078
|
stream_chunk_size=stream_chunk_size,
|
|
1074
1079
|
)
|
|
1075
1080
|
|
arize/spans/conversion.py
CHANGED
|
@@ -35,10 +35,17 @@ def convert_timestamps(df: pd.DataFrame, fmt: str = "") -> pd.DataFrame:
|
|
|
35
35
|
|
|
36
36
|
def _datetime_to_ns(dt: object, fmt: str) -> int:
|
|
37
37
|
if isinstance(dt, str):
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
38
|
+
# Try ISO 8601 with timezone first
|
|
39
|
+
try:
|
|
40
|
+
parsed = datetime.fromisoformat(dt)
|
|
41
|
+
if parsed.tzinfo is None:
|
|
42
|
+
# If no timezone, assume UTC
|
|
43
|
+
parsed = parsed.replace(tzinfo=timezone.utc)
|
|
44
|
+
except ValueError:
|
|
45
|
+
# Fall back to custom format
|
|
46
|
+
parsed = datetime.strptime(dt, fmt).replace(tzinfo=timezone.utc)
|
|
47
|
+
|
|
48
|
+
return int(parsed.timestamp() * 1e9)
|
|
42
49
|
if isinstance(dt, datetime):
|
|
43
50
|
return int(datetime.timestamp(dt) * 1e9)
|
|
44
51
|
if isinstance(dt, pd.Timestamp):
|
|
@@ -25,7 +25,7 @@ from arize.spans.validation.common.errors import (
|
|
|
25
25
|
InvalidStringValueNotAllowedInColumn,
|
|
26
26
|
InvalidTimestampValueInColumn,
|
|
27
27
|
)
|
|
28
|
-
from arize.types import is_json_str
|
|
28
|
+
from arize.utils.types import is_json_str
|
|
29
29
|
|
|
30
30
|
logger = logging.getLogger(__name__)
|
|
31
31
|
|
|
@@ -12,7 +12,7 @@ from arize.spans.conversion import is_missing_value
|
|
|
12
12
|
from arize.spans.validation.common.errors import (
|
|
13
13
|
InvalidDataFrameColumnContentTypes,
|
|
14
14
|
)
|
|
15
|
-
from arize.types import is_array_of, is_dict_of, is_list_of
|
|
15
|
+
from arize.utils.types import is_array_of, is_dict_of, is_list_of
|
|
16
16
|
|
|
17
17
|
logger = logging.getLogger(__name__)
|
|
18
18
|
|
|
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
|
|
|
7
7
|
|
|
8
8
|
from arize.constants import spans as tracing_constants
|
|
9
9
|
from arize.constants.ml import MAX_EMBEDDING_DIMENSIONALITY
|
|
10
|
+
from arize.ml.types import StatusCodes
|
|
10
11
|
from arize.spans import columns as tracing_cols
|
|
11
12
|
from arize.spans.validation.common import value_validation
|
|
12
13
|
from arize.spans.validation.common.errors import (
|
|
@@ -15,7 +16,7 @@ from arize.spans.validation.common.errors import (
|
|
|
15
16
|
InvalidEventValueInColumn,
|
|
16
17
|
InvalidLLMMessageValueInColumn,
|
|
17
18
|
)
|
|
18
|
-
from arize.types import
|
|
19
|
+
from arize.utils.types import is_dict_of, is_json_str
|
|
19
20
|
|
|
20
21
|
if TYPE_CHECKING:
|
|
21
22
|
import pandas as pd
|
arize/utils/dataframe.py
CHANGED
|
@@ -122,7 +122,7 @@ def extract_nested_data_to_column(
|
|
|
122
122
|
def _introspect_arize_attribute(value: object, attribute: str) -> object:
|
|
123
123
|
"""Recursively drill into `value` following the dot-delimited `attribute`.
|
|
124
124
|
|
|
125
|
-
|
|
125
|
+
Examples:
|
|
126
126
|
value: [{'message.role': 'assistant', 'message.content': 'The capital of China is Beijing.'}]
|
|
127
127
|
attribute: "0.message.content"
|
|
128
128
|
Returns: 'The capital of China is Beijing.'
|
|
@@ -132,7 +132,6 @@ def _introspect_arize_attribute(value: object, attribute: str) -> object:
|
|
|
132
132
|
- Parses JSON strings
|
|
133
133
|
- Converts NumPy arrays to lists
|
|
134
134
|
- Allows dotted keys (e.g. "message.content") by combining parts
|
|
135
|
-
|
|
136
135
|
"""
|
|
137
136
|
if not attribute:
|
|
138
137
|
return value
|
|
@@ -195,10 +194,10 @@ def _parse_value(
|
|
|
195
194
|
idx = _try_int(key)
|
|
196
195
|
if idx is not None:
|
|
197
196
|
# Must be a tuple or list (_ensure_deserialized() already casts numpy arrays to python lists)
|
|
198
|
-
if isinstance(current_value,
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
return (
|
|
197
|
+
if isinstance(current_value, list | tuple) and 0 <= idx < len(
|
|
198
|
+
current_value
|
|
199
|
+
):
|
|
200
|
+
return (current_value[idx], num_parts_processed)
|
|
202
201
|
return (None, num_parts_processed)
|
|
203
202
|
|
|
204
203
|
# 2) Try dict approach
|
arize/utils/types.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""Common type definitions and data models used across the Arize SDK."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import Iterable, Sequence
|
|
5
|
+
from typing import (
|
|
6
|
+
TypeVar,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def is_json_str(s: str) -> bool:
|
|
13
|
+
"""Check if a string is valid JSON.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
s: The string to validate.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
True if the string is valid JSON, False otherwise.
|
|
20
|
+
"""
|
|
21
|
+
try:
|
|
22
|
+
json.loads(s)
|
|
23
|
+
except ValueError:
|
|
24
|
+
return False
|
|
25
|
+
except TypeError:
|
|
26
|
+
return False
|
|
27
|
+
return True
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
T = TypeVar("T", bound=type)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def is_array_of(arr: Sequence[object], tp: T) -> bool:
|
|
34
|
+
"""Check if a value is a numpy array with all elements of a specific type.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
arr: The sequence to check.
|
|
38
|
+
tp: The expected type for all elements.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
True if arr is a numpy array and all elements are of type tp.
|
|
42
|
+
"""
|
|
43
|
+
return isinstance(arr, np.ndarray) and all(isinstance(x, tp) for x in arr)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def is_list_of(lst: Sequence[object], tp: T) -> bool:
|
|
47
|
+
"""Check if a value is a list with all elements of a specific type.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
lst: The sequence to check.
|
|
51
|
+
tp: The expected type for all elements.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
True if lst is a list and all elements are of type tp.
|
|
55
|
+
"""
|
|
56
|
+
return isinstance(lst, list) and all(isinstance(x, tp) for x in lst)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def is_iterable_of(lst: Sequence[object], tp: T) -> bool:
|
|
60
|
+
"""Check if a value is an iterable with all elements of a specific type.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
lst: The sequence to check.
|
|
64
|
+
tp: The expected type for all elements.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
True if lst is an iterable and all elements are of type tp.
|
|
68
|
+
"""
|
|
69
|
+
return isinstance(lst, Iterable) and all(isinstance(x, tp) for x in lst)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def is_dict_of(
|
|
73
|
+
d: dict[object, object],
|
|
74
|
+
key_allowed_types: T,
|
|
75
|
+
value_allowed_types: T = (),
|
|
76
|
+
value_list_allowed_types: T = (),
|
|
77
|
+
) -> bool:
|
|
78
|
+
"""Method to check types are valid for dictionary.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
d: Dictionary itself.
|
|
82
|
+
key_allowed_types: All allowed types for keys of dictionary.
|
|
83
|
+
value_allowed_types: All allowed types for values of dictionary.
|
|
84
|
+
value_list_allowed_types: If value is a list, these are the allowed
|
|
85
|
+
types for value list.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
True if the data types of dictionary match the types specified by the
|
|
89
|
+
arguments, false otherwise.
|
|
90
|
+
"""
|
|
91
|
+
if value_list_allowed_types and not isinstance(
|
|
92
|
+
value_list_allowed_types, tuple
|
|
93
|
+
):
|
|
94
|
+
value_list_allowed_types = (value_list_allowed_types,)
|
|
95
|
+
|
|
96
|
+
return (
|
|
97
|
+
isinstance(d, dict)
|
|
98
|
+
and all(isinstance(k, key_allowed_types) for k in d)
|
|
99
|
+
and all(
|
|
100
|
+
isinstance(v, value_allowed_types)
|
|
101
|
+
or any(is_list_of(v, t) for t in value_list_allowed_types)
|
|
102
|
+
for v in d.values()
|
|
103
|
+
if value_allowed_types or value_list_allowed_types
|
|
104
|
+
)
|
|
105
|
+
)
|
arize/version.py
CHANGED