orca-sdk 0.0.97__py3-none-any.whl → 0.0.100__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +1 -0
- orca_sdk/_shared/__init__.py +1 -0
- orca_sdk/_utils/analysis_ui.py +5 -5
- orca_sdk/_utils/auth.py +23 -33
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/classification_model.py +188 -126
- orca_sdk/classification_model_test.py +57 -8
- orca_sdk/client.py +3563 -0
- orca_sdk/conftest.py +10 -0
- orca_sdk/credentials.py +59 -21
- orca_sdk/credentials_test.py +20 -0
- orca_sdk/datasource.py +80 -93
- orca_sdk/datasource_test.py +41 -7
- orca_sdk/embedding_model.py +225 -71
- orca_sdk/embedding_model_test.py +27 -36
- orca_sdk/job.py +49 -45
- orca_sdk/job_test.py +16 -0
- orca_sdk/memoryset.py +340 -353
- orca_sdk/memoryset_test.py +7 -11
- orca_sdk/regression_model.py +120 -111
- orca_sdk/regression_model_test.py +15 -0
- orca_sdk/telemetry.py +162 -139
- {orca_sdk-0.0.97.dist-info → orca_sdk-0.0.100.dist-info}/METADATA +2 -5
- orca_sdk-0.0.100.dist-info/RECORD +40 -0
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -307
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/create_org_plan_auth_org_plan_post.py +0 -168
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/get_org_plan_auth_org_plan_get.py +0 -122
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/auth/update_org_plan_auth_org_plan_put.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_classification_model_classification_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_label_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +0 -183
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_from_content_datasource_post.py +0 -224
- orca_sdk/_generated_api_client/api/datasource/create_datasource_from_files_datasource_upload_post.py +0 -229
- orca_sdk/_generated_api_client/api/datasource/create_embedding_evaluation_datasource_name_or_id_embedding_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +0 -172
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_embedding_evaluation_datasource_name_or_id_embedding_evaluation_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/datasource/list_embedding_evaluations_datasource_name_or_id_embedding_evaluation_get.py +0 -235
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/analyze_memoryset_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/batch_delete_memoryset_batch_delete_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -210
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -186
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -235
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -180
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -212
- orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +0 -195
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -210
- orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +0 -233
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -216
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -205
- orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +0 -183
- orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +0 -150
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -192
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -161
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/regression_model/create_regression_model_regression_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_regression_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/regression_model/list_regression_models_regression_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/regression_model/update_regression_model_regression_model_name_or_id_patch.py +0 -183
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -288
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +0 -168
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/explain_prediction_telemetry_prediction_prediction_id_explanation_get.py +0 -182
- orca_sdk/_generated_api_client/api/telemetry/generate_memory_suggestions_telemetry_prediction_prediction_id_memory_suggestions_post.py +0 -239
- orca_sdk/_generated_api_client/api/telemetry/get_action_recommendation_telemetry_prediction_prediction_id_action_get.py +0 -192
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -180
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +0 -198
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -198
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -345
- orca_sdk/_generated_api_client/models/action_recommendation.py +0 -82
- orca_sdk/_generated_api_client/models/action_recommendation_action.py +0 -11
- orca_sdk/_generated_api_client/models/add_memory_recommendations.py +0 -85
- orca_sdk/_generated_api_client/models/add_memory_suggestion.py +0 -79
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -116
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -137
- orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py +0 -9
- orca_sdk/_generated_api_client/models/base_label_prediction_result.py +0 -130
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/base_score_prediction_result.py +0 -108
- orca_sdk/_generated_api_client/models/body_create_datasource_from_files_datasource_upload_post.py +0 -145
- orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +0 -154
- orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +0 -92
- orca_sdk/_generated_api_client/models/class_representatives.py +0 -92
- orca_sdk/_generated_api_client/models/classification_evaluation_request.py +0 -148
- orca_sdk/_generated_api_client/models/classification_metrics.py +0 -259
- orca_sdk/_generated_api_client/models/classification_model_metadata.py +0 -227
- orca_sdk/_generated_api_client/models/classification_prediction_request.py +0 -220
- orca_sdk/_generated_api_client/models/clone_memoryset_request.py +0 -210
- orca_sdk/_generated_api_client/models/cluster_metrics.py +0 -78
- orca_sdk/_generated_api_client/models/column_info.py +0 -145
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/constraint_violation_error_response.py +0 -81
- orca_sdk/_generated_api_client/models/constraint_violation_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/count_predictions_request.py +0 -195
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -120
- orca_sdk/_generated_api_client/models/create_api_key_request_scope_item.py +0 -9
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -145
- orca_sdk/_generated_api_client/models/create_api_key_response_scope_item.py +0 -9
- orca_sdk/_generated_api_client/models/create_classification_model_request.py +0 -237
- orca_sdk/_generated_api_client/models/create_datasource_from_content_request.py +0 -101
- orca_sdk/_generated_api_client/models/create_memoryset_request.py +0 -365
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +0 -66
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +0 -13
- orca_sdk/_generated_api_client/models/create_org_plan_request.py +0 -73
- orca_sdk/_generated_api_client/models/create_org_plan_request_tier.py +0 -11
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +0 -157
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -156
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/delete_memorysets_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -155
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +0 -205
- orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +0 -197
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +0 -158
- orca_sdk/_generated_api_client/models/embedding_evaluation_result.py +0 -86
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/embedding_model_result.py +0 -123
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -153
- orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +0 -140
- orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_metrics.py +0 -85
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/filter_item.py +0 -239
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -17
- orca_sdk/_generated_api_client/models/filter_item_field_type_1_item_type_0.py +0 -8
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_0.py +0 -8
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -22
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/http_validation_error.py +0 -86
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -81
- orca_sdk/_generated_api_client/models/internal_server_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -210
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -288
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -186
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -128
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -194
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +0 -207
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_feedback_metrics.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -319
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -101
- orca_sdk/_generated_api_client/models/lookup_score_metrics.py +0 -94
- orca_sdk/_generated_api_client/models/memory_metrics.py +0 -263
- orca_sdk/_generated_api_client/models/memory_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -245
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -105
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -182
- orca_sdk/_generated_api_client/models/memoryset_class_patterns_analysis_config.py +0 -79
- orca_sdk/_generated_api_client/models/memoryset_class_patterns_metrics.py +0 -138
- orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config.py +0 -202
- orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_clustering_method.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_partitioning_method.py +0 -10
- orca_sdk/_generated_api_client/models/memoryset_cluster_metrics.py +0 -100
- orca_sdk/_generated_api_client/models/memoryset_duplicate_analysis_config.py +0 -70
- orca_sdk/_generated_api_client/models/memoryset_duplicate_metrics.py +0 -70
- orca_sdk/_generated_api_client/models/memoryset_label_analysis_config.py +0 -70
- orca_sdk/_generated_api_client/models/memoryset_label_metrics.py +0 -116
- orca_sdk/_generated_api_client/models/memoryset_metadata.py +0 -333
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +0 -55
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +0 -13
- orca_sdk/_generated_api_client/models/memoryset_metrics.py +0 -265
- orca_sdk/_generated_api_client/models/memoryset_neighbor_analysis_config.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics.py +0 -76
- orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics_lookup_score_metrics.py +0 -68
- orca_sdk/_generated_api_client/models/memoryset_projection_analysis_config.py +0 -79
- orca_sdk/_generated_api_client/models/memoryset_projection_metrics.py +0 -55
- orca_sdk/_generated_api_client/models/memoryset_update.py +0 -121
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -99
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -23
- orca_sdk/_generated_api_client/models/not_found_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/org_plan.py +0 -99
- orca_sdk/_generated_api_client/models/org_plan_tier.py +0 -11
- orca_sdk/_generated_api_client/models/paginated_task.py +0 -108
- orca_sdk/_generated_api_client/models/paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py +0 -135
- orca_sdk/_generated_api_client/models/pr_curve.py +0 -86
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_0.py +0 -10
- orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/predictive_model_update.py +0 -111
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -115
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -17
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rar_head_type.py +0 -8
- orca_sdk/_generated_api_client/models/regression_evaluation_request.py +0 -148
- orca_sdk/_generated_api_client/models/regression_metrics.py +0 -172
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +0 -191
- orca_sdk/_generated_api_client/models/regression_prediction_request.py +0 -195
- orca_sdk/_generated_api_client/models/roc_curve.py +0 -86
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +0 -196
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +0 -252
- orca_sdk/_generated_api_client/models/scored_memory.py +0 -172
- orca_sdk/_generated_api_client/models/scored_memory_insert.py +0 -128
- orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_lookup.py +0 -180
- orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +0 -193
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -81
- orca_sdk/_generated_api_client/models/service_unavailable_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_0.py +0 -8
- orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_2.py +0 -9
- orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_0.py +0 -8
- orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_1.py +0 -8
- orca_sdk/_generated_api_client/models/telemetry_filter_item.py +0 -217
- orca_sdk/_generated_api_client/models/telemetry_filter_item_op.py +0 -15
- orca_sdk/_generated_api_client/models/telemetry_memories_request.py +0 -181
- orca_sdk/_generated_api_client/models/telemetry_sort_options.py +0 -185
- orca_sdk/_generated_api_client/models/telemetry_sort_options_direction.py +0 -9
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -73
- orca_sdk/_generated_api_client/models/unauthenticated_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -81
- orca_sdk/_generated_api_client/models/unauthorized_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/update_org_plan_request.py +0 -73
- orca_sdk/_generated_api_client/models/update_org_plan_request_tier.py +0 -11
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -133
- orca_sdk/_generated_api_client/models/validation_error.py +0 -99
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk-0.0.97.dist-info/RECORD +0 -309
- {orca_sdk-0.0.97.dist-info → orca_sdk-0.0.100.dist-info}/WHEEL +0 -0
orca_sdk/classification_model.py
CHANGED
|
@@ -1,56 +1,61 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
import os
|
|
5
4
|
from contextlib import contextmanager
|
|
6
5
|
from datetime import datetime
|
|
7
6
|
from typing import Any, Generator, Iterable, Literal, cast, overload
|
|
8
|
-
from uuid import UUID
|
|
9
7
|
|
|
10
8
|
from datasets import Dataset
|
|
11
9
|
|
|
12
|
-
from .
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
get_classification_model_evaluation,
|
|
18
|
-
list_classification_models,
|
|
19
|
-
list_predictions,
|
|
20
|
-
predict_label_gpu,
|
|
21
|
-
record_prediction_feedback,
|
|
22
|
-
update_classification_model,
|
|
23
|
-
)
|
|
24
|
-
from ._generated_api_client.models import (
|
|
25
|
-
ClassificationEvaluationRequest,
|
|
10
|
+
from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
|
|
11
|
+
from ._utils.common import UNSET, CreateMode, DropMode
|
|
12
|
+
from .client import (
|
|
13
|
+
BootstrapClassificationModelMeta,
|
|
14
|
+
BootstrapClassificationModelResult,
|
|
26
15
|
ClassificationModelMetadata,
|
|
27
|
-
ClassificationPredictionRequest,
|
|
28
|
-
CreateClassificationModelRequest,
|
|
29
|
-
LabelPredictionWithMemoriesAndFeedback,
|
|
30
|
-
ListPredictionsRequest,
|
|
31
|
-
)
|
|
32
|
-
from ._generated_api_client.models import (
|
|
33
|
-
PredictionSortItemItemType0 as PredictionSortColumns,
|
|
34
|
-
)
|
|
35
|
-
from ._generated_api_client.models import (
|
|
36
|
-
PredictionSortItemItemType1 as PredictionSortDirection,
|
|
37
|
-
)
|
|
38
|
-
from ._generated_api_client.models import (
|
|
39
16
|
PredictiveModelUpdate,
|
|
40
17
|
RACHeadType,
|
|
18
|
+
orca_api,
|
|
41
19
|
)
|
|
42
|
-
from ._generated_api_client.types import UNSET as CLIENT_UNSET
|
|
43
|
-
from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
|
|
44
|
-
from ._utils.common import UNSET, CreateMode, DropMode
|
|
45
20
|
from .datasource import Datasource
|
|
46
21
|
from .job import Job
|
|
47
22
|
from .memoryset import (
|
|
48
23
|
FilterItem,
|
|
49
24
|
FilterItemTuple,
|
|
50
25
|
LabeledMemoryset,
|
|
26
|
+
_is_metric_column,
|
|
51
27
|
_parse_filter_item_from_tuple,
|
|
52
28
|
)
|
|
53
|
-
from .telemetry import
|
|
29
|
+
from .telemetry import (
|
|
30
|
+
ClassificationPrediction,
|
|
31
|
+
TelemetryMode,
|
|
32
|
+
_get_telemetry_config,
|
|
33
|
+
_parse_feedback,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class BootstrappedClassificationModel:
|
|
38
|
+
|
|
39
|
+
datasource: Datasource | None
|
|
40
|
+
memoryset: LabeledMemoryset | None
|
|
41
|
+
classification_model: ClassificationModel | None
|
|
42
|
+
agent_output: BootstrapClassificationModelResult | None
|
|
43
|
+
|
|
44
|
+
def __init__(self, metadata: BootstrapClassificationModelMeta):
|
|
45
|
+
self.datasource = Datasource.open(metadata["datasource_meta"]["id"])
|
|
46
|
+
self.memoryset = LabeledMemoryset.open(metadata["memoryset_meta"]["id"])
|
|
47
|
+
self.classification_model = ClassificationModel.open(metadata["model_meta"]["id"])
|
|
48
|
+
self.agent_output = metadata["agent_output"]
|
|
49
|
+
|
|
50
|
+
def __repr__(self):
|
|
51
|
+
return (
|
|
52
|
+
"BootstrappedClassificationModel({\n"
|
|
53
|
+
f" datasource: {self.datasource},\n"
|
|
54
|
+
f" memoryset: {self.memoryset},\n"
|
|
55
|
+
f" classification_model: {self.classification_model},\n"
|
|
56
|
+
f" agent_output: {self.agent_output},\n"
|
|
57
|
+
"})"
|
|
58
|
+
)
|
|
54
59
|
|
|
55
60
|
|
|
56
61
|
class ClassificationModel:
|
|
@@ -86,18 +91,18 @@ class ClassificationModel:
|
|
|
86
91
|
|
|
87
92
|
def __init__(self, metadata: ClassificationModelMetadata):
|
|
88
93
|
# for internal use only, do not document
|
|
89
|
-
self.id = metadata
|
|
90
|
-
self.name = metadata
|
|
91
|
-
self.description = metadata
|
|
92
|
-
self.memoryset = LabeledMemoryset.open(metadata
|
|
93
|
-
self.head_type = metadata
|
|
94
|
-
self.num_classes = metadata
|
|
95
|
-
self.memory_lookup_count = metadata
|
|
96
|
-
self.weigh_memories = metadata
|
|
97
|
-
self.min_memory_weight = metadata
|
|
98
|
-
self.version = metadata
|
|
99
|
-
self.locked = metadata
|
|
100
|
-
self.created_at = metadata
|
|
94
|
+
self.id = metadata["id"]
|
|
95
|
+
self.name = metadata["name"]
|
|
96
|
+
self.description = metadata["description"]
|
|
97
|
+
self.memoryset = LabeledMemoryset.open(metadata["memoryset_id"])
|
|
98
|
+
self.head_type = metadata["head_type"]
|
|
99
|
+
self.num_classes = metadata["num_classes"]
|
|
100
|
+
self.memory_lookup_count = metadata["memory_lookup_count"]
|
|
101
|
+
self.weigh_memories = metadata["weigh_memories"]
|
|
102
|
+
self.min_memory_weight = metadata["min_memory_weight"]
|
|
103
|
+
self.version = metadata["version"]
|
|
104
|
+
self.locked = metadata["locked"]
|
|
105
|
+
self.created_at = datetime.fromisoformat(metadata["created_at"])
|
|
101
106
|
|
|
102
107
|
self._memoryset_override_id: str | None = None
|
|
103
108
|
self._last_prediction: ClassificationPrediction | None = None
|
|
@@ -140,7 +145,7 @@ class ClassificationModel:
|
|
|
140
145
|
cls,
|
|
141
146
|
name: str,
|
|
142
147
|
memoryset: LabeledMemoryset,
|
|
143
|
-
head_type:
|
|
148
|
+
head_type: RACHeadType = "KNN",
|
|
144
149
|
*,
|
|
145
150
|
description: str | None = None,
|
|
146
151
|
num_classes: int | None = None,
|
|
@@ -206,17 +211,18 @@ class ClassificationModel:
|
|
|
206
211
|
|
|
207
212
|
return existing
|
|
208
213
|
|
|
209
|
-
metadata =
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
214
|
+
metadata = orca_api.POST(
|
|
215
|
+
"/classification_model",
|
|
216
|
+
json={
|
|
217
|
+
"name": name,
|
|
218
|
+
"memoryset_name_or_id": memoryset.id,
|
|
219
|
+
"head_type": head_type,
|
|
220
|
+
"memory_lookup_count": memory_lookup_count,
|
|
221
|
+
"num_classes": num_classes,
|
|
222
|
+
"weigh_memories": weigh_memories,
|
|
223
|
+
"min_memory_weight": min_memory_weight,
|
|
224
|
+
"description": description,
|
|
225
|
+
},
|
|
220
226
|
)
|
|
221
227
|
return cls(metadata)
|
|
222
228
|
|
|
@@ -234,7 +240,7 @@ class ClassificationModel:
|
|
|
234
240
|
Raises:
|
|
235
241
|
LookupError: If the classification model does not exist
|
|
236
242
|
"""
|
|
237
|
-
return cls(
|
|
243
|
+
return cls(orca_api.GET("/classification_model/{name_or_id}", params={"name_or_id": name}))
|
|
238
244
|
|
|
239
245
|
@classmethod
|
|
240
246
|
def exists(cls, name_or_id: str) -> bool:
|
|
@@ -261,7 +267,7 @@ class ClassificationModel:
|
|
|
261
267
|
Returns:
|
|
262
268
|
List of handles to all classification models in the OrcaCloud
|
|
263
269
|
"""
|
|
264
|
-
return [cls(metadata) for metadata in
|
|
270
|
+
return [cls(metadata) for metadata in orca_api.GET("/classification_model")]
|
|
265
271
|
|
|
266
272
|
@classmethod
|
|
267
273
|
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
|
|
@@ -280,7 +286,7 @@ class ClassificationModel:
|
|
|
280
286
|
LookupError: If the classification model does not exist and if_not_exists is `"error"`
|
|
281
287
|
"""
|
|
282
288
|
try:
|
|
283
|
-
|
|
289
|
+
orca_api.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
|
|
284
290
|
logging.info(f"Deleted model {name_or_id}")
|
|
285
291
|
except LookupError:
|
|
286
292
|
if if_not_exists == "error":
|
|
@@ -311,11 +317,12 @@ class ClassificationModel:
|
|
|
311
317
|
Lock the model:
|
|
312
318
|
>>> model.set(locked=True)
|
|
313
319
|
"""
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
320
|
+
update: PredictiveModelUpdate = {}
|
|
321
|
+
if description is not UNSET:
|
|
322
|
+
update["description"] = description
|
|
323
|
+
if locked is not UNSET:
|
|
324
|
+
update["locked"] = locked
|
|
325
|
+
orca_api.PATCH("/classification_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
|
|
319
326
|
self.refresh()
|
|
320
327
|
|
|
321
328
|
def lock(self) -> None:
|
|
@@ -333,7 +340,9 @@ class ClassificationModel:
|
|
|
333
340
|
expected_labels: list[int] | None = None,
|
|
334
341
|
filters: list[FilterItemTuple] = [],
|
|
335
342
|
tags: set[str] | None = None,
|
|
336
|
-
save_telemetry:
|
|
343
|
+
save_telemetry: TelemetryMode = "on",
|
|
344
|
+
prompt: str | None = None,
|
|
345
|
+
use_lookup_cache: bool = True,
|
|
337
346
|
) -> list[ClassificationPrediction]:
|
|
338
347
|
pass
|
|
339
348
|
|
|
@@ -344,17 +353,21 @@ class ClassificationModel:
|
|
|
344
353
|
expected_labels: int | None = None,
|
|
345
354
|
filters: list[FilterItemTuple] = [],
|
|
346
355
|
tags: set[str] | None = None,
|
|
347
|
-
save_telemetry:
|
|
356
|
+
save_telemetry: TelemetryMode = "on",
|
|
357
|
+
prompt: str | None = None,
|
|
358
|
+
use_lookup_cache: bool = True,
|
|
348
359
|
) -> ClassificationPrediction:
|
|
349
360
|
pass
|
|
350
361
|
|
|
351
362
|
def predict(
|
|
352
363
|
self,
|
|
353
364
|
value: list[str] | str,
|
|
354
|
-
expected_labels: list[int] | int | None = None,
|
|
365
|
+
expected_labels: list[int] | list[str] | int | str | None = None,
|
|
355
366
|
filters: list[FilterItemTuple] = [],
|
|
356
367
|
tags: set[str] | None = None,
|
|
357
|
-
save_telemetry:
|
|
368
|
+
save_telemetry: TelemetryMode = "on",
|
|
369
|
+
prompt: str | None = None,
|
|
370
|
+
use_lookup_cache: bool = True,
|
|
358
371
|
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
359
372
|
"""
|
|
360
373
|
Predict label(s) for the given input value(s) grounded in similar memories
|
|
@@ -362,6 +375,7 @@ class ClassificationModel:
|
|
|
362
375
|
Params:
|
|
363
376
|
value: Value(s) to get predict the labels of
|
|
364
377
|
expected_labels: Expected label(s) for the given input to record for model evaluation
|
|
378
|
+
filters: Optional filters to apply during memory lookup
|
|
365
379
|
tags: Tags to add to the prediction(s)
|
|
366
380
|
save_telemetry: Whether to save telemetry for the prediction(s). One of
|
|
367
381
|
* `"off"`: Do not save telemetry
|
|
@@ -369,6 +383,7 @@ class ClassificationModel:
|
|
|
369
383
|
environment variable is set.
|
|
370
384
|
* `"sync"`: Save telemetry synchronously
|
|
371
385
|
* `"async"`: Save telemetry asynchronously
|
|
386
|
+
prompt: Optional prompt to use for instruction-tuned embedding models
|
|
372
387
|
|
|
373
388
|
Returns:
|
|
374
389
|
Label prediction or list of label predictions
|
|
@@ -384,48 +399,60 @@ class ClassificationModel:
|
|
|
384
399
|
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
385
400
|
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
386
401
|
]
|
|
402
|
+
|
|
403
|
+
Using a prompt with an instruction-tuned embedding model:
|
|
404
|
+
>>> prediction = model.predict("I am happy", prompt="Represent this text for sentiment classification:")
|
|
405
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
387
406
|
"""
|
|
388
407
|
|
|
389
408
|
parsed_filters = [
|
|
390
409
|
_parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
|
|
391
410
|
]
|
|
392
411
|
|
|
393
|
-
if
|
|
412
|
+
if any(_is_metric_column(filter[0]) for filter in filters):
|
|
394
413
|
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
395
414
|
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
415
|
+
if isinstance(expected_labels, int):
|
|
416
|
+
expected_labels = [expected_labels]
|
|
417
|
+
elif isinstance(expected_labels, str):
|
|
418
|
+
expected_labels = [self.memoryset.label_names.index(expected_labels)]
|
|
419
|
+
elif isinstance(expected_labels, list):
|
|
420
|
+
expected_labels = [
|
|
421
|
+
self.memoryset.label_names.index(label) if isinstance(label, str) else label
|
|
422
|
+
for label in expected_labels
|
|
423
|
+
]
|
|
424
|
+
|
|
425
|
+
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
426
|
+
response = orca_api.POST(
|
|
427
|
+
"/gpu/classification_model/{name_or_id}/prediction",
|
|
428
|
+
params={"name_or_id": self.id},
|
|
429
|
+
json={
|
|
430
|
+
"input_values": value if isinstance(value, list) else [value],
|
|
431
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
432
|
+
"expected_labels": expected_labels,
|
|
433
|
+
"tags": list(tags or set()),
|
|
434
|
+
"save_telemetry": telemetry_on,
|
|
435
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
436
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
437
|
+
"prompt": prompt,
|
|
438
|
+
"use_lookup_cache": use_lookup_cache,
|
|
439
|
+
},
|
|
413
440
|
)
|
|
414
441
|
|
|
415
|
-
if
|
|
442
|
+
if telemetry_on and any(p["prediction_id"] is None for p in response):
|
|
416
443
|
raise RuntimeError("Failed to save prediction to database.")
|
|
417
444
|
|
|
418
445
|
predictions = [
|
|
419
446
|
ClassificationPrediction(
|
|
420
|
-
prediction_id=prediction
|
|
421
|
-
label=prediction
|
|
422
|
-
label_name=prediction
|
|
447
|
+
prediction_id=prediction["prediction_id"],
|
|
448
|
+
label=prediction["label"],
|
|
449
|
+
label_name=prediction["label_name"],
|
|
423
450
|
score=None,
|
|
424
|
-
confidence=prediction
|
|
425
|
-
anomaly_score=prediction
|
|
451
|
+
confidence=prediction["confidence"],
|
|
452
|
+
anomaly_score=prediction["anomaly_score"],
|
|
426
453
|
memoryset=self.memoryset,
|
|
427
454
|
model=self,
|
|
428
|
-
logits=prediction
|
|
455
|
+
logits=prediction["logits"],
|
|
429
456
|
input_value=input_value,
|
|
430
457
|
)
|
|
431
458
|
for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
|
|
@@ -439,7 +466,7 @@ class ClassificationModel:
|
|
|
439
466
|
limit: int = 100,
|
|
440
467
|
offset: int = 0,
|
|
441
468
|
tag: str | None = None,
|
|
442
|
-
sort: list[tuple[
|
|
469
|
+
sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
|
|
443
470
|
expected_label_match: bool | None = None,
|
|
444
471
|
) -> list[ClassificationPrediction]:
|
|
445
472
|
"""
|
|
@@ -475,30 +502,31 @@ class ClassificationModel:
|
|
|
475
502
|
>>> predictions = model.predictions(expected_label_match=False)
|
|
476
503
|
[ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
|
|
477
504
|
"""
|
|
478
|
-
predictions =
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
505
|
+
predictions = orca_api.POST(
|
|
506
|
+
"/telemetry/prediction",
|
|
507
|
+
json={
|
|
508
|
+
"model_id": self.id,
|
|
509
|
+
"limit": limit,
|
|
510
|
+
"offset": offset,
|
|
511
|
+
"sort": [list(sort_item) for sort_item in sort],
|
|
512
|
+
"tag": tag,
|
|
513
|
+
"expected_label_match": expected_label_match,
|
|
514
|
+
},
|
|
487
515
|
)
|
|
488
516
|
return [
|
|
489
517
|
ClassificationPrediction(
|
|
490
|
-
prediction_id=prediction
|
|
491
|
-
label=prediction
|
|
492
|
-
label_name=prediction
|
|
518
|
+
prediction_id=prediction["prediction_id"],
|
|
519
|
+
label=prediction["label"],
|
|
520
|
+
label_name=prediction["label_name"],
|
|
493
521
|
score=None,
|
|
494
|
-
confidence=prediction
|
|
495
|
-
anomaly_score=prediction
|
|
522
|
+
confidence=prediction["confidence"],
|
|
523
|
+
anomaly_score=prediction["anomaly_score"],
|
|
496
524
|
memoryset=self.memoryset,
|
|
497
525
|
model=self,
|
|
498
526
|
telemetry=prediction,
|
|
499
527
|
)
|
|
500
528
|
for prediction in predictions
|
|
501
|
-
if
|
|
529
|
+
if "label" in prediction
|
|
502
530
|
]
|
|
503
531
|
|
|
504
532
|
def _evaluate_datasource(
|
|
@@ -510,23 +538,28 @@ class ClassificationModel:
|
|
|
510
538
|
tags: set[str] | None,
|
|
511
539
|
background: bool = False,
|
|
512
540
|
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
513
|
-
response =
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
541
|
+
response = orca_api.POST(
|
|
542
|
+
"/classification_model/{model_name_or_id}/evaluation",
|
|
543
|
+
params={"model_name_or_id": self.id},
|
|
544
|
+
json={
|
|
545
|
+
"datasource_name_or_id": datasource.id,
|
|
546
|
+
"datasource_label_column": label_column,
|
|
547
|
+
"datasource_value_column": value_column,
|
|
548
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
549
|
+
"record_telemetry": record_predictions,
|
|
550
|
+
"telemetry_tags": list(tags) if tags else None,
|
|
551
|
+
},
|
|
523
552
|
)
|
|
524
553
|
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
554
|
+
def get_value():
|
|
555
|
+
res = orca_api.GET(
|
|
556
|
+
"/classification_model/{model_name_or_id}/evaluation/{task_id}",
|
|
557
|
+
params={"model_name_or_id": self.id, "task_id": response["task_id"]},
|
|
558
|
+
)
|
|
559
|
+
assert res["result"] is not None
|
|
560
|
+
return ClassificationMetrics(**res["result"])
|
|
561
|
+
|
|
562
|
+
job = Job(response["task_id"], get_value)
|
|
530
563
|
return job if background else job.result()
|
|
531
564
|
|
|
532
565
|
def _evaluate_dataset(
|
|
@@ -709,8 +742,37 @@ class ClassificationModel:
|
|
|
709
742
|
ValueError: If the value does not match previous value types for the category, or is a
|
|
710
743
|
[`float`][float] that is not between `-1.0` and `+1.0`.
|
|
711
744
|
"""
|
|
712
|
-
|
|
713
|
-
|
|
745
|
+
orca_api.PUT(
|
|
746
|
+
"/telemetry/prediction/feedback",
|
|
747
|
+
json=[
|
|
714
748
|
_parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
|
|
715
749
|
],
|
|
716
750
|
)
|
|
751
|
+
|
|
752
|
+
@staticmethod
|
|
753
|
+
def bootstrap_model(
|
|
754
|
+
model_description: str,
|
|
755
|
+
label_names: list[str],
|
|
756
|
+
initial_examples: list[tuple[str, str]],
|
|
757
|
+
num_examples_per_label: int,
|
|
758
|
+
background: bool = False,
|
|
759
|
+
) -> Job[BootstrappedClassificationModel] | BootstrappedClassificationModel:
|
|
760
|
+
response = orca_api.POST(
|
|
761
|
+
"/agents/bootstrap_classification_model",
|
|
762
|
+
json={
|
|
763
|
+
"model_description": model_description,
|
|
764
|
+
"label_names": label_names,
|
|
765
|
+
"initial_examples": [{"text": text, "label_name": label_name} for text, label_name in initial_examples],
|
|
766
|
+
"num_examples_per_label": num_examples_per_label,
|
|
767
|
+
},
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
def get_result() -> BootstrappedClassificationModel:
|
|
771
|
+
res = orca_api.GET(
|
|
772
|
+
"/agents/bootstrap_classification_model/{task_id}", params={"task_id": response["task_id"]}
|
|
773
|
+
)
|
|
774
|
+
assert res["result"] is not None
|
|
775
|
+
return BootstrappedClassificationModel(res["result"])
|
|
776
|
+
|
|
777
|
+
job = Job(response["task_id"], get_result)
|
|
778
|
+
return job if background else job.result()
|
|
@@ -326,6 +326,37 @@ def test_predict_with_filters(classification_model: ClassificationModel):
|
|
|
326
326
|
assert filtered_prediction.label_name == "cats"
|
|
327
327
|
|
|
328
328
|
|
|
329
|
+
def test_predict_with_memoryset_update(writable_memoryset: LabeledMemoryset):
|
|
330
|
+
model = ClassificationModel.create(
|
|
331
|
+
"test_predict_with_memoryset_update",
|
|
332
|
+
writable_memoryset,
|
|
333
|
+
num_classes=2,
|
|
334
|
+
memory_lookup_count=3,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
prediction = model.predict("Do you love soup?")
|
|
338
|
+
assert prediction.label == 0
|
|
339
|
+
assert prediction.label_name == "soup"
|
|
340
|
+
|
|
341
|
+
# insert new memories
|
|
342
|
+
writable_memoryset.insert(
|
|
343
|
+
[
|
|
344
|
+
{"value": "Do you love soup?", "label": 1, "key": "g1"},
|
|
345
|
+
{"value": "Do you love soup for dinner?", "label": 1, "key": "g2"},
|
|
346
|
+
{"value": "Do you love crackers?", "label": 1, "key": "g2"},
|
|
347
|
+
{"value": "Do you love broth?", "label": 1, "key": "g2"},
|
|
348
|
+
{"value": "Do you love chicken soup?", "label": 1, "key": "g2"},
|
|
349
|
+
{"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
|
|
350
|
+
{"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
|
|
351
|
+
],
|
|
352
|
+
)
|
|
353
|
+
prediction = model.predict("Do you love soup?")
|
|
354
|
+
assert prediction.label == 1
|
|
355
|
+
assert prediction.label_name == "cats"
|
|
356
|
+
|
|
357
|
+
ClassificationModel.drop("test_predict_with_memoryset_update")
|
|
358
|
+
|
|
359
|
+
|
|
329
360
|
def test_last_prediction_with_batch(classification_model: ClassificationModel):
|
|
330
361
|
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
331
362
|
assert classification_model.last_prediction is not None
|
|
@@ -396,6 +427,8 @@ def test_action_recommendation(writable_memoryset: LabeledMemoryset):
|
|
|
396
427
|
# Make a prediction with expected label to simulate incorrect prediction
|
|
397
428
|
prediction = model.predict("Do you love soup?", expected_labels=1)
|
|
398
429
|
|
|
430
|
+
memoryset_length = model.memoryset.length
|
|
431
|
+
|
|
399
432
|
try:
|
|
400
433
|
# Get action recommendation
|
|
401
434
|
action, rationale = prediction.recommend_action()
|
|
@@ -406,19 +439,22 @@ def test_action_recommendation(writable_memoryset: LabeledMemoryset):
|
|
|
406
439
|
assert len(rationale) > 10
|
|
407
440
|
|
|
408
441
|
# Test memory suggestions
|
|
409
|
-
|
|
442
|
+
suggestions_response = prediction.generate_memory_suggestions(num_memories=2)
|
|
443
|
+
memory_suggestions = suggestions_response.suggestions
|
|
410
444
|
|
|
411
445
|
assert memory_suggestions is not None
|
|
412
446
|
assert len(memory_suggestions) == 2
|
|
413
447
|
|
|
414
448
|
for suggestion in memory_suggestions:
|
|
415
|
-
assert isinstance(suggestion,
|
|
416
|
-
assert
|
|
417
|
-
assert
|
|
418
|
-
assert
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
449
|
+
assert isinstance(suggestion[0], str)
|
|
450
|
+
assert len(suggestion[0]) > 0
|
|
451
|
+
assert isinstance(suggestion[1], str)
|
|
452
|
+
assert suggestion[1] in model.memoryset.label_names
|
|
453
|
+
|
|
454
|
+
suggestions_response.apply()
|
|
455
|
+
|
|
456
|
+
model.memoryset.refresh()
|
|
457
|
+
assert model.memoryset.length == memoryset_length + 2
|
|
422
458
|
|
|
423
459
|
except Exception as e:
|
|
424
460
|
if "ANTHROPIC_API_KEY" in str(e):
|
|
@@ -427,3 +463,16 @@ def test_action_recommendation(writable_memoryset: LabeledMemoryset):
|
|
|
427
463
|
raise e
|
|
428
464
|
finally:
|
|
429
465
|
ClassificationModel.drop("test_model_for_action")
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def test_predict_with_prompt(classification_model: ClassificationModel):
|
|
469
|
+
"""Test that prompt parameter is properly passed through to predictions"""
|
|
470
|
+
# Test with an instruction-supporting embedding model if available
|
|
471
|
+
prediction_with_prompt = classification_model.predict(
|
|
472
|
+
"I love this product!", prompt="Represent this text for sentiment classification:"
|
|
473
|
+
)
|
|
474
|
+
prediction_without_prompt = classification_model.predict("I love this product!")
|
|
475
|
+
|
|
476
|
+
# Both should work and return valid predictions
|
|
477
|
+
assert prediction_with_prompt.label is not None
|
|
478
|
+
assert prediction_without_prompt.label is not None
|