orca-sdk 0.1.0__tar.gz → 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk-0.1.2/PKG-INFO +97 -0
- orca_sdk-0.1.2/README.md +72 -0
- orca_sdk-0.1.2/orca_sdk/__init__.py +30 -0
- orca_sdk-0.1.2/orca_sdk/_shared/__init__.py +10 -0
- orca_sdk-0.1.2/orca_sdk/_shared/metrics.py +393 -0
- orca_sdk-0.1.2/orca_sdk/_shared/metrics_test.py +273 -0
- {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/_utils/analysis_ui.py +13 -11
- {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/_utils/analysis_ui_style.css +0 -3
- orca_sdk-0.1.2/orca_sdk/_utils/auth.py +61 -0
- {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/_utils/auth_test.py +1 -1
- {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/_utils/data_parsing.py +28 -2
- {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/_utils/data_parsing_test.py +15 -15
- orca_sdk-0.1.2/orca_sdk/_utils/pagination.py +126 -0
- orca_sdk-0.1.2/orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk-0.1.2/orca_sdk/_utils/prediction_result_ui.py +110 -0
- orca_sdk-0.1.2/orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk-0.1.2/orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk-0.1.2/orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk-0.1.2/orca_sdk/classification_model.py +809 -0
- orca_sdk-0.1.2/orca_sdk/classification_model_test.py +496 -0
- orca_sdk-0.1.2/orca_sdk/client.py +3747 -0
- orca_sdk-0.1.2/orca_sdk/conftest.py +262 -0
- orca_sdk-0.1.2/orca_sdk/credentials.py +177 -0
- orca_sdk-0.1.0/orca_sdk/orca_credentials_test.py → orca_sdk-0.1.2/orca_sdk/credentials_test.py +21 -1
- orca_sdk-0.1.2/orca_sdk/datasource.py +524 -0
- orca_sdk-0.1.2/orca_sdk/datasource_test.py +337 -0
- orca_sdk-0.1.2/orca_sdk/embedding_model.py +690 -0
- {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/embedding_model_test.py +40 -14
- orca_sdk-0.1.2/orca_sdk/job.py +337 -0
- orca_sdk-0.1.2/orca_sdk/job_test.py +108 -0
- orca_sdk-0.1.2/orca_sdk/memoryset.py +2190 -0
- orca_sdk-0.1.2/orca_sdk/memoryset_test.py +510 -0
- orca_sdk-0.1.2/orca_sdk/regression_model.py +684 -0
- orca_sdk-0.1.2/orca_sdk/regression_model_test.py +369 -0
- orca_sdk-0.1.2/orca_sdk/telemetry.py +692 -0
- orca_sdk-0.1.2/orca_sdk/telemetry_test.py +119 -0
- orca_sdk-0.1.2/pyproject.toml +84 -0
- orca_sdk-0.1.0/PKG-INFO +0 -39
- orca_sdk-0.1.0/README.md +0 -15
- orca_sdk-0.1.0/orca_sdk/__init__.py +0 -19
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/__init__.py +0 -193
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/__init__.py +0 -159
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/column_info.py +0 -114
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/embed_request.py +0 -127
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/filter_item.py +0 -231
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk-0.1.0/orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk-0.1.0/orca_sdk/_utils/__init__.py +0 -0
- orca_sdk-0.1.0/orca_sdk/_utils/auth.py +0 -63
- orca_sdk-0.1.0/orca_sdk/_utils/prediction_result_ui.py +0 -64
- orca_sdk-0.1.0/orca_sdk/_utils/task.py +0 -73
- orca_sdk-0.1.0/orca_sdk/classification_model.py +0 -499
- orca_sdk-0.1.0/orca_sdk/classification_model_test.py +0 -266
- orca_sdk-0.1.0/orca_sdk/conftest.py +0 -117
- orca_sdk-0.1.0/orca_sdk/datasource.py +0 -333
- orca_sdk-0.1.0/orca_sdk/datasource_test.py +0 -95
- orca_sdk-0.1.0/orca_sdk/embedding_model.py +0 -336
- orca_sdk-0.1.0/orca_sdk/labeled_memoryset.py +0 -1154
- orca_sdk-0.1.0/orca_sdk/labeled_memoryset_test.py +0 -271
- orca_sdk-0.1.0/orca_sdk/orca_credentials.py +0 -75
- orca_sdk-0.1.0/orca_sdk/telemetry.py +0 -386
- orca_sdk-0.1.0/orca_sdk/telemetry_test.py +0 -100
- orca_sdk-0.1.0/pyproject.toml +0 -71
- {orca_sdk-0.1.0/orca_sdk/_generated_api_client/api/auth → orca_sdk-0.1.2/orca_sdk/_utils}/__init__.py +0 -0
- {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/_utils/common.py +0 -0
- {orca_sdk-0.1.0 → orca_sdk-0.1.2}/orca_sdk/_utils/prediction_result_ui.css +0 -0
orca_sdk-0.1.2/PKG-INFO
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: orca_sdk
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: SDK for interacting with Orca Services
|
|
5
|
+
License-Expression: Apache-2.0
|
|
6
|
+
Author: Orca DB Inc.
|
|
7
|
+
Author-email: dev-rel@orcadb.ai
|
|
8
|
+
Requires-Python: >=3.11,<3.14
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
13
|
+
Requires-Dist: datasets (>=3.1.0,<4)
|
|
14
|
+
Requires-Dist: gradio (>=5.44.1,<6)
|
|
15
|
+
Requires-Dist: httpx (>=0.28.1)
|
|
16
|
+
Requires-Dist: httpx-retries (>=0.4.3,<0.5.0)
|
|
17
|
+
Requires-Dist: numpy (>=2.1.0,<3)
|
|
18
|
+
Requires-Dist: pandas (>=2.2.3,<3)
|
|
19
|
+
Requires-Dist: pyarrow (>=18.0.0,<19)
|
|
20
|
+
Requires-Dist: python-dotenv (>=1.1.0)
|
|
21
|
+
Requires-Dist: scikit-learn (>=1.6.1,<2)
|
|
22
|
+
Requires-Dist: torch (>=2.8.0,<3)
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
|
|
25
|
+
<!--
|
|
26
|
+
IMPORTANT NOTE:
|
|
27
|
+
- This file will get rendered in the public facing PyPi page here: https://pypi.org/project/orca_sdk/
|
|
28
|
+
- Only content suitable for public consumption should be placed in this file everything else should go into CONTRIBUTING.md
|
|
29
|
+
-->
|
|
30
|
+
|
|
31
|
+
# OrcaSDK
|
|
32
|
+
|
|
33
|
+
OrcaSDK is a Python library for building and using retrieval-augmented models with [OrcaCloud](https://orcadb.ai). It enables you to create, deploy, and maintain models that can adapt to changing circumstances without retraining by accessing external data called "memories."
|
|
34
|
+
|
|
35
|
+
## Documentation
|
|
36
|
+
|
|
37
|
+
You can find the documentation for all things Orca at [docs.orcadb.ai](https://docs.orcadb.ai). This includes tutorials, how-to guides, and the full interface reference for OrcaSDK.
|
|
38
|
+
|
|
39
|
+
## Features
|
|
40
|
+
|
|
41
|
+
- **Labeled Memorysets**: Store and manage labeled examples that your models can use to guide predictions
|
|
42
|
+
- **Classification Models**: Build retrieval-augmented classification models that adapt to new data without retraining
|
|
43
|
+
- **Embedding Models**: Use pre-trained or fine-tuned embedding models to represent your data
|
|
44
|
+
- **Telemetry**: Collect feedback and monitor memory usage to optimize model performance
|
|
45
|
+
- **Datasources**: Easily ingest data from various sources into your memorysets
|
|
46
|
+
|
|
47
|
+
## Installation
|
|
48
|
+
|
|
49
|
+
OrcaSDK is compatible with Python 3.10 or higher and is available on [PyPI](https://pypi.org/project/orca_sdk/). You can install it with your favorite python package manager:
|
|
50
|
+
|
|
51
|
+
- Pip: `pip install orca_sdk`
|
|
52
|
+
- Conda: `conda install orca_sdk`
|
|
53
|
+
- Poetry: `poetry add orca_sdk`
|
|
54
|
+
|
|
55
|
+
## Quick Start
|
|
56
|
+
|
|
57
|
+
```python
|
|
58
|
+
from dotenv import load_dotenv
|
|
59
|
+
from orca_sdk import OrcaCredentials, LabeledMemoryset, ClassificationModel
|
|
60
|
+
|
|
61
|
+
# Load your API key from environment variables
|
|
62
|
+
load_dotenv()
|
|
63
|
+
assert OrcaCredentials.is_authenticated()
|
|
64
|
+
|
|
65
|
+
# Create a labeled memoryset
|
|
66
|
+
memoryset = LabeledMemoryset.from_disk("my_memoryset", "./data.jsonl")
|
|
67
|
+
|
|
68
|
+
# Create a classification model using the memoryset
|
|
69
|
+
model = ClassificationModel("my_model", memoryset)
|
|
70
|
+
|
|
71
|
+
# Make predictions
|
|
72
|
+
prediction = model.predict("my input")
|
|
73
|
+
|
|
74
|
+
# Get Action Recommendation
|
|
75
|
+
action, rationale = prediction.recommend_action()
|
|
76
|
+
print(f"Recommended action: {action}")
|
|
77
|
+
print(f"Rationale: {rationale}")
|
|
78
|
+
|
|
79
|
+
# Generate and add synthetic memory suggestions
|
|
80
|
+
if action == "add_memories":
|
|
81
|
+
suggestions = prediction.generate_memory_suggestions(num_memories=3)
|
|
82
|
+
|
|
83
|
+
# Review suggestions
|
|
84
|
+
for suggestion in suggestions:
|
|
85
|
+
print(f"Suggested: '{suggestion['value']}' -> {suggestion['label']}")
|
|
86
|
+
|
|
87
|
+
# Add suggestions to memoryset
|
|
88
|
+
model.memoryset.insert(suggestions)
|
|
89
|
+
print(f"Added {len(suggestions)} new memories to improve model performance!")
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
For a more detailed walkthrough, check out our [Quick Start Guide](https://docs.orcadb.ai/quickstart-sdk/).
|
|
93
|
+
|
|
94
|
+
## Support
|
|
95
|
+
|
|
96
|
+
If you have any questions, please reach out to us at support@orcadb.ai.
|
|
97
|
+
|
orca_sdk-0.1.2/README.md
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
<!--
|
|
2
|
+
IMPORTANT NOTE:
|
|
3
|
+
- This file will get rendered in the public facing PyPi page here: https://pypi.org/project/orca_sdk/
|
|
4
|
+
- Only content suitable for public consumption should be placed in this file everything else should go into CONTRIBUTING.md
|
|
5
|
+
-->
|
|
6
|
+
|
|
7
|
+
# OrcaSDK
|
|
8
|
+
|
|
9
|
+
OrcaSDK is a Python library for building and using retrieval-augmented models with [OrcaCloud](https://orcadb.ai). It enables you to create, deploy, and maintain models that can adapt to changing circumstances without retraining by accessing external data called "memories."
|
|
10
|
+
|
|
11
|
+
## Documentation
|
|
12
|
+
|
|
13
|
+
You can find the documentation for all things Orca at [docs.orcadb.ai](https://docs.orcadb.ai). This includes tutorials, how-to guides, and the full interface reference for OrcaSDK.
|
|
14
|
+
|
|
15
|
+
## Features
|
|
16
|
+
|
|
17
|
+
- **Labeled Memorysets**: Store and manage labeled examples that your models can use to guide predictions
|
|
18
|
+
- **Classification Models**: Build retrieval-augmented classification models that adapt to new data without retraining
|
|
19
|
+
- **Embedding Models**: Use pre-trained or fine-tuned embedding models to represent your data
|
|
20
|
+
- **Telemetry**: Collect feedback and monitor memory usage to optimize model performance
|
|
21
|
+
- **Datasources**: Easily ingest data from various sources into your memorysets
|
|
22
|
+
|
|
23
|
+
## Installation
|
|
24
|
+
|
|
25
|
+
OrcaSDK is compatible with Python 3.10 or higher and is available on [PyPI](https://pypi.org/project/orca_sdk/). You can install it with your favorite python package manager:
|
|
26
|
+
|
|
27
|
+
- Pip: `pip install orca_sdk`
|
|
28
|
+
- Conda: `conda install orca_sdk`
|
|
29
|
+
- Poetry: `poetry add orca_sdk`
|
|
30
|
+
|
|
31
|
+
## Quick Start
|
|
32
|
+
|
|
33
|
+
```python
|
|
34
|
+
from dotenv import load_dotenv
|
|
35
|
+
from orca_sdk import OrcaCredentials, LabeledMemoryset, ClassificationModel
|
|
36
|
+
|
|
37
|
+
# Load your API key from environment variables
|
|
38
|
+
load_dotenv()
|
|
39
|
+
assert OrcaCredentials.is_authenticated()
|
|
40
|
+
|
|
41
|
+
# Create a labeled memoryset
|
|
42
|
+
memoryset = LabeledMemoryset.from_disk("my_memoryset", "./data.jsonl")
|
|
43
|
+
|
|
44
|
+
# Create a classification model using the memoryset
|
|
45
|
+
model = ClassificationModel("my_model", memoryset)
|
|
46
|
+
|
|
47
|
+
# Make predictions
|
|
48
|
+
prediction = model.predict("my input")
|
|
49
|
+
|
|
50
|
+
# Get Action Recommendation
|
|
51
|
+
action, rationale = prediction.recommend_action()
|
|
52
|
+
print(f"Recommended action: {action}")
|
|
53
|
+
print(f"Rationale: {rationale}")
|
|
54
|
+
|
|
55
|
+
# Generate and add synthetic memory suggestions
|
|
56
|
+
if action == "add_memories":
|
|
57
|
+
suggestions = prediction.generate_memory_suggestions(num_memories=3)
|
|
58
|
+
|
|
59
|
+
# Review suggestions
|
|
60
|
+
for suggestion in suggestions:
|
|
61
|
+
print(f"Suggested: '{suggestion['value']}' -> {suggestion['label']}")
|
|
62
|
+
|
|
63
|
+
# Add suggestions to memoryset
|
|
64
|
+
model.memoryset.insert(suggestions)
|
|
65
|
+
print(f"Added {len(suggestions)} new memories to improve model performance!")
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
For a more detailed walkthrough, check out our [Quick Start Guide](https://docs.orcadb.ai/quickstart-sdk/).
|
|
69
|
+
|
|
70
|
+
## Support
|
|
71
|
+
|
|
72
|
+
If you have any questions, please reach out to us at support@orcadb.ai.
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OrcaSDK is a Python library for building and using retrieval augmented models in the OrcaCloud.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from ._utils.common import UNSET, CreateMode, DropMode
|
|
6
|
+
from .classification_model import ClassificationMetrics, ClassificationModel
|
|
7
|
+
from .client import orca_api
|
|
8
|
+
from .credentials import OrcaCredentials
|
|
9
|
+
from .datasource import Datasource
|
|
10
|
+
from .embedding_model import (
|
|
11
|
+
FinetunedEmbeddingModel,
|
|
12
|
+
PretrainedEmbeddingModel,
|
|
13
|
+
PretrainedEmbeddingModelName,
|
|
14
|
+
)
|
|
15
|
+
from .job import Job, Status
|
|
16
|
+
from .memoryset import (
|
|
17
|
+
CascadingEditSuggestion,
|
|
18
|
+
FilterItemTuple,
|
|
19
|
+
LabeledMemory,
|
|
20
|
+
LabeledMemoryLookup,
|
|
21
|
+
LabeledMemoryset,
|
|
22
|
+
ScoredMemory,
|
|
23
|
+
ScoredMemoryLookup,
|
|
24
|
+
ScoredMemoryset,
|
|
25
|
+
)
|
|
26
|
+
from .regression_model import RegressionModel
|
|
27
|
+
from .telemetry import ClassificationPrediction, FeedbackCategory, RegressionPrediction
|
|
28
|
+
|
|
29
|
+
# only specify things that should show up on the root page of the reference docs because they are in private modules
|
|
30
|
+
__all__ = ["UNSET", "CreateMode", "DropMode"]
|
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains metrics for usage with the Hugging Face Trainer.
|
|
3
|
+
|
|
4
|
+
IMPORTANT:
|
|
5
|
+
- This is a shared file between OrcaLib and the OrcaSDK.
|
|
6
|
+
- Please ensure that it does not have any dependencies on the OrcaLib code.
|
|
7
|
+
- Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
|
|
8
|
+
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Any, Literal, TypedDict, cast
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import sklearn.metrics
|
|
16
|
+
from numpy.typing import NDArray
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# we don't want to depend on scipy or torch in orca_sdk
|
|
20
|
+
def softmax(logits: np.ndarray, axis: int = -1) -> np.ndarray:
|
|
21
|
+
shifted = logits - np.max(logits, axis=axis, keepdims=True)
|
|
22
|
+
exps = np.exp(shifted)
|
|
23
|
+
return exps / np.sum(exps, axis=axis, keepdims=True)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# We don't want to depend on transformers just for the eval_pred type in orca_sdk
|
|
27
|
+
def transform_eval_pred(eval_pred: Any) -> tuple[NDArray, NDArray[np.float32]]:
|
|
28
|
+
# convert results from Trainer compute_metrics param for use in calculate_classification_metrics
|
|
29
|
+
logits, references = eval_pred # transformers.trainer_utils.EvalPrediction
|
|
30
|
+
if isinstance(logits, tuple):
|
|
31
|
+
logits = logits[0]
|
|
32
|
+
if not isinstance(logits, np.ndarray):
|
|
33
|
+
raise ValueError("Logits must be a numpy array")
|
|
34
|
+
if not isinstance(references, np.ndarray):
|
|
35
|
+
raise ValueError(
|
|
36
|
+
"Multiple label columns found, use the `label_names` training argument to specify which one to use"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
return (references, logits)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PRCurve(TypedDict):
|
|
43
|
+
thresholds: list[float]
|
|
44
|
+
precisions: list[float]
|
|
45
|
+
recalls: list[float]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def calculate_pr_curve(
|
|
49
|
+
references: NDArray[np.int64],
|
|
50
|
+
probabilities: NDArray[np.float32],
|
|
51
|
+
max_length: int = 100,
|
|
52
|
+
) -> PRCurve:
|
|
53
|
+
if probabilities.ndim == 1:
|
|
54
|
+
probabilities_slice = probabilities
|
|
55
|
+
elif probabilities.ndim == 2:
|
|
56
|
+
probabilities_slice = probabilities[:, 1]
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
59
|
+
|
|
60
|
+
if len(probabilities_slice) != len(references):
|
|
61
|
+
raise ValueError("Probabilities and references must have the same length")
|
|
62
|
+
|
|
63
|
+
precisions, recalls, thresholds = sklearn.metrics.precision_recall_curve(references, probabilities_slice)
|
|
64
|
+
|
|
65
|
+
# Convert all arrays to float32 immediately after getting them
|
|
66
|
+
precisions = precisions.astype(np.float32)
|
|
67
|
+
recalls = recalls.astype(np.float32)
|
|
68
|
+
thresholds = thresholds.astype(np.float32)
|
|
69
|
+
|
|
70
|
+
# Concatenate with 0 to include the lowest threshold
|
|
71
|
+
thresholds = np.concatenate(([0], thresholds))
|
|
72
|
+
|
|
73
|
+
# Sort by threshold
|
|
74
|
+
sorted_indices = np.argsort(thresholds)
|
|
75
|
+
thresholds = thresholds[sorted_indices]
|
|
76
|
+
precisions = precisions[sorted_indices]
|
|
77
|
+
recalls = recalls[sorted_indices]
|
|
78
|
+
|
|
79
|
+
if len(precisions) > max_length:
|
|
80
|
+
new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
|
|
81
|
+
new_precisions = np.interp(new_thresholds, thresholds, precisions)
|
|
82
|
+
new_recalls = np.interp(new_thresholds, thresholds, recalls)
|
|
83
|
+
thresholds = new_thresholds
|
|
84
|
+
precisions = new_precisions
|
|
85
|
+
recalls = new_recalls
|
|
86
|
+
|
|
87
|
+
return PRCurve(
|
|
88
|
+
thresholds=cast(list[float], thresholds.tolist()),
|
|
89
|
+
precisions=cast(list[float], precisions.tolist()),
|
|
90
|
+
recalls=cast(list[float], recalls.tolist()),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class ROCCurve(TypedDict):
|
|
95
|
+
thresholds: list[float]
|
|
96
|
+
false_positive_rates: list[float]
|
|
97
|
+
true_positive_rates: list[float]
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def calculate_roc_curve(
|
|
101
|
+
references: NDArray[np.int64],
|
|
102
|
+
probabilities: NDArray[np.float32],
|
|
103
|
+
max_length: int = 100,
|
|
104
|
+
) -> ROCCurve:
|
|
105
|
+
if probabilities.ndim == 1:
|
|
106
|
+
probabilities_slice = probabilities
|
|
107
|
+
elif probabilities.ndim == 2:
|
|
108
|
+
probabilities_slice = probabilities[:, 1]
|
|
109
|
+
else:
|
|
110
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
111
|
+
|
|
112
|
+
if len(probabilities_slice) != len(references):
|
|
113
|
+
raise ValueError("Probabilities and references must have the same length")
|
|
114
|
+
|
|
115
|
+
# Convert probabilities to float32 before calling sklearn_roc_curve
|
|
116
|
+
probabilities_slice = probabilities_slice.astype(np.float32)
|
|
117
|
+
fpr, tpr, thresholds = sklearn.metrics.roc_curve(references, probabilities_slice)
|
|
118
|
+
|
|
119
|
+
# Convert all arrays to float32 immediately after getting them
|
|
120
|
+
fpr = fpr.astype(np.float32)
|
|
121
|
+
tpr = tpr.astype(np.float32)
|
|
122
|
+
thresholds = thresholds.astype(np.float32)
|
|
123
|
+
|
|
124
|
+
# We set the first threshold to 1.0 instead of inf for reasonable values in interpolation
|
|
125
|
+
thresholds[0] = 1.0
|
|
126
|
+
|
|
127
|
+
# Sort by threshold
|
|
128
|
+
sorted_indices = np.argsort(thresholds)
|
|
129
|
+
thresholds = thresholds[sorted_indices]
|
|
130
|
+
fpr = fpr[sorted_indices]
|
|
131
|
+
tpr = tpr[sorted_indices]
|
|
132
|
+
|
|
133
|
+
if len(fpr) > max_length:
|
|
134
|
+
new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
|
|
135
|
+
new_fpr = np.interp(new_thresholds, thresholds, fpr)
|
|
136
|
+
new_tpr = np.interp(new_thresholds, thresholds, tpr)
|
|
137
|
+
thresholds = new_thresholds
|
|
138
|
+
fpr = new_fpr
|
|
139
|
+
tpr = new_tpr
|
|
140
|
+
|
|
141
|
+
return ROCCurve(
|
|
142
|
+
false_positive_rates=cast(list[float], fpr.tolist()),
|
|
143
|
+
true_positive_rates=cast(list[float], tpr.tolist()),
|
|
144
|
+
thresholds=cast(list[float], thresholds.tolist()),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@dataclass
|
|
149
|
+
class ClassificationMetrics:
|
|
150
|
+
coverage: float
|
|
151
|
+
"""Percentage of predictions that are not none"""
|
|
152
|
+
|
|
153
|
+
f1_score: float
|
|
154
|
+
"""F1 score of the predictions"""
|
|
155
|
+
|
|
156
|
+
accuracy: float
|
|
157
|
+
"""Accuracy of the predictions"""
|
|
158
|
+
|
|
159
|
+
loss: float | None
|
|
160
|
+
"""Cross-entropy loss of the logits"""
|
|
161
|
+
|
|
162
|
+
anomaly_score_mean: float | None = None
|
|
163
|
+
"""Mean of anomaly scores across the dataset"""
|
|
164
|
+
|
|
165
|
+
anomaly_score_median: float | None = None
|
|
166
|
+
"""Median of anomaly scores across the dataset"""
|
|
167
|
+
|
|
168
|
+
anomaly_score_variance: float | None = None
|
|
169
|
+
"""Variance of anomaly scores across the dataset"""
|
|
170
|
+
|
|
171
|
+
roc_auc: float | None = None
|
|
172
|
+
"""Receiver operating characteristic area under the curve"""
|
|
173
|
+
|
|
174
|
+
pr_auc: float | None = None
|
|
175
|
+
"""Average precision (area under the curve of the precision-recall curve)"""
|
|
176
|
+
|
|
177
|
+
pr_curve: PRCurve | None = None
|
|
178
|
+
"""Precision-recall curve"""
|
|
179
|
+
|
|
180
|
+
roc_curve: ROCCurve | None = None
|
|
181
|
+
"""Receiver operating characteristic curve"""
|
|
182
|
+
|
|
183
|
+
def __repr__(self) -> str:
|
|
184
|
+
return (
|
|
185
|
+
"ClassificationMetrics({\n"
|
|
186
|
+
+ f" accuracy: {self.accuracy:.4f},\n"
|
|
187
|
+
+ f" f1_score: {self.f1_score:.4f},\n"
|
|
188
|
+
+ (f" roc_auc: {self.roc_auc:.4f},\n" if self.roc_auc else "")
|
|
189
|
+
+ (f" pr_auc: {self.pr_auc:.4f},\n" if self.pr_auc else "")
|
|
190
|
+
+ (
|
|
191
|
+
f" anomaly_score: {self.anomaly_score_mean:.4f} ± {self.anomaly_score_variance:.4f},\n"
|
|
192
|
+
if self.anomaly_score_mean
|
|
193
|
+
else ""
|
|
194
|
+
)
|
|
195
|
+
+ "})"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def calculate_classification_metrics(
|
|
200
|
+
expected_labels: list[int] | NDArray[np.int64],
|
|
201
|
+
logits: list[list[float]] | list[NDArray[np.float32]] | NDArray[np.float32],
|
|
202
|
+
anomaly_scores: list[float] | None = None,
|
|
203
|
+
average: Literal["micro", "macro", "weighted", "binary"] | None = None,
|
|
204
|
+
multi_class: Literal["ovr", "ovo"] = "ovr",
|
|
205
|
+
include_curves: bool = False,
|
|
206
|
+
) -> ClassificationMetrics:
|
|
207
|
+
references = np.array(expected_labels)
|
|
208
|
+
|
|
209
|
+
logits = np.array(logits)
|
|
210
|
+
if logits.ndim == 1:
|
|
211
|
+
if (logits > 1).any() or (logits < 0).any():
|
|
212
|
+
raise ValueError("Logits must be between 0 and 1 for binary classification")
|
|
213
|
+
# convert 1D probabilities (binary) to 2D logits
|
|
214
|
+
logits = np.column_stack([1 - logits, logits])
|
|
215
|
+
probabilities = logits # no need to convert to probabilities
|
|
216
|
+
elif logits.ndim == 2:
|
|
217
|
+
if logits.shape[1] < 2:
|
|
218
|
+
raise ValueError("Use a different metric function for regression tasks")
|
|
219
|
+
if not (logits > 0).all():
|
|
220
|
+
# convert logits to probabilities with softmax if necessary
|
|
221
|
+
probabilities = softmax(logits)
|
|
222
|
+
elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
|
|
223
|
+
# convert logits to probabilities through normalization if necessary
|
|
224
|
+
probabilities = logits / logits.sum(-1, keepdims=True)
|
|
225
|
+
else:
|
|
226
|
+
probabilities = logits
|
|
227
|
+
else:
|
|
228
|
+
raise ValueError("Logits must be 1 or 2 dimensional")
|
|
229
|
+
|
|
230
|
+
predictions = np.argmax(probabilities, axis=-1)
|
|
231
|
+
predictions[np.isnan(probabilities).all(axis=-1)] = -1 # set predictions to -1 for all nan logits
|
|
232
|
+
|
|
233
|
+
num_classes_references = len(set(references))
|
|
234
|
+
num_classes_predictions = len(set(predictions))
|
|
235
|
+
num_none_predictions = np.isnan(probabilities).all(axis=-1).sum()
|
|
236
|
+
coverage = 1 - num_none_predictions / len(probabilities)
|
|
237
|
+
|
|
238
|
+
if average is None:
|
|
239
|
+
average = "binary" if num_classes_references == 2 and num_none_predictions == 0 else "weighted"
|
|
240
|
+
|
|
241
|
+
anomaly_score_mean = float(np.mean(anomaly_scores)) if anomaly_scores else None
|
|
242
|
+
anomaly_score_median = float(np.median(anomaly_scores)) if anomaly_scores else None
|
|
243
|
+
anomaly_score_variance = float(np.var(anomaly_scores)) if anomaly_scores else None
|
|
244
|
+
|
|
245
|
+
accuracy = sklearn.metrics.accuracy_score(references, predictions)
|
|
246
|
+
f1 = sklearn.metrics.f1_score(references, predictions, average=average)
|
|
247
|
+
# Ensure sklearn sees the full class set corresponding to probability columns
|
|
248
|
+
# to avoid errors when y_true does not contain all classes.
|
|
249
|
+
loss = (
|
|
250
|
+
sklearn.metrics.log_loss(
|
|
251
|
+
references,
|
|
252
|
+
probabilities,
|
|
253
|
+
labels=list(range(probabilities.shape[1])),
|
|
254
|
+
)
|
|
255
|
+
if num_none_predictions == 0
|
|
256
|
+
else None
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
if num_classes_references == num_classes_predictions and num_none_predictions == 0:
|
|
260
|
+
# special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
|
|
261
|
+
if num_classes_references == 2:
|
|
262
|
+
roc_auc = sklearn.metrics.roc_auc_score(references, logits[:, 1])
|
|
263
|
+
roc_curve = calculate_roc_curve(references, logits[:, 1]) if include_curves else None
|
|
264
|
+
pr_auc = sklearn.metrics.average_precision_score(references, logits[:, 1])
|
|
265
|
+
pr_curve = calculate_pr_curve(references, logits[:, 1]) if include_curves else None
|
|
266
|
+
else:
|
|
267
|
+
roc_auc = sklearn.metrics.roc_auc_score(references, probabilities, multi_class=multi_class)
|
|
268
|
+
roc_curve = None
|
|
269
|
+
pr_auc = None
|
|
270
|
+
pr_curve = None
|
|
271
|
+
else:
|
|
272
|
+
roc_auc = None
|
|
273
|
+
pr_auc = None
|
|
274
|
+
pr_curve = None
|
|
275
|
+
roc_curve = None
|
|
276
|
+
|
|
277
|
+
return ClassificationMetrics(
|
|
278
|
+
coverage=coverage,
|
|
279
|
+
accuracy=float(accuracy),
|
|
280
|
+
f1_score=float(f1),
|
|
281
|
+
loss=float(loss) if loss is not None else None,
|
|
282
|
+
anomaly_score_mean=anomaly_score_mean,
|
|
283
|
+
anomaly_score_median=anomaly_score_median,
|
|
284
|
+
anomaly_score_variance=anomaly_score_variance,
|
|
285
|
+
roc_auc=float(roc_auc) if roc_auc is not None else None,
|
|
286
|
+
pr_auc=float(pr_auc) if pr_auc is not None else None,
|
|
287
|
+
pr_curve=pr_curve,
|
|
288
|
+
roc_curve=roc_curve,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
@dataclass
|
|
293
|
+
class RegressionMetrics:
|
|
294
|
+
coverage: float
|
|
295
|
+
"""Percentage of predictions that are not none"""
|
|
296
|
+
|
|
297
|
+
mse: float
|
|
298
|
+
"""Mean squared error of the predictions"""
|
|
299
|
+
|
|
300
|
+
rmse: float
|
|
301
|
+
"""Root mean squared error of the predictions"""
|
|
302
|
+
|
|
303
|
+
mae: float
|
|
304
|
+
"""Mean absolute error of the predictions"""
|
|
305
|
+
|
|
306
|
+
r2: float
|
|
307
|
+
"""R-squared score (coefficient of determination) of the predictions"""
|
|
308
|
+
|
|
309
|
+
explained_variance: float
|
|
310
|
+
"""Explained variance score of the predictions"""
|
|
311
|
+
|
|
312
|
+
loss: float
|
|
313
|
+
"""Mean squared error loss of the predictions"""
|
|
314
|
+
|
|
315
|
+
anomaly_score_mean: float | None = None
|
|
316
|
+
"""Mean of anomaly scores across the dataset"""
|
|
317
|
+
|
|
318
|
+
anomaly_score_median: float | None = None
|
|
319
|
+
"""Median of anomaly scores across the dataset"""
|
|
320
|
+
|
|
321
|
+
anomaly_score_variance: float | None = None
|
|
322
|
+
"""Variance of anomaly scores across the dataset"""
|
|
323
|
+
|
|
324
|
+
def __repr__(self) -> str:
|
|
325
|
+
return (
|
|
326
|
+
"RegressionMetrics({\n"
|
|
327
|
+
+ f" mae: {self.mae:.4f},\n"
|
|
328
|
+
+ f" rmse: {self.rmse:.4f},\n"
|
|
329
|
+
+ f" r2: {self.r2:.4f},\n"
|
|
330
|
+
+ (
|
|
331
|
+
f" anomaly_score: {self.anomaly_score_mean:.4f} ± {self.anomaly_score_variance:.4f},\n"
|
|
332
|
+
if self.anomaly_score_mean
|
|
333
|
+
else ""
|
|
334
|
+
)
|
|
335
|
+
+ "})"
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def calculate_regression_metrics(
|
|
340
|
+
expected_scores: NDArray[np.float32] | list[float],
|
|
341
|
+
predicted_scores: NDArray[np.float32] | list[float],
|
|
342
|
+
anomaly_scores: list[float] | None = None,
|
|
343
|
+
) -> RegressionMetrics:
|
|
344
|
+
"""
|
|
345
|
+
Calculate regression metrics for model evaluation.
|
|
346
|
+
|
|
347
|
+
Params:
|
|
348
|
+
references: True target values
|
|
349
|
+
predictions: Predicted values from the model
|
|
350
|
+
anomaly_scores: Optional anomaly scores for each prediction
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
Comprehensive regression metrics including MSE, RMSE, MAE, R², and explained variance
|
|
354
|
+
|
|
355
|
+
Raises:
|
|
356
|
+
ValueError: If predictions and references have different lengths
|
|
357
|
+
"""
|
|
358
|
+
references = np.array(expected_scores)
|
|
359
|
+
predictions = np.array(predicted_scores)
|
|
360
|
+
|
|
361
|
+
if len(predictions) != len(references):
|
|
362
|
+
raise ValueError("Predictions and references must have the same length")
|
|
363
|
+
|
|
364
|
+
anomaly_score_mean = float(np.mean(anomaly_scores)) if anomaly_scores else None
|
|
365
|
+
anomaly_score_median = float(np.median(anomaly_scores)) if anomaly_scores else None
|
|
366
|
+
anomaly_score_variance = float(np.var(anomaly_scores)) if anomaly_scores else None
|
|
367
|
+
|
|
368
|
+
none_prediction_mask = np.isnan(predictions)
|
|
369
|
+
num_none_predictions = none_prediction_mask.sum()
|
|
370
|
+
coverage = 1 - num_none_predictions / len(predictions)
|
|
371
|
+
if num_none_predictions > 0:
|
|
372
|
+
references = references[~none_prediction_mask]
|
|
373
|
+
predictions = predictions[~none_prediction_mask]
|
|
374
|
+
|
|
375
|
+
# Calculate core regression metrics
|
|
376
|
+
mse = float(sklearn.metrics.mean_squared_error(references, predictions))
|
|
377
|
+
rmse = float(np.sqrt(mse))
|
|
378
|
+
mae = float(sklearn.metrics.mean_absolute_error(references, predictions))
|
|
379
|
+
r2 = float(sklearn.metrics.r2_score(references, predictions))
|
|
380
|
+
explained_var = float(sklearn.metrics.explained_variance_score(references, predictions))
|
|
381
|
+
|
|
382
|
+
return RegressionMetrics(
|
|
383
|
+
coverage=coverage,
|
|
384
|
+
mse=mse,
|
|
385
|
+
rmse=rmse,
|
|
386
|
+
mae=mae,
|
|
387
|
+
r2=r2,
|
|
388
|
+
explained_variance=explained_var,
|
|
389
|
+
loss=mse, # For regression, loss is typically MSE
|
|
390
|
+
anomaly_score_mean=anomaly_score_mean,
|
|
391
|
+
anomaly_score_median=anomaly_score_median,
|
|
392
|
+
anomaly_score_variance=anomaly_score_variance,
|
|
393
|
+
)
|