orca-sdk 0.0.96__py3-none-any.whl → 0.0.98__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (269) hide show
  1. orca_sdk/__init__.py +2 -5
  2. orca_sdk/_shared/__init__.py +1 -0
  3. orca_sdk/_shared/metrics.py +1 -1
  4. orca_sdk/_utils/analysis_ui.py +5 -5
  5. orca_sdk/_utils/auth.py +23 -33
  6. orca_sdk/_utils/pagination.py +126 -0
  7. orca_sdk/_utils/pagination_test.py +132 -0
  8. orca_sdk/classification_model.py +188 -126
  9. orca_sdk/classification_model_test.py +102 -0
  10. orca_sdk/client.py +3515 -0
  11. orca_sdk/conftest.py +10 -0
  12. orca_sdk/credentials.py +73 -21
  13. orca_sdk/credentials_test.py +20 -0
  14. orca_sdk/datasource.py +186 -81
  15. orca_sdk/datasource_test.py +194 -0
  16. orca_sdk/embedding_model.py +267 -75
  17. orca_sdk/embedding_model_test.py +32 -14
  18. orca_sdk/job.py +59 -54
  19. orca_sdk/job_test.py +50 -0
  20. orca_sdk/memoryset.py +372 -345
  21. orca_sdk/memoryset_test.py +7 -11
  22. orca_sdk/regression_model.py +120 -111
  23. orca_sdk/regression_model_test.py +15 -0
  24. orca_sdk/telemetry.py +229 -115
  25. {orca_sdk-0.0.96.dist-info → orca_sdk-0.0.98.dist-info}/METADATA +19 -5
  26. orca_sdk-0.0.98.dist-info/RECORD +40 -0
  27. orca_sdk/_generated_api_client/__init__.py +0 -3
  28. orca_sdk/_generated_api_client/api/__init__.py +0 -287
  29. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  30. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  31. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  32. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  33. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  34. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  35. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  36. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_gpu_classification_model_post.py +0 -170
  37. orca_sdk/_generated_api_client/api/classification_model/delete_classification_model_classification_model_name_or_id_delete.py +0 -156
  38. orca_sdk/_generated_api_client/api/classification_model/delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  39. orca_sdk/_generated_api_client/api/classification_model/evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py +0 -183
  40. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +0 -156
  41. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  42. orca_sdk/_generated_api_client/api/classification_model/list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  43. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +0 -127
  44. orca_sdk/_generated_api_client/api/classification_model/predict_label_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  45. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +0 -183
  46. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  47. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  48. orca_sdk/_generated_api_client/api/datasource/create_embedding_evaluation_datasource_name_or_id_embedding_evaluation_post.py +0 -183
  49. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  50. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +0 -172
  51. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  52. orca_sdk/_generated_api_client/api/datasource/get_embedding_evaluation_datasource_name_or_id_embedding_evaluation_task_id_get.py +0 -169
  53. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  54. orca_sdk/_generated_api_client/api/datasource/list_embedding_evaluations_datasource_name_or_id_embedding_evaluation_get.py +0 -235
  55. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  56. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  57. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  58. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  59. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  60. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  61. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  62. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  63. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  64. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  65. orca_sdk/_generated_api_client/api/memoryset/analyze_memoryset_memoryset_name_or_id_analysis_post.py +0 -183
  66. orca_sdk/_generated_api_client/api/memoryset/batch_delete_memoryset_batch_delete_memoryset_post.py +0 -168
  67. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  68. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  69. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  70. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  71. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  72. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  73. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -210
  74. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -186
  75. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  76. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -188
  77. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -235
  78. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -180
  79. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -212
  80. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +0 -195
  81. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -210
  82. orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +0 -233
  83. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -216
  84. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -205
  85. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +0 -183
  86. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  87. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +0 -150
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -192
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -161
  91. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  92. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  93. orca_sdk/_generated_api_client/api/regression_model/create_regression_model_gpu_regression_model_post.py +0 -170
  94. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  95. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_regression_model_name_or_id_delete.py +0 -154
  96. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +0 -183
  97. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  98. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +0 -156
  99. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +0 -161
  100. orca_sdk/_generated_api_client/api/regression_model/list_regression_models_regression_model_get.py +0 -127
  101. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +0 -190
  102. orca_sdk/_generated_api_client/api/regression_model/update_regression_model_regression_model_name_or_id_patch.py +0 -183
  103. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  104. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  105. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  106. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +0 -156
  107. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -293
  108. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  109. orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +0 -168
  110. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  111. orca_sdk/_generated_api_client/api/telemetry/explain_prediction_telemetry_prediction_prediction_id_explanation_get.py +0 -182
  112. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  113. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -180
  114. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  115. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +0 -198
  116. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -198
  117. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  118. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  119. orca_sdk/_generated_api_client/client.py +0 -216
  120. orca_sdk/_generated_api_client/errors.py +0 -38
  121. orca_sdk/_generated_api_client/models/__init__.py +0 -295
  122. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -116
  123. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -137
  124. orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py +0 -9
  125. orca_sdk/_generated_api_client/models/base_label_prediction_result.py +0 -130
  126. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  127. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +0 -108
  128. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -207
  129. orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +0 -154
  130. orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +0 -92
  131. orca_sdk/_generated_api_client/models/classification_evaluation_request.py +0 -148
  132. orca_sdk/_generated_api_client/models/classification_metrics.py +0 -259
  133. orca_sdk/_generated_api_client/models/classification_model_metadata.py +0 -213
  134. orca_sdk/_generated_api_client/models/classification_prediction_request.py +0 -220
  135. orca_sdk/_generated_api_client/models/clone_memoryset_request.py +0 -170
  136. orca_sdk/_generated_api_client/models/cluster_metrics.py +0 -78
  137. orca_sdk/_generated_api_client/models/column_info.py +0 -145
  138. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  139. orca_sdk/_generated_api_client/models/constraint_violation_error_response.py +0 -80
  140. orca_sdk/_generated_api_client/models/count_predictions_request.py +0 -195
  141. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -120
  142. orca_sdk/_generated_api_client/models/create_api_key_request_scope_item.py +0 -9
  143. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -145
  144. orca_sdk/_generated_api_client/models/create_api_key_response_scope_item.py +0 -9
  145. orca_sdk/_generated_api_client/models/create_classification_model_request.py +0 -197
  146. orca_sdk/_generated_api_client/models/create_memoryset_request.py +0 -325
  147. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +0 -66
  148. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +0 -13
  149. orca_sdk/_generated_api_client/models/create_regression_model_request.py +0 -137
  150. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -156
  151. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  152. orca_sdk/_generated_api_client/models/delete_memorysets_request.py +0 -70
  153. orca_sdk/_generated_api_client/models/embed_request.py +0 -135
  154. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +0 -187
  155. orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +0 -179
  156. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +0 -158
  157. orca_sdk/_generated_api_client/models/embedding_evaluation_result.py +0 -86
  158. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  159. orca_sdk/_generated_api_client/models/embedding_model_result.py +0 -114
  160. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -153
  161. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +0 -140
  162. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +0 -140
  163. orca_sdk/_generated_api_client/models/feedback_metrics.py +0 -85
  164. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  165. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  166. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -17
  167. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -20
  168. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  169. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  170. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  171. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  172. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  173. orca_sdk/_generated_api_client/models/http_validation_error.py +0 -86
  174. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  175. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  176. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -210
  177. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  178. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -288
  179. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -186
  180. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -128
  181. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  182. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -194
  183. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  184. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  185. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -246
  186. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  187. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  188. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +0 -207
  189. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_feedback_metrics.py +0 -68
  190. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_metadata.py +0 -68
  191. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  192. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -319
  193. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  194. orca_sdk/_generated_api_client/models/lookup_score_metrics.py +0 -94
  195. orca_sdk/_generated_api_client/models/memory_metrics.py +0 -165
  196. orca_sdk/_generated_api_client/models/memory_type.py +0 -9
  197. orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -212
  198. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -105
  199. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -182
  200. orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config.py +0 -202
  201. orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_clustering_method.py +0 -9
  202. orca_sdk/_generated_api_client/models/memoryset_cluster_analysis_config_partitioning_method.py +0 -10
  203. orca_sdk/_generated_api_client/models/memoryset_cluster_metrics.py +0 -100
  204. orca_sdk/_generated_api_client/models/memoryset_duplicate_analysis_config.py +0 -70
  205. orca_sdk/_generated_api_client/models/memoryset_duplicate_metrics.py +0 -70
  206. orca_sdk/_generated_api_client/models/memoryset_label_analysis_config.py +0 -70
  207. orca_sdk/_generated_api_client/models/memoryset_label_metrics.py +0 -116
  208. orca_sdk/_generated_api_client/models/memoryset_metadata.py +0 -291
  209. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +0 -55
  210. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +0 -13
  211. orca_sdk/_generated_api_client/models/memoryset_metrics.py +0 -232
  212. orca_sdk/_generated_api_client/models/memoryset_neighbor_analysis_config.py +0 -83
  213. orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics.py +0 -76
  214. orca_sdk/_generated_api_client/models/memoryset_neighbor_metrics_lookup_score_metrics.py +0 -68
  215. orca_sdk/_generated_api_client/models/memoryset_projection_analysis_config.py +0 -79
  216. orca_sdk/_generated_api_client/models/memoryset_projection_metrics.py +0 -55
  217. orca_sdk/_generated_api_client/models/memoryset_update.py +0 -101
  218. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  219. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -22
  220. orca_sdk/_generated_api_client/models/paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py +0 -135
  221. orca_sdk/_generated_api_client/models/pr_curve.py +0 -86
  222. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  223. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  224. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  225. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  226. orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_0.py +0 -10
  227. orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_1.py +0 -9
  228. orca_sdk/_generated_api_client/models/predictive_model_update.py +0 -91
  229. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -107
  230. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -17
  231. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  232. orca_sdk/_generated_api_client/models/rar_head_type.py +0 -8
  233. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +0 -148
  234. orca_sdk/_generated_api_client/models/regression_metrics.py +0 -172
  235. orca_sdk/_generated_api_client/models/regression_model_metadata.py +0 -177
  236. orca_sdk/_generated_api_client/models/regression_prediction_request.py +0 -195
  237. orca_sdk/_generated_api_client/models/roc_curve.py +0 -86
  238. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +0 -196
  239. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +0 -68
  240. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +0 -252
  241. orca_sdk/_generated_api_client/models/scored_memory.py +0 -172
  242. orca_sdk/_generated_api_client/models/scored_memory_insert.py +0 -128
  243. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +0 -68
  244. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +0 -180
  245. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +0 -68
  246. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +0 -68
  247. orca_sdk/_generated_api_client/models/scored_memory_update.py +0 -171
  248. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +0 -68
  249. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +0 -193
  250. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +0 -68
  251. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +0 -68
  252. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  253. orca_sdk/_generated_api_client/models/task.py +0 -198
  254. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  255. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  256. orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_2.py +0 -9
  257. orca_sdk/_generated_api_client/models/telemetry_filter_item.py +0 -205
  258. orca_sdk/_generated_api_client/models/telemetry_filter_item_op.py +0 -15
  259. orca_sdk/_generated_api_client/models/telemetry_memories_request.py +0 -181
  260. orca_sdk/_generated_api_client/models/telemetry_sort_options.py +0 -173
  261. orca_sdk/_generated_api_client/models/telemetry_sort_options_direction.py +0 -9
  262. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  263. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  264. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -133
  265. orca_sdk/_generated_api_client/models/validation_error.py +0 -99
  266. orca_sdk/_generated_api_client/py.typed +0 -1
  267. orca_sdk/_generated_api_client/types.py +0 -56
  268. orca_sdk-0.0.96.dist-info/RECORD +0 -278
  269. {orca_sdk-0.0.96.dist-info → orca_sdk-0.0.98.dist-info}/WHEEL +0 -0
