arize-phoenix 0.0.29rc8__tar.gz → 0.0.31__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (100) hide show
  1. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/PKG-INFO +4 -1
  2. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/pyproject.toml +8 -0
  3. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/__init__.py +3 -1
  4. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/datasets/dataset.py +204 -1
  5. arize_phoenix-0.0.31/src/phoenix/experimental/evals/retrievals.py +91 -0
  6. arize_phoenix-0.0.31/src/phoenix/server/api/types/__init__.py +0 -0
  7. arize_phoenix-0.0.31/src/phoenix/server/static/index.js +6235 -0
  8. arize_phoenix-0.0.31/src/phoenix/session/__init__.py +0 -0
  9. arize_phoenix-0.0.29rc8/src/phoenix/server/static/index.js +0 -6198
  10. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/.gitignore +0 -0
  11. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/IP_NOTICE +0 -0
  12. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/LICENSE +0 -0
  13. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/README.md +0 -0
  14. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/config.py +0 -0
  15. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/core/__init__.py +0 -0
  16. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/core/dimension_data_type.py +0 -0
  17. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/core/dimension_type.py +0 -0
  18. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/core/embedding_dimension.py +0 -0
  19. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/core/model.py +0 -0
  20. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/core/model_schema.py +0 -0
  21. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/core/model_schema_adapter.py +0 -0
  22. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/datasets/__init__.py +0 -0
  23. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/datasets/errors.py +0 -0
  24. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/datasets/fixtures.py +0 -0
  25. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/datasets/schema.py +0 -0
  26. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/datasets/validation.py +0 -0
  27. {arize_phoenix-0.0.29rc8/src/phoenix/server → arize_phoenix-0.0.31/src/phoenix/experimental}/__init__.py +0 -0
  28. {arize_phoenix-0.0.29rc8/src/phoenix/server/api → arize_phoenix-0.0.31/src/phoenix/experimental/evals}/__init__.py +0 -0
  29. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/metrics/README.md +0 -0
  30. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/metrics/__init__.py +0 -0
  31. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/metrics/binning.py +0 -0
  32. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/metrics/metrics.py +0 -0
  33. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/metrics/mixins.py +0 -0
  34. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/metrics/timeseries.py +0 -0
  35. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/metrics/wrappers.py +0 -0
  36. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/pointcloud/__init__.py +0 -0
  37. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/pointcloud/clustering.py +0 -0
  38. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/pointcloud/pointcloud.py +0 -0
  39. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/pointcloud/projectors.py +0 -0
  40. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/py.typed +0 -0
  41. {arize_phoenix-0.0.29rc8/src/phoenix/server/api/input_types → arize_phoenix-0.0.31/src/phoenix/server}/__init__.py +0 -0
  42. {arize_phoenix-0.0.29rc8/src/phoenix/server/api/types → arize_phoenix-0.0.31/src/phoenix/server/api}/__init__.py +0 -0
  43. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/context.py +0 -0
  44. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/helpers.py +0 -0
  45. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/input_types/ClusterInput.py +0 -0
  46. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/input_types/Coordinates.py +0 -0
  47. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/input_types/DataQualityMetricInput.py +0 -0
  48. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/input_types/DimensionFilter.py +0 -0
  49. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/input_types/DimensionInput.py +0 -0
  50. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/input_types/Granularity.py +0 -0
  51. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/input_types/PerformanceMetricInput.py +0 -0
  52. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/input_types/TimeRange.py +0 -0
  53. {arize_phoenix-0.0.29rc8/src/phoenix/session → arize_phoenix-0.0.31/src/phoenix/server/api/input_types}/__init__.py +0 -0
  54. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/interceptor.py +0 -0
  55. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/schema.py +0 -0
  56. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/Cluster.py +0 -0
  57. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/DataQualityMetric.py +0 -0
  58. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/Dataset.py +0 -0
  59. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/DatasetRole.py +0 -0
  60. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/DatasetValues.py +0 -0
  61. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/Dimension.py +0 -0
  62. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/DimensionDataType.py +0 -0
  63. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/DimensionShape.py +0 -0
  64. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/DimensionType.py +0 -0
  65. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/DimensionWithValue.py +0 -0
  66. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/EmbeddingDimension.py +0 -0
  67. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/EmbeddingMetadata.py +0 -0
  68. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/Event.py +0 -0
  69. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/EventMetadata.py +0 -0
  70. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/ExportEventsMutation.py +0 -0
  71. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/ExportedFile.py +0 -0
  72. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/Model.py +0 -0
  73. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/NumericRange.py +0 -0
  74. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/PerformanceMetric.py +0 -0
  75. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/PromptResponse.py +0 -0
  76. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/Retrieval.py +0 -0
  77. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/ScalarDriftMetricEnum.py +0 -0
  78. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/Segments.py +0 -0
  79. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/TimeSeries.py +0 -0
  80. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/UMAPPoints.py +0 -0
  81. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/VectorDriftMetricEnum.py +0 -0
  82. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/node.py +0 -0
  83. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/api/types/pagination.py +0 -0
  84. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/app.py +0 -0
  85. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/main.py +0 -0
  86. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/static/apple-touch-icon-114x114.png +0 -0
  87. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/static/apple-touch-icon-120x120.png +0 -0
  88. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/static/apple-touch-icon-144x144.png +0 -0
  89. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/static/apple-touch-icon-152x152.png +0 -0
  90. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/static/apple-touch-icon-180x180.png +0 -0
  91. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/static/apple-touch-icon-72x72.png +0 -0
  92. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/static/apple-touch-icon-76x76.png +0 -0
  93. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/static/apple-touch-icon.png +0 -0
  94. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/static/favicon.ico +0 -0
  95. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/static/index.css +0 -0
  96. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/static/index.html +0 -0
  97. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/static/modernizr.js +0 -0
  98. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/server/thread_server.py +0 -0
  99. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/services.py +0 -0
  100. {arize_phoenix-0.0.29rc8 → arize_phoenix-0.0.31}/src/phoenix/session/session.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: arize-phoenix
