orca-sdk 0.0.91__tar.gz → 0.0.92__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/PKG-INFO +3 -1
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/base_label_prediction_result.py +9 -1
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +8 -0
- orca_sdk-0.0.92/orca_sdk/_shared/__init__.py +1 -0
- orca_sdk-0.0.92/orca_sdk/_shared/metrics.py +195 -0
- orca_sdk-0.0.92/orca_sdk/_shared/metrics_test.py +169 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/classification_model.py +144 -19
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/classification_model_test.py +49 -22
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/telemetry.py +3 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/pyproject.toml +3 -1
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/README.md +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/__init__.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/__init__.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/__init__.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/classification_model/update_model_classification_model_name_or_id_patch.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/create_embedding_evaluation_datasource_name_or_id_embedding_evaluation_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/get_embedding_evaluation_datasource_name_or_id_embedding_evaluation_task_id_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/datasource/list_embedding_evaluations_datasource_name_or_id_embedding_evaluation_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/analyze_memoryset_memoryset_name_or_id_analysis_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/batch_delete_memoryset_batch_delete_memoryset_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/explain_prediction_telemetry_prediction_prediction_id_explanation_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/client.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/errors.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/__init__.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/base_model.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/cluster_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/column_info.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/column_type.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/constraint_violation_error_response.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/create_api_key_request_scope_item.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/create_api_key_response_scope_item.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/delete_memorysets_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/embed_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/embedding_evaluation_result.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/embedding_model_result.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/evaluation_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/evaluation_response.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/feedback_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/feedback_type.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/filter_item.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/filter_item_op.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/get_memories_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +8 -8
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_feedback_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_metadata.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/labeled_memoryset_update.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/list_memories_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/lookup_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/lookup_score_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memory_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_clustering_method.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_partitioning_method.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_cluster_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_duplicate_analysis_config.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_duplicate_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_label_analysis_config.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_label_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_neighbor_analysis_config.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics_lookup_score_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_projection_analysis_config.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/memoryset_projection_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/paginated_labeled_memory_with_feedback_metrics.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/precision_recall_curve.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/prediction_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_0.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_1.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/rac_head_type.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/rac_model_update.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/roc_curve.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/task.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/task_status.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/task_status_info.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_2.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/telemetry_filter_item.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/telemetry_filter_item_op.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/telemetry_memories_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/telemetry_sort_options.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/telemetry_sort_options_direction.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/py.typed +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_generated_api_client/types.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/__init__.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/analysis_ui.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/analysis_ui_style.css +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/auth.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/auth_test.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/common.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/data_parsing.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/data_parsing_test.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/prediction_result_ui.css +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/prediction_result_ui.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/task.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/value_parser.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/_utils/value_parser_test.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/conftest.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/credentials.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/credentials_test.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/datasource.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/datasource_test.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/embedding_model.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/embedding_model_test.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/memoryset.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/memoryset_test.py +0 -0
- {orca_sdk-0.0.91 → orca_sdk-0.0.92}/orca_sdk/telemetry_test.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: orca_sdk
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.92
|
|
4
4
|
Summary: SDK for interacting with Orca Services
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Orca DB Inc.
|
|
@@ -20,7 +20,9 @@ Requires-Dist: pandas (>=2.2.3,<3.0.0)
|
|
|
20
20
|
Requires-Dist: pyarrow (>=18.0.0,<19.0.0)
|
|
21
21
|
Requires-Dist: python-dateutil (>=2.8.0,<3.0.0)
|
|
22
22
|
Requires-Dist: python-dotenv (>=1.1.0,<2.0.0)
|
|
23
|
+
Requires-Dist: scikit-learn (>=1.6.1,<2.0.0)
|
|
23
24
|
Requires-Dist: torch (>=2.5.1,<3.0.0)
|
|
25
|
+
Requires-Dist: transformers (>=4.51.3,<5.0.0)
|
|
24
26
|
Description-Content-Type: text/markdown
|
|
25
27
|
|
|
26
28
|
<!--
|
|
@@ -10,7 +10,7 @@ The main change is:
|
|
|
10
10
|
|
|
11
11
|
# flake8: noqa: C901
|
|
12
12
|
|
|
13
|
-
from typing import Any, Type, TypeVar, Union, cast
|
|
13
|
+
from typing import Any, List, Type, TypeVar, Union, cast
|
|
14
14
|
|
|
15
15
|
from attrs import define as _attrs_define
|
|
16
16
|
from attrs import field as _attrs_field
|
|
@@ -28,6 +28,7 @@ class BaseLabelPredictionResult:
|
|
|
28
28
|
anomaly_score (Union[None, float]):
|
|
29
29
|
label (int):
|
|
30
30
|
label_name (Union[None, str]):
|
|
31
|
+
logits (List[float]):
|
|
31
32
|
"""
|
|
32
33
|
|
|
33
34
|
prediction_id: Union[None, str]
|
|
@@ -35,6 +36,7 @@ class BaseLabelPredictionResult:
|
|
|
35
36
|
anomaly_score: Union[None, float]
|
|
36
37
|
label: int
|
|
37
38
|
label_name: Union[None, str]
|
|
39
|
+
logits: List[float]
|
|
38
40
|
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
|
|
39
41
|
|
|
40
42
|
def to_dict(self) -> dict[str, Any]:
|
|
@@ -51,6 +53,8 @@ class BaseLabelPredictionResult:
|
|
|
51
53
|
label_name: Union[None, str]
|
|
52
54
|
label_name = self.label_name
|
|
53
55
|
|
|
56
|
+
logits = self.logits
|
|
57
|
+
|
|
54
58
|
field_dict: dict[str, Any] = {}
|
|
55
59
|
field_dict.update(self.additional_properties)
|
|
56
60
|
field_dict.update(
|
|
@@ -60,6 +64,7 @@ class BaseLabelPredictionResult:
|
|
|
60
64
|
"anomaly_score": anomaly_score,
|
|
61
65
|
"label": label,
|
|
62
66
|
"label_name": label_name,
|
|
67
|
+
"logits": logits,
|
|
63
68
|
}
|
|
64
69
|
)
|
|
65
70
|
|
|
@@ -94,12 +99,15 @@ class BaseLabelPredictionResult:
|
|
|
94
99
|
|
|
95
100
|
label_name = _parse_label_name(d.pop("label_name"))
|
|
96
101
|
|
|
102
|
+
logits = cast(List[float], d.pop("logits"))
|
|
103
|
+
|
|
97
104
|
base_label_prediction_result = cls(
|
|
98
105
|
prediction_id=prediction_id,
|
|
99
106
|
confidence=confidence,
|
|
100
107
|
anomaly_score=anomaly_score,
|
|
101
108
|
label=label,
|
|
102
109
|
label_name=label_name,
|
|
110
|
+
logits=logits,
|
|
103
111
|
)
|
|
104
112
|
|
|
105
113
|
base_label_prediction_result.additional_properties = d
|
|
@@ -43,6 +43,7 @@ class LabeledMemorysetMetadata:
|
|
|
43
43
|
label_names (List[str]):
|
|
44
44
|
created_at (datetime.datetime):
|
|
45
45
|
updated_at (datetime.datetime):
|
|
46
|
+
memories_updated_at (datetime.datetime):
|
|
46
47
|
insertion_task_id (str):
|
|
47
48
|
insertion_status (TaskStatus): Status of task in the task queue
|
|
48
49
|
metrics (MemorysetMetrics):
|
|
@@ -59,6 +60,7 @@ class LabeledMemorysetMetadata:
|
|
|
59
60
|
label_names: List[str]
|
|
60
61
|
created_at: datetime.datetime
|
|
61
62
|
updated_at: datetime.datetime
|
|
63
|
+
memories_updated_at: datetime.datetime
|
|
62
64
|
insertion_task_id: str
|
|
63
65
|
insertion_status: TaskStatus
|
|
64
66
|
metrics: "MemorysetMetrics"
|
|
@@ -97,6 +99,8 @@ class LabeledMemorysetMetadata:
|
|
|
97
99
|
|
|
98
100
|
updated_at = self.updated_at.isoformat()
|
|
99
101
|
|
|
102
|
+
memories_updated_at = self.memories_updated_at.isoformat()
|
|
103
|
+
|
|
100
104
|
insertion_task_id = self.insertion_task_id
|
|
101
105
|
|
|
102
106
|
insertion_status = (
|
|
@@ -120,6 +124,7 @@ class LabeledMemorysetMetadata:
|
|
|
120
124
|
"label_names": label_names,
|
|
121
125
|
"created_at": created_at,
|
|
122
126
|
"updated_at": updated_at,
|
|
127
|
+
"memories_updated_at": memories_updated_at,
|
|
123
128
|
"insertion_task_id": insertion_task_id,
|
|
124
129
|
"insertion_status": insertion_status,
|
|
125
130
|
"metrics": metrics,
|
|
@@ -180,6 +185,8 @@ class LabeledMemorysetMetadata:
|
|
|
180
185
|
|
|
181
186
|
updated_at = isoparse(d.pop("updated_at"))
|
|
182
187
|
|
|
188
|
+
memories_updated_at = isoparse(d.pop("memories_updated_at"))
|
|
189
|
+
|
|
183
190
|
insertion_task_id = d.pop("insertion_task_id")
|
|
184
191
|
|
|
185
192
|
insertion_status = TaskStatus(d.pop("insertion_status"))
|
|
@@ -198,6 +205,7 @@ class LabeledMemorysetMetadata:
|
|
|
198
205
|
label_names=label_names,
|
|
199
206
|
created_at=created_at,
|
|
200
207
|
updated_at=updated_at,
|
|
208
|
+
memories_updated_at=memories_updated_at,
|
|
201
209
|
insertion_task_id=insertion_task_id,
|
|
202
210
|
insertion_status=insertion_status,
|
|
203
211
|
metrics=metrics,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .metrics import calculate_pr_curve, calculate_roc_curve, compute_classifier_metrics
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains metrics for usage with the Hugging Face Trainer.
|
|
3
|
+
|
|
4
|
+
IMPORTANT:
|
|
5
|
+
- This is a shared file between OrcaLib and the Orca SDK.
|
|
6
|
+
- Please ensure that it does not have any dependencies on the OrcaLib code.
|
|
7
|
+
- Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
|
|
8
|
+
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import Literal, Tuple, TypedDict
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
from numpy.typing import NDArray
|
|
15
|
+
from scipy.special import softmax
|
|
16
|
+
from sklearn.metrics import accuracy_score, auc, f1_score, log_loss
|
|
17
|
+
from sklearn.metrics import precision_recall_curve as sklearn_precision_recall_curve
|
|
18
|
+
from sklearn.metrics import roc_auc_score
|
|
19
|
+
from sklearn.metrics import roc_curve as sklearn_roc_curve
|
|
20
|
+
from transformers.trainer_utils import EvalPrediction
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ClassificationMetrics(TypedDict):
|
|
24
|
+
accuracy: float
|
|
25
|
+
f1_score: float
|
|
26
|
+
roc_auc: float | None # receiver operating characteristic area under the curve (if all classes are present)
|
|
27
|
+
pr_auc: float | None # precision-recall area under the curve (only for binary classification)
|
|
28
|
+
log_loss: float # cross-entropy loss for probabilities
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def compute_classifier_metrics(eval_pred: EvalPrediction) -> ClassificationMetrics:
|
|
32
|
+
"""
|
|
33
|
+
Compute standard metrics for classifier with Hugging Face Trainer.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
eval_pred: The predictions containing logits and expected labels as given by the Trainer.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A dictionary containing the accuracy, f1 score, and ROC AUC score.
|
|
40
|
+
"""
|
|
41
|
+
logits, references = eval_pred
|
|
42
|
+
if isinstance(logits, tuple):
|
|
43
|
+
logits = logits[0]
|
|
44
|
+
if not isinstance(logits, np.ndarray):
|
|
45
|
+
raise ValueError("Logits must be a numpy array")
|
|
46
|
+
if not isinstance(references, np.ndarray):
|
|
47
|
+
raise ValueError(
|
|
48
|
+
"Multiple label columns found, use the `label_names` training argument to specify which one to use"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if not (logits > 0).all():
|
|
52
|
+
# convert logits to probabilities with softmax if necessary
|
|
53
|
+
probabilities = softmax(logits)
|
|
54
|
+
elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
|
|
55
|
+
# convert logits to probabilities through normalization if necessary
|
|
56
|
+
probabilities = logits / logits.sum(-1, keepdims=True)
|
|
57
|
+
else:
|
|
58
|
+
probabilities = logits
|
|
59
|
+
|
|
60
|
+
return classification_scores(references, probabilities)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def classification_scores(
|
|
64
|
+
references: NDArray[np.int64],
|
|
65
|
+
probabilities: NDArray[np.float32],
|
|
66
|
+
average: Literal["micro", "macro", "weighted", "binary"] | None = None,
|
|
67
|
+
multi_class: Literal["ovr", "ovo"] = "ovr",
|
|
68
|
+
) -> ClassificationMetrics:
|
|
69
|
+
if probabilities.ndim == 1:
|
|
70
|
+
# convert 1D probabilities (binary) to 2D logits
|
|
71
|
+
probabilities = np.column_stack([1 - probabilities, probabilities])
|
|
72
|
+
elif probabilities.ndim == 2:
|
|
73
|
+
if probabilities.shape[1] < 2:
|
|
74
|
+
raise ValueError("Use a different metric function for regression tasks")
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
77
|
+
|
|
78
|
+
predictions = np.argmax(probabilities, axis=-1)
|
|
79
|
+
|
|
80
|
+
num_classes_references = len(set(references))
|
|
81
|
+
num_classes_predictions = len(set(predictions))
|
|
82
|
+
|
|
83
|
+
if average is None:
|
|
84
|
+
average = "binary" if num_classes_references == 2 else "weighted"
|
|
85
|
+
|
|
86
|
+
accuracy = accuracy_score(references, predictions)
|
|
87
|
+
f1 = f1_score(references, predictions, average=average)
|
|
88
|
+
loss = log_loss(references, probabilities)
|
|
89
|
+
|
|
90
|
+
if num_classes_references == num_classes_predictions:
|
|
91
|
+
# special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
|
|
92
|
+
if num_classes_references == 2:
|
|
93
|
+
roc_auc = roc_auc_score(references, probabilities[:, 1])
|
|
94
|
+
precisions, recalls, _ = calculate_pr_curve(references, probabilities[:, 1])
|
|
95
|
+
pr_auc = auc(recalls, precisions)
|
|
96
|
+
else:
|
|
97
|
+
roc_auc = roc_auc_score(references, probabilities, multi_class=multi_class)
|
|
98
|
+
pr_auc = None
|
|
99
|
+
else:
|
|
100
|
+
roc_auc = None
|
|
101
|
+
pr_auc = None
|
|
102
|
+
|
|
103
|
+
return {
|
|
104
|
+
"accuracy": float(accuracy),
|
|
105
|
+
"f1_score": float(f1),
|
|
106
|
+
"roc_auc": float(roc_auc) if roc_auc is not None else None,
|
|
107
|
+
"pr_auc": float(pr_auc) if pr_auc is not None else None,
|
|
108
|
+
"log_loss": float(loss),
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def calculate_pr_curve(
|
|
113
|
+
references: NDArray[np.int64],
|
|
114
|
+
probabilities: NDArray[np.float32],
|
|
115
|
+
max_length: int = 100,
|
|
116
|
+
) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
|
|
117
|
+
if probabilities.ndim == 1:
|
|
118
|
+
probabilities_slice = probabilities
|
|
119
|
+
elif probabilities.ndim == 2:
|
|
120
|
+
probabilities_slice = probabilities[:, 1]
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
123
|
+
|
|
124
|
+
if len(probabilities_slice) != len(references):
|
|
125
|
+
raise ValueError("Probabilities and references must have the same length")
|
|
126
|
+
|
|
127
|
+
precisions, recalls, thresholds = sklearn_precision_recall_curve(references, probabilities_slice)
|
|
128
|
+
|
|
129
|
+
# Convert all arrays to float32 immediately after getting them
|
|
130
|
+
precisions = precisions.astype(np.float32)
|
|
131
|
+
recalls = recalls.astype(np.float32)
|
|
132
|
+
thresholds = thresholds.astype(np.float32)
|
|
133
|
+
|
|
134
|
+
# Concatenate with 0 to include the lowest threshold
|
|
135
|
+
thresholds = np.concatenate(([0], thresholds))
|
|
136
|
+
|
|
137
|
+
# Sort by threshold
|
|
138
|
+
sorted_indices = np.argsort(thresholds)
|
|
139
|
+
thresholds = thresholds[sorted_indices]
|
|
140
|
+
precisions = precisions[sorted_indices]
|
|
141
|
+
recalls = recalls[sorted_indices]
|
|
142
|
+
|
|
143
|
+
if len(precisions) > max_length:
|
|
144
|
+
new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
|
|
145
|
+
new_precisions = np.interp(new_thresholds, thresholds, precisions)
|
|
146
|
+
new_recalls = np.interp(new_thresholds, thresholds, recalls)
|
|
147
|
+
thresholds = new_thresholds
|
|
148
|
+
precisions = new_precisions
|
|
149
|
+
recalls = new_recalls
|
|
150
|
+
|
|
151
|
+
return precisions.astype(np.float32), recalls.astype(np.float32), thresholds.astype(np.float32)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def calculate_roc_curve(
|
|
155
|
+
references: NDArray[np.int64],
|
|
156
|
+
probabilities: NDArray[np.float32],
|
|
157
|
+
max_length: int = 100,
|
|
158
|
+
) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
|
|
159
|
+
if probabilities.ndim == 1:
|
|
160
|
+
probabilities_slice = probabilities
|
|
161
|
+
elif probabilities.ndim == 2:
|
|
162
|
+
probabilities_slice = probabilities[:, 1]
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
165
|
+
|
|
166
|
+
if len(probabilities_slice) != len(references):
|
|
167
|
+
raise ValueError("Probabilities and references must have the same length")
|
|
168
|
+
|
|
169
|
+
# Convert probabilities to float32 before calling sklearn_roc_curve
|
|
170
|
+
probabilities_slice = probabilities_slice.astype(np.float32)
|
|
171
|
+
fpr, tpr, thresholds = sklearn_roc_curve(references, probabilities_slice)
|
|
172
|
+
|
|
173
|
+
# Convert all arrays to float32 immediately after getting them
|
|
174
|
+
fpr = fpr.astype(np.float32)
|
|
175
|
+
tpr = tpr.astype(np.float32)
|
|
176
|
+
thresholds = thresholds.astype(np.float32)
|
|
177
|
+
|
|
178
|
+
# We set the first threshold to 1.0 instead of inf for reasonable values in interpolation
|
|
179
|
+
thresholds[0] = 1.0
|
|
180
|
+
|
|
181
|
+
# Sort by threshold
|
|
182
|
+
sorted_indices = np.argsort(thresholds)
|
|
183
|
+
thresholds = thresholds[sorted_indices]
|
|
184
|
+
fpr = fpr[sorted_indices]
|
|
185
|
+
tpr = tpr[sorted_indices]
|
|
186
|
+
|
|
187
|
+
if len(fpr) > max_length:
|
|
188
|
+
new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
|
|
189
|
+
new_fpr = np.interp(new_thresholds, thresholds, fpr)
|
|
190
|
+
new_tpr = np.interp(new_thresholds, thresholds, tpr)
|
|
191
|
+
thresholds = new_thresholds
|
|
192
|
+
fpr = new_fpr
|
|
193
|
+
tpr = new_tpr
|
|
194
|
+
|
|
195
|
+
return fpr.astype(np.float32), tpr.astype(np.float32), thresholds.astype(np.float32)
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""
|
|
2
|
+
IMPORTANT:
|
|
3
|
+
- This is a shared file between OrcaLib and the Orca SDK.
|
|
4
|
+
- Please ensure that it does not have any dependencies on the OrcaLib code.
|
|
5
|
+
- Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Literal
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pytest
|
|
12
|
+
|
|
13
|
+
from .metrics import (
|
|
14
|
+
EvalPrediction,
|
|
15
|
+
calculate_pr_curve,
|
|
16
|
+
calculate_roc_curve,
|
|
17
|
+
classification_scores,
|
|
18
|
+
compute_classifier_metrics,
|
|
19
|
+
softmax,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_binary_metrics():
|
|
24
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
25
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
|
|
26
|
+
|
|
27
|
+
metrics = classification_scores(y_true, y_score)
|
|
28
|
+
|
|
29
|
+
assert metrics["accuracy"] == 0.8
|
|
30
|
+
assert metrics["f1_score"] == 0.8
|
|
31
|
+
assert metrics["roc_auc"] is not None
|
|
32
|
+
assert metrics["roc_auc"] > 0.8
|
|
33
|
+
assert metrics["roc_auc"] < 1.0
|
|
34
|
+
assert metrics["pr_auc"] is not None
|
|
35
|
+
assert metrics["pr_auc"] > 0.8
|
|
36
|
+
assert metrics["pr_auc"] < 1.0
|
|
37
|
+
assert metrics["log_loss"] is not None
|
|
38
|
+
assert metrics["log_loss"] > 0.0
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_multiclass_metrics_with_2_classes():
|
|
42
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
43
|
+
y_score = np.array([[0.9, 0.1], [0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
|
|
44
|
+
|
|
45
|
+
metrics = classification_scores(y_true, y_score)
|
|
46
|
+
|
|
47
|
+
assert metrics["accuracy"] == 0.8
|
|
48
|
+
assert metrics["f1_score"] == 0.8
|
|
49
|
+
assert metrics["roc_auc"] is not None
|
|
50
|
+
assert metrics["roc_auc"] > 0.8
|
|
51
|
+
assert metrics["roc_auc"] < 1.0
|
|
52
|
+
assert metrics["pr_auc"] is not None
|
|
53
|
+
assert metrics["pr_auc"] > 0.8
|
|
54
|
+
assert metrics["pr_auc"] < 1.0
|
|
55
|
+
assert metrics["log_loss"] is not None
|
|
56
|
+
assert metrics["log_loss"] > 0.0
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.mark.parametrize(
|
|
60
|
+
"average, multiclass",
|
|
61
|
+
[("micro", "ovr"), ("macro", "ovr"), ("weighted", "ovr"), ("micro", "ovo"), ("macro", "ovo"), ("weighted", "ovo")],
|
|
62
|
+
)
|
|
63
|
+
def test_multiclass_metrics_with_3_classes(
|
|
64
|
+
average: Literal["micro", "macro", "weighted"], multiclass: Literal["ovr", "ovo"]
|
|
65
|
+
):
|
|
66
|
+
y_true = np.array([0, 1, 1, 0, 2])
|
|
67
|
+
y_score = np.array([[0.9, 0.1, 0.0], [0.1, 0.9, 0.0], [0.2, 0.8, 0.0], [0.7, 0.3, 0.0], [0.0, 0.0, 1.0]])
|
|
68
|
+
|
|
69
|
+
metrics = classification_scores(y_true, y_score, average=average, multi_class=multiclass)
|
|
70
|
+
|
|
71
|
+
assert metrics["accuracy"] == 1.0
|
|
72
|
+
assert metrics["f1_score"] == 1.0
|
|
73
|
+
assert metrics["roc_auc"] is not None
|
|
74
|
+
assert metrics["roc_auc"] > 0.8
|
|
75
|
+
assert metrics["pr_auc"] is None
|
|
76
|
+
assert metrics["log_loss"] is not None
|
|
77
|
+
assert metrics["log_loss"] > 0.0
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def test_does_not_modify_logits_unless_necessary():
|
|
81
|
+
logits = np.array([[0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
|
|
82
|
+
references = np.array([0, 1, 0, 1])
|
|
83
|
+
metrics = compute_classifier_metrics(EvalPrediction(logits, references))
|
|
84
|
+
assert metrics["log_loss"] == classification_scores(references, logits)["log_loss"]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def test_normalizes_logits_if_necessary():
|
|
88
|
+
logits = np.array([[1.2, 3.9], [1.2, 5.8], [1.2, 2.7], [1.2, 1.3]])
|
|
89
|
+
references = np.array([0, 1, 0, 1])
|
|
90
|
+
metrics = compute_classifier_metrics(EvalPrediction(logits, references))
|
|
91
|
+
assert (
|
|
92
|
+
metrics["log_loss"] == classification_scores(references, logits / logits.sum(axis=1, keepdims=True))["log_loss"]
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_softmaxes_logits_if_necessary():
|
|
97
|
+
logits = np.array([[-1.2, 3.9], [1.2, -5.8], [1.2, 2.7], [1.2, 1.3]])
|
|
98
|
+
references = np.array([0, 1, 0, 1])
|
|
99
|
+
metrics = compute_classifier_metrics(EvalPrediction(logits, references))
|
|
100
|
+
assert metrics["log_loss"] == classification_scores(references, softmax(logits))["log_loss"]
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_precision_recall_curve():
|
|
104
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
105
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
106
|
+
|
|
107
|
+
precision, recall, thresholds = calculate_pr_curve(y_true, y_score)
|
|
108
|
+
assert precision is not None
|
|
109
|
+
assert recall is not None
|
|
110
|
+
assert thresholds is not None
|
|
111
|
+
|
|
112
|
+
assert len(precision) == len(recall) == len(thresholds) == 6
|
|
113
|
+
assert precision[0] == 0.6
|
|
114
|
+
assert recall[0] == 1.0
|
|
115
|
+
assert precision[-1] == 1.0
|
|
116
|
+
assert recall[-1] == 0.0
|
|
117
|
+
|
|
118
|
+
# test that thresholds are sorted
|
|
119
|
+
assert np.all(np.diff(thresholds) >= 0)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def test_roc_curve():
|
|
123
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
124
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
125
|
+
|
|
126
|
+
fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score)
|
|
127
|
+
assert fpr is not None
|
|
128
|
+
assert tpr is not None
|
|
129
|
+
assert thresholds is not None
|
|
130
|
+
|
|
131
|
+
assert len(fpr) == len(tpr) == len(thresholds) == 6
|
|
132
|
+
assert fpr[0] == 1.0
|
|
133
|
+
assert tpr[0] == 1.0
|
|
134
|
+
assert fpr[-1] == 0.0
|
|
135
|
+
assert tpr[-1] == 0.0
|
|
136
|
+
|
|
137
|
+
# test that thresholds are sorted
|
|
138
|
+
assert np.all(np.diff(thresholds) >= 0)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def test_precision_recall_curve_max_length():
|
|
142
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
143
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
144
|
+
|
|
145
|
+
precision, recall, thresholds = calculate_pr_curve(y_true, y_score, max_length=5)
|
|
146
|
+
assert len(precision) == len(recall) == len(thresholds) == 5
|
|
147
|
+
|
|
148
|
+
assert precision[0] == 0.6
|
|
149
|
+
assert recall[0] == 1.0
|
|
150
|
+
assert precision[-1] == 1.0
|
|
151
|
+
assert recall[-1] == 0.0
|
|
152
|
+
|
|
153
|
+
# test that thresholds are sorted
|
|
154
|
+
assert np.all(np.diff(thresholds) >= 0)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_roc_curve_max_length():
|
|
158
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
159
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
160
|
+
|
|
161
|
+
fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score, max_length=5)
|
|
162
|
+
assert len(fpr) == len(tpr) == len(thresholds) == 5
|
|
163
|
+
assert fpr[0] == 1.0
|
|
164
|
+
assert tpr[0] == 1.0
|
|
165
|
+
assert fpr[-1] == 0.0
|
|
166
|
+
assert tpr[-1] == 0.0
|
|
167
|
+
|
|
168
|
+
# test that thresholds are sorted
|
|
169
|
+
assert np.all(np.diff(thresholds) >= 0)
|