aeri-python 4.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aeri/__init__.py +72 -0
- aeri/_client/_validation.py +204 -0
- aeri/_client/attributes.py +188 -0
- aeri/_client/client.py +3761 -0
- aeri/_client/constants.py +65 -0
- aeri/_client/datasets.py +302 -0
- aeri/_client/environment_variables.py +158 -0
- aeri/_client/get_client.py +149 -0
- aeri/_client/observe.py +661 -0
- aeri/_client/propagation.py +475 -0
- aeri/_client/resource_manager.py +510 -0
- aeri/_client/span.py +1519 -0
- aeri/_client/span_filter.py +76 -0
- aeri/_client/span_processor.py +206 -0
- aeri/_client/utils.py +132 -0
- aeri/_task_manager/media_manager.py +331 -0
- aeri/_task_manager/media_upload_consumer.py +44 -0
- aeri/_task_manager/media_upload_queue.py +12 -0
- aeri/_task_manager/score_ingestion_consumer.py +208 -0
- aeri/_task_manager/task_manager.py +475 -0
- aeri/_utils/__init__.py +19 -0
- aeri/_utils/environment.py +34 -0
- aeri/_utils/error_logging.py +47 -0
- aeri/_utils/parse_error.py +99 -0
- aeri/_utils/prompt_cache.py +188 -0
- aeri/_utils/request.py +137 -0
- aeri/_utils/serializer.py +205 -0
- aeri/api/.fern/metadata.json +14 -0
- aeri/api/__init__.py +836 -0
- aeri/api/annotation_queues/__init__.py +82 -0
- aeri/api/annotation_queues/client.py +1111 -0
- aeri/api/annotation_queues/raw_client.py +2288 -0
- aeri/api/annotation_queues/types/__init__.py +84 -0
- aeri/api/annotation_queues/types/annotation_queue.py +28 -0
- aeri/api/annotation_queues/types/annotation_queue_assignment_request.py +16 -0
- aeri/api/annotation_queues/types/annotation_queue_item.py +34 -0
- aeri/api/annotation_queues/types/annotation_queue_object_type.py +26 -0
- aeri/api/annotation_queues/types/annotation_queue_status.py +22 -0
- aeri/api/annotation_queues/types/create_annotation_queue_assignment_response.py +18 -0
- aeri/api/annotation_queues/types/create_annotation_queue_item_request.py +25 -0
- aeri/api/annotation_queues/types/create_annotation_queue_request.py +20 -0
- aeri/api/annotation_queues/types/delete_annotation_queue_assignment_response.py +14 -0
- aeri/api/annotation_queues/types/delete_annotation_queue_item_response.py +15 -0
- aeri/api/annotation_queues/types/paginated_annotation_queue_items.py +17 -0
- aeri/api/annotation_queues/types/paginated_annotation_queues.py +17 -0
- aeri/api/annotation_queues/types/update_annotation_queue_item_request.py +15 -0
- aeri/api/blob_storage_integrations/__init__.py +73 -0
- aeri/api/blob_storage_integrations/client.py +550 -0
- aeri/api/blob_storage_integrations/raw_client.py +976 -0
- aeri/api/blob_storage_integrations/types/__init__.py +77 -0
- aeri/api/blob_storage_integrations/types/blob_storage_export_frequency.py +26 -0
- aeri/api/blob_storage_integrations/types/blob_storage_export_mode.py +26 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integration_deletion_response.py +14 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integration_file_type.py +26 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integration_response.py +64 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integration_status_response.py +50 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integration_type.py +26 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integrations_response.py +15 -0
- aeri/api/blob_storage_integrations/types/blob_storage_sync_status.py +47 -0
- aeri/api/blob_storage_integrations/types/create_blob_storage_integration_request.py +91 -0
- aeri/api/client.py +679 -0
- aeri/api/comments/__init__.py +44 -0
- aeri/api/comments/client.py +407 -0
- aeri/api/comments/raw_client.py +750 -0
- aeri/api/comments/types/__init__.py +46 -0
- aeri/api/comments/types/create_comment_request.py +47 -0
- aeri/api/comments/types/create_comment_response.py +17 -0
- aeri/api/comments/types/get_comments_response.py +17 -0
- aeri/api/commons/__init__.py +210 -0
- aeri/api/commons/errors/__init__.py +56 -0
- aeri/api/commons/errors/access_denied_error.py +12 -0
- aeri/api/commons/errors/error.py +12 -0
- aeri/api/commons/errors/method_not_allowed_error.py +12 -0
- aeri/api/commons/errors/not_found_error.py +12 -0
- aeri/api/commons/errors/unauthorized_error.py +12 -0
- aeri/api/commons/types/__init__.py +190 -0
- aeri/api/commons/types/base_score.py +90 -0
- aeri/api/commons/types/base_score_v1.py +70 -0
- aeri/api/commons/types/boolean_score.py +26 -0
- aeri/api/commons/types/boolean_score_v1.py +26 -0
- aeri/api/commons/types/categorical_score.py +26 -0
- aeri/api/commons/types/categorical_score_v1.py +26 -0
- aeri/api/commons/types/comment.py +36 -0
- aeri/api/commons/types/comment_object_type.py +30 -0
- aeri/api/commons/types/config_category.py +15 -0
- aeri/api/commons/types/correction_score.py +26 -0
- aeri/api/commons/types/create_score_value.py +5 -0
- aeri/api/commons/types/dataset.py +49 -0
- aeri/api/commons/types/dataset_item.py +58 -0
- aeri/api/commons/types/dataset_run.py +63 -0
- aeri/api/commons/types/dataset_run_item.py +40 -0
- aeri/api/commons/types/dataset_run_with_items.py +19 -0
- aeri/api/commons/types/dataset_status.py +22 -0
- aeri/api/commons/types/map_value.py +11 -0
- aeri/api/commons/types/model.py +125 -0
- aeri/api/commons/types/model_price.py +14 -0
- aeri/api/commons/types/model_usage_unit.py +42 -0
- aeri/api/commons/types/numeric_score.py +17 -0
- aeri/api/commons/types/numeric_score_v1.py +17 -0
- aeri/api/commons/types/observation.py +142 -0
- aeri/api/commons/types/observation_level.py +30 -0
- aeri/api/commons/types/observation_v2.py +235 -0
- aeri/api/commons/types/observations_view.py +89 -0
- aeri/api/commons/types/pricing_tier.py +91 -0
- aeri/api/commons/types/pricing_tier_condition.py +68 -0
- aeri/api/commons/types/pricing_tier_input.py +76 -0
- aeri/api/commons/types/pricing_tier_operator.py +42 -0
- aeri/api/commons/types/score.py +201 -0
- aeri/api/commons/types/score_config.py +66 -0
- aeri/api/commons/types/score_config_data_type.py +26 -0
- aeri/api/commons/types/score_data_type.py +30 -0
- aeri/api/commons/types/score_source.py +26 -0
- aeri/api/commons/types/score_v1.py +131 -0
- aeri/api/commons/types/session.py +25 -0
- aeri/api/commons/types/session_with_traces.py +15 -0
- aeri/api/commons/types/trace.py +84 -0
- aeri/api/commons/types/trace_with_details.py +43 -0
- aeri/api/commons/types/trace_with_full_details.py +45 -0
- aeri/api/commons/types/usage.py +59 -0
- aeri/api/core/__init__.py +111 -0
- aeri/api/core/api_error.py +23 -0
- aeri/api/core/client_wrapper.py +141 -0
- aeri/api/core/datetime_utils.py +30 -0
- aeri/api/core/enum.py +20 -0
- aeri/api/core/file.py +70 -0
- aeri/api/core/force_multipart.py +18 -0
- aeri/api/core/http_client.py +711 -0
- aeri/api/core/http_response.py +55 -0
- aeri/api/core/http_sse/__init__.py +48 -0
- aeri/api/core/http_sse/_api.py +114 -0
- aeri/api/core/http_sse/_decoders.py +66 -0
- aeri/api/core/http_sse/_exceptions.py +7 -0
- aeri/api/core/http_sse/_models.py +17 -0
- aeri/api/core/jsonable_encoder.py +102 -0
- aeri/api/core/pydantic_utilities.py +310 -0
- aeri/api/core/query_encoder.py +60 -0
- aeri/api/core/remove_none_from_dict.py +11 -0
- aeri/api/core/request_options.py +35 -0
- aeri/api/core/serialization.py +282 -0
- aeri/api/dataset_items/__init__.py +52 -0
- aeri/api/dataset_items/client.py +499 -0
- aeri/api/dataset_items/raw_client.py +973 -0
- aeri/api/dataset_items/types/__init__.py +50 -0
- aeri/api/dataset_items/types/create_dataset_item_request.py +37 -0
- aeri/api/dataset_items/types/delete_dataset_item_response.py +17 -0
- aeri/api/dataset_items/types/paginated_dataset_items.py +17 -0
- aeri/api/dataset_run_items/__init__.py +43 -0
- aeri/api/dataset_run_items/client.py +323 -0
- aeri/api/dataset_run_items/raw_client.py +547 -0
- aeri/api/dataset_run_items/types/__init__.py +44 -0
- aeri/api/dataset_run_items/types/create_dataset_run_item_request.py +51 -0
- aeri/api/dataset_run_items/types/paginated_dataset_run_items.py +17 -0
- aeri/api/datasets/__init__.py +55 -0
- aeri/api/datasets/client.py +661 -0
- aeri/api/datasets/raw_client.py +1368 -0
- aeri/api/datasets/types/__init__.py +53 -0
- aeri/api/datasets/types/create_dataset_request.py +31 -0
- aeri/api/datasets/types/delete_dataset_run_response.py +14 -0
- aeri/api/datasets/types/paginated_dataset_runs.py +17 -0
- aeri/api/datasets/types/paginated_datasets.py +17 -0
- aeri/api/health/__init__.py +44 -0
- aeri/api/health/client.py +112 -0
- aeri/api/health/errors/__init__.py +42 -0
- aeri/api/health/errors/service_unavailable_error.py +13 -0
- aeri/api/health/raw_client.py +227 -0
- aeri/api/health/types/__init__.py +40 -0
- aeri/api/health/types/health_response.py +30 -0
- aeri/api/ingestion/__init__.py +169 -0
- aeri/api/ingestion/client.py +221 -0
- aeri/api/ingestion/raw_client.py +293 -0
- aeri/api/ingestion/types/__init__.py +169 -0
- aeri/api/ingestion/types/base_event.py +27 -0
- aeri/api/ingestion/types/create_event_body.py +14 -0
- aeri/api/ingestion/types/create_event_event.py +15 -0
- aeri/api/ingestion/types/create_generation_body.py +40 -0
- aeri/api/ingestion/types/create_generation_event.py +15 -0
- aeri/api/ingestion/types/create_observation_event.py +15 -0
- aeri/api/ingestion/types/create_span_body.py +19 -0
- aeri/api/ingestion/types/create_span_event.py +15 -0
- aeri/api/ingestion/types/ingestion_error.py +17 -0
- aeri/api/ingestion/types/ingestion_event.py +155 -0
- aeri/api/ingestion/types/ingestion_response.py +17 -0
- aeri/api/ingestion/types/ingestion_success.py +15 -0
- aeri/api/ingestion/types/ingestion_usage.py +8 -0
- aeri/api/ingestion/types/observation_body.py +53 -0
- aeri/api/ingestion/types/observation_type.py +54 -0
- aeri/api/ingestion/types/open_ai_completion_usage_schema.py +26 -0
- aeri/api/ingestion/types/open_ai_response_usage_schema.py +24 -0
- aeri/api/ingestion/types/open_ai_usage.py +28 -0
- aeri/api/ingestion/types/optional_observation_body.py +36 -0
- aeri/api/ingestion/types/score_body.py +75 -0
- aeri/api/ingestion/types/score_event.py +15 -0
- aeri/api/ingestion/types/sdk_log_body.py +14 -0
- aeri/api/ingestion/types/sdk_log_event.py +15 -0
- aeri/api/ingestion/types/trace_body.py +36 -0
- aeri/api/ingestion/types/trace_event.py +15 -0
- aeri/api/ingestion/types/update_event_body.py +14 -0
- aeri/api/ingestion/types/update_generation_body.py +40 -0
- aeri/api/ingestion/types/update_generation_event.py +15 -0
- aeri/api/ingestion/types/update_observation_event.py +15 -0
- aeri/api/ingestion/types/update_span_body.py +19 -0
- aeri/api/ingestion/types/update_span_event.py +15 -0
- aeri/api/ingestion/types/usage_details.py +10 -0
- aeri/api/legacy/__init__.py +61 -0
- aeri/api/legacy/client.py +105 -0
- aeri/api/legacy/metrics_v1/__init__.py +40 -0
- aeri/api/legacy/metrics_v1/client.py +214 -0
- aeri/api/legacy/metrics_v1/raw_client.py +322 -0
- aeri/api/legacy/metrics_v1/types/__init__.py +40 -0
- aeri/api/legacy/metrics_v1/types/metrics_response.py +19 -0
- aeri/api/legacy/observations_v1/__init__.py +43 -0
- aeri/api/legacy/observations_v1/client.py +523 -0
- aeri/api/legacy/observations_v1/raw_client.py +759 -0
- aeri/api/legacy/observations_v1/types/__init__.py +44 -0
- aeri/api/legacy/observations_v1/types/observations.py +17 -0
- aeri/api/legacy/observations_v1/types/observations_views.py +17 -0
- aeri/api/legacy/raw_client.py +13 -0
- aeri/api/legacy/score_v1/__init__.py +43 -0
- aeri/api/legacy/score_v1/client.py +329 -0
- aeri/api/legacy/score_v1/raw_client.py +545 -0
- aeri/api/legacy/score_v1/types/__init__.py +44 -0
- aeri/api/legacy/score_v1/types/create_score_request.py +75 -0
- aeri/api/legacy/score_v1/types/create_score_response.py +17 -0
- aeri/api/llm_connections/__init__.py +55 -0
- aeri/api/llm_connections/client.py +311 -0
- aeri/api/llm_connections/raw_client.py +541 -0
- aeri/api/llm_connections/types/__init__.py +53 -0
- aeri/api/llm_connections/types/llm_adapter.py +38 -0
- aeri/api/llm_connections/types/llm_connection.py +77 -0
- aeri/api/llm_connections/types/paginated_llm_connections.py +17 -0
- aeri/api/llm_connections/types/upsert_llm_connection_request.py +69 -0
- aeri/api/media/__init__.py +58 -0
- aeri/api/media/client.py +427 -0
- aeri/api/media/raw_client.py +739 -0
- aeri/api/media/types/__init__.py +56 -0
- aeri/api/media/types/get_media_response.py +55 -0
- aeri/api/media/types/get_media_upload_url_request.py +51 -0
- aeri/api/media/types/get_media_upload_url_response.py +28 -0
- aeri/api/media/types/media_content_type.py +232 -0
- aeri/api/media/types/patch_media_body.py +43 -0
- aeri/api/metrics/__init__.py +40 -0
- aeri/api/metrics/client.py +422 -0
- aeri/api/metrics/raw_client.py +530 -0
- aeri/api/metrics/types/__init__.py +40 -0
- aeri/api/metrics/types/metrics_v2response.py +19 -0
- aeri/api/models/__init__.py +43 -0
- aeri/api/models/client.py +523 -0
- aeri/api/models/raw_client.py +993 -0
- aeri/api/models/types/__init__.py +44 -0
- aeri/api/models/types/create_model_request.py +103 -0
- aeri/api/models/types/paginated_models.py +17 -0
- aeri/api/observations/__init__.py +43 -0
- aeri/api/observations/client.py +522 -0
- aeri/api/observations/raw_client.py +641 -0
- aeri/api/observations/types/__init__.py +44 -0
- aeri/api/observations/types/observations_v2meta.py +21 -0
- aeri/api/observations/types/observations_v2response.py +28 -0
- aeri/api/opentelemetry/__init__.py +67 -0
- aeri/api/opentelemetry/client.py +276 -0
- aeri/api/opentelemetry/raw_client.py +291 -0
- aeri/api/opentelemetry/types/__init__.py +65 -0
- aeri/api/opentelemetry/types/otel_attribute.py +27 -0
- aeri/api/opentelemetry/types/otel_attribute_value.py +46 -0
- aeri/api/opentelemetry/types/otel_resource.py +24 -0
- aeri/api/opentelemetry/types/otel_resource_span.py +32 -0
- aeri/api/opentelemetry/types/otel_scope.py +34 -0
- aeri/api/opentelemetry/types/otel_scope_span.py +28 -0
- aeri/api/opentelemetry/types/otel_span.py +76 -0
- aeri/api/opentelemetry/types/otel_trace_response.py +16 -0
- aeri/api/organizations/__init__.py +73 -0
- aeri/api/organizations/client.py +756 -0
- aeri/api/organizations/raw_client.py +1707 -0
- aeri/api/organizations/types/__init__.py +71 -0
- aeri/api/organizations/types/delete_membership_request.py +16 -0
- aeri/api/organizations/types/membership_deletion_response.py +17 -0
- aeri/api/organizations/types/membership_request.py +18 -0
- aeri/api/organizations/types/membership_response.py +20 -0
- aeri/api/organizations/types/membership_role.py +30 -0
- aeri/api/organizations/types/memberships_response.py +15 -0
- aeri/api/organizations/types/organization_api_key.py +31 -0
- aeri/api/organizations/types/organization_api_keys_response.py +19 -0
- aeri/api/organizations/types/organization_project.py +25 -0
- aeri/api/organizations/types/organization_projects_response.py +15 -0
- aeri/api/projects/__init__.py +67 -0
- aeri/api/projects/client.py +760 -0
- aeri/api/projects/raw_client.py +1577 -0
- aeri/api/projects/types/__init__.py +65 -0
- aeri/api/projects/types/api_key_deletion_response.py +18 -0
- aeri/api/projects/types/api_key_list.py +23 -0
- aeri/api/projects/types/api_key_response.py +30 -0
- aeri/api/projects/types/api_key_summary.py +35 -0
- aeri/api/projects/types/organization.py +22 -0
- aeri/api/projects/types/project.py +34 -0
- aeri/api/projects/types/project_deletion_response.py +15 -0
- aeri/api/projects/types/projects.py +15 -0
- aeri/api/prompt_version/__init__.py +4 -0
- aeri/api/prompt_version/client.py +157 -0
- aeri/api/prompt_version/raw_client.py +264 -0
- aeri/api/prompts/__init__.py +100 -0
- aeri/api/prompts/client.py +550 -0
- aeri/api/prompts/raw_client.py +987 -0
- aeri/api/prompts/types/__init__.py +96 -0
- aeri/api/prompts/types/base_prompt.py +42 -0
- aeri/api/prompts/types/chat_message.py +17 -0
- aeri/api/prompts/types/chat_message_type.py +15 -0
- aeri/api/prompts/types/chat_message_with_placeholders.py +8 -0
- aeri/api/prompts/types/chat_prompt.py +15 -0
- aeri/api/prompts/types/create_chat_prompt_request.py +37 -0
- aeri/api/prompts/types/create_chat_prompt_type.py +15 -0
- aeri/api/prompts/types/create_prompt_request.py +8 -0
- aeri/api/prompts/types/create_text_prompt_request.py +36 -0
- aeri/api/prompts/types/create_text_prompt_type.py +15 -0
- aeri/api/prompts/types/placeholder_message.py +16 -0
- aeri/api/prompts/types/placeholder_message_type.py +15 -0
- aeri/api/prompts/types/prompt.py +58 -0
- aeri/api/prompts/types/prompt_meta.py +35 -0
- aeri/api/prompts/types/prompt_meta_list_response.py +17 -0
- aeri/api/prompts/types/prompt_type.py +20 -0
- aeri/api/prompts/types/text_prompt.py +14 -0
- aeri/api/scim/__init__.py +94 -0
- aeri/api/scim/client.py +686 -0
- aeri/api/scim/raw_client.py +1528 -0
- aeri/api/scim/types/__init__.py +92 -0
- aeri/api/scim/types/authentication_scheme.py +20 -0
- aeri/api/scim/types/bulk_config.py +22 -0
- aeri/api/scim/types/empty_response.py +16 -0
- aeri/api/scim/types/filter_config.py +17 -0
- aeri/api/scim/types/resource_meta.py +17 -0
- aeri/api/scim/types/resource_type.py +27 -0
- aeri/api/scim/types/resource_types_response.py +21 -0
- aeri/api/scim/types/schema_extension.py +17 -0
- aeri/api/scim/types/schema_resource.py +19 -0
- aeri/api/scim/types/schemas_response.py +21 -0
- aeri/api/scim/types/scim_email.py +16 -0
- aeri/api/scim/types/scim_feature_support.py +14 -0
- aeri/api/scim/types/scim_name.py +14 -0
- aeri/api/scim/types/scim_user.py +24 -0
- aeri/api/scim/types/scim_users_list_response.py +25 -0
- aeri/api/scim/types/service_provider_config.py +36 -0
- aeri/api/scim/types/user_meta.py +20 -0
- aeri/api/score_configs/__init__.py +44 -0
- aeri/api/score_configs/client.py +526 -0
- aeri/api/score_configs/raw_client.py +1012 -0
- aeri/api/score_configs/types/__init__.py +46 -0
- aeri/api/score_configs/types/create_score_config_request.py +46 -0
- aeri/api/score_configs/types/score_configs.py +17 -0
- aeri/api/score_configs/types/update_score_config_request.py +53 -0
- aeri/api/scores/__init__.py +76 -0
- aeri/api/scores/client.py +420 -0
- aeri/api/scores/raw_client.py +656 -0
- aeri/api/scores/types/__init__.py +76 -0
- aeri/api/scores/types/get_scores_response.py +17 -0
- aeri/api/scores/types/get_scores_response_data.py +211 -0
- aeri/api/scores/types/get_scores_response_data_boolean.py +15 -0
- aeri/api/scores/types/get_scores_response_data_categorical.py +15 -0
- aeri/api/scores/types/get_scores_response_data_correction.py +15 -0
- aeri/api/scores/types/get_scores_response_data_numeric.py +15 -0
- aeri/api/scores/types/get_scores_response_trace_data.py +38 -0
- aeri/api/sessions/__init__.py +40 -0
- aeri/api/sessions/client.py +262 -0
- aeri/api/sessions/raw_client.py +500 -0
- aeri/api/sessions/types/__init__.py +40 -0
- aeri/api/sessions/types/paginated_sessions.py +17 -0
- aeri/api/trace/__init__.py +44 -0
- aeri/api/trace/client.py +728 -0
- aeri/api/trace/raw_client.py +1208 -0
- aeri/api/trace/types/__init__.py +46 -0
- aeri/api/trace/types/delete_trace_response.py +14 -0
- aeri/api/trace/types/sort.py +14 -0
- aeri/api/trace/types/traces.py +17 -0
- aeri/api/utils/__init__.py +44 -0
- aeri/api/utils/pagination/__init__.py +40 -0
- aeri/api/utils/pagination/types/__init__.py +40 -0
- aeri/api/utils/pagination/types/meta_response.py +38 -0
- aeri/batch_evaluation.py +1643 -0
- aeri/experiment.py +1044 -0
- aeri/langchain/CallbackHandler.py +1377 -0
- aeri/langchain/__init__.py +5 -0
- aeri/langchain/utils.py +212 -0
- aeri/logger.py +28 -0
- aeri/media.py +352 -0
- aeri/model.py +477 -0
- aeri/openai.py +1124 -0
- aeri/py.typed +0 -0
- aeri/span_filter.py +17 -0
- aeri/types.py +79 -0
- aeri/version.py +3 -0
- aeri_python-4.0.0.dist-info/METADATA +51 -0
- aeri_python-4.0.0.dist-info/RECORD +391 -0
- aeri_python-4.0.0.dist-info/WHEEL +4 -0
- aeri_python-4.0.0.dist-info/licenses/LICENSE +21 -0
aeri/_client/observe.py
ADDED
|
@@ -0,0 +1,661 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import contextvars
|
|
3
|
+
import inspect
|
|
4
|
+
import os
|
|
5
|
+
from functools import wraps
|
|
6
|
+
from typing import (
|
|
7
|
+
Any,
|
|
8
|
+
AsyncGenerator,
|
|
9
|
+
Callable,
|
|
10
|
+
Dict,
|
|
11
|
+
Generator,
|
|
12
|
+
Iterable,
|
|
13
|
+
List,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
TypeVar,
|
|
17
|
+
Union,
|
|
18
|
+
cast,
|
|
19
|
+
overload,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
from opentelemetry.util._decorator import _AgnosticContextManager
|
|
23
|
+
from typing_extensions import ParamSpec
|
|
24
|
+
|
|
25
|
+
from aeri._client.constants import (
|
|
26
|
+
ObservationTypeLiteralNoEvent,
|
|
27
|
+
get_observation_types_list,
|
|
28
|
+
)
|
|
29
|
+
from aeri._client.environment_variables import (
|
|
30
|
+
AERI_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED,
|
|
31
|
+
)
|
|
32
|
+
from aeri._client.get_client import _set_current_public_key, get_client
|
|
33
|
+
from aeri._client.span import (
|
|
34
|
+
AeriAgent,
|
|
35
|
+
AeriChain,
|
|
36
|
+
AeriEmbedding,
|
|
37
|
+
AeriEvaluator,
|
|
38
|
+
AeriGeneration,
|
|
39
|
+
AeriGuardrail,
|
|
40
|
+
AeriRetriever,
|
|
41
|
+
AeriSpan,
|
|
42
|
+
AeriTool,
|
|
43
|
+
)
|
|
44
|
+
from aeri.logger import aeri_logger as logger
|
|
45
|
+
from aeri.types import TraceContext
|
|
46
|
+
|
|
47
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
48
|
+
P = ParamSpec("P")
|
|
49
|
+
R = TypeVar("R")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class AeriDecorator:
|
|
53
|
+
"""Implementation of the @observe decorator for seamless Aeri tracing integration.
|
|
54
|
+
|
|
55
|
+
This class provides the core functionality for the @observe decorator, which enables
|
|
56
|
+
automatic tracing of functions and methods in your application with Aeri.
|
|
57
|
+
It handles both synchronous and asynchronous functions, maintains proper trace context,
|
|
58
|
+
and intelligently routes to the correct Aeri client instance.
|
|
59
|
+
|
|
60
|
+
The implementation follows a singleton pattern where a single decorator instance
|
|
61
|
+
handles all @observe decorations throughout the application codebase.
|
|
62
|
+
|
|
63
|
+
Features:
|
|
64
|
+
- Automatic span creation and management for both sync and async functions
|
|
65
|
+
- Proper trace context propagation between decorated functions
|
|
66
|
+
- Specialized handling for LLM-related spans with the 'as_type="generation"' parameter
|
|
67
|
+
- Type-safe decoration that preserves function signatures and type hints
|
|
68
|
+
- Support for explicit trace and parent span ID specification
|
|
69
|
+
- Thread-safe client resolution when multiple Aeri projects are used
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
@overload
|
|
73
|
+
def observe(self, func: F) -> F: ...
|
|
74
|
+
|
|
75
|
+
@overload
|
|
76
|
+
def observe(
|
|
77
|
+
self,
|
|
78
|
+
func: None = None,
|
|
79
|
+
*,
|
|
80
|
+
name: Optional[str] = None,
|
|
81
|
+
as_type: Optional[ObservationTypeLiteralNoEvent] = None,
|
|
82
|
+
capture_input: Optional[bool] = None,
|
|
83
|
+
capture_output: Optional[bool] = None,
|
|
84
|
+
transform_to_string: Optional[Callable[[Iterable], str]] = None,
|
|
85
|
+
) -> Callable[[F], F]: ...
|
|
86
|
+
|
|
87
|
+
def observe(
|
|
88
|
+
self,
|
|
89
|
+
func: Optional[F] = None,
|
|
90
|
+
*,
|
|
91
|
+
name: Optional[str] = None,
|
|
92
|
+
as_type: Optional[ObservationTypeLiteralNoEvent] = None,
|
|
93
|
+
capture_input: Optional[bool] = None,
|
|
94
|
+
capture_output: Optional[bool] = None,
|
|
95
|
+
transform_to_string: Optional[Callable[[Iterable], str]] = None,
|
|
96
|
+
) -> Union[F, Callable[[F], F]]:
|
|
97
|
+
"""Wrap a function to create and manage Aeri tracing around its execution, supporting both synchronous and asynchronous functions.
|
|
98
|
+
|
|
99
|
+
This decorator provides seamless integration of Aeri observability into your codebase. It automatically creates
|
|
100
|
+
spans or generations around function execution, capturing timing, inputs/outputs, and error states. The decorator
|
|
101
|
+
intelligently handles both synchronous and asynchronous functions, preserving function signatures and type hints.
|
|
102
|
+
|
|
103
|
+
Using OpenTelemetry's distributed tracing system, it maintains proper trace context propagation throughout your application,
|
|
104
|
+
enabling you to see hierarchical traces of function calls with detailed performance metrics and function-specific details.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
func (Optional[Callable]): The function to decorate. When used with parentheses @observe(), this will be None.
|
|
108
|
+
name (Optional[str]): Custom name for the created trace or span. If not provided, the function name is used.
|
|
109
|
+
as_type (Optional[Literal]): Set the observation type. Supported values:
|
|
110
|
+
"generation", "span", "agent", "tool", "chain", "retriever", "embedding", "evaluator", "guardrail".
|
|
111
|
+
Observation types are highlighted in the Aeri UI for filtering and visualization.
|
|
112
|
+
The types "generation" and "embedding" create a span on which additional attributes such as model metrics
|
|
113
|
+
can be set.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Callable: A wrapped version of the original function that automatically creates and manages Aeri spans.
|
|
117
|
+
|
|
118
|
+
Example:
|
|
119
|
+
For general function tracing with automatic naming:
|
|
120
|
+
```python
|
|
121
|
+
@observe()
|
|
122
|
+
def process_user_request(user_id, query):
|
|
123
|
+
# Function is automatically traced with name "process_user_request"
|
|
124
|
+
return get_response(query)
|
|
125
|
+
```
|
|
126
|
+
|
|
127
|
+
For language model generation tracking:
|
|
128
|
+
```python
|
|
129
|
+
@observe(name="answer-generation", as_type="generation")
|
|
130
|
+
async def generate_answer(query):
|
|
131
|
+
# Creates a generation-type span with extended LLM metrics
|
|
132
|
+
response = await openai.chat.completions.create(
|
|
133
|
+
model="gpt-4",
|
|
134
|
+
messages=[{"role": "user", "content": query}]
|
|
135
|
+
)
|
|
136
|
+
return response.choices[0].message.content
|
|
137
|
+
```
|
|
138
|
+
|
|
139
|
+
For trace context propagation between functions:
|
|
140
|
+
```python
|
|
141
|
+
@observe()
|
|
142
|
+
def main_process():
|
|
143
|
+
# Parent span is created
|
|
144
|
+
return sub_process() # Child span automatically connected to parent
|
|
145
|
+
|
|
146
|
+
@observe()
|
|
147
|
+
def sub_process():
|
|
148
|
+
# Automatically becomes a child span of main_process
|
|
149
|
+
return "result"
|
|
150
|
+
```
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
Exception: Propagates any exceptions from the wrapped function after logging them in the trace.
|
|
154
|
+
|
|
155
|
+
Notes:
|
|
156
|
+
- The decorator preserves the original function's signature, docstring, and return type.
|
|
157
|
+
- Proper parent-child relationships between spans are automatically maintained.
|
|
158
|
+
- Special keyword arguments can be passed to control tracing:
|
|
159
|
+
- aeri_trace_id: Explicitly set the trace ID for this function call
|
|
160
|
+
- aeri_parent_observation_id: Explicitly set the parent span ID
|
|
161
|
+
- aeri_public_key: Use a specific Aeri project (when multiple clients exist)
|
|
162
|
+
- For async functions, the decorator returns an async function wrapper.
|
|
163
|
+
- For sync functions, the decorator returns a synchronous wrapper.
|
|
164
|
+
"""
|
|
165
|
+
valid_types = set(get_observation_types_list(ObservationTypeLiteralNoEvent))
|
|
166
|
+
if as_type is not None and as_type not in valid_types:
|
|
167
|
+
logger.warning(
|
|
168
|
+
f"Invalid as_type '{as_type}'. Valid types are: {', '.join(sorted(valid_types))}. Defaulting to 'span'."
|
|
169
|
+
)
|
|
170
|
+
as_type = "span"
|
|
171
|
+
|
|
172
|
+
function_io_capture_enabled = os.environ.get(
|
|
173
|
+
AERI_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED, "True"
|
|
174
|
+
).lower() not in ("false", "0")
|
|
175
|
+
|
|
176
|
+
should_capture_input = (
|
|
177
|
+
capture_input if capture_input is not None else function_io_capture_enabled
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
should_capture_output = (
|
|
181
|
+
capture_output
|
|
182
|
+
if capture_output is not None
|
|
183
|
+
else function_io_capture_enabled
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def decorator(func: F) -> F:
|
|
187
|
+
return (
|
|
188
|
+
self._async_observe(
|
|
189
|
+
func,
|
|
190
|
+
name=name,
|
|
191
|
+
as_type=as_type,
|
|
192
|
+
capture_input=should_capture_input,
|
|
193
|
+
capture_output=should_capture_output,
|
|
194
|
+
transform_to_string=transform_to_string,
|
|
195
|
+
)
|
|
196
|
+
if asyncio.iscoroutinefunction(func)
|
|
197
|
+
else self._sync_observe(
|
|
198
|
+
func,
|
|
199
|
+
name=name,
|
|
200
|
+
as_type=as_type,
|
|
201
|
+
capture_input=should_capture_input,
|
|
202
|
+
capture_output=should_capture_output,
|
|
203
|
+
transform_to_string=transform_to_string,
|
|
204
|
+
)
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
"""Handle decorator with or without parentheses.
|
|
208
|
+
|
|
209
|
+
This logic enables the decorator to work both with and without parentheses:
|
|
210
|
+
- @observe - Python passes the function directly to the decorator
|
|
211
|
+
- @observe() - Python calls the decorator first, which must return a function decorator
|
|
212
|
+
|
|
213
|
+
When called without arguments (@observe), the func parameter contains the function to decorate,
|
|
214
|
+
so we directly apply the decorator to it. When called with parentheses (@observe()),
|
|
215
|
+
func is None, so we return the decorator function itself for Python to apply in the next step.
|
|
216
|
+
"""
|
|
217
|
+
if func is None:
|
|
218
|
+
return decorator
|
|
219
|
+
else:
|
|
220
|
+
return decorator(func)
|
|
221
|
+
|
|
222
|
+
def _async_observe(
|
|
223
|
+
self,
|
|
224
|
+
func: F,
|
|
225
|
+
*,
|
|
226
|
+
name: Optional[str],
|
|
227
|
+
as_type: Optional[ObservationTypeLiteralNoEvent],
|
|
228
|
+
capture_input: bool,
|
|
229
|
+
capture_output: bool,
|
|
230
|
+
transform_to_string: Optional[Callable[[Iterable], str]] = None,
|
|
231
|
+
) -> F:
|
|
232
|
+
@wraps(func)
|
|
233
|
+
async def async_wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any:
|
|
234
|
+
trace_id = cast(str, kwargs.pop("aeri_trace_id", None))
|
|
235
|
+
parent_observation_id = cast(
|
|
236
|
+
str, kwargs.pop("aeri_parent_observation_id", None)
|
|
237
|
+
)
|
|
238
|
+
trace_context: Optional[TraceContext] = (
|
|
239
|
+
{
|
|
240
|
+
"trace_id": trace_id,
|
|
241
|
+
"parent_span_id": parent_observation_id,
|
|
242
|
+
}
|
|
243
|
+
if trace_id
|
|
244
|
+
else None
|
|
245
|
+
)
|
|
246
|
+
final_name = name or func.__name__
|
|
247
|
+
input = (
|
|
248
|
+
self._get_input_from_func_args(
|
|
249
|
+
is_method=self._is_method(func),
|
|
250
|
+
func_args=args,
|
|
251
|
+
func_kwargs=kwargs,
|
|
252
|
+
)
|
|
253
|
+
if capture_input
|
|
254
|
+
else None
|
|
255
|
+
)
|
|
256
|
+
public_key = cast(str, kwargs.pop("aeri_public_key", None))
|
|
257
|
+
|
|
258
|
+
# Set public key in execution context for nested decorated functions
|
|
259
|
+
with _set_current_public_key(public_key):
|
|
260
|
+
aeri_client = get_client(public_key=public_key)
|
|
261
|
+
context_manager: Optional[
|
|
262
|
+
Union[
|
|
263
|
+
_AgnosticContextManager[AeriGeneration],
|
|
264
|
+
_AgnosticContextManager[AeriSpan],
|
|
265
|
+
_AgnosticContextManager[AeriAgent],
|
|
266
|
+
_AgnosticContextManager[AeriTool],
|
|
267
|
+
_AgnosticContextManager[AeriChain],
|
|
268
|
+
_AgnosticContextManager[AeriRetriever],
|
|
269
|
+
_AgnosticContextManager[AeriEvaluator],
|
|
270
|
+
_AgnosticContextManager[AeriEmbedding],
|
|
271
|
+
_AgnosticContextManager[AeriGuardrail],
|
|
272
|
+
]
|
|
273
|
+
] = (
|
|
274
|
+
aeri_client.start_as_current_observation(
|
|
275
|
+
name=final_name,
|
|
276
|
+
as_type=as_type or "span",
|
|
277
|
+
trace_context=trace_context,
|
|
278
|
+
input=input,
|
|
279
|
+
end_on_exit=False, # when returning a generator, closing on exit would be to early
|
|
280
|
+
)
|
|
281
|
+
if aeri_client
|
|
282
|
+
else None
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
if context_manager is None:
|
|
286
|
+
return await func(*args, **kwargs)
|
|
287
|
+
|
|
288
|
+
with context_manager as aeri_span_or_generation:
|
|
289
|
+
is_return_type_generator = False
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
result = await func(*args, **kwargs)
|
|
293
|
+
|
|
294
|
+
if capture_output is True:
|
|
295
|
+
if inspect.isgenerator(result):
|
|
296
|
+
is_return_type_generator = True
|
|
297
|
+
|
|
298
|
+
return self._wrap_sync_generator_result(
|
|
299
|
+
aeri_span_or_generation,
|
|
300
|
+
result,
|
|
301
|
+
transform_to_string,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
if inspect.isasyncgen(result):
|
|
305
|
+
is_return_type_generator = True
|
|
306
|
+
|
|
307
|
+
return self._wrap_async_generator_result(
|
|
308
|
+
aeri_span_or_generation,
|
|
309
|
+
result,
|
|
310
|
+
transform_to_string,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# handle starlette.StreamingResponse
|
|
314
|
+
if type(result).__name__ == "StreamingResponse" and hasattr(
|
|
315
|
+
result, "body_iterator"
|
|
316
|
+
):
|
|
317
|
+
is_return_type_generator = True
|
|
318
|
+
|
|
319
|
+
result.body_iterator = (
|
|
320
|
+
self._wrap_async_generator_result(
|
|
321
|
+
aeri_span_or_generation,
|
|
322
|
+
result.body_iterator,
|
|
323
|
+
transform_to_string,
|
|
324
|
+
)
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
aeri_span_or_generation.update(output=result)
|
|
328
|
+
|
|
329
|
+
return result
|
|
330
|
+
except Exception as e:
|
|
331
|
+
aeri_span_or_generation.update(
|
|
332
|
+
level="ERROR", status_message=str(e) or type(e).__name__
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
raise e
|
|
336
|
+
finally:
|
|
337
|
+
if not is_return_type_generator:
|
|
338
|
+
aeri_span_or_generation.end()
|
|
339
|
+
|
|
340
|
+
return cast(F, async_wrapper)
|
|
341
|
+
|
|
342
|
+
def _sync_observe(
|
|
343
|
+
self,
|
|
344
|
+
func: F,
|
|
345
|
+
*,
|
|
346
|
+
name: Optional[str],
|
|
347
|
+
as_type: Optional[ObservationTypeLiteralNoEvent],
|
|
348
|
+
capture_input: bool,
|
|
349
|
+
capture_output: bool,
|
|
350
|
+
transform_to_string: Optional[Callable[[Iterable], str]] = None,
|
|
351
|
+
) -> F:
|
|
352
|
+
@wraps(func)
|
|
353
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
354
|
+
trace_id = kwargs.pop("aeri_trace_id", None)
|
|
355
|
+
parent_observation_id = kwargs.pop("aeri_parent_observation_id", None)
|
|
356
|
+
trace_context: Optional[TraceContext] = (
|
|
357
|
+
{
|
|
358
|
+
"trace_id": trace_id,
|
|
359
|
+
"parent_span_id": parent_observation_id,
|
|
360
|
+
}
|
|
361
|
+
if trace_id
|
|
362
|
+
else None
|
|
363
|
+
)
|
|
364
|
+
final_name = name or func.__name__
|
|
365
|
+
input = (
|
|
366
|
+
self._get_input_from_func_args(
|
|
367
|
+
is_method=self._is_method(func),
|
|
368
|
+
func_args=args,
|
|
369
|
+
func_kwargs=kwargs,
|
|
370
|
+
)
|
|
371
|
+
if capture_input
|
|
372
|
+
else None
|
|
373
|
+
)
|
|
374
|
+
public_key = kwargs.pop("aeri_public_key", None)
|
|
375
|
+
|
|
376
|
+
# Set public key in execution context for nested decorated functions
|
|
377
|
+
with _set_current_public_key(public_key):
|
|
378
|
+
aeri_client = get_client(public_key=public_key)
|
|
379
|
+
context_manager: Optional[
|
|
380
|
+
Union[
|
|
381
|
+
_AgnosticContextManager[AeriGeneration],
|
|
382
|
+
_AgnosticContextManager[AeriSpan],
|
|
383
|
+
_AgnosticContextManager[AeriAgent],
|
|
384
|
+
_AgnosticContextManager[AeriTool],
|
|
385
|
+
_AgnosticContextManager[AeriChain],
|
|
386
|
+
_AgnosticContextManager[AeriRetriever],
|
|
387
|
+
_AgnosticContextManager[AeriEvaluator],
|
|
388
|
+
_AgnosticContextManager[AeriEmbedding],
|
|
389
|
+
_AgnosticContextManager[AeriGuardrail],
|
|
390
|
+
]
|
|
391
|
+
] = (
|
|
392
|
+
aeri_client.start_as_current_observation(
|
|
393
|
+
name=final_name,
|
|
394
|
+
as_type=as_type or "span",
|
|
395
|
+
trace_context=trace_context,
|
|
396
|
+
input=input,
|
|
397
|
+
end_on_exit=False, # when returning a generator, closing on exit would be to early
|
|
398
|
+
)
|
|
399
|
+
if aeri_client
|
|
400
|
+
else None
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
if context_manager is None:
|
|
404
|
+
return func(*args, **kwargs)
|
|
405
|
+
|
|
406
|
+
with context_manager as aeri_span_or_generation:
|
|
407
|
+
is_return_type_generator = False
|
|
408
|
+
|
|
409
|
+
try:
|
|
410
|
+
result = func(*args, **kwargs)
|
|
411
|
+
|
|
412
|
+
if capture_output is True:
|
|
413
|
+
if inspect.isgenerator(result):
|
|
414
|
+
is_return_type_generator = True
|
|
415
|
+
|
|
416
|
+
return self._wrap_sync_generator_result(
|
|
417
|
+
aeri_span_or_generation,
|
|
418
|
+
result,
|
|
419
|
+
transform_to_string,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
if inspect.isasyncgen(result):
|
|
423
|
+
is_return_type_generator = True
|
|
424
|
+
|
|
425
|
+
return self._wrap_async_generator_result(
|
|
426
|
+
aeri_span_or_generation,
|
|
427
|
+
result,
|
|
428
|
+
transform_to_string,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
# handle starlette.StreamingResponse
|
|
432
|
+
if type(result).__name__ == "StreamingResponse" and hasattr(
|
|
433
|
+
result, "body_iterator"
|
|
434
|
+
):
|
|
435
|
+
is_return_type_generator = True
|
|
436
|
+
|
|
437
|
+
result.body_iterator = (
|
|
438
|
+
self._wrap_async_generator_result(
|
|
439
|
+
aeri_span_or_generation,
|
|
440
|
+
result.body_iterator,
|
|
441
|
+
transform_to_string,
|
|
442
|
+
)
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
aeri_span_or_generation.update(output=result)
|
|
446
|
+
|
|
447
|
+
return result
|
|
448
|
+
except Exception as e:
|
|
449
|
+
aeri_span_or_generation.update(
|
|
450
|
+
level="ERROR", status_message=str(e) or type(e).__name__
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
raise e
|
|
454
|
+
finally:
|
|
455
|
+
if not is_return_type_generator:
|
|
456
|
+
aeri_span_or_generation.end()
|
|
457
|
+
|
|
458
|
+
return cast(F, sync_wrapper)
|
|
459
|
+
|
|
460
|
+
@staticmethod
|
|
461
|
+
def _is_method(func: Callable) -> bool:
|
|
462
|
+
return (
|
|
463
|
+
"self" in inspect.signature(func).parameters
|
|
464
|
+
or "cls" in inspect.signature(func).parameters
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
def _get_input_from_func_args(
|
|
468
|
+
self,
|
|
469
|
+
*,
|
|
470
|
+
is_method: bool = False,
|
|
471
|
+
func_args: Tuple = (),
|
|
472
|
+
func_kwargs: Dict = {},
|
|
473
|
+
) -> Dict:
|
|
474
|
+
# Remove implicitly passed "self" or "cls" argument for instance or class methods
|
|
475
|
+
logged_args = func_args[1:] if is_method else func_args
|
|
476
|
+
|
|
477
|
+
return {
|
|
478
|
+
"args": logged_args,
|
|
479
|
+
"kwargs": func_kwargs,
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
def _wrap_sync_generator_result(
|
|
483
|
+
self,
|
|
484
|
+
aeri_span_or_generation: Union[
|
|
485
|
+
AeriSpan,
|
|
486
|
+
AeriGeneration,
|
|
487
|
+
AeriAgent,
|
|
488
|
+
AeriTool,
|
|
489
|
+
AeriChain,
|
|
490
|
+
AeriRetriever,
|
|
491
|
+
AeriEvaluator,
|
|
492
|
+
AeriEmbedding,
|
|
493
|
+
AeriGuardrail,
|
|
494
|
+
],
|
|
495
|
+
generator: Generator,
|
|
496
|
+
transform_to_string: Optional[Callable[[Iterable], str]] = None,
|
|
497
|
+
) -> Any:
|
|
498
|
+
preserved_context = contextvars.copy_context()
|
|
499
|
+
|
|
500
|
+
return _ContextPreservedSyncGeneratorWrapper(
|
|
501
|
+
generator,
|
|
502
|
+
preserved_context,
|
|
503
|
+
aeri_span_or_generation,
|
|
504
|
+
transform_to_string,
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
def _wrap_async_generator_result(
|
|
508
|
+
self,
|
|
509
|
+
aeri_span_or_generation: Union[
|
|
510
|
+
AeriSpan,
|
|
511
|
+
AeriGeneration,
|
|
512
|
+
AeriAgent,
|
|
513
|
+
AeriTool,
|
|
514
|
+
AeriChain,
|
|
515
|
+
AeriRetriever,
|
|
516
|
+
AeriEvaluator,
|
|
517
|
+
AeriEmbedding,
|
|
518
|
+
AeriGuardrail,
|
|
519
|
+
],
|
|
520
|
+
generator: AsyncGenerator,
|
|
521
|
+
transform_to_string: Optional[Callable[[Iterable], str]] = None,
|
|
522
|
+
) -> Any:
|
|
523
|
+
preserved_context = contextvars.copy_context()
|
|
524
|
+
|
|
525
|
+
return _ContextPreservedAsyncGeneratorWrapper(
|
|
526
|
+
generator,
|
|
527
|
+
preserved_context,
|
|
528
|
+
aeri_span_or_generation,
|
|
529
|
+
transform_to_string,
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
_decorator = AeriDecorator()
|
|
534
|
+
|
|
535
|
+
observe = _decorator.observe
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
class _ContextPreservedSyncGeneratorWrapper:
|
|
539
|
+
"""Sync generator wrapper that ensures each iteration runs in preserved context."""
|
|
540
|
+
|
|
541
|
+
def __init__(
|
|
542
|
+
self,
|
|
543
|
+
generator: Generator,
|
|
544
|
+
context: contextvars.Context,
|
|
545
|
+
span: Union[
|
|
546
|
+
AeriSpan,
|
|
547
|
+
AeriGeneration,
|
|
548
|
+
AeriAgent,
|
|
549
|
+
AeriTool,
|
|
550
|
+
AeriChain,
|
|
551
|
+
AeriRetriever,
|
|
552
|
+
AeriEvaluator,
|
|
553
|
+
AeriEmbedding,
|
|
554
|
+
AeriGuardrail,
|
|
555
|
+
],
|
|
556
|
+
transform_fn: Optional[Callable[[Iterable], str]],
|
|
557
|
+
) -> None:
|
|
558
|
+
self.generator = generator
|
|
559
|
+
self.context = context
|
|
560
|
+
self.items: List[Any] = []
|
|
561
|
+
self.span = span
|
|
562
|
+
self.transform_fn = transform_fn
|
|
563
|
+
|
|
564
|
+
def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper":
|
|
565
|
+
return self
|
|
566
|
+
|
|
567
|
+
def __next__(self) -> Any:
|
|
568
|
+
try:
|
|
569
|
+
# Run the generator's __next__ in the preserved context
|
|
570
|
+
item = self.context.run(next, self.generator)
|
|
571
|
+
self.items.append(item)
|
|
572
|
+
|
|
573
|
+
return item
|
|
574
|
+
|
|
575
|
+
except StopIteration:
|
|
576
|
+
# Handle output and span cleanup when generator is exhausted
|
|
577
|
+
output: Any = self.items
|
|
578
|
+
|
|
579
|
+
if self.transform_fn is not None:
|
|
580
|
+
output = self.transform_fn(self.items)
|
|
581
|
+
|
|
582
|
+
elif all(isinstance(item, str) for item in self.items):
|
|
583
|
+
output = "".join(self.items)
|
|
584
|
+
|
|
585
|
+
self.span.update(output=output).end()
|
|
586
|
+
|
|
587
|
+
raise # Re-raise StopIteration
|
|
588
|
+
|
|
589
|
+
except Exception as e:
|
|
590
|
+
self.span.update(
|
|
591
|
+
level="ERROR", status_message=str(e) or type(e).__name__
|
|
592
|
+
).end()
|
|
593
|
+
|
|
594
|
+
raise
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
class _ContextPreservedAsyncGeneratorWrapper:
|
|
598
|
+
"""Async generator wrapper that ensures each iteration runs in preserved context."""
|
|
599
|
+
|
|
600
|
+
def __init__(
|
|
601
|
+
self,
|
|
602
|
+
generator: AsyncGenerator,
|
|
603
|
+
context: contextvars.Context,
|
|
604
|
+
span: Union[
|
|
605
|
+
AeriSpan,
|
|
606
|
+
AeriGeneration,
|
|
607
|
+
AeriAgent,
|
|
608
|
+
AeriTool,
|
|
609
|
+
AeriChain,
|
|
610
|
+
AeriRetriever,
|
|
611
|
+
AeriEvaluator,
|
|
612
|
+
AeriEmbedding,
|
|
613
|
+
AeriGuardrail,
|
|
614
|
+
],
|
|
615
|
+
transform_fn: Optional[Callable[[Iterable], str]],
|
|
616
|
+
) -> None:
|
|
617
|
+
self.generator = generator
|
|
618
|
+
self.context = context
|
|
619
|
+
self.items: List[Any] = []
|
|
620
|
+
self.span = span
|
|
621
|
+
self.transform_fn = transform_fn
|
|
622
|
+
|
|
623
|
+
def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper":
|
|
624
|
+
return self
|
|
625
|
+
|
|
626
|
+
async def __anext__(self) -> Any:
|
|
627
|
+
try:
|
|
628
|
+
# Run the generator's __anext__ in the preserved context
|
|
629
|
+
try:
|
|
630
|
+
# Python 3.10+ approach with context parameter
|
|
631
|
+
item = await asyncio.create_task(
|
|
632
|
+
self.generator.__anext__(), # type: ignore
|
|
633
|
+
context=self.context,
|
|
634
|
+
) # type: ignore
|
|
635
|
+
except TypeError:
|
|
636
|
+
# Python < 3.10 fallback - context parameter not supported
|
|
637
|
+
item = await self.generator.__anext__()
|
|
638
|
+
|
|
639
|
+
self.items.append(item)
|
|
640
|
+
|
|
641
|
+
return item
|
|
642
|
+
|
|
643
|
+
except StopAsyncIteration:
|
|
644
|
+
# Handle output and span cleanup when generator is exhausted
|
|
645
|
+
output: Any = self.items
|
|
646
|
+
|
|
647
|
+
if self.transform_fn is not None:
|
|
648
|
+
output = self.transform_fn(self.items)
|
|
649
|
+
|
|
650
|
+
elif all(isinstance(item, str) for item in self.items):
|
|
651
|
+
output = "".join(self.items)
|
|
652
|
+
|
|
653
|
+
self.span.update(output=output).end()
|
|
654
|
+
|
|
655
|
+
raise # Re-raise StopAsyncIteration
|
|
656
|
+
except Exception as e:
|
|
657
|
+
self.span.update(
|
|
658
|
+
level="ERROR", status_message=str(e) or type(e).__name__
|
|
659
|
+
).end()
|
|
660
|
+
|
|
661
|
+
raise
|