aeri-python 4.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aeri/__init__.py +72 -0
- aeri/_client/_validation.py +204 -0
- aeri/_client/attributes.py +188 -0
- aeri/_client/client.py +3761 -0
- aeri/_client/constants.py +65 -0
- aeri/_client/datasets.py +302 -0
- aeri/_client/environment_variables.py +158 -0
- aeri/_client/get_client.py +149 -0
- aeri/_client/observe.py +661 -0
- aeri/_client/propagation.py +475 -0
- aeri/_client/resource_manager.py +510 -0
- aeri/_client/span.py +1519 -0
- aeri/_client/span_filter.py +76 -0
- aeri/_client/span_processor.py +206 -0
- aeri/_client/utils.py +132 -0
- aeri/_task_manager/media_manager.py +331 -0
- aeri/_task_manager/media_upload_consumer.py +44 -0
- aeri/_task_manager/media_upload_queue.py +12 -0
- aeri/_task_manager/score_ingestion_consumer.py +208 -0
- aeri/_task_manager/task_manager.py +475 -0
- aeri/_utils/__init__.py +19 -0
- aeri/_utils/environment.py +34 -0
- aeri/_utils/error_logging.py +47 -0
- aeri/_utils/parse_error.py +99 -0
- aeri/_utils/prompt_cache.py +188 -0
- aeri/_utils/request.py +137 -0
- aeri/_utils/serializer.py +205 -0
- aeri/api/.fern/metadata.json +14 -0
- aeri/api/__init__.py +836 -0
- aeri/api/annotation_queues/__init__.py +82 -0
- aeri/api/annotation_queues/client.py +1111 -0
- aeri/api/annotation_queues/raw_client.py +2288 -0
- aeri/api/annotation_queues/types/__init__.py +84 -0
- aeri/api/annotation_queues/types/annotation_queue.py +28 -0
- aeri/api/annotation_queues/types/annotation_queue_assignment_request.py +16 -0
- aeri/api/annotation_queues/types/annotation_queue_item.py +34 -0
- aeri/api/annotation_queues/types/annotation_queue_object_type.py +26 -0
- aeri/api/annotation_queues/types/annotation_queue_status.py +22 -0
- aeri/api/annotation_queues/types/create_annotation_queue_assignment_response.py +18 -0
- aeri/api/annotation_queues/types/create_annotation_queue_item_request.py +25 -0
- aeri/api/annotation_queues/types/create_annotation_queue_request.py +20 -0
- aeri/api/annotation_queues/types/delete_annotation_queue_assignment_response.py +14 -0
- aeri/api/annotation_queues/types/delete_annotation_queue_item_response.py +15 -0
- aeri/api/annotation_queues/types/paginated_annotation_queue_items.py +17 -0
- aeri/api/annotation_queues/types/paginated_annotation_queues.py +17 -0
- aeri/api/annotation_queues/types/update_annotation_queue_item_request.py +15 -0
- aeri/api/blob_storage_integrations/__init__.py +73 -0
- aeri/api/blob_storage_integrations/client.py +550 -0
- aeri/api/blob_storage_integrations/raw_client.py +976 -0
- aeri/api/blob_storage_integrations/types/__init__.py +77 -0
- aeri/api/blob_storage_integrations/types/blob_storage_export_frequency.py +26 -0
- aeri/api/blob_storage_integrations/types/blob_storage_export_mode.py +26 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integration_deletion_response.py +14 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integration_file_type.py +26 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integration_response.py +64 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integration_status_response.py +50 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integration_type.py +26 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integrations_response.py +15 -0
- aeri/api/blob_storage_integrations/types/blob_storage_sync_status.py +47 -0
- aeri/api/blob_storage_integrations/types/create_blob_storage_integration_request.py +91 -0
- aeri/api/client.py +679 -0
- aeri/api/comments/__init__.py +44 -0
- aeri/api/comments/client.py +407 -0
- aeri/api/comments/raw_client.py +750 -0
- aeri/api/comments/types/__init__.py +46 -0
- aeri/api/comments/types/create_comment_request.py +47 -0
- aeri/api/comments/types/create_comment_response.py +17 -0
- aeri/api/comments/types/get_comments_response.py +17 -0
- aeri/api/commons/__init__.py +210 -0
- aeri/api/commons/errors/__init__.py +56 -0
- aeri/api/commons/errors/access_denied_error.py +12 -0
- aeri/api/commons/errors/error.py +12 -0
- aeri/api/commons/errors/method_not_allowed_error.py +12 -0
- aeri/api/commons/errors/not_found_error.py +12 -0
- aeri/api/commons/errors/unauthorized_error.py +12 -0
- aeri/api/commons/types/__init__.py +190 -0
- aeri/api/commons/types/base_score.py +90 -0
- aeri/api/commons/types/base_score_v1.py +70 -0
- aeri/api/commons/types/boolean_score.py +26 -0
- aeri/api/commons/types/boolean_score_v1.py +26 -0
- aeri/api/commons/types/categorical_score.py +26 -0
- aeri/api/commons/types/categorical_score_v1.py +26 -0
- aeri/api/commons/types/comment.py +36 -0
- aeri/api/commons/types/comment_object_type.py +30 -0
- aeri/api/commons/types/config_category.py +15 -0
- aeri/api/commons/types/correction_score.py +26 -0
- aeri/api/commons/types/create_score_value.py +5 -0
- aeri/api/commons/types/dataset.py +49 -0
- aeri/api/commons/types/dataset_item.py +58 -0
- aeri/api/commons/types/dataset_run.py +63 -0
- aeri/api/commons/types/dataset_run_item.py +40 -0
- aeri/api/commons/types/dataset_run_with_items.py +19 -0
- aeri/api/commons/types/dataset_status.py +22 -0
- aeri/api/commons/types/map_value.py +11 -0
- aeri/api/commons/types/model.py +125 -0
- aeri/api/commons/types/model_price.py +14 -0
- aeri/api/commons/types/model_usage_unit.py +42 -0
- aeri/api/commons/types/numeric_score.py +17 -0
- aeri/api/commons/types/numeric_score_v1.py +17 -0
- aeri/api/commons/types/observation.py +142 -0
- aeri/api/commons/types/observation_level.py +30 -0
- aeri/api/commons/types/observation_v2.py +235 -0
- aeri/api/commons/types/observations_view.py +89 -0
- aeri/api/commons/types/pricing_tier.py +91 -0
- aeri/api/commons/types/pricing_tier_condition.py +68 -0
- aeri/api/commons/types/pricing_tier_input.py +76 -0
- aeri/api/commons/types/pricing_tier_operator.py +42 -0
- aeri/api/commons/types/score.py +201 -0
- aeri/api/commons/types/score_config.py +66 -0
- aeri/api/commons/types/score_config_data_type.py +26 -0
- aeri/api/commons/types/score_data_type.py +30 -0
- aeri/api/commons/types/score_source.py +26 -0
- aeri/api/commons/types/score_v1.py +131 -0
- aeri/api/commons/types/session.py +25 -0
- aeri/api/commons/types/session_with_traces.py +15 -0
- aeri/api/commons/types/trace.py +84 -0
- aeri/api/commons/types/trace_with_details.py +43 -0
- aeri/api/commons/types/trace_with_full_details.py +45 -0
- aeri/api/commons/types/usage.py +59 -0
- aeri/api/core/__init__.py +111 -0
- aeri/api/core/api_error.py +23 -0
- aeri/api/core/client_wrapper.py +141 -0
- aeri/api/core/datetime_utils.py +30 -0
- aeri/api/core/enum.py +20 -0
- aeri/api/core/file.py +70 -0
- aeri/api/core/force_multipart.py +18 -0
- aeri/api/core/http_client.py +711 -0
- aeri/api/core/http_response.py +55 -0
- aeri/api/core/http_sse/__init__.py +48 -0
- aeri/api/core/http_sse/_api.py +114 -0
- aeri/api/core/http_sse/_decoders.py +66 -0
- aeri/api/core/http_sse/_exceptions.py +7 -0
- aeri/api/core/http_sse/_models.py +17 -0
- aeri/api/core/jsonable_encoder.py +102 -0
- aeri/api/core/pydantic_utilities.py +310 -0
- aeri/api/core/query_encoder.py +60 -0
- aeri/api/core/remove_none_from_dict.py +11 -0
- aeri/api/core/request_options.py +35 -0
- aeri/api/core/serialization.py +282 -0
- aeri/api/dataset_items/__init__.py +52 -0
- aeri/api/dataset_items/client.py +499 -0
- aeri/api/dataset_items/raw_client.py +973 -0
- aeri/api/dataset_items/types/__init__.py +50 -0
- aeri/api/dataset_items/types/create_dataset_item_request.py +37 -0
- aeri/api/dataset_items/types/delete_dataset_item_response.py +17 -0
- aeri/api/dataset_items/types/paginated_dataset_items.py +17 -0
- aeri/api/dataset_run_items/__init__.py +43 -0
- aeri/api/dataset_run_items/client.py +323 -0
- aeri/api/dataset_run_items/raw_client.py +547 -0
- aeri/api/dataset_run_items/types/__init__.py +44 -0
- aeri/api/dataset_run_items/types/create_dataset_run_item_request.py +51 -0
- aeri/api/dataset_run_items/types/paginated_dataset_run_items.py +17 -0
- aeri/api/datasets/__init__.py +55 -0
- aeri/api/datasets/client.py +661 -0
- aeri/api/datasets/raw_client.py +1368 -0
- aeri/api/datasets/types/__init__.py +53 -0
- aeri/api/datasets/types/create_dataset_request.py +31 -0
- aeri/api/datasets/types/delete_dataset_run_response.py +14 -0
- aeri/api/datasets/types/paginated_dataset_runs.py +17 -0
- aeri/api/datasets/types/paginated_datasets.py +17 -0
- aeri/api/health/__init__.py +44 -0
- aeri/api/health/client.py +112 -0
- aeri/api/health/errors/__init__.py +42 -0
- aeri/api/health/errors/service_unavailable_error.py +13 -0
- aeri/api/health/raw_client.py +227 -0
- aeri/api/health/types/__init__.py +40 -0
- aeri/api/health/types/health_response.py +30 -0
- aeri/api/ingestion/__init__.py +169 -0
- aeri/api/ingestion/client.py +221 -0
- aeri/api/ingestion/raw_client.py +293 -0
- aeri/api/ingestion/types/__init__.py +169 -0
- aeri/api/ingestion/types/base_event.py +27 -0
- aeri/api/ingestion/types/create_event_body.py +14 -0
- aeri/api/ingestion/types/create_event_event.py +15 -0
- aeri/api/ingestion/types/create_generation_body.py +40 -0
- aeri/api/ingestion/types/create_generation_event.py +15 -0
- aeri/api/ingestion/types/create_observation_event.py +15 -0
- aeri/api/ingestion/types/create_span_body.py +19 -0
- aeri/api/ingestion/types/create_span_event.py +15 -0
- aeri/api/ingestion/types/ingestion_error.py +17 -0
- aeri/api/ingestion/types/ingestion_event.py +155 -0
- aeri/api/ingestion/types/ingestion_response.py +17 -0
- aeri/api/ingestion/types/ingestion_success.py +15 -0
- aeri/api/ingestion/types/ingestion_usage.py +8 -0
- aeri/api/ingestion/types/observation_body.py +53 -0
- aeri/api/ingestion/types/observation_type.py +54 -0
- aeri/api/ingestion/types/open_ai_completion_usage_schema.py +26 -0
- aeri/api/ingestion/types/open_ai_response_usage_schema.py +24 -0
- aeri/api/ingestion/types/open_ai_usage.py +28 -0
- aeri/api/ingestion/types/optional_observation_body.py +36 -0
- aeri/api/ingestion/types/score_body.py +75 -0
- aeri/api/ingestion/types/score_event.py +15 -0
- aeri/api/ingestion/types/sdk_log_body.py +14 -0
- aeri/api/ingestion/types/sdk_log_event.py +15 -0
- aeri/api/ingestion/types/trace_body.py +36 -0
- aeri/api/ingestion/types/trace_event.py +15 -0
- aeri/api/ingestion/types/update_event_body.py +14 -0
- aeri/api/ingestion/types/update_generation_body.py +40 -0
- aeri/api/ingestion/types/update_generation_event.py +15 -0
- aeri/api/ingestion/types/update_observation_event.py +15 -0
- aeri/api/ingestion/types/update_span_body.py +19 -0
- aeri/api/ingestion/types/update_span_event.py +15 -0
- aeri/api/ingestion/types/usage_details.py +10 -0
- aeri/api/legacy/__init__.py +61 -0
- aeri/api/legacy/client.py +105 -0
- aeri/api/legacy/metrics_v1/__init__.py +40 -0
- aeri/api/legacy/metrics_v1/client.py +214 -0
- aeri/api/legacy/metrics_v1/raw_client.py +322 -0
- aeri/api/legacy/metrics_v1/types/__init__.py +40 -0
- aeri/api/legacy/metrics_v1/types/metrics_response.py +19 -0
- aeri/api/legacy/observations_v1/__init__.py +43 -0
- aeri/api/legacy/observations_v1/client.py +523 -0
- aeri/api/legacy/observations_v1/raw_client.py +759 -0
- aeri/api/legacy/observations_v1/types/__init__.py +44 -0
- aeri/api/legacy/observations_v1/types/observations.py +17 -0
- aeri/api/legacy/observations_v1/types/observations_views.py +17 -0
- aeri/api/legacy/raw_client.py +13 -0
- aeri/api/legacy/score_v1/__init__.py +43 -0
- aeri/api/legacy/score_v1/client.py +329 -0
- aeri/api/legacy/score_v1/raw_client.py +545 -0
- aeri/api/legacy/score_v1/types/__init__.py +44 -0
- aeri/api/legacy/score_v1/types/create_score_request.py +75 -0
- aeri/api/legacy/score_v1/types/create_score_response.py +17 -0
- aeri/api/llm_connections/__init__.py +55 -0
- aeri/api/llm_connections/client.py +311 -0
- aeri/api/llm_connections/raw_client.py +541 -0
- aeri/api/llm_connections/types/__init__.py +53 -0
- aeri/api/llm_connections/types/llm_adapter.py +38 -0
- aeri/api/llm_connections/types/llm_connection.py +77 -0
- aeri/api/llm_connections/types/paginated_llm_connections.py +17 -0
- aeri/api/llm_connections/types/upsert_llm_connection_request.py +69 -0
- aeri/api/media/__init__.py +58 -0
- aeri/api/media/client.py +427 -0
- aeri/api/media/raw_client.py +739 -0
- aeri/api/media/types/__init__.py +56 -0
- aeri/api/media/types/get_media_response.py +55 -0
- aeri/api/media/types/get_media_upload_url_request.py +51 -0
- aeri/api/media/types/get_media_upload_url_response.py +28 -0
- aeri/api/media/types/media_content_type.py +232 -0
- aeri/api/media/types/patch_media_body.py +43 -0
- aeri/api/metrics/__init__.py +40 -0
- aeri/api/metrics/client.py +422 -0
- aeri/api/metrics/raw_client.py +530 -0
- aeri/api/metrics/types/__init__.py +40 -0
- aeri/api/metrics/types/metrics_v2response.py +19 -0
- aeri/api/models/__init__.py +43 -0
- aeri/api/models/client.py +523 -0
- aeri/api/models/raw_client.py +993 -0
- aeri/api/models/types/__init__.py +44 -0
- aeri/api/models/types/create_model_request.py +103 -0
- aeri/api/models/types/paginated_models.py +17 -0
- aeri/api/observations/__init__.py +43 -0
- aeri/api/observations/client.py +522 -0
- aeri/api/observations/raw_client.py +641 -0
- aeri/api/observations/types/__init__.py +44 -0
- aeri/api/observations/types/observations_v2meta.py +21 -0
- aeri/api/observations/types/observations_v2response.py +28 -0
- aeri/api/opentelemetry/__init__.py +67 -0
- aeri/api/opentelemetry/client.py +276 -0
- aeri/api/opentelemetry/raw_client.py +291 -0
- aeri/api/opentelemetry/types/__init__.py +65 -0
- aeri/api/opentelemetry/types/otel_attribute.py +27 -0
- aeri/api/opentelemetry/types/otel_attribute_value.py +46 -0
- aeri/api/opentelemetry/types/otel_resource.py +24 -0
- aeri/api/opentelemetry/types/otel_resource_span.py +32 -0
- aeri/api/opentelemetry/types/otel_scope.py +34 -0
- aeri/api/opentelemetry/types/otel_scope_span.py +28 -0
- aeri/api/opentelemetry/types/otel_span.py +76 -0
- aeri/api/opentelemetry/types/otel_trace_response.py +16 -0
- aeri/api/organizations/__init__.py +73 -0
- aeri/api/organizations/client.py +756 -0
- aeri/api/organizations/raw_client.py +1707 -0
- aeri/api/organizations/types/__init__.py +71 -0
- aeri/api/organizations/types/delete_membership_request.py +16 -0
- aeri/api/organizations/types/membership_deletion_response.py +17 -0
- aeri/api/organizations/types/membership_request.py +18 -0
- aeri/api/organizations/types/membership_response.py +20 -0
- aeri/api/organizations/types/membership_role.py +30 -0
- aeri/api/organizations/types/memberships_response.py +15 -0
- aeri/api/organizations/types/organization_api_key.py +31 -0
- aeri/api/organizations/types/organization_api_keys_response.py +19 -0
- aeri/api/organizations/types/organization_project.py +25 -0
- aeri/api/organizations/types/organization_projects_response.py +15 -0
- aeri/api/projects/__init__.py +67 -0
- aeri/api/projects/client.py +760 -0
- aeri/api/projects/raw_client.py +1577 -0
- aeri/api/projects/types/__init__.py +65 -0
- aeri/api/projects/types/api_key_deletion_response.py +18 -0
- aeri/api/projects/types/api_key_list.py +23 -0
- aeri/api/projects/types/api_key_response.py +30 -0
- aeri/api/projects/types/api_key_summary.py +35 -0
- aeri/api/projects/types/organization.py +22 -0
- aeri/api/projects/types/project.py +34 -0
- aeri/api/projects/types/project_deletion_response.py +15 -0
- aeri/api/projects/types/projects.py +15 -0
- aeri/api/prompt_version/__init__.py +4 -0
- aeri/api/prompt_version/client.py +157 -0
- aeri/api/prompt_version/raw_client.py +264 -0
- aeri/api/prompts/__init__.py +100 -0
- aeri/api/prompts/client.py +550 -0
- aeri/api/prompts/raw_client.py +987 -0
- aeri/api/prompts/types/__init__.py +96 -0
- aeri/api/prompts/types/base_prompt.py +42 -0
- aeri/api/prompts/types/chat_message.py +17 -0
- aeri/api/prompts/types/chat_message_type.py +15 -0
- aeri/api/prompts/types/chat_message_with_placeholders.py +8 -0
- aeri/api/prompts/types/chat_prompt.py +15 -0
- aeri/api/prompts/types/create_chat_prompt_request.py +37 -0
- aeri/api/prompts/types/create_chat_prompt_type.py +15 -0
- aeri/api/prompts/types/create_prompt_request.py +8 -0
- aeri/api/prompts/types/create_text_prompt_request.py +36 -0
- aeri/api/prompts/types/create_text_prompt_type.py +15 -0
- aeri/api/prompts/types/placeholder_message.py +16 -0
- aeri/api/prompts/types/placeholder_message_type.py +15 -0
- aeri/api/prompts/types/prompt.py +58 -0
- aeri/api/prompts/types/prompt_meta.py +35 -0
- aeri/api/prompts/types/prompt_meta_list_response.py +17 -0
- aeri/api/prompts/types/prompt_type.py +20 -0
- aeri/api/prompts/types/text_prompt.py +14 -0
- aeri/api/scim/__init__.py +94 -0
- aeri/api/scim/client.py +686 -0
- aeri/api/scim/raw_client.py +1528 -0
- aeri/api/scim/types/__init__.py +92 -0
- aeri/api/scim/types/authentication_scheme.py +20 -0
- aeri/api/scim/types/bulk_config.py +22 -0
- aeri/api/scim/types/empty_response.py +16 -0
- aeri/api/scim/types/filter_config.py +17 -0
- aeri/api/scim/types/resource_meta.py +17 -0
- aeri/api/scim/types/resource_type.py +27 -0
- aeri/api/scim/types/resource_types_response.py +21 -0
- aeri/api/scim/types/schema_extension.py +17 -0
- aeri/api/scim/types/schema_resource.py +19 -0
- aeri/api/scim/types/schemas_response.py +21 -0
- aeri/api/scim/types/scim_email.py +16 -0
- aeri/api/scim/types/scim_feature_support.py +14 -0
- aeri/api/scim/types/scim_name.py +14 -0
- aeri/api/scim/types/scim_user.py +24 -0
- aeri/api/scim/types/scim_users_list_response.py +25 -0
- aeri/api/scim/types/service_provider_config.py +36 -0
- aeri/api/scim/types/user_meta.py +20 -0
- aeri/api/score_configs/__init__.py +44 -0
- aeri/api/score_configs/client.py +526 -0
- aeri/api/score_configs/raw_client.py +1012 -0
- aeri/api/score_configs/types/__init__.py +46 -0
- aeri/api/score_configs/types/create_score_config_request.py +46 -0
- aeri/api/score_configs/types/score_configs.py +17 -0
- aeri/api/score_configs/types/update_score_config_request.py +53 -0
- aeri/api/scores/__init__.py +76 -0
- aeri/api/scores/client.py +420 -0
- aeri/api/scores/raw_client.py +656 -0
- aeri/api/scores/types/__init__.py +76 -0
- aeri/api/scores/types/get_scores_response.py +17 -0
- aeri/api/scores/types/get_scores_response_data.py +211 -0
- aeri/api/scores/types/get_scores_response_data_boolean.py +15 -0
- aeri/api/scores/types/get_scores_response_data_categorical.py +15 -0
- aeri/api/scores/types/get_scores_response_data_correction.py +15 -0
- aeri/api/scores/types/get_scores_response_data_numeric.py +15 -0
- aeri/api/scores/types/get_scores_response_trace_data.py +38 -0
- aeri/api/sessions/__init__.py +40 -0
- aeri/api/sessions/client.py +262 -0
- aeri/api/sessions/raw_client.py +500 -0
- aeri/api/sessions/types/__init__.py +40 -0
- aeri/api/sessions/types/paginated_sessions.py +17 -0
- aeri/api/trace/__init__.py +44 -0
- aeri/api/trace/client.py +728 -0
- aeri/api/trace/raw_client.py +1208 -0
- aeri/api/trace/types/__init__.py +46 -0
- aeri/api/trace/types/delete_trace_response.py +14 -0
- aeri/api/trace/types/sort.py +14 -0
- aeri/api/trace/types/traces.py +17 -0
- aeri/api/utils/__init__.py +44 -0
- aeri/api/utils/pagination/__init__.py +40 -0
- aeri/api/utils/pagination/types/__init__.py +40 -0
- aeri/api/utils/pagination/types/meta_response.py +38 -0
- aeri/batch_evaluation.py +1643 -0
- aeri/experiment.py +1044 -0
- aeri/langchain/CallbackHandler.py +1377 -0
- aeri/langchain/__init__.py +5 -0
- aeri/langchain/utils.py +212 -0
- aeri/logger.py +28 -0
- aeri/media.py +352 -0
- aeri/model.py +477 -0
- aeri/openai.py +1124 -0
- aeri/py.typed +0 -0
- aeri/span_filter.py +17 -0
- aeri/types.py +79 -0
- aeri/version.py +3 -0
- aeri_python-4.0.0.dist-info/METADATA +51 -0
- aeri_python-4.0.0.dist-info/RECORD +391 -0
- aeri_python-4.0.0.dist-info/WHEEL +4 -0
- aeri_python-4.0.0.dist-info/licenses/LICENSE +21 -0
aeri/_client/client.py
ADDED
|
@@ -0,0 +1,3761 @@
|
|
|
1
|
+
"""Aeri OpenTelemetry integration module.
|
|
2
|
+
|
|
3
|
+
This module implements Aeri's core observability functionality on top of the OpenTelemetry (OTel) standard.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import re
|
|
10
|
+
import urllib.parse
|
|
11
|
+
import warnings
|
|
12
|
+
from datetime import datetime
|
|
13
|
+
from hashlib import sha256
|
|
14
|
+
from time import time_ns
|
|
15
|
+
from typing import (
|
|
16
|
+
Any,
|
|
17
|
+
Callable,
|
|
18
|
+
Dict,
|
|
19
|
+
List,
|
|
20
|
+
Literal,
|
|
21
|
+
Optional,
|
|
22
|
+
Type,
|
|
23
|
+
Union,
|
|
24
|
+
cast,
|
|
25
|
+
overload,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
import backoff
|
|
29
|
+
import httpx
|
|
30
|
+
from opentelemetry import trace as otel_trace_api
|
|
31
|
+
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
|
|
32
|
+
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
|
|
33
|
+
from opentelemetry.util._decorator import (
|
|
34
|
+
_AgnosticContextManager,
|
|
35
|
+
_agnosticcontextmanager,
|
|
36
|
+
)
|
|
37
|
+
from packaging.version import Version
|
|
38
|
+
from typing_extensions import deprecated
|
|
39
|
+
|
|
40
|
+
from aeri._client.attributes import AeriOtelSpanAttributes, _serialize
|
|
41
|
+
from aeri._client.constants import (
|
|
42
|
+
AERI_SDK_EXPERIMENT_ENVIRONMENT,
|
|
43
|
+
ObservationTypeGenerationLike,
|
|
44
|
+
ObservationTypeLiteral,
|
|
45
|
+
ObservationTypeLiteralNoEvent,
|
|
46
|
+
ObservationTypeSpanLike,
|
|
47
|
+
get_observation_types_list,
|
|
48
|
+
)
|
|
49
|
+
from aeri._client.datasets import DatasetClient
|
|
50
|
+
from aeri._client.environment_variables import (
|
|
51
|
+
AERI_BASE_URL,
|
|
52
|
+
AERI_DEBUG,
|
|
53
|
+
AERI_HOST,
|
|
54
|
+
AERI_PUBLIC_KEY,
|
|
55
|
+
AERI_RELEASE,
|
|
56
|
+
AERI_SAMPLE_RATE,
|
|
57
|
+
AERI_SECRET_KEY,
|
|
58
|
+
AERI_TIMEOUT,
|
|
59
|
+
AERI_TRACING_ENABLED,
|
|
60
|
+
AERI_TRACING_ENVIRONMENT,
|
|
61
|
+
)
|
|
62
|
+
from aeri._client.propagation import (
|
|
63
|
+
PropagatedExperimentAttributes,
|
|
64
|
+
_propagate_attributes,
|
|
65
|
+
)
|
|
66
|
+
from aeri._client.resource_manager import AeriResourceManager
|
|
67
|
+
from aeri._client.span import (
|
|
68
|
+
AeriAgent,
|
|
69
|
+
AeriChain,
|
|
70
|
+
AeriEmbedding,
|
|
71
|
+
AeriEvaluator,
|
|
72
|
+
AeriEvent,
|
|
73
|
+
AeriGeneration,
|
|
74
|
+
AeriGuardrail,
|
|
75
|
+
AeriRetriever,
|
|
76
|
+
AeriSpan,
|
|
77
|
+
AeriTool,
|
|
78
|
+
)
|
|
79
|
+
from aeri._client._validation import AeriClientConfig, ObservationInput, ScoreInput, TraceContextInput
|
|
80
|
+
from aeri._client.utils import get_sha256_hash_hex, run_async_safely
|
|
81
|
+
from aeri._utils import _get_timestamp
|
|
82
|
+
from aeri._utils.environment import get_common_release_envs
|
|
83
|
+
from aeri._utils.parse_error import handle_fern_exception
|
|
84
|
+
from aeri._utils.prompt_cache import PromptCache
|
|
85
|
+
from aeri.api import (
|
|
86
|
+
CreateChatPromptRequest,
|
|
87
|
+
CreateChatPromptType,
|
|
88
|
+
CreateTextPromptRequest,
|
|
89
|
+
Dataset,
|
|
90
|
+
DatasetItem,
|
|
91
|
+
DatasetRunWithItems,
|
|
92
|
+
DatasetStatus,
|
|
93
|
+
DeleteDatasetRunResponse,
|
|
94
|
+
Error,
|
|
95
|
+
MapValue,
|
|
96
|
+
NotFoundError,
|
|
97
|
+
PaginatedDatasetRuns,
|
|
98
|
+
Prompt_Chat,
|
|
99
|
+
Prompt_Text,
|
|
100
|
+
ScoreBody,
|
|
101
|
+
TraceBody,
|
|
102
|
+
)
|
|
103
|
+
from aeri.batch_evaluation import (
|
|
104
|
+
BatchEvaluationResult,
|
|
105
|
+
BatchEvaluationResumeToken,
|
|
106
|
+
BatchEvaluationRunner,
|
|
107
|
+
CompositeEvaluatorFunction,
|
|
108
|
+
MapperFunction,
|
|
109
|
+
)
|
|
110
|
+
from aeri.experiment import (
|
|
111
|
+
Evaluation,
|
|
112
|
+
EvaluatorFunction,
|
|
113
|
+
ExperimentData,
|
|
114
|
+
ExperimentItem,
|
|
115
|
+
ExperimentItemResult,
|
|
116
|
+
ExperimentResult,
|
|
117
|
+
RunEvaluatorFunction,
|
|
118
|
+
TaskFunction,
|
|
119
|
+
_run_evaluator,
|
|
120
|
+
_run_task,
|
|
121
|
+
)
|
|
122
|
+
from aeri.logger import aeri_logger
|
|
123
|
+
from aeri.media import AeriMedia
|
|
124
|
+
from aeri.model import (
|
|
125
|
+
ChatMessageDict,
|
|
126
|
+
ChatMessageWithPlaceholdersDict,
|
|
127
|
+
ChatPromptClient,
|
|
128
|
+
PromptClient,
|
|
129
|
+
TextPromptClient,
|
|
130
|
+
)
|
|
131
|
+
from aeri.types import MaskFunction, ScoreDataType, SpanLevel, TraceContext
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class Aeri:
|
|
135
|
+
"""Main client for Aeri tracing and platform features.
|
|
136
|
+
|
|
137
|
+
This class provides an interface for creating and managing traces, spans,
|
|
138
|
+
and generations in Aeri as well as interacting with the Aeri API.
|
|
139
|
+
|
|
140
|
+
The client features a thread-safe singleton pattern for each unique public API key,
|
|
141
|
+
ensuring consistent trace context propagation across your application. It implements
|
|
142
|
+
efficient batching of spans with configurable flush settings and includes background
|
|
143
|
+
thread management for media uploads and score ingestion.
|
|
144
|
+
|
|
145
|
+
Configuration is flexible through either direct parameters or environment variables,
|
|
146
|
+
with graceful fallbacks and runtime configuration updates.
|
|
147
|
+
|
|
148
|
+
Attributes:
|
|
149
|
+
api: Synchronous API client for Aeri backend communication
|
|
150
|
+
async_api: Asynchronous API client for Aeri backend communication
|
|
151
|
+
_otel_tracer: Internal AeriTracer instance managing OpenTelemetry components
|
|
152
|
+
|
|
153
|
+
Parameters:
|
|
154
|
+
public_key (Optional[str]): Your Aeri public API key. Can also be set via AERI_PUBLIC_KEY environment variable.
|
|
155
|
+
secret_key (Optional[str]): Your Aeri secret API key. Can also be set via AERI_SECRET_KEY environment variable.
|
|
156
|
+
base_url (Optional[str]): The Aeri API base URL. Defaults to "https://cloud.aeri.com". Can also be set via AERI_BASE_URL environment variable.
|
|
157
|
+
host (Optional[str]): Deprecated. Use base_url instead. The Aeri API host URL. Defaults to "https://cloud.aeri.com".
|
|
158
|
+
timeout (Optional[int]): Timeout in seconds for API requests. Defaults to 5 seconds.
|
|
159
|
+
httpx_client (Optional[httpx.Client]): Custom httpx client for making non-tracing HTTP requests. If not provided, a default client will be created.
|
|
160
|
+
debug (bool): Enable debug logging. Defaults to False. Can also be set via AERI_DEBUG environment variable.
|
|
161
|
+
tracing_enabled (Optional[bool]): Enable or disable tracing. Defaults to True. Can also be set via AERI_TRACING_ENABLED environment variable.
|
|
162
|
+
flush_at (Optional[int]): Number of spans to batch before sending to the API. Defaults to 512. Can also be set via AERI_FLUSH_AT environment variable.
|
|
163
|
+
flush_interval (Optional[float]): Time in seconds between batch flushes. Defaults to 5 seconds. Can also be set via AERI_FLUSH_INTERVAL environment variable.
|
|
164
|
+
environment (Optional[str]): Environment name for tracing. Default is 'default'. Can also be set via AERI_TRACING_ENVIRONMENT environment variable. Can be any lowercase alphanumeric string with hyphens and underscores that does not start with 'aeri'.
|
|
165
|
+
release (Optional[str]): Release version/hash of your application. Used for grouping analytics by release.
|
|
166
|
+
media_upload_thread_count (Optional[int]): Number of background threads for handling media uploads. Defaults to 1. Can also be set via AERI_MEDIA_UPLOAD_THREAD_COUNT environment variable.
|
|
167
|
+
sample_rate (Optional[float]): Sampling rate for traces (0.0 to 1.0). Defaults to 1.0 (100% of traces are sampled). Can also be set via AERI_SAMPLE_RATE environment variable.
|
|
168
|
+
mask (Optional[MaskFunction]): Function to mask sensitive data in traces before sending to the API.
|
|
169
|
+
blocked_instrumentation_scopes (Optional[List[str]]): Deprecated. Use `should_export_span` instead. Equivalent behavior:
|
|
170
|
+
```python
|
|
171
|
+
from aeri.span_filter import is_default_export_span
|
|
172
|
+
blocked = {"sqlite", "requests"}
|
|
173
|
+
|
|
174
|
+
should_export_span = lambda span: (
|
|
175
|
+
is_default_export_span(span)
|
|
176
|
+
and (
|
|
177
|
+
span.instrumentation_scope is None
|
|
178
|
+
or span.instrumentation_scope.name not in blocked
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
```
|
|
182
|
+
should_export_span (Optional[Callable[[ReadableSpan], bool]]): Callback to decide whether to export a span. If omitted, Aeri uses the default filter (Aeri SDK spans, spans with `gen_ai.*` attributes, and known LLM instrumentation scopes).
|
|
183
|
+
additional_headers (Optional[Dict[str, str]]): Additional headers to include in all API requests and OTLPSpanExporter requests. These headers will be merged with default headers. Note: If httpx_client is provided, additional_headers must be set directly on your custom httpx_client as well.
|
|
184
|
+
tracer_provider(Optional[TracerProvider]): OpenTelemetry TracerProvider to use for Aeri. This can be useful to set to have disconnected tracing between Aeri and other OpenTelemetry-span emitting libraries. Note: To track active spans, the context is still shared between TracerProviders. This may lead to broken trace trees.
|
|
185
|
+
|
|
186
|
+
Example:
|
|
187
|
+
```python
|
|
188
|
+
from aeri.otel import Aeri
|
|
189
|
+
|
|
190
|
+
# Initialize the client (reads from env vars if not provided)
|
|
191
|
+
aeri = Aeri(
|
|
192
|
+
public_key="your-public-key",
|
|
193
|
+
secret_key="your-secret-key",
|
|
194
|
+
host="https://cloud.aeri.com", # Optional, default shown
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Create a trace span
|
|
198
|
+
with aeri.start_as_current_observation(name="process-query") as span:
|
|
199
|
+
# Your application code here
|
|
200
|
+
|
|
201
|
+
# Create a nested generation span for an LLM call
|
|
202
|
+
with span.start_as_current_generation(
|
|
203
|
+
name="generate-response",
|
|
204
|
+
model="gpt-4",
|
|
205
|
+
input={"query": "Tell me about AI"},
|
|
206
|
+
model_parameters={"temperature": 0.7, "max_tokens": 500}
|
|
207
|
+
) as generation:
|
|
208
|
+
# Generate response here
|
|
209
|
+
response = "AI is a field of computer science..."
|
|
210
|
+
|
|
211
|
+
generation.update(
|
|
212
|
+
output=response,
|
|
213
|
+
usage_details={"prompt_tokens": 10, "completion_tokens": 50},
|
|
214
|
+
cost_details={"total_cost": 0.0023}
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Score the generation (supports NUMERIC, BOOLEAN, CATEGORICAL)
|
|
218
|
+
generation.score(name="relevance", value=0.95, data_type="NUMERIC")
|
|
219
|
+
```
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
_resources: Optional[AeriResourceManager] = None
|
|
223
|
+
_mask: Optional[MaskFunction] = None
|
|
224
|
+
_otel_tracer: otel_trace_api.Tracer
|
|
225
|
+
|
|
226
|
+
def __init__(
|
|
227
|
+
self,
|
|
228
|
+
*,
|
|
229
|
+
public_key: Optional[str] = None,
|
|
230
|
+
secret_key: Optional[str] = None,
|
|
231
|
+
base_url: Optional[str] = None,
|
|
232
|
+
host: Optional[str] = None,
|
|
233
|
+
timeout: Optional[int] = None,
|
|
234
|
+
httpx_client: Optional[httpx.Client] = None,
|
|
235
|
+
debug: bool = False,
|
|
236
|
+
tracing_enabled: Optional[bool] = True,
|
|
237
|
+
flush_at: Optional[int] = None,
|
|
238
|
+
flush_interval: Optional[float] = None,
|
|
239
|
+
environment: Optional[str] = None,
|
|
240
|
+
release: Optional[str] = None,
|
|
241
|
+
media_upload_thread_count: Optional[int] = None,
|
|
242
|
+
sample_rate: Optional[float] = None,
|
|
243
|
+
mask: Optional[MaskFunction] = None,
|
|
244
|
+
blocked_instrumentation_scopes: Optional[List[str]] = None,
|
|
245
|
+
should_export_span: Optional[Callable[[ReadableSpan], bool]] = None,
|
|
246
|
+
additional_headers: Optional[Dict[str, str]] = None,
|
|
247
|
+
tracer_provider: Optional[TracerProvider] = None,
|
|
248
|
+
):
|
|
249
|
+
# ── Resolve raw values from kwargs + env vars ────────────────────────
|
|
250
|
+
resolved_base_url = (
|
|
251
|
+
base_url
|
|
252
|
+
or os.environ.get(AERI_BASE_URL)
|
|
253
|
+
or host
|
|
254
|
+
or os.environ.get(AERI_HOST, "https://api.aeri.com")
|
|
255
|
+
)
|
|
256
|
+
# ── Strict Pydantic V2 validation of constructor config ──────────────
|
|
257
|
+
_cfg = AeriClientConfig.model_validate(
|
|
258
|
+
{
|
|
259
|
+
"public_key": public_key or os.environ.get(AERI_PUBLIC_KEY),
|
|
260
|
+
"secret_key": secret_key or os.environ.get(AERI_SECRET_KEY),
|
|
261
|
+
"base_url": resolved_base_url,
|
|
262
|
+
"timeout": timeout or int(os.environ.get(AERI_TIMEOUT, 5)),
|
|
263
|
+
"debug": debug or (os.getenv(AERI_DEBUG, "false").lower() == "true"),
|
|
264
|
+
"tracing_enabled": (
|
|
265
|
+
tracing_enabled
|
|
266
|
+
and os.environ.get(AERI_TRACING_ENABLED, "true").lower() != "false"
|
|
267
|
+
),
|
|
268
|
+
"flush_at": flush_at or 15,
|
|
269
|
+
"flush_interval": flush_interval or 1.0,
|
|
270
|
+
"environment": environment or os.environ.get(AERI_TRACING_ENVIRONMENT),
|
|
271
|
+
"release": release or os.environ.get(AERI_RELEASE),
|
|
272
|
+
"media_upload_thread_count": media_upload_thread_count or 1,
|
|
273
|
+
"sample_rate": sample_rate or float(os.environ.get(AERI_SAMPLE_RATE, 1.0)),
|
|
274
|
+
}
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
self._base_url = _cfg.base_url
|
|
278
|
+
self._environment = environment or cast(
|
|
279
|
+
str, os.environ.get(AERI_TRACING_ENVIRONMENT)
|
|
280
|
+
)
|
|
281
|
+
self._release = (
|
|
282
|
+
release
|
|
283
|
+
or os.environ.get(AERI_RELEASE, None)
|
|
284
|
+
or get_common_release_envs()
|
|
285
|
+
)
|
|
286
|
+
self._project_id: Optional[str] = None
|
|
287
|
+
sample_rate = sample_rate or float(os.environ.get(AERI_SAMPLE_RATE, 1.0))
|
|
288
|
+
if not 0.0 <= sample_rate <= 1.0:
|
|
289
|
+
raise ValueError(
|
|
290
|
+
f"Sample rate must be between 0.0 and 1.0, got {sample_rate}"
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
timeout = timeout or int(os.environ.get(AERI_TIMEOUT, 5))
|
|
294
|
+
|
|
295
|
+
self._tracing_enabled = (
|
|
296
|
+
tracing_enabled
|
|
297
|
+
and os.environ.get(AERI_TRACING_ENABLED, "true").lower() != "false"
|
|
298
|
+
)
|
|
299
|
+
if not self._tracing_enabled:
|
|
300
|
+
aeri_logger.info(
|
|
301
|
+
"Configuration: Aeri tracing is explicitly disabled. No data will be sent to the Aeri API."
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
debug = (
|
|
305
|
+
debug if debug else (os.getenv(AERI_DEBUG, "false").lower() == "true")
|
|
306
|
+
)
|
|
307
|
+
if debug:
|
|
308
|
+
logging.basicConfig(
|
|
309
|
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
310
|
+
)
|
|
311
|
+
aeri_logger.setLevel(logging.DEBUG)
|
|
312
|
+
|
|
313
|
+
public_key = public_key or os.environ.get(AERI_PUBLIC_KEY)
|
|
314
|
+
if public_key is None:
|
|
315
|
+
aeri_logger.warning(
|
|
316
|
+
"Authentication error: Aeri client initialized without public_key. Client will be disabled. "
|
|
317
|
+
"Provide a public_key parameter or set AERI_PUBLIC_KEY environment variable. "
|
|
318
|
+
)
|
|
319
|
+
self._otel_tracer = otel_trace_api.NoOpTracer()
|
|
320
|
+
return
|
|
321
|
+
|
|
322
|
+
secret_key = secret_key or os.environ.get(AERI_SECRET_KEY)
|
|
323
|
+
if secret_key is None:
|
|
324
|
+
aeri_logger.warning(
|
|
325
|
+
"Authentication error: Aeri client initialized without secret_key. Client will be disabled. "
|
|
326
|
+
"Provide a secret_key parameter or set AERI_SECRET_KEY environment variable. "
|
|
327
|
+
)
|
|
328
|
+
self._otel_tracer = otel_trace_api.NoOpTracer()
|
|
329
|
+
return
|
|
330
|
+
|
|
331
|
+
if os.environ.get("OTEL_SDK_DISABLED", "false").lower() == "true":
|
|
332
|
+
aeri_logger.warning(
|
|
333
|
+
"OTEL_SDK_DISABLED is set. Aeri tracing will be disabled and no traces will appear in the UI."
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
if blocked_instrumentation_scopes is not None:
|
|
337
|
+
warnings.warn(
|
|
338
|
+
"`blocked_instrumentation_scopes` is deprecated and will be removed in a future release. "
|
|
339
|
+
"Use `should_export_span` instead. Example: "
|
|
340
|
+
"from aeri.span_filter import is_default_export_span; "
|
|
341
|
+
'blocked={"scope"}; should_export_span=lambda span: '
|
|
342
|
+
"is_default_export_span(span) and (span.instrumentation_scope is None or "
|
|
343
|
+
"span.instrumentation_scope.name not in blocked).",
|
|
344
|
+
DeprecationWarning,
|
|
345
|
+
stacklevel=2,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# Initialize api and tracer if requirements are met
|
|
349
|
+
self._resources = AeriResourceManager(
|
|
350
|
+
public_key=public_key,
|
|
351
|
+
secret_key=secret_key,
|
|
352
|
+
base_url=self._base_url,
|
|
353
|
+
timeout=timeout,
|
|
354
|
+
environment=self._environment,
|
|
355
|
+
release=release,
|
|
356
|
+
flush_at=flush_at,
|
|
357
|
+
flush_interval=flush_interval,
|
|
358
|
+
httpx_client=httpx_client,
|
|
359
|
+
media_upload_thread_count=media_upload_thread_count,
|
|
360
|
+
sample_rate=sample_rate,
|
|
361
|
+
mask=mask,
|
|
362
|
+
tracing_enabled=self._tracing_enabled,
|
|
363
|
+
blocked_instrumentation_scopes=blocked_instrumentation_scopes,
|
|
364
|
+
should_export_span=should_export_span,
|
|
365
|
+
additional_headers=additional_headers,
|
|
366
|
+
tracer_provider=tracer_provider,
|
|
367
|
+
)
|
|
368
|
+
self._mask = self._resources.mask
|
|
369
|
+
|
|
370
|
+
self._otel_tracer = (
|
|
371
|
+
self._resources.tracer
|
|
372
|
+
if self._tracing_enabled and self._resources.tracer is not None
|
|
373
|
+
else otel_trace_api.NoOpTracer()
|
|
374
|
+
)
|
|
375
|
+
self.api = self._resources.api
|
|
376
|
+
self.async_api = self._resources.async_api
|
|
377
|
+
|
|
378
|
+
@overload
|
|
379
|
+
def start_observation(
|
|
380
|
+
self,
|
|
381
|
+
*,
|
|
382
|
+
trace_context: Optional[TraceContext] = None,
|
|
383
|
+
name: str,
|
|
384
|
+
as_type: Literal["generation"],
|
|
385
|
+
input: Optional[Any] = None,
|
|
386
|
+
output: Optional[Any] = None,
|
|
387
|
+
metadata: Optional[Any] = None,
|
|
388
|
+
version: Optional[str] = None,
|
|
389
|
+
level: Optional[SpanLevel] = None,
|
|
390
|
+
status_message: Optional[str] = None,
|
|
391
|
+
completion_start_time: Optional[datetime] = None,
|
|
392
|
+
model: Optional[str] = None,
|
|
393
|
+
model_parameters: Optional[Dict[str, MapValue]] = None,
|
|
394
|
+
usage_details: Optional[Dict[str, int]] = None,
|
|
395
|
+
cost_details: Optional[Dict[str, float]] = None,
|
|
396
|
+
prompt: Optional[PromptClient] = None,
|
|
397
|
+
) -> AeriGeneration: ...
|
|
398
|
+
|
|
399
|
+
@overload
|
|
400
|
+
def start_observation(
|
|
401
|
+
self,
|
|
402
|
+
*,
|
|
403
|
+
trace_context: Optional[TraceContext] = None,
|
|
404
|
+
name: str,
|
|
405
|
+
as_type: Literal["span"] = "span",
|
|
406
|
+
input: Optional[Any] = None,
|
|
407
|
+
output: Optional[Any] = None,
|
|
408
|
+
metadata: Optional[Any] = None,
|
|
409
|
+
version: Optional[str] = None,
|
|
410
|
+
level: Optional[SpanLevel] = None,
|
|
411
|
+
status_message: Optional[str] = None,
|
|
412
|
+
) -> AeriSpan: ...
|
|
413
|
+
|
|
414
|
+
@overload
|
|
415
|
+
def start_observation(
|
|
416
|
+
self,
|
|
417
|
+
*,
|
|
418
|
+
trace_context: Optional[TraceContext] = None,
|
|
419
|
+
name: str,
|
|
420
|
+
as_type: Literal["agent"],
|
|
421
|
+
input: Optional[Any] = None,
|
|
422
|
+
output: Optional[Any] = None,
|
|
423
|
+
metadata: Optional[Any] = None,
|
|
424
|
+
version: Optional[str] = None,
|
|
425
|
+
level: Optional[SpanLevel] = None,
|
|
426
|
+
status_message: Optional[str] = None,
|
|
427
|
+
) -> AeriAgent: ...
|
|
428
|
+
|
|
429
|
+
@overload
|
|
430
|
+
def start_observation(
|
|
431
|
+
self,
|
|
432
|
+
*,
|
|
433
|
+
trace_context: Optional[TraceContext] = None,
|
|
434
|
+
name: str,
|
|
435
|
+
as_type: Literal["tool"],
|
|
436
|
+
input: Optional[Any] = None,
|
|
437
|
+
output: Optional[Any] = None,
|
|
438
|
+
metadata: Optional[Any] = None,
|
|
439
|
+
version: Optional[str] = None,
|
|
440
|
+
level: Optional[SpanLevel] = None,
|
|
441
|
+
status_message: Optional[str] = None,
|
|
442
|
+
) -> AeriTool: ...
|
|
443
|
+
|
|
444
|
+
@overload
|
|
445
|
+
def start_observation(
|
|
446
|
+
self,
|
|
447
|
+
*,
|
|
448
|
+
trace_context: Optional[TraceContext] = None,
|
|
449
|
+
name: str,
|
|
450
|
+
as_type: Literal["chain"],
|
|
451
|
+
input: Optional[Any] = None,
|
|
452
|
+
output: Optional[Any] = None,
|
|
453
|
+
metadata: Optional[Any] = None,
|
|
454
|
+
version: Optional[str] = None,
|
|
455
|
+
level: Optional[SpanLevel] = None,
|
|
456
|
+
status_message: Optional[str] = None,
|
|
457
|
+
) -> AeriChain: ...
|
|
458
|
+
|
|
459
|
+
@overload
|
|
460
|
+
def start_observation(
|
|
461
|
+
self,
|
|
462
|
+
*,
|
|
463
|
+
trace_context: Optional[TraceContext] = None,
|
|
464
|
+
name: str,
|
|
465
|
+
as_type: Literal["retriever"],
|
|
466
|
+
input: Optional[Any] = None,
|
|
467
|
+
output: Optional[Any] = None,
|
|
468
|
+
metadata: Optional[Any] = None,
|
|
469
|
+
version: Optional[str] = None,
|
|
470
|
+
level: Optional[SpanLevel] = None,
|
|
471
|
+
status_message: Optional[str] = None,
|
|
472
|
+
) -> AeriRetriever: ...
|
|
473
|
+
|
|
474
|
+
@overload
|
|
475
|
+
def start_observation(
|
|
476
|
+
self,
|
|
477
|
+
*,
|
|
478
|
+
trace_context: Optional[TraceContext] = None,
|
|
479
|
+
name: str,
|
|
480
|
+
as_type: Literal["evaluator"],
|
|
481
|
+
input: Optional[Any] = None,
|
|
482
|
+
output: Optional[Any] = None,
|
|
483
|
+
metadata: Optional[Any] = None,
|
|
484
|
+
version: Optional[str] = None,
|
|
485
|
+
level: Optional[SpanLevel] = None,
|
|
486
|
+
status_message: Optional[str] = None,
|
|
487
|
+
) -> AeriEvaluator: ...
|
|
488
|
+
|
|
489
|
+
@overload
|
|
490
|
+
def start_observation(
|
|
491
|
+
self,
|
|
492
|
+
*,
|
|
493
|
+
trace_context: Optional[TraceContext] = None,
|
|
494
|
+
name: str,
|
|
495
|
+
as_type: Literal["embedding"],
|
|
496
|
+
input: Optional[Any] = None,
|
|
497
|
+
output: Optional[Any] = None,
|
|
498
|
+
metadata: Optional[Any] = None,
|
|
499
|
+
version: Optional[str] = None,
|
|
500
|
+
level: Optional[SpanLevel] = None,
|
|
501
|
+
status_message: Optional[str] = None,
|
|
502
|
+
completion_start_time: Optional[datetime] = None,
|
|
503
|
+
model: Optional[str] = None,
|
|
504
|
+
model_parameters: Optional[Dict[str, MapValue]] = None,
|
|
505
|
+
usage_details: Optional[Dict[str, int]] = None,
|
|
506
|
+
cost_details: Optional[Dict[str, float]] = None,
|
|
507
|
+
prompt: Optional[PromptClient] = None,
|
|
508
|
+
) -> AeriEmbedding: ...
|
|
509
|
+
|
|
510
|
+
@overload
|
|
511
|
+
def start_observation(
|
|
512
|
+
self,
|
|
513
|
+
*,
|
|
514
|
+
trace_context: Optional[TraceContext] = None,
|
|
515
|
+
name: str,
|
|
516
|
+
as_type: Literal["guardrail"],
|
|
517
|
+
input: Optional[Any] = None,
|
|
518
|
+
output: Optional[Any] = None,
|
|
519
|
+
metadata: Optional[Any] = None,
|
|
520
|
+
version: Optional[str] = None,
|
|
521
|
+
level: Optional[SpanLevel] = None,
|
|
522
|
+
status_message: Optional[str] = None,
|
|
523
|
+
) -> AeriGuardrail: ...
|
|
524
|
+
|
|
525
|
+
def start_observation(
|
|
526
|
+
self,
|
|
527
|
+
*,
|
|
528
|
+
trace_context: Optional[TraceContext] = None,
|
|
529
|
+
name: str,
|
|
530
|
+
as_type: ObservationTypeLiteralNoEvent = "span",
|
|
531
|
+
input: Optional[Any] = None,
|
|
532
|
+
output: Optional[Any] = None,
|
|
533
|
+
metadata: Optional[Any] = None,
|
|
534
|
+
version: Optional[str] = None,
|
|
535
|
+
level: Optional[SpanLevel] = None,
|
|
536
|
+
status_message: Optional[str] = None,
|
|
537
|
+
completion_start_time: Optional[datetime] = None,
|
|
538
|
+
model: Optional[str] = None,
|
|
539
|
+
model_parameters: Optional[Dict[str, MapValue]] = None,
|
|
540
|
+
usage_details: Optional[Dict[str, int]] = None,
|
|
541
|
+
cost_details: Optional[Dict[str, float]] = None,
|
|
542
|
+
prompt: Optional[PromptClient] = None,
|
|
543
|
+
) -> Union[
|
|
544
|
+
AeriSpan,
|
|
545
|
+
AeriGeneration,
|
|
546
|
+
AeriAgent,
|
|
547
|
+
AeriTool,
|
|
548
|
+
AeriChain,
|
|
549
|
+
AeriRetriever,
|
|
550
|
+
AeriEvaluator,
|
|
551
|
+
AeriEmbedding,
|
|
552
|
+
AeriGuardrail,
|
|
553
|
+
]:
|
|
554
|
+
"""Create a new observation of the specified type.
|
|
555
|
+
|
|
556
|
+
This method creates a new observation but does not set it as the current span in the
|
|
557
|
+
context. To create and use an observation within a context, use start_as_current_observation().
|
|
558
|
+
|
|
559
|
+
Args:
|
|
560
|
+
trace_context: Optional context for connecting to an existing trace
|
|
561
|
+
name: Name of the observation
|
|
562
|
+
as_type: Type of observation to create (defaults to "span")
|
|
563
|
+
input: Input data for the operation
|
|
564
|
+
output: Output data from the operation
|
|
565
|
+
metadata: Additional metadata to associate with the observation
|
|
566
|
+
version: Version identifier for the code or component
|
|
567
|
+
level: Importance level of the observation
|
|
568
|
+
status_message: Optional status message for the observation
|
|
569
|
+
completion_start_time: When the model started generating (for generation types)
|
|
570
|
+
model: Name/identifier of the AI model used (for generation types)
|
|
571
|
+
model_parameters: Parameters used for the model (for generation types)
|
|
572
|
+
usage_details: Token usage information (for generation types)
|
|
573
|
+
cost_details: Cost information (for generation types)
|
|
574
|
+
prompt: Associated prompt template (for generation types)
|
|
575
|
+
|
|
576
|
+
Returns:
|
|
577
|
+
An observation object of the appropriate type that must be ended with .end()
|
|
578
|
+
"""
|
|
579
|
+
if trace_context:
|
|
580
|
+
trace_id = trace_context.get("trace_id", None)
|
|
581
|
+
parent_span_id = trace_context.get("parent_span_id", None)
|
|
582
|
+
|
|
583
|
+
if trace_id:
|
|
584
|
+
remote_parent_span = self._create_remote_parent_span(
|
|
585
|
+
trace_id=trace_id, parent_span_id=parent_span_id
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
with otel_trace_api.use_span(
|
|
589
|
+
cast(otel_trace_api.Span, remote_parent_span)
|
|
590
|
+
):
|
|
591
|
+
otel_span = self._otel_tracer.start_span(name=name)
|
|
592
|
+
otel_span.set_attribute(AeriOtelSpanAttributes.AS_ROOT, True)
|
|
593
|
+
|
|
594
|
+
return self._create_observation_from_otel_span(
|
|
595
|
+
otel_span=otel_span,
|
|
596
|
+
as_type=as_type,
|
|
597
|
+
input=input,
|
|
598
|
+
output=output,
|
|
599
|
+
metadata=metadata,
|
|
600
|
+
version=version,
|
|
601
|
+
level=level,
|
|
602
|
+
status_message=status_message,
|
|
603
|
+
completion_start_time=completion_start_time,
|
|
604
|
+
model=model,
|
|
605
|
+
model_parameters=model_parameters,
|
|
606
|
+
usage_details=usage_details,
|
|
607
|
+
cost_details=cost_details,
|
|
608
|
+
prompt=prompt,
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
otel_span = self._otel_tracer.start_span(name=name)
|
|
612
|
+
|
|
613
|
+
return self._create_observation_from_otel_span(
|
|
614
|
+
otel_span=otel_span,
|
|
615
|
+
as_type=as_type,
|
|
616
|
+
input=input,
|
|
617
|
+
output=output,
|
|
618
|
+
metadata=metadata,
|
|
619
|
+
version=version,
|
|
620
|
+
level=level,
|
|
621
|
+
status_message=status_message,
|
|
622
|
+
completion_start_time=completion_start_time,
|
|
623
|
+
model=model,
|
|
624
|
+
model_parameters=model_parameters,
|
|
625
|
+
usage_details=usage_details,
|
|
626
|
+
cost_details=cost_details,
|
|
627
|
+
prompt=prompt,
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
def _create_observation_from_otel_span(
|
|
631
|
+
self,
|
|
632
|
+
*,
|
|
633
|
+
otel_span: otel_trace_api.Span,
|
|
634
|
+
as_type: ObservationTypeLiteralNoEvent,
|
|
635
|
+
input: Optional[Any] = None,
|
|
636
|
+
output: Optional[Any] = None,
|
|
637
|
+
metadata: Optional[Any] = None,
|
|
638
|
+
version: Optional[str] = None,
|
|
639
|
+
level: Optional[SpanLevel] = None,
|
|
640
|
+
status_message: Optional[str] = None,
|
|
641
|
+
completion_start_time: Optional[datetime] = None,
|
|
642
|
+
model: Optional[str] = None,
|
|
643
|
+
model_parameters: Optional[Dict[str, MapValue]] = None,
|
|
644
|
+
usage_details: Optional[Dict[str, int]] = None,
|
|
645
|
+
cost_details: Optional[Dict[str, float]] = None,
|
|
646
|
+
prompt: Optional[PromptClient] = None,
|
|
647
|
+
) -> Union[
|
|
648
|
+
AeriSpan,
|
|
649
|
+
AeriGeneration,
|
|
650
|
+
AeriAgent,
|
|
651
|
+
AeriTool,
|
|
652
|
+
AeriChain,
|
|
653
|
+
AeriRetriever,
|
|
654
|
+
AeriEvaluator,
|
|
655
|
+
AeriEmbedding,
|
|
656
|
+
AeriGuardrail,
|
|
657
|
+
]:
|
|
658
|
+
"""Create the appropriate observation type from an OTEL span."""
|
|
659
|
+
if as_type in get_observation_types_list(ObservationTypeGenerationLike):
|
|
660
|
+
observation_class = self._get_span_class(as_type)
|
|
661
|
+
# Type ignore to prevent overloads of internal _get_span_class function,
|
|
662
|
+
# issue is that AeriEvent could be returned and that classes have diff. args
|
|
663
|
+
return observation_class( # type: ignore[return-value,call-arg]
|
|
664
|
+
otel_span=otel_span,
|
|
665
|
+
aeri_client=self,
|
|
666
|
+
environment=self._environment,
|
|
667
|
+
release=self._release,
|
|
668
|
+
input=input,
|
|
669
|
+
output=output,
|
|
670
|
+
metadata=metadata,
|
|
671
|
+
version=version,
|
|
672
|
+
level=level,
|
|
673
|
+
status_message=status_message,
|
|
674
|
+
completion_start_time=completion_start_time,
|
|
675
|
+
model=model,
|
|
676
|
+
model_parameters=model_parameters,
|
|
677
|
+
usage_details=usage_details,
|
|
678
|
+
cost_details=cost_details,
|
|
679
|
+
prompt=prompt,
|
|
680
|
+
)
|
|
681
|
+
else:
|
|
682
|
+
# For other types (e.g. span, guardrail), create appropriate class without generation properties
|
|
683
|
+
observation_class = self._get_span_class(as_type)
|
|
684
|
+
# Type ignore to prevent overloads of internal _get_span_class function,
|
|
685
|
+
# issue is that AeriEvent could be returned and that classes have diff. args
|
|
686
|
+
return observation_class( # type: ignore[return-value,call-arg]
|
|
687
|
+
otel_span=otel_span,
|
|
688
|
+
aeri_client=self,
|
|
689
|
+
environment=self._environment,
|
|
690
|
+
release=self._release,
|
|
691
|
+
input=input,
|
|
692
|
+
output=output,
|
|
693
|
+
metadata=metadata,
|
|
694
|
+
version=version,
|
|
695
|
+
level=level,
|
|
696
|
+
status_message=status_message,
|
|
697
|
+
)
|
|
698
|
+
# span._observation_type = as_type
|
|
699
|
+
# span._otel_span.set_attribute("aeri.observation.type", as_type)
|
|
700
|
+
# return span
|
|
701
|
+
|
|
702
|
+
@overload
|
|
703
|
+
def start_as_current_observation(
|
|
704
|
+
self,
|
|
705
|
+
*,
|
|
706
|
+
trace_context: Optional[TraceContext] = None,
|
|
707
|
+
name: str,
|
|
708
|
+
as_type: Literal["generation"],
|
|
709
|
+
input: Optional[Any] = None,
|
|
710
|
+
output: Optional[Any] = None,
|
|
711
|
+
metadata: Optional[Any] = None,
|
|
712
|
+
version: Optional[str] = None,
|
|
713
|
+
level: Optional[SpanLevel] = None,
|
|
714
|
+
status_message: Optional[str] = None,
|
|
715
|
+
completion_start_time: Optional[datetime] = None,
|
|
716
|
+
model: Optional[str] = None,
|
|
717
|
+
model_parameters: Optional[Dict[str, MapValue]] = None,
|
|
718
|
+
usage_details: Optional[Dict[str, int]] = None,
|
|
719
|
+
cost_details: Optional[Dict[str, float]] = None,
|
|
720
|
+
prompt: Optional[PromptClient] = None,
|
|
721
|
+
end_on_exit: Optional[bool] = None,
|
|
722
|
+
) -> _AgnosticContextManager[AeriGeneration]: ...
|
|
723
|
+
|
|
724
|
+
@overload
|
|
725
|
+
def start_as_current_observation(
|
|
726
|
+
self,
|
|
727
|
+
*,
|
|
728
|
+
trace_context: Optional[TraceContext] = None,
|
|
729
|
+
name: str,
|
|
730
|
+
as_type: Literal["span"] = "span",
|
|
731
|
+
input: Optional[Any] = None,
|
|
732
|
+
output: Optional[Any] = None,
|
|
733
|
+
metadata: Optional[Any] = None,
|
|
734
|
+
version: Optional[str] = None,
|
|
735
|
+
level: Optional[SpanLevel] = None,
|
|
736
|
+
status_message: Optional[str] = None,
|
|
737
|
+
end_on_exit: Optional[bool] = None,
|
|
738
|
+
) -> _AgnosticContextManager[AeriSpan]: ...
|
|
739
|
+
|
|
740
|
+
@overload
|
|
741
|
+
def start_as_current_observation(
|
|
742
|
+
self,
|
|
743
|
+
*,
|
|
744
|
+
trace_context: Optional[TraceContext] = None,
|
|
745
|
+
name: str,
|
|
746
|
+
as_type: Literal["agent"],
|
|
747
|
+
input: Optional[Any] = None,
|
|
748
|
+
output: Optional[Any] = None,
|
|
749
|
+
metadata: Optional[Any] = None,
|
|
750
|
+
version: Optional[str] = None,
|
|
751
|
+
level: Optional[SpanLevel] = None,
|
|
752
|
+
status_message: Optional[str] = None,
|
|
753
|
+
end_on_exit: Optional[bool] = None,
|
|
754
|
+
) -> _AgnosticContextManager[AeriAgent]: ...
|
|
755
|
+
|
|
756
|
+
@overload
|
|
757
|
+
def start_as_current_observation(
|
|
758
|
+
self,
|
|
759
|
+
*,
|
|
760
|
+
trace_context: Optional[TraceContext] = None,
|
|
761
|
+
name: str,
|
|
762
|
+
as_type: Literal["tool"],
|
|
763
|
+
input: Optional[Any] = None,
|
|
764
|
+
output: Optional[Any] = None,
|
|
765
|
+
metadata: Optional[Any] = None,
|
|
766
|
+
version: Optional[str] = None,
|
|
767
|
+
level: Optional[SpanLevel] = None,
|
|
768
|
+
status_message: Optional[str] = None,
|
|
769
|
+
end_on_exit: Optional[bool] = None,
|
|
770
|
+
) -> _AgnosticContextManager[AeriTool]: ...
|
|
771
|
+
|
|
772
|
+
@overload
|
|
773
|
+
def start_as_current_observation(
|
|
774
|
+
self,
|
|
775
|
+
*,
|
|
776
|
+
trace_context: Optional[TraceContext] = None,
|
|
777
|
+
name: str,
|
|
778
|
+
as_type: Literal["chain"],
|
|
779
|
+
input: Optional[Any] = None,
|
|
780
|
+
output: Optional[Any] = None,
|
|
781
|
+
metadata: Optional[Any] = None,
|
|
782
|
+
version: Optional[str] = None,
|
|
783
|
+
level: Optional[SpanLevel] = None,
|
|
784
|
+
status_message: Optional[str] = None,
|
|
785
|
+
end_on_exit: Optional[bool] = None,
|
|
786
|
+
) -> _AgnosticContextManager[AeriChain]: ...
|
|
787
|
+
|
|
788
|
+
@overload
|
|
789
|
+
def start_as_current_observation(
|
|
790
|
+
self,
|
|
791
|
+
*,
|
|
792
|
+
trace_context: Optional[TraceContext] = None,
|
|
793
|
+
name: str,
|
|
794
|
+
as_type: Literal["retriever"],
|
|
795
|
+
input: Optional[Any] = None,
|
|
796
|
+
output: Optional[Any] = None,
|
|
797
|
+
metadata: Optional[Any] = None,
|
|
798
|
+
version: Optional[str] = None,
|
|
799
|
+
level: Optional[SpanLevel] = None,
|
|
800
|
+
status_message: Optional[str] = None,
|
|
801
|
+
end_on_exit: Optional[bool] = None,
|
|
802
|
+
) -> _AgnosticContextManager[AeriRetriever]: ...
|
|
803
|
+
|
|
804
|
+
@overload
|
|
805
|
+
def start_as_current_observation(
|
|
806
|
+
self,
|
|
807
|
+
*,
|
|
808
|
+
trace_context: Optional[TraceContext] = None,
|
|
809
|
+
name: str,
|
|
810
|
+
as_type: Literal["evaluator"],
|
|
811
|
+
input: Optional[Any] = None,
|
|
812
|
+
output: Optional[Any] = None,
|
|
813
|
+
metadata: Optional[Any] = None,
|
|
814
|
+
version: Optional[str] = None,
|
|
815
|
+
level: Optional[SpanLevel] = None,
|
|
816
|
+
status_message: Optional[str] = None,
|
|
817
|
+
end_on_exit: Optional[bool] = None,
|
|
818
|
+
) -> _AgnosticContextManager[AeriEvaluator]: ...
|
|
819
|
+
|
|
820
|
+
@overload
|
|
821
|
+
def start_as_current_observation(
|
|
822
|
+
self,
|
|
823
|
+
*,
|
|
824
|
+
trace_context: Optional[TraceContext] = None,
|
|
825
|
+
name: str,
|
|
826
|
+
as_type: Literal["embedding"],
|
|
827
|
+
input: Optional[Any] = None,
|
|
828
|
+
output: Optional[Any] = None,
|
|
829
|
+
metadata: Optional[Any] = None,
|
|
830
|
+
version: Optional[str] = None,
|
|
831
|
+
level: Optional[SpanLevel] = None,
|
|
832
|
+
status_message: Optional[str] = None,
|
|
833
|
+
completion_start_time: Optional[datetime] = None,
|
|
834
|
+
model: Optional[str] = None,
|
|
835
|
+
model_parameters: Optional[Dict[str, MapValue]] = None,
|
|
836
|
+
usage_details: Optional[Dict[str, int]] = None,
|
|
837
|
+
cost_details: Optional[Dict[str, float]] = None,
|
|
838
|
+
prompt: Optional[PromptClient] = None,
|
|
839
|
+
end_on_exit: Optional[bool] = None,
|
|
840
|
+
) -> _AgnosticContextManager[AeriEmbedding]: ...
|
|
841
|
+
|
|
842
|
+
@overload
|
|
843
|
+
def start_as_current_observation(
|
|
844
|
+
self,
|
|
845
|
+
*,
|
|
846
|
+
trace_context: Optional[TraceContext] = None,
|
|
847
|
+
name: str,
|
|
848
|
+
as_type: Literal["guardrail"],
|
|
849
|
+
input: Optional[Any] = None,
|
|
850
|
+
output: Optional[Any] = None,
|
|
851
|
+
metadata: Optional[Any] = None,
|
|
852
|
+
version: Optional[str] = None,
|
|
853
|
+
level: Optional[SpanLevel] = None,
|
|
854
|
+
status_message: Optional[str] = None,
|
|
855
|
+
end_on_exit: Optional[bool] = None,
|
|
856
|
+
) -> _AgnosticContextManager[AeriGuardrail]: ...
|
|
857
|
+
|
|
858
|
+
def start_as_current_observation(
|
|
859
|
+
self,
|
|
860
|
+
*,
|
|
861
|
+
trace_context: Optional[TraceContext] = None,
|
|
862
|
+
name: str,
|
|
863
|
+
as_type: ObservationTypeLiteralNoEvent = "span",
|
|
864
|
+
input: Optional[Any] = None,
|
|
865
|
+
output: Optional[Any] = None,
|
|
866
|
+
metadata: Optional[Any] = None,
|
|
867
|
+
version: Optional[str] = None,
|
|
868
|
+
level: Optional[SpanLevel] = None,
|
|
869
|
+
status_message: Optional[str] = None,
|
|
870
|
+
completion_start_time: Optional[datetime] = None,
|
|
871
|
+
model: Optional[str] = None,
|
|
872
|
+
model_parameters: Optional[Dict[str, MapValue]] = None,
|
|
873
|
+
usage_details: Optional[Dict[str, int]] = None,
|
|
874
|
+
cost_details: Optional[Dict[str, float]] = None,
|
|
875
|
+
prompt: Optional[PromptClient] = None,
|
|
876
|
+
end_on_exit: Optional[bool] = None,
|
|
877
|
+
) -> Union[
|
|
878
|
+
_AgnosticContextManager[AeriGeneration],
|
|
879
|
+
_AgnosticContextManager[AeriSpan],
|
|
880
|
+
_AgnosticContextManager[AeriAgent],
|
|
881
|
+
_AgnosticContextManager[AeriTool],
|
|
882
|
+
_AgnosticContextManager[AeriChain],
|
|
883
|
+
_AgnosticContextManager[AeriRetriever],
|
|
884
|
+
_AgnosticContextManager[AeriEvaluator],
|
|
885
|
+
_AgnosticContextManager[AeriEmbedding],
|
|
886
|
+
_AgnosticContextManager[AeriGuardrail],
|
|
887
|
+
]:
|
|
888
|
+
"""Create a new observation and set it as the current span in a context manager.
|
|
889
|
+
|
|
890
|
+
This method creates a new observation of the specified type and sets it as the
|
|
891
|
+
current span within a context manager. Use this method with a 'with' statement to
|
|
892
|
+
automatically handle the observation lifecycle within a code block.
|
|
893
|
+
|
|
894
|
+
The created observation will be the child of the current span in the context.
|
|
895
|
+
|
|
896
|
+
Args:
|
|
897
|
+
trace_context: Optional context for connecting to an existing trace
|
|
898
|
+
name: Name of the observation (e.g., function or operation name)
|
|
899
|
+
as_type: Type of observation to create (defaults to "span")
|
|
900
|
+
input: Input data for the operation (can be any JSON-serializable object)
|
|
901
|
+
output: Output data from the operation (can be any JSON-serializable object)
|
|
902
|
+
metadata: Additional metadata to associate with the observation
|
|
903
|
+
version: Version identifier for the code or component
|
|
904
|
+
level: Importance level of the observation (info, warning, error)
|
|
905
|
+
status_message: Optional status message for the observation
|
|
906
|
+
end_on_exit (default: True): Whether to end the span automatically when leaving the context manager. If False, the span must be manually ended to avoid memory leaks.
|
|
907
|
+
|
|
908
|
+
The following parameters are available when as_type is: "generation" or "embedding".
|
|
909
|
+
completion_start_time: When the model started generating the response
|
|
910
|
+
model: Name/identifier of the AI model used (e.g., "gpt-4")
|
|
911
|
+
model_parameters: Parameters used for the model (e.g., temperature, max_tokens)
|
|
912
|
+
usage_details: Token usage information (e.g., prompt_tokens, completion_tokens)
|
|
913
|
+
cost_details: Cost information for the model call
|
|
914
|
+
prompt: Associated prompt template from Aeri prompt management
|
|
915
|
+
|
|
916
|
+
Returns:
|
|
917
|
+
A context manager that yields the appropriate observation type based on as_type
|
|
918
|
+
|
|
919
|
+
Example:
|
|
920
|
+
```python
|
|
921
|
+
# Create a span
|
|
922
|
+
with aeri.start_as_current_observation(name="process-query", as_type="span") as span:
|
|
923
|
+
# Do work
|
|
924
|
+
result = process_data()
|
|
925
|
+
span.update(output=result)
|
|
926
|
+
|
|
927
|
+
# Create a child span automatically
|
|
928
|
+
with span.start_as_current_observation(name="sub-operation") as child_span:
|
|
929
|
+
# Do sub-operation work
|
|
930
|
+
child_span.update(output="sub-result")
|
|
931
|
+
|
|
932
|
+
# Create a tool observation
|
|
933
|
+
with aeri.start_as_current_observation(name="web-search", as_type="tool") as tool:
|
|
934
|
+
# Do tool work
|
|
935
|
+
results = search_web(query)
|
|
936
|
+
tool.update(output=results)
|
|
937
|
+
|
|
938
|
+
# Create a generation observation
|
|
939
|
+
with aeri.start_as_current_observation(
|
|
940
|
+
name="answer-generation",
|
|
941
|
+
as_type="generation",
|
|
942
|
+
model="gpt-4"
|
|
943
|
+
) as generation:
|
|
944
|
+
# Generate answer
|
|
945
|
+
response = llm.generate(...)
|
|
946
|
+
generation.update(output=response)
|
|
947
|
+
```
|
|
948
|
+
"""
|
|
949
|
+
if as_type in get_observation_types_list(ObservationTypeGenerationLike):
|
|
950
|
+
if trace_context:
|
|
951
|
+
trace_id = trace_context.get("trace_id", None)
|
|
952
|
+
parent_span_id = trace_context.get("parent_span_id", None)
|
|
953
|
+
|
|
954
|
+
if trace_id:
|
|
955
|
+
remote_parent_span = self._create_remote_parent_span(
|
|
956
|
+
trace_id=trace_id, parent_span_id=parent_span_id
|
|
957
|
+
)
|
|
958
|
+
|
|
959
|
+
return cast(
|
|
960
|
+
Union[
|
|
961
|
+
_AgnosticContextManager[AeriGeneration],
|
|
962
|
+
_AgnosticContextManager[AeriEmbedding],
|
|
963
|
+
],
|
|
964
|
+
self._create_span_with_parent_context(
|
|
965
|
+
as_type=as_type,
|
|
966
|
+
name=name,
|
|
967
|
+
remote_parent_span=remote_parent_span,
|
|
968
|
+
parent=None,
|
|
969
|
+
end_on_exit=end_on_exit,
|
|
970
|
+
input=input,
|
|
971
|
+
output=output,
|
|
972
|
+
metadata=metadata,
|
|
973
|
+
version=version,
|
|
974
|
+
level=level,
|
|
975
|
+
status_message=status_message,
|
|
976
|
+
completion_start_time=completion_start_time,
|
|
977
|
+
model=model,
|
|
978
|
+
model_parameters=model_parameters,
|
|
979
|
+
usage_details=usage_details,
|
|
980
|
+
cost_details=cost_details,
|
|
981
|
+
prompt=prompt,
|
|
982
|
+
),
|
|
983
|
+
)
|
|
984
|
+
|
|
985
|
+
return cast(
|
|
986
|
+
Union[
|
|
987
|
+
_AgnosticContextManager[AeriGeneration],
|
|
988
|
+
_AgnosticContextManager[AeriEmbedding],
|
|
989
|
+
],
|
|
990
|
+
self._start_as_current_otel_span_with_processed_media(
|
|
991
|
+
as_type=as_type,
|
|
992
|
+
name=name,
|
|
993
|
+
end_on_exit=end_on_exit,
|
|
994
|
+
input=input,
|
|
995
|
+
output=output,
|
|
996
|
+
metadata=metadata,
|
|
997
|
+
version=version,
|
|
998
|
+
level=level,
|
|
999
|
+
status_message=status_message,
|
|
1000
|
+
completion_start_time=completion_start_time,
|
|
1001
|
+
model=model,
|
|
1002
|
+
model_parameters=model_parameters,
|
|
1003
|
+
usage_details=usage_details,
|
|
1004
|
+
cost_details=cost_details,
|
|
1005
|
+
prompt=prompt,
|
|
1006
|
+
),
|
|
1007
|
+
)
|
|
1008
|
+
|
|
1009
|
+
if as_type in get_observation_types_list(ObservationTypeSpanLike):
|
|
1010
|
+
if trace_context:
|
|
1011
|
+
trace_id = trace_context.get("trace_id", None)
|
|
1012
|
+
parent_span_id = trace_context.get("parent_span_id", None)
|
|
1013
|
+
|
|
1014
|
+
if trace_id:
|
|
1015
|
+
remote_parent_span = self._create_remote_parent_span(
|
|
1016
|
+
trace_id=trace_id, parent_span_id=parent_span_id
|
|
1017
|
+
)
|
|
1018
|
+
|
|
1019
|
+
return cast(
|
|
1020
|
+
Union[
|
|
1021
|
+
_AgnosticContextManager[AeriSpan],
|
|
1022
|
+
_AgnosticContextManager[AeriAgent],
|
|
1023
|
+
_AgnosticContextManager[AeriTool],
|
|
1024
|
+
_AgnosticContextManager[AeriChain],
|
|
1025
|
+
_AgnosticContextManager[AeriRetriever],
|
|
1026
|
+
_AgnosticContextManager[AeriEvaluator],
|
|
1027
|
+
_AgnosticContextManager[AeriGuardrail],
|
|
1028
|
+
],
|
|
1029
|
+
self._create_span_with_parent_context(
|
|
1030
|
+
as_type=as_type,
|
|
1031
|
+
name=name,
|
|
1032
|
+
remote_parent_span=remote_parent_span,
|
|
1033
|
+
parent=None,
|
|
1034
|
+
end_on_exit=end_on_exit,
|
|
1035
|
+
input=input,
|
|
1036
|
+
output=output,
|
|
1037
|
+
metadata=metadata,
|
|
1038
|
+
version=version,
|
|
1039
|
+
level=level,
|
|
1040
|
+
status_message=status_message,
|
|
1041
|
+
),
|
|
1042
|
+
)
|
|
1043
|
+
|
|
1044
|
+
return cast(
|
|
1045
|
+
Union[
|
|
1046
|
+
_AgnosticContextManager[AeriSpan],
|
|
1047
|
+
_AgnosticContextManager[AeriAgent],
|
|
1048
|
+
_AgnosticContextManager[AeriTool],
|
|
1049
|
+
_AgnosticContextManager[AeriChain],
|
|
1050
|
+
_AgnosticContextManager[AeriRetriever],
|
|
1051
|
+
_AgnosticContextManager[AeriEvaluator],
|
|
1052
|
+
_AgnosticContextManager[AeriGuardrail],
|
|
1053
|
+
],
|
|
1054
|
+
self._start_as_current_otel_span_with_processed_media(
|
|
1055
|
+
as_type=as_type,
|
|
1056
|
+
name=name,
|
|
1057
|
+
end_on_exit=end_on_exit,
|
|
1058
|
+
input=input,
|
|
1059
|
+
output=output,
|
|
1060
|
+
metadata=metadata,
|
|
1061
|
+
version=version,
|
|
1062
|
+
level=level,
|
|
1063
|
+
status_message=status_message,
|
|
1064
|
+
),
|
|
1065
|
+
)
|
|
1066
|
+
|
|
1067
|
+
# This should never be reached since all valid types are handled above
|
|
1068
|
+
aeri_logger.warning(
|
|
1069
|
+
f"Unknown observation type: {as_type}, falling back to span"
|
|
1070
|
+
)
|
|
1071
|
+
return self._start_as_current_otel_span_with_processed_media(
|
|
1072
|
+
as_type="span",
|
|
1073
|
+
name=name,
|
|
1074
|
+
end_on_exit=end_on_exit,
|
|
1075
|
+
input=input,
|
|
1076
|
+
output=output,
|
|
1077
|
+
metadata=metadata,
|
|
1078
|
+
version=version,
|
|
1079
|
+
level=level,
|
|
1080
|
+
status_message=status_message,
|
|
1081
|
+
)
|
|
1082
|
+
|
|
1083
|
+
def _get_span_class(
|
|
1084
|
+
self,
|
|
1085
|
+
as_type: ObservationTypeLiteral,
|
|
1086
|
+
) -> Union[
|
|
1087
|
+
Type[AeriAgent],
|
|
1088
|
+
Type[AeriTool],
|
|
1089
|
+
Type[AeriChain],
|
|
1090
|
+
Type[AeriRetriever],
|
|
1091
|
+
Type[AeriEvaluator],
|
|
1092
|
+
Type[AeriEmbedding],
|
|
1093
|
+
Type[AeriGuardrail],
|
|
1094
|
+
Type[AeriGeneration],
|
|
1095
|
+
Type[AeriEvent],
|
|
1096
|
+
Type[AeriSpan],
|
|
1097
|
+
]:
|
|
1098
|
+
"""Get the appropriate span class based on as_type."""
|
|
1099
|
+
normalized_type = as_type.lower()
|
|
1100
|
+
|
|
1101
|
+
if normalized_type == "agent":
|
|
1102
|
+
return AeriAgent
|
|
1103
|
+
elif normalized_type == "tool":
|
|
1104
|
+
return AeriTool
|
|
1105
|
+
elif normalized_type == "chain":
|
|
1106
|
+
return AeriChain
|
|
1107
|
+
elif normalized_type == "retriever":
|
|
1108
|
+
return AeriRetriever
|
|
1109
|
+
elif normalized_type == "evaluator":
|
|
1110
|
+
return AeriEvaluator
|
|
1111
|
+
elif normalized_type == "embedding":
|
|
1112
|
+
return AeriEmbedding
|
|
1113
|
+
elif normalized_type == "guardrail":
|
|
1114
|
+
return AeriGuardrail
|
|
1115
|
+
elif normalized_type == "generation":
|
|
1116
|
+
return AeriGeneration
|
|
1117
|
+
elif normalized_type == "event":
|
|
1118
|
+
return AeriEvent
|
|
1119
|
+
elif normalized_type == "span":
|
|
1120
|
+
return AeriSpan
|
|
1121
|
+
else:
|
|
1122
|
+
return AeriSpan
|
|
1123
|
+
|
|
1124
|
+
@_agnosticcontextmanager
|
|
1125
|
+
def _create_span_with_parent_context(
|
|
1126
|
+
self,
|
|
1127
|
+
*,
|
|
1128
|
+
name: str,
|
|
1129
|
+
parent: Optional[otel_trace_api.Span] = None,
|
|
1130
|
+
remote_parent_span: Optional[otel_trace_api.Span] = None,
|
|
1131
|
+
as_type: ObservationTypeLiteralNoEvent,
|
|
1132
|
+
end_on_exit: Optional[bool] = None,
|
|
1133
|
+
input: Optional[Any] = None,
|
|
1134
|
+
output: Optional[Any] = None,
|
|
1135
|
+
metadata: Optional[Any] = None,
|
|
1136
|
+
version: Optional[str] = None,
|
|
1137
|
+
level: Optional[SpanLevel] = None,
|
|
1138
|
+
status_message: Optional[str] = None,
|
|
1139
|
+
completion_start_time: Optional[datetime] = None,
|
|
1140
|
+
model: Optional[str] = None,
|
|
1141
|
+
model_parameters: Optional[Dict[str, MapValue]] = None,
|
|
1142
|
+
usage_details: Optional[Dict[str, int]] = None,
|
|
1143
|
+
cost_details: Optional[Dict[str, float]] = None,
|
|
1144
|
+
prompt: Optional[PromptClient] = None,
|
|
1145
|
+
) -> Any:
|
|
1146
|
+
parent_span = parent or cast(otel_trace_api.Span, remote_parent_span)
|
|
1147
|
+
|
|
1148
|
+
with otel_trace_api.use_span(parent_span):
|
|
1149
|
+
with self._start_as_current_otel_span_with_processed_media(
|
|
1150
|
+
name=name,
|
|
1151
|
+
as_type=as_type,
|
|
1152
|
+
end_on_exit=end_on_exit,
|
|
1153
|
+
input=input,
|
|
1154
|
+
output=output,
|
|
1155
|
+
metadata=metadata,
|
|
1156
|
+
version=version,
|
|
1157
|
+
level=level,
|
|
1158
|
+
status_message=status_message,
|
|
1159
|
+
completion_start_time=completion_start_time,
|
|
1160
|
+
model=model,
|
|
1161
|
+
model_parameters=model_parameters,
|
|
1162
|
+
usage_details=usage_details,
|
|
1163
|
+
cost_details=cost_details,
|
|
1164
|
+
prompt=prompt,
|
|
1165
|
+
) as aeri_span:
|
|
1166
|
+
if remote_parent_span is not None:
|
|
1167
|
+
aeri_span._otel_span.set_attribute(
|
|
1168
|
+
AeriOtelSpanAttributes.AS_ROOT, True
|
|
1169
|
+
)
|
|
1170
|
+
|
|
1171
|
+
yield aeri_span
|
|
1172
|
+
|
|
1173
|
+
@_agnosticcontextmanager
|
|
1174
|
+
def _start_as_current_otel_span_with_processed_media(
|
|
1175
|
+
self,
|
|
1176
|
+
*,
|
|
1177
|
+
name: str,
|
|
1178
|
+
as_type: Optional[ObservationTypeLiteralNoEvent] = None,
|
|
1179
|
+
end_on_exit: Optional[bool] = None,
|
|
1180
|
+
input: Optional[Any] = None,
|
|
1181
|
+
output: Optional[Any] = None,
|
|
1182
|
+
metadata: Optional[Any] = None,
|
|
1183
|
+
version: Optional[str] = None,
|
|
1184
|
+
level: Optional[SpanLevel] = None,
|
|
1185
|
+
status_message: Optional[str] = None,
|
|
1186
|
+
completion_start_time: Optional[datetime] = None,
|
|
1187
|
+
model: Optional[str] = None,
|
|
1188
|
+
model_parameters: Optional[Dict[str, MapValue]] = None,
|
|
1189
|
+
usage_details: Optional[Dict[str, int]] = None,
|
|
1190
|
+
cost_details: Optional[Dict[str, float]] = None,
|
|
1191
|
+
prompt: Optional[PromptClient] = None,
|
|
1192
|
+
) -> Any:
|
|
1193
|
+
with self._otel_tracer.start_as_current_span(
|
|
1194
|
+
name=name,
|
|
1195
|
+
end_on_exit=end_on_exit if end_on_exit is not None else True,
|
|
1196
|
+
) as otel_span:
|
|
1197
|
+
span_class = self._get_span_class(
|
|
1198
|
+
as_type or "generation"
|
|
1199
|
+
) # default was "generation"
|
|
1200
|
+
common_args = {
|
|
1201
|
+
"otel_span": otel_span,
|
|
1202
|
+
"aeri_client": self,
|
|
1203
|
+
"environment": self._environment,
|
|
1204
|
+
"release": self._release,
|
|
1205
|
+
"input": input,
|
|
1206
|
+
"output": output,
|
|
1207
|
+
"metadata": metadata,
|
|
1208
|
+
"version": version,
|
|
1209
|
+
"level": level,
|
|
1210
|
+
"status_message": status_message,
|
|
1211
|
+
}
|
|
1212
|
+
|
|
1213
|
+
if span_class in [
|
|
1214
|
+
AeriGeneration,
|
|
1215
|
+
AeriEmbedding,
|
|
1216
|
+
]:
|
|
1217
|
+
common_args.update(
|
|
1218
|
+
{
|
|
1219
|
+
"completion_start_time": completion_start_time,
|
|
1220
|
+
"model": model,
|
|
1221
|
+
"model_parameters": model_parameters,
|
|
1222
|
+
"usage_details": usage_details,
|
|
1223
|
+
"cost_details": cost_details,
|
|
1224
|
+
"prompt": prompt,
|
|
1225
|
+
}
|
|
1226
|
+
)
|
|
1227
|
+
# For span-like types (span, agent, tool, chain, retriever, evaluator, guardrail), no generation properties needed
|
|
1228
|
+
|
|
1229
|
+
yield span_class(**common_args) # type: ignore[arg-type]
|
|
1230
|
+
|
|
1231
|
+
def _get_current_otel_span(self) -> Optional[otel_trace_api.Span]:
|
|
1232
|
+
current_span = otel_trace_api.get_current_span()
|
|
1233
|
+
|
|
1234
|
+
if current_span is otel_trace_api.INVALID_SPAN:
|
|
1235
|
+
aeri_logger.warning(
|
|
1236
|
+
"Context error: No active span in current context. Operations that depend on an active span will be skipped. "
|
|
1237
|
+
"Ensure spans are created with start_as_current_observation() or that you're operating within an active span context."
|
|
1238
|
+
)
|
|
1239
|
+
return None
|
|
1240
|
+
|
|
1241
|
+
return current_span
|
|
1242
|
+
|
|
1243
|
+
def update_current_generation(
|
|
1244
|
+
self,
|
|
1245
|
+
*,
|
|
1246
|
+
name: Optional[str] = None,
|
|
1247
|
+
input: Optional[Any] = None,
|
|
1248
|
+
output: Optional[Any] = None,
|
|
1249
|
+
metadata: Optional[Any] = None,
|
|
1250
|
+
version: Optional[str] = None,
|
|
1251
|
+
level: Optional[SpanLevel] = None,
|
|
1252
|
+
status_message: Optional[str] = None,
|
|
1253
|
+
completion_start_time: Optional[datetime] = None,
|
|
1254
|
+
model: Optional[str] = None,
|
|
1255
|
+
model_parameters: Optional[Dict[str, MapValue]] = None,
|
|
1256
|
+
usage_details: Optional[Dict[str, int]] = None,
|
|
1257
|
+
cost_details: Optional[Dict[str, float]] = None,
|
|
1258
|
+
prompt: Optional[PromptClient] = None,
|
|
1259
|
+
) -> None:
|
|
1260
|
+
"""Update the current active generation span with new information.
|
|
1261
|
+
|
|
1262
|
+
This method updates the current generation span in the active context with
|
|
1263
|
+
additional information. It's useful for adding output, usage stats, or other
|
|
1264
|
+
details that become available during or after model generation.
|
|
1265
|
+
|
|
1266
|
+
Args:
|
|
1267
|
+
name: The generation name
|
|
1268
|
+
input: Updated input data for the model
|
|
1269
|
+
output: Output from the model (e.g., completions)
|
|
1270
|
+
metadata: Additional metadata to associate with the generation
|
|
1271
|
+
version: Version identifier for the model or component
|
|
1272
|
+
level: Importance level of the generation (info, warning, error)
|
|
1273
|
+
status_message: Optional status message for the generation
|
|
1274
|
+
completion_start_time: When the model started generating the response
|
|
1275
|
+
model: Name/identifier of the AI model used (e.g., "gpt-4")
|
|
1276
|
+
model_parameters: Parameters used for the model (e.g., temperature, max_tokens)
|
|
1277
|
+
usage_details: Token usage information (e.g., prompt_tokens, completion_tokens)
|
|
1278
|
+
cost_details: Cost information for the model call
|
|
1279
|
+
prompt: Associated prompt template from Aeri prompt management
|
|
1280
|
+
|
|
1281
|
+
Example:
|
|
1282
|
+
```python
|
|
1283
|
+
with aeri.start_as_current_generation(name="answer-query") as generation:
|
|
1284
|
+
# Initial setup and API call
|
|
1285
|
+
response = llm.generate(...)
|
|
1286
|
+
|
|
1287
|
+
# Update with results that weren't available at creation time
|
|
1288
|
+
aeri.update_current_generation(
|
|
1289
|
+
output=response.text,
|
|
1290
|
+
usage_details={
|
|
1291
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
|
1292
|
+
"completion_tokens": response.usage.completion_tokens
|
|
1293
|
+
}
|
|
1294
|
+
)
|
|
1295
|
+
```
|
|
1296
|
+
"""
|
|
1297
|
+
if not self._tracing_enabled:
|
|
1298
|
+
aeri_logger.debug(
|
|
1299
|
+
"Operation skipped: update_current_generation - Tracing is disabled or client is in no-op mode."
|
|
1300
|
+
)
|
|
1301
|
+
return
|
|
1302
|
+
|
|
1303
|
+
current_otel_span = self._get_current_otel_span()
|
|
1304
|
+
|
|
1305
|
+
if current_otel_span is not None:
|
|
1306
|
+
generation = AeriGeneration(
|
|
1307
|
+
otel_span=current_otel_span, aeri_client=self
|
|
1308
|
+
)
|
|
1309
|
+
|
|
1310
|
+
if name:
|
|
1311
|
+
current_otel_span.update_name(name)
|
|
1312
|
+
|
|
1313
|
+
generation.update(
|
|
1314
|
+
input=input,
|
|
1315
|
+
output=output,
|
|
1316
|
+
metadata=metadata,
|
|
1317
|
+
version=version,
|
|
1318
|
+
level=level,
|
|
1319
|
+
status_message=status_message,
|
|
1320
|
+
completion_start_time=completion_start_time,
|
|
1321
|
+
model=model,
|
|
1322
|
+
model_parameters=model_parameters,
|
|
1323
|
+
usage_details=usage_details,
|
|
1324
|
+
cost_details=cost_details,
|
|
1325
|
+
prompt=prompt,
|
|
1326
|
+
)
|
|
1327
|
+
|
|
1328
|
+
def update_current_span(
|
|
1329
|
+
self,
|
|
1330
|
+
*,
|
|
1331
|
+
name: Optional[str] = None,
|
|
1332
|
+
input: Optional[Any] = None,
|
|
1333
|
+
output: Optional[Any] = None,
|
|
1334
|
+
metadata: Optional[Any] = None,
|
|
1335
|
+
version: Optional[str] = None,
|
|
1336
|
+
level: Optional[SpanLevel] = None,
|
|
1337
|
+
status_message: Optional[str] = None,
|
|
1338
|
+
) -> None:
|
|
1339
|
+
"""Update the current active span with new information.
|
|
1340
|
+
|
|
1341
|
+
This method updates the current span in the active context with
|
|
1342
|
+
additional information. It's useful for adding outputs or metadata
|
|
1343
|
+
that become available during execution.
|
|
1344
|
+
|
|
1345
|
+
Args:
|
|
1346
|
+
name: The span name
|
|
1347
|
+
input: Updated input data for the operation
|
|
1348
|
+
output: Output data from the operation
|
|
1349
|
+
metadata: Additional metadata to associate with the span
|
|
1350
|
+
version: Version identifier for the code or component
|
|
1351
|
+
level: Importance level of the span (info, warning, error)
|
|
1352
|
+
status_message: Optional status message for the span
|
|
1353
|
+
|
|
1354
|
+
Example:
|
|
1355
|
+
```python
|
|
1356
|
+
with aeri.start_as_current_observation(name="process-data") as span:
|
|
1357
|
+
# Initial processing
|
|
1358
|
+
result = process_first_part()
|
|
1359
|
+
|
|
1360
|
+
# Update with intermediate results
|
|
1361
|
+
aeri.update_current_span(metadata={"intermediate_result": result})
|
|
1362
|
+
|
|
1363
|
+
# Continue processing
|
|
1364
|
+
final_result = process_second_part(result)
|
|
1365
|
+
|
|
1366
|
+
# Final update
|
|
1367
|
+
aeri.update_current_span(output=final_result)
|
|
1368
|
+
```
|
|
1369
|
+
"""
|
|
1370
|
+
if not self._tracing_enabled:
|
|
1371
|
+
aeri_logger.debug(
|
|
1372
|
+
"Operation skipped: update_current_span - Tracing is disabled or client is in no-op mode."
|
|
1373
|
+
)
|
|
1374
|
+
return
|
|
1375
|
+
|
|
1376
|
+
current_otel_span = self._get_current_otel_span()
|
|
1377
|
+
|
|
1378
|
+
if current_otel_span is not None:
|
|
1379
|
+
span = AeriSpan(
|
|
1380
|
+
otel_span=current_otel_span,
|
|
1381
|
+
aeri_client=self,
|
|
1382
|
+
environment=self._environment,
|
|
1383
|
+
release=self._release,
|
|
1384
|
+
)
|
|
1385
|
+
|
|
1386
|
+
if name:
|
|
1387
|
+
current_otel_span.update_name(name)
|
|
1388
|
+
|
|
1389
|
+
span.update(
|
|
1390
|
+
input=input,
|
|
1391
|
+
output=output,
|
|
1392
|
+
metadata=metadata,
|
|
1393
|
+
version=version,
|
|
1394
|
+
level=level,
|
|
1395
|
+
status_message=status_message,
|
|
1396
|
+
)
|
|
1397
|
+
|
|
1398
|
+
@deprecated(
|
|
1399
|
+
"Trace-level input/output is deprecated. "
|
|
1400
|
+
"For trace attributes (user_id, session_id, tags, etc.), use propagate_attributes() instead. "
|
|
1401
|
+
"This method will be removed in a future major version."
|
|
1402
|
+
)
|
|
1403
|
+
def set_current_trace_io(
|
|
1404
|
+
self,
|
|
1405
|
+
*,
|
|
1406
|
+
input: Optional[Any] = None,
|
|
1407
|
+
output: Optional[Any] = None,
|
|
1408
|
+
) -> None:
|
|
1409
|
+
"""Set trace-level input and output for the current span's trace.
|
|
1410
|
+
|
|
1411
|
+
.. deprecated::
|
|
1412
|
+
This is a legacy method for backward compatibility with Aeri platform
|
|
1413
|
+
features that still rely on trace-level input/output (e.g., legacy LLM-as-a-judge
|
|
1414
|
+
evaluators). It will be removed in a future major version.
|
|
1415
|
+
|
|
1416
|
+
For setting other trace attributes (user_id, session_id, metadata, tags, version),
|
|
1417
|
+
use :meth:`propagate_attributes` instead.
|
|
1418
|
+
|
|
1419
|
+
Args:
|
|
1420
|
+
input: Input data to associate with the trace.
|
|
1421
|
+
output: Output data to associate with the trace.
|
|
1422
|
+
"""
|
|
1423
|
+
if not self._tracing_enabled:
|
|
1424
|
+
aeri_logger.debug(
|
|
1425
|
+
"Operation skipped: set_current_trace_io - Tracing is disabled or client is in no-op mode."
|
|
1426
|
+
)
|
|
1427
|
+
return
|
|
1428
|
+
|
|
1429
|
+
current_otel_span = self._get_current_otel_span()
|
|
1430
|
+
|
|
1431
|
+
if current_otel_span is not None and current_otel_span.is_recording():
|
|
1432
|
+
existing_observation_type = current_otel_span.attributes.get( # type: ignore[attr-defined]
|
|
1433
|
+
AeriOtelSpanAttributes.OBSERVATION_TYPE, "span"
|
|
1434
|
+
)
|
|
1435
|
+
# We need to preserve the class to keep the correct observation type
|
|
1436
|
+
span_class = self._get_span_class(existing_observation_type)
|
|
1437
|
+
span = span_class(
|
|
1438
|
+
otel_span=current_otel_span,
|
|
1439
|
+
aeri_client=self,
|
|
1440
|
+
environment=self._environment,
|
|
1441
|
+
release=self._release,
|
|
1442
|
+
)
|
|
1443
|
+
|
|
1444
|
+
span.set_trace_io(
|
|
1445
|
+
input=input,
|
|
1446
|
+
output=output,
|
|
1447
|
+
)
|
|
1448
|
+
|
|
1449
|
+
def set_current_trace_as_public(self) -> None:
|
|
1450
|
+
"""Make the current trace publicly accessible via its URL.
|
|
1451
|
+
|
|
1452
|
+
When a trace is published, anyone with the trace link can view the full trace
|
|
1453
|
+
without needing to be logged in to Aeri. This action cannot be undone
|
|
1454
|
+
programmatically - once published, the entire trace becomes public.
|
|
1455
|
+
|
|
1456
|
+
This is a convenience method that publishes the trace from the currently
|
|
1457
|
+
active span context. Use this when you want to make a trace public from
|
|
1458
|
+
within a traced function without needing direct access to the span object.
|
|
1459
|
+
"""
|
|
1460
|
+
if not self._tracing_enabled:
|
|
1461
|
+
aeri_logger.debug(
|
|
1462
|
+
"Operation skipped: set_current_trace_as_public - Tracing is disabled or client is in no-op mode."
|
|
1463
|
+
)
|
|
1464
|
+
return
|
|
1465
|
+
|
|
1466
|
+
current_otel_span = self._get_current_otel_span()
|
|
1467
|
+
|
|
1468
|
+
if current_otel_span is not None and current_otel_span.is_recording():
|
|
1469
|
+
existing_observation_type = current_otel_span.attributes.get( # type: ignore[attr-defined]
|
|
1470
|
+
AeriOtelSpanAttributes.OBSERVATION_TYPE, "span"
|
|
1471
|
+
)
|
|
1472
|
+
# We need to preserve the class to keep the correct observation type
|
|
1473
|
+
span_class = self._get_span_class(existing_observation_type)
|
|
1474
|
+
span = span_class(
|
|
1475
|
+
otel_span=current_otel_span,
|
|
1476
|
+
aeri_client=self,
|
|
1477
|
+
environment=self._environment,
|
|
1478
|
+
)
|
|
1479
|
+
|
|
1480
|
+
span.set_trace_as_public()
|
|
1481
|
+
|
|
1482
|
+
def create_event(
|
|
1483
|
+
self,
|
|
1484
|
+
*,
|
|
1485
|
+
trace_context: Optional[TraceContext] = None,
|
|
1486
|
+
name: str,
|
|
1487
|
+
input: Optional[Any] = None,
|
|
1488
|
+
output: Optional[Any] = None,
|
|
1489
|
+
metadata: Optional[Any] = None,
|
|
1490
|
+
version: Optional[str] = None,
|
|
1491
|
+
level: Optional[SpanLevel] = None,
|
|
1492
|
+
status_message: Optional[str] = None,
|
|
1493
|
+
) -> AeriEvent:
|
|
1494
|
+
"""Create a new Aeri observation of type 'EVENT'.
|
|
1495
|
+
|
|
1496
|
+
The created Aeri Event observation will be the child of the current span in the context.
|
|
1497
|
+
|
|
1498
|
+
Args:
|
|
1499
|
+
trace_context: Optional context for connecting to an existing trace
|
|
1500
|
+
name: Name of the span (e.g., function or operation name)
|
|
1501
|
+
input: Input data for the operation (can be any JSON-serializable object)
|
|
1502
|
+
output: Output data from the operation (can be any JSON-serializable object)
|
|
1503
|
+
metadata: Additional metadata to associate with the span
|
|
1504
|
+
version: Version identifier for the code or component
|
|
1505
|
+
level: Importance level of the span (info, warning, error)
|
|
1506
|
+
status_message: Optional status message for the span
|
|
1507
|
+
|
|
1508
|
+
Returns:
|
|
1509
|
+
The Aeri Event object
|
|
1510
|
+
|
|
1511
|
+
Example:
|
|
1512
|
+
```python
|
|
1513
|
+
event = aeri.create_event(name="process-event")
|
|
1514
|
+
```
|
|
1515
|
+
"""
|
|
1516
|
+
timestamp = time_ns()
|
|
1517
|
+
|
|
1518
|
+
if trace_context:
|
|
1519
|
+
trace_id = trace_context.get("trace_id", None)
|
|
1520
|
+
parent_span_id = trace_context.get("parent_span_id", None)
|
|
1521
|
+
|
|
1522
|
+
if trace_id:
|
|
1523
|
+
remote_parent_span = self._create_remote_parent_span(
|
|
1524
|
+
trace_id=trace_id, parent_span_id=parent_span_id
|
|
1525
|
+
)
|
|
1526
|
+
|
|
1527
|
+
with otel_trace_api.use_span(
|
|
1528
|
+
cast(otel_trace_api.Span, remote_parent_span)
|
|
1529
|
+
):
|
|
1530
|
+
otel_span = self._otel_tracer.start_span(
|
|
1531
|
+
name=name, start_time=timestamp
|
|
1532
|
+
)
|
|
1533
|
+
otel_span.set_attribute(AeriOtelSpanAttributes.AS_ROOT, True)
|
|
1534
|
+
|
|
1535
|
+
return cast(
|
|
1536
|
+
AeriEvent,
|
|
1537
|
+
AeriEvent(
|
|
1538
|
+
otel_span=otel_span,
|
|
1539
|
+
aeri_client=self,
|
|
1540
|
+
environment=self._environment,
|
|
1541
|
+
release=self._release,
|
|
1542
|
+
input=input,
|
|
1543
|
+
output=output,
|
|
1544
|
+
metadata=metadata,
|
|
1545
|
+
version=version,
|
|
1546
|
+
level=level,
|
|
1547
|
+
status_message=status_message,
|
|
1548
|
+
).end(end_time=timestamp),
|
|
1549
|
+
)
|
|
1550
|
+
|
|
1551
|
+
otel_span = self._otel_tracer.start_span(name=name, start_time=timestamp)
|
|
1552
|
+
|
|
1553
|
+
return cast(
|
|
1554
|
+
AeriEvent,
|
|
1555
|
+
AeriEvent(
|
|
1556
|
+
otel_span=otel_span,
|
|
1557
|
+
aeri_client=self,
|
|
1558
|
+
environment=self._environment,
|
|
1559
|
+
release=self._release,
|
|
1560
|
+
input=input,
|
|
1561
|
+
output=output,
|
|
1562
|
+
metadata=metadata,
|
|
1563
|
+
version=version,
|
|
1564
|
+
level=level,
|
|
1565
|
+
status_message=status_message,
|
|
1566
|
+
).end(end_time=timestamp),
|
|
1567
|
+
)
|
|
1568
|
+
|
|
1569
|
+
def _create_remote_parent_span(
|
|
1570
|
+
self, *, trace_id: str, parent_span_id: Optional[str]
|
|
1571
|
+
) -> Any:
|
|
1572
|
+
if not self._is_valid_trace_id(trace_id):
|
|
1573
|
+
aeri_logger.warning(
|
|
1574
|
+
f"Passed trace ID '{trace_id}' is not a valid 32 lowercase hex char Aeri trace id. Ignoring trace ID."
|
|
1575
|
+
)
|
|
1576
|
+
|
|
1577
|
+
if parent_span_id and not self._is_valid_span_id(parent_span_id):
|
|
1578
|
+
aeri_logger.warning(
|
|
1579
|
+
f"Passed span ID '{parent_span_id}' is not a valid 16 lowercase hex char Aeri span id. Ignoring parent span ID."
|
|
1580
|
+
)
|
|
1581
|
+
|
|
1582
|
+
int_trace_id = int(trace_id, 16)
|
|
1583
|
+
int_parent_span_id = (
|
|
1584
|
+
int(parent_span_id, 16)
|
|
1585
|
+
if parent_span_id
|
|
1586
|
+
else RandomIdGenerator().generate_span_id()
|
|
1587
|
+
)
|
|
1588
|
+
|
|
1589
|
+
span_context = otel_trace_api.SpanContext(
|
|
1590
|
+
trace_id=int_trace_id,
|
|
1591
|
+
span_id=int_parent_span_id,
|
|
1592
|
+
trace_flags=otel_trace_api.TraceFlags(0x01), # mark span as sampled
|
|
1593
|
+
is_remote=False,
|
|
1594
|
+
)
|
|
1595
|
+
|
|
1596
|
+
return otel_trace_api.NonRecordingSpan(span_context)
|
|
1597
|
+
|
|
1598
|
+
def _is_valid_trace_id(self, trace_id: str) -> bool:
|
|
1599
|
+
pattern = r"^[0-9a-f]{32}$"
|
|
1600
|
+
|
|
1601
|
+
return bool(re.match(pattern, trace_id))
|
|
1602
|
+
|
|
1603
|
+
def _is_valid_span_id(self, span_id: str) -> bool:
|
|
1604
|
+
pattern = r"^[0-9a-f]{16}$"
|
|
1605
|
+
|
|
1606
|
+
return bool(re.match(pattern, span_id))
|
|
1607
|
+
|
|
1608
|
+
def _create_observation_id(self, *, seed: Optional[str] = None) -> str:
|
|
1609
|
+
"""Create a unique observation ID for use with Aeri.
|
|
1610
|
+
|
|
1611
|
+
This method generates a unique observation ID (span ID in OpenTelemetry terms)
|
|
1612
|
+
for use with various Aeri APIs. It can either generate a random ID or
|
|
1613
|
+
create a deterministic ID based on a seed string.
|
|
1614
|
+
|
|
1615
|
+
Observation IDs must be 16 lowercase hexadecimal characters, representing 8 bytes.
|
|
1616
|
+
This method ensures the generated ID meets this requirement. If you need to
|
|
1617
|
+
correlate an external ID with a Aeri observation ID, use the external ID as
|
|
1618
|
+
the seed to get a valid, deterministic observation ID.
|
|
1619
|
+
|
|
1620
|
+
Args:
|
|
1621
|
+
seed: Optional string to use as a seed for deterministic ID generation.
|
|
1622
|
+
If provided, the same seed will always produce the same ID.
|
|
1623
|
+
If not provided, a random ID will be generated.
|
|
1624
|
+
|
|
1625
|
+
Returns:
|
|
1626
|
+
A 16-character lowercase hexadecimal string representing the observation ID.
|
|
1627
|
+
|
|
1628
|
+
Example:
|
|
1629
|
+
```python
|
|
1630
|
+
# Generate a random observation ID
|
|
1631
|
+
obs_id = aeri.create_observation_id()
|
|
1632
|
+
|
|
1633
|
+
# Generate a deterministic ID based on a seed
|
|
1634
|
+
user_obs_id = aeri.create_observation_id(seed="user-123-feedback")
|
|
1635
|
+
|
|
1636
|
+
# Correlate an external item ID with a Aeri observation ID
|
|
1637
|
+
item_id = "item-789012"
|
|
1638
|
+
correlated_obs_id = aeri.create_observation_id(seed=item_id)
|
|
1639
|
+
|
|
1640
|
+
# Use the ID with Aeri APIs
|
|
1641
|
+
aeri.create_score(
|
|
1642
|
+
name="relevance",
|
|
1643
|
+
value=0.95,
|
|
1644
|
+
trace_id=trace_id,
|
|
1645
|
+
observation_id=obs_id
|
|
1646
|
+
)
|
|
1647
|
+
```
|
|
1648
|
+
"""
|
|
1649
|
+
if not seed:
|
|
1650
|
+
span_id_int = RandomIdGenerator().generate_span_id()
|
|
1651
|
+
|
|
1652
|
+
return self._format_otel_span_id(span_id_int)
|
|
1653
|
+
|
|
1654
|
+
return sha256(seed.encode("utf-8")).digest()[:8].hex()
|
|
1655
|
+
|
|
1656
|
+
@staticmethod
|
|
1657
|
+
def create_trace_id(*, seed: Optional[str] = None) -> str:
|
|
1658
|
+
"""Create a unique trace ID for use with Aeri.
|
|
1659
|
+
|
|
1660
|
+
This method generates a unique trace ID for use with various Aeri APIs.
|
|
1661
|
+
It can either generate a random ID or create a deterministic ID based on
|
|
1662
|
+
a seed string.
|
|
1663
|
+
|
|
1664
|
+
Trace IDs must be 32 lowercase hexadecimal characters, representing 16 bytes.
|
|
1665
|
+
This method ensures the generated ID meets this requirement. If you need to
|
|
1666
|
+
correlate an external ID with a Aeri trace ID, use the external ID as the
|
|
1667
|
+
seed to get a valid, deterministic Aeri trace ID.
|
|
1668
|
+
|
|
1669
|
+
Args:
|
|
1670
|
+
seed: Optional string to use as a seed for deterministic ID generation.
|
|
1671
|
+
If provided, the same seed will always produce the same ID.
|
|
1672
|
+
If not provided, a random ID will be generated.
|
|
1673
|
+
|
|
1674
|
+
Returns:
|
|
1675
|
+
A 32-character lowercase hexadecimal string representing the Aeri trace ID.
|
|
1676
|
+
|
|
1677
|
+
Example:
|
|
1678
|
+
```python
|
|
1679
|
+
# Generate a random trace ID
|
|
1680
|
+
trace_id = aeri.create_trace_id()
|
|
1681
|
+
|
|
1682
|
+
# Generate a deterministic ID based on a seed
|
|
1683
|
+
session_trace_id = aeri.create_trace_id(seed="session-456")
|
|
1684
|
+
|
|
1685
|
+
# Correlate an external ID with a Aeri trace ID
|
|
1686
|
+
external_id = "external-system-123456"
|
|
1687
|
+
correlated_trace_id = aeri.create_trace_id(seed=external_id)
|
|
1688
|
+
|
|
1689
|
+
# Use the ID with trace context
|
|
1690
|
+
with aeri.start_as_current_observation(
|
|
1691
|
+
name="process-request",
|
|
1692
|
+
trace_context={"trace_id": trace_id}
|
|
1693
|
+
) as span:
|
|
1694
|
+
# Operation will be part of the specific trace
|
|
1695
|
+
pass
|
|
1696
|
+
```
|
|
1697
|
+
"""
|
|
1698
|
+
if not seed:
|
|
1699
|
+
trace_id_int = RandomIdGenerator().generate_trace_id()
|
|
1700
|
+
|
|
1701
|
+
return Aeri._format_otel_trace_id(trace_id_int)
|
|
1702
|
+
|
|
1703
|
+
return sha256(seed.encode("utf-8")).digest()[:16].hex()
|
|
1704
|
+
|
|
1705
|
+
def _get_otel_trace_id(self, otel_span: otel_trace_api.Span) -> str:
|
|
1706
|
+
span_context = otel_span.get_span_context()
|
|
1707
|
+
|
|
1708
|
+
return self._format_otel_trace_id(span_context.trace_id)
|
|
1709
|
+
|
|
1710
|
+
def _get_otel_span_id(self, otel_span: otel_trace_api.Span) -> str:
|
|
1711
|
+
span_context = otel_span.get_span_context()
|
|
1712
|
+
|
|
1713
|
+
return self._format_otel_span_id(span_context.span_id)
|
|
1714
|
+
|
|
1715
|
+
@staticmethod
|
|
1716
|
+
def _format_otel_span_id(span_id_int: int) -> str:
|
|
1717
|
+
"""Format an integer span ID to a 16-character lowercase hex string.
|
|
1718
|
+
|
|
1719
|
+
Internal method to convert an OpenTelemetry integer span ID to the standard
|
|
1720
|
+
W3C Trace Context format (16-character lowercase hex string).
|
|
1721
|
+
|
|
1722
|
+
Args:
|
|
1723
|
+
span_id_int: 64-bit integer representing a span ID
|
|
1724
|
+
|
|
1725
|
+
Returns:
|
|
1726
|
+
A 16-character lowercase hexadecimal string
|
|
1727
|
+
"""
|
|
1728
|
+
return format(span_id_int, "016x")
|
|
1729
|
+
|
|
1730
|
+
@staticmethod
|
|
1731
|
+
def _format_otel_trace_id(trace_id_int: int) -> str:
|
|
1732
|
+
"""Format an integer trace ID to a 32-character lowercase hex string.
|
|
1733
|
+
|
|
1734
|
+
Internal method to convert an OpenTelemetry integer trace ID to the standard
|
|
1735
|
+
W3C Trace Context format (32-character lowercase hex string).
|
|
1736
|
+
|
|
1737
|
+
Args:
|
|
1738
|
+
trace_id_int: 128-bit integer representing a trace ID
|
|
1739
|
+
|
|
1740
|
+
Returns:
|
|
1741
|
+
A 32-character lowercase hexadecimal string
|
|
1742
|
+
"""
|
|
1743
|
+
return format(trace_id_int, "032x")
|
|
1744
|
+
|
|
1745
|
+
@overload
|
|
1746
|
+
def create_score(
|
|
1747
|
+
self,
|
|
1748
|
+
*,
|
|
1749
|
+
name: str,
|
|
1750
|
+
value: float,
|
|
1751
|
+
session_id: Optional[str] = None,
|
|
1752
|
+
dataset_run_id: Optional[str] = None,
|
|
1753
|
+
trace_id: Optional[str] = None,
|
|
1754
|
+
observation_id: Optional[str] = None,
|
|
1755
|
+
score_id: Optional[str] = None,
|
|
1756
|
+
data_type: Optional[Literal["NUMERIC", "BOOLEAN"]] = None,
|
|
1757
|
+
comment: Optional[str] = None,
|
|
1758
|
+
config_id: Optional[str] = None,
|
|
1759
|
+
metadata: Optional[Any] = None,
|
|
1760
|
+
timestamp: Optional[datetime] = None,
|
|
1761
|
+
) -> None: ...
|
|
1762
|
+
|
|
1763
|
+
@overload
|
|
1764
|
+
def create_score(
|
|
1765
|
+
self,
|
|
1766
|
+
*,
|
|
1767
|
+
name: str,
|
|
1768
|
+
value: str,
|
|
1769
|
+
session_id: Optional[str] = None,
|
|
1770
|
+
dataset_run_id: Optional[str] = None,
|
|
1771
|
+
trace_id: Optional[str] = None,
|
|
1772
|
+
score_id: Optional[str] = None,
|
|
1773
|
+
observation_id: Optional[str] = None,
|
|
1774
|
+
data_type: Optional[Literal["CATEGORICAL"]] = "CATEGORICAL",
|
|
1775
|
+
comment: Optional[str] = None,
|
|
1776
|
+
config_id: Optional[str] = None,
|
|
1777
|
+
metadata: Optional[Any] = None,
|
|
1778
|
+
timestamp: Optional[datetime] = None,
|
|
1779
|
+
) -> None: ...
|
|
1780
|
+
|
|
1781
|
+
def create_score(
|
|
1782
|
+
self,
|
|
1783
|
+
*,
|
|
1784
|
+
name: str,
|
|
1785
|
+
value: Union[float, str],
|
|
1786
|
+
session_id: Optional[str] = None,
|
|
1787
|
+
dataset_run_id: Optional[str] = None,
|
|
1788
|
+
trace_id: Optional[str] = None,
|
|
1789
|
+
observation_id: Optional[str] = None,
|
|
1790
|
+
score_id: Optional[str] = None,
|
|
1791
|
+
data_type: Optional[ScoreDataType] = None,
|
|
1792
|
+
comment: Optional[str] = None,
|
|
1793
|
+
config_id: Optional[str] = None,
|
|
1794
|
+
metadata: Optional[Any] = None,
|
|
1795
|
+
timestamp: Optional[datetime] = None,
|
|
1796
|
+
) -> None:
|
|
1797
|
+
"""Create a score for a specific trace or observation.
|
|
1798
|
+
|
|
1799
|
+
This method creates a score for evaluating a Aeri trace or observation. Scores can be
|
|
1800
|
+
used to track quality metrics, user feedback, or automated evaluations.
|
|
1801
|
+
|
|
1802
|
+
Args:
|
|
1803
|
+
name: Name of the score (e.g., "relevance", "accuracy")
|
|
1804
|
+
value: Score value (can be numeric for NUMERIC/BOOLEAN types or string for CATEGORICAL)
|
|
1805
|
+
session_id: ID of the Aeri session to associate the score with
|
|
1806
|
+
dataset_run_id: ID of the Aeri dataset run to associate the score with
|
|
1807
|
+
trace_id: ID of the Aeri trace to associate the score with
|
|
1808
|
+
observation_id: Optional ID of the specific observation to score. Trace ID must be provided too.
|
|
1809
|
+
score_id: Optional custom ID for the score (auto-generated if not provided)
|
|
1810
|
+
data_type: Type of score (NUMERIC, BOOLEAN, or CATEGORICAL)
|
|
1811
|
+
comment: Optional comment or explanation for the score
|
|
1812
|
+
config_id: Optional ID of a score config defined in Aeri
|
|
1813
|
+
metadata: Optional metadata to be attached to the score
|
|
1814
|
+
timestamp: Optional timestamp for the score (defaults to current UTC time)
|
|
1815
|
+
|
|
1816
|
+
Example:
|
|
1817
|
+
```python
|
|
1818
|
+
# Create a numeric score for accuracy
|
|
1819
|
+
aeri.create_score(
|
|
1820
|
+
name="accuracy",
|
|
1821
|
+
value=0.92,
|
|
1822
|
+
trace_id="abcdef1234567890abcdef1234567890",
|
|
1823
|
+
data_type="NUMERIC",
|
|
1824
|
+
comment="High accuracy with minor irrelevant details"
|
|
1825
|
+
)
|
|
1826
|
+
|
|
1827
|
+
# Create a categorical score for sentiment
|
|
1828
|
+
aeri.create_score(
|
|
1829
|
+
name="sentiment",
|
|
1830
|
+
value="positive",
|
|
1831
|
+
trace_id="abcdef1234567890abcdef1234567890",
|
|
1832
|
+
observation_id="abcdef1234567890",
|
|
1833
|
+
data_type="CATEGORICAL"
|
|
1834
|
+
)
|
|
1835
|
+
```
|
|
1836
|
+
"""
|
|
1837
|
+
if not self._tracing_enabled:
|
|
1838
|
+
return
|
|
1839
|
+
|
|
1840
|
+
# ── Pydantic V2 strict validation before any further processing ──────
|
|
1841
|
+
try:
|
|
1842
|
+
_validated_score = ScoreInput.model_validate(
|
|
1843
|
+
{
|
|
1844
|
+
"name": name,
|
|
1845
|
+
"value": value,
|
|
1846
|
+
"data_type": data_type,
|
|
1847
|
+
"comment": comment,
|
|
1848
|
+
"config_id": config_id,
|
|
1849
|
+
"trace_id": trace_id,
|
|
1850
|
+
"observation_id": observation_id,
|
|
1851
|
+
}
|
|
1852
|
+
)
|
|
1853
|
+
except Exception as validation_exc:
|
|
1854
|
+
err_msg = f"Score validation failed for name={name!r}: {validation_exc}"
|
|
1855
|
+
aeri_logger.error(err_msg)
|
|
1856
|
+
return
|
|
1857
|
+
|
|
1858
|
+
score_id = score_id or self._create_observation_id()
|
|
1859
|
+
|
|
1860
|
+
try:
|
|
1861
|
+
new_body = ScoreBody(
|
|
1862
|
+
id=score_id,
|
|
1863
|
+
session_id=session_id,
|
|
1864
|
+
datasetRunId=dataset_run_id,
|
|
1865
|
+
traceId=_validated_score.trace_id,
|
|
1866
|
+
observationId=_validated_score.observation_id,
|
|
1867
|
+
name=_validated_score.name,
|
|
1868
|
+
value=_validated_score.value,
|
|
1869
|
+
dataType=_validated_score.data_type, # type: ignore
|
|
1870
|
+
comment=_validated_score.comment,
|
|
1871
|
+
configId=_validated_score.config_id,
|
|
1872
|
+
environment=self._environment,
|
|
1873
|
+
metadata=metadata,
|
|
1874
|
+
)
|
|
1875
|
+
|
|
1876
|
+
event = {
|
|
1877
|
+
"id": self.create_trace_id(),
|
|
1878
|
+
"type": "score-create",
|
|
1879
|
+
"timestamp": timestamp or _get_timestamp(),
|
|
1880
|
+
"body": new_body,
|
|
1881
|
+
}
|
|
1882
|
+
|
|
1883
|
+
if self._resources is not None:
|
|
1884
|
+
# Force the score to be in sample if it was for a legacy trace ID, i.e. non-32 hexchar
|
|
1885
|
+
force_sample = (
|
|
1886
|
+
not self._is_valid_trace_id(trace_id) if trace_id else True
|
|
1887
|
+
)
|
|
1888
|
+
|
|
1889
|
+
self._resources.add_score_task(
|
|
1890
|
+
event,
|
|
1891
|
+
force_sample=force_sample,
|
|
1892
|
+
)
|
|
1893
|
+
|
|
1894
|
+
except Exception as e:
|
|
1895
|
+
aeri_logger.exception(
|
|
1896
|
+
f"Error creating score: Failed to process score event for trace_id={trace_id}, name={name}. Error: {e}"
|
|
1897
|
+
)
|
|
1898
|
+
|
|
1899
|
+
def _create_trace_tags_via_ingestion(
|
|
1900
|
+
self,
|
|
1901
|
+
*,
|
|
1902
|
+
trace_id: str,
|
|
1903
|
+
tags: List[str],
|
|
1904
|
+
) -> None:
|
|
1905
|
+
"""Private helper to enqueue trace tag updates via ingestion API events."""
|
|
1906
|
+
if not self._tracing_enabled:
|
|
1907
|
+
return
|
|
1908
|
+
|
|
1909
|
+
if len(tags) == 0:
|
|
1910
|
+
return
|
|
1911
|
+
|
|
1912
|
+
try:
|
|
1913
|
+
new_body = TraceBody(
|
|
1914
|
+
id=trace_id,
|
|
1915
|
+
tags=tags,
|
|
1916
|
+
)
|
|
1917
|
+
|
|
1918
|
+
event = {
|
|
1919
|
+
"id": self.create_trace_id(),
|
|
1920
|
+
"type": "trace-create",
|
|
1921
|
+
"timestamp": _get_timestamp(),
|
|
1922
|
+
"body": new_body,
|
|
1923
|
+
}
|
|
1924
|
+
|
|
1925
|
+
if self._resources is not None:
|
|
1926
|
+
self._resources.add_trace_task(event)
|
|
1927
|
+
except Exception as e:
|
|
1928
|
+
aeri_logger.exception(
|
|
1929
|
+
f"Error updating trace tags: Failed to process trace update event for trace_id={trace_id}. Error: {e}"
|
|
1930
|
+
)
|
|
1931
|
+
|
|
1932
|
+
@overload
|
|
1933
|
+
def score_current_span(
|
|
1934
|
+
self,
|
|
1935
|
+
*,
|
|
1936
|
+
name: str,
|
|
1937
|
+
value: float,
|
|
1938
|
+
score_id: Optional[str] = None,
|
|
1939
|
+
data_type: Optional[Literal["NUMERIC", "BOOLEAN"]] = None,
|
|
1940
|
+
comment: Optional[str] = None,
|
|
1941
|
+
config_id: Optional[str] = None,
|
|
1942
|
+
metadata: Optional[Any] = None,
|
|
1943
|
+
) -> None: ...
|
|
1944
|
+
|
|
1945
|
+
@overload
|
|
1946
|
+
def score_current_span(
|
|
1947
|
+
self,
|
|
1948
|
+
*,
|
|
1949
|
+
name: str,
|
|
1950
|
+
value: str,
|
|
1951
|
+
score_id: Optional[str] = None,
|
|
1952
|
+
data_type: Optional[Literal["CATEGORICAL"]] = "CATEGORICAL",
|
|
1953
|
+
comment: Optional[str] = None,
|
|
1954
|
+
config_id: Optional[str] = None,
|
|
1955
|
+
metadata: Optional[Any] = None,
|
|
1956
|
+
) -> None: ...
|
|
1957
|
+
|
|
1958
|
+
def score_current_span(
|
|
1959
|
+
self,
|
|
1960
|
+
*,
|
|
1961
|
+
name: str,
|
|
1962
|
+
value: Union[float, str],
|
|
1963
|
+
score_id: Optional[str] = None,
|
|
1964
|
+
data_type: Optional[ScoreDataType] = None,
|
|
1965
|
+
comment: Optional[str] = None,
|
|
1966
|
+
config_id: Optional[str] = None,
|
|
1967
|
+
metadata: Optional[Any] = None,
|
|
1968
|
+
) -> None:
|
|
1969
|
+
"""Create a score for the current active span.
|
|
1970
|
+
|
|
1971
|
+
This method scores the currently active span in the context. It's a convenient
|
|
1972
|
+
way to score the current operation without needing to know its trace and span IDs.
|
|
1973
|
+
|
|
1974
|
+
Args:
|
|
1975
|
+
name: Name of the score (e.g., "relevance", "accuracy")
|
|
1976
|
+
value: Score value (can be numeric for NUMERIC/BOOLEAN types or string for CATEGORICAL)
|
|
1977
|
+
score_id: Optional custom ID for the score (auto-generated if not provided)
|
|
1978
|
+
data_type: Type of score (NUMERIC, BOOLEAN, or CATEGORICAL)
|
|
1979
|
+
comment: Optional comment or explanation for the score
|
|
1980
|
+
config_id: Optional ID of a score config defined in Aeri
|
|
1981
|
+
metadata: Optional metadata to be attached to the score
|
|
1982
|
+
|
|
1983
|
+
Example:
|
|
1984
|
+
```python
|
|
1985
|
+
with aeri.start_as_current_generation(name="answer-query") as generation:
|
|
1986
|
+
# Generate answer
|
|
1987
|
+
response = generate_answer(...)
|
|
1988
|
+
generation.update(output=response)
|
|
1989
|
+
|
|
1990
|
+
# Score the generation
|
|
1991
|
+
aeri.score_current_span(
|
|
1992
|
+
name="relevance",
|
|
1993
|
+
value=0.85,
|
|
1994
|
+
data_type="NUMERIC",
|
|
1995
|
+
comment="Mostly relevant but contains some tangential information",
|
|
1996
|
+
metadata={"model": "gpt-4", "prompt_version": "v2"}
|
|
1997
|
+
)
|
|
1998
|
+
```
|
|
1999
|
+
"""
|
|
2000
|
+
current_span = self._get_current_otel_span()
|
|
2001
|
+
|
|
2002
|
+
if current_span is not None:
|
|
2003
|
+
trace_id = self._get_otel_trace_id(current_span)
|
|
2004
|
+
observation_id = self._get_otel_span_id(current_span)
|
|
2005
|
+
|
|
2006
|
+
aeri_logger.info(
|
|
2007
|
+
f"Score: Creating score name='{name}' value={value} for current span ({observation_id}) in trace {trace_id}"
|
|
2008
|
+
)
|
|
2009
|
+
|
|
2010
|
+
self.create_score(
|
|
2011
|
+
trace_id=trace_id,
|
|
2012
|
+
observation_id=observation_id,
|
|
2013
|
+
name=name,
|
|
2014
|
+
value=cast(str, value),
|
|
2015
|
+
score_id=score_id,
|
|
2016
|
+
data_type=cast(Literal["CATEGORICAL"], data_type),
|
|
2017
|
+
comment=comment,
|
|
2018
|
+
config_id=config_id,
|
|
2019
|
+
metadata=metadata,
|
|
2020
|
+
)
|
|
2021
|
+
|
|
2022
|
+
@overload
|
|
2023
|
+
def score_current_trace(
|
|
2024
|
+
self,
|
|
2025
|
+
*,
|
|
2026
|
+
name: str,
|
|
2027
|
+
value: float,
|
|
2028
|
+
score_id: Optional[str] = None,
|
|
2029
|
+
data_type: Optional[Literal["NUMERIC", "BOOLEAN"]] = None,
|
|
2030
|
+
comment: Optional[str] = None,
|
|
2031
|
+
config_id: Optional[str] = None,
|
|
2032
|
+
metadata: Optional[Any] = None,
|
|
2033
|
+
) -> None: ...
|
|
2034
|
+
|
|
2035
|
+
@overload
|
|
2036
|
+
def score_current_trace(
|
|
2037
|
+
self,
|
|
2038
|
+
*,
|
|
2039
|
+
name: str,
|
|
2040
|
+
value: str,
|
|
2041
|
+
score_id: Optional[str] = None,
|
|
2042
|
+
data_type: Optional[Literal["CATEGORICAL"]] = "CATEGORICAL",
|
|
2043
|
+
comment: Optional[str] = None,
|
|
2044
|
+
config_id: Optional[str] = None,
|
|
2045
|
+
metadata: Optional[Any] = None,
|
|
2046
|
+
) -> None: ...
|
|
2047
|
+
|
|
2048
|
+
def score_current_trace(
|
|
2049
|
+
self,
|
|
2050
|
+
*,
|
|
2051
|
+
name: str,
|
|
2052
|
+
value: Union[float, str],
|
|
2053
|
+
score_id: Optional[str] = None,
|
|
2054
|
+
data_type: Optional[ScoreDataType] = None,
|
|
2055
|
+
comment: Optional[str] = None,
|
|
2056
|
+
config_id: Optional[str] = None,
|
|
2057
|
+
metadata: Optional[Any] = None,
|
|
2058
|
+
) -> None:
|
|
2059
|
+
"""Create a score for the current trace.
|
|
2060
|
+
|
|
2061
|
+
This method scores the trace of the currently active span. Unlike score_current_span,
|
|
2062
|
+
this method associates the score with the entire trace rather than a specific span.
|
|
2063
|
+
It's useful for scoring overall performance or quality of the entire operation.
|
|
2064
|
+
|
|
2065
|
+
Args:
|
|
2066
|
+
name: Name of the score (e.g., "user_satisfaction", "overall_quality")
|
|
2067
|
+
value: Score value (can be numeric for NUMERIC/BOOLEAN types or string for CATEGORICAL)
|
|
2068
|
+
score_id: Optional custom ID for the score (auto-generated if not provided)
|
|
2069
|
+
data_type: Type of score (NUMERIC, BOOLEAN, or CATEGORICAL)
|
|
2070
|
+
comment: Optional comment or explanation for the score
|
|
2071
|
+
config_id: Optional ID of a score config defined in Aeri
|
|
2072
|
+
metadata: Optional metadata to be attached to the score
|
|
2073
|
+
|
|
2074
|
+
Example:
|
|
2075
|
+
```python
|
|
2076
|
+
with aeri.start_as_current_observation(name="process-user-request") as span:
|
|
2077
|
+
# Process request
|
|
2078
|
+
result = process_complete_request()
|
|
2079
|
+
span.update(output=result)
|
|
2080
|
+
|
|
2081
|
+
# Score the overall trace
|
|
2082
|
+
aeri.score_current_trace(
|
|
2083
|
+
name="overall_quality",
|
|
2084
|
+
value=0.95,
|
|
2085
|
+
data_type="NUMERIC",
|
|
2086
|
+
comment="High quality end-to-end response",
|
|
2087
|
+
metadata={"evaluator": "gpt-4", "criteria": "comprehensive"}
|
|
2088
|
+
)
|
|
2089
|
+
```
|
|
2090
|
+
"""
|
|
2091
|
+
current_span = self._get_current_otel_span()
|
|
2092
|
+
|
|
2093
|
+
if current_span is not None:
|
|
2094
|
+
trace_id = self._get_otel_trace_id(current_span)
|
|
2095
|
+
|
|
2096
|
+
aeri_logger.info(
|
|
2097
|
+
f"Score: Creating score name='{name}' value={value} for entire trace {trace_id}"
|
|
2098
|
+
)
|
|
2099
|
+
|
|
2100
|
+
self.create_score(
|
|
2101
|
+
trace_id=trace_id,
|
|
2102
|
+
name=name,
|
|
2103
|
+
value=cast(str, value),
|
|
2104
|
+
score_id=score_id,
|
|
2105
|
+
data_type=cast(Literal["CATEGORICAL"], data_type),
|
|
2106
|
+
comment=comment,
|
|
2107
|
+
config_id=config_id,
|
|
2108
|
+
metadata=metadata,
|
|
2109
|
+
)
|
|
2110
|
+
|
|
2111
|
+
def flush(self) -> None:
|
|
2112
|
+
"""Force flush all pending spans and events to the Aeri API.
|
|
2113
|
+
|
|
2114
|
+
This method manually flushes any pending spans, scores, and other events to the
|
|
2115
|
+
Aeri API. It's useful in scenarios where you want to ensure all data is sent
|
|
2116
|
+
before proceeding, without waiting for the automatic flush interval.
|
|
2117
|
+
|
|
2118
|
+
Example:
|
|
2119
|
+
```python
|
|
2120
|
+
# Record some spans and scores
|
|
2121
|
+
with aeri.start_as_current_observation(name="operation") as span:
|
|
2122
|
+
# Do work...
|
|
2123
|
+
pass
|
|
2124
|
+
|
|
2125
|
+
# Ensure all data is sent to Aeri before proceeding
|
|
2126
|
+
aeri.flush()
|
|
2127
|
+
|
|
2128
|
+
# Continue with other work
|
|
2129
|
+
```
|
|
2130
|
+
"""
|
|
2131
|
+
if self._resources is not None:
|
|
2132
|
+
self._resources.flush()
|
|
2133
|
+
|
|
2134
|
+
def shutdown(self) -> None:
|
|
2135
|
+
"""Shut down the Aeri client and flush all pending data.
|
|
2136
|
+
|
|
2137
|
+
This method cleanly shuts down the Aeri client, ensuring all pending data
|
|
2138
|
+
is flushed to the API and all background threads are properly terminated.
|
|
2139
|
+
|
|
2140
|
+
It's important to call this method when your application is shutting down to
|
|
2141
|
+
prevent data loss and resource leaks. For most applications, using the client
|
|
2142
|
+
as a context manager or relying on the automatic shutdown via atexit is sufficient.
|
|
2143
|
+
|
|
2144
|
+
Example:
|
|
2145
|
+
```python
|
|
2146
|
+
# Initialize Aeri
|
|
2147
|
+
aeri = Aeri(public_key="...", secret_key="...")
|
|
2148
|
+
|
|
2149
|
+
# Use Aeri throughout your application
|
|
2150
|
+
# ...
|
|
2151
|
+
|
|
2152
|
+
# When application is shutting down
|
|
2153
|
+
aeri.shutdown()
|
|
2154
|
+
```
|
|
2155
|
+
"""
|
|
2156
|
+
if self._resources is not None:
|
|
2157
|
+
self._resources.shutdown()
|
|
2158
|
+
|
|
2159
|
+
def get_current_trace_id(self) -> Optional[str]:
|
|
2160
|
+
"""Get the trace ID of the current active span.
|
|
2161
|
+
|
|
2162
|
+
This method retrieves the trace ID from the currently active span in the context.
|
|
2163
|
+
It can be used to get the trace ID for referencing in logs, external systems,
|
|
2164
|
+
or for creating related operations.
|
|
2165
|
+
|
|
2166
|
+
Returns:
|
|
2167
|
+
The current trace ID as a 32-character lowercase hexadecimal string,
|
|
2168
|
+
or None if there is no active span.
|
|
2169
|
+
|
|
2170
|
+
Example:
|
|
2171
|
+
```python
|
|
2172
|
+
with aeri.start_as_current_observation(name="process-request") as span:
|
|
2173
|
+
# Get the current trace ID for reference
|
|
2174
|
+
trace_id = aeri.get_current_trace_id()
|
|
2175
|
+
|
|
2176
|
+
# Use it for external correlation
|
|
2177
|
+
log.info(f"Processing request with trace_id: {trace_id}")
|
|
2178
|
+
|
|
2179
|
+
# Or pass to another system
|
|
2180
|
+
external_system.process(data, trace_id=trace_id)
|
|
2181
|
+
```
|
|
2182
|
+
"""
|
|
2183
|
+
if not self._tracing_enabled:
|
|
2184
|
+
aeri_logger.debug(
|
|
2185
|
+
"Operation skipped: get_current_trace_id - Tracing is disabled or client is in no-op mode."
|
|
2186
|
+
)
|
|
2187
|
+
return None
|
|
2188
|
+
|
|
2189
|
+
current_otel_span = self._get_current_otel_span()
|
|
2190
|
+
|
|
2191
|
+
return self._get_otel_trace_id(current_otel_span) if current_otel_span else None
|
|
2192
|
+
|
|
2193
|
+
def get_current_observation_id(self) -> Optional[str]:
|
|
2194
|
+
"""Get the observation ID (span ID) of the current active span.
|
|
2195
|
+
|
|
2196
|
+
This method retrieves the observation ID from the currently active span in the context.
|
|
2197
|
+
It can be used to get the observation ID for referencing in logs, external systems,
|
|
2198
|
+
or for creating scores or other related operations.
|
|
2199
|
+
|
|
2200
|
+
Returns:
|
|
2201
|
+
The current observation ID as a 16-character lowercase hexadecimal string,
|
|
2202
|
+
or None if there is no active span.
|
|
2203
|
+
|
|
2204
|
+
Example:
|
|
2205
|
+
```python
|
|
2206
|
+
with aeri.start_as_current_observation(name="process-user-query") as span:
|
|
2207
|
+
# Get the current observation ID
|
|
2208
|
+
observation_id = aeri.get_current_observation_id()
|
|
2209
|
+
|
|
2210
|
+
# Store it for later reference
|
|
2211
|
+
cache.set(f"query_{query_id}_observation", observation_id)
|
|
2212
|
+
|
|
2213
|
+
# Process the query...
|
|
2214
|
+
```
|
|
2215
|
+
"""
|
|
2216
|
+
if not self._tracing_enabled:
|
|
2217
|
+
aeri_logger.debug(
|
|
2218
|
+
"Operation skipped: get_current_observation_id - Tracing is disabled or client is in no-op mode."
|
|
2219
|
+
)
|
|
2220
|
+
return None
|
|
2221
|
+
|
|
2222
|
+
current_otel_span = self._get_current_otel_span()
|
|
2223
|
+
|
|
2224
|
+
return self._get_otel_span_id(current_otel_span) if current_otel_span else None
|
|
2225
|
+
|
|
2226
|
+
def _get_project_id(self) -> Optional[str]:
|
|
2227
|
+
"""Fetch and return the current project id. Persisted across requests. Returns None if no project id is found for api keys."""
|
|
2228
|
+
if not self._project_id:
|
|
2229
|
+
proj = self.api.projects.get()
|
|
2230
|
+
if not proj.data or not proj.data[0].id:
|
|
2231
|
+
return None
|
|
2232
|
+
|
|
2233
|
+
self._project_id = proj.data[0].id
|
|
2234
|
+
|
|
2235
|
+
return self._project_id
|
|
2236
|
+
|
|
2237
|
+
def get_trace_url(self, *, trace_id: Optional[str] = None) -> Optional[str]:
|
|
2238
|
+
"""Get the URL to view a trace in the Aeri UI.
|
|
2239
|
+
|
|
2240
|
+
This method generates a URL that links directly to a trace in the Aeri UI.
|
|
2241
|
+
It's useful for providing links in logs, notifications, or debugging tools.
|
|
2242
|
+
|
|
2243
|
+
Args:
|
|
2244
|
+
trace_id: Optional trace ID to generate a URL for. If not provided,
|
|
2245
|
+
the trace ID of the current active span will be used.
|
|
2246
|
+
|
|
2247
|
+
Returns:
|
|
2248
|
+
A URL string pointing to the trace in the Aeri UI,
|
|
2249
|
+
or None if the project ID couldn't be retrieved or no trace ID is available.
|
|
2250
|
+
|
|
2251
|
+
Example:
|
|
2252
|
+
```python
|
|
2253
|
+
# Get URL for the current trace
|
|
2254
|
+
with aeri.start_as_current_observation(name="process-request") as span:
|
|
2255
|
+
trace_url = aeri.get_trace_url()
|
|
2256
|
+
log.info(f"Processing trace: {trace_url}")
|
|
2257
|
+
|
|
2258
|
+
# Get URL for a specific trace
|
|
2259
|
+
specific_trace_url = aeri.get_trace_url(trace_id="1234567890abcdef1234567890abcdef")
|
|
2260
|
+
send_notification(f"Review needed for trace: {specific_trace_url}")
|
|
2261
|
+
```
|
|
2262
|
+
"""
|
|
2263
|
+
final_trace_id = trace_id or self.get_current_trace_id()
|
|
2264
|
+
if not final_trace_id:
|
|
2265
|
+
return None
|
|
2266
|
+
|
|
2267
|
+
project_id = self._get_project_id()
|
|
2268
|
+
|
|
2269
|
+
return (
|
|
2270
|
+
f"{self._base_url}/project/{project_id}/traces/{final_trace_id}"
|
|
2271
|
+
if project_id and final_trace_id
|
|
2272
|
+
else None
|
|
2273
|
+
)
|
|
2274
|
+
|
|
2275
|
+
def get_dataset(
|
|
2276
|
+
self,
|
|
2277
|
+
name: str,
|
|
2278
|
+
*,
|
|
2279
|
+
fetch_items_page_size: Optional[int] = 50,
|
|
2280
|
+
version: Optional[datetime] = None,
|
|
2281
|
+
) -> "DatasetClient":
|
|
2282
|
+
"""Fetch a dataset by its name.
|
|
2283
|
+
|
|
2284
|
+
Args:
|
|
2285
|
+
name (str): The name of the dataset to fetch.
|
|
2286
|
+
fetch_items_page_size (Optional[int]): All items of the dataset will be fetched in chunks of this size. Defaults to 50.
|
|
2287
|
+
version (Optional[datetime]): Retrieve dataset items as they existed at this specific point in time (UTC).
|
|
2288
|
+
If provided, returns the state of items at the specified UTC timestamp.
|
|
2289
|
+
If not provided, returns the latest version. Must be a timezone-aware datetime object in UTC.
|
|
2290
|
+
|
|
2291
|
+
Returns:
|
|
2292
|
+
DatasetClient: The dataset with the given name.
|
|
2293
|
+
"""
|
|
2294
|
+
try:
|
|
2295
|
+
aeri_logger.debug(f"Getting datasets {name}")
|
|
2296
|
+
dataset = self.api.datasets.get(dataset_name=self._url_encode(name))
|
|
2297
|
+
|
|
2298
|
+
dataset_items = []
|
|
2299
|
+
page = 1
|
|
2300
|
+
|
|
2301
|
+
while True:
|
|
2302
|
+
new_items = self.api.dataset_items.list(
|
|
2303
|
+
dataset_name=self._url_encode(name, is_url_param=True),
|
|
2304
|
+
page=page,
|
|
2305
|
+
limit=fetch_items_page_size,
|
|
2306
|
+
version=version,
|
|
2307
|
+
)
|
|
2308
|
+
dataset_items.extend(new_items.data)
|
|
2309
|
+
|
|
2310
|
+
if new_items.meta.total_pages <= page:
|
|
2311
|
+
break
|
|
2312
|
+
|
|
2313
|
+
page += 1
|
|
2314
|
+
|
|
2315
|
+
return DatasetClient(
|
|
2316
|
+
dataset=dataset,
|
|
2317
|
+
items=dataset_items,
|
|
2318
|
+
version=version,
|
|
2319
|
+
aeri_client=self,
|
|
2320
|
+
)
|
|
2321
|
+
|
|
2322
|
+
except Error as e:
|
|
2323
|
+
handle_fern_exception(e)
|
|
2324
|
+
raise e
|
|
2325
|
+
|
|
2326
|
+
def get_dataset_run(
|
|
2327
|
+
self, *, dataset_name: str, run_name: str
|
|
2328
|
+
) -> DatasetRunWithItems:
|
|
2329
|
+
"""Fetch a dataset run by dataset name and run name.
|
|
2330
|
+
|
|
2331
|
+
Args:
|
|
2332
|
+
dataset_name (str): The name of the dataset.
|
|
2333
|
+
run_name (str): The name of the run.
|
|
2334
|
+
|
|
2335
|
+
Returns:
|
|
2336
|
+
DatasetRunWithItems: The dataset run with its items.
|
|
2337
|
+
"""
|
|
2338
|
+
try:
|
|
2339
|
+
return cast(
|
|
2340
|
+
DatasetRunWithItems,
|
|
2341
|
+
self.api.datasets.get_run(
|
|
2342
|
+
dataset_name=self._url_encode(dataset_name),
|
|
2343
|
+
run_name=self._url_encode(run_name),
|
|
2344
|
+
request_options=None,
|
|
2345
|
+
),
|
|
2346
|
+
)
|
|
2347
|
+
except Error as e:
|
|
2348
|
+
handle_fern_exception(e)
|
|
2349
|
+
raise e
|
|
2350
|
+
|
|
2351
|
+
def get_dataset_runs(
|
|
2352
|
+
self,
|
|
2353
|
+
*,
|
|
2354
|
+
dataset_name: str,
|
|
2355
|
+
page: Optional[int] = None,
|
|
2356
|
+
limit: Optional[int] = None,
|
|
2357
|
+
) -> PaginatedDatasetRuns:
|
|
2358
|
+
"""Fetch all runs for a dataset.
|
|
2359
|
+
|
|
2360
|
+
Args:
|
|
2361
|
+
dataset_name (str): The name of the dataset.
|
|
2362
|
+
page (Optional[int]): Page number, starts at 1.
|
|
2363
|
+
limit (Optional[int]): Limit of items per page.
|
|
2364
|
+
|
|
2365
|
+
Returns:
|
|
2366
|
+
PaginatedDatasetRuns: Paginated list of dataset runs.
|
|
2367
|
+
"""
|
|
2368
|
+
try:
|
|
2369
|
+
return cast(
|
|
2370
|
+
PaginatedDatasetRuns,
|
|
2371
|
+
self.api.datasets.get_runs(
|
|
2372
|
+
dataset_name=self._url_encode(dataset_name),
|
|
2373
|
+
page=page,
|
|
2374
|
+
limit=limit,
|
|
2375
|
+
request_options=None,
|
|
2376
|
+
),
|
|
2377
|
+
)
|
|
2378
|
+
except Error as e:
|
|
2379
|
+
handle_fern_exception(e)
|
|
2380
|
+
raise e
|
|
2381
|
+
|
|
2382
|
+
def delete_dataset_run(
|
|
2383
|
+
self, *, dataset_name: str, run_name: str
|
|
2384
|
+
) -> DeleteDatasetRunResponse:
|
|
2385
|
+
"""Delete a dataset run and all its run items. This action is irreversible.
|
|
2386
|
+
|
|
2387
|
+
Args:
|
|
2388
|
+
dataset_name (str): The name of the dataset.
|
|
2389
|
+
run_name (str): The name of the run.
|
|
2390
|
+
|
|
2391
|
+
Returns:
|
|
2392
|
+
DeleteDatasetRunResponse: Confirmation of deletion.
|
|
2393
|
+
"""
|
|
2394
|
+
try:
|
|
2395
|
+
return cast(
|
|
2396
|
+
DeleteDatasetRunResponse,
|
|
2397
|
+
self.api.datasets.delete_run(
|
|
2398
|
+
dataset_name=self._url_encode(dataset_name),
|
|
2399
|
+
run_name=self._url_encode(run_name),
|
|
2400
|
+
request_options=None,
|
|
2401
|
+
),
|
|
2402
|
+
)
|
|
2403
|
+
except Error as e:
|
|
2404
|
+
handle_fern_exception(e)
|
|
2405
|
+
raise e
|
|
2406
|
+
|
|
2407
|
+
def run_experiment(
|
|
2408
|
+
self,
|
|
2409
|
+
*,
|
|
2410
|
+
name: str,
|
|
2411
|
+
run_name: Optional[str] = None,
|
|
2412
|
+
description: Optional[str] = None,
|
|
2413
|
+
data: ExperimentData,
|
|
2414
|
+
task: TaskFunction,
|
|
2415
|
+
evaluators: List[EvaluatorFunction] = [],
|
|
2416
|
+
composite_evaluator: Optional[CompositeEvaluatorFunction] = None,
|
|
2417
|
+
run_evaluators: List[RunEvaluatorFunction] = [],
|
|
2418
|
+
max_concurrency: int = 50,
|
|
2419
|
+
metadata: Optional[Dict[str, str]] = None,
|
|
2420
|
+
_dataset_version: Optional[datetime] = None,
|
|
2421
|
+
) -> ExperimentResult:
|
|
2422
|
+
"""Run an experiment on a dataset with automatic tracing and evaluation.
|
|
2423
|
+
|
|
2424
|
+
This method executes a task function on each item in the provided dataset,
|
|
2425
|
+
automatically traces all executions with Aeri for observability, runs
|
|
2426
|
+
item-level and run-level evaluators on the outputs, and returns comprehensive
|
|
2427
|
+
results with evaluation metrics.
|
|
2428
|
+
|
|
2429
|
+
The experiment system provides:
|
|
2430
|
+
- Automatic tracing of all task executions
|
|
2431
|
+
- Concurrent processing with configurable limits
|
|
2432
|
+
- Comprehensive error handling that isolates failures
|
|
2433
|
+
- Integration with Aeri datasets for experiment tracking
|
|
2434
|
+
- Flexible evaluation framework supporting both sync and async evaluators
|
|
2435
|
+
|
|
2436
|
+
Args:
|
|
2437
|
+
name: Human-readable name for the experiment. Used for identification
|
|
2438
|
+
in the Aeri UI.
|
|
2439
|
+
run_name: Optional exact name for the experiment run. If provided, this will be
|
|
2440
|
+
used as the exact dataset run name if the `data` contains Aeri dataset items.
|
|
2441
|
+
If not provided, this will default to the experiment name appended with an ISO timestamp.
|
|
2442
|
+
description: Optional description explaining the experiment's purpose,
|
|
2443
|
+
methodology, or expected outcomes.
|
|
2444
|
+
data: Array of data items to process. Can be either:
|
|
2445
|
+
- List of dict-like items with 'input', 'expected_output', 'metadata' keys
|
|
2446
|
+
- List of Aeri DatasetItem objects from dataset.items
|
|
2447
|
+
task: Function that processes each data item and returns output.
|
|
2448
|
+
Must accept 'item' as keyword argument and can return sync or async results.
|
|
2449
|
+
The task function signature should be: task(*, item, **kwargs) -> Any
|
|
2450
|
+
evaluators: List of functions to evaluate each item's output individually.
|
|
2451
|
+
Each evaluator receives input, output, expected_output, and metadata.
|
|
2452
|
+
Can return single Evaluation dict or list of Evaluation dicts.
|
|
2453
|
+
composite_evaluator: Optional function that creates composite scores from item-level evaluations.
|
|
2454
|
+
Receives the same inputs as item-level evaluators (input, output, expected_output, metadata)
|
|
2455
|
+
plus the list of evaluations from item-level evaluators. Useful for weighted averages,
|
|
2456
|
+
pass/fail decisions based on multiple criteria, or custom scoring logic combining multiple metrics.
|
|
2457
|
+
run_evaluators: List of functions to evaluate the entire experiment run.
|
|
2458
|
+
Each run evaluator receives all item_results and can compute aggregate metrics.
|
|
2459
|
+
Useful for calculating averages, distributions, or cross-item comparisons.
|
|
2460
|
+
max_concurrency: Maximum number of concurrent task executions (default: 50).
|
|
2461
|
+
Controls the number of items processed simultaneously. Adjust based on
|
|
2462
|
+
API rate limits and system resources.
|
|
2463
|
+
metadata: Optional metadata dictionary to attach to all experiment traces.
|
|
2464
|
+
This metadata will be included in every trace created during the experiment.
|
|
2465
|
+
If `data` are Aeri dataset items, the metadata will be attached to the dataset run, too.
|
|
2466
|
+
|
|
2467
|
+
Returns:
|
|
2468
|
+
ExperimentResult containing:
|
|
2469
|
+
- run_name: The experiment run name. This is equal to the dataset run name if experiment was on Aeri dataset.
|
|
2470
|
+
- item_results: List of results for each processed item with outputs and evaluations
|
|
2471
|
+
- run_evaluations: List of aggregate evaluation results for the entire run
|
|
2472
|
+
- dataset_run_id: ID of the dataset run (if using Aeri datasets)
|
|
2473
|
+
- dataset_run_url: Direct URL to view results in Aeri UI (if applicable)
|
|
2474
|
+
|
|
2475
|
+
Raises:
|
|
2476
|
+
ValueError: If required parameters are missing or invalid
|
|
2477
|
+
Exception: If experiment setup fails (individual item failures are handled gracefully)
|
|
2478
|
+
|
|
2479
|
+
Examples:
|
|
2480
|
+
Basic experiment with local data:
|
|
2481
|
+
```python
|
|
2482
|
+
def summarize_text(*, item, **kwargs):
|
|
2483
|
+
return f"Summary: {item['input'][:50]}..."
|
|
2484
|
+
|
|
2485
|
+
def length_evaluator(*, input, output, expected_output=None, **kwargs):
|
|
2486
|
+
return {
|
|
2487
|
+
"name": "output_length",
|
|
2488
|
+
"value": len(output),
|
|
2489
|
+
"comment": f"Output contains {len(output)} characters"
|
|
2490
|
+
}
|
|
2491
|
+
|
|
2492
|
+
result = aeri.run_experiment(
|
|
2493
|
+
name="Text Summarization Test",
|
|
2494
|
+
description="Evaluate summarization quality and length",
|
|
2495
|
+
data=[
|
|
2496
|
+
{"input": "Long article text...", "expected_output": "Expected summary"},
|
|
2497
|
+
{"input": "Another article...", "expected_output": "Another summary"}
|
|
2498
|
+
],
|
|
2499
|
+
task=summarize_text,
|
|
2500
|
+
evaluators=[length_evaluator]
|
|
2501
|
+
)
|
|
2502
|
+
|
|
2503
|
+
print(f"Processed {len(result.item_results)} items")
|
|
2504
|
+
for item_result in result.item_results:
|
|
2505
|
+
print(f"Input: {item_result.item['input']}")
|
|
2506
|
+
print(f"Output: {item_result.output}")
|
|
2507
|
+
print(f"Evaluations: {item_result.evaluations}")
|
|
2508
|
+
```
|
|
2509
|
+
|
|
2510
|
+
Advanced experiment with async task and multiple evaluators:
|
|
2511
|
+
```python
|
|
2512
|
+
async def llm_task(*, item, **kwargs):
|
|
2513
|
+
# Simulate async LLM call
|
|
2514
|
+
response = await openai_client.chat.completions.create(
|
|
2515
|
+
model="gpt-4",
|
|
2516
|
+
messages=[{"role": "user", "content": item["input"]}]
|
|
2517
|
+
)
|
|
2518
|
+
return response.choices[0].message.content
|
|
2519
|
+
|
|
2520
|
+
def accuracy_evaluator(*, input, output, expected_output=None, **kwargs):
|
|
2521
|
+
if expected_output and expected_output.lower() in output.lower():
|
|
2522
|
+
return {"name": "accuracy", "value": 1.0, "comment": "Correct answer"}
|
|
2523
|
+
return {"name": "accuracy", "value": 0.0, "comment": "Incorrect answer"}
|
|
2524
|
+
|
|
2525
|
+
def toxicity_evaluator(*, input, output, expected_output=None, **kwargs):
|
|
2526
|
+
# Simulate toxicity check
|
|
2527
|
+
toxicity_score = check_toxicity(output) # Your toxicity checker
|
|
2528
|
+
return {
|
|
2529
|
+
"name": "toxicity",
|
|
2530
|
+
"value": toxicity_score,
|
|
2531
|
+
"comment": f"Toxicity level: {'high' if toxicity_score > 0.7 else 'low'}"
|
|
2532
|
+
}
|
|
2533
|
+
|
|
2534
|
+
def average_accuracy(*, item_results, **kwargs):
|
|
2535
|
+
accuracies = [
|
|
2536
|
+
eval.value for result in item_results
|
|
2537
|
+
for eval in result.evaluations
|
|
2538
|
+
if eval.name == "accuracy"
|
|
2539
|
+
]
|
|
2540
|
+
return {
|
|
2541
|
+
"name": "average_accuracy",
|
|
2542
|
+
"value": sum(accuracies) / len(accuracies) if accuracies else 0,
|
|
2543
|
+
"comment": f"Average accuracy across {len(accuracies)} items"
|
|
2544
|
+
}
|
|
2545
|
+
|
|
2546
|
+
result = aeri.run_experiment(
|
|
2547
|
+
name="LLM Safety and Accuracy Test",
|
|
2548
|
+
description="Evaluate model accuracy and safety across diverse prompts",
|
|
2549
|
+
data=test_dataset, # Your dataset items
|
|
2550
|
+
task=llm_task,
|
|
2551
|
+
evaluators=[accuracy_evaluator, toxicity_evaluator],
|
|
2552
|
+
run_evaluators=[average_accuracy],
|
|
2553
|
+
max_concurrency=5, # Limit concurrent API calls
|
|
2554
|
+
metadata={"model": "gpt-4", "temperature": 0.7}
|
|
2555
|
+
)
|
|
2556
|
+
```
|
|
2557
|
+
|
|
2558
|
+
Using with Aeri datasets:
|
|
2559
|
+
```python
|
|
2560
|
+
# Get dataset from Aeri
|
|
2561
|
+
dataset = aeri.get_dataset("my-eval-dataset")
|
|
2562
|
+
|
|
2563
|
+
result = dataset.run_experiment(
|
|
2564
|
+
name="Production Model Evaluation",
|
|
2565
|
+
description="Monthly evaluation of production model performance",
|
|
2566
|
+
task=my_production_task,
|
|
2567
|
+
evaluators=[accuracy_evaluator, latency_evaluator]
|
|
2568
|
+
)
|
|
2569
|
+
|
|
2570
|
+
# Results automatically linked to dataset in Aeri UI
|
|
2571
|
+
print(f"View results: {result['dataset_run_url']}")
|
|
2572
|
+
```
|
|
2573
|
+
|
|
2574
|
+
Note:
|
|
2575
|
+
- Task and evaluator functions can be either synchronous or asynchronous
|
|
2576
|
+
- Individual item failures are logged but don't stop the experiment
|
|
2577
|
+
- All executions are automatically traced and visible in Aeri UI
|
|
2578
|
+
- When using Aeri datasets, results are automatically linked for easy comparison
|
|
2579
|
+
- This method works in both sync and async contexts (Jupyter notebooks, web apps, etc.)
|
|
2580
|
+
- Async execution is handled automatically with smart event loop detection
|
|
2581
|
+
"""
|
|
2582
|
+
return cast(
|
|
2583
|
+
ExperimentResult,
|
|
2584
|
+
run_async_safely(
|
|
2585
|
+
self._run_experiment_async(
|
|
2586
|
+
name=name,
|
|
2587
|
+
run_name=self._create_experiment_run_name(
|
|
2588
|
+
name=name, run_name=run_name
|
|
2589
|
+
),
|
|
2590
|
+
description=description,
|
|
2591
|
+
data=data,
|
|
2592
|
+
task=task,
|
|
2593
|
+
evaluators=evaluators or [],
|
|
2594
|
+
composite_evaluator=composite_evaluator,
|
|
2595
|
+
run_evaluators=run_evaluators or [],
|
|
2596
|
+
max_concurrency=max_concurrency,
|
|
2597
|
+
metadata=metadata,
|
|
2598
|
+
dataset_version=_dataset_version,
|
|
2599
|
+
),
|
|
2600
|
+
),
|
|
2601
|
+
)
|
|
2602
|
+
|
|
2603
|
+
async def _run_experiment_async(
|
|
2604
|
+
self,
|
|
2605
|
+
*,
|
|
2606
|
+
name: str,
|
|
2607
|
+
run_name: str,
|
|
2608
|
+
description: Optional[str],
|
|
2609
|
+
data: ExperimentData,
|
|
2610
|
+
task: TaskFunction,
|
|
2611
|
+
evaluators: List[EvaluatorFunction],
|
|
2612
|
+
composite_evaluator: Optional[CompositeEvaluatorFunction],
|
|
2613
|
+
run_evaluators: List[RunEvaluatorFunction],
|
|
2614
|
+
max_concurrency: int,
|
|
2615
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
2616
|
+
dataset_version: Optional[datetime] = None,
|
|
2617
|
+
) -> ExperimentResult:
|
|
2618
|
+
aeri_logger.debug(
|
|
2619
|
+
f"Starting experiment '{name}' run '{run_name}' with {len(data)} items"
|
|
2620
|
+
)
|
|
2621
|
+
|
|
2622
|
+
# Set up concurrency control
|
|
2623
|
+
semaphore = asyncio.Semaphore(max_concurrency)
|
|
2624
|
+
|
|
2625
|
+
# Process all items
|
|
2626
|
+
async def process_item(item: ExperimentItem) -> ExperimentItemResult:
|
|
2627
|
+
async with semaphore:
|
|
2628
|
+
return await self._process_experiment_item(
|
|
2629
|
+
item,
|
|
2630
|
+
task,
|
|
2631
|
+
evaluators,
|
|
2632
|
+
composite_evaluator,
|
|
2633
|
+
name,
|
|
2634
|
+
run_name,
|
|
2635
|
+
description,
|
|
2636
|
+
metadata,
|
|
2637
|
+
dataset_version,
|
|
2638
|
+
)
|
|
2639
|
+
|
|
2640
|
+
# Run all items concurrently
|
|
2641
|
+
tasks = [process_item(item) for item in data]
|
|
2642
|
+
item_results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
2643
|
+
|
|
2644
|
+
# Filter out any exceptions and log errors
|
|
2645
|
+
valid_results: List[ExperimentItemResult] = []
|
|
2646
|
+
for i, result in enumerate(item_results):
|
|
2647
|
+
if isinstance(result, Exception):
|
|
2648
|
+
aeri_logger.error(f"Item {i} failed: {result}")
|
|
2649
|
+
elif isinstance(result, ExperimentItemResult):
|
|
2650
|
+
valid_results.append(result) # type: ignore
|
|
2651
|
+
|
|
2652
|
+
# Run experiment-level evaluators
|
|
2653
|
+
run_evaluations: List[Evaluation] = []
|
|
2654
|
+
for run_evaluator in run_evaluators:
|
|
2655
|
+
try:
|
|
2656
|
+
evaluations = await _run_evaluator(
|
|
2657
|
+
run_evaluator, item_results=valid_results
|
|
2658
|
+
)
|
|
2659
|
+
run_evaluations.extend(evaluations)
|
|
2660
|
+
except Exception as e:
|
|
2661
|
+
aeri_logger.error(f"Run evaluator failed: {e}")
|
|
2662
|
+
|
|
2663
|
+
# Generate dataset run URL if applicable
|
|
2664
|
+
dataset_run_id = valid_results[0].dataset_run_id if valid_results else None
|
|
2665
|
+
dataset_run_url = None
|
|
2666
|
+
if dataset_run_id and data:
|
|
2667
|
+
try:
|
|
2668
|
+
# Check if the first item has dataset_id (for DatasetItem objects)
|
|
2669
|
+
first_item = data[0]
|
|
2670
|
+
dataset_id = None
|
|
2671
|
+
|
|
2672
|
+
if hasattr(first_item, "dataset_id"):
|
|
2673
|
+
dataset_id = getattr(first_item, "dataset_id", None)
|
|
2674
|
+
|
|
2675
|
+
if dataset_id:
|
|
2676
|
+
project_id = self._get_project_id()
|
|
2677
|
+
|
|
2678
|
+
if project_id:
|
|
2679
|
+
dataset_run_url = f"{self._base_url}/project/{project_id}/datasets/{dataset_id}/runs/{dataset_run_id}"
|
|
2680
|
+
|
|
2681
|
+
except Exception:
|
|
2682
|
+
pass # URL generation is optional
|
|
2683
|
+
|
|
2684
|
+
# Store run-level evaluations as scores
|
|
2685
|
+
for evaluation in run_evaluations:
|
|
2686
|
+
try:
|
|
2687
|
+
if dataset_run_id:
|
|
2688
|
+
self.create_score(
|
|
2689
|
+
dataset_run_id=dataset_run_id,
|
|
2690
|
+
name=evaluation.name or "<unknown>",
|
|
2691
|
+
value=evaluation.value, # type: ignore
|
|
2692
|
+
comment=evaluation.comment,
|
|
2693
|
+
metadata=evaluation.metadata,
|
|
2694
|
+
data_type=evaluation.data_type, # type: ignore
|
|
2695
|
+
config_id=evaluation.config_id,
|
|
2696
|
+
)
|
|
2697
|
+
|
|
2698
|
+
except Exception as e:
|
|
2699
|
+
aeri_logger.error(f"Failed to store run evaluation: {e}")
|
|
2700
|
+
|
|
2701
|
+
# Flush scores and traces
|
|
2702
|
+
self.flush()
|
|
2703
|
+
|
|
2704
|
+
return ExperimentResult(
|
|
2705
|
+
name=name,
|
|
2706
|
+
run_name=run_name,
|
|
2707
|
+
description=description,
|
|
2708
|
+
item_results=valid_results,
|
|
2709
|
+
run_evaluations=run_evaluations,
|
|
2710
|
+
dataset_run_id=dataset_run_id,
|
|
2711
|
+
dataset_run_url=dataset_run_url,
|
|
2712
|
+
)
|
|
2713
|
+
|
|
2714
|
+
async def _process_experiment_item(
|
|
2715
|
+
self,
|
|
2716
|
+
item: ExperimentItem,
|
|
2717
|
+
task: Callable,
|
|
2718
|
+
evaluators: List[Callable],
|
|
2719
|
+
composite_evaluator: Optional[CompositeEvaluatorFunction],
|
|
2720
|
+
experiment_name: str,
|
|
2721
|
+
experiment_run_name: str,
|
|
2722
|
+
experiment_description: Optional[str],
|
|
2723
|
+
experiment_metadata: Optional[Dict[str, Any]] = None,
|
|
2724
|
+
dataset_version: Optional[datetime] = None,
|
|
2725
|
+
) -> ExperimentItemResult:
|
|
2726
|
+
span_name = "experiment-item-run"
|
|
2727
|
+
|
|
2728
|
+
with self.start_as_current_observation(name=span_name) as span:
|
|
2729
|
+
try:
|
|
2730
|
+
input_data = (
|
|
2731
|
+
item.get("input")
|
|
2732
|
+
if isinstance(item, dict)
|
|
2733
|
+
else getattr(item, "input", None)
|
|
2734
|
+
)
|
|
2735
|
+
|
|
2736
|
+
if input_data is None:
|
|
2737
|
+
raise ValueError("Experiment Item is missing input. Skipping item.")
|
|
2738
|
+
|
|
2739
|
+
expected_output = (
|
|
2740
|
+
item.get("expected_output")
|
|
2741
|
+
if isinstance(item, dict)
|
|
2742
|
+
else getattr(item, "expected_output", None)
|
|
2743
|
+
)
|
|
2744
|
+
|
|
2745
|
+
item_metadata = (
|
|
2746
|
+
item.get("metadata")
|
|
2747
|
+
if isinstance(item, dict)
|
|
2748
|
+
else getattr(item, "metadata", None)
|
|
2749
|
+
)
|
|
2750
|
+
|
|
2751
|
+
final_observation_metadata = {
|
|
2752
|
+
"experiment_name": experiment_name,
|
|
2753
|
+
"experiment_run_name": experiment_run_name,
|
|
2754
|
+
**(experiment_metadata or {}),
|
|
2755
|
+
}
|
|
2756
|
+
|
|
2757
|
+
trace_id = span.trace_id
|
|
2758
|
+
dataset_id = None
|
|
2759
|
+
dataset_item_id = None
|
|
2760
|
+
dataset_run_id = None
|
|
2761
|
+
|
|
2762
|
+
# Link to dataset run if this is a dataset item
|
|
2763
|
+
if hasattr(item, "id") and hasattr(item, "dataset_id"):
|
|
2764
|
+
try:
|
|
2765
|
+
# Use sync API to avoid event loop issues when run_async_safely
|
|
2766
|
+
# creates multiple event loops across different threads
|
|
2767
|
+
dataset_run_item = await asyncio.to_thread(
|
|
2768
|
+
self.api.dataset_run_items.create,
|
|
2769
|
+
run_name=experiment_run_name,
|
|
2770
|
+
run_description=experiment_description,
|
|
2771
|
+
metadata=experiment_metadata,
|
|
2772
|
+
dataset_item_id=item.id, # type: ignore
|
|
2773
|
+
trace_id=trace_id,
|
|
2774
|
+
observation_id=span.id,
|
|
2775
|
+
dataset_version=dataset_version,
|
|
2776
|
+
)
|
|
2777
|
+
|
|
2778
|
+
dataset_run_id = dataset_run_item.dataset_run_id
|
|
2779
|
+
|
|
2780
|
+
except Exception as e:
|
|
2781
|
+
aeri_logger.error(f"Failed to create dataset run item: {e}")
|
|
2782
|
+
|
|
2783
|
+
if (
|
|
2784
|
+
not isinstance(item, dict)
|
|
2785
|
+
and hasattr(item, "dataset_id")
|
|
2786
|
+
and hasattr(item, "id")
|
|
2787
|
+
):
|
|
2788
|
+
dataset_id = item.dataset_id
|
|
2789
|
+
dataset_item_id = item.id
|
|
2790
|
+
|
|
2791
|
+
final_observation_metadata.update(
|
|
2792
|
+
{"dataset_id": dataset_id, "dataset_item_id": dataset_item_id}
|
|
2793
|
+
)
|
|
2794
|
+
|
|
2795
|
+
if isinstance(item_metadata, dict):
|
|
2796
|
+
final_observation_metadata.update(item_metadata)
|
|
2797
|
+
|
|
2798
|
+
experiment_id = dataset_run_id or self._create_observation_id()
|
|
2799
|
+
experiment_item_id = (
|
|
2800
|
+
dataset_item_id or get_sha256_hash_hex(_serialize(input_data))[:16]
|
|
2801
|
+
)
|
|
2802
|
+
span._otel_span.set_attributes(
|
|
2803
|
+
{
|
|
2804
|
+
k: v
|
|
2805
|
+
for k, v in {
|
|
2806
|
+
AeriOtelSpanAttributes.ENVIRONMENT: AERI_SDK_EXPERIMENT_ENVIRONMENT,
|
|
2807
|
+
AeriOtelSpanAttributes.EXPERIMENT_DESCRIPTION: experiment_description,
|
|
2808
|
+
AeriOtelSpanAttributes.EXPERIMENT_ITEM_EXPECTED_OUTPUT: _serialize(
|
|
2809
|
+
expected_output
|
|
2810
|
+
),
|
|
2811
|
+
}.items()
|
|
2812
|
+
if v is not None
|
|
2813
|
+
}
|
|
2814
|
+
)
|
|
2815
|
+
|
|
2816
|
+
propagated_experiment_attributes = PropagatedExperimentAttributes(
|
|
2817
|
+
experiment_id=experiment_id,
|
|
2818
|
+
experiment_name=experiment_run_name,
|
|
2819
|
+
experiment_metadata=_serialize(experiment_metadata),
|
|
2820
|
+
experiment_dataset_id=dataset_id,
|
|
2821
|
+
experiment_item_id=experiment_item_id,
|
|
2822
|
+
experiment_item_metadata=_serialize(item_metadata),
|
|
2823
|
+
experiment_item_root_observation_id=span.id,
|
|
2824
|
+
)
|
|
2825
|
+
|
|
2826
|
+
with _propagate_attributes(experiment=propagated_experiment_attributes):
|
|
2827
|
+
output = await _run_task(task, item)
|
|
2828
|
+
|
|
2829
|
+
span.update(
|
|
2830
|
+
input=input_data,
|
|
2831
|
+
output=output,
|
|
2832
|
+
metadata=final_observation_metadata,
|
|
2833
|
+
)
|
|
2834
|
+
|
|
2835
|
+
except Exception as e:
|
|
2836
|
+
span.update(
|
|
2837
|
+
output=f"Error: {str(e)}", level="ERROR", status_message=str(e)
|
|
2838
|
+
)
|
|
2839
|
+
raise e
|
|
2840
|
+
|
|
2841
|
+
# Run evaluators
|
|
2842
|
+
evaluations = []
|
|
2843
|
+
|
|
2844
|
+
for evaluator in evaluators:
|
|
2845
|
+
try:
|
|
2846
|
+
eval_metadata: Optional[Dict[str, Any]] = None
|
|
2847
|
+
|
|
2848
|
+
if isinstance(item, dict):
|
|
2849
|
+
eval_metadata = item.get("metadata")
|
|
2850
|
+
elif hasattr(item, "metadata"):
|
|
2851
|
+
eval_metadata = item.metadata
|
|
2852
|
+
|
|
2853
|
+
with _propagate_attributes(
|
|
2854
|
+
experiment=propagated_experiment_attributes
|
|
2855
|
+
):
|
|
2856
|
+
eval_results = await _run_evaluator(
|
|
2857
|
+
evaluator,
|
|
2858
|
+
input=input_data,
|
|
2859
|
+
output=output,
|
|
2860
|
+
expected_output=expected_output,
|
|
2861
|
+
metadata=eval_metadata,
|
|
2862
|
+
)
|
|
2863
|
+
evaluations.extend(eval_results)
|
|
2864
|
+
|
|
2865
|
+
# Store evaluations as scores
|
|
2866
|
+
for evaluation in eval_results:
|
|
2867
|
+
self.create_score(
|
|
2868
|
+
trace_id=trace_id,
|
|
2869
|
+
observation_id=span.id,
|
|
2870
|
+
name=evaluation.name,
|
|
2871
|
+
value=evaluation.value, # type: ignore
|
|
2872
|
+
comment=evaluation.comment,
|
|
2873
|
+
metadata=evaluation.metadata,
|
|
2874
|
+
config_id=evaluation.config_id,
|
|
2875
|
+
data_type=evaluation.data_type, # type: ignore
|
|
2876
|
+
)
|
|
2877
|
+
|
|
2878
|
+
except Exception as e:
|
|
2879
|
+
aeri_logger.error(f"Evaluator failed: {e}")
|
|
2880
|
+
|
|
2881
|
+
# Run composite evaluator if provided and we have evaluations
|
|
2882
|
+
if composite_evaluator and evaluations:
|
|
2883
|
+
try:
|
|
2884
|
+
composite_eval_metadata: Optional[Dict[str, Any]] = None
|
|
2885
|
+
if isinstance(item, dict):
|
|
2886
|
+
composite_eval_metadata = item.get("metadata")
|
|
2887
|
+
elif hasattr(item, "metadata"):
|
|
2888
|
+
composite_eval_metadata = item.metadata
|
|
2889
|
+
|
|
2890
|
+
with _propagate_attributes(
|
|
2891
|
+
experiment=propagated_experiment_attributes
|
|
2892
|
+
):
|
|
2893
|
+
result = composite_evaluator(
|
|
2894
|
+
input=input_data,
|
|
2895
|
+
output=output,
|
|
2896
|
+
expected_output=expected_output,
|
|
2897
|
+
metadata=composite_eval_metadata,
|
|
2898
|
+
evaluations=evaluations,
|
|
2899
|
+
)
|
|
2900
|
+
|
|
2901
|
+
# Handle async composite evaluators
|
|
2902
|
+
if asyncio.iscoroutine(result):
|
|
2903
|
+
result = await result
|
|
2904
|
+
|
|
2905
|
+
# Normalize to list
|
|
2906
|
+
composite_evals: List[Evaluation] = []
|
|
2907
|
+
if isinstance(result, (dict, Evaluation)):
|
|
2908
|
+
composite_evals = [result] # type: ignore
|
|
2909
|
+
elif isinstance(result, list):
|
|
2910
|
+
composite_evals = result # type: ignore
|
|
2911
|
+
|
|
2912
|
+
# Store composite evaluations as scores and add to evaluations list
|
|
2913
|
+
for composite_evaluation in composite_evals:
|
|
2914
|
+
self.create_score(
|
|
2915
|
+
trace_id=trace_id,
|
|
2916
|
+
observation_id=span.id,
|
|
2917
|
+
name=composite_evaluation.name,
|
|
2918
|
+
value=composite_evaluation.value, # type: ignore
|
|
2919
|
+
comment=composite_evaluation.comment,
|
|
2920
|
+
metadata=composite_evaluation.metadata,
|
|
2921
|
+
config_id=composite_evaluation.config_id,
|
|
2922
|
+
data_type=composite_evaluation.data_type, # type: ignore
|
|
2923
|
+
)
|
|
2924
|
+
evaluations.append(composite_evaluation)
|
|
2925
|
+
|
|
2926
|
+
except Exception as e:
|
|
2927
|
+
aeri_logger.error(f"Composite evaluator failed: {e}")
|
|
2928
|
+
|
|
2929
|
+
return ExperimentItemResult(
|
|
2930
|
+
item=item,
|
|
2931
|
+
output=output,
|
|
2932
|
+
evaluations=evaluations,
|
|
2933
|
+
trace_id=trace_id,
|
|
2934
|
+
dataset_run_id=dataset_run_id,
|
|
2935
|
+
)
|
|
2936
|
+
|
|
2937
|
+
def _create_experiment_run_name(
|
|
2938
|
+
self, *, name: Optional[str] = None, run_name: Optional[str] = None
|
|
2939
|
+
) -> str:
|
|
2940
|
+
if run_name:
|
|
2941
|
+
return run_name
|
|
2942
|
+
|
|
2943
|
+
iso_timestamp = _get_timestamp().isoformat().replace("+00:00", "Z")
|
|
2944
|
+
|
|
2945
|
+
return f"{name} - {iso_timestamp}"
|
|
2946
|
+
|
|
2947
|
+
def run_batched_evaluation(
|
|
2948
|
+
self,
|
|
2949
|
+
*,
|
|
2950
|
+
scope: Literal["traces", "observations"],
|
|
2951
|
+
mapper: MapperFunction,
|
|
2952
|
+
filter: Optional[str] = None,
|
|
2953
|
+
fetch_batch_size: int = 50,
|
|
2954
|
+
fetch_trace_fields: Optional[str] = None,
|
|
2955
|
+
max_items: Optional[int] = None,
|
|
2956
|
+
max_retries: int = 3,
|
|
2957
|
+
evaluators: List[EvaluatorFunction],
|
|
2958
|
+
composite_evaluator: Optional[CompositeEvaluatorFunction] = None,
|
|
2959
|
+
max_concurrency: int = 5,
|
|
2960
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
2961
|
+
_add_observation_scores_to_trace: bool = False,
|
|
2962
|
+
_additional_trace_tags: Optional[List[str]] = None,
|
|
2963
|
+
resume_from: Optional[BatchEvaluationResumeToken] = None,
|
|
2964
|
+
verbose: bool = False,
|
|
2965
|
+
) -> BatchEvaluationResult:
|
|
2966
|
+
"""Fetch traces or observations and run evaluations on each item.
|
|
2967
|
+
|
|
2968
|
+
This method provides a powerful way to evaluate existing data in Aeri at scale.
|
|
2969
|
+
It fetches items based on filters, transforms them using a mapper function, runs
|
|
2970
|
+
evaluators on each item, and creates scores that are linked back to the original
|
|
2971
|
+
entities. This is ideal for:
|
|
2972
|
+
|
|
2973
|
+
- Running evaluations on production traces after deployment
|
|
2974
|
+
- Backtesting new evaluation metrics on historical data
|
|
2975
|
+
- Batch scoring of observations for quality monitoring
|
|
2976
|
+
- Periodic evaluation runs on recent data
|
|
2977
|
+
|
|
2978
|
+
The method uses a streaming/pipeline approach to process items in batches, making
|
|
2979
|
+
it memory-efficient for large datasets. It includes comprehensive error handling,
|
|
2980
|
+
retry logic, and resume capability for long-running evaluations.
|
|
2981
|
+
|
|
2982
|
+
Args:
|
|
2983
|
+
scope: The type of items to evaluate. Must be one of:
|
|
2984
|
+
- "traces": Evaluate complete traces with all their observations
|
|
2985
|
+
- "observations": Evaluate individual observations (spans, generations, events)
|
|
2986
|
+
mapper: Function that transforms API response objects into evaluator inputs.
|
|
2987
|
+
Receives a trace/observation object and returns an EvaluatorInputs
|
|
2988
|
+
instance with input, output, expected_output, and metadata fields.
|
|
2989
|
+
Can be sync or async.
|
|
2990
|
+
evaluators: List of evaluation functions to run on each item. Each evaluator
|
|
2991
|
+
receives the mapped inputs and returns Evaluation object(s). Evaluator
|
|
2992
|
+
failures are logged but don't stop the batch evaluation.
|
|
2993
|
+
filter: Optional JSON filter string for querying items (same format as Aeri API). Examples:
|
|
2994
|
+
- '{"tags": ["production"]}'
|
|
2995
|
+
- '{"user_id": "user123", "timestamp": {"operator": ">", "value": "2024-01-01"}}'
|
|
2996
|
+
Default: None (fetches all items).
|
|
2997
|
+
fetch_batch_size: Number of items to fetch per API call and hold in memory.
|
|
2998
|
+
Larger values may be faster but use more memory. Default: 50.
|
|
2999
|
+
fetch_trace_fields: Comma-separated list of fields to include when fetching traces. Available field groups: 'core' (always included), 'io' (input, output, metadata), 'scores', 'observations', 'metrics'. If not specified, all fields are returned. Example: 'core,scores,metrics'. Note: Excluded 'observations' or 'scores' fields return empty arrays; excluded 'metrics' returns -1 for 'totalCost' and 'latency'. Only relevant if scope is 'traces'.
|
|
3000
|
+
max_items: Maximum total number of items to process. If None, processes all
|
|
3001
|
+
items matching the filter. Useful for testing or limiting evaluation runs.
|
|
3002
|
+
Default: None (process all).
|
|
3003
|
+
max_concurrency: Maximum number of items to evaluate concurrently. Controls
|
|
3004
|
+
parallelism and resource usage. Default: 5.
|
|
3005
|
+
composite_evaluator: Optional function that creates a composite score from
|
|
3006
|
+
item-level evaluations. Receives the original item and its evaluations,
|
|
3007
|
+
returns a single Evaluation. Useful for weighted averages or combined metrics.
|
|
3008
|
+
Default: None.
|
|
3009
|
+
metadata: Optional metadata dict to add to all created scores. Useful for
|
|
3010
|
+
tracking evaluation runs, versions, or other context. Default: None.
|
|
3011
|
+
max_retries: Maximum number of retry attempts for failed batch fetches.
|
|
3012
|
+
Uses exponential backoff (1s, 2s, 4s). Default: 3.
|
|
3013
|
+
verbose: If True, logs progress information to console. Useful for monitoring
|
|
3014
|
+
long-running evaluations. Default: False.
|
|
3015
|
+
resume_from: Optional resume token from a previous incomplete run. Allows
|
|
3016
|
+
continuing evaluation after interruption or failure. Default: None.
|
|
3017
|
+
|
|
3018
|
+
|
|
3019
|
+
Returns:
|
|
3020
|
+
BatchEvaluationResult containing:
|
|
3021
|
+
- total_items_fetched: Number of items fetched from API
|
|
3022
|
+
- total_items_processed: Number of items successfully evaluated
|
|
3023
|
+
- total_items_failed: Number of items that failed evaluation
|
|
3024
|
+
- total_scores_created: Scores created by item-level evaluators
|
|
3025
|
+
- total_composite_scores_created: Scores created by composite evaluator
|
|
3026
|
+
- total_evaluations_failed: Individual evaluator failures
|
|
3027
|
+
- evaluator_stats: Per-evaluator statistics (success rate, scores created)
|
|
3028
|
+
- resume_token: Token for resuming if incomplete (None if completed)
|
|
3029
|
+
- completed: True if all items processed
|
|
3030
|
+
- duration_seconds: Total execution time
|
|
3031
|
+
- failed_item_ids: IDs of items that failed
|
|
3032
|
+
- error_summary: Error types and counts
|
|
3033
|
+
- has_more_items: True if max_items reached but more exist
|
|
3034
|
+
|
|
3035
|
+
Raises:
|
|
3036
|
+
ValueError: If invalid scope is provided.
|
|
3037
|
+
|
|
3038
|
+
Examples:
|
|
3039
|
+
Basic trace evaluation:
|
|
3040
|
+
```python
|
|
3041
|
+
from aeri import Aeri, EvaluatorInputs, Evaluation
|
|
3042
|
+
|
|
3043
|
+
client = Aeri()
|
|
3044
|
+
|
|
3045
|
+
# Define mapper to extract fields from traces
|
|
3046
|
+
def trace_mapper(trace):
|
|
3047
|
+
return EvaluatorInputs(
|
|
3048
|
+
input=trace.input,
|
|
3049
|
+
output=trace.output,
|
|
3050
|
+
expected_output=None,
|
|
3051
|
+
metadata={"trace_id": trace.id}
|
|
3052
|
+
)
|
|
3053
|
+
|
|
3054
|
+
# Define evaluator
|
|
3055
|
+
def length_evaluator(*, input, output, expected_output, metadata):
|
|
3056
|
+
return Evaluation(
|
|
3057
|
+
name="output_length",
|
|
3058
|
+
value=len(output) if output else 0
|
|
3059
|
+
)
|
|
3060
|
+
|
|
3061
|
+
# Run batch evaluation
|
|
3062
|
+
result = client.run_batched_evaluation(
|
|
3063
|
+
scope="traces",
|
|
3064
|
+
mapper=trace_mapper,
|
|
3065
|
+
evaluators=[length_evaluator],
|
|
3066
|
+
filter='{"tags": ["production"]}',
|
|
3067
|
+
max_items=1000,
|
|
3068
|
+
verbose=True
|
|
3069
|
+
)
|
|
3070
|
+
|
|
3071
|
+
print(f"Processed {result.total_items_processed} traces")
|
|
3072
|
+
print(f"Created {result.total_scores_created} scores")
|
|
3073
|
+
```
|
|
3074
|
+
|
|
3075
|
+
Evaluation with composite scorer:
|
|
3076
|
+
```python
|
|
3077
|
+
def accuracy_evaluator(*, input, output, expected_output, metadata):
|
|
3078
|
+
# ... evaluation logic
|
|
3079
|
+
return Evaluation(name="accuracy", value=0.85)
|
|
3080
|
+
|
|
3081
|
+
def relevance_evaluator(*, input, output, expected_output, metadata):
|
|
3082
|
+
# ... evaluation logic
|
|
3083
|
+
return Evaluation(name="relevance", value=0.92)
|
|
3084
|
+
|
|
3085
|
+
def composite_evaluator(*, item, evaluations):
|
|
3086
|
+
# Weighted average of evaluations
|
|
3087
|
+
weights = {"accuracy": 0.6, "relevance": 0.4}
|
|
3088
|
+
total = sum(
|
|
3089
|
+
e.value * weights.get(e.name, 0)
|
|
3090
|
+
for e in evaluations
|
|
3091
|
+
if isinstance(e.value, (int, float))
|
|
3092
|
+
)
|
|
3093
|
+
return Evaluation(
|
|
3094
|
+
name="composite_score",
|
|
3095
|
+
value=total,
|
|
3096
|
+
comment=f"Weighted average of {len(evaluations)} metrics"
|
|
3097
|
+
)
|
|
3098
|
+
|
|
3099
|
+
result = client.run_batched_evaluation(
|
|
3100
|
+
scope="traces",
|
|
3101
|
+
mapper=trace_mapper,
|
|
3102
|
+
evaluators=[accuracy_evaluator, relevance_evaluator],
|
|
3103
|
+
composite_evaluator=composite_evaluator,
|
|
3104
|
+
filter='{"user_id": "important_user"}',
|
|
3105
|
+
verbose=True
|
|
3106
|
+
)
|
|
3107
|
+
```
|
|
3108
|
+
|
|
3109
|
+
Handling incomplete runs with resume:
|
|
3110
|
+
```python
|
|
3111
|
+
# Initial run that may fail or timeout
|
|
3112
|
+
result = client.run_batched_evaluation(
|
|
3113
|
+
scope="observations",
|
|
3114
|
+
mapper=obs_mapper,
|
|
3115
|
+
evaluators=[my_evaluator],
|
|
3116
|
+
max_items=10000,
|
|
3117
|
+
verbose=True
|
|
3118
|
+
)
|
|
3119
|
+
|
|
3120
|
+
# Check if incomplete
|
|
3121
|
+
if not result.completed and result.resume_token:
|
|
3122
|
+
print(f"Processed {result.resume_token.items_processed} items before interruption")
|
|
3123
|
+
|
|
3124
|
+
# Resume from where it left off
|
|
3125
|
+
result = client.run_batched_evaluation(
|
|
3126
|
+
scope="observations",
|
|
3127
|
+
mapper=obs_mapper,
|
|
3128
|
+
evaluators=[my_evaluator],
|
|
3129
|
+
resume_from=result.resume_token,
|
|
3130
|
+
verbose=True
|
|
3131
|
+
)
|
|
3132
|
+
|
|
3133
|
+
print(f"Total items processed: {result.total_items_processed}")
|
|
3134
|
+
```
|
|
3135
|
+
|
|
3136
|
+
Monitoring evaluator performance:
|
|
3137
|
+
```python
|
|
3138
|
+
result = client.run_batched_evaluation(...)
|
|
3139
|
+
|
|
3140
|
+
for stats in result.evaluator_stats:
|
|
3141
|
+
success_rate = stats.successful_runs / stats.total_runs
|
|
3142
|
+
print(f"{stats.name}:")
|
|
3143
|
+
print(f" Success rate: {success_rate:.1%}")
|
|
3144
|
+
print(f" Scores created: {stats.total_scores_created}")
|
|
3145
|
+
|
|
3146
|
+
if stats.failed_runs > 0:
|
|
3147
|
+
print(f" ⚠️ Failed {stats.failed_runs} times")
|
|
3148
|
+
```
|
|
3149
|
+
|
|
3150
|
+
Note:
|
|
3151
|
+
- Evaluator failures are logged but don't stop the batch evaluation
|
|
3152
|
+
- Individual item failures are tracked but don't stop processing
|
|
3153
|
+
- Fetch failures are retried with exponential backoff
|
|
3154
|
+
- All scores are automatically flushed to Aeri at the end
|
|
3155
|
+
- The resume mechanism uses timestamp-based filtering to avoid duplicates
|
|
3156
|
+
"""
|
|
3157
|
+
runner = BatchEvaluationRunner(self)
|
|
3158
|
+
|
|
3159
|
+
return cast(
|
|
3160
|
+
BatchEvaluationResult,
|
|
3161
|
+
run_async_safely(
|
|
3162
|
+
runner.run_async(
|
|
3163
|
+
scope=scope,
|
|
3164
|
+
mapper=mapper,
|
|
3165
|
+
evaluators=evaluators,
|
|
3166
|
+
filter=filter,
|
|
3167
|
+
fetch_batch_size=fetch_batch_size,
|
|
3168
|
+
fetch_trace_fields=fetch_trace_fields,
|
|
3169
|
+
max_items=max_items,
|
|
3170
|
+
max_concurrency=max_concurrency,
|
|
3171
|
+
composite_evaluator=composite_evaluator,
|
|
3172
|
+
metadata=metadata,
|
|
3173
|
+
_add_observation_scores_to_trace=_add_observation_scores_to_trace,
|
|
3174
|
+
_additional_trace_tags=_additional_trace_tags,
|
|
3175
|
+
max_retries=max_retries,
|
|
3176
|
+
verbose=verbose,
|
|
3177
|
+
resume_from=resume_from,
|
|
3178
|
+
)
|
|
3179
|
+
),
|
|
3180
|
+
)
|
|
3181
|
+
|
|
3182
|
+
def auth_check(self) -> bool:
|
|
3183
|
+
"""Check if the provided credentials (public and secret key) are valid.
|
|
3184
|
+
|
|
3185
|
+
Raises:
|
|
3186
|
+
Exception: If no projects were found for the provided credentials.
|
|
3187
|
+
|
|
3188
|
+
Note:
|
|
3189
|
+
This method is blocking. It is discouraged to use it in production code.
|
|
3190
|
+
"""
|
|
3191
|
+
try:
|
|
3192
|
+
projects = self.api.projects.get()
|
|
3193
|
+
aeri_logger.debug(
|
|
3194
|
+
f"Auth check successful, found {len(projects.data)} projects"
|
|
3195
|
+
)
|
|
3196
|
+
if len(projects.data) == 0:
|
|
3197
|
+
raise Exception(
|
|
3198
|
+
"Auth check failed, no project found for the keys provided."
|
|
3199
|
+
)
|
|
3200
|
+
return True
|
|
3201
|
+
|
|
3202
|
+
except AttributeError as e:
|
|
3203
|
+
aeri_logger.warning(
|
|
3204
|
+
f"Auth check failed: Client not properly initialized. Error: {e}"
|
|
3205
|
+
)
|
|
3206
|
+
return False
|
|
3207
|
+
|
|
3208
|
+
except Error as e:
|
|
3209
|
+
handle_fern_exception(e)
|
|
3210
|
+
raise e
|
|
3211
|
+
|
|
3212
|
+
def create_dataset(
|
|
3213
|
+
self,
|
|
3214
|
+
*,
|
|
3215
|
+
name: str,
|
|
3216
|
+
description: Optional[str] = None,
|
|
3217
|
+
metadata: Optional[Any] = None,
|
|
3218
|
+
input_schema: Optional[Any] = None,
|
|
3219
|
+
expected_output_schema: Optional[Any] = None,
|
|
3220
|
+
) -> Dataset:
|
|
3221
|
+
"""Create a dataset with the given name on Aeri.
|
|
3222
|
+
|
|
3223
|
+
Args:
|
|
3224
|
+
name: Name of the dataset to create.
|
|
3225
|
+
description: Description of the dataset. Defaults to None.
|
|
3226
|
+
metadata: Additional metadata. Defaults to None.
|
|
3227
|
+
input_schema: JSON Schema for validating dataset item inputs. When set, all new items will be validated against this schema.
|
|
3228
|
+
expected_output_schema: JSON Schema for validating dataset item expected outputs. When set, all new items will be validated against this schema.
|
|
3229
|
+
|
|
3230
|
+
Returns:
|
|
3231
|
+
Dataset: The created dataset as returned by the Aeri API.
|
|
3232
|
+
"""
|
|
3233
|
+
try:
|
|
3234
|
+
aeri_logger.debug(f"Creating datasets {name}")
|
|
3235
|
+
|
|
3236
|
+
result = self.api.datasets.create(
|
|
3237
|
+
name=name,
|
|
3238
|
+
description=description,
|
|
3239
|
+
metadata=metadata,
|
|
3240
|
+
input_schema=input_schema,
|
|
3241
|
+
expected_output_schema=expected_output_schema,
|
|
3242
|
+
)
|
|
3243
|
+
|
|
3244
|
+
return cast(Dataset, result)
|
|
3245
|
+
|
|
3246
|
+
except Error as e:
|
|
3247
|
+
handle_fern_exception(e)
|
|
3248
|
+
raise e
|
|
3249
|
+
|
|
3250
|
+
def create_dataset_item(
|
|
3251
|
+
self,
|
|
3252
|
+
*,
|
|
3253
|
+
dataset_name: str,
|
|
3254
|
+
input: Optional[Any] = None,
|
|
3255
|
+
expected_output: Optional[Any] = None,
|
|
3256
|
+
metadata: Optional[Any] = None,
|
|
3257
|
+
source_trace_id: Optional[str] = None,
|
|
3258
|
+
source_observation_id: Optional[str] = None,
|
|
3259
|
+
status: Optional[DatasetStatus] = None,
|
|
3260
|
+
id: Optional[str] = None,
|
|
3261
|
+
) -> DatasetItem:
|
|
3262
|
+
"""Create a dataset item.
|
|
3263
|
+
|
|
3264
|
+
Upserts if an item with id already exists.
|
|
3265
|
+
|
|
3266
|
+
Args:
|
|
3267
|
+
dataset_name: Name of the dataset in which the dataset item should be created.
|
|
3268
|
+
input: Input data. Defaults to None. Can contain any dict, list or scalar.
|
|
3269
|
+
expected_output: Expected output data. Defaults to None. Can contain any dict, list or scalar.
|
|
3270
|
+
metadata: Additional metadata. Defaults to None. Can contain any dict, list or scalar.
|
|
3271
|
+
source_trace_id: Id of the source trace. Defaults to None.
|
|
3272
|
+
source_observation_id: Id of the source observation. Defaults to None.
|
|
3273
|
+
status: Status of the dataset item. Defaults to ACTIVE for newly created items.
|
|
3274
|
+
id: Id of the dataset item. Defaults to None. Provide your own id if you want to dedupe dataset items. Id needs to be globally unique and cannot be reused across datasets.
|
|
3275
|
+
|
|
3276
|
+
Returns:
|
|
3277
|
+
DatasetItem: The created dataset item as returned by the Aeri API.
|
|
3278
|
+
|
|
3279
|
+
Example:
|
|
3280
|
+
```python
|
|
3281
|
+
from aeri import Aeri
|
|
3282
|
+
|
|
3283
|
+
aeri = Aeri()
|
|
3284
|
+
|
|
3285
|
+
# Uploading items to the Aeri dataset named "capital_cities"
|
|
3286
|
+
aeri.create_dataset_item(
|
|
3287
|
+
dataset_name="capital_cities",
|
|
3288
|
+
input={"input": {"country": "Italy"}},
|
|
3289
|
+
expected_output={"expected_output": "Rome"},
|
|
3290
|
+
metadata={"foo": "bar"}
|
|
3291
|
+
)
|
|
3292
|
+
```
|
|
3293
|
+
"""
|
|
3294
|
+
try:
|
|
3295
|
+
aeri_logger.debug(f"Creating dataset item for dataset {dataset_name}")
|
|
3296
|
+
|
|
3297
|
+
result = self.api.dataset_items.create(
|
|
3298
|
+
dataset_name=dataset_name,
|
|
3299
|
+
input=input,
|
|
3300
|
+
expected_output=expected_output,
|
|
3301
|
+
metadata=metadata,
|
|
3302
|
+
source_trace_id=source_trace_id,
|
|
3303
|
+
source_observation_id=source_observation_id,
|
|
3304
|
+
status=status,
|
|
3305
|
+
id=id,
|
|
3306
|
+
)
|
|
3307
|
+
|
|
3308
|
+
return cast(DatasetItem, result)
|
|
3309
|
+
except Error as e:
|
|
3310
|
+
handle_fern_exception(e)
|
|
3311
|
+
raise e
|
|
3312
|
+
|
|
3313
|
+
def resolve_media_references(
|
|
3314
|
+
self,
|
|
3315
|
+
*,
|
|
3316
|
+
obj: Any,
|
|
3317
|
+
resolve_with: Literal["base64_data_uri"],
|
|
3318
|
+
max_depth: int = 10,
|
|
3319
|
+
content_fetch_timeout_seconds: int = 5,
|
|
3320
|
+
) -> Any:
|
|
3321
|
+
"""Replace media reference strings in an object with base64 data URIs.
|
|
3322
|
+
|
|
3323
|
+
This method recursively traverses an object (up to max_depth) looking for media reference strings
|
|
3324
|
+
in the format "@@@aeriMedia:...@@@". When found, it (synchronously) fetches the actual media content using
|
|
3325
|
+
the provided Aeri client and replaces the reference string with a base64 data URI.
|
|
3326
|
+
|
|
3327
|
+
If fetching media content fails for a reference string, a warning is logged and the reference
|
|
3328
|
+
string is left unchanged.
|
|
3329
|
+
|
|
3330
|
+
Args:
|
|
3331
|
+
obj: The object to process. Can be a primitive value, array, or nested object.
|
|
3332
|
+
If the object has a __dict__ attribute, a dict will be returned instead of the original object type.
|
|
3333
|
+
resolve_with: The representation of the media content to replace the media reference string with.
|
|
3334
|
+
Currently only "base64_data_uri" is supported.
|
|
3335
|
+
max_depth: int: The maximum depth to traverse the object. Default is 10.
|
|
3336
|
+
content_fetch_timeout_seconds: int: The timeout in seconds for fetching media content. Default is 5.
|
|
3337
|
+
|
|
3338
|
+
Returns:
|
|
3339
|
+
A deep copy of the input object with all media references replaced with base64 data URIs where possible.
|
|
3340
|
+
If the input object has a __dict__ attribute, a dict will be returned instead of the original object type.
|
|
3341
|
+
|
|
3342
|
+
Example:
|
|
3343
|
+
obj = {
|
|
3344
|
+
"image": "@@@aeriMedia:type=image/jpeg|id=123|source=bytes@@@",
|
|
3345
|
+
"nested": {
|
|
3346
|
+
"pdf": "@@@aeriMedia:type=application/pdf|id=456|source=bytes@@@"
|
|
3347
|
+
}
|
|
3348
|
+
}
|
|
3349
|
+
|
|
3350
|
+
result = await AeriMedia.resolve_media_references(obj, aeri_client)
|
|
3351
|
+
|
|
3352
|
+
# Result:
|
|
3353
|
+
# {
|
|
3354
|
+
# "image": "data:image/jpeg;base64,/9j/4AAQSkZJRg...",
|
|
3355
|
+
# "nested": {
|
|
3356
|
+
# "pdf": "data:application/pdf;base64,JVBERi0xLjcK..."
|
|
3357
|
+
# }
|
|
3358
|
+
# }
|
|
3359
|
+
"""
|
|
3360
|
+
return AeriMedia.resolve_media_references(
|
|
3361
|
+
aeri_client=self,
|
|
3362
|
+
obj=obj,
|
|
3363
|
+
resolve_with=resolve_with,
|
|
3364
|
+
max_depth=max_depth,
|
|
3365
|
+
content_fetch_timeout_seconds=content_fetch_timeout_seconds,
|
|
3366
|
+
)
|
|
3367
|
+
|
|
3368
|
+
@overload
|
|
3369
|
+
def get_prompt(
|
|
3370
|
+
self,
|
|
3371
|
+
name: str,
|
|
3372
|
+
*,
|
|
3373
|
+
version: Optional[int] = None,
|
|
3374
|
+
label: Optional[str] = None,
|
|
3375
|
+
type: Literal["chat"],
|
|
3376
|
+
cache_ttl_seconds: Optional[int] = None,
|
|
3377
|
+
fallback: Optional[List[ChatMessageDict]] = None,
|
|
3378
|
+
max_retries: Optional[int] = None,
|
|
3379
|
+
fetch_timeout_seconds: Optional[int] = None,
|
|
3380
|
+
) -> ChatPromptClient: ...
|
|
3381
|
+
|
|
3382
|
+
@overload
|
|
3383
|
+
def get_prompt(
|
|
3384
|
+
self,
|
|
3385
|
+
name: str,
|
|
3386
|
+
*,
|
|
3387
|
+
version: Optional[int] = None,
|
|
3388
|
+
label: Optional[str] = None,
|
|
3389
|
+
type: Literal["text"] = "text",
|
|
3390
|
+
cache_ttl_seconds: Optional[int] = None,
|
|
3391
|
+
fallback: Optional[str] = None,
|
|
3392
|
+
max_retries: Optional[int] = None,
|
|
3393
|
+
fetch_timeout_seconds: Optional[int] = None,
|
|
3394
|
+
) -> TextPromptClient: ...
|
|
3395
|
+
|
|
3396
|
+
def get_prompt(
|
|
3397
|
+
self,
|
|
3398
|
+
name: str,
|
|
3399
|
+
*,
|
|
3400
|
+
version: Optional[int] = None,
|
|
3401
|
+
label: Optional[str] = None,
|
|
3402
|
+
type: Literal["chat", "text"] = "text",
|
|
3403
|
+
cache_ttl_seconds: Optional[int] = None,
|
|
3404
|
+
fallback: Union[Optional[List[ChatMessageDict]], Optional[str]] = None,
|
|
3405
|
+
max_retries: Optional[int] = None,
|
|
3406
|
+
fetch_timeout_seconds: Optional[int] = None,
|
|
3407
|
+
) -> PromptClient:
|
|
3408
|
+
"""Get a prompt.
|
|
3409
|
+
|
|
3410
|
+
This method attempts to fetch the requested prompt from the local cache. If the prompt is not found
|
|
3411
|
+
in the cache or if the cached prompt has expired, it will try to fetch the prompt from the server again
|
|
3412
|
+
and update the cache. If fetching the new prompt fails, and there is an expired prompt in the cache, it will
|
|
3413
|
+
return the expired prompt as a fallback.
|
|
3414
|
+
|
|
3415
|
+
Args:
|
|
3416
|
+
name (str): The name of the prompt to retrieve.
|
|
3417
|
+
|
|
3418
|
+
Keyword Args:
|
|
3419
|
+
version (Optional[int]): The version of the prompt to retrieve. If no label and version is specified, the `production` label is returned. Specify either version or label, not both.
|
|
3420
|
+
label: Optional[str]: The label of the prompt to retrieve. If no label and version is specified, the `production` label is returned. Specify either version or label, not both.
|
|
3421
|
+
cache_ttl_seconds: Optional[int]: Time-to-live in seconds for caching the prompt. Must be specified as a
|
|
3422
|
+
keyword argument. If not set, defaults to 60 seconds. Disables caching if set to 0.
|
|
3423
|
+
type: Literal["chat", "text"]: The type of the prompt to retrieve. Defaults to "text".
|
|
3424
|
+
fallback: Union[Optional[List[ChatMessageDict]], Optional[str]]: The prompt string to return if fetching the prompt fails. Important on the first call where no cached prompt is available. Follows Aeri prompt formatting with double curly braces for variables. Defaults to None.
|
|
3425
|
+
max_retries: Optional[int]: The maximum number of retries in case of API/network errors. Defaults to 2. The maximum value is 4. Retries have an exponential backoff with a maximum delay of 10 seconds.
|
|
3426
|
+
fetch_timeout_seconds: Optional[int]: The timeout in milliseconds for fetching the prompt. Defaults to the default timeout set on the SDK, which is 5 seconds per default.
|
|
3427
|
+
|
|
3428
|
+
Returns:
|
|
3429
|
+
The prompt object retrieved from the cache or directly fetched if not cached or expired of type
|
|
3430
|
+
- TextPromptClient, if type argument is 'text'.
|
|
3431
|
+
- ChatPromptClient, if type argument is 'chat'.
|
|
3432
|
+
|
|
3433
|
+
Raises:
|
|
3434
|
+
Exception: Propagates any exceptions raised during the fetching of a new prompt, unless there is an
|
|
3435
|
+
expired prompt in the cache, in which case it logs a warning and returns the expired prompt.
|
|
3436
|
+
"""
|
|
3437
|
+
if self._resources is None:
|
|
3438
|
+
raise Error(
|
|
3439
|
+
"SDK is not correctly initialized. Check the init logs for more details."
|
|
3440
|
+
)
|
|
3441
|
+
if version is not None and label is not None:
|
|
3442
|
+
raise ValueError("Cannot specify both version and label at the same time.")
|
|
3443
|
+
|
|
3444
|
+
if not name:
|
|
3445
|
+
raise ValueError("Prompt name cannot be empty.")
|
|
3446
|
+
|
|
3447
|
+
cache_key = PromptCache.generate_cache_key(name, version=version, label=label)
|
|
3448
|
+
bounded_max_retries = self._get_bounded_max_retries(
|
|
3449
|
+
max_retries, default_max_retries=2, max_retries_upper_bound=4
|
|
3450
|
+
)
|
|
3451
|
+
|
|
3452
|
+
aeri_logger.debug(f"Getting prompt '{cache_key}'")
|
|
3453
|
+
cached_prompt = self._resources.prompt_cache.get(cache_key)
|
|
3454
|
+
|
|
3455
|
+
if cached_prompt is None or cache_ttl_seconds == 0:
|
|
3456
|
+
aeri_logger.debug(
|
|
3457
|
+
f"Prompt '{cache_key}' not found in cache or caching disabled."
|
|
3458
|
+
)
|
|
3459
|
+
try:
|
|
3460
|
+
return self._fetch_prompt_and_update_cache(
|
|
3461
|
+
name,
|
|
3462
|
+
version=version,
|
|
3463
|
+
label=label,
|
|
3464
|
+
ttl_seconds=cache_ttl_seconds,
|
|
3465
|
+
max_retries=bounded_max_retries,
|
|
3466
|
+
fetch_timeout_seconds=fetch_timeout_seconds,
|
|
3467
|
+
)
|
|
3468
|
+
except Exception as e:
|
|
3469
|
+
if fallback:
|
|
3470
|
+
aeri_logger.warning(
|
|
3471
|
+
f"Returning fallback prompt for '{cache_key}' due to fetch error: {e}"
|
|
3472
|
+
)
|
|
3473
|
+
|
|
3474
|
+
fallback_client_args: Dict[str, Any] = {
|
|
3475
|
+
"name": name,
|
|
3476
|
+
"prompt": fallback,
|
|
3477
|
+
"type": type,
|
|
3478
|
+
"version": version or 0,
|
|
3479
|
+
"config": {},
|
|
3480
|
+
"labels": [label] if label else [],
|
|
3481
|
+
"tags": [],
|
|
3482
|
+
}
|
|
3483
|
+
|
|
3484
|
+
if type == "text":
|
|
3485
|
+
return TextPromptClient(
|
|
3486
|
+
prompt=Prompt_Text(**fallback_client_args),
|
|
3487
|
+
is_fallback=True,
|
|
3488
|
+
)
|
|
3489
|
+
|
|
3490
|
+
if type == "chat":
|
|
3491
|
+
return ChatPromptClient(
|
|
3492
|
+
prompt=Prompt_Chat(**fallback_client_args),
|
|
3493
|
+
is_fallback=True,
|
|
3494
|
+
)
|
|
3495
|
+
|
|
3496
|
+
raise e
|
|
3497
|
+
|
|
3498
|
+
if cached_prompt.is_expired():
|
|
3499
|
+
aeri_logger.debug(f"Stale prompt '{cache_key}' found in cache.")
|
|
3500
|
+
try:
|
|
3501
|
+
# refresh prompt in background thread, refresh_prompt deduplicates tasks
|
|
3502
|
+
aeri_logger.debug(f"Refreshing prompt '{cache_key}' in background.")
|
|
3503
|
+
|
|
3504
|
+
def refresh_task() -> None:
|
|
3505
|
+
self._fetch_prompt_and_update_cache(
|
|
3506
|
+
name,
|
|
3507
|
+
version=version,
|
|
3508
|
+
label=label,
|
|
3509
|
+
ttl_seconds=cache_ttl_seconds,
|
|
3510
|
+
max_retries=bounded_max_retries,
|
|
3511
|
+
fetch_timeout_seconds=fetch_timeout_seconds,
|
|
3512
|
+
)
|
|
3513
|
+
|
|
3514
|
+
self._resources.prompt_cache.add_refresh_prompt_task(
|
|
3515
|
+
cache_key,
|
|
3516
|
+
refresh_task,
|
|
3517
|
+
)
|
|
3518
|
+
aeri_logger.debug(
|
|
3519
|
+
f"Returning stale prompt '{cache_key}' from cache."
|
|
3520
|
+
)
|
|
3521
|
+
# return stale prompt
|
|
3522
|
+
return cached_prompt.value
|
|
3523
|
+
|
|
3524
|
+
except Exception as e:
|
|
3525
|
+
aeri_logger.warning(
|
|
3526
|
+
f"Error when refreshing cached prompt '{cache_key}', returning cached version. Error: {e}"
|
|
3527
|
+
)
|
|
3528
|
+
# creation of refresh prompt task failed, return stale prompt
|
|
3529
|
+
return cached_prompt.value
|
|
3530
|
+
|
|
3531
|
+
return cached_prompt.value
|
|
3532
|
+
|
|
3533
|
+
def _fetch_prompt_and_update_cache(
|
|
3534
|
+
self,
|
|
3535
|
+
name: str,
|
|
3536
|
+
*,
|
|
3537
|
+
version: Optional[int] = None,
|
|
3538
|
+
label: Optional[str] = None,
|
|
3539
|
+
ttl_seconds: Optional[int] = None,
|
|
3540
|
+
max_retries: int,
|
|
3541
|
+
fetch_timeout_seconds: Optional[int],
|
|
3542
|
+
) -> PromptClient:
|
|
3543
|
+
cache_key = PromptCache.generate_cache_key(name, version=version, label=label)
|
|
3544
|
+
aeri_logger.debug(f"Fetching prompt '{cache_key}' from server...")
|
|
3545
|
+
|
|
3546
|
+
try:
|
|
3547
|
+
|
|
3548
|
+
@backoff.on_exception(
|
|
3549
|
+
backoff.constant, Exception, max_tries=max_retries + 1, logger=None
|
|
3550
|
+
)
|
|
3551
|
+
def fetch_prompts() -> Any:
|
|
3552
|
+
return self.api.prompts.get(
|
|
3553
|
+
self._url_encode(name),
|
|
3554
|
+
version=version,
|
|
3555
|
+
label=label,
|
|
3556
|
+
request_options={
|
|
3557
|
+
"timeout_in_seconds": fetch_timeout_seconds,
|
|
3558
|
+
}
|
|
3559
|
+
if fetch_timeout_seconds is not None
|
|
3560
|
+
else None,
|
|
3561
|
+
)
|
|
3562
|
+
|
|
3563
|
+
prompt_response = fetch_prompts()
|
|
3564
|
+
|
|
3565
|
+
prompt: PromptClient
|
|
3566
|
+
if prompt_response.type == "chat":
|
|
3567
|
+
prompt = ChatPromptClient(prompt_response)
|
|
3568
|
+
else:
|
|
3569
|
+
prompt = TextPromptClient(prompt_response)
|
|
3570
|
+
|
|
3571
|
+
if self._resources is not None:
|
|
3572
|
+
self._resources.prompt_cache.set(cache_key, prompt, ttl_seconds)
|
|
3573
|
+
|
|
3574
|
+
return prompt
|
|
3575
|
+
|
|
3576
|
+
except NotFoundError as not_found_error:
|
|
3577
|
+
aeri_logger.warning(
|
|
3578
|
+
f"Prompt '{cache_key}' not found during refresh, evicting from cache."
|
|
3579
|
+
)
|
|
3580
|
+
if self._resources is not None:
|
|
3581
|
+
self._resources.prompt_cache.delete(cache_key)
|
|
3582
|
+
raise not_found_error
|
|
3583
|
+
|
|
3584
|
+
except Exception as e:
|
|
3585
|
+
aeri_logger.error(
|
|
3586
|
+
f"Error while fetching prompt '{cache_key}': {str(e)}"
|
|
3587
|
+
)
|
|
3588
|
+
raise e
|
|
3589
|
+
|
|
3590
|
+
def _get_bounded_max_retries(
|
|
3591
|
+
self,
|
|
3592
|
+
max_retries: Optional[int],
|
|
3593
|
+
*,
|
|
3594
|
+
default_max_retries: int = 2,
|
|
3595
|
+
max_retries_upper_bound: int = 4,
|
|
3596
|
+
) -> int:
|
|
3597
|
+
if max_retries is None:
|
|
3598
|
+
return default_max_retries
|
|
3599
|
+
|
|
3600
|
+
bounded_max_retries = min(
|
|
3601
|
+
max(max_retries, 0),
|
|
3602
|
+
max_retries_upper_bound,
|
|
3603
|
+
)
|
|
3604
|
+
|
|
3605
|
+
return bounded_max_retries
|
|
3606
|
+
|
|
3607
|
+
@overload
|
|
3608
|
+
def create_prompt(
|
|
3609
|
+
self,
|
|
3610
|
+
*,
|
|
3611
|
+
name: str,
|
|
3612
|
+
prompt: List[Union[ChatMessageDict, ChatMessageWithPlaceholdersDict]],
|
|
3613
|
+
labels: List[str] = [],
|
|
3614
|
+
tags: Optional[List[str]] = None,
|
|
3615
|
+
type: Optional[Literal["chat"]],
|
|
3616
|
+
config: Optional[Any] = None,
|
|
3617
|
+
commit_message: Optional[str] = None,
|
|
3618
|
+
) -> ChatPromptClient: ...
|
|
3619
|
+
|
|
3620
|
+
@overload
|
|
3621
|
+
def create_prompt(
|
|
3622
|
+
self,
|
|
3623
|
+
*,
|
|
3624
|
+
name: str,
|
|
3625
|
+
prompt: str,
|
|
3626
|
+
labels: List[str] = [],
|
|
3627
|
+
tags: Optional[List[str]] = None,
|
|
3628
|
+
type: Optional[Literal["text"]] = "text",
|
|
3629
|
+
config: Optional[Any] = None,
|
|
3630
|
+
commit_message: Optional[str] = None,
|
|
3631
|
+
) -> TextPromptClient: ...
|
|
3632
|
+
|
|
3633
|
+
def create_prompt(
|
|
3634
|
+
self,
|
|
3635
|
+
*,
|
|
3636
|
+
name: str,
|
|
3637
|
+
prompt: Union[
|
|
3638
|
+
str, List[Union[ChatMessageDict, ChatMessageWithPlaceholdersDict]]
|
|
3639
|
+
],
|
|
3640
|
+
labels: List[str] = [],
|
|
3641
|
+
tags: Optional[List[str]] = None,
|
|
3642
|
+
type: Optional[Literal["chat", "text"]] = "text",
|
|
3643
|
+
config: Optional[Any] = None,
|
|
3644
|
+
commit_message: Optional[str] = None,
|
|
3645
|
+
) -> PromptClient:
|
|
3646
|
+
"""Create a new prompt in Aeri.
|
|
3647
|
+
|
|
3648
|
+
Keyword Args:
|
|
3649
|
+
name : The name of the prompt to be created.
|
|
3650
|
+
prompt : The content of the prompt to be created.
|
|
3651
|
+
is_active [DEPRECATED] : A flag indicating whether the prompt is active or not. This is deprecated and will be removed in a future release. Please use the 'production' label instead.
|
|
3652
|
+
labels: The labels of the prompt. Defaults to None. To create a default-served prompt, add the 'production' label.
|
|
3653
|
+
tags: The tags of the prompt. Defaults to None. Will be applied to all versions of the prompt.
|
|
3654
|
+
config: Additional structured data to be saved with the prompt. Defaults to None.
|
|
3655
|
+
type: The type of the prompt to be created. "chat" vs. "text". Defaults to "text".
|
|
3656
|
+
commit_message: Optional string describing the change.
|
|
3657
|
+
|
|
3658
|
+
Returns:
|
|
3659
|
+
TextPromptClient: The prompt if type argument is 'text'.
|
|
3660
|
+
ChatPromptClient: The prompt if type argument is 'chat'.
|
|
3661
|
+
"""
|
|
3662
|
+
try:
|
|
3663
|
+
aeri_logger.debug(f"Creating prompt {name=}, {labels=}")
|
|
3664
|
+
|
|
3665
|
+
if type == "chat":
|
|
3666
|
+
if not isinstance(prompt, list):
|
|
3667
|
+
raise ValueError(
|
|
3668
|
+
"For 'chat' type, 'prompt' must be a list of chat messages with role and content attributes."
|
|
3669
|
+
)
|
|
3670
|
+
request: Union[CreateChatPromptRequest, CreateTextPromptRequest] = (
|
|
3671
|
+
CreateChatPromptRequest(
|
|
3672
|
+
name=name,
|
|
3673
|
+
prompt=cast(Any, prompt),
|
|
3674
|
+
labels=labels,
|
|
3675
|
+
tags=tags,
|
|
3676
|
+
config=config or {},
|
|
3677
|
+
commit_message=commit_message,
|
|
3678
|
+
type=CreateChatPromptType.CHAT,
|
|
3679
|
+
)
|
|
3680
|
+
)
|
|
3681
|
+
server_prompt = self.api.prompts.create(request=request)
|
|
3682
|
+
|
|
3683
|
+
if self._resources is not None:
|
|
3684
|
+
self._resources.prompt_cache.invalidate(name)
|
|
3685
|
+
|
|
3686
|
+
return ChatPromptClient(prompt=cast(Prompt_Chat, server_prompt))
|
|
3687
|
+
|
|
3688
|
+
if not isinstance(prompt, str):
|
|
3689
|
+
raise ValueError("For 'text' type, 'prompt' must be a string.")
|
|
3690
|
+
|
|
3691
|
+
request = CreateTextPromptRequest(
|
|
3692
|
+
name=name,
|
|
3693
|
+
prompt=prompt,
|
|
3694
|
+
labels=labels,
|
|
3695
|
+
tags=tags,
|
|
3696
|
+
config=config or {},
|
|
3697
|
+
commit_message=commit_message,
|
|
3698
|
+
)
|
|
3699
|
+
|
|
3700
|
+
server_prompt = self.api.prompts.create(request=request)
|
|
3701
|
+
|
|
3702
|
+
if self._resources is not None:
|
|
3703
|
+
self._resources.prompt_cache.invalidate(name)
|
|
3704
|
+
|
|
3705
|
+
return TextPromptClient(prompt=cast(Prompt_Text, server_prompt))
|
|
3706
|
+
|
|
3707
|
+
except Error as e:
|
|
3708
|
+
handle_fern_exception(e)
|
|
3709
|
+
raise e
|
|
3710
|
+
|
|
3711
|
+
def update_prompt(
|
|
3712
|
+
self,
|
|
3713
|
+
*,
|
|
3714
|
+
name: str,
|
|
3715
|
+
version: int,
|
|
3716
|
+
new_labels: List[str] = [],
|
|
3717
|
+
) -> Any:
|
|
3718
|
+
"""Update an existing prompt version in Aeri. The Aeri SDK prompt cache is invalidated for all prompts witht he specified name.
|
|
3719
|
+
|
|
3720
|
+
Args:
|
|
3721
|
+
name (str): The name of the prompt to update.
|
|
3722
|
+
version (int): The version number of the prompt to update.
|
|
3723
|
+
new_labels (List[str], optional): New labels to assign to the prompt version. Labels are unique across versions. The "latest" label is reserved and managed by Aeri. Defaults to [].
|
|
3724
|
+
|
|
3725
|
+
Returns:
|
|
3726
|
+
Prompt: The updated prompt from the Aeri API.
|
|
3727
|
+
|
|
3728
|
+
"""
|
|
3729
|
+
updated_prompt = self.api.prompt_version.update(
|
|
3730
|
+
name=self._url_encode(name),
|
|
3731
|
+
version=version,
|
|
3732
|
+
new_labels=new_labels,
|
|
3733
|
+
)
|
|
3734
|
+
|
|
3735
|
+
if self._resources is not None:
|
|
3736
|
+
self._resources.prompt_cache.invalidate(name)
|
|
3737
|
+
|
|
3738
|
+
return updated_prompt
|
|
3739
|
+
|
|
3740
|
+
def _url_encode(self, url: str, *, is_url_param: Optional[bool] = False) -> str:
|
|
3741
|
+
# httpx ≥ 0.28 does its own WHATWG-compliant quoting (eg. encodes bare
|
|
3742
|
+
# “%”, “?”, “#”, “|”, … in query/path parts). Re-quoting here would
|
|
3743
|
+
# double-encode, so we skip when the value is about to be sent straight
|
|
3744
|
+
# to httpx (`is_url_param=True`) and the installed version is ≥ 0.28.
|
|
3745
|
+
if is_url_param and Version(httpx.__version__) >= Version("0.28.0"):
|
|
3746
|
+
return url
|
|
3747
|
+
|
|
3748
|
+
# urllib.parse.quote does not escape slashes "/" by default; we need to add safe="" to force escaping
|
|
3749
|
+
# we need add safe="" to force escaping of slashes
|
|
3750
|
+
# This is necessary for prompts in prompt folders
|
|
3751
|
+
return urllib.parse.quote(url, safe="")
|
|
3752
|
+
|
|
3753
|
+
def clear_prompt_cache(self) -> None:
|
|
3754
|
+
"""Clear the entire prompt cache, removing all cached prompts.
|
|
3755
|
+
|
|
3756
|
+
This method is useful when you want to force a complete refresh of all
|
|
3757
|
+
cached prompts, for example after major updates or when you need to
|
|
3758
|
+
ensure the latest versions are fetched from the server.
|
|
3759
|
+
"""
|
|
3760
|
+
if self._resources is not None:
|
|
3761
|
+
self._resources.prompt_cache.clear()
|