orca-sdk 0.0.97__py3-none-any.whl → 0.0.100__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +1 -0
- orca_sdk/_shared/__init__.py +1 -0
- orca_sdk/_utils/analysis_ui.py +5 -5
- orca_sdk/_utils/auth.py +23 -33
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/classification_model.py +188 -126
- orca_sdk/classification_model_test.py +57 -8
- orca_sdk/client.py +3563 -0
- orca_sdk/conftest.py +10 -0
- orca_sdk/credentials.py +59 -21
- orca_sdk/credentials_test.py +20 -0
- orca_sdk/datasource.py +80 -93
- orca_sdk/datasource_test.py +41 -7
- orca_sdk/embedding_model.py +225 -71
- orca_sdk/embedding_model_test.py +27 -36
- orca_sdk/job.py +49 -45
- orca_sdk/job_test.py +16 -0
- orca_sdk/memoryset.py +340 -353
- orca_sdk/memoryset_test.py +7 -11
- orca_sdk/regression_model.py +120 -111
- orca_sdk/regression_model_test.py +15 -0
- orca_sdk/telemetry.py +162 -139
- {orca_sdk-0.0.97.dist-info → orca_sdk-0.0.100.dist-info}/METADATA +2 -5
- orca_sdk-0.0.100.dist-info/RECORD +40 -0
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -307
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/create_org_plan_auth_org_plan_post.py +0 -168
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/get_org_plan_auth_org_plan_get.py +0 -122
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/auth/update_org_plan_auth_org_plan_put.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_classification_model_classification_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_label_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +0 -183
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_from_content_datasource_post.py +0 -224
- orca_sdk/_generated_api_client/api/datasource/create_datasource_from_files_datasource_upload_post.py +0 -229
- orca_sdk/_generated_api_client/api/datasource/create_embedding_evaluation_datasource_name_or_id_embedding_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +0 -172
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_embedding_evaluation_datasource_name_or_id_embedding_evaluation_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/datasource/list_embedding_evaluations_datasource_name_or_id_embedding_evaluation_get.py +0 -235
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/analyze_memoryset_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/batch_delete_memoryset_batch_delete_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -210
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -186
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -235
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -180
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -212
- orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +0 -195
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -210
- orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +0 -233
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -216
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -205
- orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +0 -183
- orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +0 -150
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -192
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -161
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/regression_model/create_regression_model_regression_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_regression_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/regression_model/list_regression_models_regression_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/regression_model/update_regression_model_regression_model_name_or_id_patch.py +0 -183
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -288
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +0 -168
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/explain_prediction_telemetry_prediction_prediction_id_explanation_get.py +0 -182
- orca_sdk/_generated_api_client/api/telemetry/generate_memory_suggestions_telemetry_prediction_prediction_id_memory_suggestions_post.py +0 -239
- orca_sdk/_generated_api_client/api/telemetry/get_action_recommendation_telemetry_prediction_prediction_id_action_get.py +0 -192
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -180
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +0 -198
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -198
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -345
- orca_sdk/_generated_api_client/models/action_recommendation.py +0 -82
- orca_sdk/_generated_api_client/models/action_recommendation_action.py +0 -11
- orca_sdk/_generated_api_client/models/add_memory_recommendations.py +0 -85
- orca_sdk/_generated_api_client/models/add_memory_suggestion.py +0 -79
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -116
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -137
- orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py +0 -9
- orca_sdk/_generated_api_client/models/base_label_prediction_result.py +0 -130
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/base_score_prediction_result.py +0 -108
- orca_sdk/_generated_api_client/models/body_create_datasource_from_files_datasource_upload_post.py +0 -145
- orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +0 -154
- orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +0 -92
- orca_sdk/_generated_api_client/models/class_representatives.py +0 -92
- orca_sdk/_generated_api_client/models/classification_evaluation_request.py +0 -148
- orca_sdk/_generated_api_client/models/classification_metrics.py +0 -259
- orca_sdk/_generated_api_client/models/classification_model_metadata.py +0 -227
- orca_sdk/_generated_api_client/models/classification_prediction_request.py +0 -220
- orca_sdk/_generated_api_client/models/clone_memoryset_request.py +0 -210
- orca_sdk/_generated_api_client/models/cluster_metrics.py +0 -78
- orca_sdk/_generated_api_client/models/column_info.py +0 -145
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/constraint_violation_error_response.py +0 -81
- orca_sdk/_generated_api_client/models/constraint_violation_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/count_predictions_request.py +0 -195
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -120
- orca_sdk/_generated_api_client/models/create_api_key_request_scope_item.py +0 -9
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -145
- orca_sdk/_generated_api_client/models/create_api_key_response_scope_item.py +0 -9
- orca_sdk/_generated_api_client/models/create_classification_model_request.py +0 -237
- orca_sdk/_generated_api_client/models/create_datasource_from_content_request.py +0 -101
- orca_sdk/_generated_api_client/models/create_memoryset_request.py +0 -365
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +0 -66
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +0 -13
- orca_sdk/_generated_api_client/models/create_org_plan_request.py +0 -73
- orca_sdk/_generated_api_client/models/create_org_plan_request_tier.py +0 -11
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +0 -157
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -156
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/delete_memorysets_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -155
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +0 -205
- orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +0 -197
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +0 -158
- orca_sdk/_generated_api_client/models/embedding_evaluation_result.py +0 -86
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/embedding_model_result.py +0 -123
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -153
- orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +0 -140
- orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_metrics.py +0 -85
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/filter_item.py +0 -239
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -17
- orca_sdk/_generated_api_client/models/filter_item_field_type_1_item_type_0.py +0 -8
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_0.py +0 -8
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -22
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/http_validation_error.py +0 -86
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -81
- orca_sdk/_generated_api_client/models/internal_server_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -210
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -288
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -186
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -128
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -194
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +0 -207
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_feedback_metrics.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -319
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -101
- orca_sdk/_generated_api_client/models/lookup_score_metrics.py +0 -94
- orca_sdk/_generated_api_client/models/memory_metrics.py +0 -263
- orca_sdk/_generated_api_client/models/memory_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -245
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -105
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -182
- orca_sdk/_generated_api_client/models/memoryset_class_patterns_analysis_config.py +0 -79
- orca_sdk/_generated_api_client/models/memoryset_class_patterns_metrics.py +0 -138
- orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config.py +0 -202
- orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_clustering_method.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_partitioning_method.py +0 -10
- orca_sdk/_generated_api_client/models/memoryset_cluster_metrics.py +0 -100
- orca_sdk/_generated_api_client/models/memoryset_duplicate_analysis_config.py +0 -70
- orca_sdk/_generated_api_client/models/memoryset_duplicate_metrics.py +0 -70
- orca_sdk/_generated_api_client/models/memoryset_label_analysis_config.py +0 -70
- orca_sdk/_generated_api_client/models/memoryset_label_metrics.py +0 -116
- orca_sdk/_generated_api_client/models/memoryset_metadata.py +0 -333
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +0 -55
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +0 -13
- orca_sdk/_generated_api_client/models/memoryset_metrics.py +0 -265
- orca_sdk/_generated_api_client/models/memoryset_neighbor_analysis_config.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics.py +0 -76
- orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics_lookup_score_metrics.py +0 -68
- orca_sdk/_generated_api_client/models/memoryset_projection_analysis_config.py +0 -79
- orca_sdk/_generated_api_client/models/memoryset_projection_metrics.py +0 -55
- orca_sdk/_generated_api_client/models/memoryset_update.py +0 -121
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -99
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -23
- orca_sdk/_generated_api_client/models/not_found_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/org_plan.py +0 -99
- orca_sdk/_generated_api_client/models/org_plan_tier.py +0 -11
- orca_sdk/_generated_api_client/models/paginated_task.py +0 -108
- orca_sdk/_generated_api_client/models/paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py +0 -135
- orca_sdk/_generated_api_client/models/pr_curve.py +0 -86
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_0.py +0 -10
- orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/predictive_model_update.py +0 -111
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -115
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -17
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rar_head_type.py +0 -8
- orca_sdk/_generated_api_client/models/regression_evaluation_request.py +0 -148
- orca_sdk/_generated_api_client/models/regression_metrics.py +0 -172
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +0 -191
- orca_sdk/_generated_api_client/models/regression_prediction_request.py +0 -195
- orca_sdk/_generated_api_client/models/roc_curve.py +0 -86
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +0 -196
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +0 -252
- orca_sdk/_generated_api_client/models/scored_memory.py +0 -172
- orca_sdk/_generated_api_client/models/scored_memory_insert.py +0 -128
- orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_lookup.py +0 -180
- orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +0 -193
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -81
- orca_sdk/_generated_api_client/models/service_unavailable_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_0.py +0 -8
- orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_2.py +0 -9
- orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_0.py +0 -8
- orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_1.py +0 -8
- orca_sdk/_generated_api_client/models/telemetry_filter_item.py +0 -217
- orca_sdk/_generated_api_client/models/telemetry_filter_item_op.py +0 -15
- orca_sdk/_generated_api_client/models/telemetry_memories_request.py +0 -181
- orca_sdk/_generated_api_client/models/telemetry_sort_options.py +0 -185
- orca_sdk/_generated_api_client/models/telemetry_sort_options_direction.py +0 -9
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -73
- orca_sdk/_generated_api_client/models/unauthenticated_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -81
- orca_sdk/_generated_api_client/models/unauthorized_error_response_status_code.py +0 -8
- orca_sdk/_generated_api_client/models/update_org_plan_request.py +0 -73
- orca_sdk/_generated_api_client/models/update_org_plan_request_tier.py +0 -11
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -133
- orca_sdk/_generated_api_client/models/validation_error.py +0 -99
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk-0.0.97.dist-info/RECORD +0 -309
- {orca_sdk-0.0.97.dist-info → orca_sdk-0.0.100.dist-info}/WHEEL +0 -0
orca_sdk/embedding_model.py
CHANGED
|
@@ -2,28 +2,20 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from abc import abstractmethod
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import TYPE_CHECKING, Literal, Sequence, cast, overload
|
|
6
|
-
|
|
7
|
-
from .
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
embed_with_pretrained_model_gpu,
|
|
12
|
-
get_finetuned_embedding_model,
|
|
13
|
-
get_pretrained_embedding_model,
|
|
14
|
-
list_finetuned_embedding_models,
|
|
15
|
-
list_pretrained_embedding_models,
|
|
16
|
-
)
|
|
17
|
-
from ._generated_api_client.models import (
|
|
5
|
+
from typing import TYPE_CHECKING, Literal, Sequence, cast, get_args, overload
|
|
6
|
+
|
|
7
|
+
from ._shared.metrics import ClassificationMetrics, RegressionMetrics
|
|
8
|
+
from ._utils.common import UNSET, CreateMode, DropMode
|
|
9
|
+
from .client import (
|
|
10
|
+
EmbeddingEvaluationRequest,
|
|
18
11
|
EmbeddingFinetuningMethod,
|
|
19
12
|
EmbedRequest,
|
|
20
13
|
FinetunedEmbeddingModelMetadata,
|
|
21
14
|
FinetuneEmbeddingModelRequest,
|
|
22
|
-
FinetuneEmbeddingModelRequestTrainingArgs,
|
|
23
15
|
PretrainedEmbeddingModelMetadata,
|
|
24
16
|
PretrainedEmbeddingModelName,
|
|
17
|
+
orca_api,
|
|
25
18
|
)
|
|
26
|
-
from ._utils.common import UNSET, CreateMode, DropMode
|
|
27
19
|
from .datasource import Datasource
|
|
28
20
|
from .job import Job, Status
|
|
29
21
|
|
|
@@ -32,7 +24,6 @@ if TYPE_CHECKING:
|
|
|
32
24
|
|
|
33
25
|
|
|
34
26
|
class _EmbeddingModel:
|
|
35
|
-
name: str
|
|
36
27
|
embedding_dim: int
|
|
37
28
|
max_seq_length: int
|
|
38
29
|
uses_context: bool
|
|
@@ -41,7 +32,6 @@ class _EmbeddingModel:
|
|
|
41
32
|
def __init__(
|
|
42
33
|
self, *, name: str, embedding_dim: int, max_seq_length: int, uses_context: bool, supports_instructions: bool
|
|
43
34
|
):
|
|
44
|
-
self.name = name
|
|
45
35
|
self.embedding_dim = embedding_dim
|
|
46
36
|
self.max_seq_length = max_seq_length
|
|
47
37
|
self.uses_context = uses_context
|
|
@@ -56,8 +46,10 @@ class _EmbeddingModel:
|
|
|
56
46
|
"""Get error message for instruction not supported"""
|
|
57
47
|
if isinstance(self, FinetunedEmbeddingModel):
|
|
58
48
|
return f"Model {self.name} does not support instructions. Instruction-following is only supported by models based on instruction-supporting models."
|
|
59
|
-
|
|
49
|
+
elif isinstance(self, PretrainedEmbeddingModel):
|
|
60
50
|
return f"Model {self.name} does not support instructions. Instruction-following is only supported by instruction-supporting models."
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError("Invalid embedding model")
|
|
61
53
|
|
|
62
54
|
@overload
|
|
63
55
|
def embed(self, value: str, max_seq_length: int | None = None, prompt: str | None = None) -> list[float]:
|
|
@@ -85,17 +77,165 @@ class _EmbeddingModel:
|
|
|
85
77
|
values, or a list of floats representing the embedding for the single value if the
|
|
86
78
|
input is a single value
|
|
87
79
|
"""
|
|
88
|
-
|
|
89
|
-
values
|
|
90
|
-
|
|
80
|
+
payload: EmbedRequest = {
|
|
81
|
+
"values": value if isinstance(value, list) else [value],
|
|
82
|
+
"max_seq_length": max_seq_length,
|
|
83
|
+
"prompt": prompt,
|
|
84
|
+
}
|
|
91
85
|
if isinstance(self, PretrainedEmbeddingModel):
|
|
92
|
-
embeddings =
|
|
86
|
+
embeddings = orca_api.POST(
|
|
87
|
+
"/gpu/pretrained_embedding_model/{model_name}/embedding",
|
|
88
|
+
params={"model_name": cast(PretrainedEmbeddingModelName, self.name)},
|
|
89
|
+
json=payload,
|
|
90
|
+
timeout=30, # may be slow in case of cold start
|
|
91
|
+
)
|
|
93
92
|
elif isinstance(self, FinetunedEmbeddingModel):
|
|
94
|
-
embeddings =
|
|
93
|
+
embeddings = orca_api.POST(
|
|
94
|
+
"/gpu/finetuned_embedding_model/{name_or_id}/embedding",
|
|
95
|
+
params={"name_or_id": self.id},
|
|
96
|
+
json=payload,
|
|
97
|
+
timeout=30, # may be slow in case of cold start
|
|
98
|
+
)
|
|
95
99
|
else:
|
|
96
100
|
raise ValueError("Invalid embedding model")
|
|
97
101
|
return embeddings if isinstance(value, list) else embeddings[0]
|
|
98
102
|
|
|
103
|
+
@overload
|
|
104
|
+
def evaluate(
|
|
105
|
+
self,
|
|
106
|
+
datasource: Datasource,
|
|
107
|
+
*,
|
|
108
|
+
value_column: str = "value",
|
|
109
|
+
label_column: str,
|
|
110
|
+
score_column: None = None,
|
|
111
|
+
eval_datasource: Datasource | None = None,
|
|
112
|
+
subsample: int | None = None,
|
|
113
|
+
neighbor_count: int = 5,
|
|
114
|
+
batch_size: int = 32,
|
|
115
|
+
weigh_memories: bool = True,
|
|
116
|
+
background: Literal[True],
|
|
117
|
+
) -> Job[ClassificationMetrics]:
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
@overload
|
|
121
|
+
def evaluate(
|
|
122
|
+
self,
|
|
123
|
+
datasource: Datasource,
|
|
124
|
+
*,
|
|
125
|
+
value_column: str = "value",
|
|
126
|
+
label_column: str,
|
|
127
|
+
score_column: None = None,
|
|
128
|
+
eval_datasource: Datasource | None = None,
|
|
129
|
+
subsample: int | None = None,
|
|
130
|
+
neighbor_count: int = 5,
|
|
131
|
+
batch_size: int = 32,
|
|
132
|
+
weigh_memories: bool = True,
|
|
133
|
+
background: Literal[False] = False,
|
|
134
|
+
) -> ClassificationMetrics:
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
@overload
|
|
138
|
+
def evaluate(
|
|
139
|
+
self,
|
|
140
|
+
datasource: Datasource,
|
|
141
|
+
*,
|
|
142
|
+
value_column: str = "value",
|
|
143
|
+
label_column: None = None,
|
|
144
|
+
score_column: str,
|
|
145
|
+
eval_datasource: Datasource | None = None,
|
|
146
|
+
subsample: int | None = None,
|
|
147
|
+
neighbor_count: int = 5,
|
|
148
|
+
batch_size: int = 32,
|
|
149
|
+
weigh_memories: bool = True,
|
|
150
|
+
background: Literal[True],
|
|
151
|
+
) -> Job[RegressionMetrics]:
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
@overload
|
|
155
|
+
def evaluate(
|
|
156
|
+
self,
|
|
157
|
+
datasource: Datasource,
|
|
158
|
+
*,
|
|
159
|
+
value_column: str = "value",
|
|
160
|
+
label_column: None = None,
|
|
161
|
+
score_column: str,
|
|
162
|
+
eval_datasource: Datasource | None = None,
|
|
163
|
+
subsample: int | None = None,
|
|
164
|
+
neighbor_count: int = 5,
|
|
165
|
+
batch_size: int = 32,
|
|
166
|
+
weigh_memories: bool = True,
|
|
167
|
+
background: Literal[False] = False,
|
|
168
|
+
) -> RegressionMetrics:
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
def evaluate(
|
|
172
|
+
self,
|
|
173
|
+
datasource: Datasource,
|
|
174
|
+
*,
|
|
175
|
+
value_column: str = "value",
|
|
176
|
+
label_column: str | None = None,
|
|
177
|
+
score_column: str | None = None,
|
|
178
|
+
eval_datasource: Datasource | None = None,
|
|
179
|
+
subsample: int | None = None,
|
|
180
|
+
neighbor_count: int = 5,
|
|
181
|
+
batch_size: int = 32,
|
|
182
|
+
weigh_memories: bool = True,
|
|
183
|
+
background: bool = False,
|
|
184
|
+
) -> (
|
|
185
|
+
ClassificationMetrics
|
|
186
|
+
| RegressionMetrics
|
|
187
|
+
| Job[ClassificationMetrics]
|
|
188
|
+
| Job[RegressionMetrics]
|
|
189
|
+
| Job[ClassificationMetrics | RegressionMetrics]
|
|
190
|
+
):
|
|
191
|
+
"""
|
|
192
|
+
Evaluate the finetuned embedding model
|
|
193
|
+
"""
|
|
194
|
+
payload: EmbeddingEvaluationRequest = {
|
|
195
|
+
"datasource_name_or_id": datasource.id,
|
|
196
|
+
"datasource_label_column": label_column,
|
|
197
|
+
"datasource_value_column": value_column,
|
|
198
|
+
"datasource_score_column": score_column,
|
|
199
|
+
"eval_datasource_name_or_id": eval_datasource.id if eval_datasource is not None else None,
|
|
200
|
+
"subsample": subsample,
|
|
201
|
+
"neighbor_count": neighbor_count,
|
|
202
|
+
"batch_size": batch_size,
|
|
203
|
+
"weigh_memories": weigh_memories,
|
|
204
|
+
}
|
|
205
|
+
if isinstance(self, PretrainedEmbeddingModel):
|
|
206
|
+
response = orca_api.POST(
|
|
207
|
+
"/pretrained_embedding_model/{model_name}/evaluation",
|
|
208
|
+
params={"model_name": self.name},
|
|
209
|
+
json=payload,
|
|
210
|
+
)
|
|
211
|
+
elif isinstance(self, FinetunedEmbeddingModel):
|
|
212
|
+
response = orca_api.POST(
|
|
213
|
+
"/finetuned_embedding_model/{name_or_id}/evaluation",
|
|
214
|
+
params={"name_or_id": self.id},
|
|
215
|
+
json=payload,
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
raise ValueError("Invalid embedding model")
|
|
219
|
+
|
|
220
|
+
def get_result(task_id: str) -> ClassificationMetrics | RegressionMetrics:
|
|
221
|
+
if isinstance(self, PretrainedEmbeddingModel):
|
|
222
|
+
res = orca_api.GET(
|
|
223
|
+
"/pretrained_embedding_model/{model_name}/evaluation/{task_id}",
|
|
224
|
+
params={"model_name": self.name, "task_id": task_id},
|
|
225
|
+
)["result"]
|
|
226
|
+
elif isinstance(self, FinetunedEmbeddingModel):
|
|
227
|
+
res = orca_api.GET(
|
|
228
|
+
"/finetuned_embedding_model/{name_or_id}/evaluation/{task_id}",
|
|
229
|
+
params={"name_or_id": self.id, "task_id": task_id},
|
|
230
|
+
)["result"]
|
|
231
|
+
else:
|
|
232
|
+
raise ValueError("Invalid embedding model")
|
|
233
|
+
assert res is not None
|
|
234
|
+
return RegressionMetrics(**res) if "mse" in res else ClassificationMetrics(**res)
|
|
235
|
+
|
|
236
|
+
job = Job(response["task_id"], lambda: get_result(response["task_id"]))
|
|
237
|
+
return job if background else job.result()
|
|
238
|
+
|
|
99
239
|
|
|
100
240
|
class _ModelDescriptor:
|
|
101
241
|
"""
|
|
@@ -144,7 +284,7 @@ class _ModelDescriptor:
|
|
|
144
284
|
# Load the model on first access
|
|
145
285
|
if self.model is None:
|
|
146
286
|
try:
|
|
147
|
-
self.model = PretrainedEmbeddingModel._get(self.name)
|
|
287
|
+
self.model = PretrainedEmbeddingModel._get(cast(PretrainedEmbeddingModelName, self.name))
|
|
148
288
|
except (KeyError, AttributeError):
|
|
149
289
|
raise AttributeError(f"No embedding model named {self.name}")
|
|
150
290
|
|
|
@@ -205,19 +345,18 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
205
345
|
QWEN2_1_5B = _ModelDescriptor("QWEN2_1_5B")
|
|
206
346
|
BGE_BASE = _ModelDescriptor("BGE_BASE")
|
|
207
347
|
|
|
208
|
-
|
|
348
|
+
name: PretrainedEmbeddingModelName
|
|
209
349
|
|
|
210
350
|
def __init__(self, metadata: PretrainedEmbeddingModelMetadata):
|
|
211
351
|
# for internal use only, do not document
|
|
212
|
-
self.
|
|
213
|
-
|
|
352
|
+
self.name = metadata["name"]
|
|
214
353
|
super().__init__(
|
|
215
|
-
name=metadata
|
|
216
|
-
embedding_dim=metadata
|
|
217
|
-
max_seq_length=metadata
|
|
218
|
-
uses_context=metadata
|
|
354
|
+
name=metadata["name"],
|
|
355
|
+
embedding_dim=metadata["embedding_dim"],
|
|
356
|
+
max_seq_length=metadata["max_seq_length"],
|
|
357
|
+
uses_context=metadata["uses_context"],
|
|
219
358
|
supports_instructions=(
|
|
220
|
-
bool(metadata
|
|
359
|
+
bool(metadata["supports_instructions"]) if "supports_instructions" in metadata else False
|
|
221
360
|
),
|
|
222
361
|
)
|
|
223
362
|
|
|
@@ -235,21 +374,24 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
235
374
|
Returns:
|
|
236
375
|
A list of all pretrained embedding models available in the OrcaCloud
|
|
237
376
|
"""
|
|
238
|
-
return [cls(metadata) for metadata in
|
|
377
|
+
return [cls(metadata) for metadata in orca_api.GET("/pretrained_embedding_model")]
|
|
239
378
|
|
|
240
379
|
_instances: dict[str, PretrainedEmbeddingModel] = {}
|
|
241
380
|
|
|
242
381
|
@classmethod
|
|
243
|
-
def _get(cls, name: PretrainedEmbeddingModelName
|
|
382
|
+
def _get(cls, name: PretrainedEmbeddingModelName) -> PretrainedEmbeddingModel:
|
|
244
383
|
# for internal use only, do not document - we want people to use dot notation to get the model
|
|
245
384
|
cache_key = str(name)
|
|
246
385
|
if cache_key not in cls._instances:
|
|
247
|
-
metadata =
|
|
386
|
+
metadata = orca_api.GET(
|
|
387
|
+
"/pretrained_embedding_model/{model_name}",
|
|
388
|
+
params={"model_name": name},
|
|
389
|
+
)
|
|
248
390
|
cls._instances[cache_key] = cls(metadata)
|
|
249
391
|
return cls._instances[cache_key]
|
|
250
392
|
|
|
251
393
|
@classmethod
|
|
252
|
-
def open(cls, name:
|
|
394
|
+
def open(cls, name: PretrainedEmbeddingModelName) -> PretrainedEmbeddingModel:
|
|
253
395
|
"""
|
|
254
396
|
Open an embedding model by name.
|
|
255
397
|
|
|
@@ -282,7 +424,7 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
282
424
|
Returns:
|
|
283
425
|
True if the pretrained embedding model exists, False otherwise
|
|
284
426
|
"""
|
|
285
|
-
return name in PretrainedEmbeddingModelName
|
|
427
|
+
return name in get_args(PretrainedEmbeddingModelName)
|
|
286
428
|
|
|
287
429
|
@overload
|
|
288
430
|
def finetune(
|
|
@@ -293,7 +435,7 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
293
435
|
eval_datasource: Datasource | None = None,
|
|
294
436
|
label_column: str = "label",
|
|
295
437
|
value_column: str = "value",
|
|
296
|
-
training_method: EmbeddingFinetuningMethod
|
|
438
|
+
training_method: EmbeddingFinetuningMethod = "classification",
|
|
297
439
|
training_args: dict | None = None,
|
|
298
440
|
if_exists: CreateMode = "error",
|
|
299
441
|
background: Literal[True],
|
|
@@ -309,7 +451,7 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
309
451
|
eval_datasource: Datasource | None = None,
|
|
310
452
|
label_column: str = "label",
|
|
311
453
|
value_column: str = "value",
|
|
312
|
-
training_method: EmbeddingFinetuningMethod
|
|
454
|
+
training_method: EmbeddingFinetuningMethod = "classification",
|
|
313
455
|
training_args: dict | None = None,
|
|
314
456
|
if_exists: CreateMode = "error",
|
|
315
457
|
background: Literal[False] = False,
|
|
@@ -324,7 +466,7 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
324
466
|
eval_datasource: Datasource | None = None,
|
|
325
467
|
label_column: str = "label",
|
|
326
468
|
value_column: str = "value",
|
|
327
|
-
training_method: EmbeddingFinetuningMethod
|
|
469
|
+
training_method: EmbeddingFinetuningMethod = "classification",
|
|
328
470
|
training_args: dict | None = None,
|
|
329
471
|
if_exists: CreateMode = "error",
|
|
330
472
|
background: bool = False,
|
|
@@ -364,32 +506,35 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
364
506
|
elif exists and if_exists == "open":
|
|
365
507
|
existing = FinetunedEmbeddingModel.open(name)
|
|
366
508
|
|
|
367
|
-
if existing.base_model_name != self.
|
|
509
|
+
if existing.base_model_name != self.name:
|
|
368
510
|
raise ValueError(f"Finetuned embedding model '{name}' already exists, but with different base model")
|
|
369
511
|
|
|
370
512
|
return existing
|
|
371
513
|
|
|
372
514
|
from .memoryset import LabeledMemoryset
|
|
373
515
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
516
|
+
payload: FinetuneEmbeddingModelRequest = {
|
|
517
|
+
"name": name,
|
|
518
|
+
"base_model": self.name,
|
|
519
|
+
"label_column": label_column,
|
|
520
|
+
"value_column": value_column,
|
|
521
|
+
"training_method": training_method,
|
|
522
|
+
"training_args": training_args or {},
|
|
523
|
+
}
|
|
524
|
+
if isinstance(train_datasource, Datasource):
|
|
525
|
+
payload["train_datasource_name_or_id"] = train_datasource.id
|
|
526
|
+
elif isinstance(train_datasource, LabeledMemoryset):
|
|
527
|
+
payload["train_memoryset_name_or_id"] = train_datasource.id
|
|
528
|
+
if eval_datasource is not None:
|
|
529
|
+
payload["eval_datasource_name_or_id"] = eval_datasource.id
|
|
530
|
+
|
|
531
|
+
res = orca_api.POST(
|
|
532
|
+
"/finetuned_embedding_model",
|
|
533
|
+
json=payload,
|
|
389
534
|
)
|
|
390
535
|
job = Job(
|
|
391
|
-
res
|
|
392
|
-
lambda: FinetunedEmbeddingModel.open(res
|
|
536
|
+
res["finetuning_task_id"],
|
|
537
|
+
lambda: FinetunedEmbeddingModel.open(res["id"]),
|
|
393
538
|
)
|
|
394
539
|
return job if background else job.result()
|
|
395
540
|
|
|
@@ -409,23 +554,26 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
409
554
|
"""
|
|
410
555
|
|
|
411
556
|
id: str
|
|
557
|
+
name: str
|
|
412
558
|
created_at: datetime
|
|
413
559
|
updated_at: datetime
|
|
560
|
+
base_model_name: PretrainedEmbeddingModelName
|
|
414
561
|
_status: Status
|
|
415
562
|
|
|
416
563
|
def __init__(self, metadata: FinetunedEmbeddingModelMetadata):
|
|
417
564
|
# for internal use only, do not document
|
|
418
|
-
self.id = metadata
|
|
419
|
-
self.
|
|
420
|
-
self.
|
|
421
|
-
self.
|
|
422
|
-
self.
|
|
565
|
+
self.id = metadata["id"]
|
|
566
|
+
self.name = metadata["name"]
|
|
567
|
+
self.created_at = datetime.fromisoformat(metadata["created_at"])
|
|
568
|
+
self.updated_at = datetime.fromisoformat(metadata["updated_at"])
|
|
569
|
+
self.base_model_name = metadata["base_model"]
|
|
570
|
+
self._status = Status(metadata["finetuning_status"])
|
|
423
571
|
|
|
424
572
|
super().__init__(
|
|
425
|
-
name=metadata
|
|
426
|
-
embedding_dim=metadata
|
|
427
|
-
max_seq_length=metadata
|
|
428
|
-
uses_context=metadata
|
|
573
|
+
name=metadata["name"],
|
|
574
|
+
embedding_dim=metadata["embedding_dim"],
|
|
575
|
+
max_seq_length=metadata["max_seq_length"],
|
|
576
|
+
uses_context=metadata["uses_context"],
|
|
429
577
|
supports_instructions=self.base_model.supports_instructions,
|
|
430
578
|
)
|
|
431
579
|
|
|
@@ -438,7 +586,7 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
438
586
|
f" name: {self.name},\n"
|
|
439
587
|
f" embedding_dim: {self.embedding_dim},\n"
|
|
440
588
|
f" max_seq_length: {self.max_seq_length},\n"
|
|
441
|
-
f" base_model: PretrainedEmbeddingModel.{self.base_model_name
|
|
589
|
+
f" base_model: PretrainedEmbeddingModel.{self.base_model_name}\n"
|
|
442
590
|
"})"
|
|
443
591
|
)
|
|
444
592
|
|
|
@@ -455,7 +603,7 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
455
603
|
Returns:
|
|
456
604
|
A list of all finetuned embedding model handles in the OrcaCloud
|
|
457
605
|
"""
|
|
458
|
-
return [cls(metadata) for metadata in
|
|
606
|
+
return [cls(metadata) for metadata in orca_api.GET("/finetuned_embedding_model")]
|
|
459
607
|
|
|
460
608
|
@classmethod
|
|
461
609
|
def open(cls, name: str) -> FinetunedEmbeddingModel:
|
|
@@ -471,7 +619,10 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
471
619
|
Raises:
|
|
472
620
|
LookupError: If the finetuned embedding model does not exist
|
|
473
621
|
"""
|
|
474
|
-
metadata =
|
|
622
|
+
metadata = orca_api.GET(
|
|
623
|
+
"/finetuned_embedding_model/{name_or_id}",
|
|
624
|
+
params={"name_or_id": name},
|
|
625
|
+
)
|
|
475
626
|
return cls(metadata)
|
|
476
627
|
|
|
477
628
|
@classmethod
|
|
@@ -503,7 +654,10 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
503
654
|
LookupError: If the finetuned embedding model does not exist and `if_not_exists` is `"error"`
|
|
504
655
|
"""
|
|
505
656
|
try:
|
|
506
|
-
|
|
657
|
+
orca_api.DELETE(
|
|
658
|
+
"/finetuned_embedding_model/{name_or_id}",
|
|
659
|
+
params={"name_or_id": name_or_id},
|
|
660
|
+
)
|
|
507
661
|
except (LookupError, RuntimeError):
|
|
508
662
|
if if_not_exists == "error":
|
|
509
663
|
raise
|
orca_sdk/embedding_model_test.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
from typing import get_args
|
|
2
3
|
from uuid import uuid4
|
|
3
4
|
|
|
4
5
|
import pytest
|
|
5
6
|
|
|
6
7
|
from .datasource import Datasource
|
|
7
8
|
from .embedding_model import (
|
|
9
|
+
ClassificationMetrics,
|
|
8
10
|
FinetunedEmbeddingModel,
|
|
9
11
|
PretrainedEmbeddingModel,
|
|
10
12
|
PretrainedEmbeddingModelName,
|
|
@@ -30,16 +32,16 @@ def test_open_pretrained_model_unauthenticated(unauthenticated):
|
|
|
30
32
|
|
|
31
33
|
def test_open_pretrained_model_not_found():
|
|
32
34
|
with pytest.raises(LookupError):
|
|
33
|
-
PretrainedEmbeddingModel._get("INVALID_MODEL")
|
|
35
|
+
PretrainedEmbeddingModel._get("INVALID_MODEL") # type: ignore
|
|
34
36
|
|
|
35
37
|
|
|
36
38
|
def test_all_pretrained_models():
|
|
37
39
|
models = PretrainedEmbeddingModel.all()
|
|
38
40
|
assert len(models) > 1
|
|
39
|
-
if len(models) != len(PretrainedEmbeddingModelName):
|
|
41
|
+
if len(models) != len(get_args(PretrainedEmbeddingModelName)):
|
|
40
42
|
logging.warning("Please regenerate the SDK client! Some pretrained model names are not exposed yet.")
|
|
41
43
|
model_names = [m.name for m in models]
|
|
42
|
-
assert all(
|
|
44
|
+
assert all(m in model_names for m in get_args(PretrainedEmbeddingModelName))
|
|
43
45
|
|
|
44
46
|
|
|
45
47
|
def test_embed_text():
|
|
@@ -55,6 +57,13 @@ def test_embed_text_unauthenticated(unauthenticated):
|
|
|
55
57
|
PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
|
|
56
58
|
|
|
57
59
|
|
|
60
|
+
def test_evaluate_pretrained_model(datasource: Datasource):
|
|
61
|
+
metrics = PretrainedEmbeddingModel.GTE_BASE.evaluate(datasource=datasource, label_column="label")
|
|
62
|
+
assert metrics is not None
|
|
63
|
+
assert isinstance(metrics, ClassificationMetrics)
|
|
64
|
+
assert metrics.accuracy > 0.5
|
|
65
|
+
|
|
66
|
+
|
|
58
67
|
@pytest.fixture(scope="session")
|
|
59
68
|
def finetuned_model(datasource) -> FinetunedEmbeddingModel:
|
|
60
69
|
return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
|
|
@@ -83,18 +92,14 @@ def test_finetune_model_with_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
|
83
92
|
|
|
84
93
|
def test_finetune_model_already_exists_error(datasource: Datasource, finetuned_model):
|
|
85
94
|
with pytest.raises(ValueError):
|
|
86
|
-
PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource
|
|
95
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
|
|
87
96
|
|
|
88
97
|
|
|
89
98
|
def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_model):
|
|
90
99
|
with pytest.raises(ValueError):
|
|
91
|
-
PretrainedEmbeddingModel.GTE_BASE.finetune(
|
|
92
|
-
"test_finetuned_model", datasource, if_exists="open", value_column="text"
|
|
93
|
-
)
|
|
100
|
+
PretrainedEmbeddingModel.GTE_BASE.finetune("test_finetuned_model", datasource, if_exists="open")
|
|
94
101
|
|
|
95
|
-
new_model = PretrainedEmbeddingModel.DISTILBERT.finetune(
|
|
96
|
-
"test_finetuned_model", datasource, if_exists="open", value_column="text"
|
|
97
|
-
)
|
|
102
|
+
new_model = PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, if_exists="open")
|
|
98
103
|
assert new_model is not None
|
|
99
104
|
assert new_model.name == "test_finetuned_model"
|
|
100
105
|
assert new_model.base_model == PretrainedEmbeddingModel.DISTILBERT
|
|
@@ -105,9 +110,7 @@ def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_
|
|
|
105
110
|
|
|
106
111
|
def test_finetune_model_unauthenticated(unauthenticated, datasource: Datasource):
|
|
107
112
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
108
|
-
PretrainedEmbeddingModel.DISTILBERT.finetune(
|
|
109
|
-
"test_finetuned_model_unauthenticated", datasource, value_column="text"
|
|
110
|
-
)
|
|
113
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_unauthenticated", datasource)
|
|
111
114
|
|
|
112
115
|
|
|
113
116
|
def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_model: FinetunedEmbeddingModel):
|
|
@@ -166,7 +169,7 @@ def test_drop_finetuned_model(datasource: Datasource):
|
|
|
166
169
|
|
|
167
170
|
def test_drop_finetuned_model_unauthenticated(unauthenticated, datasource: Datasource):
|
|
168
171
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
169
|
-
PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource
|
|
172
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
|
|
170
173
|
|
|
171
174
|
|
|
172
175
|
def test_drop_finetuned_model_not_found():
|
|
@@ -181,28 +184,16 @@ def test_drop_finetuned_model_unauthorized(unauthorized, finetuned_model: Finetu
|
|
|
181
184
|
FinetunedEmbeddingModel.drop(finetuned_model.id)
|
|
182
185
|
|
|
183
186
|
|
|
184
|
-
def
|
|
185
|
-
|
|
186
|
-
# Test with an instruction-supporting model
|
|
187
|
-
model = PretrainedEmbeddingModel.open("E5_LARGE")
|
|
188
|
-
|
|
189
|
-
# Verify the model properties
|
|
190
|
-
assert model.supports_instructions
|
|
191
|
-
|
|
192
|
-
# Test that prompt parameter is passed through correctly (orcalib handles the default)
|
|
193
|
-
embeddings_explicit_instruction = model.embed("Hello world", prompt="Represent this sentence for retrieval:")
|
|
194
|
-
embeddings_no_instruction = model.embed("Hello world")
|
|
195
|
-
|
|
196
|
-
# These should be different since one uses a prompt and the other doesn't
|
|
197
|
-
assert embeddings_explicit_instruction != embeddings_no_instruction
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
def test_default_instruction_error_cases():
|
|
201
|
-
"""Test basic embedding model functionality."""
|
|
202
|
-
# Test that model opens correctly and has instruction support information
|
|
203
|
-
model = PretrainedEmbeddingModel.open("GTE_BASE")
|
|
187
|
+
def test_supports_instructions():
|
|
188
|
+
model = PretrainedEmbeddingModel.GTE_BASE
|
|
204
189
|
assert not model.supports_instructions
|
|
205
190
|
|
|
206
|
-
|
|
207
|
-
instruction_model = PretrainedEmbeddingModel.open("E5_LARGE")
|
|
191
|
+
instruction_model = PretrainedEmbeddingModel.BGE_BASE
|
|
208
192
|
assert instruction_model.supports_instructions
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def test_use_explicit_instruction_prompt():
|
|
196
|
+
model = PretrainedEmbeddingModel.BGE_BASE
|
|
197
|
+
assert model.supports_instructions
|
|
198
|
+
input = "Hello world"
|
|
199
|
+
assert model.embed(input, prompt="Represent this sentence for sentiment retrieval:") != model.embed(input)
|