orca-sdk 0.0.96__py3-none-any.whl → 0.0.98__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +2 -5
- orca_sdk/_shared/__init__.py +1 -0
- orca_sdk/_shared/metrics.py +1 -1
- orca_sdk/_utils/analysis_ui.py +5 -5
- orca_sdk/_utils/auth.py +23 -33
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/classification_model.py +188 -126
- orca_sdk/classification_model_test.py +102 -0
- orca_sdk/client.py +3515 -0
- orca_sdk/conftest.py +10 -0
- orca_sdk/credentials.py +73 -21
- orca_sdk/credentials_test.py +20 -0
- orca_sdk/datasource.py +186 -81
- orca_sdk/datasource_test.py +194 -0
- orca_sdk/embedding_model.py +267 -75
- orca_sdk/embedding_model_test.py +32 -14
- orca_sdk/job.py +59 -54
- orca_sdk/job_test.py +50 -0
- orca_sdk/memoryset.py +372 -345
- orca_sdk/memoryset_test.py +7 -11
- orca_sdk/regression_model.py +120 -111
- orca_sdk/regression_model_test.py +15 -0
- orca_sdk/telemetry.py +229 -115
- {orca_sdk-0.0.96.dist-info → orca_sdk-0.0.98.dist-info}/METADATA +19 -5
- orca_sdk-0.0.98.dist-info/RECORD +40 -0
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -287
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_classification_model_gpu_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_classification_model_classification_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_label_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +0 -183
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
- orca_sdk/_generated_api_client/api/datasource/create_embedding_evaluation_datasource_name_or_id_embedding_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +0 -172
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_embedding_evaluation_datasource_name_or_id_embedding_evaluation_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/datasource/list_embedding_evaluations_datasource_name_or_id_embedding_evaluation_get.py +0 -235
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/analyze_memoryset_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/batch_delete_memoryset_batch_delete_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -210
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -186
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -235
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -180
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -212
- orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +0 -195
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -210
- orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +0 -233
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -216
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -205
- orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +0 -183
- orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +0 -150
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -192
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -161
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/regression_model/create_regression_model_gpu_regression_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_regression_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/regression_model/list_regression_models_regression_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/regression_model/update_regression_model_regression_model_name_or_id_patch.py +0 -183
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -293
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +0 -168
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/explain_prediction_telemetry_prediction_prediction_id_explanation_get.py +0 -182
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -180
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +0 -198
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -198
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -295
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -116
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -137
- orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py +0 -9
- orca_sdk/_generated_api_client/models/base_label_prediction_result.py +0 -130
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/base_score_prediction_result.py +0 -108
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -207
- orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +0 -154
- orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +0 -92
- orca_sdk/_generated_api_client/models/classification_evaluation_request.py +0 -148
- orca_sdk/_generated_api_client/models/classification_metrics.py +0 -259
- orca_sdk/_generated_api_client/models/classification_model_metadata.py +0 -213
- orca_sdk/_generated_api_client/models/classification_prediction_request.py +0 -220
- orca_sdk/_generated_api_client/models/clone_memoryset_request.py +0 -170
- orca_sdk/_generated_api_client/models/cluster_metrics.py +0 -78
- orca_sdk/_generated_api_client/models/column_info.py +0 -145
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/constraint_violation_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/count_predictions_request.py +0 -195
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -120
- orca_sdk/_generated_api_client/models/create_api_key_request_scope_item.py +0 -9
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -145
- orca_sdk/_generated_api_client/models/create_api_key_response_scope_item.py +0 -9
- orca_sdk/_generated_api_client/models/create_classification_model_request.py +0 -197
- orca_sdk/_generated_api_client/models/create_memoryset_request.py +0 -325
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +0 -66
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +0 -13
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +0 -137
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -156
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/delete_memorysets_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -135
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +0 -187
- orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +0 -179
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +0 -158
- orca_sdk/_generated_api_client/models/embedding_evaluation_result.py +0 -86
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/embedding_model_result.py +0 -114
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -153
- orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +0 -140
- orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_metrics.py +0 -85
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/filter_item.py +0 -231
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -17
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -20
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/http_validation_error.py +0 -86
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -210
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -288
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -186
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -128
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -194
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -246
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +0 -207
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_feedback_metrics.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -319
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
- orca_sdk/_generated_api_client/models/lookup_score_metrics.py +0 -94
- orca_sdk/_generated_api_client/models/memory_metrics.py +0 -165
- orca_sdk/_generated_api_client/models/memory_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -212
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -105
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -182
- orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config.py +0 -202
- orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_clustering_method.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_partitioning_method.py +0 -10
- orca_sdk/_generated_api_client/models/memoryset_cluster_metrics.py +0 -100
- orca_sdk/_generated_api_client/models/memoryset_duplicate_analysis_config.py +0 -70
- orca_sdk/_generated_api_client/models/memoryset_duplicate_metrics.py +0 -70
- orca_sdk/_generated_api_client/models/memoryset_label_analysis_config.py +0 -70
- orca_sdk/_generated_api_client/models/memoryset_label_metrics.py +0 -116
- orca_sdk/_generated_api_client/models/memoryset_metadata.py +0 -291
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +0 -55
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +0 -13
- orca_sdk/_generated_api_client/models/memoryset_metrics.py +0 -232
- orca_sdk/_generated_api_client/models/memoryset_neighbor_analysis_config.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics.py +0 -76
- orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics_lookup_score_metrics.py +0 -68
- orca_sdk/_generated_api_client/models/memoryset_projection_analysis_config.py +0 -79
- orca_sdk/_generated_api_client/models/memoryset_projection_metrics.py +0 -55
- orca_sdk/_generated_api_client/models/memoryset_update.py +0 -101
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -22
- orca_sdk/_generated_api_client/models/paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py +0 -135
- orca_sdk/_generated_api_client/models/pr_curve.py +0 -86
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_0.py +0 -10
- orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/predictive_model_update.py +0 -91
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -107
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -17
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rar_head_type.py +0 -8
- orca_sdk/_generated_api_client/models/regression_evaluation_request.py +0 -148
- orca_sdk/_generated_api_client/models/regression_metrics.py +0 -172
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +0 -177
- orca_sdk/_generated_api_client/models/regression_prediction_request.py +0 -195
- orca_sdk/_generated_api_client/models/roc_curve.py +0 -86
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +0 -196
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +0 -252
- orca_sdk/_generated_api_client/models/scored_memory.py +0 -172
- orca_sdk/_generated_api_client/models/scored_memory_insert.py +0 -128
- orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_lookup.py +0 -180
- orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +0 -193
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +0 -68
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_2.py +0 -9
- orca_sdk/_generated_api_client/models/telemetry_filter_item.py +0 -205
- orca_sdk/_generated_api_client/models/telemetry_filter_item_op.py +0 -15
- orca_sdk/_generated_api_client/models/telemetry_memories_request.py +0 -181
- orca_sdk/_generated_api_client/models/telemetry_sort_options.py +0 -173
- orca_sdk/_generated_api_client/models/telemetry_sort_options_direction.py +0 -9
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -133
- orca_sdk/_generated_api_client/models/validation_error.py +0 -99
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk-0.0.96.dist-info/RECORD +0 -278
- {orca_sdk-0.0.96.dist-info → orca_sdk-0.0.98.dist-info}/WHEEL +0 -0
orca_sdk/embedding_model.py
CHANGED
|
@@ -2,28 +2,20 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from abc import abstractmethod
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import TYPE_CHECKING, Literal, Sequence, cast, overload
|
|
6
|
-
|
|
7
|
-
from .
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
embed_with_pretrained_model_gpu,
|
|
12
|
-
get_finetuned_embedding_model,
|
|
13
|
-
get_pretrained_embedding_model,
|
|
14
|
-
list_finetuned_embedding_models,
|
|
15
|
-
list_pretrained_embedding_models,
|
|
16
|
-
)
|
|
17
|
-
from ._generated_api_client.models import (
|
|
5
|
+
from typing import TYPE_CHECKING, Literal, Sequence, cast, get_args, overload
|
|
6
|
+
|
|
7
|
+
from ._shared.metrics import ClassificationMetrics, RegressionMetrics
|
|
8
|
+
from ._utils.common import UNSET, CreateMode, DropMode
|
|
9
|
+
from .client import (
|
|
10
|
+
EmbeddingEvaluationRequest,
|
|
18
11
|
EmbeddingFinetuningMethod,
|
|
19
12
|
EmbedRequest,
|
|
20
13
|
FinetunedEmbeddingModelMetadata,
|
|
21
14
|
FinetuneEmbeddingModelRequest,
|
|
22
|
-
FinetuneEmbeddingModelRequestTrainingArgs,
|
|
23
15
|
PretrainedEmbeddingModelMetadata,
|
|
24
16
|
PretrainedEmbeddingModelName,
|
|
17
|
+
orca_api,
|
|
25
18
|
)
|
|
26
|
-
from ._utils.common import CreateMode, DropMode
|
|
27
19
|
from .datasource import Datasource
|
|
28
20
|
from .job import Job, Status
|
|
29
21
|
|
|
@@ -32,52 +24,218 @@ if TYPE_CHECKING:
|
|
|
32
24
|
|
|
33
25
|
|
|
34
26
|
class _EmbeddingModel:
|
|
35
|
-
name: str
|
|
36
27
|
embedding_dim: int
|
|
37
28
|
max_seq_length: int
|
|
38
29
|
uses_context: bool
|
|
30
|
+
supports_instructions: bool
|
|
39
31
|
|
|
40
|
-
def __init__(
|
|
41
|
-
self
|
|
32
|
+
def __init__(
|
|
33
|
+
self, *, name: str, embedding_dim: int, max_seq_length: int, uses_context: bool, supports_instructions: bool
|
|
34
|
+
):
|
|
42
35
|
self.embedding_dim = embedding_dim
|
|
43
36
|
self.max_seq_length = max_seq_length
|
|
44
37
|
self.uses_context = uses_context
|
|
38
|
+
self.supports_instructions = supports_instructions
|
|
45
39
|
|
|
46
40
|
@classmethod
|
|
47
41
|
@abstractmethod
|
|
48
42
|
def all(cls) -> Sequence[_EmbeddingModel]:
|
|
49
43
|
pass
|
|
50
44
|
|
|
45
|
+
def _get_instruction_error_message(self) -> str:
|
|
46
|
+
"""Get error message for instruction not supported"""
|
|
47
|
+
if isinstance(self, FinetunedEmbeddingModel):
|
|
48
|
+
return f"Model {self.name} does not support instructions. Instruction-following is only supported by models based on instruction-supporting models."
|
|
49
|
+
elif isinstance(self, PretrainedEmbeddingModel):
|
|
50
|
+
return f"Model {self.name} does not support instructions. Instruction-following is only supported by instruction-supporting models."
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError("Invalid embedding model")
|
|
53
|
+
|
|
51
54
|
@overload
|
|
52
|
-
def embed(self, value: str, max_seq_length: int | None = None) -> list[float]:
|
|
55
|
+
def embed(self, value: str, max_seq_length: int | None = None, prompt: str | None = None) -> list[float]:
|
|
53
56
|
pass
|
|
54
57
|
|
|
55
58
|
@overload
|
|
56
|
-
def embed(
|
|
59
|
+
def embed(
|
|
60
|
+
self, value: list[str], max_seq_length: int | None = None, prompt: str | None = None
|
|
61
|
+
) -> list[list[float]]:
|
|
57
62
|
pass
|
|
58
63
|
|
|
59
|
-
def embed(
|
|
64
|
+
def embed(
|
|
65
|
+
self, value: str | list[str], max_seq_length: int | None = None, prompt: str | None = None
|
|
66
|
+
) -> list[float] | list[list[float]]:
|
|
60
67
|
"""
|
|
61
68
|
Generate embeddings for a value or list of values
|
|
62
69
|
|
|
63
70
|
Params:
|
|
64
71
|
value: The value or list of values to embed
|
|
65
72
|
max_seq_length: The maximum sequence length to truncate the input to
|
|
73
|
+
prompt: Optional prompt for prompt-following embedding models.
|
|
66
74
|
|
|
67
75
|
Returns:
|
|
68
76
|
A matrix of floats representing the embedding for each value if the input is a list of
|
|
69
77
|
values, or a list of floats representing the embedding for the single value if the
|
|
70
78
|
input is a single value
|
|
71
79
|
"""
|
|
72
|
-
|
|
80
|
+
payload: EmbedRequest = {
|
|
81
|
+
"values": value if isinstance(value, list) else [value],
|
|
82
|
+
"max_seq_length": max_seq_length,
|
|
83
|
+
"prompt": prompt,
|
|
84
|
+
}
|
|
73
85
|
if isinstance(self, PretrainedEmbeddingModel):
|
|
74
|
-
embeddings =
|
|
86
|
+
embeddings = orca_api.POST(
|
|
87
|
+
"/gpu/pretrained_embedding_model/{model_name}/embedding",
|
|
88
|
+
params={"model_name": cast(PretrainedEmbeddingModelName, self.name)},
|
|
89
|
+
json=payload,
|
|
90
|
+
timeout=30, # may be slow in case of cold start
|
|
91
|
+
)
|
|
75
92
|
elif isinstance(self, FinetunedEmbeddingModel):
|
|
76
|
-
embeddings =
|
|
93
|
+
embeddings = orca_api.POST(
|
|
94
|
+
"/gpu/finetuned_embedding_model/{name_or_id}/embedding",
|
|
95
|
+
params={"name_or_id": self.id},
|
|
96
|
+
json=payload,
|
|
97
|
+
timeout=30, # may be slow in case of cold start
|
|
98
|
+
)
|
|
77
99
|
else:
|
|
78
100
|
raise ValueError("Invalid embedding model")
|
|
79
101
|
return embeddings if isinstance(value, list) else embeddings[0]
|
|
80
102
|
|
|
103
|
+
@overload
|
|
104
|
+
def evaluate(
|
|
105
|
+
self,
|
|
106
|
+
datasource: Datasource,
|
|
107
|
+
*,
|
|
108
|
+
value_column: str = "value",
|
|
109
|
+
label_column: str,
|
|
110
|
+
score_column: None = None,
|
|
111
|
+
eval_datasource: Datasource | None = None,
|
|
112
|
+
subsample: int | None = None,
|
|
113
|
+
neighbor_count: int = 5,
|
|
114
|
+
batch_size: int = 32,
|
|
115
|
+
weigh_memories: bool = True,
|
|
116
|
+
background: Literal[True],
|
|
117
|
+
) -> Job[ClassificationMetrics]:
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
@overload
|
|
121
|
+
def evaluate(
|
|
122
|
+
self,
|
|
123
|
+
datasource: Datasource,
|
|
124
|
+
*,
|
|
125
|
+
value_column: str = "value",
|
|
126
|
+
label_column: str,
|
|
127
|
+
score_column: None = None,
|
|
128
|
+
eval_datasource: Datasource | None = None,
|
|
129
|
+
subsample: int | None = None,
|
|
130
|
+
neighbor_count: int = 5,
|
|
131
|
+
batch_size: int = 32,
|
|
132
|
+
weigh_memories: bool = True,
|
|
133
|
+
background: Literal[False] = False,
|
|
134
|
+
) -> ClassificationMetrics:
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
@overload
|
|
138
|
+
def evaluate(
|
|
139
|
+
self,
|
|
140
|
+
datasource: Datasource,
|
|
141
|
+
*,
|
|
142
|
+
value_column: str = "value",
|
|
143
|
+
label_column: None = None,
|
|
144
|
+
score_column: str,
|
|
145
|
+
eval_datasource: Datasource | None = None,
|
|
146
|
+
subsample: int | None = None,
|
|
147
|
+
neighbor_count: int = 5,
|
|
148
|
+
batch_size: int = 32,
|
|
149
|
+
weigh_memories: bool = True,
|
|
150
|
+
background: Literal[True],
|
|
151
|
+
) -> Job[RegressionMetrics]:
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
@overload
|
|
155
|
+
def evaluate(
|
|
156
|
+
self,
|
|
157
|
+
datasource: Datasource,
|
|
158
|
+
*,
|
|
159
|
+
value_column: str = "value",
|
|
160
|
+
label_column: None = None,
|
|
161
|
+
score_column: str,
|
|
162
|
+
eval_datasource: Datasource | None = None,
|
|
163
|
+
subsample: int | None = None,
|
|
164
|
+
neighbor_count: int = 5,
|
|
165
|
+
batch_size: int = 32,
|
|
166
|
+
weigh_memories: bool = True,
|
|
167
|
+
background: Literal[False] = False,
|
|
168
|
+
) -> RegressionMetrics:
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
def evaluate(
|
|
172
|
+
self,
|
|
173
|
+
datasource: Datasource,
|
|
174
|
+
*,
|
|
175
|
+
value_column: str = "value",
|
|
176
|
+
label_column: str | None = None,
|
|
177
|
+
score_column: str | None = None,
|
|
178
|
+
eval_datasource: Datasource | None = None,
|
|
179
|
+
subsample: int | None = None,
|
|
180
|
+
neighbor_count: int = 5,
|
|
181
|
+
batch_size: int = 32,
|
|
182
|
+
weigh_memories: bool = True,
|
|
183
|
+
background: bool = False,
|
|
184
|
+
) -> (
|
|
185
|
+
ClassificationMetrics
|
|
186
|
+
| RegressionMetrics
|
|
187
|
+
| Job[ClassificationMetrics]
|
|
188
|
+
| Job[RegressionMetrics]
|
|
189
|
+
| Job[ClassificationMetrics | RegressionMetrics]
|
|
190
|
+
):
|
|
191
|
+
"""
|
|
192
|
+
Evaluate the finetuned embedding model
|
|
193
|
+
"""
|
|
194
|
+
payload: EmbeddingEvaluationRequest = {
|
|
195
|
+
"datasource_name_or_id": datasource.id,
|
|
196
|
+
"datasource_label_column": label_column,
|
|
197
|
+
"datasource_value_column": value_column,
|
|
198
|
+
"datasource_score_column": score_column,
|
|
199
|
+
"eval_datasource_name_or_id": eval_datasource.id if eval_datasource is not None else None,
|
|
200
|
+
"subsample": subsample,
|
|
201
|
+
"neighbor_count": neighbor_count,
|
|
202
|
+
"batch_size": batch_size,
|
|
203
|
+
"weigh_memories": weigh_memories,
|
|
204
|
+
}
|
|
205
|
+
if isinstance(self, PretrainedEmbeddingModel):
|
|
206
|
+
response = orca_api.POST(
|
|
207
|
+
"/pretrained_embedding_model/{model_name}/evaluation",
|
|
208
|
+
params={"model_name": self.name},
|
|
209
|
+
json=payload,
|
|
210
|
+
)
|
|
211
|
+
elif isinstance(self, FinetunedEmbeddingModel):
|
|
212
|
+
response = orca_api.POST(
|
|
213
|
+
"/finetuned_embedding_model/{name_or_id}/evaluation",
|
|
214
|
+
params={"name_or_id": self.id},
|
|
215
|
+
json=payload,
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
raise ValueError("Invalid embedding model")
|
|
219
|
+
|
|
220
|
+
def get_result(task_id: str) -> ClassificationMetrics | RegressionMetrics:
|
|
221
|
+
if isinstance(self, PretrainedEmbeddingModel):
|
|
222
|
+
res = orca_api.GET(
|
|
223
|
+
"/pretrained_embedding_model/{model_name}/evaluation/{task_id}",
|
|
224
|
+
params={"model_name": self.name, "task_id": task_id},
|
|
225
|
+
)["result"]
|
|
226
|
+
elif isinstance(self, FinetunedEmbeddingModel):
|
|
227
|
+
res = orca_api.GET(
|
|
228
|
+
"/finetuned_embedding_model/{name_or_id}/evaluation/{task_id}",
|
|
229
|
+
params={"name_or_id": self.id, "task_id": task_id},
|
|
230
|
+
)["result"]
|
|
231
|
+
else:
|
|
232
|
+
raise ValueError("Invalid embedding model")
|
|
233
|
+
assert res is not None
|
|
234
|
+
return RegressionMetrics(**res) if "mse" in res else ClassificationMetrics(**res)
|
|
235
|
+
|
|
236
|
+
job = Job(response["task_id"], lambda: get_result(response["task_id"]))
|
|
237
|
+
return job if background else job.result()
|
|
238
|
+
|
|
81
239
|
|
|
82
240
|
class _ModelDescriptor:
|
|
83
241
|
"""
|
|
@@ -126,7 +284,7 @@ class _ModelDescriptor:
|
|
|
126
284
|
# Load the model on first access
|
|
127
285
|
if self.model is None:
|
|
128
286
|
try:
|
|
129
|
-
self.model = PretrainedEmbeddingModel._get(self.name)
|
|
287
|
+
self.model = PretrainedEmbeddingModel._get(cast(PretrainedEmbeddingModelName, self.name))
|
|
130
288
|
except (KeyError, AttributeError):
|
|
131
289
|
raise AttributeError(f"No embedding model named {self.name}")
|
|
132
290
|
|
|
@@ -152,17 +310,27 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
152
310
|
- **`GIST_LARGE`**: GIST-Large embedding model from Hugging Face ([avsolatorio/GIST-large-Embedding-v0](https://huggingface.co/avsolatorio/GIST-large-Embedding-v0))
|
|
153
311
|
- **`MXBAI_LARGE`**: Mixbreas's Large embedding model from Hugging Face ([mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1))
|
|
154
312
|
- **`QWEN2_1_5B`**: Alibaba's Qwen2-1.5B instruction-tuned embedding model from Hugging Face ([Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct))
|
|
313
|
+
- **`BGE_BASE`**: BAAI's BGE-Base instruction-tuned embedding model from Hugging Face ([BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5))
|
|
155
314
|
|
|
315
|
+
**Instruction Support:**
|
|
316
|
+
|
|
317
|
+
Some models support instruction-following for better task-specific embeddings. You can check if a model supports instructions
|
|
318
|
+
using the `supports_instructions` attribute.
|
|
156
319
|
|
|
157
320
|
Examples:
|
|
158
321
|
>>> PretrainedEmbeddingModel.CDE_SMALL
|
|
159
322
|
PretrainedEmbeddingModel({name: CDE_SMALL, embedding_dim: 768, max_seq_length: 512})
|
|
160
323
|
|
|
324
|
+
>>> # Using instruction with an instruction-supporting model
|
|
325
|
+
>>> model = PretrainedEmbeddingModel.E5_LARGE
|
|
326
|
+
>>> embeddings = model.embed("Hello world", prompt="Represent this sentence for retrieval:")
|
|
327
|
+
|
|
161
328
|
Attributes:
|
|
162
329
|
name: Name of the pretrained embedding model
|
|
163
330
|
embedding_dim: Dimension of the embeddings that are generated by the model
|
|
164
331
|
max_seq_length: Maximum input length (in tokens not characters) that this model can process. Inputs that are longer will be truncated during the embedding process
|
|
165
332
|
uses_context: Whether the pretrained embedding model uses context
|
|
333
|
+
supports_instructions: Whether this model supports instruction-following
|
|
166
334
|
"""
|
|
167
335
|
|
|
168
336
|
# Define descriptors for model access with IDE autocomplete
|
|
@@ -175,17 +343,21 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
175
343
|
GIST_LARGE = _ModelDescriptor("GIST_LARGE")
|
|
176
344
|
MXBAI_LARGE = _ModelDescriptor("MXBAI_LARGE")
|
|
177
345
|
QWEN2_1_5B = _ModelDescriptor("QWEN2_1_5B")
|
|
346
|
+
BGE_BASE = _ModelDescriptor("BGE_BASE")
|
|
178
347
|
|
|
179
|
-
|
|
348
|
+
name: PretrainedEmbeddingModelName
|
|
180
349
|
|
|
181
350
|
def __init__(self, metadata: PretrainedEmbeddingModelMetadata):
|
|
182
351
|
# for internal use only, do not document
|
|
183
|
-
self.
|
|
352
|
+
self.name = metadata["name"]
|
|
184
353
|
super().__init__(
|
|
185
|
-
name=metadata
|
|
186
|
-
embedding_dim=metadata
|
|
187
|
-
max_seq_length=metadata
|
|
188
|
-
uses_context=metadata
|
|
354
|
+
name=metadata["name"],
|
|
355
|
+
embedding_dim=metadata["embedding_dim"],
|
|
356
|
+
max_seq_length=metadata["max_seq_length"],
|
|
357
|
+
uses_context=metadata["uses_context"],
|
|
358
|
+
supports_instructions=(
|
|
359
|
+
bool(metadata["supports_instructions"]) if "supports_instructions" in metadata else False
|
|
360
|
+
),
|
|
189
361
|
)
|
|
190
362
|
|
|
191
363
|
def __eq__(self, other) -> bool:
|
|
@@ -202,19 +374,24 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
202
374
|
Returns:
|
|
203
375
|
A list of all pretrained embedding models available in the OrcaCloud
|
|
204
376
|
"""
|
|
205
|
-
return [cls(metadata) for metadata in
|
|
377
|
+
return [cls(metadata) for metadata in orca_api.GET("/pretrained_embedding_model")]
|
|
206
378
|
|
|
207
379
|
_instances: dict[str, PretrainedEmbeddingModel] = {}
|
|
208
380
|
|
|
209
381
|
@classmethod
|
|
210
|
-
def _get(cls, name: PretrainedEmbeddingModelName
|
|
382
|
+
def _get(cls, name: PretrainedEmbeddingModelName) -> PretrainedEmbeddingModel:
|
|
211
383
|
# for internal use only, do not document - we want people to use dot notation to get the model
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
384
|
+
cache_key = str(name)
|
|
385
|
+
if cache_key not in cls._instances:
|
|
386
|
+
metadata = orca_api.GET(
|
|
387
|
+
"/pretrained_embedding_model/{model_name}",
|
|
388
|
+
params={"model_name": name},
|
|
389
|
+
)
|
|
390
|
+
cls._instances[cache_key] = cls(metadata)
|
|
391
|
+
return cls._instances[cache_key]
|
|
215
392
|
|
|
216
393
|
@classmethod
|
|
217
|
-
def open(cls, name:
|
|
394
|
+
def open(cls, name: PretrainedEmbeddingModelName) -> PretrainedEmbeddingModel:
|
|
218
395
|
"""
|
|
219
396
|
Open an embedding model by name.
|
|
220
397
|
|
|
@@ -231,9 +408,9 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
231
408
|
>>> model = PretrainedEmbeddingModel.open("GTE_BASE")
|
|
232
409
|
"""
|
|
233
410
|
try:
|
|
234
|
-
#
|
|
235
|
-
return
|
|
236
|
-
except AttributeError:
|
|
411
|
+
# Always use the _get method which handles caching properly
|
|
412
|
+
return cls._get(name)
|
|
413
|
+
except (KeyError, AttributeError):
|
|
237
414
|
raise ValueError(f"Unknown model name: {name}")
|
|
238
415
|
|
|
239
416
|
@classmethod
|
|
@@ -247,7 +424,7 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
247
424
|
Returns:
|
|
248
425
|
True if the pretrained embedding model exists, False otherwise
|
|
249
426
|
"""
|
|
250
|
-
return name in PretrainedEmbeddingModelName
|
|
427
|
+
return name in get_args(PretrainedEmbeddingModelName)
|
|
251
428
|
|
|
252
429
|
@overload
|
|
253
430
|
def finetune(
|
|
@@ -258,7 +435,7 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
258
435
|
eval_datasource: Datasource | None = None,
|
|
259
436
|
label_column: str = "label",
|
|
260
437
|
value_column: str = "value",
|
|
261
|
-
training_method: EmbeddingFinetuningMethod
|
|
438
|
+
training_method: EmbeddingFinetuningMethod = "classification",
|
|
262
439
|
training_args: dict | None = None,
|
|
263
440
|
if_exists: CreateMode = "error",
|
|
264
441
|
background: Literal[True],
|
|
@@ -274,7 +451,7 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
274
451
|
eval_datasource: Datasource | None = None,
|
|
275
452
|
label_column: str = "label",
|
|
276
453
|
value_column: str = "value",
|
|
277
|
-
training_method: EmbeddingFinetuningMethod
|
|
454
|
+
training_method: EmbeddingFinetuningMethod = "classification",
|
|
278
455
|
training_args: dict | None = None,
|
|
279
456
|
if_exists: CreateMode = "error",
|
|
280
457
|
background: Literal[False] = False,
|
|
@@ -289,7 +466,7 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
289
466
|
eval_datasource: Datasource | None = None,
|
|
290
467
|
label_column: str = "label",
|
|
291
468
|
value_column: str = "value",
|
|
292
|
-
training_method: EmbeddingFinetuningMethod
|
|
469
|
+
training_method: EmbeddingFinetuningMethod = "classification",
|
|
293
470
|
training_args: dict | None = None,
|
|
294
471
|
if_exists: CreateMode = "error",
|
|
295
472
|
background: bool = False,
|
|
@@ -329,32 +506,35 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
329
506
|
elif exists and if_exists == "open":
|
|
330
507
|
existing = FinetunedEmbeddingModel.open(name)
|
|
331
508
|
|
|
332
|
-
if existing.base_model_name != self.
|
|
509
|
+
if existing.base_model_name != self.name:
|
|
333
510
|
raise ValueError(f"Finetuned embedding model '{name}' already exists, but with different base model")
|
|
334
511
|
|
|
335
512
|
return existing
|
|
336
513
|
|
|
337
514
|
from .memoryset import LabeledMemoryset
|
|
338
515
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
516
|
+
payload: FinetuneEmbeddingModelRequest = {
|
|
517
|
+
"name": name,
|
|
518
|
+
"base_model": self.name,
|
|
519
|
+
"label_column": label_column,
|
|
520
|
+
"value_column": value_column,
|
|
521
|
+
"training_method": training_method,
|
|
522
|
+
"training_args": training_args or {},
|
|
523
|
+
}
|
|
524
|
+
if isinstance(train_datasource, Datasource):
|
|
525
|
+
payload["train_datasource_name_or_id"] = train_datasource.id
|
|
526
|
+
elif isinstance(train_datasource, LabeledMemoryset):
|
|
527
|
+
payload["train_memoryset_name_or_id"] = train_datasource.id
|
|
528
|
+
if eval_datasource is not None:
|
|
529
|
+
payload["eval_datasource_name_or_id"] = eval_datasource.id
|
|
530
|
+
|
|
531
|
+
res = orca_api.POST(
|
|
532
|
+
"/finetuned_embedding_model",
|
|
533
|
+
json=payload,
|
|
354
534
|
)
|
|
355
535
|
job = Job(
|
|
356
|
-
res
|
|
357
|
-
lambda: FinetunedEmbeddingModel.open(res
|
|
536
|
+
res["finetuning_task_id"],
|
|
537
|
+
lambda: FinetunedEmbeddingModel.open(res["id"]),
|
|
358
538
|
)
|
|
359
539
|
return job if background else job.result()
|
|
360
540
|
|
|
@@ -374,22 +554,27 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
374
554
|
"""
|
|
375
555
|
|
|
376
556
|
id: str
|
|
557
|
+
name: str
|
|
377
558
|
created_at: datetime
|
|
378
559
|
updated_at: datetime
|
|
560
|
+
base_model_name: PretrainedEmbeddingModelName
|
|
379
561
|
_status: Status
|
|
380
562
|
|
|
381
563
|
def __init__(self, metadata: FinetunedEmbeddingModelMetadata):
|
|
382
564
|
# for internal use only, do not document
|
|
383
|
-
self.id = metadata
|
|
384
|
-
self.
|
|
385
|
-
self.
|
|
386
|
-
self.
|
|
387
|
-
self.
|
|
565
|
+
self.id = metadata["id"]
|
|
566
|
+
self.name = metadata["name"]
|
|
567
|
+
self.created_at = datetime.fromisoformat(metadata["created_at"])
|
|
568
|
+
self.updated_at = datetime.fromisoformat(metadata["updated_at"])
|
|
569
|
+
self.base_model_name = metadata["base_model"]
|
|
570
|
+
self._status = Status(metadata["finetuning_status"])
|
|
571
|
+
|
|
388
572
|
super().__init__(
|
|
389
|
-
name=metadata
|
|
390
|
-
embedding_dim=metadata
|
|
391
|
-
max_seq_length=metadata
|
|
392
|
-
uses_context=metadata
|
|
573
|
+
name=metadata["name"],
|
|
574
|
+
embedding_dim=metadata["embedding_dim"],
|
|
575
|
+
max_seq_length=metadata["max_seq_length"],
|
|
576
|
+
uses_context=metadata["uses_context"],
|
|
577
|
+
supports_instructions=self.base_model.supports_instructions,
|
|
393
578
|
)
|
|
394
579
|
|
|
395
580
|
def __eq__(self, other) -> bool:
|
|
@@ -401,7 +586,7 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
401
586
|
f" name: {self.name},\n"
|
|
402
587
|
f" embedding_dim: {self.embedding_dim},\n"
|
|
403
588
|
f" max_seq_length: {self.max_seq_length},\n"
|
|
404
|
-
f" base_model: PretrainedEmbeddingModel.{self.base_model_name
|
|
589
|
+
f" base_model: PretrainedEmbeddingModel.{self.base_model_name}\n"
|
|
405
590
|
"})"
|
|
406
591
|
)
|
|
407
592
|
|
|
@@ -418,7 +603,7 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
418
603
|
Returns:
|
|
419
604
|
A list of all finetuned embedding model handles in the OrcaCloud
|
|
420
605
|
"""
|
|
421
|
-
return [cls(metadata) for metadata in
|
|
606
|
+
return [cls(metadata) for metadata in orca_api.GET("/finetuned_embedding_model")]
|
|
422
607
|
|
|
423
608
|
@classmethod
|
|
424
609
|
def open(cls, name: str) -> FinetunedEmbeddingModel:
|
|
@@ -434,7 +619,11 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
434
619
|
Raises:
|
|
435
620
|
LookupError: If the finetuned embedding model does not exist
|
|
436
621
|
"""
|
|
437
|
-
|
|
622
|
+
metadata = orca_api.GET(
|
|
623
|
+
"/finetuned_embedding_model/{name_or_id}",
|
|
624
|
+
params={"name_or_id": name},
|
|
625
|
+
)
|
|
626
|
+
return cls(metadata)
|
|
438
627
|
|
|
439
628
|
@classmethod
|
|
440
629
|
def exists(cls, name_or_id: str) -> bool:
|
|
@@ -465,7 +654,10 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
465
654
|
LookupError: If the finetuned embedding model does not exist and `if_not_exists` is `"error"`
|
|
466
655
|
"""
|
|
467
656
|
try:
|
|
468
|
-
|
|
657
|
+
orca_api.DELETE(
|
|
658
|
+
"/finetuned_embedding_model/{name_or_id}",
|
|
659
|
+
params={"name_or_id": name_or_id},
|
|
660
|
+
)
|
|
469
661
|
except (LookupError, RuntimeError):
|
|
470
662
|
if if_not_exists == "error":
|
|
471
663
|
raise
|
orca_sdk/embedding_model_test.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
from typing import get_args
|
|
2
3
|
from uuid import uuid4
|
|
3
4
|
|
|
4
5
|
import pytest
|
|
5
6
|
|
|
6
7
|
from .datasource import Datasource
|
|
7
8
|
from .embedding_model import (
|
|
9
|
+
ClassificationMetrics,
|
|
8
10
|
FinetunedEmbeddingModel,
|
|
9
11
|
PretrainedEmbeddingModel,
|
|
10
12
|
PretrainedEmbeddingModelName,
|
|
@@ -30,16 +32,16 @@ def test_open_pretrained_model_unauthenticated(unauthenticated):
|
|
|
30
32
|
|
|
31
33
|
def test_open_pretrained_model_not_found():
|
|
32
34
|
with pytest.raises(LookupError):
|
|
33
|
-
PretrainedEmbeddingModel._get("INVALID_MODEL")
|
|
35
|
+
PretrainedEmbeddingModel._get("INVALID_MODEL") # type: ignore
|
|
34
36
|
|
|
35
37
|
|
|
36
38
|
def test_all_pretrained_models():
|
|
37
39
|
models = PretrainedEmbeddingModel.all()
|
|
38
40
|
assert len(models) > 1
|
|
39
|
-
if len(models) != len(PretrainedEmbeddingModelName):
|
|
41
|
+
if len(models) != len(get_args(PretrainedEmbeddingModelName)):
|
|
40
42
|
logging.warning("Please regenerate the SDK client! Some pretrained model names are not exposed yet.")
|
|
41
43
|
model_names = [m.name for m in models]
|
|
42
|
-
assert all(
|
|
44
|
+
assert all(m in model_names for m in get_args(PretrainedEmbeddingModelName))
|
|
43
45
|
|
|
44
46
|
|
|
45
47
|
def test_embed_text():
|
|
@@ -55,6 +57,13 @@ def test_embed_text_unauthenticated(unauthenticated):
|
|
|
55
57
|
PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
|
|
56
58
|
|
|
57
59
|
|
|
60
|
+
def test_evaluate_pretrained_model(datasource: Datasource):
|
|
61
|
+
metrics = PretrainedEmbeddingModel.GTE_BASE.evaluate(datasource=datasource, label_column="label")
|
|
62
|
+
assert metrics is not None
|
|
63
|
+
assert isinstance(metrics, ClassificationMetrics)
|
|
64
|
+
assert metrics.accuracy > 0.5
|
|
65
|
+
|
|
66
|
+
|
|
58
67
|
@pytest.fixture(scope="session")
|
|
59
68
|
def finetuned_model(datasource) -> FinetunedEmbeddingModel:
|
|
60
69
|
return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
|
|
@@ -83,18 +92,14 @@ def test_finetune_model_with_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
|
83
92
|
|
|
84
93
|
def test_finetune_model_already_exists_error(datasource: Datasource, finetuned_model):
|
|
85
94
|
with pytest.raises(ValueError):
|
|
86
|
-
PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource
|
|
95
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
|
|
87
96
|
|
|
88
97
|
|
|
89
98
|
def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_model):
|
|
90
99
|
with pytest.raises(ValueError):
|
|
91
|
-
PretrainedEmbeddingModel.GTE_BASE.finetune(
|
|
92
|
-
"test_finetuned_model", datasource, if_exists="open", value_column="text"
|
|
93
|
-
)
|
|
100
|
+
PretrainedEmbeddingModel.GTE_BASE.finetune("test_finetuned_model", datasource, if_exists="open")
|
|
94
101
|
|
|
95
|
-
new_model = PretrainedEmbeddingModel.DISTILBERT.finetune(
|
|
96
|
-
"test_finetuned_model", datasource, if_exists="open", value_column="text"
|
|
97
|
-
)
|
|
102
|
+
new_model = PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, if_exists="open")
|
|
98
103
|
assert new_model is not None
|
|
99
104
|
assert new_model.name == "test_finetuned_model"
|
|
100
105
|
assert new_model.base_model == PretrainedEmbeddingModel.DISTILBERT
|
|
@@ -105,9 +110,7 @@ def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_
|
|
|
105
110
|
|
|
106
111
|
def test_finetune_model_unauthenticated(unauthenticated, datasource: Datasource):
|
|
107
112
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
108
|
-
PretrainedEmbeddingModel.DISTILBERT.finetune(
|
|
109
|
-
"test_finetuned_model_unauthenticated", datasource, value_column="text"
|
|
110
|
-
)
|
|
113
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_unauthenticated", datasource)
|
|
111
114
|
|
|
112
115
|
|
|
113
116
|
def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_model: FinetunedEmbeddingModel):
|
|
@@ -166,7 +169,7 @@ def test_drop_finetuned_model(datasource: Datasource):
|
|
|
166
169
|
|
|
167
170
|
def test_drop_finetuned_model_unauthenticated(unauthenticated, datasource: Datasource):
|
|
168
171
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
169
|
-
PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource
|
|
172
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
|
|
170
173
|
|
|
171
174
|
|
|
172
175
|
def test_drop_finetuned_model_not_found():
|
|
@@ -179,3 +182,18 @@ def test_drop_finetuned_model_not_found():
|
|
|
179
182
|
def test_drop_finetuned_model_unauthorized(unauthorized, finetuned_model: FinetunedEmbeddingModel):
|
|
180
183
|
with pytest.raises(LookupError):
|
|
181
184
|
FinetunedEmbeddingModel.drop(finetuned_model.id)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def test_supports_instructions():
|
|
188
|
+
model = PretrainedEmbeddingModel.GTE_BASE
|
|
189
|
+
assert not model.supports_instructions
|
|
190
|
+
|
|
191
|
+
instruction_model = PretrainedEmbeddingModel.BGE_BASE
|
|
192
|
+
assert instruction_model.supports_instructions
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def test_use_explicit_instruction_prompt():
|
|
196
|
+
model = PretrainedEmbeddingModel.BGE_BASE
|
|
197
|
+
assert model.supports_instructions
|
|
198
|
+
input = "Hello world"
|
|
199
|
+
assert model.embed(input, prompt="Represent this sentence for sentiment retrieval:") != model.embed(input)
|