3
- Version: 0.0.29rc8
3
+ Version: 0.0.31
4
4
  Summary: ML Observability in your notebook
5
5
  Project-URL: Documentation, https://docs.arize.com/phoenix/
6
6
  Project-URL: Issues, https://github.com/Arize-ai/phoenix/issues
@@ -41,6 +41,9 @@ Requires-Dist: pytest; extra == 'dev'
41
41
  Requires-Dist: pytest-cov; extra == 'dev'
42
42
  Requires-Dist: pytest-lazy-fixture; extra == 'dev'
43
43
  Requires-Dist: strawberry-graphql[debug-server]==0.178.0; extra == 'dev'
44
+ Provides-Extra: experimental
45
+ Requires-Dist: openai; extra == 'experimental'
46
+ Requires-Dist: tenacity; extra == 'experimental'
44
47
  Description-Content-Type: text/markdown
45
48
 
46
49
  <p align="center">
@@ -51,6 +51,10 @@ dev = [
51
51
  "pre-commit",
52
52
  "arize[AutoEmbeddings, LLM_Evaluation]",
53
53
  ]
54
+ experimental = [
55
+ "openai",
56
+ "tenacity",
57
+ ]
54
58
 
55
59
  [project.urls]
56
60
  Documentation = "https://docs.arize.com/phoenix/"
@@ -90,6 +94,8 @@ dependencies = [
90
94
  "pytest-cov",
91
95
  "pytest-lazy-fixture",
92
96
  "arize",
97
+ "openai",
98
+ "tenacity",
93
99
  ]
94
100
 
95
101
  [tool.hatch.envs.type]
@@ -98,6 +104,7 @@ dependencies = [
98
104
  "pandas-stubs",
99
105
  "pytest",
100
106
  "types-psutil",
107
+ "tenacity",
101
108
  ]
102
109
 
103
110
  [tool.hatch.envs.style]
@@ -241,6 +248,7 @@ module = [
241
248
  "arize.*",
242
249
  "portpicker",
243
250
  "wrapt",
251
+ "openai",
244
252
  ]
245
253
  ignore_missing_imports = true
246
254
 
@@ -2,8 +2,9 @@ from .datasets.dataset import Dataset
2
2
  from .datasets.fixtures import ExampleDatasets, load_example
3
3
  from .datasets.schema import EmbeddingColumnNames, RetrievalEmbeddingColumnNames, Schema
4
4
  from .session.session import Session, active_session, close_app, launch_app
5
+ from .trace.fixtures import load_example_traces
5
6
 
6
- __version__ = "0.0.29rc8"
7
+ __version__ = "0.0.31"
7
8
 
8
9
  # module level doc-string
