aeri-python 4.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aeri/__init__.py +72 -0
- aeri/_client/_validation.py +204 -0
- aeri/_client/attributes.py +188 -0
- aeri/_client/client.py +3761 -0
- aeri/_client/constants.py +65 -0
- aeri/_client/datasets.py +302 -0
- aeri/_client/environment_variables.py +158 -0
- aeri/_client/get_client.py +149 -0
- aeri/_client/observe.py +661 -0
- aeri/_client/propagation.py +475 -0
- aeri/_client/resource_manager.py +510 -0
- aeri/_client/span.py +1519 -0
- aeri/_client/span_filter.py +76 -0
- aeri/_client/span_processor.py +206 -0
- aeri/_client/utils.py +132 -0
- aeri/_task_manager/media_manager.py +331 -0
- aeri/_task_manager/media_upload_consumer.py +44 -0
- aeri/_task_manager/media_upload_queue.py +12 -0
- aeri/_task_manager/score_ingestion_consumer.py +208 -0
- aeri/_task_manager/task_manager.py +475 -0
- aeri/_utils/__init__.py +19 -0
- aeri/_utils/environment.py +34 -0
- aeri/_utils/error_logging.py +47 -0
- aeri/_utils/parse_error.py +99 -0
- aeri/_utils/prompt_cache.py +188 -0
- aeri/_utils/request.py +137 -0
- aeri/_utils/serializer.py +205 -0
- aeri/api/.fern/metadata.json +14 -0
- aeri/api/__init__.py +836 -0
- aeri/api/annotation_queues/__init__.py +82 -0
- aeri/api/annotation_queues/client.py +1111 -0
- aeri/api/annotation_queues/raw_client.py +2288 -0
- aeri/api/annotation_queues/types/__init__.py +84 -0
- aeri/api/annotation_queues/types/annotation_queue.py +28 -0
- aeri/api/annotation_queues/types/annotation_queue_assignment_request.py +16 -0
- aeri/api/annotation_queues/types/annotation_queue_item.py +34 -0
- aeri/api/annotation_queues/types/annotation_queue_object_type.py +26 -0
- aeri/api/annotation_queues/types/annotation_queue_status.py +22 -0
- aeri/api/annotation_queues/types/create_annotation_queue_assignment_response.py +18 -0
- aeri/api/annotation_queues/types/create_annotation_queue_item_request.py +25 -0
- aeri/api/annotation_queues/types/create_annotation_queue_request.py +20 -0
- aeri/api/annotation_queues/types/delete_annotation_queue_assignment_response.py +14 -0
- aeri/api/annotation_queues/types/delete_annotation_queue_item_response.py +15 -0
- aeri/api/annotation_queues/types/paginated_annotation_queue_items.py +17 -0
- aeri/api/annotation_queues/types/paginated_annotation_queues.py +17 -0
- aeri/api/annotation_queues/types/update_annotation_queue_item_request.py +15 -0
- aeri/api/blob_storage_integrations/__init__.py +73 -0
- aeri/api/blob_storage_integrations/client.py +550 -0
- aeri/api/blob_storage_integrations/raw_client.py +976 -0
- aeri/api/blob_storage_integrations/types/__init__.py +77 -0
- aeri/api/blob_storage_integrations/types/blob_storage_export_frequency.py +26 -0
- aeri/api/blob_storage_integrations/types/blob_storage_export_mode.py +26 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integration_deletion_response.py +14 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integration_file_type.py +26 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integration_response.py +64 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integration_status_response.py +50 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integration_type.py +26 -0
- aeri/api/blob_storage_integrations/types/blob_storage_integrations_response.py +15 -0
- aeri/api/blob_storage_integrations/types/blob_storage_sync_status.py +47 -0
- aeri/api/blob_storage_integrations/types/create_blob_storage_integration_request.py +91 -0
- aeri/api/client.py +679 -0
- aeri/api/comments/__init__.py +44 -0
- aeri/api/comments/client.py +407 -0
- aeri/api/comments/raw_client.py +750 -0
- aeri/api/comments/types/__init__.py +46 -0
- aeri/api/comments/types/create_comment_request.py +47 -0
- aeri/api/comments/types/create_comment_response.py +17 -0
- aeri/api/comments/types/get_comments_response.py +17 -0
- aeri/api/commons/__init__.py +210 -0
- aeri/api/commons/errors/__init__.py +56 -0
- aeri/api/commons/errors/access_denied_error.py +12 -0
- aeri/api/commons/errors/error.py +12 -0
- aeri/api/commons/errors/method_not_allowed_error.py +12 -0
- aeri/api/commons/errors/not_found_error.py +12 -0
- aeri/api/commons/errors/unauthorized_error.py +12 -0
- aeri/api/commons/types/__init__.py +190 -0
- aeri/api/commons/types/base_score.py +90 -0
- aeri/api/commons/types/base_score_v1.py +70 -0
- aeri/api/commons/types/boolean_score.py +26 -0
- aeri/api/commons/types/boolean_score_v1.py +26 -0
- aeri/api/commons/types/categorical_score.py +26 -0
- aeri/api/commons/types/categorical_score_v1.py +26 -0
- aeri/api/commons/types/comment.py +36 -0
- aeri/api/commons/types/comment_object_type.py +30 -0
- aeri/api/commons/types/config_category.py +15 -0
- aeri/api/commons/types/correction_score.py +26 -0
- aeri/api/commons/types/create_score_value.py +5 -0
- aeri/api/commons/types/dataset.py +49 -0
- aeri/api/commons/types/dataset_item.py +58 -0
- aeri/api/commons/types/dataset_run.py +63 -0
- aeri/api/commons/types/dataset_run_item.py +40 -0
- aeri/api/commons/types/dataset_run_with_items.py +19 -0
- aeri/api/commons/types/dataset_status.py +22 -0
- aeri/api/commons/types/map_value.py +11 -0
- aeri/api/commons/types/model.py +125 -0
- aeri/api/commons/types/model_price.py +14 -0
- aeri/api/commons/types/model_usage_unit.py +42 -0
- aeri/api/commons/types/numeric_score.py +17 -0
- aeri/api/commons/types/numeric_score_v1.py +17 -0
- aeri/api/commons/types/observation.py +142 -0
- aeri/api/commons/types/observation_level.py +30 -0
- aeri/api/commons/types/observation_v2.py +235 -0
- aeri/api/commons/types/observations_view.py +89 -0
- aeri/api/commons/types/pricing_tier.py +91 -0
- aeri/api/commons/types/pricing_tier_condition.py +68 -0
- aeri/api/commons/types/pricing_tier_input.py +76 -0
- aeri/api/commons/types/pricing_tier_operator.py +42 -0
- aeri/api/commons/types/score.py +201 -0
- aeri/api/commons/types/score_config.py +66 -0
- aeri/api/commons/types/score_config_data_type.py +26 -0
- aeri/api/commons/types/score_data_type.py +30 -0
- aeri/api/commons/types/score_source.py +26 -0
- aeri/api/commons/types/score_v1.py +131 -0
- aeri/api/commons/types/session.py +25 -0
- aeri/api/commons/types/session_with_traces.py +15 -0
- aeri/api/commons/types/trace.py +84 -0
- aeri/api/commons/types/trace_with_details.py +43 -0
- aeri/api/commons/types/trace_with_full_details.py +45 -0
- aeri/api/commons/types/usage.py +59 -0
- aeri/api/core/__init__.py +111 -0
- aeri/api/core/api_error.py +23 -0
- aeri/api/core/client_wrapper.py +141 -0
- aeri/api/core/datetime_utils.py +30 -0
- aeri/api/core/enum.py +20 -0
- aeri/api/core/file.py +70 -0
- aeri/api/core/force_multipart.py +18 -0
- aeri/api/core/http_client.py +711 -0
- aeri/api/core/http_response.py +55 -0
- aeri/api/core/http_sse/__init__.py +48 -0
- aeri/api/core/http_sse/_api.py +114 -0
- aeri/api/core/http_sse/_decoders.py +66 -0
- aeri/api/core/http_sse/_exceptions.py +7 -0
- aeri/api/core/http_sse/_models.py +17 -0
- aeri/api/core/jsonable_encoder.py +102 -0
- aeri/api/core/pydantic_utilities.py +310 -0
- aeri/api/core/query_encoder.py +60 -0
- aeri/api/core/remove_none_from_dict.py +11 -0
- aeri/api/core/request_options.py +35 -0
- aeri/api/core/serialization.py +282 -0
- aeri/api/dataset_items/__init__.py +52 -0
- aeri/api/dataset_items/client.py +499 -0
- aeri/api/dataset_items/raw_client.py +973 -0
- aeri/api/dataset_items/types/__init__.py +50 -0
- aeri/api/dataset_items/types/create_dataset_item_request.py +37 -0
- aeri/api/dataset_items/types/delete_dataset_item_response.py +17 -0
- aeri/api/dataset_items/types/paginated_dataset_items.py +17 -0
- aeri/api/dataset_run_items/__init__.py +43 -0
- aeri/api/dataset_run_items/client.py +323 -0
- aeri/api/dataset_run_items/raw_client.py +547 -0
- aeri/api/dataset_run_items/types/__init__.py +44 -0
- aeri/api/dataset_run_items/types/create_dataset_run_item_request.py +51 -0
- aeri/api/dataset_run_items/types/paginated_dataset_run_items.py +17 -0
- aeri/api/datasets/__init__.py +55 -0
- aeri/api/datasets/client.py +661 -0
- aeri/api/datasets/raw_client.py +1368 -0
- aeri/api/datasets/types/__init__.py +53 -0
- aeri/api/datasets/types/create_dataset_request.py +31 -0
- aeri/api/datasets/types/delete_dataset_run_response.py +14 -0
- aeri/api/datasets/types/paginated_dataset_runs.py +17 -0
- aeri/api/datasets/types/paginated_datasets.py +17 -0
- aeri/api/health/__init__.py +44 -0
- aeri/api/health/client.py +112 -0
- aeri/api/health/errors/__init__.py +42 -0
- aeri/api/health/errors/service_unavailable_error.py +13 -0
- aeri/api/health/raw_client.py +227 -0
- aeri/api/health/types/__init__.py +40 -0
- aeri/api/health/types/health_response.py +30 -0
- aeri/api/ingestion/__init__.py +169 -0
- aeri/api/ingestion/client.py +221 -0
- aeri/api/ingestion/raw_client.py +293 -0
- aeri/api/ingestion/types/__init__.py +169 -0
- aeri/api/ingestion/types/base_event.py +27 -0
- aeri/api/ingestion/types/create_event_body.py +14 -0
- aeri/api/ingestion/types/create_event_event.py +15 -0
- aeri/api/ingestion/types/create_generation_body.py +40 -0
- aeri/api/ingestion/types/create_generation_event.py +15 -0
- aeri/api/ingestion/types/create_observation_event.py +15 -0
- aeri/api/ingestion/types/create_span_body.py +19 -0
- aeri/api/ingestion/types/create_span_event.py +15 -0
- aeri/api/ingestion/types/ingestion_error.py +17 -0
- aeri/api/ingestion/types/ingestion_event.py +155 -0
- aeri/api/ingestion/types/ingestion_response.py +17 -0
- aeri/api/ingestion/types/ingestion_success.py +15 -0
- aeri/api/ingestion/types/ingestion_usage.py +8 -0
- aeri/api/ingestion/types/observation_body.py +53 -0
- aeri/api/ingestion/types/observation_type.py +54 -0
- aeri/api/ingestion/types/open_ai_completion_usage_schema.py +26 -0
- aeri/api/ingestion/types/open_ai_response_usage_schema.py +24 -0
- aeri/api/ingestion/types/open_ai_usage.py +28 -0
- aeri/api/ingestion/types/optional_observation_body.py +36 -0
- aeri/api/ingestion/types/score_body.py +75 -0
- aeri/api/ingestion/types/score_event.py +15 -0
- aeri/api/ingestion/types/sdk_log_body.py +14 -0
- aeri/api/ingestion/types/sdk_log_event.py +15 -0
- aeri/api/ingestion/types/trace_body.py +36 -0
- aeri/api/ingestion/types/trace_event.py +15 -0
- aeri/api/ingestion/types/update_event_body.py +14 -0
- aeri/api/ingestion/types/update_generation_body.py +40 -0
- aeri/api/ingestion/types/update_generation_event.py +15 -0
- aeri/api/ingestion/types/update_observation_event.py +15 -0
- aeri/api/ingestion/types/update_span_body.py +19 -0
- aeri/api/ingestion/types/update_span_event.py +15 -0
- aeri/api/ingestion/types/usage_details.py +10 -0
- aeri/api/legacy/__init__.py +61 -0
- aeri/api/legacy/client.py +105 -0
- aeri/api/legacy/metrics_v1/__init__.py +40 -0
- aeri/api/legacy/metrics_v1/client.py +214 -0
- aeri/api/legacy/metrics_v1/raw_client.py +322 -0
- aeri/api/legacy/metrics_v1/types/__init__.py +40 -0
- aeri/api/legacy/metrics_v1/types/metrics_response.py +19 -0
- aeri/api/legacy/observations_v1/__init__.py +43 -0
- aeri/api/legacy/observations_v1/client.py +523 -0
- aeri/api/legacy/observations_v1/raw_client.py +759 -0
- aeri/api/legacy/observations_v1/types/__init__.py +44 -0
- aeri/api/legacy/observations_v1/types/observations.py +17 -0
- aeri/api/legacy/observations_v1/types/observations_views.py +17 -0
- aeri/api/legacy/raw_client.py +13 -0
- aeri/api/legacy/score_v1/__init__.py +43 -0
- aeri/api/legacy/score_v1/client.py +329 -0
- aeri/api/legacy/score_v1/raw_client.py +545 -0
- aeri/api/legacy/score_v1/types/__init__.py +44 -0
- aeri/api/legacy/score_v1/types/create_score_request.py +75 -0
- aeri/api/legacy/score_v1/types/create_score_response.py +17 -0
- aeri/api/llm_connections/__init__.py +55 -0
- aeri/api/llm_connections/client.py +311 -0
- aeri/api/llm_connections/raw_client.py +541 -0
- aeri/api/llm_connections/types/__init__.py +53 -0
- aeri/api/llm_connections/types/llm_adapter.py +38 -0
- aeri/api/llm_connections/types/llm_connection.py +77 -0
- aeri/api/llm_connections/types/paginated_llm_connections.py +17 -0
- aeri/api/llm_connections/types/upsert_llm_connection_request.py +69 -0
- aeri/api/media/__init__.py +58 -0
- aeri/api/media/client.py +427 -0
- aeri/api/media/raw_client.py +739 -0
- aeri/api/media/types/__init__.py +56 -0
- aeri/api/media/types/get_media_response.py +55 -0
- aeri/api/media/types/get_media_upload_url_request.py +51 -0
- aeri/api/media/types/get_media_upload_url_response.py +28 -0
- aeri/api/media/types/media_content_type.py +232 -0
- aeri/api/media/types/patch_media_body.py +43 -0
- aeri/api/metrics/__init__.py +40 -0
- aeri/api/metrics/client.py +422 -0
- aeri/api/metrics/raw_client.py +530 -0
- aeri/api/metrics/types/__init__.py +40 -0
- aeri/api/metrics/types/metrics_v2response.py +19 -0
- aeri/api/models/__init__.py +43 -0
- aeri/api/models/client.py +523 -0
- aeri/api/models/raw_client.py +993 -0
- aeri/api/models/types/__init__.py +44 -0
- aeri/api/models/types/create_model_request.py +103 -0
- aeri/api/models/types/paginated_models.py +17 -0
- aeri/api/observations/__init__.py +43 -0
- aeri/api/observations/client.py +522 -0
- aeri/api/observations/raw_client.py +641 -0
- aeri/api/observations/types/__init__.py +44 -0
- aeri/api/observations/types/observations_v2meta.py +21 -0
- aeri/api/observations/types/observations_v2response.py +28 -0
- aeri/api/opentelemetry/__init__.py +67 -0
- aeri/api/opentelemetry/client.py +276 -0
- aeri/api/opentelemetry/raw_client.py +291 -0
- aeri/api/opentelemetry/types/__init__.py +65 -0
- aeri/api/opentelemetry/types/otel_attribute.py +27 -0
- aeri/api/opentelemetry/types/otel_attribute_value.py +46 -0
- aeri/api/opentelemetry/types/otel_resource.py +24 -0
- aeri/api/opentelemetry/types/otel_resource_span.py +32 -0
- aeri/api/opentelemetry/types/otel_scope.py +34 -0
- aeri/api/opentelemetry/types/otel_scope_span.py +28 -0
- aeri/api/opentelemetry/types/otel_span.py +76 -0
- aeri/api/opentelemetry/types/otel_trace_response.py +16 -0
- aeri/api/organizations/__init__.py +73 -0
- aeri/api/organizations/client.py +756 -0
- aeri/api/organizations/raw_client.py +1707 -0
- aeri/api/organizations/types/__init__.py +71 -0
- aeri/api/organizations/types/delete_membership_request.py +16 -0
- aeri/api/organizations/types/membership_deletion_response.py +17 -0
- aeri/api/organizations/types/membership_request.py +18 -0
- aeri/api/organizations/types/membership_response.py +20 -0
- aeri/api/organizations/types/membership_role.py +30 -0
- aeri/api/organizations/types/memberships_response.py +15 -0
- aeri/api/organizations/types/organization_api_key.py +31 -0
- aeri/api/organizations/types/organization_api_keys_response.py +19 -0
- aeri/api/organizations/types/organization_project.py +25 -0
- aeri/api/organizations/types/organization_projects_response.py +15 -0
- aeri/api/projects/__init__.py +67 -0
- aeri/api/projects/client.py +760 -0
- aeri/api/projects/raw_client.py +1577 -0
- aeri/api/projects/types/__init__.py +65 -0
- aeri/api/projects/types/api_key_deletion_response.py +18 -0
- aeri/api/projects/types/api_key_list.py +23 -0
- aeri/api/projects/types/api_key_response.py +30 -0
- aeri/api/projects/types/api_key_summary.py +35 -0
- aeri/api/projects/types/organization.py +22 -0
- aeri/api/projects/types/project.py +34 -0
- aeri/api/projects/types/project_deletion_response.py +15 -0
- aeri/api/projects/types/projects.py +15 -0
- aeri/api/prompt_version/__init__.py +4 -0
- aeri/api/prompt_version/client.py +157 -0
- aeri/api/prompt_version/raw_client.py +264 -0
- aeri/api/prompts/__init__.py +100 -0
- aeri/api/prompts/client.py +550 -0
- aeri/api/prompts/raw_client.py +987 -0
- aeri/api/prompts/types/__init__.py +96 -0
- aeri/api/prompts/types/base_prompt.py +42 -0
- aeri/api/prompts/types/chat_message.py +17 -0
- aeri/api/prompts/types/chat_message_type.py +15 -0
- aeri/api/prompts/types/chat_message_with_placeholders.py +8 -0
- aeri/api/prompts/types/chat_prompt.py +15 -0
- aeri/api/prompts/types/create_chat_prompt_request.py +37 -0
- aeri/api/prompts/types/create_chat_prompt_type.py +15 -0
- aeri/api/prompts/types/create_prompt_request.py +8 -0
- aeri/api/prompts/types/create_text_prompt_request.py +36 -0
- aeri/api/prompts/types/create_text_prompt_type.py +15 -0
- aeri/api/prompts/types/placeholder_message.py +16 -0
- aeri/api/prompts/types/placeholder_message_type.py +15 -0
- aeri/api/prompts/types/prompt.py +58 -0
- aeri/api/prompts/types/prompt_meta.py +35 -0
- aeri/api/prompts/types/prompt_meta_list_response.py +17 -0
- aeri/api/prompts/types/prompt_type.py +20 -0
- aeri/api/prompts/types/text_prompt.py +14 -0
- aeri/api/scim/__init__.py +94 -0
- aeri/api/scim/client.py +686 -0
- aeri/api/scim/raw_client.py +1528 -0
- aeri/api/scim/types/__init__.py +92 -0
- aeri/api/scim/types/authentication_scheme.py +20 -0
- aeri/api/scim/types/bulk_config.py +22 -0
- aeri/api/scim/types/empty_response.py +16 -0
- aeri/api/scim/types/filter_config.py +17 -0
- aeri/api/scim/types/resource_meta.py +17 -0
- aeri/api/scim/types/resource_type.py +27 -0
- aeri/api/scim/types/resource_types_response.py +21 -0
- aeri/api/scim/types/schema_extension.py +17 -0
- aeri/api/scim/types/schema_resource.py +19 -0
- aeri/api/scim/types/schemas_response.py +21 -0
- aeri/api/scim/types/scim_email.py +16 -0
- aeri/api/scim/types/scim_feature_support.py +14 -0
- aeri/api/scim/types/scim_name.py +14 -0
- aeri/api/scim/types/scim_user.py +24 -0
- aeri/api/scim/types/scim_users_list_response.py +25 -0
- aeri/api/scim/types/service_provider_config.py +36 -0
- aeri/api/scim/types/user_meta.py +20 -0
- aeri/api/score_configs/__init__.py +44 -0
- aeri/api/score_configs/client.py +526 -0
- aeri/api/score_configs/raw_client.py +1012 -0
- aeri/api/score_configs/types/__init__.py +46 -0
- aeri/api/score_configs/types/create_score_config_request.py +46 -0
- aeri/api/score_configs/types/score_configs.py +17 -0
- aeri/api/score_configs/types/update_score_config_request.py +53 -0
- aeri/api/scores/__init__.py +76 -0
- aeri/api/scores/client.py +420 -0
- aeri/api/scores/raw_client.py +656 -0
- aeri/api/scores/types/__init__.py +76 -0
- aeri/api/scores/types/get_scores_response.py +17 -0
- aeri/api/scores/types/get_scores_response_data.py +211 -0
- aeri/api/scores/types/get_scores_response_data_boolean.py +15 -0
- aeri/api/scores/types/get_scores_response_data_categorical.py +15 -0
- aeri/api/scores/types/get_scores_response_data_correction.py +15 -0
- aeri/api/scores/types/get_scores_response_data_numeric.py +15 -0
- aeri/api/scores/types/get_scores_response_trace_data.py +38 -0
- aeri/api/sessions/__init__.py +40 -0
- aeri/api/sessions/client.py +262 -0
- aeri/api/sessions/raw_client.py +500 -0
- aeri/api/sessions/types/__init__.py +40 -0
- aeri/api/sessions/types/paginated_sessions.py +17 -0
- aeri/api/trace/__init__.py +44 -0
- aeri/api/trace/client.py +728 -0
- aeri/api/trace/raw_client.py +1208 -0
- aeri/api/trace/types/__init__.py +46 -0
- aeri/api/trace/types/delete_trace_response.py +14 -0
- aeri/api/trace/types/sort.py +14 -0
- aeri/api/trace/types/traces.py +17 -0
- aeri/api/utils/__init__.py +44 -0
- aeri/api/utils/pagination/__init__.py +40 -0
- aeri/api/utils/pagination/types/__init__.py +40 -0
- aeri/api/utils/pagination/types/meta_response.py +38 -0
- aeri/batch_evaluation.py +1643 -0
- aeri/experiment.py +1044 -0
- aeri/langchain/CallbackHandler.py +1377 -0
- aeri/langchain/__init__.py +5 -0
- aeri/langchain/utils.py +212 -0
- aeri/logger.py +28 -0
- aeri/media.py +352 -0
- aeri/model.py +477 -0
- aeri/openai.py +1124 -0
- aeri/py.typed +0 -0
- aeri/span_filter.py +17 -0
- aeri/types.py +79 -0
- aeri/version.py +3 -0
- aeri_python-4.0.0.dist-info/METADATA +51 -0
- aeri_python-4.0.0.dist-info/RECORD +391 -0
- aeri_python-4.0.0.dist-info/WHEEL +4 -0
- aeri_python-4.0.0.dist-info/licenses/LICENSE +21 -0
aeri/langchain/utils.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""@private"""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any, Dict, List, Literal, Optional, cast
|
|
5
|
+
|
|
6
|
+
# NOTE ON DEPENDENCIES:
|
|
7
|
+
# - since Jan 2024, there is https://pypi.org/project/langchain-openai/ which is a separate package and imports openai models.
|
|
8
|
+
# Decided to not make this a dependency of aeri as few people will have this. Need to match these models manually
|
|
9
|
+
# - langchain_community is loaded as a dependency of langchain, but extremely unreliable. Decided to not depend on it.
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _extract_model_name(
|
|
13
|
+
serialized: Optional[Dict[str, Any]],
|
|
14
|
+
**kwargs: Any,
|
|
15
|
+
) -> Optional[str]:
|
|
16
|
+
"""Extracts the model name from the serialized or kwargs object. This is used to get the model names for Aeri."""
|
|
17
|
+
# In this function we return on the first match, so the order of operations is important
|
|
18
|
+
|
|
19
|
+
# First, extract known models where we know the model name aka id
|
|
20
|
+
# Extract the model name from the provided path (aray) in the serialized or kwargs object
|
|
21
|
+
models_by_id = [
|
|
22
|
+
("ChatGoogleGenerativeAI", ["kwargs", "model"], "serialized"),
|
|
23
|
+
("ChatMistralAI", ["kwargs", "model"], "serialized"),
|
|
24
|
+
("ChatVertexAi", ["kwargs", "model_name"], "serialized"),
|
|
25
|
+
("ChatVertexAI", ["kwargs", "model_name"], "serialized"),
|
|
26
|
+
("OpenAI", ["invocation_params", "model_name"], "kwargs"),
|
|
27
|
+
("ChatOpenAI", ["invocation_params", "model_name"], "kwargs"),
|
|
28
|
+
("AzureChatOpenAI", ["invocation_params", "model"], "kwargs"),
|
|
29
|
+
("AzureChatOpenAI", ["invocation_params", "model_name"], "kwargs"),
|
|
30
|
+
("AzureChatOpenAI", ["invocation_params", "azure_deployment"], "kwargs"),
|
|
31
|
+
("HuggingFacePipeline", ["invocation_params", "model_id"], "kwargs"),
|
|
32
|
+
("BedrockChat", ["kwargs", "model_id"], "serialized"),
|
|
33
|
+
("Bedrock", ["kwargs", "model_id"], "serialized"),
|
|
34
|
+
("BedrockLLM", ["kwargs", "model_id"], "serialized"),
|
|
35
|
+
("ChatBedrock", ["kwargs", "model_id"], "serialized"),
|
|
36
|
+
("LlamaCpp", ["invocation_params", "model_path"], "kwargs"),
|
|
37
|
+
("WatsonxLLM", ["invocation_params", "model_id"], "kwargs"),
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
for model_name, keys, select_from in models_by_id:
|
|
41
|
+
model = _extract_model_by_path_for_id(
|
|
42
|
+
model_name,
|
|
43
|
+
serialized,
|
|
44
|
+
kwargs,
|
|
45
|
+
keys,
|
|
46
|
+
cast(Literal["serialized", "kwargs"], select_from),
|
|
47
|
+
)
|
|
48
|
+
if model:
|
|
49
|
+
return model
|
|
50
|
+
|
|
51
|
+
# Second, we match AzureOpenAI as we need to extract the model name, fdeployment version and deployment name
|
|
52
|
+
if serialized:
|
|
53
|
+
serialized_id = serialized.get("id")
|
|
54
|
+
if (
|
|
55
|
+
serialized_id
|
|
56
|
+
and isinstance(serialized_id, list)
|
|
57
|
+
and len(serialized_id) > 0
|
|
58
|
+
and serialized_id[-1] == "AzureOpenAI"
|
|
59
|
+
):
|
|
60
|
+
invocation_params = kwargs.get("invocation_params")
|
|
61
|
+
if invocation_params and isinstance(invocation_params, dict):
|
|
62
|
+
if invocation_params.get("model"):
|
|
63
|
+
return str(invocation_params.get("model"))
|
|
64
|
+
|
|
65
|
+
if invocation_params.get("model_name"):
|
|
66
|
+
return str(invocation_params.get("model_name"))
|
|
67
|
+
|
|
68
|
+
deployment_name = None
|
|
69
|
+
deployment_version = None
|
|
70
|
+
|
|
71
|
+
serialized_kwargs = serialized.get("kwargs")
|
|
72
|
+
if serialized_kwargs and isinstance(serialized_kwargs, dict):
|
|
73
|
+
if serialized_kwargs.get("openai_api_version"):
|
|
74
|
+
deployment_version = serialized_kwargs.get("deployment_version")
|
|
75
|
+
|
|
76
|
+
if serialized_kwargs.get("deployment_name"):
|
|
77
|
+
deployment_name = serialized_kwargs.get("deployment_name")
|
|
78
|
+
|
|
79
|
+
if not isinstance(deployment_name, str):
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
if not isinstance(deployment_version, str):
|
|
83
|
+
return deployment_name
|
|
84
|
+
|
|
85
|
+
return (
|
|
86
|
+
deployment_name + "-" + deployment_version
|
|
87
|
+
if deployment_version not in deployment_name
|
|
88
|
+
else deployment_name
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Third, for some models, we are unable to extract the model by a path in an object. Aeri provides us with a string representation of the model pbjects
|
|
92
|
+
# We use regex to extract the model from the repr string
|
|
93
|
+
models_by_pattern = [
|
|
94
|
+
("Anthropic", "model", "anthropic"),
|
|
95
|
+
("ChatAnthropic", "model", None),
|
|
96
|
+
("ChatTongyi", "model_name", None),
|
|
97
|
+
("ChatCohere", "model", None),
|
|
98
|
+
("Cohere", "model", None),
|
|
99
|
+
("HuggingFaceHub", "model", None),
|
|
100
|
+
("ChatAnyscale", "model_name", None),
|
|
101
|
+
("TextGen", "model", "text-gen"),
|
|
102
|
+
("Ollama", "model", None),
|
|
103
|
+
("OllamaLLM", "model", None),
|
|
104
|
+
("ChatOllama", "model", None),
|
|
105
|
+
("ChatFireworks", "model", None),
|
|
106
|
+
("ChatPerplexity", "model", None),
|
|
107
|
+
("VLLM", "model", None),
|
|
108
|
+
("Xinference", "model_uid", None),
|
|
109
|
+
("ChatOCIGenAI", "model_id", None),
|
|
110
|
+
("DeepInfra", "model_id", None),
|
|
111
|
+
]
|
|
112
|
+
|
|
113
|
+
for model_name, pattern, default in models_by_pattern:
|
|
114
|
+
model = _extract_model_from_repr_by_pattern(
|
|
115
|
+
model_name, serialized, pattern, default
|
|
116
|
+
)
|
|
117
|
+
if model:
|
|
118
|
+
return model
|
|
119
|
+
|
|
120
|
+
# Finally, we try to extract the most likely paths as a catch all
|
|
121
|
+
random_paths = [
|
|
122
|
+
["kwargs", "model_name"],
|
|
123
|
+
["kwargs", "model"],
|
|
124
|
+
["invocation_params", "model_name"],
|
|
125
|
+
["invocation_params", "model"],
|
|
126
|
+
]
|
|
127
|
+
for select in ["kwargs", "serialized"]:
|
|
128
|
+
for path in random_paths:
|
|
129
|
+
model = _extract_model_by_path(
|
|
130
|
+
serialized, kwargs, path, cast(Literal["serialized", "kwargs"], select)
|
|
131
|
+
)
|
|
132
|
+
if model:
|
|
133
|
+
return str(model)
|
|
134
|
+
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _extract_model_from_repr_by_pattern(
|
|
139
|
+
id: str,
|
|
140
|
+
serialized: Optional[Dict[str, Any]],
|
|
141
|
+
pattern: str,
|
|
142
|
+
default: Optional[str] = None,
|
|
143
|
+
) -> Optional[str]:
|
|
144
|
+
if serialized is None:
|
|
145
|
+
return None
|
|
146
|
+
|
|
147
|
+
serialized_id = serialized.get("id")
|
|
148
|
+
if (
|
|
149
|
+
serialized_id
|
|
150
|
+
and isinstance(serialized_id, list)
|
|
151
|
+
and len(serialized_id) > 0
|
|
152
|
+
and serialized_id[-1] == id
|
|
153
|
+
):
|
|
154
|
+
repr_str = serialized.get("repr")
|
|
155
|
+
if repr_str and isinstance(repr_str, str):
|
|
156
|
+
extracted = _extract_model_with_regex(pattern, repr_str)
|
|
157
|
+
return extracted if extracted else default if default else None
|
|
158
|
+
|
|
159
|
+
return None
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _extract_model_with_regex(pattern: str, text: str) -> Optional[str]:
|
|
163
|
+
match = re.search(rf"{pattern}='(.*?)'", text)
|
|
164
|
+
if match:
|
|
165
|
+
return match.group(1)
|
|
166
|
+
return None
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _extract_model_by_path_for_id(
|
|
170
|
+
id: str,
|
|
171
|
+
serialized: Optional[Dict[str, Any]],
|
|
172
|
+
kwargs: Dict[str, Any],
|
|
173
|
+
keys: List[str],
|
|
174
|
+
select_from: Literal["serialized", "kwargs"],
|
|
175
|
+
) -> Optional[str]:
|
|
176
|
+
if serialized is None and select_from == "serialized":
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
if serialized:
|
|
180
|
+
serialized_id = serialized.get("id")
|
|
181
|
+
if (
|
|
182
|
+
serialized_id
|
|
183
|
+
and isinstance(serialized_id, list)
|
|
184
|
+
and len(serialized_id) > 0
|
|
185
|
+
and serialized_id[-1] == id
|
|
186
|
+
):
|
|
187
|
+
result = _extract_model_by_path(serialized, kwargs, keys, select_from)
|
|
188
|
+
return str(result) if result is not None else None
|
|
189
|
+
|
|
190
|
+
return None
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _extract_model_by_path(
|
|
194
|
+
serialized: Optional[Dict[str, Any]],
|
|
195
|
+
kwargs: dict,
|
|
196
|
+
keys: List[str],
|
|
197
|
+
select_from: Literal["serialized", "kwargs"],
|
|
198
|
+
) -> Optional[str]:
|
|
199
|
+
if serialized is None and select_from == "serialized":
|
|
200
|
+
return None
|
|
201
|
+
|
|
202
|
+
current_obj = kwargs if select_from == "kwargs" else serialized
|
|
203
|
+
|
|
204
|
+
for key in keys:
|
|
205
|
+
if current_obj and isinstance(current_obj, dict):
|
|
206
|
+
current_obj = current_obj.get(key)
|
|
207
|
+
else:
|
|
208
|
+
return None
|
|
209
|
+
if not current_obj:
|
|
210
|
+
return None
|
|
211
|
+
|
|
212
|
+
return str(current_obj) if current_obj else None
|
aeri/logger.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Logger configuration for Aeri OpenTelemetry integration.
|
|
2
|
+
|
|
3
|
+
This module initializes and configures loggers used by the Aeri OpenTelemetry integration.
|
|
4
|
+
It sets up the main 'aeri' logger and configures the httpx logger to reduce noise.
|
|
5
|
+
|
|
6
|
+
Log levels used throughout Aeri:
|
|
7
|
+
- DEBUG: Detailed tracing information useful for development and diagnostics
|
|
8
|
+
- INFO: Normal operational information confirming expected behavior
|
|
9
|
+
- WARNING: Indication of potential issues that don't prevent operation
|
|
10
|
+
- ERROR: Errors that prevent specific operations but allow continued execution
|
|
11
|
+
- CRITICAL: Critical errors that may prevent further operation
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
|
|
16
|
+
# Create the main Aeri logger
|
|
17
|
+
aeri_logger = logging.getLogger("aeri")
|
|
18
|
+
aeri_logger.setLevel(logging.WARNING)
|
|
19
|
+
|
|
20
|
+
# Configure httpx logger to reduce noise from HTTP requests
|
|
21
|
+
httpx_logger = logging.getLogger("httpx")
|
|
22
|
+
httpx_logger.setLevel(logging.WARNING)
|
|
23
|
+
|
|
24
|
+
# Add console handler if no handlers exist
|
|
25
|
+
console_handler = logging.StreamHandler()
|
|
26
|
+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
27
|
+
console_handler.setFormatter(formatter)
|
|
28
|
+
httpx_logger.addHandler(console_handler)
|
aeri/media.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
"""This module contains the AeriMedia class, which is used to wrap media objects for upload to Aeri."""
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import hashlib
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, TypeVar, cast
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
|
|
11
|
+
from aeri.api import MediaContentType
|
|
12
|
+
from aeri.logger import aeri_logger as logger
|
|
13
|
+
from aeri.types import ParsedMediaReference
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from aeri._client.client import Aeri
|
|
17
|
+
|
|
18
|
+
T = TypeVar("T")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AeriMedia:
|
|
22
|
+
"""A class for wrapping media objects for upload to Aeri.
|
|
23
|
+
|
|
24
|
+
This class handles the preparation and formatting of media content for Aeri,
|
|
25
|
+
supporting both base64 data URIs and raw content bytes.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
obj (Optional[object]): The source object to be wrapped. Can be accessed via the `obj` attribute.
|
|
29
|
+
base64_data_uri (Optional[str]): A base64-encoded data URI containing the media content
|
|
30
|
+
and content type (e.g., "data:image/jpeg;base64,/9j/4AAQ...").
|
|
31
|
+
content_type (Optional[str]): The MIME type of the media content when providing raw bytes.
|
|
32
|
+
content_bytes (Optional[bytes]): Raw bytes of the media content.
|
|
33
|
+
file_path (Optional[str]): The path to the file containing the media content. For relative paths,
|
|
34
|
+
the current working directory is used.
|
|
35
|
+
|
|
36
|
+
Raises:
|
|
37
|
+
ValueError: If neither base64_data_uri or the combination of content_bytes
|
|
38
|
+
and content_type is provided.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
obj: object
|
|
42
|
+
|
|
43
|
+
_content_bytes: Optional[bytes]
|
|
44
|
+
_content_type: Optional[MediaContentType]
|
|
45
|
+
_source: Optional[str]
|
|
46
|
+
_media_id: Optional[str]
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
*,
|
|
51
|
+
obj: Optional[object] = None,
|
|
52
|
+
base64_data_uri: Optional[str] = None,
|
|
53
|
+
content_type: Optional[MediaContentType] = None,
|
|
54
|
+
content_bytes: Optional[bytes] = None,
|
|
55
|
+
file_path: Optional[str] = None,
|
|
56
|
+
):
|
|
57
|
+
"""Initialize a AeriMedia object.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
obj: The object to wrap.
|
|
61
|
+
|
|
62
|
+
base64_data_uri: A base64-encoded data URI containing the media content
|
|
63
|
+
and content type (e.g., "data:image/jpeg;base64,/9j/4AAQ...").
|
|
64
|
+
content_type: The MIME type of the media content when providing raw bytes or reading from a file.
|
|
65
|
+
content_bytes: Raw bytes of the media content.
|
|
66
|
+
file_path: The path to the file containing the media content. For relative paths,
|
|
67
|
+
the current working directory is used.
|
|
68
|
+
"""
|
|
69
|
+
self.obj = obj
|
|
70
|
+
|
|
71
|
+
if base64_data_uri is not None:
|
|
72
|
+
parsed_data = self._parse_base64_data_uri(base64_data_uri)
|
|
73
|
+
self._content_bytes, self._content_type = parsed_data
|
|
74
|
+
self._source = "base64_data_uri"
|
|
75
|
+
|
|
76
|
+
elif content_bytes is not None and content_type is not None:
|
|
77
|
+
self._content_type = content_type
|
|
78
|
+
self._content_bytes = content_bytes
|
|
79
|
+
self._source = "bytes"
|
|
80
|
+
elif (
|
|
81
|
+
file_path is not None
|
|
82
|
+
and content_type is not None
|
|
83
|
+
and os.path.exists(file_path)
|
|
84
|
+
):
|
|
85
|
+
self._content_bytes = self._read_file(file_path)
|
|
86
|
+
self._content_type = content_type if self._content_bytes else None
|
|
87
|
+
self._source = "file" if self._content_bytes else None
|
|
88
|
+
else:
|
|
89
|
+
logger.error(
|
|
90
|
+
"base64_data_uri, or content_bytes and content_type, or file_path must be provided to AeriMedia"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
self._content_bytes = None
|
|
94
|
+
self._content_type = None
|
|
95
|
+
self._source = None
|
|
96
|
+
|
|
97
|
+
self._media_id = self._get_media_id()
|
|
98
|
+
|
|
99
|
+
def _read_file(self, file_path: str) -> Optional[bytes]:
|
|
100
|
+
try:
|
|
101
|
+
with open(file_path, "rb") as file:
|
|
102
|
+
return file.read()
|
|
103
|
+
except Exception as e:
|
|
104
|
+
logger.error(f"Error reading file at path {file_path}", exc_info=e)
|
|
105
|
+
|
|
106
|
+
return None
|
|
107
|
+
|
|
108
|
+
def _get_media_id(self) -> Optional[str]:
|
|
109
|
+
content_hash = self._content_sha256_hash
|
|
110
|
+
|
|
111
|
+
if content_hash is None:
|
|
112
|
+
return None
|
|
113
|
+
|
|
114
|
+
# Convert hash to base64Url
|
|
115
|
+
url_safe_content_hash = content_hash.replace("+", "-").replace("/", "_")
|
|
116
|
+
|
|
117
|
+
return url_safe_content_hash[:22]
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def _content_length(self) -> Optional[int]:
|
|
121
|
+
return len(self._content_bytes) if self._content_bytes else None
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def _content_sha256_hash(self) -> Optional[str]:
|
|
125
|
+
if self._content_bytes is None:
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
sha256_hash_bytes = hashlib.sha256(self._content_bytes).digest()
|
|
129
|
+
|
|
130
|
+
return base64.b64encode(sha256_hash_bytes).decode("utf-8")
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def _reference_string(self) -> Optional[str]:
|
|
134
|
+
if self._content_type is None or self._source is None or self._media_id is None:
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
return f"@@@aeriMedia:type={self._content_type}|id={self._media_id}|source={self._source}@@@"
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
def parse_reference_string(reference_string: str) -> ParsedMediaReference:
|
|
141
|
+
"""Parse a media reference string into a ParsedMediaReference.
|
|
142
|
+
|
|
143
|
+
Example reference string:
|
|
144
|
+
"@@@aeriMedia:type=image/jpeg|id=some-uuid|source=base64_data_uri@@@"
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
reference_string: The reference string to parse.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
A TypedDict with the media_id, source, and content_type.
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
ValueError: If the reference string is empty or not a string.
|
|
154
|
+
ValueError: If the reference string does not start with "@@@aeriMedia:type=".
|
|
155
|
+
ValueError: If the reference string does not end with "@@@".
|
|
156
|
+
ValueError: If the reference string is missing required fields.
|
|
157
|
+
"""
|
|
158
|
+
if not reference_string:
|
|
159
|
+
raise ValueError("Reference string is empty")
|
|
160
|
+
|
|
161
|
+
if not isinstance(reference_string, str):
|
|
162
|
+
raise ValueError("Reference string is not a string")
|
|
163
|
+
|
|
164
|
+
if not reference_string.startswith("@@@aeriMedia:type="):
|
|
165
|
+
raise ValueError(
|
|
166
|
+
"Reference string does not start with '@@@aeriMedia:type='"
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
if not reference_string.endswith("@@@"):
|
|
170
|
+
raise ValueError("Reference string does not end with '@@@'")
|
|
171
|
+
|
|
172
|
+
content = reference_string[len("@@@aeriMedia:") :].rstrip("@@@")
|
|
173
|
+
|
|
174
|
+
# Split into key-value pairs
|
|
175
|
+
pairs = content.split("|")
|
|
176
|
+
parsed_data = {}
|
|
177
|
+
|
|
178
|
+
for pair in pairs:
|
|
179
|
+
key, value = pair.split("=", 1)
|
|
180
|
+
parsed_data[key] = value
|
|
181
|
+
|
|
182
|
+
# Verify all required fields are present
|
|
183
|
+
if not all(key in parsed_data for key in ["type", "id", "source"]):
|
|
184
|
+
raise ValueError("Missing required fields in reference string")
|
|
185
|
+
|
|
186
|
+
return ParsedMediaReference(
|
|
187
|
+
media_id=parsed_data["id"],
|
|
188
|
+
source=parsed_data["source"],
|
|
189
|
+
content_type=cast(MediaContentType, parsed_data["type"]),
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def _parse_base64_data_uri(
|
|
193
|
+
self, data: str
|
|
194
|
+
) -> Tuple[Optional[bytes], Optional[MediaContentType]]:
|
|
195
|
+
# Example data URI: data:image/jpeg;base64,/9j/4AAQ...
|
|
196
|
+
try:
|
|
197
|
+
if not data or not isinstance(data, str):
|
|
198
|
+
raise ValueError("Data URI is not a string")
|
|
199
|
+
|
|
200
|
+
if not data.startswith("data:"):
|
|
201
|
+
raise ValueError("Data URI does not start with 'data:'")
|
|
202
|
+
|
|
203
|
+
header, actual_data = data[5:].split(",", 1)
|
|
204
|
+
if not header or not actual_data:
|
|
205
|
+
raise ValueError("Invalid URI")
|
|
206
|
+
|
|
207
|
+
# Split header into parts and check for base64
|
|
208
|
+
header_parts = header.split(";")
|
|
209
|
+
if "base64" not in header_parts:
|
|
210
|
+
raise ValueError("Data is not base64 encoded")
|
|
211
|
+
|
|
212
|
+
# Content type is the first part
|
|
213
|
+
content_type = header_parts[0]
|
|
214
|
+
if not content_type:
|
|
215
|
+
raise ValueError("Content type is empty")
|
|
216
|
+
|
|
217
|
+
return base64.b64decode(actual_data), cast(MediaContentType, content_type)
|
|
218
|
+
|
|
219
|
+
except Exception as e:
|
|
220
|
+
logger.error("Error parsing base64 data URI", exc_info=e)
|
|
221
|
+
|
|
222
|
+
return None, None
|
|
223
|
+
|
|
224
|
+
@staticmethod
|
|
225
|
+
def resolve_media_references(
|
|
226
|
+
*,
|
|
227
|
+
obj: T,
|
|
228
|
+
aeri_client: "Aeri",
|
|
229
|
+
resolve_with: Literal["base64_data_uri"],
|
|
230
|
+
max_depth: int = 10,
|
|
231
|
+
content_fetch_timeout_seconds: int = 10,
|
|
232
|
+
) -> T:
|
|
233
|
+
"""Replace media reference strings in an object with base64 data URIs.
|
|
234
|
+
|
|
235
|
+
This method recursively traverses an object (up to max_depth) looking for media reference strings
|
|
236
|
+
in the format "@@@aeriMedia:...@@@". When found, it (synchronously) fetches the actual media content using
|
|
237
|
+
the provided Aeri client and replaces the reference string with a base64 data URI.
|
|
238
|
+
|
|
239
|
+
If fetching media content fails for a reference string, a warning is logged and the reference
|
|
240
|
+
string is left unchanged.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
obj: The object to process. Can be a primitive value, array, or nested object.
|
|
244
|
+
If the object has a __dict__ attribute, a dict will be returned instead of the original object type.
|
|
245
|
+
aeri_client: Aeri client instance used to fetch media content.
|
|
246
|
+
resolve_with: The representation of the media content to replace the media reference string with.
|
|
247
|
+
Currently only "base64_data_uri" is supported.
|
|
248
|
+
max_depth: Optional. Default is 10. The maximum depth to traverse the object.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
A deep copy of the input object with all media references replaced with base64 data URIs where possible.
|
|
252
|
+
If the input object has a __dict__ attribute, a dict will be returned instead of the original object type.
|
|
253
|
+
|
|
254
|
+
Example:
|
|
255
|
+
obj = {
|
|
256
|
+
"image": "@@@aeriMedia:type=image/jpeg|id=123|source=bytes@@@",
|
|
257
|
+
"nested": {
|
|
258
|
+
"pdf": "@@@aeriMedia:type=application/pdf|id=456|source=bytes@@@"
|
|
259
|
+
}
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
result = await AeriMedia.resolve_media_references(obj, aeri_client)
|
|
263
|
+
|
|
264
|
+
# Result:
|
|
265
|
+
# {
|
|
266
|
+
# "image": "data:image/jpeg;base64,/9j/4AAQSkZJRg...",
|
|
267
|
+
# "nested": {
|
|
268
|
+
# "pdf": "data:application/pdf;base64,JVBERi0xLjcK..."
|
|
269
|
+
# }
|
|
270
|
+
# }
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
def traverse(obj: Any, depth: int) -> Any:
|
|
274
|
+
if depth > max_depth:
|
|
275
|
+
return obj
|
|
276
|
+
|
|
277
|
+
# Handle string
|
|
278
|
+
if isinstance(obj, str):
|
|
279
|
+
regex = r"@@@aeriMedia:.+?@@@"
|
|
280
|
+
reference_string_matches = re.findall(regex, obj)
|
|
281
|
+
if len(reference_string_matches) == 0:
|
|
282
|
+
return obj
|
|
283
|
+
|
|
284
|
+
result = obj
|
|
285
|
+
reference_string_to_media_content = {}
|
|
286
|
+
httpx_client = (
|
|
287
|
+
aeri_client._resources.httpx_client
|
|
288
|
+
if aeri_client._resources is not None
|
|
289
|
+
else None
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
for reference_string in reference_string_matches:
|
|
293
|
+
try:
|
|
294
|
+
parsed_media_reference = AeriMedia.parse_reference_string(
|
|
295
|
+
reference_string
|
|
296
|
+
)
|
|
297
|
+
media_data = aeri_client.api.media.get(
|
|
298
|
+
parsed_media_reference["media_id"]
|
|
299
|
+
)
|
|
300
|
+
media_content = (
|
|
301
|
+
httpx_client.get(
|
|
302
|
+
media_data.url,
|
|
303
|
+
timeout=content_fetch_timeout_seconds,
|
|
304
|
+
)
|
|
305
|
+
if httpx_client is not None
|
|
306
|
+
else httpx.get(
|
|
307
|
+
media_data.url, timeout=content_fetch_timeout_seconds
|
|
308
|
+
)
|
|
309
|
+
)
|
|
310
|
+
media_content.raise_for_status()
|
|
311
|
+
|
|
312
|
+
base64_media_content = base64.b64encode(
|
|
313
|
+
media_content.content
|
|
314
|
+
).decode()
|
|
315
|
+
base64_data_uri = f"data:{media_data.content_type};base64,{base64_media_content}"
|
|
316
|
+
|
|
317
|
+
reference_string_to_media_content[reference_string] = (
|
|
318
|
+
base64_data_uri
|
|
319
|
+
)
|
|
320
|
+
except Exception as e:
|
|
321
|
+
logger.warning(
|
|
322
|
+
f"Error fetching media content for reference string {reference_string}: {e}"
|
|
323
|
+
)
|
|
324
|
+
# Do not replace the reference string if there's an error
|
|
325
|
+
continue
|
|
326
|
+
|
|
327
|
+
for (
|
|
328
|
+
ref_str,
|
|
329
|
+
media_content_str,
|
|
330
|
+
) in reference_string_to_media_content.items():
|
|
331
|
+
result = result.replace(ref_str, media_content_str)
|
|
332
|
+
|
|
333
|
+
return result
|
|
334
|
+
|
|
335
|
+
# Handle arrays
|
|
336
|
+
if isinstance(obj, list):
|
|
337
|
+
return [traverse(item, depth + 1) for item in obj]
|
|
338
|
+
|
|
339
|
+
# Handle dictionaries
|
|
340
|
+
if isinstance(obj, dict):
|
|
341
|
+
return {key: traverse(value, depth + 1) for key, value in obj.items()}
|
|
342
|
+
|
|
343
|
+
# Handle objects:
|
|
344
|
+
if hasattr(obj, "__dict__"):
|
|
345
|
+
return {
|
|
346
|
+
key: traverse(value, depth + 1)
|
|
347
|
+
for key, value in obj.__dict__.items()
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
return obj
|
|
351
|
+
|
|
352
|
+
return cast(T, traverse(obj, 0))
|