arize-phoenix 3.16.1__py3-none-any.whl → 7.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- arize_phoenix-7.7.1.dist-info/METADATA +261 -0
- arize_phoenix-7.7.1.dist-info/RECORD +345 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/WHEEL +1 -1
- arize_phoenix-7.7.1.dist-info/entry_points.txt +3 -0
- phoenix/__init__.py +86 -14
- phoenix/auth.py +309 -0
- phoenix/config.py +675 -45
- phoenix/core/model.py +32 -30
- phoenix/core/model_schema.py +102 -109
- phoenix/core/model_schema_adapter.py +48 -45
- phoenix/datetime_utils.py +24 -3
- phoenix/db/README.md +54 -0
- phoenix/db/__init__.py +4 -0
- phoenix/db/alembic.ini +85 -0
- phoenix/db/bulk_inserter.py +294 -0
- phoenix/db/engines.py +208 -0
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +113 -0
- phoenix/db/helpers.py +159 -0
- phoenix/db/insertion/constants.py +2 -0
- phoenix/db/insertion/dataset.py +227 -0
- phoenix/db/insertion/document_annotation.py +171 -0
- phoenix/db/insertion/evaluation.py +191 -0
- phoenix/db/insertion/helpers.py +98 -0
- phoenix/db/insertion/span.py +193 -0
- phoenix/db/insertion/span_annotation.py +158 -0
- phoenix/db/insertion/trace_annotation.py +158 -0
- phoenix/db/insertion/types.py +256 -0
- phoenix/db/migrate.py +86 -0
- phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
- phoenix/db/migrations/env.py +114 -0
- phoenix/db/migrations/script.py.mako +26 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
- phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
- phoenix/db/models.py +807 -0
- phoenix/exceptions.py +5 -1
- phoenix/experiments/__init__.py +6 -0
- phoenix/experiments/evaluators/__init__.py +29 -0
- phoenix/experiments/evaluators/base.py +158 -0
- phoenix/experiments/evaluators/code_evaluators.py +184 -0
- phoenix/experiments/evaluators/llm_evaluators.py +473 -0
- phoenix/experiments/evaluators/utils.py +236 -0
- phoenix/experiments/functions.py +772 -0
- phoenix/experiments/tracing.py +86 -0
- phoenix/experiments/types.py +726 -0
- phoenix/experiments/utils.py +25 -0
- phoenix/inferences/__init__.py +0 -0
- phoenix/{datasets → inferences}/errors.py +6 -5
- phoenix/{datasets → inferences}/fixtures.py +49 -42
- phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
- phoenix/{datasets → inferences}/schema.py +11 -11
- phoenix/{datasets → inferences}/validation.py +13 -14
- phoenix/logging/__init__.py +3 -0
- phoenix/logging/_config.py +90 -0
- phoenix/logging/_filter.py +6 -0
- phoenix/logging/_formatter.py +69 -0
- phoenix/metrics/__init__.py +5 -4
- phoenix/metrics/binning.py +4 -3
- phoenix/metrics/metrics.py +2 -1
- phoenix/metrics/mixins.py +7 -6
- phoenix/metrics/retrieval_metrics.py +2 -1
- phoenix/metrics/timeseries.py +5 -4
- phoenix/metrics/wrappers.py +9 -3
- phoenix/pointcloud/clustering.py +5 -5
- phoenix/pointcloud/pointcloud.py +7 -5
- phoenix/pointcloud/projectors.py +5 -6
- phoenix/pointcloud/umap_parameters.py +53 -52
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +44 -0
- phoenix/server/api/context.py +152 -9
- phoenix/server/api/dataloaders/__init__.py +91 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
- phoenix/server/api/dataloaders/cache/__init__.py +3 -0
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
- phoenix/server/api/dataloaders/document_evaluations.py +31 -0
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
- phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
- phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/record_counts.py +116 -0
- phoenix/server/api/dataloaders/session_io.py +79 -0
- phoenix/server/api/dataloaders/session_num_traces.py +30 -0
- phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
- phoenix/server/api/dataloaders/session_token_usages.py +41 -0
- phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
- phoenix/server/api/dataloaders/span_annotations.py +26 -0
- phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +57 -0
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/token_counts.py +124 -0
- phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
- phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +48 -0
- phoenix/server/api/helpers/__init__.py +12 -0
- phoenix/server/api/helpers/dataset_helpers.py +217 -0
- phoenix/server/api/helpers/experiment_run_filters.py +763 -0
- phoenix/server/api/helpers/playground_clients.py +948 -0
- phoenix/server/api/helpers/playground_registry.py +70 -0
- phoenix/server/api/helpers/playground_spans.py +455 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
- phoenix/server/api/input_types/ClearProjectInput.py +15 -0
- phoenix/server/api/input_types/ClusterInput.py +2 -2
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
- phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
- phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
- phoenix/server/api/input_types/DimensionFilter.py +4 -4
- phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
- phoenix/server/api/input_types/Granularity.py +1 -1
- phoenix/server/api/input_types/InvocationParameters.py +162 -0
- phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
- phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
- phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
- phoenix/server/api/input_types/SpanSort.py +134 -69
- phoenix/server/api/input_types/TemplateOptions.py +10 -0
- phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
- phoenix/server/api/input_types/UserRoleInput.py +9 -0
- phoenix/server/api/mutations/__init__.py +28 -0
- phoenix/server/api/mutations/api_key_mutations.py +167 -0
- phoenix/server/api/mutations/chat_mutations.py +593 -0
- phoenix/server/api/mutations/dataset_mutations.py +591 -0
- phoenix/server/api/mutations/experiment_mutations.py +75 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
- phoenix/server/api/mutations/project_mutations.py +57 -0
- phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
- phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
- phoenix/server/api/mutations/user_mutations.py +329 -0
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +17 -0
- phoenix/server/api/openapi/schema.py +16 -0
- phoenix/server/api/queries.py +738 -0
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +488 -0
- phoenix/server/api/routers/v1/__init__.py +64 -0
- phoenix/server/api/routers/v1/datasets.py +1017 -0
- phoenix/server/api/routers/v1/evaluations.py +362 -0
- phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
- phoenix/server/api/routers/v1/experiment_runs.py +167 -0
- phoenix/server/api/routers/v1/experiments.py +308 -0
- phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
- phoenix/server/api/routers/v1/spans.py +267 -0
- phoenix/server/api/routers/v1/traces.py +208 -0
- phoenix/server/api/routers/v1/utils.py +95 -0
- phoenix/server/api/schema.py +44 -241
- phoenix/server/api/subscriptions.py +597 -0
- phoenix/server/api/types/Annotation.py +21 -0
- phoenix/server/api/types/AnnotationSummary.py +55 -0
- phoenix/server/api/types/AnnotatorKind.py +16 -0
- phoenix/server/api/types/ApiKey.py +27 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
- phoenix/server/api/types/Cluster.py +25 -24
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/DataQualityMetric.py +31 -13
- phoenix/server/api/types/Dataset.py +288 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +32 -31
- phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
- phoenix/server/api/types/EmbeddingDimension.py +56 -49
- phoenix/server/api/types/Evaluation.py +25 -31
- phoenix/server/api/types/EvaluationSummary.py +30 -50
- phoenix/server/api/types/Event.py +20 -20
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +152 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +17 -0
- phoenix/server/api/types/ExperimentRun.py +119 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
- phoenix/server/api/types/GenerativeModel.py +9 -0
- phoenix/server/api/types/GenerativeProvider.py +85 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/LabelFraction.py +7 -0
- phoenix/server/api/types/MimeType.py +2 -2
- phoenix/server/api/types/Model.py +54 -54
- phoenix/server/api/types/PerformanceMetric.py +8 -5
- phoenix/server/api/types/Project.py +407 -142
- phoenix/server/api/types/ProjectSession.py +139 -0
- phoenix/server/api/types/Segments.py +4 -4
- phoenix/server/api/types/Span.py +221 -176
- phoenix/server/api/types/SpanAnnotation.py +43 -0
- phoenix/server/api/types/SpanIOValue.py +15 -0
- phoenix/server/api/types/SystemApiKey.py +9 -0
- phoenix/server/api/types/TemplateLanguage.py +10 -0
- phoenix/server/api/types/TimeSeries.py +19 -15
- phoenix/server/api/types/TokenUsage.py +11 -0
- phoenix/server/api/types/Trace.py +154 -0
- phoenix/server/api/types/TraceAnnotation.py +45 -0
- phoenix/server/api/types/UMAPPoints.py +7 -7
- phoenix/server/api/types/User.py +60 -0
- phoenix/server/api/types/UserApiKey.py +45 -0
- phoenix/server/api/types/UserRole.py +15 -0
- phoenix/server/api/types/node.py +4 -112
- phoenix/server/api/types/pagination.py +156 -57
- phoenix/server/api/utils.py +34 -0
- phoenix/server/app.py +864 -115
- phoenix/server/bearer_auth.py +163 -0
- phoenix/server/dml_event.py +136 -0
- phoenix/server/dml_event_handler.py +256 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +97 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +102 -0
- phoenix/server/jwt_store.py +505 -0
- phoenix/server/main.py +305 -116
- phoenix/server/oauth2.py +52 -0
- phoenix/server/openapi/__init__.py +0 -0
- phoenix/server/prometheus.py +111 -0
- phoenix/server/rate_limiters.py +188 -0
- phoenix/server/static/.vite/manifest.json +87 -0
- phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
- phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
- phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
- phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
- phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
- phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
- phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
- phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
- phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
- phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
- phoenix/server/telemetry.py +68 -0
- phoenix/server/templates/index.html +82 -23
- phoenix/server/thread_server.py +3 -3
- phoenix/server/types.py +275 -0
- phoenix/services.py +27 -18
- phoenix/session/client.py +743 -68
- phoenix/session/data_extractor.py +31 -7
- phoenix/session/evaluation.py +3 -9
- phoenix/session/session.py +263 -219
- phoenix/settings.py +22 -0
- phoenix/trace/__init__.py +2 -22
- phoenix/trace/attributes.py +338 -0
- phoenix/trace/dsl/README.md +116 -0
- phoenix/trace/dsl/filter.py +663 -213
- phoenix/trace/dsl/helpers.py +73 -21
- phoenix/trace/dsl/query.py +574 -201
- phoenix/trace/exporter.py +24 -19
- phoenix/trace/fixtures.py +368 -32
- phoenix/trace/otel.py +71 -219
- phoenix/trace/projects.py +3 -2
- phoenix/trace/schemas.py +33 -11
- phoenix/trace/span_evaluations.py +21 -16
- phoenix/trace/span_json_decoder.py +6 -4
- phoenix/trace/span_json_encoder.py +2 -2
- phoenix/trace/trace_dataset.py +47 -32
- phoenix/trace/utils.py +21 -4
- phoenix/utilities/__init__.py +0 -26
- phoenix/utilities/client.py +132 -0
- phoenix/utilities/deprecation.py +31 -0
- phoenix/utilities/error_handling.py +3 -2
- phoenix/utilities/json.py +109 -0
- phoenix/utilities/logging.py +8 -0
- phoenix/utilities/project.py +2 -2
- phoenix/utilities/re.py +49 -0
- phoenix/utilities/span_store.py +0 -23
- phoenix/utilities/template_formatters.py +99 -0
- phoenix/version.py +1 -1
- arize_phoenix-3.16.1.dist-info/METADATA +0 -495
- arize_phoenix-3.16.1.dist-info/RECORD +0 -178
- phoenix/core/project.py +0 -619
- phoenix/core/traces.py +0 -96
- phoenix/experimental/evals/__init__.py +0 -73
- phoenix/experimental/evals/evaluators.py +0 -413
- phoenix/experimental/evals/functions/__init__.py +0 -4
- phoenix/experimental/evals/functions/classify.py +0 -453
- phoenix/experimental/evals/functions/executor.py +0 -353
- phoenix/experimental/evals/functions/generate.py +0 -138
- phoenix/experimental/evals/functions/processing.py +0 -76
- phoenix/experimental/evals/models/__init__.py +0 -14
- phoenix/experimental/evals/models/anthropic.py +0 -175
- phoenix/experimental/evals/models/base.py +0 -170
- phoenix/experimental/evals/models/bedrock.py +0 -221
- phoenix/experimental/evals/models/litellm.py +0 -134
- phoenix/experimental/evals/models/openai.py +0 -448
- phoenix/experimental/evals/models/rate_limiters.py +0 -246
- phoenix/experimental/evals/models/vertex.py +0 -173
- phoenix/experimental/evals/models/vertexai.py +0 -186
- phoenix/experimental/evals/retrievals.py +0 -96
- phoenix/experimental/evals/templates/__init__.py +0 -50
- phoenix/experimental/evals/templates/default_templates.py +0 -472
- phoenix/experimental/evals/templates/template.py +0 -195
- phoenix/experimental/evals/utils/__init__.py +0 -172
- phoenix/experimental/evals/utils/threads.py +0 -27
- phoenix/server/api/helpers.py +0 -11
- phoenix/server/api/routers/evaluation_handler.py +0 -109
- phoenix/server/api/routers/span_handler.py +0 -70
- phoenix/server/api/routers/trace_handler.py +0 -60
- phoenix/server/api/types/DatasetRole.py +0 -23
- phoenix/server/static/index.css +0 -6
- phoenix/server/static/index.js +0 -7447
- phoenix/storage/span_store/__init__.py +0 -23
- phoenix/storage/span_store/text_file.py +0 -85
- phoenix/trace/dsl/missing.py +0 -60
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -35
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -102
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -30
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/LICENSE +0 -0
- /phoenix/{datasets → db/insertion}/__init__.py +0 -0
- /phoenix/{experimental → db/migrations}/__init__.py +0 -0
- /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
|
@@ -1,134 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import warnings
|
|
3
|
-
from dataclasses import dataclass, field
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
5
|
-
|
|
6
|
-
from phoenix.experimental.evals.models.base import BaseEvalModel
|
|
7
|
-
|
|
8
|
-
if TYPE_CHECKING:
|
|
9
|
-
from tiktoken import Encoding
|
|
10
|
-
|
|
11
|
-
logger = logging.getLogger(__name__)
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@dataclass
|
|
15
|
-
class LiteLLMModel(BaseEvalModel):
|
|
16
|
-
model: str = "gpt-3.5-turbo"
|
|
17
|
-
"""The model name to use."""
|
|
18
|
-
temperature: float = 0.0
|
|
19
|
-
"""What sampling temperature to use."""
|
|
20
|
-
max_tokens: int = 256
|
|
21
|
-
"""The maximum number of tokens to generate in the completion."""
|
|
22
|
-
top_p: float = 1
|
|
23
|
-
"""Total probability mass of tokens to consider at each step."""
|
|
24
|
-
num_retries: int = 6
|
|
25
|
-
"""Maximum number to retry a model if an RateLimitError, OpenAIError, or
|
|
26
|
-
ServiceUnavailableError occurs."""
|
|
27
|
-
request_timeout: int = 60
|
|
28
|
-
"""Maximum number of seconds to wait when retrying."""
|
|
29
|
-
model_kwargs: Dict[str, Any] = field(default_factory=dict)
|
|
30
|
-
"""Model specific params"""
|
|
31
|
-
|
|
32
|
-
# non-LiteLLM params
|
|
33
|
-
retry_min_seconds: int = 10
|
|
34
|
-
"""Minimum number of seconds to wait when retrying."""
|
|
35
|
-
max_content_size: Optional[int] = None
|
|
36
|
-
"""If you're using a fine-tuned model, set this to the maximum content size"""
|
|
37
|
-
|
|
38
|
-
# Deprecated fields
|
|
39
|
-
model_name: Optional[str] = None
|
|
40
|
-
"""
|
|
41
|
-
.. deprecated:: 3.0.0
|
|
42
|
-
use `model` instead. This will be removed in a future release.
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
def __post_init__(self) -> None:
|
|
46
|
-
self._migrate_model_name()
|
|
47
|
-
self._init_environment()
|
|
48
|
-
self._init_model_encoding()
|
|
49
|
-
|
|
50
|
-
def _migrate_model_name(self) -> None:
|
|
51
|
-
if self.model_name is not None:
|
|
52
|
-
warning_message = "The `model_name` field is deprecated. Use `model` instead. \
|
|
53
|
-
This will be removed in a future release."
|
|
54
|
-
warnings.warn(
|
|
55
|
-
warning_message,
|
|
56
|
-
DeprecationWarning,
|
|
57
|
-
)
|
|
58
|
-
print(warning_message)
|
|
59
|
-
self.model = self.model_name
|
|
60
|
-
self.model_name = None
|
|
61
|
-
|
|
62
|
-
def _init_environment(self) -> None:
|
|
63
|
-
try:
|
|
64
|
-
import litellm
|
|
65
|
-
from litellm import validate_environment
|
|
66
|
-
|
|
67
|
-
self._litellm = litellm
|
|
68
|
-
env_info = validate_environment(self._litellm.utils.get_llm_provider(self.model))
|
|
69
|
-
|
|
70
|
-
if not env_info["keys_in_environment"] and env_info["missing_keys"]:
|
|
71
|
-
raise RuntimeError(
|
|
72
|
-
f"Missing environment variable(s): '{str(env_info['missing_keys'])}', for "
|
|
73
|
-
f"model: {self.model}. \nFor additional information about the right "
|
|
74
|
-
"environment variables for specific model providers:\n"
|
|
75
|
-
"https://docs.litellm.ai/docs/completion/input#provider-specific-params."
|
|
76
|
-
)
|
|
77
|
-
except ImportError:
|
|
78
|
-
self._raise_import_error(
|
|
79
|
-
package_display_name="LiteLLM",
|
|
80
|
-
package_name="litellm",
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
def _init_model_encoding(self) -> None:
|
|
84
|
-
from litellm import decode, encode
|
|
85
|
-
|
|
86
|
-
self._encoding = encode
|
|
87
|
-
self._decoding = decode
|
|
88
|
-
|
|
89
|
-
@property
|
|
90
|
-
def max_context_size(self) -> int:
|
|
91
|
-
context_size = self.max_content_size or self._litellm.get_max_tokens(self.model).get(
|
|
92
|
-
"max_tokens", None
|
|
93
|
-
)
|
|
94
|
-
|
|
95
|
-
if context_size is None:
|
|
96
|
-
raise ValueError(
|
|
97
|
-
"Can't determine maximum context size. An unknown model was "
|
|
98
|
-
+ f"used: {self.model}."
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
return context_size
|
|
102
|
-
|
|
103
|
-
@property
|
|
104
|
-
def encoder(self) -> "Encoding":
|
|
105
|
-
raise NotImplementedError
|
|
106
|
-
|
|
107
|
-
def get_tokens_from_text(self, text: str) -> List[int]:
|
|
108
|
-
result: List[int] = self._encoding(model=self.model, text=text)
|
|
109
|
-
return result
|
|
110
|
-
|
|
111
|
-
def get_text_from_tokens(self, tokens: List[int]) -> str:
|
|
112
|
-
return str(self._decoding(model=self.model, tokens=tokens))
|
|
113
|
-
|
|
114
|
-
async def _async_generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
|
|
115
|
-
return self._generate(prompt, **kwargs)
|
|
116
|
-
|
|
117
|
-
def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
|
|
118
|
-
messages = self._get_messages_from_prompt(prompt)
|
|
119
|
-
response = self._litellm.completion(
|
|
120
|
-
model=self.model,
|
|
121
|
-
messages=messages,
|
|
122
|
-
temperature=self.temperature,
|
|
123
|
-
max_tokens=self.max_tokens,
|
|
124
|
-
top_p=self.top_p,
|
|
125
|
-
num_retries=self.num_retries,
|
|
126
|
-
request_timeout=self.request_timeout,
|
|
127
|
-
**self.model_kwargs,
|
|
128
|
-
)
|
|
129
|
-
return str(response.choices[0].message.content)
|
|
130
|
-
|
|
131
|
-
def _get_messages_from_prompt(self, prompt: str) -> List[Dict[str, str]]:
|
|
132
|
-
# LiteLLM requires prompts in the format of messages
|
|
133
|
-
# messages=[{"content": "ABC?","role": "user"}]
|
|
134
|
-
return [{"content": prompt, "role": "user"}]
|
|
@@ -1,448 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import os
|
|
3
|
-
import warnings
|
|
4
|
-
from dataclasses import dataclass, field, fields
|
|
5
|
-
from typing import (
|
|
6
|
-
TYPE_CHECKING,
|
|
7
|
-
Any,
|
|
8
|
-
Callable,
|
|
9
|
-
Dict,
|
|
10
|
-
List,
|
|
11
|
-
Optional,
|
|
12
|
-
Tuple,
|
|
13
|
-
Union,
|
|
14
|
-
get_args,
|
|
15
|
-
get_origin,
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
from phoenix.exceptions import PhoenixContextLimitExceeded
|
|
19
|
-
from phoenix.experimental.evals.models.base import BaseEvalModel
|
|
20
|
-
from phoenix.experimental.evals.models.rate_limiters import RateLimiter
|
|
21
|
-
|
|
22
|
-
if TYPE_CHECKING:
|
|
23
|
-
from tiktoken import Encoding
|
|
24
|
-
|
|
25
|
-
OPENAI_API_KEY_ENVVAR_NAME = "OPENAI_API_KEY"
|
|
26
|
-
MINIMUM_OPENAI_VERSION = "1.0.0"
|
|
27
|
-
MODEL_TOKEN_LIMIT_MAPPING = {
|
|
28
|
-
"gpt-3.5-turbo-instruct": 4096,
|
|
29
|
-
"gpt-3.5-turbo-0301": 4096,
|
|
30
|
-
"gpt-3.5-turbo-0613": 4096, # Current gpt-3.5-turbo default
|
|
31
|
-
"gpt-3.5-turbo-16k-0613": 16385,
|
|
32
|
-
"gpt-4-0314": 8192,
|
|
33
|
-
"gpt-4-0613": 8192, # Current gpt-4 default
|
|
34
|
-
"gpt-4-32k-0314": 32768,
|
|
35
|
-
"gpt-4-32k-0613": 32768,
|
|
36
|
-
"gpt-4-1106-preview": 128000,
|
|
37
|
-
"gpt-4-0125-preview": 128000,
|
|
38
|
-
"gpt-4-turbo-preview": 128000,
|
|
39
|
-
"gpt-4-vision-preview": 128000,
|
|
40
|
-
}
|
|
41
|
-
LEGACY_COMPLETION_API_MODELS = ("gpt-3.5-turbo-instruct",)
|
|
42
|
-
logger = logging.getLogger(__name__)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
@dataclass
|
|
46
|
-
class AzureOptions:
|
|
47
|
-
api_version: str
|
|
48
|
-
azure_endpoint: str
|
|
49
|
-
azure_deployment: Optional[str]
|
|
50
|
-
azure_ad_token: Optional[str]
|
|
51
|
-
azure_ad_token_provider: Optional[Callable[[], str]]
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
@dataclass
|
|
55
|
-
class OpenAIModel(BaseEvalModel):
|
|
56
|
-
api_key: Optional[str] = field(repr=False, default=None)
|
|
57
|
-
"""Your OpenAI key. If not provided, will be read from the environment variable"""
|
|
58
|
-
organization: Optional[str] = field(repr=False, default=None)
|
|
59
|
-
"""
|
|
60
|
-
The organization to use for the OpenAI API. If not provided, will default
|
|
61
|
-
to what's configured in OpenAI
|
|
62
|
-
"""
|
|
63
|
-
base_url: Optional[str] = field(repr=False, default=None)
|
|
64
|
-
"""
|
|
65
|
-
An optional base URL to use for the OpenAI API. If not provided, will default
|
|
66
|
-
to what's configured in OpenAI
|
|
67
|
-
"""
|
|
68
|
-
model: str = "gpt-4"
|
|
69
|
-
"""
|
|
70
|
-
Model name to use. In of azure, this is the deployment name such as gpt-35-instant
|
|
71
|
-
"""
|
|
72
|
-
temperature: float = 0.0
|
|
73
|
-
"""What sampling temperature to use."""
|
|
74
|
-
max_tokens: int = 256
|
|
75
|
-
"""The maximum number of tokens to generate in the completion.
|
|
76
|
-
-1 returns as many tokens as possible given the prompt and
|
|
77
|
-
the models maximal context size."""
|
|
78
|
-
top_p: float = 1
|
|
79
|
-
"""Total probability mass of tokens to consider at each step."""
|
|
80
|
-
frequency_penalty: float = 0
|
|
81
|
-
"""Penalizes repeated tokens according to frequency."""
|
|
82
|
-
presence_penalty: float = 0
|
|
83
|
-
"""Penalizes repeated tokens."""
|
|
84
|
-
n: int = 1
|
|
85
|
-
"""How many completions to generate for each prompt."""
|
|
86
|
-
model_kwargs: Dict[str, Any] = field(default_factory=dict)
|
|
87
|
-
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
|
88
|
-
batch_size: int = 20
|
|
89
|
-
# TODO: IMPLEMENT BATCHING
|
|
90
|
-
"""Batch size to use when passing multiple documents to generate."""
|
|
91
|
-
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
|
92
|
-
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
|
93
|
-
max_retries: int = 20
|
|
94
|
-
"""Maximum number of retries to make when generating."""
|
|
95
|
-
retry_min_seconds: int = 10
|
|
96
|
-
"""Minimum number of seconds to wait when retrying."""
|
|
97
|
-
retry_max_seconds: int = 60
|
|
98
|
-
"""Maximum number of seconds to wait when retrying."""
|
|
99
|
-
|
|
100
|
-
# Azure options
|
|
101
|
-
api_version: Optional[str] = field(default=None)
|
|
102
|
-
"""https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning"""
|
|
103
|
-
azure_endpoint: Optional[str] = field(default=None)
|
|
104
|
-
"""
|
|
105
|
-
The endpoint to use for azure openai. Available in the azure portal.
|
|
106
|
-
https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/create-resource?pivots=web-portal#create-a-resource
|
|
107
|
-
"""
|
|
108
|
-
azure_deployment: Optional[str] = field(default=None)
|
|
109
|
-
azure_ad_token: Optional[str] = field(default=None)
|
|
110
|
-
azure_ad_token_provider: Optional[Callable[[], str]] = field(default=None)
|
|
111
|
-
|
|
112
|
-
# Deprecated fields
|
|
113
|
-
model_name: Optional[str] = field(default=None)
|
|
114
|
-
"""
|
|
115
|
-
.. deprecated:: 3.0.0
|
|
116
|
-
use `model` instead. This will be removed
|
|
117
|
-
"""
|
|
118
|
-
|
|
119
|
-
def __post_init__(self) -> None:
|
|
120
|
-
self._migrate_model_name()
|
|
121
|
-
self._init_environment()
|
|
122
|
-
self._init_open_ai()
|
|
123
|
-
self._init_tiktoken()
|
|
124
|
-
self._init_rate_limiter()
|
|
125
|
-
|
|
126
|
-
def reload_client(self) -> None:
|
|
127
|
-
self._init_open_ai()
|
|
128
|
-
|
|
129
|
-
def _migrate_model_name(self) -> None:
|
|
130
|
-
if self.model_name:
|
|
131
|
-
warning_message = "The `model_name` field is deprecated. Use `model` instead. \
|
|
132
|
-
This will be removed in a future release."
|
|
133
|
-
print(
|
|
134
|
-
warning_message,
|
|
135
|
-
)
|
|
136
|
-
warnings.warn(warning_message, DeprecationWarning)
|
|
137
|
-
self.model = self.model_name
|
|
138
|
-
self.model_name = None
|
|
139
|
-
|
|
140
|
-
def _init_environment(self) -> None:
|
|
141
|
-
try:
|
|
142
|
-
import openai
|
|
143
|
-
import openai._utils as openai_util
|
|
144
|
-
|
|
145
|
-
self._openai = openai
|
|
146
|
-
self._openai_util = openai_util
|
|
147
|
-
except ImportError:
|
|
148
|
-
self._raise_import_error(
|
|
149
|
-
package_display_name="OpenAI",
|
|
150
|
-
package_name="openai",
|
|
151
|
-
package_min_version=MINIMUM_OPENAI_VERSION,
|
|
152
|
-
)
|
|
153
|
-
try:
|
|
154
|
-
import tiktoken
|
|
155
|
-
|
|
156
|
-
self._tiktoken = tiktoken
|
|
157
|
-
except ImportError:
|
|
158
|
-
self._raise_import_error(
|
|
159
|
-
package_name="tiktoken",
|
|
160
|
-
)
|
|
161
|
-
|
|
162
|
-
def _init_open_ai(self) -> None:
|
|
163
|
-
# For Azure, you need to provide the endpoint and the endpoint
|
|
164
|
-
self._is_azure = bool(self.azure_endpoint)
|
|
165
|
-
|
|
166
|
-
self._model_uses_legacy_completion_api = self.model.startswith(LEGACY_COMPLETION_API_MODELS)
|
|
167
|
-
if self.api_key is None:
|
|
168
|
-
api_key = os.getenv(OPENAI_API_KEY_ENVVAR_NAME)
|
|
169
|
-
if api_key is None:
|
|
170
|
-
# TODO: Create custom AuthenticationError
|
|
171
|
-
raise RuntimeError(
|
|
172
|
-
"OpenAI's API key not provided. Pass it as an argument to 'api_key' "
|
|
173
|
-
"or set it in your environment: 'export OPENAI_API_KEY=sk-****'"
|
|
174
|
-
)
|
|
175
|
-
self.api_key = api_key
|
|
176
|
-
|
|
177
|
-
# Set the version, organization, and base_url - default to openAI
|
|
178
|
-
self.api_version = self.api_version or self._openai.api_version
|
|
179
|
-
self.organization = self.organization or self._openai.organization
|
|
180
|
-
|
|
181
|
-
# Initialize specific clients depending on the API backend
|
|
182
|
-
# Set the type first
|
|
183
|
-
self._client: Union[self._openai.OpenAI, self._openai.AzureOpenAI] # type: ignore
|
|
184
|
-
self._async_client: Union[self._openai.AsyncOpenAI, self._openai.AsyncAzureOpenAI] # type: ignore
|
|
185
|
-
if self._is_azure:
|
|
186
|
-
# Validate the azure options and construct a client
|
|
187
|
-
azure_options = self._get_azure_options()
|
|
188
|
-
self._client = self._openai.AzureOpenAI(
|
|
189
|
-
azure_endpoint=azure_options.azure_endpoint,
|
|
190
|
-
azure_deployment=azure_options.azure_deployment,
|
|
191
|
-
api_version=azure_options.api_version,
|
|
192
|
-
azure_ad_token=azure_options.azure_ad_token,
|
|
193
|
-
azure_ad_token_provider=azure_options.azure_ad_token_provider,
|
|
194
|
-
api_key=self.api_key,
|
|
195
|
-
organization=self.organization,
|
|
196
|
-
)
|
|
197
|
-
self._async_client = self._openai.AsyncAzureOpenAI(
|
|
198
|
-
azure_endpoint=azure_options.azure_endpoint,
|
|
199
|
-
azure_deployment=azure_options.azure_deployment,
|
|
200
|
-
api_version=azure_options.api_version,
|
|
201
|
-
azure_ad_token=azure_options.azure_ad_token,
|
|
202
|
-
azure_ad_token_provider=azure_options.azure_ad_token_provider,
|
|
203
|
-
api_key=self.api_key,
|
|
204
|
-
organization=self.organization,
|
|
205
|
-
)
|
|
206
|
-
# return early since we don't need to check the model
|
|
207
|
-
return
|
|
208
|
-
|
|
209
|
-
# The client is not azure, so it must be openai
|
|
210
|
-
self._client = self._openai.OpenAI(
|
|
211
|
-
api_key=self.api_key,
|
|
212
|
-
organization=self.organization,
|
|
213
|
-
base_url=(self.base_url or self._openai.base_url),
|
|
214
|
-
)
|
|
215
|
-
|
|
216
|
-
# The client is not azure, so it must be openai
|
|
217
|
-
self._async_client = self._openai.AsyncOpenAI(
|
|
218
|
-
api_key=self.api_key,
|
|
219
|
-
organization=self.organization,
|
|
220
|
-
base_url=(self.base_url or self._openai.base_url),
|
|
221
|
-
max_retries=0,
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
def _init_tiktoken(self) -> None:
|
|
225
|
-
try:
|
|
226
|
-
encoding = self._tiktoken.encoding_for_model(self.model)
|
|
227
|
-
except KeyError:
|
|
228
|
-
encoding = self._tiktoken.get_encoding("cl100k_base")
|
|
229
|
-
self._tiktoken_encoding = encoding
|
|
230
|
-
|
|
231
|
-
def _get_azure_options(self) -> AzureOptions:
|
|
232
|
-
options = {}
|
|
233
|
-
for option in fields(AzureOptions):
|
|
234
|
-
if (value := getattr(self, option.name)) is not None:
|
|
235
|
-
options[option.name] = value
|
|
236
|
-
else:
|
|
237
|
-
# raise ValueError if field is not optional
|
|
238
|
-
# See if the field is optional - e.g. get_origin(Optional[T]) = typing.Union
|
|
239
|
-
option_is_optional = get_origin(option.type) is Union and type(None) in get_args(
|
|
240
|
-
option.type
|
|
241
|
-
)
|
|
242
|
-
if not option_is_optional:
|
|
243
|
-
raise ValueError(
|
|
244
|
-
f"Option '{option.name}' must be set when using Azure OpenAI API"
|
|
245
|
-
)
|
|
246
|
-
options[option.name] = None
|
|
247
|
-
return AzureOptions(**options)
|
|
248
|
-
|
|
249
|
-
def _init_rate_limiter(self) -> None:
|
|
250
|
-
self._rate_limiter = RateLimiter(
|
|
251
|
-
rate_limit_error=self._openai.RateLimitError,
|
|
252
|
-
max_rate_limit_retries=10,
|
|
253
|
-
initial_per_second_request_rate=5,
|
|
254
|
-
maximum_per_second_request_rate=20,
|
|
255
|
-
enforcement_window_minutes=1,
|
|
256
|
-
)
|
|
257
|
-
|
|
258
|
-
@staticmethod
|
|
259
|
-
def _build_messages(
|
|
260
|
-
prompt: str, system_instruction: Optional[str] = None
|
|
261
|
-
) -> List[Dict[str, str]]:
|
|
262
|
-
messages = [{"role": "user", "content": prompt}]
|
|
263
|
-
if system_instruction:
|
|
264
|
-
messages.insert(0, {"role": "system", "content": str(system_instruction)})
|
|
265
|
-
return messages
|
|
266
|
-
|
|
267
|
-
def verbose_generation_info(self) -> str:
|
|
268
|
-
return f"OpenAI invocation parameters: {self.public_invocation_params}"
|
|
269
|
-
|
|
270
|
-
async def _async_generate(self, prompt: str, **kwargs: Any) -> str:
|
|
271
|
-
invoke_params = self.invocation_params
|
|
272
|
-
messages = self._build_messages(prompt, kwargs.get("instruction"))
|
|
273
|
-
if functions := kwargs.get("functions"):
|
|
274
|
-
invoke_params["functions"] = functions
|
|
275
|
-
if function_call := kwargs.get("function_call"):
|
|
276
|
-
invoke_params["function_call"] = function_call
|
|
277
|
-
response = await self._async_rate_limited_completion(
|
|
278
|
-
messages=messages,
|
|
279
|
-
**invoke_params,
|
|
280
|
-
)
|
|
281
|
-
choice = response["choices"][0]
|
|
282
|
-
if self._model_uses_legacy_completion_api:
|
|
283
|
-
return str(choice["text"])
|
|
284
|
-
message = choice["message"]
|
|
285
|
-
if function_call := message.get("function_call"):
|
|
286
|
-
return str(function_call.get("arguments") or "")
|
|
287
|
-
return str(message["content"])
|
|
288
|
-
|
|
289
|
-
def _generate(self, prompt: str, **kwargs: Any) -> str:
|
|
290
|
-
invoke_params = self.invocation_params
|
|
291
|
-
messages = self._build_messages(prompt, kwargs.get("instruction"))
|
|
292
|
-
if functions := kwargs.get("functions"):
|
|
293
|
-
invoke_params["functions"] = functions
|
|
294
|
-
if function_call := kwargs.get("function_call"):
|
|
295
|
-
invoke_params["function_call"] = function_call
|
|
296
|
-
response = self._rate_limited_completion(
|
|
297
|
-
messages=messages,
|
|
298
|
-
**invoke_params,
|
|
299
|
-
)
|
|
300
|
-
choice = response["choices"][0]
|
|
301
|
-
if self._model_uses_legacy_completion_api:
|
|
302
|
-
return str(choice["text"])
|
|
303
|
-
message = choice["message"]
|
|
304
|
-
if function_call := message.get("function_call"):
|
|
305
|
-
return str(function_call.get("arguments") or "")
|
|
306
|
-
return str(message["content"])
|
|
307
|
-
|
|
308
|
-
async def _async_rate_limited_completion(self, **kwargs: Any) -> Any:
|
|
309
|
-
@self._rate_limiter.alimit
|
|
310
|
-
async def _async_completion(**kwargs: Any) -> Any:
|
|
311
|
-
try:
|
|
312
|
-
if self._model_uses_legacy_completion_api:
|
|
313
|
-
if "prompt" not in kwargs:
|
|
314
|
-
kwargs["prompt"] = "\n\n".join(
|
|
315
|
-
(message.get("content") or "")
|
|
316
|
-
for message in (kwargs.pop("messages", None) or ())
|
|
317
|
-
)
|
|
318
|
-
# OpenAI 1.0.0 API responses are pydantic objects, not dicts
|
|
319
|
-
# We must dump the model to get the dict
|
|
320
|
-
res = await self._async_client.completions.create(**kwargs)
|
|
321
|
-
else:
|
|
322
|
-
res = await self._async_client.chat.completions.create(**kwargs)
|
|
323
|
-
return res.model_dump()
|
|
324
|
-
except self._openai._exceptions.BadRequestError as e:
|
|
325
|
-
exception_message = e.args[0]
|
|
326
|
-
if exception_message and "maximum context length" in exception_message:
|
|
327
|
-
raise PhoenixContextLimitExceeded(exception_message) from e
|
|
328
|
-
raise e
|
|
329
|
-
|
|
330
|
-
return await _async_completion(**kwargs)
|
|
331
|
-
|
|
332
|
-
def _rate_limited_completion(self, **kwargs: Any) -> Any:
|
|
333
|
-
@self._rate_limiter.limit
|
|
334
|
-
def _completion(**kwargs: Any) -> Any:
|
|
335
|
-
try:
|
|
336
|
-
if self._model_uses_legacy_completion_api:
|
|
337
|
-
if "prompt" not in kwargs:
|
|
338
|
-
kwargs["prompt"] = "\n\n".join(
|
|
339
|
-
(message.get("content") or "")
|
|
340
|
-
for message in (kwargs.pop("messages", None) or ())
|
|
341
|
-
)
|
|
342
|
-
# OpenAI 1.0.0 API responses are pydantic objects, not dicts
|
|
343
|
-
# We must dump the model to get the dict
|
|
344
|
-
return self._client.completions.create(**kwargs).model_dump()
|
|
345
|
-
return self._client.chat.completions.create(**kwargs).model_dump()
|
|
346
|
-
except self._openai._exceptions.BadRequestError as e:
|
|
347
|
-
exception_message = e.args[0]
|
|
348
|
-
if exception_message and "maximum context length" in exception_message:
|
|
349
|
-
raise PhoenixContextLimitExceeded(exception_message) from e
|
|
350
|
-
raise e
|
|
351
|
-
|
|
352
|
-
return _completion(**kwargs)
|
|
353
|
-
|
|
354
|
-
@property
|
|
355
|
-
def max_context_size(self) -> int:
|
|
356
|
-
model = self.model
|
|
357
|
-
# handling finetuned models
|
|
358
|
-
if "ft-" in model:
|
|
359
|
-
model = self.model.split(":")[0]
|
|
360
|
-
if model == "gpt-4":
|
|
361
|
-
# Map gpt-4 to the current default
|
|
362
|
-
model = "gpt-4-0613"
|
|
363
|
-
|
|
364
|
-
context_size = MODEL_TOKEN_LIMIT_MAPPING.get(model, None)
|
|
365
|
-
|
|
366
|
-
if context_size is None:
|
|
367
|
-
raise ValueError(
|
|
368
|
-
"Can't determine maximum context size. An unknown model name was "
|
|
369
|
-
f"used: {model}. Please provide a valid OpenAI model name. "
|
|
370
|
-
"Known models are: " + ", ".join(MODEL_TOKEN_LIMIT_MAPPING.keys())
|
|
371
|
-
)
|
|
372
|
-
|
|
373
|
-
return context_size
|
|
374
|
-
|
|
375
|
-
@property
|
|
376
|
-
def public_invocation_params(self) -> Dict[str, Any]:
|
|
377
|
-
return {
|
|
378
|
-
**({"model": self.model}),
|
|
379
|
-
**self._default_params,
|
|
380
|
-
**self.model_kwargs,
|
|
381
|
-
}
|
|
382
|
-
|
|
383
|
-
@property
|
|
384
|
-
def invocation_params(self) -> Dict[str, Any]:
|
|
385
|
-
return {
|
|
386
|
-
**self.public_invocation_params,
|
|
387
|
-
}
|
|
388
|
-
|
|
389
|
-
@property
|
|
390
|
-
def _default_params(self) -> Dict[str, Any]:
|
|
391
|
-
"""Get the default parameters for calling OpenAI API."""
|
|
392
|
-
return {
|
|
393
|
-
"temperature": self.temperature,
|
|
394
|
-
"max_tokens": self.max_tokens,
|
|
395
|
-
"frequency_penalty": self.frequency_penalty,
|
|
396
|
-
"presence_penalty": self.presence_penalty,
|
|
397
|
-
"top_p": self.top_p,
|
|
398
|
-
"n": self.n,
|
|
399
|
-
"timeout": self.request_timeout,
|
|
400
|
-
}
|
|
401
|
-
|
|
402
|
-
@property
|
|
403
|
-
def encoder(self) -> "Encoding":
|
|
404
|
-
return self._tiktoken_encoding
|
|
405
|
-
|
|
406
|
-
def get_token_count_from_messages(self, messages: List[Dict[str, str]]) -> int:
|
|
407
|
-
"""Return the number of tokens used by a list of messages.
|
|
408
|
-
|
|
409
|
-
Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
|
410
|
-
""" # noqa
|
|
411
|
-
model = self.model
|
|
412
|
-
if model == "gpt-3.5-turbo-0301":
|
|
413
|
-
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
|
414
|
-
tokens_per_name = -1 # if there's a name, the role is omitted
|
|
415
|
-
else:
|
|
416
|
-
tokens_per_message = 3
|
|
417
|
-
tokens_per_name = 1
|
|
418
|
-
|
|
419
|
-
token_count = 0
|
|
420
|
-
for message in messages:
|
|
421
|
-
token_count += tokens_per_message
|
|
422
|
-
for key, text in message.items():
|
|
423
|
-
token_count += len(self.get_tokens_from_text(text))
|
|
424
|
-
if key == "name":
|
|
425
|
-
token_count += tokens_per_name
|
|
426
|
-
# every reply is primed with <|start|>assistant<|message|>
|
|
427
|
-
token_count += 3
|
|
428
|
-
return token_count
|
|
429
|
-
|
|
430
|
-
def get_tokens_from_text(self, text: str) -> List[int]:
|
|
431
|
-
return self.encoder.encode(text)
|
|
432
|
-
|
|
433
|
-
def get_text_from_tokens(self, tokens: List[int]) -> str:
|
|
434
|
-
return self.encoder.decode(tokens)
|
|
435
|
-
|
|
436
|
-
@property
|
|
437
|
-
def supports_function_calling(self) -> bool:
|
|
438
|
-
if (
|
|
439
|
-
self._is_azure
|
|
440
|
-
and self.api_version
|
|
441
|
-
# The first api version supporting function calling is 2023-07-01-preview.
|
|
442
|
-
# See https://github.com/Azure/azure-rest-api-specs/blob/58e92dd03733bc175e6a9540f4bc53703b57fcc9/specification/cognitiveservices/data-plane/AzureOpenAI/inference/preview/2023-07-01-preview/inference.json#L895 # noqa E501
|
|
443
|
-
and self.api_version[:10] < "2023-07-01"
|
|
444
|
-
):
|
|
445
|
-
return False
|
|
446
|
-
if self._model_uses_legacy_completion_api:
|
|
447
|
-
return False
|
|
448
|
-
return True
|