arize-phoenix 3.16.1__py3-none-any.whl → 7.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- arize_phoenix-7.7.0.dist-info/METADATA +261 -0
- arize_phoenix-7.7.0.dist-info/RECORD +345 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.0.dist-info}/WHEEL +1 -1
- arize_phoenix-7.7.0.dist-info/entry_points.txt +3 -0
- phoenix/__init__.py +86 -14
- phoenix/auth.py +309 -0
- phoenix/config.py +675 -45
- phoenix/core/model.py +32 -30
- phoenix/core/model_schema.py +102 -109
- phoenix/core/model_schema_adapter.py +48 -45
- phoenix/datetime_utils.py +24 -3
- phoenix/db/README.md +54 -0
- phoenix/db/__init__.py +4 -0
- phoenix/db/alembic.ini +85 -0
- phoenix/db/bulk_inserter.py +294 -0
- phoenix/db/engines.py +208 -0
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +113 -0
- phoenix/db/helpers.py +159 -0
- phoenix/db/insertion/constants.py +2 -0
- phoenix/db/insertion/dataset.py +227 -0
- phoenix/db/insertion/document_annotation.py +171 -0
- phoenix/db/insertion/evaluation.py +191 -0
- phoenix/db/insertion/helpers.py +98 -0
- phoenix/db/insertion/span.py +193 -0
- phoenix/db/insertion/span_annotation.py +158 -0
- phoenix/db/insertion/trace_annotation.py +158 -0
- phoenix/db/insertion/types.py +256 -0
- phoenix/db/migrate.py +86 -0
- phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
- phoenix/db/migrations/env.py +114 -0
- phoenix/db/migrations/script.py.mako +26 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
- phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
- phoenix/db/models.py +807 -0
- phoenix/exceptions.py +5 -1
- phoenix/experiments/__init__.py +6 -0
- phoenix/experiments/evaluators/__init__.py +29 -0
- phoenix/experiments/evaluators/base.py +158 -0
- phoenix/experiments/evaluators/code_evaluators.py +184 -0
- phoenix/experiments/evaluators/llm_evaluators.py +473 -0
- phoenix/experiments/evaluators/utils.py +236 -0
- phoenix/experiments/functions.py +772 -0
- phoenix/experiments/tracing.py +86 -0
- phoenix/experiments/types.py +726 -0
- phoenix/experiments/utils.py +25 -0
- phoenix/inferences/__init__.py +0 -0
- phoenix/{datasets → inferences}/errors.py +6 -5
- phoenix/{datasets → inferences}/fixtures.py +49 -42
- phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
- phoenix/{datasets → inferences}/schema.py +11 -11
- phoenix/{datasets → inferences}/validation.py +13 -14
- phoenix/logging/__init__.py +3 -0
- phoenix/logging/_config.py +90 -0
- phoenix/logging/_filter.py +6 -0
- phoenix/logging/_formatter.py +69 -0
- phoenix/metrics/__init__.py +5 -4
- phoenix/metrics/binning.py +4 -3
- phoenix/metrics/metrics.py +2 -1
- phoenix/metrics/mixins.py +7 -6
- phoenix/metrics/retrieval_metrics.py +2 -1
- phoenix/metrics/timeseries.py +5 -4
- phoenix/metrics/wrappers.py +9 -3
- phoenix/pointcloud/clustering.py +5 -5
- phoenix/pointcloud/pointcloud.py +7 -5
- phoenix/pointcloud/projectors.py +5 -6
- phoenix/pointcloud/umap_parameters.py +53 -52
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +44 -0
- phoenix/server/api/context.py +152 -9
- phoenix/server/api/dataloaders/__init__.py +91 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
- phoenix/server/api/dataloaders/cache/__init__.py +3 -0
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
- phoenix/server/api/dataloaders/document_evaluations.py +31 -0
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
- phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
- phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/record_counts.py +116 -0
- phoenix/server/api/dataloaders/session_io.py +79 -0
- phoenix/server/api/dataloaders/session_num_traces.py +30 -0
- phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
- phoenix/server/api/dataloaders/session_token_usages.py +41 -0
- phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
- phoenix/server/api/dataloaders/span_annotations.py +26 -0
- phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +57 -0
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/token_counts.py +124 -0
- phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
- phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +48 -0
- phoenix/server/api/helpers/__init__.py +12 -0
- phoenix/server/api/helpers/dataset_helpers.py +217 -0
- phoenix/server/api/helpers/experiment_run_filters.py +763 -0
- phoenix/server/api/helpers/playground_clients.py +948 -0
- phoenix/server/api/helpers/playground_registry.py +70 -0
- phoenix/server/api/helpers/playground_spans.py +455 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
- phoenix/server/api/input_types/ClearProjectInput.py +15 -0
- phoenix/server/api/input_types/ClusterInput.py +2 -2
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
- phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
- phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
- phoenix/server/api/input_types/DimensionFilter.py +4 -4
- phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
- phoenix/server/api/input_types/Granularity.py +1 -1
- phoenix/server/api/input_types/InvocationParameters.py +162 -0
- phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
- phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
- phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
- phoenix/server/api/input_types/SpanSort.py +134 -69
- phoenix/server/api/input_types/TemplateOptions.py +10 -0
- phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
- phoenix/server/api/input_types/UserRoleInput.py +9 -0
- phoenix/server/api/mutations/__init__.py +28 -0
- phoenix/server/api/mutations/api_key_mutations.py +167 -0
- phoenix/server/api/mutations/chat_mutations.py +593 -0
- phoenix/server/api/mutations/dataset_mutations.py +591 -0
- phoenix/server/api/mutations/experiment_mutations.py +75 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
- phoenix/server/api/mutations/project_mutations.py +57 -0
- phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
- phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
- phoenix/server/api/mutations/user_mutations.py +329 -0
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +17 -0
- phoenix/server/api/openapi/schema.py +16 -0
- phoenix/server/api/queries.py +738 -0
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +488 -0
- phoenix/server/api/routers/v1/__init__.py +64 -0
- phoenix/server/api/routers/v1/datasets.py +1017 -0
- phoenix/server/api/routers/v1/evaluations.py +362 -0
- phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
- phoenix/server/api/routers/v1/experiment_runs.py +167 -0
- phoenix/server/api/routers/v1/experiments.py +308 -0
- phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
- phoenix/server/api/routers/v1/spans.py +267 -0
- phoenix/server/api/routers/v1/traces.py +208 -0
- phoenix/server/api/routers/v1/utils.py +95 -0
- phoenix/server/api/schema.py +44 -241
- phoenix/server/api/subscriptions.py +597 -0
- phoenix/server/api/types/Annotation.py +21 -0
- phoenix/server/api/types/AnnotationSummary.py +55 -0
- phoenix/server/api/types/AnnotatorKind.py +16 -0
- phoenix/server/api/types/ApiKey.py +27 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
- phoenix/server/api/types/Cluster.py +25 -24
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/DataQualityMetric.py +31 -13
- phoenix/server/api/types/Dataset.py +288 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +32 -31
- phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
- phoenix/server/api/types/EmbeddingDimension.py +56 -49
- phoenix/server/api/types/Evaluation.py +25 -31
- phoenix/server/api/types/EvaluationSummary.py +30 -50
- phoenix/server/api/types/Event.py +20 -20
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +152 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +17 -0
- phoenix/server/api/types/ExperimentRun.py +119 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
- phoenix/server/api/types/GenerativeModel.py +9 -0
- phoenix/server/api/types/GenerativeProvider.py +85 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/LabelFraction.py +7 -0
- phoenix/server/api/types/MimeType.py +2 -2
- phoenix/server/api/types/Model.py +54 -54
- phoenix/server/api/types/PerformanceMetric.py +8 -5
- phoenix/server/api/types/Project.py +407 -142
- phoenix/server/api/types/ProjectSession.py +139 -0
- phoenix/server/api/types/Segments.py +4 -4
- phoenix/server/api/types/Span.py +221 -176
- phoenix/server/api/types/SpanAnnotation.py +43 -0
- phoenix/server/api/types/SpanIOValue.py +15 -0
- phoenix/server/api/types/SystemApiKey.py +9 -0
- phoenix/server/api/types/TemplateLanguage.py +10 -0
- phoenix/server/api/types/TimeSeries.py +19 -15
- phoenix/server/api/types/TokenUsage.py +11 -0
- phoenix/server/api/types/Trace.py +154 -0
- phoenix/server/api/types/TraceAnnotation.py +45 -0
- phoenix/server/api/types/UMAPPoints.py +7 -7
- phoenix/server/api/types/User.py +60 -0
- phoenix/server/api/types/UserApiKey.py +45 -0
- phoenix/server/api/types/UserRole.py +15 -0
- phoenix/server/api/types/node.py +4 -112
- phoenix/server/api/types/pagination.py +156 -57
- phoenix/server/api/utils.py +34 -0
- phoenix/server/app.py +864 -115
- phoenix/server/bearer_auth.py +163 -0
- phoenix/server/dml_event.py +136 -0
- phoenix/server/dml_event_handler.py +256 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +97 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +102 -0
- phoenix/server/jwt_store.py +505 -0
- phoenix/server/main.py +305 -116
- phoenix/server/oauth2.py +52 -0
- phoenix/server/openapi/__init__.py +0 -0
- phoenix/server/prometheus.py +111 -0
- phoenix/server/rate_limiters.py +188 -0
- phoenix/server/static/.vite/manifest.json +87 -0
- phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
- phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
- phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
- phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
- phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
- phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
- phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
- phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
- phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
- phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
- phoenix/server/telemetry.py +68 -0
- phoenix/server/templates/index.html +82 -23
- phoenix/server/thread_server.py +3 -3
- phoenix/server/types.py +275 -0
- phoenix/services.py +27 -18
- phoenix/session/client.py +743 -68
- phoenix/session/data_extractor.py +31 -7
- phoenix/session/evaluation.py +3 -9
- phoenix/session/session.py +263 -219
- phoenix/settings.py +22 -0
- phoenix/trace/__init__.py +2 -22
- phoenix/trace/attributes.py +338 -0
- phoenix/trace/dsl/README.md +116 -0
- phoenix/trace/dsl/filter.py +663 -213
- phoenix/trace/dsl/helpers.py +73 -21
- phoenix/trace/dsl/query.py +574 -201
- phoenix/trace/exporter.py +24 -19
- phoenix/trace/fixtures.py +368 -32
- phoenix/trace/otel.py +71 -219
- phoenix/trace/projects.py +3 -2
- phoenix/trace/schemas.py +33 -11
- phoenix/trace/span_evaluations.py +21 -16
- phoenix/trace/span_json_decoder.py +6 -4
- phoenix/trace/span_json_encoder.py +2 -2
- phoenix/trace/trace_dataset.py +47 -32
- phoenix/trace/utils.py +21 -4
- phoenix/utilities/__init__.py +0 -26
- phoenix/utilities/client.py +132 -0
- phoenix/utilities/deprecation.py +31 -0
- phoenix/utilities/error_handling.py +3 -2
- phoenix/utilities/json.py +109 -0
- phoenix/utilities/logging.py +8 -0
- phoenix/utilities/project.py +2 -2
- phoenix/utilities/re.py +49 -0
- phoenix/utilities/span_store.py +0 -23
- phoenix/utilities/template_formatters.py +99 -0
- phoenix/version.py +1 -1
- arize_phoenix-3.16.1.dist-info/METADATA +0 -495
- arize_phoenix-3.16.1.dist-info/RECORD +0 -178
- phoenix/core/project.py +0 -619
- phoenix/core/traces.py +0 -96
- phoenix/experimental/evals/__init__.py +0 -73
- phoenix/experimental/evals/evaluators.py +0 -413
- phoenix/experimental/evals/functions/__init__.py +0 -4
- phoenix/experimental/evals/functions/classify.py +0 -453
- phoenix/experimental/evals/functions/executor.py +0 -353
- phoenix/experimental/evals/functions/generate.py +0 -138
- phoenix/experimental/evals/functions/processing.py +0 -76
- phoenix/experimental/evals/models/__init__.py +0 -14
- phoenix/experimental/evals/models/anthropic.py +0 -175
- phoenix/experimental/evals/models/base.py +0 -170
- phoenix/experimental/evals/models/bedrock.py +0 -221
- phoenix/experimental/evals/models/litellm.py +0 -134
- phoenix/experimental/evals/models/openai.py +0 -448
- phoenix/experimental/evals/models/rate_limiters.py +0 -246
- phoenix/experimental/evals/models/vertex.py +0 -173
- phoenix/experimental/evals/models/vertexai.py +0 -186
- phoenix/experimental/evals/retrievals.py +0 -96
- phoenix/experimental/evals/templates/__init__.py +0 -50
- phoenix/experimental/evals/templates/default_templates.py +0 -472
- phoenix/experimental/evals/templates/template.py +0 -195
- phoenix/experimental/evals/utils/__init__.py +0 -172
- phoenix/experimental/evals/utils/threads.py +0 -27
- phoenix/server/api/helpers.py +0 -11
- phoenix/server/api/routers/evaluation_handler.py +0 -109
- phoenix/server/api/routers/span_handler.py +0 -70
- phoenix/server/api/routers/trace_handler.py +0 -60
- phoenix/server/api/types/DatasetRole.py +0 -23
- phoenix/server/static/index.css +0 -6
- phoenix/server/static/index.js +0 -7447
- phoenix/storage/span_store/__init__.py +0 -23
- phoenix/storage/span_store/text_file.py +0 -85
- phoenix/trace/dsl/missing.py +0 -60
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -35
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -102
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -30
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/LICENSE +0 -0
- /phoenix/{datasets → db/insertion}/__init__.py +0 -0
- /phoenix/{experimental → db/migrations}/__init__.py +0 -0
- /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
|
@@ -1,246 +0,0 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import time
|
|
3
|
-
from functools import wraps
|
|
4
|
-
from math import exp
|
|
5
|
-
from typing import Any, Callable, Coroutine, Optional, Tuple, Type, TypeVar
|
|
6
|
-
|
|
7
|
-
from typing_extensions import ParamSpec
|
|
8
|
-
|
|
9
|
-
from phoenix.exceptions import PhoenixException
|
|
10
|
-
from phoenix.utilities.logging import printif
|
|
11
|
-
|
|
12
|
-
ParameterSpec = ParamSpec("ParameterSpec")
|
|
13
|
-
GenericType = TypeVar("GenericType")
|
|
14
|
-
AsyncCallable = Callable[ParameterSpec, Coroutine[Any, Any, GenericType]]
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class UnavailableTokensError(PhoenixException):
|
|
18
|
-
pass
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class AdaptiveTokenBucket:
|
|
22
|
-
"""
|
|
23
|
-
An adaptive rate-limiter that adjusts the rate based on the number of rate limit errors.
|
|
24
|
-
|
|
25
|
-
This rate limiter does not need to know the exact rate limit. Instead, it starts with a high
|
|
26
|
-
rate and reduces it whenever a rate limit error occurs. The rate is increased slowly over time
|
|
27
|
-
if no further errors occur.
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
initial_per_second_request_rate (float): The allowed request rate.
|
|
31
|
-
maximum_per_second_request_rate (float): The maximum allowed request rate.
|
|
32
|
-
enforcement_window_minutes (float): The time window over which the rate limit is enforced.
|
|
33
|
-
rate_reduction_factor (float): Multiplier used to reduce the rate limit after an error.
|
|
34
|
-
rate_increase_factor (float): Exponential factor increasing the rate limit over time.
|
|
35
|
-
cooldown_seconds (float): The minimum time before allowing the rate limit to decrease again.
|
|
36
|
-
"""
|
|
37
|
-
|
|
38
|
-
def __init__(
|
|
39
|
-
self,
|
|
40
|
-
initial_per_second_request_rate: float,
|
|
41
|
-
maximum_per_second_request_rate: float = 1000,
|
|
42
|
-
minimum_per_second_request_rate: float = 0.1,
|
|
43
|
-
enforcement_window_minutes: float = 1,
|
|
44
|
-
rate_reduction_factor: float = 0.5,
|
|
45
|
-
rate_increase_factor: float = 0.01,
|
|
46
|
-
cooldown_seconds: float = 5,
|
|
47
|
-
):
|
|
48
|
-
now = time.time()
|
|
49
|
-
self._initial_rate = initial_per_second_request_rate
|
|
50
|
-
self.rate = initial_per_second_request_rate
|
|
51
|
-
self.maximum_rate = maximum_per_second_request_rate
|
|
52
|
-
self.minimum_rate = minimum_per_second_request_rate
|
|
53
|
-
self.rate_reduction_factor = rate_reduction_factor
|
|
54
|
-
self.enforcement_window = enforcement_window_minutes * 60
|
|
55
|
-
self.rate_increase_factor = rate_increase_factor
|
|
56
|
-
self.cooldown = cooldown_seconds
|
|
57
|
-
self.last_rate_update = now
|
|
58
|
-
self.last_checked = now
|
|
59
|
-
self.last_error = now - self.cooldown
|
|
60
|
-
self.tokens = 0.0
|
|
61
|
-
|
|
62
|
-
def increase_rate(self) -> None:
|
|
63
|
-
time_since_last_update = time.time() - self.last_rate_update
|
|
64
|
-
if time_since_last_update > self.enforcement_window:
|
|
65
|
-
self.rate = self._initial_rate
|
|
66
|
-
else:
|
|
67
|
-
self.rate *= exp(self.rate_increase_factor * time_since_last_update)
|
|
68
|
-
self.rate = min(self.rate, self.maximum_rate)
|
|
69
|
-
self.last_rate_update = time.time()
|
|
70
|
-
|
|
71
|
-
def on_rate_limit_error(self, request_start_time: float, verbose: bool = False) -> None:
|
|
72
|
-
now = time.time()
|
|
73
|
-
if request_start_time < (self.last_error + self.cooldown):
|
|
74
|
-
# do not reduce the rate for concurrent requests
|
|
75
|
-
return
|
|
76
|
-
|
|
77
|
-
original_rate = self.rate
|
|
78
|
-
|
|
79
|
-
self.rate = original_rate * self.rate_reduction_factor
|
|
80
|
-
printif(
|
|
81
|
-
verbose, f"Reducing rate from {original_rate} to {self.rate} after rate limit error"
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
self.rate = max(self.rate, self.minimum_rate)
|
|
85
|
-
|
|
86
|
-
# reset request tokens on a rate limit error
|
|
87
|
-
self.tokens = 0
|
|
88
|
-
self.last_checked = now
|
|
89
|
-
self.last_rate_update = now
|
|
90
|
-
self.last_error = now
|
|
91
|
-
time.sleep(self.cooldown) # block for a bit to let the rate limit reset
|
|
92
|
-
|
|
93
|
-
def max_tokens(self) -> float:
|
|
94
|
-
return self.rate * self.enforcement_window
|
|
95
|
-
|
|
96
|
-
def available_requests(self) -> float:
|
|
97
|
-
now = time.time()
|
|
98
|
-
time_since_last_checked = time.time() - self.last_checked
|
|
99
|
-
self.tokens = min(self.max_tokens(), self.rate * time_since_last_checked + self.tokens)
|
|
100
|
-
self.last_checked = now
|
|
101
|
-
return self.tokens
|
|
102
|
-
|
|
103
|
-
def make_request_if_ready(self) -> None:
|
|
104
|
-
if self.available_requests() <= 1:
|
|
105
|
-
raise UnavailableTokensError
|
|
106
|
-
self.tokens -= 1
|
|
107
|
-
|
|
108
|
-
def wait_until_ready(
|
|
109
|
-
self,
|
|
110
|
-
max_wait_time: float = 300,
|
|
111
|
-
) -> None:
|
|
112
|
-
start = time.time()
|
|
113
|
-
while (time.time() - start) < max_wait_time:
|
|
114
|
-
try:
|
|
115
|
-
self.increase_rate()
|
|
116
|
-
self.make_request_if_ready()
|
|
117
|
-
break
|
|
118
|
-
except UnavailableTokensError:
|
|
119
|
-
time.sleep(0.1 / self.rate)
|
|
120
|
-
continue
|
|
121
|
-
|
|
122
|
-
async def async_wait_until_ready(
|
|
123
|
-
self,
|
|
124
|
-
max_wait_time: float = 10, # defeat the token bucket rate limiter at low rates (<.1 req/s)
|
|
125
|
-
) -> None:
|
|
126
|
-
start = time.time()
|
|
127
|
-
while (time.time() - start) < max_wait_time:
|
|
128
|
-
try:
|
|
129
|
-
self.increase_rate()
|
|
130
|
-
self.make_request_if_ready()
|
|
131
|
-
break
|
|
132
|
-
except UnavailableTokensError:
|
|
133
|
-
await asyncio.sleep(0.1 / self.rate)
|
|
134
|
-
continue
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
class RateLimitError(PhoenixException): ...
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
class RateLimiter:
|
|
141
|
-
def __init__(
|
|
142
|
-
self,
|
|
143
|
-
rate_limit_error: Optional[Type[BaseException]] = None,
|
|
144
|
-
max_rate_limit_retries: int = 3,
|
|
145
|
-
initial_per_second_request_rate: float = 1,
|
|
146
|
-
maximum_per_second_request_rate: float = 50,
|
|
147
|
-
enforcement_window_minutes: float = 1,
|
|
148
|
-
rate_reduction_factor: float = 0.5,
|
|
149
|
-
rate_increase_factor: float = 0.01,
|
|
150
|
-
cooldown_seconds: float = 5,
|
|
151
|
-
verbose: bool = False,
|
|
152
|
-
) -> None:
|
|
153
|
-
self._rate_limit_error: Tuple[Type[BaseException], ...]
|
|
154
|
-
self._rate_limit_error = (rate_limit_error,) if rate_limit_error is not None else tuple()
|
|
155
|
-
|
|
156
|
-
self._max_rate_limit_retries = max_rate_limit_retries
|
|
157
|
-
self._throttler = AdaptiveTokenBucket(
|
|
158
|
-
initial_per_second_request_rate=initial_per_second_request_rate,
|
|
159
|
-
maximum_per_second_request_rate=maximum_per_second_request_rate,
|
|
160
|
-
enforcement_window_minutes=enforcement_window_minutes,
|
|
161
|
-
rate_reduction_factor=rate_reduction_factor,
|
|
162
|
-
rate_increase_factor=rate_increase_factor,
|
|
163
|
-
cooldown_seconds=cooldown_seconds,
|
|
164
|
-
)
|
|
165
|
-
self._rate_limit_handling: Optional[asyncio.Event] = None
|
|
166
|
-
self._rate_limit_handling_lock: Optional[asyncio.Lock] = None
|
|
167
|
-
self._current_loop: Optional[asyncio.AbstractEventLoop] = None
|
|
168
|
-
self._verbose = verbose
|
|
169
|
-
|
|
170
|
-
def limit(
|
|
171
|
-
self, fn: Callable[ParameterSpec, GenericType]
|
|
172
|
-
) -> Callable[ParameterSpec, GenericType]:
|
|
173
|
-
@wraps(fn)
|
|
174
|
-
def wrapper(*args: Any, **kwargs: Any) -> GenericType:
|
|
175
|
-
try:
|
|
176
|
-
self._throttler.wait_until_ready()
|
|
177
|
-
request_start_time = time.time()
|
|
178
|
-
return fn(*args, **kwargs)
|
|
179
|
-
except self._rate_limit_error:
|
|
180
|
-
self._throttler.on_rate_limit_error(request_start_time, verbose=self._verbose)
|
|
181
|
-
for _attempt in range(self._max_rate_limit_retries):
|
|
182
|
-
try:
|
|
183
|
-
request_start_time = time.time()
|
|
184
|
-
self._throttler.wait_until_ready()
|
|
185
|
-
return fn(*args, **kwargs)
|
|
186
|
-
except self._rate_limit_error:
|
|
187
|
-
self._throttler.on_rate_limit_error(
|
|
188
|
-
request_start_time, verbose=self._verbose
|
|
189
|
-
)
|
|
190
|
-
continue
|
|
191
|
-
raise RateLimitError(f"Exceeded max ({self._max_rate_limit_retries}) retries")
|
|
192
|
-
|
|
193
|
-
return wrapper
|
|
194
|
-
|
|
195
|
-
def _initialize_async_primitives(self) -> None:
|
|
196
|
-
"""
|
|
197
|
-
Lazily initialize async primitives to ensure they are created in the correct event loop.
|
|
198
|
-
"""
|
|
199
|
-
|
|
200
|
-
loop = asyncio.get_running_loop()
|
|
201
|
-
if loop is not self._current_loop:
|
|
202
|
-
self._current_loop = loop
|
|
203
|
-
self._rate_limit_handling = asyncio.Event()
|
|
204
|
-
self._rate_limit_handling.set()
|
|
205
|
-
self._rate_limit_handling_lock = asyncio.Lock()
|
|
206
|
-
|
|
207
|
-
def alimit(
|
|
208
|
-
self, fn: AsyncCallable[ParameterSpec, GenericType]
|
|
209
|
-
) -> AsyncCallable[ParameterSpec, GenericType]:
|
|
210
|
-
@wraps(fn)
|
|
211
|
-
async def wrapper(*args: Any, **kwargs: Any) -> GenericType:
|
|
212
|
-
self._initialize_async_primitives()
|
|
213
|
-
assert self._rate_limit_handling_lock is not None and isinstance(
|
|
214
|
-
self._rate_limit_handling_lock, asyncio.Lock
|
|
215
|
-
)
|
|
216
|
-
assert self._rate_limit_handling is not None and isinstance(
|
|
217
|
-
self._rate_limit_handling, asyncio.Event
|
|
218
|
-
)
|
|
219
|
-
try:
|
|
220
|
-
try:
|
|
221
|
-
await asyncio.wait_for(self._rate_limit_handling.wait(), 120)
|
|
222
|
-
except asyncio.TimeoutError:
|
|
223
|
-
self._rate_limit_handling.set() # Set the event as a failsafe
|
|
224
|
-
await self._throttler.async_wait_until_ready()
|
|
225
|
-
request_start_time = time.time()
|
|
226
|
-
return await fn(*args, **kwargs)
|
|
227
|
-
except self._rate_limit_error:
|
|
228
|
-
async with self._rate_limit_handling_lock:
|
|
229
|
-
self._rate_limit_handling.clear() # prevent new requests from starting
|
|
230
|
-
self._throttler.on_rate_limit_error(request_start_time, verbose=self._verbose)
|
|
231
|
-
try:
|
|
232
|
-
for _attempt in range(self._max_rate_limit_retries):
|
|
233
|
-
try:
|
|
234
|
-
request_start_time = time.time()
|
|
235
|
-
await self._throttler.async_wait_until_ready()
|
|
236
|
-
return await fn(*args, **kwargs)
|
|
237
|
-
except self._rate_limit_error:
|
|
238
|
-
self._throttler.on_rate_limit_error(
|
|
239
|
-
request_start_time, verbose=self._verbose
|
|
240
|
-
)
|
|
241
|
-
continue
|
|
242
|
-
finally:
|
|
243
|
-
self._rate_limit_handling.set() # allow new requests to start
|
|
244
|
-
raise RateLimitError(f"Exceeded max ({self._max_rate_limit_retries}) retries")
|
|
245
|
-
|
|
246
|
-
return wrapper
|
|
@@ -1,173 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from dataclasses import dataclass, field
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Dict, List
|
|
4
|
-
|
|
5
|
-
from phoenix.experimental.evals.models.base import BaseEvalModel
|
|
6
|
-
from phoenix.experimental.evals.models.rate_limiters import RateLimiter
|
|
7
|
-
from phoenix.utilities.logging import printif
|
|
8
|
-
|
|
9
|
-
if TYPE_CHECKING:
|
|
10
|
-
from tiktoken import Encoding
|
|
11
|
-
|
|
12
|
-
logger = logging.getLogger(__name__)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
# https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
|
|
16
|
-
MODEL_TOKEN_LIMIT_MAPPING = {
|
|
17
|
-
"gemini-pro": 32760,
|
|
18
|
-
"gemini-pro-vision": 16384,
|
|
19
|
-
}
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
@dataclass
|
|
23
|
-
class GeminiModel(BaseEvalModel):
|
|
24
|
-
# The vertex SDK runs into connection pool limits at high concurrency
|
|
25
|
-
default_concurrency: int = 5
|
|
26
|
-
|
|
27
|
-
model: str = "gemini-pro"
|
|
28
|
-
"""The model name to use."""
|
|
29
|
-
temperature: float = 0.0
|
|
30
|
-
"""What sampling temperature to use."""
|
|
31
|
-
max_tokens: int = 256
|
|
32
|
-
"""The maximum number of tokens to generate in the completion."""
|
|
33
|
-
top_p: float = 1
|
|
34
|
-
"""Total probability mass of tokens to consider at each step."""
|
|
35
|
-
top_k: int = 32
|
|
36
|
-
"""The cutoff where the model no longer selects the words"""
|
|
37
|
-
stop_sequences: List[str] = field(default_factory=list)
|
|
38
|
-
"""If the model encounters a stop sequence, it stops generating further tokens. """
|
|
39
|
-
max_retries: int = 6
|
|
40
|
-
"""Maximum number of retries to make when generating."""
|
|
41
|
-
retry_min_seconds: int = 10
|
|
42
|
-
"""Minimum number of seconds to wait when retrying."""
|
|
43
|
-
retry_max_seconds: int = 60
|
|
44
|
-
"""Maximum number of seconds to wait when retrying."""
|
|
45
|
-
|
|
46
|
-
def __post_init__(self) -> None:
|
|
47
|
-
self._init_client()
|
|
48
|
-
self._init_rate_limiter()
|
|
49
|
-
|
|
50
|
-
def reload_client(self) -> None:
|
|
51
|
-
self._init_client()
|
|
52
|
-
|
|
53
|
-
def _init_client(self) -> None:
|
|
54
|
-
try:
|
|
55
|
-
from google.api_core import exceptions # type:ignore
|
|
56
|
-
from vertexai.preview import generative_models as vertex # type:ignore
|
|
57
|
-
|
|
58
|
-
self._vertex = vertex
|
|
59
|
-
self._gcp_exceptions = exceptions
|
|
60
|
-
self._model = self._vertex.GenerativeModel(self.model)
|
|
61
|
-
except ImportError:
|
|
62
|
-
self._raise_import_error(
|
|
63
|
-
package_name="vertexai",
|
|
64
|
-
)
|
|
65
|
-
|
|
66
|
-
def _init_rate_limiter(self) -> None:
|
|
67
|
-
self._rate_limiter = RateLimiter(
|
|
68
|
-
rate_limit_error=self._gcp_exceptions.ResourceExhausted,
|
|
69
|
-
max_rate_limit_retries=10,
|
|
70
|
-
initial_per_second_request_rate=1,
|
|
71
|
-
maximum_per_second_request_rate=20,
|
|
72
|
-
enforcement_window_minutes=1,
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
@property
|
|
76
|
-
def encoder(self) -> "Encoding":
|
|
77
|
-
raise TypeError("Gemini models contain their own token counting")
|
|
78
|
-
|
|
79
|
-
def get_tokens_from_text(self, text: str) -> List[int]:
|
|
80
|
-
raise NotImplementedError
|
|
81
|
-
|
|
82
|
-
def get_text_from_tokens(self, tokens: List[int]) -> str:
|
|
83
|
-
raise NotImplementedError
|
|
84
|
-
|
|
85
|
-
@property
|
|
86
|
-
def max_context_size(self) -> int:
|
|
87
|
-
context_size = MODEL_TOKEN_LIMIT_MAPPING.get(self.model, None)
|
|
88
|
-
|
|
89
|
-
if context_size is None:
|
|
90
|
-
raise ValueError(
|
|
91
|
-
"Can't determine maximum context size. An unknown model name was "
|
|
92
|
-
+ f"used: {self.model}. Please set the `max_content_size` argument"
|
|
93
|
-
+ "when using fine-tuned models. "
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
return context_size
|
|
97
|
-
|
|
98
|
-
@property
|
|
99
|
-
def generation_config(self) -> Dict[str, Any]:
|
|
100
|
-
return {
|
|
101
|
-
"temperature": self.temperature,
|
|
102
|
-
"max_output_tokens": self.max_tokens,
|
|
103
|
-
"top_p": self.top_p,
|
|
104
|
-
"top_k": self.top_k,
|
|
105
|
-
"stop_sequences": self.stop_sequences,
|
|
106
|
-
}
|
|
107
|
-
|
|
108
|
-
def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
|
|
109
|
-
# instruction is an invalid input to Gemini models, it is passed in by
|
|
110
|
-
# BaseEvalModel.__call__ and needs to be removed
|
|
111
|
-
kwargs.pop("instruction", None)
|
|
112
|
-
|
|
113
|
-
@self._rate_limiter.limit
|
|
114
|
-
def _rate_limited_completion(
|
|
115
|
-
prompt: str, generation_config: Dict[str, Any], **kwargs: Any
|
|
116
|
-
) -> Any:
|
|
117
|
-
response = self._model.generate_content(
|
|
118
|
-
contents=prompt, generation_config=generation_config, **kwargs
|
|
119
|
-
)
|
|
120
|
-
return self._parse_response_candidates(response)
|
|
121
|
-
|
|
122
|
-
response = _rate_limited_completion(
|
|
123
|
-
prompt=prompt,
|
|
124
|
-
generation_config=self.generation_config,
|
|
125
|
-
**kwargs,
|
|
126
|
-
)
|
|
127
|
-
|
|
128
|
-
return str(response)
|
|
129
|
-
|
|
130
|
-
async def _async_generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
|
|
131
|
-
# instruction is an invalid input to Gemini models, it is passed in by
|
|
132
|
-
# BaseEvalModel.__call__ and needs to be removed
|
|
133
|
-
kwargs.pop("instruction", None)
|
|
134
|
-
|
|
135
|
-
@self._rate_limiter.alimit
|
|
136
|
-
async def _rate_limited_completion(
|
|
137
|
-
prompt: str, generation_config: Dict[str, Any], **kwargs: Any
|
|
138
|
-
) -> Any:
|
|
139
|
-
response = await self._model.generate_content_async(
|
|
140
|
-
contents=prompt, generation_config=generation_config, **kwargs
|
|
141
|
-
)
|
|
142
|
-
return self._parse_response_candidates(response)
|
|
143
|
-
|
|
144
|
-
response = await _rate_limited_completion(
|
|
145
|
-
prompt=prompt,
|
|
146
|
-
generation_config=self.generation_config,
|
|
147
|
-
**kwargs,
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
return str(response)
|
|
151
|
-
|
|
152
|
-
def _parse_response_candidates(self, response: Any) -> Any:
|
|
153
|
-
if hasattr(response, "candidates"):
|
|
154
|
-
if isinstance(response.candidates, list) and len(response.candidates) > 0:
|
|
155
|
-
try:
|
|
156
|
-
candidate = response.candidates[0].text
|
|
157
|
-
except ValueError:
|
|
158
|
-
printif(
|
|
159
|
-
self._verbose, "The 'candidates' object does not have a 'text' attribute."
|
|
160
|
-
)
|
|
161
|
-
printif(self._verbose, str(response.candidates[0]))
|
|
162
|
-
candidate = ""
|
|
163
|
-
else:
|
|
164
|
-
printif(
|
|
165
|
-
self._verbose,
|
|
166
|
-
"The 'candidates' attribute of 'response' is either not a list or is empty.",
|
|
167
|
-
)
|
|
168
|
-
printif(self._verbose, str(response))
|
|
169
|
-
candidate = ""
|
|
170
|
-
else:
|
|
171
|
-
printif(self._verbose, "The 'response' object does not have a 'candidates' attribute.")
|
|
172
|
-
candidate = ""
|
|
173
|
-
return candidate
|
|
@@ -1,186 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import warnings
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
5
|
-
|
|
6
|
-
from phoenix.experimental.evals.models.base import BaseEvalModel
|
|
7
|
-
|
|
8
|
-
if TYPE_CHECKING:
|
|
9
|
-
from google.auth.credentials import Credentials # type:ignore
|
|
10
|
-
|
|
11
|
-
logger = logging.getLogger(__name__)
|
|
12
|
-
|
|
13
|
-
MINIMUM_VERTEX_AI_VERSION = "1.33.0"
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
@dataclass
|
|
17
|
-
class VertexAIModel(BaseEvalModel):
|
|
18
|
-
project: Optional[str] = None
|
|
19
|
-
"project (str): The default project to use when making API calls."
|
|
20
|
-
location: Optional[str] = None
|
|
21
|
-
"location (str): The default location to use when making API calls. If not "
|
|
22
|
-
"set defaults to us-central-1."
|
|
23
|
-
credentials: Optional["Credentials"] = None
|
|
24
|
-
model: str = "text-bison"
|
|
25
|
-
tuned_model: Optional[str] = None
|
|
26
|
-
"The name of a tuned model. If provided, model is ignored."
|
|
27
|
-
max_retries: int = 6
|
|
28
|
-
"""Maximum number of retries to make when generating."""
|
|
29
|
-
retry_min_seconds: int = 10
|
|
30
|
-
"""Minimum number of seconds to wait when retrying."""
|
|
31
|
-
retry_max_seconds: int = 60
|
|
32
|
-
"""Maximum number of seconds to wait when retrying."""
|
|
33
|
-
temperature: float = 0.0
|
|
34
|
-
"""What sampling temperature to use."""
|
|
35
|
-
max_tokens: int = 256
|
|
36
|
-
"""The maximum number of tokens to generate in the completion.
|
|
37
|
-
-1 returns as many tokens as possible given the prompt and
|
|
38
|
-
the models maximal context size."""
|
|
39
|
-
top_p: float = 0.95
|
|
40
|
-
"Tokens are selected from most probable to least until the sum of their "
|
|
41
|
-
"probabilities equals the top-p value. Top-p is ignored for Codey models."
|
|
42
|
-
top_k: int = 40
|
|
43
|
-
"How the model selects tokens for output, the next token is selected from "
|
|
44
|
-
"among the top-k most probable tokens. Top-k is ignored for Codey models."
|
|
45
|
-
|
|
46
|
-
# Deprecated fields
|
|
47
|
-
model_name: Optional[str] = None
|
|
48
|
-
"""
|
|
49
|
-
.. deprecated:: 3.0.0
|
|
50
|
-
use `model` instead. This will be removed in a future release.
|
|
51
|
-
"""
|
|
52
|
-
tuned_model_name: Optional[str] = None
|
|
53
|
-
"""
|
|
54
|
-
.. deprecated:: 3.0.0
|
|
55
|
-
use `tuned_model` instead. This will be removed in a future release.
|
|
56
|
-
"""
|
|
57
|
-
|
|
58
|
-
def __post_init__(self) -> None:
|
|
59
|
-
self._migrate_model_name()
|
|
60
|
-
self._init_environment()
|
|
61
|
-
self._init_vertex_ai()
|
|
62
|
-
self._instantiate_model()
|
|
63
|
-
|
|
64
|
-
def _migrate_model_name(self) -> None:
|
|
65
|
-
if self.model_name is not None:
|
|
66
|
-
warning_message = (
|
|
67
|
-
"The `model_name` field is deprecated. Use `model` instead. "
|
|
68
|
-
+ "This will be removed in a future release."
|
|
69
|
-
)
|
|
70
|
-
warnings.warn(
|
|
71
|
-
warning_message,
|
|
72
|
-
DeprecationWarning,
|
|
73
|
-
)
|
|
74
|
-
print(warning_message)
|
|
75
|
-
self.model = self.model_name
|
|
76
|
-
self.model_name = None
|
|
77
|
-
if self.tuned_model_name is not None:
|
|
78
|
-
warning_message = (
|
|
79
|
-
"`tuned_model_name` field is deprecated. Use `tuned_model` instead. "
|
|
80
|
-
+ "This will be removed in a future release."
|
|
81
|
-
)
|
|
82
|
-
warnings.warn(
|
|
83
|
-
warning_message,
|
|
84
|
-
DeprecationWarning,
|
|
85
|
-
)
|
|
86
|
-
print(warning_message)
|
|
87
|
-
self.tuned_model = self.tuned_model_name
|
|
88
|
-
self.tuned_model_name = None
|
|
89
|
-
|
|
90
|
-
def _init_environment(self) -> None:
|
|
91
|
-
try:
|
|
92
|
-
import google.api_core.exceptions as google_exceptions # type:ignore
|
|
93
|
-
import vertexai # type:ignore
|
|
94
|
-
|
|
95
|
-
self._vertexai = vertexai
|
|
96
|
-
self._google_exceptions = google_exceptions
|
|
97
|
-
except ImportError:
|
|
98
|
-
self._raise_import_error(
|
|
99
|
-
package_display_name="VertexAI",
|
|
100
|
-
package_name="google-cloud-aiplatform",
|
|
101
|
-
package_min_version=MINIMUM_VERTEX_AI_VERSION,
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
def _init_vertex_ai(self) -> None:
|
|
105
|
-
self._vertexai.init(**self._init_params)
|
|
106
|
-
|
|
107
|
-
def _instantiate_model(self) -> None:
|
|
108
|
-
if self.is_codey_model:
|
|
109
|
-
from vertexai.preview.language_models import CodeGenerationModel # type:ignore
|
|
110
|
-
|
|
111
|
-
model = CodeGenerationModel
|
|
112
|
-
else:
|
|
113
|
-
from vertexai.preview.language_models import TextGenerationModel
|
|
114
|
-
|
|
115
|
-
model = TextGenerationModel
|
|
116
|
-
|
|
117
|
-
if self.tuned_model:
|
|
118
|
-
self._model = model.get_tuned_model(self.tuned_model)
|
|
119
|
-
else:
|
|
120
|
-
self._model = model.from_pretrained(self.model)
|
|
121
|
-
|
|
122
|
-
def verbose_generation_info(self) -> str:
|
|
123
|
-
return f"VertexAI invocation parameters: {self.invocation_params}"
|
|
124
|
-
|
|
125
|
-
async def _async_generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
|
|
126
|
-
return self._generate(prompt, **kwargs)
|
|
127
|
-
|
|
128
|
-
def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
|
|
129
|
-
invoke_params = self.invocation_params
|
|
130
|
-
response = self._model.predict(
|
|
131
|
-
prompt=prompt,
|
|
132
|
-
**invoke_params,
|
|
133
|
-
)
|
|
134
|
-
return str(response.text)
|
|
135
|
-
|
|
136
|
-
@property
|
|
137
|
-
def is_codey_model(self) -> bool:
|
|
138
|
-
return is_codey_model(self.tuned_model or self.model)
|
|
139
|
-
|
|
140
|
-
@property
|
|
141
|
-
def _init_params(self) -> Dict[str, Any]:
|
|
142
|
-
return {
|
|
143
|
-
"project": self.project,
|
|
144
|
-
"location": self.location,
|
|
145
|
-
"credentials": self.credentials,
|
|
146
|
-
}
|
|
147
|
-
|
|
148
|
-
@property
|
|
149
|
-
def invocation_params(self) -> Dict[str, Any]:
|
|
150
|
-
params = {
|
|
151
|
-
"temperature": self.temperature,
|
|
152
|
-
"max_output_tokens": self.max_tokens,
|
|
153
|
-
}
|
|
154
|
-
if self.is_codey_model:
|
|
155
|
-
return params
|
|
156
|
-
else:
|
|
157
|
-
return {
|
|
158
|
-
**params,
|
|
159
|
-
"top_k": self.top_k,
|
|
160
|
-
"top_p": self.top_p,
|
|
161
|
-
}
|
|
162
|
-
|
|
163
|
-
def get_tokens_from_text(self, text: str) -> List[int]:
|
|
164
|
-
raise NotImplementedError
|
|
165
|
-
|
|
166
|
-
def get_text_from_tokens(self, tokens: List[int]) -> str:
|
|
167
|
-
raise NotImplementedError
|
|
168
|
-
|
|
169
|
-
@property
|
|
170
|
-
def max_context_size(self) -> int:
|
|
171
|
-
raise NotImplementedError
|
|
172
|
-
|
|
173
|
-
@property
|
|
174
|
-
def encoder(self): # type:ignore
|
|
175
|
-
raise NotImplementedError
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
def is_codey_model(model_name: str) -> bool:
|
|
179
|
-
"""Returns True if the model name is a Codey model.
|
|
180
|
-
|
|
181
|
-
Args:
|
|
182
|
-
model_name: The model name to check.
|
|
183
|
-
|
|
184
|
-
Returns: True if the model name is a Codey model.
|
|
185
|
-
"""
|
|
186
|
-
return "code" in model_name
|
|
@@ -1,96 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Helper functions for evaluating the retrieval step of retrieval-augmented generation.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import List, Optional
|
|
6
|
-
|
|
7
|
-
from tenacity import (
|
|
8
|
-
retry,
|
|
9
|
-
stop_after_attempt,
|
|
10
|
-
wait_random_exponential,
|
|
11
|
-
)
|
|
12
|
-
|
|
13
|
-
_EVALUATION_SYSTEM_MESSAGE = (
|
|
14
|
-
"You will be given a query and a reference text. "
|
|
15
|
-
"You must determine whether the reference text contains an answer to the input query. "
|
|
16
|
-
'Your response must be single word, either "relevant" or "irrelevant", '
|
|
17
|
-
"and should not contain any text or characters aside from that word. "
|
|
18
|
-
'"irrelevant" means that the reference text does not contain an answer to the query. '
|
|
19
|
-
'"relevant" means the reference text contains an answer to the query.'
|
|
20
|
-
)
|
|
21
|
-
_QUERY_CONTEXT_PROMPT_TEMPLATE = """# Query: {query}
|
|
22
|
-
|
|
23
|
-
# Reference: {reference}
|
|
24
|
-
|
|
25
|
-
# Answer ("relevant" or "irrelevant"): """
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def compute_precisions_at_k(
|
|
29
|
-
relevance_classifications: List[Optional[bool]],
|
|
30
|
-
) -> List[Optional[float]]:
|
|
31
|
-
"""Given a list of relevance classifications, computes precision@k for k = 1, 2, ..., n, where
|
|
32
|
-
n is the length of the input list.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
relevance_classifications (List[Optional[bool]]): A list of relevance classifications for a
|
|
36
|
-
set of retrieved documents, sorted by order of retrieval (i.e., the first element is the
|
|
37
|
-
classification for the first retrieved document, the second element is the
|
|
38
|
-
classification for the second retrieved document, etc.). The list may contain None
|
|
39
|
-
values, which indicate that the relevance classification for the corresponding document
|
|
40
|
-
is unknown.
|
|
41
|
-
|
|
42
|
-
Returns:
|
|
43
|
-
List[Optional[float]]: A list of precision@k values for k = 1, 2, ..., n, where n is the
|
|
44
|
-
length of the input list. The first element is the precision@1 value, the second element
|
|
45
|
-
is the precision@2 value, etc. If the input list contains any None values, those values
|
|
46
|
-
are omitted when computing the precision@k values.
|
|
47
|
-
"""
|
|
48
|
-
precisions_at_k = []
|
|
49
|
-
num_relevant_classifications = 0
|
|
50
|
-
num_non_none_classifications = 0
|
|
51
|
-
for relevance_classification in relevance_classifications:
|
|
52
|
-
if isinstance(relevance_classification, bool):
|
|
53
|
-
num_non_none_classifications += 1
|
|
54
|
-
num_relevant_classifications += int(relevance_classification)
|
|
55
|
-
precisions_at_k.append(
|
|
56
|
-
num_relevant_classifications / num_non_none_classifications
|
|
57
|
-
if num_non_none_classifications > 0
|
|
58
|
-
else None
|
|
59
|
-
)
|
|
60
|
-
return precisions_at_k
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
|
64
|
-
def classify_relevance(query: str, document: str, model_name: str) -> Optional[bool]:
|
|
65
|
-
"""Given a query and a document, determines whether the document contains an answer to the
|
|
66
|
-
query.
|
|
67
|
-
|
|
68
|
-
Args:
|
|
69
|
-
query (str): The query text. document (str): The document text. model_name (str): The name
|
|
70
|
-
of the OpenAI API model to use for the classification.
|
|
71
|
-
|
|
72
|
-
Returns:
|
|
73
|
-
Optional[bool]: A boolean indicating whether the document contains an answer to the query
|
|
74
|
-
(True meaning relevant, False meaning irrelevant), or None if the LLM produces an
|
|
75
|
-
unparseable output.
|
|
76
|
-
"""
|
|
77
|
-
|
|
78
|
-
from openai import OpenAI
|
|
79
|
-
|
|
80
|
-
client = OpenAI()
|
|
81
|
-
|
|
82
|
-
prompt = _QUERY_CONTEXT_PROMPT_TEMPLATE.format(
|
|
83
|
-
query=query,
|
|
84
|
-
reference=document,
|
|
85
|
-
)
|
|
86
|
-
response = client.chat.completions.create(
|
|
87
|
-
messages=[
|
|
88
|
-
{"role": "system", "content": _EVALUATION_SYSTEM_MESSAGE},
|
|
89
|
-
{"role": "user", "content": prompt},
|
|
90
|
-
],
|
|
91
|
-
model=model_name,
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
raw_response_text = str(response.choices[0].message.content).strip()
|
|
95
|
-
relevance_classification = {"relevant": True, "irrelevant": False}.get(raw_response_text)
|
|
96
|
-
return relevance_classification
|