arize-phoenix 3.19.4__py3-none-any.whl → 3.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

@@ -0,0 +1,730 @@
1
+ import logging
2
+ import re
3
+ import uuid
4
+ from copy import deepcopy
5
+ from dataclasses import dataclass, fields, replace
6
+ from enum import Enum
7
+ from itertools import groupby
8
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from pandas import DataFrame, Series, Timestamp, read_parquet
13
+ from pandas.api.types import (
14
+ is_numeric_dtype,
15
+ )
16
+ from typing_extensions import TypeAlias
17
+
18
+ from phoenix.config import DATASET_DIR, GENERATED_DATASET_NAME_PREFIX
19
+ from phoenix.datetime_utils import normalize_timestamps
20
+ from phoenix.utilities.deprecation import deprecated
21
+
22
+ from . import errors as err
23
+ from .schema import (
24
+ LLM_SCHEMA_FIELD_NAMES,
25
+ MULTI_COLUMN_SCHEMA_FIELD_NAMES,
26
+ SINGLE_COLUMN_SCHEMA_FIELD_NAMES,
27
+ EmbeddingColumnNames,
28
+ EmbeddingFeatures,
29
+ RetrievalEmbeddingColumnNames,
30
+ Schema,
31
+ SchemaFieldName,
32
+ SchemaFieldValue,
33
+ )
34
+ from .validation import validate_dataset_inputs
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ # A schema like object. Not recommended to use this directly
39
+ SchemaLike: TypeAlias = Any
40
+
41
+
42
+ class Inferences:
43
+ """
44
+ A dataset to use for analysis using phoenix.
45
+ Used to construct a phoenix session via px.launch_app
46
+
47
+ Parameters
48
+ ----------
49
+ dataframe : pandas.DataFrame
50
+ The pandas dataframe containing the data to analyze
51
+ schema : phoenix.Schema
52
+ the schema of the dataset. Maps dataframe columns to the appropriate
53
+ model inference dimensions (features, predictions, actuals).
54
+ name : str, optional
55
+ The name of the dataset. If not provided, a random name will be generated.
56
+ Is helpful for identifying the dataset in the application.
57
+
58
+ Returns
59
+ -------
60
+ dataset : Dataset
61
+ The dataset object that can be used in a phoenix session
62
+
63
+ Examples
64
+ --------
65
+ >>> primary_dataset = px.Inferences(
66
+ >>> dataframe=production_dataframe, schema=schema, name="primary"
67
+ >>> )
68
+ """
69
+
70
+ _data_file_name: str = "data.parquet"
71
+ _schema_file_name: str = "schema.json"
72
+ _is_persisted: bool = False
73
+ _is_empty: bool = False
74
+
75
+ def __init__(
76
+ self,
77
+ dataframe: DataFrame,
78
+ schema: Union[Schema, SchemaLike],
79
+ name: Optional[str] = None,
80
+ ):
81
+ # allow for schema like objects
82
+ if not isinstance(schema, Schema):
83
+ schema = _get_schema_from_unknown_schema_param(schema)
84
+ errors = validate_dataset_inputs(
85
+ dataframe=dataframe,
86
+ schema=schema,
87
+ )
88
+ if errors:
89
+ raise err.DatasetError(errors)
90
+ dataframe, schema = _parse_dataframe_and_schema(dataframe, schema)
91
+ dataframe, schema = _normalize_timestamps(
92
+ dataframe, schema, default_timestamp=Timestamp.utcnow()
93
+ )
94
+ dataframe = _sort_dataframe_rows_by_timestamp(dataframe, schema)
95
+ self.__dataframe: DataFrame = dataframe
96
+ self.__schema: Schema = schema
97
+ self.__name: str = (
98
+ name if name is not None else f"{GENERATED_DATASET_NAME_PREFIX}{str(uuid.uuid4())}"
99
+ )
100
+ self._is_empty = self.dataframe.empty
101
+ logger.info(f"""Dataset: {self.__name} initialized""")
102
+
103
+ def __repr__(self) -> str:
104
+ return f'<Dataset "{self.name}">'
105
+
106
+ @property
107
+ def dataframe(self) -> DataFrame:
108
+ return self.__dataframe
109
+
110
+ @property
111
+ def schema(self) -> "Schema":
112
+ return self.__schema
113
+
114
+ @property
115
+ def name(self) -> str:
116
+ return self.__name
117
+
118
+ @classmethod
119
+ def from_name(cls, name: str) -> "Inferences":
120
+ """Retrieves a dataset by name from the file system"""
121
+ directory = DATASET_DIR / name
122
+ df = read_parquet(directory / cls._data_file_name)
123
+ with open(directory / cls._schema_file_name) as schema_file:
124
+ schema_json = schema_file.read()
125
+ schema = Schema.from_json(schema_json)
126
+ return cls(df, schema, name)
127
+
128
+ def to_disc(self) -> None:
129
+ """writes the data and schema to disc"""
130
+ directory = DATASET_DIR / self.name
131
+ directory.mkdir(parents=True, exist_ok=True)
132
+ self.dataframe.to_parquet(
133
+ directory / self._data_file_name,
134
+ allow_truncated_timestamps=True,
135
+ coerce_timestamps="ms",
136
+ )
137
+ schema_json_data = self.schema.to_json()
138
+ with open(directory / self._schema_file_name, "w+") as schema_file:
139
+ schema_file.write(schema_json_data)
140
+
141
+ @classmethod
142
+ @deprecated("Inferences.from_open_inference is deprecated and will be removed.")
143
+ def from_open_inference(cls, dataframe: DataFrame) -> "Inferences":
144
+ schema = Schema()
145
+ column_renaming: Dict[str, str] = {}
146
+ for group_name, group in groupby(
147
+ sorted(
148
+ map(_parse_open_inference_column_name, dataframe.columns),
149
+ key=lambda column: column.name,
150
+ ),
151
+ key=lambda column: column.name,
152
+ ):
153
+ open_inference_columns = list(group)
154
+ if group_name == "":
155
+ column_names_by_category = {
156
+ column.category: column.full_name for column in open_inference_columns
157
+ }
158
+ schema = replace(
159
+ schema,
160
+ prediction_id_column_name=column_names_by_category.get(
161
+ OpenInferenceCategory.id
162
+ ),
163
+ timestamp_column_name=column_names_by_category.get(
164
+ OpenInferenceCategory.timestamp
165
+ ),
166
+ )
167
+ continue
168
+ column_names_by_specifier = {
169
+ column.specifier: column.full_name for column in open_inference_columns
170
+ }
171
+ if group_name == "response":
172
+ response_vector_column_name = column_names_by_specifier.get(
173
+ OpenInferenceSpecifier.embedding
174
+ )
175
+ if response_vector_column_name is not None:
176
+ column_renaming[response_vector_column_name] = "response"
177
+ schema = replace(
178
+ schema,
179
+ response_column_names=EmbeddingColumnNames(
180
+ vector_column_name=column_renaming[response_vector_column_name],
181
+ raw_data_column_name=column_names_by_specifier.get(
182
+ OpenInferenceSpecifier.default
183
+ ),
184
+ ),
185
+ )
186
+ else:
187
+ response_text_column_name = column_names_by_specifier.get(
188
+ OpenInferenceSpecifier.default
189
+ )
190
+ if response_text_column_name is None:
191
+ raise ValueError(
192
+ "invalid OpenInference format: missing text column for response"
193
+ )
194
+ column_renaming[response_text_column_name] = "response"
195
+ schema = replace(
196
+ schema,
197
+ response_column_names=column_renaming[response_text_column_name],
198
+ )
199
+ elif group_name == "prompt":
200
+ prompt_vector_column_name = column_names_by_specifier.get(
201
+ OpenInferenceSpecifier.embedding
202
+ )
203
+ if prompt_vector_column_name is None:
204
+ raise ValueError(
205
+ "invalid OpenInference format: missing embedding vector column for prompt"
206
+ )
207
+ column_renaming[prompt_vector_column_name] = "prompt"
208
+ schema = replace(
209
+ schema,
210
+ prompt_column_names=RetrievalEmbeddingColumnNames(
211
+ vector_column_name=column_renaming[prompt_vector_column_name],
212
+ raw_data_column_name=column_names_by_specifier.get(
213
+ OpenInferenceSpecifier.default
214
+ ),
215
+ context_retrieval_ids_column_name=column_names_by_specifier.get(
216
+ OpenInferenceSpecifier.retrieved_document_ids
217
+ ),
218
+ context_retrieval_scores_column_name=column_names_by_specifier.get(
219
+ OpenInferenceSpecifier.retrieved_document_scores
220
+ ),
221
+ ),
222
+ )
223
+ elif OpenInferenceSpecifier.embedding in column_names_by_specifier:
224
+ vector_column_name = column_names_by_specifier[OpenInferenceSpecifier.embedding]
225
+ column_renaming[vector_column_name] = group_name
226
+ embedding_feature_column_names = schema.embedding_feature_column_names or {}
227
+ embedding_feature_column_names.update(
228
+ {
229
+ group_name: EmbeddingColumnNames(
230
+ vector_column_name=column_renaming[vector_column_name],
231
+ raw_data_column_name=column_names_by_specifier.get(
232
+ OpenInferenceSpecifier.raw_data
233
+ ),
234
+ link_to_data_column_name=column_names_by_specifier.get(
235
+ OpenInferenceSpecifier.link_to_data
236
+ ),
237
+ )
238
+ }
239
+ )
240
+ schema = replace(
241
+ schema,
242
+ embedding_feature_column_names=embedding_feature_column_names,
243
+ )
244
+ elif len(open_inference_columns) == 1:
245
+ open_inference_column = open_inference_columns[0]
246
+ raw_column_name = open_inference_column.full_name
247
+ column_renaming[raw_column_name] = open_inference_column.name
248
+ if open_inference_column.category is OpenInferenceCategory.feature:
249
+ schema = replace(
250
+ schema,
251
+ feature_column_names=(
252
+ (schema.feature_column_names or []) + [column_renaming[raw_column_name]]
253
+ ),
254
+ )
255
+ elif open_inference_column.category is OpenInferenceCategory.tag:
256
+ schema = replace(
257
+ schema,
258
+ tag_column_names=(
259
+ (schema.tag_column_names or []) + [column_renaming[raw_column_name]]
260
+ ),
261
+ )
262
+ elif open_inference_column.category is OpenInferenceCategory.prediction:
263
+ if open_inference_column.specifier is OpenInferenceSpecifier.score:
264
+ schema = replace(
265
+ schema,
266
+ prediction_score_column_name=column_renaming[raw_column_name],
267
+ )
268
+ if open_inference_column.specifier is OpenInferenceSpecifier.label:
269
+ schema = replace(
270
+ schema,
271
+ prediction_label_column_name=column_renaming[raw_column_name],
272
+ )
273
+ elif open_inference_column.category is OpenInferenceCategory.actual:
274
+ if open_inference_column.specifier is OpenInferenceSpecifier.score:
275
+ schema = replace(
276
+ schema,
277
+ actual_score_column_name=column_renaming[raw_column_name],
278
+ )
279
+ if open_inference_column.specifier is OpenInferenceSpecifier.label:
280
+ schema = replace(
281
+ schema,
282
+ actual_label_column_name=column_renaming[raw_column_name],
283
+ )
284
+ else:
285
+ raise ValueError(f"invalid OpenInference format: duplicated name `{group_name}`")
286
+
287
+ return cls(
288
+ dataframe.rename(
289
+ column_renaming,
290
+ axis=1,
291
+ copy=False,
292
+ ),
293
+ schema,
294
+ )
295
+
296
+
297
+ class OpenInferenceCategory(Enum):
298
+ id = "id"
299
+ timestamp = "timestamp"
300
+ feature = "feature"
301
+ tag = "tag"
302
+ prediction = "prediction"
303
+ actual = "actual"
304
+
305
+
306
+ class OpenInferenceSpecifier(Enum):
307
+ default = ""
308
+ score = "score"
309
+ label = "label"
310
+ embedding = "embedding"
311
+ raw_data = "raw_data"
312
+ link_to_data = "link_to_data"
313
+ retrieved_document_ids = "retrieved_document_ids"
314
+ retrieved_document_scores = "retrieved_document_scores"
315
+
316
+
317
+ @dataclass(frozen=True)
318
+ class _OpenInferenceColumnName:
319
+ full_name: str
320
+ category: OpenInferenceCategory
321
+ data_type: str
322
+ specifier: OpenInferenceSpecifier = OpenInferenceSpecifier.default
323
+ name: str = ""
324
+
325
+
326
+ def _parse_open_inference_column_name(column_name: str) -> _OpenInferenceColumnName:
327
+ pattern = (
328
+ r"^:(?P<category>\w+)\.(?P<data_type>\[\w+\]|\w+)(\.(?P<specifier>\w+))?:(?P<name>.*)?$"
329
+ )
330
+ if match := re.match(pattern, column_name):
331
+ extract = match.groupdict(default="")
332
+ return _OpenInferenceColumnName(
333
+ full_name=column_name,
334
+ category=OpenInferenceCategory(extract.get("category", "").lower()),
335
+ data_type=extract.get("data_type", "").lower(),
336
+ specifier=OpenInferenceSpecifier(extract.get("specifier", "").lower()),
337
+ name=extract.get("name", ""),
338
+ )
339
+ raise ValueError(f"Invalid format for column name: {column_name}")
340
+
341
+
342
+ def _parse_dataframe_and_schema(dataframe: DataFrame, schema: Schema) -> Tuple[DataFrame, Schema]:
343
+ """
344
+ Parses a dataframe according to a schema, infers feature columns names when
345
+ they are not explicitly provided, and removes excluded column names from
346
+ both dataframe and schema.
347
+
348
+ Removes column names in `schema.excluded_column_names` from the input dataframe and schema. To
349
+ remove an embedding feature and all associated columns, add the name of the embedding feature to
350
+ `schema.excluded_column_names` rather than the associated column names. If
351
+ `schema.feature_column_names` is `None`, automatically discovers features by adding all column
352
+ names present in the dataframe but not included in any other schema fields.
353
+ """
354
+
355
+ unseen_excluded_column_names: Set[str] = (
356
+ set(schema.excluded_column_names) if schema.excluded_column_names is not None else set()
357
+ )
358
+ unseen_column_names: Set[str] = set(dataframe.columns.to_list())
359
+ column_name_to_include: Dict[str, bool] = {}
360
+ schema_patch: Dict[SchemaFieldName, SchemaFieldValue] = {}
361
+
362
+ for schema_field_name in SINGLE_COLUMN_SCHEMA_FIELD_NAMES:
363
+ _check_single_column_schema_field_for_excluded_columns(
364
+ schema,
365
+ schema_field_name,
366
+ unseen_excluded_column_names,
367
+ schema_patch,
368
+ column_name_to_include,
369
+ unseen_column_names,
370
+ )
371
+
372
+ for schema_field_name in MULTI_COLUMN_SCHEMA_FIELD_NAMES:
373
+ _check_multi_column_schema_field_for_excluded_columns(
374
+ schema,
375
+ schema_field_name,
376
+ unseen_excluded_column_names,
377
+ schema_patch,
378
+ column_name_to_include,
379
+ unseen_column_names,
380
+ )
381
+
382
+ if schema.embedding_feature_column_names:
383
+ _check_embedding_features_schema_field_for_excluded_columns(
384
+ schema.embedding_feature_column_names,
385
+ unseen_excluded_column_names,
386
+ schema_patch,
387
+ column_name_to_include,
388
+ unseen_column_names,
389
+ )
390
+
391
+ for llm_schema_field_name in LLM_SCHEMA_FIELD_NAMES:
392
+ embedding_column_name_mapping = getattr(schema, llm_schema_field_name)
393
+ if isinstance(embedding_column_name_mapping, EmbeddingColumnNames):
394
+ _check_embedding_column_names_for_excluded_columns(
395
+ embedding_column_name_mapping,
396
+ column_name_to_include,
397
+ unseen_column_names,
398
+ )
399
+
400
+ if not schema.feature_column_names and unseen_column_names:
401
+ _discover_feature_columns(
402
+ dataframe,
403
+ unseen_excluded_column_names,
404
+ schema_patch,
405
+ column_name_to_include,
406
+ unseen_column_names,
407
+ )
408
+
409
+ if unseen_excluded_column_names:
410
+ logger.warning(
411
+ "The following columns and embedding features were excluded in the schema but were "
412
+ "not found in the dataframe: {}".format(", ".join(unseen_excluded_column_names))
413
+ )
414
+
415
+ parsed_dataframe, parsed_schema = _create_and_normalize_dataframe_and_schema(
416
+ dataframe, schema, schema_patch, column_name_to_include
417
+ )
418
+
419
+ return parsed_dataframe, parsed_schema
420
+
421
+
422
+ def _check_single_column_schema_field_for_excluded_columns(
423
+ schema: Schema,
424
+ schema_field_name: str,
425
+ unseen_excluded_column_names: Set[str],
426
+ schema_patch: Dict[SchemaFieldName, SchemaFieldValue],
427
+ column_name_to_include: Dict[str, bool],
428
+ unseen_column_names: Set[str],
429
+ ) -> None:
430
+ """
431
+ Checks single-column schema fields for excluded column names.
432
+ """
433
+ column_name: str = getattr(schema, schema_field_name)
434
+ include_column: bool = column_name not in unseen_excluded_column_names
435
+ column_name_to_include[column_name] = include_column
436
+ if not include_column:
437
+ schema_patch[schema_field_name] = None
438
+ unseen_excluded_column_names.discard(column_name)
439
+ logger.debug(f"excluded {schema_field_name}: {column_name}")
440
+ unseen_column_names.discard(column_name)
441
+
442
+
443
+ def _check_multi_column_schema_field_for_excluded_columns(
444
+ schema: Schema,
445
+ schema_field_name: str,
446
+ unseen_excluded_column_names: Set[str],
447
+ schema_patch: Dict[SchemaFieldName, SchemaFieldValue],
448
+ column_name_to_include: Dict[str, bool],
449
+ unseen_column_names: Set[str],
450
+ ) -> None:
451
+ """
452
+ Checks multi-column schema fields for excluded columns names.
453
+ """
454
+ column_names: Optional[List[str]] = getattr(schema, schema_field_name)
455
+ if column_names:
456
+ included_column_names: List[str] = []
457
+ excluded_column_names: List[str] = []
458
+ for column_name in column_names:
459
+ is_included_column = column_name not in unseen_excluded_column_names
460
+ column_name_to_include[column_name] = is_included_column
461
+ if is_included_column:
462
+ included_column_names.append(column_name)
463
+ else:
464
+ excluded_column_names.append(column_name)
465
+ unseen_excluded_column_names.discard(column_name)
466
+ logger.debug(f"excluded {schema_field_name}: {column_name}")
467
+ unseen_column_names.discard(column_name)
468
+ schema_patch[schema_field_name] = included_column_names if included_column_names else None
469
+
470
+
471
+ def _check_embedding_features_schema_field_for_excluded_columns(
472
+ embedding_features: EmbeddingFeatures,
473
+ unseen_excluded_column_names: Set[str],
474
+ schema_patch: Dict[SchemaFieldName, SchemaFieldValue],
475
+ column_name_to_include: Dict[str, bool],
476
+ unseen_column_names: Set[str],
477
+ ) -> None:
478
+ """
479
+ Check embedding features for excluded column names.
480
+ """
481
+ included_embedding_features: EmbeddingFeatures = {}
482
+ for (
483
+ embedding_feature_name,
484
+ embedding_column_name_mapping,
485
+ ) in embedding_features.items():
486
+ include_embedding_feature = embedding_feature_name not in unseen_excluded_column_names
487
+ if include_embedding_feature:
488
+ included_embedding_features[embedding_feature_name] = deepcopy(
489
+ embedding_column_name_mapping
490
+ )
491
+ else:
492
+ unseen_excluded_column_names.discard(embedding_feature_name)
493
+
494
+ for embedding_field in fields(embedding_column_name_mapping):
495
+ column_name: Optional[str] = getattr(
496
+ embedding_column_name_mapping, embedding_field.name
497
+ )
498
+ if column_name is not None:
499
+ column_name_to_include[column_name] = include_embedding_feature
500
+ if (
501
+ column_name != embedding_feature_name
502
+ and column_name in unseen_excluded_column_names
503
+ ):
504
+ logger.warning(
505
+ f"Excluding embedding feature columns such as "
506
+ f'"{column_name}" has no effect; instead exclude the '
507
+ f'corresponding embedding feature name "{embedding_feature_name}".'
508
+ )
509
+ unseen_excluded_column_names.discard(column_name)
510
+ unseen_column_names.discard(column_name)
511
+ schema_patch["embedding_feature_column_names"] = (
512
+ included_embedding_features if included_embedding_features else None
513
+ )
514
+
515
+
516
+ def _check_embedding_column_names_for_excluded_columns(
517
+ embedding_column_name_mapping: EmbeddingColumnNames,
518
+ column_name_to_include: Dict[str, bool],
519
+ unseen_column_names: Set[str],
520
+ ) -> None:
521
+ """
522
+ Check embedding column names for excluded column names.
523
+ """
524
+ for embedding_field in fields(embedding_column_name_mapping):
525
+ column_name: Optional[str] = getattr(embedding_column_name_mapping, embedding_field.name)
526
+ if column_name is not None:
527
+ column_name_to_include[column_name] = True
528
+ unseen_column_names.discard(column_name)
529
+
530
+
531
+ def _discover_feature_columns(
532
+ dataframe: DataFrame,
533
+ unseen_excluded_column_names: Set[str],
534
+ schema_patch: Dict[SchemaFieldName, SchemaFieldValue],
535
+ column_name_to_include: Dict[str, bool],
536
+ unseen_column_names: Set[str],
537
+ ) -> None:
538
+ """
539
+ Adds unseen and un-excluded columns as features, with the exception of "prediction_id"
540
+ which is reserved
541
+ """
542
+ discovered_feature_column_names = []
543
+ for column_name in unseen_column_names:
544
+ if column_name not in unseen_excluded_column_names and column_name != "prediction_id":
545
+ discovered_feature_column_names.append(column_name)
546
+ column_name_to_include[column_name] = True
547
+ else:
548
+ unseen_excluded_column_names.discard(column_name)
549
+ logger.debug(f"excluded feature: {column_name}")
550
+ original_column_positions: List[int] = dataframe.columns.get_indexer(
551
+ discovered_feature_column_names
552
+ ) # type: ignore
553
+ feature_column_name_to_position: Dict[str, int] = dict(
554
+ zip(discovered_feature_column_names, original_column_positions)
555
+ )
556
+ discovered_feature_column_names.sort(key=lambda col: feature_column_name_to_position[col])
557
+ schema_patch["feature_column_names"] = discovered_feature_column_names
558
+ logger.debug(
559
+ "Discovered feature column names: {}".format(", ".join(discovered_feature_column_names))
560
+ )
561
+
562
+
563
+ def _create_and_normalize_dataframe_and_schema(
564
+ dataframe: DataFrame,
565
+ schema: Schema,
566
+ schema_patch: Dict[SchemaFieldName, SchemaFieldValue],
567
+ column_name_to_include: Dict[str, bool],
568
+ ) -> Tuple[DataFrame, Schema]:
569
+ """
570
+ Creates new dataframe and schema objects to reflect excluded column names
571
+ and discovered features. This also normalizes dataframe columns to ensure a
572
+ standard set of columns (i.e. timestamp and prediction_id) and datatypes for
573
+ those columns.
574
+ """
575
+ included_column_names: List[str] = []
576
+ for column_name in dataframe.columns:
577
+ if column_name_to_include.get(str(column_name), False):
578
+ included_column_names.append(str(column_name))
579
+ parsed_dataframe = dataframe[included_column_names].copy()
580
+ parsed_schema = replace(schema, excluded_column_names=None, **schema_patch) # type: ignore
581
+ pred_id_col_name = parsed_schema.prediction_id_column_name
582
+ if pred_id_col_name is None:
583
+ parsed_schema = replace(parsed_schema, prediction_id_column_name="prediction_id")
584
+ parsed_dataframe["prediction_id"] = _add_prediction_id(len(parsed_dataframe))
585
+ elif is_numeric_dtype(parsed_dataframe.dtypes[pred_id_col_name]):
586
+ parsed_dataframe[pred_id_col_name] = parsed_dataframe[pred_id_col_name].astype(str)
587
+ for embedding in (
588
+ [parsed_schema.prompt_column_names, parsed_schema.response_column_names]
589
+ + list(parsed_schema.embedding_feature_column_names.values())
590
+ if parsed_schema.embedding_feature_column_names is not None
591
+ else []
592
+ ):
593
+ if not isinstance(embedding, EmbeddingColumnNames):
594
+ continue
595
+ vector_column_name = embedding.vector_column_name
596
+ if vector_column_name not in parsed_dataframe.columns:
597
+ continue
598
+ parsed_dataframe.loc[:, vector_column_name] = _coerce_vectors_as_arrays_if_necessary(
599
+ parsed_dataframe.loc[:, vector_column_name],
600
+ vector_column_name,
601
+ )
602
+ return parsed_dataframe, parsed_schema
603
+
604
+
605
+ def _coerce_vectors_as_arrays_if_necessary(
606
+ series: "pd.Series[Any]",
607
+ column_name: str,
608
+ ) -> "pd.Series[Any]":
609
+ not_na = ~series.isna()
610
+ if not_na.sum() == 0:
611
+ return series
612
+ if invalid_types := set(map(type, series.loc[not_na])) - {np.ndarray}:
613
+ logger.warning(
614
+ f"converting items in column `{column_name}` to numpy.ndarray, "
615
+ f"because they have the following "
616
+ f"type{'s' if len(invalid_types) > 1 else ''}: "
617
+ f"{', '.join(map(lambda t: t.__name__, invalid_types))}"
618
+ )
619
+ return series.mask(not_na, series.loc[not_na].apply(np.array))
620
+ return series
621
+
622
+
623
+ def _sort_dataframe_rows_by_timestamp(dataframe: DataFrame, schema: Schema) -> DataFrame:
624
+ """
625
+ Sorts dataframe rows by timestamp.
626
+ """
627
+ timestamp_column_name = schema.timestamp_column_name
628
+ if timestamp_column_name is None:
629
+ raise ValueError("Schema must specify a timestamp column name.")
630
+ dataframe.set_index(timestamp_column_name, drop=False, inplace=True)
631
+ dataframe.sort_index(inplace=True)
632
+ return dataframe
633
+
634
+
635
+ def _normalize_timestamps(
636
+ dataframe: DataFrame,
637
+ schema: Schema,
638
+ default_timestamp: Timestamp,
639
+ ) -> Tuple[DataFrame, Schema]:
640
+ """
641
+ Ensures that the dataframe has a timestamp column and the schema has a timestamp field. If the
642
+ input dataframe contains a Unix or datetime timestamp or ISO8601 timestamp strings column, it
643
+ is converted to UTC timezone-aware timestamp. If the input dataframe and schema do not contain
644
+ timestamps, the default timestamp is used. If a timestamp is timezone-naive, it is localized
645
+ as per local timezone and then converted to UTC
646
+ """
647
+ timestamp_column: Series[Timestamp]
648
+ if (timestamp_column_name := schema.timestamp_column_name) is None:
649
+ timestamp_column_name = "timestamp"
650
+ schema = replace(schema, timestamp_column_name=timestamp_column_name)
651
+ timestamp_column = (
652
+ Series([default_timestamp] * len(dataframe), index=dataframe.index)
653
+ if len(dataframe)
654
+ else Series([default_timestamp]).iloc[:0].set_axis(dataframe.index, axis=0)
655
+ )
656
+ else:
657
+ timestamp_column = normalize_timestamps(
658
+ dataframe[timestamp_column_name],
659
+ )
660
+ dataframe[timestamp_column_name] = timestamp_column
661
+ return dataframe, schema
662
+
663
+
664
+ def _get_schema_from_unknown_schema_param(schemaLike: SchemaLike) -> Schema:
665
+ """
666
+ Compatibility function for converting from arize.utils.types.Schema to phoenix.inferences.Schema
667
+ """
668
+ try:
669
+ from arize.utils.types import (
670
+ EmbeddingColumnNames as ArizeEmbeddingColumnNames, # fmt: off type: ignore
671
+ )
672
+ from arize.utils.types import Schema as ArizeSchema
673
+
674
+ if not isinstance(schemaLike, ArizeSchema):
675
+ raise ValueError("Unknown schema passed to Dataset. Please pass a phoenix Schema")
676
+
677
+ embedding_feature_column_names: Dict[str, EmbeddingColumnNames] = {}
678
+ if schemaLike.embedding_feature_column_names is not None:
679
+ for (
680
+ embedding_name,
681
+ arize_embedding_feature_column_names,
682
+ ) in schemaLike.embedding_feature_column_names.items():
683
+ if isinstance(arize_embedding_feature_column_names, ArizeEmbeddingColumnNames):
684
+ embedding_feature_column_names[embedding_name] = EmbeddingColumnNames(
685
+ vector_column_name=arize_embedding_feature_column_names.vector_column_name,
686
+ link_to_data_column_name=arize_embedding_feature_column_names.link_to_data_column_name,
687
+ raw_data_column_name=arize_embedding_feature_column_names.data_column_name,
688
+ )
689
+ prompt_column_names: Optional[EmbeddingColumnNames] = None
690
+ if schemaLike.prompt_column_names is not None and isinstance(
691
+ schemaLike.prompt_column_names, ArizeEmbeddingColumnNames
692
+ ):
693
+ prompt_column_names = EmbeddingColumnNames(
694
+ vector_column_name=schemaLike.prompt_column_names.vector_column_name,
695
+ raw_data_column_name=schemaLike.prompt_column_names.data_column_name,
696
+ link_to_data_column_name=schemaLike.prompt_column_names.link_to_data_column_name,
697
+ )
698
+ response_column_names: Optional[EmbeddingColumnNames] = None
699
+ if schemaLike.response_column_names is not None and isinstance(
700
+ schemaLike.response_column_names, ArizeEmbeddingColumnNames
701
+ ):
702
+ response_column_names = EmbeddingColumnNames(
703
+ vector_column_name=schemaLike.response_column_names.vector_column_name,
704
+ raw_data_column_name=schemaLike.response_column_names.data_column_name,
705
+ link_to_data_column_name=schemaLike.response_column_names.link_to_data_column_name,
706
+ )
707
+ return Schema(
708
+ feature_column_names=schemaLike.feature_column_names,
709
+ tag_column_names=schemaLike.tag_column_names,
710
+ prediction_label_column_name=schemaLike.prediction_label_column_name,
711
+ actual_label_column_name=schemaLike.actual_label_column_name,
712
+ prediction_id_column_name=schemaLike.prediction_id_column_name,
713
+ timestamp_column_name=schemaLike.timestamp_column_name,
714
+ embedding_feature_column_names=embedding_feature_column_names,
715
+ prompt_column_names=prompt_column_names,
716
+ response_column_names=response_column_names,
717
+ )
718
+ except Exception:
719
+ raise ValueError(
720
+ """Unsupported Arize Schema. Please pass a phoenix Schema or update
721
+ to the latest version of Arize."""
722
+ )
723
+
724
+
725
+ def _add_prediction_id(num_rows: int) -> List[str]:
726
+ return [str(uuid.uuid4()) for _ in range(num_rows)]
727
+
728
+
729
+ # A dataset with no data. Useful for stubs
730
+ EMPTY_INFERENCES = Inferences(pd.DataFrame(), schema=Schema())