orca-sdk 0.0.97__py3-none-any.whl → 0.0.100__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +1 -0
- orca_sdk/_shared/__init__.py +1 -0
- orca_sdk/_utils/analysis_ui.py +5 -5
- orca_sdk/_utils/auth.py +23 -33
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/classification_model.py +188 -126
- orca_sdk/classification_model_test.py +57 -8
- orca_sdk/client.py +3563 -0
- orca_sdk/conftest.py +10 -0
- orca_sdk/credentials.py +59 -21
- orca_sdk/credentials_test.py +20 -0
- orca_sdk/datasource.py +80 -93
- orca_sdk/datasource_test.py +41 -7
- orca_sdk/embedding_model.py +225 -71
- orca_sdk/embedding_model_test.py +27 -36
- orca_sdk/job.py +49 -45
- orca_sdk/job_test.py +16 -0
- orca_sdk/memoryset.py +340 -353
- orca_sdk/memoryset_test.py +7 -11
- orca_sdk/regression_model.py +120 -111
- orca_sdk/regression_model_test.py +15 -0
- orca_sdk/telemetry.py +162 -139
- {orca_sdk-0.0.97.dist-info → orca_sdk-0.0.100.dist-info}/METADATA +2 -5
- orca_sdk-0.0.100.dist-info/RECORD +40 -0
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -307
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/create_org_plan_auth_org_plan_post.py +0 -168
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/get_org_plan_auth_org_plan_get.py +0 -122
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/auth/update_org_plan_auth_org_plan_put.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_classification_model_classification_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_label_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +0 -183
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_from_content_datasource_post.py +0 -224
- orca_sdk/_generated_api_client/api/datasource/create_datasource_from_files_datasource_upload_post.py +0 -229
- orca_sdk/_generated_api_client/api/datasource/create_embedding_evaluation_datasource_name_or_id_embedding_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +0 -172
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_embedding_evaluation_datasource_name_or_id_embedding_evaluation_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/datasource/list_embedding_evaluations_datasource_name_or_id_embedding_evaluation_get.py +0 -235
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/analyze_memoryset_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/batch_delete_memoryset_batch_delete_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -210
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -186
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -235
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -180
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -212
- orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +0 -195
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -210
- orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +0 -233
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -216
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -205
- orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +0 -183
- orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +0 -150
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -192
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -161
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/regression_model/create_regression_model_regression_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_regression_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/regression_model/list_regression_models_regression_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/regression_model/update_regression_model_regression_model_name_or_id_patch.py +0 -183
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -288
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +0 -168
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/explain_prediction_telemetry_prediction_prediction_id_explanation_get.py +0 -182
- orca_sdk/_generated_api_client/api/telemetry/generate_memory_suggestions_telemetry_prediction_prediction_id_memory_suggestions_post.py +0 -239
- orca_sdk/_generated_api_client/api/telemetry/get_action_recommendation_telemetry_prediction_prediction_id_action_get.py +0 -192
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -180
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +0 -198
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -198
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -345
- orca_sdk/_generated_api_client/models/action_recommendation.py +0 -82
- orca_sdk/_generated_api_client/models/action_recommendation_action.py +0 -11
- orca_sdk/_generated_api_client/models/add_memory_recommendations.py +0 -85
- orca_sdk/_generated_api_client/models/add_memory_suggestion.py +0 -79
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -116
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -137
- orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py +0 -9
- orca_sdk/_generated_api_client/models/base_label_prediction_result.py +0 -130
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/base_score_prediction_result.py +0 -108
- orca_sdk/_generated_api_client/models/body_create_datasource_from_files_datasource_upload_post.py +0 -145
- orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +0 -154
- orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +0 -92
- orca_sdk/_generated_api_client/models/class_representatives.py +0 -92
- orca_sdk/_generated_api_client/models/classification_evaluation_request.py +0 -148
- orca_sdk/_generated_api_client/models/classification_metrics.py +0 -259
- orca_sdk/_generated_api_client/models/classification_model_metadata.py +0 -227
- orca_sdk/_generated_api_client/models/classification_prediction_request.py +0 -220
- orca_sdk/_generated_api_client/models/clone_memoryset_request.py +0 -210
- orca_sdk/_generated_api_client/models/cluster_metrics.py +0 -78
- orca_sdk/_generated_api_client/models/column_info.py +0 -145
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/constraint_violation_error_response.py +0 -81
- orca_sdk/_generated_api_client/models/constraint_violation_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/count_predictions_request.py +0 -195
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -120
- orca_sdk/_generated_api_client/models/create_api_key_request_scope_item.py +0 -9
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -145
- orca_sdk/_generated_api_client/models/create_api_key_response_scope_item.py +0 -9
- orca_sdk/_generated_api_client/models/create_classification_model_request.py +0 -237
- orca_sdk/_generated_api_client/models/create_datasource_from_content_request.py +0 -101
- orca_sdk/_generated_api_client/models/create_memoryset_request.py +0 -365
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +0 -66
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +0 -13
- orca_sdk/_generated_api_client/models/create_org_plan_request.py +0 -73
- orca_sdk/_generated_api_client/models/create_org_plan_request_tier.py +0 -11
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +0 -157
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -156
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/delete_memorysets_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -155
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +0 -205
- orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +0 -197
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +0 -158
- orca_sdk/_generated_api_client/models/embedding_evaluation_result.py +0 -86
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/embedding_model_result.py +0 -123
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -153
- orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +0 -140
- orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_metrics.py +0 -85
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/filter_item.py +0 -239
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -17
- orca_sdk/_generated_api_client/models/filter_item_field_type_1_item_type_0.py +0 -8
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_0.py +0 -8
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -22
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/http_validation_error.py +0 -86
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -81
- orca_sdk/_generated_api_client/models/internal_server_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -210
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -288
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -186
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -128
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -194
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +0 -207
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_feedback_metrics.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -319
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -101
- orca_sdk/_generated_api_client/models/lookup_score_metrics.py +0 -94
- orca_sdk/_generated_api_client/models/memory_metrics.py +0 -263
- orca_sdk/_generated_api_client/models/memory_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -245
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -105
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -182
- orca_sdk/_generated_api_client/models/memoryset_class_patterns_analysis_config.py +0 -79
- orca_sdk/_generated_api_client/models/memoryset_class_patterns_metrics.py +0 -138
- orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config.py +0 -202
- orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_clustering_method.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_partitioning_method.py +0 -10
- orca_sdk/_generated_api_client/models/memoryset_cluster_metrics.py +0 -100
- orca_sdk/_generated_api_client/models/memoryset_duplicate_analysis_config.py +0 -70
- orca_sdk/_generated_api_client/models/memoryset_duplicate_metrics.py +0 -70
- orca_sdk/_generated_api_client/models/memoryset_label_analysis_config.py +0 -70
- orca_sdk/_generated_api_client/models/memoryset_label_metrics.py +0 -116
- orca_sdk/_generated_api_client/models/memoryset_metadata.py +0 -333
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +0 -55
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +0 -13
- orca_sdk/_generated_api_client/models/memoryset_metrics.py +0 -265
- orca_sdk/_generated_api_client/models/memoryset_neighbor_analysis_config.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics.py +0 -76
- orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics_lookup_score_metrics.py +0 -68
- orca_sdk/_generated_api_client/models/memoryset_projection_analysis_config.py +0 -79
- orca_sdk/_generated_api_client/models/memoryset_projection_metrics.py +0 -55
- orca_sdk/_generated_api_client/models/memoryset_update.py +0 -121
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -99
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -23
- orca_sdk/_generated_api_client/models/not_found_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/org_plan.py +0 -99
- orca_sdk/_generated_api_client/models/org_plan_tier.py +0 -11
- orca_sdk/_generated_api_client/models/paginated_task.py +0 -108
- orca_sdk/_generated_api_client/models/paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py +0 -135
- orca_sdk/_generated_api_client/models/pr_curve.py +0 -86
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_0.py +0 -10
- orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/predictive_model_update.py +0 -111
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -115
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -17
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rar_head_type.py +0 -8
- orca_sdk/_generated_api_client/models/regression_evaluation_request.py +0 -148
- orca_sdk/_generated_api_client/models/regression_metrics.py +0 -172
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +0 -191
- orca_sdk/_generated_api_client/models/regression_prediction_request.py +0 -195
- orca_sdk/_generated_api_client/models/roc_curve.py +0 -86
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +0 -196
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +0 -252
- orca_sdk/_generated_api_client/models/scored_memory.py +0 -172
- orca_sdk/_generated_api_client/models/scored_memory_insert.py +0 -128
- orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_lookup.py +0 -180
- orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +0 -193
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -81
- orca_sdk/_generated_api_client/models/service_unavailable_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_0.py +0 -8
- orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_2.py +0 -9
- orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_0.py +0 -8
- orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_1.py +0 -8
- orca_sdk/_generated_api_client/models/telemetry_filter_item.py +0 -217
- orca_sdk/_generated_api_client/models/telemetry_filter_item_op.py +0 -15
- orca_sdk/_generated_api_client/models/telemetry_memories_request.py +0 -181
- orca_sdk/_generated_api_client/models/telemetry_sort_options.py +0 -185
- orca_sdk/_generated_api_client/models/telemetry_sort_options_direction.py +0 -9
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -73
- orca_sdk/_generated_api_client/models/unauthenticated_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -81
- orca_sdk/_generated_api_client/models/unauthorized_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/update_org_plan_request.py +0 -73
- orca_sdk/_generated_api_client/models/update_org_plan_request_tier.py +0 -11
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -133
- orca_sdk/_generated_api_client/models/validation_error.py +0 -99
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk-0.0.97.dist-info/RECORD +0 -309
- {orca_sdk-0.0.97.dist-info → orca_sdk-0.0.100.dist-info}/WHEEL +0 -0
orca_sdk/memoryset_test.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import random
|
|
3
2
|
from uuid import uuid4
|
|
4
3
|
|
|
@@ -364,17 +363,14 @@ def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
|
364
363
|
|
|
365
364
|
|
|
366
365
|
def test_embedding_evaluation(eval_datasource: Datasource):
|
|
367
|
-
|
|
366
|
+
results = LabeledMemoryset.run_embedding_evaluation(
|
|
368
367
|
eval_datasource, embedding_models=["CDE_SMALL"], neighbor_count=2
|
|
369
368
|
)
|
|
370
|
-
assert
|
|
371
|
-
assert
|
|
372
|
-
assert
|
|
373
|
-
assert
|
|
374
|
-
assert
|
|
375
|
-
assert response["evaluation_results"][0] is not None
|
|
376
|
-
assert response["evaluation_results"][0]["embedding_model_name"] == "CDE_SMALL"
|
|
377
|
-
assert response["evaluation_results"][0]["embedding_model_path"] == "OrcaDB/cde-small-v1"
|
|
369
|
+
assert isinstance(results, list)
|
|
370
|
+
assert len(results) == 1
|
|
371
|
+
assert results[0] is not None
|
|
372
|
+
assert results[0]["embedding_model_name"] == "CDE_SMALL"
|
|
373
|
+
assert results[0]["embedding_model_path"] == "OrcaDB/cde-small-v1"
|
|
378
374
|
|
|
379
375
|
|
|
380
376
|
@pytest.fixture(scope="function")
|
|
@@ -427,7 +423,7 @@ def test_get_cascading_edits_suggestions(writable_memoryset: LabeledMemoryset):
|
|
|
427
423
|
|
|
428
424
|
# Validate the suggestions
|
|
429
425
|
assert len(suggestions) == 1
|
|
430
|
-
assert suggestions[0]
|
|
426
|
+
assert suggestions[0]["neighbor"]["value"] == mislabeled_soup_text
|
|
431
427
|
|
|
432
428
|
|
|
433
429
|
def test_analyze_invalid_analysis_name(readonly_memoryset: LabeledMemoryset):
|
orca_sdk/regression_model.py
CHANGED
|
@@ -1,53 +1,29 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
import os
|
|
5
4
|
from contextlib import contextmanager
|
|
6
5
|
from datetime import datetime
|
|
7
6
|
from typing import Any, Generator, Iterable, Literal, cast, overload
|
|
8
|
-
from uuid import UUID
|
|
9
7
|
|
|
10
|
-
import numpy as np
|
|
11
8
|
from datasets import Dataset
|
|
12
9
|
|
|
13
|
-
from .
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
evaluate_regression_model,
|
|
17
|
-
get_regression_model,
|
|
18
|
-
get_regression_model_evaluation,
|
|
19
|
-
list_predictions,
|
|
20
|
-
list_regression_models,
|
|
21
|
-
predict_score_gpu,
|
|
22
|
-
record_prediction_feedback,
|
|
23
|
-
update_regression_model,
|
|
24
|
-
)
|
|
25
|
-
from ._generated_api_client.models import (
|
|
26
|
-
CreateRegressionModelRequest,
|
|
27
|
-
ListPredictionsRequest,
|
|
28
|
-
)
|
|
29
|
-
from ._generated_api_client.models import (
|
|
30
|
-
PredictionSortItemItemType0 as PredictionSortColumns,
|
|
31
|
-
)
|
|
32
|
-
from ._generated_api_client.models import (
|
|
33
|
-
PredictionSortItemItemType1 as PredictionSortDirection,
|
|
34
|
-
)
|
|
35
|
-
from ._generated_api_client.models import (
|
|
10
|
+
from ._shared.metrics import RegressionMetrics, calculate_regression_metrics
|
|
11
|
+
from ._utils.common import UNSET, CreateMode, DropMode
|
|
12
|
+
from .client import (
|
|
36
13
|
PredictiveModelUpdate,
|
|
37
14
|
RARHeadType,
|
|
38
|
-
RegressionEvaluationRequest,
|
|
39
15
|
RegressionModelMetadata,
|
|
40
|
-
|
|
41
|
-
ScorePredictionWithMemoriesAndFeedback,
|
|
16
|
+
orca_api,
|
|
42
17
|
)
|
|
43
|
-
from ._generated_api_client.types import UNSET as CLIENT_UNSET
|
|
44
|
-
from ._generated_api_client.types import Response
|
|
45
|
-
from ._shared.metrics import RegressionMetrics, calculate_regression_metrics
|
|
46
|
-
from ._utils.common import UNSET, CreateMode, DropMode
|
|
47
18
|
from .datasource import Datasource
|
|
48
19
|
from .job import Job
|
|
49
20
|
from .memoryset import ScoredMemoryset
|
|
50
|
-
from .telemetry import
|
|
21
|
+
from .telemetry import (
|
|
22
|
+
RegressionPrediction,
|
|
23
|
+
TelemetryMode,
|
|
24
|
+
_get_telemetry_config,
|
|
25
|
+
_parse_feedback,
|
|
26
|
+
)
|
|
51
27
|
|
|
52
28
|
logger = logging.getLogger(__name__)
|
|
53
29
|
|
|
@@ -86,17 +62,17 @@ class RegressionModel:
|
|
|
86
62
|
|
|
87
63
|
def __init__(self, metadata: RegressionModelMetadata):
|
|
88
64
|
# for internal use only, do not document
|
|
89
|
-
self.id = metadata
|
|
90
|
-
self.name = metadata
|
|
91
|
-
self.description = metadata
|
|
92
|
-
self.memoryset = ScoredMemoryset.open(metadata
|
|
93
|
-
self.head_type = metadata
|
|
94
|
-
self.memory_lookup_count = metadata
|
|
95
|
-
self.version = metadata
|
|
96
|
-
self.locked = metadata
|
|
97
|
-
self.created_at = metadata
|
|
98
|
-
self.updated_at = metadata
|
|
99
|
-
self.memoryset_id = metadata
|
|
65
|
+
self.id = metadata["id"]
|
|
66
|
+
self.name = metadata["name"]
|
|
67
|
+
self.description = metadata["description"]
|
|
68
|
+
self.memoryset = ScoredMemoryset.open(metadata["memoryset_id"])
|
|
69
|
+
self.head_type = metadata["head_type"]
|
|
70
|
+
self.memory_lookup_count = metadata["memory_lookup_count"]
|
|
71
|
+
self.version = metadata["version"]
|
|
72
|
+
self.locked = metadata["locked"]
|
|
73
|
+
self.created_at = datetime.fromisoformat(metadata["created_at"])
|
|
74
|
+
self.updated_at = datetime.fromisoformat(metadata["updated_at"])
|
|
75
|
+
self.memoryset_id = metadata["memoryset_id"]
|
|
100
76
|
|
|
101
77
|
self._memoryset_override_id = None
|
|
102
78
|
self._last_prediction = None
|
|
@@ -178,13 +154,14 @@ class RegressionModel:
|
|
|
178
154
|
|
|
179
155
|
return existing
|
|
180
156
|
|
|
181
|
-
metadata =
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
157
|
+
metadata = orca_api.POST(
|
|
158
|
+
"/regression_model",
|
|
159
|
+
json={
|
|
160
|
+
"name": name,
|
|
161
|
+
"memoryset_name_or_id": memoryset.id,
|
|
162
|
+
"memory_lookup_count": memory_lookup_count,
|
|
163
|
+
"description": description,
|
|
164
|
+
},
|
|
188
165
|
)
|
|
189
166
|
return cls(metadata)
|
|
190
167
|
|
|
@@ -202,7 +179,7 @@ class RegressionModel:
|
|
|
202
179
|
Raises:
|
|
203
180
|
LookupError: If the regression model does not exist
|
|
204
181
|
"""
|
|
205
|
-
return cls(
|
|
182
|
+
return cls(orca_api.GET("/regression_model/{name_or_id}", params={"name_or_id": name}))
|
|
206
183
|
|
|
207
184
|
@classmethod
|
|
208
185
|
def exists(cls, name_or_id: str) -> bool:
|
|
@@ -229,7 +206,7 @@ class RegressionModel:
|
|
|
229
206
|
Returns:
|
|
230
207
|
List of handles to all regression models in the OrcaCloud
|
|
231
208
|
"""
|
|
232
|
-
return [cls(metadata) for metadata in
|
|
209
|
+
return [cls(metadata) for metadata in orca_api.GET("/regression_model")]
|
|
233
210
|
|
|
234
211
|
@classmethod
|
|
235
212
|
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
|
|
@@ -248,7 +225,7 @@ class RegressionModel:
|
|
|
248
225
|
LookupError: If the regression model does not exist and if_not_exists is `"error"`
|
|
249
226
|
"""
|
|
250
227
|
try:
|
|
251
|
-
|
|
228
|
+
orca_api.DELETE("/regression_model/{name_or_id}", params={"name_or_id": name_or_id})
|
|
252
229
|
logging.info(f"Deleted model {name_or_id}")
|
|
253
230
|
except LookupError:
|
|
254
231
|
if if_not_exists == "error":
|
|
@@ -279,11 +256,12 @@ class RegressionModel:
|
|
|
279
256
|
Lock the model:
|
|
280
257
|
>>> model.set(locked=True)
|
|
281
258
|
"""
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
259
|
+
update: PredictiveModelUpdate = {}
|
|
260
|
+
if description is not UNSET:
|
|
261
|
+
update["description"] = description
|
|
262
|
+
if locked is not UNSET:
|
|
263
|
+
update["locked"] = locked
|
|
264
|
+
orca_api.PATCH("/regression_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
|
|
287
265
|
self.refresh()
|
|
288
266
|
|
|
289
267
|
def lock(self) -> None:
|
|
@@ -300,7 +278,9 @@ class RegressionModel:
|
|
|
300
278
|
value: str,
|
|
301
279
|
expected_scores: float | None = None,
|
|
302
280
|
tags: set[str] | None = None,
|
|
303
|
-
save_telemetry:
|
|
281
|
+
save_telemetry: TelemetryMode = "on",
|
|
282
|
+
prompt: str | None = None,
|
|
283
|
+
use_lookup_cache: bool = True,
|
|
304
284
|
) -> RegressionPrediction: ...
|
|
305
285
|
|
|
306
286
|
@overload
|
|
@@ -309,7 +289,9 @@ class RegressionModel:
|
|
|
309
289
|
value: list[str],
|
|
310
290
|
expected_scores: list[float] | None = None,
|
|
311
291
|
tags: set[str] | None = None,
|
|
312
|
-
save_telemetry:
|
|
292
|
+
save_telemetry: TelemetryMode = "on",
|
|
293
|
+
prompt: str | None = None,
|
|
294
|
+
use_lookup_cache: bool = True,
|
|
313
295
|
) -> list[RegressionPrediction]: ...
|
|
314
296
|
|
|
315
297
|
# TODO: add filter support
|
|
@@ -318,7 +300,9 @@ class RegressionModel:
|
|
|
318
300
|
value: str | list[str],
|
|
319
301
|
expected_scores: float | list[float] | None = None,
|
|
320
302
|
tags: set[str] | None = None,
|
|
321
|
-
save_telemetry:
|
|
303
|
+
save_telemetry: TelemetryMode = "on",
|
|
304
|
+
prompt: str | None = None,
|
|
305
|
+
use_lookup_cache: bool = True,
|
|
322
306
|
) -> RegressionPrediction | list[RegressionPrediction]:
|
|
323
307
|
"""
|
|
324
308
|
Make predictions using the regression model.
|
|
@@ -331,6 +315,7 @@ class RegressionModel:
|
|
|
331
315
|
which will save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
|
|
332
316
|
environment variable is set to `"1"`. You can also pass `"sync"` or `"async"` to
|
|
333
317
|
explicitly set the save mode.
|
|
318
|
+
prompt: Optional prompt for instruction-tuned embedding models
|
|
334
319
|
|
|
335
320
|
Returns:
|
|
336
321
|
Single RegressionPrediction or list of RegressionPrediction objects
|
|
@@ -338,35 +323,37 @@ class RegressionModel:
|
|
|
338
323
|
Raises:
|
|
339
324
|
ValueError: If expected_scores length doesn't match value length for batch predictions
|
|
340
325
|
"""
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
326
|
+
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
327
|
+
response = orca_api.POST(
|
|
328
|
+
"/gpu/regression_model/{name_or_id}/prediction",
|
|
329
|
+
params={"name_or_id": self.id},
|
|
330
|
+
json={
|
|
331
|
+
"input_values": value if isinstance(value, list) else [value],
|
|
332
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
333
|
+
"expected_scores": (
|
|
346
334
|
expected_scores
|
|
347
335
|
if isinstance(expected_scores, list)
|
|
348
336
|
else [expected_scores] if expected_scores is not None else None
|
|
349
337
|
),
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
),
|
|
338
|
+
"tags": list(tags or set()),
|
|
339
|
+
"save_telemetry": telemetry_on,
|
|
340
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
341
|
+
"prompt": prompt,
|
|
342
|
+
"use_lookup_cache": use_lookup_cache,
|
|
343
|
+
},
|
|
357
344
|
)
|
|
358
345
|
|
|
359
|
-
if
|
|
346
|
+
if telemetry_on and any(p["prediction_id"] is None for p in response):
|
|
360
347
|
raise RuntimeError("Failed to save prediction to database.")
|
|
361
348
|
|
|
362
349
|
predictions = [
|
|
363
350
|
RegressionPrediction(
|
|
364
|
-
prediction_id=prediction
|
|
351
|
+
prediction_id=prediction["prediction_id"],
|
|
365
352
|
label=None,
|
|
366
353
|
label_name=None,
|
|
367
|
-
score=prediction
|
|
368
|
-
confidence=prediction
|
|
369
|
-
anomaly_score=prediction
|
|
354
|
+
score=prediction["score"],
|
|
355
|
+
confidence=prediction["confidence"],
|
|
356
|
+
anomaly_score=prediction["anomaly_score"],
|
|
370
357
|
memoryset=self.memoryset,
|
|
371
358
|
model=self,
|
|
372
359
|
logits=None,
|
|
@@ -383,7 +370,7 @@ class RegressionModel:
|
|
|
383
370
|
limit: int = 100,
|
|
384
371
|
offset: int = 0,
|
|
385
372
|
tag: str | None = None,
|
|
386
|
-
sort: list[tuple[
|
|
373
|
+
sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
|
|
387
374
|
) -> list[RegressionPrediction]:
|
|
388
375
|
"""
|
|
389
376
|
Get a list of predictions made by this model
|
|
@@ -411,23 +398,24 @@ class RegressionModel:
|
|
|
411
398
|
>>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
|
|
412
399
|
[RegressionPrediction({score: 4.2, confidence: 0.90, anomaly_score: 0.1, input_value: 'Good service'})]
|
|
413
400
|
"""
|
|
414
|
-
predictions =
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
401
|
+
predictions = orca_api.POST(
|
|
402
|
+
"/telemetry/prediction",
|
|
403
|
+
json={
|
|
404
|
+
"model_id": self.id,
|
|
405
|
+
"limit": limit,
|
|
406
|
+
"offset": offset,
|
|
407
|
+
"sort": [list(sort_item) for sort_item in sort],
|
|
408
|
+
"tag": tag,
|
|
409
|
+
},
|
|
422
410
|
)
|
|
423
411
|
return [
|
|
424
412
|
RegressionPrediction(
|
|
425
|
-
prediction_id=prediction
|
|
413
|
+
prediction_id=prediction["prediction_id"],
|
|
426
414
|
label=None,
|
|
427
415
|
label_name=None,
|
|
428
|
-
score=prediction
|
|
429
|
-
confidence=prediction
|
|
430
|
-
anomaly_score=prediction
|
|
416
|
+
score=prediction["score"],
|
|
417
|
+
confidence=prediction["confidence"],
|
|
418
|
+
anomaly_score=prediction["anomaly_score"],
|
|
431
419
|
memoryset=self.memoryset,
|
|
432
420
|
model=self,
|
|
433
421
|
telemetry=prediction,
|
|
@@ -435,7 +423,7 @@ class RegressionModel:
|
|
|
435
423
|
input_value=None,
|
|
436
424
|
)
|
|
437
425
|
for prediction in predictions
|
|
438
|
-
if
|
|
426
|
+
if "score" in prediction
|
|
439
427
|
]
|
|
440
428
|
|
|
441
429
|
def _evaluate_datasource(
|
|
@@ -447,23 +435,28 @@ class RegressionModel:
|
|
|
447
435
|
tags: set[str] | None,
|
|
448
436
|
background: bool = False,
|
|
449
437
|
) -> RegressionMetrics | Job[RegressionMetrics]:
|
|
450
|
-
response =
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
438
|
+
response = orca_api.POST(
|
|
439
|
+
"/regression_model/{model_name_or_id}/evaluation",
|
|
440
|
+
params={"model_name_or_id": self.id},
|
|
441
|
+
json={
|
|
442
|
+
"datasource_name_or_id": datasource.id,
|
|
443
|
+
"datasource_score_column": score_column,
|
|
444
|
+
"datasource_value_column": value_column,
|
|
445
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
446
|
+
"record_telemetry": record_predictions,
|
|
447
|
+
"telemetry_tags": list(tags) if tags else None,
|
|
448
|
+
},
|
|
460
449
|
)
|
|
461
450
|
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
451
|
+
def get_value():
|
|
452
|
+
res = orca_api.GET(
|
|
453
|
+
"/regression_model/{model_name_or_id}/evaluation/{task_id}",
|
|
454
|
+
params={"model_name_or_id": self.id, "task_id": response["task_id"]},
|
|
455
|
+
)
|
|
456
|
+
assert res["result"] is not None
|
|
457
|
+
return RegressionMetrics(**res["result"])
|
|
458
|
+
|
|
459
|
+
job = Job(response["task_id"], get_value)
|
|
467
460
|
return job if background else job.result()
|
|
468
461
|
|
|
469
462
|
def _evaluate_dataset(
|
|
@@ -474,6 +467,7 @@ class RegressionModel:
|
|
|
474
467
|
record_predictions: bool,
|
|
475
468
|
tags: set[str],
|
|
476
469
|
batch_size: int,
|
|
470
|
+
prompt: str | None = None,
|
|
477
471
|
) -> RegressionMetrics:
|
|
478
472
|
predictions = [
|
|
479
473
|
prediction
|
|
@@ -483,6 +477,7 @@ class RegressionModel:
|
|
|
483
477
|
expected_scores=dataset[i : i + batch_size][score_column],
|
|
484
478
|
tags=tags,
|
|
485
479
|
save_telemetry="sync" if record_predictions else "off",
|
|
480
|
+
prompt=prompt,
|
|
486
481
|
)
|
|
487
482
|
]
|
|
488
483
|
|
|
@@ -502,6 +497,7 @@ class RegressionModel:
|
|
|
502
497
|
record_predictions: bool = False,
|
|
503
498
|
tags: set[str] = {"evaluation"},
|
|
504
499
|
batch_size: int = 100,
|
|
500
|
+
prompt: str | None = None,
|
|
505
501
|
background: Literal[True],
|
|
506
502
|
) -> Job[RegressionMetrics]:
|
|
507
503
|
pass
|
|
@@ -516,6 +512,7 @@ class RegressionModel:
|
|
|
516
512
|
record_predictions: bool = False,
|
|
517
513
|
tags: set[str] = {"evaluation"},
|
|
518
514
|
batch_size: int = 100,
|
|
515
|
+
prompt: str | None = None,
|
|
519
516
|
background: Literal[False] = False,
|
|
520
517
|
) -> RegressionMetrics:
|
|
521
518
|
pass
|
|
@@ -529,6 +526,7 @@ class RegressionModel:
|
|
|
529
526
|
record_predictions: bool = False,
|
|
530
527
|
tags: set[str] = {"evaluation"},
|
|
531
528
|
batch_size: int = 100,
|
|
529
|
+
prompt: str | None = None,
|
|
532
530
|
background: bool = False,
|
|
533
531
|
) -> RegressionMetrics | Job[RegressionMetrics]:
|
|
534
532
|
"""
|
|
@@ -541,6 +539,7 @@ class RegressionModel:
|
|
|
541
539
|
record_predictions: Whether to record [`RegressionPrediction`][orca_sdk.telemetry.RegressionPrediction]s for analysis
|
|
542
540
|
tags: Optional tags to add to the recorded [`RegressionPrediction`][orca_sdk.telemetry.RegressionPrediction]s
|
|
543
541
|
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
542
|
+
prompt: Optional prompt for instruction-tuned embedding models
|
|
544
543
|
background: Whether to run the operation in the background and return a job handle
|
|
545
544
|
|
|
546
545
|
Returns:
|
|
@@ -554,6 +553,14 @@ class RegressionModel:
|
|
|
554
553
|
r2: 0.8500,
|
|
555
554
|
anomaly_score: 0.3500 ± 0.0500,
|
|
556
555
|
})
|
|
556
|
+
|
|
557
|
+
>>> # Using with an instruction-tuned embedding model
|
|
558
|
+
>>> model.evaluate(dataset,prompt="Represent this review for rating prediction:")
|
|
559
|
+
RegressionMetrics({
|
|
560
|
+
mae: 0.2000,
|
|
561
|
+
rmse: 0.3000,
|
|
562
|
+
r2: 0.9000,
|
|
563
|
+
anomaly_score: 0.3000 ± 0.0400})
|
|
557
564
|
"""
|
|
558
565
|
if isinstance(data, Datasource):
|
|
559
566
|
return self._evaluate_datasource(
|
|
@@ -572,6 +579,7 @@ class RegressionModel:
|
|
|
572
579
|
record_predictions=record_predictions,
|
|
573
580
|
tags=tags,
|
|
574
581
|
batch_size=batch_size,
|
|
582
|
+
prompt=prompt,
|
|
575
583
|
)
|
|
576
584
|
else:
|
|
577
585
|
raise ValueError(f"Invalid data type: {type(data)}")
|
|
@@ -640,8 +648,9 @@ class RegressionModel:
|
|
|
640
648
|
ValueError: If the value does not match previous value types for the category, or is a
|
|
641
649
|
[`float`][float] that is not between `-1.0` and `+1.0`.
|
|
642
650
|
"""
|
|
643
|
-
|
|
644
|
-
|
|
651
|
+
orca_api.PUT(
|
|
652
|
+
"/telemetry/prediction/feedback",
|
|
653
|
+
json=[
|
|
645
654
|
_parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
|
|
646
655
|
],
|
|
647
656
|
)
|
|
@@ -207,6 +207,21 @@ def test_predict_constraint_violation(scored_memoryset: ScoredMemoryset):
|
|
|
207
207
|
model.predict("test")
|
|
208
208
|
|
|
209
209
|
|
|
210
|
+
def test_predict_with_prompt(regression_model: RegressionModel):
|
|
211
|
+
"""Test that prompt parameter is properly passed through to predictions"""
|
|
212
|
+
# Test with an instruction-supporting embedding model if available
|
|
213
|
+
prediction_with_prompt = regression_model.predict(
|
|
214
|
+
"This product is amazing!", prompt="Represent this text for rating prediction:"
|
|
215
|
+
)
|
|
216
|
+
prediction_without_prompt = regression_model.predict("This product is amazing!")
|
|
217
|
+
|
|
218
|
+
# Both should work and return valid predictions
|
|
219
|
+
assert prediction_with_prompt.score is not None
|
|
220
|
+
assert prediction_without_prompt.score is not None
|
|
221
|
+
assert 0 <= prediction_with_prompt.confidence <= 1
|
|
222
|
+
assert 0 <= prediction_without_prompt.confidence <= 1
|
|
223
|
+
|
|
224
|
+
|
|
210
225
|
def test_record_prediction_feedback(regression_model: RegressionModel):
|
|
211
226
|
predictions = regression_model.predict(["This is excellent!", "This is terrible!"])
|
|
212
227
|
expected_scores = [0.9, 0.1]
|