@@ -2,28 +2,20 @@ from __future__ import annotations
2
2
 
3
3
  from abc import abstractmethod
4
4
  from datetime import datetime
5
- from typing import TYPE_CHECKING, Literal, Sequence, cast, overload
6
-
7
- from ._generated_api_client.api import (
8
- create_finetuned_embedding_model,
9
- delete_finetuned_embedding_model,
10
- embed_with_finetuned_model_gpu,
11
- embed_with_pretrained_model_gpu,
12
- get_finetuned_embedding_model,
13
- get_pretrained_embedding_model,
14
- list_finetuned_embedding_models,
15
- list_pretrained_embedding_models,
16
- )
17
- from ._generated_api_client.models import (
5
+ from typing import TYPE_CHECKING, Literal, Sequence, cast, get_args, overload
6
+
7
+ from ._shared.metrics import ClassificationMetrics, RegressionMetrics
8
+ from ._utils.common import UNSET, CreateMode, DropMode
9
+ from .client import (
10
+ EmbeddingEvaluationRequest,
18
11
  EmbeddingFinetuningMethod,
19
12
  EmbedRequest,
20
13
  FinetunedEmbeddingModelMetadata,
21
14
  FinetuneEmbeddingModelRequest,
22
- FinetuneEmbeddingModelRequestTrainingArgs,
23
15
  PretrainedEmbeddingModelMetadata,
24
16
  PretrainedEmbeddingModelName,
17
+ orca_api,
25
18
  )
26
- from ._utils.common import CreateMode, DropMode
27
19
  from .datasource import Datasource
28
20
  from .job import Job, Status
29
21
 
@@ -32,52 +24,218 @@ if TYPE_CHECKING:
32
24
 
33
25
 
34
26
  class _EmbeddingModel:
35
- name: str
36
27
  embedding_dim: int
37
28
  max_seq_length: int
38
29
  uses_context: bool
30
+ supports_instructions: bool
39
31
 
40
- def __init__(self, *, name: str, embedding_dim: int, max_seq_length: int, uses_context: bool):
41
- self.name = name
32
+ def __init__(
33
+ self, *, name: str, embedding_dim: int, max_seq_length: int, uses_context: bool, supports_instructions: bool
34
+ ):
42
35
  self.embedding_dim = embedding_dim
43
36
  self.max_seq_length = max_seq_length
44
37
  self.uses_context = uses_context
38
+ self.supports_instructions = supports_instructions
45
39
 
46
40
  @classmethod
47
41
  @abstractmethod
48
42
  def all(cls) -> Sequence[_EmbeddingModel]:
49
43
  pass
50
44
 
45
+ def _get_instruction_error_message(self) -> str:
46
+ """Get error message for instruction not supported"""
47
+ if isinstance(self, FinetunedEmbeddingModel):
48
+ return f"Model {self.name} does not support instructions. Instruction-following is only supported by models based on instruction-supporting models."
49
+ elif isinstance(self, PretrainedEmbeddingModel):
50
+ return f"Model {self.name} does not support instructions. Instruction-following is only supported by instruction-supporting models."
51
+ else:
52
+ raise ValueError("Invalid embedding model")
53
+
51
54
  @overload
52
- def embed(self, value: str, max_seq_length: int | None = None) -> list[float]:
55
+ def embed(self, value: str, max_seq_length: int | None = None, prompt: str | None = None) -> list[float]:
53
56
  pass
54
57
 
55
58
  @overload
56
- def embed(self, value: list[str], max_seq_length: int | None = None) -> list[list[float]]:
59
+ def embed(
60
+ self, value: list[str], max_seq_length: int | None = None, prompt: str | None = None
61
+ ) -> list[list[float]]:
57
62
  pass
58
63
 
59
- def embed(self, value: str | list[str], max_seq_length: int | None = None) -> list[float] | list[list[float]]:
64
+ def embed(
65
+ self, value: str | list[str], max_seq_length: int | None = None, prompt: str | None = None
66
+ ) -> list[float] | list[list[float]]:
60
67
  """
61
68
  Generate embeddings for a value or list of values
62
69
 
63
70
  Params:
64
71
  value: The value or list of values to embed
65
72
  max_seq_length: The maximum sequence length to truncate the input to
73
+ prompt: Optional prompt for prompt-following embedding models.
66
74
 
67
75
  Returns:
68
76
  A matrix of floats representing the embedding for each value if the input is a list of
69
77
  values, or a list of floats representing the embedding for the single value if the
70
78
  input is a single value
71
79
  """
72
- request = EmbedRequest(values=value if isinstance(value, list) else [value], max_seq_length=max_seq_length)
80
+ payload: EmbedRequest = {
81
+ "values": value if isinstance(value, list) else [value],
82
+ "max_seq_length": max_seq_length,
83
+ "prompt": prompt,
84
+ }
73
85
  if isinstance(self, PretrainedEmbeddingModel):
74
- embeddings = embed_with_pretrained_model_gpu(self._model_name, body=request)
86
+ embeddings = orca_api.POST(
87
+ "/gpu/pretrained_embedding_model/{model_name}/embedding",
88
+ params={"model_name": cast(PretrainedEmbeddingModelName, self.name)},
89
+ json=payload,
90
+ timeout=30, # may be slow in case of cold start
91
+ )
75
92
  elif isinstance(self, FinetunedEmbeddingModel):
76
- embeddings = embed_with_finetuned_model_gpu(self.id, body=request)
93
+ embeddings = orca_api.POST(
94
+ "/gpu/finetuned_embedding_model/{name_or_id}/embedding",
95
+ params={"name_or_id": self.id},
96
+ json=payload,
97
+ timeout=30, # may be slow in case of cold start
98
+ )
77
99
  else:
78
100
  raise ValueError("Invalid embedding model")
79
101
  return embeddings if isinstance(value, list) else embeddings[0]
80
102
 
103
+ @overload
104
+ def evaluate(
105
+ self,
106
+ datasource: Datasource,
107
+ *,
108
+ value_column: str = "value",
109
+ label_column: str,
110
+ score_column: None = None,
111
+ eval_datasource: Datasource | None = None,
112
+ subsample: int | None = None,
113
+ neighbor_count: int = 5,
114
+ batch_size: int = 32,
115
+ weigh_memories: bool = True,
116
+ background: Literal[True],
117
+ ) -> Job[ClassificationMetrics]:
118
+ pass
119
+
120
+ @overload
121
+ def evaluate(
122
+ self,
123
+ datasource: Datasource,
124
+ *,
125
+ value_column: str = "value",
126
+ label_column: str,
127
+ score_column: None = None,
128
+ eval_datasource: Datasource | None = None,
129
+ subsample: int | None = None,
130
+ neighbor_count: int = 5,
131
+ batch_size: int = 32,
132
+ weigh_memories: bool = True,
133
+ background: Literal[False] = False,
134
+ ) -> ClassificationMetrics:
135
+ pass
136
+
137
+ @overload
138
+ def evaluate(
139
+ self,
140
+ datasource: Datasource,
141
+ *,
142
+ value_column: str = "value",
143
+ label_column: None = None,
144
+ score_column: str,
145
+ eval_datasource: Datasource | None = None,
146
+ subsample: int | None = None,
147
+ neighbor_count: int = 5,
148
+ batch_size: int = 32,
149
+ weigh_memories: bool = True,
150
+ background: Literal[True],
151
+ ) -> Job[RegressionMetrics]:
152
+ pass
153
+
154
+ @overload
155
+ def evaluate(
156
+ self,
157
+ datasource: Datasource,
158
+ *,
159
+ value_column: str = "value",
160
+ label_column: None = None,
161
+ score_column: str,
162
+ eval_datasource: Datasource | None = None,
163
+ subsample: int | None = None,
164
+ neighbor_count: int = 5,
165
+ batch_size: int = 32,
166
+ weigh_memories: bool = True,
167
+ background: Literal[False] = False,
168
+ ) -> RegressionMetrics:
169
+ pass
170
+
171
+ def evaluate(
172
+ self,
173
+ datasource: Datasource,
174
+ *,
175
+ value_column: str = "value",
176
+ label_column: str | None = None,
177
+ score_column: str | None = None,
178
+ eval_datasource: Datasource | None = None,
179
+ subsample: int | None = None,
180
+ neighbor_count: int = 5,
181
+ batch_size: int = 32,
182
+ weigh_memories: bool = True,
183
+ background: bool = False,
184
+ ) -> (
185
+ ClassificationMetrics
186
+ | RegressionMetrics
187
+ | Job[ClassificationMetrics]
188
+ | Job[RegressionMetrics]
189
+ | Job[ClassificationMetrics | RegressionMetrics]
190
+ ):
191
+ """
192
+ Evaluate the finetuned embedding model
193
+ """
194
+ payload: EmbeddingEvaluationRequest = {
195
+ "datasource_name_or_id": datasource.id,
196
+ "datasource_label_column": label_column,
197
+ "datasource_value_column": value_column,
198
+ "datasource_score_column": score_column,
199
+ "eval_datasource_name_or_id": eval_datasource.id if eval_datasource is not None else None,
200
+ "subsample": subsample,
201
+ "neighbor_count": neighbor_count,
202
+ "batch_size": batch_size,
203
+ "weigh_memories": weigh_memories,
204
+ }
205
+ if isinstance(self, PretrainedEmbeddingModel):
206
+ response = orca_api.POST(
207
+ "/pretrained_embedding_model/{model_name}/evaluation",
208
+ params={"model_name": self.name},
209
+ json=payload,
210
+ )
211
+ elif isinstance(self, FinetunedEmbeddingModel):
212
+ response = orca_api.POST(
213
+ "/finetuned_embedding_model/{name_or_id}/evaluation",
214
+ params={"name_or_id": self.id},
215
+ json=payload,
216
+ )
217
+ else:
218
+ raise ValueError("Invalid embedding model")
219
+
220
+ def get_result(task_id: str) -> ClassificationMetrics | RegressionMetrics:
221
+ if isinstance(self, PretrainedEmbeddingModel):
222
+ res = orca_api.GET(
223
+ "/pretrained_embedding_model/{model_name}/evaluation/{task_id}",
224
+ params={"model_name": self.name, "task_id": task_id},
225
+ )["result"]
226
+ elif isinstance(self, FinetunedEmbeddingModel):
227
+ res = orca_api.GET(
228
+ "/finetuned_embedding_model/{name_or_id}/evaluation/{task_id}",
229
+ params={"name_or_id": self.id, "task_id": task_id},
230
+ )["result"]
231
+ else:
232
+ raise ValueError("Invalid embedding model")
233
+ assert res is not None
234
+ return RegressionMetrics(**res) if "mse" in res else ClassificationMetrics(**res)
235
+
236
+ job = Job(response["task_id"], lambda: get_result(response["task_id"]))
237
+ return job if background else job.result()
238
+
81
239
 
82
240
  class _ModelDescriptor:
83
241
  """
@@ -126,7 +284,7 @@ class _ModelDescriptor:
126
284
  # Load the model on first access
127
285
  if self.model is None:
128
286
  try:
129
- self.model = PretrainedEmbeddingModel._get(self.name)
287
+ self.model = PretrainedEmbeddingModel._get(cast(PretrainedEmbeddingModelName, self.name))
130
288
  except (KeyError, AttributeError):
131
289
  raise AttributeError(f"No embedding model named {self.name}")
132
290
 
@@ -152,17 +310,27 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
152
310
  - **`GIST_LARGE`**: GIST-Large embedding model from Hugging Face ([avsolatorio/GIST-large-Embedding-v0](https://huggingface.co/avsolatorio/GIST-large-Embedding-v0))
153
311
  - **`MXBAI_LARGE`**: Mixbreas's Large embedding model from Hugging Face ([mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1))
154
312
  - **`QWEN2_1_5B`**: Alibaba's Qwen2-1.5B instruction-tuned embedding model from Hugging Face ([Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct))
313
+ - **`BGE_BASE`**: BAAI's BGE-Base instruction-tuned embedding model from Hugging Face ([BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5))
155
314
 
315
+ **Instruction Support:**
316
+
317
+ Some models support instruction-following for better task-specific embeddings. You can check if a model supports instructions
318
+ using the `supports_instructions` attribute.
156
319
 
157
320
  Examples:
158
321
  >>> PretrainedEmbeddingModel.CDE_SMALL
159
322
  PretrainedEmbeddingModel({name: CDE_SMALL, embedding_dim: 768, max_seq_length: 512})
160
323
 
324
+ >>> # Using instruction with an instruction-supporting model
325
+ >>> model = PretrainedEmbeddingModel.E5_LARGE
326
+ >>> embeddings = model.embed("Hello world", prompt="Represent this sentence for retrieval:")
327
+
161
328
  Attributes:
162
329
  name: Name of the pretrained embedding model
163
330
  embedding_dim: Dimension of the embeddings that are generated by the model
164
331
  max_seq_length: Maximum input length (in tokens not characters) that this model can process. Inputs that are longer will be truncated during the embedding process
165
332
  uses_context: Whether the pretrained embedding model uses context
333
+ supports_instructions: Whether this model supports instruction-following
166
334
  """
167
335
 
168
336
  # Define descriptors for model access with IDE autocomplete
@@ -175,17 +343,21 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
175
343
  GIST_LARGE = _ModelDescriptor("GIST_LARGE")
176
344
  MXBAI_LARGE = _ModelDescriptor("MXBAI_LARGE")
177
345
  QWEN2_1_5B = _ModelDescriptor("QWEN2_1_5B")
346
+ BGE_BASE = _ModelDescriptor("BGE_BASE")
178
347
 
179
- _model_name: PretrainedEmbeddingModelName
348
+ name: PretrainedEmbeddingModelName
180
349
 
181
350
  def __init__(self, metadata: PretrainedEmbeddingModelMetadata):
182
351
  # for internal use only, do not document
183
- self._model_name = metadata.name
352
+ self.name = metadata["name"]
184
353
  super().__init__(
185
- name=metadata.name.value,
186
- embedding_dim=metadata.embedding_dim,
187
- max_seq_length=metadata.max_seq_length,
188
- uses_context=metadata.uses_context,
354
+ name=metadata["name"],
355
+ embedding_dim=metadata["embedding_dim"],
356
+ max_seq_length=metadata["max_seq_length"],
357
+ uses_context=metadata["uses_context"],
358
+ supports_instructions=(
359
+ bool(metadata["supports_instructions"]) if "supports_instructions" in metadata else False
360
+ ),
189
361
  )
190
362
 
191
363
  def __eq__(self, other) -> bool:
@@ -202,19 +374,24 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
202
374
  Returns:
203
375
  A list of all pretrained embedding models available in the OrcaCloud
204
376
  """
205
- return [cls(metadata) for metadata in list_pretrained_embedding_models()]
377
+ return [cls(metadata) for metadata in orca_api.GET("/pretrained_embedding_model")]
206
378
 
207
379
  _instances: dict[str, PretrainedEmbeddingModel] = {}
208
380
 
209
381
  @classmethod
210
- def _get(cls, name: PretrainedEmbeddingModelName | str) -> PretrainedEmbeddingModel:
382
+ def _get(cls, name: PretrainedEmbeddingModelName) -> PretrainedEmbeddingModel:
211
383
  # for internal use only, do not document - we want people to use dot notation to get the model
212
- if str(name) not in cls._instances:
213
- cls._instances[str(name)] = cls(get_pretrained_embedding_model(cast(PretrainedEmbeddingModelName, name)))
214
- return cls._instances[str(name)]
384
+ cache_key = str(name)
385
+ if cache_key not in cls._instances:
386
+ metadata = orca_api.GET(
387
+ "/pretrained_embedding_model/{model_name}",
388
+ params={"model_name": name},
389
+ )
390
+ cls._instances[cache_key] = cls(metadata)
391
+ return cls._instances[cache_key]
215
392
 
216
393
  @classmethod
217
- def open(cls, name: str) -> PretrainedEmbeddingModel:
394
+ def open(cls, name: PretrainedEmbeddingModelName) -> PretrainedEmbeddingModel:
218
395
  """
219
396
  Open an embedding model by name.
220
397
 
@@ -231,9 +408,9 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
231
408
  >>> model = PretrainedEmbeddingModel.open("GTE_BASE")
232
409
  """
233
410
  try:
234
- # Use getattr to access the descriptor which will initialize the model
235
- return getattr(cls, name)
236
- except AttributeError:
411
+ # Always use the _get method which handles caching properly
412
+ return cls._get(name)
413
+ except (KeyError, AttributeError):
237
414
  raise ValueError(f"Unknown model name: {name}")
238
415
 
239
416
  @classmethod
@@ -247,7 +424,7 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
247
424
  Returns:
248
425
  True if the pretrained embedding model exists, False otherwise
249
426
  """
250
- return name in PretrainedEmbeddingModelName
427
+ return name in get_args(PretrainedEmbeddingModelName)
251
428
 
252
429
  @overload
253
430
  def finetune(
@@ -258,7 +435,7 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
258
435
  eval_datasource: Datasource | None = None,
259
436
  label_column: str = "label",
260
437
  value_column: str = "value",
261
- training_method: EmbeddingFinetuningMethod | str = EmbeddingFinetuningMethod.CLASSIFICATION,
438
+ training_method: EmbeddingFinetuningMethod = "classification",
262
439
  training_args: dict | None = None,
263
440
  if_exists: CreateMode = "error",
264
441
  background: Literal[True],
@@ -274,7 +451,7 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
274
451
  eval_datasource: Datasource | None = None,
275
452
  label_column: str = "label",
276
453
  value_column: str = "value",
277
- training_method: EmbeddingFinetuningMethod | str = EmbeddingFinetuningMethod.CLASSIFICATION,
454
+ training_method: EmbeddingFinetuningMethod = "classification",
278
455
  training_args: dict | None = None,
279
456
  if_exists: CreateMode = "error",
280
457
  background: Literal[False] = False,
@@ -289,7 +466,7 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
289
466
  eval_datasource: Datasource | None = None,
290
467
  label_column: str = "label",
291
468
  value_column: str = "value",
292
- training_method: EmbeddingFinetuningMethod | str = EmbeddingFinetuningMethod.CLASSIFICATION,
469
+ training_method: EmbeddingFinetuningMethod = "classification",
293
470
  training_args: dict | None = None,
294
471
  if_exists: CreateMode = "error",
295
472
  background: bool = False,
@@ -329,32 +506,35 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
329
506
  elif exists and if_exists == "open":
330
507
  existing = FinetunedEmbeddingModel.open(name)
331
508
 
332
- if existing.base_model_name != self._model_name:
509
+ if existing.base_model_name != self.name:
333
510
  raise ValueError(f"Finetuned embedding model '{name}' already exists, but with different base model")
334
511
 
335
512
  return existing
336
513
 
337
514
  from .memoryset import LabeledMemoryset
338
515
 
339
- train_datasource_id = train_datasource.id if isinstance(train_datasource, Datasource) else None
340
- train_memoryset_id = train_datasource.id if isinstance(train_datasource, LabeledMemoryset) else None
341
- assert train_datasource_id is not None or train_memoryset_id is not None
342
- res = create_finetuned_embedding_model(
343
- body=FinetuneEmbeddingModelRequest(
344
- name=name,
345
- base_model=self._model_name,
346
- train_memoryset_id=train_memoryset_id,
347
- train_datasource_id=train_datasource_id,
348
- eval_datasource_id=eval_datasource.id if eval_datasource is not None else None,
349
- label_column=label_column,
350
- value_column=value_column,
351
- training_method=EmbeddingFinetuningMethod(training_method),
352
- training_args=(FinetuneEmbeddingModelRequestTrainingArgs.from_dict(training_args or {})),
353
- ),
516
+ payload: FinetuneEmbeddingModelRequest = {
517
+ "name": name,
518
+ "base_model": self.name,
519
+ "label_column": label_column,
520
+ "value_column": value_column,
521
+ "training_method": training_method,
522
+ "training_args": training_args or {},
523
+ }
524
+ if isinstance(train_datasource, Datasource):
525
+ payload["train_datasource_name_or_id"] = train_datasource.id
526
+ elif isinstance(train_datasource, LabeledMemoryset):
527
+ payload["train_memoryset_name_or_id"] = train_datasource.id
528
+ if eval_datasource is not None:
529
+ payload["eval_datasource_name_or_id"] = eval_datasource.id
530
+
531
+ res = orca_api.POST(
532
+ "/finetuned_embedding_model",
533
+ json=payload,
354
534
  )
355
535
  job = Job(
356
- res.finetuning_task_id,
357
- lambda: FinetunedEmbeddingModel.open(res.id),
536
+ res["finetuning_task_id"],
537
+ lambda: FinetunedEmbeddingModel.open(res["id"]),
358
538
  )
359
539
  return job if background else job.result()
360
540
 
@@ -374,22 +554,27 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
374
554
  """
375
555
 
376
556
  id: str
557
+ name: str
377
558
  created_at: datetime
378
559
  updated_at: datetime
560
+ base_model_name: PretrainedEmbeddingModelName
379
561
  _status: Status
380
562
 
381
563
  def __init__(self, metadata: FinetunedEmbeddingModelMetadata):
382
564
  # for internal use only, do not document
383
- self.id = metadata.id
384
- self.created_at = metadata.created_at
385
- self.updated_at = metadata.updated_at
386
- self.base_model_name = metadata.base_model
387
- self._status = Status(metadata.finetuning_status.value)
565
+ self.id = metadata["id"]
566
+ self.name = metadata["name"]
567
+ self.created_at = datetime.fromisoformat(metadata["created_at"])
568
+ self.updated_at = datetime.fromisoformat(metadata["updated_at"])
569
+ self.base_model_name = metadata["base_model"]
570
+ self._status = Status(metadata["finetuning_status"])
571
+
388
572
  super().__init__(
389
- name=metadata.name,
390
- embedding_dim=metadata.embedding_dim,
391
- max_seq_length=metadata.max_seq_length,
392
- uses_context=metadata.uses_context,
573
+ name=metadata["name"],
574
+ embedding_dim=metadata["embedding_dim"],
575
+ max_seq_length=metadata["max_seq_length"],
576
+ uses_context=metadata["uses_context"],
577
+ supports_instructions=self.base_model.supports_instructions,
393
578
  )
394
579
 
395
580
  def __eq__(self, other) -> bool:
@@ -401,7 +586,7 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
401
586
  f" name: {self.name},\n"
402
587
  f" embedding_dim: {self.embedding_dim},\n"
403
588
  f" max_seq_length: {self.max_seq_length},\n"
404
- f" base_model: PretrainedEmbeddingModel.{self.base_model_name.value}\n"
589
+ f" base_model: PretrainedEmbeddingModel.{self.base_model_name}\n"
405
590
  "})"
406
591
  )
407
592
 
@@ -418,7 +603,7 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
418
603
  Returns:
419
604
  A list of all finetuned embedding model handles in the OrcaCloud
420
605
  """
421
- return [cls(metadata) for metadata in list_finetuned_embedding_models()]
606
+ return [cls(metadata) for metadata in orca_api.GET("/finetuned_embedding_model")]
422
607
 
423
608
  @classmethod
424
609
  def open(cls, name: str) -> FinetunedEmbeddingModel:
@@ -434,7 +619,11 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
434
619
  Raises:
435
620
  LookupError: If the finetuned embedding model does not exist
436
621
  """
437
- return cls(get_finetuned_embedding_model(name))
622
+ metadata = orca_api.GET(
623
+ "/finetuned_embedding_model/{name_or_id}",
624
+ params={"name_or_id": name},
625
+ )
626
+ return cls(metadata)
438
627
 
439
628
  @classmethod
440
629
  def exists(cls, name_or_id: str) -> bool:
@@ -465,7 +654,10 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
465
654
  LookupError: If the finetuned embedding model does not exist and `if_not_exists` is `"error"`
466
655
  """
467
656
  try:
468
- delete_finetuned_embedding_model(name_or_id)
657
+ orca_api.DELETE(
658
+ "/finetuned_embedding_model/{name_or_id}",
659
+ params={"name_or_id": name_or_id},
660
+ )
469
661
  except (LookupError, RuntimeError):
470
662
  if if_not_exists == "error":
471
663
  raise
@@ -1,10 +1,12 @@
1
1
  import logging
2
+ from typing import get_args
2
3
  from uuid import uuid4
3
4
 
4
5
  import pytest
5
6
 
6
7
  from .datasource import Datasource
7
8
  from .embedding_model import (
9
+ ClassificationMetrics,
8
10
  FinetunedEmbeddingModel,
9
11
  PretrainedEmbeddingModel,
10
12
  PretrainedEmbeddingModelName,
@@ -30,16 +32,16 @@ def test_open_pretrained_model_unauthenticated(unauthenticated):
30
32
 
31
33
  def test_open_pretrained_model_not_found():
32
34
  with pytest.raises(LookupError):
33
- PretrainedEmbeddingModel._get("INVALID_MODEL")
35
+ PretrainedEmbeddingModel._get("INVALID_MODEL") # type: ignore
34
36
 
35
37
 
36
38
  def test_all_pretrained_models():
37
39
  models = PretrainedEmbeddingModel.all()
38
40
  assert len(models) > 1
39
- if len(models) != len(PretrainedEmbeddingModelName):
41
+ if len(models) != len(get_args(PretrainedEmbeddingModelName)):
40
42
  logging.warning("Please regenerate the SDK client! Some pretrained model names are not exposed yet.")
41
43
  model_names = [m.name for m in models]
42
- assert all(enum_member in model_names for enum_member in PretrainedEmbeddingModelName.__members__)
44
+ assert all(m in model_names for m in get_args(PretrainedEmbeddingModelName))
43
45
 
44
46
 
45
47
  def test_embed_text():
@@ -55,6 +57,13 @@ def test_embed_text_unauthenticated(unauthenticated):
55
57
  PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
56
58
 
57
59
 
60
+ def test_evaluate_pretrained_model(datasource: Datasource):
61
+ metrics = PretrainedEmbeddingModel.GTE_BASE.evaluate(datasource=datasource, label_column="label")
62
+ assert metrics is not None
63
+ assert isinstance(metrics, ClassificationMetrics)
64
+ assert metrics.accuracy > 0.5
65
+
66
+
58
67
  @pytest.fixture(scope="session")
59
68
  def finetuned_model(datasource) -> FinetunedEmbeddingModel:
60
69
  return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
@@ -83,18 +92,14 @@ def test_finetune_model_with_memoryset(readonly_memoryset: LabeledMemoryset):
83
92
 
84
93
  def test_finetune_model_already_exists_error(datasource: Datasource, finetuned_model):
85
94
  with pytest.raises(ValueError):
86
- PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, value_column="text")
95
+ PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
87
96
 
88
97
 
89
98
  def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_model):
90
99
  with pytest.raises(ValueError):
91
- PretrainedEmbeddingModel.GTE_BASE.finetune(
92
- "test_finetuned_model", datasource, if_exists="open", value_column="text"
93
- )
100
+ PretrainedEmbeddingModel.GTE_BASE.finetune("test_finetuned_model", datasource, if_exists="open")
94
101
 
95
- new_model = PretrainedEmbeddingModel.DISTILBERT.finetune(
96
- "test_finetuned_model", datasource, if_exists="open", value_column="text"
97
- )
102
+ new_model = PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, if_exists="open")
98
103
  assert new_model is not None
99
104
  assert new_model.name == "test_finetuned_model"
100
105
  assert new_model.base_model == PretrainedEmbeddingModel.DISTILBERT
@@ -105,9 +110,7 @@ def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_
105
110
 
106
111
  def test_finetune_model_unauthenticated(unauthenticated, datasource: Datasource):
107
112
  with pytest.raises(ValueError, match="Invalid API key"):
108
- PretrainedEmbeddingModel.DISTILBERT.finetune(
109
- "test_finetuned_model_unauthenticated", datasource, value_column="text"
110
- )
113
+ PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_unauthenticated", datasource)
111
114
 
112
115
 
113
116
  def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_model: FinetunedEmbeddingModel):