9
10
  __doc__ = """
@@ -32,4 +33,5 @@ __all__ = [
32
33
  "close_app",
33
34
  "launch_app",
34
35
  "Session",
36
+ "load_example_traces",
35
37
  ]
@@ -1,7 +1,10 @@
1
1
  import logging
2
+ import re
2
3
  import uuid
3
4
  from copy import deepcopy
4
- from dataclasses import fields, replace
5
+ from dataclasses import dataclass, fields, replace
6
+ from enum import Enum
7
+ from itertools import groupby
5
8
  from typing import Any, Dict, List, Optional, Set, Tuple, Union
6
9
 
7
10
  import numpy as np
@@ -25,6 +28,7 @@ from .schema import (
25
28
  SINGLE_COLUMN_SCHEMA_FIELD_NAMES,
26
29
  EmbeddingColumnNames,
27
30
  EmbeddingFeatures,
31
+ RetrievalEmbeddingColumnNames,
28
32
  Schema,
29
33
  SchemaFieldName,
30
34
  SchemaFieldValue,
@@ -121,6 +125,160 @@ class Dataset:
121
125
  schema = Schema.from_json(schema_json)
122
126
  return cls(df, schema, name)
123
127
 
128
+ @classmethod
129
+ def from_open_inference(cls, dataframe: DataFrame) -> "Dataset":
130
+ schema = Schema()
131
+ column_renaming: Dict[str, str] = {}
132
+ for group_name, group in groupby(
133
+ sorted(
134
+ map(_parse_open_inference_column_name, dataframe.columns),
135
+ key=lambda column: column.name,
136
+ ),
137
+ key=lambda column: column.name,
138
+ ):
139
+ open_inference_columns = list(group)
140
+ if group_name == "":
141
+ column_names_by_category = {
142
+ column.category: column.full_name for column in open_inference_columns
143
+ }
144
+ schema = replace(
145
+ schema,
146
+ prediction_id_column_name=column_names_by_category.get(
147
+ OpenInferenceCategory.id
148
+ ),
149
+ timestamp_column_name=column_names_by_category.get(
150
+ OpenInferenceCategory.timestamp
151
+ ),
152
+ )
153
+ continue
154
+ column_names_by_specifier = {
155
+ column.specifier: column.full_name for column in open_inference_columns
156
+ }
157
+ if group_name == "response":
158
+ response_vector_column_name = column_names_by_specifier.get(
159
+ OpenInferenceSpecifier.embedding
160
+ )
161
+ if response_vector_column_name is not None:
162
+ column_renaming[response_vector_column_name] = "response"
163
+ schema = replace(
164
+ schema,
165
+ response_column_names=EmbeddingColumnNames(
166
+ vector_column_name=column_renaming[response_vector_column_name],
167
+ raw_data_column_name=column_names_by_specifier.get(
168
+ OpenInferenceSpecifier.default
169
+ ),
170
+ ),
171
+ )
172
+ else:
173
+ response_text_column_name = column_names_by_specifier.get(
174
+ OpenInferenceSpecifier.default
175
+ )
176
+ if response_text_column_name is None:
177
+ raise ValueError(
178
+ "invalid OpenInference format: missing text column for response"
179
+ )
180
+ column_renaming[response_text_column_name] = "response"
181
+ schema = replace(
182
+ schema,
183
+ response_column_names=column_renaming[response_text_column_name],
184
+ )
185
+ elif group_name == "prompt":
186
+ prompt_vector_column_name = column_names_by_specifier.get(
187
+ OpenInferenceSpecifier.embedding
188
+ )
189
+ if prompt_vector_column_name is None:
190
+ raise ValueError(
191
+ "invalid OpenInference format: missing embedding vector column for prompt"
192
+ )
193
+ column_renaming[prompt_vector_column_name] = "prompt"
194
+ schema = replace(
195
+ schema,
196
+ prompt_column_names=RetrievalEmbeddingColumnNames(
197
+ vector_column_name=column_renaming[prompt_vector_column_name],
198
+ raw_data_column_name=column_names_by_specifier.get(
199
+ OpenInferenceSpecifier.default
200
+ ),
201
+ context_retrieval_ids_column_name=column_names_by_specifier.get(
202
+ OpenInferenceSpecifier.retrieved_document_ids
203
+ ),
204
+ context_retrieval_scores_column_name=column_names_by_specifier.get(
205
+ OpenInferenceSpecifier.retrieved_document_scores
206
+ ),
207
+ ),
208
+ )
209
+ elif OpenInferenceSpecifier.embedding in column_names_by_specifier:
210
+ vector_column_name = column_names_by_specifier[OpenInferenceSpecifier.embedding]
211
+ column_renaming[vector_column_name] = group_name
212
+ embedding_feature_column_names = schema.embedding_feature_column_names or {}
213
+ embedding_feature_column_names.update(
214
+ {
215
+ group_name: EmbeddingColumnNames(
216
+ vector_column_name=column_renaming[vector_column_name],
217
+ raw_data_column_name=column_names_by_specifier.get(
218
+ OpenInferenceSpecifier.raw_data
219
+ ),
220
+ link_to_data_column_name=column_names_by_specifier.get(
221
+ OpenInferenceSpecifier.link_to_data
222
+ ),
223
+ )
224
+ }
225
+ )
226
+ schema = replace(
227
+ schema,
228
+ embedding_feature_column_names=embedding_feature_column_names,
229
+ )
230
+ elif len(open_inference_columns) == 1:
231
+ open_inference_column = open_inference_columns[0]
232
+ raw_column_name = open_inference_column.full_name
233
+ column_renaming[raw_column_name] = open_inference_column.name
234
+ if open_inference_column.category is OpenInferenceCategory.feature:
235
+ schema = replace(
236
+ schema,
237
+ feature_column_names=(
238
+ (schema.feature_column_names or []) + [column_renaming[raw_column_name]]
239
+ ),
240
+ )
241
+ elif open_inference_column.category is OpenInferenceCategory.tag:
242
+ schema = replace(
243
+ schema,
244
+ tag_column_names=(
245
+ (schema.tag_column_names or []) + [column_renaming[raw_column_name]]
246
+ ),
247
+ )
248
+ elif open_inference_column.category is OpenInferenceCategory.prediction:
249
+ if open_inference_column.specifier is OpenInferenceSpecifier.score:
250
+ schema = replace(
251
+ schema,
252
+ prediction_score_column_name=column_renaming[raw_column_name],
253
+ )
254
+ if open_inference_column.specifier is OpenInferenceSpecifier.label:
255
+ schema = replace(
256
+ schema,
257
+ prediction_label_column_name=column_renaming[raw_column_name],
258
+ )
259
+ elif open_inference_column.category is OpenInferenceCategory.actual:
260
+ if open_inference_column.specifier is OpenInferenceSpecifier.score:
261
+ schema = replace(
262
+ schema,
263
+ actual_score_column_name=column_renaming[raw_column_name],
264
+ )
265
+ if open_inference_column.specifier is OpenInferenceSpecifier.label:
266
+ schema = replace(
267
+ schema,
268
+ actual_label_column_name=column_renaming[raw_column_name],
269
+ )
270
+ else:
271
+ raise ValueError(f"invalid OpenInference format: duplicated name `{group_name}`")
272
+
273
+ return cls(
274
+ dataframe.rename(
275
+ column_renaming,
276
+ axis=1,
277
+ copy=False,
278
+ ),
279
+ schema,
280
+ )
281
+
124
282
  def to_disc(self) -> None:
125
283
  """writes the data and schema to disc"""
126
284
  directory = DATASET_DIR / self.name
@@ -528,3 +686,48 @@ def _get_schema_from_unknown_schema_param(schemaLike: SchemaLike) -> Schema:
528
686
 
529
687
  def _add_prediction_id(num_rows: int) -> List[str]:
530
688
  return [str(uuid.uuid4()) for _ in range(num_rows)]
689
+
690
+
691
+ class OpenInferenceCategory(Enum):
692
+ id = "id"
693
+ timestamp = "timestamp"
694
+ feature = "feature"
695
+ tag = "tag"
696
+ prediction = "prediction"
697
+ actual = "actual"
698
+
699
+
700
+ class OpenInferenceSpecifier(Enum):
701
+ default = ""
702
+ score = "score"
703
+ label = "label"
704
+ embedding = "embedding"
705
+ raw_data = "raw_data"
706
+ link_to_data = "link_to_data"
707
+ retrieved_document_ids = "retrieved_document_ids"
708
+ retrieved_document_scores = "retrieved_document_scores"
709
+
710
+
711
+ @dataclass(frozen=True)
712
+ class _OpenInferenceColumnName:
713
+ full_name: str
714
+ category: OpenInferenceCategory
715
+ data_type: str
716
+ specifier: OpenInferenceSpecifier = OpenInferenceSpecifier.default
717
+ name: str = ""
718
+
719
+
720
+ def _parse_open_inference_column_name(column_name: str) -> _OpenInferenceColumnName:
721
+ pattern = (
722
+ r"^:(?P<category>\w+)\.(?P<data_type>\[\w+\]|\w+)(\.(?P<specifier>\w+))?:(?P<name>.*)?$"
723
+ )
724
+ if match := re.match(pattern, column_name):
725
+ extract = match.groupdict(default="")
726
+ return _OpenInferenceColumnName(
727
+ full_name=column_name,
728
+ category=OpenInferenceCategory(extract.get("category", "").lower()),
729
+ data_type=extract.get("data_type", "").lower(),
730
+ specifier=OpenInferenceSpecifier(extract.get("specifier", "").lower()),
731
+ name=extract.get("name", ""),
732
+ )
733
+ raise ValueError(f"Invalid format for column name: {column_name}")
@@ -0,0 +1,91 @@
1
+ """
2
+ Helper functions for evaluating the retrieval step of retrieval-augmented generation.
3
+ """
4
+
5
+ from typing import List, Optional
6
+
7
+ import openai
8
+ from tenacity import (
9
+ retry,
10
+ stop_after_attempt,
11
+ wait_random_exponential,
12
+ )
13
+
14
+ _EVALUATION_SYSTEM_MESSAGE = (
15
+ "You will be given a query and a reference text. "
16
+ "You must determine whether the reference text contains an answer to the input query. "
17
+ 'Your response must be single word, either "relevant" or "irrelevant", '
18
+ "and should not contain any text or characters aside from that word. "
19
+ '"irrelevant" means that the reference text does not contain an answer to the query. '
20
+ '"relevant" means the reference text contains an answer to the query.'
21
+ )
22
+ _QUERY_CONTEXT_PROMPT_TEMPLATE = """# Query: {query}
23
+
24
+ # Reference: {reference}
25
+
26
+ # Answer ("relevant" or "irrelevant"): """
27
+
28
+
29
+ def compute_precisions_at_k(
30
+ relevance_classifications: List[Optional[bool]],
31
+ ) -> List[Optional[float]]:
32
+ """Given a list of relevance classifications, computes precision@k for k = 1, 2, ..., n, where
33
+ n is the length of the input list.
34
+
35
+ Args:
36
+ relevance_classifications (List[Optional[bool]]): A list of relevance classifications for a
37
+ set of retrieved documents, sorted by order of retrieval (i.e., the first element is the
38
+ classification for the first retrieved document, the second element is the
39
+ classification for the second retrieved document, etc.). The list may contain None
40
+ values, which indicate that the relevance classification for the corresponding document
41
+ is unknown.
42
+
43
+ Returns:
44
+ List[Optional[float]]: A list of precision@k values for k = 1, 2, ..., n, where n is the
45
+ length of the input list. The first element is the precision@1 value, the second element
46
+ is the precision@2 value, etc. If the input list contains any None values, those values
47
+ are omitted when computing the precision@k values.
48
+ """
49
+ precisions_at_k = []
50
+ num_relevant_classifications = 0
51
+ num_non_none_classifications = 0
52
+ for relevance_classification in relevance_classifications:
53
+ if isinstance(relevance_classification, bool):
54
+ num_non_none_classifications += 1
55
+ num_relevant_classifications += int(relevance_classification)
56
+ precisions_at_k.append(
57
+ num_relevant_classifications / num_non_none_classifications
58
+ if num_non_none_classifications > 0
59
+ else None
60
+ )
61
+ return precisions_at_k
62
+
63
+
64
+ @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
65
+ def classify_relevance(query: str, document: str, model_name: str) -> Optional[bool]:
66
+ """Given a query and a document, determines whether the document contains an answer to the
67
+ query.
68
+
69
+ Args:
70
+ query (str): The query text. document (str): The document text. model_name (str): The name
71
+ of the OpenAI API model to use for the classification.
72
+
73
+ Returns:
74
+ Optional[bool]: A boolean indicating whether the document contains an answer to the query
75
+ (True meaning relevant, False meaning irrelevant), or None if the LLM produces an
76
+ unparseable output.
77
+ """
78
+ prompt = _QUERY_CONTEXT_PROMPT_TEMPLATE.format(
79
+ query=query,
80
+ reference=document,
81
+ )
82
+ response = openai.ChatCompletion.create(
83
+ messages=[
84
+ {"role": "system", "content": _EVALUATION_SYSTEM_MESSAGE},
85
+ {"role": "user", "content": prompt},
86
+ ],
87
+ model=model_name,
88
+ )
89
+ raw_response_text = str(response["choices"][0]["message"]["content"]).strip()
90
+ relevance_classification = {"relevant": True, "irrelevant": False}.get(raw_response_text)
91
+ return relevance_classification