@@ -166,7 +169,7 @@ def test_drop_finetuned_model(datasource: Datasource):
166
169
 
167
170
  def test_drop_finetuned_model_unauthenticated(unauthenticated, datasource: Datasource):
168
171
  with pytest.raises(ValueError, match="Invalid API key"):
169
- PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource, value_column="text")
172
+ PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
170
173
 
171
174
 
172
175
  def test_drop_finetuned_model_not_found():
@@ -179,3 +182,18 @@ def test_drop_finetuned_model_not_found():
179
182
  def test_drop_finetuned_model_unauthorized(unauthorized, finetuned_model: FinetunedEmbeddingModel):
180
183
  with pytest.raises(LookupError):
181
184
  FinetunedEmbeddingModel.drop(finetuned_model.id)
185
+
186
+
187
+ def test_supports_instructions():
188
+ model = PretrainedEmbeddingModel.GTE_BASE
189
+ assert not model.supports_instructions
190
+
191
+ instruction_model = PretrainedEmbeddingModel.BGE_BASE
192
+ assert instruction_model.supports_instructions
193
+
194
+
195
+ def test_use_explicit_instruction_prompt():
196
+ model = PretrainedEmbeddingModel.BGE_BASE
197
+ assert model.supports_instructions
198
+ input = "Hello world"
199
+ assert model.embed(input, prompt="Represent this sentence for sentiment retrieval:") != model.embed(input)